捷徑

CudaGraphModule

class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None)

PyTorch 可呼叫物件的 cudagraph 包裝器。

CudaGraphModule 是一個包裝類別,為 PyTorch 可呼叫物件提供 CUDA 圖形的友善介面。

警告

CudaGraphModule 是一個原型功能,其 API 限制未來可能會變更。

此類別為 cudagraphs 提供了一個使用者友善的介面,允許在 GPU 上快速執行操作,且沒有 CPU 開銷。 它會對函式的輸入執行必要的檢查,並提供類似 nn.Module 的 API 來執行

警告

此模組要求包裝的函式符合一些要求。 使用者有責任確保滿足所有這些要求。

  • 函式不能有動態控制流程。 例如,以下程式碼片段將無法包裝在 CudaGraphModule

    >>> def func(x):
    ...     if x.norm() > 1:
    ...         return x + 1
    ...     else:
    ...         return x - 1
    

    幸運的是,PyTorch 在大多數情況下都提供了解決方案

    >>> def func(x):
    ...     return torch.where(x.norm() > 1, x + 1, x - 1)
    
  • 函式必須執行可以使用相同緩衝區精確重新執行的程式碼。 這表示不支援動態形狀 (輸入或程式碼執行期間的形狀變更)。 換句話說,輸入必須具有恆定的形狀。

  • 函式的輸出必須是分離的 (detached)。如果需要呼叫優化器,請將其放入輸入函式中。例如,以下函式是一個有效的運算子

    >>> def func(x, y):
    ...     optim.zero_grad()
    ...     loss_val = loss_fn(x, y)
    ...     loss_val.backward()
    ...     optim.step()
    ...     return loss_val.detach()
    
  • 輸入不應該是可微分的。如果需要使用 nn.Parameters(或一般的可微分張量),只需編寫一個將它們用作全域值的函式,而不是將它們作為輸入傳遞

    >>> x = nn.Parameter(torch.randn(()))
    >>> optim = Adam([x], lr=1)
    >>> def func(): # right
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    >>> def func(x): # wrong
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    
  • 是張量或 tensordict 的 Args 和 kwargs 可能會更改(前提是裝置和形狀匹配),但非張量的 args 和 kwargs 不應更改。例如,如果函式接收字串輸入,並且在任何時候輸入發生更改,則模組將會靜默執行程式碼,使用的字串是在 cudagraph 捕獲期間使用的字串。唯一支援的關鍵字引數是 tensordict_out,以防輸入是 tensordict。

  • 如果模組是一個 TensorDictModuleBase 實例,並且輸出 id 與輸入 id 匹配,那麼在呼叫 CudaGraphModule 期間,此恆等式將會被保留。在所有其他情況下,輸出將會被複製 (cloned),無論其元素是否與其中一個輸入匹配。

警告

CudaGraphModule 刻意設計成不是一個 Module,以避免收集輸入模組的參數並將其傳遞給優化器。

參數:
  • module (Callable) – 一個接收張量 (或 tensordict) 作為輸入並輸出 PyTreeable 張量集合的函式。如果提供了 tensordict,該模組也可以使用關鍵字引數執行(參見下面的範例)。

  • warmup (int, optional) – 在模組被編譯的情況下,預熱步驟的數量(編譯後的模組在被 cudagraph 捕獲之前,應該執行幾次)。對於所有模組,預設值為 2

  • in_keys (list of NestedKeys) –

    輸入鍵 (input keys),如果該模組將 TensorDict 作為輸入。如果存在此值,則預設為 module.in_keys,否則為 None

    注意

    如果提供了 in_keys 但為空,則假定該模組接收 tensordict 作為輸入。這足以讓 CudaGraphModule 知道該函式應被視為 TensorDictModule,但關鍵字引數將不會被分派。有關一些範例,請參見下文。

  • out_keys (list of NestedKeys) – 輸出鍵 (output keys),如果該模組接收和輸出 TensorDict 作為輸出。如果存在此值,則預設為 module.out_keys,否則為 None

範例

>>> # Wrap a simple function
>>> def func(x):
...     return x + 1
>>> func = CudaGraphModule(func)
>>> x = torch.rand((), device='cuda')
>>> out = func(x)
>>> assert isinstance(out, torch.Tensor)
>>> assert out == x+1
>>> # Wrap a tensordict module
>>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"])
>>> func = CudaGraphModule(func)
>>> # This can be called either with a TensorDict or regular keyword arguments alike
>>> y = func(x=x)
>>> td = TensorDict(x=x)
>>> td = func(td)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源