Daddario commited on
Commit
cd21527
·
verified ·
1 Parent(s): c748179

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +40 -0
train.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments, BertForTokenClassification
2
+ from datasets import Dataset
3
+ import json
4
+
5
+ # Carica il dataset
6
+ with open('entity_dataset.json', 'r') as f:
7
+ dataset = json.load(f)
8
+
9
+ # Prepara il dataset per l'addestramento
10
+ def prepare_dataset(dataset):
11
+ # Converti il dataset in un formato adatto per Hugging Face Dataset
12
+ data = {
13
+ "text": [entry["query"] for entry in dataset],
14
+ "labels": [entry["entities"] for entry in dataset]
15
+ }
16
+ return Dataset.from_dict(data)
17
+
18
+ train_dataset = prepare_dataset(dataset)
19
+
20
+ # Carica il modello pre-addestrato
21
+ model = BertForTokenClassification.from_pretrained("dbmdz/bert-base-italian-uncased")
22
+
23
+ # Imposta i parametri di addestramento
24
+ training_args = TrainingArguments(
25
+ output_dir="./results", # Cartella di output
26
+ evaluation_strategy="epoch", # Come viene eseguita la valutazione
27
+ learning_rate=2e-5, # Tasso di apprendimento
28
+ per_device_train_batch_size=16, # Dimensione del batch
29
+ num_train_epochs=3, # Numero di epoche
30
+ )
31
+
32
+ # Inizializza il trainer
33
+ trainer = Trainer(
34
+ model=model, # Il modello
35
+ args=training_args, # I parametri di addestramento
36
+ train_dataset=train_dataset, # Il dataset di addestramento
37
+ )
38
+
39
+ # Avvia l'addestramento
40
+ trainer.train()