Commit
·
ddf9ea0
1
Parent(s):
345c7d2
Updated the approach and shifted the flask app to gradio app for deployment on HuggingFace
Browse files- app/db.py +4 -1
- app/main.py +100 -45
- app/nlp.py +68 -3
- dockerfile +7 -2
- requirements.txt +0 -0
app/db.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import sqlite3
|
2 |
from sqlite3 import Error
|
3 |
-
from
|
|
|
4 |
|
5 |
class Database:
|
6 |
def __init__(self, db_path='data/database.sqlite'):
|
7 |
self.db_path = db_path
|
8 |
|
9 |
def execute_query(self, query):
|
|
|
10 |
if not validate_sql(query):
|
11 |
return {"error": "Invalid SQL query. Only SELECT queries are allowed at this point."}
|
12 |
try:
|
@@ -18,4 +20,5 @@ class Database:
|
|
18 |
conn.close()
|
19 |
return {"columns": columns, "data": results}
|
20 |
except Error as e:
|
|
|
21 |
return {"error": str(e)}
|
|
|
1 |
import sqlite3
|
2 |
from sqlite3 import Error
|
3 |
+
from .utils import validate_sql
|
4 |
+
import logging
|
5 |
|
6 |
class Database:
|
7 |
def __init__(self, db_path='data/database.sqlite'):
|
8 |
self.db_path = db_path
|
9 |
|
10 |
def execute_query(self, query):
|
11 |
+
logging.info(f"Executing SQL: {query}")
|
12 |
if not validate_sql(query):
|
13 |
return {"error": "Invalid SQL query. Only SELECT queries are allowed at this point."}
|
14 |
try:
|
|
|
20 |
conn.close()
|
21 |
return {"columns": columns, "data": results}
|
22 |
except Error as e:
|
23 |
+
logging.error(f"SQL Error: {str(e)}")
|
24 |
return {"error": str(e)}
|
app/main.py
CHANGED
@@ -1,53 +1,108 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from
|
|
|
4 |
|
5 |
-
app = Flask(__name__)
|
6 |
nlp = NLPToSQL()
|
7 |
db = Database()
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
"""
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
-
|
52 |
|
53 |
-
|
|
|
|
1 |
+
import logging
|
2 |
+
import gradio as gr
|
3 |
+
from .nlp import NLPToSQL
|
4 |
+
from .db import Database
|
5 |
|
|
|
6 |
nlp = NLPToSQL()
|
7 |
db = Database()
|
8 |
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
|
11 |
+
def chatbot(user_query):
|
12 |
+
try:
|
13 |
+
sql = nlp.query_to_sql(user_query)
|
14 |
+
logging.info(f"Generated SQL: {sql}")
|
15 |
+
result = db.execute_query(sql)
|
16 |
+
|
17 |
+
if 'error' in result:
|
18 |
+
return f"❌ Error: {result['error']}"
|
19 |
+
if not result['data']:
|
20 |
+
return "⚠️ No data found."
|
21 |
+
|
22 |
+
# Formatting SQL output
|
23 |
+
response = "**Query:**\n" + sql + "\n\n"
|
24 |
+
response += "**Result:**\n"
|
25 |
+
response += " | ".join(result['columns']) + "\n"
|
26 |
+
response += "-"*50 + "\n"
|
27 |
+
for row in result['data']:
|
28 |
+
response += " | ".join(str(cell) for cell in row) + "\n"
|
29 |
+
return response
|
30 |
+
|
31 |
+
except Exception as e:
|
32 |
+
logging.error(f"Error: {str(e)}")
|
33 |
+
return f"❌ Error: {str(e)}"
|
34 |
+
|
35 |
+
demo = gr.Interface(
|
36 |
+
fn=chatbot,
|
37 |
+
inputs=gr.Textbox(lines=2, placeholder="Ask your database anything..."),
|
38 |
+
outputs="markdown",
|
39 |
+
title="SQL Chat Assistant",
|
40 |
+
description="Enter a natural language question, and get the corresponding SQL query & result."
|
41 |
+
)
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
# Flask-based application can also be used
|
49 |
+
|
50 |
+
# import logging
|
51 |
+
# from flask import Flask, request, render_template_string
|
52 |
+
# from .nlp import NLPToSQL
|
53 |
+
# from .db import Database
|
54 |
+
|
55 |
+
# app = Flask(__name__)
|
56 |
+
# nlp = NLPToSQL()
|
57 |
+
# db = Database()
|
58 |
+
|
59 |
+
# logging.basicConfig(level=logging.INFO)
|
60 |
+
|
61 |
+
# HTML_TEMPLATE = """
|
62 |
+
# <!DOCTYPE html>
|
63 |
+
# <html>
|
64 |
+
# <head><title> Chat Assistant </title></head>
|
65 |
+
# <body>
|
66 |
+
# <h1> Database Chat Assistant</h1>
|
67 |
+
# <form method="POST">
|
68 |
+
# <input type="text" name="query" placeholder= "Enter your query..." size="50">
|
69 |
+
# <button type="submit">Ask</button>
|
70 |
+
# </form>
|
71 |
+
# {% if response %}
|
72 |
+
# <h3> Response: </h3>
|
73 |
+
# <pre>{{ response }}</pre>
|
74 |
+
# {% endif %}
|
75 |
+
# {% if error %}
|
76 |
+
# <p style = "color:red;">{{ error }} </p>
|
77 |
+
# {% endif %}
|
78 |
+
# </body>
|
79 |
+
# </html>
|
80 |
+
# """
|
81 |
+
|
82 |
+
# @app.route("/", methods=["GET", "POST"])
|
83 |
+
# def index():
|
84 |
+
# if request.method == 'POST':
|
85 |
+
# user_query = request.form['query']
|
86 |
+
# try:
|
87 |
+
# sql = nlp.query_to_sql(user_query)
|
88 |
+
# logging.info(f"Generated SQL: {sql}")
|
89 |
+
# result = db.execute_query(sql)
|
90 |
+
# if 'error' in result:
|
91 |
+
# return render_template_string(HTML_TEMPLATE, error=result['error'])
|
92 |
+
# if not result['data']:
|
93 |
+
# return render_template_string(HTML_TEMPLATE, error="No data found")
|
94 |
|
95 |
+
# response = " | ".join(result['columns']) + "\n"
|
96 |
+
# response += "-"*50 + "\n"
|
97 |
+
# for row in result['data']:
|
98 |
+
# response += " | ".join(str(cell) for cell in row) + "\n"
|
99 |
+
# return render_template_string(HTML_TEMPLATE, response=response)
|
100 |
|
101 |
+
# except Exception as e:
|
102 |
+
# logging.error(f"Error: {str(e)}")
|
103 |
+
# return render_template_string(HTML_TEMPLATE, error=f"Error: {str(e)}")
|
104 |
|
105 |
+
# return render_template_string(HTML_TEMPLATE)
|
106 |
|
107 |
+
# if __name__ == "__main__":
|
108 |
+
# app.run(debug=True)
|
app/nlp.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from transformers import pipeline
|
|
|
|
|
2 |
|
3 |
-
class
|
4 |
def __init__(self):
|
5 |
self.model = pipeline(
|
6 |
"text2text-generation",
|
@@ -9,6 +11,69 @@ class NLPToSQL:
|
|
9 |
)
|
10 |
|
11 |
def query_to_sql(self, user_query):
|
12 |
-
prompt = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
result = self.model(prompt, max_length=200)
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import pipeline
|
2 |
+
import re
|
3 |
+
from typing import Dict
|
4 |
|
5 |
+
class NLPToSQL2:
|
6 |
def __init__(self):
|
7 |
self.model = pipeline(
|
8 |
"text2text-generation",
|
|
|
11 |
)
|
12 |
|
13 |
def query_to_sql(self, user_query):
|
14 |
+
prompt = (f"Generate a valid SQL query in the correct format based on the following schema:\n"
|
15 |
+
f"Table1: Employees\n"
|
16 |
+
f"Columns: ID, Name, Department, Salary\n"
|
17 |
+
f"Table2: Departments\n"
|
18 |
+
f"Columns: Name, Manager\n"
|
19 |
+
f"Natural Language: {user_query}"
|
20 |
+
f"SQL query:"
|
21 |
+
)
|
22 |
+
|
23 |
result = self.model(prompt, max_length=200)
|
24 |
+
sql = result[0]['generated_text']
|
25 |
+
|
26 |
+
return sql
|
27 |
+
|
28 |
+
class NLPToSQL:
|
29 |
+
def __init__(self):
|
30 |
+
self.query_patterns: Dict[str, str] = {
|
31 |
+
r"show\s+(?:me\s+)?all\s+employees?\s+in\s+(?:the\s+)?(\w+)\s+department":
|
32 |
+
"SELECT * FROM Employees WHERE LOWER(Department) = LOWER('{}')",
|
33 |
+
|
34 |
+
r"who\s+is\s+(?:the\s+)?manager\s+of\s+(?:the\s+)?(\w+)\s+department":
|
35 |
+
"SELECT Manager FROM Departments WHERE LOWER(Name) = LOWER('{}')",
|
36 |
+
|
37 |
+
r"list\s+(?:all\s+)?employees?\s+hired\s+after\s+(\d{4}-\d{2}-\d{2})":
|
38 |
+
"SELECT * FROM Employees WHERE Hire_Date > '{}'",
|
39 |
+
|
40 |
+
r"what\s+is\s+(?:the\s+)?total\s+salary\s+(?:expense\s+)?for\s+(?:the\s+)?(\w+)\s+department":
|
41 |
+
"SELECT SUM(Salary) as Total_Salary FROM Employees WHERE LOWER(Department) = LOWER('{}')",
|
42 |
+
|
43 |
+
r"show\s+(?:me\s+)?(?:the\s+)?salary\s+of\s+(\w+)":
|
44 |
+
"SELECT Salary FROM Employees WHERE LOWER(Name) = LOWER('{}')",
|
45 |
+
|
46 |
+
r"list\s+(?:all\s+)?employees?\s+with\s+salary\s+(?:greater|more)\s+than\s+(\d+)":
|
47 |
+
"SELECT * FROM Employees WHERE Salary > {}",
|
48 |
+
|
49 |
+
r"(?:show|list)\s+(?:me\s+)?all\s+departments":
|
50 |
+
"SELECT * FROM Departments",
|
51 |
+
|
52 |
+
r"(?:show|list)\s+(?:me\s+)?all\s+employees":
|
53 |
+
"SELECT * FROM Employees"
|
54 |
+
}
|
55 |
+
|
56 |
+
def query_to_sql(self, user_query: str) -> str:
|
57 |
+
normalized_query = " ".join(user_query.lower().split())
|
58 |
+
|
59 |
+
for pattern, sql_template in self.query_patterns.items():
|
60 |
+
match = re.search(pattern, normalized_query, re.IGNORECASE)
|
61 |
+
if match:
|
62 |
+
if match.groups():
|
63 |
+
return sql_template.format(*match.groups())
|
64 |
+
return sql_template
|
65 |
+
|
66 |
+
return self._generate_fallback_query(normalized_query)
|
67 |
+
|
68 |
+
def _generate_fallback_query(self, query: str) -> str:
|
69 |
+
if any(word in query for word in ['department', 'manager']):
|
70 |
+
return "SELECT * FROM Departments"
|
71 |
+
return "SELECT * FROM Employees"
|
72 |
+
|
73 |
+
def sanitize_sql(self, sql: str) -> str:
|
74 |
+
sql = re.sub(r'[;"]', '', sql)
|
75 |
+
sql = sql.replace("'", "''")
|
76 |
+
if not sql.strip().endswith(';'):
|
77 |
+
sql = f"{sql};"
|
78 |
+
|
79 |
+
return sql
|
dockerfile
CHANGED
@@ -10,6 +10,11 @@ COPY . .
|
|
10 |
|
11 |
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
|
13 |
-
EXPOSE
|
|
|
|
|
14 |
|
15 |
-
CMD ["gunicorn", "--bind", "0.0.0.0:
|
|
|
|
|
|
|
|
10 |
|
11 |
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
|
13 |
+
EXPOSE 7860
|
14 |
+
# EXPOSE 8000
|
15 |
+
#(for flask)
|
16 |
|
17 |
+
CMD ["gunicorn", "--bind", "0.0.0.0:7860", "app.main:app"]
|
18 |
+
|
19 |
+
# (for flask)
|
20 |
+
# CMD ["gunicorn", "--bind", "0.0.0.0:8000", "app.main:app"]
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|