捷徑

tensordict.nn.set_skip_existing

class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')

用於跳過 TensorDict 圖形中現有節點的上下文管理器。

當用作上下文管理器時,它會將 skip_existing() 值設定為指示的 mode,讓使用者能夠編寫檢查全域值並相應地執行程式碼的方法。

當用作方法裝飾器時,它會檢查 tensordict 輸入鍵,如果 skip_existing() 呼叫傳回 True,如果所有輸出鍵都已存在,它將跳過該方法。預期不會將其用作不遵守以下簽名的方法的裝飾器:def fun(self, tensordict, *args, **kwargs)

參數:
  • mode (bool, optional) – 如果 True,則表示圖形中現有的項目不會被覆寫,除非它們僅部分存在。skip_existing() 將傳回 True。如果 False,則不會執行任何檢查。如果 None,則不會變更 skip_existing() 的值。這旨在專門用於裝飾方法,並允許它們的行為在用作上下文管理器時取決於同一個類別(請參閱下面的範例)。預設值為 True

  • in_key_attr (str, optional) – 被裝飾的模組方法的輸入鍵清單屬性的名稱。預設值為 in_keys

  • out_key_attr (str, optional) – 被裝飾的模組方法的輸出鍵清單屬性的名稱。預設值為 out_keys

範例

>>> with set_skip_existing():
...     if skip_existing():
...         print("True")
...     else:
...         print("False")
...
True
>>> print("calling from outside:", skip_existing())
calling from outside: False

此類別也可以用作裝飾器

範例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> module(TensorDict())  # prints hello
hello
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

每當想要讓上下文管理器從外部處理跳過事項時,使用設定為 None 的模式裝飾方法會很有用

範例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing(None)
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> _ = module(TensorDict({"out": torch.zeros(())}, []))  # prints "hello"
hello
>>> with set_skip_existing(True):
...     _ = module(TensorDict({"out": torch.zeros(())}, []))  # no print

注意

為了允許模組擁有相同的輸入和輸出鍵,且不會錯誤地忽略子圖,只要輸出鍵同時也是輸入鍵,@set_skip_existing(True) 將會被停用。

>>> class MyModule(TensorDictModuleBase):
...     in_keys = ["out"]
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("calling the method!")
...         return tensordict
...
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
calling the method!
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的全面開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題解答

檢視資源