码迷,mamicode.com
首页 > 其他好文 > 详细

SRGAN 学习心得

时间:2019-07-28 19:24:09      阅读:570      评论:0      收藏:0      [点我收藏+]

标签:图片   else   sar   event   cme   article   ram   eal   jpg   

一、理论

关于SRGAN的的论文中文翻译网上一大堆,可以直接读网络模型(大概了解),关于loss的理解,然后就能跑代码

loss  = mse + 对抗损失 + 感知损失   : https://blog.csdn.net/DuinoDu/article/details/78819344

技术图片


 

 

 二、代码及其理解

(1)文件结构

技术图片

 

 (2)train.py

技术图片
import argparse
import os
from math import log10

import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator

parser = argparse.ArgumentParser(description=Train Super Resolution Models)
parser.add_argument(--crop_size, default=88, type=int, help=training images crop size)
parser.add_argument(--upscale_factor, default=4, type=int, choices=[2, 4, 8],
                    help=super resolution upscale factor)
parser.add_argument(--num_epochs, default=100, type=int, help=train epoch number)

opt = parser.parse_args()

CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
if __name__ == __main__:
    # 加载数据集
    train_set = TrainDatasetFromFolder(/content/drive/My Drive/app/RBB/train, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder(/content/drive/My Drive/app/RBB/test, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
    # 加载网络模型
    netG = Generator(UPSCALE_FACTOR)
    print(# generator parameters:, sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print(# discriminator parameters:, sum(param.numel() for param in netD.parameters()))
    # 加载loss函数
    generator_criterion = GeneratorLoss()
    # 判断GPU加速
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()
    # 定义Adam优化器
    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())
    # 定义结果保存的字典,值为列表
    results = {d_loss: [], g_loss: [], d_score: [], g_score: [], psnr: [], ssim: []}

    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)  # 生成进度条>>>>>>>>
        # 定义字典统计相关超参数
        running_results = {batch_sizes: 0, d_loss: 0, g_loss: 0, d_score: 0, g_score: 0}

        netG.train()
        netD.train()
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results[batch_sizes] += batch_size
            ############################
            # data/z:由target下采样的低分辨率图像 -->  G --> fake_img --> D --> fake_out(label)
            # target/real_img:高分辨率图像(原图) --> D --> real_out(label)
            ############################
            # (1) 更新判别网络: maximize -1+D(z)-D(G(z))
            #     判别网络的输出是数值,即是一个概率
            ###########################
            real_img = Variable(target)     # torch数据类型的标签图像real_img
            if torch.cuda.is_available():
                real_img = real_img.cuda()

            z = Variable(data)              # torch数据类型的输入图像z
            if torch.cuda.is_available():
                z = z.cuda()

            fake_img = netG(z)              # 生成网络的的输出图像fake_img

            netD.zero_grad()                # 判别网络的梯度归零
            real_out = netD(real_img).mean()  # 判别网络对于标签图像的输出的均值real_out
            fake_out = netD(fake_img).mean()  # 判别网络对于fake_img的输出的均值fake_out
            d_loss = 1 - real_out + fake_out  # d_loss = - [D(z)-1-D(G(z))],所以最小化d_loss,则后一项的最大化
            d_loss.backward(retain_graph=True)  # 反向传播
            optimizerD.step()                   # 梯度优化

            ############################
            # (2) 更新生成网络: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()            # 生成网络梯度归零
            g_loss = generator_criterion(fake_out, fake_img, real_img)  # loss
            g_loss.backward()           # 反向传播
            optimizerG.step()           # 梯度优化
            fake_img = netG(z)          # 生成网络的的输出图像fake_img
            fake_out = netD(fake_img).mean()  # 判别网络对于fake_img的输出的均值fake_out

            g_loss = generator_criterion(fake_out, fake_img, real_img)  # 生成网络loss计算
            running_results[g_loss] += g_loss.item() * batch_size
            d_loss = 1 - real_out + fake_out                            # 判别网络loss计算
            running_results[d_loss] += d_loss.item() * batch_size
            running_results[d_score] += real_out.item() * batch_size
            running_results[g_score] += fake_out.item() * batch_size

            train_bar.set_description(desc=[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f % (
                epoch, NUM_EPOCHS, running_results[d_loss] / running_results[batch_sizes],
                running_results[g_loss] / running_results[batch_sizes],
                running_results[d_score] / running_results[batch_sizes],
                running_results[g_score] / running_results[batch_sizes]))

    # 模型评估
        netG.eval()
        out_path = training_results/SRF_ + str(UPSCALE_FACTOR) + /
        if not os.path.exists(out_path):   # 路径不存在则建立
            os.makedirs(out_path)
        val_bar = tqdm(val_loader)          # 加载验证集
        valing_results = {mse: 0, ssims: 0, psnr: 0, ssim: 0, batch_sizes: 0}
        val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results[batch_sizes] += batch_size
            with torch.no_grad():
                lr = Variable(val_lr)
                hr = Variable(val_hr)
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results[mse] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            valing_results[ssims] += batch_ssim * batch_size
            valing_results[psnr] = 10 * log10(1 / (valing_results[mse] / valing_results[batch_sizes]))
            valing_results[ssim] = valing_results[ssims] / valing_results[batch_sizes]
            val_bar.set_description(
                desc=[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f % (
                    valing_results[psnr], valing_results[ssim]))

        # save model parameters
        torch.save(netG.state_dict(), /content/drive/My Drive/app/SRGAN_master/epochs_RBB/RBB_netG_epoch_%d_%d.pth % (UPSCALE_FACTOR, epoch))
        # torch.save(netD.state_dict(), ‘/content/drive/My Drive/app/SRGAN_master/epochs/RBB_netD_epoch_%d_%d.pth‘ % (UPSCALE_FACTOR, epoch))
        # save loss\scores\psnr\ssim
        results[d_loss].append(running_results[d_loss] / running_results[batch_sizes])
        results[g_loss].append(running_results[g_loss] / running_results[batch_sizes])
        results[d_score].append(running_results[d_score] / running_results[batch_sizes])
        results[g_score].append(running_results[g_score] / running_results[batch_sizes])
        results[psnr].append(valing_results[psnr])
        results[ssim].append(valing_results[ssim])

        if epoch % 10 == 0 and epoch != 0:
            out_path = /content/drive/My Drive/app/SRGAN_master/statistics/
            data_frame = pd.DataFrame(
                data={Loss_D: results[d_loss], Loss_G: results[g_loss], Score_D: results[d_score],
                      Score_G: results[g_score], PSNR: results[psnr], SSIM: results[ssim]},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + srf_ + str(UPSCALE_FACTOR) + _train_results.csv, index_label=Epoch)
View Code

 (3)data_utils.py

技术图片
from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [.png, .jpg, .jpeg, .PNG, .JPG, .JPEG, .tif])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])


