注意
前往結尾以下載完整的範例程式碼。
操作 TensorDict 的鍵值¶
作者: Tom Begley
在本教學中,您將學習如何使用和操作 TensorDict
中的鍵值,包括取得和設定鍵值、迭代鍵值、操作巢狀值以及展平鍵值。
設定和取得鍵值¶
我們可以使用與 Python dict
相同的語法來設定和取得鍵值。
import torch
from tensordict.tensordict import TensorDict
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict["a"] = a
# retrieve the value stored under "a"
assert tensordict["a"] is a
注意
與 Python dict
不同,TensorDict
中的所有鍵值都必須是字串。但是,正如我們將看到的,也可以使用字串的元組來操作巢狀值。
我們也可以使用 .get()
和 .set
方法來完成相同的事情。
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict.set("a", a)
# retrieve the value stored under "a"
assert tensordict.get("a") is a
與 dict
一樣,我們可以為 get
提供一個預設值,如果在找不到請求的鍵值時應傳回該值。
同樣地,與 dict
一樣,我們可以利用 TensorDict.setdefault()
取得特定鍵的值,如果在找不到該鍵時傳回預設值,並在 TensorDict
中設定該值。
刪除鍵值的方法也與 Python dict
相同,使用 del
語句和選取的鍵。等效地,我們可以使用 TensorDict.del_
方法。
del tensordict["banana"]
此外,當使用 .set()
設定鍵值時,我們可以使用關鍵字引數 inplace=True
進行原地更新,或等效地使用 .set_()
方法。
tensordict.set("a", torch.zeros(10), inplace=True)
# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a
# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a
重新命名鍵值¶
若要重新命名鍵值,只需使用 TensorDict.rename_key_
方法。儲存在原始鍵值下的值將保留在 TensorDict
中,但鍵值將變更為指定的新鍵值。
tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
fields={
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
更新多個值¶
TensorDict.update
方法可用於使用另一個 TensorDict`
或 dict
來更新 TensorDict`
。已經存在的鍵會被覆寫,而尚未存在的鍵則會被建立。
tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
巢狀值¶
TensorDict
的值本身可以是 TensorDict
。 我們可以在實例化期間添加巢狀值,方法是直接添加 TensorDict
,或者使用巢狀字典
# creating nested values with a nested dict
nested_tensordict = TensorDict(
{"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
要存取這些巢狀值,我們可以使用字串的元組。例如:
double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))
類似地,我們可以使用字串的元組來設定巢狀值
tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
迭代 TensorDict 的內容¶
我們可以使用 .keys()
方法來迭代 TensorDict
的鍵。
a
nested
預設情況下,這只會迭代 TensorDict
中的頂層鍵,但是可以使用關鍵字參數 include_nested=True
遞迴地迭代 TensorDict
中的所有鍵。 這將遞迴地迭代任何巢狀 TensorDict 中的所有鍵,並將巢狀鍵作為字串的元組返回。
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested
如果您只想迭代對應於 Tensor
值的鍵,您可以額外指定 leaves_only=True
。
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')
很像 dict
,也有 .values
和 .items
方法,它們接受相同的關鍵字參數。
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor
檢查鍵是否存在¶
要檢查 TensorDict
中是否存在鍵,請結合 .keys()
使用 in
運算符。
注意
執行 key in tensordict.keys()
可有效地查找鍵(在巢狀情況下,遞迴地在每個層級),因此當 TensorDict
中存在大量鍵時,效能不會受到負面影響。
assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)
扁平化和取消扁平化巢狀鍵¶
我們可以使用 .flatten_keys()
方法來扁平化具有巢狀值的 TensorDict
。
print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
給定一個已被扁平化的 TensorDict
,可以使用 .unflatten_keys()
方法再次將其取消扁平化。
flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
當操作 torch.nn.Module
的參數時,這特別有用,因為我們最終可能會得到一個 TensorDict
,其結構模仿模組結構。
import torch.nn as nn
module = nn.Sequential(
nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()
print(params)
TensorDict(
fields={
0: TensorDict(
fields={
0: TensorDict(
fields={
bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
選擇和排除鍵¶
我們可以使用 TensorDict.select
獲得一個具有鍵子集的新 TensorDict
,它返回一個僅包含指定鍵的新 TensorDict
,或者使用 :meth:`TensorDict.exclude <tensordict.TensorDict.exclude>`,它返回一個省略指定鍵的新 TensorDict
。
print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
Exclude:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
腳本的總執行時間:(0 分鐘 0.009 秒)