• 文件 >
  • 使用 while_loop 優化記憶體使用率
快速鍵

使用 while_loop 優化記憶體使用率

while_loop

while_loop 取代純 Python while 迴圈,PyTorch 透過 torch._higher_order_ops.while_loop 支援 while_loop。PyTorch/XLA 透過 XLA::Whiletorch._higher_order_ops.while_loop 提供實驗性的 XLA 後端支援。

用法:

import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
  • cond_fn:使用者定義的條件函數。

  • body_fn:使用者定義的迴圈主體函數。

  • init:初始值 (tuple 或 list)。

使用 while_loop 的簡單範例:

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
...   return iteri > 0
...
>>> def body_fn(iteri, x):
...   return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))

控制組測試案例

為了更好比較 Python while 迴圈while_loop 之間的差異,有一個稱為純 Python while 迴圈的測試案例,其邏輯類似:累加加 1 十次

使用純 Python while 迴圈的控制組範例

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
...   init_val = init_val + 1
...   iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')

PyTorch/XLA 將在 2.4 版本中包含 while_loop 支援以及測試案例,對 fori_loop 的支援將在 2.4 版本之後新增。對於 while_loop,目前我們只應強制定義具有相同 inputoutput(return args) 形狀的 body_fn

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源