torch.bmm¶
- torch.bmm(input, mat2, *, out=None) Tensor ¶
對儲存在
input
和mat2
中的矩陣執行批次矩陣相乘。input
和mat2
必須是 3 維張量,且各自包含相同數量的矩陣。如果
input
是 張量,mat2
是 張量,則out
將會是 張量。此運算子支援 TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將使用 不同的精度 進行反向傳播。
註記
此函數不進行 廣播。如需廣播矩陣乘法,請參閱
torch.matmul()
。範例
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])