import torch from torch.utils.data import DataLoader, Dataset from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, AdamW from sklearn.metrics import accuracy_score import json from collections import Counter from torch.nn import CrossEntropyLoss, Dropout from torch.nn.utils import clip_grad_norm_ import matplotlib.pyplot as plt class YelpDataset(Dataset): def __init__(self, texts, ratings, tokenizer, max_length=128): self.texts = texts self.ratings = ratings self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] rating = self.ratings[idx] - 1 # turn 1-5 star ratings to 0-4 encoding = self.tokenizer( text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ) return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0), 'label': torch.tensor(rating, dtype=torch.long) } def compute_class_weights(labels): class_counts = Counter(labels) total_samples = sum(class_counts.values()) weights = {cls: total_samples / count for cls, count in class_counts.items()} return weights def train_model(model, train_loader, optimizer, epochs, device, val_loader=None, patience=3): model.to(device) best_loss = float('inf') patience_counter = 0 train_accuracies, val_accuracies = [], [] val_loss_per_batch, val_accuracy_per_batch = [], [] avg_tokens_all_batches = [] total_tokens_across_epochs = 0 total_batches = 0 for epoch in range(epochs): model.train() total_loss = 0 correct, total = 0, 0 epoch_tokens = 0 for batch in train_loader: optimizer.zero_grad() input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['label'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = loss_fn(outputs.logits, labels) loss.backward() clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() preds = torch.argmax(outputs.logits, dim=1) correct += (preds == labels).sum().item() total += labels.size(0) epoch_tokens += input_ids.size(1) * input_ids.size(0) # tokens per batch total_batches += 1 if val_loader: val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True) val_loss_per_batch.append(val_loss) val_accuracy_per_batch.append(val_accuracy) total_tokens_across_epochs += epoch_tokens avg_loss = total_loss / len(train_loader) train_accuracy = correct / total train_accuracies.append(train_accuracy) print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {avg_loss:.4f} - Training Accuracy: {train_accuracy:.4f}") if val_loader: val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True) val_accuracies.append(val_accuracy) if val_loss < best_loss: best_loss = val_loss patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: print("Early stopping triggered.") break avg_tokens_all_batches.append(total_tokens_across_epochs / total_batches) print(f"Average tokens per batch across all epochs: {avg_tokens_all_batches[-1]:.2f}") if val_loader: plot_accuracies(train_accuracies, val_accuracies) plt.figure() plt.title('Average Tokens per Epoch') plt.plot(range(1, len(avg_tokens_all_batches) + 1), avg_tokens_all_batches, label='Avg Tokens per Epoch') plt.xlabel('Epochs') plt.ylabel('Tokens') plt.legend() plt.figure() plt.title('Validation Metrics Over Batches') plt.plot(val_loss_per_batch, label='Validation Loss') plt.plot(val_accuracy_per_batch, label='Validation Accuracy') plt.xlabel('Batches') plt.ylabel('Metrics') plt.legend() plt.show() def evaluate_model(model, val_loader, device, return_loss=False): model.eval() all_preds, all_labels = [], [] total_loss = 0 with torch.no_grad(): for batch in val_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['label'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) preds = torch.argmax(outputs.logits, dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) if return_loss: loss = loss_fn(outputs.logits, labels) total_loss += loss.item() accuracy = accuracy_score(all_labels, all_preds) print(f"Validation Accuracy: {accuracy:.4f}") if return_loss: return total_loss / len(val_loader), accuracy return accuracy def plot_accuracies(train_accuracies, val_accuracies): epochs = range(1, len(train_accuracies) + 1) plt.plot(epochs, train_accuracies, label='Training Accuracy') plt.plot(epochs, val_accuracies, label='Validation Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() def stream_file(file_path, chunk_size): with open(file_path, 'r', encoding='utf-8') as file: chunk = [] for _, line in enumerate(file): record = json.loads(line.strip()) if "stars" in record and isinstance(record["stars"], (int, float)): chunk.append((record["text"], int(record["stars"]))) if len(chunk) == chunk_size: yield chunk chunk = [] if chunk: yield chunk if __name__ == "__main__": file_path = "yelp_academic_dataset_review.json" chunk_size = 10000 # process 10,000 lines at a time tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token model = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=5) model.config.pad_token_id = tokenizer.pad_token_id model.dropout = Dropout(p=0.1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01) all_ratings = [] for i, chunk in enumerate(stream_file(file_path, chunk_size)): _, ratings = zip(*chunk) all_ratings.extend(ratings) if i + 1 == 5: break class_weights = compute_class_weights(all_ratings) weights_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights)], dtype=torch.float).to(device) loss_fn = CrossEntropyLoss(weight=weights_tensor) first_chunk = next(stream_file(file_path, chunk_size)) val_texts, val_ratings = zip(*first_chunk[:2000]) # 2000 samples for validation set val_dataset = YelpDataset(val_texts, val_ratings, tokenizer) val_loader = DataLoader(val_dataset, batch_size=8) for chunk_idx, chunk in enumerate(stream_file(file_path, chunk_size)): if chunk_idx + 1 > 10: break print(f"Processing chunk #{chunk_idx + 1}") texts, ratings = zip(*chunk) dataset = YelpDataset(texts, ratings, tokenizer) loader = DataLoader(dataset, batch_size=8, shuffle=True) train_model(model, loader, optimizer, epochs=1, device=device, val_loader=val_loader) evaluate_model(model, val_loader, device)