快捷方式

torch.hub

Pytorch Hub 是一個預訓練模型儲存庫,旨在促進研究的可重現性。

發布模型

Pytorch Hub 支援透過新增簡單的 hubconf.py 檔案,將預訓練模型(模型定義和預訓練權重)發布到 GitHub 儲存庫;

hubconf.py 可以有多個進入點。每個進入點都定義為 Python 函式(範例:您要發布的預訓練模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何實作進入點?

如果我們在 pytorch/vision/hubconf.py 中擴展實作,以下程式碼片段指定 resnet18 模型的進入點。在大多數情況下,在 hubconf.py 中匯入正確的函式就足夠了。在這裡,我們只想使用擴展版本作為範例來說明它的工作原理。您可以在pytorch/vision repo中看到完整腳本

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
  • dependencies 變數是一個需要載入模型的套件名稱列表。請注意,這可能與訓練模型所需的依賴項略有不同。

  • argskwargs 會傳遞給真正的可呼叫函式。

  • 函式的 Docstring 就像是說明訊息。它解釋了模型的作用,以及允許使用的位置參數/關鍵字參數。強烈建議在此處新增一些範例。

  • Entrypoint 函式可以回傳一個模型 (nn.module),或者回傳輔助工具來使使用者工作流程更順暢,例如 tokenizer。

  • 以底線開頭的可呼叫物件會被視為輔助函式,不會出現在 torch.hub.list() 中。

  • 預訓練權重可以儲存在 GitHub 儲存庫的本地端,或者透過 torch.hub.load_state_dict_from_url() 載入。如果小於 2GB,建議將其附加到 專案發布 並使用發布的 URL。在上面的範例中,torchvision.models.resnet.resnet18 處理 pretrained,或者您可以將以下邏輯放入 entrypoint 定義中。

if pretrained:
    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要聲明

  • 發布的模型至少應在一個分支/標籤中。它不能是一個隨機的 commit。

從 Hub 載入模型

Pytorch Hub 提供了方便的 API,可透過 torch.hub.list() 探索 hub 中所有可用的模型,透過 torch.hub.help() 顯示 docstring 和範例,並使用 torch.hub.load() 載入預訓練模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source][source]

列出 github 指定的儲存庫中所有可呼叫的 entrypoint。

參數
  • github (str) – 格式為 "repo_owner/repo_name[:ref]" 的字串,帶有可選的 ref(標籤或分支)。如果未指定 ref,則預設分支假定為 main(如果存在),否則為 master。範例:'pytorch/vision:0.10'

  • force_reload (bool, optional) – 是否丟棄現有快取並強制重新下載。預設值為 False

  • skip_validation (bool, optional) – 如果 False,torchhub 將檢查 github 參數指定的分支或 commit 是否正確屬於儲存庫擁有者。這會向 GitHub API 發出請求;您可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub 令牌。預設值為 False

  • trust_repo (bool, str or None) –

    "check", True, FalseNone。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。

    • 如果 False,將會提示使用者是否應該信任該儲存庫。

    • 如果 True,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。

    • 如果 "check",將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到 trust_repo=False 選項。

    • 如果 None:這將引發警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅用於向後相容性,並將在 v2.0 中移除。

    預設值為 None,最終將在 v2.0 中變更為 "check"

  • verbose (bool, optional) – 如果 False,則靜音關於命中本地快取的訊息。請注意,關於首次下載的訊息無法靜音。預設值為 True

回傳值

可用的可呼叫 entrypoint

回傳值型別

list

範例

>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[原始碼][原始碼]

顯示 entrypoint model 的 docstring。

