快捷方式

ValueEstimatorBase

class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[source]

值函數模組的抽象父類別。

它的 ValueFunctionBase.forward() 方法將計算值(由值網路給定)和值估計(由值估計器給定),以及優勢,並將這些值寫入輸出 tensordict。

如果只需要值估計,則應改用 ValueFunctionBase.value_estimate()

abstract forward(tensordict: TensorDictBase, *, params: Optional[TensorDictBase] = None, target_params: Optional[TensorDictBase] = None) TensorDictBase[source]

計算 tensordict 中資料給定的優勢估計。

如果提供了函數式模組,則可以將包含參數(如果相關,則包含目標參數)的巢狀 TensorDict 傳遞給該模組。

參數:

tensordict (TensorDictBase) – 一個包含資料的 TensorDict(觀察鍵,"action", ("next", "reward"), ("next", "done"), ("next", "terminated")"next" tensordict 狀態,如環境所返回),這些資料對於計算值估計和 TDEstimate 是必要的。傳遞給此模組的資料應結構化為 [*B, T, *F],其中 B 是批次大小,T 是時間維度,F 是特徵維度。 tensordict 必須具有形狀 [*B, T]

關鍵字參數:
  • params (TensorDictBase, optional) – 一個巢狀 TensorDict,包含要傳遞到函數式值網路模組的參數。

  • target_params (TensorDictBase, 選擇性) – 包含目標參數的巢狀 TensorDict,這些目標參數將被傳遞到函數式價值網路模組。

回傳值:

一個更新後的 TensorDict,包含建構子中定義的 advantage 和 value_error 鍵。

set_keys(**kwargs) None[原始碼]

設定 tensordict 的鍵名。

value_estimate(tensordict, target_params: Optional[TensorDictBase] = None, next_value: Optional[Tensor] = None, **kwargs)[原始碼]

取得價值估計,通常用作價值網路的目標價值。

如果狀態價值鍵存在於 tensordict.get(("next", self.tensor_keys.value)) 下,則將使用此價值,而無需再次使用價值網路。

參數:
  • tensordict (TensorDictBase) – 包含要讀取資料的 tensordict。

  • target_params (TensorDictBase, 選擇性) – 包含目標參數的巢狀 TensorDict,這些目標參數將被傳遞到函數式價值網路模組。

  • next_value (torch.Tensor, 選擇性) – 下一個狀態或狀態-動作對的價值。與 target_params 互斥。

  • **kwargs – 要傳遞給價值網路的關鍵字引數。

回傳值: 對應於狀態價值的張量。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源