From 00eaa884cd3437f3e8741d1a5ae82cde74b553b0 Mon Sep 17 00:00:00 2001 From: anettapik <120940816+anettapik@users.noreply.github.com> Date: Tue, 21 Nov 2023 12:43:04 +0300 Subject: [PATCH 1/7] =?UTF-8?q?=D0=92=D1=82=D0=BE=D1=80=D0=B0=D1=8F=20?= =?UTF-8?q?=D0=B4=D0=BE=D0=BC=D0=B0=D1=88=D0=BD=D1=8F=D1=8F=20=D1=80=D0=B0?= =?UTF-8?q?=D0=B1=D0=BE=D1=82=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hw_2.ipynb | 440 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 440 insertions(+) create mode 100644 hw_2.ipynb diff --git a/hw_2.ipynb b/hw_2.ipynb new file mode 100644 index 00000000..b4d81fa6 --- /dev/null +++ b/hw_2.ipynb @@ -0,0 +1,440 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pprint import pprint\n", + "\n", + "import copy\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from implicit.nearest_neighbours import TFIDFRecommender, BM25Recommender\n", + "from implicit.als import AlternatingLeastSquares\n", + "\n", + "\n", + "from rectools import Columns\n", + "from rectools.dataset import Interactions, Dataset\n", + "from rectools.metrics import Precision, Recall, MeanInvUserFreq, Serendipity, calc_metrics, MAP, MRR\n", + "from rectools.models import ImplicitItemKNNWrapperModel, RandomModel, PopularModel\n", + "from rectools.model_selection import TimeRangeSplitter" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv('data_original/interactions.csv', parse_dates=['last_watch_dt'])\n", + "\n", + "df.rename(\n", + " columns={\n", + " 'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight\n", + " }, \n", + " inplace=True) \n", + "\n", + "interactions = Interactions(df)\n", + "\n", + "\n", + "users = pd.read_csv('data_original/users.csv')\n", + "items = pd.read_csv('data_original/items.csv')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idageincomesexkids_flg
373089666262age_65_infincome_20_40Ж0
\n", + "
" + ], + "text/plain": [ + " user_id age income sex kids_flg\n", + "373089 666262 age_65_inf income_20_40 Ж 0" + ] + }, + "execution_count": 124, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users[users['user_id'] == 666262]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Функция для расчета метрик" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "# Модели: rectools.models.RandomModel(random_state=32), rectools.models.PopularModel() с параметрами по умолчанию\n", + "models = {\n", + " \"random\": RandomModel(random_state=32),\n", + " \"popular\": PopularModel()\n", + "}\n", + "\n", + "# Метрики: 2 ранжирующие, 2 классификационные, 2 beyond-accuracy. Считаем по порогам 1, 5, 10. MAP обязательно\n", + "metrics = {\n", + " # классификационные\n", + " \"prec@1\": Precision(k=1),\n", + " \"prec@10\": Precision(k=5),\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall\": Recall(k=1),\n", + " \"recall\": Recall(k=5),\n", + " \"recall\": Recall(k=10),\n", + " # ранжирующие\n", + " \"MAP\": MAP(k=1),\n", + " \"MAP\": MAP(k=5),\n", + " \"MAP\": MAP(k=10),\n", + " # среднее значение обратного ранга\n", + " \"MRR\": MRR(k=1),\n", + " \"MRR\": MRR(k=5),\n", + " \"MRR\": MRR(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "\n", + "# 3 фолда для кросс-валидации по неделе\n", + "n_splits = 3\n", + "test_size = \"14D\"\n", + "\n", + "# Инициализированный Splitter для кросс-валидации\n", + "cv = TimeRangeSplitter(\n", + " test_size= test_size,\n", + " n_splits=n_splits,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "\n", + "# Количество рекомендаций для генерации (K)\n", + "K_RECOS = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=interactions,\n", + " user_features_df=None,\n", + " item_features_df=None,\n", + " )\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5 µs, sys: 0 ns, total: 5 µs\n", + "Wall time: 32.2 µs\n" + ] + } + ], + "source": [ + "%%time\n", + "import time\n", + "\n", + "def evaluate_models(interactions, models, metrics, cv, K_RECOS):\n", + " results = []\n", + " trained_models = {}\n", + "\n", + " # n_splits = cv.get_n_splits()\n", + " fold_iterator = cv.split(interactions, collect_fold_stats=True)\n", + "\n", + " for train_ids, test_ids, fold_info in tqdm(fold_iterator, total=n_splits):\n", + " print(f\"\\n==================== Fold {fold_info['i_split']}\")\n", + " pprint(fold_info)\n", + "\n", + " df_train = interactions.df.iloc[train_ids]\n", + " # Создаем RecTools Dataset через метод construct на train взаимодействиях для каждого фолда\n", + " dataset = Dataset.construct(df_train)\n", + " # Определили test\n", + " df_test = interactions.df.iloc[test_ids] # Предполагается, что Columns.UserItem определено\n", + " test_users = np.unique(df_test[Columns.User])\n", + "\n", + " catalog = df_train[Columns.Item].unique() # Каталог для рекомендаций\n", + "\n", + " # Обучаем модель (не забываем сделать deepcopy), рекоменуем K айтемов для каждого юзера, считаем метрики на test\n", + " for model_name, model in models.items():\n", + " \n", + " model_copy = copy.deepcopy(model)\n", + " # время перед началом обучения\n", + " start_time = time.time()\n", + " model.fit(dataset)\n", + " recos = model.recommend(\n", + " users=test_users,\n", + " dataset=dataset,\n", + " k=K_RECOS,\n", + " filter_viewed=True,\n", + " )\n", + " metric_values = calc_metrics(\n", + " metrics,\n", + " reco=recos,\n", + " interactions=df_test,\n", + " prev_interactions=df_train,\n", + " catalog=catalog,\n", + " )\n", + " \n", + " # время обучения\n", + " training_time = time.time() - start_time\n", + "\n", + " res = {\"fold\": fold_info[\"i_split\"], \"model\": model_name, \"training_time\": training_time}\n", + " res.update(metric_values)\n", + " results.append(res)\n", + "\n", + " # Сохраняем обученную модель\n", + " if fold_info['i_split'] == n_splits - 1: # Последний фолд\n", + " trained_models[model_name] = model_copy\n", + " \n", + "\n", + " # Результат оборачиваем в pandas DataFrame и усредняем по фолдам\n", + " results_df = pd.DataFrame(results)\n", + " average_results = results_df.groupby('model').mean()\n", + " average_results = average_results.reset_index()\n", + " return average_results, trained_models\n", + "\n", + "# %%time\n", + "# df_rec, trained_models = evaluate_models(interactions, models, metrics, cv, K_RECOS)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/3 [00:00 Date: Tue, 21 Nov 2023 12:46:08 +0300 Subject: [PATCH 2/7] hw_2 From 092e10ee66677c370328ff23f1d62f7d752a4259 Mon Sep 17 00:00:00 2001 From: anettapik <120940816+anettapik@users.noreply.github.com> Date: Tue, 21 Nov 2023 13:01:53 +0300 Subject: [PATCH 3/7] hw_2 From eb1f30e37deae39df0a01e955a39c1d026d5381f Mon Sep 17 00:00:00 2001 From: Anna Pikuleva Date: Mon, 27 Nov 2023 16:55:25 +0300 Subject: [PATCH 4/7] =?UTF-8?q?=D0=9F=D0=BE=D0=BF=D1=8B=D1=82=D0=BA=D0=B0?= =?UTF-8?q?=203=20=D0=B4=D0=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HW-3.1.ipynb | 4027 ++++++++++++++++ notebooks/HW-3.1.ipynb | 4027 ++++++++++++++++ notebooks/HW-3.2-rectools-research.ipynb | 725 +++ notebooks/HW-3.3-rectools-cv.ipynb | 4387 ++++++++++++++++++ notebooks/HW-3.4-model-for-online-recs.ipynb | 1161 +++++ userknn.py | 112 + 6 files changed, 14439 insertions(+) create mode 100644 HW-3.1.ipynb create mode 100644 notebooks/HW-3.1.ipynb create mode 100644 notebooks/HW-3.2-rectools-research.ipynb create mode 100644 notebooks/HW-3.3-rectools-cv.ipynb create mode 100644 notebooks/HW-3.4-model-for-online-recs.ipynb create mode 100644 userknn.py diff --git a/HW-3.1.ipynb b/HW-3.1.ipynb new file mode 100644 index 00000000..c05a3b71 --- /dev/null +++ b/HW-3.1.ipynb @@ -0,0 +1,4027 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "398a86d9", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import sys\n", + "sys.path.append('../')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8dbe6bf0", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import numpy as np\n", + "import pandas as pd\n", + "import scipy as sp\n", + "import requests\n", + "from tqdm.auto import tqdm\n", + "from scipy.stats import mode\n", + "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n", + "from rectools import Columns\n", + "from rectools.model_selection import TimeRangeSplitter\n", + "from rectools.metrics import Precision, Recall, MAP, MeanInvUserFreq, Serendipity, calc_metrics\n", + "from rectools.dataset.interactions import Interactions\n", + "\n", + "from service.utils.user_knn import UserKnn" + ] + }, + { + "cell_type": "markdown", + "id": "b1baa79f", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f2a9e540", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5476251, 5), (840197, 5), (15963, 14))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", + "users = pd.read_csv('../data/kion_train/users.csv')\n", + "items = pd.read_csv('../data/kion_train/items.csv')\n", + "\n", + "interactions.shape, users.shape, items.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "456d25f4", + "metadata": {}, + "outputs": [], + "source": [ + "interactions.rename(\n", + " columns={\n", + " 'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight\n", + " }, \n", + " inplace=True) \n", + "\n", + "interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime])" + ] + }, + { + "cell_type": "markdown", + "id": "6f7b9b0c", + "metadata": {}, + "source": [ + "## Intersection" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7c9c0c94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_iddatetimeweightwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", + "
" + ], + "text/plain": [ + " user_id item_id datetime weight watched_pct\n", + "0 176549 9506 2021-05-11 4250 72.0\n", + "1 699317 1659 2021-05-29 8317 100.0\n", + "2 656683 7107 2021-05-09 10 0.0\n", + "3 864613 7638 2021-07-05 14483 100.0\n", + "4 964868 9506 2021-04-30 6725 100.0\n", + "5476246 648596 12225 2021-08-13 76 0.0\n", + "5476247 546862 9673 2021-04-13 2308 49.0\n", + "5476248 697262 15297 2021-08-20 18307 63.0\n", + "5476249 384202 16197 2021-04-19 6203 100.0\n", + "5476250 319709 4436 2021-08-15 3921 45.0" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([interactions.head(), interactions.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c5c3ce6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Interactions dataframe shape: (5476251, 5)\n", + "Unique users in interactions: 962179\n", + "Unique items in interactions: 15706\n" + ] + } + ], + "source": [ + "print(f\"Interactions dataframe shape: {interactions.shape}\")\n", + "print(f\"Unique users in interactions: {interactions[Columns.User].nunique()}\")\n", + "print(f\"Unique items in interactions: {interactions[Columns.Item].nunique()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0214a978", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "min date in interactions: 2021-03-13 00:00:00\n", + "max date in interactions: 2021-08-22 00:00:00\n" + ] + } + ], + "source": [ + "max_date = interactions[Columns.Datetime].max()\n", + "min_date = interactions[Columns.Datetime].min()\n", + "\n", + "print(f\"min date in interactions: {min_date}\")\n", + "print(f\"max date in interactions: {max_date}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7829e796", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 5476251 entries, 0 to 5476250\n", + "Data columns (total 5 columns):\n", + " # Column Dtype \n", + "--- ------ ----- \n", + " 0 user_id int64 \n", + " 1 item_id int64 \n", + " 2 datetime datetime64[ns]\n", + " 3 weight int64 \n", + " 4 watched_pct float64 \n", + "dtypes: datetime64[ns](1), float64(1), int64(3)\n", + "memory usage: 208.9 MB\n" + ] + } + ], + "source": [ + "interactions.info()" + ] + }, + { + "cell_type": "markdown", + "id": "57cddf34", + "metadata": {}, + "source": [ + "## Users" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "de5dea16", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idageincomesexkids_flg
0973171age_25_34income_60_90М1
1962099age_18_24income_20_40М0
21047345age_45_54income_40_60Ж0
3721985age_45_54income_20_40Ж0
4704055age_35_44income_60_90Ж0
840192339025age_65_infincome_0_20Ж0
840193983617age_18_24income_20_40Ж1
840194251008NaNNaNNaN0
840195590706NaNNaNЖ0
840196166555age_65_infincome_20_40Ж0
\n", + "
" + ], + "text/plain": [ + " user_id age income sex kids_flg\n", + "0 973171 age_25_34 income_60_90 М 1\n", + "1 962099 age_18_24 income_20_40 М 0\n", + "2 1047345 age_45_54 income_40_60 Ж 0\n", + "3 721985 age_45_54 income_20_40 Ж 0\n", + "4 704055 age_35_44 income_60_90 Ж 0\n", + "840192 339025 age_65_inf income_0_20 Ж 0\n", + "840193 983617 age_18_24 income_20_40 Ж 1\n", + "840194 251008 NaN NaN NaN 0\n", + "840195 590706 NaN NaN Ж 0\n", + "840196 166555 age_65_inf income_20_40 Ж 0" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([users.head(), users.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e4e6d2f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Users dataframe shape (840197, 5)\n", + "Unique users: 840197\n" + ] + } + ], + "source": [ + "print(f\"Users dataframe shape {users.shape}\")\n", + "print(f\"Unique users: {users['user_id'].nunique()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "98b4ff6c", + "metadata": {}, + "source": [ + "## Items" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "19b43ff0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_typetitletitle_origrelease_yeargenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywords
010711filmПоговори с нейHable con ella2002.0драмы, зарубежные, детективы, мелодрамыИспанияNaN16.0NaNПедро АльмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...
12508filmГолые перцыSearch Party2014.0зарубежные, приключения, комедииСШАNaN16.0NaNСкот АрмстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...
159614538seriesСреди камнейDarklands2019.0драмы, спорт, криминалРоссия0.018.0NaNМарк О’Коннор, Конор МакМахонДэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд...Семнадцатилетний Дэмиен мечтает вырваться за п...Среди, камней, 2019, Россия
159623206seriesГошаNaN2019.0комедииРоссия0.016.0NaNМихаил МироновМкртыч Арзуманян, Виктория РунцоваДобродушный Гоша не может выйти из дома, чтобы...Гоша, 2019, Россия
\n", + "
" + ], + "text/plain": [ + " item_id content_type title title_orig release_year \\\n", + "0 10711 film Поговори с ней Hable con ella 2002.0 \n", + "1 2508 film Голые перцы Search Party 2014.0 \n", + "15961 4538 series Среди камней Darklands 2019.0 \n", + "15962 3206 series Гоша NaN 2019.0 \n", + "\n", + " genres countries for_kids \\\n", + "0 драмы, зарубежные, детективы, мелодрамы Испания NaN \n", + "1 зарубежные, приключения, комедии США NaN \n", + "15961 драмы, спорт, криминал Россия 0.0 \n", + "15962 комедии Россия 0.0 \n", + "\n", + " age_rating studios directors \\\n", + "0 16.0 NaN Педро Альмодовар \n", + "1 16.0 NaN Скот Армстронг \n", + "15961 18.0 NaN Марк О’Коннор, Конор МакМахон \n", + "15962 16.0 NaN Михаил Миронов \n", + "\n", + " actors \\\n", + "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", + "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", + "15961 Дэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд... \n", + "15962 Мкртыч Арзуманян, Виктория Рунцова \n", + "\n", + " description \\\n", + "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", + "1 Уморительная современная комедия на популярную... \n", + "15961 Семнадцатилетний Дэмиен мечтает вырваться за п... \n", + "15962 Добродушный Гоша не может выйти из дома, чтобы... \n", + "\n", + " keywords \n", + "0 Поговори, ней, 2002, Испания, друзья, любовь, ... \n", + "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... \n", + "15961 Среди, камней, 2019, Россия \n", + "15962 Гоша, 2019, Россия " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([items.head(2), items.tail(2)])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8c8fb319", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Items dataframe shape (15963, 14)\n", + "Unique item_id: 15963\n" + ] + } + ], + "source": [ + "print(f\"Items dataframe shape {items.shape}\")\n", + "print(f\"Unique item_id: {items['item_id'].nunique()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b35b460", + "metadata": {}, + "source": [ + "# userkNN model CV" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f60e6ecb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "alignmentgroup": "True", + "hovertemplate": "variable=user_id
datetime=%{x}
value=%{y}", + "legendgroup": "user_id", + "marker": { + "color": "#636efa", + "pattern": { + "shape": "" + } + }, + "name": "user_id", + "offsetgroup": "user_id", + "orientation": "v", + "showlegend": true, + "textposition": "auto", + "type": "bar", + "x": [ + "2021-03-13T00:00:00", + "2021-03-14T00:00:00", + "2021-03-15T00:00:00", + "2021-03-16T00:00:00", + "2021-03-17T00:00:00", + "2021-03-18T00:00:00", + "2021-03-19T00:00:00", + "2021-03-20T00:00:00", + "2021-03-21T00:00:00", + "2021-03-22T00:00:00", + "2021-03-23T00:00:00", + "2021-03-24T00:00:00", + "2021-03-25T00:00:00", + "2021-03-26T00:00:00", + "2021-03-27T00:00:00", + "2021-03-28T00:00:00", + "2021-03-29T00:00:00", + "2021-03-30T00:00:00", + "2021-03-31T00:00:00", + "2021-04-01T00:00:00", + "2021-04-02T00:00:00", + "2021-04-03T00:00:00", + "2021-04-04T00:00:00", + "2021-04-05T00:00:00", + "2021-04-06T00:00:00", + "2021-04-07T00:00:00", + "2021-04-08T00:00:00", + "2021-04-09T00:00:00", + "2021-04-10T00:00:00", + "2021-04-11T00:00:00", + "2021-04-12T00:00:00", + "2021-04-13T00:00:00", + "2021-04-14T00:00:00", + "2021-04-15T00:00:00", + "2021-04-16T00:00:00", + "2021-04-17T00:00:00", + "2021-04-18T00:00:00", + "2021-04-19T00:00:00", + "2021-04-20T00:00:00", + "2021-04-21T00:00:00", + "2021-04-22T00:00:00", + "2021-04-23T00:00:00", + "2021-04-24T00:00:00", + "2021-04-25T00:00:00", + "2021-04-26T00:00:00", + "2021-04-27T00:00:00", + "2021-04-28T00:00:00", + "2021-04-29T00:00:00", + "2021-04-30T00:00:00", + "2021-05-01T00:00:00", + "2021-05-02T00:00:00", + "2021-05-03T00:00:00", + "2021-05-04T00:00:00", + "2021-05-05T00:00:00", + "2021-05-06T00:00:00", + "2021-05-07T00:00:00", + "2021-05-08T00:00:00", + "2021-05-09T00:00:00", + "2021-05-10T00:00:00", + "2021-05-11T00:00:00", + "2021-05-12T00:00:00", + "2021-05-13T00:00:00", + "2021-05-14T00:00:00", + "2021-05-15T00:00:00", + "2021-05-16T00:00:00", + "2021-05-17T00:00:00", + "2021-05-18T00:00:00", + "2021-05-19T00:00:00", + "2021-05-20T00:00:00", + "2021-05-21T00:00:00", + "2021-05-22T00:00:00", + "2021-05-23T00:00:00", + "2021-05-24T00:00:00", + "2021-05-25T00:00:00", + "2021-05-26T00:00:00", + "2021-05-27T00:00:00", + "2021-05-28T00:00:00", + "2021-05-29T00:00:00", + "2021-05-30T00:00:00", + "2021-05-31T00:00:00", + "2021-06-01T00:00:00", + "2021-06-02T00:00:00", + "2021-06-03T00:00:00", + "2021-06-04T00:00:00", + "2021-06-05T00:00:00", + "2021-06-06T00:00:00", + "2021-06-07T00:00:00", + "2021-06-08T00:00:00", + "2021-06-09T00:00:00", + "2021-06-10T00:00:00", + "2021-06-11T00:00:00", + "2021-06-12T00:00:00", + "2021-06-13T00:00:00", + "2021-06-14T00:00:00", + "2021-06-15T00:00:00", + "2021-06-16T00:00:00", + "2021-06-17T00:00:00", + "2021-06-18T00:00:00", + "2021-06-19T00:00:00", + "2021-06-20T00:00:00", + "2021-06-21T00:00:00", + "2021-06-22T00:00:00", + "2021-06-23T00:00:00", + "2021-06-24T00:00:00", + "2021-06-25T00:00:00", + "2021-06-26T00:00:00", + "2021-06-27T00:00:00", + "2021-06-28T00:00:00", + "2021-06-29T00:00:00", + "2021-06-30T00:00:00", + "2021-07-01T00:00:00", + "2021-07-02T00:00:00", + "2021-07-03T00:00:00", + "2021-07-04T00:00:00", + "2021-07-05T00:00:00", + "2021-07-06T00:00:00", + "2021-07-07T00:00:00", + "2021-07-08T00:00:00", + "2021-07-09T00:00:00", + "2021-07-10T00:00:00", + "2021-07-11T00:00:00", + "2021-07-12T00:00:00", + "2021-07-13T00:00:00", + "2021-07-14T00:00:00", + "2021-07-15T00:00:00", + "2021-07-16T00:00:00", + "2021-07-17T00:00:00", + "2021-07-18T00:00:00", + "2021-07-19T00:00:00", + "2021-07-20T00:00:00", + "2021-07-21T00:00:00", + "2021-07-22T00:00:00", + "2021-07-23T00:00:00", + "2021-07-24T00:00:00", + "2021-07-25T00:00:00", + "2021-07-26T00:00:00", + "2021-07-27T00:00:00", + "2021-07-28T00:00:00", + "2021-07-29T00:00:00", + "2021-07-30T00:00:00", + "2021-07-31T00:00:00", + "2021-08-01T00:00:00", + "2021-08-02T00:00:00", + "2021-08-03T00:00:00", + "2021-08-04T00:00:00", + "2021-08-05T00:00:00", + "2021-08-06T00:00:00", + "2021-08-07T00:00:00", + "2021-08-08T00:00:00", + "2021-08-09T00:00:00", + "2021-08-10T00:00:00", + "2021-08-11T00:00:00", + "2021-08-12T00:00:00", + "2021-08-13T00:00:00", + "2021-08-14T00:00:00", + "2021-08-15T00:00:00", + "2021-08-16T00:00:00", + "2021-08-17T00:00:00", + "2021-08-18T00:00:00", + "2021-08-19T00:00:00", + "2021-08-20T00:00:00", + "2021-08-21T00:00:00", + "2021-08-22T00:00:00" + ], + "xaxis": "x", + "y": [ + 16104, + 15606, + 12363, + 12643, + 12753, + 12788, + 13657, + 15346, + 15560, + 12752, + 13147, + 13435, + 12698, + 13909, + 15657, + 16112, + 12783, + 13101, + 13460, + 12966, + 14084, + 15431, + 15346, + 12642, + 12528, + 13129, + 13827, + 14416, + 15937, + 16046, + 12835, + 12322, + 12451, + 12275, + 13342, + 15464, + 16275, + 14286, + 20420, + 23200, + 21274, + 22127, + 26161, + 28964, + 21625, + 22590, + 21406, + 19987, + 21406, + 23479, + 24767, + 26267, + 25983, + 23941, + 23510, + 23201, + 27550, + 25986, + 27242, + 20957, + 20578, + 20729, + 21152, + 24530, + 24914, + 20960, + 20574, + 21561, + 22712, + 25697, + 27895, + 29978, + 24317, + 23667, + 22529, + 23881, + 24131, + 29035, + 31308, + 26821, + 26587, + 27577, + 28683, + 33150, + 34795, + 37096, + 31402, + 31107, + 32896, + 38964, + 37935, + 38619, + 42125, + 38973, + 35993, + 57686, + 41440, + 42174, + 43679, + 47989, + 39127, + 39693, + 41688, + 38394, + 41428, + 45898, + 48903, + 43301, + 43887, + 67749, + 53900, + 46642, + 48832, + 52812, + 43375, + 41380, + 41163, + 41592, + 40955, + 44798, + 46250, + 42487, + 43764, + 43128, + 43010, + 44878, + 49714, + 54139, + 45541, + 44431, + 44422, + 46313, + 46911, + 50317, + 54378, + 48531, + 49324, + 50267, + 50585, + 53121, + 59499, + 62128, + 53495, + 52181, + 51911, + 51047, + 53745, + 59316, + 61454, + 52794, + 53712, + 55617, + 56497, + 55843, + 61644, + 66546, + 54546, + 54311, + 56789, + 58640, + 60145, + 68834, + 71171 + ], + "yaxis": "y" + } + ], + "layout": { + "barmode": "relative", + "legend": { + "title": { + "text": "variable" + }, + "tracegroupgap": 0 + }, + "margin": { + "t": 60 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "datetime" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "value" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = px.bar(interactions.groupby(Columns.Datetime)[Columns.User].agg('count'))\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "43f216d0", + "metadata": {}, + "source": [ + "Из графика видны **недельные тенденции** просмотров, поэтому следует fold-ы разделять по 7 дней, но т.к. на семинаре дали \"намек\", что private dataset имеет количество дней, меньшее чем 7. Поэтому фолды будут разбиваться на **5 и 7 дней**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "07fbdb30", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.to_datetime('23-05-2021', format='%d-%m-%Y').weekday()" + ] + }, + { + "cell_type": "markdown", + "id": "2ff625b2", + "metadata": {}, + "source": [ + "### train test split" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "759ba346", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_range(\n", + " last_date: pd.Timestamp, \n", + " n_folds: int = 7, \n", + " unit: str = \"W\", \n", + " n_units: int = 1, \n", + " show: bool = True,\n", + "):\n", + " periods = n_folds + 1\n", + " freq = f\"{n_units}{unit}\"\n", + " \n", + " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n", + " \n", + " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", + " \n", + " if show:\n", + " print(\n", + " f\"start_date: {start_date}\\n\"\n", + " f\"last_date: {last_date}\\n\"\n", + " f\"periods: {periods}\\n\"\n", + " f\"freq: {freq}\\n\"\n", + " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n", + " )\n", + " \n", + " return date_range" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "38bfd397", + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG_CV = {\n", + " \"cv_v1\": {\n", + " \"n_folds\": 7,\n", + " \"unit\": \"W\",\n", + " \"n_units\": 1,\n", + " },\n", + " \"cv_v2\": {\n", + " \"n_folds\": 7,\n", + " \"unit\": \"D\",\n", + " \"n_units\": 5,\n", + " }, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f518e089", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Timestamp('2021-08-22 00:00:00')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_date = interactions[Columns.Datetime].max().normalize()\n", + "last_date" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1fd68b9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "***Folds v1***\n", + "start_date: 2021-07-13 00:00:00\n", + "last_date: 2021-08-22 00:00:00\n", + "periods: 8\n", + "freq: 5D\n", + "Test fold borders: ['2021-07-13' '2021-07-18' '2021-07-23' '2021-07-28' '2021-08-02'\n", + " '2021-08-07' '2021-08-12' '2021-08-17']\n", + "\n" + ] + } + ], + "source": [ + "print(\"***Folds v1***\")\n", + "date_range_v1 = create_data_range(\n", + " last_date, \n", + " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n", + " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n", + " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "efc59555", + "metadata": {}, + "source": [ + "**генерируем фолды** " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9fae43f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Real number of folds: 7\n" + ] + } + ], + "source": [ + "cv_v1 = TimeRangeSplitter(\n", + " date_range=date_range_v1,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n", + "\n", + "CV = [cv_v1]" + ] + }, + { + "cell_type": "markdown", + "id": "e15a83a7", + "metadata": {}, + "source": [ + "**Формируем метрики**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8f7742c6", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b21a1ecf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'cosine_userknn_K30': ,\n", + " 'tfidf_userknn_K30': ,\n", + " 'bm25_userknn_K30': ,\n", + " 'cosine_userknn_K40': ,\n", + " 'tfidf_userknn_K40': ,\n", + " 'bm25_userknn_K40': }" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "K = [30, 40]\n", + "models = dict()\n", + "\n", + "for k in K:\n", + " models[f\"cosine_userknn_K{k}\"] = CosineRecommender(K=k)\n", + " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", + " models[f\"bm25_userknn_K{k}\"] = BM25Recommender(K=k)\n", + "\n", + "models" + ] + }, + { + "cell_type": "markdown", + "id": "0103149a", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e78b8221", + "metadata": {}, + "outputs": [], + "source": [ + "N_USERS = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "50dcff0b", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "results = []\n", + "\n", + "for idx, cv in enumerate(CV):\n", + " print(f\"\\n CV version {idx}\")\n", + " fold_iterator = cv.split(Interactions(interactions), collect_fold_stats=True)\n", + "\n", + " for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):\n", + " print(f\"\\n==================== Fold {i_fold}\")\n", + " pprint(fold_info)\n", + "\n", + " df_train = interactions.iloc[train_ids].copy()\n", + " df_test = interactions.iloc[test_ids][Columns.UserItem].copy()\n", + "\n", + " catalog = df_train[Columns.Item].unique()\n", + "\n", + " for model_name, model in models.items():\n", + " userknn_model = UserKnn(model=model, N_users=N_USERS, use_weight_idf=True)\n", + " userknn_model.fit(df_train)\n", + "\n", + " if 'bm25' in model_name:\n", + " recos = userknn_model.predict(df_test, bmp25=True)\n", + " else:\n", + " recos = userknn_model.predict(df_test)\n", + "\n", + " metric_values = calc_metrics(\n", + " metrics,\n", + " reco=recos,\n", + " interactions=df_test,\n", + " prev_interactions=df_train,\n", + " catalog=catalog,\n", + " )\n", + "\n", + " full_model_name = f\"{model_name}_cv-{idx}\"\n", + " fold = {\"fold\": i_fold, \"model\": full_model_name}\n", + " fold.update(metric_values)\n", + " results.append(fold)" + ] + }, + { + "cell_type": "markdown", + "id": "708ec5c2", + "metadata": {}, + "source": [ + "Работало больше 10 часов, случайно при перезапуске ноутбука была вызвана ячейка и остановлена, поэтому завершилась с ошибкой, поэтому ошибку убрали для лучшего вида" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "d7e2ffa7", + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldmodelprec@10recall@10MAP@10noveltyserendipity
00cosine_userknn_K30_cv-00.0035570.0211280.0036958.3314910.000040
10tfidf_userknn_K30_cv-00.0064390.0391020.0073358.1550510.000048
20bm25_userknn_K30_cv-00.0025930.0134940.0025319.3984670.000081
30cosine_userknn_K40_cv-00.0032820.0193230.0034018.5615230.000043
40tfidf_userknn_K40_cv-00.0061780.0374580.0069578.3004040.000052
50bm25_userknn_K40_cv-00.0022410.0112550.0022109.6755330.000081
61cosine_userknn_K30_cv-00.0035050.0200020.0035808.3982480.000046
71tfidf_userknn_K30_cv-00.0063280.0368440.0070228.2401330.000058
81bm25_userknn_K30_cv-00.0027220.0138560.0026589.4846920.000088
91cosine_userknn_K40_cv-00.0032450.0183680.0033058.6269060.000047
101tfidf_userknn_K40_cv-00.0061500.0359640.0069168.3779880.000061
111bm25_userknn_K40_cv-00.0024060.0120670.0023939.7564580.000086
122cosine_userknn_K30_cv-00.0032610.0184980.0032958.4392630.000047
132tfidf_userknn_K30_cv-00.0059400.0342330.0064798.2623670.000059
142bm25_userknn_K30_cv-00.0027200.0134220.0025309.5356310.000091
152cosine_userknn_K40_cv-00.0030450.0170860.0031008.6615850.000050
162tfidf_userknn_K40_cv-00.0059140.0340710.0064398.3966180.000063
172bm25_userknn_K40_cv-00.0024040.0116380.0022319.7991190.000090
183cosine_userknn_K30_cv-00.0032770.0187860.0033958.4449860.000045
193tfidf_userknn_K30_cv-00.0060230.0341710.0063288.2765030.000059
203bm25_userknn_K30_cv-00.0026200.0127620.0024979.5609840.000091
213cosine_userknn_K40_cv-00.0030760.0175120.0031738.6581500.000045
223tfidf_userknn_K40_cv-00.0059190.0333680.0062538.3991690.000062
233bm25_userknn_K40_cv-00.0023370.0112730.0022539.8163250.000089
244cosine_userknn_K30_cv-00.0031180.0180640.0031578.4858990.000042
254tfidf_userknn_K30_cv-00.0059110.0336260.0063968.2824280.000059
264bm25_userknn_K30_cv-00.0025370.0123680.0024709.5996450.000086
274cosine_userknn_K40_cv-00.0028720.0165090.0028838.7119840.000043
284tfidf_userknn_K40_cv-00.0057930.0330280.0062618.4166800.000062
294bm25_userknn_K40_cv-00.0022130.0108600.0021799.8662010.000085
305cosine_userknn_K30_cv-00.0030030.0162520.0028998.4989680.000043
315tfidf_userknn_K30_cv-00.0055270.0309420.0058238.3252730.000057
325bm25_userknn_K30_cv-00.0025970.0122630.0023869.6469570.000100
335cosine_userknn_K40_cv-00.0027650.0147130.0026618.7175590.000047
345tfidf_userknn_K40_cv-00.0055450.0308920.0058178.4540910.000059
355bm25_userknn_K40_cv-00.0023020.0107770.0021359.9140420.000100
366cosine_userknn_K30_cv-00.0029630.0165320.0028878.5638090.000050
376tfidf_userknn_K30_cv-00.0053300.0307170.0057638.3662590.000064
386bm25_userknn_K30_cv-00.0025710.0126910.0024789.7150970.000100
396cosine_userknn_K40_cv-00.0027690.0154480.0026758.7750580.000051
406tfidf_userknn_K40_cv-00.0052840.0304180.0056978.4884730.000066
416bm25_userknn_K40_cv-00.0023400.0112780.0022089.9646640.000099
\n", + "
" + ], + "text/plain": [ + " fold model prec@10 recall@10 MAP@10 novelty \\\n", + "0 0 cosine_userknn_K30_cv-0 0.003557 0.021128 0.003695 8.331491 \n", + "1 0 tfidf_userknn_K30_cv-0 0.006439 0.039102 0.007335 8.155051 \n", + "2 0 bm25_userknn_K30_cv-0 0.002593 0.013494 0.002531 9.398467 \n", + "3 0 cosine_userknn_K40_cv-0 0.003282 0.019323 0.003401 8.561523 \n", + "4 0 tfidf_userknn_K40_cv-0 0.006178 0.037458 0.006957 8.300404 \n", + "5 0 bm25_userknn_K40_cv-0 0.002241 0.011255 0.002210 9.675533 \n", + "6 1 cosine_userknn_K30_cv-0 0.003505 0.020002 0.003580 8.398248 \n", + "7 1 tfidf_userknn_K30_cv-0 0.006328 0.036844 0.007022 8.240133 \n", + "8 1 bm25_userknn_K30_cv-0 0.002722 0.013856 0.002658 9.484692 \n", + "9 1 cosine_userknn_K40_cv-0 0.003245 0.018368 0.003305 8.626906 \n", + "10 1 tfidf_userknn_K40_cv-0 0.006150 0.035964 0.006916 8.377988 \n", + "11 1 bm25_userknn_K40_cv-0 0.002406 0.012067 0.002393 9.756458 \n", + "12 2 cosine_userknn_K30_cv-0 0.003261 0.018498 0.003295 8.439263 \n", + "13 2 tfidf_userknn_K30_cv-0 0.005940 0.034233 0.006479 8.262367 \n", + "14 2 bm25_userknn_K30_cv-0 0.002720 0.013422 0.002530 9.535631 \n", + "15 2 cosine_userknn_K40_cv-0 0.003045 0.017086 0.003100 8.661585 \n", + "16 2 tfidf_userknn_K40_cv-0 0.005914 0.034071 0.006439 8.396618 \n", + "17 2 bm25_userknn_K40_cv-0 0.002404 0.011638 0.002231 9.799119 \n", + "18 3 cosine_userknn_K30_cv-0 0.003277 0.018786 0.003395 8.444986 \n", + "19 3 tfidf_userknn_K30_cv-0 0.006023 0.034171 0.006328 8.276503 \n", + "20 3 bm25_userknn_K30_cv-0 0.002620 0.012762 0.002497 9.560984 \n", + "21 3 cosine_userknn_K40_cv-0 0.003076 0.017512 0.003173 8.658150 \n", + "22 3 tfidf_userknn_K40_cv-0 0.005919 0.033368 0.006253 8.399169 \n", + "23 3 bm25_userknn_K40_cv-0 0.002337 0.011273 0.002253 9.816325 \n", + "24 4 cosine_userknn_K30_cv-0 0.003118 0.018064 0.003157 8.485899 \n", + "25 4 tfidf_userknn_K30_cv-0 0.005911 0.033626 0.006396 8.282428 \n", + "26 4 bm25_userknn_K30_cv-0 0.002537 0.012368 0.002470 9.599645 \n", + "27 4 cosine_userknn_K40_cv-0 0.002872 0.016509 0.002883 8.711984 \n", + "28 4 tfidf_userknn_K40_cv-0 0.005793 0.033028 0.006261 8.416680 \n", + "29 4 bm25_userknn_K40_cv-0 0.002213 0.010860 0.002179 9.866201 \n", + "30 5 cosine_userknn_K30_cv-0 0.003003 0.016252 0.002899 8.498968 \n", + "31 5 tfidf_userknn_K30_cv-0 0.005527 0.030942 0.005823 8.325273 \n", + "32 5 bm25_userknn_K30_cv-0 0.002597 0.012263 0.002386 9.646957 \n", + "33 5 cosine_userknn_K40_cv-0 0.002765 0.014713 0.002661 8.717559 \n", + "34 5 tfidf_userknn_K40_cv-0 0.005545 0.030892 0.005817 8.454091 \n", + "35 5 bm25_userknn_K40_cv-0 0.002302 0.010777 0.002135 9.914042 \n", + "36 6 cosine_userknn_K30_cv-0 0.002963 0.016532 0.002887 8.563809 \n", + "37 6 tfidf_userknn_K30_cv-0 0.005330 0.030717 0.005763 8.366259 \n", + "38 6 bm25_userknn_K30_cv-0 0.002571 0.012691 0.002478 9.715097 \n", + "39 6 cosine_userknn_K40_cv-0 0.002769 0.015448 0.002675 8.775058 \n", + "40 6 tfidf_userknn_K40_cv-0 0.005284 0.030418 0.005697 8.488473 \n", + "41 6 bm25_userknn_K40_cv-0 0.002340 0.011278 0.002208 9.964664 \n", + "\n", + " serendipity \n", + "0 0.000040 \n", + "1 0.000048 \n", + "2 0.000081 \n", + "3 0.000043 \n", + "4 0.000052 \n", + "5 0.000081 \n", + "6 0.000046 \n", + "7 0.000058 \n", + "8 0.000088 \n", + "9 0.000047 \n", + "10 0.000061 \n", + "11 0.000086 \n", + "12 0.000047 \n", + "13 0.000059 \n", + "14 0.000091 \n", + "15 0.000050 \n", + "16 0.000063 \n", + "17 0.000090 \n", + "18 0.000045 \n", + "19 0.000059 \n", + "20 0.000091 \n", + "21 0.000045 \n", + "22 0.000062 \n", + "23 0.000089 \n", + "24 0.000042 \n", + "25 0.000059 \n", + "26 0.000086 \n", + "27 0.000043 \n", + "28 0.000062 \n", + "29 0.000085 \n", + "30 0.000043 \n", + "31 0.000057 \n", + "32 0.000100 \n", + "33 0.000047 \n", + "34 0.000059 \n", + "35 0.000100 \n", + "36 0.000050 \n", + "37 0.000064 \n", + "38 0.000100 \n", + "39 0.000051 \n", + "40 0.000066 \n", + "41 0.000099 " + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics = pd.DataFrame(results)\n", + "df_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "a0334b9a", + "metadata": {}, + "outputs": [], + "source": [ + "df_metrics.to_pickle(\"../data/hw_3/df_metrics.pickle\")" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "446530ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldprec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-03.00.0026230.0129800.0025079.5630680.000091
bm25_userknn_K40_cv-03.00.0023200.0113070.0022309.8274770.000090
cosine_userknn_K30_cv-03.00.0032410.0184660.0032728.4518090.000045
cosine_userknn_K40_cv-03.00.0030080.0169940.0030288.6732520.000047
tfidf_userknn_K30_cv-03.00.0059280.0342340.0064498.2725730.000058
tfidf_userknn_K40_cv-03.00.0058260.0336000.0063348.4047750.000061
\n", + "
" + ], + "text/plain": [ + " fold prec@10 recall@10 MAP@10 novelty \\\n", + "model \n", + "bm25_userknn_K30_cv-0 3.0 0.002623 0.012980 0.002507 9.563068 \n", + "bm25_userknn_K40_cv-0 3.0 0.002320 0.011307 0.002230 9.827477 \n", + "cosine_userknn_K30_cv-0 3.0 0.003241 0.018466 0.003272 8.451809 \n", + "cosine_userknn_K40_cv-0 3.0 0.003008 0.016994 0.003028 8.673252 \n", + "tfidf_userknn_K30_cv-0 3.0 0.005928 0.034234 0.006449 8.272573 \n", + "tfidf_userknn_K40_cv-0 3.0 0.005826 0.033600 0.006334 8.404775 \n", + "\n", + " serendipity \n", + "model \n", + "bm25_userknn_K30_cv-0 0.000091 \n", + "bm25_userknn_K40_cv-0 0.000090 \n", + "cosine_userknn_K30_cv-0 0.000045 \n", + "cosine_userknn_K40_cv-0 0.000047 \n", + "tfidf_userknn_K30_cv-0 0.000058 \n", + "tfidf_userknn_K40_cv-0 0.000061 " + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics.groupby('model').mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "5fb9ba9f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
prec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-00.0000720.0006120.0000830.1044680.000007
bm25_userknn_K40_cv-00.0000740.0004420.0000810.0973590.000007
cosine_userknn_K30_cv-00.0002310.0017490.0003140.0746990.000003
cosine_userknn_K40_cv-00.0002130.0016030.0002950.0693100.000003
tfidf_userknn_K30_cv-00.0003980.0030030.0005770.0666270.000005
tfidf_userknn_K40_cv-00.0003210.0025340.0004870.0595650.000004
\n", + "
" + ], + "text/plain": [ + " prec@10 recall@10 MAP@10 novelty serendipity\n", + "model \n", + "bm25_userknn_K30_cv-0 0.000072 0.000612 0.000083 0.104468 0.000007\n", + "bm25_userknn_K40_cv-0 0.000074 0.000442 0.000081 0.097359 0.000007\n", + "cosine_userknn_K30_cv-0 0.000231 0.001749 0.000314 0.074699 0.000003\n", + "cosine_userknn_K40_cv-0 0.000213 0.001603 0.000295 0.069310 0.000003\n", + "tfidf_userknn_K30_cv-0 0.000398 0.003003 0.000577 0.066627 0.000005\n", + "tfidf_userknn_K40_cv-0 0.000321 0.002534 0.000487 0.059565 0.000004" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics.groupby('model').std()[metrics.keys()]" + ] + }, + { + "cell_type": "markdown", + "id": "41828ee5", + "metadata": {}, + "source": [ + "по **ofline** метрикам лучше всего себя показывает модель TFIDFRecommender\n", + "TFIDFRecommender подбор К" + ] + }, + { + "cell_type": "markdown", + "id": "7a8a0a41", + "metadata": {}, + "source": [ + "# Подбор оптимального K для TFIDFRecommender" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "1e91892d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tfidf_userknn_K50': ,\n", + " 'tfidf_userknn_K60': ,\n", + " 'tfidf_userknn_K70': }" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "N_USERS = 50\n", + "\n", + "# Т.к. метрики для К 30 и 40 уже есть\n", + "K = [k for k in range(50, 71, 10)]\n", + "models = dict()\n", + "\n", + "for k in K:\n", + " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", + "models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c2c43b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================== Fold 0\n", + "{'End date': Timestamp('2021-07-18 00:00:00', freq='5D'),\n", + " 'Start date': Timestamp('2021-07-13 00:00:00', freq='5D'),\n", + " 'Test': 156580,\n", + " 'Test items': 5793,\n", + " 'Test users': 68150,\n", + " 'Train': 3281612,\n", + " 'Train items': 14754,\n", + " 'Train users': 652905}\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "211234f034a54bae86b94dff33b9f5c4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/652905 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idlast_watch_dttotal_durwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", + "" + ], + "text/plain": [ + " user_id item_id last_watch_dt total_dur watched_pct\n", + "0 176549 9506 2021-05-11 4250 72.0\n", + "1 699317 1659 2021-05-29 8317 100.0\n", + "2 656683 7107 2021-05-09 10 0.0\n", + "3 864613 7638 2021-07-05 14483 100.0\n", + "4 964868 9506 2021-04-30 6725 100.0\n", + "5476246 648596 12225 2021-08-13 76 0.0\n", + "5476247 546862 9673 2021-04-13 2308 49.0\n", + "5476248 697262 15297 2021-08-20 18307 63.0\n", + "5476249 384202 16197 2021-04-19 6203 100.0\n", + "5476250 319709 4436 2021-08-15 3921 45.0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([interactions.head(), interactions.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "dc4d9fd7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(962179,)" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions['user_id'].unique().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "b7861d19", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(961833, 1.0),\n", + " (961849, 1.0),\n", + " (961857, 1.0),\n", + " (961871, 1.0),\n", + " (961873, 1.0),\n", + " (961876, 1.0),\n", + " (961887, 1.0),\n", + " (961907, 1.0),\n", + " (961910, 1.0),\n", + " (961912, 1.0)]" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import dill\n", + "\n", + "with open('../service/weights/userKNN/userknn_tfidf_k30.dill', 'rb') as f:\n", + " userknn = dill.load(f)\n", + "\n", + "userknn.similar_items(962178, 10)" + ] + }, + { + "cell_type": "markdown", + "id": "1905033a", + "metadata": {}, + "source": [ + "# Popular Model" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "2df74dba", + "metadata": {}, + "outputs": [], + "source": [ + "from rectools.models import PopularModel\n", + "from rectools.dataset import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6ba37a73", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Timestamp('2021-08-22 00:00:00')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_date = interactions[Columns.Datetime].max().normalize()\n", + "max_date" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "901353f9", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "train = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", + " interactions[Columns.Datetime] < max_date - pd.Timedelta(5, \"D\")]\n", + "\n", + "test = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", + " interactions[Columns.Datetime] >= max_date - pd.Timedelta(5, \"D\")]\n", + "\n", + "dataset_train = Dataset.construct(train)" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "id": "f08e3579", + "metadata": {}, + "outputs": [], + "source": [ + "popilarity_models = {\n", + " \"popular\": PopularModel(),\n", + " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "id": "03c3bfb6", + "metadata": {}, + "outputs": [], + "source": [ + "popilarity_models[\"popular\"].fit(dataset_train)\n", + "popilarity_models[\"popular_mw\"].fit(dataset_train);" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "id": "0d7de49e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 24, 20, 31, 15, 167, 81, 89, 135, 355, 116])" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "popilarity_models[\"popular\"].popularity_list[0][:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "id": "05ff208d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([11363, 11681, 12841, 13017, 2069, 13691, 13552, 13397, 11774,\n", + " 12913])" + ] + }, + "execution_count": 147, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "popilarity_models[\"popular_mw\"].popularity_list[0][:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "id": "00ef735c", + "metadata": {}, + "outputs": [], + "source": [ + "pecos_pop = popilarity_models[\"popular\"].recommend(\n", + " users=test[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=100,\n", + " filter_viewed=False,\n", + ")\n", + "\n", + "pecos_pop_mw = popilarity_models[\"popular_mw\"].recommend(\n", + " users=test[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=100,\n", + " filter_viewed=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "id": "b302db55", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {\n", + " \"prec@5\": Precision(k=5),\n", + " \"recall@5\": Recall(k=5),\n", + " \"MAP@5\": MAP(k=5),\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@20\": MAP(k=20),\n", + " \"prec@20\": Precision(k=20),\n", + " \"recall@20\": Recall(k=20),\n", + " \"MAP@100\": MAP(k=100),\n", + " \"prec@100\": Precision(k=100),\n", + " \"recall@100\": Recall(k=100),\n", + " \"MAP@100\": MAP(k=100),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "catalog = train[Columns.Item].unique()\n", + "metric_values_pop = calc_metrics(metrics, pecos_pop, test, train, catalog)\n", + "metric_values_pop_mean_weight = calc_metrics(metrics, pecos_pop_mw, test, train, catalog)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "id": "9631093b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@5': 0.0017855613317256697,\n", + " 'recall@5': 0.004623809755660008,\n", + " 'prec@10': 0.0011648975773029461,\n", + " 'recall@10': 0.005682095875283048,\n", + " 'prec@20': 0.0010502526799891945,\n", + " 'recall@20': 0.00880186008464912,\n", + " 'prec@100': 0.003247020220987923,\n", + " 'recall@100': 0.16609031082955295,\n", + " 'MAP@5': 0.0013179725619140792,\n", + " 'MAP@20': 0.0016695313583723814,\n", + " 'MAP@100': 0.005578924867474493,\n", + " 'novelty': 9.976033936531364,\n", + " 'serendipity': 1.2752762676592953e-05}" + ] + }, + "execution_count": 153, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metric_values_pop" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "id": "5d55b781", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@5': 9.09252633867684e-05,\n", + " 'recall@5': 0.00014799438063171262,\n", + " 'prec@10': 4.612151041357817e-05,\n", + " 'recall@10': 0.00015458316783365238,\n", + " 'prec@20': 2.635514880775895e-05,\n", + " 'recall@20': 0.00016946607539568094,\n", + " 'prec@100': 0.00015147621777259455,\n", + " 'recall@100': 0.0065476971391510656,\n", + " 'MAP@5': 3.0257754846536496e-05,\n", + " 'MAP@20': 3.1771198360212185e-05,\n", + " 'MAP@100': 0.00011355765992119742,\n", + " 'novelty': 17.423655787689828,\n", + " 'serendipity': 1.8991632826477633e-06}" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metric_values_pop_mean_weight" + ] + }, + { + "cell_type": "markdown", + "id": "e5a4a011", + "metadata": {}, + "source": [ + "**На офлайн метриках выигрывает обычная модель по популярному**" + ] + }, + { + "cell_type": "markdown", + "id": "5875fab7", + "metadata": {}, + "source": [ + "# Save item_idf data" + ] + }, + { + "cell_type": "markdown", + "id": "6589996f", + "metadata": {}, + "source": [ + "Создаем датасет со взвешенными item-ами по механизму idf для использования в будущем" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "d62cabb9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexidf
095067.150811
116598.524953
271075.821207
376388.407093
466867.778734
.........
15701783314.822785
15702912514.822785
157031006414.822785
157041301914.822785
157051054214.822785
\n", + "

15706 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " index idf\n", + "0 9506 7.150811\n", + "1 1659 8.524953\n", + "2 7107 5.821207\n", + "3 7638 8.407093\n", + "4 6686 7.778734\n", + "... ... ...\n", + "15701 7833 14.822785\n", + "15702 9125 14.822785\n", + "15703 10064 14.822785\n", + "15704 13019 14.822785\n", + "15705 10542 14.822785\n", + "\n", + "[15706 rows x 2 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_cnt = Counter(interactions['item_id'].values)\n", + "item_idf = pd.DataFrame.from_dict(item_cnt, orient='index', columns=['doc_freq']).reset_index()\n", + "n = interactions.shape[0]\n", + "item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: np.log((1 + n) / (1 + x) + 1))\n", + "del item_idf['doc_freq']\n", + "item_idf" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "7da47dfc", + "metadata": {}, + "outputs": [], + "source": [ + "item_idf = item_idf.sort_values(\"idf\", ascending=False)\n", + "item_idf.to_csv('../data/kion_train/items_idf.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdce2b60", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/HW-3.1.ipynb b/notebooks/HW-3.1.ipynb new file mode 100644 index 00000000..c05a3b71 --- /dev/null +++ b/notebooks/HW-3.1.ipynb @@ -0,0 +1,4027 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "398a86d9", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import sys\n", + "sys.path.append('../')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8dbe6bf0", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import numpy as np\n", + "import pandas as pd\n", + "import scipy as sp\n", + "import requests\n", + "from tqdm.auto import tqdm\n", + "from scipy.stats import mode\n", + "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n", + "from rectools import Columns\n", + "from rectools.model_selection import TimeRangeSplitter\n", + "from rectools.metrics import Precision, Recall, MAP, MeanInvUserFreq, Serendipity, calc_metrics\n", + "from rectools.dataset.interactions import Interactions\n", + "\n", + "from service.utils.user_knn import UserKnn" + ] + }, + { + "cell_type": "markdown", + "id": "b1baa79f", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f2a9e540", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5476251, 5), (840197, 5), (15963, 14))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", + "users = pd.read_csv('../data/kion_train/users.csv')\n", + "items = pd.read_csv('../data/kion_train/items.csv')\n", + "\n", + "interactions.shape, users.shape, items.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "456d25f4", + "metadata": {}, + "outputs": [], + "source": [ + "interactions.rename(\n", + " columns={\n", + " 'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight\n", + " }, \n", + " inplace=True) \n", + "\n", + "interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime])" + ] + }, + { + "cell_type": "markdown", + "id": "6f7b9b0c", + "metadata": {}, + "source": [ + "## Intersection" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7c9c0c94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_iddatetimeweightwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", + "
" + ], + "text/plain": [ + " user_id item_id datetime weight watched_pct\n", + "0 176549 9506 2021-05-11 4250 72.0\n", + "1 699317 1659 2021-05-29 8317 100.0\n", + "2 656683 7107 2021-05-09 10 0.0\n", + "3 864613 7638 2021-07-05 14483 100.0\n", + "4 964868 9506 2021-04-30 6725 100.0\n", + "5476246 648596 12225 2021-08-13 76 0.0\n", + "5476247 546862 9673 2021-04-13 2308 49.0\n", + "5476248 697262 15297 2021-08-20 18307 63.0\n", + "5476249 384202 16197 2021-04-19 6203 100.0\n", + "5476250 319709 4436 2021-08-15 3921 45.0" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([interactions.head(), interactions.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c5c3ce6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Interactions dataframe shape: (5476251, 5)\n", + "Unique users in interactions: 962179\n", + "Unique items in interactions: 15706\n" + ] + } + ], + "source": [ + "print(f\"Interactions dataframe shape: {interactions.shape}\")\n", + "print(f\"Unique users in interactions: {interactions[Columns.User].nunique()}\")\n", + "print(f\"Unique items in interactions: {interactions[Columns.Item].nunique()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0214a978", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "min date in interactions: 2021-03-13 00:00:00\n", + "max date in interactions: 2021-08-22 00:00:00\n" + ] + } + ], + "source": [ + "max_date = interactions[Columns.Datetime].max()\n", + "min_date = interactions[Columns.Datetime].min()\n", + "\n", + "print(f\"min date in interactions: {min_date}\")\n", + "print(f\"max date in interactions: {max_date}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7829e796", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 5476251 entries, 0 to 5476250\n", + "Data columns (total 5 columns):\n", + " # Column Dtype \n", + "--- ------ ----- \n", + " 0 user_id int64 \n", + " 1 item_id int64 \n", + " 2 datetime datetime64[ns]\n", + " 3 weight int64 \n", + " 4 watched_pct float64 \n", + "dtypes: datetime64[ns](1), float64(1), int64(3)\n", + "memory usage: 208.9 MB\n" + ] + } + ], + "source": [ + "interactions.info()" + ] + }, + { + "cell_type": "markdown", + "id": "57cddf34", + "metadata": {}, + "source": [ + "## Users" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "de5dea16", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idageincomesexkids_flg
0973171age_25_34income_60_90М1
1962099age_18_24income_20_40М0
21047345age_45_54income_40_60Ж0
3721985age_45_54income_20_40Ж0
4704055age_35_44income_60_90Ж0
840192339025age_65_infincome_0_20Ж0
840193983617age_18_24income_20_40Ж1
840194251008NaNNaNNaN0
840195590706NaNNaNЖ0
840196166555age_65_infincome_20_40Ж0
\n", + "
" + ], + "text/plain": [ + " user_id age income sex kids_flg\n", + "0 973171 age_25_34 income_60_90 М 1\n", + "1 962099 age_18_24 income_20_40 М 0\n", + "2 1047345 age_45_54 income_40_60 Ж 0\n", + "3 721985 age_45_54 income_20_40 Ж 0\n", + "4 704055 age_35_44 income_60_90 Ж 0\n", + "840192 339025 age_65_inf income_0_20 Ж 0\n", + "840193 983617 age_18_24 income_20_40 Ж 1\n", + "840194 251008 NaN NaN NaN 0\n", + "840195 590706 NaN NaN Ж 0\n", + "840196 166555 age_65_inf income_20_40 Ж 0" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([users.head(), users.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e4e6d2f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Users dataframe shape (840197, 5)\n", + "Unique users: 840197\n" + ] + } + ], + "source": [ + "print(f\"Users dataframe shape {users.shape}\")\n", + "print(f\"Unique users: {users['user_id'].nunique()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "98b4ff6c", + "metadata": {}, + "source": [ + "## Items" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "19b43ff0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_typetitletitle_origrelease_yeargenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywords
010711filmПоговори с нейHable con ella2002.0драмы, зарубежные, детективы, мелодрамыИспанияNaN16.0NaNПедро АльмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...
12508filmГолые перцыSearch Party2014.0зарубежные, приключения, комедииСШАNaN16.0NaNСкот АрмстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...
159614538seriesСреди камнейDarklands2019.0драмы, спорт, криминалРоссия0.018.0NaNМарк О’Коннор, Конор МакМахонДэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд...Семнадцатилетний Дэмиен мечтает вырваться за п...Среди, камней, 2019, Россия
159623206seriesГошаNaN2019.0комедииРоссия0.016.0NaNМихаил МироновМкртыч Арзуманян, Виктория РунцоваДобродушный Гоша не может выйти из дома, чтобы...Гоша, 2019, Россия
\n", + "
" + ], + "text/plain": [ + " item_id content_type title title_orig release_year \\\n", + "0 10711 film Поговори с ней Hable con ella 2002.0 \n", + "1 2508 film Голые перцы Search Party 2014.0 \n", + "15961 4538 series Среди камней Darklands 2019.0 \n", + "15962 3206 series Гоша NaN 2019.0 \n", + "\n", + " genres countries for_kids \\\n", + "0 драмы, зарубежные, детективы, мелодрамы Испания NaN \n", + "1 зарубежные, приключения, комедии США NaN \n", + "15961 драмы, спорт, криминал Россия 0.0 \n", + "15962 комедии Россия 0.0 \n", + "\n", + " age_rating studios directors \\\n", + "0 16.0 NaN Педро Альмодовар \n", + "1 16.0 NaN Скот Армстронг \n", + "15961 18.0 NaN Марк О’Коннор, Конор МакМахон \n", + "15962 16.0 NaN Михаил Миронов \n", + "\n", + " actors \\\n", + "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", + "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", + "15961 Дэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд... \n", + "15962 Мкртыч Арзуманян, Виктория Рунцова \n", + "\n", + " description \\\n", + "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", + "1 Уморительная современная комедия на популярную... \n", + "15961 Семнадцатилетний Дэмиен мечтает вырваться за п... \n", + "15962 Добродушный Гоша не может выйти из дома, чтобы... \n", + "\n", + " keywords \n", + "0 Поговори, ней, 2002, Испания, друзья, любовь, ... \n", + "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... \n", + "15961 Среди, камней, 2019, Россия \n", + "15962 Гоша, 2019, Россия " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([items.head(2), items.tail(2)])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8c8fb319", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Items dataframe shape (15963, 14)\n", + "Unique item_id: 15963\n" + ] + } + ], + "source": [ + "print(f\"Items dataframe shape {items.shape}\")\n", + "print(f\"Unique item_id: {items['item_id'].nunique()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b35b460", + "metadata": {}, + "source": [ + "# userkNN model CV" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f60e6ecb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "alignmentgroup": "True", + "hovertemplate": "variable=user_id
datetime=%{x}
value=%{y}", + "legendgroup": "user_id", + "marker": { + "color": "#636efa", + "pattern": { + "shape": "" + } + }, + "name": "user_id", + "offsetgroup": "user_id", + "orientation": "v", + "showlegend": true, + "textposition": "auto", + "type": "bar", + "x": [ + "2021-03-13T00:00:00", + "2021-03-14T00:00:00", + "2021-03-15T00:00:00", + "2021-03-16T00:00:00", + "2021-03-17T00:00:00", + "2021-03-18T00:00:00", + "2021-03-19T00:00:00", + "2021-03-20T00:00:00", + "2021-03-21T00:00:00", + "2021-03-22T00:00:00", + "2021-03-23T00:00:00", + "2021-03-24T00:00:00", + "2021-03-25T00:00:00", + "2021-03-26T00:00:00", + "2021-03-27T00:00:00", + "2021-03-28T00:00:00", + "2021-03-29T00:00:00", + "2021-03-30T00:00:00", + "2021-03-31T00:00:00", + "2021-04-01T00:00:00", + "2021-04-02T00:00:00", + "2021-04-03T00:00:00", + "2021-04-04T00:00:00", + "2021-04-05T00:00:00", + "2021-04-06T00:00:00", + "2021-04-07T00:00:00", + "2021-04-08T00:00:00", + "2021-04-09T00:00:00", + "2021-04-10T00:00:00", + "2021-04-11T00:00:00", + "2021-04-12T00:00:00", + "2021-04-13T00:00:00", + "2021-04-14T00:00:00", + "2021-04-15T00:00:00", + "2021-04-16T00:00:00", + "2021-04-17T00:00:00", + "2021-04-18T00:00:00", + "2021-04-19T00:00:00", + "2021-04-20T00:00:00", + "2021-04-21T00:00:00", + "2021-04-22T00:00:00", + "2021-04-23T00:00:00", + "2021-04-24T00:00:00", + "2021-04-25T00:00:00", + "2021-04-26T00:00:00", + "2021-04-27T00:00:00", + "2021-04-28T00:00:00", + "2021-04-29T00:00:00", + "2021-04-30T00:00:00", + "2021-05-01T00:00:00", + "2021-05-02T00:00:00", + "2021-05-03T00:00:00", + "2021-05-04T00:00:00", + "2021-05-05T00:00:00", + "2021-05-06T00:00:00", + "2021-05-07T00:00:00", + "2021-05-08T00:00:00", + "2021-05-09T00:00:00", + "2021-05-10T00:00:00", + "2021-05-11T00:00:00", + "2021-05-12T00:00:00", + "2021-05-13T00:00:00", + "2021-05-14T00:00:00", + "2021-05-15T00:00:00", + "2021-05-16T00:00:00", + "2021-05-17T00:00:00", + "2021-05-18T00:00:00", + "2021-05-19T00:00:00", + "2021-05-20T00:00:00", + "2021-05-21T00:00:00", + "2021-05-22T00:00:00", + "2021-05-23T00:00:00", + "2021-05-24T00:00:00", + "2021-05-25T00:00:00", + "2021-05-26T00:00:00", + "2021-05-27T00:00:00", + "2021-05-28T00:00:00", + "2021-05-29T00:00:00", + "2021-05-30T00:00:00", + "2021-05-31T00:00:00", + "2021-06-01T00:00:00", + "2021-06-02T00:00:00", + "2021-06-03T00:00:00", + "2021-06-04T00:00:00", + "2021-06-05T00:00:00", + "2021-06-06T00:00:00", + "2021-06-07T00:00:00", + "2021-06-08T00:00:00", + "2021-06-09T00:00:00", + "2021-06-10T00:00:00", + "2021-06-11T00:00:00", + "2021-06-12T00:00:00", + "2021-06-13T00:00:00", + "2021-06-14T00:00:00", + "2021-06-15T00:00:00", + "2021-06-16T00:00:00", + "2021-06-17T00:00:00", + "2021-06-18T00:00:00", + "2021-06-19T00:00:00", + "2021-06-20T00:00:00", + "2021-06-21T00:00:00", + "2021-06-22T00:00:00", + "2021-06-23T00:00:00", + "2021-06-24T00:00:00", + "2021-06-25T00:00:00", + "2021-06-26T00:00:00", + "2021-06-27T00:00:00", + "2021-06-28T00:00:00", + "2021-06-29T00:00:00", + "2021-06-30T00:00:00", + "2021-07-01T00:00:00", + "2021-07-02T00:00:00", + "2021-07-03T00:00:00", + "2021-07-04T00:00:00", + "2021-07-05T00:00:00", + "2021-07-06T00:00:00", + "2021-07-07T00:00:00", + "2021-07-08T00:00:00", + "2021-07-09T00:00:00", + "2021-07-10T00:00:00", + "2021-07-11T00:00:00", + "2021-07-12T00:00:00", + "2021-07-13T00:00:00", + "2021-07-14T00:00:00", + "2021-07-15T00:00:00", + "2021-07-16T00:00:00", + "2021-07-17T00:00:00", + "2021-07-18T00:00:00", + "2021-07-19T00:00:00", + "2021-07-20T00:00:00", + "2021-07-21T00:00:00", + "2021-07-22T00:00:00", + "2021-07-23T00:00:00", + "2021-07-24T00:00:00", + "2021-07-25T00:00:00", + "2021-07-26T00:00:00", + "2021-07-27T00:00:00", + "2021-07-28T00:00:00", + "2021-07-29T00:00:00", + "2021-07-30T00:00:00", + "2021-07-31T00:00:00", + "2021-08-01T00:00:00", + "2021-08-02T00:00:00", + "2021-08-03T00:00:00", + "2021-08-04T00:00:00", + "2021-08-05T00:00:00", + "2021-08-06T00:00:00", + "2021-08-07T00:00:00", + "2021-08-08T00:00:00", + "2021-08-09T00:00:00", + "2021-08-10T00:00:00", + "2021-08-11T00:00:00", + "2021-08-12T00:00:00", + "2021-08-13T00:00:00", + "2021-08-14T00:00:00", + "2021-08-15T00:00:00", + "2021-08-16T00:00:00", + "2021-08-17T00:00:00", + "2021-08-18T00:00:00", + "2021-08-19T00:00:00", + "2021-08-20T00:00:00", + "2021-08-21T00:00:00", + "2021-08-22T00:00:00" + ], + "xaxis": "x", + "y": [ + 16104, + 15606, + 12363, + 12643, + 12753, + 12788, + 13657, + 15346, + 15560, + 12752, + 13147, + 13435, + 12698, + 13909, + 15657, + 16112, + 12783, + 13101, + 13460, + 12966, + 14084, + 15431, + 15346, + 12642, + 12528, + 13129, + 13827, + 14416, + 15937, + 16046, + 12835, + 12322, + 12451, + 12275, + 13342, + 15464, + 16275, + 14286, + 20420, + 23200, + 21274, + 22127, + 26161, + 28964, + 21625, + 22590, + 21406, + 19987, + 21406, + 23479, + 24767, + 26267, + 25983, + 23941, + 23510, + 23201, + 27550, + 25986, + 27242, + 20957, + 20578, + 20729, + 21152, + 24530, + 24914, + 20960, + 20574, + 21561, + 22712, + 25697, + 27895, + 29978, + 24317, + 23667, + 22529, + 23881, + 24131, + 29035, + 31308, + 26821, + 26587, + 27577, + 28683, + 33150, + 34795, + 37096, + 31402, + 31107, + 32896, + 38964, + 37935, + 38619, + 42125, + 38973, + 35993, + 57686, + 41440, + 42174, + 43679, + 47989, + 39127, + 39693, + 41688, + 38394, + 41428, + 45898, + 48903, + 43301, + 43887, + 67749, + 53900, + 46642, + 48832, + 52812, + 43375, + 41380, + 41163, + 41592, + 40955, + 44798, + 46250, + 42487, + 43764, + 43128, + 43010, + 44878, + 49714, + 54139, + 45541, + 44431, + 44422, + 46313, + 46911, + 50317, + 54378, + 48531, + 49324, + 50267, + 50585, + 53121, + 59499, + 62128, + 53495, + 52181, + 51911, + 51047, + 53745, + 59316, + 61454, + 52794, + 53712, + 55617, + 56497, + 55843, + 61644, + 66546, + 54546, + 54311, + 56789, + 58640, + 60145, + 68834, + 71171 + ], + "yaxis": "y" + } + ], + "layout": { + "barmode": "relative", + "legend": { + "title": { + "text": "variable" + }, + "tracegroupgap": 0 + }, + "margin": { + "t": 60 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "datetime" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "value" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = px.bar(interactions.groupby(Columns.Datetime)[Columns.User].agg('count'))\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "43f216d0", + "metadata": {}, + "source": [ + "Из графика видны **недельные тенденции** просмотров, поэтому следует fold-ы разделять по 7 дней, но т.к. на семинаре дали \"намек\", что private dataset имеет количество дней, меньшее чем 7. Поэтому фолды будут разбиваться на **5 и 7 дней**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "07fbdb30", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.to_datetime('23-05-2021', format='%d-%m-%Y').weekday()" + ] + }, + { + "cell_type": "markdown", + "id": "2ff625b2", + "metadata": {}, + "source": [ + "### train test split" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "759ba346", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_range(\n", + " last_date: pd.Timestamp, \n", + " n_folds: int = 7, \n", + " unit: str = \"W\", \n", + " n_units: int = 1, \n", + " show: bool = True,\n", + "):\n", + " periods = n_folds + 1\n", + " freq = f\"{n_units}{unit}\"\n", + " \n", + " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n", + " \n", + " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", + " \n", + " if show:\n", + " print(\n", + " f\"start_date: {start_date}\\n\"\n", + " f\"last_date: {last_date}\\n\"\n", + " f\"periods: {periods}\\n\"\n", + " f\"freq: {freq}\\n\"\n", + " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n", + " )\n", + " \n", + " return date_range" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "38bfd397", + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG_CV = {\n", + " \"cv_v1\": {\n", + " \"n_folds\": 7,\n", + " \"unit\": \"W\",\n", + " \"n_units\": 1,\n", + " },\n", + " \"cv_v2\": {\n", + " \"n_folds\": 7,\n", + " \"unit\": \"D\",\n", + " \"n_units\": 5,\n", + " }, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f518e089", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Timestamp('2021-08-22 00:00:00')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_date = interactions[Columns.Datetime].max().normalize()\n", + "last_date" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1fd68b9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "***Folds v1***\n", + "start_date: 2021-07-13 00:00:00\n", + "last_date: 2021-08-22 00:00:00\n", + "periods: 8\n", + "freq: 5D\n", + "Test fold borders: ['2021-07-13' '2021-07-18' '2021-07-23' '2021-07-28' '2021-08-02'\n", + " '2021-08-07' '2021-08-12' '2021-08-17']\n", + "\n" + ] + } + ], + "source": [ + "print(\"***Folds v1***\")\n", + "date_range_v1 = create_data_range(\n", + " last_date, \n", + " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n", + " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n", + " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "efc59555", + "metadata": {}, + "source": [ + "**генерируем фолды** " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9fae43f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Real number of folds: 7\n" + ] + } + ], + "source": [ + "cv_v1 = TimeRangeSplitter(\n", + " date_range=date_range_v1,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n", + "\n", + "CV = [cv_v1]" + ] + }, + { + "cell_type": "markdown", + "id": "e15a83a7", + "metadata": {}, + "source": [ + "**Формируем метрики**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8f7742c6", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b21a1ecf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'cosine_userknn_K30': ,\n", + " 'tfidf_userknn_K30': ,\n", + " 'bm25_userknn_K30': ,\n", + " 'cosine_userknn_K40': ,\n", + " 'tfidf_userknn_K40': ,\n", + " 'bm25_userknn_K40': }" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "K = [30, 40]\n", + "models = dict()\n", + "\n", + "for k in K:\n", + " models[f\"cosine_userknn_K{k}\"] = CosineRecommender(K=k)\n", + " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", + " models[f\"bm25_userknn_K{k}\"] = BM25Recommender(K=k)\n", + "\n", + "models" + ] + }, + { + "cell_type": "markdown", + "id": "0103149a", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e78b8221", + "metadata": {}, + "outputs": [], + "source": [ + "N_USERS = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "50dcff0b", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "results = []\n", + "\n", + "for idx, cv in enumerate(CV):\n", + " print(f\"\\n CV version {idx}\")\n", + " fold_iterator = cv.split(Interactions(interactions), collect_fold_stats=True)\n", + "\n", + " for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):\n", + " print(f\"\\n==================== Fold {i_fold}\")\n", + " pprint(fold_info)\n", + "\n", + " df_train = interactions.iloc[train_ids].copy()\n", + " df_test = interactions.iloc[test_ids][Columns.UserItem].copy()\n", + "\n", + " catalog = df_train[Columns.Item].unique()\n", + "\n", + " for model_name, model in models.items():\n", + " userknn_model = UserKnn(model=model, N_users=N_USERS, use_weight_idf=True)\n", + " userknn_model.fit(df_train)\n", + "\n", + " if 'bm25' in model_name:\n", + " recos = userknn_model.predict(df_test, bmp25=True)\n", + " else:\n", + " recos = userknn_model.predict(df_test)\n", + "\n", + " metric_values = calc_metrics(\n", + " metrics,\n", + " reco=recos,\n", + " interactions=df_test,\n", + " prev_interactions=df_train,\n", + " catalog=catalog,\n", + " )\n", + "\n", + " full_model_name = f\"{model_name}_cv-{idx}\"\n", + " fold = {\"fold\": i_fold, \"model\": full_model_name}\n", + " fold.update(metric_values)\n", + " results.append(fold)" + ] + }, + { + "cell_type": "markdown", + "id": "708ec5c2", + "metadata": {}, + "source": [ + "Работало больше 10 часов, случайно при перезапуске ноутбука была вызвана ячейка и остановлена, поэтому завершилась с ошибкой, поэтому ошибку убрали для лучшего вида" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "d7e2ffa7", + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldmodelprec@10recall@10MAP@10noveltyserendipity
00cosine_userknn_K30_cv-00.0035570.0211280.0036958.3314910.000040
10tfidf_userknn_K30_cv-00.0064390.0391020.0073358.1550510.000048
20bm25_userknn_K30_cv-00.0025930.0134940.0025319.3984670.000081
30cosine_userknn_K40_cv-00.0032820.0193230.0034018.5615230.000043
40tfidf_userknn_K40_cv-00.0061780.0374580.0069578.3004040.000052
50bm25_userknn_K40_cv-00.0022410.0112550.0022109.6755330.000081
61cosine_userknn_K30_cv-00.0035050.0200020.0035808.3982480.000046
71tfidf_userknn_K30_cv-00.0063280.0368440.0070228.2401330.000058
81bm25_userknn_K30_cv-00.0027220.0138560.0026589.4846920.000088
91cosine_userknn_K40_cv-00.0032450.0183680.0033058.6269060.000047
101tfidf_userknn_K40_cv-00.0061500.0359640.0069168.3779880.000061
111bm25_userknn_K40_cv-00.0024060.0120670.0023939.7564580.000086
122cosine_userknn_K30_cv-00.0032610.0184980.0032958.4392630.000047
132tfidf_userknn_K30_cv-00.0059400.0342330.0064798.2623670.000059
142bm25_userknn_K30_cv-00.0027200.0134220.0025309.5356310.000091
152cosine_userknn_K40_cv-00.0030450.0170860.0031008.6615850.000050
162tfidf_userknn_K40_cv-00.0059140.0340710.0064398.3966180.000063
172bm25_userknn_K40_cv-00.0024040.0116380.0022319.7991190.000090
183cosine_userknn_K30_cv-00.0032770.0187860.0033958.4449860.000045
193tfidf_userknn_K30_cv-00.0060230.0341710.0063288.2765030.000059
203bm25_userknn_K30_cv-00.0026200.0127620.0024979.5609840.000091
213cosine_userknn_K40_cv-00.0030760.0175120.0031738.6581500.000045
223tfidf_userknn_K40_cv-00.0059190.0333680.0062538.3991690.000062
233bm25_userknn_K40_cv-00.0023370.0112730.0022539.8163250.000089
244cosine_userknn_K30_cv-00.0031180.0180640.0031578.4858990.000042
254tfidf_userknn_K30_cv-00.0059110.0336260.0063968.2824280.000059
264bm25_userknn_K30_cv-00.0025370.0123680.0024709.5996450.000086
274cosine_userknn_K40_cv-00.0028720.0165090.0028838.7119840.000043
284tfidf_userknn_K40_cv-00.0057930.0330280.0062618.4166800.000062
294bm25_userknn_K40_cv-00.0022130.0108600.0021799.8662010.000085
305cosine_userknn_K30_cv-00.0030030.0162520.0028998.4989680.000043
315tfidf_userknn_K30_cv-00.0055270.0309420.0058238.3252730.000057
325bm25_userknn_K30_cv-00.0025970.0122630.0023869.6469570.000100
335cosine_userknn_K40_cv-00.0027650.0147130.0026618.7175590.000047
345tfidf_userknn_K40_cv-00.0055450.0308920.0058178.4540910.000059
355bm25_userknn_K40_cv-00.0023020.0107770.0021359.9140420.000100
366cosine_userknn_K30_cv-00.0029630.0165320.0028878.5638090.000050
376tfidf_userknn_K30_cv-00.0053300.0307170.0057638.3662590.000064
386bm25_userknn_K30_cv-00.0025710.0126910.0024789.7150970.000100
396cosine_userknn_K40_cv-00.0027690.0154480.0026758.7750580.000051
406tfidf_userknn_K40_cv-00.0052840.0304180.0056978.4884730.000066
416bm25_userknn_K40_cv-00.0023400.0112780.0022089.9646640.000099
\n", + "
" + ], + "text/plain": [ + " fold model prec@10 recall@10 MAP@10 novelty \\\n", + "0 0 cosine_userknn_K30_cv-0 0.003557 0.021128 0.003695 8.331491 \n", + "1 0 tfidf_userknn_K30_cv-0 0.006439 0.039102 0.007335 8.155051 \n", + "2 0 bm25_userknn_K30_cv-0 0.002593 0.013494 0.002531 9.398467 \n", + "3 0 cosine_userknn_K40_cv-0 0.003282 0.019323 0.003401 8.561523 \n", + "4 0 tfidf_userknn_K40_cv-0 0.006178 0.037458 0.006957 8.300404 \n", + "5 0 bm25_userknn_K40_cv-0 0.002241 0.011255 0.002210 9.675533 \n", + "6 1 cosine_userknn_K30_cv-0 0.003505 0.020002 0.003580 8.398248 \n", + "7 1 tfidf_userknn_K30_cv-0 0.006328 0.036844 0.007022 8.240133 \n", + "8 1 bm25_userknn_K30_cv-0 0.002722 0.013856 0.002658 9.484692 \n", + "9 1 cosine_userknn_K40_cv-0 0.003245 0.018368 0.003305 8.626906 \n", + "10 1 tfidf_userknn_K40_cv-0 0.006150 0.035964 0.006916 8.377988 \n", + "11 1 bm25_userknn_K40_cv-0 0.002406 0.012067 0.002393 9.756458 \n", + "12 2 cosine_userknn_K30_cv-0 0.003261 0.018498 0.003295 8.439263 \n", + "13 2 tfidf_userknn_K30_cv-0 0.005940 0.034233 0.006479 8.262367 \n", + "14 2 bm25_userknn_K30_cv-0 0.002720 0.013422 0.002530 9.535631 \n", + "15 2 cosine_userknn_K40_cv-0 0.003045 0.017086 0.003100 8.661585 \n", + "16 2 tfidf_userknn_K40_cv-0 0.005914 0.034071 0.006439 8.396618 \n", + "17 2 bm25_userknn_K40_cv-0 0.002404 0.011638 0.002231 9.799119 \n", + "18 3 cosine_userknn_K30_cv-0 0.003277 0.018786 0.003395 8.444986 \n", + "19 3 tfidf_userknn_K30_cv-0 0.006023 0.034171 0.006328 8.276503 \n", + "20 3 bm25_userknn_K30_cv-0 0.002620 0.012762 0.002497 9.560984 \n", + "21 3 cosine_userknn_K40_cv-0 0.003076 0.017512 0.003173 8.658150 \n", + "22 3 tfidf_userknn_K40_cv-0 0.005919 0.033368 0.006253 8.399169 \n", + "23 3 bm25_userknn_K40_cv-0 0.002337 0.011273 0.002253 9.816325 \n", + "24 4 cosine_userknn_K30_cv-0 0.003118 0.018064 0.003157 8.485899 \n", + "25 4 tfidf_userknn_K30_cv-0 0.005911 0.033626 0.006396 8.282428 \n", + "26 4 bm25_userknn_K30_cv-0 0.002537 0.012368 0.002470 9.599645 \n", + "27 4 cosine_userknn_K40_cv-0 0.002872 0.016509 0.002883 8.711984 \n", + "28 4 tfidf_userknn_K40_cv-0 0.005793 0.033028 0.006261 8.416680 \n", + "29 4 bm25_userknn_K40_cv-0 0.002213 0.010860 0.002179 9.866201 \n", + "30 5 cosine_userknn_K30_cv-0 0.003003 0.016252 0.002899 8.498968 \n", + "31 5 tfidf_userknn_K30_cv-0 0.005527 0.030942 0.005823 8.325273 \n", + "32 5 bm25_userknn_K30_cv-0 0.002597 0.012263 0.002386 9.646957 \n", + "33 5 cosine_userknn_K40_cv-0 0.002765 0.014713 0.002661 8.717559 \n", + "34 5 tfidf_userknn_K40_cv-0 0.005545 0.030892 0.005817 8.454091 \n", + "35 5 bm25_userknn_K40_cv-0 0.002302 0.010777 0.002135 9.914042 \n", + "36 6 cosine_userknn_K30_cv-0 0.002963 0.016532 0.002887 8.563809 \n", + "37 6 tfidf_userknn_K30_cv-0 0.005330 0.030717 0.005763 8.366259 \n", + "38 6 bm25_userknn_K30_cv-0 0.002571 0.012691 0.002478 9.715097 \n", + "39 6 cosine_userknn_K40_cv-0 0.002769 0.015448 0.002675 8.775058 \n", + "40 6 tfidf_userknn_K40_cv-0 0.005284 0.030418 0.005697 8.488473 \n", + "41 6 bm25_userknn_K40_cv-0 0.002340 0.011278 0.002208 9.964664 \n", + "\n", + " serendipity \n", + "0 0.000040 \n", + "1 0.000048 \n", + "2 0.000081 \n", + "3 0.000043 \n", + "4 0.000052 \n", + "5 0.000081 \n", + "6 0.000046 \n", + "7 0.000058 \n", + "8 0.000088 \n", + "9 0.000047 \n", + "10 0.000061 \n", + "11 0.000086 \n", + "12 0.000047 \n", + "13 0.000059 \n", + "14 0.000091 \n", + "15 0.000050 \n", + "16 0.000063 \n", + "17 0.000090 \n", + "18 0.000045 \n", + "19 0.000059 \n", + "20 0.000091 \n", + "21 0.000045 \n", + "22 0.000062 \n", + "23 0.000089 \n", + "24 0.000042 \n", + "25 0.000059 \n", + "26 0.000086 \n", + "27 0.000043 \n", + "28 0.000062 \n", + "29 0.000085 \n", + "30 0.000043 \n", + "31 0.000057 \n", + "32 0.000100 \n", + "33 0.000047 \n", + "34 0.000059 \n", + "35 0.000100 \n", + "36 0.000050 \n", + "37 0.000064 \n", + "38 0.000100 \n", + "39 0.000051 \n", + "40 0.000066 \n", + "41 0.000099 " + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics = pd.DataFrame(results)\n", + "df_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "a0334b9a", + "metadata": {}, + "outputs": [], + "source": [ + "df_metrics.to_pickle(\"../data/hw_3/df_metrics.pickle\")" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "446530ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldprec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-03.00.0026230.0129800.0025079.5630680.000091
bm25_userknn_K40_cv-03.00.0023200.0113070.0022309.8274770.000090
cosine_userknn_K30_cv-03.00.0032410.0184660.0032728.4518090.000045
cosine_userknn_K40_cv-03.00.0030080.0169940.0030288.6732520.000047
tfidf_userknn_K30_cv-03.00.0059280.0342340.0064498.2725730.000058
tfidf_userknn_K40_cv-03.00.0058260.0336000.0063348.4047750.000061
\n", + "
" + ], + "text/plain": [ + " fold prec@10 recall@10 MAP@10 novelty \\\n", + "model \n", + "bm25_userknn_K30_cv-0 3.0 0.002623 0.012980 0.002507 9.563068 \n", + "bm25_userknn_K40_cv-0 3.0 0.002320 0.011307 0.002230 9.827477 \n", + "cosine_userknn_K30_cv-0 3.0 0.003241 0.018466 0.003272 8.451809 \n", + "cosine_userknn_K40_cv-0 3.0 0.003008 0.016994 0.003028 8.673252 \n", + "tfidf_userknn_K30_cv-0 3.0 0.005928 0.034234 0.006449 8.272573 \n", + "tfidf_userknn_K40_cv-0 3.0 0.005826 0.033600 0.006334 8.404775 \n", + "\n", + " serendipity \n", + "model \n", + "bm25_userknn_K30_cv-0 0.000091 \n", + "bm25_userknn_K40_cv-0 0.000090 \n", + "cosine_userknn_K30_cv-0 0.000045 \n", + "cosine_userknn_K40_cv-0 0.000047 \n", + "tfidf_userknn_K30_cv-0 0.000058 \n", + "tfidf_userknn_K40_cv-0 0.000061 " + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics.groupby('model').mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "5fb9ba9f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
prec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-00.0000720.0006120.0000830.1044680.000007
bm25_userknn_K40_cv-00.0000740.0004420.0000810.0973590.000007
cosine_userknn_K30_cv-00.0002310.0017490.0003140.0746990.000003
cosine_userknn_K40_cv-00.0002130.0016030.0002950.0693100.000003
tfidf_userknn_K30_cv-00.0003980.0030030.0005770.0666270.000005
tfidf_userknn_K40_cv-00.0003210.0025340.0004870.0595650.000004
\n", + "
" + ], + "text/plain": [ + " prec@10 recall@10 MAP@10 novelty serendipity\n", + "model \n", + "bm25_userknn_K30_cv-0 0.000072 0.000612 0.000083 0.104468 0.000007\n", + "bm25_userknn_K40_cv-0 0.000074 0.000442 0.000081 0.097359 0.000007\n", + "cosine_userknn_K30_cv-0 0.000231 0.001749 0.000314 0.074699 0.000003\n", + "cosine_userknn_K40_cv-0 0.000213 0.001603 0.000295 0.069310 0.000003\n", + "tfidf_userknn_K30_cv-0 0.000398 0.003003 0.000577 0.066627 0.000005\n", + "tfidf_userknn_K40_cv-0 0.000321 0.002534 0.000487 0.059565 0.000004" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics.groupby('model').std()[metrics.keys()]" + ] + }, + { + "cell_type": "markdown", + "id": "41828ee5", + "metadata": {}, + "source": [ + "по **ofline** метрикам лучше всего себя показывает модель TFIDFRecommender\n", + "TFIDFRecommender подбор К" + ] + }, + { + "cell_type": "markdown", + "id": "7a8a0a41", + "metadata": {}, + "source": [ + "# Подбор оптимального K для TFIDFRecommender" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "1e91892d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tfidf_userknn_K50': ,\n", + " 'tfidf_userknn_K60': ,\n", + " 'tfidf_userknn_K70': }" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "N_USERS = 50\n", + "\n", + "# Т.к. метрики для К 30 и 40 уже есть\n", + "K = [k for k in range(50, 71, 10)]\n", + "models = dict()\n", + "\n", + "for k in K:\n", + " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", + "models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c2c43b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================== Fold 0\n", + "{'End date': Timestamp('2021-07-18 00:00:00', freq='5D'),\n", + " 'Start date': Timestamp('2021-07-13 00:00:00', freq='5D'),\n", + " 'Test': 156580,\n", + " 'Test items': 5793,\n", + " 'Test users': 68150,\n", + " 'Train': 3281612,\n", + " 'Train items': 14754,\n", + " 'Train users': 652905}\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "211234f034a54bae86b94dff33b9f5c4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/652905 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idlast_watch_dttotal_durwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", + "" + ], + "text/plain": [ + " user_id item_id last_watch_dt total_dur watched_pct\n", + "0 176549 9506 2021-05-11 4250 72.0\n", + "1 699317 1659 2021-05-29 8317 100.0\n", + "2 656683 7107 2021-05-09 10 0.0\n", + "3 864613 7638 2021-07-05 14483 100.0\n", + "4 964868 9506 2021-04-30 6725 100.0\n", + "5476246 648596 12225 2021-08-13 76 0.0\n", + "5476247 546862 9673 2021-04-13 2308 49.0\n", + "5476248 697262 15297 2021-08-20 18307 63.0\n", + "5476249 384202 16197 2021-04-19 6203 100.0\n", + "5476250 319709 4436 2021-08-15 3921 45.0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([interactions.head(), interactions.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "dc4d9fd7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(962179,)" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions['user_id'].unique().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "b7861d19", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(961833, 1.0),\n", + " (961849, 1.0),\n", + " (961857, 1.0),\n", + " (961871, 1.0),\n", + " (961873, 1.0),\n", + " (961876, 1.0),\n", + " (961887, 1.0),\n", + " (961907, 1.0),\n", + " (961910, 1.0),\n", + " (961912, 1.0)]" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import dill\n", + "\n", + "with open('../service/weights/userKNN/userknn_tfidf_k30.dill', 'rb') as f:\n", + " userknn = dill.load(f)\n", + "\n", + "userknn.similar_items(962178, 10)" + ] + }, + { + "cell_type": "markdown", + "id": "1905033a", + "metadata": {}, + "source": [ + "# Popular Model" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "2df74dba", + "metadata": {}, + "outputs": [], + "source": [ + "from rectools.models import PopularModel\n", + "from rectools.dataset import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6ba37a73", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Timestamp('2021-08-22 00:00:00')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_date = interactions[Columns.Datetime].max().normalize()\n", + "max_date" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "901353f9", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "train = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", + " interactions[Columns.Datetime] < max_date - pd.Timedelta(5, \"D\")]\n", + "\n", + "test = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", + " interactions[Columns.Datetime] >= max_date - pd.Timedelta(5, \"D\")]\n", + "\n", + "dataset_train = Dataset.construct(train)" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "id": "f08e3579", + "metadata": {}, + "outputs": [], + "source": [ + "popilarity_models = {\n", + " \"popular\": PopularModel(),\n", + " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "id": "03c3bfb6", + "metadata": {}, + "outputs": [], + "source": [ + "popilarity_models[\"popular\"].fit(dataset_train)\n", + "popilarity_models[\"popular_mw\"].fit(dataset_train);" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "id": "0d7de49e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 24, 20, 31, 15, 167, 81, 89, 135, 355, 116])" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "popilarity_models[\"popular\"].popularity_list[0][:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "id": "05ff208d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([11363, 11681, 12841, 13017, 2069, 13691, 13552, 13397, 11774,\n", + " 12913])" + ] + }, + "execution_count": 147, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "popilarity_models[\"popular_mw\"].popularity_list[0][:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "id": "00ef735c", + "metadata": {}, + "outputs": [], + "source": [ + "pecos_pop = popilarity_models[\"popular\"].recommend(\n", + " users=test[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=100,\n", + " filter_viewed=False,\n", + ")\n", + "\n", + "pecos_pop_mw = popilarity_models[\"popular_mw\"].recommend(\n", + " users=test[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=100,\n", + " filter_viewed=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "id": "b302db55", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {\n", + " \"prec@5\": Precision(k=5),\n", + " \"recall@5\": Recall(k=5),\n", + " \"MAP@5\": MAP(k=5),\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@20\": MAP(k=20),\n", + " \"prec@20\": Precision(k=20),\n", + " \"recall@20\": Recall(k=20),\n", + " \"MAP@100\": MAP(k=100),\n", + " \"prec@100\": Precision(k=100),\n", + " \"recall@100\": Recall(k=100),\n", + " \"MAP@100\": MAP(k=100),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "catalog = train[Columns.Item].unique()\n", + "metric_values_pop = calc_metrics(metrics, pecos_pop, test, train, catalog)\n", + "metric_values_pop_mean_weight = calc_metrics(metrics, pecos_pop_mw, test, train, catalog)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "id": "9631093b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@5': 0.0017855613317256697,\n", + " 'recall@5': 0.004623809755660008,\n", + " 'prec@10': 0.0011648975773029461,\n", + " 'recall@10': 0.005682095875283048,\n", + " 'prec@20': 0.0010502526799891945,\n", + " 'recall@20': 0.00880186008464912,\n", + " 'prec@100': 0.003247020220987923,\n", + " 'recall@100': 0.16609031082955295,\n", + " 'MAP@5': 0.0013179725619140792,\n", + " 'MAP@20': 0.0016695313583723814,\n", + " 'MAP@100': 0.005578924867474493,\n", + " 'novelty': 9.976033936531364,\n", + " 'serendipity': 1.2752762676592953e-05}" + ] + }, + "execution_count": 153, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metric_values_pop" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "id": "5d55b781", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@5': 9.09252633867684e-05,\n", + " 'recall@5': 0.00014799438063171262,\n", + " 'prec@10': 4.612151041357817e-05,\n", + " 'recall@10': 0.00015458316783365238,\n", + " 'prec@20': 2.635514880775895e-05,\n", + " 'recall@20': 0.00016946607539568094,\n", + " 'prec@100': 0.00015147621777259455,\n", + " 'recall@100': 0.0065476971391510656,\n", + " 'MAP@5': 3.0257754846536496e-05,\n", + " 'MAP@20': 3.1771198360212185e-05,\n", + " 'MAP@100': 0.00011355765992119742,\n", + " 'novelty': 17.423655787689828,\n", + " 'serendipity': 1.8991632826477633e-06}" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metric_values_pop_mean_weight" + ] + }, + { + "cell_type": "markdown", + "id": "e5a4a011", + "metadata": {}, + "source": [ + "**На офлайн метриках выигрывает обычная модель по популярному**" + ] + }, + { + "cell_type": "markdown", + "id": "5875fab7", + "metadata": {}, + "source": [ + "# Save item_idf data" + ] + }, + { + "cell_type": "markdown", + "id": "6589996f", + "metadata": {}, + "source": [ + "Создаем датасет со взвешенными item-ами по механизму idf для использования в будущем" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "d62cabb9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexidf
095067.150811
116598.524953
271075.821207
376388.407093
466867.778734
.........
15701783314.822785
15702912514.822785
157031006414.822785
157041301914.822785
157051054214.822785
\n", + "

15706 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " index idf\n", + "0 9506 7.150811\n", + "1 1659 8.524953\n", + "2 7107 5.821207\n", + "3 7638 8.407093\n", + "4 6686 7.778734\n", + "... ... ...\n", + "15701 7833 14.822785\n", + "15702 9125 14.822785\n", + "15703 10064 14.822785\n", + "15704 13019 14.822785\n", + "15705 10542 14.822785\n", + "\n", + "[15706 rows x 2 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_cnt = Counter(interactions['item_id'].values)\n", + "item_idf = pd.DataFrame.from_dict(item_cnt, orient='index', columns=['doc_freq']).reset_index()\n", + "n = interactions.shape[0]\n", + "item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: np.log((1 + n) / (1 + x) + 1))\n", + "del item_idf['doc_freq']\n", + "item_idf" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "7da47dfc", + "metadata": {}, + "outputs": [], + "source": [ + "item_idf = item_idf.sort_values(\"idf\", ascending=False)\n", + "item_idf.to_csv('../data/kion_train/items_idf.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdce2b60", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/HW-3.2-rectools-research.ipynb b/notebooks/HW-3.2-rectools-research.ipynb new file mode 100644 index 00000000..ed456f5f --- /dev/null +++ b/notebooks/HW-3.2-rectools-research.ipynb @@ -0,0 +1,725 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "855d49cd", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import scipy as sp\n", + "import requests\n", + "from tqdm.auto import tqdm\n", + "from scipy.stats import mode \n", + "from pprint import pprint\n", + "from implicit.nearest_neighbours import CosineRecommender\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "from rectools import Columns\n", + "\n", + "pd.set_option('display.max_columns', None)\n", + "pd.set_option('display.max_colwidth', 200)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "655cd033", + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", + "\n", + "interactions.rename(columns={'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight}, \n", + " inplace=True) \n", + "\n", + "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "193c411d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start date and last date of the test fold: (Timestamp('2021-08-08 00:00:00'), Timestamp('2021-08-22 00:00:00'))\n", + "Test fold borders: ['2021-08-08' '2021-08-15']\n", + "Real number of folds: 1\n" + ] + } + ], + "source": [ + "from rectools.model_selection import TimeRangeSplitter\n", + "from rectools.dataset import Interactions\n", + "\n", + "n_folds = 1\n", + "unit = \"W\"\n", + "n_units = 1\n", + "periods = n_folds + 1\n", + "freq = f\"{n_units}{unit}\"\n", + "\n", + "last_date = interactions[Columns.Datetime].max().normalize()\n", + "start_date = last_date - pd.Timedelta(n_folds * n_units + 1, unit=unit) \n", + "print(f\"Start date and last date of the test fold: {start_date, last_date}\")\n", + " \n", + "date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", + "print(f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\")\n", + "\n", + "# generator of folds\n", + "cv = TimeRangeSplitter(\n", + " date_range=date_range,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "print(f\"Real number of folds: {cv.get_n_splits(Interactions(interactions))}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "38b80f9f", + "metadata": {}, + "outputs": [], + "source": [ + "(train_ids, test_ids, fold_info) = cv.split(Interactions(interactions), collect_fold_stats=True).__next__()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e3051991", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0, 1, 2, ..., 5476245, 5476247, 5476249])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7bc27a2f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 6, 33, 56, ..., 5476229, 5476230, 5476240])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ffdaad0c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "users_mapping amount: 842129\n", + "items_mapping amount: 15404\n" + ] + } + ], + "source": [ + "train = interactions.loc[train_ids]\n", + "test = interactions.loc[test_ids]\n", + "\n", + "users_inv_mapping = dict(enumerate(train['user_id'].unique()))\n", + "users_mapping = {v: k for k, v in users_inv_mapping.items()}\n", + "\n", + "items_inv_mapping = dict(enumerate(train['item_id'].unique()))\n", + "items_mapping = {v: k for k, v in items_inv_mapping.items()}\n", + "\n", + "print(f\"users_mapping amount: {len(users_mapping)}\")\n", + "print(f\"items_mapping amount: {len(items_mapping)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a6664026", + "metadata": {}, + "outputs": [], + "source": [ + "from rectools.dataset import Dataset\n", + "\n", + "dataset = Dataset.construct(\n", + " interactions_df=train,\n", + " user_features_df=None,\n", + " item_features_df=None\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "88f5a65c", + "metadata": {}, + "source": [ + "# ItemKNN CosineRecommender" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9c4682c5", + "metadata": {}, + "outputs": [], + "source": [ + "from implicit.nearest_neighbours import CosineRecommender\n", + "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", + "\n", + "item_knn = ImplicitItemKNNWrapperModel(model=CosineRecommender(K=30))\n", + "item_knn.fit(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "198faaa4", + "metadata": {}, + "outputs": [], + "source": [ + "recs_itemknn = item_knn.recommend(\n", + " test['user_id'].unique(), \n", + " dataset=dataset, \n", + " k=10, \n", + " filter_viewed=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "76d1a3f5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
010164581044020431.6311501
110164587348043.9999622
21016458121928033.5995303
3101645819867999.8057314
4101645844577763.2046075
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 1016458 10440 20431.631150 1\n", + "1 1016458 734 8043.999962 2\n", + "2 1016458 12192 8033.599530 3\n", + "3 1016458 1986 7999.805731 4\n", + "4 1016458 4457 7763.204607 5" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recs_itemknn.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c075a976", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@10': 0.017311708814214132,\n", + " 'recall@10': 0.09520897568691472,\n", + " 'MAP@10': 0.023145528903990274,\n", + " 'novelty': 8.05318572965277,\n", + " 'serendipity': 6.63288816067437e-05}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", + "\n", + "# calculate several classic (precision@k and recall@k) and \"beyond accuracy\" metrics\n", + "metrics = {\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "\n", + "catalog = train['item_id'].unique()\n", + "\n", + "metric_values_itemknn_cosine = calc_metrics(\n", + " metrics,\n", + " reco=recs_itemknn,\n", + " interactions=test,\n", + " prev_interactions=train,\n", + " catalog=catalog\n", + " )\n", + "\n", + "metric_values_itemknn_cosine" + ] + }, + { + "cell_type": "markdown", + "id": "b439f7fb", + "metadata": {}, + "source": [ + "# ItemKNN TFIDFRecommender" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "e31f5560", + "metadata": {}, + "outputs": [], + "source": [ + "from implicit.nearest_neighbours import TFIDFRecommender\n", + "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", + "\n", + "item_knn_tfidf = ImplicitItemKNNWrapperModel(model=TFIDFRecommender(K=30))\n", + "item_knn_tfidf.fit(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "360eafab", + "metadata": {}, + "outputs": [], + "source": [ + "recs_itemknn_tfidf = item_knn_tfidf.recommend(\n", + " test['user_id'].unique(), \n", + " dataset=dataset, \n", + " k=10, \n", + " filter_viewed=False \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "63c31f04", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
010164581044021745.3769271
11016458445710234.8633082
2101645871028987.8781293
31016458121928957.1098134
4101645819868369.8324485
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 1016458 10440 21745.376927 1\n", + "1 1016458 4457 10234.863308 2\n", + "2 1016458 7102 8987.878129 3\n", + "3 1016458 12192 8957.109813 4\n", + "4 1016458 1986 8369.832448 5" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recs_itemknn_tfidf.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "7a4d01f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@10': 0.023772589549238603,\n", + " 'recall@10': 0.12652382351172245,\n", + " 'MAP@10': 0.03005237337960426,\n", + " 'novelty': 6.699663403861505,\n", + " 'serendipity': 0.00010222896681730396}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", + "\n", + "metrics = {\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "\n", + "catalog = train['item_id'].unique()\n", + "\n", + "metric_values_itemknn_tfidf = calc_metrics(\n", + " metrics,\n", + " reco=recs_itemknn_tfidf,\n", + " interactions=test,\n", + " prev_interactions=train,\n", + " catalog=catalog\n", + " )\n", + "\n", + "metric_values_itemknn_tfidf" + ] + }, + { + "cell_type": "markdown", + "id": "2270cb27", + "metadata": {}, + "source": [ + "# UserKNN BMP25" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "c7997faf", + "metadata": {}, + "outputs": [], + "source": [ + "from implicit.nearest_neighbours import BM25Recommender\n", + "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", + "\n", + "item_knn_bmp = ImplicitItemKNNWrapperModel(model=BM25Recommender(K=30))\n", + "item_knn_bmp.fit(dataset);" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "c7ceb0e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
01016458104406.854547e+111
11016458152972.323138e+112
21016458138651.724740e+113
3101645897281.383208e+114
4101645841511.149358e+115
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 1016458 10440 6.854547e+11 1\n", + "1 1016458 15297 2.323138e+11 2\n", + "2 1016458 13865 1.724740e+11 3\n", + "3 1016458 9728 1.383208e+11 4\n", + "4 1016458 4151 1.149358e+11 5" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recs_itemknn_bmp = item_knn_bmp.recommend(\n", + " test['user_id'].unique(), \n", + " dataset=dataset, \n", + " k=10, \n", + " filter_viewed=False \n", + ")\n", + "\n", + "recs_itemknn_bmp.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "e99f3649", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'prec@10': 0.03252208701450242,\n", + " 'recall@10': 0.1683399650610623,\n", + " 'MAP@10': 0.04827657497255996,\n", + " 'novelty': 3.9201705312554833,\n", + " 'serendipity': 2.616232292298612e-05}" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", + "\n", + "metrics = {\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}\n", + "\n", + "catalog = train['item_id'].unique()\n", + "\n", + "metric_values_itemknn_bmp = calc_metrics(\n", + " metrics,\n", + " reco=recs_itemknn_bmp,\n", + " interactions=test,\n", + " prev_interactions=train,\n", + " catalog=catalog\n", + " )\n", + "\n", + "metric_values_itemknn_bmp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84fe056a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/HW-3.3-rectools-cv.ipynb b/notebooks/HW-3.3-rectools-cv.ipynb new file mode 100644 index 00000000..e5f56e68 --- /dev/null +++ b/notebooks/HW-3.3-rectools-cv.ipynb @@ -0,0 +1,4387 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "f0145080", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import scipy as sp\n", + "import requests\n", + "from tqdm.auto import tqdm\n", + "from scipy.stats import mode \n", + "from pprint import pprint\n", + "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "from rectools import Columns\n", + "from rectools.model_selection import TimeRangeSplitter\n", + "from rectools.dataset import Dataset, Interactions\n", + "from rectools.models.popular import PopularModel\n", + "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", + "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", + "\n", + "pd.set_option('display.max_columns', None)\n", + "pd.set_option('display.max_colwidth', 200)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "95ab759c", + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", + "\n", + "interactions.rename(columns={\n", + " 'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight\n", + " }, \n", + " inplace=True\n", + ") \n", + "\n", + "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" + ] + }, + { + "cell_type": "markdown", + "id": "fbd3f42d", + "metadata": {}, + "source": [ + "# Split" + ] + }, + { + "cell_type": "markdown", + "id": "c89fcc74", + "metadata": {}, + "source": [ + "В соответствии с предположением из ноутбука \"HW-3.1\" сделаем **валидацию по 5 дней и по 7 дней**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "368c7cf6", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_range(\n", + " last_date: pd.Timestamp, \n", + " n_folds: int = 7, \n", + " unit: str = \"W\", \n", + " n_units: int = 1, \n", + " show: bool = True,\n", + "):\n", + " periods = n_folds + 1\n", + " freq = f\"{n_units}{unit}\"\n", + " \n", + " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n", + " \n", + " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", + " \n", + " if show:\n", + " print(\n", + " f\"start_date: {start_date}\\n\"\n", + " f\"last_date: {last_date}\\n\"\n", + " f\"periods: {periods}\\n\"\n", + " f\"freq: {freq}\\n\"\n", + " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n", + " )\n", + " \n", + " return date_range" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "29af1fa3", + "metadata": {}, + "outputs": [], + "source": [ + "CONFIG_CV = {\n", + " \"cv_v1\": {\n", + " \"n_folds\": 5,\n", + " \"unit\": \"W\",\n", + " \"n_units\": 1,\n", + " },\n", + " \"cv_v2\": {\n", + " \"n_folds\": 5,\n", + " \"unit\": \"D\",\n", + " \"n_units\": 5,\n", + " }, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3fdeb5a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Timestamp('2021-08-22 00:00:00')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_date = interactions[Columns.Datetime].max().normalize()\n", + "last_date" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9ee0372b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "***Folds v1***\n", + "start_date: 2021-07-11 00:00:00\n", + "last_date: 2021-08-22 00:00:00\n", + "periods: 6\n", + "freq: 1W\n", + "Test fold borders: ['2021-07-11' '2021-07-18' '2021-07-25' '2021-08-01' '2021-08-08'\n", + " '2021-08-15']\n", + "\n", + "***Folds v2***\n", + "start_date: 2021-07-23 00:00:00\n", + "last_date: 2021-08-22 00:00:00\n", + "periods: 6\n", + "freq: 5D\n", + "Test fold borders: ['2021-07-23' '2021-07-28' '2021-08-02' '2021-08-07' '2021-08-12'\n", + " '2021-08-17']\n", + "\n" + ] + } + ], + "source": [ + "print(\"***Folds v1***\")\n", + "date_range_v1 = create_data_range(\n", + " last_date, \n", + " n_folds=CONFIG_CV[\"cv_v1\"][\"n_folds\"], \n", + " unit=CONFIG_CV[\"cv_v1\"][\"unit\"], \n", + " n_units=CONFIG_CV[\"cv_v1\"][\"n_units\"]\n", + ")\n", + "\n", + "print(\"***Folds v2***\")\n", + "date_range_v2 = create_data_range(\n", + " last_date, \n", + " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n", + " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n", + " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "63d80785", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Real number of folds: 5\n", + "Real number of folds: 5\n" + ] + } + ], + "source": [ + "cv_v1 = TimeRangeSplitter(\n", + " date_range=date_range_v1,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n", + "\n", + "cv_v2 = TimeRangeSplitter(\n", + " date_range=date_range_v2,\n", + " filter_already_seen=True,\n", + " filter_cold_items=True,\n", + " filter_cold_users=True,\n", + ")\n", + "print(f\"Real number of folds: {cv_v2.get_n_splits(Interactions(interactions))}\")\n", + "\n", + "CV = [cv_v1, cv_v2]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1d4bc5e3", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {\n", + " \"prec@5\": Precision(k=5),\n", + " \"recall@5\": Recall(k=5),\n", + " \"MAP@5\": MAP(k=5),\n", + " \"prec@10\": Precision(k=10),\n", + " \"recall@10\": Recall(k=10),\n", + " \"MAP@10\": MAP(k=10),\n", + " \"novelty\": MeanInvUserFreq(k=10),\n", + " \"serendipity\": Serendipity(k=10),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "f480a12f", + "metadata": {}, + "source": [ + "# Find best models" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "48888d0d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'popular': ,\n", + " 'popular_mw': ,\n", + " 'cosine_userknn_K30': ,\n", + " 'tfidf_userknn_K30': ,\n", + " 'bm25_userknn_K30': ,\n", + " 'cosine_userknn_K40': ,\n", + " 'tfidf_userknn_K40': ,\n", + " 'bm25_userknn_K40': ,\n", + " 'cosine_userknn_K50': ,\n", + " 'tfidf_userknn_K50': ,\n", + " 'bm25_userknn_K50': ,\n", + " 'cosine_userknn_K60': ,\n", + " 'tfidf_userknn_K60': ,\n", + " 'bm25_userknn_K60': }" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "K = [30, 40, 50, 60]\n", + "models = {\n", + " \"popular\": PopularModel(),\n", + " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n", + "}\n", + "\n", + "for k in K:\n", + " models[f\"popular\"]\n", + " models[f\"cosine_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=CosineRecommender(K=k))\n", + " models[f\"tfidf_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=TFIDFRecommender(K=k))\n", + " models[f\"bm25_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=BM25Recommender(K=k))\n", + "\n", + "models" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "240478ad", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " ***CV_0***\n", + "\n", + "==================== Fold 0\n", + "{'End date': Timestamp('2021-07-18 00:00:00', freq='W-SUN'),\n", + " 'Start date': Timestamp('2021-07-11 00:00:00', freq='W-SUN'),\n", + " 'Test': 214489,\n", + " 'Test items': 6313,\n", + " 'Test users': 84234,\n", + " 'Train': 3192875,\n", + " 'Train items': 14711,\n", + " 'Train users': 640144}\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "54fd89ff19334e3182f264d9c492bc0f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/14 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldmodelprec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipitycv
00popular_view-False0.0579270.1656440.0407500.2256730.0815120.0912683.5276320.000000fold_1w
10popular_view-True0.0670680.1877850.0431740.2368550.1064160.1146743.7705130.000003fold_1w
20popular_mw_view-False0.0000000.0000000.0000000.0000000.0000000.00000018.2251010.000000fold_1w
30popular_mw_view-True0.0000000.0000000.0000000.0000000.0000000.00000018.2251130.000000fold_1w
40cosine_userknn_K30_view-False0.0233090.0739180.0221430.1297750.0243640.0326517.9141100.000048fold_1w
....................................
2754cosine_userknn_K60_view-True0.0308600.0877970.0234320.1295780.0527720.0590919.1529680.000122fold_5d
2764tfidf_userknn_K60_view-False0.0197570.0608710.0214000.1222600.0196980.0285576.6513340.000095fold_5d
2774tfidf_userknn_K60_view-True0.0428030.1162650.0321730.1704580.0694600.0779126.7271280.000180fold_5d
2784bm25_userknn_K60_view-False0.0370060.1074420.0289580.1623460.0378950.0461993.9205840.000024fold_5d
2794bm25_userknn_K60_view-True0.0495680.1399710.0346460.1918270.0841810.0920224.0025740.000038fold_5d
\n", + "

280 rows × 11 columns

\n", + "" + ], + "text/plain": [ + " fold model prec@5 recall@5 prec@10 \\\n", + "0 0 popular_view-False 0.057927 0.165644 0.040750 \n", + "1 0 popular_view-True 0.067068 0.187785 0.043174 \n", + "2 0 popular_mw_view-False 0.000000 0.000000 0.000000 \n", + "3 0 popular_mw_view-True 0.000000 0.000000 0.000000 \n", + "4 0 cosine_userknn_K30_view-False 0.023309 0.073918 0.022143 \n", + ".. ... ... ... ... ... \n", + "275 4 cosine_userknn_K60_view-True 0.030860 0.087797 0.023432 \n", + "276 4 tfidf_userknn_K60_view-False 0.019757 0.060871 0.021400 \n", + "277 4 tfidf_userknn_K60_view-True 0.042803 0.116265 0.032173 \n", + "278 4 bm25_userknn_K60_view-False 0.037006 0.107442 0.028958 \n", + "279 4 bm25_userknn_K60_view-True 0.049568 0.139971 0.034646 \n", + "\n", + " recall@10 MAP@5 MAP@10 novelty serendipity cv \n", + "0 0.225673 0.081512 0.091268 3.527632 0.000000 fold_1w \n", + "1 0.236855 0.106416 0.114674 3.770513 0.000003 fold_1w \n", + "2 0.000000 0.000000 0.000000 18.225101 0.000000 fold_1w \n", + "3 0.000000 0.000000 0.000000 18.225113 0.000000 fold_1w \n", + "4 0.129775 0.024364 0.032651 7.914110 0.000048 fold_1w \n", + ".. ... ... ... ... ... ... \n", + "275 0.129578 0.052772 0.059091 9.152968 0.000122 fold_5d \n", + "276 0.122260 0.019698 0.028557 6.651334 0.000095 fold_5d \n", + "277 0.170458 0.069460 0.077912 6.727128 0.000180 fold_5d \n", + "278 0.162346 0.037895 0.046199 3.920584 0.000024 fold_5d \n", + "279 0.191827 0.084181 0.092022 4.002574 0.000038 fold_5d \n", + "\n", + "[280 rows x 11 columns]" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics = pd.DataFrame(results)\n", + "\n", + "df_metrics['cv'] = 'fold_1w'\n", + "df_metrics.loc[df_metrics[240:].index, 'cv'] = 'fold_5d'\n", + "\n", + "df_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "2c075a81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
  prec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipity
cvmodel        
fold_1wbm25_userknn_K30_view-False0.0421850.1198230.0341540.1853470.0425250.0526723.9462480.000025
bm25_userknn_K30_view-True0.0578090.1584620.0399610.2134740.0952320.1039424.0229380.000040
bm25_userknn_K40_view-False0.0421440.1197360.0341430.1853240.0424800.0526323.9468740.000024
bm25_userknn_K40_view-True0.0578090.1584590.0399780.2136720.0952250.1039584.0190520.000040
bm25_userknn_K50_view-False0.0427580.1212820.0346770.1878470.0430140.0533383.9522220.000024
bm25_userknn_K50_view-True0.0587270.1606350.0404860.2159230.0965760.1053574.0196010.000040
bm25_userknn_K60_view-False0.0427370.1212300.0346660.1877960.0429910.0533143.9540010.000024
bm25_userknn_K60_view-True0.0587330.1606450.0404890.2159890.0965820.1053694.0197450.000040
cosine_userknn_K30_view-False0.0182530.0566890.0183590.1056300.0186880.0258658.0175230.000059
cosine_userknn_K30_view-True0.0355550.0989330.0263640.1430560.0604240.0672809.2559170.000110
cosine_userknn_K40_view-False0.0182420.0566650.0184200.1058860.0186730.0258897.9988740.000059
cosine_userknn_K40_view-True0.0357950.0994330.0266170.1440550.0606590.0676009.1942320.000112
cosine_userknn_K50_view-False0.0185740.0575860.0187360.1074510.0189640.0262837.9764920.000059
cosine_userknn_K50_view-True0.0365740.1013450.0270880.1460780.0617260.0687059.1359440.000112
cosine_userknn_K60_view-False0.0185870.0576330.0187920.1077450.0189680.0263177.9643440.000059
cosine_userknn_K60_view-True0.0367750.1017880.0272630.1468410.0619650.0689959.0997830.000113
popular_mw_view-False0.0000010.0000040.0000010.0000050.0000010.00000118.4532010.000000
popular_mw_view-True0.0000010.0000040.0000010.0000050.0000010.00000118.4532120.000000
popular_view-False0.0478130.1347100.0336350.1829810.0675170.0751913.4627720.000000
popular_view-True0.0548010.1519420.0360520.1947680.0853480.0923823.7265770.000002
tfidf_userknn_K30_view-False0.0230980.0696630.0243060.1356420.0228240.0325196.7433550.000089
tfidf_userknn_K30_view-True0.0472470.1262870.0351810.1830160.0768820.0859166.9707800.000163
tfidf_userknn_K40_view-False0.0230880.0696410.0243660.1360420.0228090.0325576.7253990.000089
tfidf_userknn_K40_view-True0.0475380.1269040.0354420.1841430.0772930.0864016.9208620.000164
tfidf_userknn_K50_view-False0.0233680.0703830.0246770.1375270.0230520.0329166.7182830.000088
tfidf_userknn_K50_view-True0.0482450.1284920.0358830.1859640.0782480.0874186.8985940.000163
tfidf_userknn_K60_view-False0.0233350.0702720.0247020.1376080.0230200.0329066.7094850.000088
tfidf_userknn_K60_view-True0.0483400.1286640.0360000.1863860.0782780.0874876.8729720.000164
fold_5dbm25_userknn_K30_view-False0.0370910.1076370.0289830.1624440.0380040.0462953.9161130.000024
bm25_userknn_K30_view-True0.0495240.1398800.0346900.1919240.0841410.0920164.0087720.000040
bm25_userknn_K40_view-False0.0370410.1075290.0289650.1623580.0379570.0462533.9168820.000024
bm25_userknn_K40_view-True0.0495350.1398850.0346590.1918370.0841560.0920154.0040600.000039
bm25_userknn_K50_view-False0.0369970.1071580.0293990.1636720.0379280.0464873.9190890.000024
bm25_userknn_K50_view-True0.0500520.1405600.0352420.1936600.0843320.0924104.0038040.000039
bm25_userknn_K60_view-False0.0369840.1071180.0293940.1636480.0379040.0464643.9209690.000024
bm25_userknn_K60_view-True0.0500670.1405850.0352470.1937300.0843450.0924244.0038000.000039
cosine_userknn_K30_view-False0.0150530.0478120.0152010.0903210.0156960.0217948.0592570.000062
cosine_userknn_K30_view-True0.0301280.0862880.0227960.1271860.0520750.0582429.3133310.000118
cosine_userknn_K40_view-False0.0150880.0478410.0152530.0905090.0156930.0218108.0395780.000062
cosine_userknn_K40_view-True0.0304070.0868690.0230640.1281230.0523320.0585509.2453400.000120
cosine_userknn_K50_view-False0.0153600.0485870.0158260.0933330.0159890.0224308.0218540.000062
cosine_userknn_K50_view-True0.0312470.0885020.0239710.1324190.0535310.0601719.1785550.000119
cosine_userknn_K60_view-False0.0153780.0486400.0158630.0936170.0159980.0224648.0092990.000062
cosine_userknn_K60_view-True0.0314330.0888970.0241430.1332180.0537360.0604409.1400930.000120
popular_mw_view-False0.0000120.0000470.0000060.0000470.0000160.00001618.5009540.000001
popular_mw_view-True0.0000120.0000470.0000060.0000470.0000160.00001618.5009640.000001
popular_view-False0.0412890.1193840.0274250.1536150.0614460.0669593.4328550.000000
popular_view-True0.0471830.1346070.0301480.1685700.0769260.0821673.7145430.000002
tfidf_userknn_K30_view-False0.0198470.0612400.0213000.1217630.0198270.0285606.6921510.000096
tfidf_userknn_K30_view-True0.0423270.1152380.0315110.1680460.0689490.0771826.8414240.000178
tfidf_userknn_K40_view-False0.0197870.0609900.0213720.1222950.0197550.0285856.6717820.000096
tfidf_userknn_K40_view-True0.0425410.1157180.0318530.1694190.0691930.0775726.7890490.000178
tfidf_userknn_K50_view-False0.0201520.0619330.0218290.1244330.0201690.0292106.6699910.000095
tfidf_userknn_K50_view-True0.0431760.1170330.0325920.1725790.0705340.0792026.7754160.000176
tfidf_userknn_K60_view-False0.0201230.0618390.0218090.1242790.0201360.0291716.6608280.000094
tfidf_userknn_K60_view-True0.0432480.1171350.0326970.1727720.0705900.0792716.7478700.000176
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics_mean = df_metrics.groupby(['cv', 'model'])[\n", + " 'prec@5', 'recall@5', 'prec@10', 'recall@10', 'MAP@5', 'MAP@10', 'novelty', 'serendipity'\n", + "].mean()\n", + "\n", + "df_metrics_mean.style.highlight_max(color='lightgreen', axis=0)" + ] + }, + { + "cell_type": "markdown", + "id": "c6f89d3a", + "metadata": {}, + "source": [ + "Из результатов видно, что среднее значение метрик моделей **bmp** имеют **наилучшие** значения, причем на недельном фолде метрики выше, чем на 5 дневном \n", + "\n", + "- Следует проверить статистически различимы значения или нет. Для этого следует посмотреть дисперсию и если дисперсия меньше чем различия между средними значениями метрик, то можно сделать вывод, что значения метрик статистически различны" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "dbe6f6e9", + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
prec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipity
cvmodel
fold_1wbm25_userknn_K30_view-False0.0040420.0113800.0033190.0180500.0041440.0052350.0297301.677056e-06
bm25_userknn_K30_view-True0.0053480.0145870.0027870.0136200.0098230.0098710.0150063.357219e-06
bm25_userknn_K40_view-False0.0040360.0113930.0033100.0180330.0041490.0052340.0296131.666772e-06
bm25_userknn_K40_view-True0.0053340.0145920.0027740.0135110.0098240.0098590.0147093.359684e-06
bm25_userknn_K50_view-False0.0037770.0110130.0030870.0174760.0040390.0050710.0293291.807696e-06
bm25_userknn_K50_view-True0.0048940.0139550.0024670.0124850.0095620.0095320.0148613.329699e-06
bm25_userknn_K60_view-False0.0037740.0110020.0030880.0174920.0040380.0050740.0293681.805538e-06
bm25_userknn_K60_view-True0.0048940.0139630.0024660.0124780.0095620.0095290.0147383.340070e-06
cosine_userknn_K30_view-False0.0023930.0078030.0019180.0113260.0025180.0030910.0473375.304930e-06
cosine_userknn_K30_view-True0.0036240.0104000.0017860.0088100.0070000.0068420.0470279.087191e-06
cosine_userknn_K40_view-False0.0023890.0077940.0019080.0112950.0025170.0030810.0464455.302222e-06
cosine_userknn_K40_view-True0.0035680.0102570.0017520.0085470.0069510.0067840.0468249.276498e-06
cosine_userknn_K50_view-False0.0023160.0077730.0018440.0112840.0025120.0030720.0463455.581875e-06
cosine_userknn_K50_view-True0.0033820.0100320.0016460.0082760.0069180.0067300.0467999.814832e-06
cosine_userknn_K60_view-False0.0023140.0077920.0018590.0113540.0025140.0030830.0464545.564036e-06
cosine_userknn_K60_view-True0.0034040.0100880.0016380.0082460.0069610.0067690.0471339.793017e-06
popular_mw_view-False0.0000020.0000060.0000010.0000050.0000020.0000020.1251581.003157e-07
popular_mw_view-True0.0000020.0000060.0000010.0000050.0000020.0000020.1251551.003157e-07
popular_view-False0.0052120.0145720.0038290.0217390.0061000.0070750.0301750.000000e+00
popular_view-True0.0061040.0166560.0037930.0208750.0089550.0096530.0191312.429424e-07
tfidf_userknn_K30_view-False0.0022440.0069950.0019450.0106640.0023810.0029650.0398778.695037e-06
tfidf_userknn_K30_view-True0.0033320.0090010.0019680.0086110.0063740.0063900.0696381.279797e-05
tfidf_userknn_K40_view-False0.0022470.0070180.0019330.0105820.0023840.0029490.0408518.509896e-06
tfidf_userknn_K40_view-True0.0033190.0089040.0019240.0082880.0063000.0062880.0708191.317232e-05
tfidf_userknn_K50_view-False0.0022010.0070680.0018850.0105910.0024010.0029550.0404928.580259e-06
tfidf_userknn_K50_view-True0.0031950.0088190.0018010.0077960.0063570.0062790.0654901.351216e-05
tfidf_userknn_K60_view-False0.0022040.0070880.0018860.0106590.0024120.0029720.0408278.400699e-06
tfidf_userknn_K60_view-True0.0031990.0088290.0018090.0078410.0064200.0063480.0660501.359432e-05
fold_5dbm25_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K50_view-False0.0000390.0004560.0006120.0018670.0000020.0003660.0007251.109718e-07
bm25_userknn_K50_view-True0.0006970.0008470.0008170.0025360.0002280.0005520.0012428.703289e-07
bm25_userknn_K60_view-False0.0000310.0004580.0006160.0018410.0000130.0003750.0005438.345726e-08
bm25_userknn_K60_view-True0.0007060.0008690.0008500.0026910.0002320.0005690.0017341.165137e-06
cosine_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K50_view-False0.0003550.0009810.0007060.0033660.0004120.0007980.0010416.443968e-07
cosine_userknn_K50_view-True0.0007920.0014680.0009660.0047400.0013080.0018170.0183502.644474e-06
cosine_userknn_K60_view-False0.0003770.0010290.0007350.0035480.0004270.0008310.0032367.481876e-07
cosine_userknn_K60_view-True0.0008100.0015550.0010060.0051480.0013630.0019080.0182082.395850e-06
popular_mw_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
popular_mw_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
popular_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
popular_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K50_view-False0.0005680.0014770.0005630.0026220.0006370.0008480.0140731.071478e-06
tfidf_userknn_K50_view-True0.0006550.0012350.0007460.0033030.0015870.0019030.0288473.757778e-06
tfidf_userknn_K60_view-False0.0005170.0013700.0005780.0028550.0006190.0008680.0134261.001711e-06
tfidf_userknn_K60_view-True0.0006300.0012300.0007420.0032730.0015980.0019220.0293354.434436e-06
\n", + "
" + ], + "text/plain": [ + " prec@5 recall@5 prec@10 \\\n", + "cv model \n", + "fold_1w bm25_userknn_K30_view-False 0.004042 0.011380 0.003319 \n", + " bm25_userknn_K30_view-True 0.005348 0.014587 0.002787 \n", + " bm25_userknn_K40_view-False 0.004036 0.011393 0.003310 \n", + " bm25_userknn_K40_view-True 0.005334 0.014592 0.002774 \n", + " bm25_userknn_K50_view-False 0.003777 0.011013 0.003087 \n", + " bm25_userknn_K50_view-True 0.004894 0.013955 0.002467 \n", + " bm25_userknn_K60_view-False 0.003774 0.011002 0.003088 \n", + " bm25_userknn_K60_view-True 0.004894 0.013963 0.002466 \n", + " cosine_userknn_K30_view-False 0.002393 0.007803 0.001918 \n", + " cosine_userknn_K30_view-True 0.003624 0.010400 0.001786 \n", + " cosine_userknn_K40_view-False 0.002389 0.007794 0.001908 \n", + " cosine_userknn_K40_view-True 0.003568 0.010257 0.001752 \n", + " cosine_userknn_K50_view-False 0.002316 0.007773 0.001844 \n", + " cosine_userknn_K50_view-True 0.003382 0.010032 0.001646 \n", + " cosine_userknn_K60_view-False 0.002314 0.007792 0.001859 \n", + " cosine_userknn_K60_view-True 0.003404 0.010088 0.001638 \n", + " popular_mw_view-False 0.000002 0.000006 0.000001 \n", + " popular_mw_view-True 0.000002 0.000006 0.000001 \n", + " popular_view-False 0.005212 0.014572 0.003829 \n", + " popular_view-True 0.006104 0.016656 0.003793 \n", + " tfidf_userknn_K30_view-False 0.002244 0.006995 0.001945 \n", + " tfidf_userknn_K30_view-True 0.003332 0.009001 0.001968 \n", + " tfidf_userknn_K40_view-False 0.002247 0.007018 0.001933 \n", + " tfidf_userknn_K40_view-True 0.003319 0.008904 0.001924 \n", + " tfidf_userknn_K50_view-False 0.002201 0.007068 0.001885 \n", + " tfidf_userknn_K50_view-True 0.003195 0.008819 0.001801 \n", + " tfidf_userknn_K60_view-False 0.002204 0.007088 0.001886 \n", + " tfidf_userknn_K60_view-True 0.003199 0.008829 0.001809 \n", + "fold_5d bm25_userknn_K30_view-False NaN NaN NaN \n", + " bm25_userknn_K30_view-True NaN NaN NaN \n", + " bm25_userknn_K40_view-False NaN NaN NaN \n", + " bm25_userknn_K40_view-True NaN NaN NaN \n", + " bm25_userknn_K50_view-False 0.000039 0.000456 0.000612 \n", + " bm25_userknn_K50_view-True 0.000697 0.000847 0.000817 \n", + " bm25_userknn_K60_view-False 0.000031 0.000458 0.000616 \n", + " bm25_userknn_K60_view-True 0.000706 0.000869 0.000850 \n", + " cosine_userknn_K30_view-False NaN NaN NaN \n", + " cosine_userknn_K30_view-True NaN NaN NaN \n", + " cosine_userknn_K40_view-False NaN NaN NaN \n", + " cosine_userknn_K40_view-True NaN NaN NaN \n", + " cosine_userknn_K50_view-False 0.000355 0.000981 0.000706 \n", + " cosine_userknn_K50_view-True 0.000792 0.001468 0.000966 \n", + " cosine_userknn_K60_view-False 0.000377 0.001029 0.000735 \n", + " cosine_userknn_K60_view-True 0.000810 0.001555 0.001006 \n", + " popular_mw_view-False NaN NaN NaN \n", + " popular_mw_view-True NaN NaN NaN \n", + " popular_view-False NaN NaN NaN \n", + " popular_view-True NaN NaN NaN \n", + " tfidf_userknn_K30_view-False NaN NaN NaN \n", + " tfidf_userknn_K30_view-True NaN NaN NaN \n", + " tfidf_userknn_K40_view-False NaN NaN NaN \n", + " tfidf_userknn_K40_view-True NaN NaN NaN \n", + " tfidf_userknn_K50_view-False 0.000568 0.001477 0.000563 \n", + " tfidf_userknn_K50_view-True 0.000655 0.001235 0.000746 \n", + " tfidf_userknn_K60_view-False 0.000517 0.001370 0.000578 \n", + " tfidf_userknn_K60_view-True 0.000630 0.001230 0.000742 \n", + "\n", + " recall@10 MAP@5 MAP@10 \\\n", + "cv model \n", + "fold_1w bm25_userknn_K30_view-False 0.018050 0.004144 0.005235 \n", + " bm25_userknn_K30_view-True 0.013620 0.009823 0.009871 \n", + " bm25_userknn_K40_view-False 0.018033 0.004149 0.005234 \n", + " bm25_userknn_K40_view-True 0.013511 0.009824 0.009859 \n", + " bm25_userknn_K50_view-False 0.017476 0.004039 0.005071 \n", + " bm25_userknn_K50_view-True 0.012485 0.009562 0.009532 \n", + " bm25_userknn_K60_view-False 0.017492 0.004038 0.005074 \n", + " bm25_userknn_K60_view-True 0.012478 0.009562 0.009529 \n", + " cosine_userknn_K30_view-False 0.011326 0.002518 0.003091 \n", + " cosine_userknn_K30_view-True 0.008810 0.007000 0.006842 \n", + " cosine_userknn_K40_view-False 0.011295 0.002517 0.003081 \n", + " cosine_userknn_K40_view-True 0.008547 0.006951 0.006784 \n", + " cosine_userknn_K50_view-False 0.011284 0.002512 0.003072 \n", + " cosine_userknn_K50_view-True 0.008276 0.006918 0.006730 \n", + " cosine_userknn_K60_view-False 0.011354 0.002514 0.003083 \n", + " cosine_userknn_K60_view-True 0.008246 0.006961 0.006769 \n", + " popular_mw_view-False 0.000005 0.000002 0.000002 \n", + " popular_mw_view-True 0.000005 0.000002 0.000002 \n", + " popular_view-False 0.021739 0.006100 0.007075 \n", + " popular_view-True 0.020875 0.008955 0.009653 \n", + " tfidf_userknn_K30_view-False 0.010664 0.002381 0.002965 \n", + " tfidf_userknn_K30_view-True 0.008611 0.006374 0.006390 \n", + " tfidf_userknn_K40_view-False 0.010582 0.002384 0.002949 \n", + " tfidf_userknn_K40_view-True 0.008288 0.006300 0.006288 \n", + " tfidf_userknn_K50_view-False 0.010591 0.002401 0.002955 \n", + " tfidf_userknn_K50_view-True 0.007796 0.006357 0.006279 \n", + " tfidf_userknn_K60_view-False 0.010659 0.002412 0.002972 \n", + " tfidf_userknn_K60_view-True 0.007841 0.006420 0.006348 \n", + "fold_5d bm25_userknn_K30_view-False NaN NaN NaN \n", + " bm25_userknn_K30_view-True NaN NaN NaN \n", + " bm25_userknn_K40_view-False NaN NaN NaN \n", + " bm25_userknn_K40_view-True NaN NaN NaN \n", + " bm25_userknn_K50_view-False 0.001867 0.000002 0.000366 \n", + " bm25_userknn_K50_view-True 0.002536 0.000228 0.000552 \n", + " bm25_userknn_K60_view-False 0.001841 0.000013 0.000375 \n", + " bm25_userknn_K60_view-True 0.002691 0.000232 0.000569 \n", + " cosine_userknn_K30_view-False NaN NaN NaN \n", + " cosine_userknn_K30_view-True NaN NaN NaN \n", + " cosine_userknn_K40_view-False NaN NaN NaN \n", + " cosine_userknn_K40_view-True NaN NaN NaN \n", + " cosine_userknn_K50_view-False 0.003366 0.000412 0.000798 \n", + " cosine_userknn_K50_view-True 0.004740 0.001308 0.001817 \n", + " cosine_userknn_K60_view-False 0.003548 0.000427 0.000831 \n", + " cosine_userknn_K60_view-True 0.005148 0.001363 0.001908 \n", + " popular_mw_view-False NaN NaN NaN \n", + " popular_mw_view-True NaN NaN NaN \n", + " popular_view-False NaN NaN NaN \n", + " popular_view-True NaN NaN NaN \n", + " tfidf_userknn_K30_view-False NaN NaN NaN \n", + " tfidf_userknn_K30_view-True NaN NaN NaN \n", + " tfidf_userknn_K40_view-False NaN NaN NaN \n", + " tfidf_userknn_K40_view-True NaN NaN NaN \n", + " tfidf_userknn_K50_view-False 0.002622 0.000637 0.000848 \n", + " tfidf_userknn_K50_view-True 0.003303 0.001587 0.001903 \n", + " tfidf_userknn_K60_view-False 0.002855 0.000619 0.000868 \n", + " tfidf_userknn_K60_view-True 0.003273 0.001598 0.001922 \n", + "\n", + " novelty serendipity \n", + "cv model \n", + "fold_1w bm25_userknn_K30_view-False 0.029730 1.677056e-06 \n", + " bm25_userknn_K30_view-True 0.015006 3.357219e-06 \n", + " bm25_userknn_K40_view-False 0.029613 1.666772e-06 \n", + " bm25_userknn_K40_view-True 0.014709 3.359684e-06 \n", + " bm25_userknn_K50_view-False 0.029329 1.807696e-06 \n", + " bm25_userknn_K50_view-True 0.014861 3.329699e-06 \n", + " bm25_userknn_K60_view-False 0.029368 1.805538e-06 \n", + " bm25_userknn_K60_view-True 0.014738 3.340070e-06 \n", + " cosine_userknn_K30_view-False 0.047337 5.304930e-06 \n", + " cosine_userknn_K30_view-True 0.047027 9.087191e-06 \n", + " cosine_userknn_K40_view-False 0.046445 5.302222e-06 \n", + " cosine_userknn_K40_view-True 0.046824 9.276498e-06 \n", + " cosine_userknn_K50_view-False 0.046345 5.581875e-06 \n", + " cosine_userknn_K50_view-True 0.046799 9.814832e-06 \n", + " cosine_userknn_K60_view-False 0.046454 5.564036e-06 \n", + " cosine_userknn_K60_view-True 0.047133 9.793017e-06 \n", + " popular_mw_view-False 0.125158 1.003157e-07 \n", + " popular_mw_view-True 0.125155 1.003157e-07 \n", + " popular_view-False 0.030175 0.000000e+00 \n", + " popular_view-True 0.019131 2.429424e-07 \n", + " tfidf_userknn_K30_view-False 0.039877 8.695037e-06 \n", + " tfidf_userknn_K30_view-True 0.069638 1.279797e-05 \n", + " tfidf_userknn_K40_view-False 0.040851 8.509896e-06 \n", + " tfidf_userknn_K40_view-True 0.070819 1.317232e-05 \n", + " tfidf_userknn_K50_view-False 0.040492 8.580259e-06 \n", + " tfidf_userknn_K50_view-True 0.065490 1.351216e-05 \n", + " tfidf_userknn_K60_view-False 0.040827 8.400699e-06 \n", + " tfidf_userknn_K60_view-True 0.066050 1.359432e-05 \n", + "fold_5d bm25_userknn_K30_view-False NaN NaN \n", + " bm25_userknn_K30_view-True NaN NaN \n", + " bm25_userknn_K40_view-False NaN NaN \n", + " bm25_userknn_K40_view-True NaN NaN \n", + " bm25_userknn_K50_view-False 0.000725 1.109718e-07 \n", + " bm25_userknn_K50_view-True 0.001242 8.703289e-07 \n", + " bm25_userknn_K60_view-False 0.000543 8.345726e-08 \n", + " bm25_userknn_K60_view-True 0.001734 1.165137e-06 \n", + " cosine_userknn_K30_view-False NaN NaN \n", + " cosine_userknn_K30_view-True NaN NaN \n", + " cosine_userknn_K40_view-False NaN NaN \n", + " cosine_userknn_K40_view-True NaN NaN \n", + " cosine_userknn_K50_view-False 0.001041 6.443968e-07 \n", + " cosine_userknn_K50_view-True 0.018350 2.644474e-06 \n", + " cosine_userknn_K60_view-False 0.003236 7.481876e-07 \n", + " cosine_userknn_K60_view-True 0.018208 2.395850e-06 \n", + " popular_mw_view-False NaN NaN \n", + " popular_mw_view-True NaN NaN \n", + " popular_view-False NaN NaN \n", + " popular_view-True NaN NaN \n", + " tfidf_userknn_K30_view-False NaN NaN \n", + " tfidf_userknn_K30_view-True NaN NaN \n", + " tfidf_userknn_K40_view-False NaN NaN \n", + " tfidf_userknn_K40_view-True NaN NaN \n", + " tfidf_userknn_K50_view-False 0.014073 1.071478e-06 \n", + " tfidf_userknn_K50_view-True 0.028847 3.757778e-06 \n", + " tfidf_userknn_K60_view-False 0.013426 1.001711e-06 \n", + " tfidf_userknn_K60_view-True 0.029335 4.434436e-06 " + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_metrics_std = df_metrics.groupby(['cv', 'model'])[\n", + " 'prec@5', 'recall@5', 'prec@10', 'recall@10', 'MAP@5', 'MAP@10', 'novelty', 'serendipity'\n", + "].std()\n", + "\n", + "df_metrics_std" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "58ad9d07", + "metadata": {}, + "outputs": [], + "source": [ + "df_metrics_1w_mean = df_metrics_mean.loc[\"fold_1w\"]\n", + "df_metrics_1w_std = df_metrics_std.loc[\"fold_1w\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "52bffafc", + "metadata": {}, + "outputs": [], + "source": [ + "best_model = \"bm25_userknn_K60_view-True\"\n", + "col_metrics = list(metrics.keys())\n", + "std_best_metrics = df_metrics_1w_std[df_metrics_1w_std[\"model\"] == best_model][col_metrics].values[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "0059da61", + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "prec@5 0.004894\n", + "recall@5 0.013963\n", + "prec@10 0.002466\n", + "recall@10 0.012478\n", + "MAP@5 0.009562\n", + "MAP@10 0.009529\n", + "novelty 0.014738\n", + "serendipity 0.000003\n", + "Name: bm25_userknn_K60_view-True, dtype: float64\n", + "\n", + "===Сравнение с bm25_userknn_K30_view-False\n", + "prec@5 0.000852\n", + "recall@5 0.002584\n", + "prec@10 -0.000854\n", + "recall@10 -0.005571\n", + "MAP@5 0.005418\n", + "MAP@10 0.004293\n", + "novelty -0.014992\n", + "serendipity 0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K30_view-True\n", + "prec@5 -4.545181e-04\n", + "recall@5 -6.232778e-04\n", + "prec@10 -3.208519e-04\n", + "recall@10 -1.141463e-03\n", + "MAP@5 -2.606624e-04\n", + "MAP@10 -3.423636e-04\n", + "novelty -2.682734e-04\n", + "serendipity -1.714842e-08\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K40_view-False\n", + "prec@5 0.000858\n", + "recall@5 0.002570\n", + "prec@10 -0.000844\n", + "recall@10 -0.005554\n", + "MAP@5 0.005413\n", + "MAP@10 0.004294\n", + "novelty -0.014875\n", + "serendipity 0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K40_view-True\n", + "prec@5 -4.403880e-04\n", + "recall@5 -6.290758e-04\n", + "prec@10 -3.080923e-04\n", + "recall@10 -1.032860e-03\n", + "MAP@5 -2.622039e-04\n", + "MAP@10 -3.305713e-04\n", + "novelty 2.940168e-05\n", + "serendipity -1.961419e-08\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K50_view-False\n", + "prec@5 0.001117\n", + "recall@5 0.002950\n", + "prec@10 -0.000621\n", + "recall@10 -0.004998\n", + "MAP@5 0.005523\n", + "MAP@10 0.004457\n", + "novelty -0.014591\n", + "serendipity 0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K50_view-True\n", + "prec@5 -6.188074e-08\n", + "recall@5 8.340206e-06\n", + "prec@10 -1.581047e-06\n", + "recall@10 -6.557768e-06\n", + "MAP@5 8.532418e-08\n", + "MAP@10 -3.212381e-06\n", + "novelty -1.225529e-04\n", + "serendipity 1.037090e-08\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с bm25_userknn_K60_view-False\n", + "prec@5 0.001119\n", + "recall@5 0.002961\n", + "prec@10 -0.000622\n", + "recall@10 -0.005014\n", + "MAP@5 0.005524\n", + "MAP@10 0.004455\n", + "novelty -0.014630\n", + "serendipity 0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K30_view-False\n", + "prec@5 0.002501\n", + "recall@5 0.006161\n", + "prec@10 0.000548\n", + "recall@10 0.001152\n", + "MAP@5 0.007044\n", + "MAP@10 0.006437\n", + "novelty -0.032599\n", + "serendipity -0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K30_view-True\n", + "prec@5 0.001270\n", + "recall@5 0.003563\n", + "prec@10 0.000680\n", + "recall@10 0.003668\n", + "MAP@5 0.002563\n", + "MAP@10 0.002687\n", + "novelty -0.032289\n", + "serendipity -0.000006\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K40_view-False\n", + "prec@5 0.002504\n", + "recall@5 0.006169\n", + "prec@10 0.000558\n", + "recall@10 0.001184\n", + "MAP@5 0.007046\n", + "MAP@10 0.006448\n", + "novelty -0.031707\n", + "serendipity -0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K40_view-True\n", + "prec@5 0.001325\n", + "recall@5 0.003706\n", + "prec@10 0.000714\n", + "recall@10 0.003931\n", + "MAP@5 0.002611\n", + "MAP@10 0.002744\n", + "novelty -0.032086\n", + "serendipity -0.000006\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K50_view-False\n", + "prec@5 0.002578\n", + "recall@5 0.006190\n", + "prec@10 0.000622\n", + "recall@10 0.001195\n", + "MAP@5 0.007051\n", + "MAP@10 0.006456\n", + "novelty -0.031607\n", + "serendipity -0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K50_view-True\n", + "prec@5 0.001511\n", + "recall@5 0.003931\n", + "prec@10 0.000820\n", + "recall@10 0.004203\n", + "MAP@5 0.002644\n", + "MAP@10 0.002798\n", + "novelty -0.032061\n", + "serendipity -0.000006\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K60_view-False\n", + "prec@5 0.002580\n", + "recall@5 0.006172\n", + "prec@10 0.000607\n", + "recall@10 0.001124\n", + "MAP@5 0.007049\n", + "MAP@10 0.006446\n", + "novelty -0.031716\n", + "serendipity -0.000002\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с cosine_userknn_K60_view-True\n", + "prec@5 0.001489\n", + "recall@5 0.003875\n", + "prec@10 0.000827\n", + "recall@10 0.004232\n", + "MAP@5 0.002601\n", + "MAP@10 0.002760\n", + "novelty -0.032395\n", + "serendipity -0.000006\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с popular_mw_view-False\n", + "prec@5 0.004892\n", + "recall@5 0.013958\n", + "prec@10 0.002465\n", + "recall@10 0.012473\n", + "MAP@5 0.009561\n", + "MAP@10 0.009527\n", + "novelty -0.110420\n", + "serendipity 0.000003\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с popular_mw_view-True\n", + "prec@5 0.004892\n", + "recall@5 0.013958\n", + "prec@10 0.002465\n", + "recall@10 0.012473\n", + "MAP@5 0.009561\n", + "MAP@10 0.009527\n", + "novelty -0.110417\n", + "serendipity 0.000003\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с popular_view-False\n", + "prec@5 -0.000319\n", + "recall@5 -0.000609\n", + "prec@10 -0.001363\n", + "recall@10 -0.009260\n", + "MAP@5 0.003462\n", + "MAP@10 0.002453\n", + "novelty -0.015437\n", + "serendipity 0.000003\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с popular_view-True\n", + "prec@5 -0.001210\n", + "recall@5 -0.002692\n", + "prec@10 -0.001327\n", + "recall@10 -0.008397\n", + "MAP@5 0.000607\n", + "MAP@10 -0.000124\n", + "novelty -0.004393\n", + "serendipity 0.000003\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K30_view-False\n", + "prec@5 0.002649\n", + "recall@5 0.006968\n", + "prec@10 0.000521\n", + "recall@10 0.001815\n", + "MAP@5 0.007181\n", + "MAP@10 0.006564\n", + "novelty -0.025139\n", + "serendipity -0.000005\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K30_view-True\n", + "prec@5 0.001561\n", + "recall@5 0.004963\n", + "prec@10 0.000498\n", + "recall@10 0.003867\n", + "MAP@5 0.003188\n", + "MAP@10 0.003139\n", + "novelty -0.054900\n", + "serendipity -0.000009\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K40_view-False\n", + "prec@5 0.002647\n", + "recall@5 0.006945\n", + "prec@10 0.000532\n", + "recall@10 0.001897\n", + "MAP@5 0.007178\n", + "MAP@10 0.006579\n", + "novelty -0.026113\n", + "serendipity -0.000005\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K40_view-True\n", + "prec@5 0.001575\n", + "recall@5 0.005059\n", + "prec@10 0.000542\n", + "recall@10 0.004190\n", + "MAP@5 0.003262\n", + "MAP@10 0.003240\n", + "novelty -0.056080\n", + "serendipity -0.000010\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K50_view-False\n", + "prec@5 0.002693\n", + "recall@5 0.006895\n", + "prec@10 0.000581\n", + "recall@10 0.001887\n", + "MAP@5 0.007161\n", + "MAP@10 0.006574\n", + "novelty -0.025754\n", + "serendipity -0.000005\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K50_view-True\n", + "prec@5 0.001698\n", + "recall@5 0.005144\n", + "prec@10 0.000665\n", + "recall@10 0.004682\n", + "MAP@5 0.003205\n", + "MAP@10 0.003249\n", + "novelty -0.050752\n", + "serendipity -0.000010\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K60_view-False\n", + "prec@5 0.002690\n", + "recall@5 0.006875\n", + "prec@10 0.000580\n", + "recall@10 0.001819\n", + "MAP@5 0.007150\n", + "MAP@10 0.006557\n", + "novelty -0.026089\n", + "serendipity -0.000005\n", + "dtype: float64\n", + "=========================\n", + "\n", + "===Сравнение с tfidf_userknn_K60_view-True\n", + "prec@5 0.001694\n", + "recall@5 0.005134\n", + "prec@10 0.000657\n", + "recall@10 0.004637\n", + "MAP@5 0.003142\n", + "MAP@10 0.003180\n", + "novelty -0.051312\n", + "serendipity -0.000010\n", + "dtype: float64\n", + "=========================\n" + ] + } + ], + "source": [ + "print(df_metrics_1w_std.loc[best_model])\n", + "for model in df_metrics_1w_mean.index:\n", + " if model != best_model:\n", + " print(f\"\\n===Сравнение с {model}\")\n", + " print(df_metrics_1w_mean.loc[best_model] - df_metrics_1w_mean.loc[model])\n", + " print(\"=========================\")" + ] + }, + { + "cell_type": "markdown", + "id": "0675ba9b", + "metadata": {}, + "source": [ + "Лучшей модели большинством из моделей видны статистические различия, кроме всех моделей bmp (логично, потому что лучшая модель bmp с k = 60) и моделью tfidf, где для рекомендаций стоял флаг filter_viewed = True, что означает рекомендовать не одинаковые элементы для всех пользователей" + ] + }, + { + "cell_type": "markdown", + "id": "e233b183", + "metadata": {}, + "source": [ + "# Обучение на всех имеющихся данных и формирование оффлайн рекомендаций" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "30e985b6", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=interactions,\n", + " user_features_df=None,\n", + " item_features_df=None\n", + ")\n", + "\n", + "bmp25_k60_model = ImplicitItemKNNWrapperModel(BM25Recommender(K=60))\n", + "bmp25_k60_model.fit(dataset)\n", + "\n", + "K_RECOS = 30\n", + " \n", + "recos_offline_bmp25 = bmp25_k60_model.recommend(\n", + " users=interactions[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=K_RECOS,\n", + " filter_viewed=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "4034d96f", + "metadata": {}, + "outputs": [], + "source": [ + "recos_offline_bmp25.to_csv(\"../data/hw_3/bmp_25_k60_rectools.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d52a48b5", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=interactions,\n", + " user_features_df=None,\n", + " item_features_df=None\n", + ")\n", + "\n", + "tfidf_k60_model = ImplicitItemKNNWrapperModel(TFIDFRecommender(K=60))\n", + "tfidf_k60_model.fit(dataset)\n", + "\n", + "K_RECOS = 30\n", + " \n", + "recos_offline_tfidf = tfidf_k60_model.recommend(\n", + " users=interactions[Columns.User].unique(),\n", + " dataset=dataset,\n", + " k=K_RECOS,\n", + " filter_viewed=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d74a0ee6", + "metadata": {}, + "outputs": [], + "source": [ + "recos_offline_tfidf.to_csv(\"../data/hw_3/tfidf_k60_rectools.csv\", index=False)" + ] + }, + { + "cell_type": "markdown", + "id": "0164df93", + "metadata": {}, + "source": [ + "# Формирование рекомендаций для cold users" + ] + }, + { + "cell_type": "markdown", + "id": "5af7d214", + "metadata": {}, + "source": [ + "По моделям на основе популярного наилучшего качества достигали метрики по модели popular на основе количества уникальных пользователей взаимодействовавших с элементом, НО по среднему весу взаимодействия с элементами модель показывает по метрики новелти очень высокие результаты, поэтому стоит попробовать обе из моделей" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "51fdeae3", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=interactions,\n", + " user_features_df=None,\n", + " item_features_df=None\n", + ")\n", + "\n", + "popular_model = PopularModel()\n", + "popular_model.fit(dataset)\n", + "\n", + "item_inv = dict(enumerate(interactions[\"item_id\"].unique()))\n", + "recos_pop = []\n", + "for item_pop in popular_model.popularity_list[0]:\n", + " recos_pop.append(item_inv[item_pop])\n", + "\n", + "df_pop_recos = pd.DataFrame({\"item_id\": recos_pop})\n", + "\n", + "df_pop_recos.to_csv(\"../data/hw_3/popular_item.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "52981cce", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=interactions,\n", + " user_features_df=None,\n", + " item_features_df=None\n", + ")\n", + "\n", + "popular_model_mw = PopularModel(popularity=\"mean_weight\")\n", + "popular_model_mw.fit(dataset)\n", + "\n", + "item_inv = dict(enumerate(interactions[\"item_id\"].unique()))\n", + "recos_pop = []\n", + "for item_pop in popular_model_mw.popularity_list[0]:\n", + " recos_pop.append(item_inv[item_pop])\n", + "\n", + "df_pop_recos_mw = pd.DataFrame({\"item_id\": recos_pop})\n", + "\n", + "df_pop_recos_mw.to_csv(\"../data/hw_3/popular_mean_weight_item.csv\", index=False)" + ] + }, + { + "cell_type": "markdown", + "id": "170efd3c", + "metadata": {}, + "source": [ + "# Блендинг результатов моделей" + ] + }, + { + "cell_type": "markdown", + "id": "878f0b90", + "metadata": {}, + "source": [ + "Механизм блендинга будет выглядить следующим образом:\n", + "\n", + "1. Берутся рекомендации, сделанные моделями tfidf и bmp25, конкатятся результаты, удялются дубликаты item-ов\n", + "2. Берется заготовленный датаест items c полями item_id и idf\n", + "3. смотрится idf, чем он выше, тем выше будет стоять item в выдаче\n", + "\n", + "Такой подход обусловлен тем, что idf показывает обратную частоту item, соответственно в выдаче наверх будут попадать item, с которым меньшее количество раз взаимодейстовали пользователи, т.е. в перспективе такой подход может предлагать item, с которыми ни один пользователь не взаимодействовал или взаимодействовали очень мало, т.е. может решиться проблема длинного хвоста." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3b35f8ff", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "163f79b9", + "metadata": {}, + "outputs": [], + "source": [ + "df_bmp_recs = pd.read_csv(\"../data/hw_3/bmp_25_k60_rectools.csv\")\n", + "df_tfidf_recs = pd.read_csv(\"../data/hw_3/tfidf_k60_rectools.csv\") " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c842edef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
0176549138658.899597e+101
1176549104408.153085e+102
2176549152977.204604e+103
317654937346.953473e+104
417654941514.674591e+105
2886225869726254341.615419e+1026
2886225969726211321.605160e+1027
2886226069726274761.566697e+1028
28862261697262112371.546907e+1029
28862262697262129951.542308e+1030
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 176549 13865 8.899597e+10 1\n", + "1 176549 10440 8.153085e+10 2\n", + "2 176549 15297 7.204604e+10 3\n", + "3 176549 3734 6.953473e+10 4\n", + "4 176549 4151 4.674591e+10 5\n", + "28862258 697262 5434 1.615419e+10 26\n", + "28862259 697262 1132 1.605160e+10 27\n", + "28862260 697262 7476 1.566697e+10 28\n", + "28862261 697262 11237 1.546907e+10 29\n", + "28862262 697262 12995 1.542308e+10 30" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([df_bmp_recs.head(), df_bmp_recs.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "576f23a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
01765491174913575.6601851
11765491627011946.8727082
21765491198511355.6931193
31765491315910375.5006474
41765491526610269.0196905
2886225869726261921294.34241426
28862259697262116401277.33233327
2886226069726274761262.91937728
28862261697262141213.49928129
2886226269726237841200.34778530
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "0 176549 11749 13575.660185 1\n", + "1 176549 16270 11946.872708 2\n", + "2 176549 11985 11355.693119 3\n", + "3 176549 13159 10375.500647 4\n", + "4 176549 15266 10269.019690 5\n", + "28862258 697262 6192 1294.342414 26\n", + "28862259 697262 11640 1277.332333 27\n", + "28862260 697262 7476 1262.919377 28\n", + "28862261 697262 14 1213.499281 29\n", + "28862262 697262 3784 1200.347785 30" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.concat([df_tfidf_recs.head(), df_tfidf_recs.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc5bfdf3", + "metadata": {}, + "outputs": [], + "source": [ + "del df_tfidf_recs['rank'], df_bmp_recs['rank'], df_tfidf_recs['score'], df_bmp_recs['score']" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "edcc93bd", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_id
010975571132
110975575658
21097557142
310975573734
4109755716228
5109755712192
6109755713865
710975572657
810975579728
910975574880
10109755711778
1110975579996
1210975578636
1310975573935
1410975575803
1510975574457
1610975571844
1710975576382
1810975574716
1910975574495
4102337203734
4102337303935
4102337407417
4102337504495
4102337606382
4102337705803
4102337801844
41023379011778
4102338008636
4102338109996
4102338202657
41023383016228
4102338404880
41023385013865
410233860142
4102338706443
4102338804740
4102338906809
41023390010440
41023391014901
\n", + "
" + ], + "text/plain": [ + " user_id item_id\n", + "0 1097557 1132\n", + "1 1097557 5658\n", + "2 1097557 142\n", + "3 1097557 3734\n", + "4 1097557 16228\n", + "5 1097557 12192\n", + "6 1097557 13865\n", + "7 1097557 2657\n", + "8 1097557 9728\n", + "9 1097557 4880\n", + "10 1097557 11778\n", + "11 1097557 9996\n", + "12 1097557 8636\n", + "13 1097557 3935\n", + "14 1097557 5803\n", + "15 1097557 4457\n", + "16 1097557 1844\n", + "17 1097557 6382\n", + "18 1097557 4716\n", + "19 1097557 4495\n", + "41023372 0 3734\n", + "41023373 0 3935\n", + "41023374 0 7417\n", + "41023375 0 4495\n", + "41023376 0 6382\n", + "41023377 0 5803\n", + "41023378 0 1844\n", + "41023379 0 11778\n", + "41023380 0 8636\n", + "41023381 0 9996\n", + "41023382 0 2657\n", + "41023383 0 16228\n", + "41023384 0 4880\n", + "41023385 0 13865\n", + "41023386 0 142\n", + "41023387 0 6443\n", + "41023388 0 4740\n", + "41023389 0 6809\n", + "41023390 0 10440\n", + "41023391 0 14901" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_recs = pd.concat(\n", + " [\n", + " df_bmp_recs, df_tfidf_recs\n", + " ],\n", + " ignore_index=True\n", + ").sort_values(\n", + " [\"user_id\"], ascending=False\n", + ").drop_duplicates(\n", + " [\"user_id\", \"item_id\"]\n", + ").reset_index(drop=True)\n", + "\n", + "pd.concat([df_all_recs.head(20), df_all_recs.tail(20)])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1267df73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(15706, 2)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexidf
095067.150811
116598.524953
271075.821207
376388.407093
466867.778734
\n", + "
" + ], + "text/plain": [ + " index idf\n", + "0 9506 7.150811\n", + "1 1659 8.524953\n", + "2 7107 5.821207\n", + "3 7638 8.407093\n", + "4 6686 7.778734" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_idf = pd.read_csv(\"../data/kion_train/items_idf.csv\")\n", + "print(item_idf.shape)\n", + "item_idf.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "68c2c0c0", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idindexidf
141097557580358036.840585
171097557638263826.806090
251097557747674766.545666
181097557471647166.480408
34109755714146.467549
24109755711640116406.318255
211097557543454346.226266
01097557113211326.183141
10109755711778117786.134312
131097557393539356.067242
231097557619261926.038563
11097557565856586.025091
261097557378437845.990008
4109755716228162285.756312
311097557741774175.715013
321097557782978295.615193
191097557449544955.563930
22109755714431144315.558556
151097557445744575.548639
28109755712995129955.495888
410233740741774175.715013
410233610782978295.615193
410233750449544955.563930
41023358014431144315.558556
410233640445744575.548639
41023363012995129955.495888
410233780184418445.419019
41023366011237112375.365593
410233600757175715.267906
410233880474047405.078522
410233800863686365.041418
410233810999699964.992277
410233890680968094.917360
4102338601421424.801620
410233840488048804.610045
410233820265726574.392592
410233720373437344.306872
410233650415141514.111983
41023385013865138653.825227
41023390010440104403.333947
\n", + "
" + ], + "text/plain": [ + " user_id item_id index idf\n", + "14 1097557 5803 5803 6.840585\n", + "17 1097557 6382 6382 6.806090\n", + "25 1097557 7476 7476 6.545666\n", + "18 1097557 4716 4716 6.480408\n", + "34 1097557 14 14 6.467549\n", + "24 1097557 11640 11640 6.318255\n", + "21 1097557 5434 5434 6.226266\n", + "0 1097557 1132 1132 6.183141\n", + "10 1097557 11778 11778 6.134312\n", + "13 1097557 3935 3935 6.067242\n", + "23 1097557 6192 6192 6.038563\n", + "1 1097557 5658 5658 6.025091\n", + "26 1097557 3784 3784 5.990008\n", + "4 1097557 16228 16228 5.756312\n", + "31 1097557 7417 7417 5.715013\n", + "32 1097557 7829 7829 5.615193\n", + "19 1097557 4495 4495 5.563930\n", + "22 1097557 14431 14431 5.558556\n", + "15 1097557 4457 4457 5.548639\n", + "28 1097557 12995 12995 5.495888\n", + "41023374 0 7417 7417 5.715013\n", + "41023361 0 7829 7829 5.615193\n", + "41023375 0 4495 4495 5.563930\n", + "41023358 0 14431 14431 5.558556\n", + "41023364 0 4457 4457 5.548639\n", + "41023363 0 12995 12995 5.495888\n", + "41023378 0 1844 1844 5.419019\n", + "41023366 0 11237 11237 5.365593\n", + "41023360 0 7571 7571 5.267906\n", + "41023388 0 4740 4740 5.078522\n", + "41023380 0 8636 8636 5.041418\n", + "41023381 0 9996 9996 4.992277\n", + "41023389 0 6809 6809 4.917360\n", + "41023386 0 142 142 4.801620\n", + "41023384 0 4880 4880 4.610045\n", + "41023382 0 2657 2657 4.392592\n", + "41023372 0 3734 3734 4.306872\n", + "41023365 0 4151 4151 4.111983\n", + "41023385 0 13865 13865 3.825227\n", + "41023390 0 10440 10440 3.333947" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_recs = df_all_recs.merge(\n", + " item_idf, left_on='item_id', right_on='index', how='left'\n", + ").sort_values(['user_id', 'idf'], ascending=False)\n", + "\n", + "pd.concat([df_all_recs.head(20), df_all_recs.tail(20)])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "36ecc1dd", + "metadata": {}, + "outputs": [], + "source": [ + "del df_all_recs['index'], df_all_recs['idf']" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "2c868313", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Количество пользователей, у которорых рекомендаций меньше 10: 21\n" + ] + } + ], + "source": [ + "count_recs_by_users = df_all_recs.user_id.value_counts()\n", + "print(f\"Количество пользователей, у которорых рекомендаций меньше 10: {len(count_recs_by_users[count_recs_by_users < 10])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9820e5ab", + "metadata": {}, + "source": [ + "Для пользователей, у которых будет меньше рекомендаций, чем k_recs, рекомендации **будут пополняться популярным**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "69d1fc9c", + "metadata": {}, + "outputs": [], + "source": [ + "df_popular = pd.read_csv('../data/hw_3/popular_item.csv')\n", + "users_need = count_recs_by_users[count_recs_by_users < 10].index" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7242e8e3", + "metadata": {}, + "outputs": [], + "source": [ + "k_recs = 10\n", + "users, recs = [], []\n", + "for user, count in dict(count_recs_by_users[count_recs_by_users < 10]).items():\n", + " need_recs = k_recs - count\n", + " users.extend([user for _ in range(need_recs)])\n", + " recs.extend(df_popular[\"item_id\"][:need_recs].to_list())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "29eaaadc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_id
010975575803
110975576382
210975577476
310975574716
4109755714
.........
4102338702657
4102338803734
4102338904151
41023390013865
41023391010440
\n", + "

41023392 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id\n", + "0 1097557 5803\n", + "1 1097557 6382\n", + "2 1097557 7476\n", + "3 1097557 4716\n", + "4 1097557 14\n", + "... ... ...\n", + "41023387 0 2657\n", + "41023388 0 3734\n", + "41023389 0 4151\n", + "41023390 0 13865\n", + "41023391 0 10440\n", + "\n", + "[41023392 rows x 2 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_recs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "30df7064", + "metadata": {}, + "outputs": [], + "source": [ + "df_need = pd.DataFrame({\"user_id\": users, \"item_id\": recs})\n", + "df_all_recs = pd.concat([df_all_recs, df_need], ignore_index=True).sort_values(\"user_id\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "34f2a303", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Количество пользователей, у которорых рекомендаций меньше 10: 0\n" + ] + } + ], + "source": [ + "count_recs_by_users = df_all_recs.user_id.value_counts()\n", + "print(f\"Количество пользователей, у которорых рекомендаций меньше 10: {len(count_recs_by_users[count_recs_by_users < 10])}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f9d7c2bc", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_recs.to_csv(\"../data/hw_3/blending_tfidf_bmp25_idf_rectools.csv\", index=False)" + ] + }, + { + "cell_type": "markdown", + "id": "b8e2c037", + "metadata": {}, + "source": [ + "Offline рекомендации не работали с блендингом, решил уменьшить количество рекомендаций для одного юзера до 10 и заработало" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6e76be8c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(9621050, 2)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_recs['rank'] = df_all_recs.groupby('user_id').cumcount() + 1\n", + "df_all_recs_top10 = df_all_recs[df_all_recs['rank'] <= 10]\n", + "del df_all_recs_top10['rank']\n", + "df_all_recs_top10.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "94d693eb", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_recs_top10.to_csv(\"../data/hw_3/blending_tfidf_bmp25_idf_rectools_10.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38f155ab", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/HW-3.4-model-for-online-recs.ipynb b/notebooks/HW-3.4-model-for-online-recs.ipynb new file mode 100644 index 00000000..fafafb1f --- /dev/null +++ b/notebooks/HW-3.4-model-for-online-recs.ipynb @@ -0,0 +1,1161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "faa0d200", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "54683b63", + "metadata": {}, + "outputs": [], + "source": [ + "import typing as tp\n", + "\n", + "import dill\n", + "import pandas as pd\n", + "import numpy as np\n", + "from implicit.nearest_neighbours import BM25Recommender, TFIDFRecommender\n", + "from rectools import Columns\n", + "import scipy as sp" + ] + }, + { + "cell_type": "markdown", + "id": "e00f73f1", + "metadata": {}, + "source": [ + "В ноутбуку \"HW-3.3\" c помощью стратегии валидации по неделям были отобраны несколько моделей с наиболее высокими метриками:\n", + "\n", + "- BMP25Recommender с гиперпараметром k = 60\n", + "- TFIDFRecommender с гиперпараметром k = 60\n", + "\n", + "Для этих моделей сформированы оффлайн рекомендации, которые показали 0.10384918 и 0.09577425 соответственно.\n", + "\n", + "Для формирования онлайн рекомендаций следует обучить те же архитектуры моделей с такими же гиперпараметрами из библиотеки implicit" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e7ecfc4", + "metadata": {}, + "outputs": [], + "source": [ + "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", + "\n", + "interactions.rename(columns={\n", + " 'last_watch_dt': Columns.Datetime,\n", + " 'total_dur': Columns.Weight\n", + " }, \n", + " inplace=True\n", + ") \n", + "\n", + "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" + ] + }, + { + "cell_type": "markdown", + "id": "55fbfe8e", + "metadata": {}, + "source": [ + "# Create train data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "57f1394b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unique users: 962179\n", + "Unique items: 15706\n" + ] + } + ], + "source": [ + "# формирование id для user и item\n", + "users_inv_mapping = dict(enumerate(interactions['user_id'].unique()))\n", + "users_mapping = {v: k for k, v in users_inv_mapping.items()}\n", + "items_inv_mapping = dict(enumerate(interactions['item_id'].unique()))\n", + "items_mapping = {v: k for k, v in items_inv_mapping.items()}\n", + "print(f\"Unique users: {len(users_inv_mapping)}\")\n", + "print(f\"Unique items: {len(items_inv_mapping)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bc704533", + "metadata": {}, + "outputs": [], + "source": [ + "def get_matrix(\n", + " df: pd.DataFrame,\n", + " user_col: str = Columns.User,\n", + " item_col: str = Columns.Item,\n", + " weight_col: str = None,\n", + " users_mapping: tp.Dict[int, int] = None,\n", + " items_mapping: tp.Dict[int, int] = None\n", + "):\n", + "\n", + " if weight_col:\n", + " weights = df[weight_col].astype(np.float32)\n", + " else:\n", + " weights = np.ones(len(df), dtype=np.float32)\n", + "\n", + " interaction_matrix = sp.sparse.coo_matrix((\n", + " weights,\n", + " (\n", + " df[user_col].map(users_mapping.get),\n", + " df[item_col].map(items_mapping.get)\n", + " )\n", + " ))\n", + "\n", + " watched = df.groupby(user_col).agg({item_col: list})\n", + " return interaction_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "00f46a78", + "metadata": {}, + "outputs": [], + "source": [ + "weight_matrix = get_matrix(\n", + " df=interactions,\n", + " users_mapping=users_mapping,\n", + " items_mapping=items_mapping\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1cc272f9", + "metadata": {}, + "source": [ + "# Models train" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2d8ec8a1", + "metadata": {}, + "outputs": [], + "source": [ + "model_implicit_tfidf = TFIDFRecommender(K=60)\n", + "model_implicit_bmp25 = BM25Recommender(K=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "17990623", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0ce1d4c0e2184dfa8d159906a145f011", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/962179 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_id
user_id
0[7102, 14359, 15297, 6006, 9728, 12192]
1[3669, 10440]
2[7571, 3541, 15266, 13867, 12841, 10770, 4475,...
3[12192, 9728, 16406, 15719, 10440, 3475, 2025,...
4[4700, 6317]
1097553[24, 13058, 12463, 12659]
1097554[16361, 496, 1053, 11275, 4580, 1151, 849, 350...
1097555[14703, 140, 9728, 496, 6916, 4662, 4880]
1097556[12812]
1097557[4151, 3182, 15297]
\n", + "" + ], + "text/plain": [ + " item_id\n", + "user_id \n", + "0 [7102, 14359, 15297, 6006, 9728, 12192]\n", + "1 [3669, 10440]\n", + "2 [7571, 3541, 15266, 13867, 12841, 10770, 4475,...\n", + "3 [12192, 9728, 16406, 15719, 10440, 3475, 2025,...\n", + "4 [4700, 6317]\n", + "1097553 [24, 13058, 12463, 12659]\n", + "1097554 [16361, 496, 1053, 11275, 4580, 1151, 849, 350...\n", + "1097555 [14703, 140, 9728, 496, 6916, 4662, 4880]\n", + "1097556 [12812]\n", + "1097557 [4151, 3182, 15297]" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watched = interactions.groupby('user_id').agg({'item_id': list})\n", + "pd.concat([watched.head(), watched.tail()])" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "1460393c", + "metadata": {}, + "outputs": [], + "source": [ + "def recs_mapper(user, model, user_mapping, user_inv_mapping, k_reco: int = 10, bmp: bool = False):\n", + " user_id = user_mapping[user]\n", + " recs = model.similar_items(user_id, N=k_reco)\n", + " result = pd.DataFrame(\n", + " {\n", + " \"sim_user_id\": [user_inv_mapping[user] for user, _ in recs], \n", + " \"sim\": [sim for _, sim in recs] def\n", + " }\n", + " )\n", + " \n", + " if bmp:\n", + " return result[result['sim_user_id'] != user]\n", + " else: \n", + " return result[~(result['sim'] >= 1)] " + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "011fe4fb", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "sample_users = interactions[Columns.User].sample(100).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "3ff09747", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12861\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sim_user_idsim
17372390.479295
210695580.427898
39333710.419511
44098500.391727
59892530.384045
68176360.380609
710784200.372851
81635950.370077
910037830.368852
\n", + "
" + ], + "text/plain": [ + " sim_user_id sim\n", + "1 737239 0.479295\n", + "2 1069558 0.427898\n", + "3 933371 0.419511\n", + "4 409850 0.391727\n", + "5 989253 0.384045\n", + "6 817636 0.380609\n", + "7 1078420 0.372851\n", + "8 163595 0.370077\n", + "9 1003783 0.368852" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(sample_users[0])\n", + "df_sim = recs_mapper(sample_users[0], model_implicit_tfidf, users_mapping, users_inv_mapping)\n", + "df_sim" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "757a24ec", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sim_user_idsimitem_id
07372390.47929510755
07372390.479295496
07372390.47929512324
07372390.47929510219
07372390.4792956898
07372390.47929514476
07372390.47929513411
07372390.4792959194
07372390.4792956404
07372390.47929514961
07372390.47929512995
110695580.4278985287
110695580.42789813973
110695580.42789813865
34098500.3917277793
49892530.3840456033
49892530.384045799
49892530.3840459617
49892530.3840455405
49892530.38404513849
49892530.38404512846
58176360.3806092981
610784200.3728513935
610784200.37285110283
71635950.3700779728
\n", + "
" + ], + "text/plain": [ + " sim_user_id sim item_id\n", + "0 737239 0.479295 10755\n", + "0 737239 0.479295 496\n", + "0 737239 0.479295 12324\n", + "0 737239 0.479295 10219\n", + "0 737239 0.479295 6898\n", + "0 737239 0.479295 14476\n", + "0 737239 0.479295 13411\n", + "0 737239 0.479295 9194\n", + "0 737239 0.479295 6404\n", + "0 737239 0.479295 14961\n", + "0 737239 0.479295 12995\n", + "1 1069558 0.427898 5287\n", + "1 1069558 0.427898 13973\n", + "1 1069558 0.427898 13865\n", + "3 409850 0.391727 7793\n", + "4 989253 0.384045 6033\n", + "4 989253 0.384045 799\n", + "4 989253 0.384045 9617\n", + "4 989253 0.384045 5405\n", + "4 989253 0.384045 13849\n", + "4 989253 0.384045 12846\n", + "5 817636 0.380609 2981\n", + "6 1078420 0.372851 3935\n", + "6 1078420 0.372851 10283\n", + "7 163595 0.370077 9728" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_sim = df_sim.merge(\n", + " watched, left_on=['sim_user_id'], right_on=['user_id'], how='left'\n", + ").explode('item_id').sort_values(\n", + " [ 'sim'], ascending=False\n", + ").drop_duplicates(\n", + " ['item_id'], keep='first'\n", + ")\n", + "df_sim" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "87a9e994", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12861\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sim_user_idsim
173723985.368217
2100621675.500227
393337171.779085
498925371.633147
5106955871.248461
612473569.342326
7107842067.839079
828985467.379224
940985066.214999
\n", + "
" + ], + "text/plain": [ + " sim_user_id sim\n", + "1 737239 85.368217\n", + "2 1006216 75.500227\n", + "3 933371 71.779085\n", + "4 989253 71.633147\n", + "5 1069558 71.248461\n", + "6 124735 69.342326\n", + "7 1078420 67.839079\n", + "8 289854 67.379224\n", + "9 409850 66.214999" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(sample_users[0])\n", + "df_sim = recs_mapper(sample_users[0], model_implicit_bmp25, users_mapping, users_inv_mapping, bmp=True)\n", + "df_sim" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "2557f73f", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sim_user_idsimitem_id
073723985.36821710755
073723985.36821714961
073723985.3682176404
073723985.3682179194
073723985.36821713411
073723985.368217496
073723985.3682176898
073723985.36821710219
073723985.36821714476
073723985.36821712324
073723985.36821712995
1100621675.50022713325
1100621675.5002278891
1100621675.5002275287
398925371.63314713865
398925371.633147799
398925371.6331476033
398925371.6331479617
398925371.63314713849
398925371.63314712846
398925371.6331475405
4106955871.24846113973
512473569.3423269288
512473569.3423262100
512473569.34232614242
512473569.3423264702
6107842067.8390793935
6107842067.83907910283
6107842067.8390792981
728985467.37922416021
728985467.3792244116
728985467.37922415464
840985066.2149997793
\n", + "
" + ], + "text/plain": [ + " sim_user_id sim item_id\n", + "0 737239 85.368217 10755\n", + "0 737239 85.368217 14961\n", + "0 737239 85.368217 6404\n", + "0 737239 85.368217 9194\n", + "0 737239 85.368217 13411\n", + "0 737239 85.368217 496\n", + "0 737239 85.368217 6898\n", + "0 737239 85.368217 10219\n", + "0 737239 85.368217 14476\n", + "0 737239 85.368217 12324\n", + "0 737239 85.368217 12995\n", + "1 1006216 75.500227 13325\n", + "1 1006216 75.500227 8891\n", + "1 1006216 75.500227 5287\n", + "3 989253 71.633147 13865\n", + "3 989253 71.633147 799\n", + "3 989253 71.633147 6033\n", + "3 989253 71.633147 9617\n", + "3 989253 71.633147 13849\n", + "3 989253 71.633147 12846\n", + "3 989253 71.633147 5405\n", + "4 1069558 71.248461 13973\n", + "5 124735 69.342326 9288\n", + "5 124735 69.342326 2100\n", + "5 124735 69.342326 14242\n", + "5 124735 69.342326 4702\n", + "6 1078420 67.839079 3935\n", + "6 1078420 67.839079 10283\n", + "6 1078420 67.839079 2981\n", + "7 289854 67.379224 16021\n", + "7 289854 67.379224 4116\n", + "7 289854 67.379224 15464\n", + "8 409850 66.214999 7793" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_sim = df_sim.merge(\n", + " watched, left_on=['sim_user_id'], right_on=['user_id'], how='left'\n", + ").explode('item_id').sort_values(\n", + " [ 'sim'], ascending=False\n", + ").drop_duplicates(\n", + " ['item_id'], keep='first'\n", + ")\n", + "df_sim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e60a9c9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/userknn.py b/userknn.py new file mode 100644 index 00000000..e7cf55ab --- /dev/null +++ b/userknn.py @@ -0,0 +1,112 @@ +from typing import Dict +from collections import Counter + +import pandas as pd +import numpy as np +import scipy as sp +from implicit.nearest_neighbours import ItemItemRecommender + + +class UserKnn(): + """Class for fit-perdict UserKNN model + based on ItemKNN model from implicit.nearest_neighbours + """ + + def __init__(self, model: ItemItemRecommender, N_users: int = 50): + self.N_users = N_users + self.model = model + self.is_fitted = False + + def get_mappings(self, train): + self.users_inv_mapping = dict(enumerate(train['user_id'].unique())) + self.users_mapping = {v: k for k, v in self.users_inv_mapping.items()} + + self.items_inv_mapping = dict(enumerate(train['item_id'].unique())) + self.items_mapping = {v: k for k, v in self.items_inv_mapping.items()} + + def get_matrix(self, df: pd.DataFrame, + user_col: str = 'user_id', + item_col: str = 'item_id', + weight_col: str = None, + users_mapping: Dict[int, int] = None, + items_mapping: Dict[int, int] = None): + + if weight_col: + weights = df[weight_col].astype(np.float32) + else: + weights = np.ones(len(df), dtype=np.float32) + + self.interaction_matrix = sp.sparse.coo_matrix(( + weights, + ( + df[item_col].map(self.items_mapping.get), + df[user_col].map(self.users_mapping.get) + ) + )) + + self.watched = df\ + .groupby(user_col, as_index=False)\ + .agg({item_col: list})\ + .rename(columns={user_col: 'sim_user_id'}) + + return self.interaction_matrix + + def idf(self, n: int, x: float): + return np.log((1 + n) / (1 + x) + 1) + + def _count_item_idf(self, df: pd.DataFrame): + item_cnt = Counter(df['item_id'].values) + item_idf = pd.DataFrame.from_dict(item_cnt, orient='index', + columns=['doc_freq']).reset_index() + item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: self.idf(self.n, x)) + self.item_idf = item_idf + + def fit(self, train: pd.DataFrame): + self.user_knn = self.model + self.get_mappings(train) + self.weights_matrix = self.get_matrix(train, + users_mapping=self.users_mapping, + items_mapping=self.items_mapping) + + self.n = train.shape[0] + self._count_item_idf(train) + + self.user_knn.fit(self.weights_matrix) + self.is_fitted = True + + def _generate_recs_mapper(self, model: ItemItemRecommender, user_mapping: Dict[int, int], + user_inv_mapping: Dict[int, int], N: int): + def _recs_mapper(user): + user_id = self.users_mapping[user] + users, sim = model.similar_items(user_id, N=N) + return [self.users_inv_mapping[user] for user in users], sim + return _recs_mapper + + def predict(self, test: pd.DataFrame, N_recs: int = 10): + + if not self.is_fitted: + raise ValueError("Please call fit before predict") + + mapper = self._generate_recs_mapper( + model=self.user_knn, + user_mapping=self.users_mapping, + user_inv_mapping=self.users_inv_mapping, + N=self.N_users + ) + + recs = pd.DataFrame({'user_id': test['user_id'].unique()}) + recs['sim_user_id'], recs['sim'] = zip(*recs['user_id'].map(mapper)) + recs = recs.set_index('user_id').apply(pd.Series.explode).reset_index() + + recs = recs[~(recs['user_id'] == recs['sim_user_id'])]\ + .merge(self.watched, on=['sim_user_id'], how='left')\ + .explode('item_id')\ + .sort_values(['user_id', 'sim'], ascending=False)\ + .drop_duplicates(['user_id', 'item_id'], keep='first')\ + .merge(self.item_idf, left_on='item_id', right_on='index', how='left') + + recs['score'] = recs['sim'] * recs['idf'] + recs = recs.sort_values(['user_id', 'score'], ascending=False) + recs['rank'] = recs.groupby('user_id').cumcount() + 1 + return recs[recs['rank'] <= N_recs][['user_id', 'item_id', 'score', 'rank']] + \ No newline at end of file From 07f2cf022396bd77e5d62b43e3a700516775211a Mon Sep 17 00:00:00 2001 From: anettapik <120940816+anettapik@users.noreply.github.com> Date: Mon, 27 Nov 2023 17:00:17 +0300 Subject: [PATCH 5/7] Delete notebooks directory --- notebooks/HW-3.1.ipynb | 4027 ---------------- notebooks/HW-3.2-rectools-research.ipynb | 725 --- notebooks/HW-3.3-rectools-cv.ipynb | 4387 ------------------ notebooks/HW-3.4-model-for-online-recs.ipynb | 1161 ----- 4 files changed, 10300 deletions(-) delete mode 100644 notebooks/HW-3.1.ipynb delete mode 100644 notebooks/HW-3.2-rectools-research.ipynb delete mode 100644 notebooks/HW-3.3-rectools-cv.ipynb delete mode 100644 notebooks/HW-3.4-model-for-online-recs.ipynb diff --git a/notebooks/HW-3.1.ipynb b/notebooks/HW-3.1.ipynb deleted file mode 100644 index c05a3b71..00000000 --- a/notebooks/HW-3.1.ipynb +++ /dev/null @@ -1,4027 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "398a86d9", - "metadata": {}, - "outputs": [], - "source": [ - "from pprint import pprint\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "import sys\n", - "sys.path.append('../')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "8dbe6bf0", - "metadata": {}, - "outputs": [], - "source": [ - "import plotly.express as px\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy as sp\n", - "import requests\n", - "from tqdm.auto import tqdm\n", - "from scipy.stats import mode\n", - "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n", - "from rectools import Columns\n", - "from rectools.model_selection import TimeRangeSplitter\n", - "from rectools.metrics import Precision, Recall, MAP, MeanInvUserFreq, Serendipity, calc_metrics\n", - "from rectools.dataset.interactions import Interactions\n", - "\n", - "from service.utils.user_knn import UserKnn" - ] - }, - { - "cell_type": "markdown", - "id": "b1baa79f", - "metadata": {}, - "source": [ - "# Data" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "f2a9e540", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((5476251, 5), (840197, 5), (15963, 14))" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", - "users = pd.read_csv('../data/kion_train/users.csv')\n", - "items = pd.read_csv('../data/kion_train/items.csv')\n", - "\n", - "interactions.shape, users.shape, items.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "456d25f4", - "metadata": {}, - "outputs": [], - "source": [ - "interactions.rename(\n", - " columns={\n", - " 'last_watch_dt': Columns.Datetime,\n", - " 'total_dur': Columns.Weight\n", - " }, \n", - " inplace=True) \n", - "\n", - "interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime])" - ] - }, - { - "cell_type": "markdown", - "id": "6f7b9b0c", - "metadata": {}, - "source": [ - "## Intersection" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7c9c0c94", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_iddatetimeweightwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", - "
" - ], - "text/plain": [ - " user_id item_id datetime weight watched_pct\n", - "0 176549 9506 2021-05-11 4250 72.0\n", - "1 699317 1659 2021-05-29 8317 100.0\n", - "2 656683 7107 2021-05-09 10 0.0\n", - "3 864613 7638 2021-07-05 14483 100.0\n", - "4 964868 9506 2021-04-30 6725 100.0\n", - "5476246 648596 12225 2021-08-13 76 0.0\n", - "5476247 546862 9673 2021-04-13 2308 49.0\n", - "5476248 697262 15297 2021-08-20 18307 63.0\n", - "5476249 384202 16197 2021-04-19 6203 100.0\n", - "5476250 319709 4436 2021-08-15 3921 45.0" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([interactions.head(), interactions.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "c5c3ce6c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Interactions dataframe shape: (5476251, 5)\n", - "Unique users in interactions: 962179\n", - "Unique items in interactions: 15706\n" - ] - } - ], - "source": [ - "print(f\"Interactions dataframe shape: {interactions.shape}\")\n", - "print(f\"Unique users in interactions: {interactions[Columns.User].nunique()}\")\n", - "print(f\"Unique items in interactions: {interactions[Columns.Item].nunique()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "0214a978", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "min date in interactions: 2021-03-13 00:00:00\n", - "max date in interactions: 2021-08-22 00:00:00\n" - ] - } - ], - "source": [ - "max_date = interactions[Columns.Datetime].max()\n", - "min_date = interactions[Columns.Datetime].min()\n", - "\n", - "print(f\"min date in interactions: {min_date}\")\n", - "print(f\"max date in interactions: {max_date}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "7829e796", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "RangeIndex: 5476251 entries, 0 to 5476250\n", - "Data columns (total 5 columns):\n", - " # Column Dtype \n", - "--- ------ ----- \n", - " 0 user_id int64 \n", - " 1 item_id int64 \n", - " 2 datetime datetime64[ns]\n", - " 3 weight int64 \n", - " 4 watched_pct float64 \n", - "dtypes: datetime64[ns](1), float64(1), int64(3)\n", - "memory usage: 208.9 MB\n" - ] - } - ], - "source": [ - "interactions.info()" - ] - }, - { - "cell_type": "markdown", - "id": "57cddf34", - "metadata": {}, - "source": [ - "## Users" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "de5dea16", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idageincomesexkids_flg
0973171age_25_34income_60_90М1
1962099age_18_24income_20_40М0
21047345age_45_54income_40_60Ж0
3721985age_45_54income_20_40Ж0
4704055age_35_44income_60_90Ж0
840192339025age_65_infincome_0_20Ж0
840193983617age_18_24income_20_40Ж1
840194251008NaNNaNNaN0
840195590706NaNNaNЖ0
840196166555age_65_infincome_20_40Ж0
\n", - "
" - ], - "text/plain": [ - " user_id age income sex kids_flg\n", - "0 973171 age_25_34 income_60_90 М 1\n", - "1 962099 age_18_24 income_20_40 М 0\n", - "2 1047345 age_45_54 income_40_60 Ж 0\n", - "3 721985 age_45_54 income_20_40 Ж 0\n", - "4 704055 age_35_44 income_60_90 Ж 0\n", - "840192 339025 age_65_inf income_0_20 Ж 0\n", - "840193 983617 age_18_24 income_20_40 Ж 1\n", - "840194 251008 NaN NaN NaN 0\n", - "840195 590706 NaN NaN Ж 0\n", - "840196 166555 age_65_inf income_20_40 Ж 0" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([users.head(), users.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "e4e6d2f5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Users dataframe shape (840197, 5)\n", - "Unique users: 840197\n" - ] - } - ], - "source": [ - "print(f\"Users dataframe shape {users.shape}\")\n", - "print(f\"Unique users: {users['user_id'].nunique()}\")" - ] - }, - { - "cell_type": "markdown", - "id": "98b4ff6c", - "metadata": {}, - "source": [ - "## Items" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "19b43ff0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
item_idcontent_typetitletitle_origrelease_yeargenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywords
010711filmПоговори с нейHable con ella2002.0драмы, зарубежные, детективы, мелодрамыИспанияNaN16.0NaNПедро АльмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...
12508filmГолые перцыSearch Party2014.0зарубежные, приключения, комедииСШАNaN16.0NaNСкот АрмстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...
159614538seriesСреди камнейDarklands2019.0драмы, спорт, криминалРоссия0.018.0NaNМарк О’Коннор, Конор МакМахонДэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд...Семнадцатилетний Дэмиен мечтает вырваться за п...Среди, камней, 2019, Россия
159623206seriesГошаNaN2019.0комедииРоссия0.016.0NaNМихаил МироновМкртыч Арзуманян, Виктория РунцоваДобродушный Гоша не может выйти из дома, чтобы...Гоша, 2019, Россия
\n", - "
" - ], - "text/plain": [ - " item_id content_type title title_orig release_year \\\n", - "0 10711 film Поговори с ней Hable con ella 2002.0 \n", - "1 2508 film Голые перцы Search Party 2014.0 \n", - "15961 4538 series Среди камней Darklands 2019.0 \n", - "15962 3206 series Гоша NaN 2019.0 \n", - "\n", - " genres countries for_kids \\\n", - "0 драмы, зарубежные, детективы, мелодрамы Испания NaN \n", - "1 зарубежные, приключения, комедии США NaN \n", - "15961 драмы, спорт, криминал Россия 0.0 \n", - "15962 комедии Россия 0.0 \n", - "\n", - " age_rating studios directors \\\n", - "0 16.0 NaN Педро Альмодовар \n", - "1 16.0 NaN Скот Армстронг \n", - "15961 18.0 NaN Марк О’Коннор, Конор МакМахон \n", - "15962 16.0 NaN Михаил Миронов \n", - "\n", - " actors \\\n", - "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", - "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", - "15961 Дэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд... \n", - "15962 Мкртыч Арзуманян, Виктория Рунцова \n", - "\n", - " description \\\n", - "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", - "1 Уморительная современная комедия на популярную... \n", - "15961 Семнадцатилетний Дэмиен мечтает вырваться за п... \n", - "15962 Добродушный Гоша не может выйти из дома, чтобы... \n", - "\n", - " keywords \n", - "0 Поговори, ней, 2002, Испания, друзья, любовь, ... \n", - "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... \n", - "15961 Среди, камней, 2019, Россия \n", - "15962 Гоша, 2019, Россия " - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([items.head(2), items.tail(2)])" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "8c8fb319", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Items dataframe shape (15963, 14)\n", - "Unique item_id: 15963\n" - ] - } - ], - "source": [ - "print(f\"Items dataframe shape {items.shape}\")\n", - "print(f\"Unique item_id: {items['item_id'].nunique()}\")" - ] - }, - { - "cell_type": "markdown", - "id": "2b35b460", - "metadata": {}, - "source": [ - "# userkNN model CV" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f60e6ecb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - " \n", - " " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "alignmentgroup": "True", - "hovertemplate": "variable=user_id
datetime=%{x}
value=%{y}", - "legendgroup": "user_id", - "marker": { - "color": "#636efa", - "pattern": { - "shape": "" - } - }, - "name": "user_id", - "offsetgroup": "user_id", - "orientation": "v", - "showlegend": true, - "textposition": "auto", - "type": "bar", - "x": [ - "2021-03-13T00:00:00", - "2021-03-14T00:00:00", - "2021-03-15T00:00:00", - "2021-03-16T00:00:00", - "2021-03-17T00:00:00", - "2021-03-18T00:00:00", - "2021-03-19T00:00:00", - "2021-03-20T00:00:00", - "2021-03-21T00:00:00", - "2021-03-22T00:00:00", - "2021-03-23T00:00:00", - "2021-03-24T00:00:00", - "2021-03-25T00:00:00", - "2021-03-26T00:00:00", - "2021-03-27T00:00:00", - "2021-03-28T00:00:00", - "2021-03-29T00:00:00", - "2021-03-30T00:00:00", - "2021-03-31T00:00:00", - "2021-04-01T00:00:00", - "2021-04-02T00:00:00", - "2021-04-03T00:00:00", - "2021-04-04T00:00:00", - "2021-04-05T00:00:00", - "2021-04-06T00:00:00", - "2021-04-07T00:00:00", - "2021-04-08T00:00:00", - "2021-04-09T00:00:00", - "2021-04-10T00:00:00", - "2021-04-11T00:00:00", - "2021-04-12T00:00:00", - "2021-04-13T00:00:00", - "2021-04-14T00:00:00", - "2021-04-15T00:00:00", - "2021-04-16T00:00:00", - "2021-04-17T00:00:00", - "2021-04-18T00:00:00", - "2021-04-19T00:00:00", - "2021-04-20T00:00:00", - "2021-04-21T00:00:00", - "2021-04-22T00:00:00", - "2021-04-23T00:00:00", - "2021-04-24T00:00:00", - "2021-04-25T00:00:00", - "2021-04-26T00:00:00", - "2021-04-27T00:00:00", - "2021-04-28T00:00:00", - "2021-04-29T00:00:00", - "2021-04-30T00:00:00", - "2021-05-01T00:00:00", - "2021-05-02T00:00:00", - "2021-05-03T00:00:00", - "2021-05-04T00:00:00", - "2021-05-05T00:00:00", - "2021-05-06T00:00:00", - "2021-05-07T00:00:00", - "2021-05-08T00:00:00", - "2021-05-09T00:00:00", - "2021-05-10T00:00:00", - "2021-05-11T00:00:00", - "2021-05-12T00:00:00", - "2021-05-13T00:00:00", - "2021-05-14T00:00:00", - "2021-05-15T00:00:00", - "2021-05-16T00:00:00", - "2021-05-17T00:00:00", - "2021-05-18T00:00:00", - "2021-05-19T00:00:00", - "2021-05-20T00:00:00", - "2021-05-21T00:00:00", - "2021-05-22T00:00:00", - "2021-05-23T00:00:00", - "2021-05-24T00:00:00", - "2021-05-25T00:00:00", - "2021-05-26T00:00:00", - "2021-05-27T00:00:00", - "2021-05-28T00:00:00", - "2021-05-29T00:00:00", - "2021-05-30T00:00:00", - "2021-05-31T00:00:00", - "2021-06-01T00:00:00", - "2021-06-02T00:00:00", - "2021-06-03T00:00:00", - "2021-06-04T00:00:00", - "2021-06-05T00:00:00", - "2021-06-06T00:00:00", - "2021-06-07T00:00:00", - "2021-06-08T00:00:00", - "2021-06-09T00:00:00", - "2021-06-10T00:00:00", - "2021-06-11T00:00:00", - "2021-06-12T00:00:00", - "2021-06-13T00:00:00", - "2021-06-14T00:00:00", - "2021-06-15T00:00:00", - "2021-06-16T00:00:00", - "2021-06-17T00:00:00", - "2021-06-18T00:00:00", - "2021-06-19T00:00:00", - "2021-06-20T00:00:00", - "2021-06-21T00:00:00", - "2021-06-22T00:00:00", - "2021-06-23T00:00:00", - "2021-06-24T00:00:00", - "2021-06-25T00:00:00", - "2021-06-26T00:00:00", - "2021-06-27T00:00:00", - "2021-06-28T00:00:00", - "2021-06-29T00:00:00", - "2021-06-30T00:00:00", - "2021-07-01T00:00:00", - "2021-07-02T00:00:00", - "2021-07-03T00:00:00", - "2021-07-04T00:00:00", - "2021-07-05T00:00:00", - "2021-07-06T00:00:00", - "2021-07-07T00:00:00", - "2021-07-08T00:00:00", - "2021-07-09T00:00:00", - "2021-07-10T00:00:00", - "2021-07-11T00:00:00", - "2021-07-12T00:00:00", - "2021-07-13T00:00:00", - "2021-07-14T00:00:00", - "2021-07-15T00:00:00", - "2021-07-16T00:00:00", - "2021-07-17T00:00:00", - "2021-07-18T00:00:00", - "2021-07-19T00:00:00", - "2021-07-20T00:00:00", - "2021-07-21T00:00:00", - "2021-07-22T00:00:00", - "2021-07-23T00:00:00", - "2021-07-24T00:00:00", - "2021-07-25T00:00:00", - "2021-07-26T00:00:00", - "2021-07-27T00:00:00", - "2021-07-28T00:00:00", - "2021-07-29T00:00:00", - "2021-07-30T00:00:00", - "2021-07-31T00:00:00", - "2021-08-01T00:00:00", - "2021-08-02T00:00:00", - "2021-08-03T00:00:00", - "2021-08-04T00:00:00", - "2021-08-05T00:00:00", - "2021-08-06T00:00:00", - "2021-08-07T00:00:00", - "2021-08-08T00:00:00", - "2021-08-09T00:00:00", - "2021-08-10T00:00:00", - "2021-08-11T00:00:00", - "2021-08-12T00:00:00", - "2021-08-13T00:00:00", - "2021-08-14T00:00:00", - "2021-08-15T00:00:00", - "2021-08-16T00:00:00", - "2021-08-17T00:00:00", - "2021-08-18T00:00:00", - "2021-08-19T00:00:00", - "2021-08-20T00:00:00", - "2021-08-21T00:00:00", - "2021-08-22T00:00:00" - ], - "xaxis": "x", - "y": [ - 16104, - 15606, - 12363, - 12643, - 12753, - 12788, - 13657, - 15346, - 15560, - 12752, - 13147, - 13435, - 12698, - 13909, - 15657, - 16112, - 12783, - 13101, - 13460, - 12966, - 14084, - 15431, - 15346, - 12642, - 12528, - 13129, - 13827, - 14416, - 15937, - 16046, - 12835, - 12322, - 12451, - 12275, - 13342, - 15464, - 16275, - 14286, - 20420, - 23200, - 21274, - 22127, - 26161, - 28964, - 21625, - 22590, - 21406, - 19987, - 21406, - 23479, - 24767, - 26267, - 25983, - 23941, - 23510, - 23201, - 27550, - 25986, - 27242, - 20957, - 20578, - 20729, - 21152, - 24530, - 24914, - 20960, - 20574, - 21561, - 22712, - 25697, - 27895, - 29978, - 24317, - 23667, - 22529, - 23881, - 24131, - 29035, - 31308, - 26821, - 26587, - 27577, - 28683, - 33150, - 34795, - 37096, - 31402, - 31107, - 32896, - 38964, - 37935, - 38619, - 42125, - 38973, - 35993, - 57686, - 41440, - 42174, - 43679, - 47989, - 39127, - 39693, - 41688, - 38394, - 41428, - 45898, - 48903, - 43301, - 43887, - 67749, - 53900, - 46642, - 48832, - 52812, - 43375, - 41380, - 41163, - 41592, - 40955, - 44798, - 46250, - 42487, - 43764, - 43128, - 43010, - 44878, - 49714, - 54139, - 45541, - 44431, - 44422, - 46313, - 46911, - 50317, - 54378, - 48531, - 49324, - 50267, - 50585, - 53121, - 59499, - 62128, - 53495, - 52181, - 51911, - 51047, - 53745, - 59316, - 61454, - 52794, - 53712, - 55617, - 56497, - 55843, - 61644, - 66546, - 54546, - 54311, - 56789, - 58640, - 60145, - 68834, - 71171 - ], - "yaxis": "y" - } - ], - "layout": { - "barmode": "relative", - "legend": { - "title": { - "text": "variable" - }, - "tracegroupgap": 0 - }, - "margin": { - "t": 60 - }, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "xaxis": { - "anchor": "y", - "domain": [ - 0, - 1 - ], - "title": { - "text": "datetime" - } - }, - "yaxis": { - "anchor": "x", - "domain": [ - 0, - 1 - ], - "title": { - "text": "value" - } - } - } - }, - "text/html": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig = px.bar(interactions.groupby(Columns.Datetime)[Columns.User].agg('count'))\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "id": "43f216d0", - "metadata": {}, - "source": [ - "Из графика видны **недельные тенденции** просмотров, поэтому следует fold-ы разделять по 7 дней, но т.к. на семинаре дали \"намек\", что private dataset имеет количество дней, меньшее чем 7. Поэтому фолды будут разбиваться на **5 и 7 дней**" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "07fbdb30", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.to_datetime('23-05-2021', format='%d-%m-%Y').weekday()" - ] - }, - { - "cell_type": "markdown", - "id": "2ff625b2", - "metadata": {}, - "source": [ - "### train test split" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "759ba346", - "metadata": {}, - "outputs": [], - "source": [ - "def create_data_range(\n", - " last_date: pd.Timestamp, \n", - " n_folds: int = 7, \n", - " unit: str = \"W\", \n", - " n_units: int = 1, \n", - " show: bool = True,\n", - "):\n", - " periods = n_folds + 1\n", - " freq = f\"{n_units}{unit}\"\n", - " \n", - " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n", - " \n", - " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", - " \n", - " if show:\n", - " print(\n", - " f\"start_date: {start_date}\\n\"\n", - " f\"last_date: {last_date}\\n\"\n", - " f\"periods: {periods}\\n\"\n", - " f\"freq: {freq}\\n\"\n", - " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n", - " )\n", - " \n", - " return date_range" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "38bfd397", - "metadata": {}, - "outputs": [], - "source": [ - "CONFIG_CV = {\n", - " \"cv_v1\": {\n", - " \"n_folds\": 7,\n", - " \"unit\": \"W\",\n", - " \"n_units\": 1,\n", - " },\n", - " \"cv_v2\": {\n", - " \"n_folds\": 7,\n", - " \"unit\": \"D\",\n", - " \"n_units\": 5,\n", - " }, \n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "f518e089", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Timestamp('2021-08-22 00:00:00')" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "last_date = interactions[Columns.Datetime].max().normalize()\n", - "last_date" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "1fd68b9b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "***Folds v1***\n", - "start_date: 2021-07-13 00:00:00\n", - "last_date: 2021-08-22 00:00:00\n", - "periods: 8\n", - "freq: 5D\n", - "Test fold borders: ['2021-07-13' '2021-07-18' '2021-07-23' '2021-07-28' '2021-08-02'\n", - " '2021-08-07' '2021-08-12' '2021-08-17']\n", - "\n" - ] - } - ], - "source": [ - "print(\"***Folds v1***\")\n", - "date_range_v1 = create_data_range(\n", - " last_date, \n", - " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n", - " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n", - " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "efc59555", - "metadata": {}, - "source": [ - "**генерируем фолды** " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "9fae43f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Real number of folds: 7\n" - ] - } - ], - "source": [ - "cv_v1 = TimeRangeSplitter(\n", - " date_range=date_range_v1,\n", - " filter_already_seen=True,\n", - " filter_cold_items=True,\n", - " filter_cold_users=True,\n", - ")\n", - "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n", - "\n", - "CV = [cv_v1]" - ] - }, - { - "cell_type": "markdown", - "id": "e15a83a7", - "metadata": {}, - "source": [ - "**Формируем метрики**" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "8f7742c6", - "metadata": {}, - "outputs": [], - "source": [ - "metrics = {\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@10\": MAP(k=10),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "b21a1ecf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'cosine_userknn_K30': ,\n", - " 'tfidf_userknn_K30': ,\n", - " 'bm25_userknn_K30': ,\n", - " 'cosine_userknn_K40': ,\n", - " 'tfidf_userknn_K40': ,\n", - " 'bm25_userknn_K40': }" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "K = [30, 40]\n", - "models = dict()\n", - "\n", - "for k in K:\n", - " models[f\"cosine_userknn_K{k}\"] = CosineRecommender(K=k)\n", - " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", - " models[f\"bm25_userknn_K{k}\"] = BM25Recommender(K=k)\n", - "\n", - "models" - ] - }, - { - "cell_type": "markdown", - "id": "0103149a", - "metadata": {}, - "source": [ - "## Training" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "e78b8221", - "metadata": {}, - "outputs": [], - "source": [ - "N_USERS = 50" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "50dcff0b", - "metadata": {}, - "outputs": [], - "source": [ - "%%time\n", - "\n", - "results = []\n", - "\n", - "for idx, cv in enumerate(CV):\n", - " print(f\"\\n CV version {idx}\")\n", - " fold_iterator = cv.split(Interactions(interactions), collect_fold_stats=True)\n", - "\n", - " for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):\n", - " print(f\"\\n==================== Fold {i_fold}\")\n", - " pprint(fold_info)\n", - "\n", - " df_train = interactions.iloc[train_ids].copy()\n", - " df_test = interactions.iloc[test_ids][Columns.UserItem].copy()\n", - "\n", - " catalog = df_train[Columns.Item].unique()\n", - "\n", - " for model_name, model in models.items():\n", - " userknn_model = UserKnn(model=model, N_users=N_USERS, use_weight_idf=True)\n", - " userknn_model.fit(df_train)\n", - "\n", - " if 'bm25' in model_name:\n", - " recos = userknn_model.predict(df_test, bmp25=True)\n", - " else:\n", - " recos = userknn_model.predict(df_test)\n", - "\n", - " metric_values = calc_metrics(\n", - " metrics,\n", - " reco=recos,\n", - " interactions=df_test,\n", - " prev_interactions=df_train,\n", - " catalog=catalog,\n", - " )\n", - "\n", - " full_model_name = f\"{model_name}_cv-{idx}\"\n", - " fold = {\"fold\": i_fold, \"model\": full_model_name}\n", - " fold.update(metric_values)\n", - " results.append(fold)" - ] - }, - { - "cell_type": "markdown", - "id": "708ec5c2", - "metadata": {}, - "source": [ - "Работало больше 10 часов, случайно при перезапуске ноутбука была вызвана ячейка и остановлена, поэтому завершилась с ошибкой, поэтому ошибку убрали для лучшего вида" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "d7e2ffa7", - "metadata": { - "collapsed": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
foldmodelprec@10recall@10MAP@10noveltyserendipity
00cosine_userknn_K30_cv-00.0035570.0211280.0036958.3314910.000040
10tfidf_userknn_K30_cv-00.0064390.0391020.0073358.1550510.000048
20bm25_userknn_K30_cv-00.0025930.0134940.0025319.3984670.000081
30cosine_userknn_K40_cv-00.0032820.0193230.0034018.5615230.000043
40tfidf_userknn_K40_cv-00.0061780.0374580.0069578.3004040.000052
50bm25_userknn_K40_cv-00.0022410.0112550.0022109.6755330.000081
61cosine_userknn_K30_cv-00.0035050.0200020.0035808.3982480.000046
71tfidf_userknn_K30_cv-00.0063280.0368440.0070228.2401330.000058
81bm25_userknn_K30_cv-00.0027220.0138560.0026589.4846920.000088
91cosine_userknn_K40_cv-00.0032450.0183680.0033058.6269060.000047
101tfidf_userknn_K40_cv-00.0061500.0359640.0069168.3779880.000061
111bm25_userknn_K40_cv-00.0024060.0120670.0023939.7564580.000086
122cosine_userknn_K30_cv-00.0032610.0184980.0032958.4392630.000047
132tfidf_userknn_K30_cv-00.0059400.0342330.0064798.2623670.000059
142bm25_userknn_K30_cv-00.0027200.0134220.0025309.5356310.000091
152cosine_userknn_K40_cv-00.0030450.0170860.0031008.6615850.000050
162tfidf_userknn_K40_cv-00.0059140.0340710.0064398.3966180.000063
172bm25_userknn_K40_cv-00.0024040.0116380.0022319.7991190.000090
183cosine_userknn_K30_cv-00.0032770.0187860.0033958.4449860.000045
193tfidf_userknn_K30_cv-00.0060230.0341710.0063288.2765030.000059
203bm25_userknn_K30_cv-00.0026200.0127620.0024979.5609840.000091
213cosine_userknn_K40_cv-00.0030760.0175120.0031738.6581500.000045
223tfidf_userknn_K40_cv-00.0059190.0333680.0062538.3991690.000062
233bm25_userknn_K40_cv-00.0023370.0112730.0022539.8163250.000089
244cosine_userknn_K30_cv-00.0031180.0180640.0031578.4858990.000042
254tfidf_userknn_K30_cv-00.0059110.0336260.0063968.2824280.000059
264bm25_userknn_K30_cv-00.0025370.0123680.0024709.5996450.000086
274cosine_userknn_K40_cv-00.0028720.0165090.0028838.7119840.000043
284tfidf_userknn_K40_cv-00.0057930.0330280.0062618.4166800.000062
294bm25_userknn_K40_cv-00.0022130.0108600.0021799.8662010.000085
305cosine_userknn_K30_cv-00.0030030.0162520.0028998.4989680.000043
315tfidf_userknn_K30_cv-00.0055270.0309420.0058238.3252730.000057
325bm25_userknn_K30_cv-00.0025970.0122630.0023869.6469570.000100
335cosine_userknn_K40_cv-00.0027650.0147130.0026618.7175590.000047
345tfidf_userknn_K40_cv-00.0055450.0308920.0058178.4540910.000059
355bm25_userknn_K40_cv-00.0023020.0107770.0021359.9140420.000100
366cosine_userknn_K30_cv-00.0029630.0165320.0028878.5638090.000050
376tfidf_userknn_K30_cv-00.0053300.0307170.0057638.3662590.000064
386bm25_userknn_K30_cv-00.0025710.0126910.0024789.7150970.000100
396cosine_userknn_K40_cv-00.0027690.0154480.0026758.7750580.000051
406tfidf_userknn_K40_cv-00.0052840.0304180.0056978.4884730.000066
416bm25_userknn_K40_cv-00.0023400.0112780.0022089.9646640.000099
\n", - "
" - ], - "text/plain": [ - " fold model prec@10 recall@10 MAP@10 novelty \\\n", - "0 0 cosine_userknn_K30_cv-0 0.003557 0.021128 0.003695 8.331491 \n", - "1 0 tfidf_userknn_K30_cv-0 0.006439 0.039102 0.007335 8.155051 \n", - "2 0 bm25_userknn_K30_cv-0 0.002593 0.013494 0.002531 9.398467 \n", - "3 0 cosine_userknn_K40_cv-0 0.003282 0.019323 0.003401 8.561523 \n", - "4 0 tfidf_userknn_K40_cv-0 0.006178 0.037458 0.006957 8.300404 \n", - "5 0 bm25_userknn_K40_cv-0 0.002241 0.011255 0.002210 9.675533 \n", - "6 1 cosine_userknn_K30_cv-0 0.003505 0.020002 0.003580 8.398248 \n", - "7 1 tfidf_userknn_K30_cv-0 0.006328 0.036844 0.007022 8.240133 \n", - "8 1 bm25_userknn_K30_cv-0 0.002722 0.013856 0.002658 9.484692 \n", - "9 1 cosine_userknn_K40_cv-0 0.003245 0.018368 0.003305 8.626906 \n", - "10 1 tfidf_userknn_K40_cv-0 0.006150 0.035964 0.006916 8.377988 \n", - "11 1 bm25_userknn_K40_cv-0 0.002406 0.012067 0.002393 9.756458 \n", - "12 2 cosine_userknn_K30_cv-0 0.003261 0.018498 0.003295 8.439263 \n", - "13 2 tfidf_userknn_K30_cv-0 0.005940 0.034233 0.006479 8.262367 \n", - "14 2 bm25_userknn_K30_cv-0 0.002720 0.013422 0.002530 9.535631 \n", - "15 2 cosine_userknn_K40_cv-0 0.003045 0.017086 0.003100 8.661585 \n", - "16 2 tfidf_userknn_K40_cv-0 0.005914 0.034071 0.006439 8.396618 \n", - "17 2 bm25_userknn_K40_cv-0 0.002404 0.011638 0.002231 9.799119 \n", - "18 3 cosine_userknn_K30_cv-0 0.003277 0.018786 0.003395 8.444986 \n", - "19 3 tfidf_userknn_K30_cv-0 0.006023 0.034171 0.006328 8.276503 \n", - "20 3 bm25_userknn_K30_cv-0 0.002620 0.012762 0.002497 9.560984 \n", - "21 3 cosine_userknn_K40_cv-0 0.003076 0.017512 0.003173 8.658150 \n", - "22 3 tfidf_userknn_K40_cv-0 0.005919 0.033368 0.006253 8.399169 \n", - "23 3 bm25_userknn_K40_cv-0 0.002337 0.011273 0.002253 9.816325 \n", - "24 4 cosine_userknn_K30_cv-0 0.003118 0.018064 0.003157 8.485899 \n", - "25 4 tfidf_userknn_K30_cv-0 0.005911 0.033626 0.006396 8.282428 \n", - "26 4 bm25_userknn_K30_cv-0 0.002537 0.012368 0.002470 9.599645 \n", - "27 4 cosine_userknn_K40_cv-0 0.002872 0.016509 0.002883 8.711984 \n", - "28 4 tfidf_userknn_K40_cv-0 0.005793 0.033028 0.006261 8.416680 \n", - "29 4 bm25_userknn_K40_cv-0 0.002213 0.010860 0.002179 9.866201 \n", - "30 5 cosine_userknn_K30_cv-0 0.003003 0.016252 0.002899 8.498968 \n", - "31 5 tfidf_userknn_K30_cv-0 0.005527 0.030942 0.005823 8.325273 \n", - "32 5 bm25_userknn_K30_cv-0 0.002597 0.012263 0.002386 9.646957 \n", - "33 5 cosine_userknn_K40_cv-0 0.002765 0.014713 0.002661 8.717559 \n", - "34 5 tfidf_userknn_K40_cv-0 0.005545 0.030892 0.005817 8.454091 \n", - "35 5 bm25_userknn_K40_cv-0 0.002302 0.010777 0.002135 9.914042 \n", - "36 6 cosine_userknn_K30_cv-0 0.002963 0.016532 0.002887 8.563809 \n", - "37 6 tfidf_userknn_K30_cv-0 0.005330 0.030717 0.005763 8.366259 \n", - "38 6 bm25_userknn_K30_cv-0 0.002571 0.012691 0.002478 9.715097 \n", - "39 6 cosine_userknn_K40_cv-0 0.002769 0.015448 0.002675 8.775058 \n", - "40 6 tfidf_userknn_K40_cv-0 0.005284 0.030418 0.005697 8.488473 \n", - "41 6 bm25_userknn_K40_cv-0 0.002340 0.011278 0.002208 9.964664 \n", - "\n", - " serendipity \n", - "0 0.000040 \n", - "1 0.000048 \n", - "2 0.000081 \n", - "3 0.000043 \n", - "4 0.000052 \n", - "5 0.000081 \n", - "6 0.000046 \n", - "7 0.000058 \n", - "8 0.000088 \n", - "9 0.000047 \n", - "10 0.000061 \n", - "11 0.000086 \n", - "12 0.000047 \n", - "13 0.000059 \n", - "14 0.000091 \n", - "15 0.000050 \n", - "16 0.000063 \n", - "17 0.000090 \n", - "18 0.000045 \n", - "19 0.000059 \n", - "20 0.000091 \n", - "21 0.000045 \n", - "22 0.000062 \n", - "23 0.000089 \n", - "24 0.000042 \n", - "25 0.000059 \n", - "26 0.000086 \n", - "27 0.000043 \n", - "28 0.000062 \n", - "29 0.000085 \n", - "30 0.000043 \n", - "31 0.000057 \n", - "32 0.000100 \n", - "33 0.000047 \n", - "34 0.000059 \n", - "35 0.000100 \n", - "36 0.000050 \n", - "37 0.000064 \n", - "38 0.000100 \n", - "39 0.000051 \n", - "40 0.000066 \n", - "41 0.000099 " - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics = pd.DataFrame(results)\n", - "df_metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "a0334b9a", - "metadata": {}, - "outputs": [], - "source": [ - "df_metrics.to_pickle(\"../data/hw_3/df_metrics.pickle\")" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "446530ce", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
foldprec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-03.00.0026230.0129800.0025079.5630680.000091
bm25_userknn_K40_cv-03.00.0023200.0113070.0022309.8274770.000090
cosine_userknn_K30_cv-03.00.0032410.0184660.0032728.4518090.000045
cosine_userknn_K40_cv-03.00.0030080.0169940.0030288.6732520.000047
tfidf_userknn_K30_cv-03.00.0059280.0342340.0064498.2725730.000058
tfidf_userknn_K40_cv-03.00.0058260.0336000.0063348.4047750.000061
\n", - "
" - ], - "text/plain": [ - " fold prec@10 recall@10 MAP@10 novelty \\\n", - "model \n", - "bm25_userknn_K30_cv-0 3.0 0.002623 0.012980 0.002507 9.563068 \n", - "bm25_userknn_K40_cv-0 3.0 0.002320 0.011307 0.002230 9.827477 \n", - "cosine_userknn_K30_cv-0 3.0 0.003241 0.018466 0.003272 8.451809 \n", - "cosine_userknn_K40_cv-0 3.0 0.003008 0.016994 0.003028 8.673252 \n", - "tfidf_userknn_K30_cv-0 3.0 0.005928 0.034234 0.006449 8.272573 \n", - "tfidf_userknn_K40_cv-0 3.0 0.005826 0.033600 0.006334 8.404775 \n", - "\n", - " serendipity \n", - "model \n", - "bm25_userknn_K30_cv-0 0.000091 \n", - "bm25_userknn_K40_cv-0 0.000090 \n", - "cosine_userknn_K30_cv-0 0.000045 \n", - "cosine_userknn_K40_cv-0 0.000047 \n", - "tfidf_userknn_K30_cv-0 0.000058 \n", - "tfidf_userknn_K40_cv-0 0.000061 " - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics.groupby('model').mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "5fb9ba9f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
prec@10recall@10MAP@10noveltyserendipity
model
bm25_userknn_K30_cv-00.0000720.0006120.0000830.1044680.000007
bm25_userknn_K40_cv-00.0000740.0004420.0000810.0973590.000007
cosine_userknn_K30_cv-00.0002310.0017490.0003140.0746990.000003
cosine_userknn_K40_cv-00.0002130.0016030.0002950.0693100.000003
tfidf_userknn_K30_cv-00.0003980.0030030.0005770.0666270.000005
tfidf_userknn_K40_cv-00.0003210.0025340.0004870.0595650.000004
\n", - "
" - ], - "text/plain": [ - " prec@10 recall@10 MAP@10 novelty serendipity\n", - "model \n", - "bm25_userknn_K30_cv-0 0.000072 0.000612 0.000083 0.104468 0.000007\n", - "bm25_userknn_K40_cv-0 0.000074 0.000442 0.000081 0.097359 0.000007\n", - "cosine_userknn_K30_cv-0 0.000231 0.001749 0.000314 0.074699 0.000003\n", - "cosine_userknn_K40_cv-0 0.000213 0.001603 0.000295 0.069310 0.000003\n", - "tfidf_userknn_K30_cv-0 0.000398 0.003003 0.000577 0.066627 0.000005\n", - "tfidf_userknn_K40_cv-0 0.000321 0.002534 0.000487 0.059565 0.000004" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics.groupby('model').std()[metrics.keys()]" - ] - }, - { - "cell_type": "markdown", - "id": "41828ee5", - "metadata": {}, - "source": [ - "по **ofline** метрикам лучше всего себя показывает модель TFIDFRecommender\n", - "TFIDFRecommender подбор К" - ] - }, - { - "cell_type": "markdown", - "id": "7a8a0a41", - "metadata": {}, - "source": [ - "# Подбор оптимального K для TFIDFRecommender" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "1e91892d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tfidf_userknn_K50': ,\n", - " 'tfidf_userknn_K60': ,\n", - " 'tfidf_userknn_K70': }" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "N_USERS = 50\n", - "\n", - "# Т.к. метрики для К 30 и 40 уже есть\n", - "K = [k for k in range(50, 71, 10)]\n", - "models = dict()\n", - "\n", - "for k in K:\n", - " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n", - "models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7c2c43b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "==================== Fold 0\n", - "{'End date': Timestamp('2021-07-18 00:00:00', freq='5D'),\n", - " 'Start date': Timestamp('2021-07-13 00:00:00', freq='5D'),\n", - " 'Test': 156580,\n", - " 'Test items': 5793,\n", - " 'Test users': 68150,\n", - " 'Train': 3281612,\n", - " 'Train items': 14754,\n", - " 'Train users': 652905}\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "211234f034a54bae86b94dff33b9f5c4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/652905 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idlast_watch_dttotal_durwatched_pct
017654995062021-05-11425072.0
169931716592021-05-298317100.0
265668371072021-05-09100.0
386461376382021-07-0514483100.0
496486895062021-04-306725100.0
5476246648596122252021-08-13760.0
547624754686296732021-04-13230849.0
5476248697262152972021-08-201830763.0
5476249384202161972021-04-196203100.0
547625031970944362021-08-15392145.0
\n", - "" - ], - "text/plain": [ - " user_id item_id last_watch_dt total_dur watched_pct\n", - "0 176549 9506 2021-05-11 4250 72.0\n", - "1 699317 1659 2021-05-29 8317 100.0\n", - "2 656683 7107 2021-05-09 10 0.0\n", - "3 864613 7638 2021-07-05 14483 100.0\n", - "4 964868 9506 2021-04-30 6725 100.0\n", - "5476246 648596 12225 2021-08-13 76 0.0\n", - "5476247 546862 9673 2021-04-13 2308 49.0\n", - "5476248 697262 15297 2021-08-20 18307 63.0\n", - "5476249 384202 16197 2021-04-19 6203 100.0\n", - "5476250 319709 4436 2021-08-15 3921 45.0" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([interactions.head(), interactions.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "dc4d9fd7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(962179,)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "interactions['user_id'].unique().shape" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "b7861d19", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[(961833, 1.0),\n", - " (961849, 1.0),\n", - " (961857, 1.0),\n", - " (961871, 1.0),\n", - " (961873, 1.0),\n", - " (961876, 1.0),\n", - " (961887, 1.0),\n", - " (961907, 1.0),\n", - " (961910, 1.0),\n", - " (961912, 1.0)]" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import dill\n", - "\n", - "with open('../service/weights/userKNN/userknn_tfidf_k30.dill', 'rb') as f:\n", - " userknn = dill.load(f)\n", - "\n", - "userknn.similar_items(962178, 10)" - ] - }, - { - "cell_type": "markdown", - "id": "1905033a", - "metadata": {}, - "source": [ - "# Popular Model" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "2df74dba", - "metadata": {}, - "outputs": [], - "source": [ - "from rectools.models import PopularModel\n", - "from rectools.dataset import Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "6ba37a73", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Timestamp('2021-08-22 00:00:00')" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "max_date = interactions[Columns.Datetime].max().normalize()\n", - "max_date" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "901353f9", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "train = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", - " interactions[Columns.Datetime] < max_date - pd.Timedelta(5, \"D\")]\n", - "\n", - "test = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n", - " interactions[Columns.Datetime] >= max_date - pd.Timedelta(5, \"D\")]\n", - "\n", - "dataset_train = Dataset.construct(train)" - ] - }, - { - "cell_type": "code", - "execution_count": 144, - "id": "f08e3579", - "metadata": {}, - "outputs": [], - "source": [ - "popilarity_models = {\n", - " \"popular\": PopularModel(),\n", - " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "id": "03c3bfb6", - "metadata": {}, - "outputs": [], - "source": [ - "popilarity_models[\"popular\"].fit(dataset_train)\n", - "popilarity_models[\"popular_mw\"].fit(dataset_train);" - ] - }, - { - "cell_type": "code", - "execution_count": 146, - "id": "0d7de49e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 24, 20, 31, 15, 167, 81, 89, 135, 355, 116])" - ] - }, - "execution_count": 146, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "popilarity_models[\"popular\"].popularity_list[0][:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 147, - "id": "05ff208d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([11363, 11681, 12841, 13017, 2069, 13691, 13552, 13397, 11774,\n", - " 12913])" - ] - }, - "execution_count": 147, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "popilarity_models[\"popular_mw\"].popularity_list[0][:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 148, - "id": "00ef735c", - "metadata": {}, - "outputs": [], - "source": [ - "pecos_pop = popilarity_models[\"popular\"].recommend(\n", - " users=test[Columns.User].unique(),\n", - " dataset=dataset,\n", - " k=100,\n", - " filter_viewed=False,\n", - ")\n", - "\n", - "pecos_pop_mw = popilarity_models[\"popular_mw\"].recommend(\n", - " users=test[Columns.User].unique(),\n", - " dataset=dataset,\n", - " k=100,\n", - " filter_viewed=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 152, - "id": "b302db55", - "metadata": {}, - "outputs": [], - "source": [ - "metrics = {\n", - " \"prec@5\": Precision(k=5),\n", - " \"recall@5\": Recall(k=5),\n", - " \"MAP@5\": MAP(k=5),\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@20\": MAP(k=20),\n", - " \"prec@20\": Precision(k=20),\n", - " \"recall@20\": Recall(k=20),\n", - " \"MAP@100\": MAP(k=100),\n", - " \"prec@100\": Precision(k=100),\n", - " \"recall@100\": Recall(k=100),\n", - " \"MAP@100\": MAP(k=100),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}\n", - "catalog = train[Columns.Item].unique()\n", - "metric_values_pop = calc_metrics(metrics, pecos_pop, test, train, catalog)\n", - "metric_values_pop_mean_weight = calc_metrics(metrics, pecos_pop_mw, test, train, catalog)" - ] - }, - { - "cell_type": "code", - "execution_count": 153, - "id": "9631093b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'prec@5': 0.0017855613317256697,\n", - " 'recall@5': 0.004623809755660008,\n", - " 'prec@10': 0.0011648975773029461,\n", - " 'recall@10': 0.005682095875283048,\n", - " 'prec@20': 0.0010502526799891945,\n", - " 'recall@20': 0.00880186008464912,\n", - " 'prec@100': 0.003247020220987923,\n", - " 'recall@100': 0.16609031082955295,\n", - " 'MAP@5': 0.0013179725619140792,\n", - " 'MAP@20': 0.0016695313583723814,\n", - " 'MAP@100': 0.005578924867474493,\n", - " 'novelty': 9.976033936531364,\n", - " 'serendipity': 1.2752762676592953e-05}" - ] - }, - "execution_count": 153, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metric_values_pop" - ] - }, - { - "cell_type": "code", - "execution_count": 154, - "id": "5d55b781", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'prec@5': 9.09252633867684e-05,\n", - " 'recall@5': 0.00014799438063171262,\n", - " 'prec@10': 4.612151041357817e-05,\n", - " 'recall@10': 0.00015458316783365238,\n", - " 'prec@20': 2.635514880775895e-05,\n", - " 'recall@20': 0.00016946607539568094,\n", - " 'prec@100': 0.00015147621777259455,\n", - " 'recall@100': 0.0065476971391510656,\n", - " 'MAP@5': 3.0257754846536496e-05,\n", - " 'MAP@20': 3.1771198360212185e-05,\n", - " 'MAP@100': 0.00011355765992119742,\n", - " 'novelty': 17.423655787689828,\n", - " 'serendipity': 1.8991632826477633e-06}" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metric_values_pop_mean_weight" - ] - }, - { - "cell_type": "markdown", - "id": "e5a4a011", - "metadata": {}, - "source": [ - "**На офлайн метриках выигрывает обычная модель по популярному**" - ] - }, - { - "cell_type": "markdown", - "id": "5875fab7", - "metadata": {}, - "source": [ - "# Save item_idf data" - ] - }, - { - "cell_type": "markdown", - "id": "6589996f", - "metadata": {}, - "source": [ - "Создаем датасет со взвешенными item-ами по механизму idf для использования в будущем" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "d62cabb9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexidf
095067.150811
116598.524953
271075.821207
376388.407093
466867.778734
.........
15701783314.822785
15702912514.822785
157031006414.822785
157041301914.822785
157051054214.822785
\n", - "

15706 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " index idf\n", - "0 9506 7.150811\n", - "1 1659 8.524953\n", - "2 7107 5.821207\n", - "3 7638 8.407093\n", - "4 6686 7.778734\n", - "... ... ...\n", - "15701 7833 14.822785\n", - "15702 9125 14.822785\n", - "15703 10064 14.822785\n", - "15704 13019 14.822785\n", - "15705 10542 14.822785\n", - "\n", - "[15706 rows x 2 columns]" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "item_cnt = Counter(interactions['item_id'].values)\n", - "item_idf = pd.DataFrame.from_dict(item_cnt, orient='index', columns=['doc_freq']).reset_index()\n", - "n = interactions.shape[0]\n", - "item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: np.log((1 + n) / (1 + x) + 1))\n", - "del item_idf['doc_freq']\n", - "item_idf" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "7da47dfc", - "metadata": {}, - "outputs": [], - "source": [ - "item_idf = item_idf.sort_values(\"idf\", ascending=False)\n", - "item_idf.to_csv('../data/kion_train/items_idf.csv', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdce2b60", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/HW-3.2-rectools-research.ipynb b/notebooks/HW-3.2-rectools-research.ipynb deleted file mode 100644 index ed456f5f..00000000 --- a/notebooks/HW-3.2-rectools-research.ipynb +++ /dev/null @@ -1,725 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "855d49cd", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import scipy as sp\n", - "import requests\n", - "from tqdm.auto import tqdm\n", - "from scipy.stats import mode \n", - "from pprint import pprint\n", - "from implicit.nearest_neighbours import CosineRecommender\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "from rectools import Columns\n", - "\n", - "pd.set_option('display.max_columns', None)\n", - "pd.set_option('display.max_colwidth', 200)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "655cd033", - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", - "\n", - "interactions.rename(columns={'last_watch_dt': Columns.Datetime,\n", - " 'total_dur': Columns.Weight}, \n", - " inplace=True) \n", - "\n", - "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "193c411d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Start date and last date of the test fold: (Timestamp('2021-08-08 00:00:00'), Timestamp('2021-08-22 00:00:00'))\n", - "Test fold borders: ['2021-08-08' '2021-08-15']\n", - "Real number of folds: 1\n" - ] - } - ], - "source": [ - "from rectools.model_selection import TimeRangeSplitter\n", - "from rectools.dataset import Interactions\n", - "\n", - "n_folds = 1\n", - "unit = \"W\"\n", - "n_units = 1\n", - "periods = n_folds + 1\n", - "freq = f\"{n_units}{unit}\"\n", - "\n", - "last_date = interactions[Columns.Datetime].max().normalize()\n", - "start_date = last_date - pd.Timedelta(n_folds * n_units + 1, unit=unit) \n", - "print(f\"Start date and last date of the test fold: {start_date, last_date}\")\n", - " \n", - "date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", - "print(f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\")\n", - "\n", - "# generator of folds\n", - "cv = TimeRangeSplitter(\n", - " date_range=date_range,\n", - " filter_already_seen=True,\n", - " filter_cold_items=True,\n", - " filter_cold_users=True,\n", - ")\n", - "print(f\"Real number of folds: {cv.get_n_splits(Interactions(interactions))}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "38b80f9f", - "metadata": {}, - "outputs": [], - "source": [ - "(train_ids, test_ids, fold_info) = cv.split(Interactions(interactions), collect_fold_stats=True).__next__()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "e3051991", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 0, 1, 2, ..., 5476245, 5476247, 5476249])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_ids" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "7bc27a2f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 6, 33, 56, ..., 5476229, 5476230, 5476240])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_ids" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "ffdaad0c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "users_mapping amount: 842129\n", - "items_mapping amount: 15404\n" - ] - } - ], - "source": [ - "train = interactions.loc[train_ids]\n", - "test = interactions.loc[test_ids]\n", - "\n", - "users_inv_mapping = dict(enumerate(train['user_id'].unique()))\n", - "users_mapping = {v: k for k, v in users_inv_mapping.items()}\n", - "\n", - "items_inv_mapping = dict(enumerate(train['item_id'].unique()))\n", - "items_mapping = {v: k for k, v in items_inv_mapping.items()}\n", - "\n", - "print(f\"users_mapping amount: {len(users_mapping)}\")\n", - "print(f\"items_mapping amount: {len(items_mapping)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "a6664026", - "metadata": {}, - "outputs": [], - "source": [ - "from rectools.dataset import Dataset\n", - "\n", - "dataset = Dataset.construct(\n", - " interactions_df=train,\n", - " user_features_df=None,\n", - " item_features_df=None\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "88f5a65c", - "metadata": {}, - "source": [ - "# ItemKNN CosineRecommender" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "9c4682c5", - "metadata": {}, - "outputs": [], - "source": [ - "from implicit.nearest_neighbours import CosineRecommender\n", - "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", - "\n", - "item_knn = ImplicitItemKNNWrapperModel(model=CosineRecommender(K=30))\n", - "item_knn.fit(dataset);" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "198faaa4", - "metadata": {}, - "outputs": [], - "source": [ - "recs_itemknn = item_knn.recommend(\n", - " test['user_id'].unique(), \n", - " dataset=dataset, \n", - " k=10, \n", - " filter_viewed=False\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "76d1a3f5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
010164581044020431.6311501
110164587348043.9999622
21016458121928033.5995303
3101645819867999.8057314
4101645844577763.2046075
\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 1016458 10440 20431.631150 1\n", - "1 1016458 734 8043.999962 2\n", - "2 1016458 12192 8033.599530 3\n", - "3 1016458 1986 7999.805731 4\n", - "4 1016458 4457 7763.204607 5" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recs_itemknn.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "c075a976", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'prec@10': 0.017311708814214132,\n", - " 'recall@10': 0.09520897568691472,\n", - " 'MAP@10': 0.023145528903990274,\n", - " 'novelty': 8.05318572965277,\n", - " 'serendipity': 6.63288816067437e-05}" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", - "\n", - "# calculate several classic (precision@k and recall@k) and \"beyond accuracy\" metrics\n", - "metrics = {\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@10\": MAP(k=10),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}\n", - "\n", - "catalog = train['item_id'].unique()\n", - "\n", - "metric_values_itemknn_cosine = calc_metrics(\n", - " metrics,\n", - " reco=recs_itemknn,\n", - " interactions=test,\n", - " prev_interactions=train,\n", - " catalog=catalog\n", - " )\n", - "\n", - "metric_values_itemknn_cosine" - ] - }, - { - "cell_type": "markdown", - "id": "b439f7fb", - "metadata": {}, - "source": [ - "# ItemKNN TFIDFRecommender" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "e31f5560", - "metadata": {}, - "outputs": [], - "source": [ - "from implicit.nearest_neighbours import TFIDFRecommender\n", - "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", - "\n", - "item_knn_tfidf = ImplicitItemKNNWrapperModel(model=TFIDFRecommender(K=30))\n", - "item_knn_tfidf.fit(dataset);" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "360eafab", - "metadata": {}, - "outputs": [], - "source": [ - "recs_itemknn_tfidf = item_knn_tfidf.recommend(\n", - " test['user_id'].unique(), \n", - " dataset=dataset, \n", - " k=10, \n", - " filter_viewed=False \n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "63c31f04", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
010164581044021745.3769271
11016458445710234.8633082
2101645871028987.8781293
31016458121928957.1098134
4101645819868369.8324485
\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 1016458 10440 21745.376927 1\n", - "1 1016458 4457 10234.863308 2\n", - "2 1016458 7102 8987.878129 3\n", - "3 1016458 12192 8957.109813 4\n", - "4 1016458 1986 8369.832448 5" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recs_itemknn_tfidf.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "7a4d01f7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'prec@10': 0.023772589549238603,\n", - " 'recall@10': 0.12652382351172245,\n", - " 'MAP@10': 0.03005237337960426,\n", - " 'novelty': 6.699663403861505,\n", - " 'serendipity': 0.00010222896681730396}" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", - "\n", - "metrics = {\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@10\": MAP(k=10),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}\n", - "\n", - "catalog = train['item_id'].unique()\n", - "\n", - "metric_values_itemknn_tfidf = calc_metrics(\n", - " metrics,\n", - " reco=recs_itemknn_tfidf,\n", - " interactions=test,\n", - " prev_interactions=train,\n", - " catalog=catalog\n", - " )\n", - "\n", - "metric_values_itemknn_tfidf" - ] - }, - { - "cell_type": "markdown", - "id": "2270cb27", - "metadata": {}, - "source": [ - "# UserKNN BMP25" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "c7997faf", - "metadata": {}, - "outputs": [], - "source": [ - "from implicit.nearest_neighbours import BM25Recommender\n", - "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", - "\n", - "item_knn_bmp = ImplicitItemKNNWrapperModel(model=BM25Recommender(K=30))\n", - "item_knn_bmp.fit(dataset);" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "c7ceb0e5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
01016458104406.854547e+111
11016458152972.323138e+112
21016458138651.724740e+113
3101645897281.383208e+114
4101645841511.149358e+115
\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 1016458 10440 6.854547e+11 1\n", - "1 1016458 15297 2.323138e+11 2\n", - "2 1016458 13865 1.724740e+11 3\n", - "3 1016458 9728 1.383208e+11 4\n", - "4 1016458 4151 1.149358e+11 5" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recs_itemknn_bmp = item_knn_bmp.recommend(\n", - " test['user_id'].unique(), \n", - " dataset=dataset, \n", - " k=10, \n", - " filter_viewed=False \n", - ")\n", - "\n", - "recs_itemknn_bmp.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "e99f3649", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'prec@10': 0.03252208701450242,\n", - " 'recall@10': 0.1683399650610623,\n", - " 'MAP@10': 0.04827657497255996,\n", - " 'novelty': 3.9201705312554833,\n", - " 'serendipity': 2.616232292298612e-05}" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", - "\n", - "metrics = {\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@10\": MAP(k=10),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}\n", - "\n", - "catalog = train['item_id'].unique()\n", - "\n", - "metric_values_itemknn_bmp = calc_metrics(\n", - " metrics,\n", - " reco=recs_itemknn_bmp,\n", - " interactions=test,\n", - " prev_interactions=train,\n", - " catalog=catalog\n", - " )\n", - "\n", - "metric_values_itemknn_bmp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84fe056a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/HW-3.3-rectools-cv.ipynb b/notebooks/HW-3.3-rectools-cv.ipynb deleted file mode 100644 index e5f56e68..00000000 --- a/notebooks/HW-3.3-rectools-cv.ipynb +++ /dev/null @@ -1,4387 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "id": "f0145080", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import scipy as sp\n", - "import requests\n", - "from tqdm.auto import tqdm\n", - "from scipy.stats import mode \n", - "from pprint import pprint\n", - "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "from rectools import Columns\n", - "from rectools.model_selection import TimeRangeSplitter\n", - "from rectools.dataset import Dataset, Interactions\n", - "from rectools.models.popular import PopularModel\n", - "from rectools.models.implicit_knn import ImplicitItemKNNWrapperModel\n", - "from rectools.metrics import Precision, Recall, MeanInvUserFreq, MAP, Serendipity, calc_metrics\n", - "\n", - "pd.set_option('display.max_columns', None)\n", - "pd.set_option('display.max_colwidth', 200)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "95ab759c", - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", - "\n", - "interactions.rename(columns={\n", - " 'last_watch_dt': Columns.Datetime,\n", - " 'total_dur': Columns.Weight\n", - " }, \n", - " inplace=True\n", - ") \n", - "\n", - "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" - ] - }, - { - "cell_type": "markdown", - "id": "fbd3f42d", - "metadata": {}, - "source": [ - "# Split" - ] - }, - { - "cell_type": "markdown", - "id": "c89fcc74", - "metadata": {}, - "source": [ - "В соответствии с предположением из ноутбука \"HW-3.1\" сделаем **валидацию по 5 дней и по 7 дней**" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "368c7cf6", - "metadata": {}, - "outputs": [], - "source": [ - "def create_data_range(\n", - " last_date: pd.Timestamp, \n", - " n_folds: int = 7, \n", - " unit: str = \"W\", \n", - " n_units: int = 1, \n", - " show: bool = True,\n", - "):\n", - " periods = n_folds + 1\n", - " freq = f\"{n_units}{unit}\"\n", - " \n", - " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n", - " \n", - " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n", - " \n", - " if show:\n", - " print(\n", - " f\"start_date: {start_date}\\n\"\n", - " f\"last_date: {last_date}\\n\"\n", - " f\"periods: {periods}\\n\"\n", - " f\"freq: {freq}\\n\"\n", - " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n", - " )\n", - " \n", - " return date_range" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "29af1fa3", - "metadata": {}, - "outputs": [], - "source": [ - "CONFIG_CV = {\n", - " \"cv_v1\": {\n", - " \"n_folds\": 5,\n", - " \"unit\": \"W\",\n", - " \"n_units\": 1,\n", - " },\n", - " \"cv_v2\": {\n", - " \"n_folds\": 5,\n", - " \"unit\": \"D\",\n", - " \"n_units\": 5,\n", - " }, \n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3fdeb5a3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Timestamp('2021-08-22 00:00:00')" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "last_date = interactions[Columns.Datetime].max().normalize()\n", - "last_date" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "9ee0372b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "***Folds v1***\n", - "start_date: 2021-07-11 00:00:00\n", - "last_date: 2021-08-22 00:00:00\n", - "periods: 6\n", - "freq: 1W\n", - "Test fold borders: ['2021-07-11' '2021-07-18' '2021-07-25' '2021-08-01' '2021-08-08'\n", - " '2021-08-15']\n", - "\n", - "***Folds v2***\n", - "start_date: 2021-07-23 00:00:00\n", - "last_date: 2021-08-22 00:00:00\n", - "periods: 6\n", - "freq: 5D\n", - "Test fold borders: ['2021-07-23' '2021-07-28' '2021-08-02' '2021-08-07' '2021-08-12'\n", - " '2021-08-17']\n", - "\n" - ] - } - ], - "source": [ - "print(\"***Folds v1***\")\n", - "date_range_v1 = create_data_range(\n", - " last_date, \n", - " n_folds=CONFIG_CV[\"cv_v1\"][\"n_folds\"], \n", - " unit=CONFIG_CV[\"cv_v1\"][\"unit\"], \n", - " n_units=CONFIG_CV[\"cv_v1\"][\"n_units\"]\n", - ")\n", - "\n", - "print(\"***Folds v2***\")\n", - "date_range_v2 = create_data_range(\n", - " last_date, \n", - " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n", - " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n", - " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "63d80785", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Real number of folds: 5\n", - "Real number of folds: 5\n" - ] - } - ], - "source": [ - "cv_v1 = TimeRangeSplitter(\n", - " date_range=date_range_v1,\n", - " filter_already_seen=True,\n", - " filter_cold_items=True,\n", - " filter_cold_users=True,\n", - ")\n", - "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n", - "\n", - "cv_v2 = TimeRangeSplitter(\n", - " date_range=date_range_v2,\n", - " filter_already_seen=True,\n", - " filter_cold_items=True,\n", - " filter_cold_users=True,\n", - ")\n", - "print(f\"Real number of folds: {cv_v2.get_n_splits(Interactions(interactions))}\")\n", - "\n", - "CV = [cv_v1, cv_v2]" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "1d4bc5e3", - "metadata": {}, - "outputs": [], - "source": [ - "metrics = {\n", - " \"prec@5\": Precision(k=5),\n", - " \"recall@5\": Recall(k=5),\n", - " \"MAP@5\": MAP(k=5),\n", - " \"prec@10\": Precision(k=10),\n", - " \"recall@10\": Recall(k=10),\n", - " \"MAP@10\": MAP(k=10),\n", - " \"novelty\": MeanInvUserFreq(k=10),\n", - " \"serendipity\": Serendipity(k=10),\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "f480a12f", - "metadata": {}, - "source": [ - "# Find best models" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "48888d0d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'popular': ,\n", - " 'popular_mw': ,\n", - " 'cosine_userknn_K30': ,\n", - " 'tfidf_userknn_K30': ,\n", - " 'bm25_userknn_K30': ,\n", - " 'cosine_userknn_K40': ,\n", - " 'tfidf_userknn_K40': ,\n", - " 'bm25_userknn_K40': ,\n", - " 'cosine_userknn_K50': ,\n", - " 'tfidf_userknn_K50': ,\n", - " 'bm25_userknn_K50': ,\n", - " 'cosine_userknn_K60': ,\n", - " 'tfidf_userknn_K60': ,\n", - " 'bm25_userknn_K60': }" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "K = [30, 40, 50, 60]\n", - "models = {\n", - " \"popular\": PopularModel(),\n", - " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n", - "}\n", - "\n", - "for k in K:\n", - " models[f\"popular\"]\n", - " models[f\"cosine_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=CosineRecommender(K=k))\n", - " models[f\"tfidf_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=TFIDFRecommender(K=k))\n", - " models[f\"bm25_userknn_K{k}\"] = ImplicitItemKNNWrapperModel(model=BM25Recommender(K=k))\n", - "\n", - "models" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "240478ad", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " ***CV_0***\n", - "\n", - "==================== Fold 0\n", - "{'End date': Timestamp('2021-07-18 00:00:00', freq='W-SUN'),\n", - " 'Start date': Timestamp('2021-07-11 00:00:00', freq='W-SUN'),\n", - " 'Test': 214489,\n", - " 'Test items': 6313,\n", - " 'Test users': 84234,\n", - " 'Train': 3192875,\n", - " 'Train items': 14711,\n", - " 'Train users': 640144}\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "54fd89ff19334e3182f264d9c492bc0f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/14 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
foldmodelprec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipitycv
00popular_view-False0.0579270.1656440.0407500.2256730.0815120.0912683.5276320.000000fold_1w
10popular_view-True0.0670680.1877850.0431740.2368550.1064160.1146743.7705130.000003fold_1w
20popular_mw_view-False0.0000000.0000000.0000000.0000000.0000000.00000018.2251010.000000fold_1w
30popular_mw_view-True0.0000000.0000000.0000000.0000000.0000000.00000018.2251130.000000fold_1w
40cosine_userknn_K30_view-False0.0233090.0739180.0221430.1297750.0243640.0326517.9141100.000048fold_1w
....................................
2754cosine_userknn_K60_view-True0.0308600.0877970.0234320.1295780.0527720.0590919.1529680.000122fold_5d
2764tfidf_userknn_K60_view-False0.0197570.0608710.0214000.1222600.0196980.0285576.6513340.000095fold_5d
2774tfidf_userknn_K60_view-True0.0428030.1162650.0321730.1704580.0694600.0779126.7271280.000180fold_5d
2784bm25_userknn_K60_view-False0.0370060.1074420.0289580.1623460.0378950.0461993.9205840.000024fold_5d
2794bm25_userknn_K60_view-True0.0495680.1399710.0346460.1918270.0841810.0920224.0025740.000038fold_5d
\n", - "

280 rows × 11 columns

\n", - "" - ], - "text/plain": [ - " fold model prec@5 recall@5 prec@10 \\\n", - "0 0 popular_view-False 0.057927 0.165644 0.040750 \n", - "1 0 popular_view-True 0.067068 0.187785 0.043174 \n", - "2 0 popular_mw_view-False 0.000000 0.000000 0.000000 \n", - "3 0 popular_mw_view-True 0.000000 0.000000 0.000000 \n", - "4 0 cosine_userknn_K30_view-False 0.023309 0.073918 0.022143 \n", - ".. ... ... ... ... ... \n", - "275 4 cosine_userknn_K60_view-True 0.030860 0.087797 0.023432 \n", - "276 4 tfidf_userknn_K60_view-False 0.019757 0.060871 0.021400 \n", - "277 4 tfidf_userknn_K60_view-True 0.042803 0.116265 0.032173 \n", - "278 4 bm25_userknn_K60_view-False 0.037006 0.107442 0.028958 \n", - "279 4 bm25_userknn_K60_view-True 0.049568 0.139971 0.034646 \n", - "\n", - " recall@10 MAP@5 MAP@10 novelty serendipity cv \n", - "0 0.225673 0.081512 0.091268 3.527632 0.000000 fold_1w \n", - "1 0.236855 0.106416 0.114674 3.770513 0.000003 fold_1w \n", - "2 0.000000 0.000000 0.000000 18.225101 0.000000 fold_1w \n", - "3 0.000000 0.000000 0.000000 18.225113 0.000000 fold_1w \n", - "4 0.129775 0.024364 0.032651 7.914110 0.000048 fold_1w \n", - ".. ... ... ... ... ... ... \n", - "275 0.129578 0.052772 0.059091 9.152968 0.000122 fold_5d \n", - "276 0.122260 0.019698 0.028557 6.651334 0.000095 fold_5d \n", - "277 0.170458 0.069460 0.077912 6.727128 0.000180 fold_5d \n", - "278 0.162346 0.037895 0.046199 3.920584 0.000024 fold_5d \n", - "279 0.191827 0.084181 0.092022 4.002574 0.000038 fold_5d \n", - "\n", - "[280 rows x 11 columns]" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics = pd.DataFrame(results)\n", - "\n", - "df_metrics['cv'] = 'fold_1w'\n", - "df_metrics.loc[df_metrics[240:].index, 'cv'] = 'fold_5d'\n", - "\n", - "df_metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "2c075a81", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
  prec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipity
cvmodel        
fold_1wbm25_userknn_K30_view-False0.0421850.1198230.0341540.1853470.0425250.0526723.9462480.000025
bm25_userknn_K30_view-True0.0578090.1584620.0399610.2134740.0952320.1039424.0229380.000040
bm25_userknn_K40_view-False0.0421440.1197360.0341430.1853240.0424800.0526323.9468740.000024
bm25_userknn_K40_view-True0.0578090.1584590.0399780.2136720.0952250.1039584.0190520.000040
bm25_userknn_K50_view-False0.0427580.1212820.0346770.1878470.0430140.0533383.9522220.000024
bm25_userknn_K50_view-True0.0587270.1606350.0404860.2159230.0965760.1053574.0196010.000040
bm25_userknn_K60_view-False0.0427370.1212300.0346660.1877960.0429910.0533143.9540010.000024
bm25_userknn_K60_view-True0.0587330.1606450.0404890.2159890.0965820.1053694.0197450.000040
cosine_userknn_K30_view-False0.0182530.0566890.0183590.1056300.0186880.0258658.0175230.000059
cosine_userknn_K30_view-True0.0355550.0989330.0263640.1430560.0604240.0672809.2559170.000110
cosine_userknn_K40_view-False0.0182420.0566650.0184200.1058860.0186730.0258897.9988740.000059
cosine_userknn_K40_view-True0.0357950.0994330.0266170.1440550.0606590.0676009.1942320.000112
cosine_userknn_K50_view-False0.0185740.0575860.0187360.1074510.0189640.0262837.9764920.000059
cosine_userknn_K50_view-True0.0365740.1013450.0270880.1460780.0617260.0687059.1359440.000112
cosine_userknn_K60_view-False0.0185870.0576330.0187920.1077450.0189680.0263177.9643440.000059
cosine_userknn_K60_view-True0.0367750.1017880.0272630.1468410.0619650.0689959.0997830.000113
popular_mw_view-False0.0000010.0000040.0000010.0000050.0000010.00000118.4532010.000000
popular_mw_view-True0.0000010.0000040.0000010.0000050.0000010.00000118.4532120.000000
popular_view-False0.0478130.1347100.0336350.1829810.0675170.0751913.4627720.000000
popular_view-True0.0548010.1519420.0360520.1947680.0853480.0923823.7265770.000002
tfidf_userknn_K30_view-False0.0230980.0696630.0243060.1356420.0228240.0325196.7433550.000089
tfidf_userknn_K30_view-True0.0472470.1262870.0351810.1830160.0768820.0859166.9707800.000163
tfidf_userknn_K40_view-False0.0230880.0696410.0243660.1360420.0228090.0325576.7253990.000089
tfidf_userknn_K40_view-True0.0475380.1269040.0354420.1841430.0772930.0864016.9208620.000164
tfidf_userknn_K50_view-False0.0233680.0703830.0246770.1375270.0230520.0329166.7182830.000088
tfidf_userknn_K50_view-True0.0482450.1284920.0358830.1859640.0782480.0874186.8985940.000163
tfidf_userknn_K60_view-False0.0233350.0702720.0247020.1376080.0230200.0329066.7094850.000088
tfidf_userknn_K60_view-True0.0483400.1286640.0360000.1863860.0782780.0874876.8729720.000164
fold_5dbm25_userknn_K30_view-False0.0370910.1076370.0289830.1624440.0380040.0462953.9161130.000024
bm25_userknn_K30_view-True0.0495240.1398800.0346900.1919240.0841410.0920164.0087720.000040
bm25_userknn_K40_view-False0.0370410.1075290.0289650.1623580.0379570.0462533.9168820.000024
bm25_userknn_K40_view-True0.0495350.1398850.0346590.1918370.0841560.0920154.0040600.000039
bm25_userknn_K50_view-False0.0369970.1071580.0293990.1636720.0379280.0464873.9190890.000024
bm25_userknn_K50_view-True0.0500520.1405600.0352420.1936600.0843320.0924104.0038040.000039
bm25_userknn_K60_view-False0.0369840.1071180.0293940.1636480.0379040.0464643.9209690.000024
bm25_userknn_K60_view-True0.0500670.1405850.0352470.1937300.0843450.0924244.0038000.000039
cosine_userknn_K30_view-False0.0150530.0478120.0152010.0903210.0156960.0217948.0592570.000062
cosine_userknn_K30_view-True0.0301280.0862880.0227960.1271860.0520750.0582429.3133310.000118
cosine_userknn_K40_view-False0.0150880.0478410.0152530.0905090.0156930.0218108.0395780.000062
cosine_userknn_K40_view-True0.0304070.0868690.0230640.1281230.0523320.0585509.2453400.000120
cosine_userknn_K50_view-False0.0153600.0485870.0158260.0933330.0159890.0224308.0218540.000062
cosine_userknn_K50_view-True0.0312470.0885020.0239710.1324190.0535310.0601719.1785550.000119
cosine_userknn_K60_view-False0.0153780.0486400.0158630.0936170.0159980.0224648.0092990.000062
cosine_userknn_K60_view-True0.0314330.0888970.0241430.1332180.0537360.0604409.1400930.000120
popular_mw_view-False0.0000120.0000470.0000060.0000470.0000160.00001618.5009540.000001
popular_mw_view-True0.0000120.0000470.0000060.0000470.0000160.00001618.5009640.000001
popular_view-False0.0412890.1193840.0274250.1536150.0614460.0669593.4328550.000000
popular_view-True0.0471830.1346070.0301480.1685700.0769260.0821673.7145430.000002
tfidf_userknn_K30_view-False0.0198470.0612400.0213000.1217630.0198270.0285606.6921510.000096
tfidf_userknn_K30_view-True0.0423270.1152380.0315110.1680460.0689490.0771826.8414240.000178
tfidf_userknn_K40_view-False0.0197870.0609900.0213720.1222950.0197550.0285856.6717820.000096
tfidf_userknn_K40_view-True0.0425410.1157180.0318530.1694190.0691930.0775726.7890490.000178
tfidf_userknn_K50_view-False0.0201520.0619330.0218290.1244330.0201690.0292106.6699910.000095
tfidf_userknn_K50_view-True0.0431760.1170330.0325920.1725790.0705340.0792026.7754160.000176
tfidf_userknn_K60_view-False0.0201230.0618390.0218090.1242790.0201360.0291716.6608280.000094
tfidf_userknn_K60_view-True0.0432480.1171350.0326970.1727720.0705900.0792716.7478700.000176
\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics_mean = df_metrics.groupby(['cv', 'model'])[\n", - " 'prec@5', 'recall@5', 'prec@10', 'recall@10', 'MAP@5', 'MAP@10', 'novelty', 'serendipity'\n", - "].mean()\n", - "\n", - "df_metrics_mean.style.highlight_max(color='lightgreen', axis=0)" - ] - }, - { - "cell_type": "markdown", - "id": "c6f89d3a", - "metadata": {}, - "source": [ - "Из результатов видно, что среднее значение метрик моделей **bmp** имеют **наилучшие** значения, причем на недельном фолде метрики выше, чем на 5 дневном \n", - "\n", - "- Следует проверить статистически различимы значения или нет. Для этого следует посмотреть дисперсию и если дисперсия меньше чем различия между средними значениями метрик, то можно сделать вывод, что значения метрик статистически различны" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "dbe6f6e9", - "metadata": { - "collapsed": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
prec@5recall@5prec@10recall@10MAP@5MAP@10noveltyserendipity
cvmodel
fold_1wbm25_userknn_K30_view-False0.0040420.0113800.0033190.0180500.0041440.0052350.0297301.677056e-06
bm25_userknn_K30_view-True0.0053480.0145870.0027870.0136200.0098230.0098710.0150063.357219e-06
bm25_userknn_K40_view-False0.0040360.0113930.0033100.0180330.0041490.0052340.0296131.666772e-06
bm25_userknn_K40_view-True0.0053340.0145920.0027740.0135110.0098240.0098590.0147093.359684e-06
bm25_userknn_K50_view-False0.0037770.0110130.0030870.0174760.0040390.0050710.0293291.807696e-06
bm25_userknn_K50_view-True0.0048940.0139550.0024670.0124850.0095620.0095320.0148613.329699e-06
bm25_userknn_K60_view-False0.0037740.0110020.0030880.0174920.0040380.0050740.0293681.805538e-06
bm25_userknn_K60_view-True0.0048940.0139630.0024660.0124780.0095620.0095290.0147383.340070e-06
cosine_userknn_K30_view-False0.0023930.0078030.0019180.0113260.0025180.0030910.0473375.304930e-06
cosine_userknn_K30_view-True0.0036240.0104000.0017860.0088100.0070000.0068420.0470279.087191e-06
cosine_userknn_K40_view-False0.0023890.0077940.0019080.0112950.0025170.0030810.0464455.302222e-06
cosine_userknn_K40_view-True0.0035680.0102570.0017520.0085470.0069510.0067840.0468249.276498e-06
cosine_userknn_K50_view-False0.0023160.0077730.0018440.0112840.0025120.0030720.0463455.581875e-06
cosine_userknn_K50_view-True0.0033820.0100320.0016460.0082760.0069180.0067300.0467999.814832e-06
cosine_userknn_K60_view-False0.0023140.0077920.0018590.0113540.0025140.0030830.0464545.564036e-06
cosine_userknn_K60_view-True0.0034040.0100880.0016380.0082460.0069610.0067690.0471339.793017e-06
popular_mw_view-False0.0000020.0000060.0000010.0000050.0000020.0000020.1251581.003157e-07
popular_mw_view-True0.0000020.0000060.0000010.0000050.0000020.0000020.1251551.003157e-07
popular_view-False0.0052120.0145720.0038290.0217390.0061000.0070750.0301750.000000e+00
popular_view-True0.0061040.0166560.0037930.0208750.0089550.0096530.0191312.429424e-07
tfidf_userknn_K30_view-False0.0022440.0069950.0019450.0106640.0023810.0029650.0398778.695037e-06
tfidf_userknn_K30_view-True0.0033320.0090010.0019680.0086110.0063740.0063900.0696381.279797e-05
tfidf_userknn_K40_view-False0.0022470.0070180.0019330.0105820.0023840.0029490.0408518.509896e-06
tfidf_userknn_K40_view-True0.0033190.0089040.0019240.0082880.0063000.0062880.0708191.317232e-05
tfidf_userknn_K50_view-False0.0022010.0070680.0018850.0105910.0024010.0029550.0404928.580259e-06
tfidf_userknn_K50_view-True0.0031950.0088190.0018010.0077960.0063570.0062790.0654901.351216e-05
tfidf_userknn_K60_view-False0.0022040.0070880.0018860.0106590.0024120.0029720.0408278.400699e-06
tfidf_userknn_K60_view-True0.0031990.0088290.0018090.0078410.0064200.0063480.0660501.359432e-05
fold_5dbm25_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
bm25_userknn_K50_view-False0.0000390.0004560.0006120.0018670.0000020.0003660.0007251.109718e-07
bm25_userknn_K50_view-True0.0006970.0008470.0008170.0025360.0002280.0005520.0012428.703289e-07
bm25_userknn_K60_view-False0.0000310.0004580.0006160.0018410.0000130.0003750.0005438.345726e-08
bm25_userknn_K60_view-True0.0007060.0008690.0008500.0026910.0002320.0005690.0017341.165137e-06
cosine_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
cosine_userknn_K50_view-False0.0003550.0009810.0007060.0033660.0004120.0007980.0010416.443968e-07
cosine_userknn_K50_view-True0.0007920.0014680.0009660.0047400.0013080.0018170.0183502.644474e-06
cosine_userknn_K60_view-False0.0003770.0010290.0007350.0035480.0004270.0008310.0032367.481876e-07
cosine_userknn_K60_view-True0.0008100.0015550.0010060.0051480.0013630.0019080.0182082.395850e-06
popular_mw_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
popular_mw_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
popular_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
popular_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K30_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K30_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K40_view-FalseNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K40_view-TrueNaNNaNNaNNaNNaNNaNNaNNaN
tfidf_userknn_K50_view-False0.0005680.0014770.0005630.0026220.0006370.0008480.0140731.071478e-06
tfidf_userknn_K50_view-True0.0006550.0012350.0007460.0033030.0015870.0019030.0288473.757778e-06
tfidf_userknn_K60_view-False0.0005170.0013700.0005780.0028550.0006190.0008680.0134261.001711e-06
tfidf_userknn_K60_view-True0.0006300.0012300.0007420.0032730.0015980.0019220.0293354.434436e-06
\n", - "
" - ], - "text/plain": [ - " prec@5 recall@5 prec@10 \\\n", - "cv model \n", - "fold_1w bm25_userknn_K30_view-False 0.004042 0.011380 0.003319 \n", - " bm25_userknn_K30_view-True 0.005348 0.014587 0.002787 \n", - " bm25_userknn_K40_view-False 0.004036 0.011393 0.003310 \n", - " bm25_userknn_K40_view-True 0.005334 0.014592 0.002774 \n", - " bm25_userknn_K50_view-False 0.003777 0.011013 0.003087 \n", - " bm25_userknn_K50_view-True 0.004894 0.013955 0.002467 \n", - " bm25_userknn_K60_view-False 0.003774 0.011002 0.003088 \n", - " bm25_userknn_K60_view-True 0.004894 0.013963 0.002466 \n", - " cosine_userknn_K30_view-False 0.002393 0.007803 0.001918 \n", - " cosine_userknn_K30_view-True 0.003624 0.010400 0.001786 \n", - " cosine_userknn_K40_view-False 0.002389 0.007794 0.001908 \n", - " cosine_userknn_K40_view-True 0.003568 0.010257 0.001752 \n", - " cosine_userknn_K50_view-False 0.002316 0.007773 0.001844 \n", - " cosine_userknn_K50_view-True 0.003382 0.010032 0.001646 \n", - " cosine_userknn_K60_view-False 0.002314 0.007792 0.001859 \n", - " cosine_userknn_K60_view-True 0.003404 0.010088 0.001638 \n", - " popular_mw_view-False 0.000002 0.000006 0.000001 \n", - " popular_mw_view-True 0.000002 0.000006 0.000001 \n", - " popular_view-False 0.005212 0.014572 0.003829 \n", - " popular_view-True 0.006104 0.016656 0.003793 \n", - " tfidf_userknn_K30_view-False 0.002244 0.006995 0.001945 \n", - " tfidf_userknn_K30_view-True 0.003332 0.009001 0.001968 \n", - " tfidf_userknn_K40_view-False 0.002247 0.007018 0.001933 \n", - " tfidf_userknn_K40_view-True 0.003319 0.008904 0.001924 \n", - " tfidf_userknn_K50_view-False 0.002201 0.007068 0.001885 \n", - " tfidf_userknn_K50_view-True 0.003195 0.008819 0.001801 \n", - " tfidf_userknn_K60_view-False 0.002204 0.007088 0.001886 \n", - " tfidf_userknn_K60_view-True 0.003199 0.008829 0.001809 \n", - "fold_5d bm25_userknn_K30_view-False NaN NaN NaN \n", - " bm25_userknn_K30_view-True NaN NaN NaN \n", - " bm25_userknn_K40_view-False NaN NaN NaN \n", - " bm25_userknn_K40_view-True NaN NaN NaN \n", - " bm25_userknn_K50_view-False 0.000039 0.000456 0.000612 \n", - " bm25_userknn_K50_view-True 0.000697 0.000847 0.000817 \n", - " bm25_userknn_K60_view-False 0.000031 0.000458 0.000616 \n", - " bm25_userknn_K60_view-True 0.000706 0.000869 0.000850 \n", - " cosine_userknn_K30_view-False NaN NaN NaN \n", - " cosine_userknn_K30_view-True NaN NaN NaN \n", - " cosine_userknn_K40_view-False NaN NaN NaN \n", - " cosine_userknn_K40_view-True NaN NaN NaN \n", - " cosine_userknn_K50_view-False 0.000355 0.000981 0.000706 \n", - " cosine_userknn_K50_view-True 0.000792 0.001468 0.000966 \n", - " cosine_userknn_K60_view-False 0.000377 0.001029 0.000735 \n", - " cosine_userknn_K60_view-True 0.000810 0.001555 0.001006 \n", - " popular_mw_view-False NaN NaN NaN \n", - " popular_mw_view-True NaN NaN NaN \n", - " popular_view-False NaN NaN NaN \n", - " popular_view-True NaN NaN NaN \n", - " tfidf_userknn_K30_view-False NaN NaN NaN \n", - " tfidf_userknn_K30_view-True NaN NaN NaN \n", - " tfidf_userknn_K40_view-False NaN NaN NaN \n", - " tfidf_userknn_K40_view-True NaN NaN NaN \n", - " tfidf_userknn_K50_view-False 0.000568 0.001477 0.000563 \n", - " tfidf_userknn_K50_view-True 0.000655 0.001235 0.000746 \n", - " tfidf_userknn_K60_view-False 0.000517 0.001370 0.000578 \n", - " tfidf_userknn_K60_view-True 0.000630 0.001230 0.000742 \n", - "\n", - " recall@10 MAP@5 MAP@10 \\\n", - "cv model \n", - "fold_1w bm25_userknn_K30_view-False 0.018050 0.004144 0.005235 \n", - " bm25_userknn_K30_view-True 0.013620 0.009823 0.009871 \n", - " bm25_userknn_K40_view-False 0.018033 0.004149 0.005234 \n", - " bm25_userknn_K40_view-True 0.013511 0.009824 0.009859 \n", - " bm25_userknn_K50_view-False 0.017476 0.004039 0.005071 \n", - " bm25_userknn_K50_view-True 0.012485 0.009562 0.009532 \n", - " bm25_userknn_K60_view-False 0.017492 0.004038 0.005074 \n", - " bm25_userknn_K60_view-True 0.012478 0.009562 0.009529 \n", - " cosine_userknn_K30_view-False 0.011326 0.002518 0.003091 \n", - " cosine_userknn_K30_view-True 0.008810 0.007000 0.006842 \n", - " cosine_userknn_K40_view-False 0.011295 0.002517 0.003081 \n", - " cosine_userknn_K40_view-True 0.008547 0.006951 0.006784 \n", - " cosine_userknn_K50_view-False 0.011284 0.002512 0.003072 \n", - " cosine_userknn_K50_view-True 0.008276 0.006918 0.006730 \n", - " cosine_userknn_K60_view-False 0.011354 0.002514 0.003083 \n", - " cosine_userknn_K60_view-True 0.008246 0.006961 0.006769 \n", - " popular_mw_view-False 0.000005 0.000002 0.000002 \n", - " popular_mw_view-True 0.000005 0.000002 0.000002 \n", - " popular_view-False 0.021739 0.006100 0.007075 \n", - " popular_view-True 0.020875 0.008955 0.009653 \n", - " tfidf_userknn_K30_view-False 0.010664 0.002381 0.002965 \n", - " tfidf_userknn_K30_view-True 0.008611 0.006374 0.006390 \n", - " tfidf_userknn_K40_view-False 0.010582 0.002384 0.002949 \n", - " tfidf_userknn_K40_view-True 0.008288 0.006300 0.006288 \n", - " tfidf_userknn_K50_view-False 0.010591 0.002401 0.002955 \n", - " tfidf_userknn_K50_view-True 0.007796 0.006357 0.006279 \n", - " tfidf_userknn_K60_view-False 0.010659 0.002412 0.002972 \n", - " tfidf_userknn_K60_view-True 0.007841 0.006420 0.006348 \n", - "fold_5d bm25_userknn_K30_view-False NaN NaN NaN \n", - " bm25_userknn_K30_view-True NaN NaN NaN \n", - " bm25_userknn_K40_view-False NaN NaN NaN \n", - " bm25_userknn_K40_view-True NaN NaN NaN \n", - " bm25_userknn_K50_view-False 0.001867 0.000002 0.000366 \n", - " bm25_userknn_K50_view-True 0.002536 0.000228 0.000552 \n", - " bm25_userknn_K60_view-False 0.001841 0.000013 0.000375 \n", - " bm25_userknn_K60_view-True 0.002691 0.000232 0.000569 \n", - " cosine_userknn_K30_view-False NaN NaN NaN \n", - " cosine_userknn_K30_view-True NaN NaN NaN \n", - " cosine_userknn_K40_view-False NaN NaN NaN \n", - " cosine_userknn_K40_view-True NaN NaN NaN \n", - " cosine_userknn_K50_view-False 0.003366 0.000412 0.000798 \n", - " cosine_userknn_K50_view-True 0.004740 0.001308 0.001817 \n", - " cosine_userknn_K60_view-False 0.003548 0.000427 0.000831 \n", - " cosine_userknn_K60_view-True 0.005148 0.001363 0.001908 \n", - " popular_mw_view-False NaN NaN NaN \n", - " popular_mw_view-True NaN NaN NaN \n", - " popular_view-False NaN NaN NaN \n", - " popular_view-True NaN NaN NaN \n", - " tfidf_userknn_K30_view-False NaN NaN NaN \n", - " tfidf_userknn_K30_view-True NaN NaN NaN \n", - " tfidf_userknn_K40_view-False NaN NaN NaN \n", - " tfidf_userknn_K40_view-True NaN NaN NaN \n", - " tfidf_userknn_K50_view-False 0.002622 0.000637 0.000848 \n", - " tfidf_userknn_K50_view-True 0.003303 0.001587 0.001903 \n", - " tfidf_userknn_K60_view-False 0.002855 0.000619 0.000868 \n", - " tfidf_userknn_K60_view-True 0.003273 0.001598 0.001922 \n", - "\n", - " novelty serendipity \n", - "cv model \n", - "fold_1w bm25_userknn_K30_view-False 0.029730 1.677056e-06 \n", - " bm25_userknn_K30_view-True 0.015006 3.357219e-06 \n", - " bm25_userknn_K40_view-False 0.029613 1.666772e-06 \n", - " bm25_userknn_K40_view-True 0.014709 3.359684e-06 \n", - " bm25_userknn_K50_view-False 0.029329 1.807696e-06 \n", - " bm25_userknn_K50_view-True 0.014861 3.329699e-06 \n", - " bm25_userknn_K60_view-False 0.029368 1.805538e-06 \n", - " bm25_userknn_K60_view-True 0.014738 3.340070e-06 \n", - " cosine_userknn_K30_view-False 0.047337 5.304930e-06 \n", - " cosine_userknn_K30_view-True 0.047027 9.087191e-06 \n", - " cosine_userknn_K40_view-False 0.046445 5.302222e-06 \n", - " cosine_userknn_K40_view-True 0.046824 9.276498e-06 \n", - " cosine_userknn_K50_view-False 0.046345 5.581875e-06 \n", - " cosine_userknn_K50_view-True 0.046799 9.814832e-06 \n", - " cosine_userknn_K60_view-False 0.046454 5.564036e-06 \n", - " cosine_userknn_K60_view-True 0.047133 9.793017e-06 \n", - " popular_mw_view-False 0.125158 1.003157e-07 \n", - " popular_mw_view-True 0.125155 1.003157e-07 \n", - " popular_view-False 0.030175 0.000000e+00 \n", - " popular_view-True 0.019131 2.429424e-07 \n", - " tfidf_userknn_K30_view-False 0.039877 8.695037e-06 \n", - " tfidf_userknn_K30_view-True 0.069638 1.279797e-05 \n", - " tfidf_userknn_K40_view-False 0.040851 8.509896e-06 \n", - " tfidf_userknn_K40_view-True 0.070819 1.317232e-05 \n", - " tfidf_userknn_K50_view-False 0.040492 8.580259e-06 \n", - " tfidf_userknn_K50_view-True 0.065490 1.351216e-05 \n", - " tfidf_userknn_K60_view-False 0.040827 8.400699e-06 \n", - " tfidf_userknn_K60_view-True 0.066050 1.359432e-05 \n", - "fold_5d bm25_userknn_K30_view-False NaN NaN \n", - " bm25_userknn_K30_view-True NaN NaN \n", - " bm25_userknn_K40_view-False NaN NaN \n", - " bm25_userknn_K40_view-True NaN NaN \n", - " bm25_userknn_K50_view-False 0.000725 1.109718e-07 \n", - " bm25_userknn_K50_view-True 0.001242 8.703289e-07 \n", - " bm25_userknn_K60_view-False 0.000543 8.345726e-08 \n", - " bm25_userknn_K60_view-True 0.001734 1.165137e-06 \n", - " cosine_userknn_K30_view-False NaN NaN \n", - " cosine_userknn_K30_view-True NaN NaN \n", - " cosine_userknn_K40_view-False NaN NaN \n", - " cosine_userknn_K40_view-True NaN NaN \n", - " cosine_userknn_K50_view-False 0.001041 6.443968e-07 \n", - " cosine_userknn_K50_view-True 0.018350 2.644474e-06 \n", - " cosine_userknn_K60_view-False 0.003236 7.481876e-07 \n", - " cosine_userknn_K60_view-True 0.018208 2.395850e-06 \n", - " popular_mw_view-False NaN NaN \n", - " popular_mw_view-True NaN NaN \n", - " popular_view-False NaN NaN \n", - " popular_view-True NaN NaN \n", - " tfidf_userknn_K30_view-False NaN NaN \n", - " tfidf_userknn_K30_view-True NaN NaN \n", - " tfidf_userknn_K40_view-False NaN NaN \n", - " tfidf_userknn_K40_view-True NaN NaN \n", - " tfidf_userknn_K50_view-False 0.014073 1.071478e-06 \n", - " tfidf_userknn_K50_view-True 0.028847 3.757778e-06 \n", - " tfidf_userknn_K60_view-False 0.013426 1.001711e-06 \n", - " tfidf_userknn_K60_view-True 0.029335 4.434436e-06 " - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_metrics_std = df_metrics.groupby(['cv', 'model'])[\n", - " 'prec@5', 'recall@5', 'prec@10', 'recall@10', 'MAP@5', 'MAP@10', 'novelty', 'serendipity'\n", - "].std()\n", - "\n", - "df_metrics_std" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "id": "58ad9d07", - "metadata": {}, - "outputs": [], - "source": [ - "df_metrics_1w_mean = df_metrics_mean.loc[\"fold_1w\"]\n", - "df_metrics_1w_std = df_metrics_std.loc[\"fold_1w\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "id": "52bffafc", - "metadata": {}, - "outputs": [], - "source": [ - "best_model = \"bm25_userknn_K60_view-True\"\n", - "col_metrics = list(metrics.keys())\n", - "std_best_metrics = df_metrics_1w_std[df_metrics_1w_std[\"model\"] == best_model][col_metrics].values[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "0059da61", - "metadata": { - "collapsed": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "prec@5 0.004894\n", - "recall@5 0.013963\n", - "prec@10 0.002466\n", - "recall@10 0.012478\n", - "MAP@5 0.009562\n", - "MAP@10 0.009529\n", - "novelty 0.014738\n", - "serendipity 0.000003\n", - "Name: bm25_userknn_K60_view-True, dtype: float64\n", - "\n", - "===Сравнение с bm25_userknn_K30_view-False\n", - "prec@5 0.000852\n", - "recall@5 0.002584\n", - "prec@10 -0.000854\n", - "recall@10 -0.005571\n", - "MAP@5 0.005418\n", - "MAP@10 0.004293\n", - "novelty -0.014992\n", - "serendipity 0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K30_view-True\n", - "prec@5 -4.545181e-04\n", - "recall@5 -6.232778e-04\n", - "prec@10 -3.208519e-04\n", - "recall@10 -1.141463e-03\n", - "MAP@5 -2.606624e-04\n", - "MAP@10 -3.423636e-04\n", - "novelty -2.682734e-04\n", - "serendipity -1.714842e-08\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K40_view-False\n", - "prec@5 0.000858\n", - "recall@5 0.002570\n", - "prec@10 -0.000844\n", - "recall@10 -0.005554\n", - "MAP@5 0.005413\n", - "MAP@10 0.004294\n", - "novelty -0.014875\n", - "serendipity 0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K40_view-True\n", - "prec@5 -4.403880e-04\n", - "recall@5 -6.290758e-04\n", - "prec@10 -3.080923e-04\n", - "recall@10 -1.032860e-03\n", - "MAP@5 -2.622039e-04\n", - "MAP@10 -3.305713e-04\n", - "novelty 2.940168e-05\n", - "serendipity -1.961419e-08\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K50_view-False\n", - "prec@5 0.001117\n", - "recall@5 0.002950\n", - "prec@10 -0.000621\n", - "recall@10 -0.004998\n", - "MAP@5 0.005523\n", - "MAP@10 0.004457\n", - "novelty -0.014591\n", - "serendipity 0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K50_view-True\n", - "prec@5 -6.188074e-08\n", - "recall@5 8.340206e-06\n", - "prec@10 -1.581047e-06\n", - "recall@10 -6.557768e-06\n", - "MAP@5 8.532418e-08\n", - "MAP@10 -3.212381e-06\n", - "novelty -1.225529e-04\n", - "serendipity 1.037090e-08\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с bm25_userknn_K60_view-False\n", - "prec@5 0.001119\n", - "recall@5 0.002961\n", - "prec@10 -0.000622\n", - "recall@10 -0.005014\n", - "MAP@5 0.005524\n", - "MAP@10 0.004455\n", - "novelty -0.014630\n", - "serendipity 0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K30_view-False\n", - "prec@5 0.002501\n", - "recall@5 0.006161\n", - "prec@10 0.000548\n", - "recall@10 0.001152\n", - "MAP@5 0.007044\n", - "MAP@10 0.006437\n", - "novelty -0.032599\n", - "serendipity -0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K30_view-True\n", - "prec@5 0.001270\n", - "recall@5 0.003563\n", - "prec@10 0.000680\n", - "recall@10 0.003668\n", - "MAP@5 0.002563\n", - "MAP@10 0.002687\n", - "novelty -0.032289\n", - "serendipity -0.000006\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K40_view-False\n", - "prec@5 0.002504\n", - "recall@5 0.006169\n", - "prec@10 0.000558\n", - "recall@10 0.001184\n", - "MAP@5 0.007046\n", - "MAP@10 0.006448\n", - "novelty -0.031707\n", - "serendipity -0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K40_view-True\n", - "prec@5 0.001325\n", - "recall@5 0.003706\n", - "prec@10 0.000714\n", - "recall@10 0.003931\n", - "MAP@5 0.002611\n", - "MAP@10 0.002744\n", - "novelty -0.032086\n", - "serendipity -0.000006\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K50_view-False\n", - "prec@5 0.002578\n", - "recall@5 0.006190\n", - "prec@10 0.000622\n", - "recall@10 0.001195\n", - "MAP@5 0.007051\n", - "MAP@10 0.006456\n", - "novelty -0.031607\n", - "serendipity -0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K50_view-True\n", - "prec@5 0.001511\n", - "recall@5 0.003931\n", - "prec@10 0.000820\n", - "recall@10 0.004203\n", - "MAP@5 0.002644\n", - "MAP@10 0.002798\n", - "novelty -0.032061\n", - "serendipity -0.000006\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K60_view-False\n", - "prec@5 0.002580\n", - "recall@5 0.006172\n", - "prec@10 0.000607\n", - "recall@10 0.001124\n", - "MAP@5 0.007049\n", - "MAP@10 0.006446\n", - "novelty -0.031716\n", - "serendipity -0.000002\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с cosine_userknn_K60_view-True\n", - "prec@5 0.001489\n", - "recall@5 0.003875\n", - "prec@10 0.000827\n", - "recall@10 0.004232\n", - "MAP@5 0.002601\n", - "MAP@10 0.002760\n", - "novelty -0.032395\n", - "serendipity -0.000006\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с popular_mw_view-False\n", - "prec@5 0.004892\n", - "recall@5 0.013958\n", - "prec@10 0.002465\n", - "recall@10 0.012473\n", - "MAP@5 0.009561\n", - "MAP@10 0.009527\n", - "novelty -0.110420\n", - "serendipity 0.000003\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с popular_mw_view-True\n", - "prec@5 0.004892\n", - "recall@5 0.013958\n", - "prec@10 0.002465\n", - "recall@10 0.012473\n", - "MAP@5 0.009561\n", - "MAP@10 0.009527\n", - "novelty -0.110417\n", - "serendipity 0.000003\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с popular_view-False\n", - "prec@5 -0.000319\n", - "recall@5 -0.000609\n", - "prec@10 -0.001363\n", - "recall@10 -0.009260\n", - "MAP@5 0.003462\n", - "MAP@10 0.002453\n", - "novelty -0.015437\n", - "serendipity 0.000003\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с popular_view-True\n", - "prec@5 -0.001210\n", - "recall@5 -0.002692\n", - "prec@10 -0.001327\n", - "recall@10 -0.008397\n", - "MAP@5 0.000607\n", - "MAP@10 -0.000124\n", - "novelty -0.004393\n", - "serendipity 0.000003\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K30_view-False\n", - "prec@5 0.002649\n", - "recall@5 0.006968\n", - "prec@10 0.000521\n", - "recall@10 0.001815\n", - "MAP@5 0.007181\n", - "MAP@10 0.006564\n", - "novelty -0.025139\n", - "serendipity -0.000005\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K30_view-True\n", - "prec@5 0.001561\n", - "recall@5 0.004963\n", - "prec@10 0.000498\n", - "recall@10 0.003867\n", - "MAP@5 0.003188\n", - "MAP@10 0.003139\n", - "novelty -0.054900\n", - "serendipity -0.000009\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K40_view-False\n", - "prec@5 0.002647\n", - "recall@5 0.006945\n", - "prec@10 0.000532\n", - "recall@10 0.001897\n", - "MAP@5 0.007178\n", - "MAP@10 0.006579\n", - "novelty -0.026113\n", - "serendipity -0.000005\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K40_view-True\n", - "prec@5 0.001575\n", - "recall@5 0.005059\n", - "prec@10 0.000542\n", - "recall@10 0.004190\n", - "MAP@5 0.003262\n", - "MAP@10 0.003240\n", - "novelty -0.056080\n", - "serendipity -0.000010\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K50_view-False\n", - "prec@5 0.002693\n", - "recall@5 0.006895\n", - "prec@10 0.000581\n", - "recall@10 0.001887\n", - "MAP@5 0.007161\n", - "MAP@10 0.006574\n", - "novelty -0.025754\n", - "serendipity -0.000005\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K50_view-True\n", - "prec@5 0.001698\n", - "recall@5 0.005144\n", - "prec@10 0.000665\n", - "recall@10 0.004682\n", - "MAP@5 0.003205\n", - "MAP@10 0.003249\n", - "novelty -0.050752\n", - "serendipity -0.000010\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K60_view-False\n", - "prec@5 0.002690\n", - "recall@5 0.006875\n", - "prec@10 0.000580\n", - "recall@10 0.001819\n", - "MAP@5 0.007150\n", - "MAP@10 0.006557\n", - "novelty -0.026089\n", - "serendipity -0.000005\n", - "dtype: float64\n", - "=========================\n", - "\n", - "===Сравнение с tfidf_userknn_K60_view-True\n", - "prec@5 0.001694\n", - "recall@5 0.005134\n", - "prec@10 0.000657\n", - "recall@10 0.004637\n", - "MAP@5 0.003142\n", - "MAP@10 0.003180\n", - "novelty -0.051312\n", - "serendipity -0.000010\n", - "dtype: float64\n", - "=========================\n" - ] - } - ], - "source": [ - "print(df_metrics_1w_std.loc[best_model])\n", - "for model in df_metrics_1w_mean.index:\n", - " if model != best_model:\n", - " print(f\"\\n===Сравнение с {model}\")\n", - " print(df_metrics_1w_mean.loc[best_model] - df_metrics_1w_mean.loc[model])\n", - " print(\"=========================\")" - ] - }, - { - "cell_type": "markdown", - "id": "0675ba9b", - "metadata": {}, - "source": [ - "Лучшей модели большинством из моделей видны статистические различия, кроме всех моделей bmp (логично, потому что лучшая модель bmp с k = 60) и моделью tfidf, где для рекомендаций стоял флаг filter_viewed = True, что означает рекомендовать не одинаковые элементы для всех пользователей" - ] - }, - { - "cell_type": "markdown", - "id": "e233b183", - "metadata": {}, - "source": [ - "# Обучение на всех имеющихся данных и формирование оффлайн рекомендаций" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "30e985b6", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset.construct(\n", - " interactions_df=interactions,\n", - " user_features_df=None,\n", - " item_features_df=None\n", - ")\n", - "\n", - "bmp25_k60_model = ImplicitItemKNNWrapperModel(BM25Recommender(K=60))\n", - "bmp25_k60_model.fit(dataset)\n", - "\n", - "K_RECOS = 30\n", - " \n", - "recos_offline_bmp25 = bmp25_k60_model.recommend(\n", - " users=interactions[Columns.User].unique(),\n", - " dataset=dataset,\n", - " k=K_RECOS,\n", - " filter_viewed=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "4034d96f", - "metadata": {}, - "outputs": [], - "source": [ - "recos_offline_bmp25.to_csv(\"../data/hw_3/bmp_25_k60_rectools.csv\", index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d52a48b5", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset.construct(\n", - " interactions_df=interactions,\n", - " user_features_df=None,\n", - " item_features_df=None\n", - ")\n", - "\n", - "tfidf_k60_model = ImplicitItemKNNWrapperModel(TFIDFRecommender(K=60))\n", - "tfidf_k60_model.fit(dataset)\n", - "\n", - "K_RECOS = 30\n", - " \n", - "recos_offline_tfidf = tfidf_k60_model.recommend(\n", - " users=interactions[Columns.User].unique(),\n", - " dataset=dataset,\n", - " k=K_RECOS,\n", - " filter_viewed=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d74a0ee6", - "metadata": {}, - "outputs": [], - "source": [ - "recos_offline_tfidf.to_csv(\"../data/hw_3/tfidf_k60_rectools.csv\", index=False)" - ] - }, - { - "cell_type": "markdown", - "id": "0164df93", - "metadata": {}, - "source": [ - "# Формирование рекомендаций для cold users" - ] - }, - { - "cell_type": "markdown", - "id": "5af7d214", - "metadata": {}, - "source": [ - "По моделям на основе популярного наилучшего качества достигали метрики по модели popular на основе количества уникальных пользователей взаимодействовавших с элементом, НО по среднему весу взаимодействия с элементами модель показывает по метрики новелти очень высокие результаты, поэтому стоит попробовать обе из моделей" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "51fdeae3", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset.construct(\n", - " interactions_df=interactions,\n", - " user_features_df=None,\n", - " item_features_df=None\n", - ")\n", - "\n", - "popular_model = PopularModel()\n", - "popular_model.fit(dataset)\n", - "\n", - "item_inv = dict(enumerate(interactions[\"item_id\"].unique()))\n", - "recos_pop = []\n", - "for item_pop in popular_model.popularity_list[0]:\n", - " recos_pop.append(item_inv[item_pop])\n", - "\n", - "df_pop_recos = pd.DataFrame({\"item_id\": recos_pop})\n", - "\n", - "df_pop_recos.to_csv(\"../data/hw_3/popular_item.csv\", index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "52981cce", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset.construct(\n", - " interactions_df=interactions,\n", - " user_features_df=None,\n", - " item_features_df=None\n", - ")\n", - "\n", - "popular_model_mw = PopularModel(popularity=\"mean_weight\")\n", - "popular_model_mw.fit(dataset)\n", - "\n", - "item_inv = dict(enumerate(interactions[\"item_id\"].unique()))\n", - "recos_pop = []\n", - "for item_pop in popular_model_mw.popularity_list[0]:\n", - " recos_pop.append(item_inv[item_pop])\n", - "\n", - "df_pop_recos_mw = pd.DataFrame({\"item_id\": recos_pop})\n", - "\n", - "df_pop_recos_mw.to_csv(\"../data/hw_3/popular_mean_weight_item.csv\", index=False)" - ] - }, - { - "cell_type": "markdown", - "id": "170efd3c", - "metadata": {}, - "source": [ - "# Блендинг результатов моделей" - ] - }, - { - "cell_type": "markdown", - "id": "878f0b90", - "metadata": {}, - "source": [ - "Механизм блендинга будет выглядить следующим образом:\n", - "\n", - "1. Берутся рекомендации, сделанные моделями tfidf и bmp25, конкатятся результаты, удялются дубликаты item-ов\n", - "2. Берется заготовленный датаест items c полями item_id и idf\n", - "3. смотрится idf, чем он выше, тем выше будет стоять item в выдаче\n", - "\n", - "Такой подход обусловлен тем, что idf показывает обратную частоту item, соответственно в выдаче наверх будут попадать item, с которым меньшее количество раз взаимодейстовали пользователи, т.е. в перспективе такой подход может предлагать item, с которыми ни один пользователь не взаимодействовал или взаимодействовали очень мало, т.е. может решиться проблема длинного хвоста." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "3b35f8ff", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "163f79b9", - "metadata": {}, - "outputs": [], - "source": [ - "df_bmp_recs = pd.read_csv(\"../data/hw_3/bmp_25_k60_rectools.csv\")\n", - "df_tfidf_recs = pd.read_csv(\"../data/hw_3/tfidf_k60_rectools.csv\") " - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "c842edef", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
0176549138658.899597e+101
1176549104408.153085e+102
2176549152977.204604e+103
317654937346.953473e+104
417654941514.674591e+105
2886225869726254341.615419e+1026
2886225969726211321.605160e+1027
2886226069726274761.566697e+1028
28862261697262112371.546907e+1029
28862262697262129951.542308e+1030
\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 176549 13865 8.899597e+10 1\n", - "1 176549 10440 8.153085e+10 2\n", - "2 176549 15297 7.204604e+10 3\n", - "3 176549 3734 6.953473e+10 4\n", - "4 176549 4151 4.674591e+10 5\n", - "28862258 697262 5434 1.615419e+10 26\n", - "28862259 697262 1132 1.605160e+10 27\n", - "28862260 697262 7476 1.566697e+10 28\n", - "28862261 697262 11237 1.546907e+10 29\n", - "28862262 697262 12995 1.542308e+10 30" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([df_bmp_recs.head(), df_bmp_recs.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "576f23a7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
01765491174913575.6601851
11765491627011946.8727082
21765491198511355.6931193
31765491315910375.5006474
41765491526610269.0196905
2886225869726261921294.34241426
28862259697262116401277.33233327
2886226069726274761262.91937728
28862261697262141213.49928129
2886226269726237841200.34778530
\n", - "
" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 176549 11749 13575.660185 1\n", - "1 176549 16270 11946.872708 2\n", - "2 176549 11985 11355.693119 3\n", - "3 176549 13159 10375.500647 4\n", - "4 176549 15266 10269.019690 5\n", - "28862258 697262 6192 1294.342414 26\n", - "28862259 697262 11640 1277.332333 27\n", - "28862260 697262 7476 1262.919377 28\n", - "28862261 697262 14 1213.499281 29\n", - "28862262 697262 3784 1200.347785 30" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.concat([df_tfidf_recs.head(), df_tfidf_recs.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "dc5bfdf3", - "metadata": {}, - "outputs": [], - "source": [ - "del df_tfidf_recs['rank'], df_bmp_recs['rank'], df_tfidf_recs['score'], df_bmp_recs['score']" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "edcc93bd", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_id
010975571132
110975575658
21097557142
310975573734
4109755716228
5109755712192
6109755713865
710975572657
810975579728
910975574880
10109755711778
1110975579996
1210975578636
1310975573935
1410975575803
1510975574457
1610975571844
1710975576382
1810975574716
1910975574495
4102337203734
4102337303935
4102337407417
4102337504495
4102337606382
4102337705803
4102337801844
41023379011778
4102338008636
4102338109996
4102338202657
41023383016228
4102338404880
41023385013865
410233860142
4102338706443
4102338804740
4102338906809
41023390010440
41023391014901
\n", - "
" - ], - "text/plain": [ - " user_id item_id\n", - "0 1097557 1132\n", - "1 1097557 5658\n", - "2 1097557 142\n", - "3 1097557 3734\n", - "4 1097557 16228\n", - "5 1097557 12192\n", - "6 1097557 13865\n", - "7 1097557 2657\n", - "8 1097557 9728\n", - "9 1097557 4880\n", - "10 1097557 11778\n", - "11 1097557 9996\n", - "12 1097557 8636\n", - "13 1097557 3935\n", - "14 1097557 5803\n", - "15 1097557 4457\n", - "16 1097557 1844\n", - "17 1097557 6382\n", - "18 1097557 4716\n", - "19 1097557 4495\n", - "41023372 0 3734\n", - "41023373 0 3935\n", - "41023374 0 7417\n", - "41023375 0 4495\n", - "41023376 0 6382\n", - "41023377 0 5803\n", - "41023378 0 1844\n", - "41023379 0 11778\n", - "41023380 0 8636\n", - "41023381 0 9996\n", - "41023382 0 2657\n", - "41023383 0 16228\n", - "41023384 0 4880\n", - "41023385 0 13865\n", - "41023386 0 142\n", - "41023387 0 6443\n", - "41023388 0 4740\n", - "41023389 0 6809\n", - "41023390 0 10440\n", - "41023391 0 14901" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_all_recs = pd.concat(\n", - " [\n", - " df_bmp_recs, df_tfidf_recs\n", - " ],\n", - " ignore_index=True\n", - ").sort_values(\n", - " [\"user_id\"], ascending=False\n", - ").drop_duplicates(\n", - " [\"user_id\", \"item_id\"]\n", - ").reset_index(drop=True)\n", - "\n", - "pd.concat([df_all_recs.head(20), df_all_recs.tail(20)])" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "1267df73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(15706, 2)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexidf
095067.150811
116598.524953
271075.821207
376388.407093
466867.778734
\n", - "
" - ], - "text/plain": [ - " index idf\n", - "0 9506 7.150811\n", - "1 1659 8.524953\n", - "2 7107 5.821207\n", - "3 7638 8.407093\n", - "4 6686 7.778734" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "item_idf = pd.read_csv(\"../data/kion_train/items_idf.csv\")\n", - "print(item_idf.shape)\n", - "item_idf.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "68c2c0c0", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idindexidf
141097557580358036.840585
171097557638263826.806090
251097557747674766.545666
181097557471647166.480408
34109755714146.467549
24109755711640116406.318255
211097557543454346.226266
01097557113211326.183141
10109755711778117786.134312
131097557393539356.067242
231097557619261926.038563
11097557565856586.025091
261097557378437845.990008
4109755716228162285.756312
311097557741774175.715013
321097557782978295.615193
191097557449544955.563930
22109755714431144315.558556
151097557445744575.548639
28109755712995129955.495888
410233740741774175.715013
410233610782978295.615193
410233750449544955.563930
41023358014431144315.558556
410233640445744575.548639
41023363012995129955.495888
410233780184418445.419019
41023366011237112375.365593
410233600757175715.267906
410233880474047405.078522
410233800863686365.041418
410233810999699964.992277
410233890680968094.917360
4102338601421424.801620
410233840488048804.610045
410233820265726574.392592
410233720373437344.306872
410233650415141514.111983
41023385013865138653.825227
41023390010440104403.333947
\n", - "
" - ], - "text/plain": [ - " user_id item_id index idf\n", - "14 1097557 5803 5803 6.840585\n", - "17 1097557 6382 6382 6.806090\n", - "25 1097557 7476 7476 6.545666\n", - "18 1097557 4716 4716 6.480408\n", - "34 1097557 14 14 6.467549\n", - "24 1097557 11640 11640 6.318255\n", - "21 1097557 5434 5434 6.226266\n", - "0 1097557 1132 1132 6.183141\n", - "10 1097557 11778 11778 6.134312\n", - "13 1097557 3935 3935 6.067242\n", - "23 1097557 6192 6192 6.038563\n", - "1 1097557 5658 5658 6.025091\n", - "26 1097557 3784 3784 5.990008\n", - "4 1097557 16228 16228 5.756312\n", - "31 1097557 7417 7417 5.715013\n", - "32 1097557 7829 7829 5.615193\n", - "19 1097557 4495 4495 5.563930\n", - "22 1097557 14431 14431 5.558556\n", - "15 1097557 4457 4457 5.548639\n", - "28 1097557 12995 12995 5.495888\n", - "41023374 0 7417 7417 5.715013\n", - "41023361 0 7829 7829 5.615193\n", - "41023375 0 4495 4495 5.563930\n", - "41023358 0 14431 14431 5.558556\n", - "41023364 0 4457 4457 5.548639\n", - "41023363 0 12995 12995 5.495888\n", - "41023378 0 1844 1844 5.419019\n", - "41023366 0 11237 11237 5.365593\n", - "41023360 0 7571 7571 5.267906\n", - "41023388 0 4740 4740 5.078522\n", - "41023380 0 8636 8636 5.041418\n", - "41023381 0 9996 9996 4.992277\n", - "41023389 0 6809 6809 4.917360\n", - "41023386 0 142 142 4.801620\n", - "41023384 0 4880 4880 4.610045\n", - "41023382 0 2657 2657 4.392592\n", - "41023372 0 3734 3734 4.306872\n", - "41023365 0 4151 4151 4.111983\n", - "41023385 0 13865 13865 3.825227\n", - "41023390 0 10440 10440 3.333947" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_all_recs = df_all_recs.merge(\n", - " item_idf, left_on='item_id', right_on='index', how='left'\n", - ").sort_values(['user_id', 'idf'], ascending=False)\n", - "\n", - "pd.concat([df_all_recs.head(20), df_all_recs.tail(20)])" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "36ecc1dd", - "metadata": {}, - "outputs": [], - "source": [ - "del df_all_recs['index'], df_all_recs['idf']" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "2c868313", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Количество пользователей, у которорых рекомендаций меньше 10: 21\n" - ] - } - ], - "source": [ - "count_recs_by_users = df_all_recs.user_id.value_counts()\n", - "print(f\"Количество пользователей, у которорых рекомендаций меньше 10: {len(count_recs_by_users[count_recs_by_users < 10])}\")" - ] - }, - { - "cell_type": "markdown", - "id": "9820e5ab", - "metadata": {}, - "source": [ - "Для пользователей, у которых будет меньше рекомендаций, чем k_recs, рекомендации **будут пополняться популярным**" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "69d1fc9c", - "metadata": {}, - "outputs": [], - "source": [ - "df_popular = pd.read_csv('../data/hw_3/popular_item.csv')\n", - "users_need = count_recs_by_users[count_recs_by_users < 10].index" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "7242e8e3", - "metadata": {}, - "outputs": [], - "source": [ - "k_recs = 10\n", - "users, recs = [], []\n", - "for user, count in dict(count_recs_by_users[count_recs_by_users < 10]).items():\n", - " need_recs = k_recs - count\n", - " users.extend([user for _ in range(need_recs)])\n", - " recs.extend(df_popular[\"item_id\"][:need_recs].to_list())" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "29eaaadc", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_id
010975575803
110975576382
210975577476
310975574716
4109755714
.........
4102338702657
4102338803734
4102338904151
41023390013865
41023391010440
\n", - "

41023392 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " user_id item_id\n", - "0 1097557 5803\n", - "1 1097557 6382\n", - "2 1097557 7476\n", - "3 1097557 4716\n", - "4 1097557 14\n", - "... ... ...\n", - "41023387 0 2657\n", - "41023388 0 3734\n", - "41023389 0 4151\n", - "41023390 0 13865\n", - "41023391 0 10440\n", - "\n", - "[41023392 rows x 2 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_all_recs" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "30df7064", - "metadata": {}, - "outputs": [], - "source": [ - "df_need = pd.DataFrame({\"user_id\": users, \"item_id\": recs})\n", - "df_all_recs = pd.concat([df_all_recs, df_need], ignore_index=True).sort_values(\"user_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "34f2a303", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Количество пользователей, у которорых рекомендаций меньше 10: 0\n" - ] - } - ], - "source": [ - "count_recs_by_users = df_all_recs.user_id.value_counts()\n", - "print(f\"Количество пользователей, у которорых рекомендаций меньше 10: {len(count_recs_by_users[count_recs_by_users < 10])}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "f9d7c2bc", - "metadata": {}, - "outputs": [], - "source": [ - "df_all_recs.to_csv(\"../data/hw_3/blending_tfidf_bmp25_idf_rectools.csv\", index=False)" - ] - }, - { - "cell_type": "markdown", - "id": "b8e2c037", - "metadata": {}, - "source": [ - "Offline рекомендации не работали с блендингом, решил уменьшить количество рекомендаций для одного юзера до 10 и заработало" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "6e76be8c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(9621050, 2)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_all_recs['rank'] = df_all_recs.groupby('user_id').cumcount() + 1\n", - "df_all_recs_top10 = df_all_recs[df_all_recs['rank'] <= 10]\n", - "del df_all_recs_top10['rank']\n", - "df_all_recs_top10.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "94d693eb", - "metadata": {}, - "outputs": [], - "source": [ - "df_all_recs_top10.to_csv(\"../data/hw_3/blending_tfidf_bmp25_idf_rectools_10.csv\", index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "38f155ab", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/HW-3.4-model-for-online-recs.ipynb b/notebooks/HW-3.4-model-for-online-recs.ipynb deleted file mode 100644 index fafafb1f..00000000 --- a/notebooks/HW-3.4-model-for-online-recs.ipynb +++ /dev/null @@ -1,1161 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "faa0d200", - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "54683b63", - "metadata": {}, - "outputs": [], - "source": [ - "import typing as tp\n", - "\n", - "import dill\n", - "import pandas as pd\n", - "import numpy as np\n", - "from implicit.nearest_neighbours import BM25Recommender, TFIDFRecommender\n", - "from rectools import Columns\n", - "import scipy as sp" - ] - }, - { - "cell_type": "markdown", - "id": "e00f73f1", - "metadata": {}, - "source": [ - "В ноутбуку \"HW-3.3\" c помощью стратегии валидации по неделям были отобраны несколько моделей с наиболее высокими метриками:\n", - "\n", - "- BMP25Recommender с гиперпараметром k = 60\n", - "- TFIDFRecommender с гиперпараметром k = 60\n", - "\n", - "Для этих моделей сформированы оффлайн рекомендации, которые показали 0.10384918 и 0.09577425 соответственно.\n", - "\n", - "Для формирования онлайн рекомендаций следует обучить те же архитектуры моделей с такими же гиперпараметрами из библиотеки implicit" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "4e7ecfc4", - "metadata": {}, - "outputs": [], - "source": [ - "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n", - "\n", - "interactions.rename(columns={\n", - " 'last_watch_dt': Columns.Datetime,\n", - " 'total_dur': Columns.Weight\n", - " }, \n", - " inplace=True\n", - ") \n", - "\n", - "interactions['datetime'] = pd.to_datetime(interactions['datetime'])" - ] - }, - { - "cell_type": "markdown", - "id": "55fbfe8e", - "metadata": {}, - "source": [ - "# Create train data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "57f1394b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unique users: 962179\n", - "Unique items: 15706\n" - ] - } - ], - "source": [ - "# формирование id для user и item\n", - "users_inv_mapping = dict(enumerate(interactions['user_id'].unique()))\n", - "users_mapping = {v: k for k, v in users_inv_mapping.items()}\n", - "items_inv_mapping = dict(enumerate(interactions['item_id'].unique()))\n", - "items_mapping = {v: k for k, v in items_inv_mapping.items()}\n", - "print(f\"Unique users: {len(users_inv_mapping)}\")\n", - "print(f\"Unique items: {len(items_inv_mapping)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "bc704533", - "metadata": {}, - "outputs": [], - "source": [ - "def get_matrix(\n", - " df: pd.DataFrame,\n", - " user_col: str = Columns.User,\n", - " item_col: str = Columns.Item,\n", - " weight_col: str = None,\n", - " users_mapping: tp.Dict[int, int] = None,\n", - " items_mapping: tp.Dict[int, int] = None\n", - "):\n", - "\n", - " if weight_col:\n", - " weights = df[weight_col].astype(np.float32)\n", - " else:\n", - " weights = np.ones(len(df), dtype=np.float32)\n", - "\n", - " interaction_matrix = sp.sparse.coo_matrix((\n", - " weights,\n", - " (\n", - " df[user_col].map(users_mapping.get),\n", - " df[item_col].map(items_mapping.get)\n", - " )\n", - " ))\n", - "\n", - " watched = df.groupby(user_col).agg({item_col: list})\n", - " return interaction_matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "00f46a78", - "metadata": {}, - "outputs": [], - "source": [ - "weight_matrix = get_matrix(\n", - " df=interactions,\n", - " users_mapping=users_mapping,\n", - " items_mapping=items_mapping\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "1cc272f9", - "metadata": {}, - "source": [ - "# Models train" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2d8ec8a1", - "metadata": {}, - "outputs": [], - "source": [ - "model_implicit_tfidf = TFIDFRecommender(K=60)\n", - "model_implicit_bmp25 = BM25Recommender(K=60)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "17990623", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0ce1d4c0e2184dfa8d159906a145f011", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/962179 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
item_id
user_id
0[7102, 14359, 15297, 6006, 9728, 12192]
1[3669, 10440]
2[7571, 3541, 15266, 13867, 12841, 10770, 4475,...
3[12192, 9728, 16406, 15719, 10440, 3475, 2025,...
4[4700, 6317]
1097553[24, 13058, 12463, 12659]
1097554[16361, 496, 1053, 11275, 4580, 1151, 849, 350...
1097555[14703, 140, 9728, 496, 6916, 4662, 4880]
1097556[12812]
1097557[4151, 3182, 15297]
\n", - "" - ], - "text/plain": [ - " item_id\n", - "user_id \n", - "0 [7102, 14359, 15297, 6006, 9728, 12192]\n", - "1 [3669, 10440]\n", - "2 [7571, 3541, 15266, 13867, 12841, 10770, 4475,...\n", - "3 [12192, 9728, 16406, 15719, 10440, 3475, 2025,...\n", - "4 [4700, 6317]\n", - "1097553 [24, 13058, 12463, 12659]\n", - "1097554 [16361, 496, 1053, 11275, 4580, 1151, 849, 350...\n", - "1097555 [14703, 140, 9728, 496, 6916, 4662, 4880]\n", - "1097556 [12812]\n", - "1097557 [4151, 3182, 15297]" - ] - }, - "execution_count": 66, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watched = interactions.groupby('user_id').agg({'item_id': list})\n", - "pd.concat([watched.head(), watched.tail()])" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "1460393c", - "metadata": {}, - "outputs": [], - "source": [ - "def recs_mapper(user, model, user_mapping, user_inv_mapping, k_reco: int = 10, bmp: bool = False):\n", - " user_id = user_mapping[user]\n", - " recs = model.similar_items(user_id, N=k_reco)\n", - " result = pd.DataFrame(\n", - " {\n", - " \"sim_user_id\": [user_inv_mapping[user] for user, _ in recs], \n", - " \"sim\": [sim for _, sim in recs] def\n", - " }\n", - " )\n", - " \n", - " if bmp:\n", - " return result[result['sim_user_id'] != user]\n", - " else: \n", - " return result[~(result['sim'] >= 1)] " - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "id": "011fe4fb", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "sample_users = interactions[Columns.User].sample(100).tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "3ff09747", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12861\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sim_user_idsim
17372390.479295
210695580.427898
39333710.419511
44098500.391727
59892530.384045
68176360.380609
710784200.372851
81635950.370077
910037830.368852
\n", - "
" - ], - "text/plain": [ - " sim_user_id sim\n", - "1 737239 0.479295\n", - "2 1069558 0.427898\n", - "3 933371 0.419511\n", - "4 409850 0.391727\n", - "5 989253 0.384045\n", - "6 817636 0.380609\n", - "7 1078420 0.372851\n", - "8 163595 0.370077\n", - "9 1003783 0.368852" - ] - }, - "execution_count": 99, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(sample_users[0])\n", - "df_sim = recs_mapper(sample_users[0], model_implicit_tfidf, users_mapping, users_inv_mapping)\n", - "df_sim" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "757a24ec", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sim_user_idsimitem_id
07372390.47929510755
07372390.479295496
07372390.47929512324
07372390.47929510219
07372390.4792956898
07372390.47929514476
07372390.47929513411
07372390.4792959194
07372390.4792956404
07372390.47929514961
07372390.47929512995
110695580.4278985287
110695580.42789813973
110695580.42789813865
34098500.3917277793
49892530.3840456033
49892530.384045799
49892530.3840459617
49892530.3840455405
49892530.38404513849
49892530.38404512846
58176360.3806092981
610784200.3728513935
610784200.37285110283
71635950.3700779728
\n", - "
" - ], - "text/plain": [ - " sim_user_id sim item_id\n", - "0 737239 0.479295 10755\n", - "0 737239 0.479295 496\n", - "0 737239 0.479295 12324\n", - "0 737239 0.479295 10219\n", - "0 737239 0.479295 6898\n", - "0 737239 0.479295 14476\n", - "0 737239 0.479295 13411\n", - "0 737239 0.479295 9194\n", - "0 737239 0.479295 6404\n", - "0 737239 0.479295 14961\n", - "0 737239 0.479295 12995\n", - "1 1069558 0.427898 5287\n", - "1 1069558 0.427898 13973\n", - "1 1069558 0.427898 13865\n", - "3 409850 0.391727 7793\n", - "4 989253 0.384045 6033\n", - "4 989253 0.384045 799\n", - "4 989253 0.384045 9617\n", - "4 989253 0.384045 5405\n", - "4 989253 0.384045 13849\n", - "4 989253 0.384045 12846\n", - "5 817636 0.380609 2981\n", - "6 1078420 0.372851 3935\n", - "6 1078420 0.372851 10283\n", - "7 163595 0.370077 9728" - ] - }, - "execution_count": 100, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_sim = df_sim.merge(\n", - " watched, left_on=['sim_user_id'], right_on=['user_id'], how='left'\n", - ").explode('item_id').sort_values(\n", - " [ 'sim'], ascending=False\n", - ").drop_duplicates(\n", - " ['item_id'], keep='first'\n", - ")\n", - "df_sim" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "87a9e994", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12861\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sim_user_idsim
173723985.368217
2100621675.500227
393337171.779085
498925371.633147
5106955871.248461
612473569.342326
7107842067.839079
828985467.379224
940985066.214999
\n", - "
" - ], - "text/plain": [ - " sim_user_id sim\n", - "1 737239 85.368217\n", - "2 1006216 75.500227\n", - "3 933371 71.779085\n", - "4 989253 71.633147\n", - "5 1069558 71.248461\n", - "6 124735 69.342326\n", - "7 1078420 67.839079\n", - "8 289854 67.379224\n", - "9 409850 66.214999" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(sample_users[0])\n", - "df_sim = recs_mapper(sample_users[0], model_implicit_bmp25, users_mapping, users_inv_mapping, bmp=True)\n", - "df_sim" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "2557f73f", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sim_user_idsimitem_id
073723985.36821710755
073723985.36821714961
073723985.3682176404
073723985.3682179194
073723985.36821713411
073723985.368217496
073723985.3682176898
073723985.36821710219
073723985.36821714476
073723985.36821712324
073723985.36821712995
1100621675.50022713325
1100621675.5002278891
1100621675.5002275287
398925371.63314713865
398925371.633147799
398925371.6331476033
398925371.6331479617
398925371.63314713849
398925371.63314712846
398925371.6331475405
4106955871.24846113973
512473569.3423269288
512473569.3423262100
512473569.34232614242
512473569.3423264702
6107842067.8390793935
6107842067.83907910283
6107842067.8390792981
728985467.37922416021
728985467.3792244116
728985467.37922415464
840985066.2149997793
\n", - "
" - ], - "text/plain": [ - " sim_user_id sim item_id\n", - "0 737239 85.368217 10755\n", - "0 737239 85.368217 14961\n", - "0 737239 85.368217 6404\n", - "0 737239 85.368217 9194\n", - "0 737239 85.368217 13411\n", - "0 737239 85.368217 496\n", - "0 737239 85.368217 6898\n", - "0 737239 85.368217 10219\n", - "0 737239 85.368217 14476\n", - "0 737239 85.368217 12324\n", - "0 737239 85.368217 12995\n", - "1 1006216 75.500227 13325\n", - "1 1006216 75.500227 8891\n", - "1 1006216 75.500227 5287\n", - "3 989253 71.633147 13865\n", - "3 989253 71.633147 799\n", - "3 989253 71.633147 6033\n", - "3 989253 71.633147 9617\n", - "3 989253 71.633147 13849\n", - "3 989253 71.633147 12846\n", - "3 989253 71.633147 5405\n", - "4 1069558 71.248461 13973\n", - "5 124735 69.342326 9288\n", - "5 124735 69.342326 2100\n", - "5 124735 69.342326 14242\n", - "5 124735 69.342326 4702\n", - "6 1078420 67.839079 3935\n", - "6 1078420 67.839079 10283\n", - "6 1078420 67.839079 2981\n", - "7 289854 67.379224 16021\n", - "7 289854 67.379224 4116\n", - "7 289854 67.379224 15464\n", - "8 409850 66.214999 7793" - ] - }, - "execution_count": 102, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_sim = df_sim.merge(\n", - " watched, left_on=['sim_user_id'], right_on=['user_id'], how='left'\n", - ").explode('item_id').sort_values(\n", - " [ 'sim'], ascending=False\n", - ").drop_duplicates(\n", - " ['item_id'], keep='first'\n", - ")\n", - "df_sim" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7e60a9c9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 5b7733e5f86c1737ba7f4ed7afe9b016181939e3 Mon Sep 17 00:00:00 2001 From: anettapik <120940816+anettapik@users.noreply.github.com> Date: Wed, 13 Dec 2023 13:00:13 +0300 Subject: [PATCH 6/7] Add files via upload MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hw_5_dssm.ipynb - ноутбук с обучением DSSM модели, добавлены текстовые фичи, доработан генератор с учетом просмотренного контента hw_5_recbole.ipynb - обучено несколько моделей разных архитектур с использрванием библиотеки recbole hw_5_autoencoder.ipynb - обучен автоэнкодер, подобрана архитектура и параметры --- hw_5_autoencoder.ipynb | 1922 +++++++++++++++++++ hw_5_dssm.ipynb | 4034 ++++++++++++++++++++++++++++++++++++++++ hw_5_recbool.ipynb | 1 + 3 files changed, 5957 insertions(+) create mode 100644 hw_5_autoencoder.ipynb create mode 100644 hw_5_dssm.ipynb create mode 100644 hw_5_recbool.ipynb diff --git a/hw_5_autoencoder.ipynb b/hw_5_autoencoder.ipynb new file mode 100644 index 00000000..750a793d --- /dev/null +++ b/hw_5_autoencoder.ipynb @@ -0,0 +1,1922 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7_8DlX_2jZzT", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:32:12.524590Z", + "iopub.status.busy": "2023-01-22T12:32:12.523513Z", + "iopub.status.idle": "2023-01-22T12:32:12.529931Z", + "shell.execute_reply": "2023-01-22T12:32:12.528298Z", + "shell.execute_reply.started": "2023-01-22T12:32:12.524533Z" + }, + "id": "7_8DlX_2jZzT" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import os\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "IczRXBXHjZzV", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:32:13.867299Z", + "iopub.status.busy": "2023-01-22T12:32:13.866000Z", + "iopub.status.idle": "2023-01-22T12:32:16.353124Z", + "shell.execute_reply": "2023-01-22T12:32:16.352004Z", + "shell.execute_reply.started": "2023-01-22T12:32:13.867251Z" + }, + "id": "IczRXBXHjZzV" + }, + "outputs": [], + "source": [ + "from IPython.display import display, clear_output\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm.notebook import tqdm\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import Dataset, DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "mA1MfXOnjZzW", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:13.399626Z", + "iopub.status.busy": "2023-01-22T12:41:13.398452Z", + "iopub.status.idle": "2023-01-22T12:41:19.723408Z", + "shell.execute_reply": "2023-01-22T12:41:19.722114Z", + "shell.execute_reply.started": "2023-01-22T12:41:13.399496Z" + }, + "id": "mA1MfXOnjZzW" + }, + "outputs": [], + "source": [ + "interactions_df = pd.read_csv('interactions_processed_kion.csv')\n", + "users_df = pd.read_csv('users_processed_kion.csv')\n", + "items_df = pd.read_csv('items_processed_kion.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "G5cP9QcUjZzW", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:19.726341Z", + "iopub.status.busy": "2023-01-22T12:41:19.725645Z", + "iopub.status.idle": "2023-01-22T12:41:19.751544Z", + "shell.execute_reply": "2023-01-22T12:41:19.750286Z", + "shell.execute_reply.started": "2023-01-22T12:41:19.726296Z" + }, + "id": "G5cP9QcUjZzW", + "outputId": "9fc311f0-6f5b-4327-9bbc-1f6bdee3f918" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idlast_watch_dttotal_durwatched_pct
017654995062021-05-11425072
169931716592021-05-298317100
265668371072021-05-09100
386461376382021-07-0514483100
496486895062021-04-306725100
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " user_id item_id last_watch_dt total_dur watched_pct\n", + "0 176549 9506 2021-05-11 4250 72\n", + "1 699317 1659 2021-05-29 8317 100\n", + "2 656683 7107 2021-05-09 10 0\n", + "3 864613 7638 2021-07-05 14483 100\n", + "4 964868 9506 2021-04-30 6725 100" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b4omWvMOjZzX", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:21.721270Z", + "iopub.status.busy": "2023-01-22T12:41:21.720745Z", + "iopub.status.idle": "2023-01-22T12:41:22.116852Z", + "shell.execute_reply": "2023-01-22T12:41:22.115397Z", + "shell.execute_reply.started": "2023-01-22T12:41:21.721229Z" + }, + "id": "b4omWvMOjZzX" + }, + "outputs": [], + "source": [ + "interactions_df = interactions_df[interactions_df['last_watch_dt'] < '2021-04-01']" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "JAuH-fG0jZzX", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:23.195240Z", + "iopub.status.busy": "2023-01-22T12:41:23.194661Z", + "iopub.status.idle": "2023-01-22T12:41:23.202760Z", + "shell.execute_reply": "2023-01-22T12:41:23.201745Z", + "shell.execute_reply.started": "2023-01-22T12:41:23.195188Z" + }, + "id": "JAuH-fG0jZzX", + "outputId": "3e259108-40e9-48fc-c212-1bbd7e9e21a1" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(263874, 5)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "rWCoSNwWjZzX", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:25.368988Z", + "iopub.status.busy": "2023-01-22T12:41:25.367925Z", + "iopub.status.idle": "2023-01-22T12:41:25.558751Z", + "shell.execute_reply": "2023-01-22T12:41:25.557372Z", + "shell.execute_reply.started": "2023-01-22T12:41:25.368937Z" + }, + "id": "rWCoSNwWjZzX", + "outputId": "c7738105-161d-4c43-9028-21b64809ad04" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# users: 86614\n", + "# users with at least 5 interactions: 14563\n" + ] + } + ], + "source": [ + "users_interactions_count_df = interactions_df.groupby(['user_id', 'item_id']).size().groupby('user_id').size()\n", + "print('# users: %d' % len(users_interactions_count_df))\n", + "users_with_enough_interactions_df = users_interactions_count_df[users_interactions_count_df >= 5].reset_index()[['user_id']]\n", + "print('# users with at least 5 interactions: %d' % len(users_with_enough_interactions_df))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "qDCcr1_UjZzY", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:27.227318Z", + "iopub.status.busy": "2023-01-22T12:41:27.226717Z", + "iopub.status.idle": "2023-01-22T12:41:27.326827Z", + "shell.execute_reply": "2023-01-22T12:41:27.325761Z", + "shell.execute_reply.started": "2023-01-22T12:41:27.227269Z" + }, + "id": "qDCcr1_UjZzY", + "outputId": "cc44175d-eef0-42b9-839a-4c1a0efa5ce4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# of interactions: 263874\n", + "# of interactions from users with at least 5 interactions: 142670\n" + ] + } + ], + "source": [ + "print('# of interactions: %d' % len(interactions_df))\n", + "interactions_from_selected_users_df = interactions_df.merge(users_with_enough_interactions_df, \n", + " how = 'right',\n", + " left_on = 'user_id',\n", + " right_on = 'user_id')\n", + "print('# of interactions from users with at least 5 interactions: %d' % len(interactions_from_selected_users_df))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bs9IdB8fjZzY", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:30.431311Z", + "iopub.status.busy": "2023-01-22T12:41:30.430823Z", + "iopub.status.idle": "2023-01-22T12:41:30.436607Z", + "shell.execute_reply": "2023-01-22T12:41:30.435654Z", + "shell.execute_reply.started": "2023-01-22T12:41:30.431275Z" + }, + "id": "bs9IdB8fjZzY" + }, + "outputs": [], + "source": [ + "import math" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "MTW_Y4iOjZzY", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 376 + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:32.237281Z", + "iopub.status.busy": "2023-01-22T12:41:32.236079Z", + "iopub.status.idle": "2023-01-22T12:41:32.403346Z", + "shell.execute_reply": "2023-01-22T12:41:32.401909Z", + "shell.execute_reply.started": "2023-01-22T12:41:32.237217Z" + }, + "id": "MTW_Y4iOjZzY", + "outputId": "8027a8a8-0bf9-4c97-9b69-5281653709c9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# of unique user/item interactions: 142670\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idwatched_pct
0218496.375039
12143456.658211
221102836.658211
321122616.658211
421159976.658211
5329526.044394
63243824.954196
73248076.658211
832104366.658211
932121326.658211
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " user_id item_id watched_pct\n", + "0 21 849 6.375039\n", + "1 21 4345 6.658211\n", + "2 21 10283 6.658211\n", + "3 21 12261 6.658211\n", + "4 21 15997 6.658211\n", + "5 32 952 6.044394\n", + "6 32 4382 4.954196\n", + "7 32 4807 6.658211\n", + "8 32 10436 6.658211\n", + "9 32 12132 6.658211" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def smooth_user_preference(x):\n", + " return math.log(1+x, 2)\n", + " \n", + "interactions_full_df = interactions_from_selected_users_df \\\n", + " .groupby(['user_id', 'item_id'])['watched_pct'].sum() \\\n", + " .apply(smooth_user_preference).reset_index()\n", + "print('# of unique user/item interactions: %d' % len(interactions_full_df))\n", + "interactions_full_df.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "wNyqdsCxjZzZ", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:41:34.443808Z", + "iopub.status.busy": "2023-01-22T12:41:34.443346Z", + "iopub.status.idle": "2023-01-22T12:41:34.651267Z", + "shell.execute_reply": "2023-01-22T12:41:34.650080Z", + "shell.execute_reply.started": "2023-01-22T12:41:34.443774Z" + }, + "id": "wNyqdsCxjZzZ", + "outputId": "e2a2e169-78ef-4f8e-c099-eb56de49338e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# interactions on Train set: 114136\n", + "# interactions on Test set: 28534\n" + ] + } + ], + "source": [ + "interactions_train_df, interactions_test_df = train_test_split(interactions_full_df,\n", + " stratify=interactions_full_df['user_id'], \n", + " test_size=0.20,\n", + " random_state=42)\n", + "\n", + "print('# interactions on Train set: %d' % len(interactions_train_df))\n", + "print('# interactions on Test set: %d' % len(interactions_test_df))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "v1M9fBagjZzZ", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:38.570246Z", + "iopub.status.busy": "2023-01-22T12:41:38.568905Z", + "iopub.status.idle": "2023-01-22T12:41:38.583223Z", + "shell.execute_reply": "2023-01-22T12:41:38.581705Z", + "shell.execute_reply.started": "2023-01-22T12:41:38.570182Z" + }, + "id": "v1M9fBagjZzZ" + }, + "outputs": [], + "source": [ + "#Indexing by personId to speed up the searches during evaluation\n", + "interactions_full_indexed_df = interactions_full_df.set_index('user_id')\n", + "interactions_train_indexed_df = interactions_train_df.set_index('user_id')\n", + "interactions_test_indexed_df = interactions_test_df.set_index('user_id')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "Ra2TntFUjZzZ", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:42.934656Z", + "iopub.status.busy": "2023-01-22T12:41:42.934139Z", + "iopub.status.idle": "2023-01-22T12:41:42.940917Z", + "shell.execute_reply": "2023-01-22T12:41:42.939611Z", + "shell.execute_reply.started": "2023-01-22T12:41:42.934617Z" + }, + "id": "Ra2TntFUjZzZ" + }, + "outputs": [], + "source": [ + "def get_items_interacted(person_id, interactions_df):\n", + " # Get the user's data and merge in the movie information.\n", + " interacted_items = interactions_df.loc[person_id]['item_id']\n", + " return set(interacted_items if type(interacted_items) == pd.Series else [interacted_items])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "xpP7YjhRjZzZ", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:53.435832Z", + "iopub.status.busy": "2023-01-22T12:41:53.435366Z", + "iopub.status.idle": "2023-01-22T12:41:53.455616Z", + "shell.execute_reply": "2023-01-22T12:41:53.454525Z", + "shell.execute_reply.started": "2023-01-22T12:41:53.435796Z" + }, + "id": "xpP7YjhRjZzZ" + }, + "outputs": [], + "source": [ + "#Top-N accuracy metrics consts\n", + "EVAL_RANDOM_SAMPLE_NON_INTERACTED_ITEMS = 100\n", + "\n", + "class ModelEvaluator:\n", + "\n", + " def get_not_interacted_items_sample(self, person_id, sample_size, seed=42):\n", + " interacted_items = get_items_interacted(person_id, interactions_full_indexed_df)\n", + " all_items = set(articles_df['item_id'])\n", + " non_interacted_items = all_items - interacted_items\n", + "\n", + " random.seed(seed)\n", + " non_interacted_items_sample = random.sample(non_interacted_items, sample_size)\n", + " return set(non_interacted_items_sample)\n", + "\n", + " def _verify_hit_top_n(self, item_id, recommended_items, topn): \n", + " try:\n", + " index = next(i for i, c in enumerate(recommended_items) if c == item_id)\n", + " except:\n", + " index = -1\n", + " hit = int(index in range(0, topn))\n", + " return hit, index\n", + "\n", + " def evaluate_model_for_user(self, model, person_id):\n", + " #Getting the items in test set\n", + " interacted_values_testset = interactions_test_indexed_df.loc[person_id]\n", + " if type(interacted_values_testset['item_id']) == pd.Series:\n", + " person_interacted_items_testset = set(interacted_values_testset['item_id'])\n", + " else:\n", + " person_interacted_items_testset = set([int(interacted_values_testset['item_id'])]) \n", + " interacted_items_count_testset = len(person_interacted_items_testset) \n", + "\n", + " #Getting a ranked recommendation list from a model for a given user\n", + " person_recs_df = model.recommend_items(person_id, \n", + " items_to_ignore=get_items_interacted(person_id, \n", + " interactions_train_indexed_df), \n", + " topn=10000000000)\n", + "\n", + " hits_at_5_count = 0\n", + " hits_at_10_count = 0\n", + " #For each item the user has interacted in test set\n", + " for item_id in person_interacted_items_testset:\n", + " #Getting a random sample (100) items the user has not interacted \n", + " #(to represent items that are assumed to be no relevant to the user)\n", + " non_interacted_items_sample = self.get_not_interacted_items_sample(person_id, \n", + " sample_size=EVAL_RANDOM_SAMPLE_NON_INTERACTED_ITEMS, \n", + " seed=item_id%(2**32))\n", + "\n", + " #Combining the current interacted item with the 100 random items\n", + " items_to_filter_recs = non_interacted_items_sample.union(set([item_id]))\n", + "\n", + " #Filtering only recommendations that are either the interacted item or from a random sample of 100 non-interacted items\n", + " valid_recs_df = person_recs_df[person_recs_df['item_id'].isin(items_to_filter_recs)] \n", + " valid_recs = valid_recs_df['item_id'].values\n", + " #Verifying if the current interacted item is among the Top-N recommended items\n", + " hit_at_5, index_at_5 = self._verify_hit_top_n(item_id, valid_recs, 5)\n", + " hits_at_5_count += hit_at_5\n", + " hit_at_10, index_at_10 = self._verify_hit_top_n(item_id, valid_recs, 10)\n", + " hits_at_10_count += hit_at_10\n", + "\n", + " #Recall is the rate of the interacted items that are ranked among the Top-N recommended items, \n", + " #when mixed with a set of non-relevant items\n", + " recall_at_5 = hits_at_5_count / float(interacted_items_count_testset)\n", + " recall_at_10 = hits_at_10_count / float(interacted_items_count_testset)\n", + "\n", + " person_metrics = {'hits@5_count':hits_at_5_count, \n", + " 'hits@10_count':hits_at_10_count, \n", + " 'interacted_count': interacted_items_count_testset,\n", + " 'recall@5': recall_at_5,\n", + " 'recall@10': recall_at_10}\n", + " return person_metrics\n", + "\n", + " def evaluate_model(self, model):\n", + " #print('Running evaluation for users')\n", + " people_metrics = []\n", + " for idx, person_id in enumerate(tqdm(list(interactions_test_indexed_df.index.unique().values))):\n", + " #if idx % 100 == 0 and idx > 0:\n", + " # print('%d users processed' % idx)\n", + " person_metrics = self.evaluate_model_for_user(model, person_id) \n", + " person_metrics['user_id'] = person_id\n", + " people_metrics.append(person_metrics)\n", + " print('%d users processed' % idx)\n", + "\n", + " detailed_results_df = pd.DataFrame(people_metrics) \\\n", + " .sort_values('interacted_count', ascending=False)\n", + " \n", + " global_recall_at_5 = detailed_results_df['hits@5_count'].sum() / float(detailed_results_df['interacted_count'].sum())\n", + " global_recall_at_10 = detailed_results_df['hits@10_count'].sum() / float(detailed_results_df['interacted_count'].sum())\n", + " \n", + " global_metrics = {'modelName': model.get_model_name(),\n", + " 'recall@5': global_recall_at_5,\n", + " 'recall@10': global_recall_at_10} \n", + " return global_metrics, detailed_results_df\n", + " \n", + "model_evaluator = ModelEvaluator() " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bt-Ko_HMjZza", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:41:57.779034Z", + "iopub.status.busy": "2023-01-22T12:41:57.777417Z", + "iopub.status.idle": "2023-01-22T12:41:57.787389Z", + "shell.execute_reply": "2023-01-22T12:41:57.785909Z", + "shell.execute_reply.started": "2023-01-22T12:41:57.778960Z" + }, + "id": "bt-Ko_HMjZza" + }, + "outputs": [], + "source": [ + "from IPython.display import display, clear_output\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm.notebook import tqdm\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import Dataset, DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6ySqiCo5jZza", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:42:03.271305Z", + "iopub.status.busy": "2023-01-22T12:42:03.270810Z", + "iopub.status.idle": "2023-01-22T12:42:03.278535Z", + "shell.execute_reply": "2023-01-22T12:42:03.277141Z", + "shell.execute_reply.started": "2023-01-22T12:42:03.271268Z" + }, + "id": "6ySqiCo5jZza" + }, + "outputs": [], + "source": [ + "\n", + "# Constants\n", + "SEED = 42 # random seed for reproducibility\n", + "LR = 1e-3 # learning rate, controls the speed of the training\n", + "WEIGHT_DECAY = 0.01 # lambda for L2 reg. ()\n", + "NUM_EPOCHS = 200 # num training epochs (how many times each instance will be processed)\n", + "GAMMA = 0.9995 # learning rate scheduler parameter\n", + "BATCH_SIZE = 3000 # training batch size\n", + "EVAL_BATCH_SIZE = 3000 # evaluation batch size.\n", + "DEVICE = 'cuda' #'cuda' # device to make the calculations on" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "FtzzvibljZza", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:42:05.933060Z", + "iopub.status.busy": "2023-01-22T12:42:05.931911Z", + "iopub.status.idle": "2023-01-22T12:42:05.969002Z", + "shell.execute_reply": "2023-01-22T12:42:05.967458Z", + "shell.execute_reply.started": "2023-01-22T12:42:05.933000Z" + }, + "id": "FtzzvibljZza" + }, + "outputs": [], + "source": [ + "total_df = interactions_train_df.append(interactions_test_indexed_df.reset_index())\n", + "total_df['user_id'], users_keys = total_df.user_id.factorize()\n", + "total_df['item_id'], items_keys = total_df.item_id.factorize()\n", + "\n", + "train_encoded = total_df.iloc[:len(interactions_train_df)].values\n", + "test_encoded = total_df.iloc[len(interactions_train_df):].values" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "crbEdHiJjZza", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:42:09.354000Z", + "iopub.status.busy": "2023-01-22T12:42:09.352465Z", + "iopub.status.idle": "2023-01-22T12:42:09.967185Z", + "shell.execute_reply": "2023-01-22T12:42:09.965725Z", + "shell.execute_reply.started": "2023-01-22T12:42:09.353932Z" + }, + "id": "crbEdHiJjZza" + }, + "outputs": [], + "source": [ + "from scipy.sparse import csr_matrix\n", + "shape = [int(total_df['user_id'].max()+1), int(total_df['item_id'].max()+1)]\n", + "X_train = csr_matrix((train_encoded[:, 2], (train_encoded[:, 0], train_encoded[:, 1])), shape=shape).toarray()\n", + "X_test = csr_matrix((test_encoded[:, 2], (test_encoded[:, 0], test_encoded[:, 1])), shape=shape).toarray()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "sFeJZsDJjZzb", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:42:12.745785Z", + "iopub.status.busy": "2023-01-22T12:42:12.745283Z", + "iopub.status.idle": "2023-01-22T12:42:12.754320Z", + "shell.execute_reply": "2023-01-22T12:42:12.752855Z", + "shell.execute_reply.started": "2023-01-22T12:42:12.745745Z" + }, + "id": "sFeJZsDJjZzb" + }, + "outputs": [], + "source": [ + "# Initialize the DataObject, which must return an element (features vector x and target value y)\n", + "# for a given idx. This class must also have a length atribute\n", + "class UserOrientedDataset(Dataset):\n", + " def __init__(self, X):\n", + " super().__init__() # to initialize the parent class\n", + " self.X = X.astype(np.float32)\n", + " self.len = len(X)\n", + "\n", + " def __len__(self): # We use __func__ for implementing in-built python functions\n", + " return self.len\n", + "\n", + " def __getitem__(self, index):\n", + " return self.X[index]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "AoCCUSpUjZzb", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:42:16.254953Z", + "iopub.status.busy": "2023-01-22T12:42:16.254416Z", + "iopub.status.idle": "2023-01-22T12:42:17.434704Z", + "shell.execute_reply": "2023-01-22T12:42:17.433103Z", + "shell.execute_reply.started": "2023-01-22T12:42:16.254903Z" + }, + "id": "AoCCUSpUjZzb" + }, + "outputs": [], + "source": [ + "# Initialize DataLoaders - objects, which sample instances from DataObject-s\n", + "train_dl = DataLoader(\n", + " UserOrientedDataset(X_train),\n", + " batch_size = BATCH_SIZE,\n", + " shuffle = True\n", + ")\n", + "\n", + "test_dl = DataLoader(\n", + " UserOrientedDataset(X_test),\n", + " batch_size = EVAL_BATCH_SIZE,\n", + " shuffle = False\n", + ")\n", + "\n", + "dls = {'train': train_dl, 'test': test_dl}" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "b94CXGocjZzb", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:53:12.965059Z", + "iopub.status.busy": "2023-01-22T12:53:12.964527Z", + "iopub.status.idle": "2023-01-22T12:53:12.975037Z", + "shell.execute_reply": "2023-01-22T12:53:12.973690Z", + "shell.execute_reply.started": "2023-01-22T12:53:12.965016Z" + }, + "id": "b94CXGocjZzb" + }, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, in_and_out_features = 8287):\n", + " super().__init__()\n", + " self.in_and_out_features = in_and_out_features\n", + " self.hidden_size = 500\n", + "\n", + " self.sequential = nn.Sequential( \n", + " nn.Linear(in_and_out_features, self.hidden_size), \n", + " nn.ReLU(), \n", + " nn.Linear(self.hidden_size, in_and_out_features) # Another Linear transformation\n", + " )\n", + "\n", + " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n", + " x = self.sequential(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "aY_vqVZLjZzb", + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:54:25.315144Z", + "iopub.status.busy": "2023-01-22T12:54:25.314623Z", + "iopub.status.idle": "2023-01-22T12:54:26.136714Z", + "shell.execute_reply": "2023-01-22T12:54:26.135715Z", + "shell.execute_reply.started": "2023-01-22T12:54:25.315101Z" + }, + "id": "aY_vqVZLjZzb" + }, + "outputs": [], + "source": [ + "torch.manual_seed(SEED) # Fix random seed to have reproducible weights of model layers\n", + "\n", + "model = Model()\n", + "model.to(DEVICE)\n", + "\n", + "# Initialize GD method, which will update the weights of the model\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "# Initialize learning rate scheduler, which will decrease LR according to some rule\n", + "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)\n", + "\n", + "def rmse_for_sparse(x_pred, x_true):\n", + " mask = (x_true > 0)\n", + " sq_diff = (x_pred * mask - x_true) ** 2\n", + " mse = sq_diff.sum() / mask.sum()\n", + " return mse ** (1/2)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "LdlKerxfjZzb", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 419 + }, + "execution": { + "iopub.execute_input": "2023-01-22T12:54:33.544338Z", + "iopub.status.busy": "2023-01-22T12:54:33.543734Z" + }, + "id": "LdlKerxfjZzb", + "outputId": "0bc103bb-151d-449f-b7b2-f670b8970d92" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTrain RMSETest RMSE
002.3150152.295504
112.1916362.224912
221.9554972.108439
331.8361192.027701
441.7367832.026640
............
1951950.2886581.330020
1961960.2779171.331115
1971970.3070821.330125
1981980.3029801.331673
1991990.3073371.329725
\n", + "

200 rows × 3 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " Epoch Train RMSE Test RMSE\n", + "0 0 2.315015 2.295504\n", + "1 1 2.191636 2.224912\n", + "2 2 1.955497 2.108439\n", + "3 3 1.836119 2.027701\n", + "4 4 1.736783 2.026640\n", + ".. ... ... ...\n", + "195 195 0.288658 1.330020\n", + "196 196 0.277917 1.331115\n", + "197 197 0.307082 1.330125\n", + "198 198 0.302980 1.331673\n", + "199 199 0.307337 1.329725\n", + "\n", + "[200 rows x 3 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Training loop\n", + "metrics_dict = {\n", + " \"Epoch\": [],\n", + " \"Train RMSE\": [],\n", + " \"Test RMSE\": [],\n", + "}\n", + "\n", + "# Train loop\n", + "for epoch in range(NUM_EPOCHS):\n", + " metrics_dict[\"Epoch\"].append(epoch)\n", + " for stage in ['train', 'test']:\n", + " with torch.set_grad_enabled(stage == 'train'): # Whether to start building a graph for a backward pass\n", + " if stage == 'train':\n", + " model.train() # Enable some \"special\" layers (will speak about later)\n", + " else:\n", + " model.eval() # Disable some \"special\" layers (will speak about later)\n", + "\n", + " loss_at_stage = 0 \n", + " for batch in dls[stage]:\n", + " batch = batch.to(DEVICE)\n", + " x_pred = model(batch) # forward pass: model(x_batch) -> calls forward()\n", + " loss = rmse_for_sparse(x_pred, batch) # ¡Important! y_pred is always the first arg\n", + " if stage == \"train\":\n", + " loss.backward() # Calculate the gradients of all the parameters wrt loss\n", + " optimizer.step() # Update the parameters\n", + " scheduler.step()\n", + " optimizer.zero_grad() # Zero the saved gradient\n", + " loss_at_stage += loss.item() * len(batch)\n", + " rmse_at_stage = (loss_at_stage / len(dls[stage].dataset)) ** (1/2)\n", + " metrics_dict[f\"{stage.title()} RMSE\"].append(rmse_at_stage)\n", + " \n", + " if (epoch == NUM_EPOCHS - 1) or epoch % 10 == 9:\n", + " clear_output(wait=True)\n", + " display(pd.DataFrame(metrics_dict))" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ZXCPjyMajZzb", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXCPjyMajZzb", + "outputId": "a0448c4f-5e53-409b-c277-fc704e617202" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.3084, 2.5601, 1.0144, ..., -0.0948, -0.1467, 0.3106],\n", + " [ 0.1575, 0.8934, 0.1315, ..., -0.1049, 0.0096, 0.0350],\n", + " [ 0.6704, 1.5142, 0.6962, ..., -0.2259, -0.0353, 0.0676],\n", + " ...,\n", + " [ 0.3153, 1.1243, 0.1393, ..., -0.1222, -0.1398, 0.0617],\n", + " [ 0.3214, 1.9313, 0.3253, ..., -0.1548, -0.0918, -0.0392],\n", + " [ 0.3434, 0.9318, -0.0341, ..., -0.1714, -0.0446, 0.1267]],\n", + " device='cuda:0')" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.no_grad():\n", + " X_pred = model(torch.Tensor(X_test).to(DEVICE))\n", + "X_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "bkSfO9fgjZzc", + "metadata": { + "id": "bkSfO9fgjZzc" + }, + "outputs": [], + "source": [ + "class AERecommender:\n", + " \n", + " MODEL_NAME = 'Autoencoder'\n", + " \n", + " def __init__(self, X_preds, X_train_and_val, X_test):\n", + "\n", + " self.X_preds = X_preds.cpu().detach().numpy()\n", + " self.X_train_and_val = X_train_and_val\n", + " self.X_test = X_test\n", + " \n", + " def get_model_name(self):\n", + " return self.MODEL_NAME\n", + " \n", + " def recommend_items(self, user_id, items_to_select_idx, topn=10, verbose=False):\n", + " user_preds = self.X_preds[user_id][items_to_select_idx]\n", + " items_idx = items_to_select_idx[np.argsort(-user_preds)[:topn]]\n", + "\n", + " # Recommend the highest predicted rating movies that the user hasn't seen yet.\n", + " return items_idx\n", + "\n", + " def evaluate(self, size=100):\n", + "\n", + " X_total = self.X_train_and_val + self.X_test\n", + "\n", + " true_5 = []\n", + " true_10 = []\n", + "\n", + " for user_id in range(len(X_test)):\n", + " non_zero = np.argwhere(self.X_test[user_id] > 0).ravel()\n", + " all_nonzero = np.argwhere(X_total[user_id] > 0).ravel()\n", + " select_from = np.setdiff1d(np.arange(X_total.shape[1]), all_nonzero)\n", + "\n", + " for non_zero_idx in non_zero:\n", + " random_non_interacted_100_items = np.random.choice(select_from, size=20, replace=False)\n", + " preds = self.recommend_items(user_id, np.append(random_non_interacted_100_items, non_zero_idx), topn=10)\n", + " true_5.append(non_zero_idx in preds[:5])\n", + " true_10.append(non_zero_idx in preds)\n", + "\n", + " return {\"recall@5\": np.mean(true_5), \"recall@10\": np.mean(true_10)}\n", + " \n", + "ae_recommender_model = AERecommender(X_pred, X_train, X_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "yRBbD9xmjZzc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yRBbD9xmjZzc", + "outputId": "d407d2b7-ee44-4299-9b29-046f41deb396" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'recall@5': 0.08641891035330142, 'recall@10': 0.25274264483602643}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ae_global_metrics = ae_recommender_model.evaluate()\n", + "ae_global_metrics" + ] + }, + { + "cell_type": "markdown", + "id": "ydc-4MJn-KFM", + "metadata": { + "id": "ydc-4MJn-KFM" + }, + "source": [ + "Проведем эксперименты с моделями и гиперпараметрами" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "GZfxQH7Z-hMK", + "metadata": { + "id": "GZfxQH7Z-hMK" + }, + "outputs": [], + "source": [ + "def train_model():\n", + " torch.manual_seed(SEED) # Fix random seed to have reproducible weights of model layers\n", + "\n", + " model = Model()\n", + " model.to(DEVICE)\n", + "\n", + " # Initialize GD method, which will update the weights of the model\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + " # Initialize learning rate scheduler, which will decrease LR according to some rule\n", + " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)\n", + "\n", + "\n", + " # Training loop\n", + " metrics_dict = {\n", + " \"Epoch\": [],\n", + " \"Train RMSE\": [],\n", + " \"Test RMSE\": [],\n", + " }\n", + "\n", + " # Train loop\n", + " for epoch in range(NUM_EPOCHS):\n", + " metrics_dict[\"Epoch\"].append(epoch)\n", + " for stage in ['train', 'test']:\n", + " with torch.set_grad_enabled(stage == 'train'): # Whether to start building a graph for a backward pass\n", + " if stage == 'train':\n", + " model.train() # Enable some \"special\" layers (will speak about later)\n", + " else:\n", + " model.eval() # Disable some \"special\" layers (will speak about later)\n", + "\n", + " loss_at_stage = 0 \n", + " for batch in dls[stage]:\n", + " batch = batch.to(DEVICE)\n", + " x_pred = model(batch) # forward pass: model(x_batch) -> calls forward()\n", + " loss = rmse_for_sparse(x_pred, batch) # ¡Important! y_pred is always the first arg\n", + " if stage == \"train\":\n", + " loss.backward() # Calculate the gradients of all the parameters wrt loss\n", + " optimizer.step() # Update the parameters\n", + " scheduler.step()\n", + " optimizer.zero_grad() # Zero the saved gradient\n", + " loss_at_stage += loss.item() * len(batch)\n", + " rmse_at_stage = (loss_at_stage / len(dls[stage].dataset)) ** (1/2)\n", + " metrics_dict[f\"{stage.title()} RMSE\"].append(rmse_at_stage)\n", + " \n", + " with torch.no_grad():\n", + " X_pred = model(torch.Tensor(X_test).to(DEVICE))\n", + "\n", + " ae_recommender_model = AERecommender(X_pred, X_train, X_train)\n", + "\n", + " ae_global_metrics = ae_recommender_model.evaluate()\n", + "\n", + " metrics_dict[\"recall@5\"] = ae_global_metrics[\"recall@5\"]\n", + " metrics_dict[\"recall@10\"] = ae_global_metrics[\"recall@10\"]\n", + "\n", + "\n", + " return metrics_dict" + ] + }, + { + "cell_type": "markdown", + "id": "iYS06bYkA5uD", + "metadata": { + "id": "iYS06bYkA5uD" + }, + "source": [ + "C изначальной архитектурой" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "s69HDH9P-PZl", + "metadata": { + "id": "s69HDH9P-PZl" + }, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, in_and_out_features = 8287):\n", + " super().__init__()\n", + " self.in_and_out_features = in_and_out_features\n", + " self.hidden_size = 500\n", + "\n", + " self.sequential = nn.Sequential( \n", + " nn.Linear(in_and_out_features, self.hidden_size), \n", + " nn.ReLU(), \n", + " nn.Linear(self.hidden_size, in_and_out_features) # Another Linear transformation\n", + " )\n", + "\n", + " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n", + " x = self.sequential(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "TytUsH6vA9Wo", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TytUsH6vA9Wo", + "outputId": "464573c4-6c3f-4b04-ac32-3ea09fb84f08" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.001 ne:0.001 bs:3000 ....\n", + "lr:0.001 ne:0.001 bs:4500 ....\n", + "lr:0.001 ne:0.001 bs:3000 ....\n", + "lr:0.001 ne:0.001 bs:4500 ....\n", + "lr:0.0003 ne:0.0003 bs:3000 ....\n", + "lr:0.0003 ne:0.0003 bs:4500 ....\n", + "lr:0.0003 ne:0.0003 bs:3000 ....\n", + "lr:0.0003 ne:0.0003 bs:4500 ....\n" + ] + } + ], + "source": [ + "first_arch_metrics = {}\n", + "\n", + "for lr in [0.001, 0.0003]:\n", + " for ne in [100, 200]:\n", + " for bs in [3000, 4500]:\n", + " \n", + " print(f\"lr:{lr} ne:{lr} bs:{bs} ....\" )\n", + "\n", + " LR = lr\n", + " NUM_EPOCHS = ne\n", + " BATCH_SIZE = bs\n", + "\n", + " first_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "HnEm5GLZDAuC", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HnEm5GLZDAuC", + "outputId": "d652f3d6-c81a-4e35-e070-7350ce956120" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.001 ne:100 bs:3000 0.0856485318926931 0.24734999561176826\n", + "lr:0.001 ne:100 bs:4500 0.08339590626737009 0.2456629642992969\n", + "lr:0.001 ne:200 bs:3000 0.08698450466615308 0.2526061220708553\n", + "lr:0.001 ne:200 bs:4500 0.0867699688923128 0.2529181741055321\n", + "lr:0.0003 ne:100 bs:3000 0.08751109247467015 0.25363979443572215\n", + "lr:0.0003 ne:100 bs:4500 0.08879830711771187 0.25470272167883995\n", + "lr:0.0003 ne:200 bs:3000 0.08135781641588735 0.23005061093937415\n", + "lr:0.0003 ne:200 bs:4500 0.08193316235482266 0.23188391664310024\n" + ] + } + ], + "source": [ + "for i in first_arch_metrics.keys():\n", + " print(i, first_arch_metrics[i]['recall@5'], first_arch_metrics[i]['recall@10'])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "ZaDleIoiPyCJ", + "metadata": { + "id": "ZaDleIoiPyCJ" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "DOeEZG5_A9p4", + "metadata": { + "id": "DOeEZG5_A9p4" + }, + "source": [ + "Усложним архитектуру" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "rTOhYgiX-fEr", + "metadata": { + "id": "rTOhYgiX-fEr" + }, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, in_and_out_features = 8287):\n", + " super().__init__()\n", + " self.in_and_out_features = in_and_out_features\n", + " self.hidden_size = 512\n", + "\n", + " self.sequential = nn.Sequential( \n", + " nn.Linear(in_and_out_features, 4096), \n", + " nn.ReLU(), \n", + "\n", + " nn.Linear(4096, self.hidden_size), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(self.hidden_size, 4096), \n", + " nn.ReLU(), \n", + "\n", + " nn.Linear(4096, in_and_out_features) # Another Linear transformation\n", + " )\n", + "\n", + " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n", + " x = self.sequential(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "TPRykgiN-fNe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TPRykgiN-fNe", + "outputId": "a32247b1-3a86-479d-aa0f-df75119e18e2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.001 ne:100 bs:3000 ....\n", + "lr:0.001 ne:100 bs:4500 ....\n", + "lr:0.001 ne:200 bs:3000 ....\n", + "lr:0.001 ne:200 bs:4500 ....\n", + "lr:0.0003 ne:100 bs:3000 ....\n", + "lr:0.0003 ne:100 bs:4500 ....\n", + "lr:0.0003 ne:200 bs:3000 ....\n", + "lr:0.0003 ne:200 bs:4500 ....\n" + ] + } + ], + "source": [ + "second_arch_metrics = {}\n", + "\n", + "for lr in [0.001, 0.0003]:\n", + " for ne in [100, 200]:\n", + " for bs in [3000, 4500]:\n", + " \n", + " print(f\"lr:{lr} ne:{ne} bs:{bs} ....\" )\n", + "\n", + " LR = lr\n", + " NUM_EPOCHS = ne\n", + " BATCH_SIZE = bs\n", + "\n", + " second_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "wMGBNnslD4ax", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wMGBNnslD4ax", + "outputId": "76b97e64-342e-4ef5-b71e-f6a88c7daf36" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.001 ne:100 bs:3000 0.14852701688006476 0.363608881781037\n", + "lr:0.001 ne:100 bs:4500 0.14894633680166167 0.3632090651116074\n", + "lr:0.001 ne:200 bs:3000 0.15524588725169922 0.35925965654772934\n", + "lr:0.001 ne:200 bs:4500 0.1548265673301023 0.35853803621753927\n", + "lr:0.0003 ne:100 bs:3000 0.15456327342584375 0.37245360663890703\n", + "lr:0.0003 ne:100 bs:4500 0.15338332666972218 0.3731167172125952\n", + "lr:0.0003 ne:200 bs:3000 0.15394892098257384 0.3672852448145728\n", + "lr:0.0003 ne:200 bs:4500 0.1538026465913191 0.36697319277989604\n" + ] + } + ], + "source": [ + "for i in second_arch_metrics.keys():\n", + " print(i, second_arch_metrics[i]['recall@5'], second_arch_metrics[i]['recall@10'])" + ] + }, + { + "cell_type": "markdown", + "id": "k6zRyd-uD0Fy", + "metadata": { + "id": "k6zRyd-uD0Fy" + }, + "source": [ + "Добавим еще слоев: " + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "08GqJ7iu-fPo", + "metadata": { + "id": "08GqJ7iu-fPo" + }, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, in_and_out_features = 8287):\n", + " super().__init__()\n", + " self.in_and_out_features = in_and_out_features\n", + " self.hidden_size = 512\n", + "\n", + " self.sequential = nn.Sequential( \n", + " nn.Linear(in_and_out_features, 6000), \n", + " nn.ReLU(), \n", + "\n", + " nn.Linear(6000, 3000), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(3000, 1024), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(1024, self.hidden_size), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(self.hidden_size, 1024), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(1024, 3000), \n", + " nn.ReLU(),\n", + "\n", + " nn.Linear(3000, 6000), \n", + " nn.ReLU(), \n", + "\n", + " nn.Linear(6000, in_and_out_features) # Another Linear transformation\n", + " )\n", + "\n", + " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n", + " x = self.sequential(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "AV4bbBpd-fSg", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AV4bbBpd-fSg", + "outputId": "2a985b96-d452-490b-d73c-419e0341b0d8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.0003 ne:100 bs:3000 ....\n", + "lr:0.0003 ne:100 bs:4500 ....\n", + "lr:0.0003 ne:200 bs:3000 ....\n", + "lr:0.0003 ne:200 bs:4500 ....\n" + ] + } + ], + "source": [ + "third_arch_metrics = {}\n", + "\n", + "for lr in [0.0003]:\n", + " for ne in [100, 200]:\n", + " for bs in [3000, 4500]:\n", + " \n", + " print(f\"lr:{lr} ne:{ne} bs:{bs} ....\" )\n", + "\n", + " LR = lr\n", + " NUM_EPOCHS = ne\n", + " BATCH_SIZE = bs\n", + "\n", + " third_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "v1gCEb8aFQc-", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "v1gCEb8aFQc-", + "outputId": "f7084cf6-b161-4951-cc6d-1abe32766205" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lr:0.0003 ne:100 bs:3000 0.24635532975123603 0.6131237383833754\n", + "lr:0.0003 ne:100 bs:4500 0.2430007703784606 0.6135430583049724\n", + "lr:0.0003 ne:200 bs:3000 0.2589251757730602 0.6017533423698401\n", + "lr:0.0003 ne:200 bs:4500 0.2589739339034784 0.6040157196212469\n" + ] + } + ], + "source": [ + "for i in third_arch_metrics.keys():\n", + " print(i, third_arch_metrics[i]['recall@5'], third_arch_metrics[i]['recall@10'])" + ] + }, + { + "cell_type": "markdown", + "id": "k0qGe8sVaZq4", + "metadata": { + "id": "k0qGe8sVaZq4" + }, + "source": [ + "Модель обучена. Лучшей моделью является модель последней архитектуры , со следующими подобранными гипперпараметрам:\n", + "\n", + "* LR: 0.0003\n", + "* NUM_EPOCHS: 200\n", + "* BATCH_SIZE: 4500\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sd7WVXXYo7H1", + "metadata": { + "id": "sd7WVXXYo7H1" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/hw_5_dssm.ipynb b/hw_5_dssm.ipynb new file mode 100644 index 00000000..258bae7e --- /dev/null +++ b/hw_5_dssm.ipynb @@ -0,0 +1,4034 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:22.841107Z", + "iopub.status.busy": "2023-01-22T16:23:22.840365Z", + "iopub.status.idle": "2023-01-22T16:23:22.850076Z", + "shell.execute_reply": "2023-01-22T16:23:22.848844Z", + "shell.execute_reply.started": "2023-01-22T16:23:22.841044Z" + } + }, + "outputs": [], + "source": [ + "import ast\n", + "import json\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "import pickle\n", + "import tensorflow as tf\n", + "import tensorflow.keras.backend as K\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from collections import Counter\n", + "from random import randint, random\n", + "from scipy.sparse import coo_matrix, hstack\n", + "from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity\n", + "from sklearn.metrics.pairwise import euclidean_distances as ED\n", + "from tensorflow import keras\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:39:51.661446Z", + "start_time": "2021-10-28T18:39:51.563879Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:22.852847Z", + "iopub.status.busy": "2023-01-22T16:23:22.851743Z", + "iopub.status.idle": "2023-01-22T16:23:29.088896Z", + "shell.execute_reply": "2023-01-22T16:23:29.087873Z", + "shell.execute_reply.started": "2023-01-22T16:23:22.852800Z" + }, + "id": "25508632" + }, + "outputs": [], + "source": [ + "interactions_df = pd.read_csv('interactions_processed_kion.csv')\n", + "users_df = pd.read_csv('users_processed_kion.csv')\n", + "items_df = pd.read_csv('items_processed_kion.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:35.447336Z", + "start_time": "2021-10-28T18:40:35.434541Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.097384Z", + "iopub.status.busy": "2023-01-22T16:23:29.094877Z", + "iopub.status.idle": "2023-01-22T16:23:29.123826Z", + "shell.execute_reply": "2023-01-22T16:23:29.123005Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.097341Z" + }, + "id": "f5eacb31", + "outputId": "37b5c35b-4f4b-48ea-9012-a6ce7eed31c7" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idageincomesexkids_flg
0973171age_25_34income_60_90MTrue
1962099age_18_24income_20_40MFalse
21047345age_45_54income_40_60FFalse
3721985age_45_54income_20_40FFalse
4704055age_35_44income_60_90FFalse
\n", + "
" + ], + "text/plain": [ + " user_id age income sex kids_flg\n", + "0 973171 age_25_34 income_60_90 M True\n", + "1 962099 age_18_24 income_20_40 M False\n", + "2 1047345 age_45_54 income_40_60 F False\n", + "3 721985 age_45_54 income_20_40 F False\n", + "4 704055 age_35_44 income_60_90 F False" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:36.103997Z", + "start_time": "2021-10-28T18:40:36.094699Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.130158Z", + "iopub.status.busy": "2023-01-22T16:23:29.127963Z", + "iopub.status.idle": "2023-01-22T16:23:29.145149Z", + "shell.execute_reply": "2023-01-22T16:23:29.144033Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.130122Z" + }, + "id": "61669d0d" + }, + "outputs": [], + "source": [ + "items_df = items_df.rename(columns = {'id' : 'item_id'})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:36.378293Z", + "start_time": "2021-10-28T18:40:36.370946Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.146754Z", + "iopub.status.busy": "2023-01-22T16:23:29.146394Z", + "iopub.status.idle": "2023-01-22T16:23:29.166993Z", + "shell.execute_reply": "2023-01-22T16:23:29.165796Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.146717Z" + }, + "id": "25f4462e", + "outputId": "5cc6c801-f866-4b52-aada-f5226a5ebc21" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_typetitletitle_origgenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywordsrelease_year_cat
010711filmпоговори с нейHable con ellaдрамы, зарубежные, детективы, мелодрамыиспанияFalse16.0unknownпедро альмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...2000-2010
12508filmголые перцыSearch Partyзарубежные, приключения, комедиисшаFalse16.0unknownскот армстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...2010-2020
210716filmтактическая силаTactical Forceкриминал, зарубежные, триллеры, боевики, комедииканадаFalse16.0unknownадам п. калтрароАдриан Холмс, Даррен Шалави, Джерри Вассерман,...Профессиональный рестлер Стив Остин («Все или ...Тактическая, сила, 2011, Канада, бандиты, ганг...2010-2020
37868film45 лет45 Yearsдрамы, зарубежные, мелодрамывеликобританияFalse16.0unknownэндрю хэйАлександра Риддлстон-Барретт, Джеральдин Джейм...Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...45, лет, 2015, Великобритания, брак, жизнь, лю...2010-2020
416268filmвсе решает мгновениеNaNдрамы, спорт, советские, мелодрамысссрFalse12.0ленфильмвиктор садовскийАлександр Абдулов, Александр Демьяненко, Алекс...Расчетливая чаровница из советского кинохита «...Все, решает, мгновение, 1978, СССР, сильные, ж...1970-1980
\n", + "
" + ], + "text/plain": [ + " item_id content_type title title_orig \\\n", + "0 10711 film поговори с ней Hable con ella \n", + "1 2508 film голые перцы Search Party \n", + "2 10716 film тактическая сила Tactical Force \n", + "3 7868 film 45 лет 45 Years \n", + "4 16268 film все решает мгновение NaN \n", + "\n", + " genres countries for_kids \\\n", + "0 драмы, зарубежные, детективы, мелодрамы испания False \n", + "1 зарубежные, приключения, комедии сша False \n", + "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n", + "3 драмы, зарубежные, мелодрамы великобритания False \n", + "4 драмы, спорт, советские, мелодрамы ссср False \n", + "\n", + " age_rating studios directors \\\n", + "0 16.0 unknown педро альмодовар \n", + "1 16.0 unknown скот армстронг \n", + "2 16.0 unknown адам п. калтраро \n", + "3 16.0 unknown эндрю хэй \n", + "4 12.0 ленфильм виктор садовский \n", + "\n", + " actors \\\n", + "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", + "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", + "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n", + "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n", + "4 Александр Абдулов, Александр Демьяненко, Алекс... \n", + "\n", + " description \\\n", + "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", + "1 Уморительная современная комедия на популярную... \n", + "2 Профессиональный рестлер Стив Остин («Все или ... \n", + "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n", + "4 Расчетливая чаровница из советского кинохита «... \n", + "\n", + " keywords release_year_cat \n", + "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n", + "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n", + "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n", + "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n", + "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:36.607688Z", + "start_time": "2021-10-28T18:40:36.597640Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.169473Z", + "iopub.status.busy": "2023-01-22T16:23:29.168713Z", + "iopub.status.idle": "2023-01-22T16:23:29.183035Z", + "shell.execute_reply": "2023-01-22T16:23:29.181327Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.169432Z" + }, + "id": "b41964d3", + "outputId": "b4c8f3d5-e7af-4e29-d2e8-0defb6993b35" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idlast_watch_dttotal_durwatched_pct
017654995062021-05-11425072
169931716592021-05-298317100
265668371072021-05-09100
386461376382021-07-0514483100
496486895062021-04-306725100
\n", + "
" + ], + "text/plain": [ + " user_id item_id last_watch_dt total_dur watched_pct\n", + "0 176549 9506 2021-05-11 4250 72\n", + "1 699317 1659 2021-05-29 8317 100\n", + "2 656683 7107 2021-05-09 10 0\n", + "3 864613 7638 2021-07-05 14483 100\n", + "4 964868 9506 2021-04-30 6725 100" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cd252422" + }, + "source": [ + "## Готовим фичи пользователей" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pBdccMPAr7KR" + }, + "source": [ + "Посмотрим, какие фичи в датасете фильмов являются категориальными и закодируем их с помощью one-hot encoding." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:37.156260Z", + "start_time": "2021-10-28T18:40:37.138422Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.185708Z", + "iopub.status.busy": "2023-01-22T16:23:29.184841Z", + "iopub.status.idle": "2023-01-22T16:23:29.504659Z", + "shell.execute_reply": "2023-01-22T16:23:29.503366Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.185668Z" + }, + "id": "692270ac", + "outputId": "7491ab1f-f9fb-4921-e383-7ecf5569e999" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idage_age_18_24age_age_25_34age_age_35_44age_age_45_54age_age_55_64age_age_65_infage_age_unknownincome_income_0_20income_income_150_infincome_income_20_40income_income_40_60income_income_60_90income_income_90_150income_income_unknownsex_Fsex_Msex_sex_unknownkids_flg_Falsekids_flg_True
0973171FalseTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalseFalseTrueFalseFalseTrue
1962099TrueFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalseFalseFalseFalseTrueFalseTrueFalse
21047345FalseFalseFalseTrueFalseFalseFalseFalseFalseFalseTrueFalseFalseFalseTrueFalseFalseTrueFalse
3721985FalseFalseFalseTrueFalseFalseFalseFalseFalseTrueFalseFalseFalseFalseTrueFalseFalseTrueFalse
4704055FalseFalseTrueFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalseTrueFalseFalseTrueFalse
\n", + "
" + ], + "text/plain": [ + " user_id age_age_18_24 age_age_25_34 age_age_35_44 age_age_45_54 \\\n", + "0 973171 False True False False \n", + "1 962099 True False False False \n", + "2 1047345 False False False True \n", + "3 721985 False False False True \n", + "4 704055 False False True False \n", + "\n", + " age_age_55_64 age_age_65_inf age_age_unknown income_income_0_20 \\\n", + "0 False False False False \n", + "1 False False False False \n", + "2 False False False False \n", + "3 False False False False \n", + "4 False False False False \n", + "\n", + " income_income_150_inf income_income_20_40 income_income_40_60 \\\n", + "0 False False False \n", + "1 False True False \n", + "2 False False True \n", + "3 False True False \n", + "4 False False False \n", + "\n", + " income_income_60_90 income_income_90_150 income_income_unknown sex_F \\\n", + "0 True False False False \n", + "1 False False False False \n", + "2 False False False True \n", + "3 False False False True \n", + "4 True False False True \n", + "\n", + " sex_M sex_sex_unknown kids_flg_False kids_flg_True \n", + "0 True False False True \n", + "1 True False True False \n", + "2 False False True False \n", + "3 False False True False \n", + "4 False False True False " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "user_cat_feats = [\"age\", \"income\", \"sex\", \"kids_flg\"]\n", + "# из исходного датафрейма оставим только item_id - этот признак нам понадобится позже\n", + "# для того, чтобы маппить айтемы из датафрейма с фильмами с айтемами \n", + "# из датафрейма с взаимодействиями\n", + "users_ohe_df = users_df.user_id\n", + "for feat in user_cat_feats:\n", + " # получаем датафрейм с one-hot encoding для каждой категориальной фичи\n", + " ohe_feat_df = pd.get_dummies(users_df[feat], prefix=feat)\n", + " # конкатенируем ohe-hot датафрейм с датафреймом, \n", + " # который мы получили на предыдущем шаге\n", + " users_ohe_df = pd.concat([users_ohe_df, ohe_feat_df], axis=1)\n", + "\n", + "users_ohe_df.head()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "74cdbd93" + }, + "source": [ + "## Готовим фичи айтемов" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5kHzJ91Mr35c" + }, + "source": [ + "Кодируем их точно так же - one-hot'ом." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.507174Z", + "iopub.status.busy": "2023-01-22T16:23:29.506716Z", + "iopub.status.idle": "2023-01-22T16:23:29.528115Z", + "shell.execute_reply": "2023-01-22T16:23:29.526826Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.507133Z" + }, + "id": "-2Wd9upSsCle", + "outputId": "671c2446-81f5-4e32-e24f-3aec9c8a2076" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_typetitletitle_origgenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywordsrelease_year_cat
010711filmпоговори с нейHable con ellaдрамы, зарубежные, детективы, мелодрамыиспанияFalse16.0unknownпедро альмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...2000-2010
12508filmголые перцыSearch Partyзарубежные, приключения, комедиисшаFalse16.0unknownскот армстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...2010-2020
210716filmтактическая силаTactical Forceкриминал, зарубежные, триллеры, боевики, комедииканадаFalse16.0unknownадам п. калтрароАдриан Холмс, Даррен Шалави, Джерри Вассерман,...Профессиональный рестлер Стив Остин («Все или ...Тактическая, сила, 2011, Канада, бандиты, ганг...2010-2020
37868film45 лет45 Yearsдрамы, зарубежные, мелодрамывеликобританияFalse16.0unknownэндрю хэйАлександра Риддлстон-Барретт, Джеральдин Джейм...Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...45, лет, 2015, Великобритания, брак, жизнь, лю...2010-2020
416268filmвсе решает мгновениеNaNдрамы, спорт, советские, мелодрамысссрFalse12.0ленфильмвиктор садовскийАлександр Абдулов, Александр Демьяненко, Алекс...Расчетливая чаровница из советского кинохита «...Все, решает, мгновение, 1978, СССР, сильные, ж...1970-1980
\n", + "
" + ], + "text/plain": [ + " item_id content_type title title_orig \\\n", + "0 10711 film поговори с ней Hable con ella \n", + "1 2508 film голые перцы Search Party \n", + "2 10716 film тактическая сила Tactical Force \n", + "3 7868 film 45 лет 45 Years \n", + "4 16268 film все решает мгновение NaN \n", + "\n", + " genres countries for_kids \\\n", + "0 драмы, зарубежные, детективы, мелодрамы испания False \n", + "1 зарубежные, приключения, комедии сша False \n", + "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n", + "3 драмы, зарубежные, мелодрамы великобритания False \n", + "4 драмы, спорт, советские, мелодрамы ссср False \n", + "\n", + " age_rating studios directors \\\n", + "0 16.0 unknown педро альмодовар \n", + "1 16.0 unknown скот армстронг \n", + "2 16.0 unknown адам п. калтраро \n", + "3 16.0 unknown эндрю хэй \n", + "4 12.0 ленфильм виктор садовский \n", + "\n", + " actors \\\n", + "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", + "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", + "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n", + "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n", + "4 Александр Абдулов, Александр Демьяненко, Алекс... \n", + "\n", + " description \\\n", + "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", + "1 Уморительная современная комедия на популярную... \n", + "2 Профессиональный рестлер Стив Остин («Все или ... \n", + "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n", + "4 Расчетливая чаровница из советского кинохита «... \n", + "\n", + " keywords release_year_cat \n", + "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n", + "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n", + "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n", + "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n", + "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:37.792147Z", + "start_time": "2021-10-28T18:40:37.537501Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:29.534806Z", + "iopub.status.busy": "2023-01-22T16:23:29.533869Z", + "iopub.status.idle": "2023-01-22T16:23:30.291045Z", + "shell.execute_reply": "2023-01-22T16:23:30.289998Z", + "shell.execute_reply.started": "2023-01-22T16:23:29.534762Z" + }, + "id": "7a94ef7e", + "outputId": "1ea7a769-8c2d-43d5-f2cb-47500bc0a7ba" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_type_filmcontent_type_seriesrelease_year_cat_1920-1930release_year_cat_1930-1940release_year_cat_1940-1950release_year_cat_1950-1960release_year_cat_1960-1970release_year_cat_1970-1980release_year_cat_1980-1990...directors_ярив хоровицdirectors_ярон зильберманdirectors_ярополк лапшинdirectors_ярослав лупийdirectors_ярроу чейни, скотт моужерdirectors_ясина сезарdirectors_ясуоми умэцуdirectors_ёдзи фукуяма, ацуко фукусима, николас де креси, синъитиро ватанабэ, сёдзи кавамориdirectors_ёлкин туйчиевdirectors_ён сан-хо
010711TrueFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
12508TrueFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
210716TrueFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
37868TrueFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
416268TrueFalseFalseFalseFalseFalseFalseTrueFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

5 rows × 8589 columns

\n", + "
" + ], + "text/plain": [ + " item_id content_type_film content_type_series \\\n", + "0 10711 True False \n", + "1 2508 True False \n", + "2 10716 True False \n", + "3 7868 True False \n", + "4 16268 True False \n", + "\n", + " release_year_cat_1920-1930 release_year_cat_1930-1940 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " release_year_cat_1940-1950 release_year_cat_1950-1960 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " release_year_cat_1960-1970 release_year_cat_1970-1980 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False True \n", + "\n", + " release_year_cat_1980-1990 ... directors_ярив хоровиц \\\n", + "0 False ... False \n", + "1 False ... False \n", + "2 False ... False \n", + "3 False ... False \n", + "4 False ... False \n", + "\n", + " directors_ярон зильберман directors_ярополк лапшин \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " directors_ярослав лупий directors_ярроу чейни, скотт моужер \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " directors_ясина сезар directors_ясуоми умэцу \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " directors_ёдзи фукуяма, ацуко фукусима, николас де креси, синъитиро ватанабэ, сёдзи кавамори \\\n", + "0 False \n", + "1 False \n", + "2 False \n", + "3 False \n", + "4 False \n", + "\n", + " directors_ёлкин туйчиев directors_ён сан-хо \n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + "[5 rows x 8589 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_cat_feats = ['content_type', 'release_year_cat',\n", + " 'for_kids', 'age_rating', \n", + " 'studios', 'countries', 'directors']\n", + "\n", + "items_ohe_df = items_df.item_id\n", + "\n", + "for feat in item_cat_feats:\n", + " ohe_feat_df = pd.get_dummies(items_df[feat], prefix=feat)\n", + " items_ohe_df = pd.concat([items_ohe_df, ohe_feat_df], axis=1) \n", + "\n", + "items_ohe_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:30.294678Z", + "iopub.status.busy": "2023-01-22T16:23:30.294379Z", + "iopub.status.idle": "2023-01-22T16:23:30.316137Z", + "shell.execute_reply": "2023-01-22T16:23:30.314916Z", + "shell.execute_reply.started": "2023-01-22T16:23:30.294651Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_typetitletitle_origgenrescountriesfor_kidsage_ratingstudiosdirectorsactorsdescriptionkeywordsrelease_year_cat
010711filmпоговори с нейHable con ellaдрамы, зарубежные, детективы, мелодрамыиспанияFalse16.0unknownпедро альмодоварАдольфо Фернандес, Ана Фернандес, Дарио Гранди...Мелодрама легендарного Педро Альмодовара «Пого...Поговори, ней, 2002, Испания, друзья, любовь, ...2000-2010
12508filmголые перцыSearch Partyзарубежные, приключения, комедиисшаFalse16.0unknownскот армстронгАдам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...Уморительная современная комедия на популярную...Голые, перцы, 2014, США, друзья, свадьбы, прео...2010-2020
210716filmтактическая силаTactical Forceкриминал, зарубежные, триллеры, боевики, комедииканадаFalse16.0unknownадам п. калтрароАдриан Холмс, Даррен Шалави, Джерри Вассерман,...Профессиональный рестлер Стив Остин («Все или ...Тактическая, сила, 2011, Канада, бандиты, ганг...2010-2020
37868film45 лет45 Yearsдрамы, зарубежные, мелодрамывеликобританияFalse16.0unknownэндрю хэйАлександра Риддлстон-Барретт, Джеральдин Джейм...Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...45, лет, 2015, Великобритания, брак, жизнь, лю...2010-2020
416268filmвсе решает мгновениеNaNдрамы, спорт, советские, мелодрамысссрFalse12.0ленфильмвиктор садовскийАлександр Абдулов, Александр Демьяненко, Алекс...Расчетливая чаровница из советского кинохита «...Все, решает, мгновение, 1978, СССР, сильные, ж...1970-1980
\n", + "
" + ], + "text/plain": [ + " item_id content_type title title_orig \\\n", + "0 10711 film поговори с ней Hable con ella \n", + "1 2508 film голые перцы Search Party \n", + "2 10716 film тактическая сила Tactical Force \n", + "3 7868 film 45 лет 45 Years \n", + "4 16268 film все решает мгновение NaN \n", + "\n", + " genres countries for_kids \\\n", + "0 драмы, зарубежные, детективы, мелодрамы испания False \n", + "1 зарубежные, приключения, комедии сша False \n", + "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n", + "3 драмы, зарубежные, мелодрамы великобритания False \n", + "4 драмы, спорт, советские, мелодрамы ссср False \n", + "\n", + " age_rating studios directors \\\n", + "0 16.0 unknown педро альмодовар \n", + "1 16.0 unknown скот армстронг \n", + "2 16.0 unknown адам п. калтраро \n", + "3 16.0 unknown эндрю хэй \n", + "4 12.0 ленфильм виктор садовский \n", + "\n", + " actors \\\n", + "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n", + "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n", + "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n", + "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n", + "4 Александр Абдулов, Александр Демьяненко, Алекс... \n", + "\n", + " description \\\n", + "0 Мелодрама легендарного Педро Альмодовара «Пого... \n", + "1 Уморительная современная комедия на популярную... \n", + "2 Профессиональный рестлер Стив Остин («Все или ... \n", + "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n", + "4 Расчетливая чаровница из советского кинохита «... \n", + "\n", + " keywords release_year_cat \n", + "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n", + "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n", + "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n", + "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n", + "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Добавим текстовые фичи\n", + "С помощью TFIDFVectorizer получим эмбеддинги следующих колонок: genres, description, keywords" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:30.318830Z", + "iopub.status.busy": "2023-01-22T16:23:30.318190Z", + "iopub.status.idle": "2023-01-22T16:23:30.335666Z", + "shell.execute_reply": "2023-01-22T16:23:30.334811Z", + "shell.execute_reply.started": "2023-01-22T16:23:30.318792Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.feature_extraction.text import TfidfVectorizer" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:30.338779Z", + "iopub.status.busy": "2023-01-22T16:23:30.338019Z", + "iopub.status.idle": "2023-01-22T16:23:31.386164Z", + "shell.execute_reply": "2023-01-22T16:23:31.385106Z", + "shell.execute_reply.started": "2023-01-22T16:23:30.338741Z" + } + }, + "outputs": [], + "source": [ + "for column in ['genres', 'keywords']:\n", + " tv = TfidfVectorizer(max_features = 500)\n", + " t = pd.DataFrame.sparse.from_spmatrix(tv.fit_transform(items_df[column]))\n", + " t.columns = [column + '_' + str(x) for x in t.columns]\n", + " items_ohe_df = pd.concat([items_ohe_df, t], axis = 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:31.388255Z", + "iopub.status.busy": "2023-01-22T16:23:31.387847Z", + "iopub.status.idle": "2023-01-22T16:23:31.483917Z", + "shell.execute_reply": "2023-01-22T16:23:31.482751Z", + "shell.execute_reply.started": "2023-01-22T16:23:31.388214Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idcontent_type_filmcontent_type_seriesrelease_year_cat_1920-1930release_year_cat_1930-1940release_year_cat_1940-1950release_year_cat_1950-1960release_year_cat_1960-1970release_year_cat_1970-1980release_year_cat_1980-1990...keywords_490keywords_491keywords_492keywords_493keywords_494keywords_495keywords_496keywords_497keywords_498keywords_499
010711TrueFalseFalseFalseFalseFalseFalseFalseFalse...0.00.00.00.00.00.00.00.00.00.0
12508TrueFalseFalseFalseFalseFalseFalseFalseFalse...0.00.00.00.00.00.00.00.00.00.0
210716TrueFalseFalseFalseFalseFalseFalseFalseFalse...0.00.00.00.00.00.00.00.00.00.0
37868TrueFalseFalseFalseFalseFalseFalseFalseFalse...0.00.00.00.00.00.00.00.00.00.0
416268TrueFalseFalseFalseFalseFalseFalseTrueFalse...0.00.00.00.00.00.00.00.00.00.0
\n", + "

5 rows × 9197 columns

\n", + "
" + ], + "text/plain": [ + " item_id content_type_film content_type_series \\\n", + "0 10711 True False \n", + "1 2508 True False \n", + "2 10716 True False \n", + "3 7868 True False \n", + "4 16268 True False \n", + "\n", + " release_year_cat_1920-1930 release_year_cat_1930-1940 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " release_year_cat_1940-1950 release_year_cat_1950-1960 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False False \n", + "\n", + " release_year_cat_1960-1970 release_year_cat_1970-1980 \\\n", + "0 False False \n", + "1 False False \n", + "2 False False \n", + "3 False False \n", + "4 False True \n", + "\n", + " release_year_cat_1980-1990 ... keywords_490 keywords_491 keywords_492 \\\n", + "0 False ... 0.0 0.0 0.0 \n", + "1 False ... 0.0 0.0 0.0 \n", + "2 False ... 0.0 0.0 0.0 \n", + "3 False ... 0.0 0.0 0.0 \n", + "4 False ... 0.0 0.0 0.0 \n", + "\n", + " keywords_493 keywords_494 keywords_495 keywords_496 keywords_497 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 0.0 0.0 0.0 0.0 0.0 \n", + "4 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " keywords_498 keywords_499 \n", + "0 0.0 0.0 \n", + "1 0.0 0.0 \n", + "2 0.0 0.0 \n", + "3 0.0 0.0 \n", + "4 0.0 0.0 \n", + "\n", + "[5 rows x 9197 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_ohe_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cc595c20" + }, + "source": [ + "## Сделаем матрицу взаимодействий" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:37.898427Z", + "start_time": "2021-10-28T18:40:37.864067Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:31.486206Z", + "iopub.status.busy": "2023-01-22T16:23:31.485812Z", + "iopub.status.idle": "2023-01-22T16:23:31.604748Z", + "shell.execute_reply": "2023-01-22T16:23:31.603679Z", + "shell.execute_reply.started": "2023-01-22T16:23:31.486170Z" + }, + "id": "79c9bca3", + "outputId": "6f6148e0-8de7-4ffc-82d9-cf396db1ed98" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "item_id\n", + "10440 202457\n", + "15297 193123\n", + "9728 132865\n", + "13865 122119\n", + "4151 91167\n", + " ... \n", + "8076 1\n", + "8954 1\n", + "15664 1\n", + "818 1\n", + "10542 1\n", + "Name: count, Length: 15706, dtype: int64" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.item_id.value_counts()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YAfqm8asrBfG" + }, + "source": [ + "В датасете взаимодействий есть непопулярные фильмы и малоактивные пользователи. Кроме того, в таблице взаимодействий есть события с низким качеством взаимодействия - когда юзер начал смотреть фильм, но вскоре после начала просмотра выключил.\n", + "\n", + "Отфильтруем такие события*, малоактивных юзеров и непопулярные фильмы.\n", + "\n", + "Можете не фильтровать такие события, тогда у вас будет больше негативных примеров." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:38.103819Z", + "start_time": "2021-10-28T18:40:38.070117Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:31.606489Z", + "iopub.status.busy": "2023-01-22T16:23:31.606197Z", + "iopub.status.idle": "2023-01-22T16:23:31.985392Z", + "shell.execute_reply": "2023-01-22T16:23:31.984254Z", + "shell.execute_reply.started": "2023-01-22T16:23:31.606462Z" + }, + "id": "17334e80", + "outputId": "bfbe26dd-7778-42ad-c5dd-283635fcafa6" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "user_id\n", + "416206 1341\n", + "1010539 764\n", + "555233 685\n", + "11526 676\n", + "409259 625\n", + " ... \n", + "45493 1\n", + "615194 1\n", + "96848 1\n", + "425823 1\n", + "697262 1\n", + "Name: count, Length: 962179, dtype: int64" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.user_id.value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:39.717096Z", + "start_time": "2021-10-28T18:40:38.759740Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:31.987509Z", + "iopub.status.busy": "2023-01-22T16:23:31.986995Z", + "iopub.status.idle": "2023-01-22T16:23:34.897911Z", + "shell.execute_reply": "2023-01-22T16:23:34.896578Z", + "shell.execute_reply.started": "2023-01-22T16:23:31.987469Z" + }, + "id": "076e4ebc", + "outputId": "85c15fd2-12bb-478c-e00f-4f2b7bbcd6ab" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N users before: 962179\n", + "N items before: 15706\n", + "\n", + "N users after: 79515\n", + "N items after: 6901\n" + ] + } + ], + "source": [ + "print(f\"N users before: {interactions_df.user_id.nunique()}\")\n", + "print(f\"N items before: {interactions_df.item_id.nunique()}\\n\")\n", + "\n", + "# отфильтруем все события взаимодействий, в которых пользователь посмотрел\n", + "# фильм менее чем на 10 процентов\n", + "interactions_df = interactions_df[interactions_df.watched_pct > 10]\n", + "\n", + "# соберем всех пользователей, которые посмотрели \n", + "# больше 10 фильмов (можете выбрать другой порог)\n", + "valid_users = []\n", + "\n", + "c = Counter(interactions_df.user_id)\n", + "for user_id, entries in c.most_common():\n", + " if entries > 10:\n", + " valid_users.append(user_id)\n", + "\n", + "# и соберем все фильмы, которые посмотрели больше 10 пользователей\n", + "valid_items = []\n", + "\n", + "c = Counter(interactions_df.item_id)\n", + "for item_id, entries in c.most_common():\n", + " if entries > 10:\n", + " valid_items.append(item_id)\n", + "\n", + "# отбросим непопулярные фильмы и неактивных юзеров\n", + "interactions_df = interactions_df[interactions_df.user_id.isin(valid_users)]\n", + "interactions_df = interactions_df[interactions_df.item_id.isin(valid_items)]\n", + "\n", + "print(f\"N users after: {interactions_df.user_id.nunique()}\")\n", + "print(f\"N items after: {interactions_df.item_id.nunique()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a9163fb2" + }, + "source": [ + "После фильтрации может получиться так, что некоторые айтемы/юзеры есть в датасете взаимодействий, но при этом они отсутствуют в датасетах айтемов/юзеров или наоборот. Поэтому найдем id айтемов и id юзеров, которые есть во всех датасетах и оставим только их." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:40.231703Z", + "start_time": "2021-10-28T18:40:39.718626Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:34.900180Z", + "iopub.status.busy": "2023-01-22T16:23:34.899760Z", + "iopub.status.idle": "2023-01-22T16:23:36.064882Z", + "shell.execute_reply": "2023-01-22T16:23:36.063765Z", + "shell.execute_reply.started": "2023-01-22T16:23:34.900142Z" + }, + "id": "d55848e1", + "outputId": "48609a0b-06b9-4a5e-f8f6-061db1c6dcb2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "65974\n", + "6901\n" + ] + } + ], + "source": [ + "common_users = set(interactions_df.user_id.unique()).intersection(set(users_ohe_df.user_id.unique()))\n", + "common_items = set(interactions_df.item_id.unique()).intersection(set(items_ohe_df.item_id.unique()))\n", + "\n", + "print(len(common_users))\n", + "print(len(common_items))\n", + "\n", + "interactions_df = interactions_df[interactions_df.item_id.isin(common_items)]\n", + "interactions_df = interactions_df[interactions_df.user_id.isin(common_users)]\n", + "\n", + "items_ohe_df = items_ohe_df[items_ohe_df.item_id.isin(common_items)]\n", + "users_ohe_df = users_ohe_df[users_ohe_df.user_id.isin(common_users)]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1e8b9480" + }, + "source": [ + "\n", + "Соберем взаимодействия в матрицу user*item так, чтобы в строках этой матрицы были user_id, в столбцах - item_id, а на пересечениях строк и столбцов - единица, если пользователь взаимодействовал с айтемом и ноль, если нет.\n", + "\n", + "Такую матрицу удобно собирать в numpy array, однако нужно помнить, что numpy array индексируется порядковыми индексами, а нам же удобнее использовать item_id и user_id.\n", + "\n", + "Создадим некие внутренние индексы для user_id и item_id - uid и iid. Для этого просто соберем все user_id и item_id и пронумеруем их по порядку." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:40:40.346587Z", + "start_time": "2021-10-28T18:40:40.233046Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:36.066990Z", + "iopub.status.busy": "2023-01-22T16:23:36.066574Z", + "iopub.status.idle": "2023-01-22T16:23:36.211726Z", + "shell.execute_reply": "2023-01-22T16:23:36.210597Z", + "shell.execute_reply.started": "2023-01-22T16:23:36.066949Z" + }, + "id": "81679fb0", + "outputId": "0c6bf7ce-1ea0-46c2-9d70-42b32bf08c7e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3, 4]\n", + "[0, 1, 2, 3, 4]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idlast_watch_dttotal_durwatched_pctuidiid
017654995062021-05-11425072106163944
169931716592021-05-29831710042131675
610164583542021-08-1416722561024139
78840096932021-08-047031453150279
14532484372021-04-186598923103485
\n", + "
" + ], + "text/plain": [ + " user_id item_id last_watch_dt total_dur watched_pct uid iid\n", + "0 176549 9506 2021-05-11 4250 72 10616 3944\n", + "1 699317 1659 2021-05-29 8317 100 42131 675\n", + "6 1016458 354 2021-08-14 1672 25 61024 139\n", + "7 884009 693 2021-08-04 703 14 53150 279\n", + "14 5324 8437 2021-04-18 6598 92 310 3485" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df[\"uid\"] = interactions_df[\"user_id\"].astype(\"category\")\n", + "interactions_df[\"uid\"] = interactions_df[\"uid\"].cat.codes\n", + "\n", + "interactions_df[\"iid\"] = interactions_df[\"item_id\"].astype(\"category\")\n", + "interactions_df[\"iid\"] = interactions_df[\"iid\"].cat.codes\n", + "\n", + "print(sorted(interactions_df.iid.unique())[:5])\n", + "print(sorted(interactions_df.uid.unique())[:5])\n", + "interactions_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "61c855e5" + }, + "source": [ + "Отнормируем матрицу взаимодействий" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:36.214161Z", + "iopub.status.busy": "2023-01-22T16:23:36.213276Z", + "iopub.status.idle": "2023-01-22T16:23:36.223246Z", + "shell.execute_reply": "2023-01-22T16:23:36.222069Z", + "shell.execute_reply.started": "2023-01-22T16:23:36.214121Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 3944\n", + "1 675\n", + "6 139\n", + "7 279\n", + "14 3485\n", + " ... \n", + "5476218 169\n", + "5476224 923\n", + "5476226 5610\n", + "5476239 2929\n", + "5476249 6766\n", + "Name: iid, Length: 1463641, dtype: int16" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactions_df.iid" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.360248Z", + "start_time": "2021-10-28T18:40:40.348057Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:36.225520Z", + "iopub.status.busy": "2023-01-22T16:23:36.224590Z", + "iopub.status.idle": "2023-01-22T16:23:43.629733Z", + "shell.execute_reply": "2023-01-22T16:23:43.628568Z", + "shell.execute_reply.started": "2023-01-22T16:23:36.225480Z" + }, + "id": "3feced70" + }, + "outputs": [], + "source": [ + "interactions_vec = np.zeros((interactions_df.uid.nunique(), \n", + " interactions_df.iid.nunique())) \n", + "\n", + "for user_id, item_id in zip(interactions_df.uid, interactions_df.iid):\n", + " interactions_vec[user_id, item_id] += 1\n", + "\n", + "\n", + "res = interactions_vec.sum(axis=1)\n", + "for i in range(len(interactions_vec)):\n", + " interactions_vec[i] /= res[i]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.416061Z", + "start_time": "2021-10-28T18:41:03.363462Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:43.634195Z", + "iopub.status.busy": "2023-01-22T16:23:43.631362Z", + "iopub.status.idle": "2023-01-22T16:23:43.711673Z", + "shell.execute_reply": "2023-01-22T16:23:43.710586Z", + "shell.execute_reply.started": "2023-01-22T16:23:43.634161Z" + }, + "id": "9f5ec90f", + "outputId": "9acdfe45-aa4e-4a64-ffdc-1a750390ae84" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6897\n", + "6901\n", + "65974\n", + "65974\n", + "{11805, 9788, 11501, 1734}\n" + ] + } + ], + "source": [ + "print(interactions_df.item_id.nunique())\n", + "print(items_ohe_df.item_id.nunique())\n", + "print(interactions_df.user_id.nunique())\n", + "print(users_ohe_df.user_id.nunique())\n", + "\n", + "print(set(items_ohe_df.item_id.unique()) - set(interactions_df.item_id.unique()))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:23:43.713551Z", + "iopub.status.busy": "2023-01-22T16:23:43.713196Z", + "iopub.status.idle": "2023-01-22T16:23:44.238808Z", + "shell.execute_reply": "2023-01-22T16:23:44.237691Z", + "shell.execute_reply.started": "2023-01-22T16:23:43.713517Z" + } + }, + "outputs": [], + "source": [ + "items_ohe_df = items_ohe_df[~items_ohe_df.item_id.isin([11805, 9788, 11501, 1734])]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "19e69bae" + }, + "source": [ + "Для того, чтобы можно было удобно превратить iid/uid в item_id/user_id и наоборот соберем словари \n", + "\n", + "{iid: item_id}, {uid: user_id} и {item_id: iid}, {user_id: uid}." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.637495Z", + "start_time": "2021-10-28T18:41:03.417544Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:44.243767Z", + "iopub.status.busy": "2023-01-22T16:23:44.243422Z", + "iopub.status.idle": "2023-01-22T16:23:44.817126Z", + "shell.execute_reply": "2023-01-22T16:23:44.816088Z", + "shell.execute_reply.started": "2023-01-22T16:23:44.243739Z" + }, + "id": "c8a84024" + }, + "outputs": [], + "source": [ + "iid_to_item_id = interactions_df[[\"iid\", \"item_id\"]].drop_duplicates().set_index(\"iid\").to_dict()[\"item_id\"]\n", + "item_id_to_iid = interactions_df[[\"iid\", \"item_id\"]].drop_duplicates().set_index(\"item_id\").to_dict()[\"iid\"]\n", + "\n", + "uid_to_user_id = interactions_df[[\"uid\", \"user_id\"]].drop_duplicates().set_index(\"uid\").to_dict()[\"user_id\"]\n", + "user_id_to_uid = interactions_df[[\"uid\", \"user_id\"]].drop_duplicates().set_index(\"user_id\").to_dict()[\"uid\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "48ca5204" + }, + "source": [ + "И проиндексируем датасеты users_ohe_df и items_ohe_df по внутренним айди:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.744883Z", + "start_time": "2021-10-28T18:41:03.638719Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:44.819593Z", + "iopub.status.busy": "2023-01-22T16:23:44.818859Z", + "iopub.status.idle": "2023-01-22T16:23:44.930257Z", + "shell.execute_reply": "2023-01-22T16:23:44.929032Z", + "shell.execute_reply.started": "2023-01-22T16:23:44.819553Z" + }, + "id": "4c4980ac" + }, + "outputs": [], + "source": [ + "items_ohe_df[\"iid\"] = items_ohe_df[\"item_id\"].apply(lambda x: item_id_to_iid[x])\n", + "items_ohe_df = items_ohe_df.set_index(\"iid\")\n", + "\n", + "users_ohe_df[\"uid\"] = users_ohe_df[\"user_id\"].apply(lambda x: user_id_to_uid[x])\n", + "users_ohe_df = users_ohe_df.set_index(\"uid\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.749717Z", + "start_time": "2021-10-28T18:41:03.746067Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:23:44.932306Z", + "iopub.status.busy": "2023-01-22T16:23:44.931684Z", + "iopub.status.idle": "2023-01-22T16:23:44.939719Z", + "shell.execute_reply": "2023-01-22T16:23:44.938755Z", + "shell.execute_reply.started": "2023-01-22T16:23:44.932267Z" + }, + "id": "22c26d39" + }, + "outputs": [], + "source": [ + "def triplet_loss(y_true, y_pred, n_dims=128, alpha=0.4):\n", + " # будем ожидать, что на вход функции прилетит три сконкатенированных \n", + " # вектора - вектор юзера и два вектора айтема\n", + " anchor = y_pred[:, 0:n_dims]\n", + " positive = y_pred[:, n_dims:n_dims*2]\n", + " negative = y_pred[:, n_dims*2:n_dims*3]\n", + "\n", + " # считаем расстояния от вектора юзера до вектора хорошего айтема\n", + " pos_dist = K.sum(K.square(anchor - positive), axis=1)\n", + " # и до плохого\n", + " neg_dist = K.sum(K.square(anchor - negative), axis=1)\n", + "\n", + " # считаем лосс\n", + " basic_loss = pos_dist - neg_dist + alpha\n", + " loss = K.maximum(basic_loss, 0.0) # возвращаем ноль, если лосс отрицательный\n", + " \n", + " return loss\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T19:19:05.615364Z", + "start_time": "2021-10-28T19:19:05.612463Z" + }, + "id": "4de262b4" + }, + "source": [ + "Попробуйте другие лоссы, например, BPR Triplet loss" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:18:11.520568Z", + "iopub.status.busy": "2023-01-22T16:18:11.519791Z", + "iopub.status.idle": "2023-01-22T16:18:11.535194Z", + "shell.execute_reply": "2023-01-22T16:18:11.533962Z", + "shell.execute_reply.started": "2023-01-22T16:18:11.520528Z" + } + }, + "outputs": [], + "source": [ + "def bpr_triplet_loss(y_true, y_pred, n_dims=128):\n", + " \n", + " from keras import backend as K\n", + " \n", + " anchor = y_pred[:, 0:n_dims]\n", + " positive = y_pred[:, n_dims:n_dims*2]\n", + " negative = y_pred[:, n_dims*2:n_dims*3]\n", + "\n", + " # BPR loss\n", + " loss = 1.0 - K.sigmoid(\n", + " K.sum(anchor * positive, axis=-1, keepdims=True) -\n", + " K.sum(anchor * negative, axis=-1, keepdims=True))\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-23T11:20:03.327838Z", + "start_time": "2021-10-23T11:20:03.324389Z" + }, + "id": "85d618b6" + }, + "source": [ + "## Генератор и семплирование\n", + "\n", + "- хорошим примером будет тот айтем, который был взят из датасета взаимодействий в соответствии с распределением просмотренных айтемов для этого юзера;\n", + "- Для негативного буду рандомно брать айтем из 100 наиболее непохожих по евклидовому расстоянию на положительный айтем по вектору жанр и ключевые слова, который человек при этом не смотрел \n", + "\n", + "Т. о., если например человек посмотрел целиком триллер, то в негативный для него должно попасть что-то вроде мелодрамы, при этом ключевые слова тоже будут сильно отличаться \n", + "\n", + "\n", + "Сформируем заранее следующий словарь - для каждого айтема: список из ста наиболее непохожих айтемов. Тогда в генераторе нужно будет взять рандомное значение их ста айтемов для положительного айтема. Если считать это в моменте работы генератора, то получается чрезвычайно долго, а здесь обращение к словарю - O(1), и взятие рандомного значения такое же по сложности, как в простом генераторе\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T15:37:25.302855Z", + "iopub.status.busy": "2023-01-22T15:37:25.302472Z", + "iopub.status.idle": "2023-01-22T15:37:32.215552Z", + "shell.execute_reply": "2023-01-22T15:37:32.214488Z", + "shell.execute_reply.started": "2023-01-22T15:37:25.302820Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 6897/6897 [00:02<00:00, 2513.96it/s]\n" + ] + } + ], + "source": [ + "# формируем слоарь\n", + "\n", + "fts = items_ohe_df[[x for x in items_ohe_df if 'genre' in x or 'keywords' in x]]\n", + "\n", + "distances = pd.DataFrame(ED(fts))\n", + "distances.columns = list(fts.index)\n", + "distances.index = fts.index\n", + "\n", + "distance_dict = {}\n", + "for i in tqdm(distances.columns):\n", + " distance_dict[i] = list(distances[i].sort_values()[-100:].index)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:39:46.221189Z", + "iopub.status.busy": "2023-01-22T12:39:46.220459Z", + "iopub.status.idle": "2023-01-22T12:39:49.254755Z", + "shell.execute_reply": "2023-01-22T12:39:49.253714Z", + "shell.execute_reply.started": "2023-01-22T12:39:46.221147Z" + } + }, + "outputs": [], + "source": [ + "iids_ = np.array(fts.index)\n", + "user_interactions = interactions_df.groupby(\"uid\")['iid'].apply(lambda x: np.array(x.unique())).to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T15:37:35.885246Z", + "iopub.status.busy": "2023-01-22T15:37:35.884025Z", + "iopub.status.idle": "2023-01-22T15:37:35.891070Z", + "shell.execute_reply": "2023-01-22T15:37:35.889851Z", + "shell.execute_reply.started": "2023-01-22T15:37:35.885203Z" + } + }, + "outputs": [], + "source": [ + "def get_negative_sample(pos_i, uid_i, distance_dict):\n", + " \n", + " neg_i = np.random.choice(distance_dict[pos_i])\n", + " \n", + " return neg_i" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T12:39:49.263942Z", + "iopub.status.busy": "2023-01-22T12:39:49.263310Z", + "iopub.status.idle": "2023-01-22T12:39:49.273779Z", + "shell.execute_reply": "2023-01-22T12:39:49.272870Z", + "shell.execute_reply.started": "2023-01-22T12:39:49.263906Z" + } + }, + "outputs": [], + "source": [ + "# функция для нахождения отрицательных item\n", + "\n", + "# очень долго работает \n", + "def get_negative_sample_old(pos_i, uid_i, fts, iids_, user_interactions):\n", + " \n", + " # айтемы , с которыми взаимодействовал юзер, их исключим\n", + " user_watched_items = user_interactions[uid_i]\n", + " \n", + " # векторы айтмов, которые не смотрел юзер, и по которым посчитаем евклидовы дистанции,\n", + " # чтобы найти самые непохожие на тот айтем, который юзер смотрел\n", + "\n", + " # из всего списка item вычитаем те, с которыми пользователь взаимодействовал\n", + " # список item которых пользователь не видел\n", + " inters = np.setdiff1d(iids_, user_watched_items, assume_unique=True)\n", + " \n", + " fts_ = fts.loc[inters].sample(n = 100)\n", + " \n", + " # вектор позитивного айтема \n", + " pos_item_fts = pd.DataFrame(fts.loc[pos_i, :]).T\n", + " \n", + " # считаем дистанции\n", + " dists = ED(fts_, pos_item_fts)\n", + " \n", + " # берем десять самых непохожих и непросмотренных юзером айтемов и из них случайно выбираем один \n", + " fts_['dists'] = dists\n", + " fts_ = fts_[['dists']]\n", + " neg_candidates = fts_.sort_values(by = \"dists\")[-10:].index\n", + " \n", + " neg_i = np.random.choice(neg_candidates)\n", + " \n", + " return neg_i" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:03.755386Z", + "start_time": "2021-10-28T18:41:03.750664Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:37:53.122995Z", + "iopub.status.busy": "2023-01-22T15:37:53.122612Z", + "iopub.status.idle": "2023-01-22T15:37:53.132222Z", + "shell.execute_reply": "2023-01-22T15:37:53.130866Z", + "shell.execute_reply.started": "2023-01-22T15:37:53.122960Z" + }, + "id": "7829878b" + }, + "outputs": [], + "source": [ + "def generator(items, users, interactions, batch_size=1024):\n", + " while True:\n", + " uid_meta = []\n", + " uid_interaction = []\n", + " pos = []\n", + " neg = []\n", + " for _ in range(batch_size):\n", + " # берем рандомный uid\n", + " uid_i = randint(0, interactions.shape[0]-1)\n", + " # id хорошего айтема\n", + " pos_i = np.random.choice(range(interactions.shape[1]), p=interactions[uid_i])\n", + " # id плохого айтема\n", + " #neg_i = np.random.choice(range(interactions.shape[1]))\n", + " #neg_i = get_negative_sample_old(pos_i, uid_i, fts, iids_, user_interactions)\n", + " neg_i = get_negative_sample(pos_i, uid_i, distance_dict)\n", + " # фичи юзера\n", + " uid_meta.append(users.iloc[uid_i])\n", + " # вектор айтемов, с которыми юзер взаимодействовал\n", + " uid_interaction.append(interactions_vec[uid_i])\n", + " # фичи хорошего айтема\n", + " pos.append(items.iloc[pos_i])\n", + " # фичи плохого айтема\n", + " neg.append(items.iloc[neg_i])\n", + " \n", + " yield [np.array(uid_meta), np.array(uid_interaction), np.array(pos), np.array(neg)], [np.array(uid_meta), np.array(uid_interaction)]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.386864Z", + "start_time": "2021-10-28T18:41:03.756363Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:37:57.807501Z", + "iopub.status.busy": "2023-01-22T15:37:57.807136Z", + "iopub.status.idle": "2023-01-22T15:38:48.900316Z", + "shell.execute_reply": "2023-01-22T15:38:48.899211Z", + "shell.execute_reply.started": "2023-01-22T15:37:57.807471Z" + }, + "id": "af9d3c3b", + "outputId": "1040f567-f64a-4ccb-91a8-48034694dfdc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "вектор фичей юзера: (1024, 19)\n", + "вектор взаимодействий юзера с айтемами: (1024, 6897)\n", + "вектор 'хорошего' айтема: (1024, 9196)\n", + "вектор 'плохого' айтема: (1024, 9196)\n", + "\n", + "вектор фичей юзера: (1024, 19)\n", + "вектор взаимодействий юзера с айтемами: (1024, 6897)\n" + ] + } + ], + "source": [ + "# инициализируем генератор\n", + "gen = generator(items=items_ohe_df.drop([\"item_id\"], axis=1), \n", + " users=users_ohe_df.drop([\"user_id\"], axis=1), \n", + " interactions=interactions_vec, batch_size=1024)\n", + "\n", + "ret = next(gen)\n", + "\n", + "\n", + "print(f\"вектор фичей юзера: {ret[0][0].shape}\")\n", + "print(f\"вектор взаимодействий юзера с айтемами: {ret[0][1].shape}\")\n", + "print(f\"вектор 'хорошего' айтема: {ret[0][2].shape}\")\n", + "print(f\"вектор 'плохого' айтема: {ret[0][3].shape}\")\n", + "print()\n", + "print(f\"вектор фичей юзера: {ret[1][0].shape}\")\n", + "print(f\"вектор взаимодействий юзера с айтемами: {ret[1][1].shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8bcc3e80" + }, + "source": [ + "##Генаратор, который будет использовать информацию о качестве взаимодействия юзеров с айтемами для более репрезентативного сэмплирования\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.493030Z", + "start_time": "2021-10-28T18:41:16.388592Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:48.903047Z", + "iopub.status.busy": "2023-01-22T15:38:48.902586Z", + "iopub.status.idle": "2023-01-22T15:38:49.025937Z", + "shell.execute_reply": "2023-01-22T15:38:49.024831Z", + "shell.execute_reply.started": "2023-01-22T15:38:48.902992Z" + }, + "id": "967b819f", + "outputId": "2f7a5885-dcb3-4ab8-80f8-57a21635595d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N_FACTORS: 128\n", + "ITEM_MODEL_SHAPE: (9196,)\n", + "USER_META_MODEL_SHAPE: (19,)\n", + "USER_INTERACTION_MODEL_SHAPE: (6897,)\n" + ] + } + ], + "source": [ + "N_FACTORS = 128\n", + "\n", + "# в датасетах есть столбец user_id/item_id, помним, что он не является фичей для обучения!\n", + "ITEM_MODEL_SHAPE = (items_ohe_df.drop([\"item_id\"], axis=1).shape[1], ) \n", + "USER_META_MODEL_SHAPE = (users_ohe_df.drop([\"user_id\"], axis=1).shape[1], )\n", + "\n", + "USER_INTERACTION_MODEL_SHAPE = (interactions_vec.shape[1], )\n", + "\n", + "print(f\"N_FACTORS: {N_FACTORS}\")\n", + "print(f\"ITEM_MODEL_SHAPE: {ITEM_MODEL_SHAPE}\")\n", + "print(f\"USER_META_MODEL_SHAPE: {USER_META_MODEL_SHAPE}\")\n", + "print(f\"USER_INTERACTION_MODEL_SHAPE: {USER_INTERACTION_MODEL_SHAPE}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.816499Z", + "start_time": "2021-10-28T18:41:16.494387Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:49.027755Z", + "iopub.status.busy": "2023-01-22T15:38:49.027467Z", + "iopub.status.idle": "2023-01-22T15:38:53.151538Z", + "shell.execute_reply": "2023-01-22T15:38:53.150595Z", + "shell.execute_reply.started": "2023-01-22T15:38:49.027729Z" + }, + "id": "de649a01" + }, + "outputs": [], + "source": [ + "def item_model(n_factors=N_FACTORS):\n", + " # входной слой\n", + " inp = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n", + " \n", + " # полносвязный слой\n", + " layer_1 = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp)\n", + "\n", + " # делаем residual connection - складываем два слоя, \n", + " # чтобы градиенты не затухали во время обучения\n", + " layer_2 = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(layer_1)\n", + " \n", + " add = keras.layers.Add()([layer_1, layer_2])\n", + " \n", + " # выходной слой\n", + " out = keras.layers.Dense(N_FACTORS, activation='linear', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(add)\n", + " \n", + " return keras.models.Model(inp, out)\n", + "\n", + "\n", + "def user_model(n_factors=N_FACTORS):\n", + " # входной слой для вектора фичей юзера (из users_ohe_df)\n", + " inp_meta = keras.layers.Input(shape=USER_META_MODEL_SHAPE)\n", + " # входной слой для вектора просмотров (из iteractions_vec)\n", + " inp_interaction = keras.layers.Input(shape=USER_INTERACTION_MODEL_SHAPE)\n", + "\n", + " # полносвязный слой\n", + " layer_1_meta = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp_meta)\n", + "\n", + " layer_1_interaction = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp_interaction)\n", + "\n", + " # делаем residual connection - складываем два слоя,\n", + " # чтобы градиенты не затухали во время обучения\n", + " layer_2_meta = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(layer_1_meta)\n", + " \n", + "\n", + " add = keras.layers.Add()([layer_1_meta, layer_2_meta])\n", + " \n", + " # конкатенируем вектор фичей с вектором просмотров\n", + " concat_meta_interaction = keras.layers.Concatenate()([add, layer_1_interaction])\n", + " \n", + " # выходной слой\n", + " out = keras.layers.Dense(N_FACTORS, activation='linear', use_bias=False,\n", + " kernel_regularizer=keras.regularizers.l2(1e-6),\n", + " activity_regularizer=keras.regularizers.l2(l2=1e-6))(concat_meta_interaction)\n", + " \n", + " return keras.models.Model([inp_meta, inp_interaction], out)\n", + "\n", + "# инициализируем модели юзера и айтема\n", + "i2v = item_model()\n", + "u2v = user_model()\n", + "\n", + "# вход для вектора фичей юзера (из users_ohe_df)\n", + "ancor_meta_in = keras.layers.Input(shape=USER_META_MODEL_SHAPE)\n", + "# вход для вектора просмотра юзера (из interactions_vec)\n", + "ancor_interaction_in = keras.layers.Input(shape=USER_INTERACTION_MODEL_SHAPE)\n", + "\n", + "# вход для вектора \"хорошего\" айтема\n", + "pos_in = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n", + "# вход для вектора \"плохого\" айтема\n", + "neg_in = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n", + "\n", + "# получаем вектор юзера\n", + "ancor = u2v([ancor_meta_in, ancor_interaction_in])\n", + "# получаем вектор \"хорошего\" айтема\n", + "pos = i2v(pos_in)\n", + "# получаем вектор \"плохого\" айтема\n", + "neg = i2v(neg_in)\n", + "\n", + "# конкатенируем полученные векторы\n", + "res = keras.layers.Concatenate(name=\"concat_ancor_pos_neg\")([ancor, pos, neg])\n", + "\n", + "# собираем модель\n", + "model = keras.models.Model([ancor_meta_in, ancor_interaction_in, pos_in, neg_in], res)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.822662Z", + "start_time": "2021-10-28T18:41:16.817857Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.154784Z", + "iopub.status.busy": "2023-01-22T15:38:53.154419Z", + "iopub.status.idle": "2023-01-22T15:38:53.789679Z", + "shell.execute_reply": "2023-01-22T15:38:53.788675Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.154748Z" + }, + "id": "e912d920" + }, + "outputs": [], + "source": [ + "model_name = 'recsys_resnet_linear'\n", + "\n", + "# логируем процесс обучения в тензорборд\n", + "t_board = keras.callbacks.TensorBoard(log_dir=f'runs/{model_name}')\n", + "\n", + "# уменьшаем learning_rate, если лосс долго не уменьшается (в течение двух эпох)\n", + "decay = keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=2, factor=0.8, verbose=1)\n", + "\n", + "# сохраняем модель после каждой эпохи, если лосс уменьшился\n", + "check = keras.callbacks.ModelCheckpoint(filepath=model_name + '/epoch{epoch}-{loss:.2f}.h5', monitor=\"loss\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.832365Z", + "start_time": "2021-10-28T18:41:16.824484Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.792105Z", + "iopub.status.busy": "2023-01-22T15:38:53.791371Z", + "iopub.status.idle": "2023-01-22T15:38:53.808624Z", + "shell.execute_reply": "2023-01-22T15:38:53.807732Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.792041Z" + }, + "id": "f95049f6" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.\n", + "WARNING:absl:`lr` is deprecated in Keras optimizer, please use `learning_rate` or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.\n" + ] + } + ], + "source": [ + "# компилируем модель, используем оптимайзер Adam и triplet loss\n", + "opt = keras.optimizers.Adam(lr=0.001)\n", + "model.compile(loss=triplet_loss, optimizer=opt)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.867472Z", + "start_time": "2021-10-28T18:41:16.833753Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.811821Z", + "iopub.status.busy": "2023-01-22T15:38:53.811155Z", + "iopub.status.idle": "2023-01-22T15:38:53.852098Z", + "shell.execute_reply": "2023-01-22T15:38:53.851090Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.811786Z" + }, + "id": "fb9382d0", + "outputId": "2eca9a17-1544-4e27-a483-b86d11391767" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_3\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " input_8 (InputLayer) [(None, 9196)] 0 [] \n", + " \n", + " dense_7 (Dense) (None, 128) 1177088 ['input_8[0][0]'] \n", + " \n", + " dense_8 (Dense) (None, 128) 16384 ['dense_7[0][0]'] \n", + " \n", + " add_2 (Add) (None, 128) 0 ['dense_7[0][0]', \n", + " 'dense_8[0][0]'] \n", + " \n", + " dense_9 (Dense) (None, 128) 16384 ['add_2[0][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 1209856 (4.62 MB)\n", + "Trainable params: 1209856 (4.62 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "# модель айтема\n", + "item_model().summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.923402Z", + "start_time": "2021-10-28T18:41:16.868877Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.854198Z", + "iopub.status.busy": "2023-01-22T15:38:53.853594Z", + "iopub.status.idle": "2023-01-22T15:38:53.908177Z", + "shell.execute_reply": "2023-01-22T15:38:53.907222Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.854161Z" + }, + "id": "286149d1", + "outputId": "4284ba09-05ef-4963-c637-67e919701d19" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_4\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " input_9 (InputLayer) [(None, 19)] 0 [] \n", + " \n", + " dense_10 (Dense) (None, 128) 2432 ['input_9[0][0]'] \n", + " \n", + " dense_12 (Dense) (None, 128) 16384 ['dense_10[0][0]'] \n", + " \n", + " input_10 (InputLayer) [(None, 6897)] 0 [] \n", + " \n", + " add_3 (Add) (None, 128) 0 ['dense_10[0][0]', \n", + " 'dense_12[0][0]'] \n", + " \n", + " dense_11 (Dense) (None, 128) 882816 ['input_10[0][0]'] \n", + " \n", + " concatenate_1 (Concatenate (None, 256) 0 ['add_3[0][0]', \n", + " ) 'dense_11[0][0]'] \n", + " \n", + " dense_13 (Dense) (None, 128) 32768 ['concatenate_1[0][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 934400 (3.56 MB)\n", + "Trainable params: 934400 (3.56 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "# модель юзера\n", + "user_model().summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T18:41:16.929341Z", + "start_time": "2021-10-28T18:41:16.924663Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.909970Z", + "iopub.status.busy": "2023-01-22T15:38:53.909370Z", + "iopub.status.idle": "2023-01-22T15:38:53.917202Z", + "shell.execute_reply": "2023-01-22T15:38:53.916103Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.909934Z" + }, + "id": "d9f25a3f", + "outputId": "6f9a3700-4420-4345-8331-82f7207b566b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_2\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " input_4 (InputLayer) [(None, 19)] 0 [] \n", + " \n", + " input_5 (InputLayer) [(None, 6897)] 0 [] \n", + " \n", + " input_6 (InputLayer) [(None, 9196)] 0 [] \n", + " \n", + " input_7 (InputLayer) [(None, 9196)] 0 [] \n", + " \n", + " model_1 (Functional) (None, 128) 934400 ['input_4[0][0]', \n", + " 'input_5[0][0]'] \n", + " \n", + " model (Functional) (None, 128) 1209856 ['input_6[0][0]', \n", + " 'input_7[0][0]'] \n", + " \n", + " concat_ancor_pos_neg (Conc (None, 384) 0 ['model_1[0][0]', \n", + " atenate) 'model[0][0]', \n", + " 'model[1][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 2144256 (8.18 MB)\n", + "Trainable params: 2144256 (8.18 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "# общая модель\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T19:15:21.657529Z", + "start_time": "2021-10-28T19:15:16.365923Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T15:38:53.919463Z", + "iopub.status.busy": "2023-01-22T15:38:53.918611Z", + "iopub.status.idle": "2023-01-22T16:17:01.448835Z", + "shell.execute_reply": "2023-01-22T16:17:01.447888Z", + "shell.execute_reply.started": "2023-01-22T15:38:53.919424Z" + }, + "id": "99d50830", + "outputId": "cee25813-2173-460f-e6f2-024d75d1db08" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/30\n", + "100/100 [==============================] - 42s 418ms/step - loss: 0.4123 - lr: 0.0010\n", + "Epoch 2/30\n", + "100/100 [==============================] - 42s 425ms/step - loss: 0.2973 - lr: 0.0010\n", + "Epoch 3/30\n", + "100/100 [==============================] - 42s 423ms/step - loss: 0.2713 - lr: 0.0010\n", + "Epoch 4/30\n", + "100/100 [==============================] - 642s 6s/step - loss: 0.2240 - lr: 0.0010\n", + "Epoch 5/30\n", + "100/100 [==============================] - 41s 413ms/step - loss: 0.2200 - lr: 0.0010\n", + "Epoch 6/30\n", + "100/100 [==============================] - 41s 415ms/step - loss: 0.1929 - lr: 0.0010\n", + "Epoch 7/30\n", + "100/100 [==============================] - 42s 427ms/step - loss: 0.1727 - lr: 0.0010\n", + "Epoch 8/30\n", + "100/100 [==============================] - 42s 423ms/step - loss: 0.1849 - lr: 0.0010\n", + "Epoch 9/30\n", + "100/100 [==============================] - 41s 418ms/step - loss: 0.1594 - lr: 0.0010\n", + "Epoch 10/30\n", + "100/100 [==============================] - 42s 420ms/step - loss: 0.1485 - lr: 0.0010\n", + "Epoch 11/30\n", + "100/100 [==============================] - 41s 417ms/step - loss: 0.1523 - lr: 0.0010\n", + "Epoch 12/30\n", + "100/100 [==============================] - 41s 415ms/step - loss: 0.1328 - lr: 0.0010\n", + "Epoch 13/30\n", + "100/100 [==============================] - 41s 419ms/step - loss: 0.1407 - lr: 0.0010\n", + "Epoch 14/30\n", + "100/100 [==============================] - ETA: 0s - loss: 0.1524\n", + "Epoch 14: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.\n", + "100/100 [==============================] - 42s 421ms/step - loss: 0.1524 - lr: 0.0010\n", + "Epoch 15/30\n", + "100/100 [==============================] - 42s 420ms/step - loss: 0.1304 - lr: 8.0000e-04\n", + "Epoch 16/30\n", + "100/100 [==============================] - 42s 421ms/step - loss: 0.1305 - lr: 8.0000e-04\n", + "Epoch 17/30\n", + "100/100 [==============================] - 41s 415ms/step - loss: 0.1299 - lr: 8.0000e-04\n", + "Epoch 18/30\n", + "100/100 [==============================] - 41s 416ms/step - loss: 0.1332 - lr: 8.0000e-04\n", + "Epoch 19/30\n", + "100/100 [==============================] - 578s 6s/step - loss: 0.1128 - lr: 8.0000e-04\n", + "Epoch 20/30\n", + "100/100 [==============================] - 298s 3s/step - loss: 0.1168 - lr: 8.0000e-04\n", + "Epoch 21/30\n", + "100/100 [==============================] - ETA: 0s - loss: 0.1145\n", + "Epoch 21: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.\n", + "100/100 [==============================] - 42s 424ms/step - loss: 0.1145 - lr: 8.0000e-04\n", + "Epoch 22/30\n", + "100/100 [==============================] - 42s 423ms/step - loss: 0.1230 - lr: 6.4000e-04\n", + "Epoch 23/30\n", + "100/100 [==============================] - 42s 421ms/step - loss: 0.1108 - lr: 6.4000e-04\n", + "Epoch 24/30\n", + "100/100 [==============================] - 42s 422ms/step - loss: 0.0976 - lr: 6.4000e-04\n", + "Epoch 25/30\n", + "100/100 [==============================] - 42s 422ms/step - loss: 0.1116 - lr: 6.4000e-04\n", + "Epoch 26/30\n", + "100/100 [==============================] - ETA: 0s - loss: 0.0996\n", + "Epoch 26: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.\n", + "100/100 [==============================] - 43s 430ms/step - loss: 0.0996 - lr: 6.4000e-04\n", + "Epoch 27/30\n", + "100/100 [==============================] - 43s 433ms/step - loss: 0.1010 - lr: 5.1200e-04\n", + "Epoch 28/30\n", + "100/100 [==============================] - ETA: 0s - loss: 0.0984\n", + "Epoch 28: ReduceLROnPlateau reducing learning rate to 0.00040960004553198815.\n", + "100/100 [==============================] - 42s 429ms/step - loss: 0.0984 - lr: 5.1200e-04\n", + "Epoch 29/30\n", + "100/100 [==============================] - 43s 430ms/step - loss: 0.0865 - lr: 4.0960e-04\n", + "Epoch 30/30\n", + "100/100 [==============================] - 43s 433ms/step - loss: 0.1049 - lr: 4.0960e-04\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# начинаем обучение, не забывая дропнуть столбцы item_id и user_id \n", + "# из датафреймов при инициализации генератора.\n", + "\n", + "# batch_size можно (и лучше) поставить побольше, если вы не органичены в ресурсах\n", + "\n", + "model.fit(generator(items=items_ohe_df.drop([\"item_id\"], axis=1), \n", + " users=users_ohe_df.drop([\"user_id\"], axis=1), \n", + " interactions=interactions_vec,\n", + " batch_size=16), \n", + " steps_per_epoch=100, \n", + " epochs=30, \n", + " initial_epoch=0,\n", + " callbacks=[decay, t_board, check]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:17:01.453483Z", + "iopub.status.busy": "2023-01-22T16:17:01.453198Z", + "iopub.status.idle": "2023-01-22T16:17:01.486783Z", + "shell.execute_reply": "2023-01-22T16:17:01.485812Z", + "shell.execute_reply.started": "2023-01-22T16:17:01.453458Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" + ] + } + ], + "source": [ + "i2v.save('i2v.hdf5')\n", + "u2v.save('u2v.hdf5')" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T19:15:26.511958Z", + "start_time": "2021-10-28T19:15:26.151899Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:24:28.685695Z", + "iopub.status.busy": "2023-01-22T16:24:28.685290Z", + "iopub.status.idle": "2023-01-22T16:24:30.854186Z", + "shell.execute_reply": "2023-01-22T16:24:30.853120Z", + "shell.execute_reply.started": "2023-01-22T16:24:28.685657Z" + }, + "id": "94d23f62", + "outputId": "4a500ea7-fa38-4455-a51a-2bd0113fa2f2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 129ms/step\n", + "1/1 [==============================] - 0s 29ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[0.76927984]], dtype=float32)" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# берем рандомного юзера\n", + "rand_uid = np.random.choice(list(users_ohe_df.index))\n", + "\n", + "# получаем фичи юзера и вектор его просмотров айтемов\n", + "user_meta_feats = users_ohe_df.drop([\"user_id\"], axis=1).iloc[rand_uid]\n", + "user_interaction_vec = interactions_vec[rand_uid]\n", + "\n", + "# берем рандомный айтем\n", + "rand_iid = np.random.choice(list(items_ohe_df.index))\n", + "# получаем фичи айтема\n", + "item_feats = items_ohe_df.drop([\"item_id\"], axis=1).iloc[rand_iid]\n", + "\n", + "# получаем вектор юзера\n", + "user_vec = u2v.predict([np.array(user_meta_feats).reshape(1, -1), \n", + " np.array(user_interaction_vec).reshape(1, -1)])\n", + "\n", + "# и вектор айтема\n", + "item_vec = i2v.predict(np.array(item_feats).reshape(1, -1))\n", + "\n", + "# считаем расстояние между вектором юзера и вектором айтема\n", + "from sklearn.metrics.pairwise import euclidean_distances as ED\n", + "\n", + "ED(user_vec, item_vec)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-28T19:15:28.951471Z", + "start_time": "2021-10-28T19:15:27.763367Z" + }, + "execution": { + "iopub.execute_input": "2023-01-22T16:24:35.398767Z", + "iopub.status.busy": "2023-01-22T16:24:35.398342Z", + "iopub.status.idle": "2023-01-22T16:24:37.179114Z", + "shell.execute_reply": "2023-01-22T16:24:37.177336Z", + "shell.execute_reply.started": "2023-01-22T16:24:35.398731Z" + }, + "id": "d537d3e8", + "outputId": "6bdce370-c348-4f0b-cd4f-3cbb0ec3d019" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "216/216 [==============================] - 0s 883us/step\n" + ] + } + ], + "source": [ + "# получаем фичи всех айтемов\n", + "items_feats = items_ohe_df.drop([\"item_id\"], axis=1).to_numpy()\n", + "# получаем векторы всех айтемов\n", + "items_vecs = i2v.predict(items_feats)\n", + "\n", + "# считаем расстояния\n", + "dists = ED(user_vec, items_vecs)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:37.200481Z", + "iopub.status.busy": "2023-01-22T16:24:37.199790Z", + "iopub.status.idle": "2023-01-22T16:24:37.219685Z", + "shell.execute_reply": "2023-01-22T16:24:37.218365Z", + "shell.execute_reply.started": "2023-01-22T16:24:37.200421Z" + }, + "id": "udY36b_l0okL", + "outputId": "53287102-2434-490e-d668-b8085515d4b8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(6897, 128)" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_vecs.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:38.051890Z", + "iopub.status.busy": "2023-01-22T16:24:38.051199Z", + "iopub.status.idle": "2023-01-22T16:24:38.063043Z", + "shell.execute_reply": "2023-01-22T16:24:38.061416Z", + "shell.execute_reply.started": "2023-01-22T16:24:38.051840Z" + }, + "id": "XasFl6RN0snT" + }, + "outputs": [], + "source": [ + "users_meta_feats = users_ohe_df.drop([\"user_id\"], axis=1)\n", + "users_interaction_vec = interactions_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:38.561257Z", + "iopub.status.busy": "2023-01-22T16:24:38.560146Z", + "iopub.status.idle": "2023-01-22T16:24:38.568144Z", + "shell.execute_reply": "2023-01-22T16:24:38.566777Z", + "shell.execute_reply.started": "2023-01-22T16:24:38.561176Z" + }, + "id": "cntEZU450_MI", + "outputId": "c9aace32-281a-4b0e-8e0d-8b0f1088ce9b" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 19)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users_meta_feats.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:40.475433Z", + "iopub.status.busy": "2023-01-22T16:24:40.472691Z", + "iopub.status.idle": "2023-01-22T16:24:40.484559Z", + "shell.execute_reply": "2023-01-22T16:24:40.483559Z", + "shell.execute_reply.started": "2023-01-22T16:24:40.475392Z" + }, + "id": "kQ1EZolS1B1Y", + "outputId": "ca9dc5eb-5519-4c75-f1d8-4945941a46d1" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 6897)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users_interaction_vec.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:40.786186Z", + "iopub.status.busy": "2023-01-22T16:24:40.785775Z", + "iopub.status.idle": "2023-01-22T16:24:40.797826Z", + "shell.execute_reply": "2023-01-22T16:24:40.796580Z", + "shell.execute_reply.started": "2023-01-22T16:24:40.786151Z" + }, + "id": "hKU4MD7M1dp5", + "outputId": "bd79e8e2-8a82-4ff7-d1ac-849e3425c2c8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 19)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(users_meta_feats).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:41.775332Z", + "iopub.status.busy": "2023-01-22T16:24:41.774265Z", + "iopub.status.idle": "2023-01-22T16:24:41.780836Z", + "shell.execute_reply": "2023-01-22T16:24:41.779665Z", + "shell.execute_reply.started": "2023-01-22T16:24:41.775281Z" + } + }, + "outputs": [], + "source": [ + "del interactions_vec\n", + "del users_df, interactions_df" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:24:52.966474Z", + "iopub.status.busy": "2023-01-22T16:24:52.965002Z", + "iopub.status.idle": "2023-01-22T16:24:57.402446Z", + "shell.execute_reply": "2023-01-22T16:24:57.401009Z", + "shell.execute_reply.started": "2023-01-22T16:24:52.966417Z" + }, + "id": "x16g5FM21XGJ", + "outputId": "9b43e3e6-f98b-466a-8b0d-c2aeb0ca724e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "625/625 [==============================] - 1s 809us/step\n", + "625/625 [==============================] - 0s 779us/step\n", + "812/812 [==============================] - 1s 763us/step\n" + ] + } + ], + "source": [ + "users_vec_1 = u2v.predict([np.array(users_meta_feats.iloc[:20000]), \n", + " np.array(users_interaction_vec[:20000])])\n", + "users_vec_2 = u2v.predict([np.array(users_meta_feats.iloc[20000:40000]), \n", + " np.array(users_interaction_vec[20000:40000])])\n", + "users_vec_3 = u2v.predict([np.array(users_meta_feats.iloc[40000:]), \n", + " np.array(users_interaction_vec[40000:])])\n", + "users_vec = np.concatenate((users_vec_1, users_vec_2, users_vec_3))" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:25:54.668894Z", + "iopub.status.busy": "2023-01-22T16:25:54.667629Z", + "iopub.status.idle": "2023-01-22T16:25:54.674606Z", + "shell.execute_reply": "2023-01-22T16:25:54.673447Z", + "shell.execute_reply.started": "2023-01-22T16:25:54.668856Z" + } + }, + "outputs": [], + "source": [ + "del users_vec_1, users_vec_2, users_vec_3, users_interaction_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:25:57.451831Z", + "iopub.status.busy": "2023-01-22T16:25:57.451189Z", + "iopub.status.idle": "2023-01-22T16:25:57.458982Z", + "shell.execute_reply": "2023-01-22T16:25:57.457745Z", + "shell.execute_reply.started": "2023-01-22T16:25:57.451795Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 128)" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users_vec.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:20:21.497209Z", + "iopub.status.busy": "2023-01-22T16:20:21.496765Z", + "iopub.status.idle": "2023-01-22T16:20:21.504388Z", + "shell.execute_reply": "2023-01-22T16:20:21.503250Z", + "shell.execute_reply.started": "2023-01-22T16:20:21.497158Z" + }, + "id": "G4pntPu10ogl", + "outputId": "557b6f56-dff5-46e9-da97-569001c59c79" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(6897, 128)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_vecs.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:26:12.212846Z", + "iopub.status.busy": "2023-01-22T16:26:12.212440Z", + "iopub.status.idle": "2023-01-22T16:26:19.980077Z", + "shell.execute_reply": "2023-01-22T16:26:19.978704Z", + "shell.execute_reply.started": "2023-01-22T16:26:12.212812Z" + }, + "id": "hnUX3Yte2Jcw" + }, + "outputs": [], + "source": [ + "dists = ED(users_vec, items_vecs)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:26:33.221953Z", + "iopub.status.busy": "2023-01-22T16:26:33.220783Z", + "iopub.status.idle": "2023-01-22T16:26:33.231255Z", + "shell.execute_reply": "2023-01-22T16:26:33.229877Z", + "shell.execute_reply.started": "2023-01-22T16:26:33.221902Z" + }, + "id": "MDgiwnnu2KHk", + "outputId": "ae9eeb6f-8a29-4195-8bdf-eeaa51b6d049" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 6897)" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dists.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:26:36.257959Z", + "iopub.status.busy": "2023-01-22T16:26:36.257347Z", + "iopub.status.idle": "2023-01-22T16:26:45.531254Z", + "shell.execute_reply": "2023-01-22T16:26:45.530120Z", + "shell.execute_reply.started": "2023-01-22T16:26:36.257910Z" + }, + "id": "Ru8IQwSV2UrB" + }, + "outputs": [], + "source": [ + "top10_iids_1 = np.argsort(dists[:20000], axis=1)[:,:10]\n", + "top10_iids_2 = np.argsort(dists[20000:40000], axis=1)[:,:10]\n", + "top10_iids_3 = np.argsort(dists[40000:], axis=1)[:,:10]\n", + "top10_iids = np.concatenate((top10_iids_1, top10_iids_2, top10_iids_3))" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:28:25.544273Z", + "iopub.status.busy": "2023-01-22T16:28:25.543272Z", + "iopub.status.idle": "2023-01-22T16:28:25.551809Z", + "shell.execute_reply": "2023-01-22T16:28:25.550511Z", + "shell.execute_reply.started": "2023-01-22T16:28:25.544233Z" + }, + "id": "pAzg23jU3TSo", + "outputId": "baeb6ea9-ca6f-4bb9-da7b-3fc16669db23" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 10)" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top10_iids.reshape(dists.shape[0], 10).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:28:37.182827Z", + "iopub.status.busy": "2023-01-22T16:28:37.182088Z", + "iopub.status.idle": "2023-01-22T16:28:37.190183Z", + "shell.execute_reply": "2023-01-22T16:28:37.188831Z", + "shell.execute_reply.started": "2023-01-22T16:28:37.182788Z" + }, + "id": "ehH1-C-S6yE9", + "outputId": "5a08578e-7bc1-404a-db00-5d2114b7ad28" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 10)" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top10_iids.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:28:47.517537Z", + "iopub.status.busy": "2023-01-22T16:28:47.516704Z", + "iopub.status.idle": "2023-01-22T16:28:47.800629Z", + "shell.execute_reply": "2023-01-22T16:28:47.799272Z", + "shell.execute_reply.started": "2023-01-22T16:28:47.517501Z" + }, + "id": "srptkYsFsk1V" + }, + "outputs": [], + "source": [ + "top10_iids_item = [iid_to_item_id[iid] for iid in top10_iids.reshape(-1)]" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:28:51.654826Z", + "iopub.status.busy": "2023-01-22T16:28:51.653959Z", + "iopub.status.idle": "2023-01-22T16:28:51.700602Z", + "shell.execute_reply": "2023-01-22T16:28:51.699239Z", + "shell.execute_reply.started": "2023-01-22T16:28:51.654791Z" + }, + "id": "GWCz9zErskwn" + }, + "outputs": [], + "source": [ + "top10_iids_item = np.array(top10_iids_item).reshape(top10_iids.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:28:57.535876Z", + "iopub.status.busy": "2023-01-22T16:28:57.535194Z", + "iopub.status.idle": "2023-01-22T16:28:57.543046Z", + "shell.execute_reply": "2023-01-22T16:28:57.541704Z", + "shell.execute_reply.started": "2023-01-22T16:28:57.535836Z" + }, + "id": "pNq_brUisknx", + "outputId": "f1980332-ef9e-4920-b470-675a567c815a" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(65974, 10)" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top10_iids_item.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:00.647959Z", + "iopub.status.busy": "2023-01-22T16:29:00.646906Z", + "iopub.status.idle": "2023-01-22T16:29:00.657077Z", + "shell.execute_reply": "2023-01-22T16:29:00.655386Z", + "shell.execute_reply.started": "2023-01-22T16:29:00.647919Z" + }, + "id": "z6ussvRSth2h" + }, + "outputs": [], + "source": [ + "df_dssm = pd.DataFrame(columns = ['user_id', 'item_id'])" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:08.741226Z", + "iopub.status.busy": "2023-01-22T16:29:08.740780Z", + "iopub.status.idle": "2023-01-22T16:29:08.751118Z", + "shell.execute_reply": "2023-01-22T16:29:08.750073Z", + "shell.execute_reply.started": "2023-01-22T16:29:08.741183Z" + }, + "id": "Y9XvpPzRu82h", + "outputId": "8397c843-1427-444e-af8d-f5c5cd19a80d" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_id
\n", + "
" + ], + "text/plain": [ + "Empty DataFrame\n", + "Columns: [user_id, item_id]\n", + "Index: []" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_dssm.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:16.894986Z", + "iopub.status.busy": "2023-01-22T16:29:16.894575Z", + "iopub.status.idle": "2023-01-22T16:29:16.955651Z", + "shell.execute_reply": "2023-01-22T16:29:16.954612Z", + "shell.execute_reply.started": "2023-01-22T16:29:16.894935Z" + }, + "id": "KieINSdwvIu7" + }, + "outputs": [], + "source": [ + "df_dssm = pd.DataFrame({'user_id': [uid_to_user_id[uid] for uid in np.arange(top10_iids_item.shape[0])]})" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:18.345819Z", + "iopub.status.busy": "2023-01-22T16:29:18.345404Z", + "iopub.status.idle": "2023-01-22T16:29:18.371714Z", + "shell.execute_reply": "2023-01-22T16:29:18.370527Z", + "shell.execute_reply.started": "2023-01-22T16:29:18.345785Z" + }, + "id": "RSYHUj7IuzT1" + }, + "outputs": [], + "source": [ + "df_dssm['item_id'] = list(top10_iids_item)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:19.355843Z", + "iopub.status.busy": "2023-01-22T16:29:19.355038Z", + "iopub.status.idle": "2023-01-22T16:29:20.100815Z", + "shell.execute_reply": "2023-01-22T16:29:20.099612Z", + "shell.execute_reply.started": "2023-01-22T16:29:19.355801Z" + }, + "id": "xdPs4HY874OZ" + }, + "outputs": [], + "source": [ + "df_dssm = df_dssm.explode('item_id')\n", + "df_dssm['rank'] = df_dssm.groupby('user_id').cumcount() + 1\n", + "df_dssm = df_dssm.groupby('user_id').agg({'item_id': list}).reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:20.104964Z", + "iopub.status.busy": "2023-01-22T16:29:20.104605Z", + "iopub.status.idle": "2023-01-22T16:29:20.117444Z", + "shell.execute_reply": "2023-01-22T16:29:20.115641Z", + "shell.execute_reply.started": "2023-01-22T16:29:20.104915Z" + }, + "id": "C8kdzzf6wAuz", + "outputId": "df930923-8ad5-45f8-f76c-ab847237802d" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_id
02[4457, 4151, 142, 9988, 4475, 4740, 9169, 5982...
121[4457, 3734, 9988, 4740, 2954, 2657, 4151, 152...
253[4457, 2220, 4151, 142, 4740, 15297, 2657, 134...
360[4457, 4151, 142, 9988, 3734, 6443, 4740, 2954...
481[4151, 4740, 2657, 4457, 15297, 281, 142, 9169...
\n", + "
" + ], + "text/plain": [ + " user_id item_id\n", + "0 2 [4457, 4151, 142, 9988, 4475, 4740, 9169, 5982...\n", + "1 21 [4457, 3734, 9988, 4740, 2954, 2657, 4151, 152...\n", + "2 53 [4457, 2220, 4151, 142, 4740, 15297, 2657, 134...\n", + "3 60 [4457, 4151, 142, 9988, 3734, 6443, 4740, 2954...\n", + "4 81 [4151, 4740, 2657, 4457, 15297, 281, 142, 9169..." + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_dssm.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": { + "execution": { + "iopub.execute_input": "2023-01-22T16:29:31.350415Z", + "iopub.status.busy": "2023-01-22T16:29:31.349997Z", + "iopub.status.idle": "2023-01-22T16:29:31.715454Z", + "shell.execute_reply": "2023-01-22T16:29:31.714324Z", + "shell.execute_reply.started": "2023-01-22T16:29:31.350382Z" + } + }, + "outputs": [], + "source": [ + "df_dssm.to_csv('dssm_predictions.csv', index = False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/hw_5_recbool.ipynb b/hw_5_recbool.ipynb new file mode 100644 index 00000000..cf414323 --- /dev/null +++ b/hw_5_recbool.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T17:56:37.034499Z","iopub.status.busy":"2023-01-22T17:56:37.034012Z","iopub.status.idle":"2023-01-22T17:56:37.042666Z","shell.execute_reply":"2023-01-22T17:56:37.041481Z","shell.execute_reply.started":"2023-01-22T17:56:37.034455Z"},"papermill":{"duration":1.244043,"end_time":"2022-11-27T16:33:29.277270","exception":false,"start_time":"2022-11-27T16:33:28.033227","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import ast\n","import json\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import os\n","import pandas as pd\n","import pickle\n","\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","from collections import Counter\n","from random import randint, random\n","from scipy.sparse import coo_matrix, hstack\n","from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity"]},{"cell_type":"code","execution_count":20,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:33.523160Z","iopub.status.busy":"2023-01-22T18:02:33.522766Z","iopub.status.idle":"2023-01-22T18:02:36.724444Z","shell.execute_reply":"2023-01-22T18:02:36.723409Z","shell.execute_reply.started":"2023-01-22T18:02:33.523126Z"},"papermill":{"duration":6.445298,"end_time":"2022-11-27T16:33:35.747539","exception":false,"start_time":"2022-11-27T16:33:29.302241","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["interactions_df = pd.read_csv('interactions_processed_kion.csv')\n","users_df = pd.read_csv('users_processed_kion.csv')\n","items_df = pd.read_csv('items_processed_kion.csv')"]},{"cell_type":"code","execution_count":21,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:41.118088Z","iopub.status.busy":"2023-01-22T18:02:41.117711Z","iopub.status.idle":"2023-01-22T18:02:42.100146Z","shell.execute_reply":"2023-01-22T18:02:42.098848Z","shell.execute_reply.started":"2023-01-22T18:02:41.118057Z"},"papermill":{"duration":0.925082,"end_time":"2022-11-27T16:33:36.677439","exception":false,"start_time":"2022-11-27T16:33:35.752357","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["interactions_df['t_dat'] = pd.to_datetime(interactions_df['last_watch_dt'], format=\"%Y-%m-%d\")\n","interactions_df['timestamp'] = interactions_df.t_dat.values.astype(np.int64) // 10 ** 9"]},{"cell_type":"code","execution_count":22,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:42.111635Z","iopub.status.busy":"2023-01-22T18:02:42.110287Z","iopub.status.idle":"2023-01-22T18:02:42.408437Z","shell.execute_reply":"2023-01-22T18:02:42.407310Z","shell.execute_reply.started":"2023-01-22T18:02:42.111593Z"},"papermill":{"duration":0.284147,"end_time":"2022-11-27T16:33:36.966533","exception":false,"start_time":"2022-11-27T16:33:36.682386","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["df = interactions_df[['user_id', 'item_id', 'timestamp']].rename(\n"," columns={'user_id': 'user_id:token', 'item_id': 'item_id:token', 'timestamp': 'timestamp:float'})"]},{"cell_type":"code","execution_count":23,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:43.902049Z","iopub.status.busy":"2023-01-22T18:02:43.901227Z","iopub.status.idle":"2023-01-22T18:02:43.927071Z","shell.execute_reply":"2023-01-22T18:02:43.925875Z","shell.execute_reply.started":"2023-01-22T18:02:43.902007Z"},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
user_id:tokenitem_id:tokentimestamp:float
017654995061620691200
169931716591622246400
265668371071620518400
386461376381625443200
496486895061619740800
............
5476246648596122251628812800
547624754686296731618272000
5476248697262152971629417600
5476249384202161971618790400
547625031970944361628985600
\n","

5476251 rows × 3 columns

\n","
"],"text/plain":[" user_id:token item_id:token timestamp:float\n","0 176549 9506 1620691200\n","1 699317 1659 1622246400\n","2 656683 7107 1620518400\n","3 864613 7638 1625443200\n","4 964868 9506 1619740800\n","... ... ... ...\n","5476246 648596 12225 1628812800\n","5476247 546862 9673 1618272000\n","5476248 697262 15297 1629417600\n","5476249 384202 16197 1618790400\n","5476250 319709 4436 1628985600\n","\n","[5476251 rows x 3 columns]"]},"execution_count":23,"metadata":{},"output_type":"execute_result"}],"source":["df"]},{"cell_type":"code","execution_count":25,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:47.909321Z","iopub.status.busy":"2023-01-22T18:02:47.908208Z","iopub.status.idle":"2023-01-22T18:02:54.589560Z","shell.execute_reply":"2023-01-22T18:02:54.588499Z","shell.execute_reply.started":"2023-01-22T18:02:47.909281Z"},"papermill":{"duration":7.834652,"end_time":"2022-11-27T16:33:45.906924","exception":false,"start_time":"2022-11-27T16:33:38.072272","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["df.to_csv('recbox_data/recbox_data.inter', index=False, sep='\\t')"]},{"cell_type":"code","execution_count":28,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:54.592386Z","iopub.status.busy":"2023-01-22T18:02:54.591996Z","iopub.status.idle":"2023-01-22T18:02:55.527789Z","shell.execute_reply":"2023-01-22T18:02:55.526787Z","shell.execute_reply.started":"2023-01-22T18:02:54.592332Z"},"papermill":{"duration":3.067001,"end_time":"2022-11-27T16:34:04.068318","exception":false,"start_time":"2022-11-27T16:34:01.001317","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import logging\n","from logging import getLogger\n","from recbole.config import Config\n","from recbole.data import create_dataset, data_preparation\n","from recbole.model.sequential_recommender import GRU4Rec, Caser\n","from recbole.trainer import Trainer\n","from recbole.utils import init_seed, init_logger\n","from recbole.quick_start import run_recbole"]},{"cell_type":"code","execution_count":29,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:04:51.862473Z","iopub.status.busy":"2023-01-22T18:04:51.862041Z","iopub.status.idle":"2023-01-22T18:04:51.900690Z","shell.execute_reply":"2023-01-22T18:04:51.899741Z","shell.execute_reply.started":"2023-01-22T18:04:51.862435Z"},"papermill":{"duration":0.145622,"end_time":"2022-11-27T16:34:04.220395","exception":false,"start_time":"2022-11-27T16:34:04.074773","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["parameter_dict = {\n"," 'data_path': '',\n"," 'USER_ID_FIELD': 'user_id',\n"," 'ITEM_ID_FIELD': 'item_id',\n"," 'TIME_FIELD': 'timestamp',\n"," 'device': 'GPU',\n"," 'user_inter_num_interval': \"[40,inf)\",\n"," 'item_inter_num_interval': \"[40,inf)\",\n"," 'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},\n"," 'neg_sampling': None,\n"," 'epochs': 10,\n"," 'verbose': -1,\n"," 'show_progress' : False,\n"," 'eval_args': {\n"," 'split': {'RS': [9, 0, 1]},\n"," 'group_by': 'user',\n"," 'order': 'TO',\n"," 'mode': 'full'}\n","}\n","config = Config(model='MultiVAE', dataset='recbox_data', config_dict=parameter_dict)\n","\n","# init random seed\n","init_seed(config['seed'], config['reproducibility'])\n","\n","# logger initialization\n","init_logger(config)\n","logger = getLogger()\n","# Create handlers\n","c_handler = logging.StreamHandler()\n","c_handler.setLevel(logging.INFO)\n","logger.addHandler(c_handler)\n","\n","# write config info into log\n","# logger.info(config)"]},{"cell_type":"code","execution_count":30,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:04:55.538201Z","iopub.status.busy":"2023-01-22T18:04:55.537818Z","iopub.status.idle":"2023-01-22T18:05:32.322220Z","shell.execute_reply":"2023-01-22T18:05:32.321423Z","shell.execute_reply.started":"2023-01-22T18:04:55.538170Z"},"papermill":{"duration":42.583583,"end_time":"2022-11-27T16:34:46.811041","exception":false,"start_time":"2022-11-27T16:34:04.227458","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n"]}],"source":["dataset = create_dataset(config)\n","logger.info(dataset)"]},{"cell_type":"code","execution_count":31,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:32.324208Z","iopub.status.busy":"2023-01-22T18:05:32.323852Z","iopub.status.idle":"2023-01-22T18:05:34.256086Z","shell.execute_reply":"2023-01-22T18:05:34.255320Z","shell.execute_reply.started":"2023-01-22T18:05:32.324171Z"},"papermill":{"duration":2.241551,"end_time":"2022-11-27T16:34:49.059852","exception":false,"start_time":"2022-11-27T16:34:46.818301","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 11:56 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n"]}],"source":["# dataset splitting\n","train_data, valid_data, test_data = data_preparation(config, dataset)"]},{"cell_type":"code","execution_count":32,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:34.257762Z","iopub.status.busy":"2023-01-22T18:05:34.257174Z","iopub.status.idle":"2023-01-22T18:05:34.262360Z","shell.execute_reply":"2023-01-22T18:05:34.261553Z","shell.execute_reply.started":"2023-01-22T18:05:34.257723Z"},"papermill":{"duration":0.01694,"end_time":"2022-11-27T16:34:49.085164","exception":false,"start_time":"2022-11-27T16:34:49.068224","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import time"]},{"cell_type":"markdown","metadata":{},"source":["### Использование различных архитектур"]},{"cell_type":"code","execution_count":33,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:41.096708Z","iopub.status.busy":"2023-01-22T18:05:41.096214Z","iopub.status.idle":"2023-01-22T18:11:38.568018Z","shell.execute_reply":"2023-01-22T18:11:38.567070Z","shell.execute_reply.started":"2023-01-22T18:05:41.096667Z"},"papermill":{"duration":27259.293886,"end_time":"2022-11-28T00:09:08.387403","exception":false,"start_time":"2022-11-27T16:34:49.093517","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["running LightGCN...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 11:56 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","n_layers = 2\n","reg_weight = 1e-05\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","n_layers = 2\n","reg_weight = 1e-05\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 11:58 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 11:58 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 11:58 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 11:58 INFO LightGCN(\n"," (user_embedding): Embedding(13355, 64)\n"," (item_embedding): Embedding(3294, 64)\n"," (mf_loss): BPRLoss()\n"," (reg_loss): EmbLoss()\n",")\n","Trainable parameters: 1065536\n","LightGCN(\n"," (user_embedding): Embedding(13355, 64)\n"," (item_embedding): Embedding(3294, 64)\n"," (mf_loss): BPRLoss()\n"," (reg_loss): EmbLoss()\n",")\n","Trainable parameters: 1065536\n","11 Dec 11:58 INFO FLOPs: 0.0\n","FLOPs: 0.0\n","11 Dec 12:00 INFO epoch 0 training [time: 96.30s, train loss: 201.8552]\n","epoch 0 training [time: 96.30s, train loss: 201.8552]\n","11 Dec 12:00 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:01 INFO epoch 1 training [time: 91.67s, train loss: 166.0586]\n","epoch 1 training [time: 91.67s, train loss: 166.0586]\n","11 Dec 12:01 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:03 INFO epoch 2 training [time: 106.54s, train loss: 156.0221]\n","epoch 2 training [time: 106.54s, train loss: 156.0221]\n","11 Dec 12:03 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:05 INFO epoch 3 training [time: 125.49s, train loss: 149.5909]\n","epoch 3 training [time: 125.49s, train loss: 149.5909]\n","11 Dec 12:05 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:07 INFO epoch 4 training [time: 123.82s, train loss: 146.1851]\n","epoch 4 training [time: 123.82s, train loss: 146.1851]\n","11 Dec 12:07 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:10 INFO epoch 5 training [time: 135.29s, train loss: 143.7300]\n","epoch 5 training [time: 135.29s, train loss: 143.7300]\n","11 Dec 12:10 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:12 INFO epoch 6 training [time: 152.14s, train loss: 141.2300]\n","epoch 6 training [time: 152.14s, train loss: 141.2300]\n","11 Dec 12:12 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:14 INFO epoch 7 training [time: 114.99s, train loss: 137.4871]\n","epoch 7 training [time: 114.99s, train loss: 137.4871]\n","11 Dec 12:14 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:16 INFO epoch 8 training [time: 128.70s, train loss: 133.3195]\n","epoch 8 training [time: 128.70s, train loss: 133.3195]\n","11 Dec 12:16 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO epoch 9 training [time: 126.52s, train loss: 129.4056]\n","epoch 9 training [time: 126.52s, train loss: 129.4056]\n","11 Dec 12:18 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO Loading model structure and parameters from saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Loading model structure and parameters from saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 48.50 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.07 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 48.50 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.07 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:18 INFO best valid : None\n","best valid : None\n","11 Dec 12:18 INFO test result: OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])\n","test result: OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 21.95 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])}\n","running MultiVAE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 12:18 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 12:18 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","mlp_hidden_size = [600]\n","latent_dimension = 128\n","dropout_prob = 0.5\n","anneal_cap = 0.2\n","total_anneal_steps = 200000\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","mlp_hidden_size = [600]\n","latent_dimension = 128\n","dropout_prob = 0.5\n","anneal_cap = 0.2\n","total_anneal_steps = 200000\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 12:21 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 12:21 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 12:21 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 12:21 WARNING Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","11 Dec 12:21 INFO MultiVAE(\n"," (encoder): Sequential(\n"," (0): Linear(in_features=3294, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=128, bias=True)\n"," )\n"," (decoder): Sequential(\n"," (0): Linear(in_features=64, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=3294, bias=True)\n"," )\n",")\n","Trainable parameters: 4072622\n","MultiVAE(\n"," (encoder): Sequential(\n"," (0): Linear(in_features=3294, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=128, bias=True)\n"," )\n"," (decoder): Sequential(\n"," (0): Linear(in_features=64, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=3294, bias=True)\n"," )\n",")\n","Trainable parameters: 4072622\n","11 Dec 12:21 INFO FLOPs: 4068000.0\n","FLOPs: 4068000.0\n","11 Dec 12:21 INFO epoch 0 training [time: 2.16s, train loss: 3249.3142]\n","epoch 0 training [time: 2.16s, train loss: 3249.3142]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 1 training [time: 1.96s, train loss: 3098.4010]\n","epoch 1 training [time: 1.96s, train loss: 3098.4010]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 2 training [time: 1.97s, train loss: 3045.1938]\n","epoch 2 training [time: 1.97s, train loss: 3045.1938]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 3 training [time: 2.02s, train loss: 3008.0520]\n","epoch 3 training [time: 2.02s, train loss: 3008.0520]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 4 training [time: 2.58s, train loss: 2949.4743]\n","epoch 4 training [time: 2.58s, train loss: 2949.4743]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 5 training [time: 2.14s, train loss: 2917.6707]\n","epoch 5 training [time: 2.14s, train loss: 2917.6707]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 6 training [time: 2.28s, train loss: 2897.4954]\n","epoch 6 training [time: 2.28s, train loss: 2897.4954]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 7 training [time: 2.03s, train loss: 2885.5641]\n","epoch 7 training [time: 2.03s, train loss: 2885.5641]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 8 training [time: 2.38s, train loss: 2871.9012]\n","epoch 8 training [time: 2.38s, train loss: 2871.9012]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 9 training [time: 2.46s, train loss: 2851.2055]\n","epoch 9 training [time: 2.46s, train loss: 2851.2055]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO Loading model structure and parameters from saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Loading model structure and parameters from saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 74.20 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.08 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 74.20 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.08 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:22 INFO best valid : None\n","best valid : None\n","11 Dec 12:22 INFO test result: OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])\n","test result: OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 3.87 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])}\n","running RecVAE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 12:22 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 12:22 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","hidden_dimension = 600\n","latent_dimension = 200\n","dropout_prob = 0.5\n","beta = 0.2\n","gamma = 0.005\n","mixture_weights = [0.15, 0.75, 0.1]\n","n_enc_epochs = 3\n","n_dec_epochs = 1\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","hidden_dimension = 600\n","latent_dimension = 200\n","dropout_prob = 0.5\n","beta = 0.2\n","gamma = 0.005\n","mixture_weights = [0.15, 0.75, 0.1]\n","n_enc_epochs = 3\n","n_dec_epochs = 1\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 12:25 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 12:25 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 12:25 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 12:25 WARNING Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","11 Dec 12:25 INFO RecVAE(\n"," (encoder): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," (prior): CompositePrior(\n"," (encoder_old): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," )\n"," (decoder): Linear(in_features=200, out_features=3294, bias=True)\n",")\n","Trainable parameters: 4327894\n","RecVAE(\n"," (encoder): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," (prior): CompositePrior(\n"," (encoder_old): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," )\n"," (decoder): Linear(in_features=200, out_features=3294, bias=True)\n",")\n","Trainable parameters: 4327894\n","11 Dec 12:25 INFO FLOPs: 4321200.0\n","FLOPs: 4321200.0\n","11 Dec 12:25 INFO epoch 0 training [time: 23.87s, train loss: 2354.4009]\n","epoch 0 training [time: 23.87s, train loss: 2354.4009]\n","11 Dec 12:25 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:26 INFO epoch 1 training [time: 26.41s, train loss: 2247.2854]\n","epoch 1 training [time: 26.41s, train loss: 2247.2854]\n","11 Dec 12:26 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:26 INFO epoch 2 training [time: 26.19s, train loss: 2184.4206]\n","epoch 2 training [time: 26.19s, train loss: 2184.4206]\n","11 Dec 12:26 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 3 training [time: 28.26s, train loss: 2147.9836]\n","epoch 3 training [time: 28.26s, train loss: 2147.9836]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 4 training [time: 25.59s, train loss: 2108.6837]\n","epoch 4 training [time: 25.59s, train loss: 2108.6837]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 5 training [time: 19.99s, train loss: 2073.2995]\n","epoch 5 training [time: 19.99s, train loss: 2073.2995]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 6 training [time: 23.58s, train loss: 2043.1616]\n","epoch 6 training [time: 23.58s, train loss: 2043.1616]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 7 training [time: 24.14s, train loss: 2013.9314]\n","epoch 7 training [time: 24.14s, train loss: 2013.9314]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 8 training [time: 12.42s, train loss: 1998.7426]\n","epoch 8 training [time: 12.42s, train loss: 1998.7426]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 9 training [time: 10.69s, train loss: 1973.2974]\n","epoch 9 training [time: 10.69s, train loss: 1973.2974]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO Loading model structure and parameters from saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Loading model structure and parameters from saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:29 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 50.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.09 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 50.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.09 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:29 INFO best valid : None\n","best valid : None\n","11 Dec 12:29 INFO test result: OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])\n","test result: OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 6.70 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])}\n","CPU times: user 25min 39s, sys: 9min 10s, total: 34min 49s\n","Wall time: 32min 31s\n"]}],"source":["%%time\n","model_list = [ \"LightGCN\", \"MultiVAE\", \"RecVAE\"] \n","\n","for model_name in model_list:\n"," print(f\"running {model_name}...\")\n"," start = time.time()\n"," result = run_recbole(model=model_name, dataset = 'recbox_data',config_dict = parameter_dict)\n"," t = time.time() - start\n"," print(f\"It took {t/60:.2f} mins\")\n"," print(result)"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Collecting kmeans-pytorch\n"," Downloading kmeans_pytorch-0.3-py3-none-any.whl (4.4 kB)\n","Installing collected packages: kmeans-pytorch\n","Successfully installed kmeans-pytorch-0.3\n","Note: you may need to restart the kernel to use updated packages.\n"]}],"source":["%pip install kmeans-pytorch"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[],"source":["from kmeans_pytorch import kmeans"]},{"cell_type":"code","execution_count":37,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:14:48.482175Z","iopub.status.busy":"2023-01-22T18:14:48.481796Z","iopub.status.idle":"2023-01-22T19:32:27.636297Z","shell.execute_reply":"2023-01-22T19:32:27.635371Z","shell.execute_reply.started":"2023-01-22T18:14:48.482143Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["running CORE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 13:40 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 13:40 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","inner_size = 256\n","n_layers = 2\n","n_heads = 2\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","dnn_type = trm\n","sess_dropout = 0.2\n","item_dropout = 0.2\n","temperature = 0.07\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","inner_size = 256\n","n_layers = 2\n","n_heads = 2\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","dnn_type = trm\n","sess_dropout = 0.2\n","item_dropout = 0.2\n","temperature = 0.07\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 13:41 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 13:42 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 13:42 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 13:42 INFO CORE(\n"," (sess_dropout): Dropout(p=0.2, inplace=False)\n"," (item_dropout): Dropout(p=0.2, inplace=False)\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (net): TransNet(\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): TransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x TransformerLayer(\n"," (multi_head_attention): MultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (softmax): Softmax(dim=-1)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (fn): Linear(in_features=64, out_features=1, bias=True)\n"," )\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 314177\n","CORE(\n"," (sess_dropout): Dropout(p=0.2, inplace=False)\n"," (item_dropout): Dropout(p=0.2, inplace=False)\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (net): TransNet(\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): TransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x TransformerLayer(\n"," (multi_head_attention): MultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (softmax): Softmax(dim=-1)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (fn): Linear(in_features=64, out_features=1, bias=True)\n"," )\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 314177\n","11 Dec 13:42 INFO FLOPs: 4986664.0\n","FLOPs: 4986664.0\n","11 Dec 14:56 INFO epoch 0 training [time: 4465.02s, train loss: 3014.5362]\n","epoch 0 training [time: 4465.02s, train loss: 3014.5362]\n","11 Dec 14:56 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 15:22 INFO epoch 1 training [time: 1551.87s, train loss: 2695.9846]\n","epoch 1 training [time: 1551.87s, train loss: 2695.9846]\n","11 Dec 15:22 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 15:51 INFO epoch 2 training [time: 1725.13s, train loss: 2613.6121]\n","epoch 2 training [time: 1725.13s, train loss: 2613.6121]\n","11 Dec 15:51 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 16:20 INFO epoch 3 training [time: 1788.98s, train loss: 2579.6137]\n","epoch 3 training [time: 1788.98s, train loss: 2579.6137]\n","11 Dec 16:20 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 17:53 INFO epoch 4 training [time: 5548.65s, train loss: 2562.4357]\n","epoch 4 training [time: 5548.65s, train loss: 2562.4357]\n","11 Dec 17:53 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 18:23 INFO epoch 5 training [time: 1835.07s, train loss: 2551.7563]\n","epoch 5 training [time: 1835.07s, train loss: 2551.7563]\n","11 Dec 18:23 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 18:52 INFO epoch 6 training [time: 1691.97s, train loss: 2545.4003]\n","epoch 6 training [time: 1691.97s, train loss: 2545.4003]\n","11 Dec 18:52 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 19:23 INFO epoch 7 training [time: 1858.18s, train loss: 2540.3678]\n","epoch 7 training [time: 1858.18s, train loss: 2540.3678]\n","11 Dec 19:23 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 20:06 INFO epoch 8 training [time: 2614.77s, train loss: 2537.1684]\n","epoch 8 training [time: 2614.77s, train loss: 2537.1684]\n","11 Dec 20:06 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:22 INFO epoch 9 training [time: 4561.69s, train loss: 2534.2584]\n","epoch 9 training [time: 4561.69s, train loss: 2534.2584]\n","11 Dec 21:22 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:22 INFO Loading model structure and parameters from saved/CORE-Dec-11-2023_13-42-03.pth\n","Loading model structure and parameters from saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:23 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 46.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.77 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 46.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.77 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 21:23 INFO best valid : None\n","best valid : None\n","11 Dec 21:23 INFO test result: OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])\n","test result: OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 463.62 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])}\n","running LightSANs...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 21:23 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 21:23 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","k_interests = 5\n","n_layers = 2\n","n_heads = 2\n","hidden_size = 64\n","inner_size = 256\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","k_interests = 5\n","n_layers = 2\n","n_heads = 2\n","hidden_size = 64\n","inner_size = 256\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 21:31 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 21:31 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 21:31 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 21:31 INFO LightSANs(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): LightTransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x LightTransformerLayer(\n"," (multi_head_attention): LightMultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (attpooling_key): ItemToInterestAggregation()\n"," (attpooling_value): ItemToInterestAggregation()\n"," (pos_q_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_k_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_ln): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 332288\n","LightSANs(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): LightTransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x LightTransformerLayer(\n"," (multi_head_attention): LightMultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (attpooling_key): ItemToInterestAggregation()\n"," (attpooling_value): ItemToInterestAggregation()\n"," (pos_q_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_k_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_ln): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 332288\n","11 Dec 21:31 INFO FLOPs: 5785664.0\n","FLOPs: 5785664.0\n","11 Dec 22:10 INFO epoch 0 training [time: 2297.51s, train loss: 2745.5162]\n","epoch 0 training [time: 2297.51s, train loss: 2745.5162]\n","11 Dec 22:10 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 22:35 INFO epoch 1 training [time: 1523.42s, train loss: 2594.7601]\n","epoch 1 training [time: 1523.42s, train loss: 2594.7601]\n","11 Dec 22:35 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:00 INFO epoch 2 training [time: 1474.23s, train loss: 2552.7842]\n","epoch 2 training [time: 1474.23s, train loss: 2552.7842]\n","11 Dec 23:00 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:24 INFO epoch 3 training [time: 1459.72s, train loss: 2529.9439]\n","epoch 3 training [time: 1459.72s, train loss: 2529.9439]\n","11 Dec 23:24 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:47 INFO epoch 4 training [time: 1402.02s, train loss: 2516.8341]\n","epoch 4 training [time: 1402.02s, train loss: 2516.8341]\n","11 Dec 23:47 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 00:44 INFO epoch 5 training [time: 3408.17s, train loss: 2508.0956]\n","epoch 5 training [time: 3408.17s, train loss: 2508.0956]\n","12 Dec 00:44 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 01:06 INFO epoch 6 training [time: 1326.43s, train loss: 2501.2301]\n","epoch 6 training [time: 1326.43s, train loss: 2501.2301]\n","12 Dec 01:06 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 02:02 INFO epoch 7 training [time: 3351.58s, train loss: 2496.2055]\n","epoch 7 training [time: 3351.58s, train loss: 2496.2055]\n","12 Dec 02:02 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 06:08 INFO epoch 8 training [time: 14777.87s, train loss: 2491.3191]\n","epoch 8 training [time: 14777.87s, train loss: 2491.3191]\n","12 Dec 06:08 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:40 INFO epoch 9 training [time: 9116.53s, train loss: 2487.0272]\n","epoch 9 training [time: 9116.53s, train loss: 2487.0272]\n","12 Dec 08:40 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:40 INFO Loading model structure and parameters from saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Loading model structure and parameters from saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:41 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 30.70 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.82 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 30.70 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.82 G/8.00 G |\n","+-------------+---------------+\n","12 Dec 08:41 INFO best valid : None\n","best valid : None\n","12 Dec 08:41 INFO test result: OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])\n","test result: OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 677.87 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])}\n","running NextItNet...\n"]},{"name":"stderr","output_type":"stream","text":["12 Dec 08:41 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","12 Dec 08:41 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","kernel_size = 3\n","block_num = 5\n","dilations = [1, 4]\n","reg_weight = 1e-05\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","kernel_size = 3\n","block_num = 5\n","dilations = [1, 4]\n","reg_weight = 1e-05\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","12 Dec 08:43 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","12 Dec 08:43 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","12 Dec 08:43 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","12 Dec 08:43 INFO NextItNet(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (residual_blocks): Sequential(\n"," (0): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (1): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (2): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (3): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (4): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (5): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (6): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (7): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (8): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (9): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," )\n"," (final_layer): Linear(in_features=64, out_features=64, bias=True)\n"," (loss_fct): CrossEntropyLoss()\n"," (reg_loss): RegLoss()\n",")\n","Trainable parameters: 464576\n","NextItNet(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (residual_blocks): Sequential(\n"," (0): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (1): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (2): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (3): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (4): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (5): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (6): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (7): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (8): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (9): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," )\n"," (final_layer): Linear(in_features=64, out_features=64, bias=True)\n"," (loss_fct): CrossEntropyLoss()\n"," (reg_loss): RegLoss()\n",")\n","Trainable parameters: 464576\n","12 Dec 08:43 INFO FLOPs: 12423360.0\n","FLOPs: 12423360.0\n","12 Dec 10:08 INFO epoch 0 training [time: 5095.82s, train loss: 2732.6105]\n","epoch 0 training [time: 5095.82s, train loss: 2732.6105]\n","12 Dec 10:08 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 11:33 INFO epoch 1 training [time: 5097.32s, train loss: 2601.4325]\n","epoch 1 training [time: 5097.32s, train loss: 2601.4325]\n","12 Dec 11:33 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 13:31 INFO epoch 2 training [time: 7077.48s, train loss: 2554.3704]\n","epoch 2 training [time: 7077.48s, train loss: 2554.3704]\n","12 Dec 13:31 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 14:39 INFO epoch 3 training [time: 4117.24s, train loss: 2529.2335]\n","epoch 3 training [time: 4117.24s, train loss: 2529.2335]\n","12 Dec 14:39 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 17:57 INFO epoch 4 training [time: 11838.39s, train loss: 2512.5584]\n","epoch 4 training [time: 11838.39s, train loss: 2512.5584]\n","12 Dec 17:57 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 18:42 INFO epoch 5 training [time: 2724.19s, train loss: 2497.9890]\n","epoch 5 training [time: 2724.19s, train loss: 2497.9890]\n","12 Dec 18:42 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 19:30 INFO epoch 6 training [time: 2856.86s, train loss: 2485.4469]\n","epoch 6 training [time: 2856.86s, train loss: 2485.4469]\n","12 Dec 19:30 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 20:13 INFO epoch 7 training [time: 2609.22s, train loss: 2474.9533]\n","epoch 7 training [time: 2609.22s, train loss: 2474.9533]\n","12 Dec 20:13 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 21:18 INFO epoch 8 training [time: 3904.76s, train loss: 2465.3467]\n","epoch 8 training [time: 3904.76s, train loss: 2465.3467]\n","12 Dec 21:18 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:13 INFO epoch 9 training [time: 3272.21s, train loss: 2456.8602]\n","epoch 9 training [time: 3272.21s, train loss: 2456.8602]\n","12 Dec 22:13 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:13 INFO Loading model structure and parameters from saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Loading model structure and parameters from saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:16 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 11.80 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.74 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 11.80 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.74 G/8.00 G |\n","+-------------+---------------+\n","12 Dec 22:16 INFO best valid : None\n","best valid : None\n","12 Dec 22:16 INFO test result: OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])\n","test result: OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 814.55 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])}\n","CPU times: user 1d 24min 58s, sys: 13h 28min 5s, total: 1d 13h 53min 4s\n","Wall time: 1d 8h 36min 2s\n"]}],"source":["%%time\n","model_list = [\"CORE\", \"LightSANs\", \"NextItNet\",] \n","\n","parameter_dict[\"train_neg_sample_args\"] = None\n","\n","for model_name in model_list:\n"," print(f\"running {model_name}...\")\n"," start = time.time()\n"," result = run_recbole(model=model_name, dataset = 'recbox_data', config_dict = parameter_dict)\n"," t = time.time() - start\n"," print(f\"It took {t/60:.2f} mins\")\n"," print(result)"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"papermill":{"default_parameters":{},"duration":27491.154881,"end_time":"2022-11-28T00:11:27.624787","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2022-11-27T16:33:16.469906","version":"2.3.4"}},"nbformat":4,"nbformat_minor":5} From 9238b51fb75bf2ac1571b24f98fc55e23acb863b Mon Sep 17 00:00:00 2001 From: Anna Pikuleva Date: Fri, 15 Dec 2023 09:08:54 +0300 Subject: [PATCH 7/7] all changes --- service/api/views.py | 58 +++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/service/api/views.py b/service/api/views.py index 24cf4a7f..95178c1e 100644 --- a/service/api/views.py +++ b/service/api/views.py @@ -1,20 +1,38 @@ from typing import List - -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, Depends, FastAPI, Request from pydantic import BaseModel - -from service.api.exceptions import UserNotFoundError +import dill +from service.api.exceptions import ModelNotFoundError, UnauthorizedUserError, UserNotFoundError from service.log import app_logger +import pandas as pd +from service.models import recommend_popular + + +# load predictions of dssm model +dssm_preds = pd.read_csv("dssm_predictions.csv") +dssm_preds.item_id = dssm_preds.item_id.apply(lambda x: [int(i) for i in x[1:-1].split(", ")]) + + +# get popular recommendations +interactions = pd.read_csv('data/interactions.csv') +interactions['last_watch_dt'] = pd.to_datetime(interactions['last_watch_dt']) +interactions.rename( + columns={ + 'last_watch_dt': 'datetime', + 'total_dur': 'weight', + }, + inplace=True, + ) +popular_recs = recommend_popular(interactions) +popular_recs_30 = recommend_popular(interactions, days = 30) class RecoResponse(BaseModel): user_id: int items: List[int] - router = APIRouter() - @router.get( path="/health", tags=["Health"], @@ -23,26 +41,32 @@ async def health() -> str: return "I am alive" + @router.get( path="/reco/{model_name}/{user_id}", tags=["Recommendations"], - response_model=RecoResponse, + response_model=RecoResponse ) async def get_reco( - request: Request, - model_name: str, - user_id: int, -) -> RecoResponse: - app_logger.info(f"Request for model: {model_name}, user_id: {user_id}") - - # Write your code here + request: Request, + model_name: str, + user_id: int, + # token=Depends(bearer) + ) -> RecoResponse: + # app_logger.info(f"Request for model: {model_name}, user_id: {user_id}") + app_logger.info(f"Request for model: {model_name}") + app_logger.info(f"Request for user: {user_id}") if user_id > 10**9: raise UserNotFoundError(error_message=f"User {user_id} not found") + + if model_name == "DSSM": + try: + recs_list = dssm_preds[dssm_preds.user_id == user_id].item_id.values[0] + except: + recs_list = popular_recs_30 - k_recs = request.app.state.k_recs - reco = list(range(k_recs)) - return RecoResponse(user_id=user_id, items=reco) + return RecoResponse(user_id=user_id, items=recs_list) def add_views(app: FastAPI) -> None: