快捷鍵

torch.jit.save

torch.jit.save(m, f, _extra_files=None)[source][source]

儲存此模組的離線版本,以便在個別的程序中使用。

儲存的模組會序列化此模組的所有方法、子模組、參數和屬性。可以使用 torch::jit::load(filename) 載入到 C++ API 中,或者使用 torch.jit.load 載入到 Python API 中。

為了能夠儲存模組,它不能呼叫任何原生 Python 函數。這表示所有子模組都必須是 ScriptModule 的子類別。

危險

無論其設備為何,所有模組在載入時都會載入到 CPU 上。這與 torch.load() 的語義不同,並且未來可能會更改。

參數
  • m – 要儲存的 ScriptModule

  • f – 類檔案物件(必須實現寫入和刷新)或包含檔案名稱的字串。

  • _extra_files – 從檔案名稱到內容的映射,這些內容將作為 f 的一部分儲存。

注意

torch.jit.save 嘗試在不同版本之間保留某些運算符的行為。例如,在 PyTorch 1.5 中將兩個整數張量相除會執行 floor division,如果包含該程式碼的模組儲存在 PyTorch 1.5 中並載入到 PyTorch 1.6 中,則其除法行為將被保留。但是,在 PyTorch 1.6 中儲存的相同模組將無法在 PyTorch 1.5 中載入,因為除法的行為在 1.6 中發生了變化,並且 1.5 不知道如何複製 1.6 的行為。

範例: .. testcode

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源