• 教學 >
  • (beta) PyTorch 中的 Channels Last Memory Format
捷徑

(beta) PyTorch 中的 Channels Last Memory Format

建立於:2020 年 4 月 20 日 | 最後更新:2023 年 10 月 4 日 | 最後驗證:2024 年 11 月 5 日

作者Vitaly Fedyunin

什麼是 Channels Last?

Channels last memory format 是一種在記憶體中排序 NCHW 張量的替代方法,同時保留維度順序。 Channels last 張量以通道成為最密集維度的方式排序(又名逐像素儲存圖像)。

例如,NCHW 張量的經典(連續)儲存(在我們的案例中,它是兩個具有 3 個顏色通道的 4x4 圖像)看起來像這樣

classic_memory_format

Channels last memory format 以不同的方式排序資料

channels_last_memory_format

Pytorch 透過利用現有的 strides 結構,支援 memory format (並提供與現有模型的向後相容性,包括 eager, JIT, 和 TorchScript)。 例如,Channels last 格式中的 10x3x16x16 批次將具有等於 (768, 1, 48, 3) 的 strides。

Channels last memory format 僅針對 4D NCHW 張量實作。

Memory Format API

以下是如何在連續和 channels last memory format 之間轉換張量。

經典 PyTorch 連續張量

import torch

N, C, H, W = 10, 3, 32, 32
x = torch.empty(N, C, H, W)
print(x.stride())  # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)

轉換運算符

x = x.to(memory_format=torch.channels_last)
print(x.shape)  # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride())  # Outputs: (3072, 1, 96, 3)
torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3)

回到連續

x = x.to(memory_format=torch.contiguous_format)
print(x.stride())  # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)

替代選項

