Skip to content

Commit

Permalink
Fix incorrect dims in CompositeAffine [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
AnsonTran committed Apr 19, 2024
1 parent 2804271 commit 25e479a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 10 additions & 2 deletions lib/matplotlib/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ def test_rotate_around(self):
assert_array_almost_equal(r90[2].transform(self.multiple_points), [
[2, 2, 0], [-1, 0, 0], [2, 0, 4], [-3, 5, 0], [-4, 6, 6]])


r_pi = [Affine3D().rotate_around(*self.pivot, np.pi, dim) for dim in range(3)]
r180 = [Affine3D().rotate_deg_around(*self.pivot, 180, dim) for dim in range(3)]

Expand All @@ -448,7 +447,6 @@ def test_rotate_around(self):
assert_array_almost_equal(r180[2].transform(self.multiple_points), [
[0, 2, 0], [2, -1, 0], [2, 2, 4], [-3, -3, 0], [-4, -4, 6]])


r_pi_3_2 = [Affine3D().rotate_around(*self.pivot, 3 * np.pi / 2, dim)
for dim in range(3)]
r270 = [Affine3D().rotate_deg_around(*self.pivot, 270, dim) for dim in range(3)]
Expand All @@ -472,6 +470,16 @@ def test_rotate_around(self):
assert_array_almost_equal(
(r90[dim] + r180[dim]).get_matrix(), r270[dim].get_matrix())

def test_scale(self):
sx = Affine3D().scale(3, 1, 1)
sy = Affine3D().scale(1, -2, 1)
sz = Affine3D().scale(1, 1, 4)
trans = Affine3D().scale(3, -2, 4)
assert_array_equal((sx + sy + sz).get_matrix(), trans.get_matrix())
assert_array_equal(trans.transform(self.single_point), [3, -2, 4])
assert_array_equal(trans.transform(self.multiple_points), [
[6, 0, 0], [0, -6, 0], [0, 0, 16], [15, -10, 0], [18, -12, 24]])


def test_non_affine_caching():
class AssertingNonAffineTransform(mtransforms.Transform):
Expand Down
4 changes: 1 addition & 3 deletions lib/matplotlib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,10 +2833,8 @@ def __init__(self, a, b, **kwargs):
if a.output_dims != b.input_dims:
raise ValueError("The output dimension of 'a' must be equal to "
"the input dimensions of 'b'")
self.input_dims = a.input_dims
self.output_dims = b.output_dims
super().__init__(dims=a.output_dims, **kwargs)

super().__init__(**kwargs)
self._a = a
self._b = b
self.set_children(a, b)
Expand Down

0 comments on commit 25e479a

Please sign in to comment.