#!/usr/bin/env python3 """ Gradio application for inference with Phi-2 model using LoRA/QLoRA adapters. Pre-loads the model and provides a simple chat interface. """ import os import time import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel # Define constants DEFAULT_MODEL_PATH = "./adapters" # Path to the trained adapters DEFAULT_BASE_MODEL = "microsoft/phi-2" # Base model name DEFAULT_MAX_NEW_TOKENS = 512 DEFAULT_TEMPERATURE = 0.7 DEFAULT_TOP_P = 0.9 DEFAULT_TOP_K = 50 # Global variables to store the model and tokenizer model = None tokenizer = None def load_model( model_path=DEFAULT_MODEL_PATH, base_model=DEFAULT_BASE_MODEL, use_qlora=True, device="cuda" ): """ Load the base model and adapter weights. """ global model, tokenizer print(f"Loading tokenizer from {base_model}...") tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Configure model loading parameters model_kwargs = {"trust_remote_code": True} # Set up quantization for QLoRA if enabled if use_qlora: print("Using 4-bit quantization (QLoRA)") compute_dtype = torch.float16 if torch.cuda.is_bf16_supported(): compute_dtype = torch.bfloat16 quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True ) model_kwargs["quantization_config"] = quantization_config else: model_kwargs["torch_dtype"] = torch.float16 if torch.cuda.is_available() else torch.float32 # Check if adapter path exists if not os.path.exists(model_path): print(f"Warning: Model path '{model_path}' does not exist. Using base model only.") # Load base model print(f"Loading base model {base_model}...") base_model = AutoModelForCausalLM.from_pretrained( base_model, **model_kwargs ) # Load adapter weights if available if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "adapter_config.json")): print(f"Loading {'QLoRA' if use_qlora else 'LoRA'} adapters from {model_path}...") model = PeftModel.from_pretrained(base_model, model_path) # Special handling for QLoRA - move norm layers to float32 for stability # and ensure model and adapter layers have consistent dtypes if use_qlora: print("Harmonizing model layer dtypes...") working_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # First make sure important parts are in float16/32 for name, module in model.named_modules(): if any(x in name for x in ["lm_head", "embed_tokens"]): module.to(working_dtype) elif "norm" in name: module.to(torch.float32) # Norms should be in fp32 for stability else: model = base_model print("Using base model without adapters") # Move model to device device = torch.device(device if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() print(f"Model loaded successfully and moved to {device}!") return model, tokenizer def generate_response(prompt, chat_history): """ Generate text response from the model. """ global model, tokenizer if model is None or tokenizer is None: return chat_history + [(prompt, "Model not loaded yet. Please wait a moment.")] # Format prompt for Phi-2 formatted_prompt = f"Instruct: {prompt}\nOutput:" # Tokenize input prompt device = next(model.parameters()).device input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device) attention_mask = torch.ones_like(input_ids).to(device) # Generate text with robust error handling try: with torch.no_grad(): # Explicit type casting input_ids = input_ids.to(torch.long) # IDs should always be long attention_mask = attention_mask.to(torch.float16 if torch.cuda.is_available() else torch.float32) # First attempt with simple parameters generated_ids = model.generate( input_ids=input_ids, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, do_sample=True, temperature=DEFAULT_TEMPERATURE, top_p=DEFAULT_TOP_P, top_k=DEFAULT_TOP_K, ) except Exception as e: print(f"Generation error: {str(e)}") try: # Fallback: Try with model in eval with forced dtype print("Attempting fallback generation...") with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.float16 if torch.cuda.is_available() else torch.float32): generated_ids = model.generate( input_ids=input_ids, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, do_sample=False, # Use greedy decoding for more stability ) except Exception as e2: return chat_history + [(prompt, f"Error generating response: {str(e2)}")] # Decode the generated text generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Extract just the output part output = generated_text.split("Output:")[1].strip() if "Output:" in generated_text else generated_text # Update chat history return chat_history + [(prompt, output)] # Example prompts to demonstrate model capabilities examples = [ ["Explain the concept of quantum computing in simple terms."], ["Write a short story about a robot that learns to paint."], ["What are some ethical considerations when developing AI systems?"], ["How can I improve my productivity while working from home?"], ["Create a meal plan for a vegetarian diet that provides sufficient protein."] ] # Initialize the model at startup print("Pre-loading the model...") try: model, tokenizer = load_model() except Exception as e: print(f"Error loading model: {str(e)}") print("The app will still start, but you may need to check your model path.") # Create the Gradio interface with gr.Blocks(title="Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") as demo: gr.Markdown("# Supervised Fine Tuned (SFT) Phi-2 with QLoRA Adapters") gr.Markdown("- Base model (foundation model) Phi-2\n" "- Supervised Fine Tuned (SFT) method is used to fine-tune the model on [OpenAssistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1?row=0)\n" "- QLoRA Adapters are used to reduce the number of parameters in the model\n" "- This gives the model an ability to answer questions rather than just generating text\n" "- Chat with SFT Phi-2 model with QLoRA Adapters") chatbot = gr.Chatbot(height=500) with gr.Row(): msg = gr.Textbox( label="Type your message here", placeholder="Ask me anything...", show_label=False, scale=9 ) send_btn = gr.Button("Send", scale=1) clear = gr.Button("Clear Chat") # Add examples section gr.Markdown("### Example Capabilities") gr.Examples( examples=examples, inputs=msg, outputs=chatbot, fn=generate_response, cache_examples=False, examples_per_page=5 ) # Set up event handlers send_btn.click(generate_response, [msg, chatbot], [chatbot]).then( lambda: "", None, msg # Clear the input box after sending ) msg.submit(generate_response, [msg, chatbot], [chatbot]).then( lambda: "", None, msg # Clear the input box after sending ) clear.click(lambda: [], None, chatbot) # Launch the app if __name__ == "__main__": # Check GPU status if torch.cuda.is_available(): print(f"CUDA available: {torch.cuda.get_device_name(0)}") print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") else: print("CUDA not available, using CPU. This will be very slow for inference.") # Launch the Gradio app demo.launch(share=True)