Yelp_Review / gpt2.py
rhdang's picture
Upload 4 files
f7ac105 verified
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score
import json
from collections import Counter
from torch.nn import CrossEntropyLoss, Dropout
from torch.nn.utils import clip_grad_norm_
import matplotlib.pyplot as plt
class YelpDataset(Dataset):
def __init__(self, texts, ratings, tokenizer, max_length=128):
self.texts = texts
self.ratings = ratings
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
rating = self.ratings[idx] - 1 # turn 1-5 star ratings to 0-4
encoding = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(rating, dtype=torch.long)
}
def compute_class_weights(labels):
class_counts = Counter(labels)
total_samples = sum(class_counts.values())
weights = {cls: total_samples / count for cls, count in class_counts.items()}
return weights
def train_model(model, train_loader, optimizer, epochs, device, val_loader=None, patience=3):
model.to(device)
best_loss = float('inf')
patience_counter = 0
train_accuracies, val_accuracies = [], []
val_loss_per_batch, val_accuracy_per_batch = [], []
avg_tokens_all_batches = []
total_tokens_across_epochs = 0
total_batches = 0
for epoch in range(epochs):
model.train()
total_loss = 0
correct, total = 0, 0
epoch_tokens = 0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = loss_fn(outputs.logits, labels)
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
preds = torch.argmax(outputs.logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
epoch_tokens += input_ids.size(1) * input_ids.size(0) # tokens per batch
total_batches += 1
if val_loader:
val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True)
val_loss_per_batch.append(val_loss)
val_accuracy_per_batch.append(val_accuracy)
total_tokens_across_epochs += epoch_tokens
avg_loss = total_loss / len(train_loader)
train_accuracy = correct / total
train_accuracies.append(train_accuracy)
print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {avg_loss:.4f} - Training Accuracy: {train_accuracy:.4f}")
if val_loader:
val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True)
val_accuracies.append(val_accuracy)
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping triggered.")
break
avg_tokens_all_batches.append(total_tokens_across_epochs / total_batches)
print(f"Average tokens per batch across all epochs: {avg_tokens_all_batches[-1]:.2f}")
if val_loader:
plot_accuracies(train_accuracies, val_accuracies)
plt.figure()
plt.title('Average Tokens per Epoch')
plt.plot(range(1, len(avg_tokens_all_batches) + 1), avg_tokens_all_batches, label='Avg Tokens per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Tokens')
plt.legend()
plt.figure()
plt.title('Validation Metrics Over Batches')
plt.plot(val_loss_per_batch, label='Validation Loss')
plt.plot(val_accuracy_per_batch, label='Validation Accuracy')
plt.xlabel('Batches')
plt.ylabel('Metrics')
plt.legend()
plt.show()
def evaluate_model(model, val_loader, device, return_loss=False):
model.eval()
all_preds, all_labels = [], []
total_loss = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = torch.argmax(outputs.logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
if return_loss:
loss = loss_fn(outputs.logits, labels)
total_loss += loss.item()
accuracy = accuracy_score(all_labels, all_preds)
print(f"Validation Accuracy: {accuracy:.4f}")
if return_loss:
return total_loss / len(val_loader), accuracy
return accuracy
def plot_accuracies(train_accuracies, val_accuracies):
epochs = range(1, len(train_accuracies) + 1)
plt.plot(epochs, train_accuracies, label='Training Accuracy')
plt.plot(epochs, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
def stream_file(file_path, chunk_size):
with open(file_path, 'r', encoding='utf-8') as file:
chunk = []
for _, line in enumerate(file):
record = json.loads(line.strip())
if "stars" in record and isinstance(record["stars"], (int, float)):
chunk.append((record["text"], int(record["stars"])))
if len(chunk) == chunk_size:
yield chunk
chunk = []
if chunk:
yield chunk
if __name__ == "__main__":
file_path = "yelp_academic_dataset_review.json"
chunk_size = 10000 # process 10,000 lines at a time
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=5)
model.config.pad_token_id = tokenizer.pad_token_id
model.dropout = Dropout(p=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
all_ratings = []
for i, chunk in enumerate(stream_file(file_path, chunk_size)):
_, ratings = zip(*chunk)
all_ratings.extend(ratings)
if i + 1 == 5:
break
class_weights = compute_class_weights(all_ratings)
weights_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights)], dtype=torch.float).to(device)
loss_fn = CrossEntropyLoss(weight=weights_tensor)
first_chunk = next(stream_file(file_path, chunk_size))
val_texts, val_ratings = zip(*first_chunk[:2000]) # 2000 samples for validation set
val_dataset = YelpDataset(val_texts, val_ratings, tokenizer)
val_loader = DataLoader(val_dataset, batch_size=8)
for chunk_idx, chunk in enumerate(stream_file(file_path, chunk_size)):
if chunk_idx + 1 > 10:
break
print(f"Processing chunk #{chunk_idx + 1}")
texts, ratings = zip(*chunk)
dataset = YelpDataset(texts, ratings, tokenizer)
loader = DataLoader(dataset, batch_size=8, shuffle=True)
train_model(model, loader, optimizer, epochs=1, device=device, val_loader=val_loader)
evaluate_model(model, val_loader, device)