捷徑

torchrec.quant

Torchrec 量化

Torchrec 提供用於推論的 EmbeddingBagCollection 量化版本。 它依賴於 fbgemm 量化運算。 這減少了模型權重的規模並加快了模型執行速度。

範例

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

torchrec.quant.embedding_modules

class torchrec.quant.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)

基底: EmbeddingBagCollectionInterface, ModuleNoCopyMixin

EmbeddingBagCollection 表示一個池化嵌入(EmbeddingBags)的集合。 此 EmbeddingBagCollection 已量化為較低精度。 它依賴於 fbgemm 量化運算並提供表批次處理。

注意

EmbeddingBagCollection 是一個未分片的模組,並未針對效能進行最佳化。 對於效能敏感的情況,請考慮使用分片版本 ShardedEmbeddingBagCollection。

它以 KeyedJaggedTensor 的形式處理稀疏資料,其值的格式為 [F X B X L] F:特徵(鍵) B:批次大小 L:稀疏特徵的長度(參差不齊)

並輸出一個 KeyedTensor,其值的格式為 [B * (F * D)],其中 F:特徵(鍵) D:每個特徵(鍵)的嵌入維度 B:批次大小

參數:
  • table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]) – 表與量化權重的映射

  • embedding_configs (List[EmbeddingBagConfig]) – 嵌入表清單

  • is_weighted – (bool):輸入 KeyedJaggedTensor 是否為加權的

  • device – (Optional[torch.device]):預設計算裝置

呼叫參數

features: KeyedJaggedTensor,

回傳值:

KeyedTensor

範例

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature
features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.PlaceholderObserver.with_args(
        dtype=torch.qint8
    ),
    weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)
quantized_embeddings = qebc(features)
property device: device
embedding_bag_configs() List[EmbeddingBagConfig]
forward(features: KeyedJaggedTensor) KeyedTensor
參數:

features (KeyedJaggedTensor) – 格式為 [F X B X L] 的 KJT。

回傳值:

KeyedTensor

classmethod from_float(module: EmbeddingBagCollection, use_precomputed_fake_quant: bool = False) EmbeddingBagCollection
is_weighted() bool
output_dtype() dtype
training: bool
class torchrec.quant.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: device, need_indices: bool = False, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)

基底: EmbeddingCollectionInterface, ModuleNoCopyMixin

EmbeddingCollection 代表非池化嵌入的集合。

注意

EmbeddingCollection 是一個未分片的模組,並且未針對效能進行最佳化。對於效能要求嚴格的情況,請考慮使用分片版本 ShardedEmbeddingCollection。

它以 [F X B X L] 格式處理 KeyedJaggedTensor 形式的稀疏資料,其中

  • F:特徵(鍵值)

  • B:批次大小

  • L:稀疏特徵的長度(可變)

並輸出 Dict[feature (key), JaggedTensor]。每個 JaggedTensor 包含 (B * L) X D 格式的值,其中

  • B:批次大小

  • L:稀疏特徵的長度(參差不齊)

  • D:每個特徵(鍵值)的嵌入維度,長度為 L 格式

參數:
  • tables (List[EmbeddingConfig]) – 嵌入表的清單。

  • device (Optional[torch.device]) – 預設計算裝置。

  • need_indices (bool) – 如果需要將索引傳遞到最終的查找結果字典

範例

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
property device: device
embedding_configs() List[EmbeddingConfig]
embedding_dim() int
embedding_names_by_table() List[List[str]]
forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]
參數:

features (KeyedJaggedTensor) – 格式為 [F X B X L] 的 KJT。

回傳值:

Dict[str, JaggedTensor]

classmethod from_float(module: EmbeddingCollection, use_precomputed_fake_quant: bool = False) EmbeddingCollection
need_indices() bool
output_dtype() dtype
training: bool
class torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16, feature_processor: Optional[FeatureProcessorsCollection] = None)

基底: EmbeddingBagCollection

embedding_bags: nn.ModuleDict
forward(features: KeyedJaggedTensor) KeyedTensor
參數:

features (KeyedJaggedTensor) – 格式為 [F X B X L] 的 KJT。

回傳值:

KeyedTensor

classmethod from_float(module: FeatureProcessedEmbeddingBagCollection, use_precomputed_fake_quant: bool = False) FeatureProcessedEmbeddingBagCollection
tbes: torch.nn.ModuleList
training: bool
torchrec.quant.embedding_modules.for_each_module_of_type_do(module: Module, module_types: List[Type[Module]], op: Callable[[Module], None]) None
torchrec.quant.embedding_modules.pruned_num_embeddings(pruning_indices_mapping: Tensor) int
torchrec.quant.embedding_modules.quant_prep_customize_row_alignment(module: Module, module_types: List[Type[Module]], row_alignment: int) None
torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias(module: Module) None
torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias_for_types(module: Module, module_types: List[Type[Module]]) None
torchrec.quant.embedding_modules.quant_prep_enable_register_tbes(module: Module, module_types: List[Type[Module]]) None
torchrec.quant.embedding_modules.quantize_state_dict(module: Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], table_name_to_data_type: Dict[str, DataType], table_name_to_pruning_indices_mapping: Optional[Dict[str, Tensor]] = None) device

模組內容

Torchrec 量化

Torchrec 提供用於推論的 EmbeddingBagCollection 量化版本。 它依賴於 fbgemm 量化運算。 這減少了模型權重的規模並加快了模型執行速度。

範例

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並取得問題解答

查看資源