儲存 TensorDict 與 tensorclass 物件¶
雖然我們可以使用 save()
來儲存 tensordict,但這會建立一個包含資料結構完整內容的單一檔案。很容易想像在某些情況下,這種方式並非最佳選擇!
TensorDict 的序列化 API 主要依賴 MemoryMappedTensor
,它用於將 tensors 獨立寫入磁碟,並以模仿 TensorDict 的資料結構呈現。
與 PyTorch 使用 save()
對 pickle 的依賴相比,TensorDict 的序列化速度可能快上一個數量級。本文檔將說明如何使用 TensorDict 建立和操作儲存在磁碟上的資料。
儲存記憶體映射的 TensorDict¶
當 tensordict 被傾印為 mmap 資料結構時,每個條目對應到一個 *.memmap
檔案,而目錄結構則由鍵結構決定:通常,巢狀鍵對應於子目錄。
將資料結構儲存為結構化的記憶體映射 tensors 集合具有以下優點
可以部分載入已儲存的資料。如果一個大型模型被儲存在磁碟上,但只需要將其部分權重載入到在單獨腳本中建立的模組中,那麼只有這些權重會被載入到記憶體中。
儲存資料是安全的:使用 pickle 函式庫序列化大型資料結構可能是不安全的,因為還原序列化可能會執行任何任意程式碼。 TensorDict 的載入 API 僅從儲存的 json 檔案和儲存在磁碟上的記憶體緩衝區中讀取預先選定的欄位。
儲存速度很快:因為資料被寫入多個獨立檔案中,我們可以通過啟動多個並發執行緒,讓每個執行緒各自存取專用檔案,來分攤 IO 成本。
已儲存資料的結構是顯而易見的:目錄樹指示資料內容。
但是,這種方法也有一些缺點
並非所有資料類型都可以儲存。
tensorclass
允許儲存任何非 tensor 資料:如果這些資料可以用 json 檔案表示,將使用 json 格式。否則,非 tensor 資料將使用save()
作為備用方案獨立儲存。NonTensorData
類別可用於在常規TensorDict
實例中表示非 tensor 資料。
tensordict 的記憶體映射 API 依賴於四個核心方法:memmap_()
、memmap()
、memmap_like()
和 load_memmap()
。
memmap_()
和 memmap()
方法會將資料寫入磁碟,無論是否修改包含該資料的 tensordict 實例。 這些方法可用於將模型序列化到磁碟上(我們使用多個執行緒來加速序列化)
>>> model = nn.Transformer()
>>> weights = TensorDict.from_module(model)
>>> weights_disk = weights.memmap("/path/to/saved/dir", num_threads=32)
>>> new_weights = TensorDict.load_memmap("/path/to/saved/dir")
>>> assert (weights_disk == new_weights).all()
memmap_like()
用於需要在磁碟上預先分配資料集的情況,典型的用法是
>>> def make_datum(): # used for illustration purposes
... return TensorDict({"image": torch.randint(255, (3, 64, 64)), "label": 0}, batch_size=[])
>>> dataset_size = 1_000_000
>>> datum = make_datum() # creates a single instance of a TensorDict datapoint
>>> data = datum.expand(dataset_size) # does NOT require more memory usage than datum, since it's only a view on datum!
>>> data_disk = data.memmap_like("/path/to/data") # creates the two memory-mapped tensors on disk
>>> del data # data is not needed anymore
如上所示,當將 TensorDict`
的條目轉換為 MemoryMappedTensor
時,可以控制記憶體映射保存在磁碟上的位置,以便它們可以持久存在並在稍後載入。 另一方面,也可以使用檔案系統。 要使用此功能,只需在上述三個序列化方法中捨棄 prefix
參數即可。
當指定了 prefix
時,資料結構遵循 TensorDict 的結構
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
>>> td.memmap_(prefix="tensordict")
產生以下目錄結構
tensordict
├── a.memmap
├── b
│ ├── c.memmap
│ └── meta.json
└── meta.json
meta.json
檔案包含重建 tensordict 的所有相關信息,例如裝置、批次大小以及 tensordict 子類型。 這意味著 load_memmap()
將能夠重建複雜的巢狀結構,其中子 tensordict 具有與父 tensordict 不同的類型
>>> from tensordict import TensorDict, tensorclass, TensorDictBase
>>> from tensordict.utils import print_directory_tree
>>> import torch
>>> import tempfile
>>> td_list = [TensorDict({"item": i}, batch_size=[]) for i in range(4)]
>>> @tensorclass
... class MyClass:
... data: torch.Tensor
... metadata: str
>>> tc = MyClass(torch.randn(3), metadata="some text", batch_size=[])
>>> data = TensorDict({"td_list": torch.stack(td_list), "tensorclass": tc}, [])
>>> with tempfile.TemporaryDirectory() as tempdir:
... data.memmap_(tempdir)
...
... loaded_data = TensorDictBase.load_memmap(tempdir)
... assert (loaded_data == data).all()
... print_directory_tree(tempdir)
tmpzy1jcaoq/
tensorclass/
_tensordict/
data.memmap
meta.json
meta.json
td_list/
0/
item.memmap
meta.json
1/
item.memmap
meta.json
3/
item.memmap
meta.json
2/
item.memmap
meta.json
meta.json
meta.json
處理現有的 MemoryMappedTensor
¶
如果 TensorDict`
已經包含 MemoryMappedTensor
條目,則有幾種可能的行為。
如果未指定
prefix
且memmap()
被呼叫兩次,則產生的 TensorDict 將包含與原始 TensorDict 相同的資料。>>> td = TensorDict({"a": 1}, []) >>> td0 = td.memmap() >>> td1 = td0.memmap() >>> td0["a"] is td1["a"] True
如果指定了
prefix
且與現有MemoryMappedTensor
實例的前綴不同,則會引發異常,除非傳遞 copy_existing=True>>> with tempfile.TemporaryDirectory() as tmpdir_0: ... td0 = td.memmap(tmpdir_0) ... td0 = td.memmap(tmpdir_0) # works, results are just overwritten ... with tempfile.TemporaryDirectory() as tmpdir_1: ... td1 = td0.memmap(tmpdir_1) ... td_load = TensorDict.load_memmap(tmpdir_1) # works! ... assert (td_load == td).all() ... with tempfile.TemporaryDirectory() as tmpdir_1: ... td_load = TensorDict.load_memmap(tmpdir_1) # breaks!
實作此功能是為了防止使用者不小心將記憶體映射張量從一個位置複製到另一個位置。
TorchSnapshot 相容性¶
警告
由於 torchsnapshot 的維護已停止,因此我們不會為 tensordict 與此庫的相容性實作新功能。
TensorDict 與 torchsnapshot(一個 PyTorch 檢查點庫)相容。 TorchSnapshot 將獨立儲存您的每個張量,其資料結構模仿您的 tensordict 或 tensorclass 的資料結構。 此外,TensorDict 自然內建了在磁碟上儲存和載入巨型資料集的必要工具,而無需將完整的張量載入記憶體:換句話說,tensordict + torchsnapshot 的組合可以將數百 Gb 的張量載入到預先分配的 MemmapTensor
上,而無需以一個塊的形式傳遞到 RAM 中。
主要有兩個用例:儲存和載入適合記憶體的 tensordict,以及使用 MemmapTensor
儲存和載入儲存在磁碟上的 tensordict。
一般用例:記憶體中載入¶
如果您的目標 tensordict 未預先分配,則此方法適用。 這提供了靈活性(您可以將任何 tensordict 載入到您的 tensordict 中,您無需事先知道其內容),並且此方法在編碼上稍微容易一些。 但是,如果您的張量非常大且不適合記憶體,則此方法可能會失敗。 此外,它不允許您直接載入到您選擇的裝置上。
儲存操作要記住的兩個主要命令是
>>> state = {"state": tensordict_source}
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path="/path/to/my/snapshot")
要載入到目標 tensordict,您可以簡單地載入快照並更新 tensordict。 在幕後,此方法將呼叫 tensordict_target.load_state_dict(state_dict)
,這意味著 state_dict
將首先完全載入到記憶體中,然後載入到目標 tensordict 上
>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> state_target = {"state": tensordict_target}
>>> snapshot.restore(app_state=state_target)
這是一個完整的範例
>>> import uuid
>>> import torchsnapshot
>>> from tensordict import TensorDict
>>> import torch
>>>
>>> tensordict_source = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3)}}, [])
>>> state = {"state": tensordict}
>>> path = f"/tmp/{uuid.uuid4()}"
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict()
>>> target_state = {
>>> "state": tensordict2
>>> }
>>> snapshot.restore(app_state=target_state)
>>> assert (tensordict == tensordict2).all()
儲存和載入大型資料集¶
如果資料集太大而無法放入記憶體,則上述方法很容易失敗。 我們利用 torchsnapshot 的功能,將張量以小塊的形式載入到它們預先分配的目標位置。 這需要您知道您的目標資料將擁有的形狀、裝置等,但這是能夠檢查點您的模型或資料載入所必須付出的很小的代價!
與前一個範例相反,我們這次不會使用 load_state_dict()
方法,而是使用從目標物件取得的 state_dict
,並用儲存的資料重新填充目標物件。
同樣地,只需要兩行程式碼即可儲存資料
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path="/path/to/my/snapshot")
我們一直在使用 torchsnapshot.StateDict
,並且明確地呼叫了 my_tensordict_source.state_dict(keep_vars=True)
,這與之前的範例不同。現在,將其載入到目標 tensordict 上
>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
在這個範例中,載入完全由 torchsnapshot 處理,也就是說,沒有呼叫 TensorDict.load_state_dict()
。
注意
這有兩個重要的含義
由於
LazyStackedTensorDict.state_dict()
(和其他 lazy tensordict 類別) 在執行某些操作後會傳回資料的副本,因此載入到 state-dict 上不會更新原始類別。但是,由於支援 state_dict() 操作,因此不會引發錯誤。同樣地,由於 state-dict 是就地更新的,但 tensordict 沒有使用
TensorDict.update()
或TensorDict.set()
更新,因此目標 tensordict 中缺少的鍵將不會被注意到。
這是一個完整的範例
>>> td = TensorDict({"a": torch.randn(3), "b": TensorDict({"c": torch.randn(3, 1)}, [3, 1])}, [3])
>>> td.memmap_()
>>> assert isinstance(td["b", "c"], MemmapTensor)
>>>
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=td.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>>
>>> td_dest = TensorDict({"a": torch.zeros(3), "b": TensorDict({"c": torch.zeros(3, 1)}, [3, 1])}, [3])
>>> td_dest.memmap_()
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>> # sanity check
>>> assert (td_dest == td).all()
>>> assert (td_dest["b"].batch_size == td["b"].batch_size)
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
最後,tensorclass 也支援此功能。程式碼與上面的程式碼非常相似
>>> from __future__ import annotations
>>> import uuid
>>> from typing import Union, Optional
>>>
>>> import torchsnapshot
>>> from tensordict import TensorDict, MemmapTensor
>>> import torch
>>> from tensordict.prototype import tensorclass
>>>
>>> @tensorclass
>>> class MyClass:
... x: torch.Tensor
... y: Optional[MyClass]=None
...
>>> tc = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc.memmap_()
>>> assert isinstance(tc.y.x, MemmapTensor)
>>>
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>> tc_dest = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc_dest.memmap_()
>>> assert isinstance(tc_dest.y.x, MemmapTensor)
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>>
>>> assert (tc_dest == tc).all()
>>> assert (tc_dest.y.batch_size == tc.y.batch_size)
>>> assert isinstance(tc_dest.y.x, MemmapTensor)