• 文件 >
  • torch.nn >
  • torch.nn.utils.parametrize.register_parametrization
快捷方式

torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[source][source]

在模組中註冊張量的參數化。

為了簡化說明,假設 tensor_name="weight"。當存取 module.weight 時,模組會回傳參數化的版本 parametrization(module.weight)。如果原始張量需要梯度,反向傳播會通過 parametrization 進行微分,而最佳化器會相應地更新張量。

模組第一次註冊參數化時,此函式會在模組中新增一個名為 parametrizations 的屬性,其類型為 ParametrizationList

張量 weight 上的參數化列表可通過 module.parametrizations.weight 存取。

原始張量可通過 module.parametrizations.weight.original 存取。

通過在同一個屬性上註冊多個參數化,可以將參數化連接起來。

已註冊參數化的訓練模式會在註冊時更新,以符合主機模組的訓練模式。

參數化的參數和緩衝區有一個內建的快取系統,可以使用上下文管理器 cached() 來啟動。

parametrization 可以選擇性地實現一個具有以下簽名的方法:

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

當註冊第一個參數化時,會調用此方法於未參數化的張量上,以計算原始張量的初始值。如果未實現此方法,則原始張量將只是未參數化的張量。

如果註冊在張量上的所有參數化都實現了 right_inverse,則可以通過賦值給它來初始化參數化的張量,如下面的範例所示。

第一個參數化可以依賴於多個輸入。這可以通過從 right_inverse 回傳一個張量元組來實現 (請參閱下面 RankOne 參數化的範例實現)。

在這種情況下,未約束的張量也位於 module.parametrizations.weight 下,名稱為 original0original1 等等。

注意

如果 unsafe=False (預設值),則 forward 和 right_inverse 方法都會被調用一次,以執行一些一致性檢查。如果 unsafe=True,則僅當張量未參數化時,才會調用 right_inverse,否則不會調用任何方法。

注意

在大多數情況下,right_inverse 將是一個函數,使得 forward(right_inverse(X)) == X (請參閱 右逆函數)。有時,當參數化不是滿射時,放寬此要求可能是合理的。

警告

如果參數化依賴於多個輸入,register_parametrization() 將註冊許多新的參數。如果在建立最佳化器後註冊此類參數化,則需要手動將這些新參數添加到最佳化器。請參閱 torch.Optimizer.add_param_group()

參數
  • module (nn.Module) – 要在其上註冊參數化的模組

  • tensor_name (str) – 要在其上註冊參數化的參數或緩衝區的名稱

  • parametrization (nn.Module) – 要註冊的參數化

關鍵字參數

unsafe (bool) – 一個布林標誌,表示參數化是否可以更改張量的 dtype 和形狀。預設值:False 警告:在註冊時不會檢查參數化的一致性。啟用此標誌的風險自負。

引發

ValueError – 如果模組沒有名為 tensor_name 的參數或緩衝區

回傳類型

Module

範例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>>         # Form a rank 1 matrix multiplying two vectors
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>>         # Project Z onto the rank 1 matrices
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>>         # Return rescaled singular vectors
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學

取得針對初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並取得問題解答

查看資源