PromptTensorDictTokenizer¶
- class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]¶
提示資料集的 Tokenization 配方。
傳回一個 tokenizer 函數,該函數讀取包含提示和標籤的範例並對其進行 tokenization。
- 參數:
tokenizer (transformers 函式庫中的 tokenizer) – 要使用的 tokenizer。
max_length (int) – 序列的最大長度。
key (str, 可選) – 尋找文字的鍵。預設為
"prompt"
。padding (str, 可選) – 填充類型。預設為
"max_length"
。truncation (bool, 可選) – 序列是否應截斷為 max_length。
return_tensordict (bool, 可選) – 如果
True
,則傳回 TensoDict。 否則,將傳回原始資料。device (torch.device, 可選) – 儲存資料的裝置。 如果
return_tensordict=False
,則忽略此選項。
這個類別的
__call__()
方法將會執行以下操作:讀取
prompt
字串,並將其與label
字串連接,然後對它們進行 Token 化。結果會儲存在"input_ids"
TensorDict 條目中。寫入一個
"prompt_rindex"
條目,其值為 prompt 中最後一個有效 token 的索引。寫入一個
"valid_sample"
條目,用於識別 tensordict 中哪個條目有足夠的 tokens 來滿足max_length
標準。返回一個帶有 Token 化輸入的
tensordict.TensorDict
實例。
tensordict 的批次大小將與輸入的批次大小相符。
範例
>>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = tokenizer.eos_token >>> example = { ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"], ... "label": ["Indeed it is.", 'It might as well be.'], ... } >>> fn = PromptTensorDictTokenizer(tokenizer, 50) >>> print(fn(example)) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)