scan 和 scan_layers 使用指南¶
本指南說明如何在 PyTorch/XLA 中使用 scan
和 scan_layers
。
何時應使用此功能¶
如果您的模型具有許多同質(形狀相同、邏輯相同)層,例如 LLM,則應考慮使用 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_。這些模型的編譯速度可能很慢。scan_layers
是同質層迴圈的直接替代品,例如一批解碼器層。scan_layers
會追蹤第一層,並將編譯結果重複用於所有後續層,從而大幅縮短模型編譯時間。
``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 另一方面,是仿照 ``jax.lax.scan` <https://jax.dev.org.tw/en/latest/_autosummary/jax.lax.scan.html>`_ 建模的較低階高階運算。其主要目的是協助在底層實作 scan_layers
。不過,如果您想要編寫某種類型的迴圈邏輯,其中迴圈本身在編譯器中具有第一類表示法(特別是 XLA While
運算),您可能會發現它很有用。
scan_layers
範例¶
一般而言,Transformer 模型會將輸入嵌入傳遞通過一系列同質解碼器層,如下所示
def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
當此函式降階為 HLO 圖形時,for 迴圈會展開為一連串扁平運算,導致編譯時間過長。為了縮短編譯時間,您可以將 for 迴圈替換為對 scan_layers
的呼叫,如 ``decoder_with_scan.py` </examples/scan/decoder_with_scan.py>`_ 中所示
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
您可以從 pytorch/xla
原始碼簽出的根目錄執行下列命令,來訓練此解碼器模型。
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
scan
範例¶
``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 接受一個組合函式,並在張量的開頭維度上套用該函式,同時攜帶狀態
def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
...
您可以使用它來有效率地在張量的開頭維度上進行迴圈。如果 xs
是單一張量,則此函式大致等於下列 Python 程式碼
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
在底層,scan
的實作效率更高,方法是將迴圈降階為 XLA While
運算。這可確保 XLA 只編譯迴圈的一次迭代。
``scan_examples.py` </examples/scan/scan_examples.py>`_ 包含一些範例程式碼,示範如何使用 scan
。在該檔案中,scan_example_cumsum
使用 scan
實作累加總和。scan_example_pytree
示範如何將 PyTree 傳遞至 scan
。
您可以使用以下命令執行範例
python3 examples/scan/scan_examples.py
輸出應如下所示
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
[3.],
[6.]], device='xla:0')
Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
[1.5000],
[2.0000],
[2.5000],
[3.0000]], device='xla:0')
限制¶
AOTAutograd 相容性要求¶
傳遞至 scan
和 scan_layers
的函式/模組必須可追蹤 AOTAutograd。特別是,截至 PyTorch/XLA 2.6,scan
和 scan_layers
無法追蹤具有自訂 Pallas 核心的函式。這表示如果您的解碼器使用快閃注意力機制等功能,則它與 scan
不相容。我們正在努力在 nightly 和後續版本中支援此重要用例。
編譯時間實驗¶
為了示範編譯時間的節省,我們將在單一 TPU 晶片上使用 for 迴圈與 scan_layers
訓練具有多層的簡單解碼器。
執行 for 迴圈實作
❯ python3 examples/train_decoder_only_base.py \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 02m57s694ms418.595us
ValueRate: 02s112ms586.097us / second
Rate: 0.054285 / second
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
99%=01m03s028ms571.841us
執行
scan_layers
實作
❯ python3 examples/train_decoder_only_base.py \
scan.decoder_with_scan.DecoderWithScan \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 29s996ms941.409us
ValueRate: 02s529ms591.388us / second
Rate: 0.158152 / second
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
99%=18s995ms301.667us
我們可以發現,透過切換到 scan_layers
,最長編譯時間從 1m03s
降至 19s
。
參考資料¶
請參閱 https://github.com/pytorch/xla/issues/7253,了解 scan
和 scan_layers
本身的設計。
請參閱 ``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 和 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_ 的函式文件註解,以取得關於如何使用它們的詳細資訊。