Tensorflow1.x实现BiLstm+CRF

前面几章简单介绍了CRF层的作用以及CRF层的损失函数,详见:

BiLSTM中的CRF层(一)简介

BiLSTM中的CRF层(二)CRF层

BiLSTM中的CRF层(三)CRF损失函数

下面使用tensorflow1.x版本实现BiLstm+CRF模型,并基于“万创杯”中医药天池大数据竞赛—中药说明书实体识别挑战的比赛数据实现中药NER任务。

1.bilstm+crf模型

该文件定义了embedding层,bilstm层,全链接层,crf层等模型。

# -*- coding: utf-8 -*-
# @Time    : 2020-10-09 21:15
# @Author  : xudong
# @email   : dongxu222mk@163.com
# @Site    : 
# @File    : bilstm_crf.py
# @Software: PyCharm

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.rnn import MultiRNNCell


class Linear:
    """
    全链接层
    """
    def __init__(self, scope_name, input_size, output_size,
                 drop_out=0., trainable=True):
        with tf.variable_scope(scope_name):
            self.W = tf.get_variable('W', [input_size, output_size],
                                initializer=tf.random_uniform_initializer(-0.25, 0.25),
                                trainable=trainable)

            self.b = tf.get_variable('b', [output_size],
                                initializer=tf.zeros_initializer(),
                                trainable=trainable)

        self.drop_out = tf.layers.Dropout(drop_out)

        self.output_size = output_size

    def __call__(self, inputs, training):
        size = tf.shape(inputs)
        input_trans = tf.reshape(inputs, [-1, size[-1]])
        input_trans = tf.nn.xw_plus_b(input_trans, self.W, self.b)
        input_trans = self.drop_out(input_trans, training=training)

        input_trans = tf.reshape(input_trans, [-1, size[1], self.output_size])

        return input_trans


class LookupTable:
    """
    embedding layer
    """
    def __init__(self, scope_name, vocab_size, embed_size, reuse=False, trainable=True):
        self.vocab_size = vocab_size
        self.embed_size = embed_size

        with tf.variable_scope(scope_name, reuse=bool(reuse)):
            self.embedding = tf.get_variable('embedding', [vocab_size, embed_size],
                                             initializer=tf.random_uniform_initializer(-0.25, 0.25),
                                             trainable=trainable)

    def __call__(self, input):
        input = tf.where(tf.less(input, self.vocab_size), input, tf.ones_like(input))
        return tf.nn.embedding_lookup(self.embedding, input)


class LstmBase:
    """
    build rnn cell
    """
    def build_rnn(self, hidden_size, num_layes):
        cells = []
        for i in range(num_layes):
            cell = LSTMCell(num_units=hidden_size,
                            state_is_tuple=True,
                            initializer=tf.random_uniform_initializer(-0.25, 0.25))
            cells.append(cell)
        cells = MultiRNNCell(cells, state_is_tuple=True)

        return cells


class BiLstm(LstmBase):
    """
    define the lstm
    """
    def __init__(self, scope_name, hidden_size, num_layers):
        super(BiLstm, self).__init__()
        assert hidden_size % 2 == 0
        hidden_size /= 2

        self.fw_rnns = []
        self.bw_rnns = []
        for i in range(num_layers):
            self.fw_rnns.append(self.build_rnn(hidden_size, 1))
            self.bw_rnns.append(self.build_rnn(hidden_size, 1))

        self.scope_name = scope_name

    def __call__(self, input, input_len):
        for idx, (fw_rnn, bw_rnn) in enumerate(zip(self.fw_rnns, self.bw_rnns)):
            scope_name = '{}_{}'.format(self.scope_name, idx)
            ctx, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_rnn, bw_rnn, input, sequence_length=input_len,
                dtype=tf.float32, time_major=False,
                scope=scope_name
            )
            input = tf.concat(ctx, -1)
        ctx = input
        return ctx


