算法中最主要的是用到了gensim.models.doc2vec将信息存储成词典进行建模并将信息文件转存到数据库中供其他代码使用。因为注释写的比较清晰。所以逻辑不再赘述,直接上代码看看就知道啦~因为在前面的算法中已经对数据库操作类进行了描述,这里就不再黏贴出来,有需要的亲可以翻看以前的文章参考即可~算法代码如下:
import tensorflow as tf
import os
import gensim
import re
import jieba.posseg as pseg
from gensim.models.doc2vec import Doc2Vec
from loadData import loadData
tf.flags.DEFINE_string("base_dir", ".", "files base_dir")
tf.flags.DEFINE_string("train_dir", ".\\train", "trainning files base_dir")
tf.flags.DEFINE_string("test_dir", ".\\test", "test files base_dir")
tf.flags.DEFINE_string("model_dir", "./doc2vecmodel", "Model directory from training run")
tf.flags.DEFINE_integer('vector_dim', 500,'dimensionality of characters')
tf.flags.DEFINE_integer('epoch_num', 70,'the number of epoch')
tf.flags.DEFINE_integer('min_count', 1,'ignore the words which freq lower than min_count')
tf.flags.DEFINE_integer('window', 3,'the max distance between relative content')
tf.flags.DEFINE_integer('negative', 5,'the number of negative that we can accept')
tf.flags.DEFINE_integer('workers', 4,'the module number of worker')
FLAGS = tf.flags.FLAGS
FLAGS.is_parsed()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
class Singleton(object):
def __new__(cls, *args, **kw):
if not hasattr(cls, '_instance'):
orig = super(Singleton, cls)
cls._instance = orig.__new__(cls, *args, **kw)
return cls._instance
class retrieve(Singleton):
doc_dict = {} # 由编号映射文章ID的字典doc_dict的key和value 分别为编号(数据库的id)和对应文章ID
model_dm = None #生成的模型
def __init__(self):
self.load_doc_index()
# 第一步,训练模型前,先将语料整理成规定的形式,这里用到TaggedDocument模型
def get_trainset(self):
x_train = []
list_name = os.listdir(FLAGS.train_dir) # 用于训练模型的语料先进行预处理
TaggededDocument = gensim.models.doc2vec.TaggedDocument # 输入输出内容都为 词袋 + tag列表, 作用是记录每一篇文章的大致内容,并给该文章编号
load = loadData()
for name in list_name:
user_file = os.path.join(FLAGS.train_dir, name)
# 语料预处理
if not os.path.isdir(user_file):
data = open(user_file,mode='rb').read()
item = self.getInfoDetail(data.decode('utf-8'))
index = self.getIndexFromDoct(name.strip(".txt"))#判断文章是否已经在字典中,如果在字典中默认已经存过数据库,不再执行数据库操作
if index == -1:
index = load.insertInfo(item)# 每一篇文章需要一个对应的编号
self.doc_dict[index] = name.strip(".txt")
line = ''
if '公司名称' in item:
line = line + item['公司名称']
if '经营范围' in item:
line = line + '/t' + item['经营范围']
if line == '':
line = data.decode('utf-8')
words = self.seperate_line(self.clean_str(line))
x_train.append(TaggededDocument(words, tags=[index]))
return x_train
# 第二步,初始化训练模型的参数,再保存训练结果以释放内存
def train(self, x_train, size=500, epoch_num=1):
self.model_dm = gensim.models.Doc2Vec(x_train, min_count=FLAGS.min_count, window=FLAGS.window, size=size, sample=1e-3, negative=FLAGS.negative, workers=FLAGS.workers) # 模型的初始化,设置参数
# 提供x_train可初始化, min_cout 忽略总频率低于这个的所有单词, window 预测的词与上下文词之间最大的距离, 用于预测 size 特征向量的维数 negative 接受杂质的个数 worker 工作模块数
self.model_dm.train(x_train, total_examples=self.model_dm.corpus_count, epochs=epoch_num) # corpus_count是文件个数 epochs 训练次数
self.model_dm.save(FLAGS.model_dir) # 保存模型训练结果,释放内存空间,后续可用load加载
return self.model_dm
#第三步,利用训练好的模型计算一个文章内容的相似度
def getMatchInfos(self, text):
matchInfos = []
load = loadData()
self.load_doc_index() # 加载index_file
self.model_dm = Doc2Vec.load(FLAGS.model_dir) # 加载训练的模型 model_dm输出类似Doc2Vec(dm/m,d500,n5,w3,s0.001,t4)
test_text = self.seperate_line(self.clean_str(text))
inferred_vector_dm = self.model_dm.infer_vector(test_text)
sims = self.model_dm.docvecs.most_similar([inferred_vector_dm], topn=5)
for index, sim in sims:
print(self.doc_dict[index])
print(sim)
# doc = x_train[int(index)]
# doc = doc[0] # doc包括词袋和编号,这里只要词袋
# for word in doc:
# print(word)
doc = load.getMatchInfo(index)#从数据库读取
matchInfos.append(doc)
return matchInfos
# 第四步,将字典内容写入文档方便查阅,下次打开程序可以用另外的函数加载,不用重新
def save_doc_index(self):
index_file = os.path.join(FLAGS.base_dir, "index_file.txt")
lines = ""
for index in self.doc_dict:
lines += str(index) + ' ' + self.doc_dict[index] + '\n'
f = open(index_file, 'w')
f.write(lines)
f.close()
def load_doc_index(self):
self.doc_dict = {}
index_file = os.path.join(FLAGS.base_dir, "index_file.txt")
if os.path.exists(index_file):
f = open(index_file)
lines = f.readlines()
# 把文件内容读出来存到lines 再关掉,不占内存
f.close()
for line in lines:
line = line.strip()
tokens = line.split(" ")
self.doc_dict[int(tokens[0])] = tokens[1]
return self.doc_dict
def setModel(self):
x_train = self.get_trainset() # 获取预处理的语料
self.save_doc_index() # 保存index_file
self.model_dm = self.train(x_train, epoch_num=FLAGS.epoch_num) # 训练模型,若已经训练过可以省略这步
return("set model success!")
def clean_str(self, string):
string = re.sub('\s+', "", string)
r1 = u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
string = re.sub(r1, ' ', string)
return string.strip()
def seperate_line(self, line):
line = pseg.cut(line)
new_line = []
for words, flag in line:
if flag == 'nr' or flag == 'ns':
continue
if len(flag) == 0:
continue
if flag[0:1] != 'n' and flag != 'v':
continue
new_line.append(words)
#return ''.join([word + " " for word in new_line])
return new_line
def getInfoDetail(self, text):
new_text = {}
t_arr = text.split("$#&#$")
for item in t_arr:
i_arr = item.split(":",1)
if len(i_arr) > 1:
new_text[i_arr[0]] = i_arr[1]
return new_text
def getIndexFromDoct(self, name):
if len(self.doc_dict) > 0:
for key, value in self.doc_dict.items():
if value == name:
return key
return -1
if __name__ == '__main__':
# if __name__ == '__main__' 函数只有直接当作脚本执行时才有效,Import到其他模块时无效
r = retrieve()
x_train, doc_dict = r.get_trainset() # 获取预处理的语料
r.save_doc_index(doc_dict) # 保存index_file
doc_dict = r.load_doc_index()# 加载index_file
model_dm = r.train(x_train,epoch_num=FLAGS.epoch_num) # 训练模型,若已经训练过可以省略这步
到这里就能够完成基本的信息检索了。如果需要更复杂的功能,可以在这个代码的基础上进行改造和扩展,也欢迎小伙伴们找我一起交流。共同进步~
版权声明:本文为sinat_29673403原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。