diff --git a/cv/distiller/RKD/pytorch/README.md b/cv/distiller/RKD/pytorch/README.md new file mode 100755 index 0000000000000000000000000000000000000000..e82b0298580af2a5767d446c11fa26ae747725a2 --- /dev/null +++ b/cv/distiller/RKD/pytorch/README.md @@ -0,0 +1,62 @@ +# Relational Knowledge Distillation + +Official implementation of [Relational Knowledge Distillation](https://arxiv.org/abs/1904.05068?context=cs.LG), CVPR 2019\ +This repository contains source code of experiments for metric learning. + +```bash +# If reports 'ZLIB_1.2.9' not found, you need to install as below. +wget http://www.zlib.net/fossils/zlib-1.2.9.tar.gz +tar xvf zlib-1.2.9.tar.gz +cd zlib-1.2.9/ +./configure && make install +cd .. +rm -rf zlib-1.2.9.tar.gz zlib-1.2.9/ +``` + +## Quick Start + +```bash +# Train a teacher embedding network of resnet50 (d=512) +# using triplet loss (margin=0.2) with distance weighted sampling. +python3 run.py --mode train \ + --dataset cub200 \ + --base resnet50 \ + --sample distance \ + --margin 0.2 \ + --embedding_size 512 \ + --save_dir teacher + +# Evaluate the teacher embedding network +python3 run.py --mode eval \ + --dataset cub200 \ + --base resnet50 \ + --embedding_size 512 \ + --load teacher/best.pth + +# Distill the teacher to student embedding network +python3 run_distill.py --dataset cub200 \ + --base resnet18 \ + --embedding_size 64 \ + --l2normalize false \ + --teacher_base resnet50 \ + --teacher_embedding_size 512 \ + --teacher_load teacher/best.pth \ + --dist_ratio 1 \ + --angle_ratio 2 \ + --save_dir student + +# Distill the trained model to student network +python3 run.py --mode eval \ + --dataset cub200 \ + --base resnet18 \ + --l2normalize false \ + --embedding_size 64 \ + --load student/best.pth + +``` +## Results +| | acc | +|:----------:|:--------:| +| RKD| Best Train Recall: 0.7940, Best Eval Recall: 0.5763 | + + diff --git a/cv/distiller/RKD/pytorch/dataset/__init__.py b/cv/distiller/RKD/pytorch/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8674c8df94b0f0afcb17168800957667446bfc --- /dev/null +++ b/cv/distiller/RKD/pytorch/dataset/__init__.py @@ -0,0 +1,3 @@ +from .cars196 import Cars196Metric +from .cub200 import CUB2011Metric +from .stanford import StanfordOnlineProductsMetric diff --git a/cv/distiller/RKD/pytorch/dataset/cars196.py b/cv/distiller/RKD/pytorch/dataset/cars196.py new file mode 100644 index 0000000000000000000000000000000000000000..d15c3d4947ff466f45d3e43888d9c5a4ba34091b --- /dev/null +++ b/cv/distiller/RKD/pytorch/dataset/cars196.py @@ -0,0 +1,83 @@ +import os +import tarfile +import scipy.io as io + +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import default_loader +from torchvision.datasets.utils import download_url, check_integrity + +__all__ = ['Cars196Metric'] + + +class Cars196Metric(ImageFolder): + base_folder = 'car_ims' + img_url = 'http://imagenet.stanford.edu/internal/car196/car_ims.tgz' + img_filename = 'car_ims.tgz' + img_md5 = 'd5c8f0aa497503f355e17dc7886c3f14' + + anno_url = 'http://imagenet.stanford.edu/internal/car196/cars_annos.mat' + anno_filename = 'cars_annos.mat' + anno_md5 = 'b407c6086d669747186bd1d764ff9dbc' + + checklist = [ + ['016185.jpg', 'bab296d5e4b2290d024920bf4dc23d07'], + ['000001.jpg', '2d44a28f071aeaac9c0802fddcde452e'], + ] + + test_list = [] + num_training_classes = 98 + + def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs): + self.root = root + "/Cars196" + self.transform = transform + self.target_transform = target_transform + self.loader = default_loader + + if download: + download_url(self.img_url, self.root, self.img_filename, self.img_md5) + download_url(self.anno_url, self.root, self.anno_filename, self.anno_md5) + + if not self._check_integrity(): + cwd = os.getcwd() + tar = tarfile.open(os.path.join(self.root, self.img_filename), "r:gz") + os.chdir(self.root) + tar.extractall() + tar.close() + os.chdir(cwd) + + if not self._check_integrity() or \ + not check_integrity(os.path.join(self.root, self.anno_filename), self.anno_md5): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + ImageFolder.__init__(self, os.path.join(self.root), + transform=transform, target_transform=target_transform, **kwargs) + self.root = root + "/Cars196" + + labels = io.loadmat(os.path.join(self.root, self.anno_filename))['annotations'][0] + class_names = io.loadmat(os.path.join(self.root, self.anno_filename))['class_names'][0] + + if train: + self.classes = [str(c[0]) for c in class_names[:98]] + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + else: + self.classes = [str(c[0]) for c in class_names[98:]] + self.class_to_idx = {cls: i+98 for i, cls in enumerate(self.classes)} + + class_idx = list(self.class_to_idx.values()) + samples = [] + for l in labels: + cls = int(l[5][0, 0]) - 1 + p = l[0][0] + if cls in class_idx: + samples.append((os.path.join(self.root, p), int(cls))) + + self.samples = samples + self.imgs = self.samples + + def _check_integrity(self): + for f, md5 in self.checklist: + fpath = os.path.join(self.root, self.base_folder, f) + if not check_integrity(fpath, md5): + return False + return True diff --git a/cv/distiller/RKD/pytorch/dataset/cub200.py b/cv/distiller/RKD/pytorch/dataset/cub200.py new file mode 100644 index 0000000000000000000000000000000000000000..6abf9a9f8be18b71116923358b6fdca6cc27de47 --- /dev/null +++ b/cv/distiller/RKD/pytorch/dataset/cub200.py @@ -0,0 +1,98 @@ +import os +import tarfile + +from torchvision.datasets import ImageFolder +from torchvision.datasets.utils import download_url, check_integrity +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + + +__all__ = ['CUB2011Metric'] + + +class CUB2011(ImageFolder): + image_folder = 'CUB_200_2011/images' + base_folder = 'CUB_200_2011/' + url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz' + filename = 'CUB_200_2011.tgz' + tgz_md5 = '97eceeb196236b17998738112f37df78' + + checklist = [ + ['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'], + ['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'], + ['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'], + ['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe'] + ] + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + if download: + download_url(self.url, root, self.filename, self.tgz_md5) + + if not self._check_integrity(): + cwd = os.getcwd() + tar = tarfile.open(os.path.join(root, self.filename), "r:gz") + os.chdir(root) + tar.extractall() + tar.close() + os.chdir(cwd) + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + super(CUB2011, self).__init__(os.path.join(root, self.image_folder), + transform=transform, + target_transform=target_transform) + + def _check_integrity(self): + for f, md5 in self.checklist: + fpath = os.path.join(self.root, self.image_folder, f) + if not check_integrity(fpath, md5): + return False + return True + + +class CUB2011Classification(CUB2011): + def __init__(self, root, train=False, transform=None, target_transform=None, download=False): + CUB2011.__init__(self, root, transform=transform, target_transform=target_transform, download=download) + + with open(os.path.join(root, self.base_folder, 'images.txt'), 'r') as f: + id_to_image = [l.split(' ')[1].strip() for l in f.readlines()] + + with open(os.path.join(root, self.base_folder, 'train_test_split.txt'), 'r') as f: + id_to_istrain = [int(l.split(' ')[1]) == 1 for l in f.readlines()] + + train_list = [os.path.join(root, self.image_folder, id_to_image[idx]) for idx in range(len(id_to_image)) if id_to_istrain[idx]] + test_list = [os.path.join(root, self.image_folder, id_to_image[idx]) for idx in range(len(id_to_image)) if not id_to_istrain[idx]] + + if train: + self.samples = [(img_file_pth, cls_ind) for img_file_pth, cls_ind in self.imgs + if img_file_pth in train_list] + else: + self.samples = [(img_file_pth, cls_ind) for img_file_pth, cls_ind in self.imgs + if img_file_pth in test_list] + self.imgs = self.samples + + +class CUB2011Metric(CUB2011): + num_training_classes = 100 + + def __init__(self, root, train=False, split='none', transform=None, target_transform=None, download=False): + CUB2011.__init__(self, root, transform=transform, target_transform=target_transform, download=download) + + if train: + if split == 'train': + self.classes = self.classes[:(self.num_training_classes-20)] + elif split == 'val': + self.classes = self.classes[(self.num_training_classes-20):self.num_training_classes] + else: + self.classes = self.classes[:self.num_training_classes] + else: + self.classes = self.classes[self.num_training_classes:] + + self.class_to_idx = {cls_name: cls_ind for cls_name, cls_ind in self.class_to_idx.items() + if cls_name in self.classes} + self.samples = [(img_file_pth, cls_ind) for img_file_pth, cls_ind in self.imgs + if cls_ind in self.class_to_idx.values()] + self.imgs = self.samples diff --git a/cv/distiller/RKD/pytorch/dataset/stanford.py b/cv/distiller/RKD/pytorch/dataset/stanford.py new file mode 100644 index 0000000000000000000000000000000000000000..5acbc01c8ab7812efac5bd007b48742d8081f18b --- /dev/null +++ b/cv/distiller/RKD/pytorch/dataset/stanford.py @@ -0,0 +1,70 @@ +import os +import zipfile + +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import default_loader +from torchvision.datasets.utils import download_url, check_integrity + + +__all__ = ['StanfordOnlineProductsMetric'] + + +class StanfordOnlineProductsMetric(ImageFolder): + base_folder = 'Stanford_Online_Products' + url = 'ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip' + filename = 'Stanford_Online_Products.zip' + zip_md5 = '7f73d41a2f44250d4779881525aea32e' + + checklist = [ + ['bicycle_final/111265328556_0.JPG', '77420a4db9dd9284378d7287a0729edb'], + ['chair_final/111182689872_0.JPG', 'ce78d10ed68560f4ea5fa1bec90206ba'], + ['table_final/111194782300_0.JPG', '8203e079b5c134161bbfa7ee2a43a0a1'], + ['toaster_final/111157129195_0.JPG', 'd6c24ee8c05d986cafffa6af82ae224e'] + ] + num_training_classes = 11318 + + def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.loader = default_loader + + if download: + download_url(self.url, self.root, self.filename, self.zip_md5) + + if not self._check_integrity(): + # extract file + cwd = os.getcwd() + os.chdir(root) + with zipfile.ZipFile(self.filename, "r") as zip: + zip.extractall() + os.chdir(cwd) + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + ImageFolder.__init__(self, os.path.join(root, self.base_folder), + transform=transform, target_transform=target_transform, **kwargs) + + self.super_classes = self.classes + samples = [] + classes = set() + f = open(os.path.join(root, self.base_folder, 'Ebay_{}.txt'.format('train' if train else 'test'))) + f.readline() + for (image_id, class_id, super_class_id, path) in map(str.split, f): + samples.append((os.path.join(root, self.base_folder, path), int(class_id)-1)) + classes.add("%s.%s" % (class_id, self.super_classes[int(super_class_id)-1])) + + self.samples = samples + self.classes = list(classes) + self.classes.sort(key=lambda x: int(x.split(".")[0])) + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + self.imgs = self.samples + + def _check_integrity(self): + for f, md5 in self.checklist: + fpath = os.path.join(self.root, self.base_folder, f) + if not check_integrity(fpath, md5): + return False + return True diff --git a/cv/distiller/RKD/pytorch/examples/cub200.sh b/cv/distiller/RKD/pytorch/examples/cub200.sh new file mode 100644 index 0000000000000000000000000000000000000000..55b95216f36c187064032aeb8ee19da11e1c733c --- /dev/null +++ b/cv/distiller/RKD/pytorch/examples/cub200.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# Teacher Network +python3 run.py --dataset cub200 --epochs 40 --lr_decay_epochs 25 30 35 --lr_decay_gamma 0.5 --batch 128\ + --base resnet50 --sample distance --margin 0.2 --embedding_size 512 --save_dir cub200_resnet50_512 + +# Student with small embedding +python3 run_distill.py --dataset cub200 --epochs 80 --lr_decay_epochs 40 60 --lr_decay_gamma 0.1 --batch 128\ + --base resnet18 --embedding_size 128 --l2normalize false --dist_ratio 1 --angle_ratio 2 \ + --teacher_base resnet50 --teacher_embedding_size 512 --teacher_load cub200_resnet50_512/best.pth \ + --save_dir cub200_student_resnet18_128 + +# Self-Distillation +python3 run_distill.py --dataset cub200 --epochs 80 --lr_decay_epochs 40 60 --lr_decay_gamma 0.1 --batch 128\ + --base resnet50 --embedding_size 512 --l2normalize false --dist_ratio 1 --angle_ratio 2 \ + --teacher_base resnet50 --teacher_embedding_size 512 --teacher_load cub200_resnet50_512/best.pth \ + --save_dir cub200_student_resnet50_512 diff --git a/cv/distiller/RKD/pytorch/metric/__init__.py b/cv/distiller/RKD/pytorch/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cv/distiller/RKD/pytorch/metric/batchsampler.py b/cv/distiller/RKD/pytorch/metric/batchsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..03c9edb94f0406028666c1b1b26c19e8b005bd86 --- /dev/null +++ b/cv/distiller/RKD/pytorch/metric/batchsampler.py @@ -0,0 +1,45 @@ +import random + +from torchvision.datasets import ImageFolder +from torch.utils.data.sampler import Sampler + + +def index_dataset(dataset: ImageFolder): + kv = [(cls_ind, idx) for idx, (_, cls_ind) in enumerate(dataset.imgs)] + cls_to_ind = {} + + for k, v in kv: + if k in cls_to_ind: + cls_to_ind[k].append(v) + else: + cls_to_ind[k] = [v] + + return cls_to_ind + + +class NPairs(Sampler): + def __init__(self, data_source: ImageFolder, batch_size, m=5, iter_per_epoch=200): + super(Sampler, self).__init__() + self.m = m + self.batch_size = batch_size + self.n_batch = iter_per_epoch + self.class_idx = list(data_source.class_to_idx.values()) + self.images_by_class = index_dataset(data_source) + + def __len__(self): + return self.n_batch + + def __iter__(self): + for _ in range(self.n_batch): + selected_class = random.sample(self.class_idx, k=len(self.class_idx)) + example_indices = [] + + for c in selected_class: + img_ind_of_cls = self.images_by_class[c] + new_ind = random.sample(img_ind_of_cls, k=min(self.m, len(img_ind_of_cls))) + example_indices += new_ind + + if len(example_indices) >= self.batch_size: + break + + yield example_indices[:self.batch_size] diff --git a/cv/distiller/RKD/pytorch/metric/loss.py b/cv/distiller/RKD/pytorch/metric/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1e408d001441a3e5c8528da9deea98cb1676b579 --- /dev/null +++ b/cv/distiller/RKD/pytorch/metric/loss.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from metric.utils import pdist + +__all__ = ['L1Triplet', 'L2Triplet', 'ContrastiveLoss', 'RkdDistance', 'RKdAngle', 'HardDarkRank'] + + +class _Triplet(nn.Module): + def __init__(self, p=2, margin=0.2, sampler=None, reduce=True, size_average=True): + super().__init__() + self.p = p + self.margin = margin + + # update distance function accordingly + self.sampler = sampler + self.sampler.dist_func = lambda e: pdist(e, squared=(p==2)) + + self.reduce = reduce + self.size_average = size_average + + def forward(self, embeddings, labels): + anchor_idx, pos_idx, neg_idx = self.sampler(embeddings, labels) + + anchor_embed = embeddings[anchor_idx] + positive_embed = embeddings[pos_idx] + negative_embed = embeddings[neg_idx] + + loss = F.triplet_margin_loss(anchor_embed, positive_embed, negative_embed, + margin=self.margin, p=self.p, reduction='none') + + if not self.reduce: + return loss + + if self.size_average: + return loss.mean() + else: + return loss.sum() + + +class L2Triplet(_Triplet): + def __init__(self, margin=0.2, sampler=None): + super().__init__(p=2, margin=margin, sampler=sampler) + + +class L1Triplet(_Triplet): + def __init__(self, margin=0.2, sampler=None): + super().__init__(p=1, margin=margin, sampler=sampler) + + +class ContrastiveLoss(nn.Module): + def __init__(self, margin=0.2, sampler=None): + super().__init__() + self.margin = margin + self.sampler = sampler + + def forward(self, embeddings, labels): + anchor_idx, pos_idx, neg_idx = self.sampler(embeddings, labels) + + anchor_embed = embeddings[anchor_idx] + positive_embed = embeddings[pos_idx] + negative_embed = embeddings[neg_idx] + + pos_loss = (F.pairwise_distance(anchor_embed, positive_embed, p=2)).pow(2) + neg_loss = (self.margin - F.pairwise_distance(anchor_embed, negative_embed, p=2)).clamp(min=0).pow(2) + + loss = torch.cat((pos_loss, neg_loss)) + return loss.mean() + + +class HardDarkRank(nn.Module): + def __init__(self, alpha=3, beta=3, permute_len=4): + super().__init__() + self.alpha = alpha + self.beta = beta + self.permute_len = permute_len + + def forward(self, student, teacher): + score_teacher = -1 * self.alpha * pdist(teacher, squared=False).pow(self.beta) + score_student = -1 * self.alpha * pdist(student, squared=False).pow(self.beta) + + permute_idx = score_teacher.sort(dim=1, descending=True)[1][:, 1:(self.permute_len+1)] + ordered_student = torch.gather(score_student, 1, permute_idx) + + log_prob = (ordered_student - torch.stack([torch.logsumexp(ordered_student[:, i:], dim=1) for i in range(permute_idx.size(1))], dim=1)).sum(dim=1) + loss = (-1 * log_prob).mean() + + return loss + + +class FitNet(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + self.in_feature = in_feature + self.out_feature = out_feature + + self.transform = nn.Conv2d(in_feature, out_feature, 1, bias=False) + self.transform.weight.data.uniform_(-0.005, 0.005) + + def forward(self, student, teacher): + if student.dim() == 2: + student = student.unsqueeze(2).unsqueeze(3) + teacher = teacher.unsqueeze(2).unsqueeze(3) + + return (self.transform(student) - teacher).pow(2).mean() + + +class AttentionTransfer(nn.Module): + def forward(self, student, teacher): + s_attention = F.normalize(student.pow(2).mean(1).view(student.size(0), -1)) + + with torch.no_grad(): + t_attention = F.normalize(teacher.pow(2).mean(1).view(teacher.size(0), -1)) + + return (s_attention - t_attention).pow(2).mean() + + +class RKdAngle(nn.Module): + def forward(self, student, teacher): + # N x C + # N x N x C + + with torch.no_grad(): + td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) + norm_td = F.normalize(td, p=2, dim=2) + t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) + + sd = (student.unsqueeze(0) - student.unsqueeze(1)) + norm_sd = F.normalize(sd, p=2, dim=2) + s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) + + loss = F.smooth_l1_loss(s_angle, t_angle, reduction='elementwise_mean') + return loss + + +class RkdDistance(nn.Module): + def forward(self, student, teacher): + with torch.no_grad(): + t_d = pdist(teacher, squared=False) + mean_td = t_d[t_d>0].mean() + t_d = t_d / mean_td + + d = pdist(student, squared=False) + mean_d = d[d>0].mean() + d = d / mean_d + + loss = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean') + return loss diff --git a/cv/distiller/RKD/pytorch/metric/pairsampler.py b/cv/distiller/RKD/pytorch/metric/pairsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ef442352d6997418055620048ea911ba7845ca9a --- /dev/null +++ b/cv/distiller/RKD/pytorch/metric/pairsampler.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from metric.utils import pdist + +BIG_NUMBER = 1e12 +__all__ = ['AllPairs', 'HardNegative', 'SemiHardNegative', 'DistanceWeighted', 'RandomNegative'] + + +def pos_neg_mask(labels): + pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) * \ + (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)) + neg_mask = (labels.unsqueeze(0) != labels.unsqueeze(1)) * \ + (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)) + + return pos_mask, neg_mask + + +class _Sampler(nn.Module): + def __init__(self, dist_func=pdist): + self.dist_func = dist_func + super().__init__() + + def forward(self, embeddings, labels): + raise NotImplementedError + + +class AllPairs(_Sampler): + def forward(self, embeddings, labels): + with torch.no_grad(): + pos_mask, neg_mask = pos_neg_mask(labels) + pos_pair_idx = pos_mask.nonzero() + + apns = [] + for pair_idx in pos_pair_idx: + anchor_idx = pair_idx[0] + neg_indices = neg_mask[anchor_idx].nonzero() + + apn = torch.cat((pair_idx.unsqueeze(0).repeat(len(neg_indices), 1), neg_indices), dim=1) + apns.append(apn) + apns = torch.cat(apns, dim=0) + anchor_idx = apns[:, 0] + pos_idx = apns[:, 1] + neg_idx = apns[:, 2] + + return anchor_idx, pos_idx, neg_idx + + +class RandomNegative(_Sampler): + def forward(self, embeddings, labels): + with torch.no_grad(): + pos_mask, neg_mask = pos_neg_mask(labels) + + pos_pair_index = pos_mask.nonzero() + anchor_idx = pos_pair_index[:, 0] + pos_idx = pos_pair_index[:, 1] + neg_index = torch.multinomial(neg_mask.float()[anchor_idx], 1).squeeze(1) + + return anchor_idx, pos_idx, neg_index + + +class HardNegative(_Sampler): + def forward(self, embeddings, labels): + with torch.no_grad(): + pos_mask, neg_mask = pos_neg_mask(labels) + dist = self.dist_func(embeddings) + + pos_pair_index = pos_mask.nonzero() + anchor_idx = pos_pair_index[:, 0] + pos_idx = pos_pair_index[:, 1] + + neg_dist = (neg_mask.float() * dist) + neg_dist[neg_dist <= 0] = BIG_NUMBER + neg_idx = neg_dist.argmin(dim=1)[anchor_idx] + + return anchor_idx, pos_idx, neg_idx + + +class SemiHardNegative(_Sampler): + def forward(self, embeddings, labels): + with torch.no_grad(): + dist = self.dist_func(embeddings) + pos_mask, neg_mask = pos_neg_mask(labels) + neg_dist = dist * neg_mask.float() + + pos_pair_idx = pos_mask.nonzero() + anchor_idx = pos_pair_idx[:, 0] + pos_idx = pos_pair_idx[:, 1] + + tiled_negative = neg_dist[anchor_idx] + satisfied_neg = (tiled_negative > dist[pos_mask].unsqueeze(1)) * neg_mask[anchor_idx] + """ + When there is no negative pair that its distance bigger than positive pair, + then select negative pair with largest distance. + """ + unsatisfied_neg = (satisfied_neg.sum(dim=1) == 0).unsqueeze(1) * neg_mask[anchor_idx] + + tiled_negative = (satisfied_neg.float() * tiled_negative) - (unsatisfied_neg.float() * tiled_negative) + tiled_negative[tiled_negative == 0] = BIG_NUMBER + neg_idx = tiled_negative.argmin(dim=1) + + return anchor_idx, pos_idx, neg_idx + + +class DistanceWeighted(_Sampler): + cut_off = 0.5 + nonzero_loss_cutoff = 1.4 + """ + Distance Weighted loss assume that embeddings are normalized py 2-norm. + """ + + def forward(self, embeddings, labels): + with torch.no_grad(): + embeddings = F.normalize(embeddings, dim=1, p=2) + pos_mask, neg_mask = pos_neg_mask(labels) + pos_pair_idx = pos_mask.nonzero() + anchor_idx = pos_pair_idx[:, 0] + pos_idx = pos_pair_idx[:, 1] + + d = embeddings.size(1) + dist = (pdist(embeddings, squared=True) + torch.eye(embeddings.size(0), device=embeddings.device, dtype=torch.float32)).sqrt() + dist = dist.clamp(min=self.cut_off) + + log_weight = ((2.0 - d) * dist.log() - ((d - 3.0)/2.0) * (1.0 - 0.25 * (dist * dist)).log()) + weight = (log_weight - log_weight.max(dim=1, keepdim=True)[0]).exp() + weight = weight * (neg_mask * (dist < self.nonzero_loss_cutoff)).float() + + weight = weight + ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float() + weight = weight / (weight.sum(dim=1, keepdim=True)) + weight = weight[anchor_idx] + neg_idx = torch.multinomial(weight, 1).squeeze(1) + + return anchor_idx, pos_idx, neg_idx diff --git a/cv/distiller/RKD/pytorch/metric/utils.py b/cv/distiller/RKD/pytorch/metric/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c8b7cc03668c8184db350c8be15b532314212f --- /dev/null +++ b/cv/distiller/RKD/pytorch/metric/utils.py @@ -0,0 +1,36 @@ +import torch + +__all__ = ['pdist'] + + +def pdist(e, squared=False, eps=1e-12): + e_square = e.pow(2).sum(dim=1) + prod = e @ e.t() + res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) + + if not squared: + res = res.sqrt() + + res = res.clone() + res[range(len(e)), range(len(e))] = 0 + return res + + +def recall(embeddings, labels, K=[]): + D = pdist(embeddings, squared=True) + knn_inds = D.topk(1 + max(K), dim=1, largest=False, sorted=True)[1][:, 1:] + + """ + Check if, knn_inds contain index of query image. + """ + assert ((knn_inds == torch.arange(0, len(labels), device=knn_inds.device).unsqueeze(1)).sum().item() == 0) + + selected_labels = labels[knn_inds.contiguous().view(-1)].view_as(knn_inds) + correct_labels = labels.unsqueeze(1) == selected_labels + + recall_k = [] + + for k in K: + correct_k = (correct_labels[:, :k].sum(dim=1) > 0).float().mean().item() + recall_k.append(correct_k) + return recall_k diff --git a/cv/distiller/RKD/pytorch/model/__init__.py b/cv/distiller/RKD/pytorch/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cv/distiller/RKD/pytorch/model/backbone/__init__.py b/cv/distiller/RKD/pytorch/model/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89e664e54bf990cb8ba972572a52a0b45fbc7a7f --- /dev/null +++ b/cv/distiller/RKD/pytorch/model/backbone/__init__.py @@ -0,0 +1,3 @@ +from .inception.google import GoogleNet +from .inception.v1bn import InceptionV1BN +from .resnet import ResNet50, ResNet18 \ No newline at end of file diff --git a/cv/distiller/RKD/pytorch/model/backbone/inception/__init__.py b/cv/distiller/RKD/pytorch/model/backbone/inception/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cv/distiller/RKD/pytorch/model/backbone/inception/google.py b/cv/distiller/RKD/pytorch/model/backbone/inception/google.py new file mode 100644 index 0000000000000000000000000000000000000000..0850ad5c16f09a802c024ef0b2cf4bf34228a5cc --- /dev/null +++ b/cv/distiller/RKD/pytorch/model/backbone/inception/google.py @@ -0,0 +1,105 @@ +import os +import torch +import torch.nn as nn +import h5py + +from collections import OrderedDict +from torchvision.datasets.utils import download_url + +__all__ = ["GoogleNet"] + + +class GoogleNet(nn.Sequential): + output_size = 1024 + input_side = 227 + rescale = 255.0 + rgb_mean = [122.7717, 115.9465, 102.9801] + rgb_std = [1, 1, 1] + url = "https://github.com/vadimkantorov/metriclearningbench/releases/download/data/googlenet.h5" + md5hash = 'c7d7856bd1ab5cb02618b3f7f564e3c6' + model_filename = 'googlenet.h5' + + def __init__(self, pretrained=True, root='data'): + super(GoogleNet, self).__init__(OrderedDict([ + ('conv1', nn.Sequential(OrderedDict([ + ('7x7_s2', nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3))), + ('relu1', nn.ReLU(True)), + ('pool1', nn.MaxPool2d((3, 3), (2, 2), ceil_mode=True)), + ('lrn1', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1)) + ]))), + + ('conv2', nn.Sequential(OrderedDict([ + ('3x3_reduce', nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0))), + ('relu1', nn.ReLU(True)), + ('3x3', nn.Conv2d(64, 192, (3, 3), (1, 1), (1, 1))), + ('relu2', nn.ReLU(True)), + ('lrn2', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1)), + ('pool2', nn.MaxPool2d((3, 3), (2, 2), ceil_mode=True)) + ]))), + + ('inception_3a', InceptionModule(192, 64, 96, 128, 16, 32, 32)), + ('inception_3b', InceptionModule(256, 128, 128, 192, 32, 96, 64)), + + ('pool3', nn.MaxPool2d((3, 3), (2, 2), ceil_mode=True)), + + ('inception_4a', InceptionModule(480, 192, 96, 208, 16, 48, 64)), + ('inception_4b', InceptionModule(512, 160, 112, 224, 24, 64, 64)), + ('inception_4c', InceptionModule(512, 128, 128, 256, 24, 64, 64)), + ('inception_4d', InceptionModule(512, 112, 144, 288, 32, 64, 64)), + ('inception_4e', InceptionModule(528, 256, 160, 320, 32, 128, 128)), + + ('pool4', nn.MaxPool2d((3, 3), (2, 2), ceil_mode=True)), + + ('inception_5a', InceptionModule(832, 256, 160, 320, 32, 128, 128)), + ('inception_5b', InceptionModule(832, 384, 192, 384, 48, 128, 128)), + + ('pool5', nn.AvgPool2d((7, 7), (1, 1), ceil_mode=True)), + ])) + + if pretrained: + self.load(root) + + def load(self, root): + download_url(self.url, root, self.model_filename, self.md5hash) + h5_file = h5py.File(os.path.join(root, self.model_filename), 'r') + group_key = list(h5_file.keys())[0] + self.load_state_dict({k: torch.from_numpy(v[group_key].value) for k, v in h5_file[group_key].items()}) + + +class InceptionModule(nn.Module): + def __init__(self, inplane, outplane_a1x1, outplane_b3x3_reduce, outplane_b3x3, outplane_c5x5_reduce, outplane_c5x5, + outplane_pool_proj): + super(InceptionModule, self).__init__() + a = nn.Sequential(OrderedDict([ + ('1x1', nn.Conv2d(inplane, outplane_a1x1, (1, 1), (1, 1), (0, 0))), + ('1x1_relu', nn.ReLU(True)) + ])) + + b = nn.Sequential(OrderedDict([ + ('3x3_reduce', nn.Conv2d(inplane, outplane_b3x3_reduce, (1, 1), (1, 1), (0, 0))), + ('3x3_relu1', nn.ReLU(True)), + ('3x3', nn.Conv2d(outplane_b3x3_reduce, outplane_b3x3, (3, 3), (1, 1), (1, 1))), + ('3x3_relu2', nn.ReLU(True)) + ])) + + c = nn.Sequential(OrderedDict([ + ('5x5_reduce', nn.Conv2d(inplane, outplane_c5x5_reduce, (1, 1), (1, 1), (0, 0))), + ('5x5_relu1', nn.ReLU(True)), + ('5x5', nn.Conv2d(outplane_c5x5_reduce, outplane_c5x5, (5, 5), (1, 1), (2, 2))), + ('5x5_relu2', nn.ReLU(True)) + ])) + + d = nn.Sequential(OrderedDict([ + ('pool_pool', nn.MaxPool2d((3, 3), (1, 1), (1, 1))), + ('pool_proj', nn.Conv2d(inplane, outplane_pool_proj, (1, 1), (1, 1), (0, 0))), + ('pool_relu', nn.ReLU(True)) + ])) + + for container in [a, b, c, d]: + for name, module in container.named_children(): + self.add_module(name, module) + + self.branches = [a, b, c, d] + + def forward(self, input): + return torch.cat([branch(input) for branch in self.branches], 1) diff --git a/cv/distiller/RKD/pytorch/model/backbone/inception/v1bn.py b/cv/distiller/RKD/pytorch/model/backbone/inception/v1bn.py new file mode 100644 index 0000000000000000000000000000000000000000..e40f6da8840b6dffbb7e7ae9f4c4a95301a05991 --- /dev/null +++ b/cv/distiller/RKD/pytorch/model/backbone/inception/v1bn.py @@ -0,0 +1,531 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +__all__ = ['InceptionV1BN'] + + +pretrained_settings = { + 'bninception': { + 'imagenet': { + # Was ported using python2 (may trigger warning) + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-239d2248.pth', + # 'url': 'http://yjxiong.me/others/bn_inception-9f5701afb96c8044.pth', + 'input_space': 'BGR', + 'input_size': [3, 224, 224], + 'input_range': [0, 255], + 'mean': [104, 117, 128], + 'std': [1, 1, 1], + 'num_classes': 1000 + } + } +} + + +class InceptionV1BN(nn.Module): + output_size = 1024 + + def __init__(self, pretrained=True): + super(InceptionV1BN, self).__init__() + self.model = bninception(num_classes=1000, pretrained='imagenet' if pretrained else None) + + def forward(self, input): + x = self.model.features(input) + x = self.model.global_pool(x) + return x.view(x.size(0), -1) + + +class BNInception(nn.Module): + + def __init__(self, num_classes=1000): + super(BNInception, self).__init__() + inplace = True + self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) + self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.conv1_relu_7x7 = nn.ReLU (inplace) + self.pool1_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) + self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.conv2_relu_3x3_reduce = nn.ReLU (inplace) + self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.conv2_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.conv2_relu_3x3 = nn.ReLU (inplace) + self.pool2_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_1x1 = nn.ReLU (inplace) + self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_3x3 = nn.ReLU (inplace) + self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, eps=1e-05, momentum=0.9, affine=True) + self.inception_3a_relu_pool_proj = nn.ReLU (inplace) + self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_1x1 = nn.ReLU (inplace) + self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_3x3 = nn.ReLU (inplace) + self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3b_relu_pool_proj = nn.ReLU (inplace) + self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_3x3 = nn.ReLU (inplace) + self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_3c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_3c_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_1x1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_1x1 = nn.ReLU (inplace) + self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_3x3 = nn.ReLU (inplace) + self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4a_relu_pool_proj = nn.ReLU (inplace) + self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_1x1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_1x1 = nn.ReLU (inplace) + self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_3x3 = nn.ReLU (inplace) + self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4b_relu_pool_proj = nn.ReLU (inplace) + self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_1x1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_1x1 = nn.ReLU (inplace) + self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_3x3 = nn.ReLU (inplace) + self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4c_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4c_relu_pool_proj = nn.ReLU (inplace) + self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_1x1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_1x1 = nn.ReLU (inplace) + self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_3x3 = nn.ReLU (inplace) + self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4d_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4d_relu_pool_proj = nn.ReLU (inplace) + self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_3x3 = nn.ReLU (inplace) + self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) + self.inception_4e_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_4e_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) + self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_1x1 = nn.ReLU (inplace) + self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_3x3 = nn.ReLU (inplace) + self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) + self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_5a_relu_pool_proj = nn.ReLU (inplace) + self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_1x1 = nn.ReLU (inplace) + self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_3x3 = nn.ReLU (inplace) + self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_reduce = nn.ReLU (inplace) + self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_1 = nn.ReLU (inplace) + self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_double_3x3_2 = nn.ReLU (inplace) + self.inception_5b_pool = nn.MaxPool2d ((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) + self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) + self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) + self.inception_5b_relu_pool_proj = nn.ReLU (inplace) + self.global_pool = nn.AvgPool2d(7, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.last_linear = nn.Linear(1024, num_classes) + + def features(self, input): + conv1_7x7_s2_out = self.conv1_7x7_s2(input) + conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) + conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) + pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_7x7_s2_bn_out) + conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) + conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) + conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) + conv2_3x3_out = self.conv2_3x3(conv2_3x3_reduce_bn_out) + conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) + conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) + pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_3x3_bn_out) + inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) + inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) + inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) + inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) + inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) + inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) + inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_3x3_reduce_bn_out) + inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) + inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) + inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) + inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out) + inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out) + inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_double_3x3_reduce_bn_out) + inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) + inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) + inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_double_3x3_1_bn_out) + inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) + inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) + inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) + inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) + inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) + inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) + inception_3a_output_out = torch.cat([inception_3a_1x1_bn_out,inception_3a_3x3_bn_out,inception_3a_double_3x3_2_bn_out,inception_3a_pool_proj_bn_out], 1) + inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) + inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) + inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) + inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) + inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) + inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) + inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_3x3_reduce_bn_out) + inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) + inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) + inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) + inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out) + inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out) + inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_double_3x3_reduce_bn_out) + inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) + inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) + inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_double_3x3_1_bn_out) + inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) + inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) + inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) + inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) + inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) + inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) + inception_3b_output_out = torch.cat([inception_3b_1x1_bn_out,inception_3b_3x3_bn_out,inception_3b_double_3x3_2_bn_out,inception_3b_pool_proj_bn_out], 1) + inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) + inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) + inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) + inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_3x3_reduce_bn_out) + inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) + inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) + inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) + inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out) + inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out) + inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_double_3x3_reduce_bn_out) + inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) + inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) + inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_double_3x3_1_bn_out) + inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) + inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) + inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) + inception_3c_output_out = torch.cat([inception_3c_3x3_bn_out,inception_3c_double_3x3_2_bn_out,inception_3c_pool_out], 1) + inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) + inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) + inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) + inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) + inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) + inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) + inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_3x3_reduce_bn_out) + inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) + inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) + inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) + inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out) + inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out) + inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_double_3x3_reduce_bn_out) + inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) + inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) + inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_double_3x3_1_bn_out) + inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) + inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) + inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) + inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) + inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) + inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) + inception_4a_output_out = torch.cat([inception_4a_1x1_bn_out,inception_4a_3x3_bn_out,inception_4a_double_3x3_2_bn_out,inception_4a_pool_proj_bn_out], 1) + inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) + inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) + inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) + inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) + inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) + inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) + inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_3x3_reduce_bn_out) + inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) + inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) + inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) + inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out) + inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out) + inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_double_3x3_reduce_bn_out) + inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) + inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) + inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_double_3x3_1_bn_out) + inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) + inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) + inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) + inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) + inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) + inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) + inception_4b_output_out = torch.cat([inception_4b_1x1_bn_out,inception_4b_3x3_bn_out,inception_4b_double_3x3_2_bn_out,inception_4b_pool_proj_bn_out], 1) + inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) + inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) + inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) + inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) + inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) + inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) + inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_3x3_reduce_bn_out) + inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) + inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) + inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) + inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out) + inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out) + inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_double_3x3_reduce_bn_out) + inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) + inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) + inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_double_3x3_1_bn_out) + inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) + inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) + inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) + inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) + inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) + inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) + inception_4c_output_out = torch.cat([inception_4c_1x1_bn_out,inception_4c_3x3_bn_out,inception_4c_double_3x3_2_bn_out,inception_4c_pool_proj_bn_out], 1) + inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) + inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) + inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) + inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) + inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) + inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) + inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_3x3_reduce_bn_out) + inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) + inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) + inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) + inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out) + inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out) + inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_double_3x3_reduce_bn_out) + inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) + inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) + inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_double_3x3_1_bn_out) + inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) + inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) + inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) + inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) + inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) + inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) + inception_4d_output_out = torch.cat([inception_4d_1x1_bn_out,inception_4d_3x3_bn_out,inception_4d_double_3x3_2_bn_out,inception_4d_pool_proj_bn_out], 1) + inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) + inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) + inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) + inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_3x3_reduce_bn_out) + inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) + inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) + inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) + inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out) + inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out) + inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_double_3x3_reduce_bn_out) + inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) + inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) + inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_double_3x3_1_bn_out) + inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) + inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) + inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) + inception_4e_output_out = torch.cat([inception_4e_3x3_bn_out,inception_4e_double_3x3_2_bn_out,inception_4e_pool_out], 1) + inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) + inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) + inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) + inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) + inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) + inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) + inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_3x3_reduce_bn_out) + inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) + inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) + inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) + inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out) + inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out) + inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_double_3x3_reduce_bn_out) + inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) + inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) + inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_double_3x3_1_bn_out) + inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) + inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) + inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) + inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) + inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) + inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) + inception_5a_output_out = torch.cat([inception_5a_1x1_bn_out,inception_5a_3x3_bn_out,inception_5a_double_3x3_2_bn_out,inception_5a_pool_proj_bn_out], 1) + inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) + inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) + inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) + inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) + inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) + inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) + inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_3x3_reduce_bn_out) + inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) + inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) + inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) + inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out) + inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out) + inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_double_3x3_reduce_bn_out) + inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) + inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) + inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_double_3x3_1_bn_out) + inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) + inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) + inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) + inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) + inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) + inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) + inception_5b_output_out = torch.cat([inception_5b_1x1_bn_out,inception_5b_3x3_bn_out,inception_5b_double_3x3_2_bn_out,inception_5b_pool_proj_bn_out], 1) + return inception_5b_output_out + + def logits(self, features): + x = self.global_pool(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def bninception(num_classes=1000, pretrained='imagenet'): + r"""BNInception model architecture from `_ paper. + """ + model = BNInception(num_classes=num_classes) + if pretrained is not None: + settings = pretrained_settings['bninception'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + weight = model_zoo.load_url(settings['url']) + weight = {k: v.squeeze(0) if v.size(0) == 1 else v for k, v in weight.items()} + model.load_state_dict(weight) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + return model + + +if __name__ == '__main__': + + model = bninception() \ No newline at end of file diff --git a/cv/distiller/RKD/pytorch/model/backbone/resnet.py b/cv/distiller/RKD/pytorch/model/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..140c78b1c11558d24512eba827b0513dc26ed1c2 --- /dev/null +++ b/cv/distiller/RKD/pytorch/model/backbone/resnet.py @@ -0,0 +1,52 @@ +import torchvision +import torch.nn as nn + +__all__ = ['ResNet18', 'ResNet50'] + + +class ResNet18(nn.Module): + output_size = 512 + + def __init__(self, pretrained=True): + super(ResNet18, self).__init__() + pretrained = torchvision.models.resnet18(pretrained=pretrained) + + for module_name in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']: + self.add_module(module_name, getattr(pretrained, module_name)) + + def forward(self, x, get_ha=False): + x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) + b1 = self.layer1(x) + b2 = self.layer2(b1) + b3 = self.layer3(b2) + b4 = self.layer4(b3) + pool = self.avgpool(b4) + + if get_ha: + return b1, b2, b3, b4, pool + + return pool + + +class ResNet50(nn.Module): + output_size = 2048 + + def __init__(self, pretrained=True): + super(ResNet50, self).__init__() + pretrained = torchvision.models.resnet50(pretrained=pretrained) + + for module_name in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']: + self.add_module(module_name, getattr(pretrained, module_name)) + + def forward(self, x, get_ha=False): + x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) + b1 = self.layer1(x) + b2 = self.layer2(b1) + b3 = self.layer3(b2) + b4 = self.layer4(b3) + pool = self.avgpool(b4) + + if get_ha: + return b1, b2, b3, b4, pool + + return pool diff --git a/cv/distiller/RKD/pytorch/model/embedding.py b/cv/distiller/RKD/pytorch/model/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9a09f491b7ccd01507d400fd1c321ab1e5caa6 --- /dev/null +++ b/cv/distiller/RKD/pytorch/model/embedding.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["LinearEmbedding"] + + +class LinearEmbedding(nn.Module): + def __init__(self, base, output_size=512, embedding_size=128, normalize=True): + super(LinearEmbedding, self).__init__() + self.base = base + self.linear = nn.Linear(output_size, embedding_size) + self.normalize = normalize + + def forward(self, x, get_ha=False): + if get_ha: + b1, b2, b3, b4, pool = self.base(x, True) + else: + pool = self.base(x) + + pool = pool.view(x.size(0), -1) + embedding = self.linear(pool) + + if self.normalize: + embedding = F.normalize(embedding, p=2, dim=1) + + if get_ha: + return b1, b2, b3, b4, pool, embedding + + return embedding diff --git a/cv/distiller/RKD/pytorch/run.py b/cv/distiller/RKD/pytorch/run.py new file mode 100644 index 0000000000000000000000000000000000000000..35fbc5a3f1f81ae98ea5965d3061968b9445e432 --- /dev/null +++ b/cv/distiller/RKD/pytorch/run.py @@ -0,0 +1,215 @@ +import os +import argparse +import random + +import torch +import torch.optim as optim +import torchvision.transforms as transforms + +import dataset +import model.backbone as backbone + +import metric.loss as loss +import metric.pairsampler as pair + + +from tqdm import tqdm +from torch.utils.data import DataLoader + +from metric.utils import recall +from metric.batchsampler import NPairs +from model.embedding import LinearEmbedding + + +parser = argparse.ArgumentParser() +LookupChoices = type('', (argparse.Action, ), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + +parser.add_argument('--mode', + choices=["train", "eval"], + default="train") + +parser.add_argument('--load', + default=None) + +parser.add_argument('--dataset', + choices=dict(cub200=dataset.CUB2011Metric, + cars196=dataset.Cars196Metric, + stanford=dataset.StanfordOnlineProductsMetric), + default=dataset.CUB2011Metric, + action=LookupChoices) + +parser.add_argument('--base', + choices=dict(googlenet=backbone.GoogleNet, + inception_v1bn=backbone.InceptionV1BN, + resnet18=backbone.ResNet18, + resnet50=backbone.ResNet50), + default=backbone.ResNet50, + action=LookupChoices) + +parser.add_argument('--sample', + choices=dict(random=pair.RandomNegative, + hard=pair.HardNegative, + all=pair.AllPairs, + semihard=pair.SemiHardNegative, + distance=pair.DistanceWeighted), + default=pair.AllPairs, + action=LookupChoices) + +parser.add_argument('--loss', + choices=dict(l1_triplet=loss.L1Triplet, + l2_triplet=loss.L2Triplet, + contrastive=loss.ContrastiveLoss), + default=loss.L2Triplet, + action=LookupChoices) + +parser.add_argument('--margin', type=float, default=0.2) +parser.add_argument('--embedding_size', type=int, default=128) +parser.add_argument('--l2normalize', choices=['true', 'false'], default='true') + +parser.add_argument('--lr', default=1e-5, type=float) +parser.add_argument('--lr_decay_epochs', type=int, default=[25, 30, 35], nargs='+') +parser.add_argument('--lr_decay_gamma', default=0.5, type=float) + +parser.add_argument('--batch', default=64, type=int) +parser.add_argument('--num_image_per_class', default=5, type=int) + +parser.add_argument('--epochs', default=40, type=int) +parser.add_argument('--iter_per_epoch', type=int, default=100) +parser.add_argument('--recall', default=[1], type=int, nargs='+') + +parser.add_argument('--seed', default=random.randint(1, 1000), type=int) +parser.add_argument('--data', default='data') +parser.add_argument('--save_dir', default=None) +opts = parser.parse_args() + + +for set_random_seed in [random.seed, torch.manual_seed, torch.cuda.manual_seed_all]: + set_random_seed(opts.seed) + +base_model = opts.base(pretrained=True) +if isinstance(base_model, backbone.InceptionV1BN) or isinstance(base_model, backbone.GoogleNet): + normalize = transforms.Compose([ + transforms.Lambda(lambda x: x[[2, 1, 0], ...] * 255.0), + transforms.Normalize(mean=[104, 117, 128], std=[1, 1, 1]), + ]) +else: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +train_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize +]) + + +test_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize +]) + +dataset_train = opts.dataset(opts.data, train=True, transform=train_transform, download=True) +dataset_train_eval = opts.dataset(opts.data, train=True, transform=test_transform, download=True) +dataset_eval = opts.dataset(opts.data, train=False, transform=test_transform, download=True) + +print("Number of images in Training Set: %d" % len(dataset_train)) +print("Number of images in Test set: %d" % len(dataset_eval)) + +loader_train_sample = DataLoader(dataset_train, batch_sampler=NPairs(dataset_train, + opts.batch, + m=opts.num_image_per_class, + iter_per_epoch=opts.iter_per_epoch), + pin_memory=True, num_workers=8) +loader_train_eval = DataLoader(dataset_train_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=False, num_workers=8) +loader_eval = DataLoader(dataset_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=True, num_workers=8) +model = LinearEmbedding(base_model, + output_size=base_model.output_size, + embedding_size=opts.embedding_size, + normalize=opts.l2normalize == 'true').cuda() + +if opts.load is not None: + model.load_state_dict(torch.load(opts.load)) + print("Loaded Model from %s" % opts.load) + +criterion = opts.loss(sampler=opts.sample(), margin=opts.margin) +optimizer = optim.Adam(model.parameters(), lr=opts.lr, weight_decay=1e-5) +lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opts.lr_decay_epochs, gamma=opts.lr_decay_gamma) + + +def train(net, loader, ep): + lr_scheduler.step() + + net.train() + loss_all, norm_all = [], [] + train_iter = tqdm(loader, ncols=80) + for images, labels in train_iter: + images, labels = images.cuda(), labels.cuda() + embedding = net(images) + loss = criterion(embedding, labels) + loss_all.append(loss.item()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + train_iter.set_description("[Train][Epoch %d] Loss: %.5f" % (ep, loss.item())) + print('[Epoch %d] Loss: %.5f\n' % (ep, torch.Tensor(loss_all).mean())) + + +def eval(net, loader, ep): + K = opts.recall + net.eval() + test_iter = tqdm(loader, ncols=80) + embeddings_all, labels_all = [], [] + + test_iter.set_description("[Eval][Epoch %d]" % ep) + with torch.no_grad(): + for images, labels in test_iter: + images, labels = images.cuda(), labels.cuda() + embedding = net(images) + embeddings_all.append(embedding.data) + labels_all.append(labels.data) + + embeddings_all = torch.cat(embeddings_all).cpu() + labels_all = torch.cat(labels_all).cpu() + rec = recall(embeddings_all, labels_all, K=K) + + for k, r in zip(K, rec): + print('[Epoch %d] Recall@%d: [%.4f]\n' % (ep, k, 100 * r)) + + return rec[0] + + +if opts.mode == "eval": + eval(model, loader_train_eval, 0) + eval(model, loader_eval, 0) +else: + train_recall = eval(model, loader_train_eval, 0) + val_recall = eval(model, loader_eval, 0) + best_rec = val_recall + + for epoch in range(1, opts.epochs+1): + train(model, loader_train_sample, epoch) + train_recall = eval(model, loader_train_eval, epoch) + val_recall = eval(model, loader_eval, epoch) + + if best_rec < val_recall: + best_rec = val_recall + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(model.state_dict(), "%s/%s"%(opts.save_dir, "best.pth")) + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(model.state_dict(), "%s/%s"%(opts.save_dir, "last.pth")) + with open("%s/result.txt"%opts.save_dir, 'w') as f: + f.write("Best Recall@1: %.4f\n" % (best_rec * 100)) + f.write("Final Recall@1: %.4f\n" % (val_recall * 100)) + + print("Best Recall@1: %.4f" % best_rec) diff --git a/cv/distiller/RKD/pytorch/run_distill.py b/cv/distiller/RKD/pytorch/run_distill.py new file mode 100644 index 0000000000000000000000000000000000000000..35df20855dc56acaa44eb290f045b67686252823 --- /dev/null +++ b/cv/distiller/RKD/pytorch/run_distill.py @@ -0,0 +1,275 @@ +import os +import argparse + +import dataset +import model.backbone as backbone +import metric.pairsampler as pair + +import torch +import torch.optim as optim +import torchvision.transforms as transforms + +from tqdm import tqdm +from torch.utils.data import DataLoader + +from metric.utils import recall +from metric.batchsampler import NPairs +from metric.loss import HardDarkRank, RkdDistance, RKdAngle, L2Triplet, AttentionTransfer +from model.embedding import LinearEmbedding + + +parser = argparse.ArgumentParser() +LookupChoices = type('', (argparse.Action, ), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + +parser.add_argument('--dataset', + choices=dict(cub200=dataset.CUB2011Metric, + cars196=dataset.Cars196Metric, + stanford=dataset.StanfordOnlineProductsMetric), + default=dataset.CUB2011Metric, + action=LookupChoices) + +parser.add_argument('--base', + choices=dict(googlenet=backbone.GoogleNet, + inception_v1bn=backbone.InceptionV1BN, + resnet18=backbone.ResNet18, + resnet50=backbone.ResNet50), + default=backbone.ResNet50, + action=LookupChoices) + +parser.add_argument('--teacher_base', + choices=dict(googlenet=backbone.GoogleNet, + inception_v1bn=backbone.InceptionV1BN, + resnet18=backbone.ResNet18, + resnet50=backbone.ResNet50), + default=backbone.ResNet50, + action=LookupChoices) + +parser.add_argument('--triplet_ratio', default=0, type=float) +parser.add_argument('--dist_ratio', default=0, type=float) +parser.add_argument('--angle_ratio', default=0, type=float) + +parser.add_argument('--dark_ratio', default=0, type=float) +parser.add_argument('--dark_alpha', default=2, type=float) +parser.add_argument('--dark_beta', default=3, type=float) + +parser.add_argument('--at_ratio', default=0, type=float) + +parser.add_argument('--triplet_sample', + choices=dict(random=pair.RandomNegative, + hard=pair.HardNegative, + all=pair.AllPairs, + semihard=pair.SemiHardNegative, + distance=pair.DistanceWeighted), + default=pair.DistanceWeighted, + action=LookupChoices) + +parser.add_argument('--triplet_margin', type=float, default=0.2) +parser.add_argument('--l2normalize', choices=['true', 'false'], default='true') +parser.add_argument('--embedding_size', default=128, type=int) + +parser.add_argument('--teacher_load', default=None, required=True) +parser.add_argument('--teacher_l2normalize', choices=['true', 'false'], default='true') +parser.add_argument('--teacher_embedding_size', default=128, type=int) + +parser.add_argument('--lr', default=1e-4, type=float) +parser.add_argument('--data', default='data') +parser.add_argument('--epochs', default=80, type=int) +parser.add_argument('--batch', default=64, type=int) +parser.add_argument('--iter_per_epoch', default=100, type=int) +parser.add_argument('--lr_decay_epochs', type=int, default=[40, 60], nargs='+') +parser.add_argument('--lr_decay_gamma', type=float, default=0.1) +parser.add_argument('--save_dir', default=None) +parser.add_argument('--load', default=None) + +opts = parser.parse_args() +student_base = opts.base(pretrained=True) +teacher_base = opts.teacher_base(pretrained=False) + + +def get_normalize(net): + google_mean = torch.Tensor([104, 117, 128]).view(1, -1, 1, 1).cuda() + google_std = torch.Tensor([1, 1, 1]).view(1, -1, 1, 1).cuda() + other_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).cuda() + other_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).cuda() + + def googlenorm(x): + x = x[:, [2, 1, 0]] * 255 + x = (x - google_mean) / google_std + return x + + def othernorm(x): + x = (x - other_mean) / other_std + return x + + if isinstance(net, backbone.InceptionV1BN) or isinstance(net, backbone.GoogleNet): + return googlenorm + else: + return othernorm + + +teacher_normalize = get_normalize(teacher_base) +student_normalize = get_normalize(student_base) + +train_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), +]) + +test_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) + +dataset_train = opts.dataset(opts.data, train=True, transform=train_transform, download=True) +dataset_train_eval = opts.dataset(opts.data, train=True, transform=test_transform, download=True) +dataset_eval = opts.dataset(opts.data, train=False, transform=test_transform, download=True) + +print("Number of images in Training Set: %d" % len(dataset_train)) +print("Number of images in Test set: %d" % len(dataset_eval)) + +loader_train_sample = DataLoader(dataset_train, batch_sampler=NPairs(dataset_train, opts.batch, m=5, + iter_per_epoch=opts.iter_per_epoch), + pin_memory=True, num_workers=8) +loader_train_eval = DataLoader(dataset_train_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=False, num_workers=8) +loader_eval = DataLoader(dataset_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=True, num_workers=8) + +student = LinearEmbedding(student_base, + output_size=student_base.output_size, + embedding_size=opts.embedding_size, + normalize=opts.l2normalize == 'true') + +if opts.load is not None: + student.load_state_dict(torch.load(opts.load)) + print("Loaded Model from %s" % opts.load) + +teacher = LinearEmbedding(teacher_base, + output_size=teacher_base.output_size, + embedding_size=opts.teacher_embedding_size, + normalize=opts.teacher_l2normalize == 'true') + +teacher.load_state_dict(torch.load(opts.teacher_load)) +student = student.cuda() +teacher = teacher.cuda() + +optimizer = optim.Adam(student.parameters(), lr=opts.lr, weight_decay=1e-5) +lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opts.lr_decay_epochs, gamma=opts.lr_decay_gamma) + +dist_criterion = RkdDistance() +angle_criterion = RKdAngle() +dark_criterion = HardDarkRank(alpha=opts.dark_alpha, beta=opts.dark_beta) +triplet_criterion = L2Triplet(sampler=opts.triplet_sample(), margin=opts.triplet_margin) +at_criterion = AttentionTransfer() + + +def train(loader, ep): + lr_scheduler.step() + student.train() + teacher.eval() + + dist_loss_all = [] + angle_loss_all = [] + dark_loss_all = [] + triplet_loss_all = [] + at_loss_all = [] + loss_all = [] + + train_iter = tqdm(loader) + for images, labels in train_iter: + images, labels = images.cuda(), labels.cuda() + + with torch.no_grad(): + t_b1, t_b2, t_b3, t_b4, t_pool, t_e = teacher(teacher_normalize(images), True) + + if isinstance(student.base, backbone.GoogleNet): + assert (opts.at_ratio == 0), "AttentionTransfer cannot be applied on GoogleNet at current implementation." + e = student(student_normalize(images)) + at_loss = torch.zeros(1, device=e.device) + else: + b1, b2, b3, b4, pool, e = student(student_normalize(images), True) + at_loss = opts.at_ratio * (at_criterion(b2, t_b2) + at_criterion(b3, t_b3) + at_criterion(b4, t_b4)) + + triplet_loss = opts.triplet_ratio * triplet_criterion(e, labels) + dist_loss = opts.dist_ratio * dist_criterion(e, t_e) + angle_loss = opts.angle_ratio * angle_criterion(e, t_e) + dark_loss = opts.dark_ratio * dark_criterion(e, t_e) + + loss = triplet_loss + dist_loss + angle_loss + dark_loss + at_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + triplet_loss_all.append(triplet_loss.item()) + dist_loss_all.append(dist_loss.item()) + angle_loss_all.append(angle_loss.item()) + dark_loss_all.append(dark_loss.item()) + at_loss_all.append(at_loss.item()) + loss_all.append(loss.item()) + + train_iter.set_description("[Train][Epoch %d] Triplet: %.5f, Dist: %.5f, Angle: %.5f, Dark: %5f, At: %5f" % + (ep, triplet_loss.item(), dist_loss.item(), angle_loss.item(), dark_loss.item(), at_loss.item())) + print('[Epoch %d] Loss: %.5f, Triplet: %.5f, Dist: %.5f, Angle: %.5f, Dark: %.5f At: %.5f\n' %\ + (ep, torch.Tensor(loss_all).mean(), torch.Tensor(triplet_loss_all).mean(), + torch.Tensor(dist_loss_all).mean(), torch.Tensor(angle_loss_all).mean(), torch.Tensor(dark_loss_all).mean(), + torch.Tensor(at_loss_all).mean())) + + +def eval(net, normalize, loader, ep): + K = [1] + net.eval() + test_iter = tqdm(loader) + embeddings_all, labels_all = [], [] + + with torch.no_grad(): + for images, labels in test_iter: + images, labels = images.cuda(), labels.cuda() + output = net(normalize(images)) + embeddings_all.append(output.data) + labels_all.append(labels.data) + test_iter.set_description("[Eval][Epoch %d]" % ep) + + embeddings_all = torch.cat(embeddings_all).cpu() + labels_all = torch.cat(labels_all).cpu() + rec = recall(embeddings_all, labels_all, K=K) + + for k, r in zip(K, rec): + print('[Epoch %d] Recall@%d: [%.4f]\n' % (ep, k, 100 * r)) + return rec[0] + + +eval(teacher, teacher_normalize, loader_train_eval, 0) +eval(teacher, teacher_normalize, loader_eval, 0) +best_train_rec = eval(student, student_normalize, loader_train_eval, 0) +best_val_rec = eval(student, student_normalize, loader_eval, 0) + +for epoch in range(1, opts.epochs+1): + train(loader_train_sample, epoch) + train_recall = eval(student, student_normalize, loader_train_eval, epoch) + val_recall = eval(student, student_normalize, loader_eval, epoch) + + if best_train_rec < train_recall: + best_train_rec = train_recall + + if best_val_rec < val_recall: + best_val_rec = val_recall + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(student.state_dict(), "%s/%s" % (opts.save_dir, "best.pth")) + + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(student.state_dict(), "%s/%s" % (opts.save_dir, "last.pth")) + with open("%s/result.txt" % opts.save_dir, 'w') as f: + f.write('Best Train Recall@1: %.4f\n' % (best_train_rec * 100)) + f.write("Best Test Recall@1: %.4f\n" % (best_val_rec * 100)) + f.write("Final Recall@1: %.4f\n" % (val_recall * 100)) + + print("Best Train Recall: %.4f" % best_train_rec) + print("Best Eval Recall: %.4f" % best_val_rec) diff --git a/cv/distiller/RKD/pytorch/run_distill_fitnet.py b/cv/distiller/RKD/pytorch/run_distill_fitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..77993f548e3f61bdf694db467d9dd1ae4e2962d6 --- /dev/null +++ b/cv/distiller/RKD/pytorch/run_distill_fitnet.py @@ -0,0 +1,287 @@ +import os +import argparse + +import dataset +import model.backbone as backbone +import metric.pairsampler as pair + +import torch +import torch.optim as optim +import torchvision.transforms as transforms + +from tqdm import tqdm +from torch.utils.data import DataLoader + +from metric.utils import recall +from metric.batchsampler import NPairs +from metric.loss import HardDarkRank, RkdDistance, RKdAngle, L2Triplet, FitNet +from model.embedding import LinearEmbedding + + +parser = argparse.ArgumentParser() +LookupChoices = type('', (argparse.Action, ), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) + +parser.add_argument('--dataset', + choices=dict(cub200=dataset.CUB2011Metric, + cars196=dataset.Cars196Metric, + stanford=dataset.StanfordOnlineProductsMetric), + default=dataset.CUB2011Metric, + action=LookupChoices) + +parser.add_argument('--base', + choices=dict(resnet18=backbone.ResNet18, + resnet50=backbone.ResNet50), + default=backbone.ResNet50, + action=LookupChoices) + +parser.add_argument('--teacher_base', + choices=dict(resnet18=backbone.ResNet18, + resnet50=backbone.ResNet50), + default=backbone.ResNet50, + action=LookupChoices) + +parser.add_argument('--triplet_ratio', default=0, type=float) + +parser.add_argument('--dist_ratio', default=0, type=float) +parser.add_argument('--angle_ratio', default=0, type=float) + +parser.add_argument('--dark_ratio', default=0, type=float) +parser.add_argument('--dark_alpha', default=2, type=float) +parser.add_argument('--dark_beta', default=3, type=float) + +parser.add_argument('--fitnet_ratio', default=1, type=float) + +parser.add_argument('--triplet_sample', + choices=dict(random=pair.RandomNegative, + hard=pair.HardNegative, + all=pair.AllPairs, + semihard=pair.SemiHardNegative, + distance=pair.DistanceWeighted), + default=pair.AllPairs, + action=LookupChoices) + +parser.add_argument('--triplet_margin', type=float, default=0.2) +parser.add_argument('--l2normalize', choices=['true', 'false'], default='true') +parser.add_argument('--embedding_size', default=128, type=int) + +parser.add_argument('--teacher_load', default=None, required=True) +parser.add_argument('--teacher_l2normalize', choices=['true', 'false'], default='true') +parser.add_argument('--teacher_embedding_size', default=128, type=int) + +parser.add_argument('--lr', default=1e-5, type=float) +parser.add_argument('--data', default='data') +parser.add_argument('--epochs', default=80, type=int) +parser.add_argument('--batch', default=64, type=int) +parser.add_argument('--iter_per_epoch', default=100, type=int) +parser.add_argument('--lr_decay_epochs', type=int, default=[40, 60], nargs='+') +parser.add_argument('--lr_decay_gamma', type=float, default=0.1) +parser.add_argument('--save_dir', default=None) + +opts = parser.parse_args() +student_base = opts.base(pretrained=True) +teacher_base = opts.teacher_base(pretrained=False) + + +def get_normalize(net): + google_mean = torch.Tensor([104, 117, 128]).view(1, -1, 1, 1).cuda() + google_std = torch.Tensor([1, 1, 1]).view(1, -1, 1, 1).cuda() + other_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).cuda() + other_std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).cuda() + + def googlenorm(x): + x = x[:, [2, 1, 0]] * 255 + x = (x - google_mean) / google_std + return x + + def othernorm(x): + x = (x - other_mean) / other_std + return x + + if isinstance(net, backbone.InceptionV1BN) or isinstance(net, backbone.GoogleNet): + return googlenorm + else: + return othernorm + + +teacher_normalize = get_normalize(teacher_base) +student_normalize = get_normalize(student_base) + +train_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), +]) + +test_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) + +dataset_train = opts.dataset(opts.data, train=True, transform=train_transform, download=True) +dataset_train_eval = opts.dataset(opts.data, train=True, transform=test_transform, download=True) +dataset_eval = opts.dataset(opts.data, train=False, transform=test_transform, download=True) + +print("Number of images in Training Set: %d" % len(dataset_train)) +print("Number of images in Test set: %d" % len(dataset_eval)) + +loader_train_sample = DataLoader(dataset_train, batch_sampler=NPairs(dataset_train, opts.batch, m=5, + iter_per_epoch=opts.iter_per_epoch), + pin_memory=True, num_workers=8) +loader_train_eval = DataLoader(dataset_train_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=False, num_workers=8) +loader_eval = DataLoader(dataset_eval, shuffle=False, batch_size=opts.batch, drop_last=False, + pin_memory=True, num_workers=8) + +student = LinearEmbedding(student_base, + output_size=student_base.output_size, + embedding_size=opts.embedding_size, + normalize=opts.l2normalize == 'true') + +teacher = LinearEmbedding(teacher_base, + output_size=teacher_base.output_size, + embedding_size=opts.teacher_embedding_size, + normalize=opts.teacher_l2normalize == 'true') + +teacher.load_state_dict(torch.load(opts.teacher_load)) +student = student.cuda() +teacher = teacher.cuda() + + +optimizer = optim.Adam(student.parameters(), lr=opts.lr, weight_decay=1e-5) +lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opts.lr_decay_epochs, gamma=opts.lr_decay_gamma) + +dist_criterion = RkdDistance() +angle_criterion = RKdAngle() +dark_criterion = HardDarkRank(alpha=opts.dark_alpha, beta=opts.dark_beta) +triplet_criterion = L2Triplet(sampler=opts.triplet_sample(), margin=opts.triplet_margin) +fitnet_criterion = [FitNet(64, 256), FitNet(128, 512), FitNet(256, 1024), FitNet(512, 2048), FitNet(opts.embedding_size, 512)] +[f.cuda() for f in fitnet_criterion] + + +def train_fitnet(loader, ep): + lr_scheduler.step() + student.train() + teacher.eval() + loss_all = [] + + train_iter = tqdm(loader) + for images, labels in train_iter: + images, labels = images.cuda(), labels.cuda() + + b1, b2, b3, b4, pool, e = student(student_normalize(images), True) + with torch.no_grad(): + t_b1, t_b2, t_b3, t_b4, t_pool, t_e = teacher(teacher_normalize(images), True) + + loss = opts.fitnet_ratio * (fitnet_criterion[1](b2, t_b2) + + fitnet_criterion[2](b3, t_b3) + + fitnet_criterion[3](b4, t_b4) + + fitnet_criterion[4](e, t_e)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_all.append(loss.item()) + + train_iter.set_description("[Train][Epoch %d] FitNet: %.5f" % (ep, loss.item())) + print('[Epoch %d] Loss: %.5f \n' % (ep, torch.Tensor(loss_all).mean())) + + +def train(loader, ep): + lr_scheduler.step() + student.train() + teacher.eval() + + dist_loss_all = [] + angle_loss_all = [] + dark_loss_all = [] + triplet_loss_all = [] + loss_all = [] + + train_iter = tqdm(loader) + for images, labels in train_iter: + images, labels = images.cuda(), labels.cuda() + + e = student(student_normalize(images)) + with torch.no_grad(): + t_e = teacher(teacher_normalize(images)) + + triplet_loss = opts.triplet_ratio * triplet_criterion(e, labels) + dist_loss = opts.dist_ratio * dist_criterion(e, t_e) + angle_loss = opts.angle_ratio * angle_criterion(e, t_e) + dark_loss = opts.dark_ratio * dark_criterion(e, t_e) + + loss = triplet_loss + dist_loss + angle_loss + dark_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + triplet_loss_all.append(triplet_loss.item()) + dist_loss_all.append(dist_loss.item()) + angle_loss_all.append(angle_loss.item()) + dark_loss_all.append(dark_loss.item()) + loss_all.append(loss.item()) + + train_iter.set_description("[Train][Epoch %d] Triplet: %.5f, Dist: %.5f, Angle: %.5f, Dark: %5f" % + (ep, triplet_loss.item(), dist_loss.item(), angle_loss.item(), dark_loss.item())) + print('[Epoch %d] Loss: %.5f, Triplet: %.5f, Dist: %.5f, Angle: %.5f, Dark: %.5f \n' %\ + (ep, torch.Tensor(loss_all).mean(), torch.Tensor(triplet_loss_all).mean(), + torch.Tensor(dist_loss_all).mean(), torch.Tensor(angle_loss_all).mean(), torch.Tensor(dark_loss_all).mean())) + + +def eval(net, normalize, loader, ep): + K = [1] + net.eval() + test_iter = tqdm(loader) + embeddings_all, labels_all = [], [] + + with torch.no_grad(): + for images, labels in test_iter: + images, labels = images.cuda(), labels.cuda() + output = net(normalize(images)) + embeddings_all.append(output.data) + labels_all.append(labels.data) + test_iter.set_description("[Eval][Epoch %d]" % ep) + + embeddings_all = torch.cat(embeddings_all).cpu() + labels_all = torch.cat(labels_all).cpu() + rec = recall(embeddings_all, labels_all, K=K) + + for k, r in zip(K, rec): + print('[Epoch %d] Recall@%d: [%.4f]\n' % (ep, k, 100 * r)) + return rec[0] + + +eval(teacher, teacher_normalize, loader_train_eval, 0) +eval(teacher, teacher_normalize, loader_eval, 0) +best_train_rec = eval(student, student_normalize, loader_train_eval, 0) +best_val_rec = eval(student, student_normalize, loader_eval, 0) + +for epoch in range(1, opts.epochs+1): + train_fitnet(loader_train_sample, epoch) + train_recall = eval(student, student_normalize, loader_train_eval, epoch) + val_recall = eval(student, student_normalize, loader_eval, epoch) + + if best_train_rec < train_recall: + best_train_rec = train_recall + + if best_val_rec < val_recall: + best_val_rec = val_recall + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(student.state_dict(), "%s/%s" % (opts.save_dir, "best.pth")) + + if opts.save_dir is not None: + if not os.path.isdir(opts.save_dir): + os.mkdir(opts.save_dir) + torch.save(student.state_dict(), "%s/%s" % (opts.save_dir, "last.pth")) + with open("%s/result.txt" % opts.save_dir, 'w') as f: + f.write('Best Train Recall@1: %.4f\n' % (best_train_rec * 100)) + f.write("Best Test Recall@1: %.4f\n" % (best_val_rec * 100)) + f.write("Final Recall@1: %.4f\n" % (val_recall * 100)) + + print("Best Train Recall: %.4f" % best_train_rec) + print("Best Eval Recall: %.4f" % best_val_rec)