码迷,mamicode.com
首页 > 其他好文 > 详细

【增强学习】Recurrent Visual Attention源码解读

时间:2016-06-21 07:40:30      阅读:1012      评论:0      收藏:0      [点我收藏+]

标签:

Mnih, Volodymyr, Nicolas Heess, and Alex Graves. “Recurrent models of visual attention.” Advances in Neural Information Processing Systems. 2014.

这里下载训练代码,戳这里下载测试代码。

这篇文章处理的任务非常简单:MNIST手写数字分类。但使用了聚焦机制(Visual Attention),不是一次看一张大图进行估计,而是分多次观察小部分图像,根据每次查看结果移动观察位置,最后估计结果

Yoshua Bengio的高徒,先后供职于LISA和Element Research的Nicolas Leonard用Torch实现了这篇文章的算法。Torch官方cheetsheet的demo中,就包含这篇源码,作者自己的讲解也刊登在Torch的博客中,足见其重要性。

通过这篇源码,我们可以
- 理解聚焦机制中较简单的hard attention
- 了解增强学习的基本流程
- 复习Torch和扩展包dp的相关语法

本文解读训练源码,分三大部分:参数设置,网络构造,训练设置。以下逐次介绍其中重要的语句。

参数设置

除了Torch之外,还需要包含Nicholas Leonard自己编写的两个包。dp:能够简化DL流程,训练过程更“面向对象”;rnn:实现Recurrent网络。

require ‘dp‘
require ‘rnn‘

首先使用Torch的CmdLine类设定一系列参数,存储在opt中。这是Torch的标准写法。

cmd = torch.CmdLine()
cmd:option(‘--learningRate‘, 0.01, ‘learning rate at t=0‘)    -- 参数名,参数值,说明
local opt = cmd:parse(arg or {})    --把cmd中的参数传入opt

把数据载入到数据集ds中,数据是dp包中已经下载好的:

ds = dp[opt.dataset]()

网络构造

这篇源码中模型的写法遵循:由底到顶,先细节后整体。和CNN不同,Recurrent网络带有反馈,呈现较为复杂的多级嵌套结构。请着重关注每个模块的输入输出作用部分。

Glimpse网络

输入:图像I和观察位置l
输出:观察结果x

蓝色输入,橙色输出,菱形表示串接:
技术分享

首先用locationSensor(左半)提取位置信息l中的特征:

locationSensor:add(nn.SelectTable(2))    --选择两个输入中的第二个,位置l
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize))    --Torch中的Linear指全连层
locationSensor:add(nn[opt.transfer]())    --opt.transfer定义一种非线性运算,本文中是ReLU

之后用glimpseSensor(右半)提取图像I位置l的特征。
其中SpacialGlimpse是dp中定义的层,提取尺寸为PatchSize的Depth层图像,相邻层比例为Scale。

glimpseSensor:add(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float())    --SpatialGlimpse提取小块金字塔
glimpseSensor:add(nn.Collapse(3))    --压缩第三维
glimpseSensor:add(nn.Linear(ds:imageSize(‘c‘)*(opt.glimpsePatchSize^2)*opt.glimpseDepth, opt.glimpseHiddenSize))
glimpseSensor:add(nn[opt.transfer]())

两者结果串接为glimpse,输出包含位置和纹理信息的x,尺寸为hiddenSize:

glimpse:add(nn.ConcatTable():add(locationSensor):add(glimpseSensor))
glimpse:add(nn.JoinTable(1,1))    --把串接数据合并成一个Tensor
glimpse:add(nn.Linear(opt.glimpseHiddenSize+opt.locatorHiddenSize, opt.imageHiddenSize))
glimpse:add(nn[opt.transfer]())
glimpse:add(nn.Linear(opt.imageHiddenSize, opt.hiddenSize))    --从imageHiddenSize到hiddenSize的全连层

作用:通过小范围观测,提取纹理和位置信息。

说明
Torch的基础数据是Tensor,而lua中用Table实现类似数组的功能。nn库中专门有一系列Table层,用于处理涉及这两者的运算。例如:
ConcatTable - 把若干个输出Tensor放置在一个Table中。
SelectTable - 从输入的Table中选择一个Tensor。
JoinTable - 把输入Table中的所有Tensor合并成一个Tensor。

Recurrent网络

输入:和Glimpse网络相同,图像I,观察位置l
输出:系统循环状态r
技术分享

使用Recurrent类创建一个包含Glimpse子网络的rnn框架。Recurrent类的第二个参数(glimpse)指出如何处理输入,第三个参数(recurrent)指出如何处理前一时刻的循环状态。

recurrent = nn.Linear(opt.hiddenSize, opt.hiddenSize)
rnn = nn.Recurrent(opt.hiddenSize, glimpse, recurrent, nn[opt.transfer](), 99999)

作用:通过小范围观测,更新网络循环状态。

Locator网络

输入:系统循环状态r,也就是Recurrent网络的输出
输出:观测位置l
技术分享

这部分核心是dp库中的ReinforceNormal层:正态分布的强化学习层。dp库中还有其他分布的强化学习层。

locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd, opt.stochastic)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h")))    --对位置l做了归一化:相对图像中心的最大偏移为unitPixel。

ReinforceNormal层在训练状态下,会以前一层输入为均值,以第一个参数(2*opt.locatorStd)为方差,产生符合高斯分布采样结果;
在训练状态下,如果第二个参数(opt.stochastic)为真,则以相同方式采样,否则直接传递前一层结果。

简单来说,Reinforce层的作用是:在训练时,围绕当前策略(前层输出),探索一些新策略(高斯采样)。具体怎么训练在下篇再说。

