Spaces:
Sleeping
Sleeping
import threading | |
from datasets import load_dataset | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Global variables for dataset status | |
dataset_loaded = False | |
dataset_status_message = "Dataset is still loading..." | |
dataset_lock = threading.Lock() | |
def load_dataset_in_background(): | |
global dataset_loaded, dataset_status_message | |
try: | |
# Load dataset from Hugging Face | |
dataset = load_dataset("HuggingFaceFW/fineweb", split="train") | |
# Save to CSV for later use | |
dataset.to_csv("data.csv") | |
with dataset_lock: | |
dataset_loaded = True | |
dataset_status_message = "Dataset loaded successfully!" | |
except Exception as e: | |
with dataset_lock: | |
dataset_loaded = False | |
dataset_status_message = f"Error loading dataset: {e}" | |
# Start dataset loading in background thread | |
threading.Thread(target=load_dataset_in_background, daemon=True).start() | |
# Load GPT-2 for inference | |
model_name = "gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) | |
# Function to generate response | |
def generate_response(prompt): | |
responses = generator( | |
prompt, | |
max_length=100, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
num_return_sequences=1 | |
) | |
return responses[0]['generated_text'].strip() | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## GPT-2 AI Assistant with Background Dataset Loading") | |
status_box = gr.Textbox(value=dataset_status_message, label="Dataset Loading Status", interactive=False, lines=2) | |
def refresh_status(): | |
with dataset_lock: | |
return dataset_status_message | |
refresh_button = gr.Button("Check Dataset Status") | |
refresh_button.click(refresh_status, outputs=status_box) | |
gr.Markdown("### Chat with the AI") | |
prompt_input = gr.Textbox(label="Your prompt", placeholder="Ask me anything...") | |
response_output = gr.Textbox(label="AI Response", lines=10) | |
def chat(prompt): | |
return generate_response(prompt) | |
gr.Button("Ask").click(chat, inputs=prompt_input, outputs=response_output) | |
demo.launch() |