快速鍵

DecisionTransformerInferenceWrapper

class torchrl.modules.tensordict_module.DecisionTransformerInferenceWrapper(*args, **kwargs)[原始碼]

Decision Transformer 的推論動作封裝器。

專為 Decision Transformer 設計的封裝器,它會將輸入 tensordict 序列遮罩到推論環境中。 輸出將是一個 TensorDict,其鍵與輸入相同,但只有預測動作序列的最後一個動作和最後一個返回值。

此模組會建立 tensordict 的修改副本,即它不會就地修改 tensordict。

注意

如果動作、觀察或獎勵值的鍵不是標準的,則應使用方法 set_tensor_keys(),例如:

>>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz")

in_keys 是觀察、動作和返回值的鍵。 out_keys 與 in_keys 匹配,並新增策略中的任何其他 out_key (例如,分佈的參數或隱藏值)。

參數:

policy (TensorDictModule) – 輸入觀察值並產生動作值的策略模組

關鍵字引數:
  • inference_context (int) – 環境中不會被遮罩的先前動作數。 例如,對於形狀為 [batch_size, context, obs_dim] 的觀察輸入,其中 context=20 且 inference_context=5,則環境中的前 15 個條目將被遮罩。 預設值為 5。

  • spec (Optional[TensorSpec]) – 輸入 TensorDict 的規格。 如果為 None,則將從策略模組推斷。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import (
...      ProbabilisticActor,
...      TanhDelta,
...      DTActor,
...      DecisionTransformerInferenceWrapper,
...  )
>>> dtactor = DTActor(state_dim=4, action_dim=2,
...             transformer_config=DTActor.default_config()
... )
>>> actor_module = TensorDictModule(
...         dtactor,
...         in_keys=["observation", "action", "return_to_go"],
...         out_keys=["param"])
>>> dist_class = TanhDelta
>>> dist_kwargs = {
...     "low": -1.0,
...     "high": 1.0,
... }
>>> actor = ProbabilisticActor(
...     in_keys=["param"],
...     out_keys=["action"],
...     module=actor_module,
...     distribution_class=dist_class,
...     distribution_kwargs=dist_kwargs)
>>> inference_actor = DecisionTransformerInferenceWrapper(actor)
>>> sequence_length = 20
>>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4),
...                 "action": torch.randn(1, sequence_length, 2),
...                 "return_to_go": torch.randn(1, sequence_length, 1)}, [1,])
>>> result = inference_actor(td)
>>> print(result)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None) TensorDictBase[source]

定義每次呼叫時執行的計算。

應該被所有子類別覆寫。

注意

雖然正向傳遞 (forward pass) 的步驟需要在這個函式中定義,但應該呼叫 Module 實例,而不是直接呼叫這個函式,因為前者會處理已註冊的 hooks,而後者會直接忽略它們。

mask_context(tensordict: TensorDictBase) TensorDictBase[source]

遮罩輸入序列的上下文 (context)。

set_tensor_keys(**kwargs)[source]

設定模組的輸入鍵 (keys)。

關鍵字引數:
  • observation (NestedKey, optional) – 觀測 (observation) 鍵。

  • action (NestedKey, optional) – 動作 (action) 鍵 (網路的輸入)。

  • return_to_go (NestedKey, optional) – return_to_go 鍵。

  • out_action (NestedKey, optional) – 動作 (action) 鍵 (網路的輸出)。

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources