CUDAGraph¶
- class torch.cuda.CUDAGraph[原始碼][原始碼]¶
CUDA 圖形的封裝器。
警告
此 API 處於 beta 階段,未來版本可能會變更。
- capture_begin(pool=None, capture_error_mode='global')[原始碼][原始碼]¶
開始在目前串流上擷取 CUDA 工作。
通常,您不應該自己呼叫
capture_begin
。請使用graph
或make_graphed_callables()
,它們會在內部呼叫capture_begin
。- 參數
pool (optional) – 由
graph_pool_handle()
或other_Graph_instance.pool()
傳回的 Token,暗示此圖形可能會與指示的 pool 共用記憶體。請參閱 圖形記憶體管理。capture_error_mode (str, optional) – 指定圖形擷取串流的 cudaStreamCaptureMode。可以是 "global"、"thread_local" 或 "relaxed"。在 cuda 圖形擷取期間,某些動作(例如 cudaMalloc)可能不安全。"global" 將會對其他執行緒中的動作產生錯誤,"thread_local" 只會對目前執行緒中的動作產生錯誤,而 "relaxed" 不會對這些動作產生錯誤。除非您熟悉 cudaStreamCaptureMode,否則請勿變更此設定
- capture_end()[原始碼][原始碼]¶
結束目前串流上的 CUDA 圖形擷取。
在
capture_end
之後,可以在此實例上呼叫replay
。通常,您不應該自己呼叫
capture_end
。請使用graph
或make_graphed_callables()
,它們會在內部呼叫capture_end
。
- debug_dump(debug_path)[原始碼][原始碼]¶
- 參數
debug_path (required) – 要將圖形傾印到的路徑。
如果透過 CUDAGraph.enable_debug_mode() 啟用偵錯,則呼叫偵錯函式以傾印圖形