class BiLstm_Crf:
    def __init__(self, args, vocab_size, emb_size):
        # embedding
        scope_name = 'look_up'
        self.lookuptables = LookupTable(scope_name, vocab_size, emb_size)

        # rnn
        scope_name = 'bi_lstm'
        self.rnn = BiLstm(scope_name, args.hidden_dim, 1)

        # linear
        scope_name = 'linear'
        self.linear = Linear(scope_name, args.hidden_dim, args.num_tags,
                             drop_out=args.drop_out)

        # crf
        scope_name = 'crf_param'
        self.crf_param = tf.get_variable(scope_name, [args.num_tags, args.num_tags],
                                         dtype=tf.float32)

    def __call__(self, inputs, training):
        masks = tf.sign(inputs)
        sent_len = tf.reduce_sum(masks, axis=1)

        embedding = self.lookuptables(inputs)

        rnn_out = self.rnn(embedding, sent_len)

        logits = self.linear(rnn_out, training)

        pred_ids, _ = tf.contrib.crf.crf_decode(logits, self.crf_param, sent_len)

        return logits, pred_ids, self.crf_param

2.数据预处理

该文件是数据预处理,具体是将文本数据转换成id形式,并保存为pkl文件。

# -*- coding: utf-8 -*-
# @Time    : 2020-10-11 18:52
# @Author  : xudong
# @email   : dongxu222mk@163.com
# @Site    : 
# @File    : preprocess.py
# @Software: PyCharm
import os
import _pickle as cPickle
import pandas as pd
import random

"""
数据前处理
将数据处理成id,并封装成pkl形式
"""

tag_list = ['DRUG', 'DRUG_INGREDIENT',
            'DISEASE', 'SYMPTOM',
            'SYNDROME', 'DISEASE_GROUP',
            'FOOD_GROUP', 'FOOD',
            'PERSON_GROUP', 'DRUG_GROUP',
            'DRUG_DOSAGE', 'DRUG_TASTE',
            'DRUG_EFFICACY']
tag_dict = {'O': 0}

for tag in tag_list:
    tag_B = 'B-' + tag
    tag_I = 'I-' + tag
    tag_dict[tag_B] = len(tag_dict)
    tag_dict[tag_I] = len(tag_dict)

print(tag_dict)


def make_vocab(file_path):
    """
    构建词典
    :param file_path:
    :return:
    """
    data = pd.read_csv(file_path, sep='\t', header=None)
    data.columns = ['text', 'tag']
    vocab = {'PAD': 0, 'UNK': 1}
    words_list = []
    for index, row in data.iterrows():
        text = row['text']
        words = text.split('<#>')
        for word in words:
            words_list.append(word)

    random.shuffle(words_list)
    for word in words_list:
        if word not in vocab:
            vocab[word] = len(vocab)
    return vocab


def make_data(file_path, vocab):
    """
    构建数据
    :param file_path:
    :param vocab
    :return:
    """
    data = pd.read_csv(file_path, sep='\t', header=None)
    data.columns = ['text', 'tag']
    word_ids = []
    tag_ids = []
    for index, row in data.iterrows():
        text = row['text']
        tag_str = row['tag']

        tags = tag_str.split('<#>')
        # todo 需要按照逗号来继续分割
        words_sep = text.split('<#>。<#>')

        cnt = 0
        for word_text in words_sep:
            words = word_text.split('<#>')
            word_id = [vocab.get(word) if word in vocab else 1 for word in words]
            tag_id = [tag_dict.get(tag) for tag in tags[cnt:cnt+len(words)]]

            word_ids.append(word_id)
            tag_ids.append(tag_id)
            cnt = cnt + len(words) + 1


    return {'words': word_ids, 'tags': tag_ids}


def save_vocab(vocab, output):
    """
    save vocab dict
    :param vocab:
    :param output:
    :return:
    """
    with open(output, 'w', encoding='utf-8') as fr:
        for word in vocab:
            fr.write(word + '\t' + str(vocab.get(word)) + '\n')
    print('save vocab is ok.')


