# 深度学习 | 实战5-用slim 定义Lenet网络，并训练测试

Posted by JoselynZhao on July 11, 2019

Github源码

## 要求

Slim Lenet

1. 将Lenet 单独定义到Lenet.py 文件

def lenet(images):

1. 用with slim.arg_scope …..: 去管理 lenet中所有操作的默认参数， 例如activation_fn, weights_initializer, 等。。。

2. 编写mnist_train.py脚本，训练slim定义的lenet做MNIST字符分类。

## 代码展示

### LENET

def lenet(image):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0,0.1), #mu，sigma
weights_regularizer=slim.l2_regularizer(0.1)):
net = slim.conv2d(image, 6, [5, 5], stride=1, padding="VALID", scope="conv1")
net = slim.max_pool2d(net, [2, 2], stride=2, padding="VALID", scope="pool1")
net = slim.conv2d(net,16,[5,5],stride=1,padding = "VALID",scope ="conv2")
net = slim.flatten(net,scope="flatten")
net = slim.fully_connected(net,120, scope='fc1')
net = slim.fully_connected(net,84, scope='fc2')
net = slim.fully_connected(net,10,activation_fn=None, scope='fc3')
return net



### mnist_train

if __name__ =="__main__":
x_test = np.reshape(mnist.test.images, [-1, 28, 28, 1])
x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2), (0, 0)),'constant')  # print("Updated Image Shape: {}".format(X_train[0].shape))

iteratons = 10000
batch_size = 64
lr = 0.1

x = tf.placeholder(tf.float32, [None, 32, 32, 1])
y_ = tf.placeholder(tf.float32, [None, 10])
y = lenet(x)
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
# 准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
for i in range(iteratons):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = np.reshape(batch_xs, [-1, 28, 28, 1])
batch_xs = np.pad(batch_xs, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant')
sess.run([train_step,cross_entropy],feed_dict={x:batch_xs,y_:batch_ys})
if i % 500 ==1:
acc = sess.run(accuracy,feed_dict={x:x_test,y_:mnist.test.labels})
print("%5d: accuracy is: %4f" % (i, acc))

print('[accuracy,loss]:',sess.run([accuracy,cross_entropy],feed_dict={x:x_test,y_:mnist.test.labels}))