GJAI / app.py
GilbertAkham's picture
Upload app.py
96323b7 verified
# app.py β€” Fixed: load quantized base + local LoRA checkpoint (preferred),
# tokenizer from base, device-safe generation, Gradio UI with sliders.
import os
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# ---- USER CONFIG ----
# If ADAPTER_LOCAL_DIR exists, that local checkpoint (e.g. checkpoint-9000) will be used.
ADAPTER_LOCAL_DIR = os.environ.get("ADAPTER_LOCAL_DIR", "qwen_lora_sft_output/checkpoint-9000")
HF_ADAPTER_REPO = "GilbertAkham/gilbert-qwen-multitask-lora" # fallback adapter repo id
BASE_MODEL = "Qwen/Qwen1.5-1.8B-Chat"
# ---------------------
class MultitaskInference:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self._load_model_and_tokenizer()
def _load_model_and_tokenizer(self):
compute_dtype = torch.float16 if self.device == "cuda" else torch.float32
# Use tokenizer from base model (recommended)
print("Loading tokenizer from base model:", BASE_MODEL)
try:
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, trust_remote_code=True)
except Exception as e:
print("Failed to load tokenizer from base model:", e)
print("Trying tokenizer from local adapter or HF adapter repo as fallback...")
# fallback attempt
try:
self.tokenizer = AutoTokenizer.from_pretrained(HF_ADAPTER_REPO, use_fast=False, trust_remote_code=True)
except Exception as e2:
raise RuntimeError("Cannot load tokenizer from base or adapter repos.") from e2
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Prepare bitsandbytes config when CUDA is available
bnb_config = None
if self.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
print("Using 4-bit quantized loader (bitsandbytes) for the base model.")
# Load the base model (quantized if possible)
print("Loading base model:", BASE_MODEL)
try:
self.base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto" if self.device == "cuda" else None,
quantization_config=bnb_config,
torch_dtype=compute_dtype if self.device == "cuda" else torch.float32,
trust_remote_code=True,
)
except Exception as e:
raise RuntimeError(f"Failed to load base model {BASE_MODEL}: {e}")
# Load LoRA adapter: prefer local checkpoint folder if present
adapter_source = None
if os.path.exists(ADAPTER_LOCAL_DIR) and os.path.isdir(ADAPTER_LOCAL_DIR):
adapter_source = ADAPTER_LOCAL_DIR
print("Found local adapter checkpoint:", ADAPTER_LOCAL_DIR)
else:
adapter_source = HF_ADAPTER_REPO
print("Local adapter not found β€” will try to load adapter from HF repo:", HF_ADAPTER_REPO)
print(f"Loading LoRA adapter from: {adapter_source}")
try:
# PeftModel.from_pretrained can accept a local path or a repo id
self.model = PeftModel.from_pretrained(self.base, adapter_source, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32)
except Exception as e:
raise RuntimeError(f"Failed to load LoRA adapter from {adapter_source}: {e}")
# Move model to device (PeftModel wraps base model)
if self.device == "cuda":
# model is partitioned by device_map if bnb used; still ensure on cuda
try:
self.model.to(self.device)
except Exception:
# sometimes .to('cuda') is not required when device_map='auto' already placed weights
pass
else:
self.model.to(self.device)
self.model.eval()
print("Model + adapter loaded. Device:", self.device)
def generate_response(self, task_type: str, input_text: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9):
task_prompts = {
"email": "Draft an email reply",
"story": "Continue the story",
"tech": "Answer the technical question",
"summary": "Summarize the content",
"chat": "Provide a helpful chat response"
}
prompt = f"### Task: {task_prompts.get(task_type,'Provide a reply')}\n\n### Input:\n{input_text}\n\n### Output:\n"
# Tokenize then move tensors to same device as model
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# Move inputs to model device
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
try:
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
)
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
if "### Output:" in text:
text = text.split("### Output:")[-1].strip()
return text
except Exception as e:
return f"❌ Generation error: {e}"
# Create engine (this will load model on startup)
engine = MultitaskInference()
# Gradio UI
def process_request(task_type, user_input, max_tokens, temperature, top_p):
if not user_input or not user_input.strip():
return "⚠️ Please enter some input text."
return engine.generate_response(task_type, user_input, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p))
examples = [
["chat", "Hey β€” my VPN won't connect. Any suggestions?"],
["email", "Subject: Project update\nBody: Please share the status of Task A."],
["story", "The lighthouse blinked twice and the fog rolled in..."],
["tech", "What is the difference between model.eval() and model.train() in PyTorch?"],
["summary", "AI systems are transforming industries through automation and data insights..."],
]
with gr.Blocks(title="Gilbert Multitask AI", theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"## πŸš€ Gilbert Multitask AI\n\n**Base model:** {BASE_MODEL}\n\nLoRA adapter: local `{ADAPTER_LOCAL_DIR}` if present, otherwise `{HF_ADAPTER_REPO}`."
)
with gr.Row():
with gr.Column(scale=1):
task_type = gr.Dropdown(choices=["chat", "email", "story", "tech", "summary"], value="chat", label="Task")
max_tokens = gr.Slider(50, 1024, value=200, step=10, label="Max new tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
gr.Examples(examples=examples, inputs=[task_type, gr.Textbox(visible=False)])
with gr.Column(scale=2):
input_box = gr.Textbox(lines=8, label="Input")
output_box = gr.Textbox(lines=10, label="Generated Response", show_copy_button=True)
btn = gr.Button("Generate")
btn.click(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
input_box.submit(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)