分布式TensorFlow测试代码示例

数据集:minist  (我走的是本地读取)

数据集链接:https://pan.baidu.com/s/1o2faz60YLaba3q7hn_JWqg       提取码:yv3y

代码和数据集放在一个文件下

分布式tensorflow测试代码_tensorflow

目的:测试服务器是否安装成功cuda和cudnn

环境:ubuntu16.04,python3.6,tensorflow-gpu1.10,cuda9.0,cudnn7.4

登录后复制

import mathimport tensorflow as tffrom tensorflow.examples.tutorials
.mnist import input_dataimport osimport timeflags = tf.app.flagsflags
.DEFINE_string("data_dir", r"./mnist", "the directory of mnist_data")flags
.DEFINE_integer("train_step",1000, "the step of train")flags.DEFINE_integer("batch_size", 
128, "the number of batch")flags.DEFINE_integer("image_size", 28, "the size of image")flags
.DEFINE_integer("hid_num", 100, "the size of hid layer")flags.DEFINE_float("learning_rate", 0.01, 
"the learning rate")# flags.DEFINE_string("checkpoint_dir",r"./temp/checkpoint",
"the directory of checkpoint")# flags.DEFINE_string("log_dir",r"./temp/log",
"the directory of log")flags.DEFINE_string("summary_dir", r"./temp/summary", 
"the directory of summary")flags.DEFINE_integer("task_index", 0, "the index of task")flags
.DEFINE_string("job_name", "ps", "ps or worker")flags.DEFINE_string("ps_host","localhost:22333", 
"the ip and port in ps host")flags.DEFINE_string("worker_host", "localhost:21333", 
"the ip and port in worker host")flags.DEFINE_string("cuda", "", 
"specify gpu")FLAGS = flags.FLAGSif FLAGS.cuda:os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS
.cudamnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)def main(_):
#train_step_list=[50]ps_spc = FLAGS.ps_host.split(",")worker_spc = FLAGS.worker_host.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_spc, "worker": worker_spc})server = tf
.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)if FLAGS
.job_name == "ps":server.join()is_chief = (FLAGS.task_index == 0)with tf
.device(tf.train.replica_device_setter(cluster=cluster)):start = time.time()global_step = tf
.Variable(0, name="global_step", trainable=False)hid_w = tf
.Variable(tf.truncated_normal(shape=[FLAGS.image_size * FLAGS.image_size, FLAGS
.hid_num],stddev=1.0 / FLAGS.image_size), name="hid_w")hid_b = tf
.Variable(tf.zeros(shape=[FLAGS.hid_num]), name="hid_b")sm_w = tf
.Variable(tf.truncated_normal(shape=[FLAGS.hid_num, 10], stddev=1.0 / math
.sqrt(FLAGS.hid_num)),name="sm_w")sm_b = tf.Variable(tf.zeros(shape=[10]), name="sm_b")x = tf
.placeholder(tf.float32, [None, FLAGS.image_size * FLAGS.image_size])y_ = tf
.placeholder(tf.float32, [None, 10])hid_lay = tf.nn.xw_plus_b(x, hid_w, hid_b)hid_act = tf.nn
.relu(hid_lay)y = tf.nn.softmax(tf.nn.xw_plus_b(hid_act, sm_w, sm_b))cross_entropy = -tf
.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-4, 1.0)))train_op = tf.train
.GradientDescentOptimizer(FLAGS.learning_rate).minimize(cross_entropy,global_step=global_step)
#last_step=500hooks = [tf.train.StopAtStepHook(last_step=FLAGS.train_step)]#             
tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir,#                                          
save_steps=1000)]# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
# sess_config = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, 
allow_soft_placement=True)# sess_config.gpu_options.allow_growth = Truesess_config = tf
.ConfigProto(log_device_placement=False)with tf.train.MonitoredTrainingSession(master=server
.target,is_chief=is_chief,#                                           
checkpoint_dir=FLAGS.checkpoint_dir,hooks=hooks,config=sess_config)as mon_sess:step = 0while 
True:step += 1batch_x, batch_y = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_x, y_: batch_y}_, loss_v, g_step = mon_sess
.run([train_op, cross_entropy, global_step], feed_dict=train_feed)print("step: %d, cross_entropy: 
%f, global_step:%d" % (step, loss_v, g_step))if mon_sess.should_stop():end = time
.time()#print("step_size=", last_step)print("time costing:", end - start)breakif __name__ == "__main__":
tf.app.run()1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.
34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.
68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.

代码是一个ps,一个worker。19行和20行都走的是本地ip,如有需要多机分布式,自行修改。

如果运行提示grpc错误,杀死python的进程

运行代码:

登录后复制

python mnist_monite.py --job_name=ps --task_index=0 --cuda=-11.

再开一个页面,输入:

登录后复制

python mnist_monite.py --job_name=worker --task_index=0 --cuda=01.

然后下图是PS的运行截图:

分布式tensorflow测试代码_ubuntu_02

然后是worker的截图:

分布式tensorflow测试代码_数据集_03

ok了

               



免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空