get_graph_node_names¶
- torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]] [原始碼]¶
開發人員公用程式,用於依執行順序傳回節點名稱。請參閱
create_feature_extractor()
底下的節點名稱注意事項。適用於查看哪些節點名稱可用於特徵提取。節點名稱無法輕易地直接從模型的程式碼中讀取,原因有二並非所有子模組都會被追蹤。來自
torch.nn
的模組都屬於此類。代表相同運算或葉模組重複應用的節點會取得
_{counter}
後綴。
模型會被追蹤兩次:一次在訓練模式,另一次在評估模式。訓練模式和評估模式的節點名稱集都會傳回。
如需此處使用的節點命名慣例的更多詳細資訊,請參閱相關子標題,位於文件中。
- 參數:
model (nn.Module) – 我們想要列印節點名稱的模型
tracer_kwargs (dict, optional) – 用於
NodePathTracer
的關鍵字引數字典(它們最終會傳遞到 torch.fx.Tracer)。預設情況下,它會設定為包裝所有 torchvision 運算並將其設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供 tracer_kwargs,則上述預設引數將附加到使用者提供的字典。suppress_diff_warning (bool, optional) – 是否在圖形的訓練版本和評估版本之間存在差異時,抑制警告。預設值為 False。
concrete_args (Optional[Dict[str, any]]) – 不應視為 Proxies 的具體引數。根據 Pytorch 文件,此參數的 API 可能無法保證。
- 傳回:
從在訓練模式下追蹤模型取得的節點名稱列表,以及從在評估模式下追蹤模型取得的另一個節點名稱列表。
- 傳回類型:
範例
>>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model)