SinclairSchneider's picture
Create train.py
37c517d verified
raw
history blame
2.48 kB
import pandas as pd
import torch
from transformers import DebertaV2ForSequenceClassification, DebertaV2Tokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
from tqdm import tqdm
from datasets import Dataset, load_dataset
import numpy as np
import wandb
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
output_dir = './german_politic_DeBERTa-v2-base'
model_name = "ikim-uk-essen/geberta-base"
max_length = 512
id2label = {0: 'other', 1: 'politic'}
label2id = {'other': 0, 'politic': 1}
wandb.init(project="german_politic_yes_no_classifier", entity="xxx", name="german_politic_DeBERTa")
model = DebertaV2ForSequenceClassification.from_pretrained(model_name, num_labels = 2, id2label=id2label, label2id=label2id, output_attentions = False, output_hidden_states = False)
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name, do_lower_case=False, max_length = max_length, TOKENIZERS_PARALLELISM=True)
dataset = load_dataset("SinclairSchneider/trainset_political_text_yes_no_german")
dataset = dataset['train'].train_test_split(0.2)
def preprocess(sample):
return tokenizer(sample["text"], truncation=True)
dataset_tokenized = dataset.map(preprocess, batched = True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
training_args = TrainingArguments(
output_dir = output_dir,
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=4,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
report_to="wandb",
fp16 = False,
logging_steps = 8,
disable_tqdm = False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset_tokenized["train"],
eval_dataset=dataset_tokenized["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)