快捷鍵

動態形狀

動態形狀指的是張量形狀的可變性質,其形狀取決於另一個上游張量的值。例如

>>> import torch, torch_xla
>>> in_tensor  = torch.randint(low=0, high=2, size=(5,5), device='xla:0')
>>> out_tensor = torch.nonzero(in_tensor)

out_tensor 的形狀取決於 in_tensor 的值,並受 in_tensor 的形狀限制。換句話說,如果您執行

>>> print(out_tensor.shape)
torch.Size([<=25, 2])

您可以看到第一個維度取決於 in_tensor 的值,其最大值為 25。我們稱第一個維度為動態維度。第二個維度不取決於任何上游張量,因此我們稱之為靜態維度。

動態形狀可以進一步分為有界動態形狀和無界動態形狀。

  • 有界動態形狀:指的是其動態維度受靜態值限制的形狀。它適用於需要靜態記憶體配置的加速器(例如 TPU)。

  • 無界動態形狀:指的是其動態維度可以無限大的形狀。它適用於不需要靜態記憶體配置的加速器(例如 GPU)。

目前,僅支援有界動態形狀,且處於實驗階段。

有界動態形狀

目前,我們在 TPU 上支援具有動態大小輸入的多層感知器模型(MLP)。

此功能由標誌 XLA_EXPERIMENTAL="nonzero:masked_select" 控制。若要啟用此功能執行模型,您可以執行

XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python your_scripts.py

以下是我們在執行 MLP 模型 100 次迭代時獲得的一些數字

無動態形狀

具有動態形狀

端對端訓練時間

29.49

20.03

編譯次數

102

49

編譯快取命中率

198

1953

Performance comparison (a) without dynamic shape  (b) with dynamic shape

動態形狀的動機之一是減少當形狀在迭代之間不斷變化時,過度重新編譯的次數。從上圖您可以看到編譯次數減少了一半,從而減少了訓練時間。

若要試用,請執行

XLA_EXPERIMENTAL="nonzero:masked_select" PJRT_DEVICE=TPU python3 pytorch/xla/test/ds/test_dynamic_shape_models.py TestDynamicShapeModels.test_backward_pass_with_dynamic_input

如需更多關於我們未來計劃如何在 PyTorch/XLA 上擴展動態形狀支援的詳細資訊,請隨時查看我們的 RFC

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源