Update train.py
Browse files
train.py
CHANGED
@@ -1,24 +1,34 @@
|
|
1 |
-
from transformers import Trainer, TrainingArguments
|
2 |
from datasets import Dataset
|
3 |
import json
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Carica il dataset
|
6 |
with open('entity_dataset.json', 'r') as f:
|
7 |
dataset = json.load(f)
|
8 |
|
9 |
-
#
|
10 |
def prepare_dataset(dataset):
|
11 |
-
|
12 |
-
|
13 |
-
"text": [entry["query"] for entry in dataset],
|
14 |
-
"labels": [entry["entities"] for entry in dataset]
|
15 |
-
}
|
16 |
-
return Dataset.from_dict(data)
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
|
23 |
# Imposta i parametri di addestramento
|
24 |
training_args = TrainingArguments(
|
@@ -27,6 +37,7 @@ training_args = TrainingArguments(
|
|
27 |
learning_rate=2e-5, # Tasso di apprendimento
|
28 |
per_device_train_batch_size=16, # Dimensione del batch
|
29 |
num_train_epochs=3, # Numero di epoche
|
|
|
30 |
)
|
31 |
|
32 |
# Inizializza il trainer
|
@@ -38,3 +49,7 @@ trainer = Trainer(
|
|
38 |
|
39 |
# Avvia l'addestramento
|
40 |
trainer.train()
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
|
2 |
from datasets import Dataset
|
3 |
import json
|
4 |
+
import torch
|
5 |
+
|
6 |
+
# Carica il tokenizer e il modello pre-addestrato (dbmdz/bert-base-italian-uncased)
|
7 |
+
tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-italian-uncased")
|
8 |
+
model = BertForTokenClassification.from_pretrained("dbmdz/bert-base-italian-uncased", num_labels=5) # Aggiungi il numero corretto di etichette (labels)
|
9 |
|
10 |
# Carica il dataset
|
11 |
with open('entity_dataset.json', 'r') as f:
|
12 |
dataset = json.load(f)
|
13 |
|
14 |
+
# Funzione per preparare i dati
|
15 |
def prepare_dataset(dataset):
|
16 |
+
input_texts = [entry["query"] for entry in dataset]
|
17 |
+
labels = [entry["entities"] for entry in dataset]
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Tokenizza i dati di input
|
20 |
+
encodings = tokenizer(input_texts, truncation=True, padding=True, max_length=512)
|
21 |
+
# Aggiungi le etichette (entità) come output
|
22 |
+
# Qui supponiamo che tu stia etichettando solo le entità (puoi adattare la funzione per il tuo caso)
|
23 |
+
# Nota: dovresti mappare le etichette in modo che corrispondano al formato richiesto per BERT
|
24 |
+
|
25 |
+
# Associa le etichette agli input tokenizzati
|
26 |
+
encodings['labels'] = torch.tensor(labels)
|
27 |
|
28 |
+
return Dataset.from_dict(encodings)
|
29 |
+
|
30 |
+
# Prepara il dataset per l'addestramento
|
31 |
+
train_dataset = prepare_dataset(dataset)
|
32 |
|
33 |
# Imposta i parametri di addestramento
|
34 |
training_args = TrainingArguments(
|
|
|
37 |
learning_rate=2e-5, # Tasso di apprendimento
|
38 |
per_device_train_batch_size=16, # Dimensione del batch
|
39 |
num_train_epochs=3, # Numero di epoche
|
40 |
+
weight_decay=0.01 # Peso di decadimento (per evitare overfitting)
|
41 |
)
|
42 |
|
43 |
# Inizializza il trainer
|
|
|
49 |
|
50 |
# Avvia l'addestramento
|
51 |
trainer.train()
|
52 |
+
|
53 |
+
# Salva il modello addestrato
|
54 |
+
model.save_pretrained("./hotel_model")
|
55 |
+
tokenizer.save_pretrained("./hotel_model")
|