ONNX 模型转换
-
ONNX是Open Neural Network Exchange的缩写,它的作用是在不同框架之间实现模型互相转换,本帖将会聚焦于pytorch转出到其它的模型
-
将pytorch的模型转化为onnx的中间格式,以便于转化到其它模型,参数dummy_input是pytorch的一个输入范例
state_dict = torch.load('./ONNX/best_model.dat', map_location='cpu')['state_dict'] net.load_state_dict(state_dict) net.eval() dummy_input = image # An example of args to input to the network. input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(314)] output_names = ["output1"] torch.onnx.export(net, dummy_input, "./ONNX/MobilenetV2.onnx", input_names=input_names, output_names=output_names, # verbose=True # It can show the structure of your network. )
-
将onnx格式的模型转化为tf_rep,并测试预测是否正确
import onnx_tf.backend as tf_backend model = onnx.load('./ONNX/MobilenetV2.onnx') tf_rep = tf_backend.prepare(model) np_onnx_image = np.array(image) with tf.Session() as persisted_sess: persisted_sess.graph.as_default() tf.import_graph_def(tf_rep.graph.as_graph_def(), name='') inp = persisted_sess.graph.get_tensor_by_name( tf_rep.tensor_dict[tf_rep.inputs[0]].name ) out = persisted_sess.graph.get_tensor_by_name( tf_rep.tensor_dict[tf_rep.outputs[0]].name ) print("inputs: {}, outputs: {}".format(tf_rep.tensor_dict[tf_rep.inputs[0]].name, tf_rep.tensor_dict[tf_rep.outputs[0]].name)) res = persisted_sess.run(out, {inp: np_onnx_image}) same = True if id2class[np.argmax(res)] == label else False print(same)
-
直接对tf_rep进行测试
output = tf_rep.run(np_onnx_image) same = True if id2class[np.argmax(output)] == label else False
-
将tf_rep转化为tensorflow的pb文件
tf_rep.export_graph('./ONNX/MobilenetV2.pb')
-
tf.pb测试
with tf.Session() as persisted_sess: with tf.gfile.GFile('./ONNX/MobilenetV2.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) persisted_sess.graph.as_default() tf.import_graph_def(graph_def, name='') inp = persisted_sess.graph.get_tensor_by_name('actual_input_1:0') out = persisted_sess.graph.get_tensor_by_name('add_10:0') feed_dict = {inp: np.array(image)} res = persisted_sess.run(out, feed_dict) same = True if id2class[np.argmax(res)] == label else False print(same)
-
tf.pb转化为tflite,实现移动端的部署
graph_def_file = "./ONNX/MobilenetV2.pb" input_arrays = ["actual_input_1"] output_arrays = ["add_10"] converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("./ONNX/MobilenetV2.tflite", "wb").write(tflite_model)
-
onnx转化为caffe2 pb文件
import caffe2.python.onnx.backend as cf2_backend from onnx_caffe2.backend import Caffe2Backend as cf2_backend2 cf2_rep = cf2_backend.prepare(model) output = cf2_rep.run(np_onnx_image.astype(np.float32)) same = True if id2class[np.argmax(output)] == label else False print(same) init_net, predict_net = cf2_backend2.onnx_graph_to_caffe2_net(model.graph) with open("./ONNX/Mobilenet_init_net.pb", "wb") as f: f.write(init_net.SerializeToString()) with open("./ONNX/Mobilenet_predict_net.pb", "wb") as f: f.write(predict_net.SerializeToString())