快捷方式

模型並行

DistributedModelParallel 是使用 TorchRec 優化進行分散式訓練的主要 API。

class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)

模型並行的入口點。

參數:
  • module (nn.Module) – 要包裝的模組。

  • env (Optional[ShardingEnv]) – 具有 process group 的 sharding 環境。

  • device (Optional[torch.device]) – 計算設備,預設為 cpu。

  • plan (Optional[ShardingPlan]) – sharding 時要使用的 plan,預設為 EmbeddingShardingPlanner.collective_plan()

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – 可用於 sharding 的 ModuleSharders,預設為 EmbeddingBagCollectionSharder()

  • init_data_parallel (bool) – data-parallel 模組可以是 lazy 的,也就是說,它們會延遲參數初始化,直到第一次 forward pass。傳遞 True 以延遲 data parallel 模組的初始化。先執行第一次 forward pass,然後呼叫 DistributedModelParallel.init_data_parallel()。

  • init_parameters (bool) – 初始化仍然在 meta 設備上的模組的參數。

  • data_parallel_wrapper (Optional[DataParallelWrapper]) – data parallel 模組的自訂包裝器。

範例

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
copy(device: device) DistributedModelParallel

透過呼叫每個模組的自訂複製過程,以遞迴方式將子模組複製到新的設備,因為某些模組需要使用原始參考(例如用於推論的 ShardedModule)。

forward(*args, **kwargs) Any

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

注意

雖然 forward pass 的配方需要在這個函數中定義,但應該在之後呼叫 Module 實例,而不是這個函數,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。

init_data_parallel() None

請參閱 init_data_parallel 建構子引數以了解用法。可以安全地多次呼叫此方法。

load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

將參數和緩衝區從 state_dict 複製到此模組及其後代中。

如果 strictTrue,則 state_dict 的鍵必須與此模組的 state_dict() 函數傳回的鍵完全匹配。

警告

如果 assignTrue,除非 get_swap_module_params_on_conversion()True,否則優化器必須在呼叫 load_state_dict 之後建立。

參數:
  • state_dict (dict) – 包含參數和持久緩衝區的字典。

  • strict (bool, optional) – 是否嚴格要求 state_dict 中的鍵與此模組的 state_dict() 函數返回的鍵匹配。預設值: True

  • assign (bool, optional) – 設定為 False 時,會保留目前模組中 tensors 的屬性,而設定為 True 則會保留 state dict 中 tensors 的屬性。唯一的例外是 requires_grad 欄位。 Default: ``False`

傳回:

  • missing_keys 是一個字串列表,包含此模組預期的任何鍵

    但在提供的 state_dict 中遺失。

  • unexpected_keys 是一個字串列表,包含此模組未預期的鍵

    但在提供的 state_dict 中存在。

傳回類型:

具有 missing_keysunexpected_keys 欄位的 NamedTuple

注意

如果參數或緩衝區註冊為 None,且其對應的鍵存在於 state_dict 中,load_state_dict() 將會引發 RuntimeError

property module: Module

直接存取分片模組的屬性,該模組不會被包裝在 DDP、FSDP、DMP 或任何其他並行包裝器中。

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

傳回模組緩衝區上的迭代器,產生緩衝區的名稱以及緩衝區本身。

參數:
  • prefix (str) – 要前置於所有緩衝區名稱的前綴。

  • recurse (bool, optional) – 如果為 True,則產生此模組和所有子模組的緩衝區。 否則,僅產生作為此模組直接成員的緩衝區。 預設為 True。

  • remove_duplicate (bool, optional) – 是否移除結果中重複的緩衝區。 預設為 True。

產生:

(str, torch.Tensor) – 包含名稱和緩衝區的 Tuple

範例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

傳回模組參數上的迭代器,產生參數的名稱以及參數本身。

參數:
  • prefix (str) – 要前置於所有參數名稱的前綴。

  • recurse (bool) – 如果為 True,則產生此模組和所有子模組的參數。 否則,僅產生作為此模組直接成員的參數。

  • remove_duplicate (bool, optional) – 是否移除結果中重複的參數。 預設為 True。

產生:

(str, Parameter) – 包含名稱和參數的 Tuple

範例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

傳回一個字典,其中包含對模組整體狀態的參考。

包含參數和持久緩衝區 (例如,running averages)。 鍵是對應的參數和緩衝區名稱。 設置為 None 的參數和緩衝區不包含在內。

注意

返回的物件是淺拷貝。 它包含對模組的參數和緩衝區的參考。

警告

目前,state_dict() 也接受位置引數,依序對應 destinationprefixkeep_vars。 但是,這種做法已被棄用,並且在未來的版本中將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

參數:
  • destination (dict, optional) – 如果提供,模組的狀態將更新到字典中,並傳回相同的物件。 否則,將建立並傳回一個 OrderedDict。 預設值:None

  • prefix (str, optional) – 添加到參數和緩衝區名稱的前綴,用於組成 state_dict 中的鍵。 預設值:''

  • keep_vars (bool, optional) – 預設情況下,state dict 中傳回的 Tensor 會與自動梯度分離。 如果將其設置為 True,則不會執行分離。 預設值:False

傳回:

一個包含模組完整狀態的字典

傳回類型:

dict

範例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources