一、自定义数据读取函数,并且做好读取配置,注意data_path这个参数必须对应好,名字和load_dataset保持一致
def read_out(data_path):
"""
pass
"""
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
line_stripped = line.strip().split('\t')
if not line_stripped:
break
if len(line_stripped) == 2:
tokens = line_stripped[0].split("\002")
tags = line_stripped[1].split("\002")
else:
tokens = line_stripped.split("\002")
tags = []
yield {"tokens": tokens, "labels": tags}
train_ds = load_dataset(read_out, data_path=path_out_train_, lazy=False)
test_ds = load_dataset(read_out, data_path=path_out_test_, lazy=False)
二、标签不是自带加载的了,注意结合本地文件主动加载
label_vocab = load_dict(dict_path_)
label_num = len(label_vocab)
no_entity_id = label_num - 1
三 特别注意内部加载的函数,label会自动转换成id,如果是自己本地加载,需要在特征转换哪里传入相关参数,自己做好label到id的转换
def tokenize_and_align_labels(example, tokenizer, no_entity_id, label_vocab, max_seq_len=512):
"""
pass
"""
labels = example['labels']
example = example['tokens']
tokenized_input = tokenizer(
example,
return_length=True,
is_split_into_words=True,
max_seq_len=max_seq_len)
# -2 for [CLS] and [SEP]
if len(tokenized_input['input_ids']) - 2 < len(labels):
labels = labels[:len(tokenized_input['input_ids']) - 2]
# Read custom data locally, the system will not automatically convert it, you must manually convert label to id
tokenized_input['labels'] = [no_entity_id] + [label_vocab[x] for x in labels] + [no_entity_id]
tokenized_input['labels'] += [no_entity_id] * (
len(tokenized_input['input_ids']) - len(tokenized_input['labels']))
return tokenized_input
四、预测的时候不能是tensor,可以按tensor转换成numpy,或者直接用python数据输入
def do_predict(self, title, text):
"""
Entry function
"""
# Create dataset, tokenizer and dataloader.
predict_ds, raw_data = self.pre_data(text)
pred_list = []
len_list = []
num_of_examples = len(predict_ds)
start_idx = 0
while start_idx < num_of_examples:
end_idx = start_idx + args.batch_size
end_idx = end_idx if end_idx < num_of_examples else num_of_examples
batch_data = [
self.trans_func(example) for example in predict_ds[start_idx:end_idx]
]
start_idx += args.batch_size
input_ids, token_type_ids, length = self.batchify_fn(batch_data)
self.input_handles[0].copy_from_cpu(input_ids) # must not tensor
self.input_handles[1].copy_from_cpu(token_type_ids) # must not tensor
self.predictor.run()
logits = self.output_handle.copy_to_cpu()
pred = np.argmax(logits, axis=-1) # output is numpy
pred_list.append(pred)
len_list.append(length)
preds = self.parse_decodes(predict_ds, self.id2label, pred_list, len_list)
result = self.post_data(preds, title, raw_data)
return result
版权声明:本文为qq_15821487原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。