torch.nn.utils.convert_conv3d_weight_memory_format¶
- torch.nn.utils.convert_conv3d_weight_memory_format(module, memory_format)[原始碼][原始碼]¶
轉換
nn.Conv3d.weight
的memory_format
到指定的memory_format
。 此轉換會遞迴地應用於巢狀的nn.Module
,包括module
本身。 請注意,此函數僅更改 memory_format,但不會更改每個維度的語義。 此函數用於方便計算以採用 NHWC 核心,這為在具有計算能力 >= 7.0 的 CUDA 裝置上的 fp16 資料提供了相當大的加速。注意
呼叫
model.to(memory_format=torch.channels_last_3d)
比實用函數convert_conv3d_weight_memory_format
更具侵略性。 任何具有 4d 權重的層都會受到model.to
的影響,這不一定能從轉換為指定的memory_format
中受益。 我們有信心的其中一點是,cuDNN 中卷積的 NDHWC(channels_last_3d) 轉換,因為即使在必須對輸入張量應用排列的情況下,在 NDHWC 中執行卷積也是有益的。因此,我們這裡的策略是僅將卷積的權重轉換為 channels_last_3d。 這確保了:1. 將使用快速卷積核心,其好處可能超過排列的開銷(如果輸入格式不同)。 2. 不會在不從 memory_format 轉換中受益的層上應用不必要的排列。
最佳情況是卷積層之間的層與 channels last 相容。 輸入張量會在遇到第一個卷積層時被排列為 channels last,並保持在該記憶體格式中。 因此,後續的卷積將不需要排列其輸入張量。
如果卷積層之間存在不相容於 channels last 的層,我們需要將輸入張量排列回連續格式以供該層使用。 輸入張量將以連續格式通過剩餘層,並在遇到另一個卷積層時被排列為 channels last。 將該排列傳播到較早的層是沒有意義的,因為大多數層對
memory_format
的變化不太敏感。當 PyTorch 支援排列融合時,此聲明可能會改變,因為可能會有一個比緊接在卷積之前融合排列更好的位置。
- 參數
module (nn.Module) –
nn.Conv3d
&nn.ConvTranspose3d
或容器nn.Module
memory_format – 使用者指定的
memory_format
,例如torch.channels_last
或torch.contiguous_format
- 返回
具有更新的
nn.Conv3d
的原始模組
範例
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv3d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> out = model(input)