快捷鍵

get_graph_node_names

torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]][原始碼]

開發人員公用程式,用於依執行順序傳回節點名稱。請參閱 create_feature_extractor() 底下的節點名稱注意事項。適用於查看哪些節點名稱可用於特徵提取。節點名稱無法輕易地直接從模型的程式碼中讀取,原因有二

  1. 並非所有子模組都會被追蹤。來自 torch.nn 的模組都屬於此類。

  2. 代表相同運算或葉模組重複應用的節點會取得 _{counter} 後綴。

模型會被追蹤兩次:一次在訓練模式,另一次在評估模式。訓練模式和評估模式的節點名稱集都會傳回。

如需此處使用的節點命名慣例的更多詳細資訊,請參閱相關子標題,位於文件中。

參數:
  • model (nn.Module) – 我們想要列印節點名稱的模型

  • tracer_kwargs (dict, optional) – 用於 NodePathTracer 的關鍵字引數字典(它們最終會傳遞到 torch.fx.Tracer)。預設情況下,它會設定為包裝所有 torchvision 運算並將其設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供 tracer_kwargs,則上述預設引數將附加到使用者提供的字典。

  • suppress_diff_warning (bool, optional) – 是否在圖形的訓練版本和評估版本之間存在差異時,抑制警告。預設值為 False。

  • concrete_args (Optional[Dict[str, any]]) – 不應視為 Proxies 的具體引數。根據 Pytorch 文件,此參數的 API 可能無法保證。

傳回:

從在訓練模式下追蹤模型取得的節點名稱列表,以及從在評估模式下追蹤模型取得的另一個節點名稱列表。

傳回類型:

tuple(list, list)

範例

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

文件

取得 PyTorch 的全面開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源