|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import torch |
|
import re |
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned") |
|
tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned") |
|
|
|
|
|
def is_schema_question(question: str): |
|
schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"] |
|
return any(keyword in question.lower() for keyword in schema_keywords) |
|
|
|
|
|
def extract_table_name(question: str): |
|
|
|
table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question) |
|
|
|
if table_name_match: |
|
|
|
return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
def generate_sql(question: str, schema: dict, model, tokenizer, device): |
|
|
|
|
|
|
|
|
|
inputs = tokenizer(question, return_tensors="pt") |
|
input_ids = inputs.input_ids.to(device) |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate(input_ids, max_length=128) |
|
|
|
|
|
sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
return sql_query |
|
|
|
|
|
def handle_schema_question(question: str, schema: dict): |
|
|
|
|
|
|
|
question = question.lower() |
|
|
|
|
|
if "columns" in question or "fields" in question: |
|
table_name = extract_table_name(question) |
|
if table_name: |
|
if table_name in schema: |
|
return schema[table_name]["columns"] |
|
else: |
|
return f"Table '{table_name}' not found in the schema." |
|
|
|
|
|
elif "relations" in question or "relationships" in question: |
|
table_name = extract_table_name(question) |
|
if table_name: |
|
if table_name in schema: |
|
return schema[table_name]["relations"] |
|
else: |
|
return f"Table '{table_name}' not found in the schema." |
|
|
|
|
|
elif "tables" in question: |
|
return list(schema.keys()) |
|
|
|
|
|
return "Sorry, I couldn't understand your schema question. Could you rephrase?" |
|
|
|
|
|
|
|
custom_schema = { |
|
"products": { |
|
"columns": ["product_id", "name", "price", "category_id"], |
|
"relations": "category_id -> categories.id", |
|
}, |
|
"categories": { |
|
"columns": ["id", "category_name"], |
|
"relations": None, |
|
}, |
|
"orders": { |
|
"columns": ["order_id", "user_id", "product_id", "order_date"], |
|
"relations": ["product_id -> products.product_id", "user_id -> users.user_id"], |
|
}, |
|
"users": { |
|
"columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"], |
|
"relations": None, |
|
} |
|
} |
|
|
|
def answer_question(question: str, schema: dict, model, tokenizer, device): |
|
|
|
if is_schema_question(question): |
|
|
|
response = handle_schema_question(question, schema) |
|
return f"Schema Information: {response}" |
|
else: |
|
|
|
sql_query = generate_sql(question, schema, model, tokenizer, device) |
|
return f"Generated SQL Query: {sql_query}" |
|
|
|
|
|
question_1 = "What columns does the products table have?" |
|
question_2 = "What is the price of the product with product_id 123?" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
response_1 = answer_question(question_1, custom_schema, model, tokenizer, device) |
|
print(response_1) |
|
|
|
|
|
response_2 = answer_question(question_2, custom_schema, model, tokenizer, device) |
|
print(response_2) |