Pytorch模型通过Tensorflow Serving进行部署与TensorFlow Lite部署实践

最近一个项目需要使用Tensorflow lite, 官网上的解释又特别简单,主要给了一个例子,但是这个例子和官网的解释又不一样。。。。这里简单记录下操作方法。

添加依赖

某些加载的方法,依赖并不支持。

在自己的build.grandle的依赖中添加:

登录后复制

implementation 'org.tensorflow:tensorflow-lite:1.15.0'
    implementation 'org.tensorflow:proto:1.15.0'
    

模型转化

对于keras或者用model建立的模型,有直接的转化函数,对于sess建立的模型(计算图),有很多方法不再适用。我尝试了一种是可以使用的:

  • 先转化为saved model
  • 通过saved model 转化为tflite file

转化为saved model:

登录后复制

tf.saved_model.simple_save(
            sess,
            "%s/trained model/flows" % (savedir),
            inputs={"input": input},
            outputs={"flows": flows}
        )
        

注意:其中的inputs和outputs必须是张量,否则会提醒你array没有name属性。

inputs就是feed_dict中那些place holder,outputs可以填输出的张量,也可以直接填操作本身。

转为tflite file:

登录后复制

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(path)
tflite_model = converter.convert()
open(path+"/converted_model.tflite", "wb").write(tflite_model)

至此得到了tflite文件

模型加载

注意点:
1: 读取文件时需要申请权限
2: 对于Interpreter的构造函数,只有MappedByteBuffer那个可以使用,用File的那个会报错。

直接把生成的文件放在“我的手机”中,如果是虚拟机,则是storage/emulated/0文件夹下,然后使用

登录后复制

Environment.getExternalStorageDirectory().getPath();

后面加上文件名称,可以得到该文件的路径。

文件加载函数:

登录后复制

public MappedByteBuffer getFile(String fileName) throws IOException {
        File f = new File(fileName);
        FileInputStream in = new FileInputStream(f);
        FileChannel channel = in.getChannel();
        return channel.map(FileChannel.MapMode.READ_ONLY, 0, channel.size());
    }
    

把该函数返回的作为Interpreter的输入,可以生成一个解释器。

权限申请:

在manifest中增加:

登录后复制

<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />

对于Android11以下的版本,需要额外增加:

登录后复制

android:requestLegacyExternalStorage="true"

申请权限:

登录后复制

// Storage Permissions
private static final int REQUEST_EXTERNAL_STORAGE = 1;
private static String[] PERMISSIONS_STORAGE = {
        Manifest.permission.READ_EXTERNAL_STORAGE,
        Manifest.permission.WRITE_EXTERNAL_STORAGE
};

public static void verifyStoragePermissions(Activity activity) {
    // Check if we have write permission
    int permission = ActivityCompat.checkSelfPermission(activity, Manifest.permission.WRITE_EXTERNAL_STORAGE);

    if (permission != PackageManager.PERMISSION_GRANTED) {
        // We don't have permission so prompt the user
        ActivityCompat.requestPermissions(
                activity,
                PERMISSIONS_STORAGE,
                REQUEST_EXTERNAL_STORAGE
        );
    }
}

推理

有两个函数可以使用:

单输入输出

登录后复制

interpreter.run(inputs, outputs);

inputs和ouputs都是高维数组,比如正常的都是:

登录后复制

float[][][][] inputs = new float[1][h][w][3];

注意不管是输入还是输出,数组申请空间时候必须把每一个维度都确定,否则可能导致拷贝出错或者维度出错。

比如:

Caused by: java.lang.IllegalStateException: Internal error: Unexpected failure when preparing tensor allocations: tensorflow/lite/kernels/concatenation.cc:74 t->dims->data[d] != t0->dims->data[d] (1 != 2)
Node number 9 (CONCATENATION) failed to prepare.

输入维度错误,或者没有给定输入的维度

ouputs也是多维数组,和inputs相同。

多输入

登录后复制

interpreter.runForMultipleInputsOutputs([Objects], outputs)

第一个参数是一个数组,里面是输入的张量。我猜测顺序和convert函数中命名的字典序有关。如果不放心的话,可以使用:

登录后复制

Tensor a = interpreter.getInputTensor(0);

然后看一下copyShape是否符合。

因此,如果张量是4维的,那么第一个参数是一个5维的数组。

outputs是一个Map,可以声明如下:

登录后复制

Map outputs = HashMap<Interger,Objects>();
outputs.put(0,Objects);

其他

将Bitmap转为float数组:

登录后复制

public Bitmap toBitmap(float[][][] a){
        int[] colors = new int[h_res*w_res];
        for(int i = 0; i < w_res; i++){
            for(int j = 0; j < h_res; j++){
                int r = (int)(a[j][i][0]*255) & 0xff;
                int g = (int)(a[j][i][1]*255) & 0xff;
                int b = (int)(a[j][i][2]*255) & 0xff;
                colors[i*h_res+j] = 0xff000000 | (r<<16) | (g<<8) | b;
            }
        }
        return Bitmap.createBitmap(colors, 0, w_res, w_res, h_res, bitmaps.get(0).getConfig());
    }
    

Bitmap转float

登录后复制

public float[][][] toFloatArray(Bitmap a){
        // 获取图片宽度和高度
        int width = a.getWidth();   // 图片宽度
        int height = a.getHeight();  //图片高度
        int channel = 3; // 3个通道
        int r = 0;
        int g = 1;
        int b = 2;
        int[] data = new int[width*height];
        a.getPixels(data, 0,width,0,0,width,height);
        // 将二维数组转换为三维数组
        float[][][] rgb4DArray = new float[width][height][channel]; // 图像数量*宽度*高度*通道数


        for(int i = 0; i < width; i++) {
            for (int j = 0; j < height; j++) {
                rgb4DArray[j][i][r] = ((data[i*height + j] & 0xff0000) >> 16)/255.0f;
                rgb4DArray[j][i][g] = ((data[i*height + j] & 0xff00) >> 8)/255.0f;
                rgb4DArray[j][i][b] = (data[i*height + j] & 0xff)/255.0f;
            }
        }
        return rgb4DArray;
    }
    


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空