自動混合精度¶
Pytorch/XLA 的 AMP 擴展了 Pytorch 的 AMP 套件,支援在 XLA:GPU
和 XLA:TPU
裝置上使用自動混合精度。AMP 通過以 float32
執行某些運算,並以較低精度資料類型 (float16
或 bfloat16
,取決於硬體支援) 執行其他運算,來加速訓練和推論。本文檔介紹如何在 XLA 裝置上使用 AMP 以及最佳實踐。
XLA:TPU 的 AMP¶
TPU 上的 AMP 會自動轉換運算,使其以 float32
或 bfloat16
執行,因為 TPU 原生支援 bfloat16。以下是一個簡單的 TPU AMP 範例
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
xm.optimizer_step.(optimizer)
autocast(xm.xla_device())
是 torch.autocast('xla')
的別名,當 XLA 裝置是 TPU 時。或者,如果腳本僅與 TPU 一起使用,則可以直接使用 torch.autocast('xla', dtype=torch.bfloat16)
。
如果有應自動轉換但未包含在內的運算符,請提交 issue 或 pull request。
XLA:TPU 的 AMP 最佳實踐¶
autocast
應僅包裝網路的前向傳遞和損失計算。反向運算會以 autocast 用於相應前向運算的相同類型執行。由於 TPU 使用 bfloat16 混合精度,因此不需要梯度縮放。
Pytorch/XLA 提供了修改版本的最佳化器,避免了裝置和主機之間的額外同步。
支援的運算符¶
TPU 上的 AMP 運作方式與 Pytorch 的 AMP 類似。自動轉換的應用規則總結如下
只有異地運算和 Tensor 方法才有資格自動轉換。原地變體和明確提供 out=… Tensor 的呼叫允許在啟用自動轉換的區域中使用,但不會經過自動轉換。例如,在啟用自動轉換的區域中,a.addmm(b, c) 可以自動轉換,但 a.addmm_(b, c) 和 a.addmm(b, c, out=d) 不能。為了獲得最佳效能和穩定性,請在啟用自動轉換的區域中優先使用異地運算。
以 float64 或非浮點 dtype 執行的運算不符合資格,無論是否啟用自動轉換,都將以這些類型執行。此外,使用明確 dtype=… 參數呼叫的運算也不符合資格,並且將產生符合 dtype 參數的輸出。
未在下面列出的運算不會經過自動轉換。它們以輸入定義的類型執行。如果未列出的運算位於自動轉換運算的下游,則自動轉換仍可能更改這些運算執行的類型。
自動轉換為 ``bfloat16`` 的運算
__matmul__
, addbmm
, addmm
, addmv
, addr
, baddbmm
,bmm
, conv1d
, conv2d
, conv3d
, conv_transpose1d
, conv_transpose2d
, conv_transpose3d
, linear
, matmul
, mm
, relu
, prelu
, max_pool2d
自動轉換為 ``float32`` 的運算
batch_norm
, log_softmax
, binary_cross_entropy
, binary_cross_entropy_with_logits
, prod
, cdist
, trace
, chloesky
,inverse
, reflection_pad
, replication_pad
, mse_loss
, cosine_embbeding_loss
, nll_loss
, multilabel_margin_loss
, qr
, svd
, triangular_solve
, linalg_svd
, linalg_inv_ex
自動轉換為最寬輸入類型的運算
stack
, cat
, index_copy
XLA:GPU 的 AMP¶
XLA:GPU 裝置上的 AMP 重複使用 Pytorch 的 AMP 規則。有關 CUDA 特定行為,請參閱 Pytorch 的 AMP 文件。以下是一個簡單的 CUDA AMP 範例
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward pass
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
scaler.step(optimizer)
scaler.update()
autocast(xm.xla_device())
是 torch.cuda.amp.autocast()
的別名,當 XLA 裝置是 CUDA 裝置 (XLA:GPU) 時。或者,如果腳本僅與 CUDA 裝置一起使用,則可以直接使用 torch.cuda.amp.autocast
,但需要 torch
編譯時具有 cuda
對 torch.bfloat16
資料類型的支援。我們建議在 XLA:GPU 上使用 autocast(xm.xla_device())
,因為它不需要 torch.cuda
支援任何資料類型,包括 torch.bfloat16
。
範例¶
我們的 mnist 訓練腳本和 imagenet 訓練腳本示範了如何在 TPU 和 GPU 上使用 AMP。