码迷,mamicode.com
首页 > Web开发 > 详细

MxNet 模型转Tensorflow pb模型

时间:2019-07-04 19:08:52      阅读:696      评论:0      收藏:0      [点我收藏+]

标签:spl   nis   install   init   with   node   rom   coding   http   

用mmdnn实现模型转换

参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

  1. 安装mmdnn
    pip install mmdnn

     

  2. 准备好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50为例        https://github.com/deepinsight/insightface
  3. 用mmdnn运行命令行
    python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d resnet50 --inputShape 3,112,112 

     

     会生成resnet50.json(可视化文件) resnet50.npy(权重参数) resnet50.pb(网络结构)三个文件。

  4. 用mmdnn运行命令行
    python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py 

     

     生成tf_resnet50.py文件,可以调用tf_resnet50.py中的KitModel函数加载npy权重参数重新生成原网络框架。

  5. 打开tf_resnet.py文件,修改load_weights()中的代码 (tensorflow=1.14.0报错) 

     try:
            weights_dict = np.load(weight_file).item()
        except:
            weights_dict = np.load(weight_file, encoding=bytes).item()

    改为

     try:
            weights_dict = np.load(weight_file, allow_pickle=True).item()
    except:
            weights_dict = np.load(weight_file, allow_pickle=True, encoding=bytes).item()

     

  6. 基于resnet50.npy和tf_resnet50.py文??件,固化参数,生成PB文件:

    import tensorflow as tf
    import tf_resnet50 as tf_fun
    def netWork():
        model=tf_fun.KitModel("./resnet50.npy")
        return model
    def freeze_graph(output_graph):
        output_node_names = "output"
        data,fc1=netWork()
        fc1=tf.identity(fc1,name="output")
    
        graph = tf.get_default_graph()  # 獲得默認的圖
        input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
                sess=sess,
                input_graph_def=input_graph_def,  # 等於:sess.graph_def
                output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
    
            with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                f.write(output_graph_def.SerializeToString())  # 序列化輸出
    
    if __name__ == __main__:
        freeze_graph("frozen_insightface_r50.pb")
        print("finish!")

     

  7. 采用tensorflow的post-train quantization离线量化方法(有一定的精度损失)转换成tflite模型,从而完成端侧的模型部署:
    import tensorflow as tf
    
    convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
                                                      input_shapes={"data":[1,112,112,3]})
    convert.post_training_quantize=True
    tflite_model=convert.convert()
    open("quantized_insightface_r50.tflite","wb").write(tflite_model)
    print("finish!")

     

MxNet 模型转Tensorflow pb模型

标签:spl   nis   install   init   with   node   rom   coding   http   

原文地址:https://www.cnblogs.com/qiangz/p/11134240.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!