快捷方式

SelectKeys

class torchrl.trainers.SelectKeys(keys: Sequence[str])[原始碼]

在 TensorDict 批次中選擇鍵。

參數:

keys (字串的可迭代物件) – 要在 tensordict 中選擇的鍵。

範例

>>> trainer = make_trainer()
>>> key1 = "first key"
>>> key2 = "second key"
>>> td = TensorDict(
...     {
...         key1: torch.randn(3),
...         key2: torch.randn(3),
...     },
...     [],
... )
>>> trainer.register_op("batch_process", SelectKeys([key1]))
>>> td_out = trainer._process_batch_hook(td)
>>> assert key1 in td_out.keys()
>>> assert key2 not in td_out.keys()
register(trainer, name='select_keys') None[原始碼]

在預設位置向訓練器註冊 hook。

參數:
  • trainer (Trainer) – 必須註冊 hook 的訓練器。

  • name (str) – hook 的名稱。

注意

若要在預設位置以外的其他位置註冊 hook,請使用 register_op()

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源