码迷,mamicode.com
首页 > 其他好文 > 详细

干货|10分钟入门PyTorch(2)~附源码

时间:2020-11-27 10:54:43      阅读:5      评论:0      收藏:0      [点我收藏+]

标签:类别   for   阅读   pre   就是   cto   随机梯度   函数   rect   

10分钟入门PyTorch(2)

上一节介绍了简单的线性回归10分钟快速入门PyTorch(1),如何在pytorch里面用最小二乘来拟合一些离散的点,这一节我们将开始简单的logistic回归,介绍图像分类问题,使用的数据是手写字体数据集MNIST。

1

logistic回归

logistic回归简单来说和线性回归是一样的,要做的运算同样是 y = w * x + b。
logistic回归简单的是做二分类问题,使用sigmoid函数将所有的正数和负数都变成0-1之间的数,这样就可以用这个数来确定到底属于哪一类,可以简单的认为概率大于0.5即为第二类,小于0.5为第一类。
技术图片

这就是sigmoid的图形
技术图片

而我们这里要做的是多分类问题,对于每一个数据,我们输出的维数是分类的总数,比如10分类,我们输出的就是一个10维的向量,然后我们使用另外一个激活函数,softmax
技术图片
这就是softmax函数作用的机制,其实简单的理解就是确定这10个数每个数对应的概率有多大,因为这10个数有正有负,所以通过指数函数将他们全部变成正数,然后求和,然后这10个数每个数都除以这个和,这样就得到了每个类别的概率。

data

首先导入torch里面专门做图形处理的一个库,torchvision,根据官方安装指南,你在安装pytorch的时候torchvision也会安装。
我们需要使用的是torchvision.transforms和torchvision.datasets以及torch.utils.data.DataLoader

首先DataLoader是导入图片的操作,里面有一些参数,比如batch_size和shuffle等,默认load进去的图片类型是PIL.Image.open的类型,如果你不知道PIL,简单来说就是一种读取图片的库

torchvision.transforms里面的操作是对导入的图片做处理,比如可以随机取(50, 50)这样的窗框大小,或者随机翻转,或者去中间的(50, 50)的窗框大小部分等等,但是里面必须要用的是transforms.ToTensor(),这可以将PIL的图片类型转换成tensor,这样pytorch才可以对其做处理

torchvision.datasets里面有很多数据类型,里面有官网处理好的数据,比如我们要使用的MNIST数据集,可以通过torchvision.datasets.MNIST()来得到,还有一个常使用的是torchvision.datasets.ImageFolder(),这个可以让我们按文件夹来取图片,和keras里面的flow_from_directory()类似,具体的可以去看看官方文档的介绍。
技术图片

以上就是我们对图片数据的读取操作

model

之前讲过模型定义的框架,废话不多说,直接上代码
技术图片
我们需要向这个模型传入参数,第一个参数定义为数据的维度,第二维数是我们分类的数目。

接着我们可以在gpu上跑模型,怎么做呢?
首先可以判断一下你是否能在gpu上跑

技术图片
如果返回True就说明有gpu支持
接着你只需要一个简单的命令就可以了

技术图片

或者

技术图片

都可以
然后需要定义loss和optimizer

技术图片

这里我们使用的loss是交叉熵,是一种处理分类问题的loss,optimizer我们还是使用随机梯度下降

train

接着就可以开始训练了

技术图片
技术图片

注意我们如果将模型放到了gpu上,相应的我们的Variable也要放到gpu上,也很简单

技术图片

然后可以测试模型,过程与训练类似,只是注意要将模型改成测试模式

技术图片

这是跑完100 epoch的结果

技术图片

具体的结果多久打印一次,如何打印可以自己在for循环里面去设计

这一部分我们就讲解了如何用logistic回归去做一个简单的图片分类问题,知道了如何在gpu上跑模型,下一节我们将介绍如何写简单的卷积神经网络,不了解卷积网络的同学可以先去我的专栏看看之前卷积网络的介绍。

本文代码已经上传到了github上
欢迎查看我的知乎专栏,深度炼丹
欢迎访问我的博客

推荐阅读文章:

10分钟快速入门PyTorch(1)
10分钟入门pytorch(0)
隐马尔科夫模型-前向算法

全是通俗易懂的硬货!只需置顶~欢迎关注交流~

技术图片

干货|10分钟入门PyTorch(2)~附源码

标签:类别   for   阅读   pre   就是   cto   随机梯度   函数   rect   

原文地址:https://blog.51cto.com/15009309/2553589

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!