chinese_wwm_ext_pytorch初体验

  |   0 评论   |   0 浏览

背景

使用

下载资源包

https://www.kaggle.com/terrychanorg/chinese-wwm-ext-pytorch

安装依赖

pip3 install torch pytorch_pretrained_bert

计算得分

import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer,BertForMaskedLM
# Load pre-trained model (weights)

with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('path/to/chinese_wwm_ext_pytorch/')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('/path/to/chinese_wwm_ext_pytorch/')

def score(sentence):
    tokenize_input = tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
    sentence_loss=0.
    for i,word in enumerate(tokenize_input):

        tokenize_input[i]='[MASK]'
        mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
        word_loss=model(mask_input, masked_lm_labels=tensor_input).data.numpy()
        sentence_loss +=word_loss
        #print("Word: %s : %f"%(word, np.exp(-word_loss)))

    return np.exp(sentence_loss/len(tokenize_input))

if __name__ == '__main__':
    print(score("你好 世界"))

结果

% python3 a.py
3706.2062184104116