码迷,mamicode.com
首页 > 其他好文 > 详细

transformer模型转torchscript格式

时间:2021-04-08 13:24:15      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:class   transform   OLE   config   rac   rtm   json   rgba   attention   

 

from transformers import BertModel, BertTokenizer, BertConfig
import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# 输入文本tokenize
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# 将一个token置为mask
masked_index = 8
tokenized_text[masked_index] = [MASK]
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# 创建虚拟输入
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# 初始化模型时将torchscript参数置为True
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)

# 初始化模型
model = BertModel(config)

# 模型置为eval模式
model.eval()

# 也可以从pretrained初始化模型
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

# 创建trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

# 加载模型
loaded_model = torch.jit.load("traced_model.pt")
loaded_model.eval()

all_encoder_layers, pooled_output = loaded_model(dummy_input)

# 使用traced model进行推理
traced_model(tokens_tensor, segments_tensors)

 

transformer模型转torchscript格式

标签:class   transform   OLE   config   rac   rtm   json   rgba   attention   

原文地址:https://www.cnblogs.com/zcsh/p/14627734.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!