File size: 898 Bytes
59e21a9
4266281
59e21a9
 
4266281
 
 
59e21a9
 
 
4266281
 
 
 
 
 
 
 
59e21a9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import streamlit as st
from transformers import BertTokenizer, BertForTokenClassification
from transformers import pipeline

# Carica il modello e il tokenizer
model = BertForTokenClassification.from_pretrained("./hotel_model")
tokenizer = BertTokenizer.from_pretrained("./hotel_model")

# Funzione di inferenza
def predict_entities(text):
    # Tokenizza il testo
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    # Ottieni le predizioni dal modello
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_ids = torch.argmax(logits, dim=-1)
    # Converti le predizioni in etichette (potresti aver bisogno di una mappatura delle etichette)
    return predicted_ids

st.title("Hotel Bot")
query = st.text_input("Inserisci una query:")

if query:
    entities = predict_entities(query)
    st.write(f"Entità estratte: {entities}")