isin¶
- class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)¶
測試 input 中 `key` 在 `dim` 維度上的每個元素是否也存在於 reference 中。
此函式會傳回一個長度為 `input.batch_size[dim]` 的布林張量,若 `key` 條目中的元素也存在於 `reference` 中,則該元素為 `True`。此函式假設 `input` 和 `reference` 具有相同的批次大小並包含指定的條目,否則會引發錯誤。
- 參數:
input (TensorDictBase) – 輸入 TensorDict。
reference (TensorDictBase) – 用於測試的目標 TensorDict。
key (Nestedkey) – 要測試的鍵。
dim (int, optional) – 要測試的維度。預設為 `0`。
- 傳回:
- 長度為 `input.batch_size[dim]` 的布林張量,若
`input` 中 `key` 張量的元素也存在於 `reference` 中,則該元素為 `True`。
- 傳回型別:
out (Tensor)
範例
>>> td = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]), ... }, ... batch_size=[4], ... ) >>> td_ref = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]), ... }, ... batch_size=[3], ... ) >>> in_reference = isin(td, td_ref, key="tensor1") >>> expected_in_reference = torch.tensor([True, True, True, False]) >>> torch.testing.assert_close(in_reference, expected_in_reference)