備註
點擊 這裡以下載完整的範例程式碼
訓練器範例¶
這是一個使用 PyTorch Lightning 訓練模型的 TorchX 應用程式範例。
此應用程式僅使用標準的 OSS 函式庫,並且沒有執行階段 torchx 依賴項。為了儲存和載入資料和模型,它使用 fsspec,這使得應用程式與其執行的環境無關。
用法¶
若要將訓練器在本機端作為 ddp 應用程式執行,使用 1 個節點和每個節點 2 個工作器(世界大小 = 2)
$ torchx run -s local_cwd dist.ddp
-j 1x2
--script ./lightning/train.py
--
--epochs=1
--output_path=/tmp/torchx/train
--log_path=/tmp/torchx/logs
--skip_export
備註
--
用於分隔元件 (dist.ddp
) 和應用程式參數。
使用 --help
選項查看應用程式選項的完整清單
$ torchx run -s local_cwd dist.ddp -j 1x1 --script ./lightning/train.py -- --help
這實際上與 ./train.py --help
相同。若要在遠端排程器上執行,請使用 -s
選項指定排程器。根據遠端排程器的類型,您可能必須使用 -cfg
選項傳遞其他排程器配置。請參閱 遠端排程器 以取得更多詳細資訊。
import argparse
import os
import sys
import tempfile
from typing import List, Optional
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchx.examples.apps.lightning.data import (
create_random_data,
download_data,
TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import (
export_inference_model,
TinyImageNetModel,
)
from torchx.examples.apps.lightning.profiler import SimpleLoggingProfiler
# ensure data and module are on the path
sys.path.append(".")
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="pytorch lightning TorchX example app")
parser.add_argument(
"--epochs", type=int, default=3, help="number of epochs to train"
)
parser.add_argument("--lr", type=float, help="learning rate")
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size to use for training"
)
parser.add_argument("--num_samples", type=int, default=10, help="num_samples")
parser.add_argument(
"--data_path",
type=str,
help="path to load the training data from, if not provided, random data will be generated",
)
parser.add_argument("--skip_export", action="store_true")
parser.add_argument("--load_path", type=str, help="checkpoint path to load from")
parser.add_argument(
"--output_path",
type=str,
help="path to place checkpoints and model outputs, if not specified, checkpoints are not saved",
)
parser.add_argument(
"--log_path",
type=str,
help="path to place the tensorboard logs",
default="/tmp",
)
parser.add_argument(
"--layers",
nargs="+",
type=int,
help="the MLP hidden layers and sizes, used for neural architecture search",
)
return parser.parse_args(argv)
def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
if not args.output_path:
return None
# Note: It is important that each rank behaves the same.
# All of the ranks, or none of them should return ModelCheckpoint
# Otherwise, there will be deadlock for distributed training
return ModelCheckpoint(
monitor="train_loss",
dirpath=args.output_path,
save_last=True,
)
def main(argv: List[str]) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
args = parse_args(argv)
# Init our model
model = TinyImageNetModel(args.layers)
print(model)
# Download and setup the data module
if not args.data_path:
data_path = os.path.join(tmpdir, "data")
os.makedirs(data_path)
create_random_data(data_path)
else:
data_path = download_data(args.data_path, tmpdir)
data = TinyImageNetDataModule(
data_dir=data_path,
batch_size=args.batch_size,
num_samples=args.num_samples,
)
# Setup model checkpointing
checkpoint_callback = get_model_checkpoint(args)
callbacks = []
if checkpoint_callback:
callbacks.append(checkpoint_callback)
if args.load_path:
print(f"loading checkpoint: {args.load_path}...")
model.load_from_checkpoint(checkpoint_path=args.load_path)
logger = TensorBoardLogger(
save_dir=args.log_path, version=1, name="lightning_logs"
)
# Initialize a trainer
trainer = pl.Trainer(
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
strategy="ddp",
logger=logger,
max_epochs=args.epochs,
callbacks=callbacks,
profiler=SimpleLoggingProfiler(logger),
)
# Train the model ⚡
trainer.fit(model, data)
print(
f"train acc: {model.train_acc.compute()}, val acc: {model.val_acc.compute()}"
)
rank = int(os.environ.get("RANK", 0))
if rank == 0 and not args.skip_export and args.output_path:
# Export the inference model
export_inference_model(model, args.output_path, tmpdir)
if __name__ == "__main__" and "NOTEBOOK" not in globals():
main(sys.argv[1:])
# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'
指令碼的總執行時間: (0 分鐘 0.000 秒)