注意
跳到結尾下載完整範例程式碼。
TorchRL 訓練器:DQN 範例¶
TorchRL 提供了一個通用的 Trainer
類別來處理您的訓練迴圈。訓練器執行一個巢狀迴圈,其中外迴圈是資料收集,內迴圈消耗這些資料或從重播緩衝區檢索的一些資料來訓練模型。在這個訓練迴圈中的不同點,可以附加 hooks 並以給定的間隔執行。
在本教學中,我們將使用訓練器類別來訓練 DQN 演算法,以從頭開始解決 CartPole 任務。
主要要點
使用其基本組件建構訓練器:資料收集器、損失模組、重播緩衝區和最佳化器。
將 hooks 新增到訓練器,例如記錄器、目標網路更新器等。
訓練器是完全可自訂的,並提供大量功能。本教學圍繞其建構進行組織。我們將首先詳細說明如何建構函式庫的每個組件,然後使用 Trainer
類別將這些部分組合在一起。
在此過程中,我們還將關注函式庫的其他一些方面
如何在 TorchRL 中建構環境,包括轉換(例如資料正規化、幀串聯、調整大小和轉換為灰階)和平行執行。與我們在 DDPG 教學中所做的不同,我們將正規化像素而不是狀態向量。
如何設計一個
QValueActor
物件,即一個估計動作值並選擇具有最高估計回報的動作的 actor;如何有效地從您的環境收集資料並將它們儲存在重播緩衝區中;
如何使用 multi-step,這是離策略演算法的簡單預處理步驟;
以及最後如何評估您的模型。
先決條件:我們鼓勵您首先透過 PPO 教學熟悉 torchrl。
DQN¶
DQN (Deep Q-Learning) 是深度強化學習的基礎工作。
從高層次來看,該演算法非常簡單:Q-learning 包括學習一個狀態-動作值表,以便在遇到任何特定狀態時,我們只需搜尋具有最高值的動作就知道要選擇哪個動作。這個簡單的設定要求動作和狀態是離散的,否則無法建構查詢表。
DQN 使用神經網路將狀態-動作空間映射到值(純量)空間,從而攤銷了儲存和探索所有可能的狀態-動作組合的成本:如果過去沒有見過某個狀態,我們仍然可以將它與各種可用的動作一起傳遞到我們的神經網路中,並獲得每個可用動作的插值。
我們將解決經典的倒立擺控制問題。此環境是從 Gymnasium 文件中取得。

