diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" deleted file mode 100644 index e73b82551b37b917db94236d42d3d7cb3f9ee9de..0000000000000000000000000000000000000000 --- "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" +++ /dev/null @@ -1,219 +0,0 @@ -import numpy as np -import mindspore as ms -from mindspore import ops, nn -import mindspore.dataset as ds -import mindspore.common.initializer as init #库 - - -"""各函数定义""" -def get_data(): - ds_train = create_dataset() - file = open("data.txt", "r") - file2 = open("lable.txt", "r") - datas=[] - lables=[] - data = file.readlines() - for fields in data: - fields = fields.strip() - fields = fields.split(" ") - if fields!="": - fields = list(map(float,fields)) - datas.append(fields) - lable = file2.readlines() - for fields in lable: - fields = fields.strip() - fields = fields.split(" ") - if fields!="": - fields = list(map(float,fields)) - lables.append(fields) - - datas = np.array(datas,dtype=np.float32) - lables=np.array(lables,dtype=np.float32) - - for i in range(len(lables)): - yield datas[i].astype(np.float32), lables[i].astype(np.float32) - -def get_data2(): - """读取验证数据""" - file = open("data_test.txt", "r") - file2 = open("lable_test.txt", "r") - datas=[] - lables=[] - data = file.readlines() - for fields in data: - fields = fields.strip() - fields = fields.split(" ") - if fields!="": - fields = list(map(float,fields)) - datas.append(fields) - lable = file2.readlines() - for fields in lable: - fields = fields.strip() - fields = fields.split(" ") - if fields!="": - fields = list(map(float,fields)) - lables.append(fields) - - datas = np.array(datas,dtype=np.float32) - lables=np.array(lables,dtype=np.float32) - - for i in range(len(lables)): - yield datas[i].astype(np.float32), lables[i].astype(np.float32) - -def create_dataset(batch_size=12, repeat_size=1): - """定义数据集""" - input_data = ds.GeneratorDataset(list(get_data()), column_names=['data', 'label']) - input_data = input_data.batch(batch_size) - input_data = input_data.repeat(repeat_size) - return input_data - -def create_dataset_2(batch_size=1, repeat_size=1): - """验证数据集""" - input_data = ds.GeneratorDataset(list(get_data2()), column_names=['data', 'label']) - input_data = input_data.batch(batch_size) - input_data = input_data.repeat(repeat_size) - return input_data - -class MyNet(nn.Cell): - """定义网络""" - def __init__(self, input_size=15): - super(MyNet, self).__init__() - self.fc1 = nn.Dense(input_size, 120, weight_init=init.Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=init.Normal(0.02)) - self.fc3 = nn.Dense(84, 20, weight_init=init.Normal(0.02)) - self.fc4 = nn.Dense(20, 1, weight_init=init.Normal(0.02)) - self.relu = nn.ReLU() - - def construct(self, x): - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.relu(self.fc3(x)) - x = self.fc4(x) - return x - -class MyL1Loss(nn.LossBase): - """定义损失""" - def __init__(self, reduction="mean"): - super(MyL1Loss, self).__init__(reduction) - self.abs = ops.Abs() - self.loss=nn.MSELoss() - - def construct(self, base, target): - x = self.abs(base - target) - return self.get_loss(x) - - -class MyMomentum(nn.Optimizer): - """Momentum优化器""" - def __init__(self, params, learning_rate, momentum=0.9, use_nesterov=False): - super(MyMomentum, self).__init__(learning_rate, params) - self.moments = self.parameters.clone(prefix="moments", init="zeros") - self.momentum = momentum - self.opt = ops.ApplyMomentum(use_nesterov=use_nesterov) - - def construct(self, gradients): - params = self.parameters - success = None - for param, mom, grad in zip(params, self.moments, gradients): - success = self.opt(param, mom, self.learning_rate, grad, self.momentum) - return success - -class MyWithLossCell(nn.Cell): - """定义损失网络""" - def __init__(self, backbone, loss_fn): - super(MyWithLossCell, self).__init__(auto_prefix=False) - self.backbone = backbone - self.loss_fn = loss_fn - - def construct(self, data, label): - out = self.backbone(data) - return self.loss_fn(out, label) - - def backbone_network(self): - return self.backbone - -class MyTrainStep(nn.TrainOneStepCell): - """定义训练流程""" - def __init__(self, network, optimizer): - """参数初始化""" - super(MyTrainStep, self).__init__(network, optimizer) - self.grad = ops.GradOperation(get_by_list=True) - - def construct(self, data, label): - """构建训练过程""" - weights = self.weights - loss = self.network(data, label) - grads = self.grad(self.network, weights)(data, label) - return loss, self.optimizer(grads) - - -"""搭建网络 训练 """ -ds_train = create_dataset() -# 网络 -net = MyNet() -# 损失函数 -loss_func = MyL1Loss() -# 优化器 -opt = MyMomentum(net.trainable_params(), 0.1) -# 构建损失网络 -net_with_criterion = MyWithLossCell(net, loss_func) -# 构建训练网络 -train_net = MyTrainStep(net_with_criterion, opt) -# 执行训练,每个epoch打印一次损失值 -epochs = 5000 -arr=[] -for epoch in range(epochs): - for train_x, train_y in ds_train: - train_net(train_x, train_y) - loss_val = net_with_criterion(train_x, train_y) - print(loss_val) - -""" 模型评估部分""" -class MyMAE(nn.Metric): - """定义metric""" - def __init__(self): - super(MyMAE, self).__init__() - self.clear() - - def clear(self): - self.abs_error_sum = 0 - self.samples_num = 0 - - def update(self, *inputs): - y_pred = inputs[0].asnumpy() - y = inputs[1].asnumpy() - error_abs = np.abs(y.reshape(y_pred.shape) - y_pred) - self.abs_error_sum += error_abs.sum() - self.samples_num += y.shape[0] - - def eval(self): - return self.abs_error_sum / self.samples_num - - -class MyWithEvalCell(nn.Cell): - """定义验证流程""" - def __init__(self, network): - super(MyWithEvalCell, self).__init__(auto_prefix=False) - self.network = network - - def construct(self, data, label): - outputs = self.network(data) - return outputs, label - -""" 计算MAE""" -# 获取验证数据 -ds_eval = create_dataset_2() -# 定义评估网络 -eval_net = MyWithEvalCell(net) -eval_net.set_train(False) -# 定义评估指标 -mae = MyMAE() -# 执行推理过程 -for eval_x, eval_y in ds_eval: - output, eval_y = eval_net(eval_x, eval_y) - mae.update(output, eval_y) - -mae_result = mae.eval() -print("mae: ", mae_result) - -ms.save_checkpoint(net, "./MyNet.ckpt") \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" deleted file mode 100644 index fa7786338ae7c3a4774f6c087e13ae0c31179861..0000000000000000000000000000000000000000 --- "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" +++ /dev/null @@ -1,36 +0,0 @@ --266.875 4931.05 276920.9 2087.5 274682.15 2758.925 24.95 461230.7692 15314.10256 0.20640852 0.031791654 -0.075472142 61099.4 274318.95 2166.7 --266.875 4737.225 279008.4 2087.5 277441.075 2758.925 24.975 497461.5385 36230.76923 0.194195932 0.103107395 0.071315742 67399.2 277004.825 2302.5 --266.875 4543.4 281095.9 2087.5 280200 2758.925 25 590769.2308 93307.69231 0.181983344 -0.115118577 -0.218225973 59640.3 279690.7 2393.9 -2137.475 3967.45 284780.2 3684.3 284475 4275 25.025 701153.8462 110384.6154 0.159411232 0.10160244 0.216721017 65699.9 286679.125 2282.7 -2137.475 3391.5 288464.5 3684.3 288750 4275 25.05 893636.3636 192482.5175 0.136839119 0.018561672 -0.083040768 66919.4 293667.55 2362.7 -2137.475 2815.55 292148.8 3684.3 293025 4275 25.075 718583.3333 -175053.0303 0.114267007 0.079108599 0.060546927 72213.3 300655.975 2396.6 -2137.475 2239.6 295833.1 3684.3 297300 4275 25.1 572181.8182 -146401.5152 0.091694895 -0.121603638 -0.200712237 63431.9 307644.4 2134.3 -1430.725 2157.35 309126.9 13293.8 308686.175 11386.175 25.65 578750 6568.181818 0.114394105 0.120735781 0.242339419 71090.4 320543.55 2218 -1430.725 2075.1 322420.7 13293.8 320072.35 11386.175 26.2 572750 -6000 0.137093315 0.041244669 -0.079491112 74022.5 333442.7 2265.2 -1430.725 1910.6 349008.3 13293.8 342844.7 11386.175 27.3 742083.3333 65750 0.182491736 -0.119045468 -0.210404176 71168 359241 2783.7 --24.275 1799.4 347493.775 -1514.525 345033.525 2188.825 27.125 738181.8182 -3901.515152 0.182863311 0.106493087 0.225538554 78746.9 361516.75 2591 --24.275 1688.2 345979.25 -1514.525 347222.35 2188.825 26.95 741875 3693.181818 0.183234887 0.033268611 -0.073224476 81366.7 363792.5 2696.7 --24.275 1577 344464.725 -1514.525 349411.175 2188.825 26.775 783800 41925 0.183606462 0.091614874 0.058346263 88821.1 366068.25 2787.1 --24.275 1465.8 342950.2 -1514.525 351600 2188.825 26.6 770909.0909 -12890.90909 0.183978037 0.113505687 0.021890813 98902.8 368344 2797.7 -2657.825 1331.225 360144.375 17194.175 362328.2 10728.2 26.9 823833.3333 52924.24242 0.163929228 0.115379949 0.001874262 110314.2 381864.6 2928.2 -2657.825 1196.65 377338.55 17194.175 373056.4 10728.2 27.2 837000 13166.66667 0.143880419 0.044502884 -0.070877065 115223.5 395385.2 3095 -2657.825 1062.075 394532.725 17194.175 383784.6 10728.2 27.5 841500 4500 0.123831609 0.102549393 0.05804651 127039.6 408905.8 3332 -2657.825 927.5 411726.9 17194.175 394512.8 10728.2 27.8 775090.9091 -66409.09091 0.1037828 -0.158227041 -0.260776434 106938.5 422426.4 3159.5 -965.175 883.325 414901.65 3174.75 395242.65 729.85 28.025 762454.5455 -12636.36364 0.103080889 0.110520533 0.268747574 118757.4 424165.6 3294.6 -965.175 839.15 418076.4 3174.75 395972.5 729.85 28.25 635461.5385 -126993.007 0.102378978 0.043446556 -0.067073977 123917 425904.8 3005.7 -965.175 794.975 421251.15 3174.75 396702.35 729.85 28.475 634538.4615 -923.0769231 0.101677068 0.108567832 0.065121276 137370.4 427644 3091.4 -965.175 750.8 424425.9 3174.75 397432.2 729.85 28.7 623454.5455 -11083.91608 0.100975157 -0.16035405 -0.268921882 115342.5 429383.2 3224.6 --894.95 706.65 421222.8 -3203.1 394922.125 -2510.075 29.075 609272.7273 -14181.81818 0.097064859 0.107518044 0.267872094 127743.9 426022.325 3551.6 --894.95 662.5 418019.7 -3203.1 392412.05 -2510.075 29.45 559071.4286 -50201.2987 0.09315456 0.047029251 -0.060488793 133751.6 422661.45 3268 --894.95 618.35 414816.6 -3203.1 389901.975 -2510.075 29.825 571000 11928.57143 0.089244262 0.106268635 0.059239384 147965.2 419300.575 3675.3 --894.95 574.2 411613.5 -3203.1 387391.9 -2510.075 30.2 554166.6667 -16833.33333 0.085333963 -0.161744113 -0.268012748 124032.7 415939.7 3663.7 --2178.875 564.1 407963.65 -3649.85 384207.475 -3184.425 30.625 531615.3846 -22551.28205 0.081595911 0.108395609 0.270139721 137477.3 410586.5 3490 --2178.875 554 404313.8 -3649.85 381023.05 -3184.425 31.05 482307.6923 -49307.69231 0.077857859 0.04343117 -0.064964438 143448.1 405233.3 3425.1 --2178.875 543.9 400663.95 -3649.85 377838.625 -3184.425 31.475 502333.3333 20025.64103 0.074119807 0.107130035 0.063698865 158815.7 399880.1 3446.1 --2178.875 533.8 397014.1 -3649.85 374654.2 -3184.425 31.9 490909.0909 -11424.24242 0.070381754 -0.163282975 -0.27041301 132883.8 394526.9 3546.9 -1284.125 523.7 393364.25 -3649.85 371469.775 -3184.425 32.25 421100 -69809.09091 0.066643702 0.108302141 0.271585116 147275.4 392636.7 3205.8 -1284.125 513.6 389714.4 -3649.85 368285.35 -3184.425 32.6 403076.9231 -18023.07692 0.06290565 0.042121766 -0.066180374 153478.9 390746.5 3405.4 -1284.125 503.5 386064.55 -3649.85 365100.925 -3184.425 32.95 370833.3333 -32243.58974 0.059167598 0.106555364 0.064433597 169832.9 388856.3 3527.2 -1284.125 493.4 382414.7 -3649.85 361916.5 -3184.425 33.3 377500 6666.666667 0.055429545 -0.047531426 -0.15408679 161760.5 386966.1 3498.3 -387.575 323.2075 395533.63 13118.93 369353.3475 7436.8475 33.675 390000 12500 0.058261228 0.112682021 0.160213447 179988 394960.715 3674.5 -387.575 289.265 394109.86 -1423.77 367612.945 -1740.4025 34.05 477692.3077 87692.30769 0.05536512 0.047887637 -0.064794384 188607.2 393641.83 3348.7 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" deleted file mode 100644 index 09c06a351613d1a51e372fb592ef4a0ccd7b2701..0000000000000000000000000000000000000000 --- "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" +++ /dev/null @@ -1,2 +0,0 @@ -387.575 255.3225 392686.09 -1423.77 365872.5425 -1740.4025 34.425 653250 175557.6923 0.052469012 0.090524646 0.042637009 205680.8 392322.945 3879.5 -387.575 221.38 391262.32 -1423.77 364132.14 -1740.4025 34.8 618958.3333 -34291.66667 0.049572903 -0.15811539 -0.248640036 173159.5 391004.06 3862.7 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" deleted file mode 100644 index f9d605025fd00fc7e743904d51d5c249c1840b9b..0000000000000000000000000000000000000000 --- "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" +++ /dev/null @@ -1,36 +0,0 @@ -461.2307692 -497.4615385 -590.7692308 -701.1538462 -893.6363636 -718.5833333 -572.1818182 -578.75 -572.75 -742.0833333 -738.1818182 -741.875 -783.8 -770.9090909 -823.8333333 -837 -841.5 -775.0909091 -762.4545455 -635.4615385 -634.5384615 -623.4545455 -609.2727273 -559.0714286 -571 -554.1666667 -531.6153846 -482.3076923 -502.3333333 -490.9090909 -421.1 -403.0769231 -370.8333333 -377.5 -390 -477.6923077 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" deleted file mode 100644 index ca6d05bba5c62a5872bb22688ec4f9619f16b2b1..0000000000000000000000000000000000000000 --- "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" +++ /dev/null @@ -1,2 +0,0 @@ -653.25 -618.9583333 \ No newline at end of file diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/.keep" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/cells.py" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/cells.py" new file mode 100644 index 0000000000000000000000000000000000000000..5243e1c8b2394dbab74d20ec21cf1477cd175523 --- /dev/null +++ "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/cells.py" @@ -0,0 +1,132 @@ +from mindspore import nn +import mindspore.ops.operations as P +import mindspore.ops.functional as F +import mindspore.ops.composite as C +from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, + _get_parallel_mode) +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer + + +class Reshape(nn.Cell): + def __init__(self, shape, auto_prefix=True): + super().__init__(auto_prefix=auto_prefix) + self.shape = shape + self.reshape = P.Reshape() + + def construct(self, x): + return self.reshape(x, self.shape) + + +class SigmoidCrossEntropyWithLogits(nn.loss.loss._Loss): + def __init__(self): + super().__init__() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + + def construct(self, data, label): + x = self.cross_entropy(data, label) + return self.get_loss(x) + + +class GenWithLossCell(nn.Cell): + def __init__(self, netG, netD, loss_fn, auto_prefix=True): + super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix) + self.netG = netG + self.netD = netD + self.loss_fn = loss_fn + + def construct(self, latent_code): + fake_data = self.netG(latent_code) + fake_out = self.netD(fake_data) + loss_G = self.loss_fn(fake_out, F.ones_like(fake_out)) + + return loss_G + + +class DisWithLossCell(nn.Cell): + def __init__(self, netG, netD, loss_fn, auto_prefix=True): + super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix) + self.netG = netG + self.netD = netD + self.loss_fn = loss_fn + + def construct(self, real_data, latent_code): + fake_data = self.netG(latent_code) + + fake_out = self.netD(fake_data) + fake_loss = self.loss_fn(fake_out, F.zeros_like(fake_out)) + + real_out = self.netD(real_data) + real_loss = self.loss_fn(real_out, F.ones_like(real_out)) + loss_D = real_loss + fake_loss + + return loss_D + + +class TrainOneStepCell(nn.Cell): + def __init__( + self, + netG, + netD, + optimizerG: nn.Optimizer, + optimizerD: nn.Optimizer, + sens=1.0, + auto_prefix=True, + ): + super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix) + self.netG = netG + self.netG.set_grad() + self.netG.add_flags(defer_inline=True) + + self.netD = netD + self.netD.set_grad() + self.netD.add_flags(defer_inline=True) + + self.weights_G = optimizerG.parameters + self.optimizerG = optimizerG + self.weights_D = optimizerD.parameters + self.optimizerD = optimizerD + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + + self.sens = sens + self.reducer_flag = False + self.grad_reducer_G = F.identity + self.grad_reducer_D = F.identity + self.parallel_mode = _get_parallel_mode() + if self.parallel_mode in (ParallelMode.DATA_PARALLEL, + ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_gradients_mean() + degree = _get_device_num() + self.grad_reducer_G = DistributedGradReducer( + self.weights_G, mean, degree) + self.grad_reducer_D = DistributedGradReducer( + self.weights_D, mean, degree) + + def trainD(self, real_data, latent_code, loss, loss_net, grad, optimizer, + weights, grad_reducer): + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = grad(loss_net, weights)(real_data, latent_code, sens) + grads = grad_reducer(grads) + return F.depend(loss, optimizer(grads)) + + def trainG(self, latent_code, loss, loss_net, grad, optimizer, weights, + grad_reducer): + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = grad(loss_net, weights)(latent_code, sens) + grads = grad_reducer(grads) + return F.depend(loss, optimizer(grads)) + + def construct(self, real_data, latent_code): + loss_D = self.netD(real_data, latent_code) + loss_G = self.netG(latent_code) + d_out = self.trainD(real_data, latent_code, loss_D, self.netD, + self.grad, self.optimizerD, self.weights_D, + self.grad_reducer_D) + g_out = self.trainG(latent_code, loss_G, self.netG, self.grad, + self.optimizerG, self.weights_G, + self.grad_reducer_G) + + return d_out, g_out,loss_D,loss_G diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/.keep" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (1).png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (1).png" new file mode 100644 index 0000000000000000000000000000000000000000..ee85a02a76d9180f63e272167333c68dfa918d27 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (1).png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (2).png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (2).png" new file mode 100644 index 0000000000000000000000000000000000000000..8a851d835b6c1f3edf1745f91bbf46334aec04e2 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (2).png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (3).png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (3).png" new file mode 100644 index 0000000000000000000000000000000000000000..aab1796afd596ac5000dd3d3f8b9fa49feb4cfe3 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (3).png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (4).png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (4).png" new file mode 100644 index 0000000000000000000000000000000000000000..74881105bddce1aca39eb2e4fba72507e9f1d059 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (4).png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (5).png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (5).png" new file mode 100644 index 0000000000000000000000000000000000000000..43783efd8b41713f81367b9ffb2ed9fe0c7f34b5 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/images/result (5).png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/loss.png" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/loss.png" new file mode 100644 index 0000000000000000000000000000000000000000..fc99fc51ab7cfc06c93b380fa19d1bb99e3cdc11 Binary files /dev/null and "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/loss.png" differ diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/model.py" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/model.py" new file mode 100644 index 0000000000000000000000000000000000000000..a0c09309731ecd81207bbaab57912427712625c3 --- /dev/null +++ "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/model.py" @@ -0,0 +1,56 @@ +from mindspore import nn +from cells import Reshape + +class Generator(nn.Cell): + """定义生成器结构""" + def __init__(self, latent_size, auto_prefix=True): + super(Generator, self).__init__(auto_prefix=auto_prefix) + self.network = nn.SequentialCell() + """ + self.network.append(nn.Dense(latent_size, 512 * 12 * 12, has_bias=False)) + self.network.append(Reshape((-1, 512, 12, 12)))""" + self.network.append(nn.Conv2dTranspose(100, 512, 4, 1, padding=0)) + self.network.append(nn.BatchNorm2d(512)) + self.network.append(nn.ReLU()) + + self.network.append(nn.Conv2dTranspose(512, 256, 4, 2,pad_mode='pad',padding=1)) + self.network.append(nn.BatchNorm2d(256)) + self.network.append(nn.ReLU()) + + self.network.append(nn.Conv2dTranspose(256, 128, 4, 2,pad_mode='pad',padding=1)) + self.network.append(nn.BatchNorm2d(128)) + self.network.append(nn.ReLU()) + + self.network.append(nn.Conv2dTranspose(128, 3, 4, 2, pad_mode='pad',padding=1)) + self.network.append(nn.Tanh()) + + def construct(self, x): + return self.network(x) + + + + +class Discriminator(nn.Cell): + '''定义判别器结构''' + def __init__(self, auto_prefix=True): + super().__init__(auto_prefix=auto_prefix) + + self.network = nn.SequentialCell() + + self.network.append(nn.Conv2d(3, 128, 4, 2, pad_mode='pad', padding=1)) + self.network.append(nn.BatchNorm2d(128)) + self.network.append(nn.LeakyReLU(0.2)) + + self.network.append(nn.Conv2d(128, 256, 4, 2, pad_mode='pad',padding=1)) + self.network.append(nn.BatchNorm2d(256)) + self.network.append(nn.LeakyReLU(0.2)) + + self.network.append(nn.Conv2d(256, 512, 4, 2, pad_mode='pad',padding=1)) + self.network.append(nn.BatchNorm2d(512)) + self.network.append(nn.LeakyReLU(0.2)) + + self.network.append(nn.Conv2d(512, 1, 4, 1)) + self.network.append(nn.Sigmoid()) + + def construct(self, x): + return self.network(x) \ No newline at end of file diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/preprocess.py" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/preprocess.py" new file mode 100644 index 0000000000000000000000000000000000000000..cb3f3ad05aaeaf7f272261b256211df7f5da4a7c --- /dev/null +++ "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/preprocess.py" @@ -0,0 +1,105 @@ +import cv2 +import glob +import math +import sys + +def rotateCoords(coords, center, angleRadians): + # Positive y is down so reverse the angle, too. + angleRadians = -angleRadians + xs, ys = coords[::2], coords[1::2] + newCoords = [] + n = min(len(xs), len(ys)) + i = 0 + centerX = center[0] + centerY = center[1] + cosAngle = math.cos(angleRadians) + sinAngle = math.sin(angleRadians) + while i < n: + xOffset = xs[i] - centerX + yOffset = ys[i] - centerY + newX = xOffset * cosAngle - yOffset * sinAngle + centerX + newY = xOffset * sinAngle + yOffset * cosAngle + centerY + newCoords += [newX, newY] + i += 1 + return newCoords + +def preprocessCatFace(coords, image): + + leftEyeX, leftEyeY = coords[0], coords[1] + rightEyeX, rightEyeY = coords[2], coords[3] + mouthX = coords[4] + if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \ + mouthX > rightEyeX: + # The "right eye" is in the second quadrant of the face, + # while the "left eye" is in the fourth quadrant (from the + # viewer's perspective.) Swap the eyes' labels in order to + # simplify the rotation logic. + leftEyeX, rightEyeX = rightEyeX, leftEyeX + leftEyeY, rightEyeY = rightEyeY, leftEyeY + + eyesCenter = (0.5 * (leftEyeX + rightEyeX), + 0.5 * (leftEyeY + rightEyeY)) + + eyesDeltaX = rightEyeX - leftEyeX + eyesDeltaY = rightEyeY - leftEyeY + eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX) + eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi + + # Straighten the image and fill in gray for blank borders. + rotation = cv2.getRotationMatrix2D( + eyesCenter, eyesAngleDegrees, 1.0) + imageSize = image.shape[1::-1] + straight = cv2.warpAffine(image, rotation, imageSize, + borderValue=(128, 128, 128)) + + # Straighten the coordinates of the features. + newCoords = rotateCoords( + coords, eyesCenter, eyesAngleRadians) + + # Make the face as wide as the space between the ear bases. + w = abs(newCoords[16] - newCoords[6]) + # Make the face square. + h = w + # Put the center point between the eyes at (0.5, 0.4) in + # proportion to the entire face. + minX = eyesCenter[0] - w/2 + if minX < 0: + w += minX + minX = 0 + minY = eyesCenter[1] - h*2/5 + if minY < 0: + h += minY + minY = 0 + + # Crop the face. + crop = straight[int(minY):int(minY+h), int(minX):int(minX+w)] + # Return the crop. + return crop + +def describePositive(): + for imagePath in glob.glob('CAT*/*.jpg'): + # Open the '.cat' annotation file associated with this + # image. + input = open('%s.cat' % imagePath, 'r') + # Read the coordinates of the cat features from the + # file. Discard the first number, which is the number + # of features. + coords = [int(i) for i in input.readline().split()[1:]] + # Read the image. + image = cv2.imread(imagePath) + # Straighten and crop the cat face. + crop = preprocessCatFace(coords, image) + if crop is None: + print(f'Failed to preprocess image at {imagePath}.', file=sys.stderr) + continue + # Save the crop to folders based on size + h, w, colors = crop.shape + if min(h,w) >= 64: + Path1 = imagePath.replace("cat_dataset","cats_bigger_than_64x64") + cv2.imwrite(Path1, crop) + if min(h,w) >= 128: + Path2 = imagePath.replace("cat_dataset","cats_bigger_than_128x128") + cv2.imwrite(Path2, crop)\ + +if __name__ == '__main__': + describePositive() \ No newline at end of file diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/test.py" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/test.py" new file mode 100644 index 0000000000000000000000000000000000000000..1d9b1ece41e22ad8a3ca5b892b34a2611aadb117 --- /dev/null +++ "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/test.py" @@ -0,0 +1,52 @@ +import mindspore.dataset as ds +import mindspore as ms +import mindspore.dataset.vision.c_transforms as CV +from mindspore.common import dtype as mstype +import os +import numpy as np +from cells import SigmoidCrossEntropyWithLogits, GenWithLossCell, DisWithLossCell, TrainOneStepCell, Reshape +import matplotlib.pyplot as plt +from mindspore import context, Tensor, nn, load_checkpoint +from model import Generator, Discriminator + +batch_size = 16 +epochs = 300 +input_dim = 100 +lr = 0.0002 +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + +def save_imgs(gen_imgs): + for i in range(gen_imgs.shape[0]): + plt.subplot(4, 4, i + 1) + # print(gen_imgs.shape) + img = gen_imgs[i, :, :, :] + img = np.transpose(img, (1, 2, 0)) + + img = ((img * 127.5) + 127.5).astype(np.uint8) + plt.imshow(img) + + plt.axis("off") + plt.savefig("./image/{}.png".format("result")) + + + + +netG = Generator(input_dim) +netD = Discriminator() +loss = SigmoidCrossEntropyWithLogits() +netG_with_loss = GenWithLossCell(netG, netD, loss) +netD_with_loss = DisWithLossCell(netG, netD, loss) +optimizerG = nn.Adam(netG.trainable_params(), lr, beta1=0.5, beta2=0.999) +optimizerD = nn.Adam(netD.trainable_params(), lr, beta1=0.5, beta2=0.999) +net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG, + optimizerD) + +load_checkpoint('./out/G_300.ckpt', netG) + +netG.set_train() +netD.set_train() +test_latent_code = Tensor(np.random.normal(size=(16, input_dim)), + dtype=mstype.float32) +gen_imgs = netG(test_latent_code) +save_imgs(gen_imgs.asnumpy()) diff --git "a/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/train.py" "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/train.py" new file mode 100644 index 0000000000000000000000000000000000000000..fcf5c41b92fea390dd634d62afc5710ff4b82d13 --- /dev/null +++ "b/code/2021_autumn/\351\202\242\350\201\252\351\242\226-\345\237\272\344\272\216DCGAN\347\232\204\347\214\253\345\222\252\345\233\276\345\203\217\347\224\237\346\210\220\347\256\227\346\263\225/train.py" @@ -0,0 +1,129 @@ +from mindspore import nn +import mindspore.dataset as ds +import mindspore as ms +import mindspore.dataset.vision.c_transforms as CV +from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore import context +import os +import numpy as np +from cells import SigmoidCrossEntropyWithLogits, GenWithLossCell, DisWithLossCell, TrainOneStepCell, Reshape +import matplotlib.pyplot as plt +import time +from mindspore.train.serialization import save_checkpoint + +from model import Generator, Discriminator +from mindspore.dataset.vision import c_transforms as vision + +batch_size = 32 +epochs = 300 +input_dim = 100 +lr = 0.0002 +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + +def save_imgs(gen_imgs, idx): + for i in range(gen_imgs.shape[0]): + plt.subplot(4, 4, i + 1) + # print(gen_imgs.shape) + gen_imgs[i] = gen_imgs[i] * 127.5 + 127.5 + + img = np.transpose(gen_imgs[i], (1, 2, 0)) + + img = img.astype(int) + plt.imshow(img) + + plt.axis("off") + plt.savefig("./image/{}.png".format(idx)) + + +def create_dataset(data_path, latent_size): + transform_img = [ + vision.Decode(), + vision.Resize((96, 96)), + vision.CenterCrop((96, 96)), + vision.HWC2CHW() + ] + + dataset = ms.dataset.ImageFolderDataset(data_path) + + dataset = dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=4) + + dataset = dataset.map( + input_columns="image", + operations=lambda x: ((x - 127.5) / 127.5).astype("float32"), + + ) + + dataset = dataset.map( + input_columns="image", + operations=lambda x: ( + x, + np.random.normal(size=(latent_size,1,1)).astype("float32"), + ), + output_columns=["image", "latent_code"], + column_order=["image", "latent_code"], + num_parallel_workers=4, + ) + + dataset = dataset.shuffle(buffer_size=10000) # 10000 as in imageNet train script + dataset = dataset.batch(batch_size, drop_remainder=True) + + return dataset + + + +netG = Generator(input_dim) +netD = Discriminator() +loss = SigmoidCrossEntropyWithLogits() +netG_with_loss = GenWithLossCell(netG, netD, loss) +netD_with_loss = DisWithLossCell(netG, netD, loss) +optimizerG = nn.Adam(netG.trainable_params(), lr, beta1=0.5, beta2=0.999) +optimizerD = nn.Adam(netD.trainable_params(), lr, beta1=0.5, beta2=0.999) +net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG, + optimizerD) + +ds = create_dataset(os.path.join('data/cat'), + latent_size=input_dim, + ) + +# dataset_helper = DatasetHelper(ds, epoch_num=epochs, dataset_sink_mode=False) +# net_train = connect_network_with_dataset(net_train, dataset_helper ) +netG.set_train() +netD.set_train() +test_latent_code = Tensor(np.random.normal(size=(16, input_dim,1, 1)).astype("float32")) +Loss_list_g = np.zeros([epochs]) +Loss_list_d = np.zeros([epochs]) +for epoch in range(epochs): + start = time.time() + + for data in ds: + imgs = data[0] + latent_code = data[1] + fake_data = netG(latent_code) + fake_out = netD(fake_data) + d_out, g_out, loss_D, loss_G = net_train(imgs, latent_code) + lss = loss_G.asnumpy() + Loss_list_g[epoch] = lss + Loss_list_d[epoch] = loss_D.asnumpy() + if ((epoch + 1) % 20 == 0): + save_checkpoint(netG, os.path.join('./out', f"G_{epoch+1}.ckpt")) + save_checkpoint(netD, os.path.join('./out', f"D_{epoch+1}.ckpt")) + + t = time.time() - start + + print("time of epoch {} is {:.2f}s".format(epoch, t)) + gen_imgs = netG(test_latent_code) + + save_imgs(gen_imgs.asnumpy(), epoch) + + +plt.figure(figsize=(10, 5)) +plt.title("Generator and Discriminator Loss During Training") +plt.plot(Loss_list_g, label="G") +plt.plot(Loss_list_d, label="D") +plt.xlabel("iterations") +plt.ylabel("Loss") +plt.legend() +plt.savefig("./out/{}.png".format(20))