捷徑

next_state_value

class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]

計算下一個狀態值(沒有梯度)以計算目標值。

目標值通常用於計算距離損失(例如 MSE)

L = Sum[ (q_value - target_value)^2 ]

目標值計算為

r + gamma ** n_steps_to_next * value_next_state

如果獎勵是立即獎勵,則 n_steps_to_next=1。如果使用 N 步獎勵,則 n_steps_to_next 從輸入 tensordict 收集。

參數:
  • tensordict (TensorDictBase) – 包含獎勵和 done 鍵的 Tensordict(以及用於 n 步獎勵的 n_steps_to_next 鍵)。

  • operator (ProbabilisticTDModule, optional) – 值函數運算符。調用時應在輸入 tensordict 中寫入 ‘next_val_key’ 鍵值。如果給出 pred_next_val,則不需要提供它。

  • next_val_key (str, optional) – 將寫入下一個值的鍵。預設值:‘state_action_value’

  • gamma (float, optional) – 回傳折扣率。預設值:0.99

  • pred_next_val (Tensor, optional) – 如果未使用運算子計算下一個狀態值,則可以提供該值。

回傳值:

一個 Tensor,其大小與包含預測狀態值的輸入 tensordict 相同。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源