From ea1cac79da4a0236b82c7830ac7a44dd12a006e1 Mon Sep 17 00:00:00 2001 From: linjingming0103 <1023339599@qq.com> Date: Sun, 21 Nov 2021 14:51:28 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BD=9C=E4=B8=9A=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Mycode" | 219 ++++++++++++++++++ .../data" | 36 +++ .../data_test" | 2 + .../lable" | 36 +++ .../lable_test" | 2 + 5 files changed, 295 insertions(+) create mode 100644 "code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" create mode 100644 "code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" create mode 100644 "code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" create mode 100644 "code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" create mode 100644 "code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" new file mode 100644 index 0000000..e73b825 --- /dev/null +++ "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/Mycode" @@ -0,0 +1,219 @@ +import numpy as np +import mindspore as ms +from mindspore import ops, nn +import mindspore.dataset as ds +import mindspore.common.initializer as init #库 + + +"""各函数定义""" +def get_data(): + ds_train = create_dataset() + file = open("data.txt", "r") + file2 = open("lable.txt", "r") + datas=[] + lables=[] + data = file.readlines() + for fields in data: + fields = fields.strip() + fields = fields.split(" ") + if fields!="": + fields = list(map(float,fields)) + datas.append(fields) + lable = file2.readlines() + for fields in lable: + fields = fields.strip() + fields = fields.split(" ") + if fields!="": + fields = list(map(float,fields)) + lables.append(fields) + + datas = np.array(datas,dtype=np.float32) + lables=np.array(lables,dtype=np.float32) + + for i in range(len(lables)): + yield datas[i].astype(np.float32), lables[i].astype(np.float32) + +def get_data2(): + """读取验证数据""" + file = open("data_test.txt", "r") + file2 = open("lable_test.txt", "r") + datas=[] + lables=[] + data = file.readlines() + for fields in data: + fields = fields.strip() + fields = fields.split(" ") + if fields!="": + fields = list(map(float,fields)) + datas.append(fields) + lable = file2.readlines() + for fields in lable: + fields = fields.strip() + fields = fields.split(" ") + if fields!="": + fields = list(map(float,fields)) + lables.append(fields) + + datas = np.array(datas,dtype=np.float32) + lables=np.array(lables,dtype=np.float32) + + for i in range(len(lables)): + yield datas[i].astype(np.float32), lables[i].astype(np.float32) + +def create_dataset(batch_size=12, repeat_size=1): + """定义数据集""" + input_data = ds.GeneratorDataset(list(get_data()), column_names=['data', 'label']) + input_data = input_data.batch(batch_size) + input_data = input_data.repeat(repeat_size) + return input_data + +def create_dataset_2(batch_size=1, repeat_size=1): + """验证数据集""" + input_data = ds.GeneratorDataset(list(get_data2()), column_names=['data', 'label']) + input_data = input_data.batch(batch_size) + input_data = input_data.repeat(repeat_size) + return input_data + +class MyNet(nn.Cell): + """定义网络""" + def __init__(self, input_size=15): + super(MyNet, self).__init__() + self.fc1 = nn.Dense(input_size, 120, weight_init=init.Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=init.Normal(0.02)) + self.fc3 = nn.Dense(84, 20, weight_init=init.Normal(0.02)) + self.fc4 = nn.Dense(20, 1, weight_init=init.Normal(0.02)) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.relu(self.fc3(x)) + x = self.fc4(x) + return x + +class MyL1Loss(nn.LossBase): + """定义损失""" + def __init__(self, reduction="mean"): + super(MyL1Loss, self).__init__(reduction) + self.abs = ops.Abs() + self.loss=nn.MSELoss() + + def construct(self, base, target): + x = self.abs(base - target) + return self.get_loss(x) + + +class MyMomentum(nn.Optimizer): + """Momentum优化器""" + def __init__(self, params, learning_rate, momentum=0.9, use_nesterov=False): + super(MyMomentum, self).__init__(learning_rate, params) + self.moments = self.parameters.clone(prefix="moments", init="zeros") + self.momentum = momentum + self.opt = ops.ApplyMomentum(use_nesterov=use_nesterov) + + def construct(self, gradients): + params = self.parameters + success = None + for param, mom, grad in zip(params, self.moments, gradients): + success = self.opt(param, mom, self.learning_rate, grad, self.momentum) + return success + +class MyWithLossCell(nn.Cell): + """定义损失网络""" + def __init__(self, backbone, loss_fn): + super(MyWithLossCell, self).__init__(auto_prefix=False) + self.backbone = backbone + self.loss_fn = loss_fn + + def construct(self, data, label): + out = self.backbone(data) + return self.loss_fn(out, label) + + def backbone_network(self): + return self.backbone + +class MyTrainStep(nn.TrainOneStepCell): + """定义训练流程""" + def __init__(self, network, optimizer): + """参数初始化""" + super(MyTrainStep, self).__init__(network, optimizer) + self.grad = ops.GradOperation(get_by_list=True) + + def construct(self, data, label): + """构建训练过程""" + weights = self.weights + loss = self.network(data, label) + grads = self.grad(self.network, weights)(data, label) + return loss, self.optimizer(grads) + + +"""搭建网络 训练 """ +ds_train = create_dataset() +# 网络 +net = MyNet() +# 损失函数 +loss_func = MyL1Loss() +# 优化器 +opt = MyMomentum(net.trainable_params(), 0.1) +# 构建损失网络 +net_with_criterion = MyWithLossCell(net, loss_func) +# 构建训练网络 +train_net = MyTrainStep(net_with_criterion, opt) +# 执行训练,每个epoch打印一次损失值 +epochs = 5000 +arr=[] +for epoch in range(epochs): + for train_x, train_y in ds_train: + train_net(train_x, train_y) + loss_val = net_with_criterion(train_x, train_y) + print(loss_val) + +""" 模型评估部分""" +class MyMAE(nn.Metric): + """定义metric""" + def __init__(self): + super(MyMAE, self).__init__() + self.clear() + + def clear(self): + self.abs_error_sum = 0 + self.samples_num = 0 + + def update(self, *inputs): + y_pred = inputs[0].asnumpy() + y = inputs[1].asnumpy() + error_abs = np.abs(y.reshape(y_pred.shape) - y_pred) + self.abs_error_sum += error_abs.sum() + self.samples_num += y.shape[0] + + def eval(self): + return self.abs_error_sum / self.samples_num + + +class MyWithEvalCell(nn.Cell): + """定义验证流程""" + def __init__(self, network): + super(MyWithEvalCell, self).__init__(auto_prefix=False) + self.network = network + + def construct(self, data, label): + outputs = self.network(data) + return outputs, label + +""" 计算MAE""" +# 获取验证数据 +ds_eval = create_dataset_2() +# 定义评估网络 +eval_net = MyWithEvalCell(net) +eval_net.set_train(False) +# 定义评估指标 +mae = MyMAE() +# 执行推理过程 +for eval_x, eval_y in ds_eval: + output, eval_y = eval_net(eval_x, eval_y) + mae.update(output, eval_y) + +mae_result = mae.eval() +print("mae: ", mae_result) + +ms.save_checkpoint(net, "./MyNet.ckpt") \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" new file mode 100644 index 0000000..fa77863 --- /dev/null +++ "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data" @@ -0,0 +1,36 @@ +-266.875 4931.05 276920.9 2087.5 274682.15 2758.925 24.95 461230.7692 15314.10256 0.20640852 0.031791654 -0.075472142 61099.4 274318.95 2166.7 +-266.875 4737.225 279008.4 2087.5 277441.075 2758.925 24.975 497461.5385 36230.76923 0.194195932 0.103107395 0.071315742 67399.2 277004.825 2302.5 +-266.875 4543.4 281095.9 2087.5 280200 2758.925 25 590769.2308 93307.69231 0.181983344 -0.115118577 -0.218225973 59640.3 279690.7 2393.9 +2137.475 3967.45 284780.2 3684.3 284475 4275 25.025 701153.8462 110384.6154 0.159411232 0.10160244 0.216721017 65699.9 286679.125 2282.7 +2137.475 3391.5 288464.5 3684.3 288750 4275 25.05 893636.3636 192482.5175 0.136839119 0.018561672 -0.083040768 66919.4 293667.55 2362.7 +2137.475 2815.55 292148.8 3684.3 293025 4275 25.075 718583.3333 -175053.0303 0.114267007 0.079108599 0.060546927 72213.3 300655.975 2396.6 +2137.475 2239.6 295833.1 3684.3 297300 4275 25.1 572181.8182 -146401.5152 0.091694895 -0.121603638 -0.200712237 63431.9 307644.4 2134.3 +1430.725 2157.35 309126.9 13293.8 308686.175 11386.175 25.65 578750 6568.181818 0.114394105 0.120735781 0.242339419 71090.4 320543.55 2218 +1430.725 2075.1 322420.7 13293.8 320072.35 11386.175 26.2 572750 -6000 0.137093315 0.041244669 -0.079491112 74022.5 333442.7 2265.2 +1430.725 1910.6 349008.3 13293.8 342844.7 11386.175 27.3 742083.3333 65750 0.182491736 -0.119045468 -0.210404176 71168 359241 2783.7 +-24.275 1799.4 347493.775 -1514.525 345033.525 2188.825 27.125 738181.8182 -3901.515152 0.182863311 0.106493087 0.225538554 78746.9 361516.75 2591 +-24.275 1688.2 345979.25 -1514.525 347222.35 2188.825 26.95 741875 3693.181818 0.183234887 0.033268611 -0.073224476 81366.7 363792.5 2696.7 +-24.275 1577 344464.725 -1514.525 349411.175 2188.825 26.775 783800 41925 0.183606462 0.091614874 0.058346263 88821.1 366068.25 2787.1 +-24.275 1465.8 342950.2 -1514.525 351600 2188.825 26.6 770909.0909 -12890.90909 0.183978037 0.113505687 0.021890813 98902.8 368344 2797.7 +2657.825 1331.225 360144.375 17194.175 362328.2 10728.2 26.9 823833.3333 52924.24242 0.163929228 0.115379949 0.001874262 110314.2 381864.6 2928.2 +2657.825 1196.65 377338.55 17194.175 373056.4 10728.2 27.2 837000 13166.66667 0.143880419 0.044502884 -0.070877065 115223.5 395385.2 3095 +2657.825 1062.075 394532.725 17194.175 383784.6 10728.2 27.5 841500 4500 0.123831609 0.102549393 0.05804651 127039.6 408905.8 3332 +2657.825 927.5 411726.9 17194.175 394512.8 10728.2 27.8 775090.9091 -66409.09091 0.1037828 -0.158227041 -0.260776434 106938.5 422426.4 3159.5 +965.175 883.325 414901.65 3174.75 395242.65 729.85 28.025 762454.5455 -12636.36364 0.103080889 0.110520533 0.268747574 118757.4 424165.6 3294.6 +965.175 839.15 418076.4 3174.75 395972.5 729.85 28.25 635461.5385 -126993.007 0.102378978 0.043446556 -0.067073977 123917 425904.8 3005.7 +965.175 794.975 421251.15 3174.75 396702.35 729.85 28.475 634538.4615 -923.0769231 0.101677068 0.108567832 0.065121276 137370.4 427644 3091.4 +965.175 750.8 424425.9 3174.75 397432.2 729.85 28.7 623454.5455 -11083.91608 0.100975157 -0.16035405 -0.268921882 115342.5 429383.2 3224.6 +-894.95 706.65 421222.8 -3203.1 394922.125 -2510.075 29.075 609272.7273 -14181.81818 0.097064859 0.107518044 0.267872094 127743.9 426022.325 3551.6 +-894.95 662.5 418019.7 -3203.1 392412.05 -2510.075 29.45 559071.4286 -50201.2987 0.09315456 0.047029251 -0.060488793 133751.6 422661.45 3268 +-894.95 618.35 414816.6 -3203.1 389901.975 -2510.075 29.825 571000 11928.57143 0.089244262 0.106268635 0.059239384 147965.2 419300.575 3675.3 +-894.95 574.2 411613.5 -3203.1 387391.9 -2510.075 30.2 554166.6667 -16833.33333 0.085333963 -0.161744113 -0.268012748 124032.7 415939.7 3663.7 +-2178.875 564.1 407963.65 -3649.85 384207.475 -3184.425 30.625 531615.3846 -22551.28205 0.081595911 0.108395609 0.270139721 137477.3 410586.5 3490 +-2178.875 554 404313.8 -3649.85 381023.05 -3184.425 31.05 482307.6923 -49307.69231 0.077857859 0.04343117 -0.064964438 143448.1 405233.3 3425.1 +-2178.875 543.9 400663.95 -3649.85 377838.625 -3184.425 31.475 502333.3333 20025.64103 0.074119807 0.107130035 0.063698865 158815.7 399880.1 3446.1 +-2178.875 533.8 397014.1 -3649.85 374654.2 -3184.425 31.9 490909.0909 -11424.24242 0.070381754 -0.163282975 -0.27041301 132883.8 394526.9 3546.9 +1284.125 523.7 393364.25 -3649.85 371469.775 -3184.425 32.25 421100 -69809.09091 0.066643702 0.108302141 0.271585116 147275.4 392636.7 3205.8 +1284.125 513.6 389714.4 -3649.85 368285.35 -3184.425 32.6 403076.9231 -18023.07692 0.06290565 0.042121766 -0.066180374 153478.9 390746.5 3405.4 +1284.125 503.5 386064.55 -3649.85 365100.925 -3184.425 32.95 370833.3333 -32243.58974 0.059167598 0.106555364 0.064433597 169832.9 388856.3 3527.2 +1284.125 493.4 382414.7 -3649.85 361916.5 -3184.425 33.3 377500 6666.666667 0.055429545 -0.047531426 -0.15408679 161760.5 386966.1 3498.3 +387.575 323.2075 395533.63 13118.93 369353.3475 7436.8475 33.675 390000 12500 0.058261228 0.112682021 0.160213447 179988 394960.715 3674.5 +387.575 289.265 394109.86 -1423.77 367612.945 -1740.4025 34.05 477692.3077 87692.30769 0.05536512 0.047887637 -0.064794384 188607.2 393641.83 3348.7 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" new file mode 100644 index 0000000..09c06a3 --- /dev/null +++ "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/data_test" @@ -0,0 +1,2 @@ +387.575 255.3225 392686.09 -1423.77 365872.5425 -1740.4025 34.425 653250 175557.6923 0.052469012 0.090524646 0.042637009 205680.8 392322.945 3879.5 +387.575 221.38 391262.32 -1423.77 364132.14 -1740.4025 34.8 618958.3333 -34291.66667 0.049572903 -0.15811539 -0.248640036 173159.5 391004.06 3862.7 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" new file mode 100644 index 0000000..f9d6050 --- /dev/null +++ "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable" @@ -0,0 +1,36 @@ +461.2307692 +497.4615385 +590.7692308 +701.1538462 +893.6363636 +718.5833333 +572.1818182 +578.75 +572.75 +742.0833333 +738.1818182 +741.875 +783.8 +770.9090909 +823.8333333 +837 +841.5 +775.0909091 +762.4545455 +635.4615385 +634.5384615 +623.4545455 +609.2727273 +559.0714286 +571 +554.1666667 +531.6153846 +482.3076923 +502.3333333 +490.9090909 +421.1 +403.0769231 +370.8333333 +377.5 +390 +477.6923077 \ No newline at end of file diff --git "a/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" new file mode 100644 index 0000000..ca6d05b --- /dev/null +++ "b/code/2021_autumn/\346\236\227\351\235\231\351\270\243-\346\267\261\345\272\246\345\255\246\344\271\240\345\234\250\347\205\244\347\202\255\345\273\272\346\250\241\347\232\204\345\272\224\347\224\250/lable_test" @@ -0,0 +1,2 @@ +653.25 +618.9583333 \ No newline at end of file -- Gitee