Adchay commited on
Commit
8585229
·
verified ·
1 Parent(s): 290769c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -6
app.py CHANGED
@@ -1,14 +1,37 @@
1
  from fastapi import FastAPI
 
2
  from transformers import pipeline
 
 
3
 
4
  app = FastAPI()
5
 
6
- # Load model at startup (only once)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  classifier = pipeline(
8
  "zero-shot-classification",
9
  model="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
10
  )
11
 
 
12
  subject_labels = [
13
  "Physics", "Chemistry", "Biology", "Astronomy",
14
  "Earth Science", "Environmental Science",
@@ -25,12 +48,36 @@ subject_labels = [
25
  "Physical Education", "Health Science"
26
  ]
27
 
28
- @app.post("/predict/")
29
- async def predict_topic(text: str):
 
 
 
 
 
 
30
  result = classifier(
31
- text,
32
  candidate_labels=subject_labels,
33
  hypothesis_template="This text is about {}."
34
  )
35
- predicted_topic = result["labels"][0]
36
- return {"topic": predicted_topic}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  from transformers import pipeline
4
+ import mysql.connector
5
+ import os
6
 
7
  app = FastAPI()
8
 
9
+ # Database connection settings
10
+ DB_HOST = "gateway01.ap-southeast-1.prod.aws.tidbcloud.com"
11
+ DB_PORT = 4000
12
+ DB_USER = "4V44XYoMA7okY9v.root"
13
+ DB_PASS = "aW2CrSwcTgjFhNAb"
14
+ DB_NAME = "final_project"
15
+
16
+ # Create MySQL connection
17
+ conn = mysql.connector.connect(
18
+ host=DB_HOST,
19
+ port=DB_PORT,
20
+ user=DB_USER,
21
+ password=DB_PASS,
22
+ database=DB_NAME,
23
+ ssl_verify_cert=True,
24
+ ssl_verify_identity=True
25
+ )
26
+ cursor = conn.cursor()
27
+
28
+ # Load model
29
  classifier = pipeline(
30
  "zero-shot-classification",
31
  model="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
32
  )
33
 
34
+ # Labels
35
  subject_labels = [
36
  "Physics", "Chemistry", "Biology", "Astronomy",
37
  "Earth Science", "Environmental Science",
 
48
  "Physical Education", "Health Science"
49
  ]
50
 
51
+ # Request model
52
+ class TextInput(BaseModel):
53
+ student_id: str
54
+ text: str
55
+
56
+ @app.post("/predict")
57
+ def predict_topic(data: TextInput):
58
+ # Predict subject
59
  result = classifier(
60
+ data.text,
61
  candidate_labels=subject_labels,
62
  hypothesis_template="This text is about {}."
63
  )
64
+ predicted_subject = result["labels"][0]
65
+
66
+ # Get first 100 characters of the text
67
+ sample_text = data.text[:100]
68
+
69
+ # Save to DB
70
+ cursor.execute(
71
+ """
72
+ INSERT INTO log_table (student_id, input_sample, subject)
73
+ VALUES (%s, %s, %s)
74
+ """,
75
+ (data.student_id, sample_text, predicted_subject)
76
+ )
77
+ conn.commit()
78
+
79
+ return {
80
+ "student_id": data.student_id,
81
+ "predicted_subject": predicted_subject,
82
+ "sample_text": sample_text
83
+ }