diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/.keep" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/band_mean/bandmean.m" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/band_mean/bandmean.m" new file mode 100644 index 0000000000000000000000000000000000000000..f3dd1a4b267fa70b0f059b0ed68e4eeebdf753c8 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/band_mean/bandmean.m" @@ -0,0 +1,23 @@ +clear +clc +close all + +dataset = 'CAVE'; + +%% obtian the original hyperspectral image +src_path = ['/data2/cys/data/',dataset,'/process_train/2/']; +fileFolder=fullfile(src_path); +dirOutput=dir(fullfile(fileFolder,'*.mat')); +fileNames={dirOutput.name}'; +length(fileNames) + +for i = 1:length(fileNames) + name = char(fileNames(i)); + disp(['-----deal with:',num2str(i),'----name:',name]); + data_path = [src_path, '/', name]; + load(data_path) + sizeLR = size(hsi); + band_mean(i,:) = mean(reshape(hsi,[sizeLR(1)*sizeLR(2), sizeLR(3)])); +end + +band_mean = mean(band_mean) \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/generate_testdata_CAVE.m" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/generate_testdata_CAVE.m" new file mode 100644 index 0000000000000000000000000000000000000000..ea074e1c0d8a0f84002351654b1d57f732553dcc --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/generate_testdata_CAVE.m" @@ -0,0 +1,63 @@ +clear +clc +close all + +dataset = 'CAVE'; +upscale = 2; +savePath = ['/data2/cys/data/',dataset,'/process_test/',num2str(upscale)]; % save test set to "savePath" +if ~exist(savePath, 'dir') + mkdir(savePath) +end + +%% obtian all the original hyperspectral image +srPath = '/data2/cys/data/CAVE/test/'; +srFile=fullfile(srPath); +srdirOutput=dir(fullfile(srFile)); +srfileNames={srdirOutput.name}'; +number = length(srfileNames) + +for index = 1 : number + name = char(srfileNames(index)); + if(isequal(name,'.')||... % remove the two hidden folders that come with the system + isequal(name,'..')) + continue; + end + disp(['-----deal with:',num2str(index),'----name:',name]); + + singlePath= [srPath, name, '/', name]; + disp(['path:',singlePath]); + singleFile=fullfile(singlePath); + srdirOutput=dir(fullfile(singleFile,'/*.png')); + singlefileNames={srdirOutput.name}'; + Band = length(singlefileNames); + source = zeros(512*512, Band); + for i = 1:Band + srName = char(singlefileNames(i)); + srImage = imread([singlePath,'/',srName]); + if i == 1 + width = size(srImage,1); + height = size(srImage,2); + end + %try + source(:,i) = srImage(:); + %catch TODO: 有个西瓜图错误,莫名其妙 + % disp([num2str(i),' size: ',num2str(size(srImage,1)),' ',num2str(size(srImage,2))]); + % end + end + + %% normalization + imgz=double(source(:)); + imgz=imgz./65535; + img=reshape(imgz,width*height, Band); + + %% obtian HR and LR hyperspectral image + hrImage = reshape(img, width, height, Band); + + HR = modcrop(hrImage, upscale); + LR = imresize(HR,1/upscale,'bicubic'); %LR + save([savePath,'/',name,'.mat'], 'HR', 'LR') + + clear source + clear HR + clear LR +end \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/modcrop.m" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/modcrop.m" new file mode 100644 index 0000000000000000000000000000000000000000..1093b19fdab29864452e2b0c87a9b79bf2fd070c --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/testset_pre-processing/modcrop.m" @@ -0,0 +1,11 @@ +function imgs = modcrop(imgs, modulo) +if size(imgs,3)==1 + sz = size(imgs); + sz = sz - mod(sz, modulo); + imgs = imgs(1:sz(1), 1:sz(2)); +else + tmpsz = size(imgs); + sz = tmpsz(1:2); + sz = sz - mod(sz, modulo); + imgs = imgs(1:sz(1), 1:sz(2),:); +end diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_CAVE.m" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_CAVE.m" new file mode 100644 index 0000000000000000000000000000000000000000..38f8a404bdb310199368cc86a8e4e2a6cc697578 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_CAVE.m" @@ -0,0 +1,90 @@ +clc +clear +close all + +%% define hyperparameters +Band = 31; +patchSize = 32; +randomNumber = 24; +upscale_factor = 4; +data_type = 'CAVE'; +global count +count = 0; +imagePatch = patchSize*upscale_factor; +scales = [1.0, 0.75, 0.5]; +%% bulid upscale folder +savePath=['/data2/cys/data/',data_type,'/process_train/',num2str(upscale_factor),'/']; +if ~exist(savePath, 'dir') + mkdir(savePath) +end + +%% +srPath = '/data2/cys/data/CAVE/train/'; %source data downlaoded from website +srFile=fullfile(srPath); +srdirOutput=dir(fullfile(srFile)); +srfileNames={srdirOutput.name}'; +number = length(srfileNames)-2 +% disp(srfileNames) + +for index = 1:length(srfileNames) + name = char(srfileNames(index)); + if(isequal(name,'.')||... % remove the two hidden folders that come with the system + isequal(name,'..')) + continue; + end + disp(['----:',data_type,'----upscale_factor:',num2str(upscale_factor),'----deal with:',num2str(index-2),'----name:',name]); + + singlePath= [srPath, name, '/', name]; + singleFile=fullfile(singlePath); + srdirOutput=dir(fullfile(singleFile,'/*.png')); + singlefileNames={srdirOutput.name}'; + Band = length(singlefileNames); + % disp(['SinglePath:',singlePath,'Band Num: ',num2str(Band)]) + source = zeros(512*512, Band); + for i = 1:Band + srName = char(singlefileNames(i)); + srImage = imread([singlePath,'/',srName]); + if i == 1 + width = size(srImage,1); + height = size(srImage,2); + end + source(:,i) = srImage(:); + end + + %% normalization + imgz=double(source(:)); + img=imgz./65535; + t = reshape(img, width, height, Band); + + %% + for sc = 1:length(scales) + newt = imresize(t, scales(sc)); + x_random = randperm(size(newt,1) - imagePatch, randomNumber); + y_random = randperm(size(newt,2) - imagePatch, randomNumber); + + for j = 1:randomNumber + hrImage = newt(x_random(j):x_random(j)+imagePatch-1, y_random(j):y_random(j)+imagePatch-1, :); + + label = hrImage; + data_augment(label, upscale_factor, savePath); + + label = imrotate(hrImage,180); + data_augment(label, upscale_factor, savePath); + + label = imrotate(hrImage,90); + data_augment(label, upscale_factor, savePath); + + label = imrotate(hrImage,270); + data_augment(label, upscale_factor, savePath); + + label = flipdim(hrImage,1); + data_augment(label, upscale_factor, savePath); + + end + clear x_random; + clear y_random; + clear newt; + + end + clear t; +end \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_augment.m" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_augment.m" new file mode 100644 index 0000000000000000000000000000000000000000..6b2421789a98eeb64138df36e62d633a2873db09 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_pre_processing/train_set_augment/data_augment.m" @@ -0,0 +1,11 @@ +function [outputArg1,outputArg2] = data_augment(label, upscale_factor, savePath) + global count + input = imresize(label, 1/upscale_factor, 'bicubic'); + count = count+1; + count_name = num2str(count, '%05d'); + lr = permute(input, [3 1 2]); + hr = permute(label, [3 1 2]); + lr = single(lr); + hr = single(hr); + save([savePath,count_name,'.mat'],'lr','hr'); % save augmented hyperspectral image to "savePath" +end \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_utils.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_utils.py" new file mode 100644 index 0000000000000000000000000000000000000000..425b2ca7187176cb90d7e96ae86d2c7f9a8e6d5f --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/data_utils.py" @@ -0,0 +1,239 @@ +import sys +from tkinter import Variable +import torch +import numpy as np +import torch.utils.data as data + +from os import listdir +from os.path import join +import scipy.io as scio + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in [".mat"]) + + +def shuffle(x, g=3,mode="origin"): + """Band shuffle""" + C, H, W = x.size() + # g = 3 # group number + # 维度变换之后必须要使用.contiguous()使得张量在内存连续之后才能调用view函数 + if mode == "origin": + if C // g * g < C: + # 不能整除,最后几个band不shuffle + split_x = x[: C // g * g, :, :] + x[: C // g * g, :, :] = ( + split_x.view(g, int(C / g), H, W) + .permute(1, 0, 2, 3) + .contiguous() + .view(C // g * g, H, W) + ) + else: + x = x.view(g, int(C / g), H, W).permute(1, 0, 2, 3).contiguous().view(C, H, W) + elif mode == "odd-even": + # 奇序列和偶序列分离 + odd_idx = [i*2+1 for i in range(C//2+1) if i*2+1 < C] + even_idx = [i*2 for i in range(C//2+1)] + idx = even_idx + odd_idx + # print(idx) + x = x[idx,:,:].view(x.size()) + return x + + +def bandshuffle(x, g=3): + # assert(len(x.size) == 4) + B, C, H, W = x.size() + if C // g * g < C: + # 不能整除,最后几个band不shuffle + split_x = x[:, : C // g * g, :, :] + x[:, : C // g * g, :, :] = ( + split_x.view(B, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(B, C // g * g, H, W) + ) + else: + x = ( + x.view(B, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(B, C, H, W) + ) + return x + +def choose_x(input, i, mode): + """ + 选择和本帧差异性最大的三帧 + 目前是将band数量除以三,来跳跃式选取 + """ + Num = input.shape[1] + if mode == "jump": + x = [] + x.append(input[:, i:i+1, :, :]) + second = (i + Num//3) % Num + third = (i+2*Num//3) & Num + x.append(input[:, second:second+1, :, :]) + x.append(input[:, third:third+1, :, :]) + x = torch.cat(x, 1) + elif mode == "origin": + if i == 0: + x = input[:, 0:3, :, :] + elif i == Num - 1: + x = input[:, i-2:i+1, :, :] + else: + x = input[:, i-1:i+2, :, :] + return x + +class TrainsetFromFolder(data.Dataset): + def __init__(self, dataset_dir, shuffle,shufflemode="origin",g=3): + super(TrainsetFromFolder, self).__init__() + self.image_filenames = [ + join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x) + ] + self.shuffle = shuffle + self.g = g + self.shufflemode = shufflemode + + def __getitem__(self, index): + mat = scio.loadmat( + self.image_filenames[index], verify_compressed_data_integrity=False + ) + # mat = h5py.File(self.image_filenames[index], "r") + input = mat["lr"].astype(np.float32) + label = mat["hr"].astype(np.float32) + + # print(input.shape, label.shape) + input = torch.from_numpy(input) + label = torch.from_numpy(label) + # input shape:C,H,W + # assert(not torch.equal(input, shuffle(input))) + if self.shuffle and "random" not in self.shufflemode: + # print("shuffle") + input = shuffle(input,self.g,self.shufflemode) + label = shuffle(label,self.g,self.shufflemode) + elif self.shuffle and self.shufflemode == "random": + idx = torch.randperm(input.shape[0]) + input = input[idx,:,:].view(input.size()) + label = label[idx,:,:].view(label.size()) + elif self.shufflemode == "grouprandom": + idx = [] + C = input.shape[0] // self.g + for i in range(self.g): + idx.append(torch.randperm(C) + i * C) + if(C * self.g < input.shape[0]): + idx.append(torch.randperm(input.shape[0] - C * self.g)+input.shape[0] // self.g * self.g) + idxs = torch.cat(idx,dim=0) + input = input[idxs,:,:].view(input.size()) + label = label[idxs,:,:].view(label.size()) + input = shuffle(input,self.g) + label = shuffle(label,self.g) + return input, label + + def __len__(self): + return len(self.image_filenames) + + +class ValsetFromFolder(data.Dataset): + def __init__(self, dataset_dir, shuffle,shufflemode="origin",g=3): + super(ValsetFromFolder, self).__init__() + self.image_filenames = [ + join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x) + ] + self.shuffle = shuffle + self.g = g + self.shufflemode = shufflemode + + def __getitem__(self, index): + mat = scio.loadmat(self.image_filenames[index]) + # mat = h5py.File(self.image_filenames[index], "r") + input = mat["LR"].astype(np.float32).transpose(2, 0, 1) + label = mat["HR"].astype(np.float32).transpose(2, 0, 1) + # print(input.shape) + input = torch.from_numpy(input).float() + label = torch.from_numpy(label).float() + if self.shuffle and "random" not in self.shufflemode: + # print("shuffle") + input = shuffle(input,self.g,self.shufflemode) + label = shuffle(label,self.g,self.shufflemode) + elif self.shuffle and self.shufflemode == "random": + idx = torch.randperm(input.shape[0]) + input = input[idx,:,:].view(input.size()) + label = label[idx,:,:].view(label.size()) + elif self.shufflemode == "grouprandom": + idx = [] + C = input.shape[0] // self.g + for i in range(self.g): + idx.append(torch.randperm(C) + i * C) + if(C * self.g < input.shape[0]): + idx.append(torch.randperm(input.shape[0] - C * self.g)+input.shape[0] // self.g * self.g) + idxs = torch.cat(idx,dim=0) + input = input[idxs,:,:].view(input.size()) + label = label[idxs,:,:].view(label.size()) + # grouprandom + input = shuffle(input,self.g) + label = shuffle(label,self.g) # groupsamling + + + + return input, label + + def __len__(self): + return len(self.image_filenames) + + +def chop_forward(x, model, scale,shave=16): + b, c, h, w = x.size() + h_half, w_half = h // 2, w // 2 + h_size, w_size = h_half + shave, w_half + shave + inputlist = [ + x[:,:, 0:h_size, 0:w_size], + x[:,:, 0:h_size, (w - w_size):w], + x[:,:, (h - h_size):h, 0:w_size], + x[:,:, (h - h_size):h, (w - w_size):w]] + outputlist = [] + for i in range(4): + input_batch = inputlist[i] + output_batch = model(input_batch) + # print("patch shape:",output_batch.shape) + outputlist.append(output_batch) + + output = np.zeros((c, h*scale, w*scale)).astype(np.float32) + # print("output shape: ",output.shape) + output[:, 0:h_half*scale, 0:w_half*scale] = outputlist[0][0, :, 0:h_half*scale, 0:w_half*scale].cpu().numpy() + output[:, 0:h_half*scale, w_half*scale:w*scale] = outputlist[1][0, :, 0:h_half*scale, (w_size - w + w_half)*scale:w_size*scale].cpu().numpy() + output[:, h_half*scale:h*scale, 0:w_half*scale] = outputlist[2][0, :, (h_size - h + h_half)*scale:h_size*scale, 0:w_half*scale].cpu().numpy() + output[:, h_half*scale:h*scale, w_half*scale:w*scale] = outputlist[3][0, :, (h_size - h + h_half)*scale:h_size*scale, (w_size - w + w_half)*scale:w_size*scale].cpu().numpy() + + return output + + +def rand_bbox(size,lam): + W = size[2] + H = size[3] + + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy + cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1,bby1,bbx2,bby2 + +def CutMix(input,label,beta = 0.5): + lam = np.random.beta(beta,beta) + rand_index = torch.randperm(input.size()[0]).cuda() + bbx1,bby1,bbx2,bby2 = rand_bbox(input.size(),lam) + + s = int(label.size()[2] / input.size()[2]) + # print(s) + + # B,C,H,W + input[:,:,bbx1:bbx2,bby1:bby2] = input[rand_index,:,bbx1:bbx2,bby1:bby2] + label[:,:,bbx1*s:bbx2*s,bby1*s:bby2*s] = label[rand_index,:,bbx1*s:bbx2*s,bby1*s:bby2*s] + return input,label \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/eval.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/eval.py" new file mode 100644 index 0000000000000000000000000000000000000000..b47cea81633835099e312110240be2ad280dc613 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/eval.py" @@ -0,0 +1,147 @@ +import math +import cv2 +import numpy as np +from scipy.signal import convolve2d +import torch + + +def PSNR(pred, gt): + valid = gt - pred + rmse = math.sqrt(np.mean(valid ** 2)) + + if rmse == 0: + return 100 + psnr = 20 * math.log10(1.0 / rmse) + return psnr + + +def SSIM(pred, gt): + ssim = 0 + for i in range(gt.shape[0]): + ssim = ssim + compute_ssim(pred[i, :, :], gt[i, :, :]) + return ssim / gt.shape[0] + + +def SAM(pred, gt): + # Shape N,H,W + eps = 2.2204e-16 + pred[np.where(pred == 0)] = eps + gt[np.where(gt == 0)] = eps + + nom = sum(pred*gt) + denom1 = sum(pred*pred)**0.5 + denom2 = sum(gt*gt)**0.5 + sam = np.real(np.arccos(nom.astype(np.float32)/(denom1*denom2+eps))) + sam[np.isnan(sam)] = 0 + sam_sum = np.mean(sam)*180/np.pi + + return sam_sum + + +def cal_sam(Itrue, Ifake): + if len(Itrue.shape) == 3: + Itrue = Itrue.unsqueeze(0) + Ifake = Ifake.unsqueeze(0) + # print(Itrue.shape) B,N,H,W + esp = 2.2204e-16 + InnerPro = torch.sum(Itrue*Ifake, 1, keepdim=True) + len1 = torch.norm(Itrue, p=2, dim=1, keepdim=True) + len2 = torch.norm(Ifake, p=2, dim=1, keepdim=True) + divisor = len1*len2 + mask = torch.eq(divisor, 0) + divisor = divisor + (mask.float())*esp + cosA = torch.sum(InnerPro/divisor, 1).clamp(-1+esp, 1-esp) + sam = torch.acos(cosA) + sam = torch.mean(sam)*180 / np.pi + return sam + + +def matlab_style_gauss2D(shape=np.array([11, 11]), sigma=1.5): + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + siz = (shape-np.array([1, 1]))/2 + std = sigma + eps = 2.2204e-16 + x = np.arange(-siz[1], siz[1]+1, 1) + y = np.arange(-siz[0], siz[1]+1, 1) + m, n = np.meshgrid(x, y) + + h = np.exp(-(m*m + n*n).astype(np.float32) / (2.*sigma*sigma)) + h[h < eps*h.max()] = 0 + sumh = h.sum() + + if sumh != 0: + h = h.astype(np.float32) / sumh + return h + + +def filter2(x, kernel, mode='same'): + return convolve2d(x, np.rot90(kernel, 2), mode=mode) + + +def compute_ssim(im1, im2, k1=0.01, k2=0.03, win_size=11, L=1): + + if not im1.shape == im2.shape: + raise ValueError("Input Imagees must have the same dimensions") + if len(im1.shape) > 2: + raise ValueError("Please input the images with 1 channel") + + M, N = im1.shape + C1 = (k1*L)**2 + C2 = (k2*L)**2 + window = matlab_style_gauss2D( + shape=np.array([win_size, win_size]), sigma=1.5) + window = window.astype(np.float32)/np.sum(np.sum(window)) + + mu1 = filter2(im1, window, 'valid') + mu2 = filter2(im2, window, 'valid') + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = filter2(im1*im1, window, 'valid') - mu1_sq + sigma2_sq = filter2(im2*im2, window, 'valid') - mu2_sq + sigmal2 = filter2(im1*im2, window, 'valid') - mu1_mu2 + + ssim_map = ((2*mu1_mu2+C1) * (2*sigmal2+C2)).astype(np.float32) / \ + ((mu1_sq+mu2_sq+C1) * (sigma1_sq+sigma2_sq+C2)) + + return np.mean(np.mean(ssim_map)) + + +def constrast(img): + # 传入灰度图 + m,n = img.shape + b = np.sum(np.power(img[:,1:] - img[:,:n-1],2)) + np.sum(np.power(img[1:,:]-img[:m-1,:],2)) + return b/((m-1)*n+(n-1)*m) + +def Bconstrast(img): + # img: B,N,H,W -- > B,N,1 + B,N,H,W = img.size() + # print(img.shape) + b = torch.sum((img[:,:,:,1:]-img[:,:,:,:W-1])*(img[:,:,:,1:]-img[:,:,:,:W-1]),dim=(2,3),keepdim=True) \ + + torch.sum((img[:,:,1:,:]-img[:,:,:H-1,:])*(img[:,:,1:,:]-img[:,:,:H-1,:]),dim=(2,3),keepdim=True) + b = b / ((H-1)*W + (W-1)*H) + # print(b.shape) + return b + + + + + + + + + # img_ext = cv2.copyMakeBorder(img,1,1,1,1,cv2.BORDER_REPLICATE)/1.0 + # rows_ext,cols_ext = img_ext.shape + # b = 0.0 + # for i in range(1,rows_ext- 1): + # for j in range(1,cols_ext-1): + # b += ((img_ext[i,j]-img_ext[i,j+1])**2 + (img_ext[i,j]-img_ext[i,j-1])**2+ + # (img_ext[i,j]-img_ext[i+1,j])**2 + (img_ext[i,j]-img_ext[i-1,j])**2) + + # cg = b/(4*(m-2)*(n-2)+3*(2*(m-2)+2*(n-2))+2*4) + # return cg + # return b + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HLoss.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HLoss.py" new file mode 100644 index 0000000000000000000000000000000000000000..d5cab1928e5bd0f1baf616f9bb984779c9a4c23b --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HLoss.py" @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + + +import numpy as np +from scipy.signal import convolve2d +import torch +import torch.nn.functional as F + +def cal_gradient_c(x): + c_x = x.size(1) + g = x[:, 1:, 1:, 1:] - x[:, :c_x - 1, 1:, 1:] + return g + +def cal_gradient_x(x): + c_x = x.size(2) + g = x[:, 1:, 1:, 1:] - x[:, 1:, :c_x - 1, 1:] + return g + +def cal_gradient_y(x): + c_x = x.size(3) + g = x[:, 1:, 1:, 1:] - x[:, 1:, 1:, :c_x - 1] + return g + +def cal_gradient(inp): + x = cal_gradient_x(inp) + y = cal_gradient_y(inp) + c = cal_gradient_c(inp) + g = torch.sqrt(torch.pow(x, 2) + torch.pow(y,2) + torch.pow(c,2)+1e-6) + return g + +def cal_sam(Itrue, Ifake): + esp = 1e-6 + InnerPro = torch.sum(Itrue*Ifake,1,keepdim=True) + len1 = torch.norm(Itrue, p=2,dim=1,keepdim=True) + len2 = torch.norm(Ifake, p=2,dim=1,keepdim=True) + divisor = len1*len2 + mask = torch.eq(divisor,0) + divisor = divisor + (mask.float())*esp + cosA = torch.sum(InnerPro/divisor,1).clamp(-1+esp, 1-esp) + sam = torch.acos(cosA) + sam = torch.mean(sam) / np.pi + return sam + +class HLoss(torch.nn.Module): + def __init__(self, la1=0.5,la2=0.1,sam=True, gra=True): + super(HLoss,self).__init__() + self.lamd1 = la1 + self.lamd2 = la2 + self.sam = sam + self.gra = gra + + self.fidelity = torch.nn.L1Loss() + self.gra = torch.nn.L1Loss() + + def forward(self, y, gt): + loss1 = self.fidelity(y, gt) + loss2 = self.lamd1*cal_sam(y, gt) + loss3 = self.lamd2*self.gra(cal_gradient(y),cal_gradient(gt)) + loss = loss1+loss2+loss3 + return loss + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HybridLoss.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HybridLoss.py" new file mode 100644 index 0000000000000000000000000000000000000000..5e720b0075dce480feabfdfd12f42982d9b4bab3 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/HybridLoss.py" @@ -0,0 +1,62 @@ +import torch + + +class HybridLoss(torch.nn.Module): + def __init__(self, lamd=1e-1, spatial_tv=False, spectral_tv=False): + super(HybridLoss, self).__init__() + self.lamd = lamd + self.use_spatial_TV = spatial_tv + self.use_spectral_TV = spectral_tv + self.fidelity = torch.nn.L1Loss() + self.spatial = TVLoss(weight=1e-3) + self.spectral = TVLossSpectral(weight=1e-3) + + def forward(self, y, gt): + loss = self.fidelity(y, gt) + spatial_TV = 0.0 + spectral_TV = 0.0 + if self.use_spatial_TV: + spatial_TV = self.spatial(y) + if self.use_spectral_TV: + spectral_TV = self.spectral(y) + total_loss = loss + spatial_TV + spectral_TV + return total_loss + + +# from https://github.com/jxgu1016/Total_Variation_Loss.pytorch with slight modifications +class TVLoss(torch.nn.Module): + def __init__(self, weight=1.0): + super(TVLoss, self).__init__() + self.TVLoss_weight = weight + + def forward(self, x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self._tensor_size(x[:, :, 1:, :]) + count_w = self._tensor_size(x[:, :, :, 1:]) + # h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]).sum() + # w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]).sum() + h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() + w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() + return self.TVLoss_weight * (h_tv / count_h + w_tv / count_w) / batch_size + + def _tensor_size(self, t): + return t.size()[1] * t.size()[2] * t.size()[3] + + +class TVLossSpectral(torch.nn.Module): + def __init__(self, weight=1.0): + super(TVLossSpectral, self).__init__() + self.TVLoss_weight = weight + + def forward(self, x): + batch_size = x.size()[0] + c_x = x.size()[1] + count_c = self._tensor_size(x[:, 1:, :, :]) + # c_tv = torch.abs((x[:, 1:, :, :] - x[:, :c_x - 1, :, :])).sum() + c_tv = torch.pow((x[:, 1:, :, :] - x[:, :c_x - 1, :, :]), 2).sum() + return self.TVLoss_weight * 2 * (c_tv / count_c) / batch_size + + def _tensor_size(self, t): + return t.size()[1] * t.size()[2] * t.size()[3] \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/PWL1Loss.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/PWL1Loss.py" new file mode 100644 index 0000000000000000000000000000000000000000..25d9428a8b9e567e0779cc4281ddca1d402e97b4 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/PWL1Loss.py" @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + + +class PWL1Loss(nn.Module): + def __init__(self, ): + super(PWL1Loss, self).__init__() + self.gra = torch.nn.L1Loss() + + def forward(self, SR, GT): + """ + SR shape: B,band,H,W + """ + diff = self.gra(SR*GT,GT*GT) + return diff + # diff = torch.abs(SR - GT) + # if len(GT.shape) == 4: + # Batch, Band, H, W = GT.shape + # Weight = nn.Softmax2d()(GT) + # # Weight = nn.Softmax(dim=2)(GT.view(Batch, Band, -1)) + # # Weight = Weight.view(Batch, Band, H, W) + # elif len(GT.shape) == 3: + # Batch, H, W = GT.shape + # Weight = nn.Softmax(dim=1)(GT.view(Batch, -1)) + # Weight = Weight.view(Batch, H, W) + # else: + # raise Exception("shape error!") + # # Weight = GT + # Loss = torch.sum(Weight.mul(diff)) + # return Loss diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/SAMLoss.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/SAMLoss.py" new file mode 100644 index 0000000000000000000000000000000000000000..a83b328b08c59921b852c7c4bf89acc1c9478b48 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/SAMLoss.py" @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +import numpy as np + + +def cal_sam(Itrue, Ifake): + esp = 2.2204e-16 + InnerPro = torch.sum(Itrue*Ifake, 1, keepdim=True) + len1 = torch.norm(Itrue, p=2, dim=1, keepdim=True) + len2 = torch.norm(Ifake, p=2, dim=1, keepdim=True) + divisor = len1*len2 + mask = torch.eq(divisor, 0) + divisor = divisor + (mask.float())*esp + cosA = torch.sum(InnerPro/divisor, 1).clamp(-1+esp, 1-esp) + sam = torch.acos(cosA) + sam = torch.mean(sam)*180 / np.pi + return sam + + +class SAMLoss(nn.Module): + def __init__(self, ): + super(SAMLoss, self).__init__() + self.lam = 0.001 + + def forward(self, SR, GT): + """ + 计算光谱角损失,就是SAM值平均 + SR shape: B,H,W + """ + # eps = 2.2204e-16 + # # pred[np.where(pred==0)] = eps + # # gt[np.where(gt==0)] = eps + + # nom = torch.sum(GT * SR, dim=0) + # denom1 = torch.sum(SR*SR, dim=0)**0.5 + # denom2 = torch.sum(GT*GT, dim=0)**0.5 + # sam = torch.acos(nom/(denom1*denom2+eps)) + # # sam[.isnan(sam)] = 0 + # Loss = torch.mean(sam)*180/np.pi + return self.lam * cal_sam(GT, SR) diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__init__.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__init__.py" new file mode 100644 index 0000000000000000000000000000000000000000..06312ad97c260873f05e43b914b57ed5dc24d3b9 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__init__.py" @@ -0,0 +1,44 @@ +import torch.nn as nn +from loss.PWL1Loss import PWL1Loss +from loss.SAMLoss import SAMLoss +from loss.HLoss import HLoss +from loss.HybridLoss import HybridLoss +from loss.gradient import GLoss + +class Loss(): + def __init__(self, opt): + loss_list = opt.loss + self.criterions = {} + for loss_name in loss_list: + if loss_name == "L1": + self.criterions["L1"] = nn.L1Loss() + elif loss_name == "SAM": + self.criterions["SAM"] = SAMLoss() + elif loss_name == "PWL1": + self.criterions["PWL1"] = PWL1Loss() + elif loss_name == "HLoss": + self.criterions["HLoss"] = HLoss() + elif loss_name == "HybridLoss": + self.criterions["HybridLoss"] = HybridLoss() + elif loss_name == "GLoss": + self.criterions["GLoss"] = GLoss() + self.Num = len(loss_list) + + if opt.cuda: + for key, value in self.criterions.items(): + self.criterions[key] = self.criterions[key].cuda() + + def loss(self, SR, GT): + """ + 多 Loss 集成 + """ + losses = [] + for _, value in self.criterions.items(): + # if key != "SAM" or epoch >= self.SAMepoch: + loss_i = value(SR, GT) + # print(loss_i) + losses.append(loss_i) + final_Loss = losses[0] + for i in range(1, self.Num): + final_Loss += losses[i] + return final_Loss diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HLoss.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HLoss.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..26886b4d239fa08dad1372681d982486df5ccd97 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HLoss.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HybridLoss.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HybridLoss.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..c7fa641115cc8233372020675fdd2bd28a80317a Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/HybridLoss.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..5b4126c9dfbbb3b1191b3a121cc520132c00724b Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-38.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-38.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..40492d163627c1e9af43e63eab596e2258191bb5 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/PWL1Loss.cpython-38.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..7938ba27a5d5b9c51230722c962b046e2c5d7d6f Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-38.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-38.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..b965235cfd69d549cbfe9c40750c280ba4dc540d Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/SAMLoss.cpython-38.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..61cdce66678ac8ff94b1d9d47a6d95d2e5151913 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-38.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-38.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..83c6a0c0d9188aa6c499b1d92916ee8d8deeaa15 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/__init__.cpython-38.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/gradient.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/gradient.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..7da92fea9c0bc4b5fc14feda730c81fcd5e0dd71 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/__pycache__/gradient.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/gradient.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/gradient.py" new file mode 100644 index 0000000000000000000000000000000000000000..fe981d9591fbf6f193a835f593826594933cac89 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/loss/gradient.py" @@ -0,0 +1,39 @@ +import numpy as np +from scipy.signal import convolve2d +import torch +import torch.nn.functional as F + + +def cal_gradient_c(x): + c_x = x.size(1) + g = x[:, 1:, 1:, 1:] - x[:, :c_x - 1, 1:, 1:] + return g + +def cal_gradient_x(x): + c_x = x.size(2) + g = x[:, 1:, 1:, 1:] - x[:, 1:, :c_x - 1, 1:] + return g + +def cal_gradient_y(x): + c_x = x.size(3) + g = x[:, 1:, 1:, 1:] - x[:, 1:, 1:, :c_x - 1] + return g + +def cal_gradient(inp): + x = cal_gradient_x(inp) + y = cal_gradient_y(inp) + c = cal_gradient_c(inp) + g = torch.sqrt(torch.pow(x, 2) + torch.pow(y,2) + torch.pow(c,2)+1e-6) + return g + + + +class GLoss(torch.nn.Module): + def __init__(self,lamd2=0.5): + super(GLoss,self).__init__() + self.gra = torch.nn.L1Loss() + self.lamd2 = lamd2 + + def forward(self, y, gt): + loss = self.lamd2*self.gra(cal_gradient(y),cal_gradient(gt)) + return loss \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/Bicubic.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/Bicubic.py" new file mode 100644 index 0000000000000000000000000000000000000000..bb7a904f815f93163d519be084ba6403481c27dd --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/Bicubic.py" @@ -0,0 +1,16 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Bicubic(nn.Module): + def __init__(self, args): + super(Bicubic, self).__init__() + self.scale = args.upscale_factor + + def forward(self, x, h=None, i=None): + # x shape: B,3,H,W --->B,1,3,H,W --> B,N,3,H,W + x = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + return x + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/EDSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/EDSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..70ff9ca159fcc8c54c0e6d9b2bf7b07b1de4b403 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/EDSR.py" @@ -0,0 +1,69 @@ +from model import common2 + +import torch.nn as nn + + +class EDSR(nn.Module): + def __init__(self, args, conv=common2.default_conv): + super(EDSR, self).__init__() + + n_resblocks = 16 + n_feats = 128 + kernel_size = 3 + scale = args.upscale_factor + act = nn.ReLU(True) + # self.sub_mean = common2.MeanShift(255) + # self.add_mean = common2.MeanShift(255, sign=1) + + # define head module + m_head = [conv(args.band, n_feats, kernel_size)] + + # define body module + m_body = [ + common2.ResBlock( + conv, n_feats, kernel_size, act=act, res_scale=0.1 + ) for _ in range(n_resblocks) + ] + m_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + m_tail = [ + common2.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats, args.band, kernel_size) + ] + + self.head = nn.Sequential(*m_head) + self.body = nn.Sequential(*m_body) + self.tail = nn.Sequential(*m_tail) + + def forward(self, x): + # x = self.sub_mean(x) + x = self.head(x) + + res = self.body(x) + res += x + + x = self.tail(res) + # x = self.add_mean(x) + + return x + + def load_state_dict(self, state_dict, strict=True): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/ERCSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/ERCSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..d675557908e372c4c212879138628f68a5ba6e16 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/ERCSR.py" @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn + + +def _to_4d_tensor(x, depth_stride=None): + """Converts a 5d tensor to 4d by stacking + the batch and depth dimensions.""" + x = x.transpose(0, 2) # swap batch and depth dimensions: NxCxDxHxW => DxCxNxHxW + if depth_stride: + x = x[::depth_stride] # downsample feature maps along depth dimension + depth = x.size()[0] + x = x.permute(2, 0, 1, 3, 4) # DxCxNxHxW => NxDxCxHxW + x = torch.split( + x, 1, dim=0 + ) # split along batch dimension: NxDxCxHxW => N*[1xDxCxHxW] + x = torch.cat( + x, 1 + ) # concatenate along depth dimension: N*[1xDxCxHxW] => 1x(N*D)xCxHxW + x = x.squeeze(0) # 1x(N*D)xCxHxW => (N*D)xCxHxW + return x, depth + + +def _to_5d_tensor(x, depth): + """Converts a 4d tensor back to 5d by splitting + the batch dimension to restore the depth dimension.""" + x = torch.split(x, depth) # (N*D)xCxHxW => N*[DxCxHxW] + x = torch.stack(x, dim=0) # re-instate the batch dimension: NxDxCxHxW + x = x.transpose( + 1, 2 + ) # swap back depth and channel dimensions: NxDxCxHxW => NxCxDxHxW + return x + + +class twoUint(nn.Module): + def __init__(self, wn, n_feats): + super(twoUint, self).__init__() + self.relu = nn.ReLU(inplace=True) + + self.conv1 = wn( + nn.Conv2d(n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1)) + ) + self.conv2 = wn( + nn.Conv2d(n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1)) + ) + + def forward(self, x): + + out = self.conv1(x) + out = self.relu(out) + out = self.conv2(out) + out = torch.add(x, out) + + return out + + +class E_HCM(nn.Module): + def __init__(self, wn, n_feats, n_twoUint): + super(E_HCM, self).__init__() + self.relu = nn.ReLU(inplace=True) + + self.conv1 = wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ) + ) + self.conv2 = wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(3, 3, 1), stride=1, padding=(1, 1, 0) + ) + ) + self.conv3 = wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(3, 1, 3), stride=1, padding=(1, 0, 1) + ) + ) + + twoD_body = [twoUint(wn, n_feats) for _ in range(n_twoUint)] + + self.twoD_body = nn.Sequential(*twoD_body) + + def forward(self, x): + + out = self.conv1(x) + out = self.relu(out) + out = torch.add(self.conv2(out), self.conv3(out)) + + t = out + out, depth = _to_4d_tensor(out, depth_stride=1) + + out = self.twoD_body(out) + + out = _to_5d_tensor(out, depth) + out = torch.add(out, t) + out = torch.add(out, x) + + return out + + +class ERCSR(nn.Module): + def __init__(self, args): + super(ERCSR, self).__init__() + + scale = args.upscale_factor + n_colors = args.band + n_feats = args.n_feats + self.n_E_HCM = 4 + n_twoUint = 2 + + band_mean = ( + 0.0939, + 0.0950, + 0.0869, + 0.0839, + 0.0850, + 0.0809, + 0.0769, + 0.0762, + 0.0788, + 0.0790, + 0.0834, + 0.0894, + 0.0944, + 0.0956, + 0.0939, + 0.1187, + 0.0903, + 0.0928, + 0.0985, + 0.1046, + 0.1121, + 0.1194, + 0.1240, + 0.1256, + 0.1259, + 0.1272, + 0.1291, + 0.1300, + 0.1352, + 0.1428, + 0.1541, + ) # CAVE + # band_mean = (0.0100, 0.0137, 0.0219, 0.0285, 0.0376, 0.0424, 0.0512, 0.0651, 0.0694, 0.0723, 0.0816, + # 0.0950, 0.1338, 0.1525, 0.1217, 0.1187, 0.1337, 0.1481, 0.1601, 0.1817, 0.1752, 0.1445, + # 0.1450, 0.1378, 0.1343, 0.1328, 0.1303, 0.1299, 0.1456, 0.1433, 0.1303) #Hararvd + + # band_mean = (0.0483, 0.0400, 0.0363, 0.0373, 0.0425, 0.0520, 0.0559, 0.0539, 0.0568, 0.0564, 0.0591, + # 0.0678, 0.0797, 0.0927, 0.0986, 0.1086, 0.1086, 0.1015, 0.0994, 0.0947, 0.0980, 0.0973, + # 0.0925, 0.0873, 0.0887, 0.0854, 0.0844, 0.0833, 0.0823, 0.0866, 0.1171, 0.1538, 0.1535) #Foster + + wn = lambda x: torch.nn.utils.weight_norm(x) + # self.band_mean = torch.autograd.Variable(torch.FloatTensor(band_mean)).view( + # [1, n_colors, 1, 1] + # ) + self.relu = nn.ReLU(inplace=True) + + head = [] + head.append( + wn( + nn.Conv3d( + 1, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ) + ) + ) + head.append( + wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0) + ) + ) + ) + self.head = nn.Sequential(*head) + + self.nearest = nn.Upsample(scale_factor=scale, mode="nearest") + + body = [E_HCM(wn, n_feats, n_twoUint) for _ in range(self.n_E_HCM)] + + self.body = nn.Sequential(*body) + + self.reduceD = wn( + nn.Conv3d(n_feats * self.n_E_HCM, n_feats, kernel_size=(1, 1, 1), stride=1) + ) + self.gamma = nn.Parameter(torch.ones(self.n_E_HCM)) + + end = [] + end.append( + wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ) + ) + ) + end.append( + wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0) + ) + ) + ) + self.end = nn.Sequential(*end) + + tail = [] + tail.append( + wn( + nn.ConvTranspose3d( + n_feats, + n_feats, + kernel_size=(3, 2 + scale, 2 + scale), + stride=(1, scale, scale), + padding=(1, 1, 1), + ) + ) + ) + tail.append( + wn( + nn.Conv3d( + n_feats, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ) + ) + ) + tail.append( + wn( + nn.Conv3d( + n_feats, 1, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0) + ) + ) + ) + self.tail = nn.Sequential(*tail) + + def forward(self, x): + + # x = x - self.band_mean.cuda() + CSKC = self.nearest(x) + x = x.unsqueeze(1) + x = self.head(x) + LSC = x + + H = [] + for i in range(self.n_E_HCM): + x = self.body[i](x) + H.append(x * self.gamma[i]) + + x = torch.cat(H, 1) + + x = self.reduceD(x) + x = self.end(x) + x = torch.add(x, LSC) + + x = self.tail(x) + x = x.squeeze(1) + + x = torch.add(x, CSKC) + # x = x + self.band_mean.cuda() + return x + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/GDRRN.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/GDRRN.py" new file mode 100644 index 0000000000000000000000000000000000000000..a53f448679553d3f7dae22a9e19d7c53744dd5d9 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/GDRRN.py" @@ -0,0 +1,179 @@ +import torch +import numpy as np +from math import sqrt +import torch.nn as nn +import torch.nn.functional as F + +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, groups=1, bias=False): + super(Conv, self).__init__() + if padding == None: + if stride == 1: + padding = (kernel_size-1)//2 + else: + padding = 0 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups, bias=bias) + + def forward(self, x): + return self.conv(x) + +class Conv_ReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, groups=1, bias=True): + super(Conv_ReLU, self).__init__() + if padding == None: + if stride == 1: + padding = (kernel_size-1)//2 + else: + padding = 0 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups, bias=bias) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + return self.relu(x) + +class Conv_BN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, groups=1, bias=False): + super(Conv_BN, self).__init__() + if padding == None: + if stride == 1: + padding = (kernel_size-1)//2 + else: + padding = 0 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + return self.bn(self.conv(x)) + +class Conv_BN_ReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, groups=1, bias=False): + super(Conv_BN_ReLU, self).__init__() + if padding == None: + if stride == 1: + padding = (kernel_size-1)//2 + else: + padding = 0 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class Denoise_Block_BN(nn.Module): + def __init__(self, input_chnl, output_chnl=None,inner_chnl=64, padding=1, num_of_layers=15, groups=1): + super(Denoise_Block_BN, self).__init__() + kernel_size = 3 + num_chnl = inner_chnl + if output_chnl is None: + output_chnl = input_chnl + self.conv_input = nn.Sequential(Conv_BN_ReLU(in_channels=input_chnl, out_channels=num_chnl, kernel_size=kernel_size, padding=padding, groups=groups)) + self.conv_layers = self._make_layers(Conv_BN_ReLU,num_chnl=num_chnl, kernel_size=kernel_size, padding=padding, num_of_layers=num_of_layers-2, groups=groups) + self.conv_out = nn.Sequential(Conv_BN_ReLU(in_channels=num_chnl, out_channels=output_chnl, kernel_size=kernel_size, padding=padding, groups=groups)) + + def _make_layers(self, block, num_chnl, kernel_size, padding, num_of_layers, groups=1): + layers = [] + for _ in range(num_of_layers): + layers.append(block(in_channels=num_chnl, out_channels=num_chnl, kernel_size=kernel_size, padding=padding, groups=groups)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.conv_out(self.conv_layers(self.conv_input(x))) + +class DnCNN(nn.Module): + def __init__(self, input_chnl, groups=1): + super(DnCNN, self).__init__() + kernel_size = 3 + num_chnl = 64 + self.conv1 = nn.Sequential(nn.Conv2d(in_channels=input_chnl, out_channels=num_chnl, + kernel_size=kernel_size, stride=1, padding=1, + groups=1, bias=True), + nn.ReLU(inplace=True)) + self.dn_block = self._make_layers(Conv_BN_ReLU, kernel_size, num_chnl, num_of_layers=15, bias=False) + # self.output = nn.Sequential(nn.Conv2d(in_channels=num_chnl, out_channels=input_chnl, + # kernel_size=kernel_size, stride=1, padding=1, + # groups=groups, bias=True), + # nn.BatchNorm2d(input_chnl)) + self.output = nn.Conv2d(in_channels=num_chnl, out_channels=input_chnl, + kernel_size=kernel_size, stride=1, padding=1, + groups=groups, bias=True) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5) + if isinstance(m, nn.BatchNorm2d): + m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5) + clip_b = 0.025 + w = m.weight.data.shape[0] + for j in range(w): + if m.weight.data[j] >= 0 and m.weight.data[j] < clip_b: + m.weight.data[j] = clip_b + elif m.weight.data[j] > -clip_b and m.weight.data[j] < 0: + m.weight.data[j] = -clip_b + m.running_var.fill_(0.01) + + def _make_layers(self, block, kernel_size, num_chnl, num_of_layers, padding=1, groups=1, bias=False): + layers = [] + for _ in range(num_of_layers): + layers.append(block(in_channels=num_chnl, out_channels=num_chnl, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias)) + return nn.Sequential(*layers) + + def forward(self, x): + residual = x + x = self.conv1(x) + # x = self.nl(x) + x = self.dn_block(x) + return self.output(x) #+ residual + +class GDRRN(nn.Module): + def __init__(self, opt): + super(GDRRN, self).__init__() + input_chnl_hsi = opt.band + self.scale = opt.upscale_factor + num_chnl = 128 + group = 2 + self.input = nn.Conv2d(in_channels=input_chnl_hsi, out_channels=num_chnl, kernel_size=3, stride=1, padding=1, bias=False, groups=1) + self.conv1 = nn.Conv2d(in_channels=num_chnl, out_channels=num_chnl, kernel_size=3, stride=1, padding=1, bias=False, groups=group) + self.conv2 = nn.Conv2d(in_channels=num_chnl, out_channels=num_chnl, kernel_size=3, stride=1, padding=1, bias=False, groups=group) + self.output = nn.Conv2d(in_channels=num_chnl, out_channels=input_chnl_hsi, kernel_size=3, stride=1, padding=1, bias=False, groups=1) + self.relu = nn.ReLU(inplace=True) + + # weights initialization + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + residual = x + inputs = self.input(self.relu(x)) + out = inputs + # rnd = np.array([j for j in range(self.conv1.out_channels)]) + # for i in range(self.conv1.groups): + # rnd[i * (self.conv1.out_channels // self.conv1.groups):(i + 1) * ( + # self.conv1.out_channels // self.conv1.groups)] = \ + # np.arange(i, self.conv1.out_channels, self.conv1.groups) + for _ in range(9): + skip_x = out + out = self.conv1(self.relu(out)) + # out.data = out.data[:, rnd, :, :] + out = self.conv2(self.relu(out)) + out = out + skip_x + out = torch.add(out, inputs) + + out = self.output(self.relu(out)) + out = torch.add(out, residual) + return out \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/MCNet.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/MCNet.py" new file mode 100644 index 0000000000000000000000000000000000000000..80899562d4df5645eaee2980a2f6a9e5aeb7b0c8 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/MCNet.py" @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import pdb + + +class BasicConv3d(nn.Module): + def __init__( + self, wn, in_channel, out_channel, kernel_size, stride, padding=(0, 0, 0) + ): + super(BasicConv3d, self).__init__() + self.conv = wn( + nn.Conv3d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + + x = self.conv(x) + x = self.relu(x) + return x + + +class S3Dblock(nn.Module): + def __init__(self, wn, n_feats): + super(S3Dblock, self).__init__() + + self.conv = nn.Sequential( + BasicConv3d( + wn, n_feats, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ), + BasicConv3d( + wn, n_feats, n_feats, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0) + ), + ) + + def forward(self, x): + + return self.conv(x) + + +def _to_4d_tensor(x, depth_stride=None): + """Converts a 5d tensor to 4d by stackin + the batch and depth dimensions.""" + x = x.transpose(0, 2) # swap batch and depth dimensions: NxCxDxHxW => DxCxNxHxW + if depth_stride: + x = x[::depth_stride] # downsample feature maps along depth dimension + depth = x.size()[0] + x = x.permute(2, 0, 1, 3, 4) # DxCxNxHxW => NxDxCxHxW + x = torch.split( + x, 1, dim=0 + ) # split along batch dimension: NxDxCxHxW => N*[1xDxCxHxW] + x = torch.cat( + x, 1 + ) # concatenate along depth dimension: N*[1xDxCxHxW] => 1x(N*D)xCxHxW + x = x.squeeze(0) # 1x(N*D)xCxHxW => (N*D)xCxHxW + return x, depth + + +def _to_5d_tensor(x, depth): + """Converts a 4d tensor back to 5d by splitting + the batch dimension to restore the depth dimension.""" + x = torch.split(x, depth) # (N*D)xCxHxW => N*[DxCxHxW] + x = torch.stack(x, dim=0) # re-instate the batch dimension: NxDxCxHxW + x = x.transpose( + 1, 2 + ) # swap back depth and channel dimensions: NxDxCxHxW => NxCxDxHxW + return x + + +class Block(nn.Module): + def __init__(self, wn, n_feats, n_conv): + super(Block, self).__init__() + + # self.relu = nn.ReLU(inplace=True) + self.relu = nn.ReLU() + + Block1 = [] + for i in range(n_conv): + Block1.append(S3Dblock(wn, n_feats)) + self.Block1 = nn.Sequential(*Block1) + + Block2 = [] + for i in range(n_conv): + Block2.append(S3Dblock(wn, n_feats)) + self.Block2 = nn.Sequential(*Block2) + + Block3 = [] + for i in range(n_conv): + Block3.append(S3Dblock(wn, n_feats)) + self.Block3 = nn.Sequential(*Block3) + + self.reduceF = BasicConv3d(wn, n_feats * 3, n_feats, kernel_size=1, stride=1) + self.Conv = S3Dblock(wn, n_feats) + self.gamma = nn.Parameter(torch.ones(3)) + + conv1 = [] + conv1.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + conv1.append(self.relu) + conv1.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + self.conv1 = nn.Sequential(*conv1) + + conv2 = [] + conv2.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + conv2.append(self.relu) + conv2.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + self.conv2 = nn.Sequential(*conv2) + + conv3 = [] + conv3.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + conv3.append(self.relu) + conv3.append( + wn( + nn.Conv2d( + n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1) + ) + ) + ) + self.conv3 = nn.Sequential(*conv3) + + def forward(self, x): + + res = x + x1 = self.Block1(x) + x + x2 = self.Block2(x1) + x1 + x3 = self.Block3(x2) + x2 + + x1, depth = _to_4d_tensor(x1, depth_stride=1) + x1 = self.conv1(x1) + x1 = _to_5d_tensor(x1, depth) + + x2, depth = _to_4d_tensor(x2, depth_stride=1) + x2 = self.conv2(x2) + x2 = _to_5d_tensor(x2, depth) + + x3, depth = _to_4d_tensor(x3, depth_stride=1) + x3 = self.conv3(x3) + x3 = _to_5d_tensor(x3, depth) + + x = torch.cat([self.gamma[0] * x1, self.gamma[1] * x2, self.gamma[2] * x3], 1) + x = self.reduceF(x) + x = self.relu(x) + x = x + res + + x = self.Conv(x) + return x + + +class MCNet(nn.Module): + def __init__(self, args): + super(MCNet, self).__init__() + + scale = args.upscale_factor + n_colors = args.band + n_feats = args.n_feats + n_conv = 1 + kernel_size = 3 + + band_mean = ( + 0.0939, + 0.0950, + 0.0869, + 0.0839, + 0.0850, + 0.0809, + 0.0769, + 0.0762, + 0.0788, + 0.0790, + 0.0834, + 0.0894, + 0.0944, + 0.0956, + 0.0939, + 0.1187, + 0.0903, + 0.0928, + 0.0985, + 0.1046, + 0.1121, + 0.1194, + 0.1240, + 0.1256, + 0.1259, + 0.1272, + 0.1291, + 0.1300, + 0.1352, + 0.1428, + 0.1541, + ) # CAVE + # band_mean = (0.0100, 0.0137, 0.0219, 0.0285, 0.0376, 0.0424, 0.0512, 0.0651, 0.0694, 0.0723, 0.0816, + # 0.0950, 0.1338, 0.1525, 0.1217, 0.1187, 0.1337, 0.1481, 0.1601, 0.1817, 0.1752, 0.1445, + # 0.1450, 0.1378, 0.1343, 0.1328, 0.1303, 0.1299, 0.1456, 0.1433, 0.1303) #Hararvd + + # band_mean = (0.0944, 0.1143, 0.1297, 0.1368, 0.1599, 0.1853, 0.2029, 0.2149, 0.2278, 0.2275, 0.2311, + # 0.2331, 0.2265, 0.2347, 0.2384, 0.1187, 0.2425, 0.2441, 0.2471, 0.2453, 0.2494, 0.2584, + # 0.2597, 0.2547, 0.2552, 0.2434, 0.2386, 0.2385, 0.2326, 0.2112, 0.2227) #ICVL + + # band_mean = (0.0483, 0.0400, 0.0363, 0.0373, 0.0425, 0.0520, 0.0559, 0.0539, 0.0568, 0.0564, 0.0591, + # 0.0678, 0.0797, 0.0927, 0.0986, 0.1086, 0.1086, 0.1015, 0.0994, 0.0947, 0.0980, 0.0973, + # 0.0925, 0.0873, 0.0887, 0.0854, 0.0844, 0.0833, 0.0823, 0.0866, 0.1171, 0.1538, 0.1535) #Foster + + # band_mean = (0.0595, 0.0600, 0.0651, 0.0639, 0.0641, 0.0637, 0.0646, 0.0618, 0.0679, 0.0641, 0.0677, + # 0.0650, 0.0671, 0.0687, 0.0693, 0.0687, 0.0688, 0.0677, 0.0689, 0.0736, 0.0735, 0.0728, 0.0713, 0.0734, + # 0.0726, 0.0722, 0.074, 0.0742, 0.0794, 0.0892, 0.1005) #Foster2002 + wn = lambda x: torch.nn.utils.weight_norm(x) + # self.band_mean = torch.autograd.Variable(torch.FloatTensor(band_mean)).view( + # [1, n_colors, 1, 1] + # ) + + self.head = wn(nn.Conv3d(1, n_feats, kernel_size, padding=kernel_size // 2)) + + self.SSRM1 = Block(wn, n_feats, n_conv) + self.SSRM2 = Block(wn, n_feats, n_conv) + self.SSRM3 = Block(wn, n_feats, n_conv) + self.SSRM4 = Block(wn, n_feats, n_conv) + + tail = [] + tail.append( + wn( + nn.ConvTranspose3d( + n_feats, + n_feats, + kernel_size=(3, 2 + scale, 2 + scale), + stride=(1, scale, scale), + padding=(1, 1, 1), + ) + ) + ) + tail.append(wn(nn.Conv3d(n_feats, 1, kernel_size, padding=kernel_size // 2))) + self.tail = nn.Sequential(*tail) + + def forward(self, x): + + # x = x - self.band_mean.cuda() + x = x.unsqueeze(1) + T = self.head(x) + + x = self.SSRM1(T) + x = torch.add(x, T) + + x = self.SSRM2(x) + x = torch.add(x, T) + + x = self.SSRM3(x) + x = torch.add(x, T) + + x = self.SSRM4(x) + x = torch.add(x, T) + + x = self.tail(x) + x = x.squeeze(1) + # x = x + self.band_mean.cuda() + return x + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/RFSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/RFSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..cf1c43d80e3ddd00d330931dcfe5c6df30592f24 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/RFSR.py" @@ -0,0 +1,178 @@ +#!/usr/bin/env python +# coding: utf-8 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class Res3DBlock(nn.Module): + def __init__(self, n_feats, bias=True, act=nn.ReLU(True), res_scale=1): + super(Res3DBlock, self).__init__() + + self.body = nn.Sequential(nn.Conv3d(1, n_feats, (3,1,1),1,(1,0,0), bias=bias), + act, + nn.Conv3d(n_feats, 1, (1,3,3),1,(0,1,1), bias=bias) + ) + self.res_scale = res_scale + + def forward(self, x): + x = self.body(x.unsqueeze(1))+x.unsqueeze(1) + return x.squeeze(1) + +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class RCAB(nn.Module): + def __init__( + self, n_feat,reduction=16, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(nn.Conv2d(n_feat, n_feat, 3,1,1, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +class Upsampler(nn.Sequential): + def __init__(self, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(n_feats, 4 * n_feats, 3,1,1,bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(nn.Conv2d(n_feats, 9* n_feats, 3,1,1,bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + +class ShuffleDown(nn.Module): + def __init__(self, scale): + super(ShuffleDown, self).__init__() + self.scale = scale + + def forward(self, x): + b, cin, hin, win= x.size() + cout = cin * self.scale ** 2 + hout = hin // self.scale + wout = win // self.scale + output = x.view(b, cin, hout, self.scale, wout, self.scale) + output = output.permute(0, 1, 5, 3, 2, 4).contiguous() + output = output.view(b, cout, hout, wout) + return output + +class Net(nn.Module): + def __init__(self, opt): + super(Net, self).__init__() + self.n_feats = 64 + self.kernel_size = 3 + self.devices = torch.device("cuda") + + self.scale = opt.upscale_factor + self.band = opt.band + if opt.band == 31: + self.g = 8 + self.sub= 4 + elif opt.band == 128: + self.g = 8 + self.sub = 16 + + self.layer1 = default_conv(self.sub+self.n_feats+self.sub*self.scale ** 2, self.n_feats, self.kernel_size) + + self.out_layer1 = default_conv(self.n_feats, self.sub,self.kernel_size) + self.out_layer2 = default_conv(self.n_feats, self.n_feats, self.kernel_size) + + if self.band == 31: + n_a=16 + else: + n_a=6 + body1 = [RCAB(self.n_feats) for _ in range(n_a)] + self.RB1 = nn.Sequential(*body1) + self.up = Upsampler(self.scale, self.n_feats) + self.down = ShuffleDown(self.scale) + + self.act = nn.ReLU(True) + # if self.band == 31: + # n_b = 3 + # else: + # n_b= 1 + # body2 = [Res3DBlock(opt.band) for _ in range(n_b)] + # self.body2 = nn.Sequential(*body2) + + def forward(self, x): + out = [] + B,C,h,w =x.shape + + if self.band == 31: + p=self.sub-C%self.sub + ini = torch.zeros(B,p,h,w).to(self.devices) + x=torch.cat([x,ini],1) + + h1 = torch.zeros(B,self.n_feats,h,w).to(self.devices) + sr = torch.zeros(B,self.sub*self.scale ** 2,h,w).to(self.devices) + + for x_ilr in torch.chunk(x, self.g, 1): + h1 = self.act(self.layer1(torch.cat([h1,sr,x_ilr], dim=1))) + h1 = self.RB1(h1) + sr = self.out_layer1(self.up(h1)) + F.interpolate(x_ilr,(h*self.scale,w*self.scale)) + h1 = self.out_layer2(h1) + out.append(sr) + sr = self.down(sr) + + out = torch.cat(out[:],1)[:,0:C,:,:] + # out = self.body2(out) + return out + + + + + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SFCSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SFCSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..f981ccef30e03ccaf0bd67dad67a930f24172c68 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SFCSR.py" @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import math + + +class TwoCNN(nn.Module): + def __init__(self, wn, n_feats=64): + super(TwoCNN, self).__init__() + + self.body = wn(nn.Conv2d(n_feats, n_feats, + kernel_size=(3, 3), stride=1, padding=(1, 1))) + + def forward(self, x): + + out = self.body(x) + out = torch.add(out, x) + + return out + + +class ThreeCNN(nn.Module): + def __init__(self, wn, n_feats=64): + super(ThreeCNN, self).__init__() + self.act = nn.ReLU(inplace=True) + + body_spatial = [] + for i in range(2): + body_spatial.append(wn(nn.Conv3d(n_feats, n_feats, kernel_size=( + 1, 3, 3), stride=1, padding=(0, 1, 1)))) + + body_spectral = [] + for i in range(2): + body_spectral.append(wn(nn.Conv3d( + n_feats, n_feats, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)))) + + self.body_spatial = nn.Sequential(*body_spatial) + self.body_spectral = nn.Sequential(*body_spectral) + + def forward(self, x): + out = x + for i in range(2): + + out = torch.add(self.body_spatial[i]( + out), self.body_spectral[i](out)) + if i == 0: + out = self.act(out) + + out = torch.add(out, x) + return out + + +class SFCSR(nn.Module): + def __init__(self, args): + super(SFCSR, self).__init__() + + scale = args.upscale_factor + n_feats = args.n_feats + # self.n_module = 5 + self.n_module = args.n_module + + def wn(x): return torch.nn.utils.weight_norm(x) + + self.gamma_X = nn.Parameter(torch.ones(self.n_module)) + self.gamma_Y = nn.Parameter(torch.ones(self.n_module)) + self.gamma_DFF = nn.Parameter(torch.ones(4)) + self.gamma_FCF = nn.Parameter(torch.ones(2)) + + ThreeHead = [] + ThreeHead.append( + wn(nn.Conv3d(1, n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1)))) + ThreeHead.append(wn(nn.Conv3d(n_feats, n_feats, kernel_size=( + 3, 1, 1), stride=1, padding=(1, 0, 0)))) + self.ThreeHead = nn.Sequential(*ThreeHead) + + TwoHead = [] + TwoHead.append( + wn(nn.Conv2d(1, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1)))) + self.TwoHead = nn.Sequential(*TwoHead) + + TwoTail = [] + if (scale & (scale - 1)) == 0: + for _ in range(int(math.log(scale, 2))): + TwoTail.append( + wn(nn.Conv2d(n_feats, n_feats*4, kernel_size=(3, 3), stride=1, padding=(1, 1)))) + TwoTail.append(nn.PixelShuffle(2)) + else: + TwoTail.append( + wn(nn.Conv2d(n_feats, n_feats*9, kernel_size=(3, 3), stride=1, padding=(1, 1)))) + TwoTail.append(nn.PixelShuffle(3)) + + TwoTail.append( + wn(nn.Conv2d(n_feats, 1, kernel_size=(3, 3), stride=1, padding=(1, 1)))) + self.TwoTail = nn.Sequential(*TwoTail) + + twoCNN = [] + for _ in range(self.n_module): + twoCNN.append(TwoCNN(wn, n_feats)) + self.twoCNN = nn.Sequential(*twoCNN) + + self.reduceD_Y = wn(nn.Conv2d(n_feats*self.n_module, + n_feats, kernel_size=(1, 1), stride=1)) + self.twofusion = wn( + nn.Conv2d(n_feats, n_feats, kernel_size=(3, 3), stride=1, padding=(1, 1))) + + threeCNN = [] + for _ in range(self.n_module): + threeCNN.append(ThreeCNN(wn, n_feats)) + self.threeCNN = nn.Sequential(*threeCNN) + + reduceD = [] + for _ in range(self.n_module): + reduceD.append( + wn(nn.Conv2d(n_feats*4, n_feats, kernel_size=(1, 1), stride=1))) + self.reduceD = nn.Sequential(*reduceD) + + self.reduceD_X = wn(nn.Conv3d(n_feats*self.n_module, + n_feats, kernel_size=(1, 1, 1), stride=1)) + + threefusion = [] + threefusion.append(wn(nn.Conv3d(n_feats, n_feats, kernel_size=( + 1, 3, 3), stride=1, padding=(0, 1, 1)))) + threefusion.append(wn(nn.Conv3d(n_feats, n_feats, kernel_size=( + 3, 1, 1), stride=1, padding=(1, 0, 0)))) + self.threefusion = nn.Sequential(*threefusion) + + self.reduceD_DFF = wn( + nn.Conv2d(n_feats*4, n_feats, kernel_size=(1, 1), stride=1)) + self.conv_DFF = wn( + nn.Conv2d(n_feats, n_feats, kernel_size=(1, 1), stride=1)) + + self.reduceD_FCF = wn( + nn.Conv2d(n_feats*2, n_feats, kernel_size=(1, 1), stride=1)) + self.conv_FCF = wn( + nn.Conv2d(n_feats, n_feats, kernel_size=(1, 1), stride=1)) + + def forward(self, x, y, localFeats, i): + x = x.unsqueeze(1) + x = self.ThreeHead(x) + skip_x = x + + y = y.unsqueeze(1) + y = self.TwoHead(y) + skip_y = y + + channelX = [] + channelY = [] + + for j in range(self.n_module): + x = self.threeCNN[j](x) + x = torch.add(skip_x, x) + channelX.append(self.gamma_X[j]*x) + + y = self.twoCNN[j](y) + y = torch.cat( + [y, x[:, :, 0, :, :], x[:, :, 1, :, :], x[:, :, 2, :, :]], 1) + y = self.reduceD[j](y) + y = torch.add(skip_y, y) + channelY.append(self.gamma_Y[j]*y) + + x = torch.cat(channelX, 1) + x = self.reduceD_X(x) + x = self.threefusion(x) + + y = torch.cat(channelY, 1) + y = self.reduceD_Y(y) + y = self.twofusion(y) + + y = torch.cat([self.gamma_DFF[0]*x[:, :, 0, :, :], self.gamma_DFF[1]*x[:, + :, 1, :, :], self.gamma_DFF[2]*x[:, :, 2, :, :], self.gamma_DFF[3]*y], 1) + + y = self.reduceD_DFF(y) + y = self.conv_DFF(y) + + if i == 0: + localFeats = y + else: + y = torch.cat( + [self.gamma_FCF[0]*y, self.gamma_FCF[1]*localFeats], 1) + y = self.reduceD_FCF(y) + y = self.conv_FCF(y) + localFeats = y + y = torch.add(y, skip_y) + y = self.TwoTail(y) + y = y.squeeze(1) + + return y, localFeats diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SGSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SGSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..ceed4c30ee7c7dd18c00f101d5e45fa16540cb82 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SGSR.py" @@ -0,0 +1,91 @@ +from os import replace +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.common import * + + +# class BranchUnit(nn.Module): +# def __init__(self, args): +# super(BranchUnit, self).__init__() +# padding (kernel+1)//2 -1 +class SGSR(nn.Module): + def __init__(self, args): + super(SGSR, self).__init__() + + self.scale = args.upscale_factor + self.n_feats = args.n_feats + self.n_module = args.n_module + self.windSize = args.window_size + + self.gamma_FCF = nn.Parameter(torch.ones(2)) + + ThreeHead = [] + ThreeHead.append( + nn.Conv3d( + 1, self.n_feats, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1) + ) + ) + ThreeHead.append( + nn.Conv3d( + self.n_feats, + self.n_feats, + kernel_size=(3, 1, 1), + stride=1, + padding=(1, 0, 0), + ) + ) + self.ThreeHead = nn.Sequential(*ThreeHead) + # SSFFBs = [SSFFB(self.n_feats) for _ in range(self.n_module)] + SSFFBs = [Res3DBlock(self.n_feats) for _ in range(self.n_module)] + # SSFFBs = [NL_FFC() for _ in range(self.n_module)] + # SSFFBs = [GEB() for _ in range(self.n_module)] + self.SSFFBs = nn.Sequential(*SSFFBs) + + self.reduceD_DFF = nn.Conv2d( + self.n_feats * self.windSize, self.n_feats, kernel_size=(1, 1), stride=1 + ) + self.conv_DFF = nn.Conv2d( + self.n_feats, self.n_feats, kernel_size=(1, 1), stride=1 + ) + + self.reduceD_FCF = nn.Conv2d( + self.n_feats * 2, self.n_feats, kernel_size=(1, 1), stride=1 + ) + + self.conv_FCF = nn.Conv2d( + self.n_feats, self.n_feats, kernel_size=(1, 1), stride=1 + ) + self.Up = Upsampler(self.scale, args.n_feats) + self.final = nn.Conv2d( + self.n_feats, self.windSize, kernel_size=(3, 3), stride=1, padding=1 + ) + # self.act = nn.ReLU() + + def forward(self, x, h=None, i=None): + # x shape: B,3,H,W --->B,1,3,H,W --> B,N,3,H,W + y = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + x = x.unsqueeze(1) + x = self.ThreeHead(x) + + skip_x = x + for j in range(self.n_module): + x = self.SSFFBs[j](x) + + x = x + skip_x + + x = x.view(x.shape[0], -1, x.shape[3], x.shape[4]) + x = self.reduceD_DFF(x) + x = self.conv_DFF(x) + + if i != 0: # B,N,H,W + x = torch.cat([self.gamma_FCF[0] * x, self.gamma_FCF[1] * h], 1) + x = self.reduceD_FCF(x) + x = self.conv_FCF(x) + h = x + + + x = self.final(self.Up(x)) + y + return x, h diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SSPSR.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SSPSR.py" new file mode 100644 index 0000000000000000000000000000000000000000..38b5cf4aa7ef78c1582f40bb805b3cc3918c69c3 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/SSPSR.py" @@ -0,0 +1,242 @@ +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + + +def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1): + if dilation==1: + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + elif dilation==2: + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=2, bias=bias, dilation=dilation) + + else: + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=3, bias=bias, dilation=dilation) + +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class ResBlock(nn.Module): + def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class ResAttentionBlock(nn.Module): + def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + super(ResAttentionBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + m.append(CALayer(n_feats, 16)) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + + +class SSB(nn.Module): + def __init__(self, n_feats, kernel_size, act, res_scale, conv=default_conv): + super(SSB, self).__init__() + self.spa = ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale) + self.spc = ResAttentionBlock(conv, n_feats, 1, act=act, res_scale=res_scale) + + def forward(self, x): + return self.spc(self.spa(x)) + + +class SSPN(nn.Module): + def __init__(self, n_feats, n_blocks, act, res_scale): + super(SSPN, self).__init__() + + kernel_size = 3 + m = [] + + for i in range(n_blocks): + m.append(SSB(n_feats, kernel_size, act=act, res_scale=res_scale)) + + self.net = nn.Sequential(*m) + + def forward(self, x): + res = self.net(x) + res += x + + return res + + +# a single branch of proposed SSPSR +class BranchUnit(nn.Module): + def __init__(self, n_colors, n_feats, n_blocks, act, res_scale, up_scale, use_tail=True, conv=default_conv): + super(BranchUnit, self).__init__() + kernel_size = 3 + self.head = conv(n_colors, n_feats, kernel_size) + self.body = SSPN(n_feats, n_blocks, act, res_scale) + self.upsample = Upsampler(conv, up_scale, n_feats) + self.tail = None + + if use_tail: + self.tail = conv(n_feats, n_colors, kernel_size) + + def forward(self, x): + y = self.head(x) + y = self.body(y) + y = self.upsample(y) + if self.tail is not None: + y = self.tail(y) + + return y + + +class SSPSR(nn.Module): + def __init__(self, args,res_scale=0.1, use_share=True, conv=default_conv): + super(SSPSR, self).__init__() + kernel_size = 3 + self.shared = use_share + act = nn.ReLU(True) + + + n_subs = 4 + n_colors = args.band + n_ovls = 1 + n_blocks = 3 + n_feats = 256 + n_scale = args.upscale_factor + self.scale = n_scale + + + # calculate the group number (the number of branch networks) + self.G = math.ceil((n_colors - n_ovls) / (n_subs - n_ovls)) + # calculate group indices + self.start_idx = [] + self.end_idx = [] + + for g in range(self.G): + sta_ind = (n_subs - n_ovls) * g + end_ind = sta_ind + n_subs + if end_ind > n_colors: + end_ind = n_colors + sta_ind = n_colors - n_subs + self.start_idx.append(sta_ind) + self.end_idx.append(end_ind) + + if self.shared: + self.branch = BranchUnit(n_subs, n_feats, n_blocks, act, res_scale, up_scale=n_scale//2, conv=default_conv) + # up_scale=n_scale//2 means that we upsample the LR input n_scale//2 at the branch network, and then conduct 2 times upsampleing at the global network + else: + self.branch = nn.ModuleList() + for i in range(self.G): + self.branch.append(BranchUnit(n_subs, n_feats, n_blocks, act, res_scale, up_scale=2, conv=default_conv)) + + self.trunk = BranchUnit(n_colors, n_feats, n_blocks, act, res_scale, up_scale=2, use_tail=False, conv=default_conv) + self.skip_conv = conv(n_colors, n_feats, kernel_size) + self.final = conv(n_feats, n_colors, kernel_size) + self.sca = n_scale//2 + + def forward(self, x): + b, c, h, w = x.shape + lms = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + + # Initialize intermediate “result”, which is upsampled with n_scale//2 times + y = torch.zeros(b, c, self.sca * h, self.sca * w).cuda() + + channel_counter = torch.zeros(c).cuda() + + for g in range(self.G): + sta_ind = self.start_idx[g] + end_ind = self.end_idx[g] + + xi = x[:, sta_ind:end_ind, :, :] + if self.shared: + xi = self.branch(xi) + else: + xi = self.branch[g](xi) + + y[:, sta_ind:end_ind, :, :] += xi + channel_counter[sta_ind:end_ind] = channel_counter[sta_ind:end_ind] + 1 + + # intermediate “result” is averaged according to their spectral indices + y = y / channel_counter.unsqueeze(1).unsqueeze(2) + + y = self.trunk(y) + y = y + self.skip_conv(lms) + y = self.final(y) + + return y \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/THreeDFCNN.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/THreeDFCNN.py" new file mode 100644 index 0000000000000000000000000000000000000000..aa80a7a8ff281d1fdd165de94692fa6493c135ec --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/THreeDFCNN.py" @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + + +class ThreeDFCNN(nn.Module): + def __init__(self, args): + super(ThreeDFCNN, self).__init__() + self.scale = args.upscale_factor + + self.f3d_1 = nn.Conv3d( + 1, 64, kernel_size=(7, 9, 9), stride=1, padding=(3,4,4) + ) # 64表示filter number + # 对于3d卷积来说,filter number不能再理解为通常意义上的通道数了。 + # 输入图像的通道数被band num所取代 + # 3D卷积相对于2D卷积最大的优势就是,如果我愿意,我可以让输入输出之间的bandnum保持不变。 + # 3D卷积多出来的哪一维可以为光谱波段之间的关系建模。 + + self.f3d_2 = nn.Conv3d( + 64, 32, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0) + ) + self.f3d_3 = nn.Conv3d( + 32, 9, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0) + ) + self.f3d_4 = nn.Conv3d(9, 1, kernel_size=(3, 5, 5), stride=1, padding=(1, 2, 2)) + + self.relu = nn.ReLU() + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + # print(x.shape) + x = x.unsqueeze(1) + x = self.relu(self.f3d_1(x)) + x = self.relu(self.f3d_2(x)) + x = self.relu(self.f3d_3(x)) + x = self.f3d_4(x) + x = x.squeeze(1) + # print(x.shape) + return x diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__init__.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__init__.py" new file mode 100644 index 0000000000000000000000000000000000000000..649e9d9a2a7341f79904f9c0fe4e540b88ac0b47 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__init__.py" @@ -0,0 +1,37 @@ +from model.SFCSR import SFCSR +# from model.Branch import BranchUnit +from model.SGSR import SGSR +from model.THreeDFCNN import ThreeDFCNN +from model.Bicubic import Bicubic +from model.ERCSR import ERCSR +from model.MCNet import MCNet +from model.GDRRN import GDRRN +from model.EDSR import EDSR +from model.RFSR import Net +from model.SSPSR import SSPSR +from model.baseline import Baseline + +def Model(opt): + if opt.method == "SFCSR": + model = SFCSR(opt) + elif opt.method == "baseline": + model = Baseline(opt) + elif opt.method == "SGSR": + model = SGSR(opt) + elif opt.method == "Bicubic": + model = Bicubic(opt) + elif opt.method == "3DFCNN": + model = ThreeDFCNN(opt) + elif opt.method == "ERCSR": + model = ERCSR(opt) + elif opt.method == "MCNet": + model = MCNet(opt) + elif opt.method == "GDRRN": + model = GDRRN(opt) + elif opt.method == "EDSR": + model = EDSR(opt) + elif opt.method == "RFSR": + model = Net(opt) + elif opt.method == "SSPSR": + model = SSPSR(opt) + return model diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/AttModel.cpython-38.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/AttModel.cpython-38.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..f0e1e93c79bae0bdc730dac650169a180d81072e Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/AttModel.cpython-38.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Bicubic.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Bicubic.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..8ebe22a8ad20cc6153131ec0c99b149d7473ebf0 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Bicubic.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Branch.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Branch.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..ed8707f758916599c589b8f2c2434e4c5cfc6df3 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/Branch.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/baseline.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/baseline.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..2cd639b27b02b238ef37afbc47b6281708a13a9c Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/baseline.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..f9c3001e05d50ff5379494b72ee7e727a62190d5 Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common2.cpython-36.pyc" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common2.cpython-36.pyc" new file mode 100644 index 0000000000000000000000000000000000000000..b624f308deadc89fbf48123d30232a81828b888b Binary files /dev/null and "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/__pycache__/common2.cpython-36.pyc" differ diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/baseline.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/baseline.py" new file mode 100644 index 0000000000000000000000000000000000000000..78419786e20a20b4fb3fc43d3cc9243c203a0dd8 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/baseline.py" @@ -0,0 +1,113 @@ +from os import replace +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.common import * + +class Res3DBlock(nn.Module): + def __init__(self, n_feats, bias=True, act=nn.ReLU(True), res_scale=1): + super(Res3DBlock, self).__init__() + + self.body = nn.Sequential(nn.Conv3d(1, n_feats, (3,1,1),1,(1,0,0), bias=bias), + act, + nn.Conv3d(n_feats, 1, (1,3,3),1,(0,1,1), bias=bias) + ) + self.res_scale = res_scale + + def forward(self, x): + x = self.body(x.unsqueeze(1))+x.unsqueeze(1) + return x.squeeze(1) + +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class RCAB(nn.Module): + def __init__( + self, n_feat,reduction=16, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(nn.Conv2d(n_feat, n_feat, 3,1,1, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +class Baseline(nn.Module): + def __init__(self, args): + super(Baseline, self).__init__() + + self.scale = args.upscale_factor + self.n_feats = args.n_feats + self.n_module = args.n_module + self.windSize = args.window_size + self.c = args.window_size + + self.gamma_FCF = nn.Parameter(torch.ones(2)) + + self.TwoHead = nn.Conv2d(self.c,self.n_feats,kernel_size=3,padding=1) + FEs = [RCAB(self.n_feats) for _ in range(self.n_module)] + self.FEs = nn.Sequential(*FEs) + + self.reduceD_FCF = nn.Conv2d( + self.n_feats * 2, self.n_feats, kernel_size=(1, 1), stride=1 + ) + + self.conv_FCF = nn.Conv2d( + self.n_feats, self.n_feats, kernel_size=(1, 1), stride=1 + ) + self.Up = Upsampler(self.scale, args.n_feats) + self.final = nn.Conv2d( + self.n_feats, self.windSize, kernel_size=(3, 3), stride=1, padding=1 + ) + # self.act = nn.ReLU() + + def forward(self, x, h=None, i=None): + # x shape: B,3,H,W ---> B,N,H,W + y = F.interpolate(x, scale_factor=self.scale, mode="bicubic").clamp( + min=0, max=1 + ) + x = self.TwoHead(x) + + # feature extractor + skip_x = x + for j in range(self.n_module): + x = self.FEs[j](x) + + x = x + skip_x + + # group fusion + if i != 0: # B,N,H,W + x = torch.cat([self.gamma_FCF[0] * x, self.gamma_FCF[1] * h], 1) + x = self.reduceD_FCF(x) + x = self.conv_FCF(x) + h = x + + + x = self.final(self.Up(x)) + y + return x, h diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common.py" new file mode 100644 index 0000000000000000000000000000000000000000..43b8a2db9006e98fd7853021ce548cf4673703c1 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common.py" @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from eval import Bconstrast +from data_utils import shuffle +# import torch.fft + +class SSFFB(nn.Module): + def __init__(self, n_feats=64): + super(SSFFB, self).__init__() + self.act = nn.ReLU(inplace=True) + + body_spatial = [] + for i in range(1): + body_spatial.append( + nn.Conv3d( + n_feats, + n_feats, + kernel_size=(1, 3, 3), + stride=1, + padding=(0, 1, 1), + ) + ) + + body_spectral = [] + for i in range(1): + body_spectral.append( + nn.Conv3d( + n_feats, + n_feats, + kernel_size=(3, 1, 1), + stride=1, + padding=(1, 0, 0), + ) + ) + + self.body_spatial = nn.Sequential(*body_spatial) + self.body_spectral = nn.Sequential(*body_spectral) + self.reduce = nn.Conv3d(n_feats * 2, n_feats, kernel_size=(1, 1, 1), stride=1) + + def forward(self, x): + out = x + spe = x + spa = x + for i in range(1): + spa = self.body_spatial[i](spa) + spe = self.body_spectral[i](spe) + if i == 0: + spe = self.act(spe) + spa = self.act(spa) + out = torch.cat([spe, spa], dim=1) + out = self.reduce(out) + # out = spe + out = out + x + return out + + +class Res3DBlock(nn.Module): + def __init__(self, n_feats=64, bias=True, act=nn.ReLU(True), res_scale=1): + super(Res3DBlock, self).__init__() + + self.body = nn.Sequential(nn.Conv3d(n_feats, n_feats, (3,1,1),1,(1,0,0), bias=bias), + act, + nn.Conv3d(n_feats, n_feats, (1,3,3),1,(0,1,1), bias=bias) + ) + + def forward(self, x): + x = self.body((x))+x + return x + + +class Upsampler(nn.Sequential): + def __init__(self, scale, n_feats): + TwoTail = [] + if (scale & (scale - 1)) == 0: + for _ in range(int(math.log(scale, 2))): + TwoTail.append( + nn.Conv2d( + n_feats, + n_feats * 4, + kernel_size=(3, 3), + stride=1, + padding=(1, 1), + ) + ) + TwoTail.append(nn.PixelShuffle(2)) + elif scale == 3: + TwoTail.append( + nn.Conv2d( + n_feats, n_feats * 9, kernel_size=(3, 3), stride=1, padding=(1, 1), + ) + ) + TwoTail.append(nn.PixelShuffle(3)) + else: + raise NotImplementedError + super(Upsampler, self).__init__(*TwoTail) + +def _to_4d_tensor(x): + # B,N,C,H,W + x = x.permute(0,2,1,3,4) + x = torch.split(x,1,dim=0) + x = torch.cat(x,1).squeeze(0) + return x + +def _to_5d_tensor(x,C): + x = torch.split(x,C) + x = torch.stack(x,dim=0) + x = x.transpose(1,2) + return x + +class FFC(nn.Module): + """ + 傅里叶变换 + """ + def __init__(self,n_feats): + super(FFC,self).__init__() + self.fft_norm = "ortho" + # self.spatial = Res3DBlock() + self.frequency = Res3DBlock(n_feats) + + def forward(self,x): + # X.shape: B,N,C,H,W + B,N,C,H,W = x.shape + # x = self.spatial(x) + # fft_dim = (-2,-1) # 默认最后二维傅里叶变换 + ffted = torch.rfft(x,signal_ndim=2,onesided=False) + # print(ffted.shape) + ffted = torch.cat([ffted[...,0],ffted[...,1]],dim=-1) # 实部和虚部堆叠 + # print(ffted.shape) # B,N,C,H,W+2 + ffted = self.frequency(ffted) + ffted = torch.stack((ffted[...,:W],ffted[...,W:]),dim=-1) + # print(ffted.shape) + ffted = torch.irfft(ffted,signal_ndim=2,onesided=False) + # print(ffted.shape) + x = x + ffted + return x + +class NLA(nn.Module): + def __init__(self,n_feats,reduction=16): + super(NLA,self).__init__() + """ + Non-Local Attention + """ + self.k = nn.Conv3d(n_feats,n_feats//reduction,1,padding=0,bias=True) + self.v = nn.Conv3d(n_feats,n_feats//reduction,1,padding=0,bias=True) + self.q = nn.Conv3d(n_feats,n_feats//reduction,1,padding=0,bias=True) + self.unsqueeze = nn.Conv3d(n_feats//reduction,n_feats,1,padding=0,bias=True) + self.softmax = nn.Softmax2d() + + def forward(self,x): + # B,N,C,H,W + B,N,C,H,W = x.shape + scale = C*H*W + k = self.k(x) + v = self.v(x) + q = self.q(x) + k = k.view(B,k.shape[1],-1) # B,N1,CHW + v = v.view(B,v.shape[1],-1).transpose(1,2) # B,CHW,N1 + q = q.view(B,q.shape[1],-1).transpose(1,2) + M = torch.bmm(v,k)/scale + M = self.softmax(M.unsqueeze(1)).squeeze(1) + q = torch.bmm(M,q)/scale + q = q.transpose(1,2).view(B,-1,C,H,W) + x = self.unsqueeze(q) + x + return x + +class NL_FFC(nn.Module): + def __init__(self,n_feats=32): + super(NL_FFC,self).__init__() + self.fll = Res3DBlock(n_feats) + # self.flg = nn.Sequential( + # nn.Conv3d(n_feats,n_feats,kernel_size=1), + # NLA(n_feats), + # ) + self.flg = Res3DBlock(n_feats) + self.fgl = Res3DBlock(n_feats) + self.fgg = FFC(n_feats) + self.conv = Res3DBlock(n_feats*2) + + + def forward(self,x): + # B,N,C,H,W + skip_x = x + x_l,x_g = torch.chunk(x,2,1) + x_l,x_g = (self.fll(x_l) + self.fgl(x_g))/2,(self.flg(x_l) + self.fgg(x_g))/2 + x = torch.cat([x_l,x_g],1) + x = self.conv(x) + skip_x + return x + + + + +class GEB(nn.Module): + def __init__(self, channel=64,reduction=16): + super(GEB, self).__init__() + self.act = nn.ReLU(inplace=True) + + self.conv1 = nn.Sequential(nn.Conv3d(channel, channel, (3,1,1),1,(1,0,0), bias=True), + self.act, + nn.Conv3d(channel, channel, (1,3,3),1,(0,1,1), bias=True) + ) + + self.conv2 = nn.Sequential(nn.Conv3d(channel, channel, (3,1,1),1,(1,0,0), bias=True), + self.act, + nn.Conv3d(channel, channel, (1,3,3),1,(0,1,1), bias=True) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + + def forward(self, x): + out = x + x = self.conv1(x) + # x = self.act(x) + x = self.conv2(x) + skip_x = x + # B,N,C,H,W ---> B*C, N, H, W + C = x.shape[2] + x = _to_4d_tensor(x) + ffted = torch.rfft(x,signal_ndim=2,onesided=True) # B*C, N,H,W/2+1,2 + ffted = torch.cat([ffted[...,0],ffted[...,1]],dim=-1) # 实部和虚部堆叠 + s = self.conv_du(self.avg_pool(ffted)) # 梯度池化通道注意力机制 + # s = self.conv_du(self.avg_pool(x)) # 梯度池化通道注意力机制 + x = s * x + x = _to_5d_tensor(x,C) + x = x + skip_x + out = x + out + return x diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common2.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common2.py" new file mode 100644 index 0000000000000000000000000000000000000000..bcd6f1a052072abe32af185573af8481147a3ad0 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/model/common2.py" @@ -0,0 +1,131 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + + +class ResAttentionBlock(nn.Module): + def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + super(ResAttentionBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + m.append(CALayer(n_feats, 16)) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/option.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/option.py" new file mode 100644 index 0000000000000000000000000000000000000000..7767754a78c4b3aa38c22b79dfc74a14d9eb20e0 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/option.py" @@ -0,0 +1,68 @@ +import argparse + +# Training settings +parser = argparse.ArgumentParser(description="Super-Resolution") +parser.add_argument( + "--upscale_factor", default=2, type=int, help="super resolution upscale factor" +) +parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)") +parser.add_argument("--batchSize", type=int, default=32, help="training batch size") +parser.add_argument( + "--nEpochs", type=int, default=100, help="maximum number of epochs to train" +) +parser.add_argument("--band", type=int, default=31) +parser.add_argument("--show", action="store_true", help="show Tensorboard") + +parser.add_argument("--lr", type=int, default=1e-4, help="lerning rate") +parser.add_argument("--cuda", action="store_true", help="Use cuda") +parser.add_argument("--gpus", default="0,1,2,3", type=str, help="gpu ids (default: 0)") +parser.add_argument("--dist", action="store_true", help="use dist") +parser.add_argument( + "--threads", type=int, default=12, help="number of threads for dataloader to use" +) +parser.add_argument( + "--resume", + default="", + type=str, + help="Path to checkpoint (default: none) checkpoint/model_epoch_95.pth", +) +# parser.add_argument( +# "--branch", +# default="", +# type=str, +# help="Path to Branch checkpoint (default: none) checkpoint/model_epoch_95.pth", +# ) +parser.add_argument( + "--start-epoch", + default=1, + type=int, + help="Manual epoch number (useful on restarts)", +) + +parser.add_argument("--datasetName", default="CAVE", type=str, help="data name") + +parser.add_argument("--shuffleMode", type=str, default="origin") +parser.add_argument("--shuffle", type=int, default=1) +parser.add_argument("--shufflegroup",type=int,default=10) + +# Network settings +parser.add_argument("--n_module", type=int, default=8, help="number of modules") +parser.add_argument("--n_feats", type=int, default=64, help="number of feature maps") +parser.add_argument("--loss", type=str, default="L1") +parser.add_argument("--window_size", type=int, default=3) + +# Test image +parser.add_argument( + "--model_name", default="", type=str, help="super resolution model name ", +) +parser.add_argument( + "--method", default="SGSR", type=str, help="super resolution method name" +) +parser.add_argument("--ex", type=str, default="origin", help="experiment save name") +parser.add_argument("--exgroup", type=str, default="") +opt = parser.parse_args() +opt.loss = opt.loss.split("+") +if opt.datasetName == "Foster": + opt.band = 33 +elif opt.datasetName == "Chikusei": + opt.band = 128 diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/param.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/param.py" new file mode 100644 index 0000000000000000000000000000000000000000..06d3398e1ac9c8bde3ffe144a4c413c40c997494 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/param.py" @@ -0,0 +1,28 @@ +from model import Model +from option import opt + + +Models = ["3DFCNN","GDRRN","EDSR","RFSR","SSPSR","MCNet","ERCSR","SFCSR","SGSR",] +Scale = [2,3,4] +Dataset = ["CAVE","Chikusei"] + +for d in Dataset: + print("datasetName:",d) + for s in Scale: + for m in Models: + if m=="SSPSR" and s in [2,3]: + continue + if m=="SGSR" and d == "Chikusei": + opt.n_module = 10 + if m=="SFCSR": + opt.n_module = 5 + + opt.method = m + opt.datasetName = d + opt.upscale_factor = s + if d == "Chikusei": + opt.band = 128 + opt.window_size = 8 + + model = Model(opt) + print("Model:",m," Scale: ",s," parameters:", sum(param.numel() for param in model.parameters())) diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/test.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/test.py" new file mode 100644 index 0000000000000000000000000000000000000000..6e3773d940da2c3a6879e10d151dd7baa5ac0cf1 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/test.py" @@ -0,0 +1,185 @@ +import os +import sys +import cv2 +import numpy as np +import torch +from os import listdir +import torch.nn as nn +from torch.autograd import Variable +from torch.utils.data.dataloader import DataLoader +from option import opt +from data_utils import ValsetFromFolder, choose_x, chop_forward, is_image_file, shuffle +import scipy.io as scio +from eval import PSNR, SSIM, SAM, cal_sam,constrast +from model import Model + +def main(): + + input_path = ( + "/data2/cys/data/" + + opt.datasetName + + "/process_test/" + + str(opt.upscale_factor) + + "/" + ) + out_path = ( "result/" + + opt.datasetName+"/" + + str(opt.upscale_factor) + + "/" + + opt.method + + "/" + ) + + val_set = ValsetFromFolder(input_path, opt.shuffle,g = opt.window_size) + val_loader = DataLoader( + dataset=val_set, num_workers=opt.threads, batch_size=1, shuffle=False + ) + + if not os.path.exists(out_path): + os.makedirs(out_path) + PSNRs = [] + SSIMs = [] + SAMs = [] + + if opt.cuda: + print("=> use gpu id: '{}'".format(opt.gpus)) + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus + if not torch.cuda.is_available(): + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") + + model = Model(opt) + + if opt.cuda and opt.dist: + model = nn.DataParallel(model).cuda() + else: + model = model.cuda() + + if opt.model_name: + checkpoint = torch.load(opt.model_name) + Branch_dict = checkpoint["model"] + model_dict = model.state_dict() + # print(model_dict.keys()) + # print(Branch_dict.keys()) + pretrained_dict = {k: v for k, v in Branch_dict.items() if k in model_dict} + miss_param = {k for k in Branch_dict.keys() if k not in model_dict} + print("miss_param:",len(miss_param)) + print(miss_param) + # print(model_dict) + # print(len(pretrained_dict)) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # model.load_state_dict(checkpoint) + model.eval() + # sys.exit(0) + + images_name = [x for x in listdir(input_path) if is_image_file(x)] + T = 0 + for index, batch in enumerate(val_loader): + with torch.no_grad(): + input, HR = Variable(batch[0]), Variable(batch[1]) + # print(input.shape, HR.shape) + SR = np.zeros((HR.shape[1], HR.shape[2], HR.shape[3])).astype(np.float32) + + HR = HR.data[0].numpy() + B, C, h, w = input.shape + g = (C + opt.window_size - 1) // opt.window_size + + if opt.cuda: + input = input.cuda() + + # start = time.time() + + if opt.method == "SGSR": + h1 = [] + channel_count = torch.zeros(C) + for i in range(g): + start = i * opt.window_size + end = (i + 1) * opt.window_size + if end > C: + end = C + start = end - opt.window_size + x = input[:, start:end, :, :] + y, h1 = model(x, h1, i) + SR[start:end, :, :] += y.cpu().data[0].numpy() + channel_count[start:end] += 1 + SR = SR / channel_count.reshape(-1, 1, 1).numpy() + elif opt.method == "SFCSR": + localFeats = [] + for i in range(input.shape[1]): + x = choose_x(input, i, opt.shuffleMode) + y = input[:, i, :, :] + output, localFeats = model(x, y, localFeats, i) + # output, localFeats = model(y, localFeats, i) + SR[i, :, :] = output.cpu().data[0].numpy() + elif opt.method in ["RFSR","3DFCNN","ERCSR","MCNet"] and opt.datasetName == "Chikusei": + SR = chop_forward(input,model,opt.upscale_factor) + else: + SR = model(input).cpu().data[0].numpy() + # end = time.time() + # print("测试时间:", (end - start)/60,"min") + # T = T + (end - start) + SR[SR < 0] = 0 + SR[SR > 1.0] = 1.0 + # print(SR.shape, HR.shape) 31,512,512 31,512,512 + + psnr = PSNR(SR, HR) + ssim = SSIM(SR, HR) + sam = SAM(SR, HR) + + PSNRs.append(psnr) + SSIMs.append(ssim) + SAMs.append(sam) + + if opt.method == "SGSR": + SR = shuffle(torch.from_numpy(SR),opt.band // opt.window_size).numpy() + HR = shuffle(torch.from_numpy(HR),opt.band // opt.window_size).numpy() + + + + SR = SR.transpose(1, 2, 0) + HR = HR.transpose(1, 2, 0) + # scio.savemat(out_path + images_name[index], {'HR': HR, 'SR': SR}) + + # shape 512,512,31 + # 保存SR图像,每个band以png格式保存,保存error map + img_out_path = out_path + images_name[index][:-4] + if not os.path.exists(img_out_path): + os.mkdir(img_out_path) + os.mkdir(img_out_path + "/image") + os.mkdir(img_out_path + "/error_map") + # HR_con = 0.0 + # SR_con = 0.0 + for i in range(SR.shape[2]): + img = cv2.cvtColor( + SR[:, :, i] * 255, cv2.COLOR_GRAY2BGR) + cv2.imwrite(img_out_path+"/image/"+str(i)+".png", img) + # HR_con += constrast(HR[:,:,i]*255) + # SR_con += constrast(SR[:,:,i]*255) + # 制作error_ma + error_map = abs(HR[:, :, i] - SR[:, :, i]) + # print(np.max(error_map)) + error_map = error_map/np.max(error_map)*255 + error_map = cv2.cvtColor(error_map, cv2.COLOR_GRAY2BGR) + cv2.imwrite(img_out_path+"/error_map/" + + str(i)+".png", error_map) + + print( + "===The {}-th picture=====PSNR:{:.3f}=====SSIM:{:.4f}=====SAM:{:.3f}====Name:{}".format( + index + 1, psnr, ssim, sam, images_name[index] + ) + ) + # print("Constrast: HR: ",HR_con/SR.shape[2],"SR: ",SR_con/SR.shape[2]) + + print( + "=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format( + np.mean(PSNRs), np.mean(SSIMs), np.mean(SAMs) + ) + ) + print(T / len(images_name)) + + +if __name__ == "__main__": + main() + + + diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.bash" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.bash" new file mode 100644 index 0000000000000000000000000000000000000000..3f6a7156131328f403e1c3f8e25b6aa070005ed8 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.bash" @@ -0,0 +1,5 @@ +# processing data +matlab -nodesktop -nosplash -r data_CAVE + +# GDRRN +python train_G.py --cuda --gpus 0 --method GDRRN --ex GDRRN --datasetName Chikusei --show --shuffle 0 \ No newline at end of file diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.py" new file mode 100644 index 0000000000000000000000000000000000000000..b02fb7dbd27a9185ac0b8a37df183d263d9c87e7 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train.py" @@ -0,0 +1,286 @@ +import os +import sys +from numpy.core.numeric import False_ +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable +from torch.utils.data import DataLoader +from tensorboardX import SummaryWriter +from model import Model +from loss import Loss + +from option import opt +from data_utils import TrainsetFromFolder, ValsetFromFolder,CutMix +from eval import PSNR, SAM, SSIM +from torch.optim.lr_scheduler import MultiStepLR +import numpy as np +import time + +psnr = [] +out_path = "result/" + opt.datasetName + "/" + + +def main(): + print(opt) + best_psnr = 0 + best_sam = 1e6 + + if opt.show: + global writer + writer = SummaryWriter( + log_dir="logs/" + + opt.datasetName + + "/" + + str(opt.upscale_factor) + + "/" + + opt.exgroup + + "/" + + opt.ex + ) + + if opt.cuda: + print("=> Use GPU ID: '{}'".format(opt.gpus)) + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus + if not torch.cuda.is_available(): + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") + + torch.manual_seed(opt.seed) + if opt.cuda: + torch.cuda.manual_seed(opt.seed) + cudnn.benchmark = True + g = (opt.band + opt.window_size - 1) // opt.window_size + + # Loading datasets + train_set = TrainsetFromFolder( + "/data2/cys/data/" + + opt.datasetName + + "/process_train/" + + str(opt.upscale_factor) + + "/", + opt.shuffle, + opt.shuffleMode, + opt.band // opt.shufflegroup, + # opt.window_size + ) + train_loader = DataLoader( + dataset=train_set, + num_workers=opt.threads, + batch_size=opt.batchSize, + shuffle=True, + ) + val_set = ValsetFromFolder( + "/data2/cys/data/" + + opt.datasetName + + "/process_test/" + + str(opt.upscale_factor) + + "/", + opt.shuffle, + opt.shuffleMode, + opt.band // opt.shufflegroup, + # opt.window_size + ) + val_loader = DataLoader( + dataset=val_set, num_workers=opt.threads, batch_size=1, shuffle=False + ) + + # Buliding model + model = Model(opt) + # print(model) + + # choose Loss + criterion = Loss(opt) + + if opt.cuda and opt.dist: + model = nn.DataParallel(model).cuda() + elif opt.cuda: + model = model.cuda() + else: + model = model.cpu() + print("# parameters:", sum(param.numel() for param in model.parameters())) + + # Setting Optimizer + optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08) + + # optionally resuming from a checkpoint + if opt.resume: + if os.path.isfile(opt.resume): + print("=> loading checkpoint '{}'".format(opt.resume)) + checkpoint = torch.load(opt.resume) + opt.start_epoch = checkpoint["epoch"] + 1 + Branch_dict = checkpoint["model"] + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in Branch_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + print("=> no checkpoint found at '{}'".format(opt.resume)) + + # Setting learning rate + scheduler = MultiStepLR( + optimizer, milestones=[35, 70, 105, 140, 175], gamma=0.5, last_epoch=-1 + ) + + # Training + for epoch in range(opt.start_epoch, opt.nEpochs + 1): + print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"])) + start = time.time() + train(train_loader, optimizer, model, criterion, epoch) + end = time.time() + print("epoch Cost:", (end - start) / 60, "min") + scheduler.step() + best_psnr, best_sam = val( + val_loader, model, epoch, optimizer, best_psnr, best_sam + ) + # save_model(model, epoch, optimizer, "last") + + +def train(train_loader, optimizer, model, criterion, epoch): + + model.train() + for iteration, batch in enumerate(train_loader, 1): + input, label = Variable(batch[0]), Variable(batch[1], requires_grad=False) + if opt.cuda: + input = input.cuda() + label = label.cuda() + + B, N, h, w = input.shape + + # CutMix + # if np.random.rand(1) < 0.5: + # input,label = CutMix(input,label) + + + h1 = [] + c = opt.window_size + g = (N + c - 1) // c + last_group = [] + for i in range(g): + start = i * c + end = (i + 1) * c + if end > input.shape[1]: + end = input.shape[1] + start = end - c + x = input[:, start:end, :, :] + new_label = label[:, start:end, :, :] + SR, h1 = model(x, h1, i) + h1 = Variable(h1.detach().data, requires_grad=False) + # loss = criterion.loss(SR, new_label) + if i in range(1,g-1): #group spectral difference constraint + diff_label = new_label - label[:,start-c:start,:,:] + loss = nn.L1Loss()(SR-last_group,diff_label) + else: + loss = criterion.loss(SR, new_label) + last_group = Variable(SR.detach().data, requires_grad=False) # 上一组输出结果 + optimizer.zero_grad() + loss.backward() + optimizer.step() + # sys.exit(0) + + if iteration % 100 == 0: + print( + "===> Epoch[{}]({}/{}): Loss: {:.10f}".format( + epoch, iteration, len(train_loader), loss.item() + ) + ) + + if opt.show: + niter = epoch * len(train_loader) + iteration + if niter % 500 == 0: + writer.add_scalar("Train/Loss", loss.item(), niter) + + +def val(val_loader, model, epoch, optimizer, best_psnr, best_sam): + + model.eval() + val_psnr = 0 + val_sam = 0 + val_SSIM = 0 + + for iteration, batch in enumerate(val_loader, 1): + with torch.no_grad(): + input, label = Variable(batch[0]), Variable(batch[1]) + + if opt.cuda: + input = input.cuda() + + B, C, h, w = input.shape + g = (C + opt.window_size - 1) // opt.window_size + # p = opt.windSize - C % opt.windSize + SR = np.zeros((label.shape[1], label.shape[2], label.shape[3])).astype( + np.float32 + ) + + h1 = [] + channel_count = torch.zeros(C) + for i in range(g): + start = i * opt.window_size + end = (i + 1) * opt.window_size + if end > C: + end = C + start = end - opt.window_size + x = input[:, start:end, :, :] + y, h1 = model(x, h1, i) + SR[start:end, :, :] += y.cpu().data[0].numpy() + channel_count[start:end] += 1 + + SR = SR / channel_count.reshape(-1, 1, 1).numpy() + SR[SR < 0] = 0 + SR[SR > 1.0] = 1.0 + val_psnr += PSNR(SR, label.data[0].numpy()) + val_sam += SAM(SR, label.data[0].numpy()) + val_SSIM += SSIM(SR,label.data[0].numpy()) + # print("PSNR", val_psnr) + val_psnr = val_psnr / len(val_loader) + val_sam = val_sam / len(val_loader) + val_SSIM = val_SSIM / len(val_loader) + if val_psnr > best_psnr: + save_model(model, epoch, optimizer, "psnr_best") + best_psnr = val_psnr + if val_sam < best_sam: + save_model(model, epoch, optimizer, "sam_best") + best_sam = val_sam + + print( + "PSNR = {:.3f},best_PSNR = {:.3f},SSIM = {:.4f},SAM = {:.3f},best_sam={:.3f}".format( + val_psnr, best_psnr, val_SSIM, val_sam, best_sam + ) + ) + + if opt.show: + writer.add_scalar("Val/PSNR", val_psnr, epoch) + writer.add_scalar("Val/SAM", val_sam, epoch) + + # save_model( + # model, epoch, optimizer, "epoch:" + str(epoch), + # ) + + return best_psnr, best_sam + + +def save_model(model, epoch, optimizer, name): + model_out_dir = ( + "checkpoint/" + + opt.datasetName + + "/" + + str(opt.upscale_factor) + + "/" + + opt.ex + + "/" + ) + model_out_path = model_out_dir + name + ".pth" + state = { + "epoch": epoch, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if not os.path.exists(model_out_dir): + os.makedirs(model_out_dir) + torch.save(state, model_out_path) + + +if __name__ == "__main__": + main() diff --git "a/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train_G.py" "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train_G.py" new file mode 100644 index 0000000000000000000000000000000000000000..481bd7696a062cd310350d69006dea32e1c4f7a7 --- /dev/null +++ "b/code/2022_autumn/\347\250\213\350\277\216\346\235\276-\351\253\230\345\205\211\350\260\261\350\266\205\345\210\206/train_G.py" @@ -0,0 +1,293 @@ +import os +import sys +from numpy.core.numeric import False_ +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable +from torch.utils.data import DataLoader +from tensorboardX import SummaryWriter +from model import Model +from loss import Loss + +from option import opt +from data_utils import TrainsetFromFolder, ValsetFromFolder,choose_x,chop_forward +from eval import PSNR, SAM, SSIM +from torch.optim.lr_scheduler import MultiStepLR +import numpy as np +import time + +import scipy.io as scio + +psnr = [] +out_path = "result/" + opt.datasetName + "/" + + +def main(): + print(opt) + best_psnr = 0 + best_sam = 1e6 + + if opt.show: + global writer + writer = SummaryWriter( + log_dir="logs/" + + opt.datasetName + + "/" + + str(opt.upscale_factor) + + "/" + + opt.exgroup + + "/" + + opt.ex + ) + + if opt.cuda: + print("=> Use GPU ID: '{}'".format(opt.gpus)) + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus + if not torch.cuda.is_available(): + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") + + torch.manual_seed(opt.seed) + if opt.cuda: + torch.cuda.manual_seed(opt.seed) + cudnn.benchmark = True + + # Loading datasets + train_set = TrainsetFromFolder( + "/data2/cys/data/" + + opt.datasetName + + "/process_train/" + + str(opt.upscale_factor) + + "/", + opt.shuffle, + opt.shuffleMode, + ) + + train_loader = DataLoader( + dataset=train_set, + num_workers=opt.threads, + batch_size=opt.batchSize, + shuffle=True, + ) + val_set = ValsetFromFolder( + "/data2/cys/data/" + + opt.datasetName + + "/process_test/" + + str(opt.upscale_factor) + + "/", + opt.shuffle, + opt.shuffleMode, + ) + val_loader = DataLoader( + dataset=val_set, num_workers=opt.threads, batch_size=1, shuffle=False + ) + + # Buliding model + model = Model(opt) + # print(model) + + # choose Loss + criterion = Loss(opt) + + if opt.cuda and opt.dist: + model = nn.DataParallel(model).cuda() + elif opt.cuda: + model = model.cuda() + else: + model = model.cpu() + print("# parameters:", sum(param.numel() for param in model.parameters())) + + # # loader GroupBranch Model param + # if opt.branch: + # checkpoint = torch.load(opt.branch) + # Branch_dict = checkpoint["model"] + # model_dict = model.state_dict() + # pretrained_dict = {k:v for k,v in Branch_dict.items() if k in model_dict} + # model_dict.update(pretrained_dict) + # model.load_state_dict(model_dict) + # for para in model.Encoder.parameters(): + # para.requires_grad = False # Frozen param + + # Setting Optimizer + optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.parameters()), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08) + + # optionally resuming from a checkpoint + if opt.resume: + if os.path.isfile(opt.resume): + print("=> loading checkpoint '{}'".format(opt.resume)) + checkpoint = torch.load(opt.resume) + opt.start_epoch = checkpoint["epoch"] + 1 + Branch_dict = checkpoint["model"] + model_dict = model.state_dict() + pretrained_dict = {k:v for k,v in Branch_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + print("=> no checkpoint found at '{}'".format(opt.resume)) + + # Setting learning rate + scheduler = MultiStepLR( + optimizer, milestones=[35, 70, 105, 140, 175], gamma=0.5, last_epoch=-1 + ) + + + # Training + for epoch in range(opt.start_epoch, opt.nEpochs + 1): + print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"])) + start = time.time() + train(train_loader, optimizer, model, criterion, epoch) + end = time.time() + print("epoch Cost:", (end - start) / 60, "min") + scheduler.step() + # torch.cuda.empty_cache() + best_psnr, best_sam = val( + val_loader, model, epoch, optimizer, best_psnr, best_sam + ) + # save_model(model, epoch, optimizer, "last") + + +def train(train_loader, optimizer, model, criterion, epoch): + + model.train() + for iteration, batch in enumerate(train_loader, 1): + + input, label = Variable(batch[0]), Variable(batch[1], requires_grad=False) + if opt.cuda: + input = input.cuda() + label = label.cuda() + + if opt.method == "SFCSR": + localFeats = [] + for i in range(input.shape[1]): + x = choose_x(input, i, opt.shuffleMode) + y = input[:, i, :, :] + new_label = label[:, i, :, :] + + SR, localFeats = model(x, y, localFeats, i) + # SR, localFeats = model(y, localFeats, i) + localFeats = localFeats.detach() + localFeats = Variable(localFeats.data, requires_grad=False) + # print(SR.shape, new_label.shape) + loss = criterion.loss(SR, new_label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + else: + output = model(input) + # shape: B,N,H,W + # print(output.shape,label.shape) + loss = criterion.loss(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if iteration % 100 == 0: + print( + "===> Epoch[{}]({}/{}): Loss: {:.10f}".format( + epoch, iteration, len(train_loader), loss.item() + ) + ) + + if opt.show: + niter = epoch * len(train_loader) + iteration + if niter % 500 == 0: + writer.add_scalar("Train/Loss", loss.item(), niter) + + +def val(val_loader, model, epoch, optimizer, best_psnr, best_sam): + + model.eval() + val_psnr = 0 + val_sam = 0 + val_SSIM = 0 + + for iteration, batch in enumerate(val_loader, 1): + with torch.no_grad(): + input, label = Variable(batch[0]), Variable(batch[1]) + SR = np.zeros((label.shape[1], label.shape[2], label.shape[3])).astype( + np.float32 + ) + + if opt.cuda: + input = input.cuda() + # print(input.shape) + if opt.method == "SFCSR": + localFeats = [] + for i in range(input.shape[1]): + x = choose_x(input, i, opt.shuffleMode) + y = input[:, i, :, :] + output, localFeats = model(x, y, localFeats, i) + # output, localFeats = model(y, localFeats, i) + SR[i, :, :] = output.cpu().data[0].numpy() + elif opt.method in ["RFSR","3DFCNN","ERCSR","MCNet"] and opt.datasetName == "Chikusei": + SR = chop_forward(input,model,opt.upscale_factor) + else: + SR = model(input).cpu().data[0].numpy() + SR[SR < 0] = 0 + SR[SR > 1.0] = 1.0 + val_psnr += PSNR(SR, label.data[0].numpy()) + val_sam += SAM(SR, label.data[0].numpy()) + val_SSIM += SSIM(SR,label.data[0].numpy()) + # print("PSNR", val_psnr) + val_psnr = val_psnr / len(val_loader) + val_sam = val_sam / len(val_loader) + val_SSIM = val_SSIM / len(val_loader) + if val_psnr > best_psnr: + save_model(model, epoch, optimizer, "psnr_best") + best_psnr = val_psnr + if val_sam < best_sam: + save_model(model, epoch, optimizer, "sam_best") + best_sam = val_sam + + print( + "PSNR = {:.3f},best_PSNR = {:.3f},SSIM={:.4f},SAM = {:.3f},best_sam={:.3f}".format( + val_psnr, best_psnr,val_SSIM,val_sam, best_sam + ) + ) + if opt.show: + writer.add_scalar("Val/PSNR", val_psnr, epoch) + writer.add_scalar("Val/SAM", val_sam, epoch) + + # save_model( + # model, epoch, optimizer, "epoch:" + str(epoch), + # ) + + return best_psnr, best_sam + + +# def save_model(model, epoch, optimizer, name): +# model_out_path = "checkpoint/" + opt.ex + "/" + name + ".pth" +# state = { +# "epoch": epoch, +# "model": model.state_dict(), +# "optimizer": optimizer.state_dict(), +# } +# if not os.path.exists("checkpoint/" + opt.ex + "/"): +# os.makedirs("checkpoint/" + opt.ex + "/") +# torch.save(state, model_out_path) + +def save_model(model, epoch, optimizer, name): + model_out_dir = ( + "checkpoint/" + + opt.datasetName + + "/" + + str(opt.upscale_factor) + + "/" + + opt.ex + + "/" + ) + model_out_path = model_out_dir + name + ".pth" + state = { + "epoch": epoch, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if not os.path.exists(model_out_dir): + os.makedirs(model_out_dir) + torch.save(state, model_out_path) + +if __name__ == "__main__": + main()