torch.unravel_index¶
- torch.unravel_index(indices, shape)[原始碼][原始碼]¶
將扁平索引的張量轉換為座標張量的元組,這些座標張量用於索引到指定形狀的任意張量中。
- 參數
indices (Tensor) – 一個整數張量,包含任意張量(形狀為
shape
)展平後的索引。所有元素必須在範圍[0, prod(shape) - 1]
內。shape (int, 整數序列, 或 torch.Size) – 任意張量的形狀。所有元素必須為非負數。
- 返回
輸出中的每個第
i
個張量對應於shape
的第i
個維度。每個張量都具有與indices
相同的形狀,並包含每個由indices
給出的扁平索引在第i
個維度中的一個索引。- 返回類型
tuple of Tensors(張量元組)
範例
>>> import torch >>> torch.unravel_index(torch.tensor(4), (3, 2)) (tensor(2), tensor(0)) >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) (tensor([2, 0]), tensor([0, 1])) >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) (tensor([0, 0, 1, 1, 2, 2]), tensor([0, 1, 0, 1, 0, 1])) >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) (tensor([1, 5]), tensor([2, 6]), tensor([3, 7]), tensor([4, 8])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) (tensor([[1], [5]]), tensor([[2], [6]]), tensor([[3], [7]]), tensor([[4], [8]])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) (tensor([[12], [56]]), tensor([[34], [78]]))