Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- new_user.json +8 -0
- train.py +181 -0
- transactions.csv +0 -0
new_user.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"user_id": 500,
|
3 |
+
"liked_categories": [
|
4 |
+
"Cosmetics",
|
5 |
+
"HairCare",
|
6 |
+
"Food"
|
7 |
+
]
|
8 |
+
}
|
train.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard library imports
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch_geometric
|
11 |
+
from torch_geometric.nn.conv import MessagePassing
|
12 |
+
from torch_geometric.utils import degree
|
13 |
+
from tqdm.notebook import tqdm
|
14 |
+
from sklearn import preprocessing as pp
|
15 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler
|
16 |
+
from sklearn.model_selection import train_test_split
|
17 |
+
import scipy.sparse as sp
|
18 |
+
from networkx.algorithms import bipartite
|
19 |
+
from networkx.drawing.layout import bipartite_layout
|
20 |
+
import datetime
|
21 |
+
import os
|
22 |
+
import fcntl
|
23 |
+
|
24 |
+
|
25 |
+
os.makedirs("/app/models", exist_ok=True)
|
26 |
+
# Device configuration
|
27 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
28 |
+
|
29 |
+
# Load dataset
|
30 |
+
df = pd.read_csv('transactions.csv')
|
31 |
+
df = df.rename(columns={'Customer_ID': 'userId', 'Rating': 'rating', 'Date': 'timestamp'})
|
32 |
+
df_collab = df.drop(['Transaction_ID', 'Quantity', 'Price', 'Product_Category', 'Product_Name', 'Product_Brand'], axis=1)
|
33 |
+
df_collab.rename(columns={"Timestamp": "timestamp"}, inplace=True)
|
34 |
+
df_collab = df_collab[['userId', 'rating', 'timestamp', 'Item_ID']]
|
35 |
+
df_collab = df_collab[df_collab['rating'] >= 3]
|
36 |
+
|
37 |
+
# Train-test split
|
38 |
+
train, test = train_test_split(df_collab.values, test_size=0.2, random_state=16)
|
39 |
+
train_df = pd.DataFrame(train, columns=df_collab.columns)
|
40 |
+
test_df = pd.DataFrame(test, columns=df_collab.columns)
|
41 |
+
|
42 |
+
# Label encoding
|
43 |
+
le_user = pp.LabelEncoder()
|
44 |
+
le_item = pp.LabelEncoder()
|
45 |
+
train_df['user_id_idx'] = le_user.fit_transform(train_df['userId'].values)
|
46 |
+
train_df['item_id_idx'] = le_item.fit_transform(train_df['Item_ID'].values)
|
47 |
+
|
48 |
+
train_user_ids = train_df['userId'].unique()
|
49 |
+
train_item_ids = train_df['Item_ID'].unique()
|
50 |
+
|
51 |
+
test_df = test_df[
|
52 |
+
(test_df['userId'].isin(train_user_ids)) & (test_df['Item_ID'].isin(train_item_ids))
|
53 |
+
]
|
54 |
+
test_df['user_id_idx'] = le_user.transform(test_df['userId'].values)
|
55 |
+
test_df['item_id_idx'] = le_item.transform(test_df['Item_ID'].values)
|
56 |
+
|
57 |
+
n_users = train_df['user_id_idx'].nunique()
|
58 |
+
n_items = train_df['item_id_idx'].nunique()
|
59 |
+
|
60 |
+
# Data loader function
|
61 |
+
def data_loader(data, batch_size, n_usr, n_itm):
|
62 |
+
def sample_neg(x):
|
63 |
+
while True:
|
64 |
+
neg_id = random.randint(0, n_itm - 1)
|
65 |
+
if neg_id not in x:
|
66 |
+
return neg_id
|
67 |
+
|
68 |
+
interected_items_df = data.groupby('user_id_idx')['item_id_idx'].apply(list).reset_index()
|
69 |
+
indices = list(range(n_usr))
|
70 |
+
|
71 |
+
if n_usr < batch_size:
|
72 |
+
users = [random.choice(indices) for _ in range(batch_size)]
|
73 |
+
else:
|
74 |
+
users = random.sample(indices, batch_size)
|
75 |
+
users.sort()
|
76 |
+
users_df = pd.DataFrame(users, columns=['users'])
|
77 |
+
|
78 |
+
interected_items_df = pd.merge(interected_items_df, users_df, how='right', left_on='user_id_idx', right_on='users')
|
79 |
+
pos_items = interected_items_df['item_id_idx'].apply(lambda x: random.choice(x)).values
|
80 |
+
neg_items = interected_items_df['item_id_idx'].apply(lambda x: sample_neg(x)).values
|
81 |
+
|
82 |
+
return (
|
83 |
+
torch.LongTensor(list(users)).to(device),
|
84 |
+
torch.LongTensor(list(pos_items)).to(device) + n_usr,
|
85 |
+
torch.LongTensor(list(neg_items)).to(device) + n_usr
|
86 |
+
)
|
87 |
+
|
88 |
+
data_loader(train_df, 2, n_users, n_items)
|
89 |
+
|
90 |
+
# Graph construction
|
91 |
+
u_t = torch.LongTensor(train_df.user_id_idx)
|
92 |
+
i_t = torch.LongTensor(train_df.item_id_idx) + n_users
|
93 |
+
train_edge_index = torch.stack((torch.cat([u_t, i_t]), torch.cat([i_t, u_t]))).to(device)
|
94 |
+
|
95 |
+
# LightGCNConv class
|
96 |
+
class LightGCNConv(MessagePassing):
|
97 |
+
def __init__(self):
|
98 |
+
super().__init__(aggr='add')
|
99 |
+
|
100 |
+
def forward(self, x, edge_index):
|
101 |
+
from_, to_ = edge_index
|
102 |
+
deg = degree(to_, x.size(0), dtype=x.dtype)
|
103 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
104 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
105 |
+
norm = deg_inv_sqrt[from_] * deg_inv_sqrt[to_]
|
106 |
+
return self.propagate(edge_index, x=x, norm=norm)
|
107 |
+
|
108 |
+
def message(self, x_j, norm):
|
109 |
+
return norm.view(-1, 1) * x_j
|
110 |
+
|
111 |
+
# Recommendation System Model
|
112 |
+
class RecSysGNN(nn.Module):
|
113 |
+
def __init__(self, latent_dim, num_layers, num_users, num_items):
|
114 |
+
super(RecSysGNN, self).__init__()
|
115 |
+
self.embedding = nn.Embedding(num_users + num_items, latent_dim)
|
116 |
+
self.convs = nn.ModuleList(LightGCNConv() for _ in range(num_layers))
|
117 |
+
self.init_parameters()
|
118 |
+
|
119 |
+
def init_parameters(self):
|
120 |
+
nn.init.normal_(self.embedding.weight, std=0.1)
|
121 |
+
|
122 |
+
def forward(self, edge_index):
|
123 |
+
emb0 = self.embedding.weight
|
124 |
+
embs = [emb0]
|
125 |
+
emb = emb0
|
126 |
+
for conv in self.convs:
|
127 |
+
emb = conv(x=emb, edge_index=edge_index)
|
128 |
+
embs.append(emb)
|
129 |
+
return emb0, torch.mean(torch.stack(embs, dim=0), dim=0)
|
130 |
+
|
131 |
+
def encode_minibatch(self, users, pos_items, neg_items, edge_index):
|
132 |
+
emb0, out = self(edge_index)
|
133 |
+
return out[users], out[pos_items], out[neg_items], emb0[users], emb0[pos_items], emb0[neg_items]
|
134 |
+
|
135 |
+
# Compute BPR Loss
|
136 |
+
def compute_bpr_loss(users, users_emb, pos_emb, neg_emb, user_emb0, pos_emb0, neg_emb0):
|
137 |
+
reg_loss = (1 / 2) * (user_emb0.norm().pow(2) + pos_emb0.norm().pow(2) + neg_emb0.norm().pow(2)) / float(len(users))
|
138 |
+
pos_scores = torch.mul(users_emb, pos_emb).sum(dim=1)
|
139 |
+
neg_scores = torch.mul(users_emb, neg_emb).sum(dim=1)
|
140 |
+
bpr_loss = torch.mean(F.softplus(neg_scores - pos_scores))
|
141 |
+
return bpr_loss, reg_loss
|
142 |
+
|
143 |
+
|
144 |
+
# Ensure `n_users` and `n_items` are defined before creating the model
|
145 |
+
if "n_users" not in globals() or "n_items" not in globals():
|
146 |
+
raise ValueError("Ensure `n_users` and `n_items` are defined before initializing the model.")
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
model = RecSysGNN(latent_dim=64, num_layers=3, num_users=n_users, num_items=n_items).to(device)
|
151 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
152 |
+
|
153 |
+
for epoch in range(10):
|
154 |
+
users, pos_items, neg_items = data_loader(train_df, batch_size=256, n_usr=n_users, n_itm=n_items)
|
155 |
+
optimizer.zero_grad()
|
156 |
+
|
157 |
+
users_emb, pos_emb, neg_emb, user_emb0, pos_emb0, neg_emb0 = model.encode_minibatch(users, pos_items, neg_items, train_edge_index)
|
158 |
+
loss, reg_loss = compute_bpr_loss(users, users_emb, pos_emb, neg_emb, user_emb0, pos_emb0, neg_emb0)
|
159 |
+
|
160 |
+
loss.backward()
|
161 |
+
optimizer.step()
|
162 |
+
|
163 |
+
print(f"Epoch {epoch + 1}: Loss = {loss.item():.4f}, Reg Loss = {reg_loss.item():.4f}")
|
164 |
+
|
165 |
+
# Use a writable directory
|
166 |
+
model_dir = "./models" # Saves inside the current working directory
|
167 |
+
os.makedirs(model_dir, exist_ok=True)
|
168 |
+
|
169 |
+
try:
|
170 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
171 |
+
model_path = os.path.join(model_dir, f"model_{timestamp}.pth")
|
172 |
+
|
173 |
+
with open(model_path, 'wb') as f:
|
174 |
+
fcntl.flock(f, fcntl.LOCK_EX)
|
175 |
+
torch.save(model.state_dict(), f)
|
176 |
+
fcntl.flock(f, fcntl.LOCK_UN)
|
177 |
+
|
178 |
+
print(f"Model saved: {model_path}")
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Error saving model: {e}")
|
transactions.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|