參數
  • github (str) – 格式為 <repo_owner/repo_name[:ref]> 的字串,其中 ref 為選填(可以是標籤或分支)。如果未指定 ref,預設分支會假設為 main (如果存在),否則為 master。範例:‘pytorch/vision:0.10’

  • model (str) – 在 repo 的 hubconf.py 中定義的 entrypoint 名稱字串

  • force_reload (bool, optional) – 是否丟棄現有快取並強制重新下載。預設值為 False

  • skip_validation (bool, optional) – 如果 False,torchhub 將會檢查由 github 參數指定的 ref 是否正確屬於 repo owner。這會向 GitHub API 發出請求;您可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub 令牌。預設值為 False

  • trust_repo (bool, str or None) –

    "check", True, FalseNone。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。

    • 如果 False,將會提示使用者是否應該信任該儲存庫。

    • 如果 True,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。

    • 如果 "check",將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到 trust_repo=False 選項。

    • 如果 None:這將引發警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅用於向後相容性,並將在 v2.0 中移除。

    預設值為 None,最終將在 v2.0 中變更為 "check"

範例

>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[原始碼][原始碼]

從 github repo 或本機目錄載入模型。

注意:載入模型是典型的用例,但這也可以用於載入其他物件,例如 tokenizers、loss functions 等。

如果 source 是 ‘github’,則 repo_or_dir 應該是 repo_owner/repo_name[:ref] 的形式,其中 ref 為選填(可以是標籤或分支)。

如果 source 是 ‘local’,則 repo_or_dir 應該是本機目錄的路徑。

參數
  • repo_or_dir (str) – 如果 source 是 ‘github’,這應該對應到一個 github repo,格式為 repo_owner/repo_name[:ref],其中 ref 為選填(標籤或分支),例如 ‘pytorch/vision:0.10’。如果未指定 ref,預設分支會假設為 main (如果存在),否則為 master。 如果 source 是 ‘local’,則它應該是本機目錄的路徑。

  • model (str) – 在 repo/dir 的 hubconf.py 中定義的可呼叫(entrypoint)的名稱。

  • *args (optional) – 可呼叫 model 的對應 args。

  • source (str, optional) – ‘github’ 或 ‘local’。指定如何解釋 repo_or_dir。預設值為 ‘github’。

  • trust_repo (bool, str or None) –

    "check", True, FalseNone。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。

    • 如果 False,將會提示使用者是否應該信任該儲存庫。

    • 如果 True,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。

    • 如果 "check",將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到 trust_repo=False 選項。

    • 如果 None:這將引發警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅用於向後相容性,並將在 v2.0 中移除。

    預設值為 None,最終將在 v2.0 中變更為 "check"

  • force_reload (bool, optional) – 是否強制無條件地重新下載 github repo。 如果 source = 'local',則無效。預設值為 False

  • verbose (bool, optional) – 如果 False,則靜音關於命中本機快取的訊息。 請注意,關於首次下載的訊息無法靜音。 如果 source = 'local',則無效。預設值為 True

  • skip_validation (bool, optional) – 如果 False,torchhub 將檢查 github 參數指定的分支或 commit 是否正確屬於儲存庫擁有者。這會向 GitHub API 發出請求;您可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub 令牌。預設值為 False

  • **kwargs (optional) – 可呼叫 model 的對應 kwargs。

回傳值

使用給定的 *args**kwargs 呼叫 model 可呼叫物件的輸出。

範例

>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[原始碼][原始碼]

將給定 URL 的物件下載到本機路徑。

參數
  • url (str) – 要下載的物件的 URL

  • dst (str) – 將儲存物件的完整路徑,例如 /tmp/temporary_file

  • hash_prefix (str, optional) – 如果不是 None,則下載的 SHA256 檔案應該以 hash_prefix 開頭。預設值:None

  • progress (bool, optional) – 是否將進度條顯示於 stderr。預設值:True

範例

