|
import torch
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
|
from sklearn.metrics import accuracy_score
|
|
import json
|
|
from collections import Counter
|
|
from torch.nn import CrossEntropyLoss, Dropout
|
|
from torch.nn.utils import clip_grad_norm_
|
|
import matplotlib.pyplot as plt
|
|
|
|
class YelpDataset(Dataset):
|
|
def __init__(self, texts, ratings, tokenizer, max_length=128):
|
|
self.texts = texts
|
|
self.ratings = ratings
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, idx):
|
|
text = self.texts[idx]
|
|
rating = self.ratings[idx] - 1
|
|
encoding = self.tokenizer(
|
|
text,
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=self.max_length,
|
|
return_tensors="pt"
|
|
)
|
|
return {
|
|
'input_ids': encoding['input_ids'].squeeze(0),
|
|
'attention_mask': encoding['attention_mask'].squeeze(0),
|
|
'label': torch.tensor(rating, dtype=torch.long)
|
|
}
|
|
|
|
def compute_class_weights(labels):
|
|
class_counts = Counter(labels)
|
|
total_samples = sum(class_counts.values())
|
|
weights = {cls: total_samples / count for cls, count in class_counts.items()}
|
|
return weights
|
|
|
|
def train_model(model, train_loader, optimizer, epochs, device, val_loader=None, patience=3):
|
|
model.to(device)
|
|
best_loss = float('inf')
|
|
patience_counter = 0
|
|
|
|
train_accuracies, val_accuracies = [], []
|
|
val_loss_per_batch, val_accuracy_per_batch = [], []
|
|
avg_tokens_all_batches = []
|
|
total_tokens_across_epochs = 0
|
|
total_batches = 0
|
|
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
total_loss = 0
|
|
correct, total = 0, 0
|
|
epoch_tokens = 0
|
|
|
|
for batch in train_loader:
|
|
optimizer.zero_grad()
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
loss = loss_fn(outputs.logits, labels)
|
|
loss.backward()
|
|
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(outputs.logits, dim=1)
|
|
correct += (preds == labels).sum().item()
|
|
total += labels.size(0)
|
|
epoch_tokens += input_ids.size(1) * input_ids.size(0)
|
|
total_batches += 1
|
|
|
|
if val_loader:
|
|
val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True)
|
|
val_loss_per_batch.append(val_loss)
|
|
val_accuracy_per_batch.append(val_accuracy)
|
|
|
|
total_tokens_across_epochs += epoch_tokens
|
|
avg_loss = total_loss / len(train_loader)
|
|
train_accuracy = correct / total
|
|
train_accuracies.append(train_accuracy)
|
|
print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {avg_loss:.4f} - Training Accuracy: {train_accuracy:.4f}")
|
|
|
|
if val_loader:
|
|
val_loss, val_accuracy = evaluate_model(model, val_loader, device, return_loss=True)
|
|
val_accuracies.append(val_accuracy)
|
|
|
|
if val_loss < best_loss:
|
|
best_loss = val_loss
|
|
patience_counter = 0
|
|
else:
|
|
patience_counter += 1
|
|
|
|
if patience_counter >= patience:
|
|
print("Early stopping triggered.")
|
|
break
|
|
|
|
avg_tokens_all_batches.append(total_tokens_across_epochs / total_batches)
|
|
print(f"Average tokens per batch across all epochs: {avg_tokens_all_batches[-1]:.2f}")
|
|
|
|
if val_loader:
|
|
plot_accuracies(train_accuracies, val_accuracies)
|
|
plt.figure()
|
|
plt.title('Average Tokens per Epoch')
|
|
plt.plot(range(1, len(avg_tokens_all_batches) + 1), avg_tokens_all_batches, label='Avg Tokens per Epoch')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Tokens')
|
|
plt.legend()
|
|
plt.figure()
|
|
plt.title('Validation Metrics Over Batches')
|
|
plt.plot(val_loss_per_batch, label='Validation Loss')
|
|
plt.plot(val_accuracy_per_batch, label='Validation Accuracy')
|
|
plt.xlabel('Batches')
|
|
plt.ylabel('Metrics')
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
def evaluate_model(model, val_loader, device, return_loss=False):
|
|
model.eval()
|
|
all_preds, all_labels = [], []
|
|
total_loss = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
preds = torch.argmax(outputs.logits, dim=1)
|
|
|
|
all_preds.extend(preds.cpu().numpy())
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
if return_loss:
|
|
loss = loss_fn(outputs.logits, labels)
|
|
total_loss += loss.item()
|
|
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
print(f"Validation Accuracy: {accuracy:.4f}")
|
|
if return_loss:
|
|
return total_loss / len(val_loader), accuracy
|
|
return accuracy
|
|
|
|
def plot_accuracies(train, val):
|
|
plt.plot(train, label="Training Accuracy")
|
|
plt.plot(val, label="Validation Accuracy")
|
|
plt.xlabel("Batch")
|
|
plt.ylabel("Accuracy")
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
def stream_file(file_path, chunk_size):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
chunk = []
|
|
for _, line in enumerate(file):
|
|
record = json.loads(line.strip())
|
|
if "stars" in record and isinstance(record["stars"], (int, float)):
|
|
chunk.append((record["text"], int(record["stars"])))
|
|
if len(chunk) == chunk_size:
|
|
yield chunk
|
|
chunk = []
|
|
if chunk:
|
|
yield chunk
|
|
|
|
if __name__ == "__main__":
|
|
file_path = "yelp_academic_dataset_review.json"
|
|
chunk_size = 10000
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
|
|
|
|
model.dropout = Dropout(p=0.1)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
|
|
|
|
all_ratings = []
|
|
for i, chunk in enumerate(stream_file(file_path, chunk_size)):
|
|
_, ratings = zip(*chunk)
|
|
all_ratings.extend(ratings)
|
|
if i + 1 == 5:
|
|
break
|
|
|
|
class_weights = compute_class_weights(all_ratings)
|
|
weights_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights)], dtype=torch.float).to(device)
|
|
loss_fn = CrossEntropyLoss(weight=weights_tensor)
|
|
|
|
first_chunk = next(stream_file(file_path, chunk_size))
|
|
val_texts, val_ratings = zip(*first_chunk[:2000])
|
|
val_dataset = YelpDataset(val_texts, val_ratings, tokenizer)
|
|
val_loader = DataLoader(val_dataset, batch_size=8)
|
|
|
|
for chunk_idx, chunk in enumerate(stream_file(file_path, chunk_size)):
|
|
if chunk_idx + 1 > 10:
|
|
break
|
|
print(f"Processing chunk #{chunk_idx + 1}")
|
|
|
|
texts, ratings = zip(*chunk)
|
|
dataset = YelpDataset(texts, ratings, tokenizer)
|
|
loader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
|
|
train_model(model, loader, optimizer, epochs=1, device=device, val_loader=val_loader)
|
|
|
|
evaluate_model(model, val_loader, device)
|
|
|