捷徑

CommDebugMode 入門

建立時間:2024 年 8 月 19 日 | 最後更新:2024 年 10 月 08 日 | 最後驗證:2024 年 11 月 05 日

作者Anshul Sinha

在本教學中,我們將探索如何將 CommDebugMode 與 PyTorch 的 DistributedTensor (DTensor) 一起使用,透過追蹤分散式訓練環境中的集體運算來進行除錯。

先決條件

  • Python 3.8 - 3.11

  • PyTorch 2.2 或更高版本

什麼是 CommDebugMode 以及它有什麼用

隨著模型尺寸持續增加,使用者正尋求利用各種平行策略的組合來擴展分散式訓練。然而,現有解決方案之間缺乏互通性構成了一項重大挑戰,主要是由於缺乏可以橋接這些不同平行策略的統一抽象。為了應對這個問題,PyTorch 提出了 DistributedTensor(DTensor),它抽象了分散式訓練中張量通訊的複雜性,提供無縫的使用者體驗。然而,當處理現有的平行解決方案以及使用像 DTensor 這樣的統一抽象開發平行解決方案時,缺乏關於底層發生了什麼以及何時發生集體通訊的透明度,可能會使進階使用者難以識別和解決問題。為了應對這個挑戰,CommDebugMode (一個 Python 上下文管理器) 將作為 DTensor 的主要除錯工具之一,使使用者能夠在使用 DTensor 時查看何時以及為何發生集體運算,從而有效地解決這個問題。

使用 CommDebugMode

以下是如何使用 CommDebugMode

# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
    with comm_mode:
        output = model(inp)

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

這是 noise level 為 0 時 MLPModule 的輸出外觀

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

要使用 CommDebugMode,您必須將執行模型的程式碼包裝在 CommDebugMode 中,並呼叫您想要用來顯示資料的 API。您還可以使用 noise_level 參數來控制顯示資訊的詳細程度。以下是每個 noise level 顯示的內容

0. 打印模組層級的集體計數
1. 打印 DTensor 運算(不包括瑣碎運算)、模組分片資訊
2. 打印張量運算(不包括瑣碎運算)
3. 打印所有運算

在上面的範例中,您可以看到集體運算 all_reduce 在 MLPModule 的前向傳遞中發生一次。此外,您可以使用 CommDebugMode 來精確指出 all-reduce 運算發生在 MLPModule 的第二個線性層中。

以下是您可以使用的互動式模組樹狀結構視覺化,可以用來上傳您自己的 JSON 轉儲

CommDebugMode 模組樹
在此處拖曳檔案

結論

在本食譜中,我們學習了如何使用 CommDebugMode 來除錯 Distributed Tensor 和使用 PyTorch 通訊集體的平行解決方案。您可以在嵌入式視覺瀏覽器中使用您自己的 JSON 輸出。

有關 CommDebugMode 的更詳細資訊,請參閱comm_mode_features_example.py

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並取得您問題的解答

查看資源