快捷鍵

torch.mm

torch.mm(input, mat2, *, out=None) Tensor

執行矩陣 inputmat2 的矩陣乘法。

如果 input 是一個 (n×m)(n \times m) 張量,mat2 是一個 (m×p)(m \times p) 張量,則 out 將會是一個 (n×p)(n \times p) 張量。

注意

此函式不進行廣播 (broadcast)。若要進行廣播矩陣乘積,請參閱 torch.matmul()

支援跨步 (strided) 和稀疏 (sparse) 的 2 維張量作為輸入,並支援對跨步輸入進行自動微分 (autograd)。

此操作支援具有稀疏佈局 (sparse layouts)的引數。 如果提供了 out,則將使用其佈局。 否則,結果佈局將從 input 的佈局中推導得出。

警告

稀疏支援是一項 beta 功能,某些佈局/dtype/裝置組合可能不受支援,或者可能不支援自動微分。 如果您發現缺少功能,請開啟一個功能請求。

此運算子支援 TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將使用不同的精度進行反向傳播 (backward)。

參數
  • input (Tensor) – 作為矩陣乘法的第一個矩陣

  • mat2 (Tensor) – 作為矩陣乘法的第二個矩陣

關鍵字引數

out (Tensor, optional) – 輸出張量。

範例

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources