CRF总结

介绍CRF的文章看起来总是高深莫测,本文不介绍CRF的数学原理,旨在通过实战及定性上对CRF有一个认识。

BERT+Bi-LSTM+CRF实现NER

参考文章: https://zhuanlan.zhihu.com/p/453350271
参考代码: https://github.com/XavierWww/Chinese-Medical-Entity-Recognition

首先需要对NER标注的数据进行预处理得到文字/标注对,NER任务有多种标注方式,具体方案在本文不详细阐述,本文旨在展示CRF如何用。
直接上 BERT+Bi-LSTM+CRF 模型的代码吧,代码对各个变量代表的意思,及shape进行了详细的标注

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torchcrf import CRF
class Bert_BiLSTM_CRF(nn.Module):

def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256):
super(Bert_BiLSTM_CRF, self).__init__()
self.tag_to_ix = tag_to_ix
self.tagset_size = len(tag_to_ix)
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim

self.bert = BertModel.from_pretrained("bert-base-chinese")
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2,
num_layers=2, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(p=0.1)
self.linear = nn.Linear(hidden_dim, self.tagset_size)
self.crf = CRF(self.tagset_size, batch_first=True)

def _get_features(self, sentence):
with torch.no_grad():
embeds_out = self.bert(sentence) # ([64, 228, 768])
enc, _ = self.lstm(embeds_out['last_hidden_state']) # ([64, 228, 256])
enc = self.dropout(enc)
feats = self.linear(enc) # ([64, 228, 16])
return feats

def forward(self, sentence, tags, mask, is_test=False):
# sentence ([64, 228]) tags ([64, 228]) mask ([64, 228])
# 64 为 batch size 228为该batch size的最长的句子长度
emissions = self._get_features(sentence)
if not is_test: # Training,return loss
loss=-self.crf.forward(emissions, tags, mask, reduction='mean')
return loss
else: # Testing,return decoding
decode=self.crf.decode(emissions, mask)
return decode

从以上代码可以发现,CRF的训练细节及测试步骤全部封装在CRF类中,训练时调用forward函数,测试时调用decode函数就行。
另外值得注意的是每个句子长度不一,怎样训练效率更高且减小pad的影响呢,这里作者的dataloader如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class NerDataset(Dataset):
''' Generate our dataset '''
def __getitem__(self, idx):
words, tags = self.sents[idx], self.tags_li[idx]
token_ids = tokenizer.convert_tokens_to_ids(words)
laebl_ids = [tag2idx[tag] for tag in tags]
seqlen = len(laebl_ids)
return token_ids, laebl_ids, seqlen

train_iter = data.DataLoader(dataset=train_dataset,
batch_size=ner.batch_size,
shuffle=True,
num_workers=4,
collate_fn=PadBatch)
def PadBatch(batch):
maxlen = max([i[2] for i in batch])
token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch])
label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch])
mask = (token_tensors > 0)
return token_tensors, label_tensors, mask

从上述代码可以看出来主要还是DataLoader中collate_fn参数的作用,相当于是告诉DataLoader如何组织 batch 中的每一项。