diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/.keep" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/.keep" new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Evaluator.py" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Evaluator.py" new file mode 100644 index 0000000000000000000000000000000000000000..7ee6bea73753997ca4633f55096d089f71827163 --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Evaluator.py" @@ -0,0 +1,200 @@ +import random +from surprise.model_selection import train_test_split +from collections import defaultdict +import numpy as np + +#构建数据集 +class EvalData: + def __init__(self, data, popularity): + + # 流行度 + self.popularity = popularity + + #构建85/15的train/test数据集 + self.trainSet, self.testSet = train_test_split(data, test_size=.15, random_state=1) + + #构建全数据集,进行更精确的协同滤波,和相关指标的评测 + self.fullTrainSet = data.build_full_trainset() + self.fullAntiTestSet = self.fullTrainSet.build_anti_testset() + + def GetFullTrainSet(self): + return self.fullTrainSet + + def GetFullAntiTestSet(self): + return self.fullAntiTestSet + + # 从所有电影中除去uid评分也就是看过的 + def GetAntiTestSetForUser(self, uid): + trainset = self.fullTrainSet + fill = trainset.global_mean #所有rating的均值 + anti_testset = [] + + #转换到内部uid + u = trainset.to_inner_uid(str(uid)) + user_items = set([j for (j, _) in trainset.ur[u]]) + # 不含rating过的movie + anti_testset += [(trainset.to_raw_uid(u), trainset.to_raw_iid(i), fill) for + i in trainset.all_items() if + i not in user_items] + return anti_testset + + def GetTrainSet(self): + return self.trainSet + + def GetTestSet(self): + return self.testSet + + def GetPopularity(self): + return self.popularity + + +class EvalAlgo: + + def __init__(self, algorithm, name): + #algorithm可以是多种类型的算法对象 + self.algorithm = algorithm + self.name = name + + def Measure(self, evalData): + #evalData是一个EvalData类的实例 + metrics = {} + + #print("Measuring Metrics ...") + + print("Measuring RMSE&MAE ...") + # 使用train_test_split()分出的训练集进行fit,以便评估精度 + self.algorithm.fit(evalData.GetTrainSet()) + # 先调用的父类AlgoBase的test(),然后调用自己的estimate() + # surprise生成的测试集是用的rawID,但是test等函数内部会转换为innerID进行内部计算,并将结果转换为rawID返回 + # predictions是List,单元为 surprise.Prediction(uid, iid, r_ui, est, details),类型为(str,str,float,float,dict) + predictions = self.algorithm.test(evalData.GetTestSet()) + metrics["RMSE"] = self.RMSE(predictions) + metrics["MAE"] = self.MAE(predictions) + + print("Measuring LongTail ...") + # 使用全数据集进行评测 + self.algorithm.fit(evalData.GetFullTrainSet()) + predictions2 = self.algorithm.test(evalData.GetFullAntiTestSet()) + # 评估TopN序列的长尾效应 + topNList = self.GetTopN(predictions2, n=10) + metrics["LongTail"] = self.LongTail(topNList, evalData.GetPopularity()) + + return metrics + + def GetName(self): + return self.name + + def GetAlgorithm(self): + return self.algorithm + + #计算MAE + def MAE(self, predictions): + if not predictions: + raise ValueError('Prediction list is empty.') + mae = np.mean([float(abs(true_r - est)) + for (_, _, true_r, est, _) in predictions]) + return mae + + #计算RMSE + def RMSE(self,predictions): + if not predictions: + raise ValueError('Prediction list is empty.') + mse = np.mean([float((true_r - est)**2) + for (_, _, true_r, est, _) in predictions]) + rmse = np.sqrt(mse) + return rmse + + # 获取topN序列 + def GetTopN(self, predictions, n=10): + # 使用value为list的字典形式储存 + topN = defaultdict(list) + + for uid, iid, actualRating, estimatedRating, _ in predictions: + topN[int(uid)].append((int(iid), estimatedRating)) + + #先打乱,再逆序,再截断后更新 + for userID, ratings in topN.items(): + random.shuffle(ratings) + ratings.sort(key=lambda x: x[1], reverse=True) + topN[int(userID)] = ratings[:n] + + return topN + + # 计算LongTail值 + def LongTail(self,topN, rankings): + n = 0 + total = 0 + for userID in topN.keys(): + for rating in topN[userID]: + movieID = rating[0] + rank = rankings[movieID] + total += rank + n += 1 + return total / n + + +class Evaluator: + + algorithms = [] + + def __init__(self, dataset, rankings): + ed = EvalData(dataset, rankings) + self.dataset = ed + + def AddAlgorithm(self, algorithm, name): + # algorithm可以是多种类型的算法对象 + alg = EvalAlgo(algorithm, name) + self.algorithms.append(alg) + + def Evaluate(self): + results = {} + #self.algorithms是装载了EvalAlgo对象的list + for algorithm in self.algorithms: + print("Evaluating ", algorithm.GetName(), "...") + #self.dataset是EvalData的对象 + results[algorithm.GetName()] = algorithm.Measure(self.dataset) + + print("{:<10} {:<10} {:<10} {:<10}".format("Algorithm", "RMSE", "MAE", "LongTail")) + for (name, metrics) in results.items(): + print("{:<10} {:<10.4f} {:<10.4f} {:<10.4f}".format(name, metrics["RMSE"], metrics["MAE"],metrics["LongTail"])) + + + def EvalTopNRecs(self, ml, n= 10, uidRec = 0): + print("\nFor user: ",uidRec) + + for algo in self.algorithms: + print("\nAlgorithm: ", algo.GetName()) + + #利用全数据集进行协同滤波 + trainSet = self.dataset.GetFullTrainSet() + if(algo.GetName() == "RBM"): + print("building...") + algo.GetAlgorithm().fit(trainSet,getRec = 1) + else: + algo.GetAlgorithm().fit(trainSet) + + #获取不含用户rated的数据集进行预测 + testSet = self.dataset.GetAntiTestSetForUser(uidRec) + predictions = algo.GetAlgorithm().test(testSet) + + recommendations = [] + + print ("\nRecommend:") + # 提取电影ID和预测rating + for userID, movieID, actualRating, estimatedRating, _ in predictions: + intMovieID = int(movieID) + recommendations.append((intMovieID, estimatedRating)) + + # 防止出现很多最高分,打乱让后面的5分也能被推荐 + random.shuffle(recommendations) + recommendations.sort(key=lambda x: x[1], reverse=True) + + #将id换成名称后输出 + for ratings in recommendations[:n]: + print(ml.getMovieName(ratings[0]), round(ratings[1],3)) + + + + + + \ No newline at end of file diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.ipynb" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.ipynb" new file mode 100644 index 0000000000000000000000000000000000000000..9e0d4f28dce058e62b6bb6cbc2f8de75a4fbffbb --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.ipynb" @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bff29d3d", + "metadata": {}, + "outputs": [], + "source": [ + "from MovieLens import MovieLens\n", + "from RBMAlgo import RBMAlgo\n", + "from Evaluator import Evaluator\n", + "\n", + "from surprise import NormalPredictor" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "69117c2f", + "metadata": {}, + "outputs": [], + "source": [ + "#本地用\n", + "ratingsPath = 'J:/DeepLearning/ml-latest-small/ratings.csv'\n", + "moviesPath = 'J:/DeepLearning/ml-latest-small/movies.csv'\n", + "\n", + "#ModelArts上用\n", + "#ratingsPath = '/home/ma-user/work/ml-latest-small/ratings.csv'\n", + "#moviesPath = '/home/ma-user/work/ml-latest-small/movies.csv'\n", + "\n", + "\n", + "ml = MovieLens(ratingsPath,moviesPath)\n", + "evalData = ml.loadMovieLensData()\n", + "rankings = ml.getPopularityRanks()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5d27d9f4", + "metadata": {}, + "outputs": [], + "source": [ + "# 创建evaluator,装载数据和不同算法\n", + "#print(\"initiating Evaluator....\")\n", + "evaluator = Evaluator(evalData, rankings)\n", + "\n", + "# 创建RBM实例\n", + "rbmAlgo = RBMAlgo(batchSize = 100, epochs=10,hDims=100,lr=0.002,momentum = 0.95)\n", + "evaluator.AddAlgorithm(rbmAlgo, \"RBM\")\n", + "\n", + "# 使用随机推荐系统\n", + "# 随机评估打分是用global_mean和sigma的正态分布\n", + "RandomAlgo = NormalPredictor()\n", + "evaluator.AddAlgorithm(RandomAlgo, \"Random\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fcdc24f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating RBM ...\n", + "Measuring RMSE&MAE ...\n", + "hDims: 100 lr: 0.002 momentum: 0.95\n", + "Trained epoch 0\tloss = 0.68656 \tt = 7.7\n", + "Trained epoch 1\tloss = 0.48377 \tt = 15.1\n", + "Trained epoch 2\tloss = 0.31079 \tt = 22.2\n", + "Trained epoch 3\tloss = 0.2462 \tt = 29.4\n", + "Trained epoch 4\tloss = 0.24946 \tt = 36.6\n", + "Trained epoch 5\tloss = 0.19326 \tt = 43.7\n", + "Trained epoch 6\tloss = 0.21461 \tt = 50.9\n", + "Trained epoch 7\tloss = 0.22422 \tt = 58.1\n", + "Trained epoch 8\tloss = 0.19242 \tt = 65.3\n", + "Trained epoch 9\tloss = 0.23234 \tt = 72.6\n", + "Trained Total: 72.575510837\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjqklEQVR4nO3deXxV9bnv8c+TmSGEKYyBBAhCURkjEnFAnKsHJ6xiRW3rsdaiPW1Prb2395zenr5upzPc46xVT52qVWyPVq22V0UFQQkKKCAKAUICSAgkTJnz3D/2RgMNyJCVlb3X9/167VfYa/2y97O3Zn33b621n2XujoiIRFdK2AWIiEi4FAQiIhGnIBARiTgFgYhIxCkIREQiLi3sAo5U3759vaCgIOwyREQSypIlS7a5e25b6xIuCAoKCigpKQm7DBGRhGJmGw62LtBdQ2Z2vpmtNrM1ZnZ7G+v/w8yWxm8fm1l1kPWIiMjfCmxGYGapwN3AOUA5sNjMnnf3lfvGuPt3W42/BZgQVD0iItK2IGcEk4E17l7q7g3AU8DFhxg/C3gywHpERKQNQQbBYGBjq/vl8WV/w8zygWHAawdZf6OZlZhZSWVlZbsXKiISZZ3l9NGrgLnu3tzWSnd/wN2L3L0oN7fNg94iInKUggyCCmBIq/t58WVtuQrtFhIRCUWQQbAYGGlmw8wsg9jG/vkDB5nZaKAXsDDAWkRE5CACCwJ3bwLmAK8Aq4Cn3X2Fmf3UzGa0GnoV8JQH3A972cZqfvnyR0E+hYhIQgr0C2Xu/hLw0gHL/umA+z8JsoZ9lpdXc++8tZx3/ADGD+nZEU8pIpIQOsvB4sBdMmEw3TJSeXTh+rBLERHpVCITBNlZ6Vw2MY8Xlm9m+56GsMsREek0IhMEALOL82loauHpko1fPFhEJCIiFQTH9c/m5GG9eXzRBppbdK1mERGIWBAAXFtcQPmOWuat3hp2KSIinULkguDc4/vTv0cmjy06aEdWEZFIiVwQpKemMGvyUN74uJINVXvCLkdEJHSRCwKAWZOHkmrG45oViIhEMwj698jivOMH8HRJObUNbfa5ExGJjEgGAcROJa2pbeRPyzeFXYqISKgiGwQnD+vNcf2789jCDQTc5khEpFOLbBCYGbOn5PNBRQ1LN1aHXY6ISGgiGwQAl07Mo3tmGo8t1EFjEYmuSAdB98w0Lps4WP2HRCTSIh0EANdMyaehuYXfL1b/IRGJpsgHwXH9s5kyXP2HRCS6Ih8EEOs/VFGt/kMiEk0KAuCcMbH+Q4/qoLGIRJCCgP37D63fpv5DIhItCoK4qycPJS1F/YdEJHoUBHH9emRx3gkDeGaJ+g+JSLQoCFqZPSXef2iZ+g+JSHQoCFrZ13/o0UXr1X9IRCJDQdCKmTG7uIAPK3byvvoPiUhEKAgOcOmEwXTPTONxnUoqIhGhIDhA6/5DVbvrwy5HRCRwCoI2zN7Xf6hE/YdEJPkpCNowsn82xcP78MSiMvUfEpGkpyA4iNnF+VRU1/L6R+o/JCLJTUFwEJ/1H9I3jUUkySkIDiI9NYWrJ+fz5seVrFP/IRFJYoEGgZmdb2arzWyNmd1+kDFfMbOVZrbCzH4XZD1HatbkIaSlGE9oViAiSSywIDCzVOBu4AJgDDDLzMYcMGYk8CNgqrsfD/xDUPUcjX39h54u2aj+QyKStIKcEUwG1rh7qbs3AE8BFx8w5u+Bu919B4C7d7ojs9dOyWdnXRPPL6sIuxQRkUAEGQSDgdYn4pfHl7V2HHCcmS0ws0Vmdn5bD2RmN5pZiZmVVFZWBlRu2yYP682o/tk8unCD+g+JSFIK+2BxGjASmAbMAn5jZj0PHOTuD7h7kbsX5ebmdmiBZsY1xfms2KT+QyKSnIIMggpgSKv7efFlrZUDz7t7o7uvAz4mFgydyr7+Q4+p/5CIJKEgg2AxMNLMhplZBnAV8PwBY/6b2GwAM+tLbFdRaYA1HZXumWlcPnEwLy7fzDb1HxKRJBNYELh7EzAHeAVYBTzt7ivM7KdmNiM+7BWgysxWAq8DP3D3qqBqOhbX7Os/tFj9h0QkuViiHQAtKirykpKSUJ571gOLKNu+lzdvO5PUFAulBhGRo2FmS9y9qK11YR8sTijXxvsPvab+QyKSRBQER+Cz/kML14ddiohIu1EQHIG0eP+htz7ZRmnl7rDLERFpFwqCI/RZ/6F3ysIuRUSkXSgIjlC/Hlmcf8IAnlH/IRFJEgqCozA73n/ouaXqPyQiiU9BcBTUf0hEkomC4CiYGbOL81m5eSfvlVWHXY6IyDFREBylz/sPrQ+7FBGRY6IgOErd4v2HXvpgi/oPiUhCUxAcg9nF6j8kIolPQXAMCvtlc8qIPvzunTKaW3TQWEQSk4LgGM2eEus/9OqqT8MuRUTkqCgIjtE5Y/ozoEcWjy3SRWtEJDEpCI5RWmoKV588VP2HRCRhKQjawVXx/kOPL1L/IRFJPAqCdtAvO95/aMlG9jY0hV2OiMgRURC0k2uLC9hV18RzSzeFXYqIyBFRELSTkwp6MXpANo+p/5CIJBgFQTsxM66Zsq//0I6wyxEROWwKgnZ06YTBZGem8ehCnUoqIolDQdCOumWmcfmkPF76YLP6D4lIwlAQtLNrpuTT2OzqPyQiCUNB0M4K+3XnlBF9eGLRBpqaW8IuR0TkCykIAnBtcT6baup49aOtYZciIvKFFAQBOPtLsf5Dj6v/kIgkAAVBANR/SEQSiYIgIFdNHkJ6qqkrqYh0egqCgMT6Dw1k7pJy9R8SkU5NQRCga4vz1X9IRDo9BUGAivJj/YceVf8hEenEAg0CMzvfzFab2Rozu72N9debWaWZLY3fbgiyno5mZswuzmfV5p0s2aD+QyLSOQUWBGaWCtwNXACMAWaZ2Zg2hv7e3cfHbw8GVU9YLhmv/kMi0rkFOSOYDKxx91J3bwCeAi4O8Pk6pX39h/784WYqd6n/kIh0PkEGwWCgdcOd8viyA11uZsvNbK6ZDWnrgczsRjMrMbOSysrKIGoN1Of9h3QpSxHpfMI+WPwnoMDdxwJ/BR5pa5C7P+DuRe5elJub26EFtofCft2ZWtiHJ94pU/8hEel0ggyCCqD1J/y8+LLPuHuVu+/bX/IgMCnAekI1e0oBm9V/SEQ6oSCDYDEw0syGmVkGcBXwfOsBZjaw1d0ZwKoA6wnV2V/qx8CcLB7TQWMR6WQCCwJ3bwLmAK8Q28A/7e4rzOynZjYjPuxWM1thZsuAW4Hrg6onbGmpKVw9eSjz12xjrfoPiUgnYon2RaeioiIvKSkJu4yjsnVXHVN/8RpfPTmfn8w4PuxyRCRCzGyJuxe1tS7sg8WR0i87iwtOGMgzJRup3tsQdjkiIoCCoMN9a9oI9jQ08/CC9WGXIiICKAg63JcG9uCcMf35rwXr2FnXGHY5IiIKgjDcOn0ku+qaePTt9WGXIiKiIAjDiXk5nDkql4fmr2NPva5VICLhUhCE5JazRrJjb6OuaywioVMQhGTi0F6cWtiX37xVSm1Dc9jliEiEKQhCdMv0QrbtbuDJd9WMTkTCoyAI0cnD+zB5WG/uf3MtdY2aFYhIOBQEIbt1+kg+3VnPMyUbv3iwiEgAFAQhm1rYhwlDe3LvvLU0NKlFtYh0PAVByMyMW6ePZFNNHX94rzzsckQkgg4rCMzsO2bWw2IeMrP3zOzcoIuLimmjcjlxcA73zFurC9eISIc73BnB1919J3Au0AuYDfwisKoixsy4ZXohZdv38tzSTWGXIyIRc7hBYPGfXwYec/cVrZZJOzhnTH9GD8jm7tfX0NySWK3BRSSxHW4QLDGzvxALglfMLBvQPox2FJsVjKR02x5e/GBz2OWISIQcbhB8A7gdOMnd9wLpwNcCqyqiLjhhAIX9unPXa5/QolmBiHSQww2CYmC1u1eb2TXAj4Ga4MqKppQUY86ZhXz86W7+snJL2OWISEQcbhDcC+w1s3HA94G1wKOBVRVhF40dSEGfrtzx6hoS7TKiIpKYDjcImjy2VboYuMvd7waygysrutJSU7j5zEJWbt7Jq6u2hl2OiETA4QbBLjP7EbHTRl80sxRixwkkAJdOGExery7c+donmhWISOAONwiuBOqJfZ9gC5AH/DqwqiIuPTWFb00bwbLyGt78ZFvY5YhIkjusIIhv/J8AcszsIqDO3XWMIEAzJ+UxMCeLO1/VrEBEgnW4LSa+ArwLXAF8BXjHzGYGWVjUZaalctMZIyjZsIOFpVVhlyMiSexwdw39T2LfIbjO3a8FJgP/K7iyBODKk4aQm53Jna+uCbsUEUlihxsEKe7e+hSWqiP4XTlKWempfPP04SwsraJk/fawyxGRJHW4G/OXzewVM7vezK4HXgReCq4s2efqk4fSu1sGd7ymWYGIBONwDxb/AHgAGBu/PeDuPwyyMInpmpHGDacN482PK1m6sTrsckQkCR327h13f9bdvxe//THIomR/1xYXkNMlnTtf/STsUkQkCR0yCMxsl5ntbOO2y8x2dlSRUdc9M42vTx3Gqx9t5cMKtXgSkfZ1yCBw92x379HGLdvde3RUkQLXTy0gOzONu3SsQETaWaBn/pjZ+Wa22szWmNnthxh3uZm5mRUFWU8iy+mSznWnFPDyii2s3rIr7HJEJIkEFgRmlgrcDVwAjAFmmdmYNsZlA98B3gmqlmTxjVOH0TUjlbte16xARNpPkDOCycAady919wbgKWLdSw/0L8AvgboAa0kKvbplMLs4nxeWb2Jt5e6wyxGRJBFkEAwGNra6Xx5f9hkzmwgMcfcXD/VAZnajmZWYWUllZWX7V5pA/v604WSmpXC3ZgUi0k5C+3ZwvJX1vxO70M0hufsD7l7k7kW5ubnBF9eJ9e2eydWT83lu6SbKqvaGXY6IJIEgg6ACGNLqfl582T7ZwAnAPDNbD0wBntcB4y/2zTOGk5pi3DNPswIROXZBBsFiYKSZDTOzDOAq4Pl9K929xt37unuBuxcAi4AZ7l4SYE1JoX+PLK4sGsKz75VTvkOzAhE5NoEFgbs3AXOAV4BVwNPuvsLMfmpmM4J63qi4adoIAO57Y23IlYhIoksL8sHd/SUOaE7n7v90kLHTgqwl2Qzu2YXLJ+bx9OJy5pw5kgE5WWGXJCIJSq2kE9jN0wppduf+NzUrEJGjpyBIYEP7dOWS8YP53TtlVO6qD7scEUlQCoIE9+0zR9DY3MKDb5WGXYqIJCgFQYIbntudi8YO4rFFG9i+pyHsckQkASkIksCc6YXsbWjm4fnrwi5FRBKQgiAJHNc/mwtOGMAjb6+nprYx7HJEJMEoCJLEnOmF7Kpv4rcL1oddiogkGAVBkjh+UA5nf6kfDy9Yx646zQpE5PApCJLILdNHUlPbyKMLN4RdiogkEAVBEhk3pCenH5fLQ/PXsbehKexyRCRBKAiSzHfOKmT7ngaeWFQWdikikiAUBElmUn5vThnRh/vfLKWusTnsckQkASgIktAt00eybXc9T72rWYGIfDEFQRKaMrw3JxX04r43Sqlv0qxARA5NQZCEzIxbpo9ky8465i4pD7scEenkFARJ6rSRfRk3pCf3zltLY3NL2OWISCemIEhSZsat0wsp31HLH9+v+OJfEJHIUhAksemj+3H8oB7c8/oamjQrEJGDUBAksdixgkLWV+3lT8s3hV2OiHRSCoIkd+6YAYzqn81dr62hucXDLkdEOiEFQZJLSTHmTC9kbeUe/vzh5rDLEZFOSEEQAV8+cSDDc7tx12traNGsQEQOoCCIgNQUY86ZhXy0ZRd/XfVp2OWISCejIIiIGeMGkd+nK3e+9gnumhWIyOcUBBGRlprCzdNG8GHFTuatrgy7HBHpRBQEEXLphDwG9+zCHZoViEgrCoIIyUhL4aZpI3i/rJoFa6rCLkdEOgkFQcRcMSmP/j0yuePVT8IuRUQ6CQVBxGSlp3LTGSN4d/12FpVqViAiCoJImjV5KH27Z/L9p5exbGN12OWISMgUBBGUlZ7Kw9cXAXDFfQt5bOF6HTwWibBAg8DMzjez1Wa2xsxub2P9TWb2gZktNbP5ZjYmyHrkc2PzevLCLacytbAP/+u5Fdz61FL21DeFXZaIhCCwIDCzVOBu4AJgDDCrjQ3979z9RHcfD/wK+Peg6pG/1atbBg9ddxI/OG8ULy7fxIy75vPxp7vCLktEOliQM4LJwBp3L3X3BuAp4OLWA9x9Z6u73QDtn+hgKSnGt88s5PEbTqamtomL71rAH97T5S1FoiTIIBgMbGx1vzy+bD9m9m0zW0tsRnBrWw9kZjeaWYmZlVRW6luxQThlRF9euvVUTszL4XtPL+NHf/iAukZd+F4kCkI/WOzud7v7COCHwI8PMuYBdy9y96Lc3NyOLTBC+vXI4nc3nMxNZ4zgyXfLuPzetymr2ht2WSISsCCDoAIY0up+XnzZwTwFXBJgPXIY0lJTuP2C0Tx4bREbt+/lwjvf4i8rtoRdlogEKMggWAyMNLNhZpYBXAU833qAmY1sdfdCQF937STOHtOfF289jYI+3bjxsSX8/KVVNOq6xyJJKbAgcPcmYA7wCrAKeNrdV5jZT81sRnzYHDNbYWZLge8B1wVVjxy5Ib278sxNxVwzZSj3v1nK1b9ZxKc768IuS0TamSXaF4mKioq8pKQk7DIi57mlFdz+7Ad0y0zlP6+awNTCvmGXJCJHwMyWuHtRW+tCP1gsieHi8YN5fs5UenbNYPZD73Dnq5/ospciSUJBIIdtZP9snvv2VGaMG8S//fVjvv7IYnbsaQi7LBE5RgoCOSLdMtP4jyvH87NLTuDtNVVceMdbvF+2I+yyROQYKAjkiJkZ10zJ59lvnUJKivGV+xfy2wXr1LhOJEEpCOSonZiXw4u3nMYZx+Xykz+tZM7v3mdXXWPYZYnIEVIQyDHJ6ZrOA7OLuP2C0by8YgsX37WAj7bs/OJfFJFOQ0EgxywlxbjpjBH87oaT2VXfxCV3L2DuEjWuE0kUCgJpNycP78OLt57KhCG9+MdnlnH7s8vVuE4kASgIpF31y87isW9M5ttnjuCpxRu57J632VC1J+yyROQQFATS7tJSU/jBeaN5+PoiKqprueiO+bz8oRrXiXRWCgIJzPTR/Xnx1lMZntuNmx5fws9eWKnGdSKdkIJAApXXqytP31TMdcX5PDh/HbMeWMSWGjWuE+lMFAQSuMy0VP73xSdw56wJrNq8ky/f8RZvfaIrzYl0FmlhFyDR8XfjBvGlgT24+YklXPvwu3znrJHcMn0kqSl2TI/r7tQ3tVDb0ExtY+xWF7/VNrR8vqzV+tqGVmMam3GH2cX5jM3r2T4vViSBqA21dLi9DU38+I8f8of3KzhtZF/+btygv9lw18U31q037LUNzdQ17duIt3y2Qa89ylNUM9NS6JKRSpf0VHbXN7G3oZmbzhjOrWeNJDMttZ1ftUi4DtWGWjMC6XBdM9L4t6+M46Rhvfnn51fw1ifb9lufkZZCl/TYBrpLRipZ6al0SU8hKz2VnC7pZMU33vuv/3zMfsviY7P2+3cKWWmppLSaidTUNvIvL6zk7tfX8teVn/KvV4zT7EAiQzMCCVX13gZ21zd9tuHOSk895l1Fx+L1j7byoz98QOXuer55+nC+c7ZmB5IcdGEa6bR6ds0gr1dX+nbPpFtmWqghAHDm6H688t3TuXziYO6Zt5aL7pjP0o3VodYkEjQFgcgBcrqk86uZ4/jt105id30Tl92zgF/8+SO1y5CkpSAQOYhpo2Kzg68UDeG+N9Zy0Z3zdREeCY27B3bNDwWByCH0yErnF5eP5ZGvT2ZvfROX3/s2P//zKs0OpMM0tzh/WraJC++Yz5sHnFjRXhQEIofhjONyefm7p3PlSUO4/41SLrzjLd7T7EACVN/UzJPvlnHWv83jliffp66pObAZgc4aEjlCb35cye3PLmfLzjpuOG043zvnOLLSdWaRtI/d9U08+U4ZD84v5dOd9YzNy+HmaSM4Z8yAYzqZ4lBnDSkIRI7CrrpG/s9LH/Hku2UMz+3Gr2eOY1J+r7DLSmi765vYXF3Lppq6z3726ZbB340bRO9uGWGXF7jtexr47YJ1PLJwAzW1jZwyog83TytkamEfzI79bDoFgUhA5n+yjR8+u5xNNbXccOowvn/uKM0O2lDX2MyWmjo21dSyubqOzTW1VMR/bq6OLd9V17Tf75iBO6SnGmeN7s/MSXmcMSqX9NTk2qO9qbqW37xVylPvbqS2sZlzx/Tn5jMLGT+kZ7s+j4JAJEC765v4+UureOKdMob37cavZo6lqKB32GV1mKbmFj7dVf/Zp/hN1bWff7KPb+ir9jT8ze/17pbBwJwsBuZ0YVDPz38O6tmFgTlZ9O+RxZqtu3l2STn/vbSCbbsb6Ns9k0snDGLmpCGMGpAdwqttP2srd3PfvLX899IK3GHG+EF864wRjOwfzOtSEIh0gAVrtnHb3Njs4OtTh/GP546iS0Zizw5aWpxte+rZVF23326bza0+3W/dVUfLAZuR7Mw0Bn62ce/CoJwsBrb6OTAn64hmTo3NLcxbXcncJRt5ddVWmlqcEwfncEVRHjPGDaJn18TZdbS8vJp7563l5RVbyExL4aqThnLDacPI69U10OdVEIh0kN31Tfziz6t4fFEZw+Kzg5MSZHZQs7eRd9ZVsah0Ox9uqmFzTS1baupobN5/G5GZlsLgnl0+39C32rjv+zSfnZUeWJ1Vu+t5bukm5i4pZ+XmnWSkpnD2mH5cMWkIp43sS1on3HXk7ixcW8U989Yyf802srPSuK64gOunFtC3e2aH1KAgEOlgb6/Zxm3PLqeiupavnTKMH5zX+WYHO+saWbxuOwvXVrGwtIqVm3fiHtvQnzg4h8G9uuy322bfhr5X1/R2OXjZHlZsqmHuknKeW7qJ7XsayM3O5LIJg5k5KS+wXSxHoqXF+euqT7ln3lqWbaymb/dMbjhtGF89eWigYdkWBYFICPbUN/HLlz/i0YUbKOjTlV/NHMfkYeHNDnbXN7F43XYWlcY2/B9W1NDisW6vE4f2pHh4X6YM7834oT0TrtFeQ1MLr6/eyjMl5by+eivNLc64IT2ZOSmPGWMHkdO1Yze6jc0tPL90E/e9sZZPtu5maO+u3Hj6cGZOygvtZAIFgUiIFq6t4rZnl1G+o5brigu47fxRdM0IvgP8nvomSjbsYOHaKhaVVvFBRQ3NLU5Gagrjh/ZkyvA+FA/vw4ShPZPqTKfKXfU8t7SCuUvK+WjLLjLSUjh3TOyso9NG5gba2LC2oZnfLy7jN2+to6K6ltEDsvnWtBFceOLA0HdZhRYEZnY+8J9AKvCgu//igPXfA24AmoBK4OvuvuFQj6kgkES0p76JX738EY8s3EB+n6786vKxnDy8T7s+R21DM0s27GBh6TYWrq1ieXkNTS1OWooxfkh8wz+iDxOH9up0u6mC4O6s2LSTufGzjqr3NtK/RyaXTczj8ol5FPbr3m7PVVPbyGML1/NfC9ZTtaeBovxe3HzmCM4c1a/T7EYLJQjMLBX4GDgHKAcWA7PcfWWrMWcC77j7XjP7FjDN3a881OMqCCSRLSqt4ra5yynbvpfrTzm22UFdYzPvbdjBwtLYJ/6lG6tpbHZSU4yxeTkUD+/DlOF9KCro1SEzkM6svqmZ11ZtZe6ScuZ9XElzizNhaE+umDSEC8cOJKfL0e062rqzjocWrOOJRWXsrm9i2qhcbp5WGOouwIMJKwiKgZ+4+3nx+z8CcPefH2T8BOAud596qMdVEEii29vQxK9eXs1v317P0N5d+eXlYyke8cWzg7rGZpZurP7s4O7SsmoamltIMThxcA5TRsR29RQV9KZ7ZrQ3/IeydVcdz72/iWeWbOTjT3eTmZbCeccPYOakPKYW9j2sXUdlVXu57821zF1STlNzC18+cSDfmjaC4wfldMArODphBcFM4Hx3vyF+fzZwsrvPOcj4u4At7v6zNtbdCNwIMHTo0EkbNhxy75FIQnintIrbnl3Ohqq9XFuczw/PH023Vhvw+qZmlm2s+Wwf/3tlO6hvasEMThiUw5ThvSke0YeTCnp3+BkoycDd+aDi87OOamobGZiTxWUTB3P5xDyG5/7trqNVm3dy77y1vLB8E2kpKVw+KY9vnj6cgr7dQngFR6bTB4GZXQPMAc5w9/pDPa5mBJJM9jY08etXYrODvF5d+MdzR7Fx+14WllaxZMMO6hpjG/4vDehB8YjYrp7Jw3of9a4MaVtdYzOvrtrK3CUbeePjSlocivJ7MXNSHheOHcjqLbu4Z95aXvtoK90yUvnqlHy+ceow+vfICrv0w9apdw2Z2dnAncRCYOsXPa6CQJLRu+u2c9vcZayv2gvA6AHZnx3cPXlY74T65myi+3RnHX98v4JnSjaytnIP6alGY7PTq2s6X5s6jGuL8xPyv0dYQZBG7GDxWUAFsYPFV7v7ilZjJgBzic0cPjmcx1UQSLKqbWjm/bIdjB7YIxLdNjs7d2dZeQ0vLNvE4F5duPKkIQl90P1QQRDYq3L3JjObA7xC7PTRh919hZn9FChx9+eBXwPdgWfip1iVufuMoGoS6cy6ZKRySmHfsMuQOLPYabft3QW0Mwo03tz9JeClA5b9U6t/nx3k84uIyBfrfN2ZRESkQykIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRl3AXpjGzSuBou871Bba1YzmJTu/H/vR+fE7vxf6S4f3Id/fctlYkXBAcCzMrOdhXrKNI78f+9H58Tu/F/pL9/dCuIRGRiFMQiIhEXNSC4IGwC+hk9H7sT+/H5/Re7C+p349IHSMQEZG/FbUZgYiIHEBBICIScZEJAjM738xWm9kaM7s97HrCYmZDzOx1M1tpZivM7Dth19QZmFmqmb1vZi+EXUvYzKynmc01s4/MbFX8srORZGbfjf+dfGhmT5pZ4lyk+AhEIgjMLBW4G7gAGAPMMrMx4VYVmibg++4+BpgCfDvC70Vr3wFWhV1EJ/GfwMvuPhoYR0TfFzMbDNwKFLn7CcSutHhVuFUFIxJBAEwG1rh7qbs3AE8BF4dcUyjcfbO7vxf/9y5if+SDw60qXGaWB1wIPBh2LWEzsxzgdOAhAHdvcPfqUIsKVxrQJX4N9q7AppDrCURUgmAwsLHV/XIivvEDMLMCYALwTsilhO3/ArcBLSHX0RkMAyqB/4rvKnvQzLqFXVQY3L0C+FegDNgM1Lj7X8KtKhhRCQI5gJl1B54F/sHdd4ZdT1jM7CJgq7svCbuWTiINmAjc6+4TgD1AJI+pmVkvYnsOhgGDgG5mdk24VQUjKkFQAQxpdT8vviySzCydWAg84e5/CLuekE0FZpjZemK7DKeb2ePhlhSqcqDc3ffNEucSC4YoOhtY5+6V7t4I/AE4JeSaAhGVIFgMjDSzYWaWQeyAz/Mh1xQKMzNi+39Xufu/h11P2Nz9R+6e5+4FxP6/eM3dk/JT3+Fw9y3ARjMbFV90FrAyxJLCVAZMMbOu8b+bs0jSA+dpYRfQEdy9yczmAK8QO/L/sLuvCLmssEwFZgMfmNnS+LL/4e4vhVeSdDK3AE/EPzSVAl8LuZ5QuPs7ZjYXeI/Y2Xbvk6StJtRiQkQk4qKya0hERA5CQSAiEnEKAhGRiFMQiIhEnIJARCTiFAQiATOzaepqKp2ZgkBEJOIUBCJxZnaNmb1rZkvN7P74NQp2m9l/xHvSv2pmufGx481skZktN7M/xvvSYGaFZvb/zGyZmb1nZiPiD9+9VY//J+LfVMXMfhG/NsRyM/vXkF66RJyCQAQwsy8BVwJT3X080Ax8FegGlLj78cAbwD/Hf+VR4IfuPhb4oNXyJ4C73X0csb40m+PLJwD/QOx6GMOBqWbWB7gUOD7+OD8L8jWKHIyCQCTmLGASsDjeeuMsYhvsFuD38TGPA6fGe/b3dPc34ssfAU43s2xgsLv/EcDd69x9b3zMu+5e7u4twFKgAKgB6oCHzOwyYN9YkQ6lIBCJMeARdx8fv41y95+0Me5oe7LUt/p3M5Dm7k3ELpo0F7gIePkoH1vkmCgIRGJeBWaaWT8AM+ttZvnE/kZmxsdcDcx39xpgh5mdFl8+G3gjfsW3cjO7JP4YmWbW9WBPGL8mRE684d93iV0WUqTDRaL7qMgXcfeVZvZj4C9mlgI0At8mdmGWyfF1W4kdRwC4DrgvvqFv3aFzNnC/mf00/hhXHOJps4Hn4hdEN+B77fyyRA6Luo+KHIKZ7Xb37mHXIRIk7RoSEYk4zQhERCJOMwIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYm4/w+MFd/qJEvsjAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Measuring LongTail ...\n", + "hDims: 100 lr: 0.002 momentum: 0.95\n", + "Trained epoch 0\tloss = 0.69055 \tt = 8.3\n", + "Trained epoch 1\tloss = 0.45243 \tt = 16.0\n", + "Trained epoch 2\tloss = 0.29425 \tt = 23.6\n", + "Trained epoch 3\tloss = 0.25288 \tt = 31.2\n", + "Trained epoch 4\tloss = 0.22337 \tt = 38.8\n", + "Trained epoch 5\tloss = 0.21395 \tt = 46.5\n", + "Trained epoch 6\tloss = 0.24403 \tt = 54.1\n", + "Trained epoch 7\tloss = 0.17864 \tt = 61.6\n", + "Trained epoch 8\tloss = 0.26049 \tt = 69.2\n", + "Trained epoch 9\tloss = 0.2356 \tt = 76.8\n", + "Trained Total: 76.78940287500001\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlQUlEQVR4nO3deXhV5bn38e+diTCEMCWQgSHMgxCIiCgOWD2KIqhoj9pqa6u1PdZqa2urbd+eXr5vPfZ0PkdqtbWtbVWOA/WgoFgHqLOEuQGCYYYQEqYQhsz3+0d2NWLEAFlZ2Xv/PteVi6y11177ztbsX9bzrOd5zN0REZH4lRB2ASIiEi4FgYhInFMQiIjEOQWBiEicUxCIiMS5pLALOF59+vTxQYMGhV2GiEhUWbp06W53z2jpsagLgkGDBlFYWBh2GSIiUcXMtnzcY2oaEhGJc4EGgZlNM7NiMysxs7taePwXZrYi8rXezPYHWY+IiHxUYE1DZpYIzAb+BdgOLDGzee6+5p/HuPs3mh3/NWBCUPWIiEjLgrwimASUuPtGd68F5gCXHeP4a4HHA6xHRERaEGQQ5ADbmm1vj+z7CDMbCOQBrwRYj4iItKCjdBZfAzzl7g0tPWhmN5tZoZkVVlRUtHNpIiKxLcgg2AH0b7adG9nXkms4RrOQuz/k7hPdfWJGRou3wYqIyAkKMgiWAMPMLM/MUmj6sJ939EFmNhLoCbwVYC2s2LafH7+wLsiXEBGJSoEFgbvXA7cCC4G1wBPuXmRm95jZzGaHXgPM8YAXRli9o5IHFm1gTemBIF9GRCTqBNpH4O4L3H24uw9x9x9F9v3A3ec1O+aH7v6RMQZtbca4LJITjbnLtgf9UiIiUaWjdBYHrkeXFM4f2ZdnVpRS39AYdjkiIh1G3AQBwKyCHHYfrOG193aHXYqISIcRV0EwdUQmPbsk87Sah0RE3hdXQZCSlMDM/GxeXLOLyiN1YZcjItIhxFUQAFx5ai619Y08v3pn2KWIiHQIcRcEY3PSGZrZjbnLPm5sm4hIfIm7IDAzZhXk8O7mvWzdczjsckREQhd3QQBw+fgczGDucnUai4jEZRBk9+jMmUN6M3fZDgIe0Cwi0uHFZRAAzJqQy9a9h1m6ZV/YpYiIhCpug2DaKf3okpLI0+o0FpE4F7dB0LVTEtNO6cdzq0qprmtxGQQRkbgQt0EAcGVBLlXV9by0dlfYpYiIhCaug2Dy4N5kpadqTIGIxLW4DoLEBOPyCTksXl9BRVVN2OWIiIQiroMAYNaEHBoanXkrS8MuRUQkFHEfBMP6pjEuN10L1ohI3Ir7IICmq4Ki0gOsK9MyliISfxQEwMzxOSQlmDqNRSQuKQiAXl1TOG9kJn9dvkPLWIpI3FEQRFxZkENFVQ1vbNgTdikiIu1KQRBx3shM0jsnq9NYROKOgiCiU1IiM/KzWFhURlW1lrEUkfihIGhmVkEu1XWNPP+PsrBLERFpNwqCZib070Fen65qHhKRuKIgaMbMuLIgh7c37mXbXi1jKSLxQUFwlMsn5ADwzHKNKRCR+KAgOEpuzy5MHtyLucu1jKWIxAcFQQtmFeSyafchlm/bH3YpIiKBUxC04OJT+pGanKBOYxGJC4EGgZlNM7NiMysxs7s+5ph/NbM1ZlZkZo8FWU9rpaUmc9GYfjy7cic19VrGUkRiW2BBYGaJwGzgYmA0cK2ZjT7qmGHA3cAUdx8DfD2oeo7XrIJcKo/U8cra8rBLEREJVJBXBJOAEnff6O61wBzgsqOO+RIw2933Abh7h/nUPWtoHzLTOvG0ZiQVkRgXZBDkANuabW+P7GtuODDczN4ws7fNbFpLJzKzm82s0MwKKyoqAir3wxITjCsm5LCouJw9B7WMpYjErrA7i5OAYcBU4Frgt2bW4+iD3P0hd5/o7hMzMjLarbhZBbnUNzrPahlLEYlhQQbBDqB/s+3cyL7mtgPz3L3O3TcB62kKhg5hRL80xmR3Z64Gl4lIDAsyCJYAw8wsz8xSgGuAeUcd8wxNVwOYWR+amoo2BljTcZtVkMuq7ZW8t6sq7FJERAIRWBC4ez1wK7AQWAs84e5FZnaPmc2MHLYQ2GNma4BXgTvdvUOtDDMzP5vEBFOnsYjELIu2aRQmTpzohYWF7fqaN/5xCUWlB3jjrk+RmGDt+toiIm3BzJa6+8SWHgu7szgqzCrIpexANW9pGUsRiUEKglY4f1QmaalJmnJCRGKSgqAVUpMTuXRcNs//o4xDNfVhlyMi0qYUBK10ZUEOR+oaeEHLWIpIjFEQtNKpA3syoFcXnlbzkIjEGAVBK5kZswpyeGvjHnbsPxJ2OSIibUZBcBxmTcjFXctYikhsURAchwG9uzBpUC/mLtuuZSxFJGYoCI7TrIIcNlQcYtX2yrBLERFpEwqC43TJuCxSkrSMpYjEDgXBceqemsyFo/syb2UptfWNYZcjInLSFAQn4MqCXPYdruPV4g6zoJqIyAlTEJyAs4f1oU+3TmoeEpGYoCA4AUmJCVw+PptX1pWz71Bt2OWIiJwUBcEJmlWQS12D89wqLWMpItFNQXCCRmd3Z2S/NC1YIyJRT0FwEq4syGXFtv1sqDgYdikiIidMQXASLhufTYKhTmMRiWoKgpOQ2T2Vs4dl8NdlO2hs1JQTIhKdFAQnaVZBDqWV1by9SctYikh0UhCcpIvG9COtUxJz1WksIlFKQXCSUpMTuWRsFs+v3snhWi1jKSLRR0HQBmYV5HCotoEXi3aFXYqIyHFTELSB0wb1IrdnZy1jKSJRSUHQBhISjFkTcni9ZDdlldVhlyMiclwUBG3kioLIMpYr1GksItFFQdBG8vp05dSBPXl6qZaxFJHooiBoQ7MKcniv/CBFpQfCLkVEpNUUBG3o0rHZpCQmqNNYRKKKgqANpXdJ5oLRmcxbUUpdg5axFJHoEGgQmNk0Mys2sxIzu6uFx28wswozWxH5uinIetrDrAm57DlUy+LiirBLERFplcCCwMwSgdnAxcBo4FozG93Cof/j7uMjX78Lqp72cu6IDHp1TWHucjUPiUh0CPKKYBJQ4u4b3b0WmANcFuDrdQjJiQnMzM/mpTXlVB6uC7scEZFPFGQQ5ADbmm1vj+w72pVmtsrMnjKz/i2dyMxuNrNCMyusqOj4TS5XnZpLbUMjz63WMpYi0vGF3Vn8LDDI3ccBfwMeaekgd3/I3Se6+8SMjIx2LfBEjMnuzvC+3TQjqYhEhSCDYAfQ/C/83Mi+97n7HneviWz+Djg1wHrajZkxqyCXpVv2sWn3obDLERE5piCDYAkwzMzyzCwFuAaY1/wAM8tqtjkTWBtgPe3q8vE5mMFfNaZARDq4wILA3euBW4GFNH3AP+HuRWZ2j5nNjBx2m5kVmdlK4DbghqDqaW/90lM5a2gf5i7XMpYi0rEF2kfg7gvcfbi7D3H3H0X2/cDd50W+v9vdx7h7vruf5+7rgqynvc0qyGH7viMs2bw37FJERD5W2J3FMe2iMf3ompKoTmMR6dAUBAHqkpLExWOzmL96J9V1DWGXIyLSIgVBwGYV5HCwpp4X12gZSxHpmBQEAZuc15vs9FSeXqq7h0SkY1IQBCwhwbiiIIfX3qug/ICWsRSRjkdB0A6umJBLo8P/rtCUEyLS8SgI2sHQzG7k9++hBWtEpENSELSTKwtyWFdWxRotYykiHYyCoJ3MGJdNcqIxV1cFItLBKAjaSc+uKXxqZCbPrCilXstYikgHoiBoR7MKctl9sIbX3tsddikiIu9TELSj80Zk0qNLMk+peUhEOhAFQTtKSUrgyoJcXvhHGSXlB8MuR0QEaGUQmNntZtbdmjxsZsvM7MKgi4tFt0wdQpfkRO57PqYmWhWRKNbaK4IvuvsB4EKgJ3A9cF9gVcWw3t068W/nDeGltbt4a8OesMsREWl1EFjk30uAP7t7UbN9cpy+OCWP7PRU7l2wVovWiEjoWhsES83sRZqCYKGZpQG6B/IEpSYncue0EazeUcm8lZp2QkTC1doguBG4CzjN3Q8DycAXAqsqDlyWn8MpOd35ycJirVUgIqFqbRCcARS7+34zuw74PlAZXFmxLyHB+O7Fo9ix/wh/fHNz2OWISBxrbRA8ABw2s3zgm8AG4E+BVRUnzhzah/NHZjL7lRL2HqoNuxwRiVOtDYJ6d3fgMuB+d58NpAVXVvy46+KRHKqt579efi/sUkQkTrU2CKrM7G6abhudb2YJNPUTyEka1jeNayYN4C9vb2HT7kNhlyMicai1QXA1UEPTeIIyIBf4SWBVxZmvXzCMTkkJ/FiDzEQkBK0KgsiH/6NAupldClS7u/oI2khmWipfOXcILxSVsWTz3rDLEZE409opJv4VeBf4NPCvwDtmdlWQhcWbm84eTN/unfjR/LU0dceIiLSP1jYNfY+mMQSfd/fPAZOA/xNcWfGnc0oi37xwBCu27Wf+6p1hlyMicaS1QZDg7uXNtvccx3Olla4syGVkvzR+/MI6auo1yExE2kdrP8xfMLOFZnaDmd0AzAcWBFdWfEpMML43fRTb9h7hz29tCbscEYkTre0svhN4CBgX+XrI3b8TZGHx6uxhGZwzPIP/fqWE/Yc1yExEgtfq5h13f9rd74h8/bU1zzGzaWZWbGYlZnbXMY670szczCa2tp5Y9t1LRlJVXcf9r5SEXYqIxIFjBoGZVZnZgRa+qszswCc8NxGYDVwMjAauNbPRLRyXBtwOvHPiP0ZsGdmvO58+tT+PvLWZrXsOh12OiMS4YwaBu6e5e/cWvtLcvfsnnHsSUOLuG929FphD0xQVR/u/wI+B6hP6CWLUHRcOJykhgR8v1CAzEQlWkHf+5ADbmm1vj+x7n5kVAP3dff6xTmRmN5tZoZkVVlRUtH2lHVDf7ql86ZzBzF+1k2Vb94VdjojEsNBuAY3MV/RzmmYzPSZ3f8jdJ7r7xIyMjOCL6yC+fM5g+nTrxL0aZCYiAQoyCHYA/Ztt50b2/VMacAqwyMw2A5OBeeow/kDXTkl888LhFG7Zx8KisrDLEZEYFWQQLAGGmVmemaUA1wDz/vmgu1e6ex93H+Tug4C3gZnuXhhgTVHn06fmMiyzG/c9v47aeq0OKiJtL7AgcPd64FZgIbAWeMLdi8zsHjObGdTrxpqkxAS+e8koNu85zGPvaJCZiLS9pCBP7u4LOGoEsrv/4GOOnRpkLdFs6ogMpgztza9efo8rCnJJ76ylIESk7Wi+oChgZnz3klHsP1LHrxdpkJmItC0FQZQYk53OFRNy+MMbm9m+T4PMRKTtKAiiyLcuHIEBP11YHHYpIhJDFARRJLtHZ246O49nVpSyavv+sMsRkRihIIgyXzl3CL27pmglMxFpMwqCKJOWmszXLxjGO5v28vLa8k9+gojIJ1AQRKFrJg1gcEZX7n1+LXUNGmQmIidHQRCFkhMTuPviUWysOMScJds++QkiIsegIIhSF4zKZFJeL3710nqqquvCLkdEopiCIEqZGd+7ZBS7D9by4OKNYZcjIlFMQRDF8vv34LLx2fz2tY3srDwSdjkiEqUUBFHuWxeOwIGfLlwfdikiEqUUBFGuf68ufOHMQcxdvp2i0sqwyxGRKKQgiAG3nDeU9M7J3LtAg8xE5PgpCGJAeudkbj9/GG+U7GHR+vhY01lE2o6CIEZ89vSBDOrdhf9YsJZ6DTITkeOgIIgRKUkJfGfaSNbvOshTS7eHXY6IRBEFQQyZdko/Th3Yk5/9bT2HaurDLkdEooSCIIaYGd+bPoqKqhoe+rsGmYlI6ygIYkzBgJ5MH5vFQ3/fSPmB6rDLEZEooCCIQd+eNoL6xkZ+/jcNMhORT6YgiEEDe3flc2cM4onCbawrOxB2OSLSwSkIYtTXPjWUbp2S+I8F68IuRUQ6OAVBjOrRJYWvfWoYi9dX8Np7GmQmIh9PQRDDPnfmQHJ7duZH89fS0KipJ0SkZQqCGNYpKZHvTBvJurIq5i7TIDMRaZmCIMZdOi6L/P49+OmLxRypbQi7HBHpgBQEMe6fK5ntOlDDw69rkJmIfJSCIA5MyuvFRWP68sCiDVRU1YRdjoh0MAqCOPGdaSOpqW/kly9pkJmIfFigQWBm08ys2MxKzOyuFh7/ipmtNrMVZva6mY0Osp54NjijG589fQBzlmyjpLwq7HJEpAMJLAjMLBGYDVwMjAaubeGD/jF3H+vu44H/BH4eVD0Ct50/jC7Jidz3vAaZicgHgrwimASUuPtGd68F5gCXNT/A3ZvPf9AV0M3uAerdrRO3nDeUl9aW8+aG3WGXIyIdRJBBkANsa7a9PbLvQ8zsq2a2gaYrgttaOpGZ3WxmhWZWWFGhUbIn4wtTBpHTozP3LlhLowaZiQgdoLPY3We7+xDgO8D3P+aYh9x9ortPzMjIaN8CY0xqciLfumg4/9hxgD+9tVmL3YtIoEGwA+jfbDs3su/jzAEuD7AeibgsP4fTBvXkh8+u4ZL/ep15K0s1BYVIHAsyCJYAw8wsz8xSgGuAec0PMLNhzTanA+8FWI9EJCQYj31pMj/9dD51DY3c9vhyPvWzRTz2zlZq6jX6WCTeWJBNA2Z2CfBLIBH4vbv/yMzuAQrdfZ6Z/Qq4AKgD9gG3unvRsc45ceJELywsDKzmeNPY6Ly4ZhcPLCph5fZKMtM6cdPZeXzm9IF065QUdnki0kbMbKm7T2zxsWhrI1YQBMPdeXPDHh5YtIHXS3aT3jmZz58xkBum5NGra0rY5YnISVIQyHFZuW0/v15UwsKiXaQmJ3DNaQO4+ZzBZPfoHHZpInKCFARyQkrKq3hg0Ub+d0VTH//lE3L4yrlDGJrZLeTKROR4KQjkpOzYf4Tf/n0jc5Zspaa+kYtG9+OW84YwLrdH2KWJSCspCKRN7DlYwx/f3Mwjb27mQHU9Zw3twy1Th3DGkN6YWdjlicgxKAikTVVV1/HYO1v53eubqKiqIb9/D26ZOoR/GdWXhAQFgkhHpCCQQFTXNfD0su08uHgjW/ceZlhmN75y7hBmjs8mOTH0Qesi0oyCQAJV39DI/NU7eWDRBtaVVZHTozNfOjuPq08bQOeUxLDLExEUBNJO3J1Xi8v59asbKNyyj95dU/jClEFcf8Yg0jsnh12eSFxTEEi7e3fTXn69qIRFxRV065TEZycP4Maz8shMSw27NJHj8vbGPVTXNTB5cG9Sk6P3CldBIKEpKq3kgUUbWLB6J0mJCXz61Fy+fM4QBvTuEnZpIp/o14tK+M8XigHolJTA5MG9mToig3OHZ5DXp2tU3S2nIJDQbd59iAf/vpGnl26nvrGRGfnZ/NvUIYzs1z3s0kQ+wt2574V1PLh4IzPzs5lVkMPi9RUsLq5g4+5DAAzo1YWpIzKYOiKDyYN70yWlY8/NpSCQDmPXgWoefn0Tj769hUO1DXxqZCbXTR7AWUMzSEnSnUYSvoZG5/vP/IPH393KdZMHcM/MUz50W/TWPYdZvL6cRcUVvLlhD0fqGkhJSuD0vF6cOzyDqSMyGZLR8a4WFATS4ew/XMuf3trCH97YxL7DdXRPTeKiMf2YPi6LKUP76PZTCUVtfSN3PLGC51bt5JapQ7jzohHH/ECvrmtgyea9LC6uYNH6CkrKDwKQ27Pz+6Fw5pDedO0AM/kqCKTDqqlv4I2S3Ty3cid/W7OLqpp6enRJZlokFM4Y3JskhYK0gyO1Dfzbo0tZVFzB3ReP5MvnDjnuc2zbe7ipCWl9BW+W7OZQbQPJicZpg3pFmpEyGZbZLZSrBQWBRIXqugZee28381eV8rc1uzhU20CvrilcNKYfM8ZlMSmvl0KhHfxzLet4GiV+oLqOm/5YyJIte7n3irFcO2nASZ+ztr6Rws17WRTpWyjeVQVAdnoq547I4NzhmUwZ2pu01Pa5tVpBIFGnuq6BRcUVzF+9k5fX7uJwbQN9uqUw7ZR+TB+bzaS8XiTG0QdV0Nyd1TsqeXZlKc+t2klNfSP3zRrLhWP6hV1a4PYcrOFzv3+X9buq+MXV47l0XHYgr1O6/8j7Hc6vl+zmYE09SQnGqQN7MnVEJlNHZDCyX1pgVwsKAolqR2obWFRcznOrdvLyul1U1zWSkdaJS07px/Rx2Uwc2DOu/nptS8VlVTy7spRnV5WyZc9hkhONc4ZlUHagmqLSA9x4Vh7fmTYyZjvyS/cf4bqH36F0/xEeuO5UzhuR2S6vW9fQyNIt+1hUXMGi4nLWlTVdLfTt3un9voUpQ/u06UBMBYHEjMO19byyrpznVu7k1eJyauob6du9E5eMzeLScVlM6K9Q+CSbdh/iuciH//pdB0kwmDK0DzPGZXPRmH6kd0mmpr6B/1iwjj++uZn8/j24/9oJ9O8VW2M/NlYc5PqH3+XAkToevuE0JuX1Cq2WXQeqIx3O5bz23m6qqutJTDAKBvRg6ohMzh2ewZjs7id1taAgkJh0sKael9fuYv6qnSxaX0FtfSPZ6alcMjaL6eOyGN+/R4e7hS8sO/Yf4blIs8/qHZUATBrUixn5WUw7JYuMtE4tPu/51Tv59tOrAPjJVflMOyU2moqKSiv5/O/fxR0e+eIkTslJD7uk99U3NLJ8234WFTfdolpUegCAjLROfH/6KC4bn3NC51UQSMyrqq7jpUgoLF5fQV2Dk9OjM9PHNV0pjM1Jj7tQKK+qZsGqnTy7aidLt+wDID83nRn52VwyNqvVS49u3XOYrz2+jJXbK7nhzEHcfclIOiVF71QLhZv38oU/LiGtUxJ/vul0hmR07BX3yquq+fv63SwqLue6yQOZPLj3CZ1HQSBxpfJIHX9bs4v5q0p57b3d1Dc6/Xt1ZvrYbC4dl3XSl9gd2b5DtbxQVMazK0t5e+MeGh1G9ktjRn7Tzz6wd9cTOm9tfSP3Pb+O37+xibE56dz/mQknfK4wLSou5yt/WUp2emf+fNPp5MTROtwKAolb+w/X8uKaXTy3aidvlOymodEZ1LsL08dlMX1sNqOygrtLo71UVTcF37MrPwi+vD5dmZGfzYxxWQzrm9Zmr/ViURnfenIl7nDfleOYPi6rzc4dtPmrdvL1/1nOsMw0/nTjJPp0a7k5LFYpCERo+mt5YVEZ81fv5M0Ne2hodAZndOXSsVlcPDaLIRndoubumCO1Dby8runD/9Xipv6RnB6duTQ/ixnjsgO96tm+7zC3PracFdv2c/3kgXxv+qgOPyvnnHe38t2/rqZgQE8evuG0uJwWXUEgcpQ9B2t4oaiM+at2vt+EYga9u3YiKz2Vvt1TyUpPpV96Kv26R/6NfB/WdAE19Q38ff1unl1ZykuRsRUZaZ2YPjaLGfnZFAxov87x2vpGfrJwHb99bRNjsrtz/2cKyOvTMZuKHvr7Bu5dsI5zh2fwm+tOjdvFkhQEIsdQUVXDouJytu87wq4D1eysrH7/38ojdR85Pi016YNwiARG3/TUZgHSmZ5dktvkQ7m+oZE3N+zh2ZWlvFBURlV1PT27JDPtlCxm5Gdxel7vUAfWvbx2F998ciX1Dc69s8YyMz+YwVgnwt356YvFzH51A9PHZvGLq8dHzRVfEBQEIifoSG0DZQeqKauspuzAkaaQqPxwWFQcrOHoX6OUpISPhsVRoZHRrVOLU2Y0NjpLNu/l2VWlLFhdxt5DtaR1SuLCMf2Ykd/xJuUr3X+Erz2+nKVb9vGZ0wfwg0tHh95U1Njo/Pu8Iv789hauOa0/P7pibNyPRFcQiASorqGRiqqaDwKjsrrF72sbGj/0vARruje8eWA4sLCojF0HakhNTuCCUX2ZkZ/NucMzQv9wPZa6hkZ+9uJ6frN4AyP7pTH7swWh3ZZZ19DInU+u5JkVpXz5nMHcdfHIqL8hoC0oCERC5u7sO1zHzsoPmp9aCo2a+kbOHZHBjPxszh+Z2SGmLz4er64r544nVlBT38i9V4zl8gknNvjpRFXXNXDrY8t4aW05d140glumDlEIRCgIRKJEY6NH/RQZOyuPcNvjy1myeR9XT+zPD2eOaZcO2oM19dz0yBLe2bSXey47hesnDwz8NaPJsYKg4zQ0ikjUhwBAVnpnHv/SZL563hCeWLqNy2e/QUl5VaCvufdQLZ/57dss2byPX149XiFwnAINAjObZmbFZlZiZne18PgdZrbGzFaZ2ctmpv96IjEgKTGBOy8aySNfmMTugzXM+O83eHrp9kBeq6yymqsffIvisioeuv7UE56LJ54FFgRmlgjMBi4GRgPXmtnoow5bDkx093HAU8B/BlWPiLS/c4ZnsOD2s8nvn843n1zJt55cyeHa+jY7/5Y9h7jqN2+ys7KaR744ifNH9W2zc8eTIK8IJgEl7r7R3WuBOcBlzQ9w91fd/XBk820gN8B6RCQEfbun8uhNk7nt/GE8vWw7l93/But3nXxT0bqyA1z1m7c4VFPPY186/YQnY5NggyAH2NZse3tk38e5EXi+pQfM7GYzKzSzwoqKijYsUUTaQ2KCcce/DOcvN57OvsN1zLz/dZ5Yso0TvVll2dZ9XP3g2ySa8cSXz2Bcbo+2LTjOdIjOYjO7DpgI/KSlx939IXef6O4TMzIy2rc4EWkzU4b2YcHtZ1EwoCfffnoV33xiJYdqjq+p6PX3dnPd796hR5dknvzKGW06qV68CjIIdgD9m23nRvZ9iJldAHwPmOnuNQHWIyIdQGZaKn++8XS+ccFwnlmxg5n3v866sgOteu4L/yjji39cwoBeXXjyy2fE3KppYQkyCJYAw8wsz8xSgGuAec0PMLMJwIM0hUB5gLWISAeSmGDcfsEwHr1pMgeq67ns/jd4/N2tx2wqerJwG7c8upQxOd2Zc/NkMruntmPFsS2wIHD3euBWYCGwFnjC3YvM7B4zmxk57CdAN+BJM1thZvM+5nQiEoPOGNKb528/m0l5vbh77mpun7OCgy00Ff3+9U3c+dQqzhzSh7/ceDo9uqSEUG3s0shiEQldY6PzwOIN/OzFYgb27sr9n5nAmOx03J1fvvQev3r5PaaN6cevrh0f1ctkhulYI4ujayITEYlJCQnGV88bysSBPbltznKu+PWb/J9LR7Ox4iB/eGMzV52ay32zxrY4W6ucPF0RiEiHsudgDXc8sZLF65tuFf/ilDy+P31UTEy/ESZdEYhI1OjdrRN/uOE0/vTWZsyMz50xUDOIBkxBICIdTkKCccOUvLDLiBtqcBMRiXMKAhGROKcgEBGJcwoCEZE4pyAQEYlzCgIRkTinIBARiXMKAhGROBd1U0yYWQWw5QSf3gfY3YblRDu9Hx+m9+MDei8+LBbej4Hu3uLKXlEXBCfDzAo/bq6NeKT348P0fnxA78WHxfr7oaYhEZE4pyAQEYlz8RYED4VdQAej9+PD9H58QO/Fh8X0+xFXfQQiIvJR8XZFICIiR1EQiIjEubgJAjObZmbFZlZiZneFXU9YzKy/mb1qZmvMrMjMbg+7po7AzBLNbLmZPRd2LWEzsx5m9pSZrTOztWZ2Rtg1hcXMvhH5PfmHmT1uZqlh1xSEuAgCM0sEZgMXA6OBa81sdLhVhaYe+Ka7jwYmA1+N4/eiuduBtWEX0UH8CnjB3UcC+cTp+2JmOcBtwER3PwVIBK4Jt6pgxEUQAJOAEnff6O61wBzgspBrCoW773T3ZZHvq2j6Jc8Jt6pwmVkuMB34Xdi1hM3M0oFzgIcB3L3W3feHWlS4koDOZpYEdAFKQ64nEPESBDnAtmbb24nzDz8AMxsETADeCbmUsP0S+DbQGHIdHUEeUAH8IdJU9jsz6xp2UWFw9x3AT4GtwE6g0t1fDLeqYMRLEMhRzKwb8DTwdXc/EHY9YTGzS4Fyd18adi0dRBJQADzg7hOAQ0Bc9qmZWU+aWg7ygGygq5ldF25VwYiXINgB9G+2nRvZF5fMLJmmEHjU3eeGXU/IpgAzzWwzTU2GnzKzv4RbUqi2A9vd/Z9XiU/RFAzx6AJgk7tXuHsdMBc4M+SaAhEvQbAEGGZmeWaWQlOHz7yQawqFmRlN7b9r3f3nYdcTNne/291z3X0QTf9fvOLuMflXX2u4exmwzcxGRHadD6wJsaQwbQUmm1mXyO/N+cRox3lS2AW0B3evN7NbgYU09fz/3t2LQi4rLFOA64HVZrYisu+77r4gvJKkg/ka8Gjkj6aNwBdCricU7v6OmT0FLKPpbrvlxOhUE5piQkQkzsVL05CIiHwMBYGISJxTEIiIxDkFgYhInFMQiIjEOQWBSMDMbKpmNZWOTEEgIhLnFAQiEWZ2nZm9a2YrzOzByBoFB83sF5E56V82s4zIsePN7G0zW2Vmf43MS4OZDTWzl8xspZktM7MhkdN3azbH/6ORkaqY2X2RtSFWmdlPQ/rRJc4pCEQAMxsFXA1McffxQAPwWaArUOjuY4DFwL9HnvIn4DvuPg5Y3Wz/o8Bsd8+naV6anZH9E4Cv07QexmBgipn1Bq4AxkTO8/+C/BlFPo6CQKTJ+cCpwJLI1Bvn0/SB3Qj8T+SYvwBnRebs7+HuiyP7HwHOMbM0IMfd/wrg7tXufjhyzLvuvt3dG4EVwCCgEqgGHjazWcA/jxVpVwoCkSYGPOLu4yNfI9z9hy0cd6JzstQ0+74BSHL3epoWTXoKuBR44QTPLXJSFAQiTV4GrjKzTAAz62VmA2n6HbkqcsxngNfdvRLYZ2ZnR/ZfDyyOrPi23cwuj5yjk5l1+bgXjKwJkR6Z8O8bNC0LKdLu4mL2UZFP4u5rzOz7wItmlgDUAV+laWGWSZHHymnqRwD4PPCbyAd98xk6rwceNLN7Iuf49DFeNg3438iC6Abc0cY/lkiraPZRkWMws4Pu3i3sOkSCpKYhEZE4pysCEZE4pysCEZE4pyAQEYlzCgIRkTinIBARiXMKAhGROPf/AbQIt5HFLc2iAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating Random ...\n", + "Measuring RMSE&MAE ...\n", + "Measuring LongTail ...\n", + "Algorithm RMSE MAE LongTail \n", + "RBM 1.1696 0.9774 848.7648 \n", + "Random 1.4419 1.1493 4587.0864 \n" + ] + } + ], + "source": [ + "#评估两个推荐系统的表现\n", + "evaluator.Evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1bb9b8a2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "For user: 352\n", + "\n", + "Algorithm: RBM\n", + "building...\n", + "\n", + "Recommend:\n", + "Moon (2009) 4.423\n", + "How to Train Your Dragon (2010) 4.411\n", + "Smoke (1995) 4.41\n", + "Harry Potter and the Half-Blood Prince (2009) 4.4\n", + "Harry Potter and the Deathly Hallows: Part 2 (2011) 4.4\n", + "Sherlock Holmes (2009) 4.39\n", + "Band of Brothers (2001) 4.382\n", + "Blood Diamond (2006) 4.369\n", + "Sherlock Holmes: A Game of Shadows (2011) 4.363\n", + "Road Warrior, The (Mad Max 2) (1981) 4.362\n", + "\n", + "Algorithm: Random\n", + "\n", + "Recommend:\n", + "Ciao, Professore! (Io speriamo che me la cavo) (1992) 5\n", + "Last Action Hero (1993) 5\n", + "Picnic at Hanging Rock (1975) 5\n", + "Reno 911!: Miami (2007) 5\n", + "Shining, The (1980) 5\n", + "Midnight Run (1988) 5\n", + "Judas Kiss (2011) 5\n", + "Shining Through (1992) 5\n", + "Passengers (2008) 5\n", + "Twilight Saga: New Moon, The (2009) 5\n" + ] + } + ], + "source": [ + "#推荐topN电影\n", + "evaluator.EvalTopNRecs(ml,uidRec = 352)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97c96a0f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.py" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.py" new file mode 100644 index 0000000000000000000000000000000000000000..2f6075990088feced188f1c072f128e4d65a2b58 --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/Main.py" @@ -0,0 +1,41 @@ +from MovieLens import MovieLens +from RBMAlgo import RBMAlgo +from Evaluator import Evaluator + +from surprise import NormalPredictor + + + +#本地用 +ratingsPath = 'J:/DeepLearning/ml-latest-small/ratings.csv' +moviesPath = 'J:/DeepLearning/ml-latest-small/movies.csv' + +#ModelArts上用 +#ratingsPath = '/home/ma-user/work/ml-latest-small/ratings.csv' +#moviesPath = '/home/ma-user/work/ml-latest-small/movies.csv' + + +ml = MovieLens(ratingsPath,moviesPath) +evalData = ml.loadMovieLensData() +rankings = ml.getPopularityRanks() + + +# 创建evaluator,装载数据和不同算法 +#print("initiating Evaluator....") +evaluator = Evaluator(evalData, rankings) + +# 创建RBM实例 +rbmAlgo = RBMAlgo(batchSize = 100, epochs=10,hDims=100,lr=0.002,momentum = 0.95) +evaluator.AddAlgorithm(rbmAlgo, "RBM") + +# 使用随机推荐系统 +# 随机评估打分是用global_mean和sigma的正态分布 +RandomAlgo = NormalPredictor() +evaluator.AddAlgorithm(RandomAlgo, "Random") + + +#评估两个推荐系统的表现 +evaluator.Evaluate() + +#推荐topN电影 +evaluator.EvalTopNRecs(ml,uidRec = 352) diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/MovieLens.py" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/MovieLens.py" new file mode 100644 index 0000000000000000000000000000000000000000..4e5339e7eba33b83d5730b89caa961f6646ba7eb --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/MovieLens.py" @@ -0,0 +1,61 @@ +import csv +from surprise import Dataset +from surprise import Reader +from collections import defaultdict + +class MovieLens: + + def __init__(self,ratingsPath,moviesPath): + self.movieID_to_name = {} + self.name_to_movieID = {} + self.ratingsPath = ratingsPath + self.moviesPath = moviesPath + + + #读取并整理数据 + def loadMovieLensData(self): + ratingsDataset = 0 + self.movieID_to_name = {} + self.name_to_movieID = {} + + #设置reader参数 + reader = Reader(line_format='user item rating timestamp', sep=',',rating_scale=(0.5, 5), skip_lines=1) + ratingsDataset = Dataset.load_from_file(self.ratingsPath, reader=reader) + + with open(self.moviesPath, newline='', encoding='ISO-8859-1') as csvfile: + movieReader = csv.reader(csvfile) + next(movieReader) #跳过第一行 + for row in movieReader: + movieID = int(row[0]) + movieName = row[1] + # 建立电影名和电影ID之间的相互映射 + self.movieID_to_name[movieID] = movieName + self.name_to_movieID[movieName] = movieID + + return ratingsDataset + + # 获取movie的流行度排名 + def getPopularityRanks(self): + ratings = defaultdict(int) + rankings = defaultdict(int) + with open(self.ratingsPath, newline='') as csvfile: + ratingReader = csv.reader(csvfile) + next(ratingReader) + for row in ratingReader: + movieID = int(row[1]) + ratings[movieID] += 1 # 出现一次就记一次,不管rating为多少 + + rank = 1 + # 将count排序后赋值rank + for movieID, ratingCount in sorted(ratings.items(), key=lambda x: x[1], reverse=True): + rankings[movieID] = rank + rank += 1 + return rankings + + + def getMovieName(self, movieID): + if movieID in self.movieID_to_name: + return self.movieID_to_name[movieID] + else: + return "" + diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBM.py" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBM.py" new file mode 100644 index 0000000000000000000000000000000000000000..ed7f5a6e113ddb7c8a889b09d309b0fcafc5038a --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBM.py" @@ -0,0 +1,150 @@ +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +import mindspore as ms +import time +import matplotlib.pyplot as plt + +class RBM(object): + + def __init__(self, vDims, hDims=50, rLevels=10, lr=0.001, momentum = 0.9, batchSize=100, epochs=20): + + self.vDims = vDims + self.epochs = epochs + self.hDims = round(hDims + 10e-5) + self.rLevels = rLevels + self.lr = lr + self.momentum = momentum + self.batchSize = batchSize + self.loss = [] + self.loss_mean = np.zeros(self.epochs) + + + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU() + self.softmax = nn.Softmax() + + + def Train(self, XbinR, verbose = False): + #XbinR为nUsers*(nItems*rLevels)大小的2D二值矩阵 + start_time = time.perf_counter() + + + # 初始化 + maxWeight = 4.0 * np.sqrt(6.0 / (self.hDims + self.vDims)) + self.weights = np.random.uniform(low =-maxWeight, high =maxWeight,size = (self.vDims, self.hDims)) + self.hBias = np.zeros([self.hDims]) + self.vBias = np.zeros([self.vDims]) + + trXbinR = np.array(XbinR) + self.iter_per_epoch = max(trXbinR.shape[0] // self.batchSize, 1) + self.nItems = trXbinR.shape[1]/self.rLevels + + for epoch in range(self.epochs): + #防止每次都用同样一堆的batch去train + np.random.shuffle(trXbinR) + for i in range(self.iter_per_epoch): + batch = trXbinR[i:i+self.batchSize] # 大小为batchSize*(nItems*rLevels) + self.doCD_1(batch) + + recs = self.OnePass(batch) + recs = recs.reshape(self.batchSize,-1,self.rLevels) + rated_sum = np.sum(batch) + + #求损失函数值 + loss = -np.sum((np.log(recs.reshape(1,-1))*batch.reshape(1,-1)))/rated_sum + self.loss.append(loss) + #print("loss = ",self.loss[epoch*self.iter_per_epoch+i],'\t',end='') + self.loss_mean[epoch] += loss/self.iter_per_epoch + if(verbose): + print("Trained epoch ", epoch, end='') + print("\tloss = ",round(self.loss_mean[epoch],5),end='') + print(" \tt = "+ str(round(time.perf_counter()-start_time,1))) + + + if(verbose): + print("Trained Total: "+str(time.perf_counter()-start_time)) + + x = np.arange(len(self.loss_mean)) + plt.plot(x, self.loss_mean) + plt.xlabel("epochs") + plt.ylabel("loss") + plt.show() + + # + def doCD_1(self, v0, method = 1): + # 采用CD-1算法 + + # 获取隐含层h0的激活概率 + h0Prob = self.sigmoid(Tensor((np.matmul(v0, self.weights) + self.hBias),ms.float32)) + h0Prob = h0Prob.asnumpy() + # 采样隐藏层 + h0GS = self.relu(Tensor(np.sign(h0Prob - np.random.uniform(size = (np.shape(h0Prob)))),ms.float32)) + h0GS = h0GS.asnumpy() + + + if (method == 0): + #''' one Way + # 计算data + posCD = np.matmul(np.transpose(v0), h0Prob) + + # 重构可见层 + v1Prob = self.sigmoid(Tensor(np.matmul(h0GS, np.transpose(self.weights)) + self.vBias,ms.float32)) + v1Prob = v1Prob.asnumpy() + v1GS = self.relu(Tensor(np.sign(v1Prob - np.random.uniform(size = (np.shape(v1Prob)))),ms.float32)) + v1GS = v1GS.asnumpy() + + # 获取recon隐含层h1的激活概率 + h1Prob = self.sigmoid(Tensor(np.matmul(v1GS, self.weights) + self.hBias,ms.float32)) + h1Prob = h1Prob.asnumpy() + # 计算recon + negCD = np.matmul(np.transpose(v1GS), h1Prob) + + # 更新参数 + # 更新权重 + self.weights = self.momentum*self.weights + self.lr * (posCD - negCD) + # 更新可见层偏置 + self.vBias = self.momentum*self.vBias + self.lr * np.mean(v0 - v1GS, axis=0) + # 更新隐藏层偏置 + self.hBias = self.momentum*self.hBias + self.lr * np.mean(h0Prob - h1Prob, axis=0) + #''' + + else: + #''' another way + #计算data + posCD = np.matmul(np.transpose(v0), h0GS) + + v1_ = np.matmul(h0GS, np.transpose(self.weights)) + self.vBias + #获取有rating存在的mask + vMask = np.sign(v0) + vMask3D = np.reshape(vMask, (np.shape(v1_)[0], -1, self.rLevels)) + vMask3D = np.max(vMask3D, axis=2, keepdims=True) + #对v1进行mask + v1_ = np.reshape(v1_, (np.shape(v1_)[0], -1, self.rLevels)) + v_masked = v1_ * vMask3D + #利用softmax获取v1概率 + v1Prob = self.softmax(Tensor(v_masked,ms.float32)) + v1Prob = v1Prob.asnumpy() + #v1Prob = v1Prob*vMask3D + v1Prob = np.reshape(v1Prob, (np.shape(v1_)[0], -1)) + + # 计算重构后的h1概率分布 + h1Prob = self.sigmoid(Tensor(np.matmul(v1Prob, self.weights) + self.hBias,ms.float32)) + h1Prob = h1Prob.asnumpy() + # 技术recon + negCD = np.matmul(np.transpose(v1Prob), h1Prob) + + #更新权重和偏置 + self.weights = self.momentum*self.weights + self.lr * (posCD - negCD) + self.vBias = self.momentum*self.vBias + self.lr * np.mean(v0 - v1Prob, axis=0) + self.hBias = self.momentum*self.hBias + self.lr * np.mean(h0Prob - h1Prob, axis=0) + #''' + + + def OnePass(self,v0): + h0 = self.sigmoid(Tensor(np.matmul(v0, self.weights) + self.hBias,ms.float32)) + h0 = h0.asnumpy() + v1 = self.sigmoid(Tensor(np.matmul(h0, np.transpose(self.weights)) + self.vBias,ms.float32)) + v1 = v1.asnumpy() + return v1 + diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBMAlgo.py" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBMAlgo.py" new file mode 100644 index 0000000000000000000000000000000000000000..3d63ffd9436527faa43e0f257d81e71d59ba79a0 --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/RBMAlgo.py" @@ -0,0 +1,82 @@ +from surprise import AlgoBase +from surprise import PredictionImpossible +import numpy as np +from RBM import RBM + +# 对接和处理RBM的输入输出 +class RBMAlgo(AlgoBase): + + def __init__(self, hDims=100, rLevels=10, lr=0.001, momentum = 0.9,batchSize=100, epochs=20, sim_options={}): + AlgoBase.__init__(self) + self.epochs = epochs + self.hDims = hDims + self.rLevels = rLevels + self.lr = lr + self.momentum = momentum + self.batchSize = batchSize + + + def fit(self, trainset, getRec = 0): + AlgoBase.fit(self, trainset) + self.nUsers = trainset.n_users + self.nItems = trainset.n_items + #self.trainXrealR = np.zeros([self.nUsers, self.nItems], dtype=np.float32) + self.trainXbinR = np.zeros([self.nUsers, self.nItems, self.rLevels], dtype=np.int32) + + for (iuid, iiid, rating) in trainset.all_ratings(): + #self.trainXrealR[int(iuid), int(iiid)] = float(rating) + adjustedRating = int(float(rating)*2.0 + 1e-5) - 1 #调整为0-9的十个整数 + self.trainXbinR[int(iuid), int(iiid), adjustedRating] = 1 #标记存在的rating + + + # 展平为2D矩阵,每行代表一个uer的i,r信息 + self.trainXbinR = np.reshape(self.trainXbinR, [self.trainXbinR.shape[0], -1]) + + + if (getRec == 0): + print('hDims: ',self.hDims, ' lr: ',self.lr, ' momentum: ',self.momentum) + verbose = True + else: + verbose = False + #self.epochs = max(self.epochs + np.random.randint(-2,2),0) + # 创建RBM实例并训练,可见层结点为nItems*rLevels + self.rbm = RBM(self.trainXbinR.shape[1], hDims=self.hDims, lr=self.lr, momentum = self.momentum, rLevels = self.rLevels, batchSize=self.batchSize, epochs=self.epochs) + self.rbm.Train(self.trainXbinR,verbose = verbose) + + #开辟nUsers*nItems大小矩阵储存rating + self.predictedRatings = np.zeros([self.nUsers, self.nItems], dtype=np.float32) + + #''' + recs = self.rbm.OnePass(self.trainXbinR) + + #归一化recs + recs = recs.reshape(self.nUsers,-1,self.rLevels) + recs_sum = np.sum(recs,axis=2) + recs_sum = np.reshape(np.repeat(recs_sum, self.rLevels, axis = 1),recs.shape) + recs_noml = recs/recs_sum + + #转换为rating + rBase = np.arange(self.rLevels) + rBase = np.tile(rBase,(self.nUsers*self.nItems,1)) + rBase = rBase.reshape(self.nUsers,-1,self.rLevels) + + # 行列号是用innerID表示的 + self.predictedRatings = getRec+(np.sum(rBase*recs_noml,axis = 2)+1)*0.5 + #''' + + + return self + + #对父类AlgoBase的estimate()进行实现,以供父类的test()调用 + def estimate(self, iuid, iiid): + + if not (self.trainset.knows_user(iuid) and self.trainset.knows_item(iiid)): + raise PredictionImpossible('User and/or item is unkown.') + + rating = self.predictedRatings[iuid, iiid] + + if (rating < 0.001): + raise PredictionImpossible('No valid prediction exists.') + + return rating + \ No newline at end of file diff --git "a/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/README.md" "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/README.md" new file mode 100644 index 0000000000000000000000000000000000000000..d2bc3ec2f3461318c5b84137a53b3d5f37b5f00b --- /dev/null +++ "b/code/2021_autumn/\350\221\243\346\235\260\346\237\257-RBM\345\234\250\345\215\217\345\220\214\350\277\207\346\273\244\347\232\204\347\224\265\345\275\261\346\216\250\350\215\220\347\263\273\347\273\237\344\270\255\347\232\204\345\272\224\347\224\250/README.md" @@ -0,0 +1,19 @@ +# RBM在协同过滤的电影推荐系统中的应用 +## 简介 +- 使用[MovieLens](https://grouplens.org/datasets/movielens/)的ml-latest-small数据集,一共671名用户对9066部电影的100004条评分。 +- 实现“Ruslan Salakhutdinov, Andriy Mnih, Geoffrey Hinton. Restricted Boltzmann machines for collaborative filtering[P]. Machine learning, 2007.”文中所提的基于协同过滤的RBM。(更多参考文献见文档) +- 增加“长尾效应”评测指标和TopN推荐的实现 + +## 环境配置 +- Python 3.7.11 +- Mindspore 1.3.0 +- scikit-surprise 1.1.1 +> 可本地配置相关环境运行 +> 也可选用ModelArts平台中的配置为tensorflow1.15-mindspore1.3.0-cann5.0.2-euler2.8-aarch64或者mindspore1.2.0-openmpi2.1.1-ubuntu18.04的镜像,激活mindspore环境安装surprise库后即可运行 + +## 文件介绍 +- Main.py 和 Main.ipynb 为主函数文件,两者功能一样,根据相应的环境选择一个就行 +- RBM.py 主要负责网络的训练和学习 +- RBMAlgo.py 负责对接和处理RBM的输入输出 +- MovieLens.py 负责读取和整理MovieLens数据集 +- Evaluator.py 负责构建训练/测试数据集、评价系统相关指标以及生成TopN推荐结果 \ No newline at end of file