CudaGraphModule¶
- class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None)¶
PyTorch 可呼叫物件的 cudagraph 包裝器。
CudaGraphModule
是一個包裝類別,為 PyTorch 可呼叫物件提供 CUDA 圖形的友善介面。警告
CudaGraphModule
是一個原型功能,其 API 限制未來可能會變更。此類別為 cudagraphs 提供了一個使用者友善的介面,允許在 GPU 上快速執行操作,且沒有 CPU 開銷。 它會對函式的輸入執行必要的檢查,並提供類似 nn.Module 的 API 來執行
警告
此模組要求包裝的函式符合一些要求。 使用者有責任確保滿足所有這些要求。
函式不能有動態控制流程。 例如,以下程式碼片段將無法包裝在 CudaGraphModule 中
>>> def func(x): ... if x.norm() > 1: ... return x + 1 ... else: ... return x - 1
幸運的是,PyTorch 在大多數情況下都提供了解決方案
>>> def func(x): ... return torch.where(x.norm() > 1, x + 1, x - 1)
函式必須執行可以使用相同緩衝區精確重新執行的程式碼。 這表示不支援動態形狀 (輸入或程式碼執行期間的形狀變更)。 換句話說,輸入必須具有恆定的形狀。
函式的輸出必須是分離的 (detached)。如果需要呼叫優化器,請將其放入輸入函式中。例如,以下函式是一個有效的運算子
>>> def func(x, y): ... optim.zero_grad() ... loss_val = loss_fn(x, y) ... loss_val.backward() ... optim.step() ... return loss_val.detach()
輸入不應該是可微分的。如果需要使用 nn.Parameters(或一般的可微分張量),只需編寫一個將它們用作全域值的函式,而不是將它們作為輸入傳遞
>>> x = nn.Parameter(torch.randn(())) >>> optim = Adam([x], lr=1) >>> def func(): # right ... optim.zero_grad() ... (x+1).backward() ... optim.step() >>> def func(x): # wrong ... optim.zero_grad() ... (x+1).backward() ... optim.step()
是張量或 tensordict 的 Args 和 kwargs 可能會更改(前提是裝置和形狀匹配),但非張量的 args 和 kwargs 不應更改。例如,如果函式接收字串輸入,並且在任何時候輸入發生更改,則模組將會靜默執行程式碼,使用的字串是在 cudagraph 捕獲期間使用的字串。唯一支援的關鍵字引數是 tensordict_out,以防輸入是 tensordict。
如果模組是一個
TensorDictModuleBase
實例,並且輸出 id 與輸入 id 匹配,那麼在呼叫CudaGraphModule
期間,此恆等式將會被保留。在所有其他情況下,輸出將會被複製 (cloned),無論其元素是否與其中一個輸入匹配。
警告
CudaGraphModule
刻意設計成不是一個Module
,以避免收集輸入模組的參數並將其傳遞給優化器。- 參數:
module (Callable) – 一個接收張量 (或 tensordict) 作為輸入並輸出 PyTreeable 張量集合的函式。如果提供了 tensordict,該模組也可以使用關鍵字引數執行(參見下面的範例)。
warmup (int, optional) – 在模組被編譯的情況下,預熱步驟的數量(編譯後的模組在被 cudagraph 捕獲之前,應該執行幾次)。對於所有模組,預設值為
2
。in_keys (list of NestedKeys) –
輸入鍵 (input keys),如果該模組將 TensorDict 作為輸入。如果存在此值,則預設為
module.in_keys
,否則為None
。注意
如果提供了
in_keys
但為空,則假定該模組接收 tensordict 作為輸入。這足以讓CudaGraphModule
知道該函式應被視為 TensorDictModule,但關鍵字引數將不會被分派。有關一些範例,請參見下文。out_keys (list of NestedKeys) – 輸出鍵 (output keys),如果該模組接收和輸出 TensorDict 作為輸出。如果存在此值,則預設為
module.out_keys
,否則為None
。
範例
>>> # Wrap a simple function >>> def func(x): ... return x + 1 >>> func = CudaGraphModule(func) >>> x = torch.rand((), device='cuda') >>> out = func(x) >>> assert isinstance(out, torch.Tensor) >>> assert out == x+1 >>> # Wrap a tensordict module >>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]) >>> func = CudaGraphModule(func) >>> # This can be called either with a TensorDict or regular keyword arguments alike >>> y = func(x=x) >>> td = TensorDict(x=x) >>> td = func(td)