• 文件 >
  • Pooled Embedding 模組
捷徑

Pooled Embedding 模組

穩定版 API

class fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings(embs_dims: List[int], permute: List[int], device: device | None = None)[source]

用於沿著特徵維度置換 Embedding 輸出的模組

Embedding 輸出張量包含批次中所有特徵的 Embedding 輸出。它以 2D 格式表示,其中列是批次大小維度,而欄是特徵 * Embedding 維度。沿著特徵維度置換本質上是沿著第二個維度 (dim 1) 置換。

範例

>>> import torch
>>> import fbgemm_gpu
>>> from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
>>>
>>> # Suppose batch size = 3 and there are 3 features
>>> batch_size = 3
>>>
>>> # Embedding dimensions for each feature
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
>>>
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
>>> # to position 1, so on
>>> permute = [2, 0, 1]
>>>
>>> # Instantiate the module
>>> perm = PermutePooledEmbeddings(embs_dims, permute)
>>>
>>> # Generate an example input
>>> pooled_embs = torch.arange(
>>>     embs_dims.sum().item() * batch_size,
>>>     dtype=torch.float32, device="cuda"
>>> ).reshape(batch_size, -1)
>>> print(pooled_embs)
>>>
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15.],
        [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
         30., 31.],
        [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
         46., 47.]], device='cuda:0')
>>>
>>> # Invoke
>>> perm(pooled_embs)
>>>
tensor([[ 8.,  9., 10., 11., 12., 13., 14., 15.,  0.,  1.,  2.,  3.,  4.,  5.,
          6.,  7.],
        [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
         22., 23.],
        [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
         38., 39.]], device='cuda:0')
參數:
  • embs_dims (List[int]) – 所有特徵的 Embedding 維度列表。長度 = 特徵數量

  • permute (List[int]) – 描述如何置換每個特徵的列表。permute[i] 是將特徵 permute[i] 置換到位置 i

  • device (Optional[torch.device] = None) – 在其上執行此模組的裝置

__call__(pooled_embs: Tensor) Tensor[source]

沿著特徵維度執行 pooled embedding 輸出置換

參數:

pooled_embs (Tensor) – 要置換的 Embedding 輸出。形狀為 (B_local, total_global_D),其中 B_local = 本機批次大小,而 total_global_D 是所有特徵(全域)的總 Embedding 維度

傳回:

已置換的 Embedding 輸出 (Tensor)。與 pooled_embs 形狀相同

其他 API

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得適合初學者和進階開發者的深度教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源