标签:class transform OLE config rac rtm json rgba attention
from transformers import BertModel, BertTokenizer, BertConfig import torch enc = BertTokenizer.from_pretrained("bert-base-uncased") # 输入文本tokenize text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" tokenized_text = enc.tokenize(text) # 将一个token置为mask masked_index = 8 tokenized_text[masked_index] = ‘[MASK]‘ indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 创建虚拟输入 tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) dummy_input = [tokens_tensor, segments_tensors] # 初始化模型时将torchscript参数置为True config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True) # 初始化模型 model = BertModel(config) # 模型置为eval模式 model.eval() # 也可以从pretrained初始化模型 model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) # 创建trace traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) torch.jit.save(traced_model, "traced_bert.pt") # 加载模型 loaded_model = torch.jit.load("traced_model.pt") loaded_model.eval() all_encoder_layers, pooled_output = loaded_model(dummy_input) # 使用traced model进行推理 traced_model(tokens_tensor, segments_tensors)
标签:class transform OLE config rac rtm json rgba attention
原文地址:https://www.cnblogs.com/zcsh/p/14627734.html