捷徑

fuse_modules

class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[來源][來源]

將模組列表融合為單一模組。

僅融合以下模組序列:conv, bn conv, bn, relu conv, relu linear, relu bn, relu 所有其他序列保持不變。對於這些序列,將列表中的第一個項目替換為融合模組,並將其餘模組替換為恆等式。

參數
  • model – 包含要融合的模組的模型

  • modules_to_fuse – 要融合的模組名稱列表的列表。如果只有一個要融合的模組列表,也可以是字串列表。

  • inplace – 布林值,指定是否在模型上進行就地融合,預設會傳回新模型

  • fuser_func – 函數,接收模組列表作為輸入,並輸出相同長度的融合模組列表。例如,fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()]。預設值為 torch.ao.quantization.fuse_known_modules

  • fuse_custom_config_dict – 用於融合的自定義配置

# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
    # Additional fuser_method mapping
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
    },
}
返回值

具有融合模組的模型。如果 inplace=True,則會創建一個新的副本。

範例

>>> m = M().eval()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = M().eval()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

文件

獲取 PyTorch 的完整開發者文檔

查看文檔

教程

獲取適用於初學者和高級開發者的深度教程

查看教程

資源

查找開發資源並獲得您問題的解答

查看資源