标签:cluster hid 预测概率 for https host %s join soft
为了方便查阅,我将完整的BERT源码分析整理成了PDF版本,可以在微信公众号NewBeeNLP后台直接下载。
继续之前没有介绍完的Pre-training部分,在上一篇中我们已经完成了对输入数据的处理,接下来看看BERT是怎么完成Masked LM和Next Sentence Prediction两个任务的训练的。
get_masked_lm_output
函数用于计算任务#1的训练loss。输入为BertModel的最后一层sequence_output输出([batch_size, seq_length, hidden_size]),因为对一个序列的MASK标记的预测属于标注问题,需要整个sequence的输出状态。
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
"""Get loss and log probs for the masked LM."""
# 获取mask词的encode
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# 在输出之前添加一个非线性变换,只在预训练阶段起作用
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# output_weights是和传入的word embedding一样的
# 这里再添加一个bias
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
# label_ids表示mask掉的Token的id
label_ids = tf.reshape(label_ids, [-1])
label_weights = tf.reshape(label_weights, [-1])
one_hot_labels = tf.one_hot(
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
# 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding)
# 而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉。
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
numerator = tf.reduce_sum(label_weights * per_example_loss)
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator
return (loss, per_example_loss, log_probs)
get_next_sentence_output
函数用于计算任务#2的训练loss。输入为BertModel的最后一层pooled_output输出([batch_size, hidden_size]),因为该任务属于二分类问题,所以只需要每个序列的第一个token【CLS】即可。
def get_next_sentence_output(bert_config, input_tensor, labels):
"""Get loss and log probs for the next sentence prediction."""
# 标签0表示 下一个句子关系成立; 标签1表示 下一个句子关系不成立。
# 这个分类器的参数在实际Fine-tuning阶段会丢弃掉
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, bert_config.hidden_size],
initializer=modeling.create_initializer(bert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, per_example_loss, log_probs)
module_fn_builder
函数,用于构造Estimator使用的model_fn
。定义好了上述两个训练任务,就可以写出训练过程,之后将训练集传入自动训练。
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
def model_fn(features, labels, mode, params):
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
next_sentence_labels = features["next_sentence_labels"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
# 创建Transformer实例对象
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
# 获得MASK LM任务的批损失,平均损失以及预测概率矩阵
(masked_lm_loss,
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights)
# 获得NEXT SENTENCE PREDICTION任务的批损失,平均损失以及预测概率矩阵
(next_sentence_loss, next_sentence_example_loss,
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels)
# 总的损失定义为两者之和
total_loss = masked_lm_loss + next_sentence_loss
# 获取所有变量
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
# 如果有之前保存的模型,则进行恢复
if init_checkpoint:
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train