捷徑

torch.load

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[原始碼][原始碼]

從檔案載入使用 torch.save() 儲存的物件。

torch.load() 使用 Python 的 unpickling 功能,但會特別處理儲存空間(storages),這是 tensors 的基礎。它們首先會在 CPU 上反序列化,然後移動到儲存它們的裝置上。如果此操作失敗(例如,因為執行時系統沒有某些裝置),則會引發例外。但是,可以使用 map_location 引數將儲存空間動態地重新對應到另一組裝置。

如果 map_location 是一個可呼叫的物件(callable),它將為每個序列化的儲存空間呼叫一次,並帶有兩個引數:storage 和 location。 storage 引數將是儲存空間的初始反序列化,位於 CPU 上。每個序列化的儲存空間都有一個與之關聯的位置標籤,用於識別它儲存的裝置,而此標籤是傳遞給 map_location 的第二個引數。內建的位置標籤是 CPU tensors 的 'cpu' 和 CUDA tensors 的 'cuda:device_id'(例如 'cuda:2')。 map_location 應返回 None 或一個儲存空間。如果 map_location 返回一個儲存空間,它將被用作最終的反序列化物件,並且已經移動到正確的裝置。否則,torch.load() 將退回到預設行為,就像未指定 map_location 一樣。

如果 map_location 是一個 torch.device 物件或包含裝置標籤的字串,則它表示所有 tensors 應載入的位置。

否則,如果 map_location 是一個 dict,它將用於重新對應檔案中出現的位置標籤(鍵),以指定將儲存空間放置在哪裡(值)。

使用者擴充可以使用 torch.serialization.register_package() 註冊它們自己的位置標籤以及標記和反序列化方法。

引數
  • f (Union[str, PathLike, BinaryIO, IO[bytes]]) – 類檔案物件(必須實作 read()readline()tell()seek()),或包含檔名的字串或 os.PathLike 物件

  • map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, Dict[str, str]]]) – 一個函式、torch.device、字串或 dict,用於指定如何重新對應儲存位置

  • pickle_module (Optional[Any]) – 用於 unpickling metadata 和物件的模組(必須與用於序列化檔案的 pickle_module 相符)

  • weights_only (Optional[bool]) – 指示 unpickler 是否應僅限於載入 tensors、基本類型、字典以及透過 torch.serialization.add_safe_globals() 新增的任何類型。 詳情請參閱 torch.load with weights_only=True

  • mmap (Optional[bool]) – 指示是否應該 mmap 檔案,而不是將所有儲存空間載入到記憶體中。 通常,檔案中的 tensor 儲存空間將首先從磁碟移動到 CPU 記憶體,然後移動到它們在儲存時標記的位置,或由 map_location 指定。 如果最終位置是 CPU,則第二步是 no-op。 設定 mmap 旗標時,不是在第一步將 tensor 儲存空間從磁碟複製到 CPU 記憶體,而是 mmap f

  • pickle_load_args (Any) – (僅限 Python 3) 傳遞給 pickle_module.load()pickle_module.Unpickler() 的可選關鍵字引數,例如,errors=...

傳回類型

Any

警告

除非將 weights_only 參數設為 True,否則torch.load() 會隱式使用 pickle 模組,而此模組已知不安全。 有可能構造惡意的 pickle 資料,這些資料在反序列化過程中會執行任意程式碼。 永遠不要在不安全的模式下載入可能來自不受信任來源或可能被篡改的資料。只載入您信任的資料

注意

當您對包含 GPU tensors 的檔案呼叫 torch.load() 時,預設情況下這些 tensors 將會被載入到 GPU 中。 您可以呼叫 torch.load(.., map_location='cpu'),然後呼叫 load_state_dict(),以避免在載入模型檢查點時 GPU RAM 激增。

注意

預設情況下,我們會將位元組字串解碼為 utf-8。 這是為了避免在 Python 3 中載入由 Python 2 儲存的檔案時,出現常見的錯誤案例 UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。 如果此預設值不正確,您可以使用額外的 encoding 關鍵字參數來指定應如何載入這些物件,例如,encoding='latin1' 會使用 latin1 編碼將它們解碼為字串,而 encoding='bytes' 會將它們保留為位元組陣列,這些位元組陣列可以稍後使用 byte_array.decode(...) 解碼。

範例

>>> torch.load("tensors.pt", weights_only=True)
# Load all tensors onto the CPU
>>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True)
# Load all tensors onto the CPU, using a function
>>> torch.load(
...     "tensors.pt", map_location=lambda storage, loc: storage, weights_only=True
... )
# Load all tensors onto GPU 1
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage.cuda(1),
...     weights_only=True,
... )  # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
>>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, weights_only=True)
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open("tensor.pt", "rb") as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load("module.pt", encoding="ascii", weights_only=False)

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源