快捷方式

torch.unravel_index

torch.unravel_index(indices, shape)[原始碼][原始碼]

將扁平索引的張量轉換為座標張量的元組,這些座標張量用於索引到指定形狀的任意張量中。

參數
  • indices (Tensor) – 一個整數張量,包含任意張量(形狀為 shape)展平後的索引。所有元素必須在範圍 [0, prod(shape) - 1] 內。

  • shape (int, 整數序列, 或 torch.Size) – 任意張量的形狀。所有元素必須為非負數。

返回

輸出中的每個第 i 個張量對應於 shape 的第 i 個維度。每個張量都具有與 indices 相同的形狀,並包含每個由 indices 給出的扁平索引在第 i 個維度中的一個索引。

返回類型

tuple of Tensors(張量元組)

範例

>>> import torch
>>> torch.unravel_index(torch.tensor(4), (3, 2))
(tensor(2),
 tensor(0))

>>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
(tensor([2, 0]),
 tensor([0, 1]))

>>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
(tensor([0, 0, 1, 1, 2, 2]),
 tensor([0, 1, 0, 1, 0, 1]))

>>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
(tensor([1, 5]),
 tensor([2, 6]),
 tensor([3, 7]),
 tensor([4, 8]))

>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
(tensor([[1], [5]]),
 tensor([[2], [6]]),
 tensor([[3], [7]]),
 tensor([[4], [8]]))

>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
(tensor([[12], [56]]),
 tensor([[34], [78]]))

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源