Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import requests | |
from io import StringIO | |
import os | |
# Access the Hugging Face API token from environment variables | |
hf_token = os.getenv('HF_API_TOKEN') | |
if not hf_token: | |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") | |
# Define your prompt template | |
prompt_template = """\ | |
You are an expert in generating synthetic data for machine learning models. | |
Your task is to generate a synthetic tabular dataset based on the description provided below. | |
Description: {description} | |
The dataset should include the following columns: {columns} | |
Please provide the data in CSV format with a minimum of 100 rows per generation. | |
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. | |
Example Description: | |
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' | |
Example Output: | |
Size,Location,Number of Bedrooms,Price | |
1200,Suburban,3,250000 | |
900,Urban,2,200000 | |
1500,Rural,4,300000 | |
... | |
Description: | |
{description} | |
Columns: | |
{columns} | |
Output: """ | |
def preprocess_user_prompt(user_prompt): | |
return user_prompt | |
def format_prompt(description, columns): | |
processed_description = preprocess_user_prompt(description) | |
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) | |
return prompt | |
import requests | |
# Define your Streamlit Space inference URL | |
inference_endpoint = "https://huggingface.co/spaces/yakine/model" | |
def generate_synthetic_data(description, columns): | |
try: | |
# Format the prompt for your Llama 3 model | |
formatted_prompt = f"{description}, with columns: {', '.join(columns)}" # Adjust this based on your Streamlit app's prompt format | |
# Send a POST request to the Streamlit Space API | |
headers = { | |
"Authorization": f"Bearer {hf_token}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"inputs": formatted_prompt, # Adjust according to the input expected by your Streamlit app | |
"parameters": { | |
"max_new_tokens": 512, | |
"top_p": 0.90, | |
"temperature": 0.8 | |
} | |
} | |
response = requests.post(inference_endpoint , json=data, headers=headers) | |
if response.status_code != 200: | |
return f"Error: {response.status_code}, {response.text}" | |
# Extract the generated text from the response | |
generated_text = response.json().get('data') # Adjust based on your Streamlit Space response structure | |
return generated_text | |
except Exception as e: | |
print(f"Error in generate_synthetic_data: {e}") | |
return f"Error: {e}" | |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): | |
data_frames = [] | |
num_iterations = num_rows // rows_per_generation | |
for _ in range(num_iterations): | |
generated_data = generate_synthetic_data(description, columns) | |
if "Error" in generated_data: | |
return generated_data | |
df_synthetic = process_generated_data(generated_data) | |
data_frames.append(df_synthetic) | |
return pd.concat(data_frames, ignore_index=True) | |
def process_generated_data(csv_data): | |
data = StringIO(csv_data) | |
df = pd.read_csv(data) | |
return df | |
def main(description, columns): | |
description = description.strip() | |
columns = [col.strip() for col in columns.split(',')] | |
df_synthetic = generate_large_synthetic_data(description, columns) | |
if isinstance(df_synthetic, str) and "Error" in df_synthetic: | |
return df_synthetic # Return the error message if any | |
return df_synthetic.to_csv(index=False) | |
iface = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"), | |
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade") | |
], | |
outputs="text", | |
title="Synthetic Data Generator", | |
description="Generate synthetic tabular datasets based on a description and specified columns.", | |
api_name="generate" # Set the API name directly here | |
) | |
iface.api_name = "generate" | |
# Run the Gradio app | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |