import threading from datasets import load_dataset import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Global variables for dataset status dataset_loaded = False dataset_status_message = "Dataset is still loading..." dataset_lock = threading.Lock() def load_dataset_in_background(): global dataset_loaded, dataset_status_message try: # Load dataset from Hugging Face dataset = load_dataset("HuggingFaceFW/fineweb", split="train") # Save to CSV for later use dataset.to_csv("data.csv") with dataset_lock: dataset_loaded = True dataset_status_message = "Dataset loaded successfully!" except Exception as e: with dataset_lock: dataset_loaded = False dataset_status_message = f"Error loading dataset: {e}" # Start dataset loading in background thread threading.Thread(target=load_dataset_in_background, daemon=True).start() # Load GPT-2 for inference model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) # Function to generate response def generate_response(prompt): responses = generator( prompt, max_length=100, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1 ) return responses[0]['generated_text'].strip() # Gradio interface with gr.Blocks() as demo: gr.Markdown("## GPT-2 AI Assistant with Background Dataset Loading") status_box = gr.Textbox(value=dataset_status_message, label="Dataset Loading Status", interactive=False, lines=2) def refresh_status(): with dataset_lock: return dataset_status_message refresh_button = gr.Button("Check Dataset Status") refresh_button.click(refresh_status, outputs=status_box) gr.Markdown("### Chat with the AI") prompt_input = gr.Textbox(label="Your prompt", placeholder="Ask me anything...") response_output = gr.Textbox(label="AI Response", lines=10) def chat(prompt): return generate_response(prompt) gr.Button("Ask").click(chat, inputs=prompt_input, outputs=response_output) demo.launch()