tensordict.nn.dispatch¶
- class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)¶
允許使用 kwargs 呼叫期望 TensorDict 的函式。
dispatch()
必須在具有in_keys
(或由source
關鍵字引數指示的另一個鍵的來源)和out_keys
(或另一個dest
鍵列表)屬性的模組中使用,這些屬性指示要從 tensordict 讀取和寫入哪些鍵。 包裝的函式也應該有一個tensordict
前導引數。產生的函式將返回一個單個張量(如果 out_keys 中只有一個元素),否則它將返回一個元組,該元組的排序方式與模組的
out_keys
相同。dispatch()
可以作為方法或類別使用,當需要傳遞額外的引數時。- 參數:
separator (str, optional) – 將子鍵組合在一起的分隔符,用於作為字串元組的
in_keys
。 預設值為"_"
。source (str 或 鍵的列表, optional) – 如果提供字串,則它指向包含要使用的輸入鍵列表的模組屬性。 如果提供列表,它將包含用作模組輸入的鍵。 預設值為
"in_keys"
,它是TensorDictModule
輸入鍵列表的屬性名稱。dest (str 或 鍵的列表, optional) – 如果提供字串,則它指向包含要使用的輸出鍵列表的模組屬性。 如果提供列表,它將包含用作模組輸出的鍵。 預設值為
"out_keys"
,它是TensorDictModule
輸出鍵列表的屬性名稱。auto_batch_size (bool, optional) – 若為
True
,則輸入 tensordict 的 batch size 將自動判定為所有輸入張量中最多的共同維度數量。預設值為True
。
範例
>>> class MyModule(nn.Module): ... in_keys = ["a"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # equivalently >>> class MyModule(nn.Module): ... keys_in = ["a"] ... keys_out = ["b"] ... ... @dispatch(source="keys_in", dest="keys_out") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # or this >>> class MyModule(nn.Module): ... @dispatch(source=["a"], dest=["b"]) ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all()
dispatch_kwargs()
也能以預設的"_"
分隔符號搭配巢狀鍵使用。範例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(a_c=torch.zeros(1, 2)) >>> assert (b == 1).all()
若要使用其他分隔符號,可以在建構子中使用
separator
參數指定。範例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch(separator="sep") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(asepc=torch.zeros(1, 2)) >>> assert (b == 1).all()
由於輸入鍵是經過排序的字串序列,因此
dispatch()
也能搭配未命名的引數使用,但順序必須與輸入鍵的順序相符。注意
如果第一個引數是
TensorDictBase
實例,則會假定 __不__ 使用 dispatch,且此 tensordict 包含所有執行模組所需的資訊。換句話說,無法將模組輸入的第一個鍵指向 tensordict 實例的 tensordict 分解。一般而言,最好只搭配 tensordict 葉節點使用dispatch()
。範例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c"), "d"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + tensordict["d"] ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> b, = module(torch.zeros(1, 2), torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> try: ... b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) # fails ... except: ... print("oopsy!") ...