chinese_wwm_ext_pytorch初体验
背景
使用
下载资源包
https://www.kaggle.com/terrychanorg/chinese-wwm-ext-pytorch
安装依赖
pip3 install torch pytorch_pretrained_bert
计算得分
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer,BertForMaskedLM
# Load pre-trained model (weights)
with torch.no_grad():
model = BertForMaskedLM.from_pretrained('path/to/chinese_wwm_ext_pytorch/')
model.eval()
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('/path/to/chinese_wwm_ext_pytorch/')
def score(sentence):
tokenize_input = tokenizer.tokenize(sentence)
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
sentence_loss=0.
for i,word in enumerate(tokenize_input):
tokenize_input[i]='[MASK]'
mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
word_loss=model(mask_input, masked_lm_labels=tensor_input).data.numpy()
sentence_loss +=word_loss
#print("Word: %s : %f"%(word, np.exp(-word_loss)))
return np.exp(sentence_loss/len(tokenize_input))
if __name__ == '__main__':
print(score("你好 世界"))
结果
% python3 a.py
3706.2062184104116