码迷,mamicode.com
首页 > 其他好文 > 详细

BERT:pytorch版,记录一次寻找cls.predictions.bias如何被从全0到load的过程

时间:2021-03-03 12:08:49      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:could   ict   war   put   tokenizer   过程   was   ack   odi   

一个简单的主入口是这样滴:

import sys
sys.path.append(..)

import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(bert-base-uncased)

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 6
tokenized_text[masked_index] = [MASK]
assert tokenized_text == [who, was, jim, henson, ?, jim, [MASK], was, a, puppet, ##eer]

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
# segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]).to(cuda)
segments_tensors = torch.tensor([segments_ids]).to(cuda)

# ========================= BertForMaskedLM ==============================
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained(bert-base-uncased)
model.to(cuda)
model.eval()

入口就是倒数第三行。

然后进到这里这个from_pretrained方法,这里的代码逻辑还是是有顺序的:

    @classmethod
    def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
        """
        Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        """
        if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
        else:
            archive_file = pretrained_model_name
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        except FileNotFoundError:
            logger.error(
                "Model name ‘{}‘ was not found in model name list ({}). "
                "We assumed ‘{}‘ was a path or url but couldn‘t find any file "
                "associated to this path or url.".format(
                    pretrained_model_name,
                    , .join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                    archive_file))
            return None
        if resolved_archive_file == archive_file:
            logger.info("loading archive file {}".format(archive_file))
        else:
            logger.info("loading archive file {} from cache at {}".format(
                archive_file, resolved_archive_file))
        tempdir = None
        if os.path.isdir(resolved_archive_file):
            serialization_dir = resolved_archive_file
        else:
            # Extract archive to temp dir
            tempdir = tempfile.mkdtemp()
            logger.info("extracting archive file {} to temp dir {}".format(
                resolved_archive_file, tempdir))
            with tarfile.open(resolved_archive_file, r:gz) as archive:
                archive.extractall(tempdir)
            serialization_dir = tempdir
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None:
            weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
            state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if gamma in key:
                new_key = key.replace(gamma, weight)
            if beta in key:
                new_key = key.replace(beta, bias)
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, _metadata, None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=‘‘):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + .)
        load(model, prefix=‘‘ if hasattr(model, bert) else bert.) #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换
        if len(missing_keys) > 0:
            logger.info("Weights of {} not initialized from pretrained model: {}".format(
                model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info("Weights from pretrained model not used in {}: {}".format(
                model.__class__.__name__, unexpected_keys))
        if tempdir:
            # Clean up temp dir
            shutil.rmtree(tempdir)
        return model

方法虽然长一点,但功能只是简单的载入模型然后load所有的预训练参数

然后注意其中这个load方法:

        def load(module, prefix=‘‘):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + .)
        load(model, prefix=‘‘ if hasattr(model, bert) else bert.) #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换

这个load方法载入了所有的预训练参数,那么这个bias到底是指的哪一个bias呢,是这个类:

class BertLMPredictionHead(nn.Module):
    """
    Arch:
        - BertPredictionHeadTransform (Input=torch.Size([1, 11, 768]), Output=torch.Size([1, 11, 768]))
            - Dense (768, 768)
            - Activation (gelu)
            - LayerNorm
        - Linear (768, 30522)

    y = W * x + b
    y = self.decoder.weight * self.decoder + self.bias
    i.e., y = torch.Size([30522, 768]) * torch.Size([768, 30522]) + torch.Size([30522])

    Input:
        torch.Size([1, 11, 768])
    Output:
        torch.Size([1, 11, 30522])

    The purpose is to Decode.
    """
    def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        """
        bert_model_embedding_weights.size():
            torch.Size([30522, 768])
        """
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)  # torch.Size([768, 30522])
        self.decoder.weight = bert_model_embedding_weights  # torch.Size([30522, 768])
        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

    def forward(self, hidden_states):
        """
        hidden_states:
            torch.Size([1, 11, 768])

        torch.Size([1, 11, 768]) --> torch.Size([1, 11, 768])
        """
        hidden_states = self.transform(hidden_states)
        """
        To predict the corresponding word in vocab. 
        
        Each of 11 positions has a tensor size=[30522] same to the size of vocab.
        """
        hidden_states = self.decoder(hidden_states) + self.bias  # torch.Size([1, 11, 30522])
        return hidden_states

就是这个bias:

        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

