torch.load¶
- torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[原始碼][原始碼]¶
從檔案載入使用
torch.save()
儲存的物件。torch.load()
使用 Python 的 unpickling 功能,但會特別處理儲存空間(storages),這是 tensors 的基礎。它們首先會在 CPU 上反序列化,然後移動到儲存它們的裝置上。如果此操作失敗(例如,因為執行時系統沒有某些裝置),則會引發例外。但是,可以使用map_location
引數將儲存空間動態地重新對應到另一組裝置。如果
map_location
是一個可呼叫的物件(callable),它將為每個序列化的儲存空間呼叫一次,並帶有兩個引數:storage 和 location。 storage 引數將是儲存空間的初始反序列化,位於 CPU 上。每個序列化的儲存空間都有一個與之關聯的位置標籤,用於識別它儲存的裝置,而此標籤是傳遞給map_location
的第二個引數。內建的位置標籤是 CPU tensors 的'cpu'
和 CUDA tensors 的'cuda:device_id'
(例如'cuda:2'
)。map_location
應返回None
或一個儲存空間。如果map_location
返回一個儲存空間,它將被用作最終的反序列化物件,並且已經移動到正確的裝置。否則,torch.load()
將退回到預設行為,就像未指定map_location
一樣。如果
map_location
是一個torch.device
物件或包含裝置標籤的字串,則它表示所有 tensors 應載入的位置。否則,如果
map_location
是一個 dict,它將用於重新對應檔案中出現的位置標籤(鍵),以指定將儲存空間放置在哪裡(值)。使用者擴充可以使用
torch.serialization.register_package()
註冊它們自己的位置標籤以及標記和反序列化方法。- 引數
f (Union[str, PathLike, BinaryIO, IO[bytes]]) – 類檔案物件(必須實作
read()
、readline()
、tell()
和seek()
),或包含檔名的字串或 os.PathLike 物件map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, Dict[str, str]]]) – 一個函式、
torch.device
、字串或 dict,用於指定如何重新對應儲存位置pickle_module (Optional[Any]) – 用於 unpickling metadata 和物件的模組(必須與用於序列化檔案的
pickle_module
相符)weights_only (Optional[bool]) – 指示 unpickler 是否應僅限於載入 tensors、基本類型、字典以及透過
torch.serialization.add_safe_globals()
新增的任何類型。 詳情請參閱 torch.load with weights_only=True 。mmap (Optional[bool]) – 指示是否應該 mmap 檔案,而不是將所有儲存空間載入到記憶體中。 通常,檔案中的 tensor 儲存空間將首先從磁碟移動到 CPU 記憶體,然後移動到它們在儲存時標記的位置,或由
map_location
指定。 如果最終位置是 CPU,則第二步是 no-op。 設定mmap
旗標時,不是在第一步將 tensor 儲存空間從磁碟複製到 CPU 記憶體,而是 mmapf
。pickle_load_args (Any) – (僅限 Python 3) 傳遞給
pickle_module.load()
和pickle_module.Unpickler()
的可選關鍵字引數,例如,errors=...
。
- 傳回類型
警告
除非將 weights_only 參數設為 True,否則
torch.load()
會隱式使用pickle
模組,而此模組已知不安全。 有可能構造惡意的 pickle 資料,這些資料在反序列化過程中會執行任意程式碼。 永遠不要在不安全的模式下載入可能來自不受信任來源或可能被篡改的資料。只載入您信任的資料。注意
當您對包含 GPU tensors 的檔案呼叫
torch.load()
時,預設情況下這些 tensors 將會被載入到 GPU 中。 您可以呼叫torch.load(.., map_location='cpu')
,然後呼叫load_state_dict()
,以避免在載入模型檢查點時 GPU RAM 激增。注意
預設情況下,我們會將位元組字串解碼為
utf-8
。 這是為了避免在 Python 3 中載入由 Python 2 儲存的檔案時,出現常見的錯誤案例UnicodeDecodeError: 'ascii' codec can't decode byte 0x...
。 如果此預設值不正確,您可以使用額外的encoding
關鍵字參數來指定應如何載入這些物件,例如,encoding='latin1'
會使用latin1
編碼將它們解碼為字串,而encoding='bytes'
會將它們保留為位元組陣列,這些位元組陣列可以稍後使用byte_array.decode(...)
解碼。範例
>>> torch.load("tensors.pt", weights_only=True) # Load all tensors onto the CPU >>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True) # Load all tensors onto the CPU, using a function >>> torch.load( ... "tensors.pt", map_location=lambda storage, loc: storage, weights_only=True ... ) # Load all tensors onto GPU 1 >>> torch.load( ... "tensors.pt", ... map_location=lambda storage, loc: storage.cuda(1), ... weights_only=True, ... ) # type: ignore[attr-defined] # Map tensors from GPU 1 to GPU 0 >>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, weights_only=True) # Load tensor from io.BytesIO object # Loading from a buffer setting weights_only=False, warning this can be unsafe >>> with open("tensor.pt", "rb") as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer, weights_only=False) # Load a module with 'ascii' encoding for unpickling # Loading from a module setting weights_only=False, warning this can be unsafe >>> torch.load("module.pt", encoding="ascii", weights_only=False)