捷徑

使用 Flask 部署

建立於:2020 年 5 月 4 日 | 上次更新:2021 年 9 月 15 日 | 上次驗證:未驗證

在本食譜中,您將學習

  • 如何將您訓練好的 PyTorch 模型包裝在 Flask 容器中,以透過 Web API 公開它

  • 如何將傳入的 Web 請求轉換為您模型的 PyTorch 張量

  • 如何封裝您模型的輸出以用於 HTTP 回應

需求

您將需要一個 Python 3 環境,並安裝以下套件 (及其相依性)

  • PyTorch 1.5

  • TorchVision 0.6.0

  • Flask 1.1

選擇性地,要取得一些支援檔案,您需要 git。

安裝 PyTorch 和 TorchVision 的說明可在 pytorch.org 取得。安裝 Flask 的說明可在 Flask 網站 取得。

什麼是 Flask?

Flask 是一個用 Python 編寫的輕量級 Web 伺服器。它提供了一種方便的方式,讓您可以快速為您訓練好的 PyTorch 模型設定一個 Web API,用於直接使用,或作為較大系統中的 Web 服務。

設定和支援檔案

我們將建立一個 Web 服務,該服務接收圖片,並將它們對應到 ImageNet 資料集的 1000 個類別之一。為此,您需要一個用於測試的圖片檔案。或者,您也可以取得一個檔案,將模型輸出的類別索引對應到人類可讀的類別名稱。

選項 1:快速取得兩個檔案

您可以透過檢查 TorchServe 儲存庫並將它們複製到您的工作資料夾,快速提取兩個支援檔案。(注意:本教學課程不依賴 TorchServe - 這只是一種快速取得檔案的方式。) 從您的 shell 提示符發出以下指令

git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .

您已經取得它們了!

選項 2:帶上您自己的圖片

在下面的 Flask 服務中,index_to_name.json 檔案是可選的。您可以使用您自己的圖片測試您的服務 - 只需確保它是 3 色 JPEG。

建立您的 Flask 服務

Flask 服務的完整 Python 腳本顯示在本食譜的結尾;您可以將其複製並貼到您自己的 app.py 檔案中。下面我們將查看各個部分,以使它們的功能清晰。

匯入

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

依序

  • 我們將使用來自 torchvision.models 的預訓練 DenseNet 模型

  • torchvision.transforms 包含用於操作您的圖片資料的工具

  • Pillow (PIL) 是我們最初用於載入圖片檔案的工具

  • 當然,我們還需要來自 flask 的類別

預處理

def transform_image(infile):
    input_transforms = [transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)
    timg = my_transforms(image)
    timg.unsqueeze_(0)
    return timg

Web 請求給我們一個圖片檔案,但我們的模型期望一個形狀為 (N, 3, 224, 224) 的 PyTorch 張量,其中 N 是輸入批次中的項目數。(我們將只有 1 的批次大小。) 我們要做的第一件事是組成一組 TorchVision 轉換,它們調整圖片大小並裁剪圖片、將其轉換為張量,然後正規化張量中的值。(有關此正規化的更多資訊,請參閱 torchvision.models_ 的文件。)

之後,我們開啟檔案並應用轉換。轉換傳回一個形狀為 (3, 224, 224) 的張量 - 一個 224x224 圖片的 3 個顏色通道。因為我們需要將這個單一圖片做成一個批次,我們使用 unsqueeze_(0) 呼叫來就地修改張量,方法是新增一個新的第一維度。張量包含相同的資料,但現在的形狀為 (1, 3, 224, 224)。

一般來說,即使您不使用圖片資料,您也需要將 HTTP 請求中的輸入轉換為 PyTorch 可以使用的張量。

推論

def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)
    _, y_hat = outputs.max(1)
    prediction = y_hat.item()
    return prediction

推論本身是最簡單的部分:當我們將輸入張量傳遞給模型時,我們會取回一個值張量,該張量代表模型估計的圖片屬於特定類別的可能性。 max() 呼叫找到具有最大可能性的類別,並傳回該值以及 ImageNet 類別索引。最後,我們使用 item() 呼叫從包含它的張量中提取該類別索引,並將其傳回。

後處理

def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name

render_prediction() 方法將預測的類別索引對應到人類可讀的類別標籤。通常,在從您的模型取得預測後,執行後處理以使預測準備好供人類使用或供另一段軟體使用。

執行完整的 Flask 應用程式

將以下內容貼到名為 app.py 的檔案中

import io
import json
import os

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
model = models.densenet121(pretrained=True)               # Trained on 1000 classes from ImageNet
model.eval()                                              # Turns off autograd



img_class_map = None
mapping_file_path = 'index_to_name.json'                  # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
    with open (mapping_file_path) as f:
        img_class_map = json.load(f)



# Transform input into the form our model expects
def transform_image(infile):
    input_transforms = [transforms.Resize(255),           # We use multiple TorchVision transforms to ready the image
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],       # Standard normalization for ImageNet model input
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)                            # Open the image file
    timg = my_transforms(image)                           # Transform PIL image to appropriately-shaped PyTorch tensor
    timg.unsqueeze_(0)                                    # PyTorch models expect batched input; create a batch of 1
    return timg


# Get a prediction
def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)                 # Get likelihoods for all ImageNet classes
    _, y_hat = outputs.max(1)                             # Extract the most likely class
    prediction = y_hat.item()                             # Extract the int value from the PyTorch tensor
    return prediction

# Make the prediction human-readable
def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name


@app.route('/', methods=['GET'])
def root():
    return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        if file is not None:
            input_tensor = transform_image(file)
            prediction_idx = get_prediction(input_tensor)
            class_id, class_name = render_prediction(prediction_idx)
            return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

要從您的 shell 提示符啟動伺服器,請發出以下指令

FLASK_APP=app.py flask run

預設情況下,您的 Flask 伺服器正在監聽埠 5000。伺服器執行後,開啟另一個終端機視窗,並測試您的新推論伺服器

curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"

如果一切設置正確,您應該會收到類似以下的響應

{"class_id":285,"class_name":"Egyptian_cat"}

重要資源

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您問題的解答

查看資源