torch.hub¶
Pytorch Hub 是一個預訓練模型儲存庫,旨在促進研究的可重現性。
發布模型¶
Pytorch Hub 支援透過新增簡單的 hubconf.py
檔案,將預訓練模型(模型定義和預訓練權重)發布到 GitHub 儲存庫;
hubconf.py
可以有多個進入點。每個進入點都定義為 Python 函式(範例:您要發布的預訓練模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何實作進入點?¶
如果我們在 pytorch/vision/hubconf.py
中擴展實作,以下程式碼片段指定 resnet18
模型的進入點。在大多數情況下,在 hubconf.py
中匯入正確的函式就足夠了。在這裡,我們只想使用擴展版本作為範例來說明它的工作原理。您可以在pytorch/vision repo中看到完整腳本
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
變數是一個需要載入模型的套件名稱列表。請注意,這可能與訓練模型所需的依賴項略有不同。args
和kwargs
會傳遞給真正的可呼叫函式。函式的 Docstring 就像是說明訊息。它解釋了模型的作用,以及允許使用的位置參數/關鍵字參數。強烈建議在此處新增一些範例。
Entrypoint 函式可以回傳一個模型 (nn.module),或者回傳輔助工具來使使用者工作流程更順暢,例如 tokenizer。
以底線開頭的可呼叫物件會被視為輔助函式,不會出現在
torch.hub.list()
中。預訓練權重可以儲存在 GitHub 儲存庫的本地端,或者透過
torch.hub.load_state_dict_from_url()
載入。如果小於 2GB,建議將其附加到 專案發布 並使用發布的 URL。在上面的範例中,torchvision.models.resnet.resnet18
處理pretrained
,或者您可以將以下邏輯放入 entrypoint 定義中。
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要聲明¶
發布的模型至少應在一個分支/標籤中。它不能是一個隨機的 commit。
從 Hub 載入模型¶
Pytorch Hub 提供了方便的 API,可透過 torch.hub.list()
探索 hub 中所有可用的模型,透過 torch.hub.help()
顯示 docstring 和範例,並使用 torch.hub.load()
載入預訓練模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source][source]¶
列出
github
指定的儲存庫中所有可呼叫的 entrypoint。- 參數
github (str) – 格式為 "repo_owner/repo_name[:ref]" 的字串,帶有可選的 ref(標籤或分支)。如果未指定
ref
,則預設分支假定為main
(如果存在),否則為master
。範例:'pytorch/vision:0.10'force_reload (bool, optional) – 是否丟棄現有快取並強制重新下載。預設值為
False
。skip_validation (bool, optional) – 如果
False
,torchhub 將檢查github
參數指定的分支或 commit 是否正確屬於儲存庫擁有者。這會向 GitHub API 發出請求;您可以透過設定GITHUB_TOKEN
環境變數來指定非預設的 GitHub 令牌。預設值為False
。trust_repo (bool, str or None) –
"check"
,True
,False
或None
。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。如果
False
,將會提示使用者是否應該信任該儲存庫。如果
True
,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。如果
"check"
,將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到trust_repo=False
選項。如果
None
:這將引發警告,邀請使用者將trust_repo
設定為False
、True
或"check"
。這僅用於向後相容性,並將在 v2.0 中移除。
預設值為
None
,最終將在 v2.0 中變更為"check"
。verbose (bool, optional) – 如果
False
,則靜音關於命中本地快取的訊息。請注意,關於首次下載的訊息無法靜音。預設值為True
。
- 回傳值
可用的可呼叫 entrypoint
- 回傳值型別
範例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[原始碼][原始碼]¶
顯示 entrypoint
model
的 docstring。- 參數
github (str) – 格式為 <repo_owner/repo_name[:ref]> 的字串,其中 ref 為選填(可以是標籤或分支)。如果未指定
ref
,預設分支會假設為main
(如果存在),否則為master
。範例:‘pytorch/vision:0.10’model (str) – 在 repo 的
hubconf.py
中定義的 entrypoint 名稱字串force_reload (bool, optional) – 是否丟棄現有快取並強制重新下載。預設值為
False
。skip_validation (bool, optional) – 如果
False
,torchhub 將會檢查由github
參數指定的 ref 是否正確屬於 repo owner。這會向 GitHub API 發出請求;您可以透過設定GITHUB_TOKEN
環境變數來指定非預設的 GitHub 令牌。預設值為False
。trust_repo (bool, str or None) –
"check"
,True
,False
或None
。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。如果
False
,將會提示使用者是否應該信任該儲存庫。如果
True
,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。如果
"check"
,將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到trust_repo=False
選項。如果
None
:這將引發警告,邀請使用者將trust_repo
設定為False
、True
或"check"
。這僅用於向後相容性,並將在 v2.0 中移除。
預設值為
None
,最終將在 v2.0 中變更為"check"
。
範例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[原始碼][原始碼]¶
從 github repo 或本機目錄載入模型。
注意:載入模型是典型的用例,但這也可以用於載入其他物件,例如 tokenizers、loss functions 等。
如果
source
是 ‘github’,則repo_or_dir
應該是repo_owner/repo_name[:ref]
的形式,其中 ref 為選填(可以是標籤或分支)。如果
source
是 ‘local’,則repo_or_dir
應該是本機目錄的路徑。- 參數
repo_or_dir (str) – 如果
source
是 ‘github’,這應該對應到一個 github repo,格式為repo_owner/repo_name[:ref]
,其中 ref 為選填(標籤或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref
,預設分支會假設為main
(如果存在),否則為master
。 如果source
是 ‘local’,則它應該是本機目錄的路徑。model (str) – 在 repo/dir 的
hubconf.py
中定義的可呼叫(entrypoint)的名稱。*args (optional) – 可呼叫
model
的對應 args。source (str, optional) – ‘github’ 或 ‘local’。指定如何解釋
repo_or_dir
。預設值為 ‘github’。trust_repo (bool, str or None) –
"check"
,True
,False
或None
。此參數於 v1.12 中引入,有助於確保使用者僅執行來自他們信任的儲存庫的程式碼。如果
False
,將會提示使用者是否應該信任該儲存庫。如果
True
,該儲存庫將被新增到信任清單中,並且無需明確確認即可載入。如果
"check"
,將會針對快取中的信任儲存庫清單檢查該儲存庫。如果它不在該清單中,則行為將回退到trust_repo=False
選項。如果
None
:這將引發警告,邀請使用者將trust_repo
設定為False
、True
或"check"
。這僅用於向後相容性,並將在 v2.0 中移除。
預設值為
None
,最終將在 v2.0 中變更為"check"
。force_reload (bool, optional) – 是否強制無條件地重新下載 github repo。 如果
source = 'local'
,則無效。預設值為False
。verbose (bool, optional) – 如果
False
,則靜音關於命中本機快取的訊息。 請注意,關於首次下載的訊息無法靜音。 如果source = 'local'
,則無效。預設值為True
。skip_validation (bool, optional) – 如果
False
,torchhub 將檢查github
參數指定的分支或 commit 是否正確屬於儲存庫擁有者。這會向 GitHub API 發出請求;您可以透過設定GITHUB_TOKEN
環境變數來指定非預設的 GitHub 令牌。預設值為False
。**kwargs (optional) – 可呼叫
model
的對應 kwargs。
- 回傳值
使用給定的
*args
和**kwargs
呼叫model
可呼叫物件的輸出。
範例
>>> # from a github repo >>> repo = "pytorch/vision" >>> model = torch.hub.load( ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" ... ) >>> # from a local directory >>> path = "/some/local/path/pytorch/vision" >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[原始碼][原始碼]¶
將給定 URL 的物件下載到本機路徑。
- 參數
範例
>>> torch.hub.download_url_to_file( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", ... "/tmp/temporary_file", ... )
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source][source]¶
載入指定 URL 的 Torch 序列化物件。
如果下載的檔案是 zip 檔,將會自動解壓縮。
如果物件已存在於 model_dir 中,則會將其反序列化並傳回。
model_dir
的預設值為<hub_dir>/checkpoints
,其中hub_dir
是由get_dir()
傳回的目錄。- 參數
url (str) – 要下載的物件的 URL
model_dir (str, optional) – 要儲存物件的目錄
map_location (optional) – 一個函數或字典,用於指定如何重新映射儲存位置(請參閱 torch.load)
progress (bool, optional) – 是否將進度條顯示於 stderr。預設值:True
check_hash (bool, optional) – 如果為 True,則 URL 的檔名部分應遵循命名慣例
filename-<sha256>.ext
,其中<sha256>
是檔案內容的 SHA256 雜湊值的前八位或更多位數。該雜湊值用於確保名稱的唯一性並驗證檔案的內容。預設值:Falsefile_name (str, optional) – 下載檔案的名稱。如果未設定,將使用
url
中的檔名。weights_only (bool, optional) – 如果為 True,則只會載入權重,而不會載入複雜的 pickled 物件。建議用於不信任的來源。請參閱
load()
獲取更多詳細資訊。
- 回傳值型別
範例
>>> state_dict = torch.hub.load_state_dict_from_url( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" ... )
執行載入的模型:¶
請注意,torch.hub.load()
中的 *args
和 **kwargs
用於實例化模型。在載入模型後,如何找出你可以對模型做什麼?建議的工作流程是
dir(model)
查看模型的所有可用方法。help(model.foo)
檢查執行model.foo
需要哪些參數
為了幫助使用者探索,而無需來回查閱文件,我們強烈建議 repo 所有者使函數說明訊息清晰簡潔。包含最小工作範例也很有幫助。
我的下載模型儲存在哪裡?¶
位置按以下順序使用
呼叫
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果設定了環境變數TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果設定了環境變數XDG_CACHE_HOME
。~/.cache/torch/hub
快取邏輯¶
預設情況下,我們不會在載入檔案後清理它們。如果 Hub 已經存在於 get_dir()
返回的目錄中,Hub 預設會使用快取。
使用者可以透過呼叫 hub.load(..., force_reload=True)
強制重新載入。這將刪除現有的 GitHub 資料夾和下載的權重,並重新初始化全新的下載。當更新發佈到同一個分支時,這很有用,使用者可以隨時掌握最新的版本。
已知限制:¶
Torch hub 的運作方式是將套件匯入,就好像它是已安裝的一樣。在 Python 中匯入會引入一些副作用。例如,您可以在 Python 快取 sys.modules
和 sys.path_importer_cache
中看到新的項目,這是正常的 Python 行為。這也意味著,如果不同的儲存庫具有相同的子套件名稱(通常是 model
子套件),則從不同的儲存庫匯入不同的模型時,可能會發生匯入錯誤。解決這些匯入錯誤的一種方法是從 sys.modules
字典中移除有問題的子套件;更多詳細資訊可以在 這個 GitHub issue 中找到。
這裡值得一提的一個已知限制:使用者無法在同一個 Python 程序中載入同一個儲存庫的兩個不同分支。這就像在 Python 中安裝兩個同名的套件一樣,這是不好的。如果您真的嘗試這樣做,快取可能會加入並給您帶來驚喜。當然,在單獨的程序中載入它們是完全沒有問題的。