注意
點擊這裡下載完整範例程式碼
(beta) 使用 FX 建立簡單的 CPU 效能分析器¶
建立於:2021 年 3 月 4 日 | 最後更新:2024 年 1 月 16 日 | 最後驗證:未驗證
作者: James Reed
在本教學中,我們將使用 FX 來執行以下操作
以我們可以檢查和收集有關程式碼結構和執行的統計資訊的方式捕獲 PyTorch Python 程式碼
建立一個小類別,作為一個簡單的效能「分析器」,從實際執行中收集有關模型每個部分的執行時間統計資訊。
在本教學中,我們將使用 torchvision ResNet18 模型進行示範。
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
現在我們有了模型,我們想要更深入地研究它的效能。 也就是說,對於以下調用,模型的哪些部分佔用時間最長?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答這個問題的一個常見方法是瀏覽程式原始碼,添加在程式中各個點收集時間戳的程式碼,並比較這些時間戳之間的差異,以查看時間戳之間的區域需要多長時間。
這種技術當然適用於 PyTorch 程式碼,但如果我們不必複製模型程式碼並對其進行編輯,那就更好了,尤其是我們沒有編寫的程式碼(如這個 torchvision 模型)。 相反,我們將使用 FX 來自動化此「檢測」流程,而無需修改任何原始碼。
首先,讓我們解決一些導入問題(我們稍後會在程式碼中使用所有這些)。
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
注意
tabulate
是一個外部函式庫,不是 PyTorch 的依賴項。 我們將使用它來更輕鬆地視覺化效能資料。 請確保你已從你最喜歡的 Python 套件來源安裝它。
使用符號追蹤捕獲模型¶
接下來,我們將使用 FX 的符號追蹤機制來捕獲模型定義到我們可以操作和檢查的資料結構中。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
這為我們提供了 ResNet18 模型的 Graph 表示。 Graph 由一系列相互連接的節點組成。 每個節點代表 Python 程式碼中的一個呼叫點(無論是函式、模組還是方法),邊緣(表示為每個節點上的 args
和 kwargs
)代表在這些呼叫點之間傳遞的值。 有關 Graph 表示和 FX 其餘 API 的更多資訊,請參閱 FX 文件 https://pytorch.dev.org.tw/docs/master/fx.html。
建立分析器直譯器¶
接下來,我們將建立一個從 torch.fx.Interpreter
繼承的類別。 雖然 symbolic_trace
產生的 GraphModule
編譯了你在呼叫 GraphModule
時執行的 Python 程式碼,但執行 GraphModule
的另一種方式是透過逐個執行 Graph
中的每個 Node
。 這是 Interpreter
提供的功能:它逐節點地直譯圖形。
透過從 Interpreter
繼承,我們可以覆寫各種功能並安裝我們想要的分析行為。 目標是擁有一個可以將模型傳遞給它的物件,呼叫模型一次或多次,然後取得有關模型和模型每個部分在這些執行期間花費的時間的統計資訊。
讓我們定義我們的 ProfilingInterpreter
類別
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
注意
我們使用 Python 的 time.time
函數來獲取實際時間的時間戳並進行比較。 這不是最準確的測量效能方式,只能給我們一個初步的近似值。 我們使用這種簡單技術僅用於本教學中的演示目的。
研究 ResNet18 的效能¶
我們現在可以使用 ProfilingInterpreter
來檢查 ResNet18 模型的效能特徵;
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module maxpool 0.00859714 8.02576
call_module conv1 0.00653529 6.10095
call_module layer4_0_conv2 0.00636888 5.94559
call_module layer4_1_conv1 0.00583315 5.44547
call_module layer4_1_conv2 0.00581217 5.42588
call_module layer1_0_conv1 0.00575733 5.37469
call_module layer1_1_conv2 0.00560737 5.23469
call_module layer2_1_conv1 0.00524116 4.89282
call_module layer1_1_conv1 0.00520563 4.85966
call_module layer3_1_conv2 0.00516272 4.81959
call_module layer3_1_conv1 0.00511885 4.77864
call_module layer3_0_conv2 0.0050416 4.70653
call_module layer2_1_conv2 0.00477362 4.45635
call_module layer1_0_conv2 0.00472641 4.41229
call_module layer2_0_conv2 0.00456476 4.26138
call_module layer4_0_conv1 0.00383401 3.57919
call_module layer3_0_conv1 0.00297451 2.77682
call_module layer2_0_conv1 0.00259256 2.42026
call_module bn1 0.00228763 2.13559
call_module layer2_0_downsample_0 0.00136423 1.27356
call_function add_1 0.000645161 0.602282
call_module layer3_0_downsample_0 0.000638962 0.596495
call_function add 0.000619888 0.57869
call_module layer4_0_downsample_0 0.000584602 0.545749
call_module relu 0.000487328 0.454939
call_function add_3 0.000329733 0.307818
call_module layer1_1_bn2 0.000267506 0.249727
call_module layer1_0_bn1 0.000254631 0.237708
call_module layer1_1_bn1 0.000223398 0.208551
call_module layer1_0_bn2 0.000217915 0.203432
call_module fc 0.000207663 0.193861
call_module layer2_0_downsample_1 0.00019908 0.185848
call_module layer2_1_bn1 0.000183105 0.170936
call_module layer2_0_bn2 0.000172615 0.161143
call_module layer2_1_bn2 0.000161886 0.151127
call_module avgpool 0.000159979 0.149346
call_module layer2_0_bn1 0.000159264 0.148679
call_module layer3_1_bn2 0.000156403 0.146008
call_module layer3_1_bn1 0.000155449 0.145118
call_module layer4_1_bn2 0.000144243 0.134657
call_module layer3_0_downsample_1 0.000143051 0.133544
call_module layer4_0_bn2 0.000140667 0.131318
call_module layer4_1_bn1 0.000140429 0.131095
call_module layer3_0_bn2 0.000138998 0.12976
call_module layer1_0_relu 0.00013876 0.129537
call_module layer1_1_relu_1 0.000132561 0.123751
call_module layer4_0_bn1 0.000131369 0.122638
call_module layer3_0_bn1 0.000128984 0.120412
call_module layer1_0_relu_1 0.000127077 0.118631
call_module layer1_1_relu 0.000126839 0.118409
call_module layer4_0_downsample_1 0.000124693 0.116406
call_function add_2 0.000115156 0.107503
call_function add_7 0.000108719 0.101493
call_function add_6 9.94205e-05 0.0928129
call_module layer2_0_relu_1 9.77516e-05 0.0912549
call_module layer2_1_relu 9.75132e-05 0.0910323
call_module layer2_0_relu 9.53674e-05 0.0890292
call_module layer2_1_relu_1 9.48906e-05 0.088584
call_module layer4_1_relu 9.27448e-05 0.0865809
call_module layer4_0_relu 9.2268e-05 0.0861357
call_function add_5 8.79765e-05 0.0821294
call_module layer3_1_relu 8.08239e-05 0.0754522
call_module layer3_0_relu 8.05855e-05 0.0752296
call_module layer4_1_relu_1 7.89165e-05 0.0736716
call_module layer4_0_relu_1 7.82013e-05 0.0730039
call_function add_4 7.79629e-05 0.0727813
call_module layer3_0_relu_1 7.77245e-05 0.0725588
call_module layer3_1_relu_1 7.72476e-05 0.0721136
call_function flatten 4.24385e-05 0.039618
placeholder x 2.6226e-05 0.024483
output output 1.90735e-05 0.0178058
這裡有兩件事我們應該提出來
MaxPool2d
佔用了最多的時間。 這是一個已知的問題: https://github.com/pytorch/pytorch/issues/51393BatchNorm2d 也佔用了相當多的時間。 我們可以繼續這種思路,並在 FX 的 Conv-BN Fusion 教學中對此進行優化。
結論¶
正如我們所看到的,使用 FX,我們可以輕鬆捕獲 PyTorch 程式(即使是那些我們沒有原始碼的程式!),以機器可解釋的格式,並將其用於分析,例如我們在這裡進行的效能分析。 FX 為使用 PyTorch 程式開闢了一個令人興奮的可能世界。
最後,由於 FX 仍處於 beta 階段,我們很樂意聽取您在使用它時的任何意見。 請隨時使用 PyTorch 論壇 (https://discuss.pytorch.org/) 和問題追蹤器 (https://github.com/pytorch/pytorch/issues) 來提供您可能有的任何回饋。
腳本總運行時間: ( 0 分鐘 0.486 秒)