标签:super print self 大小 oss worker ide list false
1 import torch 2 import torch.nn as nn 3 import torch.utils.data as Data 4 import numpy as np 5 import pymysql 6 import datetime 7 import csv 8 import time 9 10 11 EPOCH = 100 12 BATCH_SIZE = 50 13 14 15 class MyNet(nn.Module): 16 def __init__(self): 17 super(MyNet, self).__init__() 18 self.con1 = nn.Sequential( 19 nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), 20 nn.MaxPool1d(kernel_size=1), 21 nn.ReLU(), 22 ) 23 self.con2 = nn.Sequential( 24 nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 25 nn.MaxPool1d(kernel_size=1), 26 nn.ReLU(), 27 ) 28 self.fc = nn.Sequential( 29 # 线性分类器 30 nn.Linear(128*6*1, 128), # 修改大小后要重新计算 31 nn.ReLU(), 32 nn.Linear(128, 6), 33 # nn.Softmax(dim=1), 34 ) 35 self.mls = nn.MSELoss() 36 self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3) 37 self.start = datetime.datetime.now() 38 39 def forward(self, inputs): 40 out = self.con1(inputs) 41 out = self.con2(out) 42 out = out.view(out.size(0), -1) # 展开成一维 43 out = self.fc(out) 44 # out = F.log_softmax(out, dim=1) 45 return out 46 47 def train(self, x, y): 48 out = self.forward(x) 49 loss = self.mls(out, y) 50 print(‘loss: ‘, loss) 51 self.opt.zero_grad() 52 loss.backward() 53 self.opt.step() 54 55 def test(self, x): 56 out = self.forward(x) 57 return out 58 59 def get_data(self): 60 with open(‘aaa.csv‘, ‘r‘) as f: 61 results = csv.reader(f) 62 results = [row for row in results] 63 results = results[1:1500] 64 inputs = [] 65 labels = [] 66 for result in results: 67 # 手动独热编码 68 one_hot = [0 for i in range(6)] 69 index = int(result[6])-1 70 one_hot[index] = 1 71 # labels.append(label) 72 # one_hot = [] 73 # label = result[6] 74 # for i in range(6): 75 # if str(i) == label: 76 # one_hot.append(1) 77 # else: 78 # one_hot.append(0) 79 labels.append(one_hot) 80 input = result[:6] 81 input = [float(x) for x in input] 82 # label = [float(y) for y in label] 83 inputs.append(input) 84 # print(labels) # [[0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], 85 time.sleep(10) 86 inputs = np.array(inputs) 87 labels = np.array(labels) 88 inputs = torch.from_numpy(inputs).float() 89 inputs = torch.unsqueeze(inputs, 1) 90 91 labels = torch.from_numpy(labels).float() 92 return inputs, labels 93 94 def get_test_data(self): 95 with open(‘aaa.csv‘, ‘r‘) as f: 96 results = csv.reader(f) 97 results = [row for row in results] 98 results = results[1500: 1817] 99 inputs = [] 100 labels = [] 101 for result in results: 102 label = [result[6]] 103 input = result[:6] 104 input = [float(x) for x in input] 105 label = [float(y) for y in label] 106 inputs.append(input) 107 labels.append(label) 108 inputs = np.array(inputs) 109 # labels = np.array(labels) 110 inputs = torch.from_numpy(inputs).float() 111 inputs = torch.unsqueeze(inputs, 1) 112 labels = np.array(labels) 113 labels = torch.from_numpy(labels).float() 114 return inputs, labels 115 116 117 if __name__ == ‘__main__‘: 118 # 训练数据 119 # net = MyNet() 120 # x_data, y_data = net.get_data() 121 # torch_dataset = Data.TensorDataset(x_data, y_data) 122 # loader = Data.DataLoader( 123 # dataset=torch_dataset, 124 # batch_size=BATCH_SIZE, 125 # shuffle=True, 126 # num_workers=2, 127 # ) 128 # for epoch in range(EPOCH): 129 # for step, (batch_x, batch_y) in enumerate(loader): 130 # print(step) 131 # # print(‘batch_x={}; batch_y={}‘.format(batch_x, batch_y)) 132 # net.train(batch_x, batch_y) 133 # # 保存模型 134 # torch.save(net, ‘net.pkl‘) 135 136 137 # 测试数据 138 net = MyNet() 139 net.get_test_data() 140 # 加载模型 141 net = torch.load(‘net.pkl‘) 142 x_data, y_data = net.get_test_data() 143 torch_dataset = Data.TensorDataset(x_data, y_data) 144 loader = Data.DataLoader( 145 dataset=torch_dataset, 146 batch_size=100, 147 shuffle=False, 148 num_workers=1, 149 ) 150 num_success = 0 151 num_sum = 317 152 for step, (batch_x, batch_y) in enumerate(loader): 153 # print(step) 154 output = net.test(batch_x) 155 # output = output.detach().numpy() 156 y = batch_y.detach().numpy() 157 for index, i in enumerate(output): 158 i = i.detach().numpy() 159 i = i.tolist() 160 j = i.index(max(i)) 161 print(‘输出为{}标签为{}‘.format(j+1, y[index][0])) 162 loss = j+1-y[index][0] 163 if loss == 0.0: 164 num_success += 1 165 print(‘正确率为{}‘.format(num_success/num_sum))
标签:super print self 大小 oss worker ide list false
原文地址:https://www.cnblogs.com/MC-Curry/p/10529566.html