码迷,mamicode.com
首页 > 其他好文 > 详细

年龄_性别识别

时间:2020-06-16 15:02:05      阅读:124      评论:0      收藏:0      [点我收藏+]

标签:als   app   inf   路径   constant   version   转换   tput   isp   

参考开源项目:年龄_性别识别

1.识别效果如下图

技术图片

2.keras模型转pb模型,方便模型的迁移和rknn平台的使用,代码1如下:

技术图片
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
from wide_resnet import WideResNet
import tensorflow as tf
from tensorflow.python.framework import graph_io

print(tf.__version__)
import keras as ks
print(ks.__version__)
import platform

print (platform.python_version())

def freeze_graph(graph, session, output_node_names, model_name):
    with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
        graph_io.write_graph(graphdef_frozen, "tmp", os.path.basename(model_name) + ".pb", as_text=False)
        print("done")




def pb_transfer():
    weight_file = "E:\\python_project\\age-gender-estimation-master\\pretrained_models\\weights.28-3.73.hdf5"

    output_fld =./
    output_graph_name = age-gender.pb
    tf.keras.backend.set_learning_phase(0)

    img_size = 64
    model = WideResNet(img_size, depth=16, k=8)()
    model.load_weights(weight_file)
    for out in model.outputs:
        print(out.op.name)

    session = tf.keras.backend.get_session()
    freeze_graph(session.graph, session, [out.op.name for out in model.outputs], weight_file)



if __name__ == __main__:
    pb_transfer()
View Code

代码2如下:

技术图片
# coding=utf-8

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
#路径参数
weight_file_path = "E:\\python_project\\age-gender-estimation-master\\pretrained_models\\weights.28-3.73.hdf5"
output_graph_name = ttt.pb
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
    if osp.exists(output_dir) == False:
        os.mkdir(output_dir)
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.output[i],out_prefix + str(i + 1))
    sess = K.get_session()
    from tensorflow.python.framework import graph_util,graph_io
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
    graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = "./"
#加载模型
#h5_model = load_model(weight_file_path)
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
from wide_resnet import WideResNet
import tensorflow as tf
weight_file = "E:\\python_project\\age-gender-estimation-master\\pretrained_models\\weights.28-3.73.hdf5"

output_fld =./
tf.keras.backend.set_learning_phase(0)
img_size = 64
model = WideResNet(img_size, depth=16, k=8)()
model.load_weights(weight_file)
h5_to_pb(model,output_dir = output_dir,model_name = output_graph_name)
print(model saved)
View Code

3.推理代码如下:

技术图片
import tensorflow as tf
from tensorflow.python.platform import gfile
import os
import cv2
import numpy as np
import time

from keras.layers import Input, Activation, add, Dense, Flatten, Dropout

#facenet_model_checkpoint ="E:\\python_project\\age-gender-estimation-master\\tmp\\weights.28-3.73.hdf5.pb"
facenet_model_checkpoint ="E:\\python_project\\age-gender-estimation-master\\ttt.pb"


def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print(Model filename: %s % model_exp)
        with gfile.FastGFile(model_exp,rb) as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name=‘‘)



def main():
    img_size = 64
    with tf.Graph().as_default():
        with tf.Session() as sess:
            print("load model:" + facenet_model_checkpoint)
            load_model(facenet_model_checkpoint)
            print("load over.")


            images_placeholder = tf.get_default_graph().get_tensor_by_name("input_1:0")
            gender = tf.get_default_graph().get_tensor_by_name("pred_gender/Softmax:0")
            age = tf.get_default_graph().get_tensor_by_name("pred_age/Softmax:0")
            while True:
                img = cv2.imread("E:\\python_project\\age-gender-estimation-master\\0036.jpg")
                faces = cv2.resize(img, (img_size, img_size))
                faces = faces[np.newaxis, :, :, :]
                start_time = time.time()
                feed_dict = {images_placeholder: faces}
                results = sess.run([gender,age], feed_dict=feed_dict)
                predicted_genders = results[0]
               # print(predicted_genders)



                ages = np.arange(0, 101).reshape(101, 1)
                predicted_ages = results[1].dot(ages).flatten()
                print("spend_time is", time.time() - start_time)
                print(int(predicted_ages[0]))
                if predicted_genders[0][0] < 0.5:
                    print("m")
                else:
                    print("f")

if __name__ == __main__:
    main()
View Code

4.推理时间在tx2上为:70ms

 

年龄_性别识别

标签:als   app   inf   路径   constant   version   转换   tput   isp   

原文地址:https://www.cnblogs.com/liuwenhua/p/13140763.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!