捷徑

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[來源]

用於具有邊界動作空間之確定性策略的 Tanh 模組。

此轉換將用作 TensorDictModule 層,以將網路輸出映射到邊界空間。

參數:
  • in_keys (字串列表字串元組) – 模組的輸入鍵。

  • out_keys (字串列表字串元組選用) – 模組的輸出鍵。如果未提供,則假定與 in_keys 相同的鍵。

關鍵字引數:
  • spec (TensorSpec選用) – 如果提供,則為輸出的規格。如果提供 Composite,則其鍵必須與 out_keys 中的鍵匹配。否則,假定 out_keys 的鍵,並且所有輸出都使用相同的規格。

  • low (floatnp.ndarraytorch.Tensor) – 空間的下限。如果未提供且未提供規格,則假定為 -1。如果提供規格,則將檢索規格的最小值。

  • high (floatnp.ndarraytorch.Tensor) – 空間的上限。如果未提供且未提供規格,則假定為 1。如果提供規格,則將檢索規格的最大值。

  • clamp (bool選用) – 如果 True,則輸出將被限制在邊界內,但與邊界至少有一定的解析度。預設為 False

範例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import Bounded
>>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = Composite(
...     a=Bounded(-3, 0, shape=()),
...     b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict=None)[來源]

定義在每次呼叫時執行的計算。

應被所有子類別覆寫。

注意

雖然正向傳遞的配方需要在這個函式中定義,但應該在之後呼叫 Module 實例,而不是這個函式,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源