快捷方式

torch.chain_matmul

torch.chain_matmul(*matrices, out=None)[source][source]

傳回 NN 個 2 維張量的矩陣乘積。此乘積使用矩陣鏈順序演算法有效率地計算,該演算法選擇算術運算成本最低的順序 ([CLRS])。請注意,由於這是一個計算乘積的函數,NN 需要大於或等於 2;如果等於 2,則會傳回一個簡單的矩陣-矩陣乘積。如果 NN 為 1,則這是一個無操作 - 原始矩陣會原樣傳回。

警告

torch.chain_matmul() 已棄用,將在未來的 PyTorch 版本中移除。請改用 torch.linalg.multi_dot(),它接受兩個或更多張量的列表,而不是多個引數。

參數
  • matrices (Tensors...) – 2 個或更多 2 維張量的序列,將決定其乘積。

  • out (Tensor, optional) – 輸出張量。如果 out = None,則忽略。

傳回

如果第 ithi^{th} 張量的維度為 pi×pi+1p_{i} \times p_{i + 1},則乘積的維度將為 p1×pN+1p_{1} \times p_{N + 1}

傳回類型

Tensor

範例

>>> a = torch.randn(3, 4)
>>> b = torch.randn(4, 5)
>>> c = torch.randn(5, 6)
>>> d = torch.randn(6, 7)
>>> # will raise a deprecation warning
>>> torch.chain_matmul(a, b, c, d)
tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
        [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
        [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源