Steph974 commited on
Commit
8c2d966
·
verified ·
1 Parent(s): 1bbcb6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app_code = """\
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
+
7
+ # Load the fine-tuned SBERT model from Hugging Face
8
+ model_name = "Steph974/SBERT-FineTuned-Classifier" # Your uploaded model
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ # Ensure the model is on the correct device
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ def predict_similarity(sentence1, sentence2):
18
+ \"\"\"
19
+ Predicts the probability of two sentences belonging to the same class (1) or different (0).
20
+ Returns probability instead of class label.
21
+ \"\"\"
22
+ # Tokenize input
23
+ inputs = tokenizer(sentence1, sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
24
+ inputs = {key: value.to(device) for key, value in inputs.items()} # Move tensors to model device
25
+
26
+ # Perform inference
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+
30
+ # Get probabilities
31
+ probabilities = F.softmax(outputs.logits, dim=1).cpu().numpy()[0]
32
+ proba_same = probabilities[1] # Probability that sentences are in the same class
33
+ proba_diff = probabilities[0] # Probability that sentences are different
34
+
35
+ return {
36
+ "Same Class Probability": round(proba_same * 100, 2),
37
+ "Different Class Probability": round(proba_diff * 100, 2)
38
+ }
39
+
40
+ # Gradio UI
41
+ interface = gr.Interface(
42
+ fn=predict_similarity,
43
+ inputs=[
44
+ gr.Textbox(label="Sentence 1", placeholder="Enter the first sentence..."),
45
+ gr.Textbox(label="Sentence 2", placeholder="Enter the second sentence...")
46
+ ],
47
+ outputs=gr.Label(label="Prediction Probabilities"),
48
+ title="SBERT Sentence-Pair Similarity",
49
+ description="Enter two sentences and see how similar they are according to the fine-tuned SBERT model.",
50
+ theme="huggingface",
51
+ )
52
+
53
+ # Launch the Gradio app
54
+ interface.launch()
55
+ """
56
+
57
+ # Save to app.py
58
+ with open("app.py", "w") as f:
59
+ f.write(app_code)