标签:enc flat function type dia 扩展 const ext UNC
一、配置类
class BertConfig(object): """Configuration for `BertModel`.""" def __init__(self, vocab_size, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, initializer_range=0.02): """Constructs BertConfig. Args: vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in the Transformer encoder. intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the encoder and pooler. hidden_dropout_prob: The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size: The vocabulary size of the `token_type_ids` passed into `BertModel`. initializer_range: The stdev of the truncated_normal_initializer for initializing all weight matrices. """ self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range @classmethod def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" config = BertConfig(vocab_size=None) for (key, value) in six.iteritems(json_object): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with tf.gfile.GFile(json_file, "r") as reader: text = reader.read() return cls.from_dict(json.loads(text)) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
二、获取词向量(Embedding_lookup)
对于输入 word_ids,返回 embedding table。可以选用 one-hot 或者 tf.gather()
def embedding_lookup(input_ids, # word_id:【batch_size, seq_length】 vocab_size, embedding_size=128, initializer_range=0.02, word_embedding_name="word_embeddings", use_one_hot_embeddings=False): # 该函数默认输入的形状为【batch_size, seq_length, input_num】 # 如果输入为2D的【batch_size, seq_length】,则扩展到【batch_size, seq_length, 1】 if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=create_initializer(initializer_range)) flat_input_ids = tf.reshape(input_ids, [-1]) #【batch_size*seq_length*input_num】 if use_one_hot_embeddings: one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else: # 按索引取值 output = tf.gather(embedding_table, flat_input_ids) input_shape = get_shape_list(input_ids) # output:[batch_size, seq_length, num_inputs] # 转成:[batch_size, seq_length, num_inputs*embedding_size] output = tf.reshape(output, input_shape[0:-1] + [input_shape[-1] * embedding_size]) return (output, embedding_table)
1) tf.gather 用法
import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]]) index_a = tf.Variable([0,2]) b = tf.Variable([1,2,3,4,5,6,7,8,9,10]) index_b = tf.Variable([2,4,6,8]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #从a中取出第0个和第2个索引位置的值,因为a里面的元素值都是list,所以是取出了两个list print(sess.run(tf.gather(a, index_a))) #从b中取出索引位置为2,3,6,8的元素值。 print(sess.run(tf.gather(b, index_b))) out: # [[ 1 2 3 4 5] # [11 12 13 14 15]] # [3 5 7 9]
标签:enc flat function type dia 扩展 const ext UNC
原文地址:https://www.cnblogs.com/gczr/p/12382240.html