标签:str sample 处理 closed sha rgba tuple spatial ati
# config
1 # data parameters 2 dataset_name: paris 3 data_with_subfolder: False # 是否有子文件夹 4 train_data_path: F:\\pycharm\\Dataset\\paris\\paris_eval_gt 5 val_data_path: 6 resume: 7 batch_size: 5 8 image_shape: [256, 256, 3] # resize之后输入network的size 9 mask_shape: [128, 128] 10 mask_batch_same: True 11 max_delta_shape: [32, 32] 12 margin: [0, 0] 13 discounted_mask: True 14 spatial_discounting_gamma: 0.9 15 random_crop: True 16 mask_type: hole # hole | mosaic 17 mosaic_unit_size: 12
# 加载处理数据
1 train_dataset = Dataset(data_path=config[‘train_data_path‘], 2 with_subfolder=config[‘data_with_subfolder‘], 3 image_shape=config[‘image_shape‘], 4 random_crop=config[‘random_crop‘]) 5 # val_dataset = Dataset(data_path=config[‘val_data_path‘], 6 # with_subfolder=config[‘data_with_subfolder‘], 7 # image_size=config[‘image_size‘], 8 # random_crop=config[‘random_crop‘]) 9 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 10 batch_size=config[‘batch_size‘], 11 shuffle=True, 12 num_workers=config[‘num_workers‘]) 13 # val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 14 # batch_size=config[‘batch_size‘], 15 # shuffle=False, 16 # num_workers=config[‘num_workers‘])
# Dataset
1 class Dataset(data.Dataset): 2 def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False): 3 super(Dataset, self).__init__() 4 if with_subfolder: # 检查是否有子目录,如果有加载进去 5 self.samples = self._find_samples_in_subfolders(data_path) 6 else: 7 self.samples = [x for x in listdir(data_path) if is_image_file(x)] 8 self.data_path = data_path 9 self.image_shape = image_shape[:-1] 10 self.random_crop = random_crop 11 self.return_name = return_name 12 13 def __getitem__(self, index): 14 # 加载训练数据集目录 F:\\pycharm\\Dataset\\paris\\paris_eval_gt\001_im.png 15 path = os.path.join(self.data_path, self.samples[index]) 16 # 加载图片 img_size [227, 227] 17 img = default_loader(path) 18 # 随机裁剪 resize之后输入network的size,resize为[image_shape[0], image_shape[1]] 19 if self.random_crop: 20 imgw, imgh = img.size 21 if imgh < self.image_shape[0] or imgw < self.image_shape[1]: 22 img = transforms.Resize(min(self.image_shape))(img) 23 img = transforms.RandomCrop(self.image_shape)(img) 24 else: 25 img = transforms.Resize(self.image_shape)(img) 26 img = transforms.RandomCrop(self.image_shape)(img) 27 28 # 转变为torch.Size([3, image_shape[0], image_shape[1]]) 29 img = transforms.ToTensor()(img) # turn the image to a tensor 30 img = normalize(img) 31 32 if self.return_name: 33 return self.samples[index], img 34 else: 35 return img 36 37 def _find_samples_in_subfolders(self, dir): 38 """ 39 Finds the class folders in a dataset. 40 Args: 41 dir (string): Root directory path. 42 Returns: 43 tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 44 Ensures: 45 No class is a subdirectory of another. 46 """ 47 if sys.version_info >= (3, 5): 48 # Faster and available in Python 3.5 and above 49 classes = [d.name for d in os.scandir(dir) if d.is_dir()] 50 else: 51 classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 52 classes.sort() 53 class_to_idx = {classes[i]: i for i in range(len(classes))} 54 samples = [] 55 for target in sorted(class_to_idx.keys()): 56 d = os.path.join(dir, target) 57 if not os.path.isdir(d): 58 continue 59 for root, _, fnames in sorted(os.walk(d)): 60 for fname in sorted(fnames): 61 if is_image_file(fname): 62 path = os.path.join(root, fname) 63 # item = (path, class_to_idx[target]) 64 # samples.append(item) 65 samples.append(path) 66 return samples 67 68 def __len__(self): 69 return len(self.samples)
标签:str sample 处理 closed sha rgba tuple spatial ati
原文地址:https://www.cnblogs.com/Overture/p/14587127.html