|
import pandas as pd |
|
import torch |
|
from transformers import DebertaV2ForSequenceClassification, DebertaV2Tokenizer, DataCollatorWithPadding, Trainer, TrainingArguments |
|
from tqdm import tqdm |
|
from datasets import Dataset, load_dataset |
|
import numpy as np |
|
import wandb |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
|
|
output_dir = './german_politic_DeBERTa-v2-base' |
|
model_name = "ikim-uk-essen/geberta-base" |
|
max_length = 512 |
|
id2label = {0: 'other', 1: 'politic'} |
|
label2id = {'other': 0, 'politic': 1} |
|
|
|
wandb.init(project="german_politic_yes_no_classifier", entity="xxx", name="german_politic_DeBERTa") |
|
|
|
model = DebertaV2ForSequenceClassification.from_pretrained(model_name, num_labels = 2, id2label=id2label, label2id=label2id, output_attentions = False, output_hidden_states = False) |
|
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name, do_lower_case=False, max_length = max_length, TOKENIZERS_PARALLELISM=True) |
|
|
|
dataset = load_dataset("SinclairSchneider/trainset_political_text_yes_no_german") |
|
dataset = dataset['train'].train_test_split(0.2) |
|
|
|
def preprocess(sample): |
|
return tokenizer(sample["text"], truncation=True) |
|
|
|
dataset_tokenized = dataset.map(preprocess, batched = True) |
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') |
|
acc = accuracy_score(labels, preds) |
|
return { |
|
'accuracy': acc, |
|
'f1': f1, |
|
'precision': precision, |
|
'recall': recall |
|
} |
|
|
|
training_args = TrainingArguments( |
|
output_dir = output_dir, |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=16, |
|
per_device_eval_batch_size=16, |
|
num_train_epochs=4, |
|
weight_decay=0.01, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
load_best_model_at_end=True, |
|
report_to="wandb", |
|
fp16 = False, |
|
logging_steps = 8, |
|
disable_tqdm = False, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset_tokenized["train"], |
|
eval_dataset=dataset_tokenized["test"], |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
trainer.train() |
|
|
|
model_to_save = model.module if hasattr(model, 'module') else model |
|
model_to_save.save_pretrained(output_dir) |
|
tokenizer.save_pretrained(output_dir) |