Skip to content

Commit

Permalink
Convert functions in art3d to use transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AnsonTran committed May 8, 2024
1 parent 148f1b7 commit c99765f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 50 deletions.
85 changes: 36 additions & 49 deletions lib/mpl_toolkits/mplot3d/art3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,11 @@ def set_3d_properties(self, z=0, zdir='z'):
@artist.allow_rasterization
def draw(self, renderer):
position3d = np.array((self._x, self._y, self._z))
proj = proj3d._proj_trans_points(
[position3d, position3d + self._dir_vec], self.axes.M)
dx = proj[0][1] - proj[0][0]
dy = proj[1][1] - proj[1][0]
proj = self.axes.M.transform([position3d, position3d + self._dir_vec])
dx = proj[1][0] - proj[0][0]
dy = proj[1][1] - proj[0][1]
angle = math.degrees(math.atan2(dy, dx))
with cbook._setattr_cm(self, _x=proj[0][0], _y=proj[1][0],
with cbook._setattr_cm(self, _x=proj[0][0], _y=proj[0][1],
_rotation=_norm_text_angle(angle)):
mtext.Text.draw(self, renderer)
self.stale = False
Expand Down Expand Up @@ -267,8 +266,8 @@ def get_data_3d(self):
@artist.allow_rasterization
def draw(self, renderer):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_data(xs, ys)
points = self.axes.M.transform(np.column_stack((xs3d, ys3d, zs3d)))
self.set_data(points[:, 0], points[:, 1])
super().draw(renderer)
self.stale = False

Expand Down Expand Up @@ -349,11 +348,11 @@ class Collection3D(Collection):

def do_3d_projection(self):
"""Project the points according to renderer matrix."""
xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M)
for vs, _ in self._3dverts_codes]
self._paths = [mpath.Path(np.column_stack([xs, ys]), cs)
for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)]
zs = np.concatenate([zs for _, _, zs in xyzs_list])
path_vertices = [self.axes.M.transform(vs) for vs, _ in self._3dverts_codes]
self._paths = [mpath.Path(vertices[:, :2], codes)
for (vertices, (_, codes))
in zip(path_vertices, self._3dverts_codes)]
zs = np.concatenate(path_vertices)[:, 2]
return zs.min() if len(zs) else 1e9


Expand Down Expand Up @@ -390,15 +389,14 @@ def do_3d_projection(self):
"""
Project the points according to renderer matrix.
"""
xyslist = [proj3d._proj_trans_points(points, self.axes.M)
for points in self._segments3d]
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
segments_3d = [self.axes.M.transform(segment) for segment in self._segments3d]
segments_2d = [segment[:, :2] for segment in segments_3d]
LineCollection.set_segments(self, segments_2d)

# FIXME
minz = 1e9
for xs, ys, zs in xyslist:
minz = min(minz, min(zs))
for segment in segments_3d:
minz = min(minz, segment[0][2], segment[1][2])
return minz


Expand Down Expand Up @@ -456,12 +454,10 @@ def get_path(self):
return self._path2d

def do_3d_projection(self):
s = self._segment3d
xs, ys, zs = zip(*s)
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._path2d = mpath.Path(np.column_stack([vxs, vys]))
return min(vzs)
segments = self.axes.M.transform(self._segment3d)
self._path2d = mpath.Path(segments[:, :2])

return min(segments[:, 2])


class PathPatch3D(Patch3D):
Expand Down Expand Up @@ -503,12 +499,10 @@ def set_3d_properties(self, path, zs=0, zdir='z'):
self._code3d = path.codes

def do_3d_projection(self):
s = self._segment3d
xs, ys, zs = zip(*s)
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._path2d = mpath.Path(np.column_stack([vxs, vys]), self._code3d)
return min(vzs)
segments = self.axes.M.transform(self._segment3d)
self._path2d = mpath.Path(segments[:, :2], self._code3d)

return min(segments[:, 2])


def _get_patch_verts(patch):
Expand Down Expand Up @@ -610,14 +604,13 @@ def set_3d_properties(self, zs, zdir):
self.stale = True

def do_3d_projection(self):
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._vzs = vzs
super().set_offsets(np.column_stack([vxs, vys]))
points = self.axes.M.transform(np.column_stack(self._offsets3d))
super().set_offsets(points[:, :2])

self._vzs = points[:, 2]

if vzs.size > 0:
return min(vzs)
if self._vzs.size > 0:
return min(self._vzs)
else:
return np.nan

Expand Down Expand Up @@ -751,37 +744,31 @@ def set_depthshade(self, depthshade):
self.stale = True

def do_3d_projection(self):
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
# Sort the points based on z coordinates
# Performance optimization: Create a sorted index array and reorder
# points and point properties according to the index array
z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1]
self._vzs = vzs
points = self.axes.M.transform(np.column_stack(self._offsets3d))
z_markers_idx = self._z_markers_idx = np.argsort(points[:, 2])[::-1]
self._vzs = points[:, 2]

# we have to special case the sizes because of code in collections.py
# as the draw method does
# self.set_sizes(self._sizes, self.figure.dpi)
# so we cannot rely on doing the sorting on the way out via get_*

if len(self._sizes3d) > 1:
self._sizes = self._sizes3d[z_markers_idx]

if len(self._linewidths3d) > 1:
self._linewidths = self._linewidths3d[z_markers_idx]

PathCollection.set_offsets(self, np.column_stack((vxs, vys)))
PathCollection.set_offsets(self, points[:, :2])

# Re-order items
vzs = vzs[z_markers_idx]
vxs = vxs[z_markers_idx]
vys = vys[z_markers_idx]
points = points[z_markers_idx]

# Store ordered offset for drawing purpose
self._offset_zordered = np.column_stack((vxs, vys))

return np.min(vzs) if vzs.size else np.nan
self._offset_zordered = points[:, :2]
return np.min(self._vzs) if self._vzs.size else np.nan

@contextmanager
def _use_zordered_offset(self):
Expand Down
3 changes: 2 additions & 1 deletion lib/mpl_toolkits/mplot3d/tests/test_axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def test_axes3d_repr():
"title={'center': 'title'}, xlabel='x', ylabel='y', zlabel='z'>")


@mpl3d_image_comparison(['axes3d_primary_views.png'], style='mpl20')
@mpl3d_image_comparison(['axes3d_primary_views.png'], style='mpl20', tol=0 if
platform.machine() == 'x86_64' else 0.045)
def test_axes3d_primary_views():
# (elev, azim, roll)
views = [(90, -90, 0), # XY
Expand Down

0 comments on commit c99765f

Please sign in to comment.