• 文件 >
  • 張量並行處理 - torch.distributed.tensor.parallel
捷徑

張量並行處理 - torch.distributed.tensor.parallel

張量並行處理 (TP) 建構於 PyTorch DistributedTensor 之上 ( DTensor),並提供不同的並行處理樣式:Colwise、Rowwise 和序列並行處理。

警告

張量並行處理 API 為實驗性質,可能會變更。

使用張量並行處理來平行化 nn.Module 的入口點是

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None)[source][source]

透過根據使用者指定的計畫平行化模組或子模組,在 PyTorch 中應用張量並行處理。

我們根據 parallelize_plan 平行化模組或子模組。 parallelize_plan 包含 ParallelStyle,其指示使用者希望如何平行化模組或子模組。

使用者也可以針對每個模組的完整限定名稱 (FQN) 指定不同的平行化樣式。

請注意,parallelize_module 僅接受一維的 DeviceMesh。如果您有二維或 N 維的 DeviceMesh,請先將 DeviceMesh 切片成一維的子 DeviceMesh,然後再傳遞給此 API (例如 device_mesh["tp"])。

參數
  • module (nn.Module) – 要平行化的模組。

  • device_mesh (DeviceMesh, optional) – 描述 DTensor 的裝置網格拓撲結構的物件。如果未指定,則呼叫必須在 DeviceMesh 上下文中進行。

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]], optional) – 用於平行化模組的計畫。它可以是一個包含我們如何準備 Tensor Parallelism 的輸入/輸出的 ParallelStyle 物件,也可以是一個模組 FQN 及其對應的 ParallelStyle 物件的字典。如果未指定,則目前呼叫不會執行任何動作。

回傳

已平行化的 nn.Module 物件。

回傳類型

Module

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>

注意

對於像 Attention、MLP 層等複雜的模組架構,我們建議將不同的 ParallelStyles 組合在一起 (例如 ColwiseParallelRowwiseParallel) 並作為 parallelize_plan 傳遞,以實現所需的 shards 計算。

Tensor Parallelism 支援以下平行化樣式

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[原始碼][原始碼]

以列方向 (column-wise) 分割相容的 nn.Module。目前支援 nn.Linear 和 nn.Embedding。使用者可以將其與 RowwiseParallel 組合在一起,以實現更複雜模組的 shards (例如 MLP、Attention)。

關鍵字引數
  • input_layouts (Placement, optional) – nn.Module 的輸入張量的 DTensor 佈局,這用於註釋輸入張量以成為 DTensor。如果未指定,我們假設輸入張量已被複製。

  • output_layouts (Placement, optional) – nn.Module 輸出的 DTensor 佈局,這用於確保 nn.Module 的輸出具有使用者所需的佈局。如果未指定,則輸出張量會在最後一個維度上進行 shards。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸出,預設值:True。

回傳

一個 ParallelStyle 物件,表示 nn.Module 的 Colwise shards。

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意

預設情況下,如果未指定 output_layouts,則 ColwiseParallel 輸出會在最後一個維度上進行 shards。如果存在需要特定張量形狀的運算符 (例如,在配對的 RowwiseParallel 之前),請記住,如果輸出已 shards,則可能需要將運算符調整為 shards 的大小。

class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[原始碼][原始碼]

以行方向 (row-wise) 分割相容的 nn.Module。目前支援 nn.Linear 和 nn.Embedding。使用者可以將其與 ColwiseParallel 組合在一起,以實現更複雜模組的 shards (例如 MLP、Attention)。

關鍵字引數
  • input_layouts (Placement, optional) – nn.Module 的輸入張量的 DTensor 佈局,這用於註釋輸入張量以成為 DTensor。如果未指定,我們假設輸入張量在最後一個維度上進行 shards。

  • output_layouts (Placement, optional) – nn.Module 輸出的 DTensor 佈局,這用於確保 nn.Module 的輸出具有使用者所需的佈局。如果未指定,則輸出張量會被複製。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸出,預設值:True。

回傳

一個 ParallelStyle 物件,表示 nn.Module 的 Rowwise shards。

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
>>> ...
class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[原始碼][原始碼]

SequenceParallel 複製相容的 nn.Module 參數,並執行序列維度上分片的輸入計算。目前支援 nn.LayerNormnn.DropoutRMSNorm python 實作

此樣式實作論文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。

如果傳遞到此 nn.Module 的輸入是一個 torch.Tensor,則假定輸入已在序列維度上分片,並將輸入轉換為在序列維度上分片的 DTensor。如果傳遞到此 nn.Module 的輸入已經是一個 DTensor,但未在序列維度上分片,則它會重新分配輸入,使其在序列維度上分片。

nn.Module 的輸出將在序列維度上分片。

關鍵字引數
  • sequence_dim (int, optional) – nn.Module 的輸入張量的序列維度,用於註解輸入張量以成為在序列維度上分片的 DTensor,預設值:1。

  • use_local_output (bool, optional) – 是否對模組輸出使用本機 torch.Tensor 而不是 DTensor,預設值:False。

