注意
前往結尾以下載完整的範例程式碼。
操作 TensorDict 的形狀¶
作者: Tom Begley
在本教學中,您將學習如何操作 TensorDict
及其內容的形狀。
當我們建立 TensorDict
時,我們會指定一個 batch_size
,它必須與 TensorDict
中所有條目的前導維度一致。由於我們保證所有條目都具有共同的維度,因此 TensorDict
能夠公開許多方法,我們可以使用這些方法來操作 TensorDict
及其內容的形狀。
import torch
from tensordict.tensordict import TensorDict
索引 TensorDict
¶
由於批次維度保證存在於所有條目上,我們可以隨意索引它們,並且 TensorDict
的每個條目都將以相同的方式進行索引。
a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])
重塑 TensorDict
¶
TensorDict.reshape
的工作方式與 torch.Tensor.reshape()
完全相同。它沿著批次維度應用於 TensorDict
的所有內容 - 請注意下面範例中 b
的形狀。它也會更新 batch_size
屬性。
reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])
分割 TensorDict
¶
TensorDict.split
類似於 torch.Tensor.split()
。它會將 TensorDict
分割成多個區塊。每個區塊都是一個 TensorDict
,其結構與原始 TensorDict
相同,但其中的條目是原始 TensorDict
中相應條目的視圖。
chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])
注意
每當函數或方法接受 dim
參數時,負數維度會根據調用該函數或方法的 TensorDict
的 batch_size
進行相對解釋。特別是,如果存在具有不同批次大小的巢狀 TensorDict
值,則負數維度始終根據根節點的批次維度進行相對解釋。
>>> tensordict = TensorDict(
... {
... "a": torch.rand(3, 4),
... "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
... },
... [3, 4],
... )
>>> # dim = -2 will be interpreted as the first dimension throughout, as the root
>>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
>>> chunks = tensordict.split([2, 1], dim=-2)
>>> assert chunks[0].batch_size == torch.Size([2, 4])
>>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])
從這個範例中可以看到,TensorDict.split
方法的行為就像我們在呼叫之前,將 dim=-2
替換為 dim=tensordict.batch_dims - 2
一樣。
Unbind¶
TensorDict.unbind
類似於 torch.Tensor.unbind()
,並且在概念上類似於 TensorDict.split
。它會移除指定的維度,並傳回沿該維度的所有切片的 tuple
。
slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])
堆疊與串聯¶
TensorDict
可以與 torch.cat
和 torch.stack
結合使用。
堆疊 TensorDict
¶
堆疊可以延遲或連續地完成。延遲堆疊只是將 tensordict 列表呈現為 tensordict 堆疊。它允許使用者攜帶具有不同內容形狀、裝置或金鑰集的 tensordict 包。另一個優點是堆疊操作可能非常昂貴,如果只需要一小部分金鑰,延遲堆疊會比適當的堆疊快得多。它依賴於 LazyStackedTensorDict
類別。在這種情況下,只有在存取值時才會按需堆疊它們。
from tensordict import LazyStackedTensorDict
cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
[tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)
# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:
from tensordict.utils import set_lazy_legacy
with set_lazy_legacy(True): # old behaviour
lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)
with set_lazy_legacy(False): # new behaviour
dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 3, 4]),
device=None,
is_shared=False,
stack_dim=0)
如果我們沿著堆疊維度索引 LazyStackedTensorDict
,我們會恢復原始 TensorDict
。
assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict
存取 LazyStackedTensorDict
中的金鑰會導致這些值被堆疊。如果金鑰對應於巢狀 TensorDict
,那麼我們將恢復另一個 LazyStackedTensorDict
。
assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])
注意
由於值是按需堆疊的,因此多次存取一個項目意味著它會被多次堆疊,這是低效率的。如果您需要多次存取堆疊的 TensorDict
中的值,您可能需要考慮將 LazyStackedTensorDict
轉換為連續的 TensorDict
,這可以使用 LazyStackedTensorDict.to_tensordict
或 LazyStackedTensorDict.contiguous
方法來完成。
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
在呼叫其中任一方法後,我們將擁有一個包含堆疊值的常規 TensorDict
,並且在存取值時不會執行額外的計算。
串聯 TensorDict
¶
串聯不是延遲執行的,而是直接在 torch.cat()
中呼叫 TensorDict
實例的列表,它只會回傳一個 TensorDict
,其條目是列表元素的串聯條目。
concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])
擴展 TensorDict
¶
我們可以使用 TensorDict.expand
來擴展 TensorDict
的所有條目。
exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])
壓縮和取消壓縮 TensorDict
¶
我們可以使用 squeeze()
和 unsqueeze()
方法來壓縮或取消壓縮 TensorDict
的內容。
tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")
unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 1, 4, 1]),
device=None,
is_shared=False)
注意
到目前為止,像 unsqueeze()
、squeeze()
、view()
、permute()
、transpose()
這樣的操作都會回傳這些操作的延遲版本(也就是,一個儲存原始 tensordict 的容器,並且每次存取 key 時都會套用操作)。這種行為在未來將會被棄用,並且可以透過 set_lazy_legacy()
函式來控制。
>>> with set_lazy_legacy(True):
... lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
... dense_unsqueeze = tensordict.unsqueeze(0)
請記住,這些方法始終只適用於批次維度。條目的任何非批次維度都不會受到影響。
tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])
檢視 TensorDict¶
TensorDict
也支援 view
。這會建立一個 _ViewedTensorDict
,當存取其內容時,會延遲地建立檢視。
tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))
# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])
置換批次維度¶
TensorDict.permute
方法可以用來置換批次維度,很像 torch.permute()
。非批次維度保持不變。
這個操作是延遲的,所以只有在我們嘗試存取條目時,才會置換批次維度。一如既往,如果您可能需要多次存取特定條目,請考慮轉換為 TensorDict
。
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])
assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])
將 tensordict 用作裝飾器¶
對於一堆可逆的操作,tensordict 可以用作裝飾器。這些操作包括用於函數呼叫的 to_module()
、unlock_()
和 lock_()
或形狀操作,例如 view()
、permute()
transpose()
、squeeze()
和 unsqueeze()
。以下是一個使用 transpose
函式的快速範例。
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
with tensordict.transpose(1, 0) as tdt:
tdt.set("c", torch.ones(4, 3)) # we have permuted the dims
# the ``"c"`` entry is now in the tensordict we used as decorator:
#
assert (tensordict.get("c") == 1).all()
在 TensorDict
中收集數值¶
TensorDict.gather
方法可以用來沿著批次維度進行索引,並將結果收集到單一維度中,很像 torch.gather()
。
index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
tensor([[2, 3, 2, 1],
[3, 3, 0, 0],
[3, 1, 1, 2]])
tensordict['a']:
tensor([[0.1814, 0.2808, 0.2381, 0.4003],
[0.1536, 0.0138, 0.4464, 0.6981],
[0.9308, 0.0727, 0.3552, 0.4791]])
gathered_tensordict['a']:
tensor([[0.2381, 0.4003, 0.2381, 0.2808],
[0.6981, 0.6981, 0.1536, 0.1536],
[0.4791, 0.0727, 0.0727, 0.3552]])
腳本的總運行時間: (0 分鐘 0.008 秒)