torchrl.objectives 封包¶
TorchRL 提供一系列用於訓練腳本中的損失。 目的是擁有易於重複使用/交換且具有簡單簽名的損失。
TorchRL 損失的主要特徵是
它們是有狀態的物件:它們包含可訓練參數的副本,因此
loss_module.parameters()
提供訓練演算法所需的任何內容。它們遵循
tensordict
慣例:torch.nn.Module.forward()
方法將接收一個 tensordict 作為輸入,其中包含傳回損失值所需的所有資訊。它們輸出一個
tensordict.TensorDict
實例,損失值寫在"loss_<smth>"
下,其中smth
是一個描述損失的字串。 tensordict 中的其他鍵可能是訓練期間記錄的有用指標。
注意
我們傳回獨立損失的原因是讓使用者可以為不同的參數集使用不同的最佳化器。 損失的總和可以透過簡單地透過以下方式完成
>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))
注意
損失中的初始化參數可以透過查詢 get_stateful_net()
來完成,它將傳回網路的有狀態版本,可以像任何其他模組一樣進行初始化。 如果修改是就地完成的,它將被向下傳輸到使用相同參數集的任何其他模組(在損失內部和外部):例如,從損失中修改 actor_network
參數也將修改收集器中的 actor。 如果參數被異地修改,可以使用 from_stateful_net()
將損失中的參數重設為新值。
torch.vmap 和隨機性¶
TorchRL 的損失模組大量呼叫 vmap()
以分攤在迴圈中呼叫多個相似模型的成本,並改為向量化這些操作。vmap 需要被明確告知在呼叫中需要生成隨機數時該怎麼做。為此,需要設定一個隨機性模式,並且必須是 “error” (預設,處理偽隨機函數時會出錯)、“same” (在批次中複製結果) 或 “different” (批次中的每個元素都被單獨處理) 其中之一。依賴預設值通常會導致類似以下的錯誤
>>> RuntimeError: vmap: called random operation while in randomness error mode.
由於對 vmap 的呼叫被埋藏在損失模組中,TorchRL 提供了一個介面,可以從外部透過 loss.vmap_randomness = str_value 設定該 vmap 模式,詳情請參閱 vmap_randomness()
。
如果未檢測到隨機模組,則 LossModule.vmap_randomness
預設為 “error”,否則預設為 “different”。預設情況下,只有有限數量的模組被列為隨機模組,但可以使用 add_random_module()
函數擴展該列表。
訓練價值函數¶
TorchRL 提供了一系列的**價值估計器**,例如 TD(0)、TD(1)、TD(\(\lambda\)) 和 GAE。簡而言之,價值估計器是數據(主要是獎勵和完成狀態)和狀態價值(即,由擬合用於估計狀態價值的函數返回的值)的函數。要了解更多關於價值估計器的信息,請查看 Sutton 和 Barto 的 RL 介紹,特別是關於價值迭代和 TD 學習的章節。它根據數據和代理映射,對狀態或狀態-動作對之後的折扣回報進行一些有偏差的估計。這些估計器在以下兩種情況中使用
為了訓練價值網路以學習“真實”的狀態價值(或狀態-動作價值)映射,需要一個目標價值來擬合它。估計器越好(偏差越小,方差越小),價值網路就會越好,這反過來可以顯著加快策略訓練。通常,價值網路損失看起來像
>>> value = value_network(states) >>> target_value = value_estimator(rewards, done, value_network(next_state)) >>> value_net_loss = (value - target_value).pow(2).mean()
計算用於策略優化的“優勢”信號。優勢是價值估計(來自估計器,即來自“真實”數據)和價值網路的輸出(即該價值的代理)之間的差值。一個正優勢可以被看作是一個信號,表明該策略實際上比預期的表現更好,從而表明如果將該軌跡作為範例,則有改進的空間。相反,負優勢表示該策略的表現不如預期。
事情並不像上面的例子那麼容易,並且計算價值估計器或優勢的公式可能比這稍微複雜一些。為了幫助用戶靈活地使用一個或另一個價值估計器,我們提供了一個簡單的 API 來隨時更改它。這是一個 DQN 的例子,但所有模組都將遵循類似的結構
>>> from torchrl.objectives import DQNLoss, ValueEstimators
>>> loss_module = DQNLoss(actor)
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)
ValueEstimators
類別列舉了可供選擇的價值估計器。這使得用戶可以輕鬆地依靠自動完成來做出選擇。
|
RL 損失的父類別。 |
DQN¶
|
DQN 損失類別。 |
|
一個分配 DQN 損失類別。 |
DDPG¶
|
DDPG 損失類別。 |
SAC¶
|
SAC 損失的 TorchRL 實現。 |
|
離散 SAC 損失模組。 |
REDQ¶
|
REDQ 損失模組。 |
CrossQ¶
|
CrossQ 損失的 TorchRL 實現。 |
IQL¶
|
IQL 損失的 TorchRL 實現。 |
|
離散 IQL 損失的 TorchRL 實現。 |
CQL¶
|
連續 CQL 損失的 TorchRL 實現。 |
|
離散 CQL 損失的 TorchRL 實現。 |
GAIL¶
|
Generative Adversarial Imitation Learning (GAIL) 損失函數的 TorchRL 實作。 |
DT¶
|
Online Decision Transformer 損失函數的 TorchRL 實作。 |
|
Online Decision Transformer 損失函數的 TorchRL 實作。 |
TD3¶
|
TD3 損失模組。 |
TD3+BC¶
|
TD3+BC 損失模組。 |
PPO¶
|
PPO 損失函數的父類別。 |
|
Clipped PPO 損失函數。 |
|
KL Penalty PPO 損失函數。 |
A2C¶
|
A2C 損失函數的 TorchRL 實作。 |
Reinforce¶
|
Reinforce 損失模組。 |
Dreamer¶
|
Dreamer Actor 損失函數。 |
|
Dreamer Model 損失函數。 |
|
Dreamer Value 損失函數。 |
Multi-agent objectives¶
這些目標函數專屬於多代理人演算法。
QMixer¶
|
QMixer 損失函數類別。 |
Returns¶
|
數值函數模組的抽象父類別。 |
|
優勢函數的時間差分 (TD(0)) 估計。 |
|
\(\infty\)-優勢函數的時間差分 (TD(1)) 估計。 |
|
優勢函數的 TD(\(\lambda\)) 估計。 |
|
廣義優勢估計函數的類別包裝器。 |
|
軌跡的 TD(0) 折扣回報估計。 |
|
軌跡的 TD(0) 優勢估計。 |
|
TD(1) 回報估計。 |
|
向量化 TD(1) 回報估計。 |
|
TD(1) 優勢估計。 |
|
向量化 TD(1) 優勢估計。 |
|
TD(\(\lambda\)) 回報估計。 |
向量化 TD(\(\lambda\)) 回報估計。 |
|
TD(\(\lambda\)) 優勢估計。 |
|
向量化 TD(\(\lambda\)) 優勢估計。 |
|
軌跡的廣義優勢估計。 |
|
向量化軌跡的廣義優勢估計。 |
|
|
計算給定多個軌跡和 episode 結束時的折扣累積獎勵總和。 |
工具程式¶
|
計算兩個 tensors 之間的距離損失。 |
|
將網路排除在計算圖之外的上下文管理器。 |
|
將參數列表排除在計算圖之外的上下文管理器。 |
|
計算下一個狀態值(無梯度)以計算目標值。 |
|
Double DQN/DDPG 中用於目標網路更新的軟更新類別。 |
|
Double DQN/DDPG 中用於目標網路更新的硬更新類別(與軟更新相比)。 |
|
自訂建構估算器的值函數列舉器。 |
|
預設值函數關鍵字參數產生器。 |