merge_tensordicts¶
- class tensordict.merge_tensordicts(*tensordicts: T, callback_exist: Optional[Union[Callable[[Any], Any], Dict[NestedKey, Callable[[Any], Any]]]] = None)¶
將 tensordicts 合併在一起。
- 參數:
*tensordicts (TensorDict 序列 或 等效項目) – 要合併在一起的 tensordicts 列表。
- 關鍵字參數:
callback_exist (callable 或 Dict[str, callable], optional) – 在每個 tensordict 中都存在一個條目的情況下,使用可呼叫物件。 如果該條目存在於某些而非所有 tensordicts 中,或者如果未傳遞
callback_exist
,則使用 update,並且將使用 tensordict 序列中的第一個非None
值。 如果傳遞可呼叫物件的字典,則它將包含與傳遞給函數的 tensordicts 中某些巢狀鍵相關聯的回呼函數。
範例
>>> from tensordict import merge_tensordicts, TensorDict >>> td0 = TensorDict({"a": {"b0": 0}, "c": {"d": {"e": 0}}, "common": 0}) >>> td1 = TensorDict({"a": {"b1": 1}, "f": {"g": {"h": 1}}, "common": 1}) >>> td2 = TensorDict({"a": {"b2": 2}, "f": {"g": {"h": 2}}, "common": 2}) >>> td = merge_tensordicts(td0, td1, td2, callback_exist=lambda *v: torch.stack(list(v))) >>> print(td) TensorDict( fields={ a: TensorDict( fields={ b0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), c: TensorDict( fields={ d: TensorDict( fields={ e: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), common: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False), f: TensorDict( fields={ g: TensorDict( fields={ h: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(td["common"]) tensor([0, 1, 2])