捷徑

get_dataloader

class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: device, dataset_name: Optional[str] = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: Optional[str] = None, from_disk: bool = False, num_workers: Optional[int] = None)[原始碼]

建立一個資料集並從中返回一個資料載入器。

參數:
  • batch_size (int) – 資料載入器樣本的批次大小。

  • block_size (int) – 資料載入器中序列的最大長度。

  • tensorclass_type (tensorclass class) – 一個 tensorclass 類別,必須具有 from_dataset() 方法,且該方法必須接受三個關鍵字引數:split(見下方)、用於訓練的區塊大小 max_length,以及指示資料集的字串 dataset_name。 此外,也應支援 root_dirfrom_disk 引數。

  • device (torch.device等效物件) – 樣本應轉換到的裝置。

  • dataset_name (str, optional) – 資料集名稱。 如果未提供,且 tensorclass 支援,將會為正在使用的 tensorclass 收集預設資料集名稱。

  • infinite (bool, optional) – 如果為 True,則迭代將是無限的,因此 next(iterator) 將永遠返回一個值。 預設為 True

  • prefetch (int, optional) – 如果使用多執行緒資料載入,要預先提取的項目數量。

  • split (str, optional) – 資料分割。 可以是 "train""valid"。 預設為 "train"

  • root_dir (path, optional) – 儲存資料集的路徑。 預設為 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果為 True,將使用 datasets.load_from_disk()。 否則,將使用 datasets.load_dataset()。 預設為 False

  • num_workers (int, optional) – 在 Token 化期間呼叫的 datasets.dataset.map() 的 worker 數量。 預設為 max(os.cpu_count() // 2, 1)

範例

>>> from torchrl.data.rlhf.reward import PairwiseDataset
>>> dataloader = get_dataloader(
...     batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
>>> for d in dataloader:
...     print(d)
...     break
PairwiseDataset(
    chosen_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    rejected_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    batch_size=torch.Size([256]),
    device=cpu,
    is_shared=False)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源