快捷方式

NormalParamWrapper

class torchrl.modules.NormalParamWrapper(operator: Module, scale_mapping: str = 'biased_softplus_1.0', scale_lb: Number = 0.0001)[原始碼]

Normal 分佈參數的包裝器。

參數:
  • operator (nn.Module) – 其輸出將轉換為位置和尺度參數的運算子

  • scale_mapping (str, optional) – 與 std 一起使用的正向映射函式。預設值 = “biased_softplus_1.0”(即,帶有偏差的 softplus 映射,使得 fn(0.0) = 1.0)選項:“softplus”、“exp”、“relu”、“biased_softplus_1”;

  • scale_lb (Number, optional) – 變異數可以採取的最小值。預設值為 1e-4。

範例

>>> from torch import nn
>>> import torch
>>> module = nn.Linear(3, 4)
>>> module_normal = NormalParamWrapper(module)
>>> tensor = torch.randn(3)
>>> loc, scale = module_normal(tensor)
>>> print(loc.shape, scale.shape)
torch.Size([2]) torch.Size([2])
>>> assert (scale > 0).all()
>>> # with modules that return more than one tensor
>>> module = nn.LSTM(3, 4)
>>> module_normal = NormalParamWrapper(module)
>>> tensor = torch.randn(4, 2, 3)
>>> loc, scale, others = module_normal(tensor)
>>> print(loc.shape, scale.shape)
torch.Size([4, 2, 2]) torch.Size([4, 2, 2])
>>> assert (scale > 0).all()
forward(*tensors: Tensor) Tuple[Tensor][原始碼]

定義每次呼叫時執行的計算。

應該被所有子類別覆寫。

注意

雖然正向傳遞的配方需要在這個函式中定義,但應該在之後呼叫 Module 實例,而不是這個函式,因為前者會處理執行註冊的鉤子,而後者會靜默地忽略它們。

文件

取得 PyTorch 的全面開發人員文件

查看文件

教學

取得適合初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

檢視資源