chronos / wikiPreTrained.py
Manoj Kumar
Mark POhase 1
e6f4fec
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import re
# Load the trained model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned")
tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned")
# Define a simple function to check if the question is schema-related or SQL-related
def is_schema_question(question: str):
schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"]
return any(keyword in question.lower() for keyword in schema_keywords)
# Helper function to extract table name from the question
def extract_table_name(question: str):
# Regex pattern to find table names, assuming table names are capitalized or match a known pattern
table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question)
if table_name_match:
# Return the matched table name (first capturing group)
return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3)
# If no table name is detected, return None
return None
# Define a function to handle SQL generation
def generate_sql(question: str, schema: dict, model, tokenizer, device):
# Preprocess the question for SQL generation (e.g., reformat)
# Example question: "What is the price of the product with ID 123?"
# Here we use the model to generate SQL query
inputs = tokenizer(question, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
generated_ids = model.generate(input_ids, max_length=128)
# Decode the SQL query generated by the model
sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return sql_query
# Define a function to handle schema-related questions
def handle_schema_question(question: str, schema: dict):
# Here you handle questions about the schema (tables, columns, relations)
# Example schema-related question: "What columns does the products table have?"
question = question.lower()
# Check if the question asks about columns
if "columns" in question or "fields" in question:
table_name = extract_table_name(question)
if table_name:
if table_name in schema:
return schema[table_name]["columns"]
else:
return f"Table '{table_name}' not found in the schema."
# Check if the question asks about relations
elif "relations" in question or "relationships" in question:
table_name = extract_table_name(question)
if table_name:
if table_name in schema:
return schema[table_name]["relations"]
else:
return f"Table '{table_name}' not found in the schema."
# Additional cases can be handled here (e.g., "Which tables are in the schema?")
elif "tables" in question:
return list(schema.keys())
# If the question is too vague or doesn't match the expected patterns
return "Sorry, I couldn't understand your schema question. Could you rephrase?"
# Example schema for your custom use case
custom_schema = {
"products": {
"columns": ["product_id", "name", "price", "category_id"],
"relations": "category_id -> categories.id",
},
"categories": {
"columns": ["id", "category_name"],
"relations": None,
},
"orders": {
"columns": ["order_id", "user_id", "product_id", "order_date"],
"relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
},
"users": {
"columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"],
"relations": None,
}
}
def answer_question(question: str, schema: dict, model, tokenizer, device):
# First, check if the question is about the schema or SQL
if is_schema_question(question):
# Handle schema-related questions
response = handle_schema_question(question, schema)
return f"Schema Information: {response}"
else:
# Generate an SQL query for data-related questions
sql_query = generate_sql(question, schema, model, tokenizer, device)
return f"Generated SQL Query: {sql_query}"
# Example input questions
question_1 = "What columns does the products table have?"
question_2 = "What is the price of the product with product_id 123?"
# Assuming you have loaded your model and tokenizer as `model` and `tokenizer`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Handle schema question
response_1 = answer_question(question_1, custom_schema, model, tokenizer, device)
print(response_1) # This should give you the columns of the products table
# Handle SQL query question
response_2 = answer_question(question_2, custom_schema, model, tokenizer, device)
print(response_2) # This should generate an SQL query for fetching the price