tensordict.nn.set_skip_existing¶
- class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')¶
用於跳過 TensorDict 圖形中現有節點的上下文管理器。
當用作上下文管理器時,它會將 skip_existing() 值設定為指示的
mode
,讓使用者能夠編寫檢查全域值並相應地執行程式碼的方法。當用作方法裝飾器時,它會檢查 tensordict 輸入鍵,如果
skip_existing()
呼叫傳回True
,如果所有輸出鍵都已存在,它將跳過該方法。預期不會將其用作不遵守以下簽名的方法的裝飾器:def fun(self, tensordict, *args, **kwargs)
。- 參數:
mode (bool, optional) – 如果
True
,則表示圖形中現有的項目不會被覆寫,除非它們僅部分存在。skip_existing()
將傳回True
。如果False
,則不會執行任何檢查。如果None
,則不會變更skip_existing()
的值。這旨在專門用於裝飾方法,並允許它們的行為在用作上下文管理器時取決於同一個類別(請參閱下面的範例)。預設值為True
。in_key_attr (str, optional) – 被裝飾的模組方法的輸入鍵清單屬性的名稱。預設值為
in_keys
。out_key_attr (str, optional) – 被裝飾的模組方法的輸出鍵清單屬性的名稱。預設值為
out_keys
。
範例
>>> with set_skip_existing(): ... if skip_existing(): ... print("True") ... else: ... print("False") ... True >>> print("calling from outside:", skip_existing()) calling from outside: False
此類別也可以用作裝飾器
範例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> module(TensorDict()) # prints hello hello TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
每當想要讓上下文管理器從外部處理跳過事項時,使用設定為
None
的模式裝飾方法會很有用範例
>>> from tensordict import TensorDict >>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase >>> class MyModule(TensorDictModuleBase): ... in_keys = [] ... out_keys = ["out"] ... @set_skip_existing(None) ... def forward(self, tensordict): ... print("hello") ... tensordict.set("out", torch.zeros(())) ... return tensordict >>> module = MyModule() >>> _ = module(TensorDict({"out": torch.zeros(())}, [])) # prints "hello" hello >>> with set_skip_existing(True): ... _ = module(TensorDict({"out": torch.zeros(())}, [])) # no print
注意
為了允許模組擁有相同的輸入和輸出鍵,且不會錯誤地忽略子圖,只要輸出鍵同時也是輸入鍵,
@set_skip_existing(True)
將會被停用。>>> class MyModule(TensorDictModuleBase): ... in_keys = ["out"] ... out_keys = ["out"] ... @set_skip_existing() ... def forward(self, tensordict): ... print("calling the method!") ... return tensordict ... >>> module = MyModule() >>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything calling the method! TensorDict( fields={ out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)