RolloutFromModel¶
- class torchrl.data.RolloutFromModel(model, ref_model, reward_model, kl_coef=0.1, max_new_tokens=50, score_clip=10.0, kl_scheduler: Optional[KLControllerBase] = None, num_steps: Optional[int] = None)[原始碼]¶
一個用於使用因果語言模型執行 rollout 的類別。
假設此類別包裝的模型將 tokenized 文字作為輸入,且其任務是在閱讀了 n 個先前的單字後預測句子中的下一個單字。
- 參數:
model (transformers.Transformer) – 要使用的模型。 應該有一個
generate()
方法。ref_model (transformers.Transformer) –
model
的凍結版本,其中的參數處於其初始配置。 這用於計算獎勵的 KL 懲罰,以防止模型在訓練期間偏離參考模型太遠。reward_model – (nn.Module, tensordict.nn.TensorDictModule): 給定
input_ids
和attention_mask
的模型,用於計算每個 token 的獎勵和 end_scores(每個序列中最終 token 的獎勵)。kl_coef – (float, optional): 初始 kl 係數。
max_new_tokens (int, optional) – 序列的最大長度。預設值為 50。
score_clip (float, optional) – 獎勵模型中的分數被裁剪到範圍
(-score_clip, score_clip)
。預設值為 10。kl_scheduler (KLControllerBase, optional) – KL 係數排程器。
num_steps (int, optional) – 兩次最佳化之間的步數。
範例
>>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models.rlhf import GPT2RewardModel >>> from torchrl.data.rlhf.utils import RolloutFromModel >>> from torchrl.data.rlhf.dataset import get_dataloader >>> from torchrl.data.rlhf.prompt import PromptData >>> from transformers import GPT2LMHeadModel >>> >>> dl = get_dataloader( ... batch_size=4, ... block_size=550, ... tensorclass_type=PromptData, ... device="cpu", ... dataset_name="CarperAI/openai_summarize_tldr", ... ) >>> model = GPT2LMHeadModel.from_pretrained("gpt2") >>> # we load ref_model with random weights so it differs from model >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) >>> reward_model = GPT2RewardModel(model_path="gpt2") >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model) >>> >>> batch = next(dl) >>> rollout = rollout_from_model.rollout_from_data(batch) >>> rollout TensorDict( fields={ action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False)
- create_rollout_td(batch, generated, log_probs, log_ratio)[source]¶
用於生成資料的 TensorDict 包裝器。
此函式接收一個批次以及產生的 tokens,並複製從 TorchRL 環境 roll out 取得的 tensordict 結構,該環境在每個時間步採樣一個 token。
- 參數:
batch (TensorDict) – 包含原始 prompt 和一個指示 prompt 正確索引的 "rindex" 欄位的資料批次。
generated (torch.Tensor) – Tokenized prompt 後面跟著產生的 tokens。 這可以透過呼叫
generate
方法來獲得。log_probs (torch.Tensor) – 產生 tokens 的 log 機率。 可以透過呼叫
generate
方法來獲得。log_ratio (torch.Tensor) – 根據生成模型和參考模型,產生 tokens 的機率的 log 比率。 可以透過呼叫
generate
方法來獲得。
- 傳回:
"action"
: 動作序列 (產生 tokens)"input_ids"
: 在每個時間步傳遞給生成模型的 input_ids。"attention_mask"
: 在每個時間步傳遞給生成模型的 attention_masks。"sample_log_prob"
: 生成期間每個 token 的 log 機率("next", "input_ids")
: 生成後的 tokens 序列。 組成了將用於產生下一個 token 的輸入的一部分。("next", "attention_mask")
: token 生成後更新的 attention_mask。 在下一個時間步傳遞給生成模型("next", "terminated")
: 布林陣列,指示我們是否已達到終端狀態 (因為我們產生了 EOS token,或者因為我們已達到 token 限制)("next", "done")
: 布林陣列,指示我們是否已達到最終狀態。 目前是"terminated"
的副本。("next", "reward")
: 在每個時間步收到的獎勵("next", "reward_raw")
: 來自獎勵模型的原始獎勵,沒有 KL 項。 這主要用於除錯和記錄,不使用於訓練("next", "reward_kl")
: 來自獎勵的 KL 項。 這主要用於除錯和記錄,不使用於訓練。
- 傳回類型:
一個
TensorDict
具有以下鍵
- generate(batch: PromptData, generation_config=None)[source]¶
從資料收集器採樣的資料批次產生 tokens 序列。
- 參數:
batch (PromptData) – 要使用的資料。 必須具有
input_ids
和prompt_rindex
欄位。generation_config (GenerationConfig, optional) – 呼叫產生 (generate) 的配置。
- 傳回:
- 一個 [B x (Ti +To)] 整數 (tokens) 序列,
其中 Ti 是輸入序列的長度,To 是產生的序列的長度。
log_probs_gen: 產生的 token 的 log 機率。 log_ratio: 生成模型下機率之間的 log 比率
模型和凍結版本。
- 傳回類型:
generated (torch.Tensor)