自訂¶
本節說明如何自訂 TorchElastic 以符合您的需求。
啟動器¶
TorchElastic 隨附的啟動器程式應足以應付大多數使用案例 (請參閱torchrun (彈性啟動))。您可以透過程式設計方式建立代理程式並將工作者的規格傳遞給它來實作自訂啟動器,如下所示。
# my_launcher.py
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
rdzv_handler = RendezvousHandler(...)
spec = WorkerSpec(
local_world_size=args.nproc_per_node,
fn=trainer_entrypoint_fn,
args=(trainer_entrypoint_fn args.fn_args,...),
rdzv_handler=rdzv_handler,
max_restarts=args.max_restarts,
monitor_interval=args.monitor_interval,
)
agent = LocalElasticAgent(spec, start_method="spawn")
try:
run_result = agent.run()
if run_result.is_failed():
print(f"worker 0 failed with: run_result.failures[0]")
else:
print(f"worker 0 return value is: run_result.return_values[0]")
except Exception ex:
# handle exception
Rendezvous 處理常式¶
若要實作您自己的 rendezvous,請擴充 torch.distributed.elastic.rendezvous.RendezvousHandler
並實作其方法。
警告
Rendezvous 處理常式很難實作。在您開始之前,請確保您完全了解 rendezvous 的屬性。請參閱Rendezvous 以取得更多資訊。
實作後,您可以在建立代理程式時將您的自訂 rendezvous 處理常式傳遞給工作者規格。
spec = WorkerSpec(
rdzv_handler=MyRendezvousHandler(params),
...
)
elastic_agent = LocalElasticAgent(spec, start_method=start_method)
elastic_agent.run(spec.role)
指標處理常式¶
TorchElastic 會發出平台層級的指標 (請參閱指標)。預設情況下,指標會發送到 /dev/null,因此您不會看到它們。若要將指標推送到基礎架構中的指標處理服務,請實作 torch.distributed.elastic.metrics.MetricHandler 並在您的自訂啟動器中 configure 它。
# my_launcher.py
import torch.distributed.elastic.metrics as metrics
class MyMetricHandler(metrics.MetricHandler):
def emit(self, metric_data: metrics.MetricData):
# push metric_data to your metric sink
def main():
metrics.configure(MyMetricHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()
事件處理器¶
TorchElastic 支援事件記錄 (請參閱事件)。 事件模組定義了 API,可讓您記錄事件並實作自訂的 EventHandler。EventHandler 用於將 torchelastic 執行期間產生的事件發佈到不同的來源,例如 AWS CloudWatch。 預設情況下,它使用 torch.distributed.elastic.events.NullEventHandler,該處理器會忽略事件。 若要設定自訂事件處理器,您需要實作 torch.distributed.elastic.events.EventHandler 介面,並在您的自訂啟動器中設定它。
# my_launcher.py
import torch.distributed.elastic.events as events
class MyEventHandler(events.EventHandler):
def record(self, event: events.Event):
# process event
def main():
events.configure(MyEventHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()