将TensorFlow模型部署到服务器的教程

前言

当一个TensorFlow模型训练出来的时候,为了投入到实际应用,所以就需要部署到服务器上。由于我本次所做的项目是一个javaweb的图像识别项目。所有我就想去寻找一下java调用TensorFlow训练模型的办法。

TensorFlow模型部署到服务器---TensorFlow2.0_h5

由于TensorFlow很久没更新的缘故,网上的博客大都是18/19年的,并且是基于TensorFlow1.0的,对于现在使用的TensorFlow2.0不太友好。


下面我简述一下TensorFlow1.0时期的方法:

1.动态模型生成不便

需要将训练的.h5模型转换成.pb模型,并且需要自己定义.pb模型的输入输出参数。(pb模型是一种基于动态图的模型)

pb的生成代码冗长、而且对初学者真滴不太友好

TensorFlow模型部署到服务器---TensorFlow2.0_tensorflow_02

相比之下.h5模型的生成代码就一行

TensorFlow模型部署到服务器---TensorFlow2.0_tensorflow_03

此外,这个生成pb模型的代码是否能照搬使用,还是一个问题,并且还可能报一些奇奇怪怪的错误。

2.maven导包不便

查阅资料发现java上的TensorFlow的jar包都是TensorFlow1.0的

TensorFlow模型部署到服务器---TensorFlow2.0_h5_04

现状:

TensorFlow模型部署到服务器---TensorFlow2.0_java_05

并且maven官网上的TensorFlow2.0的api已经改名成了tensorflow-core-api,并且网上相关方面的教程十分难找。由于网上都是导入的1.0的包,自己导入2.0的包之后,详细的调用教程可以说是没有。从上面也可以看出来TensorFlow对java的调用也不怎么重视了。所以这又给学习的途中徒增了很多困难。全新思路

思路一

用java直接调用训练好的模型很困难,那么我们想办法让java调用python脚本,让python脚本去调用.h5模型会不会更简单呢?

代码如下


登录后复制

package com.guard.service;import java.io.BufferedReader;import java.io.IOException;
import java.io.InputStreamReader;public class api_service {public String recognize(String path)
{//此处的path是图片路径Process proc;String res = null;
try {System.out.println("接受到的参数"+path);String[] cmd = new String[] 
{ "python", "E:\\machine_learning\\predict.py", path};proc = Runtime.getRuntime().exec(cmd);
BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
String line = null;while ((line = in.readLine()) != null) {System.out.println(line);
res = line;            }in.close();proc.waitFor();        } catch (IOException e) 
{e.printStackTrace();        } catch (InterruptedException e) {e.printStackTrace();        
}System.out.println(res+">>>>>>>>>>>");return res;    }}1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.
17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.



但是我们可以看出,这个其实是用java在win上跑了这样一个指令

TensorFlow模型部署到服务器---TensorFlow2.0_flask_06

虽然这个确实是一个好办法,但是这个路径参数需要事先知道服务器上的路径,并且在协作开发的时候,每个人的路径和环境就不同,虽然该方法能用,但是我认为还不够好。



思路二

我们可以直接用python的flask框架,直接生成一个api接口,就可以远程直接调用TensorFlow训练好的模型进行结果预测。

TensorFlow模型部署到服务器---TensorFlow2.0_h5_07

TensorFlow模型部署到服务器---TensorFlow2.0_flask_08

个人认为,这种方法相较于用java调用命令行,这种方法还是更加直观的

并且flask仅仅需要加个@app.route的注解就能实现,可谓是十分方便


下面是模型调用代码

model.py

登录后复制

import globimport sysimport osimport cv2import numpy as npimport tensorflow as tfimport 
image_processingdef model_ues(path):# 缩放图片大小为100*100w = 100h = 100# 测试图像的地址 (改为自己的)
# path_test = "resource/test24.jpg"api_token = 
"fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
path_test = image_processing.download_img(path,api_token)
# 创建保存图像的空列表imgs = []img = cv2.imread(path_test)img = cv2.resize(img, (w, h))
# 将每张经过处理的图像数据保存在之前创建的imgs空列表当中imgs.append(img)imgs = np.asarray(imgs, np
.float32)# print("shape of data:",imgs.shape)
# 导入模型model = tf.keras.models.load_model(r"resource/rice_0.93.h5")
# 创建图像标签列表rice_dict = {0: 'Rice blast', 1: 'Rice fleck',2: 'Rice koji disease',
3: 'Sheath blight'}# 将图像导入模型进行预测prediction = model.predict_classes(imgs)
# prediction = np.argmax(model.predict(imgs), 
axis=-1)# 绘制预测图像for i in range(np.size(prediction)):
# 打印每张图像的预测结果print(rice_dict[prediction[i]])return rice_dict[prediction[0]]
1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.
32.33.34.35.36.37.38.39.40.41.42.43.44.


为了实现图片外链接受,下面是图片下载脚本

image_processing.py

登录后复制

# coding: utf8import requestsimport randomdef 
download_img(img_url, api_token):print (img_url)header = {"Authorization": "Bearer " + api_token}
# 设置http header,视情况加需要的条目,这里的token是用来鉴权的一种方式r = requests
.get(img_url, headers=header, stream=True)print(r.status_code)
# 返回状态码file_img = 'resource/img.png'
# file_img = 'resource/'print(file_img)if r.status_code == 200:open(file_img, 'wb')
.write(r.content) # 将内容写入图片print("done")del rreturn file_img# if __name__ == '__main__':
#     # 下载要的图片#     
img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png"#     
api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
#     download_img(img_url, api_token)1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19
.20.21.22.23.24.



主程序脚本

app.py

登录后复制

from flask import Flask,render_template, url_for, 
request, json,jsonifyimport modelapp = Flask(__name__)
#设置编码app.config['JSON_AS_ASCII'] = False@app
.route('/test')def hello_world():return "hello world"@app.route('/predict', 
methods=['GET', 'POST'])def form_data():my_path = request
.form['path']print(my_path)str = model
.model_ues(my_path)print("http://127.0.0.1:5000/predict")return jsonify({'result':str,
'msg':'200'})if __name__ == '__main__':app.run()1.2.3.4.5.6.7.8.9.10.11.12.13.14.
15.16.17.18.19.20.21.22.


数据解析

虽然我们能够通过postman进行测试接受到回传的结果,但是我们要怎么用java实现呢??

1.使用postman生成大致代码框架(postman生成的代码可能不能直接运行)

TensorFlow模型部署到服务器---TensorFlow2.0_h5_09

这里我选用的是java-okhttp的方法,但其实使用Unirest写出来的代码更加简洁易懂。

登录后复制

public class Get_result {public  String getResult(String path) throws IOException 
{//        String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg";
OkHttpClient client = new OkHttpClient().newBuilder()                .build();
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "path="+path);
Request request = new Request.Builder()
.url("http://127.0.0.1:8000/predict")                
.method("POST", body)                
.addHeader("Content-Type", "application/x-www-form-urlencoded")                
.build();Response response = client.newCall(request).execute();String result = response
.body().string();System.out.println(result);            
}}1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.




登录后复制
{"msg": "200","result": "Rice fleck"}1.2.3.4.



获取到json数据之后,就需要对json数据进行解析

java上的解析原理是,先按照json编写一个类,之后用Gson对接受到的数据按照这个类进行规范化

(这里可以用GsonFormatPlus插件来自动生成这个实体类)

登录后复制

//Rice_result.java---为该json的实体类package com.guard.tool;import lombok.Data;
import lombok.NoArgsConstructor;@NoArgsConstructor@Datapublic class Rice_result 
{private String msg;private String result;}1.2.3.4.5.6.7.8.9.10.11.12.13.


下面是数据解析代码(和上面的okhttp获取json数据的代码连起来看)

登录后复制

//json数据解析Gson gson = new Gson();java.lang.reflect.
Type type = new TypeToken<Rice_result>(){}.getType();
Rice_result rice_result = gson.fromJson(result, type);
System.out.println(rice_result);if("200".equals(rice_result
.getMsg())){//            System.out.println(rice_result.getResult());
return Rice_result.convertdata(rice_result.getResult());        }else
 {//            System.out.println("获取结果出错!!");return "获取结果出错!!";        
 }1.2.3.4.5.6.7.8.9.10.11.12.



这样的话就可以进行json数据的解析了。图链制作

由于需要使用java发送post请求给flask的预测端口,那么就需要把本地上传的数据做成图链,把图链作为数据传给flask的预测端口,从而来接收结果。

由于前端js的知识大多遗忘,这里就选用了用java来发送一个post请求,获得回传的信息。

这里我使用的是sm.ms的图床(该图床无需登录,且速度快,算得上是一个好的选择)

登录后复制

//sm.ms的使用方法,建议看官方文档package com.guard.tool;import com.google
.gson.Gson;import com.google.gson.reflect.TypeToken;import okhttp3.*;
import java.io.File;import java.io.IOException;public class CloudUpload 
{public String toUrl(String path) throws IOException 
{//    String file_path = "E:/machine_learning/test8.jpg";
String file_path = path;OkHttpClient client = new OkHttpClient().newBuilder()            
.build();MediaType mediaType = MediaType.parse("multipart/form-data");
RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM)            
.addFormDataPart("smfile",file_path,RequestBody
.create(MediaType.parse("application/octet-stream"),new File(file_path)))            
.addFormDataPart("format","json")            
.build();Request request = new Request.Builder()            
.url("https://sm.ms/api/v2/upload")            
.method("POST", body)            
.addHeader("Content-Type", "multipart/form-data")            
.addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00")//
这里的user-agent和Cookie需要自己打开网站,到网站的页面去拿取            
.addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537
.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36")            
.addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; 
PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D")
.build();Response response = client.newCall(request)
.execute();String result = response.body().string();System.out.println(result);//   
String result = response.body().string();Gson gson = new Gson();
java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType();
Image_data imge_data = gson.fromJson(result, type);System.out.println(imge_data);
if (imge_data.getSuccess()){System.out.println(imge_data.getData().getUrl());
return imge_data.getData().getUrl();    }else{System.out.println("图片已经上传过一次!!");
System.out.println(imge_data.getImages());return imge_data.getImages();}  }}
1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.
34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.



回传的json结果--这个就需要使用上面的插件来进行处理

登录后复制

{"success": true,"code": "success","message": "Upload success.",
"data": {"file_id": 0,"width": 192,"height": 454,"filename": "test25.jpg","storename": 
"xICPNzFsfth5uJk.png","size": 124993,"path": "/2021/08/01/xICPNzFsfth5uJk.png",
"hash": "2exIdQGvBru46RKMyNjg3DhCTO","url": 
"https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png",
"delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO",
"page": "https://sm.ms/image/xICPNzFsfth5uJk"    },
"RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18"}1.2.3.4.5.6.7.8.9.10.11.12.13.14.
15.16.17.18.19.



总结

以上就是我这次在小组编写<基于CNN图像分类的水稻病虫害识别>这个项目中的收获。在此记录下学习路上踩过的一些坑和一些解决方法。



   


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空