使用 while_loop
優化記憶體使用率¶
while_loop
¶
while_loop
取代純 Python while
迴圈,PyTorch 透過 torch._higher_order_ops.while_loop 支援 while_loop
。PyTorch/XLA 透過 XLA::While
為 torch._higher_order_ops.while_loop
提供實驗性的 XLA 後端支援。
用法:¶
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
cond_fn
:使用者定義的條件函數。body_fn
:使用者定義的迴圈主體函數。init
:初始值 (tuple 或 list)。
使用 while_loop
的簡單範例:¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))
控制組測試案例¶
為了更好比較 純 Python while 迴圈
和 while_loop
之間的差異,有一個稱為純 Python while
迴圈的測試案例,其邏輯類似:累加加 1 十次
使用純 Python while
迴圈的控制組範例¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')
PyTorch/XLA 將在 2.4 版本中包含 while_loop
支援以及測試案例,對 fori_loop
的支援將在 2.4 版本之後新增。對於 while_loop
,目前我們只應強制定義具有相同 input
和 output(return args)
形狀的 body_fn