File size: 2,319 Bytes
8b84056
52779c2
 
8b84056
52779c2
577b861
8b84056
577b861
 
52779c2
8b84056
577b861
8b84056
577b861
8b84056
577b861
8b84056
577b861
 
 
8b84056
577b861
 
 
52779c2
8b84056
 
 
577b861
8b84056
52779c2
 
 
 
577b861
52779c2
 
 
 
 
 
 
 
 
 
 
 
577b861
52779c2
577b861
 
52779c2
577b861
 
 
52779c2
577b861
 
52779c2
8b84056
 
 
52779c2
577b861
52779c2
 
577b861
52779c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import threading
from datasets import load_dataset
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Global variables for dataset status
dataset_loaded = False
dataset_status_message = "Dataset is still loading..."
dataset_lock = threading.Lock()

def load_dataset_in_background():
    global dataset_loaded, dataset_status_message
    try:
        # Load dataset from Hugging Face
        dataset = load_dataset("HuggingFaceFW/fineweb", split="train")
        # Save to CSV for later use
        dataset.to_csv("data.csv")
        with dataset_lock:
            dataset_loaded = True
            dataset_status_message = "Dataset loaded successfully!"
    except Exception as e:
        with dataset_lock:
            dataset_loaded = False
            dataset_status_message = f"Error loading dataset: {e}"

# Start dataset loading in background thread
threading.Thread(target=load_dataset_in_background, daemon=True).start()

# Load GPT-2 for inference
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1)

# Function to generate response
def generate_response(prompt):
    responses = generator(
        prompt,
        max_length=100,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        num_return_sequences=1
    )
    return responses[0]['generated_text'].strip()

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## GPT-2 AI Assistant with Background Dataset Loading")
    status_box = gr.Textbox(value=dataset_status_message, label="Dataset Loading Status", interactive=False, lines=2)

    def refresh_status():
        with dataset_lock:
            return dataset_status_message

    refresh_button = gr.Button("Check Dataset Status")
    refresh_button.click(refresh_status, outputs=status_box)

    gr.Markdown("### Chat with the AI")
    prompt_input = gr.Textbox(label="Your prompt", placeholder="Ask me anything...")
    response_output = gr.Textbox(label="AI Response", lines=10)

    def chat(prompt):
        return generate_response(prompt)

    gr.Button("Ask").click(chat, inputs=prompt_input, outputs=response_output)

demo.launch()