捷徑

create_feature_extractor

torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule[原始碼]

建立新的圖形模組,從給定模型傳回中繼節點,作為字典,其中使用者指定的鍵為字串,而請求的輸出為值。這是透過 FX 重新撰寫模型的運算圖來實現,以傳回所需的節點作為輸出。所有未使用的節點及其對應的參數都會被移除。

所需的輸出節點必須指定為 . 分隔的路徑,從頂層模組向下走到葉操作或葉模組。如需關於此處使用的節點命名慣例的更多詳細資訊,請參閱 相關子標題,位於 文件 中。

並非所有模型都可透過 FX 追蹤,但經過一些調整,它們可以協同運作。以下是一些提示(非詳盡列表)

  • 如果您不需要追蹤特定的、有問題的子模組,請將其轉換為「葉模組」,方法是將 leaf_modules 列表作為 tracer_kwargs 之一傳遞(請參閱以下範例)。它將不會被追蹤,而是產生的圖形將保留對該模組前向方法的參考。

  • 同樣地,您可以透過將 autowrap_functions 列表作為 tracer_kwargs 之一傳遞,將函式轉換為葉函式(請參閱以下範例)。

  • 某些內建的 Python 函式可能會產生問題。例如,int 將在追蹤期間引發錯誤。您可以將它們包裝在您自己的函式中,然後將其在 autowrap_functions 中作為 tracer_kwargs 之一傳遞。

如需關於 FX 的更多資訊,請參閱 torch.fx 文件

參數:
  • model (nn.Module) – 我們將在其上提取特徵的模型

  • return_nodes (listdict, optional) – 包含節點名稱(或部分名稱 - 請參閱上方註釋)的 ListDict,將為這些節點傳回激活值。如果是 Dict,則鍵是節點名稱,值是圖形模組傳回字典的使用者指定鍵。如果是 List,則將其視為 Dict,將節點規格字串直接映射到輸出名稱。在指定 train_return_nodeseval_return_nodes 的情況下,不應指定此項。

  • train_return_nodes (listdict, optional) – 類似於 return_nodes。如果訓練模式的傳回節點與評估模式的傳回節點不同,則可以使用此項。如果指定了此項,則也必須指定 eval_return_nodes,且不應指定 return_nodes

  • eval_return_nodes (listdict, optional) – 類似於 return_nodes。如果訓練模式的傳回節點與評估模式的傳回節點不同,則可以使用此項。如果指定了此項,則也必須指定 train_return_nodes,且不應指定 return_nodes

  • tracer_kwargs (dict, optional) – NodePathTracer 的關鍵字引數字典(它將它們傳遞給其父類別 torch.fx.Tracer)。預設情況下,它將設定為包裝並將所有 torchvision 運算設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供 tracer_kwargs,上述預設引數將附加至使用者提供的字典。

  • suppress_diff_warning (bool, optional) – 是否在圖形的訓練版本和評估版本之間存在差異時,抑制警告。預設值為 False。

  • concrete_args (Optional[Dict[str, any]]) – 不應視為 Proxies 的具體引數。根據 Pytorch 文件,此參數的 API 可能無法保證。

範例

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源