回傳

一個 ParallelStyle 物件,表示 nn.Module 的序列平行。

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...

注意

SequenceParallel 樣式假定如果 nn.Module 中存在權重(即 nn.LayerNormRMSNorm,並且它們預設具有 1 的初始化),則進行 1 的初始化。如果您對這些模組上的權重進行了自訂初始化,則需要在平行化之前/之後廣播權重,以確保它們被複製。

為了簡單地使用 DTensor 佈局配置 nn.Module 的輸入和輸出,並執行必要的佈局重新分配,而無需將模組參數分發到 DTensor,可以在呼叫 parallelize_module 時,在 parallelize_plan 中使用以下 ParallelStyle

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[原始碼][原始碼]

配置 nn.Module 的輸入,以便在運行時根據 input_layouts 將 nn.Module 的輸入張量轉換為 DTensor,並根據 desired_input_layouts 執行佈局重新分配。

關鍵字引數
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 的輸入張量的 DTensor 佈局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 torch.Tensor 或不需要轉換為 DTensor,則需要將 None 指定為佔位符。預設值:None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 的輸入張量的所需 DTensor 佈局,用於確保 nn.Module 的輸入具有所需的 DTensor 佈局。此引數需要與 input_layouts 具有相同的長度。預設值:None。

  • input_kwarg_layouts (Dict[str, Placement]) – nn.Module 的輸入 kwargs 的 DTensor 佈局,用於將輸入 kwarg 張量轉換為 DTensor。預設值:None

  • desired_input_kwarg_layouts – (Dict[str, Placement]): nn.Module 的輸入 kwargs 的所需 DTensor 佈局,用於確保 nn.Module 的輸入具有所需的 DTensor 佈局。預設值:None。

  • use_local_output (bool,選用) – 是否對模組輸入使用本地端的 torch.Tensor 而非 DTensor,預設值為 False。

回傳

一個 ParallelStyle 物件,用於準備 nn.Module 輸入的分片佈局。

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...)
>>>         ),
>>>     }
>>> )
class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[原始碼][原始碼]

設定 nn.Module 的輸出,以便在運行時根據 output_layouts 將 nn.Module 的輸出張量轉換為 DTensors,並根據 desired_output_layouts 執行佈局重新分配。

關鍵字引數
  • output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 輸出張量的 DTensor 佈局,如果輸出張量是 torch.Tensor,則使用此參數將輸出張量轉換為 DTensors。 如果某些輸出不是 torch.Tensor 或不需要轉換為 DTensors,則需要將 None 指定為佔位符。

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 輸出張量的期望 DTensor 佈局,用於確保 nn.Module 的輸出具有期望的 DTensor 佈局。

  • use_local_output (bool,選用) – 是否對模組輸出使用本地端的 torch.Tensor 而非 DTensor,預設值為 True。

回傳

一個 ParallelStyle 物件,用於準備 nn.Module 輸出的分片佈局。

範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(),
>>>         desired_output_layouts=Shard(0)
>>>     )
>>> )

注意

當使用 Shard(dim) 作為上述 ParallelStyle 的輸入/輸出佈局時,我們假設輸入/輸出激活張量在 TP 操作的 DeviceMesh 上的張量維度 dim 上均勻分片。 例如,由於 RowwiseParallel 接受在最後一個維度上分片的輸入,因此它假設輸入張量已經在最後一個維度上均勻分片。 對於不均勻分片激活張量的情況,可以直接將 DTensor 傳遞到分區模組,並使用 use_local_output=False 在每個 ParallelStyle 之後返回 DTensor,其中 DTensor 可以追蹤不均勻的分片資訊。

對於像 Transformer 這樣的模型,我們建議使用者在 parallelize_plan 中一起使用 ColwiseParallelRowwiseParallel,以實現整個模型 (即 Attention 和 MLP) 的期望分片。

透過以下上下文管理器支援並行化交叉熵損失計算(損失並行性)

torch.distributed.tensor.parallel.loss_parallel()[原始碼][原始碼]

一個啟用損失並行性的上下文管理器,當輸入在類別維度上分片時,可以執行有效的並行化損失計算。 目前僅支援交叉熵損失。

在此上下文管理器中,可以像往常一樣使用 cross_entropy()CrossEntropyLoss,但對輸入參數有以下假設。 如果有的話,相應的 backward() 呼叫也需要在這個上下文管理器下進行。

參數
  • input (DTensor) – 輸入 logits。 假設在類別維度上分片。

  • target (Union[torch.Tensor, DTensor]) – 必須是真實類別索引 (目前不支援類別機率)。 假設在 DeviceMesh 上複製。

  • weight (Union[torch.Tensor, DTensor], 選用) – 如果給定,假設在 DeviceMesh 上複製。

  • label_smoothing – 目前不支援。

回傳

一個複製的 DTensor

範例

此處手動建立一個分片的 DTensor 來展示其用法。 在實務中,它通常是 TP 模組的輸出。

>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告

loss_parallel API 是實驗性的,可能會有所變動。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源