Devashish-Nagpal commited on
Commit
ddf9ea0
·
1 Parent(s): 345c7d2

Updated the approach and shifted the flask app to gradio app for deployment on HuggingFace

Browse files
Files changed (5) hide show
  1. app/db.py +4 -1
  2. app/main.py +100 -45
  3. app/nlp.py +68 -3
  4. dockerfile +7 -2
  5. requirements.txt +0 -0
app/db.py CHANGED
@@ -1,12 +1,14 @@
1
  import sqlite3
2
  from sqlite3 import Error
3
- from app.utils import validate_sql
 
4
 
5
  class Database:
6
  def __init__(self, db_path='data/database.sqlite'):
7
  self.db_path = db_path
8
 
9
  def execute_query(self, query):
 
10
  if not validate_sql(query):
11
  return {"error": "Invalid SQL query. Only SELECT queries are allowed at this point."}
12
  try:
@@ -18,4 +20,5 @@ class Database:
18
  conn.close()
19
  return {"columns": columns, "data": results}
20
  except Error as e:
 
21
  return {"error": str(e)}
 
1
  import sqlite3
2
  from sqlite3 import Error
3
+ from .utils import validate_sql
4
+ import logging
5
 
6
  class Database:
7
  def __init__(self, db_path='data/database.sqlite'):
8
  self.db_path = db_path
9
 
10
  def execute_query(self, query):
11
+ logging.info(f"Executing SQL: {query}")
12
  if not validate_sql(query):
13
  return {"error": "Invalid SQL query. Only SELECT queries are allowed at this point."}
14
  try:
 
20
  conn.close()
21
  return {"columns": columns, "data": results}
22
  except Error as e:
23
+ logging.error(f"SQL Error: {str(e)}")
24
  return {"error": str(e)}
app/main.py CHANGED
@@ -1,53 +1,108 @@
1
- from flask import Flask, request, render_template_string
2
- from app.nlp import NLPToSQL
3
- from app.db import Database
 
4
 
5
- app = Flask(__name__)
6
  nlp = NLPToSQL()
7
  db = Database()
8
 
