備註
點選 這裡以下載完整的範例程式碼
模型可解釋性範例¶
這是一個範例 TorchX 應用程式,它使用 captum 來分析輸入,以達到模型可解釋性的目的。它會使用來自訓練器應用程式範例的已訓練模型和來自資料預處理應用程式範例的預處理範例。輸出是一系列圖像,其中疊加了整合梯度歸因。
如需使用 captum 的詳細資訊,請參閱 https://captum.ai/tutorials/CIFAR_TorchVision_Interpret。
用法¶
將此主要模組作為 Python 程序在本地執行。以下執行假設已使用 torchx/examples/apps/lightning/train.py
中的用法說明訓練模型。
$ torchx run -s local_cwd utils.python
--script ./lightning/interpret.py
--
--load_path /tmp/torchx/train/last.ckpt
--output_path /tmp/torchx/interpret
使用圖像檢視器來視覺化 output_path
下產生的 *.png
檔案。
備註
對於使用 TorchX 的 utils.python
內建函式庫的本地執行,實際上等於直接執行主要模組(例如 python ./interpret.py
)。使用 TorchX 啟動簡單的單程序 Python 程式的優點是可以透過將 -s local_cwd
替換為遠端排程器(例如 kubernetes)來在遠端排程器上啟動,方法是指定 -s kubernetes
。
import argparse
import itertools
import os.path
import sys
import tempfile
from typing import List
import fsspec
import torch
from torchx.examples.apps.lightning.data import (
create_random_data,
download_data,
TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import TinyImageNetModel
# ensure data and module are on the path
sys.path.append(".")
# FIXME: captum must be imported after torch otherwise it causes python to crash
if True:
import numpy as np
from captum.attr import IntegratedGradients, visualization as viz
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="example TorchX captum app")
parser.add_argument(
"--load_path",
type=str,
help="checkpoint path to load model weights from",
required=True,
)
parser.add_argument(
"--data_path",
type=str,
help="path to load the training data from, if not provided, random dataset will be created",
)
parser.add_argument(
"--output_path",
type=str,
help="path to place analysis results",
required=True,
)
return parser.parse_args(argv)
def convert_to_rgb(arr: torch.Tensor) -> np.ndarray: # pyre-ignore[24]
"""
This converts the image from a torch tensor with size (1, 1, 64, 64) to
numpy array with size (64, 64, 3).
"""
out = arr.squeeze().swapaxes(0, 2)
assert out.shape == (64, 64, 3), "invalid shape produced"
return out.numpy()
def main(argv: List[str]) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
args = parse_args(argv)
# Init our model
model = TinyImageNetModel()
print(f"loading checkpoint: {args.load_path}...")
model.load_from_checkpoint(checkpoint_path=args.load_path)
# Download and setup the data module
if not args.data_path:
data_path = os.path.join(tmpdir, "data")
os.makedirs(data_path)
create_random_data(data_path)
else:
data_path = download_data(args.data_path, tmpdir)
data = TinyImageNetDataModule(
data_dir=data_path,
batch_size=1,
)
ig = IntegratedGradients(model)
data.setup("test")
dataloader = data.test_dataloader()
# process first 5 images
for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
print(f"analyzing example {i}")
# input = input.unsqueeze(dim=0)
model.zero_grad()
attr_ig, delta = ig.attribute(
input,
target=label,
baselines=input * 0,
return_convergence_delta=True,
)
if attr_ig.count_nonzero() == 0:
# Our toy model sometimes has no IG results.
print("skipping due to zero gradients")
continue
fig, axis = viz.visualize_image_attr(
convert_to_rgb(attr_ig),
convert_to_rgb(input),
method="blended_heat_map",
sign="all",
show_colorbar=True,
title="Overlayed Integrated Gradients",
)
out_path = os.path.join(args.output_path, f"ig_{i}.png")
print(f"saving heatmap to {out_path}")
with fsspec.open(out_path, "wb") as f:
fig.savefig(f)
if __name__ == "__main__" and "NOTEBOOK" not in globals():
main(sys.argv[1:])
# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'
腳本的總執行時間:(0 分鐘 0.000 秒)