torch.chain_matmul¶
- torch.chain_matmul(*matrices, out=None)[source][source]¶
傳回 個 2 維張量的矩陣乘積。此乘積使用矩陣鏈順序演算法有效率地計算,該演算法選擇算術運算成本最低的順序 ([CLRS])。請注意,由於這是一個計算乘積的函數, 需要大於或等於 2;如果等於 2,則會傳回一個簡單的矩陣-矩陣乘積。如果 為 1,則這是一個無操作 - 原始矩陣會原樣傳回。
警告
torch.chain_matmul()
已棄用,將在未來的 PyTorch 版本中移除。請改用torch.linalg.multi_dot()
,它接受兩個或更多張量的列表,而不是多個引數。- 參數
matrices (Tensors...) – 2 個或更多 2 維張量的序列,將決定其乘積。
out (Tensor, optional) – 輸出張量。如果
out
=None
,則忽略。
- 傳回
如果第 張量的維度為 ,則乘積的維度將為 。
- 傳回類型
範例
>>> a = torch.randn(3, 4) >>> b = torch.randn(4, 5) >>> c = torch.randn(5, 6) >>> d = torch.randn(6, 7) >>> # will raise a deprecation warning >>> torch.chain_matmul(a, b, c, d) tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])