Manoj Kumar commited on
Commit
80d5bbb
·
1 Parent(s): f860f0a
Files changed (3) hide show
  1. README.md +1 -1
  2. main.py +45 -0
  3. requirements.txt +4 -1
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
- app_file: t5.py
9
  pinned: false
10
  python: 3.9
11
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
+ app_file: main.py
9
  pinned: false
10
  python: 3.9
11
  ---
main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import uvicorn
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
6
+
7
+ app = FastAPI()
8
+
9
+ # Load fine-tuned text-to-SQL model
10
+ MODEL_NAME = "budecosystem/sql-millennials-13b"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) #AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
+
14
+ def generate_sql(query):
15
+ print(query)
16
+ inputs = tokenizer(query, return_tensors="pt")
17
+ outputs = model.generate(**inputs)
18
+ print(outputs)
19
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+ print("======>", sql_query)
21
+ return sql_query
22
+
23
+ def execute_sql(sql_query):
24
+ conn = sqlite3.connect("./ecommerce.db")
25
+ cursor = conn.cursor()
26
+ try:
27
+ cursor.execute(sql_query)
28
+ result = cursor.fetchall()
29
+ conn.commit()
30
+ except Exception as e:
31
+ result = str(e)
32
+ conn.close()
33
+ return result
34
+
35
+ class QueryRequest(BaseModel):
36
+ text: str
37
+
38
+ @app.post("/generate_sql/")
39
+ def get_sql(query: QueryRequest):
40
+ sql_query = generate_sql(query.text)
41
+ result = execute_sql(sql_query)
42
+ return {"sql": sql_query, "result": result}
43
+
44
+ if __name__ == "__main__":
45
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -4,4 +4,7 @@ accelerate>=0.26.0
4
  tiktoken
5
  datasets
6
  sentencepiece
7
- tqdm
 
 
 
 
4
  tiktoken
5
  datasets
6
  sentencepiece
7
+ tqdm
8
+ pydantic
9
+ fastapi
10
+ uvicorn