torch.autograd.functional.jacobian¶
-
torch.autograd.functional.
jacobian
(func, inputs, create_graph=False, strict=False, vectorize=False, strategy='reverse-mode')[source]¶ Function that computes the Jacobian of a given function.
- Parameters
func (function) – a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor) – inputs to the function
func
.create_graph (bool, optional) – If
True
, the Jacobian will be computed in a differentiable manner. Note that whenstrict
isFalse
, the result can not require gradients or be disconnected from the inputs. Defaults toFalse
.strict (bool, optional) – If
True
, an error will be raised when we detect that there exists an input such that all the outputs are independent of it. IfFalse
, we return a Tensor of zeros as the jacobian for said inputs, which is the expected mathematical value. Defaults toFalse
.vectorize (bool, optional) – This feature is experimental. Please consider using functorch’s jacrev or jacfwd instead if you are looking for something less experimental and more performant. When computing the jacobian, usually we invoke
autograd.grad
once per row of the jacobian. If this flag isTrue
, we perform only a singleautograd.grad
call withbatched_grad=True
which uses the vmap prototype feature. Though this should lead to performance improvements in many cases, because this feature is still experimental, there may be performance cliffs. Seetorch.autograd.grad()
’sbatched_grad
parameter for more information.strategy (str, optional) – Set to
"forward-mode"
or"reverse-mode"
to determine whether the Jacobian will be computed with forward or reverse mode AD. Currently,"forward-mode"
requiresvectorized=True
. Defaults to"reverse-mode"
. Iffunc
has more outputs than inputs,"forward-mode"
tends to be more performant. Otherwise, prefer to use"reverse-mode"
.
- Returns
if there is a single input and output, this will be a single Tensor containing the Jacobian for the linearized inputs and output. If one of the two is a tuple, then the Jacobian will be a tuple of Tensors. If both of them are tuples, then the Jacobian will be a tuple of tuple of Tensors where
Jacobian[i][j]
will contain the Jacobian of thei
th output andj
th input and will have as size the concatenation of the sizes of the corresponding output and the corresponding input and will have same dtype and device as the corresponding input. If strategy isforward-mode
, the dtype will be that of the output; otherwise, the input.- Return type
Jacobian (Tensor or nested tuple of Tensors)
Example
>>> def exp_reducer(x): ... return x.exp().sum(dim=1) >>> inputs = torch.rand(2, 2) >>> jacobian(exp_reducer, inputs) tensor([[[1.4917, 2.4352], [0.0000, 0.0000]], [[0.0000, 0.0000], [2.4369, 2.3799]]])
>>> jacobian(exp_reducer, inputs, create_graph=True) tensor([[[1.4917, 2.4352], [0.0000, 0.0000]], [[0.0000, 0.0000], [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
>>> def exp_adder(x, y): ... return 2 * x.exp() + 3 * y >>> inputs = (torch.rand(2), torch.rand(2)) >>> jacobian(exp_adder, inputs) (tensor([[2.8052, 0.0000], [0.0000, 3.3963]]), tensor([[3., 0.], [0., 3.]]))