快捷方式

torch.bmm

torch.bmm(input, mat2, *, out=None) Tensor

對儲存在 inputmat2 中的矩陣執行批次矩陣相乘。

inputmat2 必須是 3 維張量,且各自包含相同數量的矩陣。

如果 input(b×n×m)(b \times n \times m) 張量,mat2(b×m×p)(b \times m \times p) 張量,則 out 將會是 (b×n×p)(b \times n \times p) 張量。

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

此運算子支援 TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將使用 不同的精度 進行反向傳播。

註記

此函數不進行 廣播。如需廣播矩陣乘法,請參閱 torch.matmul()

參數
  • input (Tensor) – 要相乘的第一批矩陣

  • mat2 (Tensor) – 要相乘的第二批矩陣

關鍵字參數

out (Tensor, optional) – 輸出張量。

範例

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源