捷徑

ProbabilisticTensorDictModule

class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)

一個機率 TD 模組。

ProbabilisticTensorDictModule 是一個非參數模組,代表一個機率分佈。它使用指定的 in_keys 從輸入 TensorDict 讀取分佈參數。輸出會根據某些規則進行採樣,這些規則由輸入的 default_interaction_type 參數和 interaction_type() 全域函式指定。

ProbabilisticTensorDictModule 可用於建構分佈(透過 get_dist() 方法)和/或從此分佈取樣(透過對模組進行常規 __call__())。

一個 ProbabilisticTensorDictModule 實例有兩個主要特點: - 它讀取和寫入 TensorDict 物件 - 它使用真實映射 R^n -> R^m 來建立 R^d 中的分佈,可以從中取樣或計算值。

__call__ / forward 方法被呼叫時,會建立一個分佈,並計算一個值(使用 'mean'、'mode'、'median' 屬性或 'rsample'、'sample' 方法)。如果提供的 TensorDict 已經擁有所有需要的鍵值對,則會跳過取樣步驟。

預設情況下,ProbabilisticTensorDictModule 分佈類別是一個 Delta 分佈,這使得 ProbabilisticTensorDictModule 成為一個簡單的確定性映射函式包裝器。

參數:
  • in_keys (NestedKeyNestedKey 列表dict) – 將從輸入 TensorDict 讀取並用於建構分佈的鍵。重要的是,如果它是一個 NestedKey 列表或 NestedKey,則這些鍵的葉節點(最後一個元素)必須與感興趣的分佈類別使用的關鍵字匹配,例如,Normal 分佈的 "loc""scale" 等。如果 in_keys 是一個字典,則鍵是分佈的鍵,值是在 tensordict 中與相應分佈鍵匹配的鍵。

  • out_keys (NestedKeyNestedKey 列表) – 樣本值將被寫入的鍵。重要的是,如果在輸入 TensorDict 中找到這些鍵,則會跳過取樣步驟。

  • default_interaction_mode (str, 可選) – 已棄用 僅限關鍵字的參數。請改用 default_interaction_type。

  • default_interaction_type (InteractionType, 可選) –

    僅限關鍵字的參數。用於檢索輸出值的預設方法。應該是 InteractionType 之一:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值是從分佈中隨機取樣)。預設值為 MODE。

    注意

    當繪製樣本時,ProbabilisticTensorDictModule 實例將首先尋找由 interaction_type() 全域函式指定的交互模式。如果這返回 None(其預設值),則將使用 ProbabilisticTDModule 實例的 default_interaction_type。請注意,DataCollectorBase 實例將預設使用 set_interaction_type 設為 tensordict.nn.InteractionType.RANDOM

    注意

    在某些情況下,可能無法透過對應的屬性直接取得眾數、中位數或平均數值。為了改善這種情況,ProbabilisticTensorDictModule 會先嘗試呼叫 get_mode()get_median()get_mean() 來取得數值(如果該方法存在)。

  • distribution_class (Type, optional) –

    僅限關鍵字參數。用於抽樣的 torch.distributions.Distribution 類別。預設值為 Delta

    注意

    如果 distribution 類別屬於 CompositeDistribution 類型,則可以從此類別的 distribution_kwargs 關鍵字參數提供的 "distribution_map""name_map" 關鍵字參數中直接推斷出 out_keys,在這種情況下,out_keys 是可選的。

  • distribution_kwargs (dict, optional) – 僅限關鍵字參數。要傳遞給 distribution 的關鍵字參數對。

  • return_log_prob (bool, optional) – 僅限關鍵字參數。如果 True,則 distribution 樣本的 log-probability 將會使用索引鍵 log_prob_key 寫入 tensordict 中。預設值為 False

  • log_prob_key (NestedKey, optional) – 如果 return_log_prob = True,則 log_prob 的寫入索引鍵。預設值為 ‘sample_log_prob’

  • cache_dist (bool, optional) – 僅限關鍵字參數。實驗性功能:如果 True,distribution 的參數(即 module 的輸出)將會與樣本一起寫入 tensordict 中。這些參數可用於稍後重新計算原始 distribution(例如,計算用於抽樣動作的 distribution 與 PPO 中更新的 distribution 之間的 divergence)。預設值為 False

  • n_empirical_estimate (int, optional) – 僅限關鍵字參數。計算經驗平均值時使用的樣本數,當無法直接取得平均值時使用。預設值為 1000。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal, Independent
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["params"]
... )
>>> normal_params = TensorDictModule(
...     NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
... )
>>> def IndepNormal(**kwargs):
...     return Independent(Normal(**kwargs), 1)
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=IndepNormal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(
...     module, normal_params, prob_module
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> with params.to_module(td_module):
...     dist = td_module.get_dist(td)
>>> print(dist)
Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1)
>>> # we can also apply the module to the TensorDict with vmap
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase

定義每次呼叫時執行的計算。

應該由所有子類別覆寫。

注意

雖然正向傳遞的配方需要在這個函式中定義,但應該在這個函式之後呼叫 Module 實例,而不是這個函式,因為前者會處理執行註冊的 hooks,而後者會靜默地忽略它們。

get_dist(tensordict: TensorDictBase) Distribution

使用輸入 tensordict 中提供的參數建立 torch.distribution.Distribution 實例。

log_prob(tensordict)

寫入 distribution 樣本的 log-probability。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源