torch.jit.annotate¶
- torch.jit.annotate(the_type, the_value)[source][source]¶
用於在 TorchScript 編譯器中指定 the_value 的類型。
此方法是一個直通函式,會傳回 the_value,用於提示 TorchScript 編譯器 the_value 的類型。在 TorchScript 之外執行時,它是一個空操作。
儘管 TorchScript 可以推斷大多數 Python 表達式的正確類型,但在某些情況下,類型推斷可能會出錯,包括
空容器,例如 [] 和 {},TorchScript 假設它們是 Tensor 的容器
可選類型,例如 Optional[T],但已分配類型 T 的有效值,TorchScript 會假設它是類型 T 而不是 Optional[T]
請注意,annotate() 在 torch.nn.Module 子類別的 __init__ 方法中不起作用,因為它是在 eager 模式下執行的。若要註釋 torch.nn.Module 屬性的類型,請改用
Attribute()
。範例
import torch from typing import Dict @torch.jit.script def fn(): # Telling TorchScript that this empty dictionary is a (str -> int) dictionary # instead of default dictionary type of (str -> Tensor). d = torch.jit.annotate(Dict[str, int], {}) # Without `torch.jit.annotate` above, following statement would fail because of # type mismatch. d["name"] = 20
- 參數
the_type – 應該作為類型提示傳遞給 TorchScript 編譯器的 Python 類型,用於 the_value
the_value – 用於提示類型的值或表達式。
- 回傳值
the_value 作為回傳值傳回。