|
import sqlite3 |
|
import uvicorn |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM |
|
|
|
app = FastAPI() |
|
|
|
|
|
MODEL_NAME = "budecosystem/sql-millennials-13b" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
def generate_sql(query): |
|
print(query) |
|
inputs = tokenizer(query, return_tensors="pt") |
|
outputs = model.generate(**inputs) |
|
print(outputs) |
|
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("======>", sql_query) |
|
return sql_query |
|
|
|
def execute_sql(sql_query): |
|
conn = sqlite3.connect("./ecommerce.db") |
|
cursor = conn.cursor() |
|
try: |
|
cursor.execute(sql_query) |
|
result = cursor.fetchall() |
|
conn.commit() |
|
except Exception as e: |
|
result = str(e) |
|
conn.close() |
|
return result |
|
|
|
class QueryRequest(BaseModel): |
|
text: str |
|
|
|
@app.post("/generate_sql/") |
|
def get_sql(query: QueryRequest): |
|
sql_query = generate_sql(query.text) |
|
result = execute_sql(sql_query) |
|
return {"sql": sql_query, "result": result} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |