如果是在嵌入式设备上运行tensorflow模型,最好的方式是转化为tflite模型,然后在使用tensorflow官方提供的tflite c++接口进行推理,这样的好处是,运行速度非常的快,并且几乎无损失,当然tflite可以使用量化版的,也可以使用非量化版的。量化版的速度提升和模型大小缩小。 模型转化: 根据官方网站,有各种训练后转化为tflite模型的接口,这里我使用的保存模型方式为 tf.saved_model.save(model, config.save_model_dir)
保存模型后,对模型进行转换,代码如下 import tensorflow as tf
import config
if __name__ == '__main__':
# GPU settings
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# load model
converter = tf.lite.TFLiteConverter.from_saved_model(config.save_model_dir)
tflite_model = converter.convert()
#model = tf.saved_model.load(config.save_model_dir)
#concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
#concrete_func.inputs[0].set_shape([1, 128, 256, 3])
#converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
#tflite_model = converter.convert()
open(config.TFLite_model_dir, "wb").write(tflite_model)
上述gpu的配置需要根据tensorflow的版本进行调整,版本不同,接口也不同,注释掉的部分为修改输入的尺寸或者输入的batch的大小。 模型推断 首先是读取模型: #读取模型
std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile("resnet.tflite");
构建模型: tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter);
设置模型的一些参数,包括(数据类型,线程数等): interpreter->AllocateTensors();
interpreter->SetAllowFp16PrecisionForFp32(true);
interpreter->SetNumThreads(4); //quad core 线程数
模型输入: In = interpreter->inputs()[0];
model_height = interpreter->tensor(In)->dims->data[1];
model_width = interpreter->tensor(In)->dims->data[2];
model_channels = interpreter->tensor(In)->dims->data[3];
运行模型: interpreter->Invoke(); // run your model 输出结果: int output = interpreter->outputs()[0];
TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;
auto output_size = output_dims->data[output_dims->size - 1];
cout << "output_size: " << output_size << "\n";
float* output_a = interpreter->typed_output_tensor(output);//输出结果
总结:对于tflite,整个流程还是挺简单的,但是对于效果,我跑到是resnet18,在服务器上进行训练的模型,验证集和测试集,以及训练集都达到了99%, 跑resnet.tflite时,使用图片进行跑也能达到这个结果,但是如果是在摄像头上直接获取图片,然后输入tflite模型中,就会发生错误率挺高,但是使用同样的数据,然后使用caffe的resnet进行训练,模型直接使用opencv进行调用,就不会发生这样的事,图片判断正确的情况很高。 文件如下 tflite.zip
|