hotel_bot / train.py
Daddario's picture
Update train.py
b7ffce6 verified
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import Dataset
import json
import torch
# Carica il tokenizer e il modello pre-addestrato (dbmdz/bert-base-italian-uncased)
tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-italian-uncased")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-base-italian-uncased", num_labels=5) # Aggiungi il numero corretto di etichette (labels)
# Carica il dataset
with open('entity_dataset.json', 'r') as f:
dataset = json.load(f)
# Funzione per preparare i dati
def prepare_dataset(dataset):
input_texts = [entry["query"] for entry in dataset]
labels = [entry["entities"] for entry in dataset]
# Tokenizza i dati di input
encodings = tokenizer(input_texts, truncation=True, padding=True, max_length=512)
# Aggiungi le etichette (entità) come output
# Qui supponiamo che tu stia etichettando solo le entità (puoi adattare la funzione per il tuo caso)
# Nota: dovresti mappare le etichette in modo che corrispondano al formato richiesto per BERT
# Associa le etichette agli input tokenizzati
encodings['labels'] = torch.tensor(labels)
return Dataset.from_dict(encodings)
# Prepara il dataset per l'addestramento
train_dataset = prepare_dataset(dataset)
# Imposta i parametri di addestramento
training_args = TrainingArguments(
output_dir="./results", # Cartella di output
evaluation_strategy="epoch", # Come viene eseguita la valutazione
learning_rate=2e-5, # Tasso di apprendimento
per_device_train_batch_size=16, # Dimensione del batch
num_train_epochs=3, # Numero di epoche
weight_decay=0.01 # Peso di decadimento (per evitare overfitting)
)
# Inizializza il trainer
trainer = Trainer(
model=model, # Il modello
args=training_args, # I parametri di addestramento
train_dataset=train_dataset, # Il dataset di addestramento
)
# Avvia l'addestramento
trainer.train()
# Salva il modello addestrato
model.save_pretrained("./hotel_model")
tokenizer.save_pretrained("./hotel_model")