Spaces:
Running
Running
File size: 7,989 Bytes
96323b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# app.py β Fixed: load quantized base + local LoRA checkpoint (preferred),
# tokenizer from base, device-safe generation, Gradio UI with sliders.
import os
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# ---- USER CONFIG ----
# If ADAPTER_LOCAL_DIR exists, that local checkpoint (e.g. checkpoint-9000) will be used.
ADAPTER_LOCAL_DIR = os.environ.get("ADAPTER_LOCAL_DIR", "qwen_lora_sft_output/checkpoint-9000")
HF_ADAPTER_REPO = "GilbertAkham/gilbert-qwen-multitask-lora" # fallback adapter repo id
BASE_MODEL = "Qwen/Qwen1.5-1.8B-Chat"
# ---------------------
class MultitaskInference:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self._load_model_and_tokenizer()
def _load_model_and_tokenizer(self):
compute_dtype = torch.float16 if self.device == "cuda" else torch.float32
# Use tokenizer from base model (recommended)
print("Loading tokenizer from base model:", BASE_MODEL)
try:
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, trust_remote_code=True)
except Exception as e:
print("Failed to load tokenizer from base model:", e)
print("Trying tokenizer from local adapter or HF adapter repo as fallback...")
# fallback attempt
try:
self.tokenizer = AutoTokenizer.from_pretrained(HF_ADAPTER_REPO, use_fast=False, trust_remote_code=True)
except Exception as e2:
raise RuntimeError("Cannot load tokenizer from base or adapter repos.") from e2
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Prepare bitsandbytes config when CUDA is available
bnb_config = None
if self.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
print("Using 4-bit quantized loader (bitsandbytes) for the base model.")
# Load the base model (quantized if possible)
print("Loading base model:", BASE_MODEL)
try:
self.base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto" if self.device == "cuda" else None,
quantization_config=bnb_config,
torch_dtype=compute_dtype if self.device == "cuda" else torch.float32,
trust_remote_code=True,
)
except Exception as e:
raise RuntimeError(f"Failed to load base model {BASE_MODEL}: {e}")
# Load LoRA adapter: prefer local checkpoint folder if present
adapter_source = None
if os.path.exists(ADAPTER_LOCAL_DIR) and os.path.isdir(ADAPTER_LOCAL_DIR):
adapter_source = ADAPTER_LOCAL_DIR
print("Found local adapter checkpoint:", ADAPTER_LOCAL_DIR)
else:
adapter_source = HF_ADAPTER_REPO
print("Local adapter not found β will try to load adapter from HF repo:", HF_ADAPTER_REPO)
print(f"Loading LoRA adapter from: {adapter_source}")
try:
# PeftModel.from_pretrained can accept a local path or a repo id
self.model = PeftModel.from_pretrained(self.base, adapter_source, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32)
except Exception as e:
raise RuntimeError(f"Failed to load LoRA adapter from {adapter_source}: {e}")
# Move model to device (PeftModel wraps base model)
if self.device == "cuda":
# model is partitioned by device_map if bnb used; still ensure on cuda
try:
self.model.to(self.device)
except Exception:
# sometimes .to('cuda') is not required when device_map='auto' already placed weights
pass
else:
self.model.to(self.device)
self.model.eval()
print("Model + adapter loaded. Device:", self.device)
def generate_response(self, task_type: str, input_text: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9):
task_prompts = {
"email": "Draft an email reply",
"story": "Continue the story",
"tech": "Answer the technical question",
"summary": "Summarize the content",
"chat": "Provide a helpful chat response"
}
prompt = f"### Task: {task_prompts.get(task_type,'Provide a reply')}\n\n### Input:\n{input_text}\n\n### Output:\n"
# Tokenize then move tensors to same device as model
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# Move inputs to model device
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
try:
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
)
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
if "### Output:" in text:
text = text.split("### Output:")[-1].strip()
return text
except Exception as e:
return f"β Generation error: {e}"
# Create engine (this will load model on startup)
engine = MultitaskInference()
# Gradio UI
def process_request(task_type, user_input, max_tokens, temperature, top_p):
if not user_input or not user_input.strip():
return "β οΈ Please enter some input text."
return engine.generate_response(task_type, user_input, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p))
examples = [
["chat", "Hey β my VPN won't connect. Any suggestions?"],
["email", "Subject: Project update\nBody: Please share the status of Task A."],
["story", "The lighthouse blinked twice and the fog rolled in..."],
["tech", "What is the difference between model.eval() and model.train() in PyTorch?"],
["summary", "AI systems are transforming industries through automation and data insights..."],
]
with gr.Blocks(title="Gilbert Multitask AI", theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"## π Gilbert Multitask AI\n\n**Base model:** {BASE_MODEL}\n\nLoRA adapter: local `{ADAPTER_LOCAL_DIR}` if present, otherwise `{HF_ADAPTER_REPO}`."
)
with gr.Row():
with gr.Column(scale=1):
task_type = gr.Dropdown(choices=["chat", "email", "story", "tech", "summary"], value="chat", label="Task")
max_tokens = gr.Slider(50, 1024, value=200, step=10, label="Max new tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
gr.Examples(examples=examples, inputs=[task_type, gr.Textbox(visible=False)])
with gr.Column(scale=2):
input_box = gr.Textbox(lines=8, label="Input")
output_box = gr.Textbox(lines=10, label="Generated Response", show_copy_button=True)
btn = gr.Button("Generate")
btn.click(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
input_box.submit(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|