diff --git "a/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/Hexiao.py" "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/Hexiao.py" new file mode 100644 index 0000000000000000000000000000000000000000..f0a11caaf31f9232804de2a58676e7c858f5d1bb --- /dev/null +++ "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/Hexiao.py" @@ -0,0 +1,166 @@ +import pandas as pd +import numpy as np +import mindspore.nn as nn +from mindspore import ops +from mindspore.dataset import ds +from mindspore.nn import Accuracy +from mindspore.train.callback import LossMonitor +from mindspore import Model +from matplotlib import pyplot as plt + +# 定义超参数 +epochs = 5000 +batch_size = 16 +feature = ['Tmaincontact', 'Tarccontact', 'Rmaincontact', 'Rarccontact', 'maxsplitRes'] +trainfile = 'train_data.csv' +testfile = 'test_data.csv' + + +# 读取文件 +def get_data(inputfile): + data = pd.read_csv(inputfile, header=0, delimiter=';', index_col=None) # .csv文件分隔符为';' + # data_test = pd.read_csv('./test_data.csv', header=0, delimiter=';', index_col=None) + data_train = data.sample(frac=1) + # data_train_feat = data_train[feature] + # data_train_feat_mean = data_train_feat.mean() + # data_train_feat_std = data_train_feat.std() + a = np.array(data_train[feature].values, dtype=np.float32) + #for i in range(5): + #a[:, i] = (a[:, i] - np.mean(a[:, i]))/np.std(a[:, i]) + b = np.array(data_train['class'].values, dtype=np.int32) + q = [] + for i in range(len(a)): + q.append((a[i], np.asarray(b[i]))) + # c = np.eye(4)[b - 1] + return q + + +def create_dataset(inputfile, batch_size, repeat_size=1): + """定义数据集""" + q = list(get_data(inputfile)) + input_data = ds.GeneratorDataset(list(get_data(inputfile)), column_names=['data', 'label']) + input_data = input_data.batch(batch_size) + input_data = input_data.repeat(repeat_size) + return input_data + + +# def get_data(data_num, data_size): +# for _ in range(data_num): +# data = np.random.randn(data_size) +# p = np.array([1, 0, -3, 5]) +# label = np.polyval(p, data).sum() +# yield data.astype(np.float32), np.array([label]).astype(np.float32) +# +# def create_dataset(data_num, data_size, batch_size=32, repeat_size=1): +# """定义数据集""" +# q = list(get_data(data_num, data_size)) +# input_data = ds.GeneratorDataset(list(get_data(data_num, data_size)), column_names=['data', 'label']) +# input_data = input_data.batch(batch_size) +# input_data = input_data.repeat(repeat_size) +# return input_data + + +class BPNet(nn.Cell): + """ + BP神经网络结构 + """ + + def __init__(self, num_input=5, num_output=4): + super(BPNet, self).__init__() + self.fc1 = nn.Dense(num_input, 8) + self.fc2 = nn.Dense(8, 8) + self.fc3 = nn.Dense(8, num_output) + self.ReLU = nn.ReLU() + self.softmax = nn.Softmax() + + def construct(self, x1): + x = self.fc1(x1) + x = self.ReLU(x) + x = self.fc2(x) + x = self.ReLU(x) + x = self.fc3(x) + x = self.softmax(x) + return x + + +net = BPNet() + + +class MyWithLossCell(nn.Cell): + """定义损失网络""" + + def __init__(self, backbone, loss_fn): + super(MyWithLossCell, self).__init__(auto_prefix=False) + self.backbone = backbone + self.loss_fn = loss_fn + + def construct(self, data, label): + out = self.backbone(data) + return self.loss_fn(out, label) + + def backbone_network(self): + return self.backbone + + def acc(self, data): + out = self.backbone(data) + return out + + +class MyTrainStep(nn.TrainOneStepCell): + """定义训练流程""" + + def __init__(self, network, optimizer): + """参数初始化""" + super(MyTrainStep, self).__init__(network, optimizer) + self.grad = ops.GradOperation(get_by_list=True) + + def construct(self, data, label): + """构建训练过程""" + weights = self.weights + loss = self.network(data, label) + grads = self.grad(self.network, weights)(data, label) + return loss, self.optimizer(grads) + + +# 生成多项式分布的训练数据 + +ds_train = create_dataset(trainfile, batch_size, repeat_size=1) +ds_test = create_dataset(testfile, 16, repeat_size=1) + +# ds_train = create_dataset(64, 128, repeat_size=1) +# 定义损失函数 +net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + +# 定义优化器 +net_opt = nn.Adam(net.trainable_params(), learning_rate=0.0001) +# 构建损失网络 +net_with_criterion = MyWithLossCell(net, net_loss) +# 构建训练网络 +train_net = MyTrainStep(net_with_criterion, net_opt) +# 执行训练,每个epoch打印一次损失值 + +scores = [] +loss_ = [] +for epoch in range(epochs): + for train_x, train_y in ds_train: + train_net(train_x, train_y) + loss_val = net_with_criterion(train_x, train_y) + loss_test = 0 + for test_x, test_y in ds_test: + loss_test += net_with_criterion(test_x, test_y) + out_y = net_with_criterion.acc(test_x) + out_y = np.argmax(out_y, axis=-1) + cnt_array = np.where(out_y - test_y.asnumpy(), 0, 1) + acc = np.sum(cnt_array) / 15 + loss_.append(loss_test.asnumpy())#转换数组类型为numpy + scores.append(acc) + if epoch %100 == 0: + plt.clf() + plt.plot(scores,label="Accuracy") + plt.plot(loss_,label="Loss") + plt.legend(loc="best", fontsize=6) + plt.show() + plt.pause(0.01) + + +#参考了mindspore的教程https://mindspore.cn/tutorials/zh-CN/r1.5/intermediate/mid_low_level_api.html中的部分代码 \ No newline at end of file diff --git "a/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/test_data.csv" "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/test_data.csv" new file mode 100644 index 0000000000000000000000000000000000000000..b93dd4a3f51941fa94d2bbcde8ba2b9b907db781 --- /dev/null +++ "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/test_data.csv" @@ -0,0 +1,16 @@ +Tmaincontact;Tarccontact;Rmaincontact;Rarccontact;maxsplitRes;class;servicelife +9.8;18.6;151.9540816;422.3629032;586;0;1 +9.8;15.7;205.7653061;410.8375796;569;0;0.882352941 +9.8;16.2;200.0357143;496.3641975;726;0;0.823529412 +10.9;14.5;226.0412844;568.7103448;623;0;0.764705882 +9.7;14;233.7216495;554.3785714;607;1;0.705882353 +10.1;11.6;255.9405941;544.0517241;692;1;0.647058824 +9.4;12.7;279.4148936;604.4685039;715;1;0.588235294 +9.4;11.9;252.2659574;556.5042017;627;1;0.529411765 +9.5;10.2;262.5315789;620.1372549;764;2;0.470588235 +9.2;9.6;308.5108696;575.8958333;632;2;0.411764706 +9.5;9.8;301.6052632;526.244898;601;2;0.352941176 +8.1;9.5;308.4197531;572.0263158;644;2;0.294117647 +8.8;7.4;296.9090909;628.1013514;851;3;0.235294118 +8.6;7.6;315.9825581;632.7105263;759;3;0.176470588 +8.6;6.7;336.627907;638.4701493;759;3;0.117647059 diff --git "a/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/train_data.csv" "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/train_data.csv" new file mode 100644 index 0000000000000000000000000000000000000000..e237fbeb911ee51971a3e5e35f52d9d1769d6fac --- /dev/null +++ "b/code/2021_autumn/\344\275\225\346\275\207-\345\237\272\344\272\216BP\347\245\236\347\273\217\347\275\221\347\273\234\347\232\204SF6\346\226\255\350\267\257\345\231\250\347\224\265\345\257\277\345\221\275\350\257\204\344\274\260/train_data.csv" @@ -0,0 +1,65 @@ +Tmaincontact;Tarccontact;Rmaincontact;Rarccontact;maxsplitRes;class;servicelife +9.5;18.7;157.9368421;425.3475936;592;0;1 +10;18.3;153.75;427.510929;572;0;1 +9.6;18.4;150.2447917;425.1820652;523;0;1 +9.6;19;150.6041667;429.2842105;583;0;1 +9.7;16.8;205.8350515;416.7083333;498;0;0.882353 +9.6;17;203.8333333;420.5411765;490;0;0.882352941 +10;16.6;195.36;366.4879518;446;0;0.882352941 +10;16.8;196.955;412.422619;468;0;0.882352941 +9.5;16.3;208.3842105;482.3527607;626;0;0.823529412 +9.4;16.3;202.2393617;484.8496933;677;0;0.823529412 +9.6;16.2;199.8020833;484.5679012;617;0;0.823529412 +9.8;15.8;204.0510204;476.5949367;754;0;0.823529412 +10.7;14.7;226.5514019;543.3265306;606;0;0.764705882 +10.5;15.2;228.7428571;540.75;632;0;0.764705882 +10.5;15.2;225.3285714;532.8125;621;0;0.764705882 +10.4;14.8;229.9134615;534.5;594;0;0.764705882 +10;13.5;233.835;559.5111111;680;1;0.705882353 +9.8;14;237.0867347;594.6607143;642;1;0.705882353 +9.4;14;232.0585106;566.6821429;633;1;0.705882353 +9.7;13.8;238.0618557;564.6086957;645;1;0.705882353 +10;12.8;256.365;564.03125;703;1;0.647058824 +9.7;13.3;262.1443299;578.7518797;730;1;0.647058824 +9.8;12.7;262.755102;550.9212598;620;1;0.647058824 +9.6;13.4;262.8541667;569.6119403;635;1;0.647058824 +9.4;12.5;260.5425532;540.072;694;1;0.588235294 +9;13;271.0388889;572.4538462;691;1;0.588235294 +9.9;12.1;277.3585859;588.9380165;658;1;0.588235294 +9.4;12.9;280.7180851;578.9806202;704;1;0.588235294 +9.2;11.9;259.3695652;567.5756303;759;1;0.529411765 +9;12.4;251.3166667;559.3266129;695;1;0.529411765 +9.2;12.1;258.701087;573.946281;692;1;0.529411765 +9.4;11.7;263.1968085;572.3931624;690;1;0.529411765 +9.9;9.8;272.6868687;632.2091837;760;2;0.470588235 +9.4;10;271.6968085;621.26;870;2;0.470588235 +8.9;10.9;262.011236;604.6192661;722;2;0.470588235 +9.5;9.9;270.2421053;610.4242424;727;2;0.470588235 +9.2;9.7;302.5380435;570.4072165;721;2;0.411764706 +8.4;9;293.0297619;526.9;644;2;0.411764706 +9.1;9.9;299.7032967;529.4343434;590;2;0.411764706 +9.1;9.9;302.1868132;550.459596;609;2;0.411764706 +9.5;9.6;310.5210526;539.3958333;625;2;0.352941176 +9.4;9.8;310.4468085;552.1887755;623;2;0.352941176 +9.1;10.3;301.956044;527.7912621;591;2;0.352941176 +9.6;9.6;308.5885417;539.6145833;623;2;0.352941176 +8.8;8.4;288.2045455;580.1904762;644;2;0.294117647 +9;8.9;316.9833333;566.5674157;646;2;0.294117647 +8.6;8.8;289.4651163;543.3352273;590;2;0.294117647 +8.5;9.2;298.7058824;563.9293478;602;2;0.294117647 +8.4;7.1;290.1309524;606.8450704;825;3;0.235294118 +8.5;7.7;312.7823529;624.3441558;830;3;0.235294118 +8.7;7.8;299.7988506;620.2884615;854;3;0.235294118 +8.8;7.5;297.9886364;626.7;804;3;0.235294118 +8.7;5.4;313.8735632;603.5092593;956;3;0.176470588 +8.4;7.9;326.0952381;623.3987342;827;3;0.176470588 +9.6;6;334.0833333;635.7916667;910;3;0.176470588 +8.6;6.7;328.6337209;655.2537313;780;3;0.117647059 +8.4;7.1;350.797619;675.0704225;768;3;0.117647059 +8.2;7.3;370;672.2123288;816;3;0.117647059 +8.5;7;345.3941176;653.45;739;3;0.117647059 +7.4;4.6;280.1418919;679.3695652;996;3;0.058823529 +8;3.8;269.63125;682.9078947;910;3;0.058823529 +7.7;4.1;280.7662338;679.1097561;947;3;0.058823529 +8.1;3.9;280.7716049;694.9230769;940;3;0.058823529 +7.2;4;323.0347222;764.7125;1042;3;0.058823529