捷徑

PromptTensorDictTokenizer

class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]

提示資料集的 Tokenization 配方。

傳回一個 tokenizer 函數,該函數讀取包含提示和標籤的範例並對其進行 tokenization。

參數:
  • tokenizer (transformers 函式庫中的 tokenizer) – 要使用的 tokenizer。

  • max_length (int) – 序列的最大長度。

  • key (str, 可選) – 尋找文字的鍵。預設為 "prompt"

  • padding (str, 可選) – 填充類型。預設為 "max_length"

  • truncation (bool, 可選) – 序列是否應截斷為 max_length。

  • return_tensordict (bool, 可選) – 如果 True,則傳回 TensoDict。 否則,將傳回原始資料。

  • device (torch.device, 可選) – 儲存資料的裝置。 如果 return_tensordict=False,則忽略此選項。

這個類別的 __call__() 方法將會執行以下操作:

  • 讀取 prompt 字串,並將其與 label 字串連接,然後對它們進行 Token 化。結果會儲存在 "input_ids" TensorDict 條目中。

  • 寫入一個 "prompt_rindex" 條目,其值為 prompt 中最後一個有效 token 的索引。

  • 寫入一個 "valid_sample" 條目,用於識別 tensordict 中哪個條目有足夠的 tokens 來滿足 max_length 標準。

  • 返回一個帶有 Token 化輸入的 tensordict.TensorDict 實例。

tensordict 的批次大小將與輸入的批次大小相符。

範例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = tokenizer.eos_token
>>> example = {
...     "prompt": ["This prompt is long enough to be tokenized.", "this one too!"],
...     "label": ["Indeed it is.", 'It might as well be.'],
... }
>>> fn = PromptTensorDictTokenizer(tokenizer, 50)
>>> print(fn(example))
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源