捷徑

追蹤

概觀與用法

注意

實驗性,使用風險自負,API 可能會有變動

在 TorchX 中,應用程式是二進位檔案(可執行檔),因此沒有內建的方法可以從應用程式「返回」結果。 torchx.runtime.tracking 模組允許應用程式返回簡單的結果(請注意關鍵字「簡單」)。追蹤器模組支援的返回類型是刻意限制的。例如,不允許嘗試返回訓練好的模型權重,因為它們的大小可能達到數百 GB。此模組的設計目的並非用於傳遞大量的資料或二進位 Blob。

當應用程式作為更高階協調工作(例如,工作流程、管道、超參數最佳化)的一部分啟動時,通常需要協調器或工作流程中的其他應用程式能夠存取應用程式的結果。

假設應用程式 1 和應用程式 2 依序啟動,應用程式 1 的輸出作為應用程式 2 的輸入。由於這些是二進位檔案,因此在應用程式之間鏈接輸入/輸出的典型方法是將應用程式 1 的輸出檔案路徑作為應用程式 2 的輸入檔案路徑傳遞。

$ app1 --output-file s3://foo/out/app1.out
$ app2 --input-file s3://foo/out/app1.out

儘管這看起來很簡單,但仍有一些需要注意的事項

  1. 檔案 app1.out 的格式(應用程式 1 需要以應用程式 2 能夠理解的格式寫入檔案)

  2. 實際解析 URL 並寫入/讀取輸出檔案

因此,應用程式的主要部分最終看起來像這樣(偽代碼,僅供說明)

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   s3client = ...
   out = {"accuracy": accuracy}

   with open("/tmp/out", "w") as f:
       f = json.dumps(out).encode("utf-8")

   s3client.put(args.output_file, f)

# in app2.py
if __name__ == "__main__":
   s3client = ...
   with open("/tmp/out", "w") as f:
       s3client.get(args.input_file, f)

   with open("/tmp/out", "r") as f:
       in = json.loads(f.read().decode("utf-8"))

   do_something_else(in["accuracy"])

相反,透過追蹤器,可以使用具有相同 tracker_base 的追蹤器跨應用程式使用,使一個應用程式的返回值可供另一個應用程式使用,而無需將一個應用程式的輸出檔案路徑與另一個應用程式的輸入檔案路徑鏈接起來,也無需處理自訂序列化和檔案寫入。

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   tracker = FsspecResultTracker(args.tracker_base)
   tracker["app1_out"] = {"accuracy": accuracy}

# in app2.py
if __name__ == "__main__":
   tracker = FsspecResultTracker(args.tracker_base)
   app1_accuracy = tracker["app1_out"]
   do_something_else(app1_accuracy)

ResultTracker

Base

類別 torchx.runtime.tracking.ResultTracker[來源]

基礎結果追蹤器,應該將其子類化以實作追蹤器。通常,每個後端儲存體都有一個追蹤器實作。

用法

# get and put APIs can be used directly or in map-like API
# the following are equivalent
tracker.put("foo", l2norm=1.2)
tracker["foo"] = {"l2norm": 1.2}

# so are these
tracker.get("foo")["l2norm"] == 1.2
tracker["foo"]["l2norm"] == 1.2

有效的 result 類型為

  1. 數值:int、float

  2. 字元:str(以 UTF-8 編碼時大小限制為 1 KB)

有效的 key 類型為

  1. int

  2. str

按照慣例,可以在金鑰中使用「斜線」來儲存統計結果。例如,要儲存 l2norm 的均值和標準誤差

tracker[key] = {"l2norm/mean" : 1.2, "l2norm/sem": 3.4}
tracker[key]["l2norm/mean"] # returns 1.2
tracker[key]["l2norm/sem"] # returns 3.4

假設金鑰在追蹤器後端儲存體的範圍內是唯一的。例如,如果追蹤器由本地目錄支援,並且 key 是儲存結果的目錄中的檔案,則

# same key, different backing directory -> results are not overwritten
FsspecResultTracker("/tmp/foo")["1"] = {"l2norm":1.2}
FsspecResultTracker("/tmp/bar")["1"] = {"l2norm":3.4}

追蹤器不是一個中心實體,因此在同一個金鑰上執行的 putget 操作之間不會做出強一致性保證(超出後端儲存體提供的保證)。同樣,在同一個金鑰上執行的兩個連續 putget 操作之間也不會做出強一致性保證。

例如

tracker[1] = {"l2norm":1.2}
tracker[1] = {"l2norm":3.4}
tracker[1] # NOT GUARANTEED TO BE 3.4!

sleep(1*MIN)
tracker[1] # more likely to be 3.4 but still not guaranteed!

強烈建議使用唯一 ID 作為金鑰。對於簡單的工作,此 ID 通常是工作 ID,或者對於迭代應用程式(如超參數最佳化),可以是(實驗 ID、試驗次數)或(工作 ID、副本/工作器排名)的串聯。

Fsspec

類別 torchx.runtime.tracking.FsspecResultTracker(tracker_base: str)[來源]

在底層使用 fsspec 來儲存結果的追蹤器。

用法

from torchx.runtime.tracking import FsspecResultTracker

# PUT: in trainer.py
tracker_base = "/tmp/foobar" # also supports URIs (e.g. "s3://bucket/trainer/123")
tracker = FsspecResultTracker(tracker_base)
tracker["attempt_1/out"] = {"accuracy": 0.233}

# GET: anywhere outside trainer.py
tracker = FsspecResultTracker(tracker_base)
print(tracker["attempt_1/out"]["accuracy"])
0.233

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源