infinitymatter commited on
Commit
f0edb7b
Β·
verified Β·
1 Parent(s): 40dad27

Upload 4 files

Browse files
Files changed (4) hide show
  1. fetch_data.py +15 -0
  2. generate_schema.py +44 -0
  3. main.py +29 -0
  4. synthetic_generator.py +69 -0
fetch_data.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import pandas as pd
3
+ from io import BytesIO
4
+ from Utils.config import DATASET_URLS
5
+
6
+ def fetch_real_data(domain):
7
+ url = DATASET_URLS.get(domain)
8
+ if not url:
9
+ raise ValueError(f"No URL found for domain: {domain}")
10
+
11
+ response = requests.get(url)
12
+ response.raise_for_status()
13
+
14
+ df = pd.read_csv(BytesIO(response.content))
15
+ return df
generate_schema.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+ API_KEY = os.getenv("hf_token")
9
+
10
+
11
+ def generate_schema(user_prompt):
12
+ """ Generates a synthetic dataset schema using Hugging Face API. """
13
+
14
+ system_prompt = """
15
+ You are an expert data scientist designing synthetic datasets.
16
+ For any given dataset description, generate:
17
+ - Column names
18
+ - Data types (string, int, float, date)
19
+ - Approximate row count
20
+
21
+ Output in **pure JSON** format like:
22
+ {
23
+ "columns": ["PatientID", "Age", "Gender", "Diagnosis"],
24
+ "types": ["int", "int", "string", "string"],
25
+ "size": 500
26
+ }
27
+ """
28
+
29
+ payload = {
30
+ "inputs": system_prompt + "\n\nUser request: " + user_prompt,
31
+ "options": {"wait_for_model": True}
32
+ }
33
+
34
+ response = requests.post(HF_MODEL_URL, headers=HEADERS, json=payload)
35
+
36
+ if response.status_code == 200:
37
+ try:
38
+ output = response.json()[0]['generated_text']
39
+ schema = json.loads(output.strip()) # Convert to JSON
40
+ return schema
41
+ except json.JSONDecodeError:
42
+ return {"error": "Invalid JSON output from model. Try again."}
43
+ else:
44
+ return {"error": f"API request failed. Status Code: {response.status_code}"}
main.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ from generate_schema import generate_schema
4
+ from fetch_data import fetch_real_data
5
+ from synthetic_generator import train_and_generate_synthetic
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--prompt", type=str, required=True, help="Describe the dataset you want")
10
+ parser.add_argument("--domain", type=str, default="healthcare", help="Domain to fetch real data from (optional)")
11
+ args = parser.parse_args()
12
+
13
+ # Step 1: Generate schema using LLM
14
+ schema = generate_schema(args.prompt)
15
+ print(f"πŸ“Š Generated schema: {schema}")
16
+
17
+ # Step 2: Fetch real data (optional)
18
+ real_data = fetch_real_data(args.domain)
19
+
20
+ # Step 3: Preprocess (if necessary)
21
+ real_data = real_data[schema['columns']] # Match columns from schema
22
+ print(f"βœ… Fetched real data with shape: {real_data.shape}")
23
+
24
+ # Step 4: Train GAN and generate synthetic data
25
+ output_path = f"outputs/synthetic_{args.domain}.csv"
26
+ train_and_generate_synthetic(real_data, schema, output_path)
27
+
28
+ if __name__ == "__main__":
29
+ main()
synthetic_generator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from ctgan import CTGAN
3
+ from sklearn.preprocessing import LabelEncoder
4
+ import os
5
+ import json
6
+ import requests
7
+
8
+ def train_and_generate_synthetic(real_data, schema, output_path):
9
+ """Trains a CTGAN model and generates synthetic data."""
10
+ categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
11
+
12
+ # Store label encoders
13
+ label_encoders = {}
14
+ for col in categorical_cols:
15
+ le = LabelEncoder()
16
+ real_data[col] = le.fit_transform(real_data[col])
17
+ label_encoders[col] = le
18
+
19
+ # Train CTGAN
20
+ gan = CTGAN(epochs=300)
21
+ gan.fit(real_data, categorical_cols)
22
+
23
+ # Generate synthetic data
24
+ synthetic_data = gan.sample(schema['size'])
25
+
26
+ # Decode categorical columns
27
+ for col in categorical_cols:
28
+ synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
29
+
30
+ # Save to CSV
31
+ os.makedirs('outputs', exist_ok=True)
32
+ synthetic_data.to_csv(output_path, index=False)
33
+ print(f"βœ… Synthetic data saved to {output_path}")
34
+
35
+ def generate_schema(prompt):
36
+ """Fetches schema from an external API and validates JSON."""
37
+ API_URL = "https://api.example.com/schema" # Replace with correct API URL
38
+ headers = {"Authorization": f"Bearer YOUR_HUGGINGFACE_TOKEN"} # Add if needed
39
+
40
+ try:
41
+ response = requests.post(API_URL, json={"prompt": prompt}, headers=headers)
42
+ print("πŸ” Raw API Response:", response.text) # Debugging line
43
+
44
+ schema = response.json()
45
+
46
+ # Validate required keys
47
+ if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
48
+ raise ValueError("❌ Invalid schema format! Expected keys: 'columns', 'types', 'size'")
49
+
50
+ print("βœ… Valid Schema Received:", schema) # Debugging line
51
+ return schema
52
+
53
+ except json.JSONDecodeError:
54
+ print("❌ Failed to parse JSON response. API might be down or returning non-JSON data.")
55
+ return None
56
+ except requests.exceptions.RequestException as e:
57
+ print(f"❌ API request failed: {e}")
58
+ return None
59
+
60
+ def fetch_data(domain):
61
+ """Fetches real data for the given domain and ensures it's a valid DataFrame."""
62
+ data_path = f"datasets/{domain}.csv"
63
+ if os.path.exists(data_path):
64
+ df = pd.read_csv(data_path)
65
+ if not isinstance(df, pd.DataFrame) or df.empty:
66
+ raise ValueError("❌ Loaded data is invalid!")
67
+ return df
68
+ else:
69
+ raise FileNotFoundError(f"❌ Dataset for {domain} not found.")