File size: 898 Bytes
59e21a9 4266281 59e21a9 4266281 59e21a9 4266281 59e21a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import streamlit as st
from transformers import BertTokenizer, BertForTokenClassification
from transformers import pipeline
# Carica il modello e il tokenizer
model = BertForTokenClassification.from_pretrained("./hotel_model")
tokenizer = BertTokenizer.from_pretrained("./hotel_model")
# Funzione di inferenza
def predict_entities(text):
# Tokenizza il testo
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Ottieni le predizioni dal modello
outputs = model(**inputs)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
# Converti le predizioni in etichette (potresti aver bisogno di una mappatura delle etichette)
return predicted_ids
st.title("Hotel Bot")
query = st.text_input("Inserisci una query:")
if query:
entities = predict_entities(query)
st.write(f"Entità estratte: {entities}")
|