diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/WheatRustClassification-master.iml" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/WheatRustClassification-master.iml"
new file mode 100644
index 0000000000000000000000000000000000000000..2490720877fab004e42b2431305a420215af8d9a
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/WheatRustClassification-master.iml"
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/misc.xml" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/misc.xml"
new file mode 100644
index 0000000000000000000000000000000000000000..404edc08e8ba122bb2b28a3b89db4e00d58dc31a
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/misc.xml"
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/modules.xml" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/modules.xml"
new file mode 100644
index 0000000000000000000000000000000000000000..d2ee608f9a9832fa7f1bed90606a09698e1fd947
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/modules.xml"
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/workspace.xml" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/workspace.xml"
new file mode 100644
index 0000000000000000000000000000000000000000..cf80099519390a7f5ffd56b166aec217d89e96ba
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.idea/workspace.xml"
@@ -0,0 +1,156 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1606914352714
+
+
+ 1606914352714
+
+
+
+
\ No newline at end of file
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint(1).ipynb" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint(1).ipynb"
new file mode 100644
index 0000000000000000000000000000000000000000..7b2cb97d7937660aba53503ca55d65e40663c3ec
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint(1).ipynb"
@@ -0,0 +1,331 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load dataset.py\n",
+ "from torch.utils.data import Dataset\n",
+ "from torchvision import transforms, utils\n",
+ "import numpy as np\n",
+ "from scipy import ndimage\n",
+ "import torch\n",
+ "from PIL import Image #图像处理库\n",
+ "\n",
+ "class ICLRDataset(Dataset):\n",
+ " def __init__(self, imgs, gts, split_type, index, transform, img_mix_enable = True):\n",
+ " if index is None:\n",
+ " self.imgs = imgs\n",
+ " self.gts = gts\n",
+ " else:\n",
+ " self.imgs = [imgs[i] for i in index] #图片集\n",
+ " self.gts = [gts[i] for i in index] \n",
+ " \n",
+ " self.split_type = split_type\n",
+ " self.transform = transform\n",
+ " self.img_mix_enable = img_mix_enable\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.imgs)\n",
+ " \n",
+ " def augment(self, img, y): \n",
+ " p = np.random.random(1) #生成0-1之间的一个1维数组\n",
+ " if p[0] > 0.5: #取出数组里的那个数跟0.5作比较\n",
+ " while True:\n",
+ " rnd_idx = np.random.randint(0, len(self.imgs)) #前闭后开,其实就是所有图片索引\n",
+ " if self.gts[rnd_idx] != y: #如果图片标签不是y就跳出---检查是不是有分错类的图片\n",
+ " break\n",
+ " rnd_crop = self.transform(Image.fromarray(self.imgs[rnd_idx])) #用于变换的图片集\n",
+ " d = 0.8\n",
+ " img = img * d + rnd_crop * (1 - d) #对图像进行混合和随机裁剪\n",
+ " return img\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " img = self.imgs[idx]\n",
+ " y = self.gts[idx]\n",
+ " img = Image.fromarray(img)\n",
+ " img = self.transform(img)\n",
+ " if (self.split_type == 'train') & self.img_mix_enable:\n",
+ " img = self.augment(img, y) \n",
+ " return img, y #增强训练集数据,返回增强后的图片和对应标签\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load utils.py\n",
+ "#用于训练、测试和读取数据集图像的使用函数\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.optim import lr_scheduler\n",
+ "import numpy as np\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "#import matplotlib.pyplot as plt\n",
+ "import time\n",
+ "import os\n",
+ "import copy\n",
+ "import torch.nn.functional as F\n",
+ "from PIL import Image, ExifTags\n",
+ "\n",
+ "def train_model_snapshot(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle):\n",
+ " since = time.time() #记录训练时间\n",
+ "\n",
+ " best_model_wts = copy.deepcopy(model.state_dict()) #从预训练的模型中复制权重并初始化模型\n",
+ " best_acc = 0.0\n",
+ " best_loss = 1000000.0\n",
+ " model_w_arr = []\n",
+ " prob = torch.zeros((dataset_sizes['val'], 3), dtype = torch.float32).to(device) #预测\n",
+ " lbl = torch.zeros((dataset_sizes['val'],), dtype = torch.long).to(device) #标签\n",
+ " for cycle in range(num_cycles):\n",
+ " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)#, weight_decay = 0.0005)\n",
+ " scheduler = lr_scheduler.CosineAnnealingLR(optimizer, num_epochs_per_cycle*len(dataloaders['train'])) \n",
+ " #余弦周期调整学习率,修改优化器中的学习率,(优化器,调整间隔数,调整系数)\n",
+ " for epoch in range(num_epochs_per_cycle):\n",
+ " #print('Cycle {}: Epoch {}/{}'.format(cycle, epoch, num_epochs_per_cycle - 1))\n",
+ " #print('-' * 10)\n",
+ "\n",
+ " # Each epoch has a training and validation phase\n",
+ " for phase in ['train', 'val']:\n",
+ " if phase == 'train':\n",
+ " model.train() # Set model to training mode\n",
+ " else:\n",
+ " model.eval() # Set model to evaluate mode\n",
+ "\n",
+ " running_loss = 0.0\n",
+ " running_corrects = 0\n",
+ " idx = 0\n",
+ " # Iterate over data.迭代数据\n",
+ " for inputs, labels in dataloaders[phase]:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " # zero the parameter gradients\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(phase == 'train'):\n",
+ " outputs = model(inputs)\n",
+ " _, preds = torch.max(outputs, 1)\n",
+ " if (epoch == num_epochs_per_cycle-1) and (phase == 'val'):\n",
+ " prob[idx:idx+inputs.shape[0]] += F.softmax(outputs, dim = 1)\n",
+ " lbl[idx:idx+inputs.shape[0]] = labels\n",
+ " idx += inputs.shape[0]\n",
+ " loss = criterion(outputs, labels)\n",
+ " # backward + optimize only if in training phase\n",
+ " if phase == 'train':\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " scheduler.step()\n",
+ " #print(optimizer.param_groups[0]['lr'])\n",
+ " \n",
+ " # statistics\n",
+ " running_loss += loss.item() * inputs.size(0)\n",
+ " running_corrects += torch.sum(preds == labels.data)\n",
+ "\n",
+ " epoch_loss = running_loss / dataset_sizes[phase]\n",
+ " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
+ "\n",
+ " #print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
+ " # phase, epoch_loss, epoch_acc))\n",
+ "\n",
+ " # deep copy the model\n",
+ " if phase == 'val' and epoch_loss < best_loss:\n",
+ " best_loss = epoch_loss\n",
+ " best_model_wts = copy.deepcopy(model.state_dict())\n",
+ " #print()\n",
+ " model_w_arr.append(copy.deepcopy(model.state_dict()))\n",
+ "\n",
+ " prob /= num_cycles\n",
+ " ensemble_loss = F.nll_loss(torch.log(prob), lbl) \n",
+ " ensemble_loss = ensemble_loss.item()\n",
+ " time_elapsed = time.time() - since\n",
+ " #print('Training complete in {:.0f}m {:.0f}s'.format(\n",
+ " # time_elapsed // 60, time_elapsed % 60))\n",
+ " #print('Ensemble Loss : {:4f}, Best val Loss: {:4f}'.format(ensemble_loss, best_loss))\n",
+ "\n",
+ " # load best model weights加载最佳模型权重\n",
+ " model_arr =[]\n",
+ " for weights in model_w_arr:\n",
+ " model.load_state_dict(weights) \n",
+ " model_arr.append(model) \n",
+ " return model_arr, ensemble_loss, best_loss, prob\n",
+ "\n",
+ "def test(models_arr, loader, device):\n",
+ " res = np.zeros((610, 3), dtype = np.float32)\n",
+ " for model in models_arr:\n",
+ " model.eval()\n",
+ " res_arr = []\n",
+ " for inputs, _ in loader:\n",
+ " inputs = inputs.to(device)\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(False):\n",
+ " outputs = F.softmax(model(inputs), dim = 1) \n",
+ " res_arr.append(outputs.detach().cpu().numpy())\n",
+ " res_arr = np.concatenate(res_arr, axis = 0)\n",
+ " res += res_arr\n",
+ " return res / len(models_arr)\n",
+ "\n",
+ "def read_train_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " for i, lbl in enumerate(os.listdir(p)):\n",
+ " for fname in os.listdir(os.path.join(p, lbl)):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, lbl, fname))\n",
+ " #rotate image to original view旋转图像到原始视图\n",
+ " try:\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size将所有图像调整为相同的大小\n",
+ " img = np.array(img.convert('RGB').resize((512,512), Image.ANTIALIAS))\n",
+ " imgs.append(img)\n",
+ " labels.append(i)\n",
+ " return imgs, labels\n",
+ "\n",
+ "def read_test_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " ids = []\n",
+ " for fname in os.listdir(p):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, fname))\n",
+ " #rotate image to original view\n",
+ " try:\n",
+ " if not('DMWVNR' in fname):\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size\n",
+ " img = img.convert('RGB').resize((512,512), Image.ANTIALIAS)\n",
+ " imgs.append(np.array(img.copy()))\n",
+ " labels.append(0)\n",
+ " ids.append(fname.split('.')[0])\n",
+ " img.close()\n",
+ " return imgs, labels, ids\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "ename": "AssertionError",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[0mparser\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'D:\\datasets\\test'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'path to test data folder'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'test_data'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[0mparser\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'D:\\datasets\\savepath'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'save path for training and test numpy matrices of images'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'.'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0margs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mparser\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparse_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m#获取参数,调用上面的属性\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;31m#read training data\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mparse_args\u001b[1;34m(self, args, namespace)\u001b[0m\n\u001b[0;32m 1747\u001b[0m \u001b[1;31m# =====================================\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1748\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mparse_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnamespace\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1749\u001b[1;33m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparse_known_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnamespace\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1750\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0margv\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1751\u001b[0m \u001b[0mmsg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'unrecognized arguments: %s'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mparse_known_args\u001b[1;34m(self, args, namespace)\u001b[0m\n\u001b[0;32m 1779\u001b[0m \u001b[1;31m# parse the arguments and exit if there are any errors\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1780\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1781\u001b[1;33m \u001b[0mnamespace\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_parse_known_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnamespace\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1782\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_UNRECOGNIZED_ARGS_ATTR\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1783\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnamespace\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_UNRECOGNIZED_ARGS_ATTR\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36m_parse_known_args\u001b[1;34m(self, arg_strings, namespace)\u001b[0m\n\u001b[0;32m 2014\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrequired_actions\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2015\u001b[0m self.error(_('the following arguments are required: %s') %\n\u001b[1;32m-> 2016\u001b[1;33m ', '.join(required_actions))\n\u001b[0m\u001b[0;32m 2017\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2018\u001b[0m \u001b[1;31m# make sure all required groups had one option present\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36merror\u001b[1;34m(self, message)\u001b[0m\n\u001b[0;32m 2497\u001b[0m \u001b[0mshould\u001b[0m \u001b[0meither\u001b[0m \u001b[0mexit\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0man\u001b[0m \u001b[0mexception\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2498\u001b[0m \"\"\"\n\u001b[1;32m-> 2499\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprint_usage\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstderr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2500\u001b[0m \u001b[0margs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;34m'prog'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprog\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'message'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2501\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'%(prog)s: error: %(message)s\\n'\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m%\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mprint_usage\u001b[1;34m(self, file)\u001b[0m\n\u001b[0;32m 2467\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfile\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2468\u001b[0m \u001b[0mfile\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstdout\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2469\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_print_message\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat_usage\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2470\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2471\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mprint_help\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfile\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mformat_usage\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 2433\u001b[0m formatter.add_usage(self.usage, self._actions,\n\u001b[0;32m 2434\u001b[0m self._mutually_exclusive_groups)\n\u001b[1;32m-> 2435\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat_help\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2436\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2437\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mformat_help\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mformat_help\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 282\u001b[0m \u001b[1;31m# =======================\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 283\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mformat_help\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 284\u001b[1;33m \u001b[0mhelp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_root_section\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat_help\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 285\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhelp\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 286\u001b[0m \u001b[0mhelp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_long_break_matcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msub\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'\\n\\n'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36mformat_help\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 213\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_indent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[0mjoin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_join_parts\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 215\u001b[1;33m \u001b[0mitem_help\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 216\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparent\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dedent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 213\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_indent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[0mjoin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_join_parts\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 215\u001b[1;33m \u001b[0mitem_help\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 216\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparent\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dedent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\argparse.py\u001b[0m in \u001b[0;36m_format_usage\u001b[1;34m(self, usage, actions, groups, prefix)\u001b[0m\n\u001b[0;32m 338\u001b[0m \u001b[0mpos_parts\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_re\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfindall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpart_regexp\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos_usage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 339\u001b[0m \u001b[1;32massert\u001b[0m \u001b[1;34m' '\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopt_parts\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mopt_usage\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 340\u001b[1;33m \u001b[1;32massert\u001b[0m \u001b[1;34m' '\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpos_parts\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mpos_usage\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 341\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 342\u001b[0m \u001b[1;31m# helper for wrapping lines\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;31mAssertionError\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "# %load prepare_dataset.py\n",
+ "#读取训练数据和测试数据,从训练数据中删除重复的数据并保存在numpy矩阵中\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import argparse\n",
+ "from utils import read_train_data, read_test_data\n",
+ "\n",
+ "parser = argparse.ArgumentParser(description='Data preperation')\n",
+ "parser.add_argument('D:\\datasets\\train', help='path to training data folder', default='train_data', type=str)\n",
+ "parser.add_argument('D:\\datasets\\test', help='path to test data folder', default='test_data', type=str)\n",
+ "parser.add_argument('D:\\datasets\\savepath', help='save path for training and test numpy matrices of images', default='.', type=str)\n",
+ "args = parser.parse_args() #获取参数,调用上面的属性\n",
+ "\n",
+ "#read training data\n",
+ "train_imgs, train_gts = read_train_data(args.train_data_path)\n",
+ "\n",
+ "#remove dublicate training imgs\n",
+ "idx_to_rmv = []\n",
+ "for i in range(len(train_imgs)-1):\n",
+ " for j in range(i+1, len(train_imgs)):\n",
+ " if np.all(train_imgs[i] == train_imgs[j]):\n",
+ " idx_to_rmv.append(i)\n",
+ " if train_gts[i] != train_gts[j]:\n",
+ " idx_to_rmv.append(j)\n",
+ "\n",
+ "idx = [i for i in range(len(train_imgs)) if not(i in idx_to_rmv)]\n",
+ "print('unique train imgs:',len(idx))\n",
+ "\n",
+ "#save unique training imgs\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_imgs_rot_fixed'), np.array(train_imgs)[idx])\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_gts_rot_fixed'), np.array(train_gts)[idx])\n",
+ "\n",
+ "#read test data\n",
+ "test_imgs, test_gts, ids = read_test_data(args.test_data_path)\n",
+ "\n",
+ "#save test data\n",
+ "np.save(os.path.join(args.save_path, 'test_imgs_rot_fixed'), np.array(test_imgs))\n",
+ "np.save(os.path.join(args.save_path, 'test_gts'), np.array(test_gts))\n",
+ "np.save(os.path.join(args.save_path, 'ids'), np.array(ids))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint.ipynb" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint.ipynb"
new file mode 100644
index 0000000000000000000000000000000000000000..acb10783124581994607a556fe508afc08c6f8e8
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.ipynb_checkpoints/Untitled-checkpoint.ipynb"
@@ -0,0 +1,620 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load dataset.py\n",
+ "from torch.utils.data import Dataset\n",
+ "from torchvision import transforms, utils\n",
+ "import numpy as np\n",
+ "from scipy import ndimage\n",
+ "import torch\n",
+ "from PIL import Image #图像处理库\n",
+ "\n",
+ "class ICLRDataset(Dataset):\n",
+ " def __init__(self, imgs, gts, split_type, index, transform, img_mix_enable = True):\n",
+ " if index is None:\n",
+ " self.imgs = imgs\n",
+ " self.gts = gts\n",
+ " else:\n",
+ " self.imgs = [imgs[i] for i in index] #图片集\n",
+ " self.gts = [gts[i] for i in index] \n",
+ " \n",
+ " self.split_type = split_type\n",
+ " self.transform = transform\n",
+ " self.img_mix_enable = img_mix_enable\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.imgs)\n",
+ " \n",
+ " def augment(self, img, y): \n",
+ " p = np.random.random(1) #生成0-1之间的一个1维数组\n",
+ " if p[0] > 0.5: #取出数组里的那个数跟0.5作比较\n",
+ " while True:\n",
+ " rnd_idx = np.random.randint(0, len(self.imgs)) #前闭后开,其实就是所有图片索引\n",
+ " if self.gts[rnd_idx] != y: #如果图片标签不是y就跳出---检查是不是有分错类的图片\n",
+ " break\n",
+ " rnd_crop = self.transform(Image.fromarray(self.imgs[rnd_idx])) #用于变换的图片集\n",
+ " d = 0.8\n",
+ " img = img * d + rnd_crop * (1 - d) #对图像进行混合和随机裁剪\n",
+ " return img\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " img = self.imgs[idx]\n",
+ " y = self.gts[idx]\n",
+ " img = Image.fromarray(img)\n",
+ " img = self.transform(img)\n",
+ " if (self.split_type == 'train') & self.img_mix_enable:\n",
+ " img = self.augment(img, y) \n",
+ " return img, y #增强训练集数据,返回增强后的图片和对应标签\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load utils.py\n",
+ "#用于训练、测试和读取数据集图像的使用函数\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.optim import lr_scheduler\n",
+ "import numpy as np\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "#import matplotlib.pyplot as plt\n",
+ "import time\n",
+ "import os\n",
+ "import copy\n",
+ "import torch.nn.functional as F\n",
+ "from PIL import Image, ExifTags\n",
+ "\n",
+ "def train_model_snapshot(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle):\n",
+ " since = time.time() #记录训练时间\n",
+ "\n",
+ " best_model_wts = copy.deepcopy(model.state_dict()) #从预训练的模型中复制权重并初始化模型\n",
+ " best_acc = 0.0\n",
+ " best_loss = 1000000.0\n",
+ " model_w_arr = []\n",
+ " prob = torch.zeros((dataset_sizes['val'], 3), dtype = torch.float32).to(device) #预测\n",
+ " lbl = torch.zeros((dataset_sizes['val'],), dtype = torch.long).to(device) #标签\n",
+ " for cycle in range(num_cycles):\n",
+ " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)#, weight_decay = 0.0005)\n",
+ " scheduler = lr_scheduler.CosineAnnealingLR(optimizer, num_epochs_per_cycle*len(dataloaders['train'])) \n",
+ " #余弦周期调整学习率,修改优化器中的学习率,(优化器,调整间隔数,调整系数)\n",
+ " for epoch in range(num_epochs_per_cycle):\n",
+ " #print('Cycle {}: Epoch {}/{}'.format(cycle, epoch, num_epochs_per_cycle - 1))\n",
+ " #print('-' * 10)\n",
+ "\n",
+ " # Each epoch has a training and validation phase\n",
+ " for phase in ['train', 'val']:\n",
+ " if phase == 'train':\n",
+ " model.train() # Set model to training mode\n",
+ " else:\n",
+ " model.eval() # Set model to evaluate mode\n",
+ "\n",
+ " running_loss = 0.0\n",
+ " running_corrects = 0\n",
+ " idx = 0\n",
+ " # Iterate over data.迭代数据\n",
+ " for inputs, labels in dataloaders[phase]:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " # zero the parameter gradients\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(phase == 'train'):\n",
+ " outputs = model(inputs)\n",
+ " _, preds = torch.max(outputs, 1)\n",
+ " if (epoch == num_epochs_per_cycle-1) and (phase == 'val'):\n",
+ " prob[idx:idx+inputs.shape[0]] += F.softmax(outputs, dim = 1)\n",
+ " lbl[idx:idx+inputs.shape[0]] = labels\n",
+ " idx += inputs.shape[0]\n",
+ " loss = criterion(outputs, labels)\n",
+ " # backward + optimize only if in training phase\n",
+ " if phase == 'train':\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " scheduler.step()\n",
+ " #print(optimizer.param_groups[0]['lr'])\n",
+ " \n",
+ " # statistics\n",
+ " running_loss += loss.item() * inputs.size(0)\n",
+ " running_corrects += torch.sum(preds == labels.data)\n",
+ "\n",
+ " epoch_loss = running_loss / dataset_sizes[phase]\n",
+ " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
+ "\n",
+ " #print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
+ " # phase, epoch_loss, epoch_acc))\n",
+ "\n",
+ " # deep copy the model\n",
+ " if phase == 'val' and epoch_loss < best_loss:\n",
+ " best_loss = epoch_loss\n",
+ " best_model_wts = copy.deepcopy(model.state_dict())\n",
+ " #print()\n",
+ " model_w_arr.append(copy.deepcopy(model.state_dict()))\n",
+ "\n",
+ " prob /= num_cycles\n",
+ " ensemble_loss = F.nll_loss(torch.log(prob), lbl) \n",
+ " ensemble_loss = ensemble_loss.item()\n",
+ " time_elapsed = time.time() - since\n",
+ " #print('Training complete in {:.0f}m {:.0f}s'.format(\n",
+ " # time_elapsed // 60, time_elapsed % 60))\n",
+ " #print('Ensemble Loss : {:4f}, Best val Loss: {:4f}'.format(ensemble_loss, best_loss))\n",
+ "\n",
+ " # load best model weights加载最佳模型权重\n",
+ " model_arr =[]\n",
+ " for weights in model_w_arr:\n",
+ " model.load_state_dict(weights) \n",
+ " model_arr.append(model) \n",
+ " return model_arr, ensemble_loss, best_loss, prob\n",
+ "\n",
+ "def test(models_arr, loader, device):\n",
+ " res = np.zeros((610, 3), dtype = np.float32)\n",
+ " for model in models_arr:\n",
+ " model.eval()\n",
+ " res_arr = []\n",
+ " for inputs, _ in loader:\n",
+ " inputs = inputs.to(device)\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(False):\n",
+ " outputs = F.softmax(model(inputs), dim = 1) \n",
+ " res_arr.append(outputs.detach().cpu().numpy())\n",
+ " res_arr = np.concatenate(res_arr, axis = 0)\n",
+ " res += res_arr\n",
+ " return res / len(models_arr)\n",
+ "\n",
+ "def read_train_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " for i, lbl in enumerate(os.listdir(p)):\n",
+ " for fname in os.listdir(os.path.join(p, lbl)):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, lbl, fname))\n",
+ " #rotate image to original view旋转图像到原始视图\n",
+ " try:\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size将所有图像调整为相同的大小\n",
+ " img = np.array(img.convert('RGB').resize((512,512), Image.ANTIALIAS))\n",
+ " imgs.append(img)\n",
+ " labels.append(i)\n",
+ " return imgs, labels\n",
+ "\n",
+ "def read_test_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " ids = []\n",
+ " for fname in os.listdir(p):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, fname))\n",
+ " #rotate image to original view\n",
+ " try:\n",
+ " if not('DMWVNR' in fname):\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size\n",
+ " img = img.convert('RGB').resize((512,512), Image.ANTIALIAS)\n",
+ " imgs.append(np.array(img.copy()))\n",
+ " labels.append(0)\n",
+ " ids.append(fname.split('.')[0])\n",
+ " img.close()\n",
+ " return imgs, labels, ids\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "unique train imgs: 732\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "usage: ipykernel_launcher.py [-h] [--train_data_path TRAIN_DATA_PATH]\n",
+ " [--test_data_path TEST_DATA_PATH]\n",
+ " [--save_path SAVE_PATH]\n",
+ "ipykernel_launcher.py: error: unrecognized arguments: -f C:\\Users\\Administrator\\AppData\\Roaming\\jupyter\\runtime\\kernel-60e74fca-82ff-42d7-afc4-1d27b752461b.json\n"
+ ]
+ },
+ {
+ "ename": "SystemExit",
+ "evalue": "2",
+ "output_type": "error",
+ "traceback": [
+ "An exception has occurred, use %tb to see the full traceback.\n",
+ "\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "# %load prepare_dataset.py\n",
+ "#读取训练数据和测试数据,从训练数据中删除重复的数据并保存在numpy矩阵中\n",
+ "%run prepare_dataset.py\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import argparse\n",
+ "from utils import read_train_data, read_test_data\n",
+ "\n",
+ "parser = argparse.ArgumentParser(description='Data preperation')\n",
+ "parser.add_argument('--train_data_path', help='path', default='D:/datasets/train', type=str)\n",
+ "parser.add_argument('--test_data_path', help='path', default='D:/datasets/test', type=str)\n",
+ "parser.add_argument('--save_path', help='save', default='D:/datasets/savepath', type=str)\n",
+ "args = parser.parse_args() #获取参数,调用上面的属性\n",
+ "\n",
+ "#read training data\n",
+ "train_imgs, train_gts = read_train_data(args.train_data_path)\n",
+ "\n",
+ "#remove dublicate training imgs\n",
+ "idx_to_rmv = []\n",
+ "for i in range(len(train_imgs)-1):\n",
+ " for j in range(i+1, len(train_imgs)):\n",
+ " if np.all(train_imgs[i] == train_imgs[j]):\n",
+ " idx_to_rmv.append(i)\n",
+ " if train_gts[i] != train_gts[j]:\n",
+ " idx_to_rmv.append(j)\n",
+ "\n",
+ "idx = [i for i in range(len(train_imgs)) if not(i in idx_to_rmv)]\n",
+ "print('unique train imgs:',len(idx))\n",
+ "\n",
+ "#save unique training imgs\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_imgs_rot_fixed'), np.array(train_imgs)[idx])\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_gts_rot_fixed'), np.array(train_gts)[idx])\n",
+ "\n",
+ "#read test data\n",
+ "test_imgs, test_gts, ids = read_test_data(args.test_data_path)\n",
+ "\n",
+ "#save test data\n",
+ "np.save(os.path.join(args.save_path, 'test_imgs_rot_fixed'), np.array(test_imgs))\n",
+ "np.save(os.path.join(args.save_path, 'test_gts'), np.array(test_gts))\n",
+ "np.save(os.path.join(args.save_path, 'ids'), np.array(ids))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "unique train imgs: 732\n"
+ ]
+ }
+ ],
+ "source": [
+ "%run prepare_dataset.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "usage: ipykernel_launcher.py [-h] [--data_path DATA_PATH]\n",
+ " [--library_size LIBRARY_SIZE]\n",
+ " [--library_path LIBRARY_PATH]\n",
+ "ipykernel_launcher.py: error: unrecognized arguments: -f C:\\Users\\Administrator\\AppData\\Roaming\\jupyter\\runtime\\kernel-9ba0d6cf-6ce4-4517-8c53-84c6b0e19712.json\n"
+ ]
+ },
+ {
+ "ename": "SystemExit",
+ "evalue": "2",
+ "output_type": "error",
+ "traceback": [
+ "An exception has occurred, use %tb to see the full traceback.\n",
+ "\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py:3333: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
+ " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# %load generate_library_of_models.py\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.optim import lr_scheduler\n",
+ "import numpy as np\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import time\n",
+ "import os\n",
+ "import argparse\n",
+ "import copy\n",
+ "from sklearn.model_selection import StratifiedKFold\n",
+ "import datetime\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from dataset import ICLRDataset\n",
+ "from utils import train_model_snapshot, test\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "from hyperopt import hp, tpe, fmin, Trials\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "\n",
+ "def score(params):\n",
+ " global test_prob, val_prob, trails_sc_arr,idx # 全局变量 \n",
+ " print(params)\n",
+ " k = 5 \n",
+ " sss = StratifiedKFold(n_splits=k, shuffle = True, random_state=seed_arr[idx]) #提供训练/测试索引来分割训练/测试集中的数据\n",
+ " #(折叠的数量,至少是2;在分组前是否对每个类的样本进行洗牌;当shuffle为真时,random_state将影响索引的排序)\n",
+ " \n",
+ " #define trail data augmentations 训练集数据增强\n",
+ " data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.ColorJitter(contrast = params['contrast'], hue = params['hue'], brightness = params['brightness']),\n",
+ " transforms.RandomAffine(degrees = params['degrees']),\n",
+ " transforms.RandomResizedCrop(224),\n",
+ " transforms.RandomHorizontalFlip(p = 0.5 if params['h_flip'] else 0.0), #以概率P水平翻转图像\n",
+ " transforms.RandomVerticalFlip(p = 0.5 if params['v_flip'] else 0.0),#以概率P垂直翻转图像\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize((params['val_img_size'], params['val_img_size'])),\n",
+ " transforms.CenterCrop(224),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " }\n",
+ "\n",
+ " trail_test_prob = np.zeros((test_imgs.shape[0], 3), dtype = np.float32)\n",
+ " trail_val_prob = torch.zeros((train_imgs.shape[0], 3), dtype = torch.float32).to(device)\n",
+ " \n",
+ " sc_arr = []\n",
+ " models_arr = []\n",
+ " fold = 0\n",
+ " #train a model for each split\n",
+ " for train_index, val_index in sss.split(train_imgs, train_gts):\n",
+ " #define dataset and loader for training and validation\n",
+ " image_datasets = {'train': ICLRDataset(train_imgs, train_gts, 'train', train_index, data_transforms['train'], params['img_mix_enable']),\n",
+ "\t 'val': ICLRDataset(train_imgs, train_gts, 'val', val_index, data_transforms['val'])}\n",
+ "\n",
+ " dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=16, shuffle=True, num_workers=2),\n",
+ " 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=16, shuffle=False, num_workers=2)}\n",
+ "\n",
+ " #create model instance\n",
+ " model_ft = params['arch'](pretrained=True)\n",
+ " try:\n",
+ " num_ftrs = model_ft.fc.in_features\n",
+ " model_ft.fc = nn.Linear(num_ftrs, 3)\n",
+ " except:\n",
+ " num_ftrs = model_ft.classifier.in_features\n",
+ " model_ft.classifier = nn.Linear(num_ftrs, 3)\n",
+ " model_ft = model_ft.to(device)\n",
+ "\n",
+ " criterion = nn.CrossEntropyLoss()\n",
+ "\n",
+ " dataset_sizes = {x:len(image_datasets[x]) for x in ['train', 'val']}\n",
+ " \n",
+ " model_ft_arr, ensemble_loss, _, fold_val_prob = train_model_snapshot(model_ft, criterion, params['lr'], dataloaders, dataset_sizes, device,\n",
+ " num_cycles=params['num_cycles'], num_epochs_per_cycle=params['num_epochs_per_cycle'])\n",
+ " models_arr.extend(model_ft_arr)\n",
+ " fold += 1\n",
+ " sc_arr.append(ensemble_loss)\n",
+ " trail_val_prob[val_index] = fold_val_prob\n",
+ " \n",
+ " #predict on test data using average of kfold models\n",
+ " image_datasets['test'] = ICLRDataset(test_imgs, test_gts, 'test', None, data_transforms['val'])\n",
+ " test_loader = torch.utils.data.DataLoader(image_datasets['test'], batch_size=4,shuffle=False, num_workers=16)\n",
+ " trail_test_prob = test(models_arr, test_loader, device)\n",
+ "\n",
+ " print('mean val loss:', np.mean(sc_arr))\n",
+ "\n",
+ " test_prob.append(trail_test_prob)\n",
+ " val_prob.append(trail_val_prob)\n",
+ "\n",
+ " #save validation and test results for further processing \n",
+ " np.save(os.path.join(args.library_path, 'val_prob_trail_%d'%(idx)), trail_val_prob.detach().cpu().numpy())\n",
+ " np.save(os.path.join(args.library_path, 'test_prob_trail_%d'%(idx)), trail_test_prob)\n",
+ " idx += 1\n",
+ " \n",
+ " trails_sc_arr.append(np.mean(sc_arr))\n",
+ "\n",
+ " torch.cuda.empty_cache()\n",
+ " del models_arr\n",
+ "\n",
+ " return np.mean(sc_arr)\n",
+ "\n",
+ "parser = argparse.ArgumentParser(description='Data preperation')\n",
+ "parser.add_argument('--data_path', help='path to training and test numpy matrices of images', default='.', type=str)\n",
+ "parser.add_argument('--library_size', help='number of models to be trained in the library of models', default=50, type=int)\n",
+ "parser.add_argument('--library_path', help='save path for validation and test predictions of the library of models', default='trails', type=str)\n",
+ "args = parser.parse_args()\n",
+ "\n",
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "torch.manual_seed(0)\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "\n",
+ "#read train data\n",
+ "train_imgs = np.load(os.path.join(args.data_path, 'D:/datasets/savepath/unique_train_imgs_rot_fixed.npy'))\n",
+ "train_gts = np.load(os.path.join(args.data_path, 'unique_train_gts_rot_fixed.npy'))\n",
+ "\n",
+ "#read test data\n",
+ "test_imgs = np.load(os.path.join(args.data_path, 'test_imgs_rot_fixed.npy'))\n",
+ "test_gts = np.load(os.path.join(args.data_path, 'test_gts.npy'))\n",
+ "ids = np.load(os.path.join(args.data_path, 'ids.npy')).tolist()\n",
+ "\n",
+ "test_prob = []\n",
+ "val_prob = []\n",
+ "trails_sc_arr = []\n",
+ "\n",
+ "n_trails = args.library_size\n",
+ "seed_arr = np.random.randint(low=0, high=1000000, size=n_trails)\n",
+ "\n",
+ "#create search space for hyperparameter optimization\n",
+ "space = OrderedDict([('lr', hp.choice('lr', [i*0.001 for i in range(1,4)])),\n",
+ " ('num_cycles', hp.choice('num_cycles', range(3, 6))),\n",
+ " ('num_epochs_per_cycle', hp.choice('num_epochs_per_cycle', range(3, 6))),\n",
+ " ('arch', hp.choice('arch', [models.densenet201, models.densenet121, models.densenet169,\n",
+ " models.wide_resnet50_2, models.resnet152, \n",
+ " models.resnet101, models.resnet50, models.resnet34, models.resnet18])),\n",
+ " ('img_mix_enable', hp.choice('img_mix_enable', [True, False])),\n",
+ " ('v_flip', hp.choice('v_flip', [True, False])),\n",
+ " ('h_flip', hp.choice('h_flip', [True, False])),\n",
+ " ('degrees', hp.choice('degrees', range(1, 90))),\n",
+ " ('contrast', hp.uniform('contrast', 0.0, 0.3)),\n",
+ " ('hue', hp.uniform('hue', 0.0, 0.3)),\n",
+ " ('brightness', hp.uniform('brightness', 0.0, 0.3)),\n",
+ " ('val_img_size', hp.choice('val_img_size', range(224, 512, 24))),\n",
+ " ])\n",
+ "\n",
+ "trials = Trials()\n",
+ "\n",
+ "idx = 0\n",
+ "if not os.path.exists(args.library_path):\n",
+ " os.mkdir(args.library_path)\n",
+ "\n",
+ "#use tpe algorithm in hyperopt to generate a library of differnet models 利用hyperopt中的tpe算法生成不同模型库\n",
+ "best = fmin(fn=score,space=space,algo=tpe.suggest,max_evals=n_trails,trials=trials)\n",
+ "#fmin是对不用的算法集及其超参数进行迭代,使目标函数最小化的优化函数\n",
+ "#(最小化的目标函数;定义的搜索空间;搜索算法-为超参数空间的顺序搜索提供逻辑;最大评估数;trials对象)\n",
+ "print(best)\n",
+ "\n",
+ "np.save(os.path.join(args.library_path, 'scores.npy'), np.array(trails_sc_arr))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'arch': , 'brightness': 0.20568033410096712, 'contrast': 0.047925566574344027, 'degrees': 87, 'h_flip': True, 'hue': 0.2804254868966057, 'img_mix_enable': True, 'lr': 0.001, 'num_cycles': 5, 'num_epochs_per_cycle': 5, 'v_flip': True, 'val_img_size': 248}\n",
+ " 0%| | 0/50 [00:00, ?trial/s, best loss=?]"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "job exception: module '__main__' has no attribute '__spec__'\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "ename": "AttributeError",
+ "evalue": "module '__main__' has no attribute '__spec__'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\generate_library_of_models.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;31m#use tpe algorithm in hyperopt to generate a library of differnet models\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 162\u001b[1;33m \u001b[0mbest\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfmin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mscore\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mspace\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mspace\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0malgo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtpe\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuggest\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mmax_evals\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mn_trails\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 163\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbest\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 164\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[0;32m 480\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 481\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 482\u001b[1;33m \u001b[0mshow_progressbar\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 483\u001b[0m )\n\u001b[0;32m 484\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\base.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(self, fn, space, algo, max_evals, timeout, loss_threshold, max_queue_len, rstate, verbose, pass_expr_memo_ctrl, catch_eval_exceptions, return_argmin, show_progressbar)\u001b[0m\n\u001b[0;32m 684\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 685\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 686\u001b[1;33m \u001b[0mshow_progressbar\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 687\u001b[0m )\n\u001b[0;32m 688\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[0;32m 507\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 508\u001b[0m \u001b[1;31m# next line is where the fmin is actually executed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 509\u001b[1;33m \u001b[0mrval\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexhaust\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 510\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 511\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mexhaust\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 328\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mexhaust\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 329\u001b[0m \u001b[0mn_done\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 330\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax_evals\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mn_done\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mblock_until_done\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masynchronous\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 331\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 332\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, N, block_until_done)\u001b[0m\n\u001b[0;32m 284\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 285\u001b[0m \u001b[1;31m# -- loop over trials and do the jobs directly\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 286\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mserial_evaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 287\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 288\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mserial_evaluate\u001b[1;34m(self, N)\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[0mctrl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbase\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mCtrl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurrent_trial\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrial\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 164\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 165\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdomain\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctrl\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 166\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 167\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0merror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"job exception: %s\"\u001b[0m \u001b[1;33m%\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\base.py\u001b[0m in \u001b[0;36mevaluate\u001b[1;34m(self, config, ctrl, attach_attachments)\u001b[0m\n\u001b[0;32m 892\u001b[0m \u001b[0mprint_node_on_error\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrec_eval_print_node_on_error\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 893\u001b[0m )\n\u001b[1;32m--> 894\u001b[1;33m \u001b[0mrval\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpyll_rval\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 895\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 896\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrval\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumber\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\generate_library_of_models.py\u001b[0m in \u001b[0;36mscore\u001b[1;34m(params)\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 79\u001b[0m model_ft_arr, ensemble_loss, _, fold_val_prob = train_model_snapshot(model_ft, criterion, params['lr'], dataloaders, dataset_sizes, device,\n\u001b[1;32m---> 80\u001b[1;33m num_cycles=params['num_cycles'], num_epochs_per_cycle=params['num_epochs_per_cycle'])\n\u001b[0m\u001b[0;32m 81\u001b[0m \u001b[0mmodels_arr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_ft_arr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[0mfold\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\utils.py\u001b[0m in \u001b[0;36mtrain_model_snapshot\u001b[1;34m(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle)\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;31m# Iterate over data.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 42\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdataloaders\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 43\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 277\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_SingleProcessDataLoaderIter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 278\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 279\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_MultiProcessingDataLoaderIter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 280\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 281\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, loader)\u001b[0m\n\u001b[0;32m 717\u001b[0m \u001b[1;31m# before it starts, and __del__ tries to join but will get:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 718\u001b[0m \u001b[1;31m# AssertionError: can only join a started process.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 719\u001b[1;33m \u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 720\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_index_queues\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex_queue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 721\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_workers\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\process.py\u001b[0m in \u001b[0;36mstart\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;34m'daemonic processes are not allowed to have children'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0m_cleanup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_popen\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sentinel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_popen\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msentinel\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;31m# Avoid a refcycle if the target function holds an indirect\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\context.py\u001b[0m in \u001b[0;36m_Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 221\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 223\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_default_context\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 224\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 225\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mDefaultContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mBaseContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\context.py\u001b[0m in \u001b[0;36m_Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 320\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 321\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;33m.\u001b[0m\u001b[0mpopen_spawn_win32\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mPopen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 322\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mPopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 323\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 324\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mSpawnContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mBaseContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\popen_spawn_win32.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, process_obj)\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 46\u001b[1;33m \u001b[0mprep_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mspawn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_preparation_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 47\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 48\u001b[0m \u001b[1;31m# read end of pipe will be \"stolen\" by the child process\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\spawn.py\u001b[0m in \u001b[0;36mget_preparation_data\u001b[1;34m(name)\u001b[0m\n\u001b[0;32m 170\u001b[0m \u001b[1;31m# or through direct execution (or to leave it alone entirely)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[0mmain_module\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'__main__'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 172\u001b[1;33m \u001b[0mmain_mod_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmain_module\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__spec__\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"name\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 173\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmain_mod_name\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 174\u001b[0m \u001b[0md\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'init_main_from_name'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmain_mod_name\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;31mAttributeError\u001b[0m: module '__main__' has no attribute '__spec__'"
+ ]
+ }
+ ],
+ "source": [
+ "%run generate_library_of_models.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.keep" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.keep"
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/CACHEDIR.TAG" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/CACHEDIR.TAG"
new file mode 100644
index 0000000000000000000000000000000000000000..381f03a5958d0ecbc936e9d4e3863147757787e9
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/CACHEDIR.TAG"
@@ -0,0 +1,4 @@
+Signature: 8a477f597d28d172789f06886806bc55
+# This file is a cache directory tag created by pytest.
+# For information about cache directory tags, see:
+# http://www.bford.info/cachedir/spec.html
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/README.md" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/README.md"
new file mode 100644
index 0000000000000000000000000000000000000000..b10f023dc7730f954decb50f7fdb61fe6028fc4b
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/.pytest_cache/README.md"
@@ -0,0 +1,8 @@
+# pytest cache directory #
+
+This directory contains data from the pytest's cache plugin,
+which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
+
+**Do not** commit this to version control.
+
+See [the docs](https://docs.pytest.org/en/latest/cache.html) for more information.
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/README.md" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/README.md"
new file mode 100644
index 0000000000000000000000000000000000000000..def227c3fc754b4b83425ab155659b177b933086
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/README.md"
@@ -0,0 +1,127 @@
+# Wheat Rust Classification from Ensemble Selection of CNNs
+
+Fourth place solution for CGIAR Computer Vision for Crop Disease competition organized by CV4A workshop at ICLR 2020. The main objective of the competition is to classify a given image of wheat whether it is healthy, has a stem rust or has a leaf rust.
+
+## Summary of Approach
+
+Create an ensemble from a library of diverse models with different architectures and augmentations. All models are initially pre-trained on imagenet and fine-tuned on the dataset. The models and augmentations are chosen automatically using hyperparameter optimization.
+
+### Model Architectures
+
+The following architecturs are included in the library of models:
+
+* ResNet [1]
+* ResNext [2]
+* WideResNet [3]
+* DenseNet [4]
+
+### Data Augmentations
+
+The following augmentations are included in the search space of hyperparameter optimization to choose from:
+
+* Rotation
+* Random cropping and resizing
+* Horizontal flipping
+* Vertical flipping
+* Brightness augmentation
+* Hue augmentation
+* Contrast augmentation
+* Mixup augmentation [5]
+
+### Common Configuration
+
+The following configurations is applied on all trails in hyperparameter optimization process:
+
+* Stochastic Gradient Descent (SGD) optimizer
+* Snapshot ensemble [6]
+* 5-Fold training
+
+## Getting Started
+
+### Prerequisites
+
+Firstly, you need to have
+
+* Ubuntu 18.04
+* Python3
+* At least 11 GB GPU RAM
+
+Secondly, you need to install the challenge data and sample submission file by the following the instructions [here](https://zindi.africa/competitions/iclr-workshop-challenge-1-cgiar-computer-vision-for-crop-disease/data).
+
+Thirdly, you need to install the dependencies by running:
+
+```
+pip3 install -r requirements.txt
+```
+
+### Project files
+
+* prepare_dataset.py: reads training and test data, removes duplicates from training data and saves them in numpy matrices. It has the following arguments:
+
+ --train_data_path: path to training data folder
+
+ --test_data_path: path to test data folder
+
+ --save_path: save path for training and test numpy matrices of images
+
+* generate_library_of_models.py: generates a library of models with different architectures and augmentations through hyperparameter optimization search. It has the following arguments:
+
+ --data_path: path to training and test numpy matrices of images
+
+ --library_size: number of models to be trained in the library of models
+
+ --library_path: save path for validation and test predictions of the library of models
+
+* ensemble_selection.py: applies Ensemble Selection [7] algorithm on the generated library of models to find the best ensemble with the lowest validation error and use it to create the final submission. It has the following arguments:
+
+ --train_data_path: path to training data folder
+
+ --data_path: path to training and test numpy matrices of images
+
+ --sample_sub_file_path: path to sample submission file
+
+ --library_size: number of models to be trained in the library of models
+
+ --library_path: save path for validation and test predictions of the library of models
+
+ --final_sub_file_save_path: save path for final submission file
+
+* dataset.py: has the dataset class for training and test data.
+
+* utils.py: utility functions for training, testing and reading dataset images.
+
+## Running
+
+### 1- Prepare dataset
+
+```
+python3 prepare_dataset.py
+```
+
+### 2- Generate the library of models
+
+```
+python3 generate_library_of_models.py
+```
+
+### 3- Create ensemble and generate submission file
+
+```
+python3 ensemble_selection.py
+```
+
+## References
+[1] He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+
+[2] Xie, Saining, et al. "Aggregated residual transformations for deep neural networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
+
+[3] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).
+
+[4] Huang, Gao, et al. "Densely connected convolutional networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
+
+[5] Zhang, Hongyi, et al. "mixup: Beyond empirical risk minimization." arXiv preprint arXiv:1710.09412 (2017).
+
+[6] Huang, Gao, et al. "Snapshot ensembles: Train 1, get m for free." arXiv preprint arXiv:1704.00109 (2017).
+
+[7] Caruana, Rich, et al. "Ensemble selection from libraries of models." Proceedings of the twenty-first international conference on Machine learning. 2004.
+
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/Untitled.ipynb" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/Untitled.ipynb"
new file mode 100644
index 0000000000000000000000000000000000000000..c8c0a43b6b86ff082fc7a7509de99de367205866
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/Untitled.ipynb"
@@ -0,0 +1,620 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load dataset.py\n",
+ "from torch.utils.data import Dataset\n",
+ "from torchvision import transforms, utils\n",
+ "import numpy as np\n",
+ "from scipy import ndimage\n",
+ "import torch\n",
+ "from PIL import Image #图像处理库\n",
+ "\n",
+ "class ICLRDataset(Dataset):\n",
+ " def __init__(self, imgs, gts, split_type, index, transform, img_mix_enable = True):\n",
+ " if index is None:\n",
+ " self.imgs = imgs\n",
+ " self.gts = gts\n",
+ " else:\n",
+ " self.imgs = [imgs[i] for i in index] #图片集\n",
+ " self.gts = [gts[i] for i in index] \n",
+ " \n",
+ " self.split_type = split_type\n",
+ " self.transform = transform\n",
+ " self.img_mix_enable = img_mix_enable\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.imgs)\n",
+ " \n",
+ " def augment(self, img, y): \n",
+ " p = np.random.random(1) #生成0-1之间的一个1维数组\n",
+ " if p[0] > 0.5: #取出数组里的那个数跟0.5作比较\n",
+ " while True:\n",
+ " rnd_idx = np.random.randint(0, len(self.imgs)) #前闭后开,其实就是所有图片索引\n",
+ " if self.gts[rnd_idx] != y: #如果图片标签不是y就跳出---检查是不是有分错类的图片\n",
+ " break\n",
+ " rnd_crop = self.transform(Image.fromarray(self.imgs[rnd_idx])) #用于变换的图片集\n",
+ " d = 0.8\n",
+ " img = img * d + rnd_crop * (1 - d) #对图像进行混合和随机裁剪\n",
+ " return img\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " img = self.imgs[idx]\n",
+ " y = self.gts[idx]\n",
+ " img = Image.fromarray(img)\n",
+ " img = self.transform(img)\n",
+ " if (self.split_type == 'train') & self.img_mix_enable:\n",
+ " img = self.augment(img, y) \n",
+ " return img, y #增强训练集数据,返回增强后的图片和对应标签\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %load utils.py\n",
+ "#用于训练、测试和读取数据集图像的使用函数\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.optim import lr_scheduler\n",
+ "import numpy as np\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "#import matplotlib.pyplot as plt\n",
+ "import time\n",
+ "import os\n",
+ "import copy\n",
+ "import torch.nn.functional as F\n",
+ "from PIL import Image, ExifTags\n",
+ "\n",
+ "def train_model_snapshot(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle):\n",
+ " since = time.time() #记录训练时间\n",
+ "\n",
+ " best_model_wts = copy.deepcopy(model.state_dict()) #从预训练的模型中复制权重并初始化模型\n",
+ " best_acc = 0.0\n",
+ " best_loss = 1000000.0\n",
+ " model_w_arr = []\n",
+ " prob = torch.zeros((dataset_sizes['val'], 3), dtype = torch.float32).to(device) #预测\n",
+ " lbl = torch.zeros((dataset_sizes['val'],), dtype = torch.long).to(device) #标签\n",
+ " for cycle in range(num_cycles):\n",
+ " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)#, weight_decay = 0.0005)\n",
+ " scheduler = lr_scheduler.CosineAnnealingLR(optimizer, num_epochs_per_cycle*len(dataloaders['train'])) \n",
+ " #余弦周期调整学习率,修改优化器中的学习率,(优化器,调整间隔数,调整系数)\n",
+ " for epoch in range(num_epochs_per_cycle):\n",
+ " #print('Cycle {}: Epoch {}/{}'.format(cycle, epoch, num_epochs_per_cycle - 1))\n",
+ " #print('-' * 10)\n",
+ "\n",
+ " # Each epoch has a training and validation phase\n",
+ " for phase in ['train', 'val']:\n",
+ " if phase == 'train':\n",
+ " model.train() # Set model to training mode\n",
+ " else:\n",
+ " model.eval() # Set model to evaluate mode\n",
+ "\n",
+ " running_loss = 0.0\n",
+ " running_corrects = 0\n",
+ " idx = 0\n",
+ " # Iterate over data.迭代数据\n",
+ " for inputs, labels in dataloaders[phase]:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " # zero the parameter gradients\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(phase == 'train'):\n",
+ " outputs = model(inputs)\n",
+ " _, preds = torch.max(outputs, 1)\n",
+ " if (epoch == num_epochs_per_cycle-1) and (phase == 'val'):\n",
+ " prob[idx:idx+inputs.shape[0]] += F.softmax(outputs, dim = 1)\n",
+ " lbl[idx:idx+inputs.shape[0]] = labels\n",
+ " idx += inputs.shape[0]\n",
+ " loss = criterion(outputs, labels)\n",
+ " # backward + optimize only if in training phase\n",
+ " if phase == 'train':\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " scheduler.step()\n",
+ " #print(optimizer.param_groups[0]['lr'])\n",
+ " \n",
+ " # statistics\n",
+ " running_loss += loss.item() * inputs.size(0)\n",
+ " running_corrects += torch.sum(preds == labels.data)\n",
+ "\n",
+ " epoch_loss = running_loss / dataset_sizes[phase]\n",
+ " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
+ "\n",
+ " #print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
+ " # phase, epoch_loss, epoch_acc))\n",
+ "\n",
+ " # deep copy the model\n",
+ " if phase == 'val' and epoch_loss < best_loss:\n",
+ " best_loss = epoch_loss\n",
+ " best_model_wts = copy.deepcopy(model.state_dict())\n",
+ " #print()\n",
+ " model_w_arr.append(copy.deepcopy(model.state_dict()))\n",
+ "\n",
+ " prob /= num_cycles\n",
+ " ensemble_loss = F.nll_loss(torch.log(prob), lbl) \n",
+ " ensemble_loss = ensemble_loss.item()\n",
+ " time_elapsed = time.time() - since\n",
+ " #print('Training complete in {:.0f}m {:.0f}s'.format(\n",
+ " # time_elapsed // 60, time_elapsed % 60))\n",
+ " #print('Ensemble Loss : {:4f}, Best val Loss: {:4f}'.format(ensemble_loss, best_loss))\n",
+ "\n",
+ " # load best model weights加载最佳模型权重\n",
+ " model_arr =[]\n",
+ " for weights in model_w_arr:\n",
+ " model.load_state_dict(weights) \n",
+ " model_arr.append(model) \n",
+ " return model_arr, ensemble_loss, best_loss, prob\n",
+ "\n",
+ "def test(models_arr, loader, device):\n",
+ " res = np.zeros((610, 3), dtype = np.float32)\n",
+ " for model in models_arr:\n",
+ " model.eval()\n",
+ " res_arr = []\n",
+ " for inputs, _ in loader:\n",
+ " inputs = inputs.to(device)\n",
+ " # forward\n",
+ " # track history if only in train\n",
+ " with torch.set_grad_enabled(False):\n",
+ " outputs = F.softmax(model(inputs), dim = 1) \n",
+ " res_arr.append(outputs.detach().cpu().numpy())\n",
+ " res_arr = np.concatenate(res_arr, axis = 0)\n",
+ " res += res_arr\n",
+ " return res / len(models_arr)\n",
+ "\n",
+ "def read_train_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " for i, lbl in enumerate(os.listdir(p)):\n",
+ " for fname in os.listdir(os.path.join(p, lbl)):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, lbl, fname))\n",
+ " #rotate image to original view旋转图像到原始视图\n",
+ " try:\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size将所有图像调整为相同的大小\n",
+ " img = np.array(img.convert('RGB').resize((512,512), Image.ANTIALIAS))\n",
+ " imgs.append(img)\n",
+ " labels.append(i)\n",
+ " return imgs, labels\n",
+ "\n",
+ "def read_test_data(p):\n",
+ " imgs = []\n",
+ " labels = []\n",
+ " ids = []\n",
+ " for fname in os.listdir(p):\n",
+ " #read image\n",
+ " img = Image.open(os.path.join(p, fname))\n",
+ " #rotate image to original view\n",
+ " try:\n",
+ " if not('DMWVNR' in fname):\n",
+ " exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)\n",
+ " if exif['Orientation'] == 3:\n",
+ " img=img.rotate(180, expand=True)\n",
+ " elif exif['Orientation'] == 6:\n",
+ " img=img.rotate(270, expand=True)\n",
+ " elif exif['Orientation'] == 8:\n",
+ " img=img.rotate(90, expand=True)\n",
+ " except:\n",
+ " pass\n",
+ " #resize all images to the same size\n",
+ " img = img.convert('RGB').resize((512,512), Image.ANTIALIAS)\n",
+ " imgs.append(np.array(img.copy()))\n",
+ " labels.append(0)\n",
+ " ids.append(fname.split('.')[0])\n",
+ " img.close()\n",
+ " return imgs, labels, ids\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "unique train imgs: 732\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "usage: ipykernel_launcher.py [-h] [--train_data_path TRAIN_DATA_PATH]\n",
+ " [--test_data_path TEST_DATA_PATH]\n",
+ " [--save_path SAVE_PATH]\n",
+ "ipykernel_launcher.py: error: unrecognized arguments: -f C:\\Users\\Administrator\\AppData\\Roaming\\jupyter\\runtime\\kernel-60e74fca-82ff-42d7-afc4-1d27b752461b.json\n"
+ ]
+ },
+ {
+ "ename": "SystemExit",
+ "evalue": "2",
+ "output_type": "error",
+ "traceback": [
+ "An exception has occurred, use %tb to see the full traceback.\n",
+ "\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "# %load prepare_dataset.py\n",
+ "#读取训练数据和测试数据,从训练数据中删除重复的数据并保存在numpy矩阵中\n",
+ "%run prepare_dataset.py\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import argparse\n",
+ "from utils import read_train_data, read_test_data\n",
+ "\n",
+ "parser = argparse.ArgumentParser(description='Data preperation')\n",
+ "parser.add_argument('--train_data_path', help='path', default='D:/datasets/train', type=str)\n",
+ "parser.add_argument('--test_data_path', help='path', default='D:/datasets/test', type=str)\n",
+ "parser.add_argument('--save_path', help='save', default='D:/datasets/savepath', type=str)\n",
+ "args = parser.parse_args() #获取参数,调用上面的属性\n",
+ "\n",
+ "#read training data\n",
+ "train_imgs, train_gts = read_train_data(args.train_data_path)\n",
+ "\n",
+ "#remove dublicate training imgs\n",
+ "idx_to_rmv = []\n",
+ "for i in range(len(train_imgs)-1):\n",
+ " for j in range(i+1, len(train_imgs)):\n",
+ " if np.all(train_imgs[i] == train_imgs[j]):\n",
+ " idx_to_rmv.append(i)\n",
+ " if train_gts[i] != train_gts[j]:\n",
+ " idx_to_rmv.append(j)\n",
+ "\n",
+ "idx = [i for i in range(len(train_imgs)) if not(i in idx_to_rmv)]\n",
+ "print('unique train imgs:',len(idx))\n",
+ "\n",
+ "#save unique training imgs\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_imgs_rot_fixed'), np.array(train_imgs)[idx])\n",
+ "np.save(os.path.join(args.save_path, 'unique_train_gts_rot_fixed'), np.array(train_gts)[idx])\n",
+ "\n",
+ "#read test data\n",
+ "test_imgs, test_gts, ids = read_test_data(args.test_data_path)\n",
+ "\n",
+ "#save test data\n",
+ "np.save(os.path.join(args.save_path, 'test_imgs_rot_fixed'), np.array(test_imgs))\n",
+ "np.save(os.path.join(args.save_path, 'test_gts'), np.array(test_gts))\n",
+ "np.save(os.path.join(args.save_path, 'ids'), np.array(ids))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "unique train imgs: 732\n"
+ ]
+ }
+ ],
+ "source": [
+ "%run prepare_dataset.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "usage: ipykernel_launcher.py [-h] [--data_path DATA_PATH]\n",
+ " [--library_size LIBRARY_SIZE]\n",
+ " [--library_path LIBRARY_PATH]\n",
+ "ipykernel_launcher.py: error: unrecognized arguments: -f C:\\Users\\Administrator\\AppData\\Roaming\\jupyter\\runtime\\kernel-9ba0d6cf-6ce4-4517-8c53-84c6b0e19712.json\n"
+ ]
+ },
+ {
+ "ename": "SystemExit",
+ "evalue": "2",
+ "output_type": "error",
+ "traceback": [
+ "An exception has occurred, use %tb to see the full traceback.\n",
+ "\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py:3333: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
+ " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# %load generate_library_of_models.py\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.optim import lr_scheduler\n",
+ "import numpy as np\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import time\n",
+ "import os\n",
+ "import argparse\n",
+ "import copy\n",
+ "from sklearn.model_selection import StratifiedKFold\n",
+ "import datetime\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from dataset import ICLRDataset\n",
+ "from utils import train_model_snapshot, test\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "from hyperopt import hp, tpe, fmin, Trials\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "\n",
+ "def score(params):\n",
+ " global test_prob, val_prob, trails_sc_arr,idx # 全局变量 \n",
+ " print(params)\n",
+ " k = 5 \n",
+ " sss = StratifiedKFold(n_splits=k, shuffle = True, random_state=seed_arr[idx]) #提供训练/测试索引来分割训练/测试集中的数据\n",
+ " #(折叠的数量,至少是2;在分组前是否对每个类的样本进行洗牌;当shuffle为真时,random_state将影响索引的排序)\n",
+ " \n",
+ " #define trail data augmentations 训练集数据增强和归一化、验证集归一化\n",
+ " data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.ColorJitter(contrast = params['contrast'], hue = params['hue'], brightness = params['brightness']),\n",
+ " transforms.RandomAffine(degrees = params['degrees']),\n",
+ " transforms.RandomResizedCrop(224),\n",
+ " transforms.RandomHorizontalFlip(p = 0.5 if params['h_flip'] else 0.0), #以概率P水平翻转图像\n",
+ " transforms.RandomVerticalFlip(p = 0.5 if params['v_flip'] else 0.0),#以概率P垂直翻转图像\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize((params['val_img_size'], params['val_img_size'])),\n",
+ " transforms.CenterCrop(224),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " }\n",
+ "\n",
+ " trail_test_prob = np.zeros((test_imgs.shape[0], 3), dtype = np.float32)\n",
+ " trail_val_prob = torch.zeros((train_imgs.shape[0], 3), dtype = torch.float32).to(device)\n",
+ " \n",
+ " sc_arr = []\n",
+ " models_arr = []\n",
+ " fold = 0\n",
+ " #train a model for each split 为每一个分割训练一个模型\n",
+ " for train_index, val_index in sss.split(train_imgs, train_gts):\n",
+ " #define dataset and loader for training and validation 确定数据集,载入训练集、验证集\n",
+ " image_datasets = {'train': ICLRDataset(train_imgs, train_gts, 'train', train_index, data_transforms['train'], params['img_mix_enable']),\n",
+ "\t 'val': ICLRDataset(train_imgs, train_gts, 'val', val_index, data_transforms['val'])}\n",
+ "\n",
+ " dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=16, shuffle=True, num_workers=2),\n",
+ " 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=16, shuffle=False, num_workers=2)}\n",
+ "\n",
+ " #create model instance创建模型实例\n",
+ " model_ft = params['arch'](pretrained=True) #预训练\n",
+ " try:\n",
+ " num_ftrs = model_ft.fc.in_features\n",
+ " model_ft.fc = nn.Linear(num_ftrs, 3)\n",
+ " except:\n",
+ " num_ftrs = model_ft.classifier.in_features\n",
+ " model_ft.classifier = nn.Linear(num_ftrs, 3)\n",
+ " model_ft = model_ft.to(device)\n",
+ "\n",
+ " criterion = nn.CrossEntropyLoss()\n",
+ "\n",
+ " dataset_sizes = {x:len(image_datasets[x]) for x in ['train', 'val']}\n",
+ " \n",
+ " model_ft_arr, ensemble_loss, _, fold_val_prob = train_model_snapshot(model_ft, criterion, params['lr'], dataloaders, dataset_sizes, device,\n",
+ " num_cycles=params['num_cycles'], num_epochs_per_cycle=params['num_epochs_per_cycle'])\n",
+ " models_arr.extend(model_ft_arr)\n",
+ " fold += 1\n",
+ " sc_arr.append(ensemble_loss)\n",
+ " trail_val_prob[val_index] = fold_val_prob\n",
+ " \n",
+ " #predict on test data using average of kfold models\n",
+ " image_datasets['test'] = ICLRDataset(test_imgs, test_gts, 'test', None, data_transforms['val'])\n",
+ " test_loader = torch.utils.data.DataLoader(image_datasets['test'], batch_size=4,shuffle=False, num_workers=16)\n",
+ " trail_test_prob = test(models_arr, test_loader, device)\n",
+ "\n",
+ " print('mean val loss:', np.mean(sc_arr))\n",
+ "\n",
+ " test_prob.append(trail_test_prob)\n",
+ " val_prob.append(trail_val_prob)\n",
+ "\n",
+ " #save validation and test results for further processing \n",
+ " np.save(os.path.join(args.library_path, 'val_prob_trail_%d'%(idx)), trail_val_prob.detach().cpu().numpy())\n",
+ " np.save(os.path.join(args.library_path, 'test_prob_trail_%d'%(idx)), trail_test_prob)\n",
+ " idx += 1\n",
+ " \n",
+ " trails_sc_arr.append(np.mean(sc_arr))\n",
+ "\n",
+ " torch.cuda.empty_cache()\n",
+ " del models_arr\n",
+ "\n",
+ " return np.mean(sc_arr)\n",
+ "\n",
+ "parser = argparse.ArgumentParser(description='Data preperation')\n",
+ "parser.add_argument('--data_path', help='path to training and test numpy matrices of images', default='.', type=str)\n",
+ "parser.add_argument('--library_size', help='number of models to be trained in the library of models', default=50, type=int)\n",
+ "parser.add_argument('--library_path', help='save path for validation and test predictions of the library of models', default='trails', type=str)\n",
+ "args = parser.parse_args()\n",
+ "\n",
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "torch.manual_seed(0)\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "\n",
+ "#read train data\n",
+ "train_imgs = np.load(os.path.join(args.data_path, 'D:/datasets/savepath/unique_train_imgs_rot_fixed.npy'))\n",
+ "train_gts = np.load(os.path.join(args.data_path, 'unique_train_gts_rot_fixed.npy'))\n",
+ "\n",
+ "#read test data\n",
+ "test_imgs = np.load(os.path.join(args.data_path, 'test_imgs_rot_fixed.npy'))\n",
+ "test_gts = np.load(os.path.join(args.data_path, 'test_gts.npy'))\n",
+ "ids = np.load(os.path.join(args.data_path, 'ids.npy')).tolist()\n",
+ "\n",
+ "test_prob = []\n",
+ "val_prob = []\n",
+ "trails_sc_arr = []\n",
+ "\n",
+ "n_trails = args.library_size\n",
+ "seed_arr = np.random.randint(low=0, high=1000000, size=n_trails)\n",
+ "\n",
+ "#create search space for hyperparameter optimization\n",
+ "space = OrderedDict([('lr', hp.choice('lr', [i*0.001 for i in range(1,4)])),\n",
+ " ('num_cycles', hp.choice('num_cycles', range(3, 6))),\n",
+ " ('num_epochs_per_cycle', hp.choice('num_epochs_per_cycle', range(3, 6))),\n",
+ " ('arch', hp.choice('arch', [models.densenet201, models.densenet121, models.densenet169,\n",
+ " models.wide_resnet50_2, models.resnet152, \n",
+ " models.resnet101, models.resnet50, models.resnet34, models.resnet18])),\n",
+ " ('img_mix_enable', hp.choice('img_mix_enable', [True, False])),\n",
+ " ('v_flip', hp.choice('v_flip', [True, False])),\n",
+ " ('h_flip', hp.choice('h_flip', [True, False])),\n",
+ " ('degrees', hp.choice('degrees', range(1, 90))),\n",
+ " ('contrast', hp.uniform('contrast', 0.0, 0.3)),\n",
+ " ('hue', hp.uniform('hue', 0.0, 0.3)),\n",
+ " ('brightness', hp.uniform('brightness', 0.0, 0.3)),\n",
+ " ('val_img_size', hp.choice('val_img_size', range(224, 512, 24))),\n",
+ " ])\n",
+ "\n",
+ "trials = Trials()\n",
+ "\n",
+ "idx = 0\n",
+ "if not os.path.exists(args.library_path):\n",
+ " os.mkdir(args.library_path)\n",
+ "\n",
+ "#use tpe algorithm in hyperopt to generate a library of differnet models 利用hyperopt中的tpe算法生成不同模型库\n",
+ "best = fmin(fn=score,space=space,algo=tpe.suggest,max_evals=n_trails,trials=trials)\n",
+ "#fmin是对不用的算法集及其超参数进行迭代,使目标函数最小化的优化函数\n",
+ "#(最小化的目标函数;定义的搜索空间;搜索算法-为超参数空间的顺序搜索提供逻辑;最大评估数;trials对象)\n",
+ "print(best)\n",
+ "\n",
+ "np.save(os.path.join(args.library_path, 'scores.npy'), np.array(trails_sc_arr))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'arch': , 'brightness': 0.006903174616102503, 'contrast': 0.10395561286019817, 'degrees': 14, 'h_flip': True, 'hue': 0.12930405533670436, 'img_mix_enable': False, 'lr': 0.001, 'num_cycles': 5, 'num_epochs_per_cycle': 5, 'v_flip': True, 'val_img_size': 416}\n",
+ " 0%| | 0/100 [00:00, ?trial/s, best loss=?]"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "job exception: module '__main__' has no attribute '__spec__'\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "ename": "AttributeError",
+ "evalue": "module '__main__' has no attribute '__spec__'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\generate_library_of_models.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;31m#use tpe algorithm in hyperopt to generate a library of differnet models\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 162\u001b[0m \u001b[1;31m#best = fmin(fn=score,space=space,algo=tpe.suggest,max_evals=n_trails,trials=trials)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 163\u001b[1;33m \u001b[0mbest\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfmin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mscore\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mspace\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mspace\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0malgo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtpe\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuggest\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mmax_evals\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m100\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 164\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbest\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[0;32m 480\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 481\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 482\u001b[1;33m \u001b[0mshow_progressbar\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 483\u001b[0m )\n\u001b[0;32m 484\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\base.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(self, fn, space, algo, max_evals, timeout, loss_threshold, max_queue_len, rstate, verbose, pass_expr_memo_ctrl, catch_eval_exceptions, return_argmin, show_progressbar)\u001b[0m\n\u001b[0;32m 684\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 685\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 686\u001b[1;33m \u001b[0mshow_progressbar\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 687\u001b[0m )\n\u001b[0;32m 688\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[1;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[0;32m 507\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 508\u001b[0m \u001b[1;31m# next line is where the fmin is actually executed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 509\u001b[1;33m \u001b[0mrval\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexhaust\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 510\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 511\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mexhaust\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 328\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mexhaust\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 329\u001b[0m \u001b[0mn_done\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 330\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax_evals\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mn_done\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mblock_until_done\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masynchronous\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 331\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 332\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, N, block_until_done)\u001b[0m\n\u001b[0;32m 284\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 285\u001b[0m \u001b[1;31m# -- loop over trials and do the jobs directly\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 286\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mserial_evaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 287\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 288\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\fmin.py\u001b[0m in \u001b[0;36mserial_evaluate\u001b[1;34m(self, N)\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[0mctrl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbase\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mCtrl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurrent_trial\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrial\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 164\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 165\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdomain\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctrl\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 166\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 167\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0merror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"job exception: %s\"\u001b[0m \u001b[1;33m%\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\hyperopt\\base.py\u001b[0m in \u001b[0;36mevaluate\u001b[1;34m(self, config, ctrl, attach_attachments)\u001b[0m\n\u001b[0;32m 892\u001b[0m \u001b[0mprint_node_on_error\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrec_eval_print_node_on_error\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 893\u001b[0m )\n\u001b[1;32m--> 894\u001b[1;33m \u001b[0mrval\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpyll_rval\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 895\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 896\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrval\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumber\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\generate_library_of_models.py\u001b[0m in \u001b[0;36mscore\u001b[1;34m(params)\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 79\u001b[0m model_ft_arr, ensemble_loss, _, fold_val_prob = train_model_snapshot(model_ft, criterion, params['lr'], dataloaders, dataset_sizes, device,\n\u001b[1;32m---> 80\u001b[1;33m num_cycles=params['num_cycles'], num_epochs_per_cycle=params['num_epochs_per_cycle'])\n\u001b[0m\u001b[0;32m 81\u001b[0m \u001b[0mmodels_arr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_ft_arr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[0mfold\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\Desktop\\WheatRustClassification-master\\utils.py\u001b[0m in \u001b[0;36mtrain_model_snapshot\u001b[1;34m(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle)\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;31m# Iterate over data.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 42\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdataloaders\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 43\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 277\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_SingleProcessDataLoaderIter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 278\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 279\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_MultiProcessingDataLoaderIter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 280\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 281\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, loader)\u001b[0m\n\u001b[0;32m 717\u001b[0m \u001b[1;31m# before it starts, and __del__ tries to join but will get:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 718\u001b[0m \u001b[1;31m# AssertionError: can only join a started process.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 719\u001b[1;33m \u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 720\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_index_queues\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex_queue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 721\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_workers\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\process.py\u001b[0m in \u001b[0;36mstart\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;34m'daemonic processes are not allowed to have children'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0m_cleanup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_popen\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sentinel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_popen\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msentinel\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;31m# Avoid a refcycle if the target function holds an indirect\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\context.py\u001b[0m in \u001b[0;36m_Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 221\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 223\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_default_context\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 224\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 225\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mDefaultContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mBaseContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\context.py\u001b[0m in \u001b[0;36m_Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 320\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_Popen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 321\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;33m.\u001b[0m\u001b[0mpopen_spawn_win32\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mPopen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 322\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mPopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 323\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 324\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mSpawnContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mBaseContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\popen_spawn_win32.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, process_obj)\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mprocess_obj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 46\u001b[1;33m \u001b[0mprep_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mspawn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_preparation_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprocess_obj\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 47\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 48\u001b[0m \u001b[1;31m# read end of pipe will be \"stolen\" by the child process\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\multiprocessing\\spawn.py\u001b[0m in \u001b[0;36mget_preparation_data\u001b[1;34m(name)\u001b[0m\n\u001b[0;32m 170\u001b[0m \u001b[1;31m# or through direct execution (or to leave it alone entirely)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[0mmain_module\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'__main__'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 172\u001b[1;33m \u001b[0mmain_mod_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmain_module\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__spec__\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"name\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 173\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmain_mod_name\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 174\u001b[0m \u001b[0md\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'init_main_from_name'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmain_mod_name\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;31mAttributeError\u001b[0m: module '__main__' has no attribute '__spec__'"
+ ]
+ }
+ ],
+ "source": [
+ "%run generate_library_of_models.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-36.pyc" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-36.pyc"
new file mode 100644
index 0000000000000000000000000000000000000000..fb3841946c3e952ba38f0a23c969491afe094e13
Binary files /dev/null and "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-36.pyc" differ
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37(1).pyc" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37(1).pyc"
new file mode 100644
index 0000000000000000000000000000000000000000..0fbc904d20541e8000ff00da46a22a48026ce44a
Binary files /dev/null and "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37(1).pyc" differ
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37.pyc" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37.pyc"
new file mode 100644
index 0000000000000000000000000000000000000000..0fbc904d20541e8000ff00da46a22a48026ce44a
Binary files /dev/null and "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/__pycache__/dataset.cpython-37.pyc" differ
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/dataset.py" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/dataset.py"
new file mode 100644
index 0000000000000000000000000000000000000000..b68a127da255af8c35aa64f7efd31b414750360d
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/dataset.py"
@@ -0,0 +1,43 @@
+from torch.utils.data import Dataset
+from torchvision import transforms, utils
+import numpy as np
+from scipy import ndimage
+import torch
+from PIL import Image
+
+class ICLRDataset(Dataset):
+ def __init__(self, imgs, gts, split_type, index, transform, img_mix_enable = True):
+ if index is None:
+ self.imgs = imgs
+ self.gts = gts
+ else:
+ self.imgs = [imgs[i] for i in index]
+ self.gts = [gts[i] for i in index]
+
+ self.split_type = split_type
+ self.transform = transform
+ self.img_mix_enable = img_mix_enable
+
+ def __len__(self):
+ return len(self.imgs)
+
+ def augment(self, img, y):
+ p = np.random.random(1)
+ if p[0] > 0.5:
+ while True:
+ rnd_idx = np.random.randint(0, len(self.imgs))
+ if self.gts[rnd_idx] != y:
+ break
+ rnd_crop = self.transform(Image.fromarray(self.imgs[rnd_idx]))
+ d = 0.8
+ img = img * d + rnd_crop * (1 - d)
+ return img
+
+ def __getitem__(self, idx):
+ img = self.imgs[idx]
+ y = self.gts[idx]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ if (self.split_type == 'train') & self.img_mix_enable:
+ img = self.augment(img, y)
+ return img, y
\ No newline at end of file
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/ensemble_selection.py" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/ensemble_selection.py"
new file mode 100644
index 0000000000000000000000000000000000000000..9da4b64cd208c1d820e9bd2a692afa4aa37be881
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/ensemble_selection.py"
@@ -0,0 +1,169 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import numpy as np
+import torchvision
+from torchvision import datasets, models, transforms
+import pandas as pd
+#import matplotlib.pyplot as plt
+import time
+import os
+import copy
+import argparse
+from sklearn.model_selection import StratifiedKFold
+import datetime
+from PIL import Image
+
+import torch.nn.functional as F
+
+def cross_entropy(y, p):
+ p /= p.sum(1).reshape(-1,1)
+ return F.nll_loss(torch.log(torch.tensor(p)), torch.tensor(y)).numpy()
+
+def weighted_cross_entropy(y, p):
+ p /= p.sum(1).reshape(-1,1)
+ w_arr = np.array([0.53, 0.3, 0.0])
+ return np.sum([F.nll_loss(torch.log(torch.tensor(p[y==c])), torch.tensor(y[y==c])).numpy()*w_arr[c] for c in range(3)])
+
+class ensembleSelection:
+
+ def __init__(self, metric):
+ self.metric = metric
+
+ def _compare(self, sc1, sc2):
+ if sc1 < sc2:
+ return True
+ return False
+
+ def _initialize(self, X_p, y):
+ """
+ This function finds the id of the best validation probabiltiy
+ """
+ current_sc = self.metric(y, X_p[0])
+ ind = 0
+ for i in range(1, X_p.shape[0]):
+ print(i)
+ sc = self.metric(y, X_p[i])
+ print(sc)
+ if self._compare(sc, current_sc):
+ current_sc = sc
+ ind = i
+ return ind, current_sc
+
+ def es_with_replacement(self, X_p, Xtest_p, y):
+ best_ind, best_sc = self._initialize(X_p, y)
+ current_sc = best_sc
+ sumP = np.copy(X_p[best_ind])
+ sumP_test = np.copy(Xtest_p[best_ind])
+ i = 1
+ # find the best combintation of input models' reuslts
+ while True:
+ i += 1
+ ind = -1
+ for m in range(X_p.shape[0]):
+ #check if adding model m to the combination of best models will improve the results or not
+ sc = self.metric(y, (sumP*X_p[m])**(1/i))
+ if self._compare(sc, current_sc):
+ current_sc = sc
+ ind = m
+ if ind>-1:
+ sumP *= X_p[ind]
+ sumP_test *= Xtest_p[ind]
+ else:
+ break
+ sumP = sumP**(1/(i-1))
+ sumP_test = sumP_test**(1/(i-1))
+
+ sumP /= sumP.sum(1).reshape(-1,1)
+ sumP_test /= sumP_test.sum(1).reshape(-1,1)
+
+ return current_sc, sumP, sumP_test
+
+ def es_with_bagging(self, X_p, Xtest_p, y, f = 0.5, n_bags = 20):
+ list_of_indecies = [i for i in range(X_p.shape[0])]
+ bag_size = int(f*X_p.shape[0])
+ sumP = None
+ sumP_test = None
+ for i in range(n_bags):
+ #create a random subset (bag) of models
+ model_weight = [0 for j in range(X_p.shape[0])]
+ rng = np.copy(list_of_indecies)
+ np.random.shuffle(rng)
+ rng = rng[:bag_size]
+ #find the best combination from the input bag
+ sc, p, ptest = self.es_with_replacement(X_p[rng], Xtest_p[rng], y)
+ print('bag: %d, sc: %f'%(i, sc))
+ if sumP is None:
+ sumP = p
+ sumP_test = ptest
+ else:
+ sumP *= p
+ sumP_test *= ptest
+
+ #combine the reuslts of all bags
+ sumP = sumP**(1/n_bags)
+ sumP_test = sumP_test**(1/n_bags)
+
+ sumP /= sumP.sum(1).reshape(-1,1)
+ sumP_test /= sumP_test.sum(1).reshape(-1,1)
+
+ sumP[sumP < 1e-6] = 1e-6
+ sumP_test[sumP_test < 1e-6] = 1e-6
+
+ final_sc = self.metric(y, sumP)
+ print('avg sc: %f'%(final_sc))
+ return (final_sc, sumP, sumP_test)
+
+parser = argparse.ArgumentParser(description='Data preperation')
+parser.add_argument('--train_data_path', help='path to training data folder', default='train_data', type=str)
+parser.add_argument('--data_path', help='path to training and test numpy matrices of images', default='.', type=str)
+parser.add_argument('--sample_sub_file_path', help='path to sample submission file', default='.', type=str)
+parser.add_argument('--library_size', help='number of models to be trained in the library of models', default=50, type=int)
+parser.add_argument('--library_path', help='save path for validation and test predictions of the library of models', default='trails', type=str)
+parser.add_argument('--final_sub_file_save_path', help='save path for final submission file', default='.', type=str)
+args = parser.parse_args()
+
+np.random.seed(4321)
+
+n = args.library_size
+
+#read training gt
+train_gts = np.load(os.path.join(args.data_path, 'C:/Users/x2/Desktop/pq/wheat rust/savepath/unique_train_gts_rot_fixed.npy'))
+#train_gts = np.load(os.path.join(args.data_path, 'unique_train_gts_rot_fixed.npy'))
+#read validation probability on training data generated from automatuic hypropt trails
+#and create a matrix of (N,D,3) where N i the number of models and D is the data size
+train_prob = np.array([np.load(os.path.join(args.library_path, 'C:/Users/x2/Desktop/pq/wheat rust/trails/val_prob_trail_%d.npy'%(i))) for i in range(n)])
+#train_prob = np.array([np.load(os.path.join(args.library_path, 'val_prob_trail_%d.npy'%(i))) for i in range(n)])
+
+#read test probability generated from hypropt trails
+#and create a matrix of (N,D,3) where N is the number of models and D is the data size
+test_prob = np.array([np.load(os.path.join(args.library_path, 'C:/Users/x2/Desktop/pq/wheat rust/trails/test_prob_trail_%d.npy'%(i))) for i in range(n)])
+#test_prob = np.array([np.load(os.path.join(args.library_path, 'test_prob_trail_%d.npy'%(i))) for i in range(n)])
+
+ids = np.load('C:/Users/x2/Desktop/pq/wheat rust/savepath/ids.npy').tolist()
+
+#use ensemble selection algorithm to find best combination of models using geometric average
+es_obj = ensembleSelection(cross_entropy)
+sc, es_train_prob, es_test_prob = es_obj.es_with_bagging(train_prob, test_prob, train_gts, n_bags = 10, f = 0.65)
+
+#detect samples with high confidence for healthy wheat
+idx = (np.max(es_test_prob, 1) > 0.7) & (np.argmax(es_test_prob, 1) == 2)
+
+#create another ensemble with more weights for leaf and stem classes
+es_obj = ensembleSelection(weighted_cross_entropy)
+sc, es_train_prob, es_test_prob = es_obj.es_with_bagging(train_prob, test_prob, train_gts, n_bags = 10, f = 0.65)
+
+#increase the probability of confident samples for healthy wheat
+es_test_prob[idx, 0] = 1e-6
+es_test_prob[idx, 1] = 1e-6
+es_test_prob[idx, 2] = 1.0
+
+#create submission
+sub = pd.read_csv(os.path.join(args.sample_sub_file_path, 'sample_submission.csv'))
+sub['ID'] = ids
+lbl_names = os.listdir(args.train_data_path)
+for i, name in enumerate(lbl_names):
+ sub[name] = es_test_prob[:,i].tolist()
+sub.to_csv(os.path.join(args.final_sub_file_save_path, 'final_sub.csv'), index = False)
+
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/generate_library_of_models.py" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/generate_library_of_models.py"
new file mode 100644
index 0000000000000000000000000000000000000000..d397095b4f1534993c3f250464de8a7ed13b7cfd
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/generate_library_of_models.py"
@@ -0,0 +1,191 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import numpy as np
+from torchvision import datasets, models, transforms
+
+import pandas as pd
+#import matplotlib.pyplot as plt
+import time
+import os
+import argparse
+import copy
+from sklearn.model_selection import StratifiedKFold
+import datetime
+from PIL import Image
+import torch.nn.functional as F
+
+from dataset import ICLRDataset
+from utils import train_model_snapshot, test
+from sklearn.metrics import confusion_matrix
+from hyperopt import hp, tpe, fmin, Trials
+from collections import OrderedDict
+
+
+
+
+
+def score(params):
+ global test_prob, val_prob, trails_sc_arr,idx
+ print(params)
+ k = 5
+ sss = StratifiedKFold(n_splits=k, shuffle = True, random_state=seed_arr[idx])
+ #define trail data augmentations
+ data_transforms = {
+ 'train': transforms.Compose([
+ transforms.ColorJitter(contrast = params['contrast'], hue = params['hue'], brightness = params['brightness']),
+ transforms.RandomAffine(degrees = params['degrees']),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(p = 0.5 if params['h_flip'] else 0.0),
+ transforms.RandomVerticalFlip(p = 0.5 if params['v_flip'] else 0.0),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ]),
+ 'val': transforms.Compose([
+ transforms.Resize((params['val_img_size'], params['val_img_size'])),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ]),
+ }
+
+ trail_test_prob = np.zeros((test_imgs.shape[0], 3), dtype = np.float32)
+ trail_val_prob = torch.zeros((train_imgs.shape[0], 3), dtype = torch.float32).to(device)
+
+ sc_arr = []
+ models_arr = []
+# fold = 0
+ #train a model for each split
+ for train_index, val_index in sss.split(train_imgs, train_gts):
+ #define dataset and loader for training and validation
+ image_datasets = {'train': ICLRDataset(train_imgs, train_gts, 'train', train_index, data_transforms['train'], params['img_mix_enable']),
+ 'val': ICLRDataset(train_imgs, train_gts, 'val', val_index, data_transforms['val'])}
+
+ dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=16, shuffle=True, num_workers=0),
+ 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=16, shuffle=False, num_workers=0)}
+
+ #create model instance
+ model_ft = params['arch'](pretrained=True)
+ try:
+ num_ftrs = model_ft.fc.in_features
+ model_ft.fc = nn.Linear(num_ftrs, 3)
+ except:
+ num_ftrs = model_ft.classifier.in_features
+ model_ft.classifier = nn.Linear(num_ftrs, 3)
+ model_ft = model_ft.to(device)
+
+ criterion = nn.CrossEntropyLoss()
+
+ dataset_sizes = {x:len(image_datasets[x]) for x in ['train', 'val']}
+
+ model_ft_arr, ensemble_loss, _, fold_val_prob = train_model_snapshot(model_ft, criterion, params['lr'], dataloaders, dataset_sizes, device,
+ num_cycles=params['num_cycles'], num_epochs_per_cycle=params['num_epochs_per_cycle'])
+ models_arr.extend(model_ft_arr)
+ fold += 1 #K折交叉验证
+ sc_arr.append(ensemble_loss)
+ trail_val_prob[val_index] = fold_val_prob
+
+ #predict on test data using average of kfold models
+ image_datasets['test'] = ICLRDataset(test_imgs, test_gts, 'test', None, data_transforms['val'])
+ test_loader = torch.utils.data.DataLoader(image_datasets['test'], batch_size=4,shuffle=False, num_workers=0)
+ trail_test_prob = test(models_arr, test_loader, device)
+
+ print('mean val loss:', np.mean(sc_arr))
+
+ test_prob.append(trail_test_prob)
+ val_prob.append(trail_val_prob)
+
+ #save validation and test results for further processing
+ np.save(os.path.join(args.library_path, 'val_prob_trail_%d'%(idx)), trail_val_prob.detach().cpu().numpy())
+ np.save(os.path.join(args.library_path, 'test_prob_trail_%d'%(idx)), trail_test_prob)
+ idx += 1
+
+ trails_sc_arr.append(np.mean(sc_arr))
+
+ torch.cuda.empty_cache()
+ del models_arr
+
+ return np.mean(sc_arr)
+
+parser = argparse.ArgumentParser(description='Data preperation')
+parser.add_argument('--data_path', help='path to training and test numpy matrices of images', default='.', type=str)
+parser.add_argument('--library_size', help='number of models to be trained in the library of models', default=50, type=int)
+parser.add_argument('--library_path', help='save path for validation and test predictions of the library of models', default='trails', type=str)
+args = parser.parse_args()
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+torch.manual_seed(0)
+np.random.seed(0)
+
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+#read train data
+train_imgs = np.load(os.path.join('D:/datasets/savepath', 'unique_train_imgs_rot_fixed.npy'))
+train_gts = np.load(os.path.join('D:/datasets/savepath', 'unique_train_gts_rot_fixed.npy'))
+
+#read test data
+test_imgs = np.load(os.path.join('D:/datasets/savepath', 'test_imgs_rot_fixed.npy'))
+test_gts = np.load(os.path.join('D:/datasets/savepath', 'test_gts.npy'))
+ids = np.load(os.path.join('D:/datasets/savepath', 'ids.npy')).tolist()
+
+test_prob = []
+val_prob = []
+trails_sc_arr = []
+
+n_trails = args.library_size
+seed_arr = np.random.randint(low=0, high=1000000, size=n_trails)
+
+#create search space for hyperparameter optimization
+space = OrderedDict([('lr', hp.choice('lr', [i*0.001 for i in range(1,4)])),
+ ('num_cycles', hp.choice('num_cycles', range(3, 6))),
+ ('num_epochs_per_cycle', hp.choice('num_epochs_per_cycle', range(3, 6))),
+ ('arch', hp.choice('arch', [models.densenet201, models.densenet169,
+ models.resnet152, models.resnet101,
+ models.vgg])),
+ ('img_mix_enable', hp.choice('img_mix_enable', [True, False])),
+ ('v_flip', hp.choice('v_flip', [True, False])),
+ ('h_flip', hp.choice('h_flip', [True, False])),
+ ('degrees', hp.choice('degrees', range(1, 90))),
+ ('contrast', hp.uniform('contrast', 0.0, 0.3)),
+ ('hue', hp.uniform('hue', 0.0, 0.3)),
+ ('brightness', hp.uniform('brightness', 0.0, 0.3)),
+ ('val_img_size', hp.choice('val_img_size', range(224, 512, 24))),
+ ])
+
+trials = Trials()
+
+idx = 0
+if not os.path.exists(args.library_path):
+ os.mkdir(args.library_path)
+
+#use tpe algorithm in hyperopt to generate a library of differnet models
+best = fmin(fn=score,space=space,algo=tpe.suggest,max_evals=n_trails,trials=trials)
+
+
+np.save(os.path.join(args.library_path, 'scores.npy'), np.array(trails_sc_arr))
+#np.save(os.path.join('D:/datasets', 'scores.npy'), np.array(trails_sc_arr))
+
+#from multiprocessing import Process
+#import os
+#if __name__=="__main__":
+
+# best = fmin(fn=score, space=space, algo=tpe.suggest, max_evals=n_trails, trials=trials)
+# p = Process(best)
+# print(p)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/prepare_dataset.py" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/prepare_dataset.py"
new file mode 100644
index 0000000000000000000000000000000000000000..f88d16b3d08ecda417430efb5aee9d188d88ae62
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/prepare_dataset.py"
@@ -0,0 +1,38 @@
+import numpy as np
+import os
+import argparse
+from utils import read_train_data, read_test_data
+
+parser = argparse.ArgumentParser(description='Data preperation')
+parser.add_argument('--train_data_path', help='path', default='C:/Users/x2/Desktop/pq/wheat rust/train', type=str)
+parser.add_argument('--test_data_path', help='path', default='C:/Users/x2/Desktop/pq/wheat rust/test', type=str)
+parser.add_argument('--save_path', help='save', default='C:/Users/x2/Desktop/pq/wheat rust/savepath', type=str)
+args = parser.parse_args()
+
+#read training data
+train_imgs, train_gts = read_train_data(args.train_data_path)
+
+#remove dublicate training imgs
+idx_to_rmv = []
+for i in range(len(train_imgs)-1):
+ for j in range(i+1, len(train_imgs)):
+ if np.all(train_imgs[i] == train_imgs[j]):
+ idx_to_rmv.append(i)
+ if train_gts[i] != train_gts[j]:
+ idx_to_rmv.append(j)
+
+idx = [i for i in range(len(train_imgs)) if not(i in idx_to_rmv)]
+print('unique train imgs:',len(idx))
+
+#save unique training imgs
+np.save(os.path.join(args.save_path, 'unique_train_imgs_rot_fixed'), np.array(train_imgs)[idx])
+np.save(os.path.join(args.save_path, 'unique_train_gts_rot_fixed'), np.array(train_gts)[idx])
+
+#read test data
+test_imgs, test_gts, ids = read_test_data(args.test_data_path)
+
+#save test data
+np.save(os.path.join(args.save_path, 'test_imgs_rot_fixed'), np.array(test_imgs))
+np.save(os.path.join(args.save_path, 'test_gts'), np.array(test_gts))
+np.save(os.path.join(args.save_path, 'ids'), np.array(ids))
+
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/requirements.txt" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/requirements.txt"
new file mode 100644
index 0000000000000000000000000000000000000000..3ada6af12789b2b45ae5df9abc2ead5228f2ce0d
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/requirements.txt"
@@ -0,0 +1,8 @@
+hyperopt==0.2.3
+pandas==1.0.1
+torch==1.4.0
+scipy==1.4.1
+torchvision==0.5.0
+numpy==1.18.1
+Pillow==7.0.0
+scikit_learn==0.22.2.post1
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.py" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.py"
new file mode 100644
index 0000000000000000000000000000000000000000..ec9737d0fe472ffd9268c88520d4067385d1e896
--- /dev/null
+++ "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.py"
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import numpy as np
+import torchvision
+from torchvision import datasets, models, transforms
+#import matplotlib.pyplot as plt
+import time
+import os
+import copy
+import torch.nn.functional as F
+from PIL import Image, ExifTags
+
+def train_model_snapshot(model, criterion, lr, dataloaders, dataset_sizes, device, num_cycles, num_epochs_per_cycle):
+ since = time.time()
+
+ best_model_wts = copy.deepcopy(model.state_dict())
+ best_acc = 0.0
+ best_loss = 1000000.0
+ model_w_arr = []
+ prob = torch.zeros((dataset_sizes['val'], 3), dtype = torch.float32).to(device)
+ lbl = torch.zeros((dataset_sizes['val'],), dtype = torch.long).to(device)
+ for cycle in range(num_cycles):
+ optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)#, weight_decay = 0.0005)
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, num_epochs_per_cycle*len(dataloaders['train']))
+ for epoch in range(num_epochs_per_cycle):
+ #print('Cycle {}: Epoch {}/{}'.format(cycle, epoch, num_epochs_per_cycle - 1))
+ #print('-' * 10)
+
+ # Each epoch has a training and validation phase
+ for phase in ['train', 'val']:
+ if phase == 'train':
+ model.train() # Set model to training mode
+ else:
+ model.eval() # Set model to evaluate mode
+
+ running_loss = 0.0
+ running_corrects = 0
+ idx = 0
+ # Iterate over data.
+ for inputs, labels in dataloaders[phase]:
+ inputs = inputs.to(device)
+ labels = labels.to(device)
+
+ # zero the parameter gradients
+ optimizer.zero_grad()
+
+ # forward
+ # track history if only in train
+ with torch.set_grad_enabled(phase == 'train'):
+ outputs = model(inputs)
+ _, preds = torch.max(outputs, 1)
+ if (epoch == num_epochs_per_cycle-1) and (phase == 'val'):
+ prob[idx:idx+inputs.shape[0]] += F.softmax(outputs, dim = 1)
+ lbl[idx:idx+inputs.shape[0]] = labels
+ idx += inputs.shape[0]
+ loss = criterion(outputs, labels)
+ # backward + optimize only if in training phase
+ if phase == 'train':
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ #print(optimizer.param_groups[0]['lr'])
+
+ # statistics
+ running_loss += loss.item() * inputs.size(0)
+ running_corrects += torch.sum(preds == labels.data)
+
+ epoch_loss = running_loss / dataset_sizes[phase]
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
+
+ #print('{} Loss: {:.4f} Acc: {:.4f}'.format(
+ # phase, epoch_loss, epoch_acc))
+
+ # deep copy the model
+ if phase == 'val' and epoch_loss < best_loss:
+ best_loss = epoch_loss
+ best_model_wts = copy.deepcopy(model.state_dict())
+ #print()
+ model_w_arr.append(copy.deepcopy(model.state_dict()))
+
+ prob /= num_cycles
+ ensemble_loss = F.nll_loss(torch.log(prob), lbl)
+ ensemble_loss = ensemble_loss.item()
+ time_elapsed = time.time() - since
+ #print('Training complete in {:.0f}m {:.0f}s'.format(
+ # time_elapsed // 60, time_elapsed % 60))
+ #print('Ensemble Loss : {:4f}, Best val Loss: {:4f}'.format(ensemble_loss, best_loss))
+
+ # load best model weights
+ model_arr =[]
+ for weights in model_w_arr:
+ model.load_state_dict(weights)
+ model_arr.append(model)
+ return model_arr, ensemble_loss, best_loss, prob
+
+def test(models_arr, loader, device):
+ res = np.zeros((610, 3), dtype = np.float32)
+ for model in models_arr:
+ model.eval()
+ res_arr = []
+ for inputs, _ in loader:
+ inputs = inputs.to(device)
+ # forward
+ # track history if only in train
+ with torch.set_grad_enabled(False):
+ outputs = F.softmax(model(inputs), dim = 1)
+ res_arr.append(outputs.detach().cpu().numpy())
+ res_arr = np.concatenate(res_arr, axis = 0)
+ res += res_arr
+ return res / len(models_arr)
+
+def read_train_data(p):
+ imgs = []
+ labels = []
+ for i, lbl in enumerate(os.listdir(p)):
+ for fname in os.listdir(os.path.join(p, lbl)):
+ #read image
+ img = Image.open(os.path.join(p, lbl, fname))
+ #rotate image to original view
+ try:
+ exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)
+ if exif['Orientation'] == 3:
+ img=img.rotate(180, expand=True)
+ elif exif['Orientation'] == 6:
+ img=img.rotate(270, expand=True)
+ elif exif['Orientation'] == 8:
+ img=img.rotate(90, expand=True)
+ except:
+ pass
+ #resize all images to the same size
+ img = np.array(img.convert('RGB').resize((512,512), Image.ANTIALIAS))
+ imgs.append(img)
+ labels.append(i)
+ return imgs, labels
+
+def read_test_data(p):
+ imgs = []
+ labels = []
+ ids = []
+ for fname in os.listdir(p):
+ #read image
+ img = Image.open(os.path.join(p, fname))
+ #rotate image to original view
+ try:
+ if not('DMWVNR' in fname):
+ exif=dict((ExifTags.TAGS[k], v) for k, v in img._getexif().items() if k in ExifTags.TAGS)
+ if exif['Orientation'] == 3:
+ img=img.rotate(180, expand=True)
+ elif exif['Orientation'] == 6:
+ img=img.rotate(270, expand=True)
+ elif exif['Orientation'] == 8:
+ img=img.rotate(90, expand=True)
+ except:
+ pass
+ #resize all images to the same size
+ img = img.convert('RGB').resize((512,512), Image.ANTIALIAS)
+ imgs.append(np.array(img.copy()))
+ labels.append(0)
+ ids.append(fname.split('.')[0])
+ img.close()
+ return imgs, labels, ids
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.pyc" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.pyc"
new file mode 100644
index 0000000000000000000000000000000000000000..863636974968572fee27f88e1d9430217988b249
Binary files /dev/null and "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/utils.pyc" differ
diff --git "a/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266.rar" "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266.rar"
new file mode 100644
index 0000000000000000000000000000000000000000..32d43afa2f08fadb2be8f172278a1396c8205715
Binary files /dev/null and "b/code/2022_autumn/\346\275\230\345\200\251-\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266/\346\275\230\345\200\251-\345\237\272\344\272\216WR-EL\346\250\241\345\236\213\347\232\204\345\260\217\351\272\246\351\224\210\347\227\205\350\257\206\345\210\253\347\240\224\347\251\266.rar" differ