捷徑

torch.nn.utils.parametrizations.orthogonal

torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[原始碼][原始碼]

對矩陣或一批矩陣應用正交或么正參數化。

K\mathbb{K}R\mathbb{R}C\mathbb{C},參數化的矩陣 QKm×nQ \in \mathbb{K}^{m \times n}正交的,如下所示:

QHQ=Inif mnQQH=Imif m<n\begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*}

其中 QHQ^{\text{H}}QQ 是複數時表示共軛轉置,而在 QQ 是實數值時表示轉置,且 In\mathrm{I}_nn 維的單位矩陣。簡單來說,當 mnm \geq n 時,QQ 將具有正交歸一化的列,否則將具有正交歸一化的行。

如果張量具有兩個以上的維度,我們會將其視為形狀為 (…, m, n) 的矩陣批次。

矩陣 QQ 可以透過原始張量的三個不同的 orthogonal_map 進行參數化

  • "matrix_exp"/"cayley"matrix_exp() Q=exp(A)Q = \exp(A)Cayley 映射 Q=(In+A/2)(InA/2)1Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1} 應用於斜對稱矩陣 AA 以產生正交矩陣。

  • "householder":計算 Householder 反射器的乘積 (householder_product())。

"matrix_exp"/ "cayley" 通常會比 "householder" 更快地使參數化的權重收斂,但對於非常細長或非常寬的矩陣,它們的計算速度會較慢。

如果 use_trivialization=True (預設值),參數化會實作「動態平凡化框架」(Dynamic Trivialization Framework),其中一個額外的矩陣 BKn×nB \in \mathbb{K}^{n \times n} 儲存在 module.parametrizations.weight[0].base 下。這有助於參數化層的收斂,但會犧牲一些額外的記憶體使用量。請參閱 Trivializations for Gradient-Based Optimization on Manifolds

QQ 的初始值:如果原始張量未被參數化且 use_trivialization=True (預設值),QQ 的初始值會與原始張量的初始值相同(如果它是正交的,或在複數情況下是么正的),否則會透過 QR 分解進行正交化(請參閱 torch.linalg.qr())。當它未被參數化且 orthogonal_map="householder" 時,即使 use_trivialization=False,也會發生相同情況。否則,初始值是應用於原始張量的所有已註冊參數化組合的結果。

注意

此函數使用 register_parametrization() 中的參數化功能來實現。

參數
  • module (nn.Module) – 要在其上註冊參數化的模組。

  • name (str, optional) – 要使其正交的張量的名稱。預設值: "weight"

  • orthogonal_map (str, optional) – 以下其中之一: "matrix_exp", "cayley", "householder"。預設值:如果矩陣是正方形或複數,則為 "matrix_exp",否則為 "householder"

  • use_trivialization (bool, optional) – 是否使用動態平凡化框架。預設值: True

回傳

將正交參數化註冊到指定權重的原始模組

回傳類型

Module

範例

>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
    (weight): ParametrizationList(
    (0): _Orthogonal()
    )
)
)
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源