序列化語義¶
本筆記說明如何在 Python 中儲存和載入 PyTorch tensor 和模組狀態,以及如何序列化 Python 模組,以便可以在 C++ 中載入它們。
目錄
儲存和載入 tensor¶
torch.save()
和 torch.load()
讓您可以輕鬆儲存和載入 tensor
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照慣例,PyTorch 檔案通常以 '.pt' 或 '.pth' 擴展名編寫。
torch.save()
和 torch.load()
預設使用 Python 的 pickle,因此您也可以將多個張量儲存為 Python 物件的一部分,例如元組、列表和字典。
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
如果資料結構是可 pickle 的,也可以儲存包含 PyTorch 張量的自訂資料結構。
儲存和載入張量會保留視圖¶
儲存張量會保留它們的視圖關係。
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在底層,這些張量共享相同的「儲存空間」。有關視圖和儲存空間的更多資訊,請參閱 張量視圖。
當 PyTorch 儲存張量時,它會分別儲存它們的儲存物件和張量元數據。這是一個實現細節,將來可能會更改,但通常可以節省空間,並讓 PyTorch 輕鬆重建載入的張量之間的視圖關係。例如,在上面的程式碼片段中,只有一個儲存空間被寫入 'tensors.pt'。
然而,在某些情況下,儲存目前的儲存物件可能是不必要的,並且會產生過大的檔案。在下面的程式碼片段中,一個比儲存的張量大得多的儲存空間被寫入檔案
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
與其只將 small 張量中的五個值儲存到 'small.pt',不如將它與 large 共享的儲存空間中的 999 個值儲存並載入。
當儲存的張量擁有的元素少於它們的儲存物件時,可以先克隆張量來減小儲存檔案的大小。克隆張量會產生一個新的張量,該張量具有一個新的儲存物件,其中僅包含張量中的值。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
然而,由於克隆的張量彼此獨立,因此它們沒有原始張量擁有的任何視圖關係。如果在儲存張量時,檔案大小和視圖關係都很重要,那麼必須小心建構新的張量,以最大程度地減少其儲存物件的大小,但仍具有所需的視圖關係,然後再儲存。
儲存和載入 torch.nn.Modules¶
另請參閱:教學:儲存和載入模組
在 PyTorch 中,模組的狀態經常使用「狀態字典 (state dict)」進行序列化。模組的狀態字典包含其所有參數和持久性緩衝區。
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
建議不要直接儲存模組,而是僅儲存其狀態字典,以實現相容性。Python 模組甚至有一個函數 load_state_dict()
,可以從狀態字典中恢復它們的狀態。
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
請注意,狀態字典首先使用 torch.load()
從其檔案中載入,然後使用 load_state_dict()
恢復狀態。
即使是自訂模組和包含其他模組的模組也具有狀態字典,並且可以使用這種模式。
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
torch.save
的序列化檔案格式¶
自 PyTorch 1.6.0 起,除非使用者設定 _use_new_zipfile_serialization=False
,否則 torch.save
預設返回未壓縮的 ZIP64 封存檔。
在這個封存檔中,檔案的順序如下
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 這些條目如下
data.pkl
是對傳遞給torch.save
的物件進行 pickling 的結果,但不包括它包含的torch.Storage
物件。byteorder
包含一個字串,其中包含儲存時的sys.byteorder
(“little”或“big”)。data/
包含物件中的所有儲存空間,其中每個儲存空間都是一個單獨的檔案。version
包含儲存時的版本號,該版本號可用於載入時。
儲存時,PyTorch 將確保每個檔案的本地檔案標頭都填充到一個偏移量,該偏移量是 64 位元組的倍數,從而確保每個檔案的偏移量與 64 位元組對齊。
注意
某些裝置(例如 XLA)上的張量會序列化為已 pickle 的 numpy 陣列。因此,它們的儲存空間不會被序列化。在這些情況下,checkpoint 中可能不存在 data/
。
torch.load
with weights_only=True
¶
從 2.6 版開始,如果未傳遞 pickle_module
參數,torch.load
將使用 weights_only=True
。
如 torch.load()
的文件所述,weights_only=True
將 torch.load
中使用的 unpickler 限制為僅執行 state_dicts
的純 torch.Tensors
以及一些其他原始類型所需的功能/構建類別。此外,與 pickle
模組提供的預設 Unpickler
不同,weights_only
Unpickler 不允許在 unpickling 期間動態匯入任何內容。
如上所述,儲存模組的 state_dict
是使用 torch.save
時的最佳實務。如果載入包含 nn.Module
的舊 checkpoint,我們建議使用 weights_only=False
。當載入包含張量子類的 checkpoint 時,可能需要將一些函數/類別列入允許清單,請參閱以下詳細資訊。
如果 weights_only
Unpickler 遇到未在 pickle 檔案中預設允許清單中的函數或類別,您應該會看到類似這樣的可操作錯誤
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
請按照錯誤訊息中的步驟操作,並且僅在您信任這些函數或類別時才將它們列入允許清單。
若要取得檢查點中所有尚未加入允許清單的 GLOBAL(函式/類別),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint()
,它會傳回一個字串清單,格式為 {__module__}.{__name__}
。如果您信任這些函式/類別,您可以匯入它們,並根據錯誤訊息,透過 torch.serialization.add_safe_globals()
或上下文管理器 torch.serialization.safe_globals
將它們加入允許清單。
若要存取使用者允許清單中的函式/類別,您可以使用 torch.serialization.get_safe_globals()
,而要清除目前的清單,請參閱 torch.serialization.clear_safe_globals()
。
疑難排解 weights_only
¶
取得不安全的全域變數¶
一個需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint()
會靜態分析檢查點,某些類型可能會在反序列化過程中動態建構,因此不會被 torch.serialization.get_unsafe_globals_in_checkpoint()
回報。其中一個例子是 numpy 中的 dtypes
。在 numpy < 1.25
中,在將 torch.serialization.get_unsafe_globals_in_checkpoint()
回報的所有函式/類別加入允許清單後,您可能會看到類似以下的錯誤:
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
可以使用 {add_}safe_globals([type(np.dtype(np.float32))])
將其加入允許清單。
在 numpy >=1.25
中,您會看到
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
可以使用 {add_}safe_globals([np.dtypes.Float32DType])
將其加入允許清單。
序列化 torch.nn.Modules 並在 C++ 中載入它們¶
另請參閱:教學課程:在 C++ 中載入 TorchScript 模型
ScriptModules 可以序列化為 TorchScript 程式,並使用 torch.jit.load()
載入。這種序列化會編碼所有模組的方法、子模組、參數和屬性,並允許在 C++ 中載入序列化的程式(即,無需 Python)。
torch.jit.save()
和 torch.save()
之間的區別可能不明顯。torch.save()
使用 pickle 儲存 Python 物件。這對於原型設計、研究和訓練特別有用。torch.jit.save()
另一方面,將 ScriptModules 序列化為可以在 Python 或 C++ 中載入的格式。這在儲存和載入 C++ 模組,或使用 C++ 運行在 Python 中訓練的模組時很有用,這是部署 PyTorch 模型時的常見做法。
要在 Python 中編寫腳本、序列化和載入模組
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
追蹤的模組也可以使用 torch.jit.save()
儲存,但需要注意的是,只有追蹤的程式碼路徑會被序列化。以下範例示範了這一點
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上面的模組有一個 if 語句,它沒有被追蹤的輸入觸發,因此不屬於追蹤的模組,也不會與它一起序列化。然而,腳本模組包含 if 語句,並與它一起序列化。有關腳本編寫和追蹤的更多資訊,請參閱 TorchScript 文件。
最後,要在 C++ 中載入模組
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有關如何在 C++ 中使用 PyTorch 模組的詳細資訊,請參閱 PyTorch C++ API 文件。
跨 PyTorch 版本儲存和載入 ScriptModules¶
PyTorch 團隊建議使用相同版本的 PyTorch 儲存和載入模組。較舊版本的 PyTorch 可能不支援較新的模組,而較新版本可能已移除或修改了較舊的行為。這些變更已在 PyTorch 的發布說明中明確說明,而依賴已變更功能的模組可能需要更新才能繼續正常運作。在少數情況下(如下詳述),PyTorch 將保留序列化的 ScriptModules 的歷史行為,因此它們不需要更新。
torch.div 執行整數除法¶
在 PyTorch 1.5 及更早版本中,torch.div()
在給定兩個整數輸入時會執行向下取整除法
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
但在 PyTorch 1.7 中,torch.div()
將始終對其輸入執行真除法,就像 Python 3 中的除法一樣
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
torch.div()
的行為保留在序列化的 ScriptModules 中。也就是說,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 將繼續看到 torch.div()
在給定兩個整數輸入時執行向下取整除法,即使使用較新版本的 PyTorch 載入也是如此。但是,使用 torch.div()
且在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在較早版本的 PyTorch 中載入,因為那些較早版本不理解新的行為。
torch.full 總是推斷為浮點數 dtype¶
在 PyTorch 1.5 及更早版本中,無論給定的填充值是什麼,torch.full()
總是回傳一個浮點數張量
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
但在 PyTorch 1.7 中,torch.full()
將從填充值推斷回傳張量的 dtype
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full()
的行為保留在序列化的 ScriptModules 中。也就是說,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 將預設繼續看到 torch.full 回傳浮點數張量,即使給定布林值或整數填充值也是如此。但是,使用 torch.full()
且在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在較早版本的 PyTorch 中載入,因為那些較早版本不理解新的行為。
實用函式¶
以下實用函式與序列化相關
- torch.serialization.register_package(priority, tagger, deserializer)[原始碼][原始碼]¶
註冊可呼叫物件,用於標記和反序列化具有相關優先順序的儲存物件。標記將裝置與儲存物件在儲存時關聯,而反序列化在載入時將儲存物件移動到適當的裝置。
tagger
和deserializer
按照它們的priority
給定的順序運行,直到 tagger/deserializer 回傳一個不是 None 的值。要覆蓋全域註冊表中裝置的反序列化行為,可以註冊一個具有比現有 tagger 更高優先順序的 tagger。
此函式也可用於註冊新裝置的 tagger 和反序列化器。
- 參數
priority (int) – 指示與標記器和反序列化器相關聯的優先順序,其中較低的值表示較高的優先順序。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 可呼叫物件,它接受一個儲存物件,並將其標記的裝置作為字串或 None 回傳。
deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 可呼叫物件,它接受一個儲存物件和一個裝置字串,並在適當的裝置上回傳一個儲存物件或 None。
- 回傳
None
範例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_crc32_options()[原始碼][原始碼]¶
取得
torch.save()
是否計算並為每個記錄寫入 crc32。預設為
True
。- 回傳型別
- torch.serialization.set_crc32_options(compute_crc32)[原始碼][原始碼]¶
設定
torch.save()
是否計算並寫入每個記錄的 crc32 檢測碼。注意
將此選項設定為
False
可能會導致解壓縮torch.save
的輸出檔案失敗或警告,因為 CRC32 檢測碼已損壞。但是,torch.load
仍然可以載入該檔案。- 參數
compute_crc32 (bool) – 設定 crc32 計算旗標
- torch.serialization.get_default_load_endianness()[原始碼][原始碼]¶
取得載入檔案時使用的預設位元組序
如果儲存的檢查點中沒有位元組序標記,則使用此位元組序作為後備方案。 預設情況下,它是「原生」位元組序。
- 回傳
Optional[LoadEndianness]
- 回傳型別
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[原始碼][原始碼]¶
設定載入檔案時使用的後備位元組序
如果儲存的檢查點中沒有位元組序標記,則使用此位元組序作為後備方案。 預設情況下,它是「原生」位元組序。
- 參數
endianness – 新的後備位元組序
- torch.serialization.get_default_mmap_options()[原始碼][原始碼]¶
取得
torch.load()
在mmap=True
時使用的預設 mmap 選項。預設為
mmap.MAP_PRIVATE
。- 回傳
int
- 回傳型別
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[原始碼][原始碼]¶
設定
torch.load()
在mmap=True
時使用的預設 mmap 選項為 flags 的上下文管理器或函式。目前僅支援
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
。 如果您需要在此處新增任何其他選項,請提出 issue。注意
此功能目前不支援 Windows。
- 參數
flags (int) –
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
- torch.serialization.add_safe_globals(safe_globals)[原始碼][原始碼]¶
將給定的全域變數標記為
weights_only
載入的安全對象。 例如,可以在還原序列化期間呼叫新增到此清單的函式,可以實例化類別並設定狀態。清單中的每個項目可以是函式/類別,或 (函式/類別, 字串) 形式的元組,其中字串是函式/類別的完整路徑。
在序列化格式中,每個函式都使用其完整路徑識別為
{__module__}.{__name__}
。 呼叫此 API 時,您可以提供此完整路徑,該路徑應與檢查點中的路徑相符,否則將使用預設的{fn.__module__}.{fn.__name__}
。- 參數
safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要標記為安全的全域變數清單
範例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization.get_unsafe_globals_in_checkpoint(f)[原始碼][原始碼]¶
傳回
torch.save
物件中不適合用於weights_only
的函式/類別字串列表。對於給定的函式或類別
f
,對應的字串格式為{f.__module__}.{f.__name__}
。此函式將傳回檢查點中不在
weights_only
安全集合中的任何 GLOBAL(透過add_safe_globals()
或safe_globals
上下文,或預設情況下由torch
列入白名單)。注意
此函式將靜態反組譯檢查點中的 pickle 檔案。 這表示在 unpickling 期間動態推送到堆疊上的任何類別都不會包含在輸出中。
- class torch.serialization.safe_globals(safe_globals)[原始碼][原始碼]¶
context-manager,用於將某些全域變數新增為可安全用於
weights_only
載入。範例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[原始碼][原始碼]¶
context-manager,用於跳過
torch.save
呼叫的儲存體位元組寫入。儲存體仍然會被儲存,但它們的位元組通常會被寫入的空間將會是空白空間。 然後可以在單獨的過程中填充儲存體位元組。
警告
skip_data
context manager 是一個早期原型,可能會發生變更。- 參數
materialize_fake_tensors (bool) – 是否實現 FakeTensors。
範例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])