DataParallel¶
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source][source]¶
在模組層級實作資料平行處理。
這個容器藉由將輸入資料分割到指定的裝置上,以批次維度進行分塊 (其他物件將會在每個裝置上複製一次),來平行化給定的
module
的應用。在正向傳播過程中,模組會在每個裝置上複製,並且每個複製品會處理一部分的輸入。在反向傳播過程中,每個複製品的梯度會被加總到原始模組中。批次大小應該要大於所使用的 GPU 數量。
警告
建議使用
DistributedDataParallel
,而不是這個類別,來進行多 GPU 訓練,即使只有單一節點也是一樣。請參閱:使用 nn.parallel.DistributedDataParallel 取代 multiprocessing 或 nn.DataParallel 和 分散式資料平行處理。允許將任意位置和關鍵字輸入傳遞到 DataParallel,但某些類型會被特別處理。張量將在指定的維度(預設為 0)上進行**分散 (scattered)**。 tuple、list 和 dict 類型將會進行淺複製。其他類型將會在不同的執行緒之間共享,如果在模型的正向傳播中寫入,可能會損壞。
平行化的
module
在執行此DataParallel
模組之前,其參數和緩衝區必須位於device_ids[0]
上。警告
在每次正向傳播中,
module
會在每個裝置上被**複製 (replicated)**,因此在forward
中對正在執行的模組所做的任何更新都會遺失。例如,如果module
有一個計數器屬性,該屬性在每次forward
中都會遞增,則它將始終保持在初始值,因為更新是在複製品上完成的,這些複製品在forward
之後會被銷毀。然而,DataParallel
保證在device[0]
上的複製品,會使其參數和緩衝區與基礎平行化的module
共享儲存空間。因此,對device[0]
上的參數或緩衝區進行**原地 (in-place)** 更新將會被記錄。例如,BatchNorm2d
和spectral_norm()
依賴此行為來更新緩衝區。警告
在
module
及其子模組上定義的正向和反向鉤子 (hooks) 將會被調用len(device_ids)
次,每次都帶有位於特定裝置上的輸入。特別是,只能保證鉤子會以相對於在相應裝置上運算的正確順序執行。例如,不能保證透過register_forward_pre_hook()
設置的鉤子會在 所有len(device_ids)
個forward()
呼叫之前執行,但可以保證每個這樣的鉤子會在該裝置的相應forward()
呼叫之前執行。警告
當
module
在forward()
中回傳一個純量(即 0 維張量)時,這個包裝器將會回傳一個長度等於資料平行處理中所使用裝置數量的向量,其中包含來自每個裝置的結果。注意
在
DataParallel
包裝的Module
中使用pack sequence -> recurrent network -> unpack sequence
模式時,存在一個微妙之處。有關詳細資訊,請參閱常見問題解答中的 我的遞迴網路無法與資料平行處理一起運作 區段。- 參數
module (Module) – 要平行化的模組
device_ids (list of int or torch.device) – CUDA 裝置 (預設:所有裝置)
output_device (int or torch.device) – 輸出的裝置位置 (預設:device_ids[0])
- 變數
module (Module) – 要平行化的模組
範例
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU