标签:hidden dataset 参数 zed sof config its sig input
1.定义
tf.estimator.Estimator(model_fn=model_fn) #model_fn是一个方法
2.定义model_fn:
def model_fn_builder(self, bert_config, num_labels, init_checkpoint): """ :param bert_config: :param num_labels: :param init_checkpoint: :param learning_rate: :param num_train_steps: :param num_warmup_steps: :return: """ def model_fn(features, labels, mode, params): """ 这4个参数必须这样定义,就算是不用某个参数,也要把它定义出来 :param features: 是estimator传过来的feature :param labels: 数据标签 :param mode: tf.estimator.TRAIN/tf.estimator.EVAL/tf.estimator.PREDICTION :param params:这个暂时没弄懂 :return: """ input_ids = features[‘input_ids‘] input_mask = features[‘input_mask‘] segment_ids = features[‘segment_ids‘] probabilities = self.creat_model(bert_config, input_ids, input_mask, segment_ids, num_labels) # 这里是重点,这里要定义模型和要取模型的什么值 tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # assignment_map是模型所有的变量字典,init_checkpoint为模型文件 tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 加载模型 output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities) # 应为上面已经从create_model中获取了我们要做什么op,获取什么值,prediction为op或值 return output_spec return model_fn
def get_assignment_map_from_checkpoint(tvars, init_checkpoint): """Compute the union of the current variables and checkpoint variables.""" assignment_map = {} initialized_variable_names = {} name_to_variable = collections.OrderedDict() for var in tvars: name = var.name m = re.match("^(.*):\\d+$", name) if m is not None: name = m.group(1) name_to_variable[name] = var init_vars = tf.train.list_variables(init_checkpoint) assignment_map = collections.OrderedDict() for x in init_vars: (name, var) = (x[0], x[1]) if name not in name_to_variable: continue assignment_map[name] = name initialized_variable_names[name] = 1 initialized_variable_names[name + ":0"] = 1 return (assignment_map, initialized_variable_names)
def creat_model(self, bert_config, input_ids, input_mask, segment_ids, num_labels): """ :param bert_config: :param input_ids: :param input_mask: :param segment_ids: :param num_labels: :return: """ model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) output_layer = model.get_pooled_output() hidden_size = output_layer.shape[-1].value
# 获得已经训练好的值 output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probabilities = tf.nn.softmax(logits, axis=-1) return probabilities
2.使用estimator.predict
def predict(self, text_a, text_b):
"""
:param text_a:
:param text_b:
:return:
"""
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
input_ids, input_mask, segment_ids = self.convert_single_example(text_a, text_b)
features = collections.OrderedDict()
features[‘input_ids‘] = create_int_feature(input_ids)
features[‘input_mask‘] = create_int_feature(input_mask)
features[‘segment_ids‘] = create_int_feature(segment_ids)
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) # 将feature转换为example
self.writer.write(tf_example.SerializeToString())# 序列化example,写入tfrecord文件
result = self.estimator.predict(input_fn=self.predict_input_fn)
def file_based_input_fn_builder(self): """ :param examples: :return: """ name_to_features = { "input_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), "input_mask": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), "segment_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), } def decode_record(_examples, _name_to_feature): """ :param _examples: :param _name_to_feature: :return: """ return tf.parse_single_example(_examples, _name_to_feature) def input_fn(): """ :param params: :return: """ d = tf.data.TFRecordDataset(self.predict_file) # 读取TFRecord文件 d = d.apply( tf.data.experimental.map_and_batch( lambda record: decode_record(record, name_to_features), # 将序列化的feature映射到字典上 batch_size=1, drop_remainder=False)) return d # 这里返回的值会进入到定义estimator时的model_fn中,model_fn中的feature是d.get_next()的结果 return input_fn
1
标签:hidden dataset 参数 zed sof config its sig input
原文地址:https://www.cnblogs.com/callyblog/p/10216058.html