标签:sig iter pre backward net int 建立 ogr back
require(‘nn‘) local function createQNetwork() local mlp = nn.Sequential() mlp:add(nn.Reshape(10)) mlp:add(nn.Linear(10, 32)) mlp:add(nn.Sigmoid()) mlp:add(nn.Linear(32, 1)) return mlp end local function qinput(obs_table) local obs = torch.Tensor(#obs_table):fill(0) for k = 1, #obs_table do obs[k] = obs_table[k] end obs = obs:view(#obs_table, -1) return obs end qnn = createQNetwork() for i = 1, 30 do obs_table = {} if i <10 then for j = 1, 10 do obs_table[j] = 2 end else for j = 1, 10 do obs_table[j] = 3 end end obs = qinput(obs_table) print(obs) q = qnn:forward(obs) print("q",q) bs = torch.Tensor(1):fill(5) cri = nn.MSECriterion() qloss = cri:forward(q, bs) dloss_dpredict = cri:backward(q, bs) qnn:zeroGradParameters() qgradInput = qnn:backward(obs, dloss_dpredict) qnn:updateParameters(0.3) end
标签:sig iter pre backward net int 建立 ogr back
原文地址:http://www.cnblogs.com/WegZumHimmel/p/7753833.html