捷徑

資料預處理應用程式範例

這是一個簡單的 TorchX 應用程式,它透過 HTTP 下載一些資料,透過 torchvision 將影像正規化,然後透過 fsspec 將其重新上傳。

使用方法

注意

datapreproc 應用程式是一個單一程序 Python 程式,因此對於本機執行,您可以將其作為一般的 Python 程式執行:python ./datapreproc.py。TorchX 可讓您在遠端叢集上執行此應用程式。

若要在本機使用 TorchX 啟動(請參閱上述注意事項),請執行

$ torchx run -s local_cwd utils.python       --script ./datapreproc/datapreproc.py       --       --input_path="http://cs231n.stanford.edu/tiny-imagenet-200.zip"       --output_path=/tmp/torchx/datapreproc

若要將此應用程式啟動到遠端叢集上,只需在 -s 選項中指定不同的排程器。

$ torchx run -s kubernetes -cfg queue=foo,namespace=bar utils.python       --script ./datapreproc/datapreproc.py       --       --input_path="http://cs231n.stanford.edu/tiny-imagenet-200.zip"       --output_path=/tmp/torchx/datapreproc
import argparse
import os
import sys
import tarfile
import tempfile
import zipfile
from typing import List

import fsspec
from PIL import Image
from torchvision import transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="example data preprocessing",
    )
    parser.add_argument(
        "--input_path",
        type=str,
        help="dataset to download",
        default="http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        help="remote path to save the .tar.gz data to",
        required=True,
    )
    parser.add_argument(
        "--limit",
        type=int,
        help="limit number of processed examples",
    )
    return parser.parse_args(argv)


def download_and_extract_zip_archive(url: str, path: str) -> None:
    with fsspec.open(url, "rb") as f:
        with zipfile.ZipFile(f, "r") as zip_ref:
            zip_ref.extractall(path)


def main(argv: List[str]) -> None:
    args = parse_args(argv)
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"downloading {args.input_path} to {tmpdir}...")
        download_and_extract_zip_archive(args.input_path, tmpdir)

        img_root = os.path.join(
            tmpdir,
            os.path.splitext(os.path.basename(args.input_path))[0],
        )
        print(f"img_root={img_root}")

        print("transforming images...")
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                transforms.ToPILImage(),
            ]
        )

        image_files = []
        for root, _, fnames in os.walk(img_root):
            for fname in fnames:
                path = os.path.join(root, fname)
                if not is_image_file(path):
                    continue
                image_files.append(path)

                if args.limit and len(image_files) > args.limit:
                    break
        for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
            f = Image.open(path)
            f = transform(f)
            f.save(path)

        tar_path = os.path.join(tmpdir, "out.tar.gz")
        print(f"packing images into {tar_path}...")
        with tarfile.open(tar_path, mode="w:gz") as f:
            f.add(img_root, arcname="")

        print(f"uploading dataset to {args.output_path}...")
        fs, _, rpaths = fsspec.get_fs_token_paths(args.output_path)
        assert len(rpaths) == 1, "must have single output path"
        if fs.exists(rpaths[0]):
            fs.rm(rpaths[0])
        fs.put(tar_path, rpaths[0])


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

腳本總執行時間: (0 分 0.000 秒)

圖庫由 Sphinx-Gallery 產生

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源