make_trainer¶
- torchrl.trainers.helpers.make_trainer(collector: DataCollectorBase, loss_module: LossModule, recorder: Optional[EnvBase] = None, target_net_updater: Optional[TargetNetUpdater] = None, policy_exploration: Optional[Union[TensorDictModuleWrapper, TensorDictModule]] = None, replay_buffer: Optional[ReplayBuffer] = None, logger: Optional[Logger] = None, cfg: DictConfig = None) Trainer [source]¶
根據其組成部分創建一個 Trainer 實例。
- 參數:
collector (DataCollectorBase) – 用於收集資料的資料收集器。
loss_module (LossModule) – 一個 TorchRL 損失模組。
recorder (EnvBase, 可選) – 一個記錄環境。如果為 None,則 trainer 將在不測試策略的情況下訓練該策略。
target_net_updater (TargetNetUpdater, 可選) – 目標網路更新物件。
policy_exploration (TDModule 或 TensorDictModuleWrapper, 可選) – 用於記錄和探索更新的策略(應與學習到的策略同步)。
replay_buffer (ReplayBuffer, 可選) – 用於收集資料的重播緩衝區。
logger (Logger, 可選) – 用於記錄的 Logger。
cfg (DictConfig, 可選) – 包含腳本參數的 DictConfig。如果為 None,則使用預設參數。
- 返回:
使用輸入物件建立的 trainer。優化器由該輔助函數使用提供的 cfg 建立。
範例
>>> import torch >>> import tempfile >>> from torchrl.trainers.loggers import TensorboardLogger >>> from torchrl.trainers import Trainer >>> from torchrl.envs import EnvCreator >>> from torchrl.collectors.collectors import SyncDataCollector >>> from torchrl.data import TensorDictReplayBuffer >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper >>> from torchrl.objectives.common import LossModule >>> from torchrl.objectives.utils import TargetNetUpdater >>> from torchrl.objectives import DDPGLoss >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0")) >>> env_proof = env_maker() >>> obs_spec = env_proof.observation_spec >>> action_spec = env_proof.action_spec >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) >>> loss_module = DDPGLoss(policy, value, gamma=0.99) >>> recorder = env_proof >>> target_net_updater = None >>> policy_exploration = EGreedyWrapper(policy) >>> replay_buffer = TensorDictReplayBuffer() >>> dir = tempfile.gettempdir() >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, ... replay_buffer, logger) >>> print(trainer)