码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow加载embedding模型进行可视化

时间:2019-01-04 19:41:34      阅读:417      评论:0      收藏:0      [点我收藏+]

标签:port   结束   line   .project   als   \n   config   plugins   int   

1.功能

采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量

ps:采用C++版本训练的w2v模型,python的gensim模块读不了。

2.python训练word2vec模型代码

import multiprocessing

from gensim.models.word2vec import Word2Vec, LineSentence

print(开始训练)
train_file = "/tmp/train_data"

model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10)
print(结束)
model.init_sims(replace=True)
model.save(/tmp/emb.bin)

3.tensorflow读取模型可视化

import numpy as np
import tensorflow as tf
import os
from gensim.models.word2vec import Word2Vec
from tensorflow.contrib.tensorboard.plugins import projector

log_dir = /tmp/embedding_log
if not os.path.exists(log_dir):
    os.mkdir(log_dir)


# load model
model_file = /tmp/emb.bin
word2vec = Word2Vec.load(model_file)

# create a list of vectors
embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32)
for i, word in enumerate(word2vec.vocab.keys()):
    embedding[i] = word2vec[word]

# setup a TensorFlow session
tf.reset_default_graph()
sess = tf.InteractiveSession()
X = tf.Variable([0.0], name=embedding)
place = tf.placeholder(tf.float32, shape=embedding.shape)
set_x = tf.assign(X, place, validate_shape=False)
sess.run(tf.global_variables_initializer())
sess.run(set_x, feed_dict={place: embedding})

# write labels
with open(os.path.join(log_dir, metadata.tsv), w) as f:
    for word in word2vec.vocab.keys():
        f.write(word + \n)

# create a TensorFlow summary writer
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
config = projector.ProjectorConfig()
embedding_conf = config.embeddings.add()
embedding_conf.tensor_name = embedding:0
embedding_conf.metadata_path = os.path.join(log_dir, metadata.tsv)
projector.visualize_embeddings(summary_writer, config)

# save the model
saver = tf.train.Saver()
saver.save(sess, os.path.join(log_dir, "model.ckpt"))

print("完成!")

 

tensorflow加载embedding模型进行可视化

标签:port   结束   line   .project   als   \n   config   plugins   int   

原文地址:https://www.cnblogs.com/aijianiula/p/10221970.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!