Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import pandas as pd | |
import json | |
import io | |
import csv | |
from typing import List, Dict | |
import threading | |
import time | |
import queue | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import asyncio | |
# Global model cache and loading status | |
MODEL_CACHE = {} | |
MODEL_LOADING_STATUS = {} | |
MODEL_LOADING_LOCK = threading.Lock() | |
def check_model_loading_status(model_names: List[str]) -> Dict: | |
"""Check loading status of multiple models""" | |
with MODEL_LOADING_LOCK: | |
status = {} | |
for model_name in model_names: | |
if model_name in MODEL_CACHE: | |
status[model_name] = "ready" | |
elif model_name in MODEL_LOADING_STATUS: | |
status[model_name] = MODEL_LOADING_STATUS[model_name] | |
else: | |
status[model_name] = "not_loaded" | |
return status | |
def load_model_with_status_tracking(model_name: str): | |
"""Load model with status tracking""" | |
with MODEL_LOADING_LOCK: | |
if model_name in MODEL_CACHE: | |
return MODEL_CACHE[model_name], None | |
if model_name in MODEL_LOADING_STATUS: | |
return None, f"โมเดล {model_name} กำลังโหลดอยู่..." | |
MODEL_LOADING_STATUS[model_name] = "loading" | |
try: | |
print(f"🔄 เริ่มโหลดโมเดล {model_name}...") | |
# Update status | |
with MODEL_LOADING_LOCK: | |
MODEL_LOADING_STATUS[model_name] = "downloading" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
with MODEL_LOADING_LOCK: | |
MODEL_LOADING_STATUS[model_name] = "loading_model" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
with MODEL_LOADING_LOCK: | |
MODEL_LOADING_STATUS[model_name] = "creating_pipeline" | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
with MODEL_LOADING_LOCK: | |
MODEL_CACHE[model_name] = generator | |
MODEL_LOADING_STATUS[model_name] = "ready" | |
print(f"✅ โหลดโมเดล {model_name} สำเร็จ") | |
return generator, None | |
except Exception as e: | |
error_msg = f"❌ ไม่สามารถโหลดโมเดล {model_name}: {str(e)}" | |
print(error_msg) | |
with MODEL_LOADING_LOCK: | |
if model_name in MODEL_LOADING_STATUS: | |
del MODEL_LOADING_STATUS[model_name] | |
return None, error_msg | |
def preload_models_async(model_names: List[str], progress_callback=None): | |
"""Preload models asynchronously""" | |
def load_single_model(model_name): | |
generator, error = load_model_with_status_tracking(model_name) | |
if progress_callback: | |
progress_callback(model_name, "ready" if generator else "error", error) | |
return model_name, generator, error | |
results = {} | |
with ThreadPoolExecutor(max_workers=2) as executor: # Limit concurrent loading | |
futures = {executor.submit(load_single_model, model): model for model in model_names} | |
for future in as_completed(futures): | |
model_name, generator, error = future.result() | |
results[model_name] = {"generator": generator, "error": error} | |
return results | |
# Predefined task templates with Thai language support | |
TASK_TEMPLATES = { | |
"text_generation": { | |
"name": "การสร้างข้อความ (Text Generation)", | |
"template": "เขียนเรื่องราวสร้างสรรค์เกี่ยวกับ {topic}", | |
"description": "สร้างข้อความสร้างสรรค์ภาษาไทยจากหัวข้อที่กำหนด" | |
}, | |
"question_answering": { | |
"name": "คำถาม-คำตอบ (Question Answering)", | |
"template": "คำถาม: {question}\nคำตอบ:", | |
"description": "สร้างคู่คำถาม-คำตอบภาษาไทย" | |
}, | |
"summarization": { | |
"name": "การสรุปข้อความ (Text Summarization)", | |
"template": "สรุปข้อความต่อไปนี้: {text}", | |
"description": "สร้างตัวอย่างการสรุปข้อความภาษาไทย" | |
}, | |
"translation": { | |
"name": "การแปลภาษา (Translation)", | |
"template": "แปลจาก {source_lang} เป็น {target_lang}: {text}", | |
"description": "สร้างคู่ข้อมูลสำหรับการแปลภาษา" | |
}, | |
"classification": { | |
"name": "การจำแนกข้อความ (Text Classification)", | |
"template": "จำแนกอารมณ์ของข้อความนี้: {text}\nอารมณ์:", | |
"description": "สร้างตัวอย่างการจำแนกอารมณ์หรือหมวดหมู่ของข้อความ" | |
}, | |
"conversation": { | |
"name": "บทสนทนา (Conversation)", | |
"template": "มนุษย์: {input}\nผู้ช่วย:", | |
"description": "สร้างข้อมูลบทสนทนาภาษาไทย" | |
}, | |
"instruction_following": { | |
"name": "การทำตามคำสั่ง (Instruction Following)", | |
"template": "คำสั่ง: {instruction}\nการตอบสนอง:", | |
"description": "สร้างคู่คำสั่ง-การตอบสนองภาษาไทย" | |
}, | |
"thai_poetry": { | |
"name": "กวีนิพนธ์ไทย (Thai Poetry)", | |
"template": "แต่งกวีนิพนธ์เกี่ยวกับ {topic} ในรูปแบบ {style}", | |
"description": "สร้างกวีนิพนธ์ไทยในรูปแบบต่างๆ" | |
}, | |
"thai_news": { | |
"name": "ข่าวภาษาไทย (Thai News)", | |
"template": "เขียนข่าวภาษาไทยเกี่ยวกับ {topic} ในหัวข้อ {category}", | |
"description": "สร้างข้อความข่าวภาษาไทยในหมวดหมู่ต่างๆ" | |
} | |
} | |
# Thai language models from Hugging Face | |
THAI_MODELS = { | |
"typhoon-7b": { | |
"name": "🌪️ Typhoon-7B (SCB10X)", | |
"model_id": "scb10x/typhoon-7b", | |
"description": "โมเดลภาษาไทยขนาด 7B พารามิเตอร์ ประสิทธิภาพสูง" | |
}, | |
"openthaigpt": { | |
"name": "🇹🇭 OpenThaiGPT 1.5-7B", | |
"model_id": "openthaigpt/openthaigpt1.5-7b-instruct", | |
"description": "โมเดลภาษาไทยรองรับคำสั่งและบทสนทนาหลายรอบ" | |
}, | |
"wangchanlion": { | |
"name": "🦁 Gemma2-9B WangchanLION", | |
"model_id": "aisingapore/Gemma2-9b-WangchanLIONv2-instruct", | |
"description": "โมเดลขนาด 9B รองรับไทย-อังกฤษ พัฒนาโดย AI Singapore" | |
}, | |
"sambalingo": { | |
"name": "🌍 SambaLingo-Thai-Base", | |
"model_id": "sambanovasystems/SambaLingo-Thai-Base", | |
"description": "โมเดลภาษาไทยพื้นฐาน รองรับทั้งไทยและอังกฤษ" | |
}, | |
"other": { | |
"name": "🔧 โมเดลอื่นๆ (Custom)", | |
"model_id": "custom", | |
"description": "ระบุชื่อโมเดลที่ต้องการใช้งานเอง" | |
} | |
} | |
def load_file_data(file_path: str) -> List[Dict]: | |
"""Load data from uploaded file""" | |
try: | |
if file_path.endswith('.csv'): | |
df = pd.read_csv(file_path) | |
return df.to_dict('records') | |
elif file_path.endswith('.json'): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
elif file_path.endswith('.txt'): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
return [{'text': line.strip()} for line in lines if line.strip()] | |
else: | |
raise ValueError("Unsupported file format. Use CSV, JSON, or TXT files.") | |
except Exception as e: | |
raise Exception(f"Error reading file: {str(e)}") | |
def generate_from_template(template: str, data_row: Dict) -> str: | |
"""Generate prompt from template and data""" | |
try: | |
return template.format(**data_row) | |
except KeyError as e: | |
return f"Template error: Missing field {e}" | |
def load_model(model_name): | |
"""Load a Hugging Face model for text generation""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
return generator, None | |
except Exception as e: | |
return None, str(e) | |
def generate_dataset(model_name, prompt_template, num_samples, max_length, temperature, top_p): | |
"""Generate dataset using Hugging Face model""" | |
try: | |
generator, error = load_model(model_name) | |
if error: | |
return None, f"Error loading model: {error}" | |
dataset = [] | |
for i in range(num_samples): | |
# Generate text | |
generated = generator( | |
prompt_template, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
do_sample=True | |
) | |
generated_text = generated[0]['generated_text'] | |
dataset.append({ | |
'id': i + 1, | |
'prompt': prompt_template, | |
'generated_text': generated_text, | |
'full_text': generated_text | |
}) | |
# Convert to DataFrame for display | |
df = pd.DataFrame(dataset) | |
# Create downloadable files | |
csv_data = df.to_csv(index=False) | |
json_data = json.dumps(dataset, indent=2, ensure_ascii=False) | |
return df, csv_data, json_data, None | |
except Exception as e: | |
return None, None, None, f"Error generating dataset: {str(e)}" | |
def generate_dataset_from_task(model_name, task_type, custom_template, file_data, num_samples, max_length, temperature, top_p): | |
"""Generate dataset using task templates or file input""" | |
try: | |
generator, error = load_model(model_name) | |
if error: | |
return None, f"Error loading model: {error}" | |
dataset = [] | |
# Determine the template to use | |
if custom_template and custom_template.strip(): | |
template = custom_template | |
elif task_type in TASK_TEMPLATES: | |
template = TASK_TEMPLATES[task_type]["template"] | |
else: | |
template = "Generate text: {input}" | |
# Generate samples | |
for i in range(num_samples): | |
if file_data and len(file_data) > 0: | |
# Use file data cyclically | |
data_row = file_data[i % len(file_data)] | |
prompt = generate_from_template(template, data_row) | |
else: | |
# Use template with placeholder values | |
prompt = template.replace("{topic}", "artificial intelligence") \ | |
.replace("{question}", "What is machine learning?") \ | |
.replace("{text}", "Sample text for processing") \ | |
.replace("{input}", f"Sample input {i+1}") \ | |
.replace("{instruction}", f"Complete this task {i+1}") | |
# Generate text | |
generated = generator( | |
prompt, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
do_sample=True, | |
pad_token_id=generator.tokenizer.eos_token_id | |
) | |
generated_text = generated[0]['generated_text'] | |
dataset.append({ | |
'id': i + 1, | |
'task_type': task_type, | |
'prompt': prompt, | |
'generated_text': generated_text, | |
'original_data': data_row if file_data else None | |
}) | |
# Convert to DataFrame for display | |
df = pd.DataFrame(dataset) | |
# Create downloadable files | |
csv_data = df.to_csv(index=False) | |
json_data = json.dumps(dataset, indent=2, ensure_ascii=False) | |
return df, csv_data, json_data, None | |
except Exception as e: | |
return None, None, None, f"Error generating dataset: {str(e)}" | |
# Multi-model generation status tracking | |
class ModelStatus: | |
def __init__(self): | |
self.models = {} | |
self.record_status = {} # record_id: {"status": "pending/processing/completed", "model": "model_name"} | |
self.completed_records = [] | |
self.lock = threading.Lock() | |
def set_record_processing(self, record_id: int, model_name: str): | |
with self.lock: | |
self.record_status[record_id] = {"status": "processing", "model": model_name} | |
def set_record_completed(self, record_id: int, result: dict): | |
with self.lock: | |
self.record_status[record_id]["status"] = "completed" | |
self.completed_records.append(result) | |
def get_next_available_record(self, total_records: int, model_name: str) -> int: | |
with self.lock: | |
for i in range(total_records): | |
if i not in self.record_status or self.record_status[i]["status"] == "pending": | |
self.record_status[i] = {"status": "pending", "model": model_name} | |
return i | |
return -1 # No available records | |
def get_progress(self, total_records: int) -> dict: | |
with self.lock: | |
completed = len([r for r in self.record_status.values() if r["status"] == "completed"]) | |
processing = len([r for r in self.record_status.values() if r["status"] == "processing"]) | |
return { | |
"completed": completed, | |
"processing": processing, | |
"total": total_records, | |
"percentage": (completed / total_records * 100) if total_records > 0 else 0 | |
} | |
def load_model_with_cache(model_name: str, cache: dict): | |
"""Load model with caching and progress feedback""" | |
if model_name in cache: | |
return cache[model_name], None | |
try: | |
print(f"🔄 กำลังโหลดโมเดล {model_name}...") | |
# Use smaller models or quantized versions for faster loading | |
if "typhoon" in model_name.lower(): | |
# Load with optimizations | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, # Use half precision | |
device_map="auto", | |
trust_remote_code=True | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
cache[model_name] = generator | |
print(f"✅ โหลดโมเดล {model_name} สำเร็จ") | |
return generator, None | |
except Exception as e: | |
error_msg = f"❌ ไม่สามารถโหลดโมเดล {model_name}: {str(e)}" | |
print(error_msg) | |
return None, error_msg | |
def generate_single_record(generator, prompt: str, record_id: int, model_name: str, | |
max_length: int, temperature: float, top_p: float, | |
task_type: str, original_data: dict, status_tracker: ModelStatus): | |
"""Generate a single record with the given model""" | |
try: | |
# Mark record as processing | |
status_tracker.set_record_processing(record_id, model_name) | |
# Generate text | |
generated = generator( | |
prompt, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
do_sample=True, | |
pad_token_id=generator.tokenizer.eos_token_id if hasattr(generator.tokenizer, 'eos_token_id') else generator.tokenizer.pad_token_id | |
) | |
generated_text = generated[0]['generated_text'] | |
result = { | |
'id': record_id + 1, | |
'model_used': model_name, | |
'task_type': task_type, | |
'prompt': prompt, | |
'generated_text': generated_text, | |
'original_data': original_data, | |
'generation_time': time.time() | |
} | |
# Mark record as completed | |
status_tracker.set_record_completed(record_id, result) | |
return result | |
except Exception as e: | |
# If generation fails, mark as pending again for other models to try | |
with status_tracker.lock: | |
if record_id in status_tracker.record_status: | |
status_tracker.record_status[record_id]["status"] = "pending" | |
return None | |
def model_worker(model_name: str, model_cache: dict, prompts: List[str], | |
task_type: str, original_data_list: List[dict], | |
max_length: int, temperature: float, top_p: float, | |
status_tracker: ModelStatus, progress_callback=None): | |
"""Worker function for each model to process available records""" | |
# Load model | |
generator, error = load_model_with_cache(model_name, model_cache) | |
if error: | |
return f"Error loading {model_name}: {error}" | |
total_records = len(prompts) | |
processed_count = 0 | |
while True: | |
# Get next available record | |
record_id = status_tracker.get_next_available_record(total_records, model_name) | |
if record_id == -1: # No more records available | |
break | |
# Generate record | |
prompt = prompts[record_id] | |
original_data = original_data_list[record_id] if original_data_list else None | |
result = generate_single_record( | |
generator, prompt, record_id, model_name, | |
max_length, temperature, top_p, task_type, | |
original_data, status_tracker | |
) | |
if result: | |
processed_count += 1 | |
# Update progress | |
if progress_callback: | |
progress = status_tracker.get_progress(total_records) | |
progress_callback(progress, model_name, processed_count) | |
return f"{model_name}: Processed {processed_count} records" | |
def generate_dataset_multi_model(selected_models: List[str], task_type: str, custom_template: str, | |
file_data: List[dict], num_samples: int, max_length: int, | |
temperature: float, top_p: float, progress_callback=None): | |
"""Generate dataset using multiple models collaboratively""" | |
try: | |
# Prepare prompts | |
prompts = [] | |
original_data_list = [] | |
# Determine template | |
if custom_template and custom_template.strip(): | |
template = custom_template | |
elif task_type in TASK_TEMPLATES: | |
template = TASK_TEMPLATES[task_type]["template"] | |
else: | |
template = "Generate text: {input}" | |
# Generate prompts for all records | |
for i in range(num_samples): | |
if file_data and len(file_data) > 0: | |
data_row = file_data[i % len(file_data)] | |
prompt = generate_from_template(template, data_row) | |
original_data_list.append(data_row) | |
else: | |
# Use template with placeholder values | |
prompt = template.replace("{topic}", f"หัวข้อที่ {i+1}") \ | |
.replace("{question}", f"คำถามที่ {i+1} เกี่ยวกับการเรียนรู้ของเครื่อง") \ | |
.replace("{text}", f"ข้อความตัวอย่างที่ {i+1} สำหรับการประมวลผล") \ | |
.replace("{input}", f"ข้อมูลนำเข้าที่ {i+1}") \ | |
.replace("{instruction}", f"คำสั่งที่ {i+1}: ให้ทำงานนี้") \ | |
.replace("{category}", "เทคโนโลยี") \ | |
.replace("{style}", "โคลงสี่สุภาพ") | |
original_data_list.append(None) | |
prompts.append(prompt) | |
# Initialize status tracker | |
status_tracker = ModelStatus() | |
model_cache = {} | |
# Start worker threads for each model | |
with ThreadPoolExecutor(max_workers=len(selected_models)) as executor: | |
futures = [] | |
for model_name in selected_models: | |
future = executor.submit( | |
model_worker, model_name, model_cache, prompts, | |
task_type, original_data_list, max_length, | |
temperature, top_p, status_tracker, progress_callback | |
) | |
futures.append((future, model_name)) | |
# Wait for all workers to complete | |
for future, model_name in futures: | |
try: | |
result = future.result(timeout=300) # 5 minute timeout per model | |
print(f"Model {model_name} completed: {result}") | |
except Exception as e: | |
print(f"Model {model_name} failed: {str(e)}") | |
# Collect results | |
dataset = sorted(status_tracker.completed_records, key=lambda x: x['id']) | |
if not dataset: | |
return None, None, None, "ไม่สามารถสร้างข้อมูลได้" | |
# Convert to DataFrame | |
df = pd.DataFrame(dataset) | |
# Create downloadable files | |
csv_data = df.to_csv(index=False) | |
json_data = json.dumps(dataset, indent=2, ensure_ascii=False) | |
return df, csv_data, json_data, None | |
except Exception as e: | |
return None, None, None, f"Error in multi-model generation: {str(e)}" | |
def create_interface(): | |
with gr.Blocks(title="🇹🇭 Thai Dataset Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🤗 เครื่องมือสร้างชุดข้อมูลภาษาไทยคุณภาพสูง") | |
gr.Markdown("⚡ **เคล็ดลับ**: ใช้โมเดลใดก็ได้จาก Hugging Face - เริ่มต้นด้วยโมเดลเล็กๆ เพื่อทดสอบก่อน") | |
with gr.Row(): | |
with gr.Column(): | |
# Flexible model input | |
gr.Markdown("### 🤖 เลือกโมเดลจาก Hugging Face") | |
gr.Markdown("💡 **คำแนะนำ**: ใส่ชื่อโมเดลจาก [Hugging Face](https://huggingface.co/models) เช่น `microsoft/DialoGPT-small`, `gpt2`, `scb10x/typhoon-7b`") | |
model_input_mode = gr.Radio( | |
choices=[ | |
("📝 ใส่ชื่อโมเดลเอง", "manual"), | |
("📋 เลือกจากรายการแนะนำ", "suggested"), | |
("🔀 ใช้หลายโมเดลพร้อมกัน", "multiple") | |
], | |
value="manual", | |
label="วิธีการเลือกโมเดล" | |
) | |
# Manual model input | |
manual_model_group = gr.Group(visible=True) | |
with manual_model_group: | |
single_model_name = gr.Textbox( | |
label="ชื่อโมเดลจาก Hugging Face", | |
value="microsoft/DialoGPT-small", | |
placeholder="เช่น gpt2, microsoft/DialoGPT-medium, scb10x/typhoon-7b", | |
info="ใส่ชื่อโมเดลที่ต้องการใช้งาน" | |
) | |
model_verification = gr.Button("🔍 ตรวจสอบโมเดล", variant="secondary", size="sm") | |
model_download = gr.Button("⬇️ ดาวน์โหลดโมเดล", variant="secondary", size="sm") | |
model_status = gr.Textbox( | |
label="สถานะโมเดล", | |
value="ยังไม่ได้ตรวจสอบ", | |
interactive=False | |
) | |
# เชื่อมปุ่มตรวจสอบโมเดลกับฟังก์ชันตรวจสอบ | |
def verify_model(model_name): | |
from transformers import AutoTokenizer | |
try: | |
# ลองโหลด tokenizer (เร็วกว่าโหลด model) | |
AutoTokenizer.from_pretrained(model_name) | |
return gr.update(value=f"✅ พบโมเดล {model_name} ใน Hugging Face", interactive=False) | |
except Exception as e: | |
return gr.update(value=f"❌ ไม่พบโมเดลหรือโหลดไม่ได้: {str(e)}", interactive=False) | |
model_verification.click( | |
fn=verify_model, | |
inputs=[single_model_name], | |
outputs=[model_status] | |
) | |
# ปุ่มดาวน์โหลดโมเดล (preload) | |
def download_model(model_name): | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
try: | |
t0 = time.time() | |
model_status_msg = f"⏳ กำลังดาวน์โหลดและโหลดโมเดล {model_name} ..." | |
yield gr.update(value=model_status_msg, interactive=False) | |
# โหลด tokenizer และ model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
t1 = time.time() | |
msg = f"✅ โหลดโมเดล {model_name} สำเร็จใน {t1-t0:.1f} วินาที" | |
yield gr.update(value=msg, interactive=False) | |
except Exception as e: | |
yield gr.update(value=f"❌ ไม่สามารถโหลดโมเดล: {str(e)}", interactive=False) | |
model_download.click( | |
fn=download_model, | |
inputs=[single_model_name], | |
outputs=[model_status] | |
) | |
# Suggested models | |
suggested_model_group = gr.Group(visible=False) | |
with suggested_model_group: | |
gr.Markdown("#### โมเดลแนะนำ") | |
suggested_models = gr.Dropdown( | |
choices=[ | |
# Small/Fast models | |
("⚡ DistilGPT2 (เล็ก, เร็ว)", "distilgpt2"), | |
("⚡ GPT2 (กลาง)", "gpt2"), | |
("⚡ DialoGPT-small (บทสนทนา)", "microsoft/DialoGPT-small"), | |
("⚡ DialoGPT-medium (บทสนทนา)", "microsoft/DialoGPT-medium"), | |
# Thai models | |
("🇹🇭 Typhoon-7B (ไทย, ใหญ่)", "scb10x/typhoon-7b"), | |
("🇹🇭 OpenThaiGPT-1.5-7B (ไทย)", "openthaigpt/openthaigpt1.5-7b-instruct"), | |
("🇹🇭 WangchanLION-7B (ไทย)", "aisingapore/llama2-7b-chat-thai"), | |
# Multilingual models | |
("🌍 mGPT (หลายภาษา)", "ai-forever/mGPT"), | |
("🌍 Bloom-560m (หลายภาษา, เล็ก)", "bigscience/bloom-560m"), | |
("🌍 Bloom-1b1 (หลายภาษา)", "bigscience/bloom-1b1"), | |
# Instruction-following | |
("🎯 Flan-T5-small (คำสั่ง)", "google/flan-t5-small"), | |
("🎯 Flan-T5-base (คำสั่ง)", "google/flan-t5-base"), | |
# Other popular models | |
("🔥 OPT-350m (Meta)", "facebook/opt-350m"), | |
("🔥 OPT-1.3b (Meta)", "facebook/opt-1.3b"), | |
], | |
value="distilgpt2", | |
label="เลือกโมเดลแนะนำ" | |
) | |
# Multiple models | |
multiple_model_group = gr.Group(visible=False) | |
with multiple_model_group: | |
multiple_model_names = gr.Textbox( | |
label="ชื่อโมเดลหลายตัว (แยกด้วยเครื่องหมายจุลภาค)", | |
value="distilgpt2, microsoft/DialoGPT-small", | |
placeholder="gpt2, microsoft/DialoGPT-medium, scb10x/typhoon-7b", | |
lines=3, | |
info="ใส่ชื่อโมเดลหลายตัวแยกด้วยเครื่องหมายจุลภาค" | |
) | |
model_distribution_mode = gr.Radio( | |
choices=[ | |
("🔄 แบ่งงานกัน (Collaborative)", "collaborative"), | |
("🎲 สุ่มเลือก (Random)", "random"), | |
("📊 เท่าๆ กัน (Round-robin)", "round_robin") | |
], | |
value="collaborative", | |
label="วิธีการใช้โมเดลหลายตัว" | |
) | |
# Model info display | |
current_models_display = gr.Textbox( | |
label="โมเดลที่จะใช้", | |
value="microsoft/DialoGPT-small", | |
interactive=False | |
) | |
# Task selection with Thai tasks | |
gr.Markdown("### 📝 เลือกประเภทงาน") | |
task_dropdown = gr.Dropdown( | |
choices=[(v["name"], k) for k, v in TASK_TEMPLATES.items()], | |
value="text_generation", | |
label="ประเภทงานที่ต้องการ" | |
) | |
task_description = gr.Textbox( | |
label="คำอธิบายงาน", | |
value=TASK_TEMPLATES["text_generation"]["description"], | |
interactive=False | |
) | |
# File upload section | |
gr.Markdown("### 📁 อัปโหลดข้อมูลต้นฉบับ (ไม่บังคับ)") | |
gr.Markdown("อัปโหลดไฟล์ CSV, JSON หรือ TXT ที่มีข้อมูลต้นฉบับภาษาไทย") | |
file_upload = gr.File( | |
label="อัปโหลดไฟล์ข้อมูล", | |
file_types=[".csv", ".json", ".txt"] | |
) | |
file_preview = gr.Dataframe( | |
label="ตัวอย่างข้อมูลจากไฟล์ (5 แถวแรก)", | |
visible=False | |
) | |
# State สำหรับเก็บข้อมูลไฟล์ (ต้องอยู่ก่อนใช้งาน) | |
file_data_state = gr.State() | |
# ฟังก์ชัน handle file upload | |
def handle_file_upload(file): | |
import pandas as pd | |
import json | |
if file is None: | |
return gr.update(visible=False), None | |
try: | |
if file.name.endswith('.csv'): | |
df = pd.read_csv(file.name) | |
elif file.name.endswith('.json'): | |
with open(file.name, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
df = pd.DataFrame(data) | |
elif file.name.endswith('.txt'): | |
with open(file.name, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
df = pd.DataFrame({'text': [line.strip() for line in lines if line.strip()]}) | |
else: | |
return gr.update(visible=True, value="ไม่รองรับไฟล์นี้"), None | |
preview = df.head(5) | |
# คืน preview และข้อมูลทั้งหมด (list of dict) | |
return gr.update(visible=True, value=preview), df.to_dict('records') | |
except Exception as e: | |
return gr.update(visible=True, value=f"❌ อ่านไฟล์ผิดพลาด: {str(e)}"), None | |
file_upload.change( | |
fn=handle_file_upload, | |
inputs=[file_upload], | |
outputs=[file_preview, file_data_state] | |
) | |
# Template customization with multi-prompt support | |
gr.Markdown("### 🎯 ปรับแต่งเทมเพลตและ Prompt") | |
gr.Markdown("ใช้ {ชื่อฟิลด์} สำหรับตัวแปรในเทมเพลต") | |
prompt_mode = gr.Radio( | |
choices=[ | |
("📝 Prompt เดียว (Single)", "single"), | |
("📋 หลาย Prompt (Multiple)", "multiple"), | |
("🎲 สุ่มจาก Template (Random)", "random") | |
], | |
value="single", | |
label="โหมดการใส่ Prompt" | |
) | |
# Single prompt mode | |
single_prompt_group = gr.Group(visible=True) | |
with single_prompt_group: | |
template_display = gr.Textbox( | |
label="เทมเพลตปัจจุบัน", | |
value=TASK_TEMPLATES["text_generation"]["template"], | |
interactive=False | |
) | |
custom_template = gr.Textbox( | |
label="เทมเพลตกำหนดเอง (ไม่บังคับ)", | |
lines=3, | |
placeholder="สร้างเทมเพลตของคุณเองที่นี่..." | |
) | |
# Multiple prompts mode | |
multi_prompt_group = gr.Group(visible=False) | |
with multi_prompt_group: | |
gr.Markdown("#### 📋 ใส่หลาย Prompt (แต่ละบรรทัดคือ prompt หนึ่งตัว)") | |
multi_prompts = gr.Textbox( | |
label="Prompts หลายตัว (แยกด้วยการขึ้นบรรทัดใหม่)", | |
lines=10, | |
placeholder="""เขียนเรื่องราวเกี่ยวกับการผจญภัยในป่า | |
สร้างบทสนทนาระหว่างครูกับนักเรียน | |
อธิบายวิธีการทำอาหารไทย | |
เขียนบทกวีเกี่ยวกับธรรมชาติ | |
สร้างเรื่องสั้นเกี่ยวกับมิตรภาพ""" | |
) | |
prompt_distribution = gr.Radio( | |
choices=[ | |
("📊 กระจายเท่าๆ กัน", "equal"), | |
("🎯 ตามสัดส่วนที่กำหนด", "weighted"), | |
("🎲 สุ่ม", "random") | |
], | |
value="equal", | |
label="วิธีการกระจาย Prompt" | |
) | |
prompt_weights = gr.Textbox( | |
label="น้ำหนักของแต่ละ Prompt (เช่น 2,1,3,1,2)", | |
placeholder="2,1,3,1,2", | |
visible=False | |
) | |
# Random template mode | |
random_prompt_group = gr.Group(visible=False) | |
with random_prompt_group: | |
gr.Markdown("#### 🎲 สุ่ม Prompt จาก Template ที่เลือก") | |
random_templates = gr.CheckboxGroup( | |
choices=[(v["name"], k) for k, v in TASK_TEMPLATES.items()], | |
value=["text_generation", "conversation"], | |
label="เลือก Template ที่จะสุ่ม" | |
) | |
random_variables = gr.Textbox( | |
label="ตัวแปรสำหรับสุ่ม (JSON format)", | |
lines=5, | |
value="""{ | |
"topic": ["การเดินทาง", "เทคโนโลยี", "อาหาร", "ธรรมชาติ", "ศิลปะ"], | |
"question": ["AI คืออะไร", "โลกร้อนคืออะไร", "การศึกษาสำคัญอย่างไร"], | |
"instruction": ["เขียนบทความ", "สรุปข้อมูล", "วิเคราะห์ปัญหา"] | |
}""", | |
placeholder="ใส่ตัวแปรในรูปแบบ JSON" | |
) | |
# Prompt preview and count | |
prompt_preview = gr.Textbox( | |
label="ตัวอย่าง Prompt ที่จะใช้", | |
lines=3, | |
interactive=False | |
) | |
prompt_count = gr.Textbox( | |
label="จำนวน Prompt ที่พร้อมใช้", | |
value="1 prompt", | |
interactive=False | |
) | |
# State สำหรับเก็บข้อมูลไฟล์ | |
file_data_state = gr.State() | |
# ตัวเลือกจำนวนแถวข้อมูล (row_preset) | |
row_preset = gr.Dropdown( | |
choices=[ | |
("10 แถว", 10), | |
("100 แถว", 100), | |
("500 แถว", 500), | |
("1000 แถว", 1000) | |
], | |
value=10, | |
label="จำนวนแถวข้อมูลที่ต้องการสร้าง" | |
) | |
# กำหนดจำนวนแถวเอง (custom_rows) | |
custom_rows = gr.Textbox( | |
label="จำนวนแถวกำหนดเอง (ถ้าเว้นว่างจะใช้ค่าจากด้านบน)", | |
placeholder="ใส่ตัวเลข เช่น 123" | |
) | |
# ตัวเลือกการตั้งค่าการสร้างข้อความ | |
max_length = gr.Slider( | |
minimum=16, | |
maximum=2048, | |
value=128, | |
step=1, | |
label="ความยาวสูงสุดของข้อความที่สร้าง (max_length)" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.0, | |
step=0.05, | |
label="Temperature (ความสุ่ม)" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.01, | |
label="Top-p (nucleus sampling)" | |
) | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=32, | |
value=1, | |
step=1, | |
label="Batch size" | |
) | |
# ปุ่มสร้างข้อมูล | |
generate_btn = gr.Button("🚀 สร้างข้อมูล", variant="primary") | |
# Data Quality Settings | |
gr.Markdown("### 🧼 การจัดการคุณภาพข้อมูล") | |
enable_cleaning = gr.Checkbox( | |
label="เปิดใช้การทำความสะอาดข้อมูล", | |
value=True | |
) | |
remove_duplicates = gr.Checkbox( | |
label="ลบข้อมูลซ้ำซ้อน", | |
value=True | |
) | |
min_quality_score = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
label="คะแนนคุณภาพขั้นต่ำ (0-1)" | |
) | |
# ตัวเลือกแยกชุดข้อมูล (train/val/test split) | |
create_splits = gr.Checkbox( | |
label="แยกชุดข้อมูลเป็น train/val/test", | |
value=False | |
) | |
# Export Settings | |
gr.Markdown("### 📦 การส่งออกข้อมูล") | |
export_format = gr.CheckboxGroup( | |
choices=[ | |
("📊 CSV (Excel, Spreadsheet)", "csv"), | |
("📋 JSON (Web APIs, General)", "json"), | |
("📄 JSONL (Fine-tuning, Streaming)", "jsonl"), | |
("🤗 Hugging Face Dataset (Complete Package)", "huggingface"), | |
("📝 TXT (Simple Text)", "txt"), | |
("🗃️ Parquet (Big Data, Analytics)", "parquet"), | |
("📋 TSV (Tab-separated)", "tsv"), | |
("🎯 Custom Format", "custom") | |
], | |
value=["csv", "json"], | |
label="เลือกรูปแบบไฟล์ที่ต้องการ (สามารถเลือกหลายแบบ)" | |
) | |
# Custom format settings | |
custom_format_group = gr.Group(visible=False) | |
with custom_format_group: | |
gr.Markdown("#### 🎯 การตั้งค่ารูปแบบกำหนดเอง") | |
custom_template_format = gr.Textbox( | |
label="Template สำหรับแต่ละ record", | |
value="Input: {input}\nOutput: {output}\n---", | |
lines=3, | |
placeholder="ใช้ {field_name} สำหรับข้อมูล" | |
) | |
custom_file_extension = gr.Textbox( | |
label="นามสกุลไฟล์", | |
value="txt", | |
placeholder="เช่น txt, md, xml" | |
) | |
# Advanced export options | |
with gr.Accordion("⚙️ ตัวเลือกขั้นสูง", open=False): | |
include_metadata = gr.Checkbox( | |
label="รวม Metadata (model_used, timestamp, etc.)", | |
value=True | |
) | |
include_quality_score = gr.Checkbox( | |
label="รวม Quality Score", | |
value=True | |
) | |
file_naming_pattern = gr.Textbox( | |
label="รูปแบบชื่อไฟล์", | |
value="thai_dataset_{task}_{timestamp}", | |
placeholder="ใช้ {task}, {timestamp}, {model}, {count}" | |
) | |
compression = gr.Radio( | |
choices=[ | |
("ไม่บีบอัด", "none"), | |
("ZIP", "zip"), | |
("GZIP", "gzip") | |
], | |
value="none", | |
label="การบีบอัดไฟล์" | |
) | |
# ...existing code... | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("📊 ตัวอย่างข้อมูล"): | |
dataset_preview = gr.Dataframe( | |
headers=["id", "task_type", "input", "output", "quality_score"], | |
interactive=False | |
) | |
status_message = gr.Markdown( | |
value="", | |
visible=True | |
) | |
# State สำหรับข้อมูลที่สร้าง | |
csv_data_state = gr.State() | |
json_data_state = gr.State() | |
dataset_card_state = gr.State() | |
hf_export_state = gr.State() | |
loading_status = gr.State() | |
with gr.TabItem("📈 รายงานคุณภาพ"): | |
quality_report = gr.JSON( | |
label="รายงานคุณภาพข้อมูล", | |
visible=True | |
) | |
quality_summary = gr.Markdown( | |
value="สร้างข้อมูลเสร็จแล้วจึงจะแสดงรายงานคุณภาพ" | |
) | |
with gr.TabItem("💾 ดาวน์โหลด"): | |
gr.Markdown("### 💾 ดาวน์โหลดชุดข้อมูลในรูปแบบต่างๆ") | |
download_status = gr.Markdown("สร้างข้อมูลเสร็จแล้วจึงจะสามารถดาวน์โหลดได้") | |
# Dynamic download buttons based on selected formats | |
download_buttons = {} | |
download_files = {} | |
with gr.Row(): | |
csv_btn = gr.Button("📊 CSV", variant="secondary", visible=False) | |
json_btn = gr.Button("📋 JSON", variant="secondary", visible=False) | |
jsonl_btn = gr.Button("📄 JSONL", variant="secondary", visible=False) | |
txt_btn = gr.Button("📝 TXT", variant="secondary", visible=False) | |
with gr.Row(): | |
parquet_btn = gr.Button("🗃️ Parquet", variant="secondary", visible=False) | |
tsv_btn = gr.Button("📋 TSV", variant="secondary", visible=False) | |
hf_btn = gr.Button("🤗 HF Dataset", variant="secondary", visible=False) | |
custom_btn = gr.Button("🎯 Custom", variant="secondary", visible=False) | |
# Download files | |
csv_download = gr.File(label="CSV File", visible=False) | |
json_download = gr.File(label="JSON File", visible=False) | |
jsonl_download = gr.File(label="JSONL File", visible=False) | |
txt_download = gr.File(label="TXT File", visible=False) | |
parquet_download = gr.File(label="Parquet File", visible=False) | |
tsv_download = gr.File(label="TSV File", visible=False) | |
hf_download = gr.File(label="HF Dataset Package", visible=False) | |
custom_download = gr.File(label="Custom Format", visible=False) | |
# All formats in one package | |
with gr.Row(): | |
package_btn = gr.Button("📦 ดาวน์โหลดทั้งหมด (ZIP)", variant="primary") | |
package_download = gr.File(label="Complete Package", visible=False) | |
# ...existing code for states... | |
def update_export_format_visibility(selected_formats): | |
"""Update visibility of download buttons based on selected formats""" | |
return [ | |
gr.update(visible=("csv" in selected_formats)), | |
gr.update(visible=("json" in selected_formats)), | |
gr.update(visible=("jsonl" in selected_formats)), | |
gr.update(visible=("txt" in selected_formats)), | |
gr.update(visible=("parquet" in selected_formats)), | |
gr.update(visible=("tsv" in selected_formats)), | |
gr.update(visible=("huggingface" in selected_formats)), | |
gr.update(visible=("custom" in selected_formats)), | |
gr.update(visible=("custom" in selected_formats)) | |
] | |
def generate_multiple_formats(data, selected_formats, include_metadata, include_quality_score, | |
file_naming_pattern, custom_template_format, custom_file_extension, | |
task_type, compression): | |
"""Generate data in multiple formats""" | |
from datetime import datetime | |
import tempfile | |
import zipfile | |
import gzip | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
if not data: | |
return {} | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
model_name = data[0].get('model_used', 'unknown').replace('/', '_') | |
# Prepare data | |
export_data = [] | |
for record in data: | |
export_record = {} | |
export_record['input'] = record.get('prompt', '') | |
export_record['output'] = record.get('generated_text', '') | |
if include_metadata: | |
export_record['metadata'] = { | |
'model_used': record.get('model_used', ''), | |
'task_type': record.get('task_type', ''), | |
'timestamp': record.get('generation_time', '') | |
} | |
if include_quality_score and 'quality_score' in record: | |
export_record['quality_score'] = record['quality_score'] | |
export_data.append(export_record) | |
# Generate filename | |
filename_base = file_naming_pattern.format( | |
task=task_type, | |
timestamp=timestamp, | |
model=model_name, | |
count=len(export_data) | |
) | |
generated_files = {} | |
# Generate each format | |
if "csv" in selected_formats: | |
df = pd.DataFrame(export_data) | |
csv_content = df.to_csv(index=False) | |
generated_files['csv'] = (f"{filename_base}.csv", csv_content) | |
if "json" in selected_formats: | |
json_content = json.dumps(export_data, indent=2, ensure_ascii=False) | |
generated_files['json'] = (f"{filename_base}.json", json_content) | |
if "jsonl" in selected_formats: | |
jsonl_content = '\n'.join([json.dumps(record, ensure_ascii=False) for record in export_data]) | |
generated_files['jsonl'] = (f"{filename_base}.jsonl", jsonl_content) | |
if "txt" in selected_formats: | |
txt_content = '\n'.join([f"Input: {record['input']}\nOutput: {record['output']}\n---" for record in export_data]) | |
generated_files['txt'] = (f"{filename_base}.txt", txt_content) | |
if "tsv" in selected_formats: | |
df = pd.DataFrame(export_data) | |
tsv_content = df.to_csv(index=False, sep='\t') | |
generated_files['tsv'] = (f"{filename_base}.tsv", tsv_content) | |
if "parquet" in selected_formats: | |
df = pd.DataFrame(export_data) | |
temp_parquet = tempfile.mktemp(suffix='.parquet') | |
df.to_parquet(temp_parquet) | |
with open(temp_parquet, 'rb') as f: | |
parquet_content = f.read() | |
generated_files['parquet'] = (f"{filename_base}.parquet", parquet_content) | |
if "custom" in selected_formats: | |
custom_content = [] | |
for record in export_data: | |
formatted = custom_template_format.format(**record) | |
custom_content.append(formatted) | |
custom_text = '\n'.join(custom_content) | |
generated_files['custom'] = (f"{filename_base}.{custom_file_extension}", custom_text) | |
# Apply compression if selected | |
if compression == "gzip": | |
for format_name, (filename, content) in generated_files.items(): | |
if isinstance(content, str): | |
content = content.encode('utf-8') | |
compressed = gzip.compress(content) | |
generated_files[format_name] = (filename + '.gz', compressed) | |
return generated_files | |
def create_complete_package(generated_files, compression): | |
"""Create a complete package with all formats""" | |
import tempfile | |
import zipfile | |
if not generated_files: | |
return None | |
temp_zip = tempfile.mktemp(suffix='.zip') | |
with zipfile.ZipFile(temp_zip, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for format_name, (filename, content) in generated_files.items(): | |
if isinstance(content, str): | |
content = content.encode('utf-8') | |
zipf.writestr(filename, content) | |
return temp_zip | |
def download_specific_format(format_name, generated_files): | |
"""Download specific format""" | |
if format_name in generated_files: | |
filename, content = generated_files[format_name] | |
if isinstance(content, str): | |
return gr.update(visible=True, value=io.StringIO(content)) | |
else: | |
temp_file = tempfile.mktemp() | |
with open(temp_file, 'wb') as f: | |
f.write(content) | |
return gr.update(visible=True, value=temp_file) | |
return gr.update(visible=False) | |
# Event handlers | |
export_format.change( | |
fn=update_export_format_visibility, | |
inputs=[export_format], | |
outputs=[csv_btn, json_btn, jsonl_btn, txt_btn, parquet_btn, tsv_btn, hf_btn, custom_btn, custom_format_group] | |
) | |
# ...existing code for other event handlers... | |
# Download button handlers | |
csv_btn.click( | |
fn=lambda files: download_specific_format('csv', files), | |
inputs=[gr.State()], # Will be connected to generated files state | |
outputs=[csv_download] | |
) | |
json_btn.click( | |
fn=lambda files: download_specific_format('json', files), | |
inputs=[gr.State()], | |
outputs=[json_download] | |
) | |
jsonl_btn.click( | |
fn=lambda files: download_specific_format('jsonl', files), | |
inputs=[gr.State()], | |
outputs=[jsonl_download] | |
) | |
txt_btn.click( | |
fn=lambda files: download_specific_format('txt', files), | |
inputs=[gr.State()], | |
outputs=[txt_download] | |
) | |
parquet_btn.click( | |
fn=lambda files: download_specific_format('parquet', files), | |
inputs=[gr.State()], | |
outputs=[parquet_download] | |
) | |
tsv_btn.click( | |
fn=lambda files: download_specific_format('tsv', files), | |
inputs=[gr.State()], | |
outputs=[tsv_download] | |
) | |
hf_btn.click( | |
fn=lambda files: download_specific_format('huggingface', files), | |
inputs=[gr.State()], | |
outputs=[hf_download] | |
) | |
custom_btn.click( | |
fn=lambda files: download_specific_format('custom', files), | |
inputs=[gr.State()], | |
outputs=[custom_download] | |
) | |
package_btn.click( | |
fn=lambda files, comp: gr.update(visible=True, value=create_complete_package(files, comp)), | |
inputs=[gr.State(), compression], # Will be connected to generated files and compression | |
outputs=[package_download] | |
) | |
# Update generate button to use correct function | |
generate_btn.click( | |
fn=process_with_flexible_models, | |
inputs=[model_input_mode, single_model_name, suggested_models, multiple_model_names, | |
model_distribution_mode, task_dropdown, prompt_mode, custom_template, | |
multi_prompts, random_templates, random_variables, file_data_state, | |
row_preset, custom_rows, max_length, temperature, top_p, batch_size, | |
enable_cleaning, remove_duplicates, min_quality_score, | |
create_splits, export_format], | |
outputs=[dataset_preview, status_message, quality_report, quality_summary, | |
csv_data_state, json_data_state, dataset_card_state, hf_export_state, | |
loading_status] | |
) | |
return demo | |
def validate_models_before_generation(*args, **kwargs): | |
# TODO: implement validation logic | |
return None | |
def process_with_flexible_models(input_mode, single_model, suggested_model, multiple_models, | |
model_distribution_mode, task_type, prompt_mode, custom_template, | |
multi_prompts, random_templates, random_variables, file_data, | |
row_preset, custom_rows, max_length, temperature, top_p, batch_size, | |
enable_cleaning, remove_duplicates, min_quality_score, | |
create_splits, export_format): | |
"""Process generation with flexible model selection""" | |
# ฟังก์ชันเลือกโมเดลที่ใช้จริง | |
def get_selected_models(input_mode, single_model, suggested_model, multiple_models): | |
if input_mode == "manual": | |
return [single_model.strip()] if single_model and single_model.strip() else [] | |
elif input_mode == "suggested": | |
return [suggested_model] if suggested_model else [] | |
elif input_mode == "multiple": | |
# แยกชื่อโมเดลด้วย , และลบช่องว่าง | |
return [m.strip() for m in multiple_models.split(",") if m.strip()] | |
return [] | |
# ฟังก์ชันนับจำนวนแถวข้อมูลที่ต้องการสร้าง | |
def get_final_row_count(row_preset, custom_rows): | |
try: | |
if custom_rows and str(custom_rows).strip(): | |
return int(custom_rows) | |
return int(row_preset) | |
except Exception: | |
return 10 | |
# Get selected models | |
selected_models = get_selected_models(input_mode, single_model, suggested_model, multiple_models) | |
if not selected_models: | |
yield ( | |
gr.update(visible=False), | |
gr.update(visible=True, value="❌ กรุณาเลือกโมเดลอย่างน้อยหนึ่งตัว"), | |
{}, "ไม่มีโมเดล", None, None, None, None, | |
"❌ ไม่ได้เลือกโมเดล" | |
) | |
return | |
num_samples = get_final_row_count(row_preset, custom_rows) | |
try: | |
yield ( | |
gr.update(visible=False), | |
gr.update(visible=True, value=f"🔄 กำลังสร้างข้อมูล {num_samples} แถว..."), | |
{}, "กำลังสร้าง...", None, None, None, None, | |
f"🔄 กำลังประมวลผล..." | |
) | |
# Simple generation for now | |
model_name = selected_models[0] | |
df, csv_data, json_data, error = generate_dataset_from_task( | |
model_name, task_type, custom_template, file_data, | |
num_samples, max_length, temperature, top_p | |
) | |
if error: | |
yield ( | |
gr.update(visible=False), | |
gr.update(visible=True, value=f"❌ เกิดข้อผิดพลาด: {error}"), | |
{}, "เกิดข้อผิดพลาด", None, None, None, None, | |
f"❌ {error}" | |
) | |
return | |
# Basic quality processing | |
raw_data = df.to_dict('records') | |
quality_report = { | |
"total_records": len(raw_data), | |
"models_used": selected_models | |
} | |
final_df = pd.DataFrame(raw_data) | |
final_csv = final_df.to_csv(index=False) | |
final_json = json.dumps(raw_data, indent=2, ensure_ascii=False) | |
dataset_card = f"# Dataset generated with {model_name}\n\nRecords: {len(raw_data)}" | |
success_msg = f"✅ สร้างข้อมูลสำเร็จ! ได้ {len(raw_data)} แถว" | |
quality_summary = f"📊 จำนวนข้อมูล: {len(raw_data)} แถว" | |
yield ( | |
gr.update(visible=True, value=final_df), | |
gr.update(visible=True, value=success_msg), | |
quality_report, | |
quality_summary, | |
final_csv, | |
final_json, | |
dataset_card, | |
None, | |
"✅ เสร็จสิ้น!" | |
) | |
except Exception as e: | |
yield ( | |
gr.update(visible=False), | |
gr.update(visible=True, value=f"❌ ข้อผิดพลาด: {str(e)}"), | |
{}, "เกิดข้อผิดพลาด", None, None, None, None, | |
f"❌ {str(e)}" | |
) | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |