码迷,mamicode.com
首页 > 其他好文 > 详细

FizeBuzz

时间:2020-05-19 18:51:33      阅读:69      评论:0      收藏:0      [点我收藏+]

标签:max   elf   col   help   dig   start   rom   tensor   from   

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import numpy as np
#traing setting
batch_size = 64
learning_rate =1e-4

def fize_buss_encode(i):
    if i%15 ==0: return 3
    elif i %5 ==0: return 2
    elif i %3 ==0: return 1
    else: return 0
def fize_buss_decode(i,prediction):
    return [str(i),"fizz", "buzz","fuzzbuzz"][prediction]
def helper(i):
    print(fize_buss_decode(i,fize_buss_encode(i)))

NUM_DIGHT = 13

def Transform(i,num_dight):
    return np.array([i>>d&1 for d in range(num_dight)][::-1])

trainX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101,2**NUM_DIGHT)])
trainY =torch.LongTensor([fize_buss_encode(i)for i in range(101,2**NUM_DIGHT)])
‘‘‘class twolayer(torch.nn.Module):
    def __init__(self,D_in,H,D_out):
        super(twolayer,self).__init__()
        self.linear1 = nn.Linear(D_in,H)
        nn.ReLU()
        self.linear2 = nn.Linear(H,D_out)
    def forward(self,x):
        return self.linear2(self.linear1(x).clamp(min=0))
‘‘‘
H =100 

model = torch.nn.Sequential(torch.nn.Linear(NUM_DIGHT,H),
                            torch.nn.ReLU(),
                            torch.nn.Linear(H,4))

loss_fc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr = 0.05)
BATCH_SIZE = 128

for epoch in range(1000):
    loss =0
    for start in range(0,len(trainX),BATCH_SIZE):
        end = start + BATCH_SIZE
        train_x = trainX[start:end]
        train_y = trainY[start:end]
        loss = loss_fc(model(train_x),train_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

textX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101)])
textY = torch.Tensor([fize_buss_encode(i) for i in range(101)])

with torch.no_grad():
    textY_pred = model(textX)

predictions = zip(range(0,101),textY_pred.max(1)[1].data.tolist())
for i, j in predictions:
    if fize_buss_decode(i,j) !=fize_buss_decode(i,fize_buss_encode(i)):
        print( fize_buss_decode(i,j),fize_buss_decode(i,fize_buss_encode(i)))

 

FizeBuzz

标签:max   elf   col   help   dig   start   rom   tensor   from   

原文地址:https://www.cnblogs.com/yin101/p/12918354.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!