get_dataloader¶
- class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: device, dataset_name: Optional[str] = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: Optional[str] = None, from_disk: bool = False, num_workers: Optional[int] = None)[原始碼]¶
建立一個資料集並從中返回一個資料載入器。
- 參數:
batch_size (int) – 資料載入器樣本的批次大小。
block_size (int) – 資料載入器中序列的最大長度。
tensorclass_type (tensorclass class) – 一個 tensorclass 類別,必須具有
from_dataset()
方法,且該方法必須接受三個關鍵字引數:split
(見下方)、用於訓練的區塊大小max_length
,以及指示資料集的字串dataset_name
。 此外,也應支援root_dir
和from_disk
引數。device (torch.device 或 等效物件) – 樣本應轉換到的裝置。
dataset_name (str, optional) – 資料集名稱。 如果未提供,且 tensorclass 支援,將會為正在使用的 tensorclass 收集預設資料集名稱。
infinite (bool, optional) – 如果為
True
,則迭代將是無限的,因此next(iterator)
將永遠返回一個值。 預設為True
。prefetch (int, optional) – 如果使用多執行緒資料載入,要預先提取的項目數量。
split (str, optional) – 資料分割。 可以是
"train"
或"valid"
。 預設為"train"
。root_dir (path, optional) – 儲存資料集的路徑。 預設為
"$HOME/.cache/torchrl/data"
。from_disk (bool, optional) – 如果為
True
,將使用datasets.load_from_disk()
。 否則,將使用datasets.load_dataset()
。 預設為False
。num_workers (int, optional) – 在 Token 化期間呼叫的
datasets.dataset.map()
的 worker 數量。 預設為max(os.cpu_count() // 2, 1)
。
範例
>>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: ... print(d) ... break PairwiseDataset( chosen_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), rejected_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), batch_size=torch.Size([256]), device=cpu, is_shared=False)