9
- HTML_TEMPLATE = """
10
- !DOCTYPE html>
11
- <html>
12
- <head><title> Chat Assistant </title></head>
13
- <body>
14
- <h1> Database Chat Assistant</h1>
15
- <form method="POST">
16
- <input type="text" name="query" placeholder= "Enter your query..." size="50">
17
- <button type="submit">Ask</button>
18
- </form>
19
- {% if response %}
20
- <h3> Response: </h3>
21
- <pre>{{ response }}</pre>
22
- {% endif %}
23
- {% if error %}
24
- <p style = "color:red;">{{ error }} </p>
25
- {% endif %}
26
- </body>
27
- </html>
28
- """
29
-
30
- @app.route("/", methods=["GET", "POST"])
31
- def index():
32
- if request.method == 'POST':
33
- user_query = request.form['query']
34
- try:
35
- sql = nlp.query_to_sql(user_query)
36
- result = db.execute_query(sql)
37
- if 'error' in result:
38
- return render_template_string(HTML_TEMPLATE, error = result['error'])
39
- if not result['data']:
40
- return render_template_string(HTML_TEMPLATE, error = "No data found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- response = " | ".join(result['columns']) + "\n"
43
- response += "-"*50 + "\n"
44
- for row in result['data']:
45
- response += " | ".join(str(cell) for cell in row) + "\n"
46
- return render_template_string(HTML_TEMPLATE, response = response)
47
 
48
- except Exception as e:
49
- return render_template_string(HTML_TEMPLATE, error = f"Error: {str(e)}")
 
50
 
51
- return render_template_string(HTML_TEMPLATE)
52
 
53
-
 
 
1
+ import logging
2
+ import gradio as gr
3
+ from .nlp import NLPToSQL
4
+ from .db import Database
5
 
 
6
  nlp = NLPToSQL()
7
  db = Database()
8
 
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+ def chatbot(user_query):
12
+ try:
13
+ sql = nlp.query_to_sql(user_query)
14
+ logging.info(f"Generated SQL: {sql}")
15
+ result = db.execute_query(sql)
16
+
17
+ if 'error' in result:
18
+ return f"❌ Error: {result['error']}"
19
+ if not result['data']:
20
+ return "⚠️ No data found."
21
+
22
+ # Formatting SQL output
23
+ response = "**Query:**\n" + sql + "\n\n"
24
+ response += "**Result:**\n"
25
+ response += " | ".join(result['columns']) + "\n"
26
+ response += "-"*50 + "\n"
27
+ for row in result['data']:
28
+ response += " | ".join(str(cell) for cell in row) + "\n"
29
+ return response
30
+
31
+ except Exception as e:
32
+ logging.error(f"Error: {str(e)}")
33
+ return f"❌ Error: {str(e)}"
34
+
35
+ demo = gr.Interface(
36
+ fn=chatbot,
37
+ inputs=gr.Textbox(lines=2, placeholder="Ask your database anything..."),
38
+ outputs="markdown",
39
+ title="SQL Chat Assistant",
40
+ description="Enter a natural language question, and get the corresponding SQL query & result."
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ demo.launch(server_name="0.0.0.0", server_port=7860)
45
+
46
+
47
+
48
+ # Flask-based application can also be used
49
+
50
+ # import logging
51
+ # from flask import Flask, request, render_template_string
52
+ # from .nlp import NLPToSQL
53
+ # from .db import Database
54
+
55
+ # app = Flask(__name__)
56
+ # nlp = NLPToSQL()
57
+ # db = Database()
58
+
59
+ # logging.basicConfig(level=logging.INFO)
60
+
61
+ # HTML_TEMPLATE = """
62
+ # <!DOCTYPE html>
63
+ # <html>
64
+ # <head><title> Chat Assistant </title></head>
65
+ # <body>
66
+ # <h1> Database Chat Assistant</h1>
67
+ # <form method="POST">
68
+ # <input type="text" name="query" placeholder= "Enter your query..." size="50">
69
+ # <button type="submit">Ask</button>
70
+ # </form>
71
+ # {% if response %}
72
+ # <h3> Response: </h3>
73
+ # <pre>{{ response }}</pre>
74
+ # {% endif %}
75
+ # {% if error %}
76
+ # <p style = "color:red;">{{ error }} </p>
77
+ # {% endif %}
78
+ # </body>
79
+ # </html>
80
+ # """
81
+
82
+ # @app.route("/", methods=["GET", "POST"])
83
+ # def index():
84
+ # if request.method == 'POST':
85
+ # user_query = request.form['query']
86
+ # try:
87
+ # sql = nlp.query_to_sql(user_query)
88
+ # logging.info(f"Generated SQL: {sql}")
89
+ # result = db.execute_query(sql)
90
+ # if 'error' in result:
91
+ # return render_template_string(HTML_TEMPLATE, error=result['error'])
92
+ # if not result['data']:
93
+ # return render_template_string(HTML_TEMPLATE, error="No data found")
94
 
95
+ # response = " | ".join(result['columns']) + "\n"
96
+ # response += "-"*50 + "\n"
97
+ # for row in result['data']:
98
+ # response += " | ".join(str(cell) for cell in row) + "\n"
99
+ # return render_template_string(HTML_TEMPLATE, response=response)
100
 
101
+ # except Exception as e:
102
+ # logging.error(f"Error: {str(e)}")
103
+ # return render_template_string(HTML_TEMPLATE, error=f"Error: {str(e)}")
104
 
105
+ # return render_template_string(HTML_TEMPLATE)
106
 
107
+ # if __name__ == "__main__":
108
+ # app.run(debug=True)
app/nlp.py CHANGED
@@ -1,6 +1,8 @@
1
  from transformers import pipeline
 
 
2
 
3
- class NLPToSQL:
4
  def __init__(self):
5
  self.model = pipeline(
6
  "text2text-generation",
@@ -9,6 +11,69 @@ class NLPToSQL:
9
  )
10
 
11
  def query_to_sql(self, user_query):
12
- prompt = f"translate English to SQL: {user_query}"
 
 
 
 
 
 
 
 
13
  result = self.model(prompt, max_length=200)
14
- return result[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
+ import re
3
+ from typing import Dict
4
 
5
+ class NLPToSQL2:
6
  def __init__(self):
7
  self.model = pipeline(
8
  "text2text-generation",
 
11
  )
12
 
13
  def query_to_sql(self, user_query):
14
+ prompt = (f"Generate a valid SQL query in the correct format based on the following schema:\n"
15
+ f"Table1: Employees\n"
16
+ f"Columns: ID, Name, Department, Salary\n"
17
+ f"Table2: Departments\n"
18
+ f"Columns: Name, Manager\n"
19
+ f"Natural Language: {user_query}"
20
+ f"SQL query:"
21
+ )
22
+
23
  result = self.model(prompt, max_length=200)
24
+ sql = result[0]['generated_text']
25
+
26
+ return sql
27
+
28
+ class NLPToSQL:
29
+ def __init__(self):
30
+ self.query_patterns: Dict[str, str] = {
31
+ r"show\s+(?:me\s+)?all\s+employees?\s+in\s+(?:the\s+)?(\w+)\s+department":
32
+ "SELECT * FROM Employees WHERE LOWER(Department) = LOWER('{}')",
33
+
34
+ r"who\s+is\s+(?:the\s+)?manager\s+of\s+(?:the\s+)?(\w+)\s+department":
35
+ "SELECT Manager FROM Departments WHERE LOWER(Name) = LOWER('{}')",
36
+
37
+ r"list\s+(?:all\s+)?employees?\s+hired\s+after\s+(\d{4}-\d{2}-\d{2})":
38
+ "SELECT * FROM Employees WHERE Hire_Date > '{}'",
39
+
40
+ r"what\s+is\s+(?:the\s+)?total\s+salary\s+(?:expense\s+)?for\s+(?:the\s+)?(\w+)\s+department":
41
+ "SELECT SUM(Salary) as Total_Salary FROM Employees WHERE LOWER(Department) = LOWER('{}')",
42
+
43
+ r"show\s+(?:me\s+)?(?:the\s+)?salary\s+of\s+(\w+)":
44
+ "SELECT Salary FROM Employees WHERE LOWER(Name) = LOWER('{}')",
45
+
46
+ r"list\s+(?:all\s+)?employees?\s+with\s+salary\s+(?:greater|more)\s+than\s+(\d+)":
47
+ "SELECT * FROM Employees WHERE Salary > {}",
48
+
49
+ r"(?:show|list)\s+(?:me\s+)?all\s+departments":
50
+ "SELECT * FROM Departments",
51
+
52
+ r"(?:show|list)\s+(?:me\s+)?all\s+employees":
53
+ "SELECT * FROM Employees"
54
+ }
55
+
56
+ def query_to_sql(self, user_query: str) -> str:
57
+ normalized_query = " ".join(user_query.lower().split())
58
+
59
+ for pattern, sql_template in self.query_patterns.items():
60
+ match = re.search(pattern, normalized_query, re.IGNORECASE)
61
+ if match:
62
+ if match.groups():
63
+ return sql_template.format(*match.groups())
64
+ return sql_template
65
+
66
+ return self._generate_fallback_query(normalized_query)
67
+
68
+ def _generate_fallback_query(self, query: str) -> str:
69
+ if any(word in query for word in ['department', 'manager']):
70
+ return "SELECT * FROM Departments"
71
+ return "SELECT * FROM Employees"
72
+
73
+ def sanitize_sql(self, sql: str) -> str:
74
+ sql = re.sub(r'[;"]', '', sql)
75
+ sql = sql.replace("'", "''")
76
+ if not sql.strip().endswith(';'):
77
+ sql = f"{sql};"
78
+
79
+ return sql
dockerfile CHANGED
@@ -10,6 +10,11 @@ COPY . .
10
 
11
  RUN pip install --no-cache-dir -r requirements.txt
12
 
13
- EXPOSE 8000
 
 
14
 
15
- CMD ["gunicorn", "--bind", "0.0.0.0:8000", "app.main:app"]
 
 
 
 
10
 
11
  RUN pip install --no-cache-dir -r requirements.txt
12
 
13
+ EXPOSE 7860
14
+ # EXPOSE 8000
15
+ #(for flask)
16
 
17
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "app.main:app"]
18
+
19
+ # (for flask)
20
+ # CMD ["gunicorn", "--bind", "0.0.0.0:8000", "app.main:app"]
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