Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,64 +1,56 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            from  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
                 | 
| 17 | 
            -
            ) | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
                 | 
| 27 | 
            -
             | 
| 28 | 
            -
                 | 
| 29 | 
            -
             | 
| 30 | 
            -
                 | 
| 31 | 
            -
                    messages,
         | 
| 32 | 
            -
                    max_tokens=max_tokens,
         | 
| 33 | 
            -
                    stream=True,
         | 
| 34 | 
            -
                    temperature=temperature,
         | 
| 35 | 
            -
                    top_p=top_p,
         | 
| 36 | 
            -
                ):
         | 
| 37 | 
            -
                    token = message.choices[0].delta.content
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                    response += token
         | 
| 40 | 
            -
                    yield response
         | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
            """
         | 
| 44 | 
            -
            For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
         | 
| 45 | 
            -
            """
         | 
| 46 | 
            -
            demo = gr.ChatInterface(
         | 
| 47 | 
            -
                respond,
         | 
| 48 | 
            -
                additional_inputs=[
         | 
| 49 | 
            -
                    gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
         | 
| 50 | 
            -
                    gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
         | 
| 51 | 
            -
                    gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
         | 
| 52 | 
            -
                    gr.Slider(
         | 
| 53 | 
            -
                        minimum=0.1,
         | 
| 54 | 
            -
                        maximum=1.0,
         | 
| 55 | 
            -
                        value=0.95,
         | 
| 56 | 
            -
                        step=0.05,
         | 
| 57 | 
            -
                        label="Top-p (nucleus sampling)",
         | 
| 58 | 
            -
                    ),
         | 
| 59 | 
            -
                ],
         | 
| 60 | 
             
            )
         | 
| 61 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 62 |  | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            from datasets import load_dataset
         | 
| 3 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Load dataset (replace 'daily_dialog' with your dataset)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            dataset = load_dataset("nazlicanto/persona-based-chat")
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Choose a base model (DialoGPT)
         | 
| 10 | 
            +
            model_name = "microsoft/DialoGPT-small"
         | 
| 11 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 12 | 
            +
            model = AutoModelForCausalLM.from_pretrained(model_name)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # Preprocess the dataset
         | 
| 15 | 
            +
            def preprocess_data(example):
         | 
| 16 | 
            +
                input_text = "User: " + example["dialog"][0] + " Bot: " + example["dialog"][1]
         | 
| 17 | 
            +
                return tokenizer(input_text, truncation=True, padding="max_length", max_length=128)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            tokenized_dataset = dataset.map(preprocess_data, batched=True)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # Training arguments
         | 
| 22 | 
            +
            training_args = TrainingArguments(
         | 
| 23 | 
            +
                output_dir="./chatbot_model",
         | 
| 24 | 
            +
                evaluation_strategy="steps",
         | 
| 25 | 
            +
                eval_steps=500,
         | 
| 26 | 
            +
                save_steps=1000,
         | 
| 27 | 
            +
                per_device_train_batch_size=4,
         | 
| 28 | 
            +
                per_device_eval_batch_size=4,
         | 
| 29 | 
            +
                num_train_epochs=3,
         | 
| 30 | 
            +
                save_total_limit=2
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
            )
         | 
| 32 |  | 
| 33 | 
            +
            trainer = Trainer(
         | 
| 34 | 
            +
                model=model,
         | 
| 35 | 
            +
                args=training_args,
         | 
| 36 | 
            +
                train_dataset=tokenized_dataset["train"],
         | 
| 37 | 
            +
                eval_dataset=tokenized_dataset["validation"],
         | 
| 38 | 
            +
                tokenizer=tokenizer,
         | 
| 39 | 
            +
            )
         | 
| 40 |  | 
| 41 | 
            +
            # Train the model
         | 
| 42 | 
            +
            def train_model():
         | 
| 43 | 
            +
                trainer.train()
         | 
| 44 | 
            +
                model.save_pretrained("trained_chatbot")
         | 
| 45 | 
            +
                tokenizer.save_pretrained("trained_chatbot")
         | 
| 46 | 
            +
                return "Training Complete!"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # Chat interface
         | 
| 49 | 
            +
            def chatbot(user_input):
         | 
| 50 | 
            +
                inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
         | 
| 51 | 
            +
                outputs = model.generate(inputs, max_length=150, pad_token_id=tokenizer.eos_token_id)
         | 
| 52 | 
            +
                return tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Gradio UI
         | 
| 55 | 
            +
            iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", live=True)
         | 
| 56 | 
            +
            iface.launch()
         |