使用 Flask 部署¶
建立於:2020 年 5 月 4 日 | 上次更新:2021 年 9 月 15 日 | 上次驗證:未驗證
在本食譜中,您將學習
如何將您訓練好的 PyTorch 模型包裝在 Flask 容器中,以透過 Web API 公開它
如何將傳入的 Web 請求轉換為您模型的 PyTorch 張量
如何封裝您模型的輸出以用於 HTTP 回應
需求¶
您將需要一個 Python 3 環境,並安裝以下套件 (及其相依性)
PyTorch 1.5
TorchVision 0.6.0
Flask 1.1
選擇性地,要取得一些支援檔案,您需要 git。
安裝 PyTorch 和 TorchVision 的說明可在 pytorch.org 取得。安裝 Flask 的說明可在 Flask 網站 取得。
什麼是 Flask?¶
Flask 是一個用 Python 編寫的輕量級 Web 伺服器。它提供了一種方便的方式,讓您可以快速為您訓練好的 PyTorch 模型設定一個 Web API,用於直接使用,或作為較大系統中的 Web 服務。
設定和支援檔案¶
我們將建立一個 Web 服務,該服務接收圖片,並將它們對應到 ImageNet 資料集的 1000 個類別之一。為此,您需要一個用於測試的圖片檔案。或者,您也可以取得一個檔案,將模型輸出的類別索引對應到人類可讀的類別名稱。
選項 1:快速取得兩個檔案¶
您可以透過檢查 TorchServe 儲存庫並將它們複製到您的工作資料夾,快速提取兩個支援檔案。(注意:本教學課程不依賴 TorchServe - 這只是一種快速取得檔案的方式。) 從您的 shell 提示符發出以下指令
git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .
您已經取得它們了!
選項 2:帶上您自己的圖片¶
在下面的 Flask 服務中,index_to_name.json
檔案是可選的。您可以使用您自己的圖片測試您的服務 - 只需確保它是 3 色 JPEG。
建立您的 Flask 服務¶
Flask 服務的完整 Python 腳本顯示在本食譜的結尾;您可以將其複製並貼到您自己的 app.py
檔案中。下面我們將查看各個部分,以使它們的功能清晰。
匯入¶
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
依序
我們將使用來自
torchvision.models
的預訓練 DenseNet 模型torchvision.transforms
包含用於操作您的圖片資料的工具Pillow (
PIL
) 是我們最初用於載入圖片檔案的工具當然,我們還需要來自
flask
的類別
預處理¶
def transform_image(infile):
input_transforms = [transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile)
timg = my_transforms(image)
timg.unsqueeze_(0)
return timg
Web 請求給我們一個圖片檔案,但我們的模型期望一個形狀為 (N, 3, 224, 224) 的 PyTorch 張量,其中 N 是輸入批次中的項目數。(我們將只有 1 的批次大小。) 我們要做的第一件事是組成一組 TorchVision 轉換,它們調整圖片大小並裁剪圖片、將其轉換為張量,然後正規化張量中的值。(有關此正規化的更多資訊,請參閱 torchvision.models_
的文件。)
之後,我們開啟檔案並應用轉換。轉換傳回一個形狀為 (3, 224, 224) 的張量 - 一個 224x224 圖片的 3 個顏色通道。因為我們需要將這個單一圖片做成一個批次,我們使用 unsqueeze_(0)
呼叫來就地修改張量,方法是新增一個新的第一維度。張量包含相同的資料,但現在的形狀為 (1, 3, 224, 224)。
一般來說,即使您不使用圖片資料,您也需要將 HTTP 請求中的輸入轉換為 PyTorch 可以使用的張量。
推論¶
def get_prediction(input_tensor):
outputs = model.forward(input_tensor)
_, y_hat = outputs.max(1)
prediction = y_hat.item()
return prediction
推論本身是最簡單的部分:當我們將輸入張量傳遞給模型時,我們會取回一個值張量,該張量代表模型估計的圖片屬於特定類別的可能性。 max()
呼叫找到具有最大可能性的類別,並傳回該值以及 ImageNet 類別索引。最後,我們使用 item()
呼叫從包含它的張量中提取該類別索引,並將其傳回。
後處理¶
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
render_prediction()
方法將預測的類別索引對應到人類可讀的類別標籤。通常,在從您的模型取得預測後,執行後處理以使預測準備好供人類使用或供另一段軟體使用。
執行完整的 Flask 應用程式¶
將以下內容貼到名為 app.py
的檔案中
import io
import json
import os
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
model = models.densenet121(pretrained=True) # Trained on 1000 classes from ImageNet
model.eval() # Turns off autograd
img_class_map = None
mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
with open (mapping_file_path) as f:
img_class_map = json.load(f)
# Transform input into the form our model expects
def transform_image(infile):
input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile) # Open the image file
timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor
timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1
return timg
# Get a prediction
def get_prediction(input_tensor):
outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes
_, y_hat = outputs.max(1) # Extract the most likely class
prediction = y_hat.item() # Extract the int value from the PyTorch tensor
return prediction
# Make the prediction human-readable
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
@app.route('/', methods=['GET'])
def root():
return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
要從您的 shell 提示符啟動伺服器,請發出以下指令
FLASK_APP=app.py flask run
預設情況下,您的 Flask 伺服器正在監聽埠 5000。伺服器執行後,開啟另一個終端機視窗,並測試您的新推論伺服器
curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"
如果一切設置正確,您應該會收到類似以下的響應
{"class_id":285,"class_name":"Egyptian_cat"}
重要資源¶
pytorch.org 網站提供安裝說明、更多文件和教學