快捷方式

tensordict.nn.dispatch

class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)

允許使用 kwargs 呼叫期望 TensorDict 的函式。

dispatch() 必須在具有 in_keys(或由 source 關鍵字引數指示的另一個鍵的來源)和 out_keys(或另一個 dest 鍵列表)屬性的模組中使用,這些屬性指示要從 tensordict 讀取和寫入哪些鍵。 包裝的函式也應該有一個 tensordict 前導引數。

產生的函式將返回一個單個張量(如果 out_keys 中只有一個元素),否則它將返回一個元組,該元組的排序方式與模組的 out_keys 相同。

dispatch() 可以作為方法或類別使用,當需要傳遞額外的引數時。

參數:
  • separator (str, optional) – 將子鍵組合在一起的分隔符,用於作為字串元組的 in_keys。 預設值為 "_"

  • source (str鍵的列表, optional) – 如果提供字串,則它指向包含要使用的輸入鍵列表的模組屬性。 如果提供列表,它將包含用作模組輸入的鍵。 預設值為 "in_keys",它是 TensorDictModule 輸入鍵列表的屬性名稱。

  • dest (str鍵的列表, optional) – 如果提供字串,則它指向包含要使用的輸出鍵列表的模組屬性。 如果提供列表,它將包含用作模組輸出的鍵。 預設值為 "out_keys",它是 TensorDictModule 輸出鍵列表的屬性名稱。

  • auto_batch_size (bool, optional) – 若為 True,則輸入 tensordict 的 batch size 將自動判定為所有輸入張量中最多的共同維度數量。預設值為 True

範例

>>> class MyModule(nn.Module):
...     in_keys = ["a"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # equivalently
>>> class MyModule(nn.Module):
...     keys_in = ["a"]
...     keys_out = ["b"]
...
...     @dispatch(source="keys_in", dest="keys_out")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # or this
>>> class MyModule(nn.Module):
...     @dispatch(source=["a"], dest=["b"])
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()

dispatch_kwargs() 也能以預設的 "_" 分隔符號搭配巢狀鍵使用。

範例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(a_c=torch.zeros(1, 2))
>>> assert (b == 1).all()

若要使用其他分隔符號,可以在建構子中使用 separator 參數指定。

範例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch(separator="sep")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(asepc=torch.zeros(1, 2))
>>> assert (b == 1).all()

由於輸入鍵是經過排序的字串序列,因此 dispatch() 也能搭配未命名的引數使用,但順序必須與輸入鍵的順序相符。

注意

如果第一個引數是 TensorDictBase 實例,則會假定 __不__ 使用 dispatch,且此 tensordict 包含所有執行模組所需的資訊。換句話說,無法將模組輸入的第一個鍵指向 tensordict 實例的 tensordict 分解。一般而言,最好只搭配 tensordict 葉節點使用 dispatch()

範例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c"), "d"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + tensordict["d"]
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> b, = module(torch.zeros(1, 2), torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> try:
...     b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2))  # fails
... except:
...     print("oopsy!")
...

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您的問題解答

檢視資源