标签:turn search 一个 == http 神经网络 strong 梯度下降 for
视频链接:https://www.bilibili.com/video/BV12741177Cu?from=search&seid=17209581732555565064
视频上是用的jupyter notebook实现的,这次我是用的pycharm实现的代码。
fizz\buzz\fizzbuzz小游戏的意思是:如果被3除尽打印fizz,被5除尽打印buzz,被15除尽打印fizzbuzz。这可以用一个函数实现,但是我们是学习神经网络,所以用一个二层神经网络实现,自己去学习,自己去玩,当然界面不实现
主要有三个.py文件:utils.py存放工具函数,model.py训练模型,paragraph2.py:使用模型进行预测
utils.py
import numpy as np
def binary_encode(i, num_digits): # 转二进制计算
return np.array([i >> d & 1 for d in range(num_digits)])[::-1] # [::-1]是把arry倒过来,因为一开始转的是二进制反的
def fizz_buzz_encode(i):
if i % 15 == 0: return 3
elif i % 5 == 0: return 2
elif i % 3 == 0: return 1
else: return 0
def fizz_buzz_decode(i, prediction):
return [str(i), ‘fizz‘, ‘buzz‘, ‘fizzbuzz‘][prediction] #这是个很好玩的用法,我也是第一次见,各位可以打印一下试试
model.py实现:
import torch
from p2.utils import binary_encode, fizz_buzz_encode
NUM_DIGITS = 10
trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)]) # 训练数据, 101致以上,好像是923个
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)]) # x可以是float类型,但是y是表示类别的,不行
NUM_HIDDEN = 100
model = torch.nn.Sequential( # 模型定义,激活函数为ReLU
torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
torch.nn.ReLU(),
torch.nn.Linear(NUM_HIDDEN, 4)
)
if torch.cuda.is_available(): # 模型转到gpu上运行
model = model.cuda()
loss_fn = torch.nn.CrossEntropyLoss() # 损失函数使用交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.05) # 优化算法选择SGD,可百度下SGD,是随机梯度下降法,torch封装了好几个优化算法,可以自行试试
BATCH_SIZE = 128
def __main__():
for epoch in range(1000): # 训练epoch是1000, 视频上老师训练是10000,我嫌太大了,慢,所以改为了1000,但是效果确实不如10000的,可以自己试试
for start in range(0, len(trX), BATCH_SIZE): # 批量大小为BATCH_SIZE
end = start + BATCH_SIZE
batchX = trX[start:end]
batchY = trY[start:end]
if torch.cuda.is_available(): # 训练数据搬到gpu
batchX = batchX.cuda()
batchY = batchY.cuda()
y_pred = model(batchX)
loss = loss_fn(y_pred, batchY)
print("Epoch", epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model, ‘fbmodel.pkl‘)
paragraph2.py实现
import torch
from p2.utils import binary_encode, fizz_buzz_decode
model = torch.load(‘p2/fbmodel.pkl‘)
NUM_DIGITS = 10
testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
if torch.cuda.is_available():
testX = testX.cuda()
with torch.no_grad():
testY = model(testX)
predictions = zip(range(1, 101), testY.max(1)[1].cpu().data.tolist()) # 非常有意思和技巧的一个东西,testY.max(1)[1].cpu().data.tolist()可以自己试试,打印
print([fizz_buzz_decode(i, x) for i, x in predictions])
训练epoch为1000的结果:
[‘1‘, ‘2‘, ‘fizz‘, ‘4‘, ‘buzz‘, ‘fizz‘, ‘7‘, ‘8‘, ‘fizz‘, ‘10‘, ‘11‘, ‘fizz‘, ‘13‘, ‘14‘, ‘fizz‘, ‘16‘, ‘17‘, ‘fizz‘, ‘19‘, ‘20‘, ‘fizz‘, ‘22‘, ‘23‘, ‘fizz‘, ‘25‘, ‘26‘, ‘fizz‘, ‘28‘, ‘29‘, ‘fizz‘, ‘31‘, ‘32‘, ‘fizz‘, ‘34‘, ‘35‘, ‘fizz‘, ‘37‘, ‘38‘, ‘fizz‘, ‘40‘, ‘41‘, ‘42‘, ‘43‘, ‘44‘, ‘fizzbuzz‘, ‘46‘, ‘47‘, ‘fizz‘, ‘49‘, ‘50‘, ‘fizz‘, ‘52‘, ‘53‘, ‘fizz‘, ‘55‘, ‘56‘, ‘fizz‘, ‘58‘, ‘59‘, ‘fizzbuzz‘, ‘61‘, ‘62‘, ‘fizz‘, ‘64‘, ‘buzz‘, ‘fizz‘, ‘67‘, ‘fizz‘, ‘fizz‘, ‘70‘, ‘71‘, ‘fizz‘, ‘73‘, ‘74‘, ‘fizzbuzz‘, ‘76‘, ‘buzz‘, ‘fizz‘, ‘79‘, ‘buzz‘, ‘fizz‘, ‘82‘, ‘83‘, ‘fizz‘, ‘fizz‘, ‘86‘, ‘fizz‘, ‘88‘, ‘89‘, ‘fizzbuzz‘, ‘91‘, ‘92‘, ‘fizz‘, ‘94‘, ‘buzz‘, ‘fizz‘, ‘97‘, ‘98‘, ‘fizz‘, ‘100‘]
训练epoch为10000的结果:
[‘1‘, ‘2‘, ‘fizz‘, ‘4‘, ‘buzz‘, ‘fizz‘, ‘7‘, ‘8‘, ‘fizz‘, ‘buzz‘, ‘11‘, ‘fizz‘, ‘13‘, ‘14‘, ‘fizzbuzz‘, ‘16‘, ‘17‘, ‘fizz‘, ‘fizz‘, ‘buzz‘, ‘fizz‘, ‘22‘, ‘23‘, ‘fizz‘, ‘buzz‘, ‘26‘, ‘fizz‘, ‘28‘, ‘29‘, ‘fizzbuzz‘, ‘31‘, ‘32‘, ‘fizz‘, ‘34‘, ‘buzz‘, ‘fizz‘, ‘37‘, ‘38‘, ‘fizz‘, ‘buzz‘, ‘41‘, ‘fizz‘, ‘43‘, ‘44‘, ‘fizzbuzz‘, ‘46‘, ‘47‘, ‘fizz‘, ‘49‘, ‘buzz‘, ‘fizz‘, ‘52‘, ‘53‘, ‘fizz‘, ‘buzz‘, ‘56‘, ‘fizz‘, ‘58‘, ‘59‘, ‘fizzbuzz‘, ‘61‘, ‘62‘, ‘fizz‘, ‘64‘, ‘buzz‘, ‘66‘, ‘67‘, ‘68‘, ‘fizz‘, ‘70‘, ‘71‘, ‘fizz‘, ‘73‘, ‘74‘, ‘fizzbuzz‘, ‘76‘, ‘77‘, ‘78‘, ‘79‘, ‘buzz‘, ‘fizz‘, ‘82‘, ‘83‘, ‘fizz‘, ‘buzz‘, ‘fizz‘, ‘fizz‘, ‘88‘, ‘89‘, ‘fizzbuzz‘, ‘91‘, ‘92‘, ‘fizz‘, ‘94‘, ‘buzz‘, ‘fizz‘, ‘97‘, ‘98‘, ‘fizz‘, ‘buzz‘]
训练数据多少还是有区别的。
我是小白,虽然我黑。一起学习,一起探讨,加油。
pytorch入门与实战学习->第一课复习(fizz\buzz\fizzbuzz小游戏)
标签:turn search 一个 == http 神经网络 strong 梯度下降 for
原文地址:https://www.cnblogs.com/JadenFK3326/p/13113421.html