## 训练集高分辨率图预处理函数
def train_hr_transform(crop_size):
return Compose([
RandomCrop(crop_size),
ToTensor(),
])
## 训练集低分辨率图预处理函数
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
ToTensor()
])
## 训练数据集类
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] ##获得所有图像
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)##获得裁剪尺寸
self.hr_transform = train_hr_transform(crop_size) ##高分辨率图预处理函数
self.lr_transform = train_lr_transform(crop_size, upscale_factor) ##低分辨率图预处理函数
##数据集迭代指针
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index])) ##随机裁剪获得高分辨率图
lr_image = self.lr_transform(hr_image) ##获得低分辨率图
return lr_image, hr_image
def __len__(self):
return len(self.image_filenames)
## 验证数据集类
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]