ValueEstimatorBase¶
- class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[source]¶
值函數模組的抽象父類別。
它的
ValueFunctionBase.forward()
方法將計算值(由值網路給定)和值估計(由值估計器給定),以及優勢,並將這些值寫入輸出 tensordict。如果只需要值估計,則應改用
ValueFunctionBase.value_estimate()
。- abstract forward(tensordict: TensorDictBase, *, params: Optional[TensorDictBase] = None, target_params: Optional[TensorDictBase] = None) TensorDictBase [source]¶
計算 tensordict 中資料給定的優勢估計。
如果提供了函數式模組,則可以將包含參數(如果相關,則包含目標參數)的巢狀 TensorDict 傳遞給該模組。
- 參數:
tensordict (TensorDictBase) – 一個包含資料的 TensorDict(觀察鍵,
"action"
,("next", "reward")
,("next", "done")
,("next", "terminated")
和"next"
tensordict 狀態,如環境所返回),這些資料對於計算值估計和 TDEstimate 是必要的。傳遞給此模組的資料應結構化為[*B, T, *F]
,其中B
是批次大小,T
是時間維度,F
是特徵維度。 tensordict 必須具有形狀[*B, T]
。- 關鍵字參數:
params (TensorDictBase, optional) – 一個巢狀 TensorDict,包含要傳遞到函數式值網路模組的參數。
target_params (TensorDictBase, 選擇性) – 包含目標參數的巢狀 TensorDict,這些目標參數將被傳遞到函數式價值網路模組。
- 回傳值:
一個更新後的 TensorDict,包含建構子中定義的 advantage 和 value_error 鍵。
- value_estimate(tensordict, target_params: Optional[TensorDictBase] = None, next_value: Optional[Tensor] = None, **kwargs)[原始碼]¶
取得價值估計,通常用作價值網路的目標價值。
如果狀態價值鍵存在於
tensordict.get(("next", self.tensor_keys.value))
下,則將使用此價值,而無需再次使用價值網路。- 參數:
tensordict (TensorDictBase) – 包含要讀取資料的 tensordict。
target_params (TensorDictBase, 選擇性) – 包含目標參數的巢狀 TensorDict,這些目標參數將被傳遞到函數式價值網路模組。
next_value (torch.Tensor, 選擇性) – 下一個狀態或狀態-動作對的價值。與
target_params
互斥。**kwargs – 要傳遞給價值網路的關鍵字引數。
回傳值: 對應於狀態價值的張量。