标签:max elf col help dig start rom tensor from
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import torchvision from torch.autograd import Variable from torch.utils.data import DataLoader import cv2 import numpy as np #traing setting batch_size = 64 learning_rate =1e-4 def fize_buss_encode(i): if i%15 ==0: return 3 elif i %5 ==0: return 2 elif i %3 ==0: return 1 else: return 0 def fize_buss_decode(i,prediction): return [str(i),"fizz", "buzz","fuzzbuzz"][prediction] def helper(i): print(fize_buss_decode(i,fize_buss_encode(i))) NUM_DIGHT = 13 def Transform(i,num_dight): return np.array([i>>d&1 for d in range(num_dight)][::-1]) trainX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101,2**NUM_DIGHT)]) trainY =torch.LongTensor([fize_buss_encode(i)for i in range(101,2**NUM_DIGHT)]) ‘‘‘class twolayer(torch.nn.Module): def __init__(self,D_in,H,D_out): super(twolayer,self).__init__() self.linear1 = nn.Linear(D_in,H) nn.ReLU() self.linear2 = nn.Linear(H,D_out) def forward(self,x): return self.linear2(self.linear1(x).clamp(min=0)) ‘‘‘ H =100 model = torch.nn.Sequential(torch.nn.Linear(NUM_DIGHT,H), torch.nn.ReLU(), torch.nn.Linear(H,4)) loss_fc = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(),lr = 0.05) BATCH_SIZE = 128 for epoch in range(1000): loss =0 for start in range(0,len(trainX),BATCH_SIZE): end = start + BATCH_SIZE train_x = trainX[start:end] train_y = trainY[start:end] loss = loss_fc(model(train_x),train_y) optimizer.zero_grad() loss.backward() optimizer.step() textX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101)]) textY = torch.Tensor([fize_buss_encode(i) for i in range(101)]) with torch.no_grad(): textY_pred = model(textX) predictions = zip(range(0,101),textY_pred.max(1)[1].data.tolist()) for i, j in predictions: if fize_buss_decode(i,j) !=fize_buss_decode(i,fize_buss_encode(i)): print( fize_buss_decode(i,j),fize_buss_decode(i,fize_buss_encode(i)))
标签:max elf col help dig start rom tensor from
原文地址:https://www.cnblogs.com/yin101/p/12918354.html