graph¶
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[原始碼][原始碼]¶
Context-manager,將 CUDA 工作擷取到
torch.cuda.CUDAGraph
物件中,以便稍後重播。請參閱 CUDA 圖形,以取得一般介紹、詳細使用方式和限制。
- 參數
cuda_graph (torch.cuda.CUDAGraph) – 用於捕獲的 Graph 物件。
pool (optional) – 不透明的 Token(由呼叫
graph_pool_handle()
或other_Graph_instance.pool()
傳回),暗示此 graph 的捕獲可能會共享來自指定 pool 的記憶體。請參閱 Graph 記憶體管理。stream (torch.cuda.Stream, optional) – 如果提供,將在上下文中設定為目前的 stream。如果未提供,
graph
會將其自身的內部 side stream 設定為上下文中目前的 stream。capture_error_mode (str, optional) – 指定 graph 捕獲 stream 的 cudaStreamCaptureMode。可以是 "global"、"thread_local" 或 "relaxed"。在 cuda graph 捕獲期間,某些操作(例如 cudaMalloc)可能不安全。"global" 將會在其他執行緒中的操作上出錯,"thread_local" 將僅對目前執行緒中的操作出錯,而 "relaxed" 將不會對操作出錯。除非您熟悉 cudaStreamCaptureMode,否則請勿更改此設定
注意
為了有效的記憶體共享,如果您傳遞先前捕獲所使用的
pool
,並且先前捕獲使用了明確的stream
參數,則您應該將相同的stream
參數傳遞給此捕獲。警告
此 API 處於 beta 階段,未來版本可能會更改。