def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])


def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])


class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        self.upscale_factor = upscale_factor

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)


class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + /SRF_ + str(upscale_factor) + /data/
        self.hr_path = dataset_dir + /SRF_ + str(upscale_factor) + /target/
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split(/)[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.lr_filenames)
View Code

 (4)loss.py

技术图片
import torch
from torch import nn
from torchvision.models.vgg import vgg16


class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)
View Code

 (5)model.py

技术图片
import math
import torch
# import torch.nn.functional as F
from torch import nn


class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x
View Code

 (6)test_image.py

技术图片
import argparse
import time

import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage

from model import Generator

parser = argparse.ArgumentParser(description=Test Single Image)
parser.add_argument(--upscale_factor, default=4, type=int, help=super resolution upscale factor)
parser.add_argument(--test_mode, default=GPU, type=str, choices=[GPU, CPU], help=using GPU or CPU)
parser.add_argument(--image_name, type=str, help=test low resolution image name)
parser.add_argument(--model_name, default=netG_epoch_2_100.pth, type=str, help=generator model epoch name)
opt = parser.parse_args()

UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == GPU else False
IMAGE_NAME = opt.image_name
MODEL_NAME = opt.model_name

model = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:
    model.cuda()
    model.load_state_dict(torch.load(/content/drive/My Drive/app/SRGAN_master/ + MODEL_NAME))
else:
    model.load_state_dict(torch.load(/content/drive/My Drive/app/SRGAN_master/ + MODEL_NAME, map_location=lambda storage, loc: storage))

image = Image.open(IMAGE_NAME)
with torch.no_grad():
    image = Variable(ToTensor()(image)).unsqueeze(0)
if TEST_MODE:
    image = image.cuda()

start = time.clock()
out = model(image)
elapsed = (time.clock() - start)
print(cost + str(elapsed) + s)
out_img = ToPILImage()(out[0].data.cpu())
out_img.save(/content/drive/My Drive/app/SRGAN_master/result/_out_srf_2.tif)
View Code

 

SRGAN 学习心得

标签:图片   else   sar   event   cme   article   ram   eal   jpg   

原文地址:https://www.cnblogs.com/zgqcn/p/11260343.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!