为什么要选择mxnet呢,主要是被tensorflow弄烦了!不是说tensorflow很难上手,就是来来回回折腾了好多次,换个口味。其次,仰慕李沐大神,国人小骄傲。再其次,接触了一点mxnet之后确实比tensorflow友好些,以后做实验可能mxnet更方便的一点。当然,前期tensorflow踩过的坑让我上手mxnet快多了。
做简单的实验
用mxnet的NDARRAY ,AUTOGRAD就足够了
NDARRAY,AUTOGRAD的使用自己看
官方教程,中文文档实在写的很厉害,再次崇拜李沐大神。而我只是记录自己学习使用的一些心得和体会
还是想简单说下自动求导
- 关于求变量x的梯度,先用attach_grad函数来申请存储梯度需要的内存
- record函数记录有关变量x的计算
- backward函数对record中进行求梯度,backward的该为标量,如果非标量,可以求和转为标量
深度学习最好的入门是线性回归
构造数据集
1 | n_examples = 1000 |
构造训练的迭代器,指定batchsize
1 | import random |
定参数,并初始化
1 | w = nd.random_normal(shape=[n_inputs,1]) |
网络计算
1 | def net(x): |
定义损失函数
1 | def squre_loss(yhat,y): |
优化器,如何更新参数,这里选梯度下降方法
1 | def SGD(params,lr): |
开始训练,设置迭代次数
1 | epoch = 5 |
真诚的附上运行截图
给出gluon版本
gluon给出了 构造数据迭代器的api
给出了搭建层的Sequential容器,串联各个层,给定输入,容器中的每一层输出将作为下一层的输入,里面的Dense相当于全连接层
不需要自己定参数,只需给输出维度
给出loss方法
给出迭代param的trainer,但是trainer.step传了一个batchsize,使得太小的学习率无法快速收敛,与上一个版本不同。
训练部分几乎一下,整体思路就是这些
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 from mxnet import ndarray as nd
from mxnet import autograd as ag
#构造数据集
n_examples = 1000
n_inputs = 2
x = nd.random_normal(shape=[n_examples,n_inputs])
true_w = [2,-3.4]
true_b = 4.2
y = x[:,0]*true_w[0]+x[:,1]*true_w[1]+true_b
#给出噪音
y+=.01*nd.random_normal(shape=y.shape)
#用gluon的好处不用考虑参数,对复杂网络特别友好
from mxnet import gluon
#构造迭代器
batch_size = 10
dataset = gluon.data.ArrayDataset(x,y)
data_iter = gluon.data.DataLoader(dataset,batch_size = batch_size,shuffle=True)
#网络用sequential容器,只需指定输出的维度,初始网络中的参数
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1))
net.initialize()
squre_loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})
#开始训练
epoch = 5
dense = net[0]
for e in range(epoch):
total_loss = 0
for data,label in data_iter:
with ag.record():
yhat = net(data)
loss = squre_loss(yhat,label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
print('After %d epoch,loss id %.10f params is'%(e+1,total_loss),dense.weight.data(),dense.bias.data())
总结
1.构造数据集(迭代器)
2.模型的构造
3.定义损失函数
4.定义优化器
5.迭代训练,更新参数