注意
點擊這裡下載完整的範例程式碼
使用 torch.compiler.set_stance
進行動態編譯控制¶
作者: William Wen
torch.compiler.set_stance
是一個 torch.compiler
API,使您能夠在對模型的不同呼叫中更改 torch.compile
的行為,而無需將 torch.compile
重新應用於您的模型。
此食譜提供了一些關於如何使用 torch.compiler.set_stance
的範例。
描述¶
torch.compile.set_stance
可以用作裝飾器、上下文管理器或原始函數,以更改 torch.compile
在對模型的不同呼叫中的行為。
在下面的範例中,"force_eager"
stance 忽略所有 torch.compile
指令。
import torch
@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1
inp = torch.zeros(3)
print(foo(inp)) # compiled, prints 1
tensor([1., 1., 1.])
範例裝飾器用法
@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)
print(bar(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
範例上下文管理器用法
with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
範例原始函數用法
torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")
print(foo(inp)) # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])
torch.compile
stance 只能在任何 torch.compile
區域之外更改。 嘗試這樣做將導致錯誤。
@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1
try:
baz(inp)
except Exception as e:
print(e)
@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1
@torch.compile
def outer(x):
# error!
return inner(x)
try:
outer(inp)
except Exception as e:
print(e)
Attempt to trace forbidden callable <function set_stance at 0x7fd5e807d870>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
with torch.compiler.set_stance("force_eager"):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Attempt to trace forbidden callable <function inner at 0x7fd4a27bb0a0>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
return inner(x)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
- 其他 stances 包括
"default"
:預設 stance,用於正常編譯。"eager_on_recompile"
:在需要重新編譯時,立即執行代碼。 如果存在對輸入有效的快取編譯代碼,則仍將使用它。"fail_on_recompile"
:重新編譯函數時引發錯誤。
有關更多 stances 和選項,請參閱 torch.compiler.set_stance
文檔頁面。 將來也可能會添加更多 stances/選項。
範例¶
防止重新編譯¶
有些模型不希望進行任何重新編譯 - 例如,您可能始終具有相同形狀的輸入。 由於重新編譯可能很昂貴,因此我們可能希望在嘗試重新編譯時出錯,以便我們可以檢測並修復重新編譯案例。 "fail_on_recompilation"
stance 可以用於此。
@torch.compile
def my_big_model(x):
return torch.relu(x)
# first compilation
my_big_model(torch.randn(3))
with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'
如果出錯太具破壞性,我們可以改用 "eager_on_recompile"
,這將導致 torch.compile
回退到 eager 而不是出錯。 如果我們不希望重新編譯經常發生,但當需要重新編譯時,我們寧願付出積極運行的代價,而不是重新編譯的代價,這可能很有用。
@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1
# first compilation
print(my_huge_model(torch.zeros(3))) # 1
with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])
測量性能提升¶
torch.compiler.set_stance
可用於比較 eager 與編譯後的性能,而無需定義單獨的 eager 模型。
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x
inps = torch.randn(5, 5), torch.randn(5, 5)
with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
# warmups
for _ in range(3):
my_gigantic_model(*inps)
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00016115200519561766
compiled: 0.00016368000209331514
更快崩潰¶
在使用 "force_eager"
stance 編譯的迭代之前,先運行 eager 迭代,可以幫助我們在嘗試進行非常長的編譯之前,捕獲與 torch.compile
無關的錯誤。
@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)
try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)
sin() takes 1 positional argument but 2 were given
結論¶
在本食譜中,我們學習了如何使用 torch.compiler.set_stance
API 來修改 torch.compile
在對模型的不同呼叫中的行為,而無需重新應用它。 該食譜演示了如何使用 torch.compiler.set_stance
作為裝飾器、上下文管理器或原始函數來控制編譯 stances,例如 force_eager
、default
、eager_on_recompile
和 "fail_on_recompile"。
有關更多資訊,請參閱:torch.compiler.set_stance API 文檔。
腳本的總運行時間:(0 分鐘 13.738 秒)