def main(output_path):
    """
    main method
    :param output_path:
    :return:
    """
    data = {}
    train_path = './data_path/train.txt'
    test_path = './data_path/test.txt'
    vocab = make_vocab(train_path)
    train_data = make_data(train_path, vocab)
    test_data = make_data(test_path, vocab)

    data['train'] = train_data
    data['test'] = test_data

    data_path = os.path.join(output_path, 'ner_data.pkl')
    cPickle.dump(data, open(data_path, 'wb'), protocol=2)
    print('save data to pkl ok.')

    vocab_path = os.path.join(output_path, 'ner_vocab.txt')
    save_vocab(vocab, vocab_path)


if __name__ == '__main__':
    output = './data_path/'
    main(output)
    data = cPickle.load(open('./data_path/ner_data.pkl', 'rb'))

    print(data['train']['words'][0])
    print(data['train']['tags'][0])

3.数据集构建类

构建模型训练所需的数据格式

# -*- coding: utf-8 -*-
# @Time    : 2020-10-09 21:18
# @Author  : xudong
# @email   : dongxu222mk@163.com
# @Site    : 
# @File    : datasets.py
# @Software: PyCharm
import numpy as np
import tensorflow as tf

"""
数据集构建类
"""


class DataBuilder:
    def __init__(self, data):
        self.words = np.asarray(data['words'])
        self.tags = np.asarray(data['tags'])

    @property
    def size(self):
        return len(self.words)

    def build_generator(self):
        """
        build data generator for model
        :return:
        """
        for word, tag in zip(self.words, self.tags):
            yield (word, len(word)), tag

    def build_dataset(self):
        """
        build dataset from generator
        :return:
        """
        dataset = tf.data.Dataset.from_generator(
            self.build_generator,
            ((tf.int64, tf.int64), tf.int64),
            ((tf.TensorShape([None]), tf.TensorShape([])), tf.TensorShape([None]))
        )
        return dataset

    def get_train_batch(self, dataset, batch_size, epoch):
        """
        get one batch train data
        :param dataset:
        :param batch_size:
        :param epoch:
        :return:
        """
        dataset = dataset.cache()\
            .shuffle(buffer_size=10000)\
            .padded_batch(batch_size, padded_shapes=(([None], []), [None]))\
            .repeat(epoch)
        return dataset.make_one_shot_iterator().get_next()

    def get_test_batch(self, dataset, batch_size):
        """
        get one batch test data
        :param dataset:
        :param batch_size:
        :return:
        """
        dataset = dataset.padded_batch(batch_size,
                                       padded_shapes=(([None], []), [None]))
        return dataset.make_one_shot_iterator().get_next()

3.模型训练

模型参数定义以及训练过程,并将训练完成后的模型保存为pb格式和ckpt格式

# -*- coding: utf-8 -*-
# @Time    : 2020-10-09 23:07
# @Author  : xudong
# @email   : dongxu222mk@163.com
# @Site    : 
# @File    : ner_main.py
# @Software: PyCharm
import sys
import os
import time
import tensorflow as tf
from data_utils import datasets

import _pickle as cPickle

from argparse import ArgumentParser
from models.bilstm_crf import BiLstm_Crf

parser = ArgumentParser()

parser.add_argument("--vocab_size", type=int, default=2500, help='vocab size')
parser.add_argument("--emb_size", type=int, default=300, help='emb size')
parser.add_argument("--train_path", type=str, default='./data_path/ner_data.pkl')
parser.add_argument("--test_path", type=str, default='./data_path/ner_data.pkl')
parser.add_argument("--model_dir", type=str, default='./model_ckpt/')
parser.add_argument("--model_export", type=str, default='./model_pb')
parser.add_argument("--hidden_dim", type=int, default=300)
parser.add_argument("--num_tags", type=int, default=27)
parser.add_argument("--drop_out", type=float, default=0.1)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epoch", type=int, default=1)
parser.add_argument("--type", type=str, default='lstm', help='[lstm,textcnn...]')


