ONNX 模型转换



  • ONNX是Open Neural Network Exchange的缩写,它的作用是在不同框架之间实现模型互相转换,本帖将会聚焦于pytorch转出到其它的模型



  • 将pytorch的模型转化为onnx的中间格式,以便于转化到其它模型,参数dummy_input是pytorch的一个输入范例

    state_dict = torch.load('./ONNX/best_model.dat', map_location='cpu')['state_dict']
    net.load_state_dict(state_dict)
    net.eval()
    
    dummy_input = image                     # An example of args to input to the network.
    input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(314)]
    output_names = ["output1"]
    torch.onnx.export(net, dummy_input, "./ONNX/MobilenetV2.onnx",
                      input_names=input_names,
                      output_names=output_names,
                      # verbose=True        # It can show the structure of your network.
                      )
    


  • 将onnx格式的模型转化为tf_rep,并测试预测是否正确

    import onnx_tf.backend as tf_backend
    
    model = onnx.load('./ONNX/MobilenetV2.onnx')
    tf_rep = tf_backend.prepare(model)
    np_onnx_image = np.array(image)
    
    with tf.Session() as persisted_sess:
        persisted_sess.graph.as_default()
        tf.import_graph_def(tf_rep.graph.as_graph_def(), name='')
        inp = persisted_sess.graph.get_tensor_by_name(
            tf_rep.tensor_dict[tf_rep.inputs[0]].name
        )
        out = persisted_sess.graph.get_tensor_by_name(
            tf_rep.tensor_dict[tf_rep.outputs[0]].name
        )
        print("inputs: {}, outputs: {}".format(tf_rep.tensor_dict[tf_rep.inputs[0]].name,
                                               tf_rep.tensor_dict[tf_rep.outputs[0]].name))
        res = persisted_sess.run(out, {inp: np_onnx_image})
    
        same = True if id2class[np.argmax(res)] == label else False
        print(same)
    


  • 直接对tf_rep进行测试

    output = tf_rep.run(np_onnx_image)
    same = True if id2class[np.argmax(output)] == label else False
    


  • 将tf_rep转化为tensorflow的pb文件

    tf_rep.export_graph('./ONNX/MobilenetV2.pb')
    


  • tf.pb测试

    with tf.Session() as persisted_sess:
        with tf.gfile.GFile('./ONNX/MobilenetV2.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    
        inp = persisted_sess.graph.get_tensor_by_name('actual_input_1:0')
        out = persisted_sess.graph.get_tensor_by_name('add_10:0')
    
        feed_dict = {inp: np.array(image)}
    
        res = persisted_sess.run(out, feed_dict)
        same = True if id2class[np.argmax(res)] == label else False
        print(same)
    


  • tf.pb转化为tflite,实现移动端的部署

    graph_def_file = "./ONNX/MobilenetV2.pb"
    input_arrays = ["actual_input_1"]
    output_arrays = ["add_10"]
    converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
      graph_def_file, input_arrays, output_arrays)
    tflite_model = converter.convert()
    open("./ONNX/MobilenetV2.tflite", "wb").write(tflite_model)
    


  • onnx转化为caffe2 pb文件

    import caffe2.python.onnx.backend as cf2_backend
    from onnx_caffe2.backend import Caffe2Backend as cf2_backend2
    
    cf2_rep = cf2_backend.prepare(model)
    output = cf2_rep.run(np_onnx_image.astype(np.float32))
    
    same = True if id2class[np.argmax(output)] == label else False
    print(same)
    
    init_net, predict_net = cf2_backend2.onnx_graph_to_caffe2_net(model.graph)
    with open("./ONNX/Mobilenet_init_net.pb", "wb") as f:
        f.write(init_net.SerializeToString())
    with open("./ONNX/Mobilenet_predict_net.pb", "wb") as f:
        f.write(predict_net.SerializeToString())
    

 

Copyright © 2018 bbs.dian.org.cn All rights reserved.

与 Dian 的连接断开,我们正在尝试重连,请耐心等待