码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch使用中需要注意的地方

时间:2020-06-08 19:21:24      阅读:92      评论:0      收藏:0      [点我收藏+]

标签:a long   ddb   ade   预测   required   numpy   sed   cuda   需要   

参考博客:

https://blog.csdn.net/u011276025/article/details/73826562/

 

1. 把Label要转成LongTensor格式

self.y = torch.LongTensor(y)

完整使用代码如下:

技术图片
 1 class ImgDataset(Dataset):
 2     def __init__(self, x, y=None, transform=None):
 3         self.x = x
 4         # label is required to be a LongTensor
 5         self.y = y
 6         if y is not None:
 7             self.y = torch.LongTensor(y)
 8         self.transform = transform
 9     def __len__(self):
10         return len(self.x)
11     def __getitem__(self, index):
12         X = self.x[index]
13         if self.transform is not None:
14             X = self.transform(X)
15         if self.y is not None:
16             Y = self.y[index]
17             return X, Y
18         else:
19             return X
View Code
技术图片
 1 class ImgDataset(Dataset):
 2     def __init__(self, x, y=None, transform=None):
 3         self.x = x
 4         # label is required to be a LongTensor
 5         self.y = y
 6         if y is not None:
 7             self.y = torch.LongTensor(y)
 8         self.transform = transform
 9     def __len__(self):
10         return len(self.x)
11     def __getitem__(self, index):
12         X = self.x[index]
13         if self.transform is not None:
14             X = self.transform(X)
15         if self.y is not None:
16             Y = self.y[index]
17             return X, Y
18         else:
19             return X
View Code

需要保证target类型为torch.cuda.LongTensor,需要在数据读取的迭代其中把target的类型转换为int64位的:target = target.astype(np.int64),这样,输出的target类型为torch.cuda.LongTensor。(或者在使用前使用Tensor.type(torch.LongTensor)进行转换)。

*LongTensor其实就是int64,有符号整型

 

 

2. 做预测时,没有y值,从dataloader中传入给model的直接是data,而不再是data[0]了

model_best.eval()
prediction = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        #print(data[0].size())
        # 特别要注意的是,这里直接传入data,因为已经没有y值了,所以无需data[0]。
        # 如果传了data[0]反而导致没有传入整个batch,计算错误
        test_pred = model_best(data.cuda())
        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
        for y in test_label:
            prediction.append(y)

 

未完待续。。。

PyTorch使用中需要注意的地方

标签:a long   ddb   ade   预测   required   numpy   sed   cuda   需要   

原文地址:https://www.cnblogs.com/YeZzz/p/13067470.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!