torch.combinations¶
- torch.combinations(input: Tensor, r: int = 2, with_replacement: bool = False) seq ¶
計算給定張量長度為 的組合。其行為類似於 Python 的 itertools.combinations,當 with_replacement 設定為 False 時;而當 with_replacement 設定為 True 時,則類似於 itertools.combinations_with_replacement。
- 參數
- 返回
一個張量,等同於將所有輸入張量轉換為列表,在這些列表上執行 itertools.combinations 或 itertools.combinations_with_replacement,最後將結果列表轉換為張量。
- 返回類型
範例
>>> a = [1, 2, 3] >>> list(itertools.combinations(a, r=2)) [(1, 2), (1, 3), (2, 3)] >>> list(itertools.combinations(a, r=3)) [(1, 2, 3)] >>> list(itertools.combinations_with_replacement(a, r=2)) [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] >>> tensor_a = torch.tensor(a) >>> torch.combinations(tensor_a) tensor([[1, 2], [1, 3], [2, 3]]) >>> torch.combinations(tensor_a, r=3) tensor([[1, 2, 3]]) >>> torch.combinations(tensor_a, with_replacement=True) tensor([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]])