码迷,mamicode.com
首页 > 其他好文 > 详细

天气预测(CNN)

时间:2019-03-14 13:19:24      阅读:247      评论:0      收藏:0      [点我收藏+]

标签:super   print   self   大小   oss   worker   ide   list   false   

  1 import torch
  2 import torch.nn as nn
  3 import torch.utils.data as Data
  4 import numpy as np
  5 import pymysql
  6 import datetime
  7 import csv
  8 import time
  9 
 10 
 11 EPOCH = 100
 12 BATCH_SIZE = 50
 13 
 14 
 15 class MyNet(nn.Module):
 16     def __init__(self):
 17         super(MyNet, self).__init__()
 18         self.con1 = nn.Sequential(
 19             nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
 20             nn.MaxPool1d(kernel_size=1),
 21             nn.ReLU(),
 22         )
 23         self.con2 = nn.Sequential(
 24             nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
 25             nn.MaxPool1d(kernel_size=1),
 26             nn.ReLU(),
 27         )
 28         self.fc = nn.Sequential(
 29             # 线性分类器
 30             nn.Linear(128*6*1, 128),  # 修改大小后要重新计算
 31             nn.ReLU(),
 32             nn.Linear(128, 6),
 33             # nn.Softmax(dim=1),
 34         )
 35         self.mls = nn.MSELoss()
 36         self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3)
 37         self.start = datetime.datetime.now()
 38 
 39     def forward(self, inputs):
 40         out = self.con1(inputs)
 41         out = self.con2(out)
 42         out = out.view(out.size(0), -1)  # 展开成一维
 43         out = self.fc(out)
 44         # out = F.log_softmax(out, dim=1)
 45         return out
 46 
 47     def train(self, x, y):
 48         out = self.forward(x)
 49         loss = self.mls(out, y)
 50         print(loss: , loss)
 51         self.opt.zero_grad()
 52         loss.backward()
 53         self.opt.step()
 54 
 55     def test(self, x):
 56         out = self.forward(x)
 57         return out
 58 
 59     def get_data(self):
 60         with open(aaa.csv, r) as f:
 61             results = csv.reader(f)
 62             results = [row for row in results]
 63             results = results[1:1500]
 64         inputs = []
 65         labels = []
 66         for result in results:
 67             # 手动独热编码
 68             one_hot = [0 for i in range(6)]
 69             index = int(result[6])-1
 70             one_hot[index] = 1
 71             # labels.append(label)
 72             # one_hot = []
 73             # label = result[6]
 74             # for i in range(6):
 75             #     if str(i) == label:
 76             #         one_hot.append(1)
 77             #     else:
 78             #         one_hot.append(0)
 79             labels.append(one_hot)
 80             input = result[:6]
 81             input = [float(x) for x in input]
 82             # label = [float(y) for y in label]
 83             inputs.append(input)
 84         # print(labels)  # [[0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1],
 85         time.sleep(10)
 86         inputs = np.array(inputs)
 87         labels = np.array(labels)
 88         inputs = torch.from_numpy(inputs).float()
 89         inputs = torch.unsqueeze(inputs, 1)
 90 
 91         labels = torch.from_numpy(labels).float()
 92         return inputs, labels
 93 
 94     def get_test_data(self):
 95         with open(aaa.csv, r) as f:
 96             results = csv.reader(f)
 97             results = [row for row in results]
 98             results = results[1500: 1817]
 99         inputs = []
100         labels = []
101         for result in results:
102             label = [result[6]]
103             input = result[:6]
104             input = [float(x) for x in input]
105             label = [float(y) for y in label]
106             inputs.append(input)
107             labels.append(label)
108         inputs = np.array(inputs)
109         # labels = np.array(labels)
110         inputs = torch.from_numpy(inputs).float()
111         inputs = torch.unsqueeze(inputs, 1)
112         labels = np.array(labels)
113         labels = torch.from_numpy(labels).float()
114         return inputs, labels
115 
116 
117 if __name__ == __main__:
118     # 训练数据
119     # net = MyNet()
120     # x_data, y_data = net.get_data()
121     # torch_dataset = Data.TensorDataset(x_data, y_data)
122     # loader = Data.DataLoader(
123     #     dataset=torch_dataset,
124     #     batch_size=BATCH_SIZE,
125     #     shuffle=True,
126     #     num_workers=2,
127     # )
128     # for epoch in range(EPOCH):
129     #     for step, (batch_x, batch_y) in enumerate(loader):
130     #         print(step)
131     #         # print(‘batch_x={};  batch_y={}‘.format(batch_x, batch_y))
132     #         net.train(batch_x, batch_y)
133     # # 保存模型
134     # torch.save(net, ‘net.pkl‘)
135 
136 
137     # 测试数据
138     net = MyNet()
139     net.get_test_data()
140     # 加载模型
141     net = torch.load(net.pkl)
142     x_data, y_data = net.get_test_data()
143     torch_dataset = Data.TensorDataset(x_data, y_data)
144     loader = Data.DataLoader(
145         dataset=torch_dataset,
146         batch_size=100,
147         shuffle=False,
148         num_workers=1,
149     )
150     num_success = 0
151     num_sum = 317
152     for step, (batch_x, batch_y) in enumerate(loader):
153         # print(step)
154         output = net.test(batch_x)
155         # output = output.detach().numpy()
156         y = batch_y.detach().numpy()
157         for index, i in enumerate(output):
158             i = i.detach().numpy()
159             i = i.tolist()
160             j = i.index(max(i))
161             print(输出为{}标签为{}.format(j+1, y[index][0]))
162             loss = j+1-y[index][0]
163             if loss == 0.0:
164                 num_success += 1
165     print(正确率为{}.format(num_success/num_sum))

 

天气预测(CNN)

标签:super   print   self   大小   oss   worker   ide   list   false   

原文地址:https://www.cnblogs.com/MC-Curry/p/10529566.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!