但是为啥我觉得奇怪呢,因为这个类并不是bert的model本身的权值,而是一个扩展类,用来预测【musk】的ids的,然后找到了这个预训练权重的大state_dict,是这样的:

cls.predictions.bias = {Tensor: 30522} tensor([-0.4191, -0.4202, -0.4191,  ..., -0.7900, -0.7822, -0.4965])
cls.predictions.transform.dense.weight = {Tensor: 768} tensor([[ 0.3681,  0.0147,  0.0430,  ...,  0.0384, -0.0296,  0.0227],\n        [ 0.0034,  0.2647, -0.0618,  ..., -0.0397, -0.0335,  0.0203],\n        [ 0.0179, -0.0060,  0.1788,  ...,  0.0267,  0.0555, -0.0432],\n        ...,\n        [ 0.0784,  0.0172,  0.0583,  ...,  0.3548,  0.0209, -0.0261],\n        [ 0.0175, -0.0466,  0.0834,  ...,  0.0069,  0.2132, -0.0503],\n        [-0.0832,  0.0461,  0.0490,  ..., -0.0116, -0.0594,  0.3525]])
cls.predictions.transform.dense.bias = {Tensor: 768} tensor([ 5.3890e-02,  1.0068e-01,  4.5532e-02,  2.7030e-02,  3.8845e-02,\n         3.3157e-02,  4.1188e-02,  2.8206e-02,  2.4197e-02,  1.3879e-01,\n         4.4386e-02,  4.8806e-02,  3.4415e-02,  5.9976e-02,  4.2772e-02,\n         2.5261e-02,  1.0533e-01,  4.1858e-02,  4.9016e-02,  9.8930e-02,\n         2.4026e-02,  4.1394e-02,  4.2273e-02,  2.9724e-02,  1.0857e-01,\n         4.8379e-02,  3.6337e-02,  5.2781e-02,  2.9902e-02,  2.6919e-02,\n         2.1127e-02,  4.8463e-02,  5.7389e-02,  4.8581e-02,  9.8151e-02,\n         6.3899e-02,  4.4544e-02,  4.9595e-02,  4.5315e-02,  3.5128e-02,\n         3.4962e-02,  6.9260e-02,  4.8273e-02,  4.3921e-02,  3.6126e-02,\n         3.9017e-02,  4.7681e-02,  4.1840e-02,  4.2173e-02,  5.2243e-02,\n         3.3530e-02,  4.3681e-02,  9.2896e-02, -1.3240e-01,  3.5652e-02,\n         3.2232e-02,  6.1398e-02,  3.9744e-02,  4.3546e-02,  3.7697e-02,\n         3.2834e-02,  2.5923e-02, -7.8080e-02,  2.7405e-02,  7.5468e-02,\n         3.8439e-02,  8.4586e-02,  3.0094e-02,  3.6...
cls.predictions.decoder.weight = {Tensor: 30522} tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],\n        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],\n        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],\n        ...,\n        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],\n        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],\n        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])
cls.seq_relationship.weight = {Tensor: 2} tensor([[-0.0154, -0.0062, -0.0137,  ..., -0.0128, -0.0099,  0.0006],\n        [ 0.0058,  0.0120,  0.0128,  ...,  0.0088,  0.0137, -0.0162]])
cls.seq_relationship.bias = {Tensor: 2} tensor([ 0.0211, -0.0021])

一共一百多个不同名称的权值,其中有这么几个权值命名是cls开头的

然后这个看了下代码逻辑,是按照名称载入的,所以这个模型的cls.predictions.bias就被替换掉了,本来是全0的。

我很奇怪,因为我觉得这个dict里面不太应该有这么个东西,后来想了一下,预训练的时候也可能用到了这个musk的功能类,权值就被保存下来了,

同时,cls.predictions.decoder.weight这个,也好像被重置了,那么它这个模型一开始就把这个weight用Embedding层的weight初始化,是没必要的,可以从代码里发现,这个权值从bert里直接塞过去是这样的:

Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       requires_grad=True)
-0.0102, -0.0615。。。。这个数字和上面第四行那个开头是一致的,可以简单断言这俩权值是相同的。
也就是Embedding层里面的权重,
至于结论嘛、。。。。这个预训练权重可以再缩缩。。(弱弱的手动狗头)

BERT:pytorch版,记录一次寻找cls.predictions.bias如何被从全0到load的过程

标签:could   ict   war   put   tokenizer   过程   was   ack   odi   

原文地址:https://www.cnblogs.com/DDBD/p/14470519.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!