torch.nn.utils.convert_conv2d_weight_memory_format¶
- torch.nn.utils.convert_conv2d_weight_memory_format(module, memory_format)[原始碼][原始碼]¶
將
nn.Conv2d.weight
的memory_format
轉換為指定的memory_format
。此轉換會遞迴地應用於巢狀的
nn.Module
,包括module
本身。請注意,它僅更改 memory_format,而不改變每個維度的語義。此函數用於促進採用 NHWC kernels 的計算,這為 CUDA 設備上計算能力 >= 7.0 的 fp16 資料提供了相當大的加速。注意
呼叫
model.to(memory_format=torch.channels_last)
比實用函數convert_conv2d_weight_memory_format
更具侵略性。 任何具有 4d 權重的層都會受到model.to
的影響,這不一定能從轉換為指定的memory_format
中受益。 我們可以確信的一點是,cuDNN 中的卷積 (convolution) 的 NHWC(channels_last) 轉換,有利於在 NHWC 中執行卷積,即使在我們必須對輸入張量應用置換 (permutation) 的情況下也是如此。因此,我們在這裡的策略是僅將卷積的權重轉換為 channels_last。 這確保了:1. 將使用快速卷積 kernels,其好處可能超過置換的開銷(如果輸入格式不同)。 2. 不會在不需要記憶體格式轉換的層上應用不必要的置換。
最佳情況是,卷積層之間的層與 channels last 相容。 輸入張量會在遇到第一個卷積層時被置換為 channels last,並保持在該記憶體格式中。 因此,後續的卷積將不需要置換其輸入張量。
如果卷積層之間存在與 channels last 不相容的層,則需要將輸入張量置換回連續格式以供該層使用。 輸入張量將以連續格式經過剩餘層,並在遇到另一個卷積層時被置換為 channels last。 將該置換傳播到更早的層沒有意義,因為大多數層對於
memory_format
相當不敏感。當 PyTorch 支援置換融合時,此聲明可能會改變,因為可能存在比緊鄰卷積之前更好的融合置換的位置。
- 參數
module (nn.Module) –
nn.Conv2d
&nn.ConvTranspose2d
或容器nn.Module
memory_format – 使用者指定的
memory_format
,例如torch.channels_last
或torch.contiguous_format
- 返回
具有更新後的
nn.Conv2d
的原始模組
範例
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv2d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> out = model(input)