yakine's picture
Update app.py
ef435e0 verified
import gradio as gr
import pandas as pd
import requests
from io import StringIO
import os
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """
def preprocess_user_prompt(user_prompt):
return user_prompt
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
import requests
# Define your Streamlit Space inference URL
inference_endpoint = "https://huggingface.co/spaces/yakine/model"
def generate_synthetic_data(description, columns):
try:
# Format the prompt for your Llama 3 model
formatted_prompt = f"{description}, with columns: {', '.join(columns)}" # Adjust this based on your Streamlit app's prompt format
# Send a POST request to the Streamlit Space API
headers = {
"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json"
}
data = {
"inputs": formatted_prompt, # Adjust according to the input expected by your Streamlit app
"parameters": {
"max_new_tokens": 512,
"top_p": 0.90,
"temperature": 0.8
}
}
response = requests.post(inference_endpoint , json=data, headers=headers)
if response.status_code != 200:
return f"Error: {response.status_code}, {response.text}"
# Extract the generated text from the response
generated_text = response.json().get('data') # Adjust based on your Streamlit Space response structure
return generated_text
except Exception as e:
print(f"Error in generate_synthetic_data: {e}")
return f"Error: {e}"
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
num_iterations = num_rows // rows_per_generation
for _ in range(num_iterations):
generated_data = generate_synthetic_data(description, columns)
if "Error" in generated_data:
return generated_data
df_synthetic = process_generated_data(generated_data)
data_frames.append(df_synthetic)
return pd.concat(data_frames, ignore_index=True)
def process_generated_data(csv_data):
data = StringIO(csv_data)
df = pd.read_csv(data)
return df
def main(description, columns):
description = description.strip()
columns = [col.strip() for col in columns.split(',')]
df_synthetic = generate_large_synthetic_data(description, columns)
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
return df_synthetic # Return the error message if any
return df_synthetic.to_csv(index=False)
iface = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"),
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade")
],
outputs="text",
title="Synthetic Data Generator",
description="Generate synthetic tabular datasets based on a description and specified columns.",
api_name="generate" # Set the API name directly here
)
iface.api_name = "generate"
# Run the Gradio app
iface.launch(server_name="0.0.0.0", server_port=7860)