码迷,mamicode.com
首页 > 其他好文 > 详细

Pytorch之数据处理

时间:2020-07-10 19:23:20      阅读:71      评论:0      收藏:0      [点我收藏+]

标签:ber   ack   dem   bsp   sum   数据   rop   ring   drop   

使用TensorDataset和DataLoader来简化

 
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
?
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
?
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
 
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
 
 
 
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout

 

 

import numpy as np
?
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)
?
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print(‘当前step:‘+str(step), ‘验证集损失:‘+str(val_loss))
 
 
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
 
 
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)
?
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
?
    return loss.item(), len(xb)
 
 

 

 

三行搞定!

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
 
 
 
 
 
?

 

Pytorch之数据处理

标签:ber   ack   dem   bsp   sum   数据   rop   ring   drop   

原文地址:https://www.cnblogs.com/BetterThanEver_Victor/p/13280586.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!