在Python中部署pytorch(带有flask的REST API)

湛鸿
2023-12-01

这是在生产中部署 PyTorch 模型的系列教程中的第一篇。以这种方式使用 Flask 是迄今为止开始为PyTorch模型提供服务的最简单方法,但它不适用于具有高性能要求的用例。

API 定义

我们将首先定义我们的 API 端点、请求和响应类型。我们的 API 端点将位于 /predict,它使用包含图像的文件参数接收 HTTP POST 请求。响应将是包含预测的 JSON 响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖

安装依赖:

$ pip install Flask==2.0.1 torchvision==0.10.0

简单的web服务

下面是一个 简单的web服务:

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

将上述代码段保存在名为 app.py 的文件中,您现在可以通过键入以下内容来运行 Flask 开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在 Web 浏览器中访问 http://localhost:5000/ 时,您将看到 Hello World!文本。
我们将对上面的代码片段稍作修改,使其适合我们的 API 定义。首先,我们将重命名要预测的方法。我们将端点路径更新为 /predict。由于图像文件将通过 HTTP POST 请求发送,我们将对其进行更新,使其也仅接受 POST 请求:

@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

我们还将更改响应类型,使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。更新后的 app.py 文件现在是:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

预测

在接下来的部分中,将专注于编写推理代码。这将涉及两部分,一是准备图像,以便将其馈送到 DenseNet,接下来,将编写代码以从模型中获得实际预测。

准备图像

DenseNet 模型要求图像是大小为 224 x 224 的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。我们将使用来自 torchvision 库的转换并构建一个转换管道,它根据需要转换我们的图像。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法以字节为单位获取图像数据,应用一系列变换并返回一个张量。要测试上述方法,请以字节模式读取图像文件,然后查看是否返回张量。

with open("img.jpg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在将使用预训练的 DenseNet 121 模型来预测图像类别。我们将使用来自 torchvision 库的一个,加载模型并进行推理。虽然我们将在本示例中使用预训练模型,但您可以对自己的模型使用相同的方法。

from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量 y_hat 将包含预测类 id 的索引。但是,我们需要一个人类可读的类名。为此,我们需要一个类 id 来命名映射。将此文件下载为 imagenet_class_index.json。该文件包含 ImageNet 类 id 到 ImageNet 类名的映射。我们将加载这个 JSON 文件并获取预测索引的类名。

import json

imagenet_class_index = json.load(open('/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用 imagenet_class_index 字典之前,首先我们将张量值转换为字符串值,因为 imagenet_class_index 字典中的键是字符串。我们将测试我们的上述方法:

with open("img.jpg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

你会得到一个像这样的响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类 ID,第二项是可读的名称。

将模型集成到我们的 API 服务器中

在这最后一部分中,将把模型添加到 Flask API 服务器中。由于 API 服务器应该获取图像文件,将更新 predict 方法以从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

app.py 文件现已完成。以下是完整版;用保存文件的路径替换路径:

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

测试服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

可以使用 requests 库向应用程序发送 POST 请求:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印 resp.json() 将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

出自:https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html

 类似资料: