机器学习demo之Flask框架开发鸢尾花模型API

整体步骤:选择鸢尾花数据集进行训练模型,然后根据生成的模型来开发模型API,通过API进行预测(Postman)

1.鸢尾花数据集准备(网上找呗)
2.鸢尾花数据训练GBDT.pkl模型,数据集、训练脚本、预测脚本同一级目录下即可
各包版本(建议使用anaconda,然后在pip install xgboost就行了):
matplotlib3.2.2
joblib
0.16.0
xgboost1.2.1
Flask
1.1.2
numpy1.18.5
pandas
1.0.5
scikit_learn==0.23.2

# -*- coding:utf-8 -*-

from sklearn.model_selection  import train_test_split
from xgboost import XGBClassifier
import numpy as np
from joblib import dump,load
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


def load_data():
    import pandas as pd 
    iris = pd.read_csv("Iris.csv")
    iris.drop(axis=0, columns="Id", inplace=True)   #取除多余行
    iris.drop_duplicates(inplace=True)              #取除重复的行
    labels = iris.iloc[::,4:].values                #花的类型标签
    species = set(labels.flatten())                 #flatten()展平数组                      
    label_type = {}
    for i,j in enumerate(species):
        label_type[j]=i
    #print(label_type)
    #用数字代替花的类别,对应的关系为label_type = {'Iris-versicolor': 0, 'Iris-virginica': 1, 'Iris-setosa': 2}
    iris["Species"] = iris["Species"].map(label_type)
    labels = iris.iloc[::,4:].values                #数值类型标签
    datas = iris.iloc[::,:4].values                 #样本
    return datas,labels.flatten()

if __name__ == "__main__":
    X,Y = load_data()
    X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.3,random_state=1)
    clf=XGBClassifier(base_score=0.5, booster='gbtree', learning_rate=0.05, max_depth=8, n_estimators=50)
    #训练
    clf.fit(X_train, Y_train)
    #保存模型
    dump(clf,"GDBT.pkl")
    #预测
    model = load("GDBT.pkl")
    #打分
    print('GDBT the accuracy:\t',model.score(X_test, Y_test))
    plt.scatter(range(len(X_test)),model.predict(X_test),marker="x",c="r",label="预测类别") 
    plt.scatter(range(len(X_test)),Y_test,marker="o",c="b",label="真实类别") #局部汉字 fontproperties="SimHei"
    plt.title("预测与真实效果图")
    plt.legend()
    plt.savefig("预测与真实效果图.png")
    #plt.show()
    

3.Flask框架开发鸢尾花模型API

# -*- coding:utf-8 -*-
#模型API服务参考:https://www.imooc.com/article/259009
"""
输入:
{
    "SepalLengthCm": 5.7,
    "SepalWidthCm": 3.0,
    "PetalLengthCm": 4.2,
    "PetalWidthCm": 1.2
}
输出:
{
    "species": [
        0
    ]
}

"""

from joblib  import load
import numpy as np
import traceback
import sys
import pandas as pd
from flask import request
from flask import Flask
from flask import jsonify

app = Flask(__name__)
#http://192.168.1.34:8000/api/GBDT_iris/predict
name = "GBDT_iris"
@app.route(f'/api/{name}/predict', methods=['POST']) 
def predict():    
    if model:        
        try:
            json = request.json 
            data =  pd.DataFrame([json])
            predict_data = data.values
            #return jsonify({"data":predict_data.tolist()})   #predict_data为numpy的array类型,json不识别,所以转换
            prediction = model.predict(predict_data)
            #return jsonify({"species":prediction.tolist()})   #输入预测类别
            for i in list(label_type.keys()):
                if label_type[i] == prediction[0]:           
                    return jsonify({"鸢尾花的预测类型为":[i]})     
        except:           
            return jsonify({'trace': traceback.format_exc()}) 
        
    else:             
        return '模型不可用'

if __name__ == '__main__': 
    label_type = {'Iris-versicolor': 0, 'Iris-virginica': 1, 'Iris-setosa': 2}
    try:
        port = int(sys.argv[1])  
    except:
        port = 8888  
    model = load('GBDT.pkl') 
    #predict()
    app.run(host='192.168.1.34', port=port, debug=True)

4.使用postman调用预测API
注意:在使用postman调API时,预测基本必须处于运行中
在这里插入图片描述在这里插入图片描述
模型API参考:https://www.imooc.com/article/259009

注:纸上得来终觉浅,绝知此事要躬行。


版权声明:本文为weixin_43509698原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。