码迷,mamicode.com
首页 > 其他好文 > 详细

原来你是这样的BERT,i了i了! —— 超详细BERT介绍(三)BERT下游任务

时间:2020-06-21 15:37:35      阅读:474      评论:0      收藏:0      [点我收藏+]

标签:The   sel   选项   ons   active   输入   独立   cti   student   

原来你是这样的BERT,i了i了! —— 超详细BERT介绍(三)BERT下游任务

BERTBidirectional Encoder Representations from Transformers)是谷歌在2018年10月推出的深度语言表示模型。

一经推出便席卷整个NLP领域,带来了革命性的进步。
从此,无数英雄好汉竞相投身于这场追剧(芝麻街)运动。
只听得这边G家110亿,那边M家又1750亿,真是好不热闹!

然而大家真的了解BERT的具体构造,以及使用细节吗?
本文就带大家来细品一下。


前言

本系列文章分成三篇介绍BERT,上两篇分别介绍了BERT主模型的结构及其组件相关和BERT预训练相关,这一篇是最终话,介绍如何将BERT应用到不同的下游任务。

文章中的一些缩写:NLP(natural language processing)自然语言处理;CV(computer vision)计算机视觉;DL(deep learning)深度学习;NLP&DL 自然语言处理和深度学习的交叉领域;CV&DL 计算机视觉和深度学习的交叉领域。

文章公式中的向量均为行向量,矩阵或张量的形状均按照PyTorch的方式描述。
向量、矩阵或张量后的括号表示其形状。

本系列文章的代码均是基于transformers库(v2.11.0)的代码(基于Python语言、PyTorch框架)。
为便于理解,简化了原代码中不必要的部分,并保持主要功能等价。

阅读本系列文章需要一些背景知识,包括Word2VecLSTMTransformer-BaseELMoGPT等,由于本文不想过于冗长(其实是懒),以及相信来看本文的读者们也都是冲着BERT来的,所以这部分内容还请读者们自行学习。
本文假设读者们均已有相关背景知识。


目录


3、序列分类

序列分类任务就是输入一个序列,输出整个序列的标签。
输入的序列可以是单句也可以是双句。
单句序列分类任务就是文本分类(text classification)任务,包括主题(topic)、情感(sentiment)、垃圾邮件(spam)等的分类任务;双句序列分类任务包括相似度(similarity)、释义(paraphrase)、蕴含(entailment)等的分类任务。
根据标签数量分,可以分成单标签和多标签(multi-label)的分类任务。
根据标签的类别数量分,可以分成二分类或三分类、五分类等多分类任务。

BERT中的序列分类任务包括单句和双句的单标签回归或分类任务,涉及到语言可接受性(linguistic acceptability)、情感、相似度、释义、蕴含等特征的分类,即GLUEGeneral Language Understanding Evaluation)中的任务。

如下为一个相似度回归任务的例子(来自transformers库的示例):

5.000	A plane is taking off. ||| An air plane is taking off.
3.800	A man is playing a large flute. ||| A man is playing a flute.
3.800	A man is spreading shreded cheese on a pizza. ||| A man is spreading shredded cheese on an uncooked pizza.

其中,最左边的是标签,表示两句话的相似度分数,分数越高,相似度越高,分数的取值范围是\([0, 5]\)

再如下为一个双句释义二分类任务的例子(来自transformers库的示例):

1	He said the foodservice pie business ... ||| The foodservice pie business ...
0	Magnarelli said Racicot hated ... ||| His wife said he was ...
0	The dollar was at 116.92 yen against the yen ... ||| The dollar was at 116.78 yen JPY ...

其中,最左边的是标签,如果后句是前句的释义,即解释说明,那么标签为1,否则为0。

序列分类代码如下:

代码
# BERT之序列分类
class BertForSeqCls(BertPreTrainedModel):
	def __init__(self, config):
		super().__init__(config)
		self.config = config
		# 标签的类别数量
		self.num_labels = config.num_labels

		# 主模型
		self.bert = BertModel(config)
		self.dropout = nn.Dropout(config.hidden_dropout_prob)
		# 线性回归或分类器
		self.cls = nn.Linear(config.hidden_size, config.num_labels)
		# 回归或分类损失函数
		self.loss_fct = LossRgrsCls(config.num_labels)

		self.init_weights()
	def forward(self,
			tok_ids,  # 标记编码(batch_size * seq_length)
			pos_ids=None,  # 位置编码(batch_size * seq_length)
			sent_pos_ids=None,  # 句子位置编码(batch_size * seq_length)
			att_masks=None,  # 注意力掩码(batch_size * seq_length)
			labels=None,  # 标签(batch_size)
	):
		_, pooled_outputs = self.bert(
			tok_ids,
			pos_ids=pos_ids,
			sent_pos_ids=sent_pos_ids,
			att_masks=att_masks,
		)

		pooled_outputs = self.dropout(pooled_outputs)
		logits = self.cls(pooled_outputs)

		if labels is None:
			return logits  # 对数几率(batch_size * num_labels)

		loss = self.loss_fct(logits, labels)
		return loss

其中,
num_labels是标签的类别数量(注意:并不是标签数量,BERT的序列分类任务均为单标签分类任务),=1时为回归任务。


4、标记分类

标记分类任务就是输入一个序列,输出序列中每个标记的标签。
输入的序列一般是单句。
标记分类任务就是序列标注(sequence tagging)任务,包括中文分词(Chinese word segmentation)、词性标注(Part-of-Speech tagging,POS tagging)、命名实体识别(named entity recognition,NER)等。

序列标注任务常规的做法是BIO标注,B表示需要标注的片段的开头标记,I表示非开头标记,O表示不需要标注的标记。

如下为一个NER任务的例子(来自transformers库的示例):

例子
Schartau B-PER
sagte O
dem O
" O
Tagesspiegel B-ORG
" O
vom O
Freitag O
, O
Fischer B-PER
sei O
" O
in O
einer O
Weise O
aufgetreten O
, O
die O
alles O
andere O
als O
überzeugend O
war O
" O
. O

Firmengründer O
Wolf B-PER
Peter I-PER
Bree I-PER
arbeitete O
Anfang O
der O
siebziger O
Jahre O
als O
M?belvertreter O
, O
als O
er O
einen O
fliegenden O
H?ndler O
aus O
dem O
Libanon B-LOC
traf O
. O

Ob O
sie O
dabei O
nach O
dem O
Runden O
Tisch O
am O
23. O
April O
in O
Berlin B-LOC
durch O
ein O
p?dagogisches O
Konzept O
unterstützt O
wird O
, O
ist O
allerdings O
zu O
bezweifeln O
. O

其中,每一行为一个标记和其标签,空行分隔不同的句子;PER是人名、ORG是组织名、LOC是地名。

标记分类代码如下:

代码
# BERT之标记分类
class BertForTokCls(BertPreTrainedModel):
	def __init__(self, config):
		super().__init__(config)
		self.config = config
		# 标签的类别数量
		self.num_labels = config.num_labels

		# 主模型
		self.bert = BertModel(config)
		self.dropout = nn.Dropout(config.hidden_dropout_prob)
		# 线性分类器
		self.cls = nn.Linear(config.hidden_size, config.num_labels)
		# 分类损失函数
		self.loss_fct = LossCls(config.num_labels)

		self.init_weights()

	def forward(self,
			tok_ids,  # 标记编码(batch_size * seq_length)
			pos_ids=None,  # 位置编码(batch_size * seq_length)
			sent_pos_ids=None,  # 句子位置编码(batch_size * seq_length)
			att_masks=None,  # 注意力掩码(batch_size * seq_length)
			labels=None,  # 标签(batch_size * seq_length)
	):
		outputs, _ = self.bert(
			tok_ids,
			pos_ids=pos_ids,
			sent_pos_ids=sent_pos_ids,
			att_masks=att_masks,
		)

		outputs = self.dropout(outputs)
		logits = self.cls(outputs)

		if labels is None:
			return logits  # 对数几率(batch_size * seq_length * num_labels)

		# 只计算非填充标记的损失
		if att_masks is not None:
			active = att_masks.view(-1)>0
			logits = logits.view(-1, self.num_labels)[active]
			labels = labels.view(-1)[active]
		loss = self.loss_fct(logits, labels)
		return loss

5、选择题

BERT中的选择题是给出前句以及num_choices个后句,选择最优的后句。
如下(来自SWAG数据集):

2
Students lower their eyes nervously. She
pats her shoulder, then saunters toward someone.
turns with two students.
walks slowly towards someone.
wheels around as her dog thunders out.

其中,第一行是标签,第二行是前句,第三行到最后是四个后句;标签数字从0开始计数,即标签为2表示第三个(walks slowly towards someone.)为正确选项。

BERT将每个样本转换成num_choices个双句:

Students lower their eyes nervously. ||| She pats her shoulder, then saunters toward someone.
Students lower their eyes nervously. ||| She turns with two students.
Students lower their eyes nervously. ||| She walks slowly towards someone.
Students lower their eyes nervously. ||| She wheels around as her dog thunders out.

然后每个双句的序列表示产生一个对数几率,num_choices个双句就得到一个长度为num_choices的对数几率向量,最后将这个向量作为这个样本的输出,计算损失即可。

选择题代码如下:

