Spaces:
Running
Running
| # app.py β Fixed: load quantized base + local LoRA checkpoint (preferred), | |
| # tokenizer from base, device-safe generation, Gradio UI with sliders. | |
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import PeftModel | |
| # ---- USER CONFIG ---- | |
| # If ADAPTER_LOCAL_DIR exists, that local checkpoint (e.g. checkpoint-9000) will be used. | |
| ADAPTER_LOCAL_DIR = os.environ.get("ADAPTER_LOCAL_DIR", "qwen_lora_sft_output/checkpoint-9000") | |
| HF_ADAPTER_REPO = "GilbertAkham/gilbert-qwen-multitask-lora" # fallback adapter repo id | |
| BASE_MODEL = "Qwen/Qwen1.5-1.8B-Chat" | |
| # --------------------- | |
| class MultitaskInference: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.tokenizer = None | |
| self._load_model_and_tokenizer() | |
| def _load_model_and_tokenizer(self): | |
| compute_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| # Use tokenizer from base model (recommended) | |
| print("Loading tokenizer from base model:", BASE_MODEL) | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, trust_remote_code=True) | |
| except Exception as e: | |
| print("Failed to load tokenizer from base model:", e) | |
| print("Trying tokenizer from local adapter or HF adapter repo as fallback...") | |
| # fallback attempt | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(HF_ADAPTER_REPO, use_fast=False, trust_remote_code=True) | |
| except Exception as e2: | |
| raise RuntimeError("Cannot load tokenizer from base or adapter repos.") from e2 | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Prepare bitsandbytes config when CUDA is available | |
| bnb_config = None | |
| if self.device == "cuda": | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| ) | |
| print("Using 4-bit quantized loader (bitsandbytes) for the base model.") | |
| # Load the base model (quantized if possible) | |
| print("Loading base model:", BASE_MODEL) | |
| try: | |
| self.base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="auto" if self.device == "cuda" else None, | |
| quantization_config=bnb_config, | |
| torch_dtype=compute_dtype if self.device == "cuda" else torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load base model {BASE_MODEL}: {e}") | |
| # Load LoRA adapter: prefer local checkpoint folder if present | |
| adapter_source = None | |
| if os.path.exists(ADAPTER_LOCAL_DIR) and os.path.isdir(ADAPTER_LOCAL_DIR): | |
| adapter_source = ADAPTER_LOCAL_DIR | |
| print("Found local adapter checkpoint:", ADAPTER_LOCAL_DIR) | |
| else: | |
| adapter_source = HF_ADAPTER_REPO | |
| print("Local adapter not found β will try to load adapter from HF repo:", HF_ADAPTER_REPO) | |
| print(f"Loading LoRA adapter from: {adapter_source}") | |
| try: | |
| # PeftModel.from_pretrained can accept a local path or a repo id | |
| self.model = PeftModel.from_pretrained(self.base, adapter_source, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load LoRA adapter from {adapter_source}: {e}") | |
| # Move model to device (PeftModel wraps base model) | |
| if self.device == "cuda": | |
| # model is partitioned by device_map if bnb used; still ensure on cuda | |
| try: | |
| self.model.to(self.device) | |
| except Exception: | |
| # sometimes .to('cuda') is not required when device_map='auto' already placed weights | |
| pass | |
| else: | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Model + adapter loaded. Device:", self.device) | |
| def generate_response(self, task_type: str, input_text: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9): | |
| task_prompts = { | |
| "email": "Draft an email reply", | |
| "story": "Continue the story", | |
| "tech": "Answer the technical question", | |
| "summary": "Summarize the content", | |
| "chat": "Provide a helpful chat response" | |
| } | |
| prompt = f"### Task: {task_prompts.get(task_type,'Provide a reply')}\n\n### Input:\n{input_text}\n\n### Output:\n" | |
| # Tokenize then move tensors to same device as model | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
| # Move inputs to model device | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| try: | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| ) | |
| text = self.tokenizer.decode(out[0], skip_special_tokens=True) | |
| if "### Output:" in text: | |
| text = text.split("### Output:")[-1].strip() | |
| return text | |
| except Exception as e: | |
| return f"β Generation error: {e}" | |
| # Create engine (this will load model on startup) | |
| engine = MultitaskInference() | |
| # Gradio UI | |
| def process_request(task_type, user_input, max_tokens, temperature, top_p): | |
| if not user_input or not user_input.strip(): | |
| return "β οΈ Please enter some input text." | |
| return engine.generate_response(task_type, user_input, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p)) | |
| examples = [ | |
| ["chat", "Hey β my VPN won't connect. Any suggestions?"], | |
| ["email", "Subject: Project update\nBody: Please share the status of Task A."], | |
| ["story", "The lighthouse blinked twice and the fog rolled in..."], | |
| ["tech", "What is the difference between model.eval() and model.train() in PyTorch?"], | |
| ["summary", "AI systems are transforming industries through automation and data insights..."], | |
| ] | |
| with gr.Blocks(title="Gilbert Multitask AI", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f"## π Gilbert Multitask AI\n\n**Base model:** {BASE_MODEL}\n\nLoRA adapter: local `{ADAPTER_LOCAL_DIR}` if present, otherwise `{HF_ADAPTER_REPO}`." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_type = gr.Dropdown(choices=["chat", "email", "story", "tech", "summary"], value="chat", label="Task") | |
| max_tokens = gr.Slider(50, 1024, value=200, step=10, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| gr.Examples(examples=examples, inputs=[task_type, gr.Textbox(visible=False)]) | |
| with gr.Column(scale=2): | |
| input_box = gr.Textbox(lines=8, label="Input") | |
| output_box = gr.Textbox(lines=10, label="Generated Response", show_copy_button=True) | |
| btn = gr.Button("Generate") | |
| btn.click(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box) | |
| input_box.submit(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |