Skip to content

Commit

Permalink
Convert axes3D to use transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AnsonTran committed May 7, 2024
1 parent 5f6d8e3 commit 711ce8a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
30 changes: 12 additions & 18 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import art3d
from . import proj3d
from . import axis3d
from . import transform3d


@_docstring.interpd
Expand Down Expand Up @@ -236,7 +237,7 @@ def _transformed_cube(self, vals):
(maxx, miny, maxz),
(maxx, maxy, maxz),
(minx, maxy, maxz)]
return proj3d._proj_points(xyzs, self.M)
return self.M.transform(xyzs)

def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
"""
Expand Down Expand Up @@ -423,7 +424,7 @@ def draw(self, renderer):

# add the projection matrix to the renderer
self.M = self.get_proj()
self.invM = np.linalg.inv(self.M)
self.invM = self.M.inverted()

collections_and_patches = (
artist for artist in self._children
Expand Down Expand Up @@ -1200,12 +1201,8 @@ def get_proj(self):

# Transform to uniform world coordinates 0-1, 0-1, 0-1
box_aspect = self._roll_to_vertical(self._box_aspect)
worldM = proj3d.world_transformation(
*self.get_xlim3d(),
*self.get_ylim3d(),
*self.get_zlim3d(),
pb_aspect=box_aspect,
)
worldM = transform3d.WorldTransform(*self.get_xlim3d(), *self.get_ylim3d(),
*self.get_zlim3d(), pb_aspect=box_aspect)

# Look into the middle of the world coordinates:
R = 0.5 * box_aspect
Expand Down Expand Up @@ -1238,21 +1235,18 @@ def get_proj(self):
# Generate the view and projection transformation matrices
if self._focal_length == np.inf:
# Orthographic projection
viewM = proj3d._view_transformation_uvw(u, v, w, eye)
projM = proj3d._ortho_transformation(-self._dist, self._dist)
viewM = transform3d.ViewTransform(u, v, w, eye)
projM = transform3d.OrthographicTransform(-self._dist, self._dist)
else:
# Perspective projection
# Scale the eye dist to compensate for the focal length zoom effect
eye_focal = R + self._dist * ps * self._focal_length
viewM = proj3d._view_transformation_uvw(u, v, w, eye_focal)
projM = proj3d._persp_transformation(-self._dist,
self._dist,
self._focal_length)
viewM = transform3d.ViewTransform(u, v, w, eye_focal)
projM = transform3d.PerspectiveTransform(-self._dist, self._dist,
self._focal_length)

# Combine all the transformation matrices to get the final projection
M0 = np.dot(viewM, worldM)
M = np.dot(projM, M0)
return M
return worldM + viewM + projM

def mouse_init(self, rotate_btn=1, pan_btn=2, zoom_btn=3):
"""
Expand Down Expand Up @@ -1459,7 +1453,7 @@ def _calc_coord(self, xv, yv, renderer=None):
zv = -1 / self._focal_length

# Convert point on view plane to data coordinates
p1 = np.array(proj3d.inv_transform(xv, yv, zv, self.invM)).ravel()
p1 = self.invM.transform([xv, yv, zv])

# Get the vector from the camera to the point on the view plane
vec = self._get_camera_loc() - p1
Expand Down
3 changes: 1 addition & 2 deletions lib/mpl_toolkits/mplot3d/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import matplotlib.transforms as mtransforms


class NonAffine3D(mtransforms.Transform):
class NonAffine3D(mtransforms.Affine3D):
input_dims = output_dims = 3
is_affine = False

def __init__(self, *args, matrix=None, **kwargs):
super().__init__(*args, **kwargs)
Expand Down

0 comments on commit 711ce8a

Please sign in to comment.