next_state_value¶
- class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]¶
計算下一個狀態值(沒有梯度)以計算目標值。
- 目標值通常用於計算距離損失(例如 MSE)
L = Sum[ (q_value - target_value)^2 ]
- 目標值計算為
r + gamma ** n_steps_to_next * value_next_state
如果獎勵是立即獎勵,則 n_steps_to_next=1。如果使用 N 步獎勵,則 n_steps_to_next 從輸入 tensordict 收集。
- 參數:
tensordict (TensorDictBase) – 包含獎勵和 done 鍵的 Tensordict(以及用於 n 步獎勵的 n_steps_to_next 鍵)。
operator (ProbabilisticTDModule, optional) – 值函數運算符。調用時應在輸入 tensordict 中寫入 ‘next_val_key’ 鍵值。如果給出 pred_next_val,則不需要提供它。
next_val_key (str, optional) – 將寫入下一個值的鍵。預設值:‘state_action_value’
gamma (float, optional) – 回傳折扣率。預設值:0.99
pred_next_val (Tensor, optional) – 如果未使用運算子計算下一個狀態值,則可以提供該值。
- 回傳值:
一個 Tensor,其大小與包含預測狀態值的輸入 tensordict 相同。