代码
# BERT之选择题
class BertForMultiChoice(BertPreTrainedModel):
	def __init__(self, config):
		super().__init__(config)
		self.config = config
		# 选项个数
		self.num_choices = config.num_choices

		# 主模型
		self.bert = BertModel(config)
		self.dropout = nn.Dropout(config.hidden_dropout_prob)
		# 线性分类器
		self.cls = nn.Linear(config.hidden_size, 1)
		# 分类损失函数
		self.loss_fct = LossCls(1)

		self.init_weights()
	def forward(self,
			tok_ids,  # 标记编码(batch_size * num_choices * seq_length)
			pos_ids=None,  # 位置编码(batch_size * num_choices * seq_length)
			sent_pos_ids=None,  # 句子位置编码(batch_size * num_choices * seq_length)
			att_masks=None,  # 注意力掩码(batch_size * num_choices * seq_length)
			labels=None,  # 标签(batch_size)
	):
		seq_length = tok_ids.shape[-1]

		# 调整形状,每个前句-后句选项对看作一个双句输入
		tok_ids = tok_ids.view(-1, seq_length)
		if pos_ids is not None: pos_ids = pos_ids.view(-1, seq_length)
		if sent_pos_ids is not None: sent_pos_ids = sent_pos_ids.view(-1, seq_length)
		if att_masks is not None: att_masks = att_masks.view(-1, seq_length)

		_, pooled_outputs = self.bert(
			tok_ids,
			pos_ids=pos_ids,
			sent_pos_ids=sent_pos_ids,
			att_masks=att_masks,
		)

		pooled_outputs = self.dropout(pooled_outputs)
		logits = self.cls(pooled_outputs)
		# 调整形状,每num_choices个对数几率看作一个样本的输出
		logits = logits.view(-1, self.num_choices)

		if labels is None:
			return logits  # 对数几率(batch_size * num_choices)

		loss = self.loss_fct(logits, labels)
		return loss

其中,
num_choices是选项个数。


6、问答

BERT中的问答任务其实是抽取式的机器阅读理解(machine reading comprehension)任务,即给定一段话,给定一个问题,问题的答案来自这段话的某个连续的片段。
如下(来自transformers库的示例):

0	Computational complexity theory
What branch of theoretical computer science deals with broadly classifying computational problems by difficulty and class of relationship?
Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty ...

其中,第一行是答案,答案左边的数字表示这个答案在给定的这段话的起始位置(从0开始计数),第二行是问题,第三行是给定的一段话。

BERT将这个抽取式任务转化为一个预测答案起始和结束位置的分类任务,标签的类别数量是seq_length,起始位置和结束位置分别预测,即相当于两个标签。
注意:这个起始和结束位置是标记化等预处理后答案在输入的编码向量里的位置。

BERT将所有的标记表示转化成两个对数几率,然后横向切片,得到两个长度为seq_length的对数几率向量,分别作为起始和结束位置的预测,最后计算损失即可。

问答代码如下:

代码
# BERT之问答
class BertForQustAns(BertPreTrainedModel):
	def __init__(self, config):
		super().__init__(config)
		self.config = config

		# 主模型
		self.bert = BertModel(config)
		# 线性分类器
		self.cls = nn.Linear(config.hidden_size, 2)

		self.init_weights()
	def forward(self,
			tok_ids,  # 标记编码(batch_size * seq_length)
			pos_ids=None,  # 位置编码(batch_size * seq_length)
			sent_pos_ids=None,  # 句子位置编码(batch_size * seq_length)
			att_masks=None,  # 注意力掩码(batch_size * seq_length)
			start_pos=None,  # 起始位置标签(batch_size)
			end_pos=None,  # 结束位置标签(batch_size)
	):

		seq_length = tok_ids.shape[-1]

		outputs, _ = self.bert(
			tok_ids,
			pos_ids=pos_ids,
			sent_pos_ids=sent_pos_ids,
			att_masks=att_masks,
		)

		logits = self.cls(outputs)
		# 拆分起始和结束位置对数几率
		start_logits, end_logits = logits.split(1, dim=-1)
		start_logits = start_logits.view(-1, seq_length)
		end_logits = end_logits.view(-1, seq_length)

		if start_pos is None or end_pos is None:
			return (
				start_logits,  # 起始位置对数几率(batch_size * seq_length)
				end_logits,  # 结束位置对数几率(batch_size * seq_length)
			)

		# 标签值裁剪,使值 (- [0, seq_length],
		# 其中合法值 (- [0, seq_length-1],非法值 = seq_length
		start_pos = start_pos.clamp(0, seq_length)
		end_pos = end_pos.clamp(0, seq_length)

		# ignore_index=seq_length:忽略标签值 = seq_length对应的损失
		loss_fct = LossCls(seq_length, ignore_index=seq_length)
		start_loss = loss_fct(start_logits, start_pos)
		end_loss = loss_fct(end_logits, end_pos)
		loss = (start_loss + end_loss) / 2
		return loss

后记

本文作为系列的最后一篇文章,详细地介绍了BERT下游任务,BERT的通用性就体现在只需要添加少量模块就能应用到各种不同的下游任务。

BERT充分地利用了主模型输出的标记表示和序列表示,并对其进行一定地修改,从而可以应用到各种不同的下游任务中。
其中应用到选择题和问答任务的方式特别巧妙,分别活用了序列和标记表示。

然而,如同预训练,标记分类任务每个标记的标签是独立产生的,以及问答任务的起始和结束位置也是独立产生的,这其实不是非常合理。


原来你是这样的BERT,i了i了! —— 超详细BERT介绍(三)BERT下游任务

标签:The   sel   选项   ons   active   输入   独立   cti   student   

原文地址:https://www.cnblogs.com/wangzb96/p/bert_downstream_tasks.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!