模型並行¶
DistributedModelParallel
是使用 TorchRec 優化進行分散式訓練的主要 API。
- class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)¶
模型並行的入口點。
- 參數:
module (nn.Module) – 要包裝的模組。
env (Optional[ShardingEnv]) – 具有 process group 的 sharding 環境。
device (Optional[torch.device]) – 計算設備,預設為 cpu。
plan (Optional[ShardingPlan]) – sharding 時要使用的 plan,預設為 EmbeddingShardingPlanner.collective_plan()。
sharders (Optional[List[ModuleSharder[nn.Module]]]) – 可用於 sharding 的 ModuleSharders,預設為 EmbeddingBagCollectionSharder()。
init_data_parallel (bool) – data-parallel 模組可以是 lazy 的,也就是說,它們會延遲參數初始化,直到第一次 forward pass。傳遞 True 以延遲 data parallel 模組的初始化。先執行第一次 forward pass,然後呼叫 DistributedModelParallel.init_data_parallel()。
init_parameters (bool) – 初始化仍然在 meta 設備上的模組的參數。
data_parallel_wrapper (Optional[DataParallelWrapper]) – data parallel 模組的自訂包裝器。
範例
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- copy(device: device) DistributedModelParallel ¶
透過呼叫每個模組的自訂複製過程,以遞迴方式將子模組複製到新的設備,因為某些模組需要使用原始參考(例如用於推論的 ShardedModule)。
- forward(*args, **kwargs) Any ¶
定義每次呼叫時執行的計算。
應由所有子類別覆寫。
注意
雖然 forward pass 的配方需要在這個函數中定義,但應該在之後呼叫
Module
實例,而不是這個函數,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。
- init_data_parallel() None ¶
請參閱 init_data_parallel 建構子引數以了解用法。可以安全地多次呼叫此方法。
- load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys ¶
將參數和緩衝區從
state_dict
複製到此模組及其後代中。如果
strict
為True
,則state_dict
的鍵必須與此模組的state_dict()
函數傳回的鍵完全匹配。警告
如果
assign
為True
,除非get_swap_module_params_on_conversion()
為True
,否則優化器必須在呼叫load_state_dict
之後建立。- 參數:
state_dict (dict) – 包含參數和持久緩衝區的字典。
strict (bool, optional) – 是否嚴格要求
state_dict
中的鍵與此模組的state_dict()
函數返回的鍵匹配。預設值:True
assign (bool, optional) – 設定為
False
時,會保留目前模組中 tensors 的屬性,而設定為True
則會保留 state dict 中 tensors 的屬性。唯一的例外是requires_grad
欄位。Default: ``False`
- 傳回:
- missing_keys 是一個字串列表,包含此模組預期的任何鍵
但在提供的
state_dict
中遺失。
- unexpected_keys 是一個字串列表,包含此模組未預期的鍵
但在提供的
state_dict
中存在。
- 傳回類型:
具有
missing_keys
和unexpected_keys
欄位的NamedTuple
注意
如果參數或緩衝區註冊為
None
,且其對應的鍵存在於state_dict
中,load_state_dict()
將會引發RuntimeError
。
- property module: Module¶
直接存取分片模組的屬性,該模組不會被包裝在 DDP、FSDP、DMP 或任何其他並行包裝器中。
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
傳回模組緩衝區上的迭代器,產生緩衝區的名稱以及緩衝區本身。
- 參數:
prefix (str) – 要前置於所有緩衝區名稱的前綴。
recurse (bool, optional) – 如果為 True,則產生此模組和所有子模組的緩衝區。 否則,僅產生作為此模組直接成員的緩衝區。 預設為 True。
remove_duplicate (bool, optional) – 是否移除結果中重複的緩衝區。 預設為 True。
- 產生:
(str, torch.Tensor) – 包含名稱和緩衝區的 Tuple
範例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
傳回模組參數上的迭代器,產生參數的名稱以及參數本身。
- 參數:
prefix (str) – 要前置於所有參數名稱的前綴。
recurse (bool) – 如果為 True,則產生此模組和所有子模組的參數。 否則,僅產生作為此模組直接成員的參數。
remove_duplicate (bool, optional) – 是否移除結果中重複的參數。 預設為 True。
- 產生:
(str, Parameter) – 包含名稱和參數的 Tuple
範例
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
傳回一個字典,其中包含對模組整體狀態的參考。
包含參數和持久緩衝區 (例如,running averages)。 鍵是對應的參數和緩衝區名稱。 設置為
None
的參數和緩衝區不包含在內。注意
返回的物件是淺拷貝。 它包含對模組的參數和緩衝區的參考。
警告
目前,
state_dict()
也接受位置引數,依序對應destination
、prefix
和keep_vars
。 但是,這種做法已被棄用,並且在未來的版本中將強制使用關鍵字引數。警告
請避免使用引數
destination
,因為它不是為終端使用者設計的。- 參數:
destination (dict, optional) – 如果提供,模組的狀態將更新到字典中,並傳回相同的物件。 否則,將建立並傳回一個
OrderedDict
。 預設值:None
。prefix (str, optional) – 添加到參數和緩衝區名稱的前綴,用於組成 state_dict 中的鍵。 預設值:
''
。keep_vars (bool, optional) – 預設情況下,state dict 中傳回的
Tensor
會與自動梯度分離。 如果將其設置為True
,則不會執行分離。 預設值:False
。
- 傳回:
一個包含模組完整狀態的字典
- 傳回類型:
dict
範例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']