Skip to content

Commit

Permalink
Convert existing mplot3d functions to use transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AnsonTran committed May 10, 2024
1 parent 2a5b1e3 commit 1c1eb22
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 98 deletions.
122 changes: 53 additions & 69 deletions lib/mpl_toolkits/mplot3d/art3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Collection, LineCollection, PolyCollection, PatchCollection, PathCollection)
from matplotlib.colors import Normalize
from matplotlib.patches import Patch
from . import proj3d


def _norm_angle(a):
Expand Down Expand Up @@ -148,12 +147,11 @@ def set_3d_properties(self, z=0, zdir='z'):
@artist.allow_rasterization
def draw(self, renderer):
position3d = np.array((self._x, self._y, self._z))
proj = proj3d._proj_trans_points(
[position3d, position3d + self._dir_vec], self.axes.M)
dx = proj[0][1] - proj[0][0]
dy = proj[1][1] - proj[1][0]
proj = self.axes.M.transform([position3d, position3d + self._dir_vec])
dx = proj[1][0] - proj[0][0]
dy = proj[1][1] - proj[0][1]
angle = math.degrees(math.atan2(dy, dx))
with cbook._setattr_cm(self, _x=proj[0][0], _y=proj[1][0],
with cbook._setattr_cm(self, _x=proj[0][0], _y=proj[0][1],
_rotation=_norm_text_angle(angle)):
mtext.Text.draw(self, renderer)
self.stale = False
Expand Down Expand Up @@ -267,8 +265,8 @@ def get_data_3d(self):
@artist.allow_rasterization
def draw(self, renderer):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_data(xs, ys)
points = self.axes.M.transform(np.column_stack((xs3d, ys3d, zs3d)))
self.set_data(points[:, 0], points[:, 1])
super().draw(renderer)
self.stale = False

Expand Down Expand Up @@ -349,11 +347,11 @@ class Collection3D(Collection):

def do_3d_projection(self):
"""Project the points according to renderer matrix."""
xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M)
for vs, _ in self._3dverts_codes]
self._paths = [mpath.Path(np.column_stack([xs, ys]), cs)
for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)]
zs = np.concatenate([zs for _, _, zs in xyzs_list])
path_vertices = [self.axes.M.transform(vs) for vs, _ in self._3dverts_codes]
self._paths = [mpath.Path(vertices[:, :2], codes)
for (vertices, (_, codes))
in zip(path_vertices, self._3dverts_codes)]
zs = np.concatenate(path_vertices)[:, 2]
return zs.min() if len(zs) else 1e9


Expand Down Expand Up @@ -390,15 +388,14 @@ def do_3d_projection(self):
"""
Project the points according to renderer matrix.
"""
xyslist = [proj3d._proj_trans_points(points, self.axes.M)
for points in self._segments3d]
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
segments_3d = [self.axes.M.transform(segment) for segment in self._segments3d]
segments_2d = [segment[:, :2] for segment in segments_3d]
LineCollection.set_segments(self, segments_2d)

# FIXME
minz = 1e9
for xs, ys, zs in xyslist:
minz = min(minz, min(zs))
for segment in segments_3d:
minz = min(minz, segment[0][2], segment[1][2])
return minz


Expand Down Expand Up @@ -456,12 +453,10 @@ def get_path(self):
return self._path2d

def do_3d_projection(self):
s = self._segment3d
xs, ys, zs = zip(*s)
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._path2d = mpath.Path(np.column_stack([vxs, vys]))
return min(vzs)
segments = self.axes.M.transform(self._segment3d)
self._path2d = mpath.Path(segments[:, :2])

return min(segments[:, 2])


class PathPatch3D(Patch3D):
Expand Down Expand Up @@ -503,12 +498,10 @@ def set_3d_properties(self, path, zs=0, zdir='z'):
self._code3d = path.codes

def do_3d_projection(self):
s = self._segment3d
xs, ys, zs = zip(*s)
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._path2d = mpath.Path(np.column_stack([vxs, vys]), self._code3d)
return min(vzs)
segments = self.axes.M.transform(self._segment3d)
self._path2d = mpath.Path(segments[:, :2], self._code3d)

return min(segments[:, 2])


def _get_patch_verts(patch):
Expand Down Expand Up @@ -610,14 +603,13 @@ def set_3d_properties(self, zs, zdir):
self.stale = True

def do_3d_projection(self):
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
self._vzs = vzs
super().set_offsets(np.column_stack([vxs, vys]))
points = self.axes.M.transform(np.column_stack(self._offsets3d))
super().set_offsets(points[:, :2])

if vzs.size > 0:
return min(vzs)
self._vzs = points[:, 2]

if self._vzs.size > 0:
return min(self._vzs)
else:
return np.nan

Expand Down Expand Up @@ -751,37 +743,31 @@ def set_depthshade(self, depthshade):
self.stale = True

def do_3d_projection(self):
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
# Sort the points based on z coordinates
# Performance optimization: Create a sorted index array and reorder
# points and point properties according to the index array
z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1]
self._vzs = vzs
points = self.axes.M.transform(np.column_stack(self._offsets3d))
z_markers_idx = self._z_markers_idx = np.argsort(points[:, 2])[::-1]
self._vzs = points[:, 2]

# we have to special case the sizes because of code in collections.py
# as the draw method does
# self.set_sizes(self._sizes, self.figure.dpi)
# so we cannot rely on doing the sorting on the way out via get_*

if len(self._sizes3d) > 1:
self._sizes = self._sizes3d[z_markers_idx]

if len(self._linewidths3d) > 1:
self._linewidths = self._linewidths3d[z_markers_idx]

PathCollection.set_offsets(self, np.column_stack((vxs, vys)))
PathCollection.set_offsets(self, points[:, :2])

# Re-order items
vzs = vzs[z_markers_idx]
vxs = vxs[z_markers_idx]
vys = vys[z_markers_idx]
points = points[z_markers_idx]

# Store ordered offset for drawing purpose
self._offset_zordered = np.column_stack((vxs, vys))

return np.min(vzs) if vzs.size else np.nan
self._offset_zordered = points[:, :2]
return np.min(self._vzs) if self._vzs.size else np.nan

@contextmanager
def _use_zordered_offset(self):
Expand Down Expand Up @@ -954,8 +940,7 @@ def get_vector(self, segments3d):
xs, ys, zs = np.vstack(segments3d).T
else: # vstack can't stack zero arrays.
xs, ys, zs = [], [], []
ones = np.ones(len(xs))
self._vec = np.array([xs, ys, zs, ones])
self._vec = np.array([xs, ys, zs])

indices = [0, *np.cumsum([len(segment) for segment in segments3d])]
self._segslices = [*map(slice, indices[:-1], indices[1:])]
Expand Down Expand Up @@ -1020,27 +1005,28 @@ def do_3d_projection(self):
self._facecolor3d = self._facecolors
if self._edge_is_mapped:
self._edgecolor3d = self._edgecolors
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]

verts = self.axes.M.transform(np.column_stack(self._vec))
verts_slices = [verts[sl] for sl in self._segslices]

# This extra fuss is to re-order face / edge colors
cface = self._facecolor3d
cedge = self._edgecolor3d
if len(cface) != len(xyzlist):
cface = cface.repeat(len(xyzlist), axis=0)
if len(cedge) != len(xyzlist):

if len(cface) != len(verts_slices):
cface = cface.repeat(len(verts_slices), axis=0)
if len(cedge) != len(verts_slices):
if len(cedge) == 0:
cedge = cface
else:
cedge = cedge.repeat(len(xyzlist), axis=0)
cedge = cedge.repeat(len(verts_slices), axis=0)

if xyzlist:
# sort by depth (furthest drawn first)
if verts_slices:
z_segments_2d = sorted(
((self._zsortfunc(zs), np.column_stack([xs, ys]), fc, ec, idx)
for idx, ((xs, ys, zs), fc, ec)
in enumerate(zip(xyzlist, cface, cedge))),
key=lambda x: x[0], reverse=True)
((self._zsortfunc(verts[:, 2]), verts[:, :2], fc, ec, idx)
for idx, (verts, fc, ec)
in enumerate(zip(verts_slices, cface, cedge))),
key=lambda x: x[0], reverse=True)

_, segments_2d, self._facecolors2d, self._edgecolors2d, idxs = \
zip(*z_segments_2d)
Expand All @@ -1061,14 +1047,12 @@ def do_3d_projection(self):

# Return zorder value
if self._sort_zpos is not None:
zvec = np.array([[0], [0], [self._sort_zpos], [1]])
ztrans = proj3d._proj_transform_vec(zvec, self.axes.M)
return ztrans[2][0]
elif tzs.size > 0:
return self.axes.M.transform([0, 0, self._sort_zpos])[2]
elif len(verts) > 0:
# FIXME: Some results still don't look quite right.
# In particular, examine contourf3d_demo2.py
# with az = -54 and elev = -45.
return np.min(tzs)
return np.min(verts[:, 2])
else:
return np.nan

Expand Down
34 changes: 14 additions & 20 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import art3d
from . import proj3d
from . import axis3d
from . import transform3d


@_docstring.interpd
Expand Down Expand Up @@ -158,8 +159,8 @@ def __init__(
super().set_axis_off()
# Enable drawing of axes by Axes3D class
self.set_axis_on()
self.M = None
self.invM = None
self.M = mtransforms.IdentityTransform(dims=3)
self.invM = mtransforms.IdentityTransform(dims=3)

self._view_margin = 1/48 # default value to match mpl3.8
self.autoscale_view()
Expand Down Expand Up @@ -236,7 +237,7 @@ def _transformed_cube(self, vals):
(maxx, miny, maxz),
(maxx, maxy, maxz),
(minx, maxy, maxz)]
return proj3d._proj_points(xyzs, self.M)
return self.M.transform(xyzs)

def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
"""
Expand Down Expand Up @@ -423,7 +424,7 @@ def draw(self, renderer):

# add the projection matrix to the renderer
self.M = self.get_proj()
self.invM = np.linalg.inv(self.M)
self.invM = self.M.inverted()

collections_and_patches = (
artist for artist in self._children
Expand Down Expand Up @@ -1200,12 +1201,8 @@ def get_proj(self):

# Transform to uniform world coordinates 0-1, 0-1, 0-1
box_aspect = self._roll_to_vertical(self._box_aspect)
worldM = proj3d.world_transformation(
*self.get_xlim3d(),
*self.get_ylim3d(),
*self.get_zlim3d(),
pb_aspect=box_aspect,
)
worldM = transform3d.WorldTransform(*self.get_xlim3d(), *self.get_ylim3d(),
*self.get_zlim3d(), pb_aspect=box_aspect)

# Look into the middle of the world coordinates:
R = 0.5 * box_aspect
Expand Down Expand Up @@ -1238,21 +1235,18 @@ def get_proj(self):
# Generate the view and projection transformation matrices
if self._focal_length == np.inf:
# Orthographic projection
viewM = proj3d._view_transformation_uvw(u, v, w, eye)
projM = proj3d._ortho_transformation(-self._dist, self._dist)
viewM = transform3d.ViewTransform(u, v, w, eye)
projM = transform3d.OrthographicTransform(-self._dist, self._dist)
else:
# Perspective projection
# Scale the eye dist to compensate for the focal length zoom effect
eye_focal = R + self._dist * ps * self._focal_length
viewM = proj3d._view_transformation_uvw(u, v, w, eye_focal)
projM = proj3d._persp_transformation(-self._dist,
self._dist,
self._focal_length)
viewM = transform3d.ViewTransform(u, v, w, eye_focal)
projM = transform3d.PerspectiveTransform(-self._dist, self._dist,
self._focal_length)

# Combine all the transformation matrices to get the final projection
M0 = np.dot(viewM, worldM)
M = np.dot(projM, M0)
return M
return worldM + viewM + projM

def mouse_init(self, rotate_btn=1, pan_btn=2, zoom_btn=3):
"""
Expand Down Expand Up @@ -1459,7 +1453,7 @@ def _calc_coord(self, xv, yv, renderer=None):
zv = -1 / self._focal_length

# Convert point on view plane to data coordinates
p1 = np.array(proj3d.inv_transform(xv, yv, zv, self.invM)).ravel()
p1 = self.invM.transform([xv, yv, zv])

# Get the vector from the camera to the point on the view plane
vec = self._get_camera_loc() - p1
Expand Down

0 comments on commit 1c1eb22

Please sign in to comment.