|
import streamlit as st |
|
from transformers import BertTokenizer, BertForTokenClassification |
|
from transformers import pipeline |
|
|
|
|
|
model = BertForTokenClassification.from_pretrained("./hotel_model") |
|
tokenizer = BertTokenizer.from_pretrained("./hotel_model") |
|
|
|
|
|
def predict_entities(text): |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
return predicted_ids |
|
|
|
st.title("Hotel Bot") |
|
query = st.text_input("Inserisci una query:") |
|
|
|
if query: |
|
entities = predict_entities(query) |
|
st.write(f"Entità estratte: {entities}") |
|
|