介绍CRF的文章看起来总是高深莫测,本文不介绍CRF的数学原理,旨在通过实战及定性上对CRF有一个认识。
BERT+Bi-LSTM+CRF实现NER
参考文章: https://zhuanlan.zhihu.com/p/453350271
参考代码: https://github.com/XavierWww/Chinese-Medical-Entity-Recognition
首先需要对NER标注的数据进行预处理得到文字/标注对,NER任务有多种标注方式,具体方案在本文不详细阐述,本文旨在展示CRF如何用。
直接上 BERT+Bi-LSTM+CRF 模型的代码吧,代码对各个变量代表的意思,及shape进行了详细的标注
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| from torchcrf import CRF class Bert_BiLSTM_CRF(nn.Module):
def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256): super(Bert_BiLSTM_CRF, self).__init__() self.tag_to_ix = tag_to_ix self.tagset_size = len(tag_to_ix) self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim
self.bert = BertModel.from_pretrained("bert-base-chinese") self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2, num_layers=2, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(p=0.1) self.linear = nn.Linear(hidden_dim, self.tagset_size) self.crf = CRF(self.tagset_size, batch_first=True) def _get_features(self, sentence): with torch.no_grad(): embeds_out = self.bert(sentence) enc, _ = self.lstm(embeds_out['last_hidden_state']) enc = self.dropout(enc) feats = self.linear(enc) return feats
def forward(self, sentence, tags, mask, is_test=False): emissions = self._get_features(sentence) if not is_test: loss=-self.crf.forward(emissions, tags, mask, reduction='mean') return loss else: decode=self.crf.decode(emissions, mask) return decode
|
从以上代码可以发现,CRF的训练细节及测试步骤全部封装在CRF类中,训练时调用forward函数,测试时调用decode函数就行。
另外值得注意的是每个句子长度不一,怎样训练效率更高且减小pad的影响呢,这里作者的dataloader如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| class NerDataset(Dataset): ''' Generate our dataset ''' def __getitem__(self, idx): words, tags = self.sents[idx], self.tags_li[idx] token_ids = tokenizer.convert_tokens_to_ids(words) laebl_ids = [tag2idx[tag] for tag in tags] seqlen = len(laebl_ids) return token_ids, laebl_ids, seqlen
train_iter = data.DataLoader(dataset=train_dataset, batch_size=ner.batch_size, shuffle=True, num_workers=4, collate_fn=PadBatch) def PadBatch(batch): maxlen = max([i[2] for i in batch]) token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch]) label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch]) mask = (token_tensors > 0) return token_tensors, label_tensors, mask
|
从上述代码可以看出来主要还是DataLoader中collate_fn参数的作用,相当于是告诉DataLoader如何组织 batch 中的每一项。