>>> torch.hub.download_url_to_file(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
...     "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source][source]

載入指定 URL 的 Torch 序列化物件。

如果下載的檔案是 zip 檔,將會自動解壓縮。

如果物件已存在於 model_dir 中,則會將其反序列化並傳回。model_dir 的預設值為 <hub_dir>/checkpoints,其中 hub_dir 是由 get_dir() 傳回的目錄。

參數
  • url (str) – 要下載的物件的 URL

  • model_dir (str, optional) – 要儲存物件的目錄

  • map_location (optional) – 一個函數或字典,用於指定如何重新映射儲存位置(請參閱 torch.load)

  • progress (bool, optional) – 是否將進度條顯示於 stderr。預設值:True

  • check_hash (bool, optional) – 如果為 True,則 URL 的檔名部分應遵循命名慣例 filename-<sha256>.ext,其中 <sha256> 是檔案內容的 SHA256 雜湊值的前八位或更多位數。該雜湊值用於確保名稱的唯一性並驗證檔案的內容。預設值:False

  • file_name (str, optional) – 下載檔案的名稱。如果未設定,將使用 url 中的檔名。

  • weights_only (bool, optional) – 如果為 True,則只會載入權重,而不會載入複雜的 pickled 物件。建議用於不信任的來源。請參閱 load() 獲取更多詳細資訊。

回傳值型別

Dict[str, Any]

範例

>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

執行載入的模型:

請注意,torch.hub.load() 中的 *args**kwargs 用於實例化模型。在載入模型後,如何找出你可以對模型做什麼?建議的工作流程是

  • dir(model) 查看模型的所有可用方法。

  • help(model.foo) 檢查執行 model.foo 需要哪些參數

為了幫助使用者探索,而無需來回查閱文件,我們強烈建議 repo 所有者使函數說明訊息清晰簡潔。包含最小工作範例也很有幫助。

我的下載模型儲存在哪裡?

位置按以下順序使用

  • 呼叫 hub.set_dir(<PATH_TO_HUB_DIR>)

  • $TORCH_HOME/hub,如果設定了環境變數 TORCH_HOME

  • $XDG_CACHE_HOME/torch/hub,如果設定了環境變數 XDG_CACHE_HOME

  • ~/.cache/torch/hub

torch.hub.get_dir()[source][source]

取得用於儲存下載模型和權重的 Torch Hub 快取目錄。

如果未呼叫 set_dir(),則預設路徑為 $TORCH_HOME/hub,其中環境變數 $TORCH_HOME 預設為 $XDG_CACHE_HOME/torch$XDG_CACHE_HOME 遵循 Linux 檔案系統佈局的 X Design Group 規範,如果未設定環境變數,則預設值為 ~/.cache

回傳值型別

str

torch.hub.set_dir(d)[原始碼][原始碼]

選擇性地設定 Torch Hub 目錄,用於儲存下載的模型和權重。

參數

d (str) – 用於儲存下載模型和權重的本機資料夾路徑。

快取邏輯

預設情況下,我們不會在載入檔案後清理它們。如果 Hub 已經存在於 get_dir() 返回的目錄中,Hub 預設會使用快取。

使用者可以透過呼叫 hub.load(..., force_reload=True) 強制重新載入。這將刪除現有的 GitHub 資料夾和下載的權重,並重新初始化全新的下載。當更新發佈到同一個分支時,這很有用,使用者可以隨時掌握最新的版本。

已知限制:

Torch hub 的運作方式是將套件匯入,就好像它是已安裝的一樣。在 Python 中匯入會引入一些副作用。例如,您可以在 Python 快取 sys.modulessys.path_importer_cache 中看到新的項目,這是正常的 Python 行為。這也意味著,如果不同的儲存庫具有相同的子套件名稱(通常是 model 子套件),則從不同的儲存庫匯入不同的模型時,可能會發生匯入錯誤。解決這些匯入錯誤的一種方法是從 sys.modules 字典中移除有問題的子套件;更多詳細資訊可以在 這個 GitHub issue 中找到。

這裡值得一提的一個已知限制:使用者無法同一個 Python 程序中載入同一個儲存庫的兩個不同分支。這就像在 Python 中安裝兩個同名的套件一樣,這是不好的。如果您真的嘗試這樣做,快取可能會加入並給您帶來驚喜。當然,在單獨的程序中載入它們是完全沒有問題的。

文件

訪問 PyTorch 的全面開發者文檔

查看文檔

教學課程

取得初學者和高級開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得您問題的解答

查看資源