x = x.contiguous(memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

格式檢查

print(x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
True

兩個 API tocontiguous 之間存在細微差異。 我們建議在明確轉換張量的 memory format 時,堅持使用 to

對於一般情況,這兩個 API 的行為相同。 然而,在大小為 NCHW 的 4D 張量的特殊情況下,當:C==1H==1 && W==1 時,只有 to 會產生適當的 stride 來表示 channels last memory format。

這是因為在上述兩種情況的任一種情況下,張量的 memory format 都是模糊的,即大小為 N1HW 的連續張量在記憶體儲存中既是 contiguous 也是 channels last。 因此,對於給定的 memory format,它們已經被認為是 is_contiguous,因此 contiguous 呼叫會變成 no-op 並且不會更新 stride。 相反,to 會重新stride張量,並在尺寸大小為 1 的維度上使用有意義的 stride,以便正確表示預期的 memory format。

special_x = torch.empty(4, 1, 4, 4)
print(special_x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
print(special_x.is_contiguous(memory_format=torch.contiguous_format))  # Outputs: True
True
True

同樣的事情適用於顯式置換 API permute。 在可能發生歧義的特殊情況下,permute 不能保證產生正確攜帶預期 memory format 的 stride。 我們建議使用具有顯式 memory format 的 to 來避免意外行為。

另外,在極端情況下,其中三個非批次維度都等於 1 (C==1 && H==1 && W==1),目前的實作無法將張量標記為 channels last memory format。

建立為 channels last

x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

clone 保留 memory format

y = x.clone()
print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

tocudafloat ... 保留 memory format

if torch.cuda.is_available():
    y = x.cuda()
    print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

empty_like*_like 運算符保留 memory format

y = torch.empty_like(x)
print(y.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

逐點運算符保留 memory format

z = x + y
print(z.stride())  # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)

使用 cudnn 後端的 ConvBatchnorm 模組支援 channels last(僅適用於 cuDNN >= 7.6)。 與二元逐點運算符不同,卷積模組將 channels last 作為主要的 memory format。 如果所有輸入都採用連續 memory format,則運算符會產生採用連續 memory format 的輸出。 否則,輸出將採用 channels last memory format。

if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
    model = torch.nn.Conv2d(8, 4, 3).cuda().half()
    model = model.to(memory_format=torch.channels_last)  # Module parameters need to be channels last

    input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
    input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)

    out = model(input)
    print(out.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
True

當輸入張量到達沒有 channels last 支援的運算符時,應該在核心中自動套用置換,以在輸入張量上恢復連續性。 這會引入 overhead 並停止 channels last memory format 傳播。 然而,它可以保證正確的輸出。

效能提升

Channels last memory format 優化在 GPU 和 CPU 上都可用。 在 GPU 上,在 NVIDIA 的硬體上觀察到最顯著的效能提升,該硬體支援在降低的精度 (torch.float16) 上運行的 Tensor Cores。 透過 channels last 比較連續格式,我們能夠使用 'AMP (Automated Mixed Precision)' 訓練腳本來實現超過 22% 的效能提升。 我們的腳本使用 NVIDIA 提供的 AMP https://github.com/NVIDIA/apex

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2  ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)

傳遞 --channels-last true 允許以 Channels Last 格式執行模型,觀察到效能提升 22%。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)

以下模型列表完全支援 Channels Last,並且在 Volta 裝置上顯示出 8%-35% 的效能提升:alexnetmnasnet0_5mnasnet0_75mnasnet1_0mnasnet1_3mobilenet_v2resnet101resnet152resnet18resnet34resnet50resnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0shufflenet_v2_x1_5shufflenet_v2_x2_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

以下模型列表完全支援 Channels Last,並且在 Intel(R) Xeon(R) Ice Lake(或更新)CPU 上顯示出 26%-76% 的效能提升:alexnetdensenet121densenet161densenet169googlenetinception_v3mnasnet0_5mnasnet1_0resnet101resnet152resnet18resnet34resnet50resnext101_32x8dresnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

轉換現有模型

Channels Last 的支援不限於現有模型,因為任何模型都可以轉換為 Channels Last,並在輸入(或某些權重)被正確格式化後,透過圖形傳播格式。

# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last)  # Replace with your model

# Need to be done for every input
input = input.to(memory_format=torch.channels_last)  # Replace with your input
output = model(input)

然而,並非所有運算子都完全轉換為支援 Channels Last(通常會改為返回連續的輸出)。在上面發布的範例中,不支援 Channels Last 的層會停止記憶體格式的傳播。儘管如此,由於我們已將模型轉換為 Channels Last 格式,這意味著每個卷積層(其 4 維權重以 Channels Last 記憶體格式儲存)將恢復 Channels Last 記憶體格式,並受益於更快的核心。

但是,不支援 Channels Last 的運算子確實會引入通過排列造成的額外負擔。您可以選擇性地調查並識別模型中不支援 Channels Last 的運算子,如果您想要提高已轉換模型的效能。

這意味著您需要根據支援的運算子列表驗證使用的運算子列表 https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,或將記憶體格式檢查引入到 eager execution 模式並運行您的模型。

運行以下程式碼後,如果運算子的輸出與輸入的記憶體格式不符,運算子將引發異常。

def contains_cl(args):
    for t in args:
        if isinstance(t, torch.Tensor):
            if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
                return True
        elif isinstance(t, list) or isinstance(t, tuple):
            if contains_cl(list(t)):
                return True
    return False


def print_inputs(args, indent=""):
    for t in args:
        if isinstance(t, torch.Tensor):
            print(indent, t.stride(), t.shape, t.device, t.dtype)
        elif isinstance(t, list) or isinstance(t, tuple):
            print(indent, type(t))
            print_inputs(list(t), indent=indent + "    ")
        else:
            print(indent, t)


def check_wrapper(fn):
    name = fn.__name__

    def check_cl(*args, **kwargs):
        was_cl = contains_cl(args)
        try:
            result = fn(*args, **kwargs)
        except Exception as e:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            print("-------------------")
            raise e
        failed = False
        if was_cl:
            if isinstance(result, torch.Tensor):
                if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
                    print(
                        "`{}` got channels_last input, but output is not channels_last:".format(name),
                        result.shape,
                        result.stride(),
                        result.device,
                        result.dtype,
                    )
                    failed = True
        if failed and True:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            raise Exception("Operator `{}` lost channels_last property".format(name))
        return result

    return check_cl


old_attrs = dict()


def attribute(m):
    old_attrs[m] = dict()
    for i in dir(m):
        e = getattr(m, i)
        exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
        if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
            try:
                old_attrs[m][i] = e
                setattr(m, i, check_wrapper(e))
            except Exception as e:
                print(i)
                print(e)


attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch)

如果您發現有運算子不支援 Channels Last 張量,並且想要貢獻,請隨時使用以下開發人員指南 https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators

下面的程式碼是用來恢復 torch 的屬性。

for (m, attrs) in old_attrs.items():
    for (k, v) in attrs.items():
        setattr(m, k, v)

待辦事項

還有許多事情要做,例如

  • 解決 N1HWNC11 張量的模糊性;

  • 測試分散式訓練支援;

  • 改善運算子涵蓋範圍。

如果您有回饋和/或改進建議,請透過建立 issue 讓我們知道。

腳本總執行時間:(0 分鐘 0.041 秒)

由 Sphinx-Gallery 產生的圖庫

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得適合初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源