create_feature_extractor¶
- torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule [原始碼]¶
建立新的圖形模組,從給定模型傳回中繼節點,作為字典,其中使用者指定的鍵為字串,而請求的輸出為值。這是透過 FX 重新撰寫模型的運算圖來實現,以傳回所需的節點作為輸出。所有未使用的節點及其對應的參數都會被移除。
所需的輸出節點必須指定為
.
分隔的路徑,從頂層模組向下走到葉操作或葉模組。如需關於此處使用的節點命名慣例的更多詳細資訊,請參閱 相關子標題,位於 文件 中。並非所有模型都可透過 FX 追蹤,但經過一些調整,它們可以協同運作。以下是一些提示(非詳盡列表)
如果您不需要追蹤特定的、有問題的子模組,請將其轉換為「葉模組」,方法是將
leaf_modules
列表作為tracer_kwargs
之一傳遞(請參閱以下範例)。它將不會被追蹤,而是產生的圖形將保留對該模組前向方法的參考。同樣地,您可以透過將
autowrap_functions
列表作為tracer_kwargs
之一傳遞,將函式轉換為葉函式(請參閱以下範例)。某些內建的 Python 函式可能會產生問題。例如,
int
將在追蹤期間引發錯誤。您可以將它們包裝在您自己的函式中,然後將其在autowrap_functions
中作為tracer_kwargs
之一傳遞。
如需關於 FX 的更多資訊,請參閱 torch.fx 文件。
- 參數:
model (nn.Module) – 我們將在其上提取特徵的模型
return_nodes (list 或 dict, optional) – 包含節點名稱(或部分名稱 - 請參閱上方註釋)的
List
或Dict
,將為這些節點傳回激活值。如果是Dict
,則鍵是節點名稱,值是圖形模組傳回字典的使用者指定鍵。如果是List
,則將其視為Dict
,將節點規格字串直接映射到輸出名稱。在指定train_return_nodes
和eval_return_nodes
的情況下,不應指定此項。train_return_nodes (list 或 dict, optional) – 類似於
return_nodes
。如果訓練模式的傳回節點與評估模式的傳回節點不同,則可以使用此項。如果指定了此項,則也必須指定eval_return_nodes
,且不應指定return_nodes
。eval_return_nodes (list 或 dict, optional) – 類似於
return_nodes
。如果訓練模式的傳回節點與評估模式的傳回節點不同,則可以使用此項。如果指定了此項,則也必須指定train_return_nodes
,且不應指定 return_nodes。tracer_kwargs (dict, optional) –
NodePathTracer
的關鍵字引數字典(它將它們傳遞給其父類別 torch.fx.Tracer)。預設情況下,它將設定為包裝並將所有 torchvision 運算設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供 tracer_kwargs,上述預設引數將附加至使用者提供的字典。suppress_diff_warning (bool, optional) – 是否在圖形的訓練版本和評估版本之間存在差異時,抑制警告。預設值為 False。
concrete_args (Optional[Dict[str, any]]) – 不應視為 Proxies 的具體引數。根據 Pytorch 文件,此參數的 API 可能無法保證。
範例
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})