CommDebugMode
入門¶
建立時間:2024 年 8 月 19 日 | 最後更新:2024 年 10 月 08 日 | 最後驗證:2024 年 11 月 05 日
作者:Anshul Sinha
在本教學中,我們將探索如何將 CommDebugMode
與 PyTorch 的 DistributedTensor (DTensor) 一起使用,透過追蹤分散式訓練環境中的集體運算來進行除錯。
先決條件¶
Python 3.8 - 3.11
PyTorch 2.2 或更高版本
什麼是 CommDebugMode
以及它有什麼用¶
隨著模型尺寸持續增加,使用者正尋求利用各種平行策略的組合來擴展分散式訓練。然而,現有解決方案之間缺乏互通性構成了一項重大挑戰,主要是由於缺乏可以橋接這些不同平行策略的統一抽象。為了應對這個問題,PyTorch 提出了 DistributedTensor(DTensor),它抽象了分散式訓練中張量通訊的複雜性,提供無縫的使用者體驗。然而,當處理現有的平行解決方案以及使用像 DTensor 這樣的統一抽象開發平行解決方案時,缺乏關於底層發生了什麼以及何時發生集體通訊的透明度,可能會使進階使用者難以識別和解決問題。為了應對這個挑戰,CommDebugMode
(一個 Python 上下文管理器) 將作為 DTensor 的主要除錯工具之一,使使用者能夠在使用 DTensor 時查看何時以及為何發生集體運算,從而有效地解決這個問題。
使用 CommDebugMode
¶
以下是如何使用 CommDebugMode
# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)
這是 noise level 為 0 時 MLPModule 的輸出外觀
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
要使用 CommDebugMode
,您必須將執行模型的程式碼包裝在 CommDebugMode
中,並呼叫您想要用來顯示資料的 API。您還可以使用 noise_level
參數來控制顯示資訊的詳細程度。以下是每個 noise level 顯示的內容
在上面的範例中,您可以看到集體運算 all_reduce 在 MLPModule
的前向傳遞中發生一次。此外,您可以使用 CommDebugMode
來精確指出 all-reduce 運算發生在 MLPModule
的第二個線性層中。
以下是您可以使用的互動式模組樹狀結構視覺化,可以用來上傳您自己的 JSON 轉儲
結論¶
在本食譜中,我們學習了如何使用 CommDebugMode
來除錯 Distributed Tensor 和使用 PyTorch 通訊集體的平行解決方案。您可以在嵌入式視覺瀏覽器中使用您自己的 JSON 輸出。
有關 CommDebugMode
的更詳細資訊,請參閱comm_mode_features_example.py