get_primers_from_module¶
- class torchrl.modules.utils.get_primers_from_module(module)[來源]¶
從模組的所有子模組取得所有 tensordict primers。
此方法對於從父模組中包含的模組檢索 primers 很有用。
- 參數:
module (torch.nn.Module) – 父模組。
- 回傳:
TensorDictPrimer 轉換。
- 回傳類型:
範例
>>> from torchrl.modules.utils import get_primers_from_module >>> from torchrl.modules import GRUModule, MLP >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> # Define a GRU module >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... num_layers=1, ... in_keys=["input", "recurrent_state", "is_init"], ... out_keys=["features", ("next", "recurrent_state")], ... ) >>> # Define a head module >>> head = TensorDictModule( ... MLP( ... in_features=10, ... out_features=10, ... num_cells=[], ... ), ... in_keys=["features"], ... out_keys=["output"], ... ) >>> # Create a sequential model >>> model = TensorDictSequential(gru_module, head) >>> # Retrieve primers from the model >>> primers = get_primers_from_module(model) >>> print(primers)
- TensorDictPrimer(primers=Composite(
- recurrent_state: UnboundedContinuous(
shape=torch.Size([1, 10]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([])), default_value={‘recurrent_state’: 0.0}, random=None)