歡迎來到 TensorDict 文件!¶
TensorDict 是一個類似字典的類別,繼承了 tensors 的屬性,例如索引、形狀操作、轉換為裝置等等。
您可以直接從 PyPI 安裝 tensordict (請參閱以下專用章節以取得更多安裝說明)
$ pip install tensordict
TensorDict 的主要目的是透過抽離客製化的操作,使程式碼庫更具可讀性和模組化
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
透過這個抽象層級,可以為高度異質的任務重複使用訓練迴圈。 訓練迴圈的每個步驟 (資料收集和轉換、模型預測、損失計算等) 都可以針對手邊的用例進行客製化,而不會影響其他步驟。 例如,上述範例可以輕鬆地跨分類和分割任務使用,以及許多其他任務。
安裝¶
Tensordict 版本與 PyTorch 同步,因此請確保始終使用 最新版本的 PyTorch (儘管核心功能保證與 pytorch>=1.13 向後相容) 享受該庫的最新功能。 可以透過以下方式安裝 Nightly 版本
$ pip install tensordict-nightly
或者,如果您願意為該庫做出貢獻,可以透過 git clone
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ cd tensordict
$ python setup.py develop