捷徑

模型可解釋性範例

這是一個範例 TorchX 應用程式,它使用 captum 來分析輸入,以達到模型可解釋性的目的。它會使用來自訓練器應用程式範例的已訓練模型和來自資料預處理應用程式範例的預處理範例。輸出是一系列圖像,其中疊加了整合梯度歸因。

如需使用 captum 的詳細資訊,請參閱 https://captum.ai/tutorials/CIFAR_TorchVision_Interpret

用法

將此主要模組作為 Python 程序在本地執行。以下執行假設已使用 torchx/examples/apps/lightning/train.py 中的用法說明訓練模型。

$ torchx run -s local_cwd utils.python
    --script ./lightning/interpret.py
    --
    --load_path /tmp/torchx/train/last.ckpt
    --output_path /tmp/torchx/interpret

使用圖像檢視器來視覺化 output_path 下產生的 *.png 檔案。

備註

對於使用 TorchX 的 utils.python 內建函式庫的本地執行,實際上等於直接執行主要模組(例如 python ./interpret.py)。使用 TorchX 啟動簡單的單程序 Python 程式的優點是可以透過將 -s local_cwd 替換為遠端排程器(例如 kubernetes)來在遠端排程器上啟動,方法是指定 -s kubernetes

import argparse
import itertools
import os.path
import sys
import tempfile
from typing import List

import fsspec
import torch
from torchx.examples.apps.lightning.data import (
    create_random_data,
    download_data,
    TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import TinyImageNetModel


# ensure data and module are on the path
sys.path.append(".")


# FIXME: captum must be imported after torch otherwise it causes python to crash
if True:
    import numpy as np
    from captum.attr import IntegratedGradients, visualization as viz


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="example TorchX captum app")
    parser.add_argument(
        "--load_path",
        type=str,
        help="checkpoint path to load model weights from",
        required=True,
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="path to load the training data from, if not provided, random dataset will be created",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        help="path to place analysis results",
        required=True,
    )

    return parser.parse_args(argv)


def convert_to_rgb(arr: torch.Tensor) -> np.ndarray:  # pyre-ignore[24]
    """
    This converts the image from a torch tensor with size (1, 1, 64, 64) to
    numpy array with size (64, 64, 3).
    """
    out = arr.squeeze().swapaxes(0, 2)
    assert out.shape == (64, 64, 3), "invalid shape produced"
    return out.numpy()


def main(argv: List[str]) -> None:
    with tempfile.TemporaryDirectory() as tmpdir:
        args = parse_args(argv)

        # Init our model
        model = TinyImageNetModel()

        print(f"loading checkpoint: {args.load_path}...")
        model.load_from_checkpoint(checkpoint_path=args.load_path)

        # Download and setup the data module
        if not args.data_path:
            data_path = os.path.join(tmpdir, "data")
            os.makedirs(data_path)
            create_random_data(data_path)
        else:
            data_path = download_data(args.data_path, tmpdir)
        data = TinyImageNetDataModule(
            data_dir=data_path,
            batch_size=1,
        )

        ig = IntegratedGradients(model)

        data.setup("test")
        dataloader = data.test_dataloader()

        # process first 5 images
        for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
            print(f"analyzing example {i}")
            # input = input.unsqueeze(dim=0)
            model.zero_grad()
            attr_ig, delta = ig.attribute(
                input,
                target=label,
                baselines=input * 0,
                return_convergence_delta=True,
            )

            if attr_ig.count_nonzero() == 0:
                # Our toy model sometimes has no IG results.
                print("skipping due to zero gradients")
                continue

            fig, axis = viz.visualize_image_attr(
                convert_to_rgb(attr_ig),
                convert_to_rgb(input),
                method="blended_heat_map",
                sign="all",
                show_colorbar=True,
                title="Overlayed Integrated Gradients",
            )
            out_path = os.path.join(args.output_path, f"ig_{i}.png")
            print(f"saving heatmap to {out_path}")
            with fsspec.open(out_path, "wb") as f:
                fig.savefig(f)


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

腳本的總執行時間:(0 分鐘 0.000 秒)

由 Sphinx-Gallery 生成的圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源