• 文件 >
  • 開始您的第一個訓練迴圈
快捷方式

開始您的第一個訓練迴圈

作者Vincent Moens

注意

若要在筆記本中執行本教學,請在開頭新增一個包含以下內容的安裝儲存格:

!pip install tensordict
!pip install torchrl

是時候總結我們在本入門系列中學到的一切了!

在本教學中,我們將僅使用先前課程中介紹的元件來編寫最基本的訓練迴圈。

我們將使用帶有 CartPole 環境的 DQN 作為原型範例。

我們將自願地將詳細程度保持在最低限度,僅將每個章節連結到相關的教學。

建構環境

我們將使用帶有 StepCounter 轉換的 gym 環境。如果您需要複習,請查看 環境教學 中介紹的這些功能。

import torch

torch.manual_seed(0)

import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

設計策略

下一步是建構我們的策略。我們將建立一個常規的、確定性的 actor 版本,以便在 損失模組 中以及在 評估 期間使用。接下來,我們將使用探索模組來增強它,以便進行 推論

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

資料收集器和重播緩衝區

接下來是資料部分:我們需要一個 資料收集器,以便輕鬆取得批次資料,以及一個 重播緩衝區,以便儲存該資料以進行訓練。

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

損失模組和最佳化器

我們按照專門的教學文件中所述的方式建立損失函數(loss),並包含其最佳化器(optimizer)和目標參數更新器。

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

記錄器 (Logger)

我們將使用 CSV 記錄器來記錄我們的結果,並儲存渲染後的影片。

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)

訓練迴圈 (Training loop)

我們不會固定執行特定次數的迭代,而是會持續訓練網路,直到它達到一定的效能(任意定義為環境中的 200 個步驟 – 對於 CartPole 來說,成功被定義為擁有更長的軌跡)。

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

渲染 (Rendering)

最後,我們會盡可能地執行環境,並將影片儲存在本地(請注意,我們沒有進行探索)。

record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

這就是經過完整的訓練迴圈後,你所渲染的 CartPole 影片看起來的樣子

../_images/cartpole.gif

這就結束了我們的「TorchRL 入門」系列教學!歡迎在 GitHub 上分享您對此系列的意見。

腳本總執行時間:(0 分鐘 22.853 秒)

預估記憶體使用量: 323 MB

由 Sphinx-Gallery 產生圖庫

文件

存取 PyTorch 的完整開發者文件

查看文件

教學文件

取得適合初學者和進階開發人員的深入教學文件

查看教學文件

資源

尋找開發資源並獲得問題解答

查看資源