dense_stack_tds¶
- class tensordict.dense_stack_tds(td_list: Union[Sequence[TensorDictBase], LazyStackedTensorDict], dim: Optional[int] = None)¶
密集地堆疊
TensorDictBase
物件的列表 (或LazyStackedTensorDict
),前提是它們具有相同的結構。此函數使用
TensorDictBase
的列表 (直接傳遞或從LazyStackedTensorDict
取得) 呼叫。 此函數不會呼叫torch.stack(td_list)
,而是會傳回LazyStackedTensorDict
,此函數會展開輸入列表的第一個元素,並將輸入列表堆疊到該元素上。 只有當輸入列表的所有元素都具有相同的結構時,此方法才有效。 傳回的TensorDictBase
將具有與輸入列表元素相同的類型。當需要堆疊的某些
TensorDictBase
物件是LazyStackedTensorDict
或在條目 (或巢狀條目) 中包含LazyStackedTensorDict
時,此函數很有用。 在這些情況下,呼叫torch.stack(td_list).to_tensordict()
是不可行的。 因此,此函數提供了一種密集堆疊所提供列表的替代方案。- 參數:
td_list (TensorDictBase 的列表 或 LazyStackedTensorDict) – 要堆疊的 tds。
dim (int, 選用) – 堆疊它們的維度。 如果 td_list 是 LazyStackedTensorDict,它將自動被檢索。
範例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict import dense_stack_tds >>> from tensordict.tensordict import assert_allclose_td >>> td0 = TensorDict({"a": torch.zeros(3)},[]) >>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[]) >>> td_lazy = torch.stack([td0, td1], dim=0) >>> td_container = TensorDict({"lazy": td_lazy}, []) >>> td_container_clone = td_container.clone() >>> td_stack = torch.stack([td_container, td_container_clone], dim=0) >>> td_stack LazyStackedTensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=0)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim TensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ 1 -> b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=1)}, batch_size=torch.Size([2]), device=None, is_shared=False) # Note that # (1) td_stack is now a TensorDict # (2) this has pushed the stack_dim of "lazy" (0 -> 1) # (3) this has revealed the exclusive keys. >>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0)) # This shows it is the same to pass a list or a LazyStackedTensorDict