作用:利用系统循环状态,决定观测位置。

Attention网络

输入:图像I
输出:系统循环状态r
技术分享

直接使用rnn包中的RecurrentAttention层进行定义。第一个参数(rnn)指明如何处理循环状态r的记忆,第二个参数(locator)指明利用循环状态执行何种动作(action)。

attention = nn.RecurrentAttention(rnn, locator, opt.rho, {opt.hiddenSize})

作用:利用图像更新系统循环状态。

Agent网络

输入:图像I
输出:字符属于各类的概率向量p
技术分享

在前面attention网络的基础上,只对系统循环变量做简单非线性变换,即得到图像属于各类字符的概率p

agent:add(attention)
agent:add(nn.SelectTable(-1))
agent:add(nn.Linear(opt.hiddenSize, #ds:classes()))
agent:add(nn.LogSoftMax())    -- 这里输出分类结果

由于系统中存在强化学习层ReinforceNormal,所以需要一个baseline变量b。这里利用ConcatTableb和分类结果合并到一个Table里输出。

seq:add(nn.Constant(1,1))
seq:add(nn.Add(1))
concat = nn.ConcatTable():add(nn.Identity()):add(seq)
concat2 = nn.ConcatTable():add(nn.Identity()):add(concat)
agent:add(concat2)

整个系有两组输出:分类结果p,以及分类结果+baseline对{p,b}

作用:把系统隐变量转化成估计结果,并且输出一个baseline,便于后续优化。

训练设置

在dp库中,训练过程是分层定义的,为了说明清晰,倒序讲解。
首先(在代码里是最后),定义实验xp,使用的模型就是前述网络agent

xp = dp.Experiment{
   model = agent,       -- nn.Sequential, 待优化模型
   optimizer = train,   -- dp.Optimizer,训练
   validator = valid,   -- dp.Evaluator,验证
   tester = tester,     -- dp.Evaluator,测试
   observer = {         -- 设定log
      ad,
      dp.FileLogger(),
      dp.EarlyStopper{
         max_epochs = opt.maxTries,
         error_report={‘validator‘,‘feedback‘,‘confusion‘,‘accuracy‘},
         maximize = true
      }
   },
   random_seed = os.time(),
   max_epoch = opt.maxEpoch   -- 最大迭代次数
}

训练

train是一个dp.Optimizer类型对象,这个类继承自抽象类dp.propogator,需要指明6个参数:

train = dp.Optimizer{
    loss=..., epoch_callback=..., callback = ..., feedback - ...,sampler = ..., progress = ...
}

loss定义了损失层。用ParallelCriterion把监督学习的ClassNLLCriterion和增强学习的VRClassReward并列优化。

loss = nn.ParallelCriterion(true)
    :add(nn.ModuleCriterion(nn.ClassNLLCriterion(), nil,nn.Convert())) --  监督学习:negative log-likelihood
    :add(nn.ModuleCriterion(nn.VRClassReward(agent, opt.rewardScale), nil, nn.Convert())) -- 增强学习:得分最高类与标定相同反馈1,否则反馈-1

epoch_callback函数设定每个epoch结束时执行的动作,一般用来调整opt中的学习率。

epoch_callback = function(model, report) -- called every epoch
  if report.epoch > 0 then
     opt.learningRate = opt.learningRate + opt.decayFactor
     opt.learningRate = math.max(opt.minLR, opt.learningRate)
     if not opt.silent then
        print("learningRate", opt.learningRate)
     end
  end
end

callback是核心函数,更新模型参数:

callback = function(model, report)
    if opt.cutoffNorm > 0 then
        local norm = model:gradParamClip(opt.cutoffNorm) -- dpnn扩展,约束梯度,有益于RNN
        opt.meanNorm = opt.meanNorm and (opt.meanNorm*0.9 + norm*0.1) or norm;
        if opt.lastEpoch < report.epoch and not opt.silent then
            print("mean gradParam norm", opt.meanNorm)
        end
    end
    model:updateGradParameters(opt.momentum) -- dpnn扩展,根据momentum更新梯度
    model:updateParameters(opt.learningRate) -- 根据学习率更新参数
    model:maxParamNorm(opt.maxOutNorm) -- dpnn扩展,约束参数范围
    model:zeroGradParameters() -- 梯度置零
end

feedback提供I/O用来生成报告,这里输出分类结果与真值比较的confusion matrix。回忆一下:网络的输出是{p,{p,b}},所以真正的输出用SelectTable(1)获得。

feedback = dp.Confusion{output_module=nn.SelectTable(1)}

sampler决定如何从训练集中采样:设定epoch和batch大小。

sampler = dp.ShuffleSampler{
    epoch_size = opt.trainEpochSize, batch_size = opt.batchSize
   }

progress是个布尔型,控制是否显示进度条。

progress = opt.progress

验证与测试

valid是一个dp.Evaluator类成员变量,同样继承自dp.propogator。只需要指明feedbacksamplerprogress这三个参数即可。

valid = dp.Evaluator{
   feedback = dp.Confusion{output_module=nn.SelectTable(1)},
   sampler = dp.Sampler{epoch_size = opt.validEpochSize, batch_size = opt.batchSize},
   progress = opt.progress
}

testvalid类似,连进度条都不用打了

tester = dp.Evaluator{
  feedback = dp.Confusion{output_module=nn.SelectTable(1)},
  sampler = dp.Sampler{batch_size = opt.batchSize}
}

执行

在这一步,把已经读取好的数据集ds输入到实验xp中去:

xp:run(ds)

【增强学习】Recurrent Visual Attention源码解读

标签:

原文地址:http://blog.csdn.net/shenxiaolu1984/article/details/51582185

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!