Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Issues with metric_tensor and trainable params causing errors with jax #5642

Open
1 task done
albi3ro opened this issue May 3, 2024 · 0 comments
Open
1 task done
Labels
bug 🐛 Something isn't working

Comments

@albi3ro
Copy link
Contributor

albi3ro commented May 3, 2024

Expected behavior

I expect similar behaviour as with interfaces.

Actual behavior

traceback below

Basically, since the device decomposes an operation, the trainable_params gets out of date, and there becomes a mismatch in the number of trainable parameters between the classical jacobian component and the quantum metric tensor component.

Additional information

Fixing this is going to take some thought into the nature of trainable_params and when we can rely on it.

Source code

To reproduce the error, you need:

  1. An operation that that is not supported on the device
  2. A parameter that is not trainable
  3. A parameter that is trainble
  4. Jax
dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def circuit(x):
    qml.Permute([0,1], wires=(0,1))
    qml.RY(0.5, wires=0)
    qml.RX(x, wires=0)
    return qml.expval(qml.Z(0))

qml.metric_tensor(circuit)(jax.numpy.array(0.5))

Tracebacks

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 9
      6     qml.RX(x, wires=0)
      7     return qml.expval(qml.Z(0))
----> 9 qml.metric_tensor(circuit)(jax.numpy.array(0.5))

File ~/Prog/pennylane/pennylane/workflow/qnode.py:1098, in QNode.__call__(self, *args, **kwargs)
   1095 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1097 try:
-> 1098     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1099 finally:
   1100     if old_interface == "auto":

File ~/Prog/pennylane/pennylane/workflow/qnode.py:1052, in QNode._execution_component(self, args, kwargs, override_shots)
   1049 full_transform_program.prune_dynamic_transform()
   1051 # pylint: disable=unexpected-keyword-arg
-> 1052 res = qml.execute(
   1053     (self._tape,),
   1054     device=self.device,
   1055     gradient_fn=self.gradient_fn,
   1056     interface=self.interface,
   1057     transform_program=full_transform_program,
   1058     config=config,
   1059     gradient_kwargs=self.gradient_kwargs,
   1060     override_shots=override_shots,
   1061     **self.execute_kwargs,
   1062 )
   1063 res = res[0]
   1065 # convert result to the interface in case the qfunc has no parameters

File ~/Prog/pennylane/pennylane/workflow/execution.py:617, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
    615 if no_interface_boundary_required:
    616     results = inner_execute(tapes)
--> 617     return post_processing(results)
    619 _grad_on_execution = False
    621 if (
    622     device_vjp
    623     and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos")
    624     and interface in jpc_interfaces
    625 ):

File ~/Prog/pennylane/pennylane/transforms/core/transform_program.py:88, in _apply_postprocessing_stack(results, postprocessing_stack)
     65 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
     66 
     67 Args:
   (...)
     85 
     86 """
     87 for postprocessing in reversed(postprocessing_stack):
---> 88     results = postprocessing(results)
     89 return results

File ~/Prog/pennylane/pennylane/transforms/core/transform_program.py:58, in _batch_postprocessing(results, individual_fns, slices)
     32 def _batch_postprocessing(
     33     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     34 ) -> ResultBatch:
     35     """Broadcast individual post processing functions onto their respective tapes.
     36 
     37     Args:
   (...)
     56 
     57     """
---> 58     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/Prog/pennylane/pennylane/transforms/core/transform_program.py:58, in <genexpr>(.0)
     32 def _batch_postprocessing(
     33     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     34 ) -> ResultBatch:
     35     """Broadcast individual post processing functions onto their respective tapes.
     36 
     37     Args:
   (...)
     56 
     57     """
---> 58     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/Prog/pennylane/pennylane/gradients/metric_tensor.py:63, in _contract_metric_tensor_with_cjac(mt, cjac, tape)
     59 if is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0])):
     60     # Classical Jacobian is the identity. No classical processing
     61     # is present inside the QNode.
     62     return mt
---> 63 mt_cjac = qml.math.tensordot(mt, cjac, axes=[[-1], [0]])
     64 mt = qml.math.tensordot(cjac, mt_cjac, axes=[[0], [0]])
     66 return mt

File ~/Prog/pennylane/pennylane/math/multi_dispatch.py:151, in multi_dispatch.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    148 interface = interface or get_interface(*dispatch_args)
    149 kwargs["like"] = interface
--> 151 return fn(*args, **kwargs)

File ~/Prog/pennylane/pennylane/math/multi_dispatch.py:398, in tensordot(tensor1, tensor2, axes, like)
    372 """Returns the tensor product of two tensors.
    373 In general ``axes`` specifies either the set of axes for both
    374 tensors that are contracted (with the first/second entry of ``axes``
   (...)
    395     tensor_like: the tensor product of the two input tensors
    396 """
    397 tensor1, tensor2 = np.coerce([tensor1, tensor2], like=like)
--> 398 return np.tensordot(tensor1, tensor2, axes=axes, like=like)

File ~/Prog/pl/lib/python3.11/site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
     31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
     32 dispatch to retrieve ``fn`` based on whichever library defines the class of
     33 the ``args[0]``, or the ``like`` keyword argument if specified.
   (...)
     77     <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
     78 """
     79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3422, in tensordot(a, b, axes, precision, preferred_element_type)
   3419   msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
   3420          "of lists/tuples of ints.")
   3421   raise TypeError(msg)
-> 3422 result = lax.dot_general(a, b, (contracting_dims, ((), ())),
   3423                          precision=precision, preferred_element_type=preferred_element_type)
   3424 return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/lax/lax.py:762, in dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
    757 bdims = (api_util._ensure_index_tuple(lhs_batch),
    758          api_util._ensure_index_tuple(rhs_batch))
    759 preferred_element_type = (
    760     None if preferred_element_type is None else
    761     dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
--> 762 return dot_general_p.bind(lhs, rhs,
    763                           dimension_numbers=(cdims, bdims),
    764                           precision=canonicalize_precision(precision),
    765                           preferred_element_type=preferred_element_type)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/core.py:444, in Primitive.bind(self, *args, **params)
    441 def bind(self, *args, **params):
    442   assert (not config.enable_checks.value or
    443           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 444   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/core.py:447, in Primitive.bind_with_trace(self, trace, args, params)
    446 def bind_with_trace(self, trace, args, params):
--> 447   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    448   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/core.py:935, in EvalTrace.process_primitive(self, primitive, tracers, params)
    934 def process_primitive(self, primitive, tracers, params):
--> 935   return primitive.impl(*tracers, **params)

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 19 frame]

File ~/Prog/pl/lib/python3.11/site-packages/jax/_src/lax/lax.py:2580, in _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision, preferred_element_type)
   2577 if not core.definitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
   2578   msg = ("dot_general requires contracting dimensions to have the same "
   2579          "shape, got {} and {}.")
-> 2580   raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
   2582 return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)

TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (2,).

System information

Master

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant