快捷方式

TokenizedDatasetLoader

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: Type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: Optional[int] = None, tokenizer_class=None, tokenizer_model_name=None)[來源]

載入已 Token 化的資料集,並快取其記憶體映射副本。

參數:
  • split (str) – "train""valid" 之一。

  • max_length (int) – 最大序列長度。

  • dataset_name (str) – 資料集的名稱。

  • tokenizer_fn (callable) – Token 化方法建構子,例如 torchrl.data.rlhf.TensorDictTokenizer。呼叫時,它應該傳回一個 tensordict.TensorDict 實例或具有 Token 化資料的類似字典的結構。

  • pre_tokenization_hook (callable, optional) – 在 Token 化之前對 Dataset 執行的函式。它應該回傳一個修改過的 Dataset 物件。其用途是用於執行需要修改整個資料集(而不是修改個別資料點)的任務,例如根據特定條件捨棄某些資料點。資料的 Token 化和其他「逐元素」操作由 process 函式執行,該函式會映射到資料集。

  • root_dir (path, optional) – 儲存資料集的路徑。預設為 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果為 True,將會使用 datasets.load_from_disk()。否則,將會使用 datasets.load_dataset()。預設為 False

  • valid_size (int, optional) – 驗證資料集的大小(如果 split 以 "valid" 開頭),將會截斷為此值。預設為 2000 個項目。

  • num_workers (int, optional) – datasets.dataset.map() 的 worker 數量,在 Token 化期間會呼叫此函式。預設為 max(os.cpu_count() // 2, 1)

  • tokenizer_class (type, optional) – Tokenizer 類別,例如 AutoTokenizer(預設)。

  • tokenizer_model_name (str, optional) – 應該從哪個模型收集詞彙表。預設為 "gpt2"

資料集將儲存在 <root_dir>/<split>/<max_length>/ 中。

範例

>>> from torchrl.data.rlhf import TensorDictTokenizer
>>> from torchrl.data.rlhf.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: 'datasets.Dataset' | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, batch_dims=1, valid_mask_key=None)[source]

將資料集轉換為記憶體映射的 TensorDict。

如果資料集已經是 TensorDict 實例,則會簡單地將其轉換為記憶體映射的 TensorDict。 否則,預期資料集具有 features 屬性,該屬性是一個字串序列,指示可以在資料集中找到的特徵。 如果沒有,則必須將 features 明確傳遞給此函式。

參數:
  • dataset (datasets.Dataset, TensorDict or equivalent) – 要轉換為記憶體映射的 TensorDict 的資料集。如果 featuresNone,則它必須具有 features 屬性,其中包含要在 tensordict 中寫入的索引鍵清單。

  • data_dir (Path or equivalent) – 應該寫入資料的目錄。

  • prefix (NestedKey, optional) – 資料集位置的前綴。這可以用於區分已進行不同預處理的相同資料集的多個副本。

  • features (sequence of str, optional) – 一個字串序列,指示可以在資料集中找到的特徵。

  • batch_dims (int, optional) – 資料的批次維度數量(即可以索引 tensordict 的維度數量)。預設為 1。

  • valid_mask_key (NestedKey, optional) – 如果提供,將嘗試收集此項目並將其用於過濾資料。預設為 None (即,沒有過濾器鍵)。

回傳:一個包含具有資料集的記憶體映射張量的 TensorDict。

範例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[source]

如果存在,則載入預先處理過的記憶體映射資料集,否則建立它。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源