捷徑

torch.set_float32_matmul_precision

torch.set_float32_matmul_precision(precision)[原始碼][原始碼]

設定 float32 矩陣乘法的內部精確度。

以較低的精確度執行 float32 矩陣乘法可以顯著提高效能,並且在某些程式中,精確度的損失影響可以忽略不計。

支援三種設定:

  • “highest”(最高):float32 矩陣乘法使用 float32 資料類型(24 位尾數,其中 23 位顯式儲存)進行內部計算。

  • “high”(高):如果可以使用適當的快速矩陣乘法演算法,float32 矩陣乘法可以使用 TensorFloat32 資料類型(10 位尾數顯式儲存),或將每個 float32 數字視為兩個 bfloat16 數字之和(大約 16 位尾數,其中 14 位顯式儲存)。 否則,float32 矩陣乘法的計算方式如同精確度為“highest”。 有關 bfloat16 方法的更多資訊,請參閱下文。

  • “medium”(中):如果可以使用內部使用該資料類型的快速矩陣乘法演算法,float32 矩陣乘法使用 bfloat16 資料類型(8 位尾數,其中 7 位顯式儲存)進行內部計算。 否則,float32 矩陣乘法的計算方式如同精確度為“high”。

當使用 “high” 精確度時,float32 乘法可能會使用基於 bfloat16 的演算法,該演算法比簡單地截斷為一些較小的尾數位數(例如,TensorFloat32 為 10,bfloat16 顯式儲存為 7)更複雜。 有關此演算法的完整描述,請參閱 [Henry2019]。 簡要解釋一下,第一步是意識到我們可以完美地將單個 float32 數字編碼為三個 bfloat16 數字之和(因為 float32 有 23 個尾數位數,而 bfloat16 有 7 個顯式儲存,並且兩者具有相同數量的指數位)。 這意味著兩個 float32 數字的乘積可以完全由九個 bfloat16 數字的乘積之和給出。 然後,我們可以透過刪除其中一些產品來交換準確性以獲得速度。 “high” 精確度演算法專門僅保留三個最重要的產品,這方便地排除了涉及任一輸入的最後 8 個尾數位數的所有產品。 這意味著我們可以將輸入表示為兩個 bfloat16 數字之和,而不是三個。 由於 bfloat16 融合乘加 (FMA) 指令通常比 float32 指令快 >10 倍,因此使用 bfloat16 精確度進行三次乘法和 2 次加法比使用 float32 精確度進行單次乘法更快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

這不會更改 float32 矩陣乘法的輸出 dtype,它控制矩陣乘法的內部計算方式。

注意

這不會更改卷積運算的精確度。 其他標誌,例如 torch.backends.cudnn.allow_tf32,可能會控制卷積運算的精確度。

注意

此標誌目前僅影響一種原生裝置類型:CUDA。 如果設定了 “high” 或 “medium”,則在計算 float32 矩陣乘法時將使用 TensorFloat32 資料類型,相當於設定 torch.backends.cuda.matmul.allow_tf32 = True。 當設定 “highest”(預設)時,float32 資料類型用於內部計算,相當於設定 torch.backends.cuda.matmul.allow_tf32 = False

參數

precision (str) – 可以設定為 “highest”(預設)、“high” 或 “medium”(請參閱上文)。

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources