捷徑

訓練器範例

這是一個使用 PyTorch Lightning 訓練模型的 TorchX 應用程式範例。

此應用程式僅使用標準的 OSS 函式庫,並且沒有執行階段 torchx 依賴項。為了儲存和載入資料和模型,它使用 fsspec,這使得應用程式與其執行的環境無關。

用法

若要將訓練器在本機端作為 ddp 應用程式執行,使用 1 個節點和每個節點 2 個工作器(世界大小 = 2)

$ torchx run -s local_cwd dist.ddp
   -j 1x2
   --script ./lightning/train.py
   --
   --epochs=1
   --output_path=/tmp/torchx/train
   --log_path=/tmp/torchx/logs
   --skip_export

備註

-- 用於分隔元件 (dist.ddp) 和應用程式參數。

使用 --help 選項查看應用程式選項的完整清單

$ torchx run -s local_cwd dist.ddp -j 1x1 --script ./lightning/train.py -- --help

這實際上與 ./train.py --help 相同。若要在遠端排程器上執行,請使用 -s 選項指定排程器。根據遠端排程器的類型,您可能必須使用 -cfg 選項傳遞其他排程器配置。請參閱 遠端排程器 以取得更多詳細資訊。

import argparse
import os
import sys
import tempfile
from typing import List, Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchx.examples.apps.lightning.data import (
    create_random_data,
    download_data,
    TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import (
    export_inference_model,
    TinyImageNetModel,
)
from torchx.examples.apps.lightning.profiler import SimpleLoggingProfiler


# ensure data and module are on the path
sys.path.append(".")


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="pytorch lightning TorchX example app")
    parser.add_argument(
        "--epochs", type=int, default=3, help="number of epochs to train"
    )
    parser.add_argument("--lr", type=float, help="learning rate")
    parser.add_argument(
        "--batch_size", type=int, default=32, help="batch size to use for training"
    )
    parser.add_argument("--num_samples", type=int, default=10, help="num_samples")
    parser.add_argument(
        "--data_path",
        type=str,
        help="path to load the training data from, if not provided, random data will be generated",
    )
    parser.add_argument("--skip_export", action="store_true")
    parser.add_argument("--load_path", type=str, help="checkpoint path to load from")
    parser.add_argument(
        "--output_path",
        type=str,
        help="path to place checkpoints and model outputs, if not specified, checkpoints are not saved",
    )
    parser.add_argument(
        "--log_path",
        type=str,
        help="path to place the tensorboard logs",
        default="/tmp",
    )
    parser.add_argument(
        "--layers",
        nargs="+",
        type=int,
        help="the MLP hidden layers and sizes, used for neural architecture search",
    )
    return parser.parse_args(argv)


def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
    if not args.output_path:
        return None
    # Note: It is important that each rank behaves the same.
    # All of the ranks, or none of them should return ModelCheckpoint
    # Otherwise, there will be deadlock for distributed training
    return ModelCheckpoint(
        monitor="train_loss",
        dirpath=args.output_path,
        save_last=True,
    )


def main(argv: List[str]) -> None:
    with tempfile.TemporaryDirectory() as tmpdir:
        args = parse_args(argv)

        # Init our model
        model = TinyImageNetModel(args.layers)
        print(model)

        # Download and setup the data module
        if not args.data_path:
            data_path = os.path.join(tmpdir, "data")
            os.makedirs(data_path)
            create_random_data(data_path)
        else:
            data_path = download_data(args.data_path, tmpdir)

        data = TinyImageNetDataModule(
            data_dir=data_path,
            batch_size=args.batch_size,
            num_samples=args.num_samples,
        )

        # Setup model checkpointing
        checkpoint_callback = get_model_checkpoint(args)
        callbacks = []
        if checkpoint_callback:
            callbacks.append(checkpoint_callback)
        if args.load_path:
            print(f"loading checkpoint: {args.load_path}...")
            model.load_from_checkpoint(checkpoint_path=args.load_path)

        logger = TensorBoardLogger(
            save_dir=args.log_path, version=1, name="lightning_logs"
        )
        # Initialize a trainer
        trainer = pl.Trainer(
            num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
            strategy="ddp",
            logger=logger,
            max_epochs=args.epochs,
            callbacks=callbacks,
            profiler=SimpleLoggingProfiler(logger),
        )

        # Train the model ⚡
        trainer.fit(model, data)
        print(
            f"train acc: {model.train_acc.compute()}, val acc: {model.val_acc.compute()}"
        )

        rank = int(os.environ.get("RANK", 0))
        if rank == 0 and not args.skip_export and args.output_path:
            # Export the inference model
            export_inference_model(model, args.output_path, tmpdir)


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

指令碼的總執行時間: (0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生的圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源