我們的目標不是提供該演算法的 SOTA 實作,而是提供 TorchRL 功能在此演算法背景下的高階說明。
import os
import uuid
import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
EnvCreator,
ExplorationType,
ParallelEnv,
RewardScaling,
StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
Compose,
GrayScale,
ObservationNorm,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
)
def is_notebook() -> bool:
try:
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
讓我們開始介紹我們的演算法所需的各種組件
一個環境;
一個策略(以及我們在 “模型” 範疇下分組的相關模組);
一個資料收集器,它使策略在環境中執行並提供訓練資料;
一個重播緩衝區來儲存訓練資料;
一個損失模組,它計算目標函數來訓練我們的策略以最大化回報;
一個優化器,它根據我們的損失執行參數更新。
其他模組包括一個記錄器、一個記錄器(以 “eval” 模式執行策略)和一個目標網路更新器。有了所有這些組件,很容易看出在訓練腳本中可能會錯放或誤用某個組件。訓練器就是為了替你協調一切!
建立環境¶
首先,讓我們編寫一個輔助函數,它將輸出一個環境。像往常一樣,“原始” 環境可能太過簡單而無法在實踐中使用,我們需要一些資料轉換才能將其輸出公開給策略。
我們將使用五種轉換
StepCounter
來計算每個軌跡中的步數;ToTensorImage
將把[W, H, C]
uint8 張量轉換為[0, 1]
空間中的浮點張量,形狀為[C, W, H]
;RewardScaling
降低回報的規模;GrayScale
將把我們的圖片轉換為灰階;Resize
將把圖片調整為 64x64 格式;CatFrames
將沿通道維度連接任意數量的連續幀(N=4
)到一個張量中。這很有用,因為單個圖片不帶有關於倒立擺運動的資訊。需要一些關於過去觀察和動作的記憶,可以透過遞迴神經網路或使用幀堆疊來實現。ObservationNorm
將根據一些自訂的統計摘要來正規化我們的觀察。
實際上,我們的環境建構器有兩個參數
parallel
:決定是否需要並行執行多個環境。我們在ParallelEnv
之後堆疊轉換,以利用設備上運算的向量化,儘管從技術上講,這適用於連接到其自身轉換集合的每個單獨環境。obs_norm_sd
將包含ObservationNorm
轉換的正規化常數。
def make_env(
parallel=False,
obs_norm_sd=None,
num_workers=1,
):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
if parallel:
def maker():
return GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
base_env = ParallelEnv(
num_workers,
EnvCreator(maker),
# Don't create a sub-process if we have only one worker
serial_for_single=True,
mp_start_method=mp_context,
)
else:
base_env = GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
env = TransformedEnv(
base_env,
Compose(
StepCounter(), # to count the steps of each trajectory
ToTensorImage(),
RewardScaling(loc=0.0, scale=0.1),
GrayScale(),
Resize(64, 64),
CatFrames(4, in_keys=["pixels"], dim=-3),
ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
),
)
return env
計算正規化常數¶
為了正規化圖像,我們不想使用完整的 [C, W, H]
正規化遮罩獨立地正規化每個像素,而是使用更簡單的 [C, 1, 1]
形狀的正規化常數集(位置和尺度參數)。我們將使用 init_stats()
的 reduce_dim
參數來指示必須縮減哪些維度,並使用 keep_dims
參數來確保並非所有維度都在過程中消失
def get_norm_stats():
test_env = make_env()
test_env.transform[-1].init_stats(
num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
)
obs_norm_sd = test_env.transform[-1].state_dict()
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
del test_env
return obs_norm_sd
建立模型(深度 Q 網路)¶
以下函數建立一個 DuelingCnnDQNet
物件,它是一個簡單的 CNN,後跟一個兩層 MLP。這裡使用的唯一技巧是,動作值(即左和右動作值)是使用以下公式計算的
其中 \(\mathbb{v}\) 是我們的動作值向量,\(b\) 是一個 \(\mathbb{R}^n \rightarrow 1\) 函數,\(v\) 是一個 \(\mathbb{R}^n \rightarrow \mathbb{R}^m\) 函數,其中 \(n = \# obs\) 且 \(m = \# actions\)。
我們的網路被包裝在 QValueActor
中,它將讀取狀態-動作值,選擇具有最大值的那個,並將所有這些結果寫入輸入 tensordict.TensorDict
。
def make_model(dummy_env):
cnn_kwargs = {
"num_cells": [32, 64, 64],
"kernel_sizes": [6, 4, 3],
"strides": [2, 2, 1],
"activation_class": nn.ELU,
# This can be used to reduce the size of the last layer of the CNN
# "squeeze_output": True,
# "aggregator_class": nn.AdaptiveAvgPool2d,
# "aggregator_kwargs": {"output_size": (1, 1)},
}
mlp_kwargs = {
"depth": 2,
"num_cells": [
64,
64,
],
"activation_class": nn.ELU,
}
net = DuelingCnnDQNet(
dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs
).to(device)
net.value[-1].bias.data.fill_(init_bias)
actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device)
# init actor: because the model is composed of lazy conv/linear layers,
# we must pass a fake batch of data through it to instantiate them.
tensordict = dummy_env.fake_tensordict()
actor(tensordict)
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)
return actor, actor_explore
收集和儲存資料¶
重播緩衝區¶
重播緩衝區在諸如 DQN 等離線策略 RL 演算法中扮演著核心角色。它們構成了我們在訓練期間將從中取樣的資料集。
在這裡,我們將使用常規的取樣策略,儘管優先級 RB 可以顯著提高性能。
我們使用 LazyMemmapStorage
類將儲存空間放置在磁碟上。此儲存空間以延遲方式建立:它只會在第一批資料傳遞給它時才被實例化。
這個儲存空間的唯一要求是,寫入時傳遞給它的資料必須始終具有相同的形狀。
def get_replay_buffer(buffer_size, n_optim, batch_size):
replay_buffer = TensorDictReplayBuffer(
batch_size=batch_size,
storage=LazyMemmapStorage(buffer_size),
prefetch=n_optim,
)
return replay_buffer
資料收集器¶
如同在 PPO 和 DDPG 中,我們將使用資料收集器作為外部迴圈中的資料載入器。
我們選擇以下配置:我們將在一系列平行的環境中同步平行執行,這些環境本身在不同的收集器中平行但異步執行。
注意
這個功能只有在 Python 多進程庫的 "spawn" 啟動方法中執行程式碼時才能使用。如果本教學直接作為腳本執行(因此使用 "fork" 方法),我們將使用常規的 SyncDataCollector
。
這種配置的優點是,我們可以平衡批次執行的計算量與我們想要異步執行的量。我們鼓勵讀者實驗修改收集器的數量(即傳遞給收集器的環境建構子的數量)以及每個收集器中平行執行的環境數量(由 num_workers
超參數控制)如何影響收集速度。
收集器的設備可以透過 device
(通用)、policy_device
、env_device
和 storing_device
參數完全參數化。storing_device
參數將修改正在收集的資料的位置:如果我們收集的批次具有相當大的大小,我們可能希望將它們儲存在與計算發生的設備不同的位置。對於像我們這樣的異步資料收集器,不同的儲存設備意味著我們收集的資料每次不會位於相同的設備上,這是我們的訓練迴圈必須考慮的事情。為了簡單起見,我們將所有子收集器的設備設定為相同的值。
def get_collector(
stats,
num_collectors,
actor_explore,
frames_per_batch,
total_frames,
device,
):
# We can't use nested child processes with mp_start_method="fork"
if is_fork:
cls = SyncDataCollector
env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
else:
cls = MultiaSyncDataCollector
env_arg = [
make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
# this is the default behavior: the collector runs in ``"random"`` (or explorative) mode
exploration_type=ExplorationType.RANDOM,
# We set the all the devices to be identical. Below is an example of
# heterogeneous devices
device=device,
storing_device=device,
split_trajs=False,
postproc=MultiStep(gamma=gamma, n_steps=5),
)
return data_collector
損失函數¶
建構我們的損失函數很簡單:我們只需要向 DQNLoss 類別提供模型和一堆超參數。
目標參數¶
許多離策略 RL 演算法在使用估計下一個狀態或狀態-動作對的值時,使用「目標參數」的概念。目標參數是模型參數的滯後副本。因為它們的預測與目前模型配置的預測不匹配,它們透過對估計的值設定悲觀的界限來幫助學習。這是一個強大的技巧(稱為「雙重 Q 學習」),在類似的演算法中非常普遍。
def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module, eps=0.995)
return loss_module, target_updater
超參數¶
讓我們從我們的超參數開始。以下設定在實踐中應該效果很好,並且演算法的效能希望不會對這些設定的輕微變化太敏感。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
最佳化器¶
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8
DQN 參數¶
gamma 衰減因子
gamma = 0.99
平滑目標網路更新衰減參數。這大致對應於具有硬目標網路更新的 1/tau 間隔。
tau = 0.02
資料收集和重播緩衝區¶
注意
用於適當訓練的值已註解。
在環境中收集的總幀數。在其他實作中,使用者定義了最大集數。這對於我們的資料收集器來說更難做到,因為它們會傳回 N 個收集幀的批次,其中 N 是一個常數。但是,透過在收集到一定數量的集數時中斷訓練迴圈,可以輕鬆獲得對集數的相同限制。
total_frames = 5_000 # 500000
用於初始化重播緩衝區的隨機幀。
init_random_frames = 100 # 1000
每個收集批次中的幀數。
frames_per_batch = 32 # 128
在每個最佳化步驟從重播緩衝區中抽樣的幀數
batch_size = 32 # 256
重播緩衝區的大小(以幀數表示)
buffer_size = min(total_frames, 100000)
在每個資料收集器中平行執行的環境數量
num_workers = 2 # 8
num_collectors = 2 # 4
環境和探索¶
我們設定了 Epsilon-greedy 探索中 epsilon 因子的初始值和最終值。由於我們的策略是確定性的,因此探索至關重要:如果沒有它,唯一的隨機性來源將是環境重置。
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005
為了加速學習,我們將價值網路最後一層的偏差設定為預定義的值(這不是強制性的)
init_bias = 2.0
注意
為了快速呈現本教學,total_frames
超參數被設定為一個非常低的值。為了獲得合理的效能,請使用更大的值,例如 500000
建立 Trainer¶
TorchRL 的 Trainer
類別建構子接受以下僅關鍵字參數
collector
loss_module
optimizer
logger
:logger 可以是total_frames
:這個參數定義了 trainer 的生命週期。frame_skip
:當使用 frame-skip 時,必須讓收集器知道它,以便準確計算收集的幀數等。讓 trainer 知道這個參數不是強制性的,但有助於在總幀數(預算)固定但 frame-skip 可變的情況下進行更公平的比較。
stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)
collector = get_collector(
stats=stats,
num_collectors=num_collectors,
actor_explore=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
)
optimizer = torch.optim.Adam(
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")
state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]],
[[0.9895]],
[[0.9895]],
[[0.9895]]])), ('scale', tensor([[[0.0737]],
[[0.0737]],
[[0.0737]],
[[0.0737]]]))])
我們可以控制記錄純量值的頻率。在這裡,我們將其設定為一個較低的值,因為我們的訓練迴圈很短
註冊鉤子¶
註冊鉤子可以透過兩種不同的方式實現
如果鉤子有
register()
方法,這是首選。只需將 trainer 作為輸入提供,鉤子將以預設名稱在預設位置註冊。對於某些鉤子,註冊可能相當複雜:ReplayBufferTrainer
需要 3 個鉤子(extend
、sample
和update_priority
),實作起來可能很麻煩。
buffer_hook = ReplayBufferTrainer(
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
)
recorder.register(trainer)
探索模組 epsilon 因子也會退火。
trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)
任何可呼叫的物件 (包括
TrainerHookBase
的子類別) 都可以使用register_op()
註冊。在這種情況下,必須明確傳遞位置 ()。此方法可以更好地控制 Hook 的位置,但也需要更多對 Trainer 機制的理解。查看 trainer 文件 以獲取 Trainer Hook 的詳細說明。
trainer.register_op("post_optim", target_net_updater.step)
我們也可以記錄訓練獎勵。請注意,這對於 CartPole 來說意義不大,因為獎勵始終為 1。折扣後的獎勵總和並不是通過獲得更高的獎勵來最大化,而是通過讓 cart-pole 存活更長時間來實現。這將反映在進度條中顯示的 total_rewards 值中。
log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)
注意
如果需要,可以將多個優化器鏈接到 Trainer。在這種情況下,每個優化器都將與損失字典中的一個欄位綁定。查看 OptimizerHook
以了解更多資訊。
我們已經準備好訓練我們的演算法了!簡單地調用 trainer.train()
,我們就會將結果記錄下來。
trainer.train()
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 32/5000 [00:07<20:39, 4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%| | 32/5000 [00:07<20:39, 4.01it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:08<09:02, 9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 1%|▏ | 64/5000 [00:08<09:02, 9.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 2%|▏ | 96/5000 [00:08<05:18, 15.40it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 128/5000 [00:09<03:33, 22.84it/s]
r_training: 0.3323, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 3%|▎ | 160/5000 [00:09<02:34, 31.38it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 192/5000 [00:09<01:57, 40.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 4%|▍ | 224/5000 [00:10<01:36, 49.41it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 5%|▌ | 256/5000 [00:10<01:21, 58.17it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▌ | 288/5000 [00:10<01:12, 65.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 6%|▋ | 320/5000 [00:11<01:05, 71.61it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 7%|▋ | 352/5000 [00:11<01:01, 76.03it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 384/5000 [00:11<00:57, 80.71it/s]
r_training: 0.3505, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 8%|▊ | 416/5000 [00:12<00:55, 82.38it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 9%|▉ | 448/5000 [00:12<00:52, 86.92it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 10%|▉ | 480/5000 [00:12<00:51, 87.57it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 10%|█ | 512/5000 [00:13<00:52, 86.28it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 11%|█ | 544/5000 [00:13<00:51, 86.63it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 576/5000 [00:14<00:50, 88.04it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 12%|█▏ | 608/5000 [00:14<00:49, 88.23it/s]
r_training: 0.3596, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 640/5000 [00:14<00:49, 88.38it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 13%|█▎ | 672/5000 [00:15<00:47, 90.20it/s]
r_training: 0.3566, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 14%|█▍ | 704/5000 [00:15<00:48, 88.22it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434: 15%|█▍ | 736/5000 [00:15<00:48, 88.05it/s]
r_training: 0.3960, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 15%|█▌ | 768/5000 [00:16<00:47, 88.59it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 16%|█▌ | 800/5000 [00:16<00:46, 89.47it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 832/5000 [00:16<00:47, 88.51it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 17%|█▋ | 864/5000 [00:17<00:45, 91.01it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 18%|█▊ | 896/5000 [00:17<00:46, 89.03it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 19%|█▊ | 928/5000 [00:17<00:46, 88.35it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434: 19%|█▉ | 960/5000 [00:18<00:45, 88.13it/s]
r_training: 0.3292, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 20%|█▉ | 992/5000 [00:18<00:44, 89.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 20%|██ | 1024/5000 [00:19<00:45, 88.35it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 21%|██ | 1056/5000 [00:19<00:45, 87.62it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1088/5000 [00:19<00:42, 91.60it/s]
r_training: 0.3899, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 22%|██▏ | 1120/5000 [00:20<00:41, 93.01it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 23%|██▎ | 1152/5000 [00:20<00:40, 95.24it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 24%|██▎ | 1184/5000 [00:20<00:41, 92.75it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 24%|██▍ | 1216/5000 [00:21<00:41, 90.23it/s]
r_training: 0.3445, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 25%|██▍ | 1248/5000 [00:21<00:42, 87.78it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1280/5000 [00:21<00:42, 87.94it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 26%|██▌ | 1312/5000 [00:22<00:41, 88.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 27%|██▋ | 1344/5000 [00:22<00:40, 90.44it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1376/5000 [00:22<00:40, 88.57it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 28%|██▊ | 1408/5000 [00:23<00:40, 88.32it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1440/5000 [00:23<00:40, 88.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 29%|██▉ | 1472/5000 [00:24<00:40, 88.02it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 30%|███ | 1504/5000 [00:24<00:40, 87.23it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███ | 1536/5000 [00:24<00:39, 86.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 31%|███▏ | 1568/5000 [00:25<00:39, 87.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 32%|███▏ | 1600/5000 [00:25<00:39, 86.43it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1632/5000 [00:25<00:38, 86.66it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 33%|███▎ | 1664/5000 [00:26<00:38, 87.17it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 34%|███▍ | 1696/5000 [00:26<00:37, 89.14it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 35%|███▍ | 1728/5000 [00:26<00:36, 89.07it/s]
r_training: 0.3718, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 35%|███▌ | 1760/5000 [00:27<00:35, 91.73it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▌ | 1792/5000 [00:27<00:35, 91.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 36%|███▋ | 1824/5000 [00:28<00:35, 88.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 37%|███▋ | 1856/5000 [00:28<00:35, 88.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1888/5000 [00:28<00:33, 92.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 38%|███▊ | 1920/5000 [00:29<00:32, 94.75it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 39%|███▉ | 1952/5000 [00:29<00:32, 94.41it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|███▉ | 1984/5000 [00:29<00:31, 94.62it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 40%|████ | 2016/5000 [00:30<00:31, 94.09it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 41%|████ | 2048/5000 [00:30<00:31, 93.61it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2080/5000 [00:30<00:31, 94.07it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 42%|████▏ | 2112/5000 [00:31<00:30, 95.79it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 43%|████▎ | 2144/5000 [00:31<00:31, 91.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▎ | 2176/5000 [00:31<00:30, 93.25it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 44%|████▍ | 2208/5000 [00:32<00:29, 95.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 45%|████▍ | 2240/5000 [00:32<00:29, 93.00it/s]
r_training: 0.3808, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 45%|████▌ | 2272/5000 [00:32<00:28, 94.99it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 46%|████▌ | 2304/5000 [00:33<00:28, 95.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2336/5000 [00:33<00:28, 94.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 47%|████▋ | 2368/5000 [00:33<00:28, 92.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 48%|████▊ | 2400/5000 [00:34<00:28, 91.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 49%|████▊ | 2432/5000 [00:34<00:28, 91.36it/s]
r_training: 0.3778, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 49%|████▉ | 2464/5000 [00:34<00:27, 92.58it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 50%|████▉ | 2496/5000 [00:35<00:27, 90.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2528/5000 [00:35<00:27, 89.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 51%|█████ | 2560/5000 [00:35<00:27, 87.80it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2592/5000 [00:36<00:26, 89.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 52%|█████▏ | 2624/5000 [00:36<00:26, 90.43it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 53%|█████▎ | 2656/5000 [00:37<00:26, 88.07it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2688/5000 [00:37<00:26, 86.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 54%|█████▍ | 2720/5000 [00:37<00:26, 86.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 55%|█████▌ | 2752/5000 [00:38<00:25, 88.22it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▌ | 2784/5000 [00:38<00:24, 90.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 56%|█████▋ | 2816/5000 [00:38<00:23, 92.05it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 57%|█████▋ | 2848/5000 [00:39<00:24, 89.07it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2880/5000 [00:39<00:23, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 58%|█████▊ | 2912/5000 [00:39<00:22, 92.89it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 59%|█████▉ | 2944/5000 [00:40<00:22, 90.57it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|█████▉ | 2976/5000 [00:40<00:22, 89.29it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 60%|██████ | 3008/5000 [00:40<00:21, 91.42it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████ | 3040/5000 [00:41<00:21, 90.81it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 61%|██████▏ | 3072/5000 [00:41<00:21, 91.65it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 62%|██████▏ | 3104/5000 [00:41<00:20, 93.16it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3136/5000 [00:42<00:19, 94.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 63%|██████▎ | 3168/5000 [00:42<00:19, 93.29it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 64%|██████▍ | 3200/5000 [00:42<00:18, 94.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 0.9434: 65%|██████▍ | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▍ | 3232/5000 [00:50<02:12, 13.38it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▌ | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 65%|██████▌ | 3264/5000 [00:50<01:37, 17.83it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 66%|██████▌ | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 66%|██████▌ | 3296/5000 [00:50<01:13, 23.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3328/5000 [00:51<00:55, 29.87it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 67%|██████▋ | 3360/5000 [00:51<00:44, 36.95it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3392/5000 [00:52<00:36, 44.30it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 68%|██████▊ | 3424/5000 [00:52<00:30, 51.78it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 69%|██████▉ | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 69%|██████▉ | 3456/5000 [00:52<00:26, 58.47it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 70%|██████▉ | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 70%|██████▉ | 3488/5000 [00:53<00:23, 64.95it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 70%|███████ | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 70%|███████ | 3520/5000 [00:53<00:21, 69.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 71%|███████ | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 71%|███████ | 3552/5000 [00:54<00:19, 72.87it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3584/5000 [00:54<00:18, 75.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 72%|███████▏ | 3616/5000 [00:54<00:17, 78.72it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 73%|███████▎ | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 73%|███████▎ | 3648/5000 [00:55<00:16, 80.90it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▎ | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▎ | 3680/5000 [00:55<00:15, 82.97it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▍ | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 74%|███████▍ | 3712/5000 [00:55<00:15, 84.10it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 75%|███████▍ | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 75%|███████▍ | 3744/5000 [00:56<00:14, 83.85it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3776/5000 [00:56<00:14, 84.22it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 76%|███████▌ | 3808/5000 [00:56<00:14, 84.97it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3840/5000 [00:57<00:13, 85.76it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 77%|███████▋ | 3872/5000 [00:57<00:12, 88.48it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 78%|███████▊ | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 78%|███████▊ | 3904/5000 [00:58<00:12, 89.26it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▊ | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▊ | 3936/5000 [00:58<00:11, 90.04it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▉ | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 79%|███████▉ | 3968/5000 [00:58<00:11, 89.77it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 80%|████████ | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 80%|████████ | 4000/5000 [00:59<00:11, 89.92it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 81%|████████ | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████ | 4032/5000 [00:59<00:10, 88.32it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 81%|████████▏ | 4064/5000 [00:59<00:10, 86.08it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 82%|████████▏ | 4096/5000 [01:00<00:10, 84.34it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4128/5000 [01:00<00:10, 85.10it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 83%|████████▎ | 4160/5000 [01:00<00:09, 87.36it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4192/5000 [01:01<00:08, 90.26it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 84%|████████▍ | 4224/5000 [01:01<00:08, 92.19it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 85%|████████▌ | 4256/5000 [01:01<00:08, 92.86it/s]
r_training: 0.3688, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▌ | 4288/5000 [01:02<00:07, 90.60it/s]
r_training: 0.3991, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 86%|████████▋ | 4320/5000 [01:02<00:07, 92.06it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 87%|████████▋ | 4352/5000 [01:03<00:06, 93.60it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4384/5000 [01:03<00:06, 92.15it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 88%|████████▊ | 4416/5000 [01:03<00:06, 92.21it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 89%|████████▉ | 4448/5000 [01:04<00:05, 92.90it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|████████▉ | 4480/5000 [01:04<00:05, 89.09it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 90%|█████████ | 4512/5000 [01:04<00:05, 87.57it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 91%|█████████ | 4544/5000 [01:05<00:05, 87.44it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4576/5000 [01:05<00:04, 86.47it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 92%|█████████▏| 4608/5000 [01:05<00:04, 85.84it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4640/5000 [01:06<00:04, 85.98it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 93%|█████████▎| 4672/5000 [01:06<00:03, 87.97it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: 94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 94%|█████████▍| 4704/5000 [01:07<00:03, 87.37it/s]
r_training: 0.3869, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▍| 4736/5000 [01:07<00:02, 90.46it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 95%|█████████▌| 4768/5000 [01:07<00:02, 93.67it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 96%|█████████▌| 4800/5000 [01:08<00:02, 93.89it/s]
r_training: 0.4082, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4832/5000 [01:08<00:01, 91.12it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 97%|█████████▋| 4864/5000 [01:08<00:01, 89.91it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 98%|█████████▊| 4896/5000 [01:09<00:01, 88.75it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▊| 4928/5000 [01:09<00:00, 90.36it/s]
r_training: 0.4173, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 99%|█████████▉| 4960/5000 [01:09<00:00, 93.68it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: 100%|█████████▉| 4992/5000 [01:10<00:00, 91.37it/s]
r_training: 0.4295, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
r_training: 0.4021, rewards: 0.1000, total_rewards: 5.5556: : 5024it [01:10, 90.51it/s]
我們現在可以快速檢查包含結果的 CSV 檔案。
def print_csv_files_in_folder(folder_path):
"""
Find all CSV files in a folder and prints the first 10 lines of each file.
Args:
folder_path (str): The relative path to the folder.
"""
csv_files = []
output_str = ""
for dirpath, _, filenames in os.walk(folder_path):
for file in filenames:
if file.endswith(".csv"):
csv_files.append(os.path.join(dirpath, file))
for csv_file in csv_files:
output_str += f"File: {csv_file}\n"
with open(csv_file, "r") as f:
for i, line in enumerate(f):
if i == 10:
break
output_str += line.strip() + "\n"
output_str += "\n"
print(output_str)
print_csv_files_in_folder(logger.experiment.log_dir)
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/r_training.csv
512,0.3566153347492218
1024,0.39912936091423035
1536,0.39912936091423035
2048,0.39912936091423035
2560,0.42945271730422974
3072,0.40213119983673096
3584,0.39912933111190796
4096,0.42945271730422974
4608,0.42945271730422974
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/optim_steps.csv
512,128.0
1024,256.0
1536,384.0
2048,512.0
2560,640.0
3072,768.0
3584,896.0
4096,1024.0
4608,1152.0
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/loss.csv
512,0.47876793146133423
1024,0.18667784333229065
1536,0.1948033571243286
2048,0.22345909476280212
2560,0.2145865112543106
3072,0.47586697340011597
3584,0.28343674540519714
4096,0.3203103542327881
4608,0.3053428530693054
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/grad_norm_0.csv
512,5.5816755294799805
1024,2.9089717864990234
1536,3.4687838554382324
2048,2.8756051063537598
2560,2.7815587520599365
3072,6.685841083526611
3584,3.793360948562622
4096,3.469670295715332
4608,3.317387104034424
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/rewards.csv
3232,0.10000000894069672
File: /tmp/tmp16y9hbib/dqn_exp_a0ee01f2-90c8-11ef-a49b-0242ac110002/scalars/total_rewards.csv
3232,5.555555820465088
結論與可能的改進¶
在本教學中,我們學習了
如何編寫 Trainer,包括建立其元件並將它們註冊到 Trainer 中;
如何編碼 DQN 演算法,包括如何創建一個 Policy,該 Policy 使用
QValueNetwork
選擇具有最高值的動作;如何建立多進程資料收集器;
對本教學可能的改進包括
也可以使用優先順序重播緩衝區。這將給予具有最差價值準確度的樣本更高的優先順序。在文件的 重播緩衝區部分 了解更多資訊。
分佈式損失(請參閱
DistributionalDQNLoss
以了解更多資訊)。更精美的探索技術,例如
NoisyLinear
層等等。
腳本的總運行時間: (2 分鐘 40.957 秒)
預估的記憶體使用量: 1267 MB