torch.nn.utils.parametrizations.orthogonal¶
-
torch.nn.utils.parametrizations.
orthogonal
(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[source]¶ Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting be or , the parametrized matrix is orthogonal as
where is the conjugate transpose when is complex and the transpose when is real-valued, and is the n-dimensional identity matrix. In plain words, will have orthonormal columns whenever and orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape (…, m, n).
The matrix may be parametrized via three different
orthogonal_map
in terms of the original tensor:"matrix_exp"
/"cayley"
: thematrix_exp()
and the Cayley map are applied to a skew-symmetric to give an orthogonal matrix."householder"
: computes a product of Householder reflectors (householder_product()
).
"matrix_exp"
/"cayley"
often make the parametrized weight converge faster than"householder"
, but they are slower to compute for very thin or very wide matrices.If
use_trivialization=True
(default), the parametrization implements the “Dynamic Trivialization Framework”, where an extra matrix is stored undermodule.parametrizations.weight[0].base
. This helps the convergence of the parametrized layer at the expense of some extra memory use. See Trivializations for Gradient-Based Optimization on Manifolds .Initial value of : If the original tensor is not parametrized and
use_trivialization=True
(default), the initial value of is that of the original tensor if it is orthogonal (or unitary in the complex case) and it is orthogonalized via the QR decomposition otherwise (seetorch.linalg.qr()
). Same happens when it is not parametrized andorthogonal_map="householder"
even whenuse_trivialization=False
. Otherwise, the initial value is the result of the composition of all the registered parametrizations applied to the original tensor.Note
This function is implemented using the parametrization functionality in
register_parametrization()
.- Parameters
module (nn.Module) – module on which to register the parametrization.
name (str, optional) – name of the tensor to make orthogonal. Default:
"weight"
.orthogonal_map (str, optional) – One of the following:
"matrix_exp"
,"cayley"
,"householder"
. Default:"matrix_exp"
if the matrix is square or complex,"householder"
otherwise.use_trivialization (bool, optional) – whether to use the dynamic trivialization framework. Default:
True
.
- Returns
The original module with an orthogonal parametrization registered to the specified weight
Example:
>>> orth_linear = orthogonal(nn.Linear(20, 40)) >>> orth_linear ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _Orthogonal() ) ) ) >>> Q = orth_linear.weight >>> torch.dist(Q.T @ Q, torch.eye(20)) tensor(4.9332e-07)