tf.logging.set_verbosity(tf.logging.INFO)
ARGS, unparsed = parser.parse_known_args()
print(ARGS)

sys.stdout.flush()


def init_data(file_name, type=None):
    """
    init data
    :param file_name:
    :param type:
    :return:
    """
    data = cPickle.load(open(file_name, 'rb'))[type]

    data_builder = datasets.DataBuilder(data)
    dataset = data_builder.build_dataset()

    def train_input():
        return data_builder.get_train_batch(dataset, ARGS.batch_size, ARGS.epoch)

    def test_input():
        return data_builder.get_test_batch(dataset, ARGS.batch_size)

    return train_input if type == 'train' else test_input


def make_model():
    """
    build model
    :return:
    """
    vocab_size = ARGS.vocab_size
    emb_size = ARGS.emb_size

    if ARGS.type == 'lstm':
        model = BiLstm_Crf(ARGS, vocab_size, emb_size)
    else:
        pass

    return model


def model_fn(features, labels, mode, params):
    """
    build model fn
    :return:
    """
    model = make_model()

    if isinstance(features, dict):
        features = features['words'], features['words_len']

    words, words_len = features

    if mode == tf.estimator.ModeKeys.PREDICT:
        _, pred_ids, _ = model(words, training=False)

        prediction = {'tag_ids': tf.identity(pred_ids, name='tag_ids')}

        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=prediction,
            export_outputs={'classify':tf.estimator.export.PredictOutput(prediction)}
        )
    else:
        tags = labels
        weights = tf.sequence_mask(words_len)
        if mode == tf.estimator.ModeKeys.TRAIN:
            logits, pred_ids, crf_params = model(words, training=True)

            log_like_lihood, _ = tf.contrib.crf.crf_log_likelihood(
                logits, tags, words_len, crf_params
            )
            loss = -tf.reduce_mean(log_like_lihood)
            accuracy = tf.metrics.accuracy(tags, pred_ids, weights)

            tf.identity(accuracy[1], name='train_accuracy')
            tf.summary.scalar('train_accuracy', accuracy[1])
            optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())
            )
        else:
            _, pred_ids, _ = model(words, training=False)
            accuracy = tf.metrics.accuracy(tags, pred_ids, weights)
            metrics = {
                'accuracy': accuracy
            }
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=tf.constant(0),
                eval_metric_ops=metrics
            )


def main_es(unparsed):
    """
    main method
    :param unparsed:
    :return:
    """
    cur_time = time.time()
    model_dir = ARGS.model_dir + str(int(cur_time))

    classifer = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=model_dir,
        params={}
    )

    # train
    train_input = init_data(ARGS.train_path, 'train')
    tensors_to_log = {'train_accuracy':'train_accuracy'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
    classifer.train(input_fn=train_input, hooks=[logging_hook])

    # eval
    test_input = init_data(ARGS.test_path, 'test')
    eval_res = classifer.evaluate(input_fn=test_input)
    print(f'Evaluation res is : \n\t{eval_res}')


    if ARGS.export_dir:
        words = tf.placeholder(tf.int64, [None, None], name='input_words')
        words_len = tf.placeholder(tf.int64, [None], name='input_len')
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'words': words,
            'words_len': words_len
        })
        path = os.path.join(ARGS.export_dir, str(int(cur_time)))
        classifer.export_savedmodel(path, input_fn)


if __name__ == '__main__':
    tf.app.run(main=main_es, argv=[sys.argv[0]] + unparsed)

上述代码+数据已打包上传到CSDN,传送门:《BiLstm+CRF实现命名实体识别代码

有空的话会上传到github,也可以留下邮箱,私发到位~~~~


版权声明:本文为u013963380原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。