Sarathi.AI / app.py
JDhruv14's picture
Update app.py
a2ef1b5 verified
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from datasets import load_dataset
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Qwen2.5-3B-Gita-FT")
GITA_SYSTEM_PROMPT = """You are Lord Krishna—the serene, compassionate teacher of the Bhagavad Gita."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
dataset = load_dataset("JDhruv14/Bhagavad-Gita-QA")
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
if not history:
return msgs
if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
for m in history:
role, content = m.get("role"), m.get("content")
if role in ("user", "assistant", "system") and content:
msgs.append({"role": role, "content": content})
else:
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
ids = set()
if tok.eos_token_id is not None:
if isinstance(tok.eos_token_id, (list, tuple)):
ids.update(tok.eos_token_id)
else:
ids.add(tok.eos_token_id)
try:
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None and im_end != tok.unk_token_id:
ids.add(im_end)
except Exception:
pass
return list(ids)
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
eos = _eos_ids(tokenizer)
gen_cfg_kwargs = dict(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
if eos:
gen_cfg_kwargs["eos_token_id"] = eos
gen_cfg = GenerationConfig(**gen_cfg_kwargs)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
@spaces.GPU()
def gradio_fn(message, history):
return chat_fn(
message=message,
history=history,
system_text=GITA_SYSTEM_PROMPT,
temperature=0.7,
top_p=0.95,
max_new=512,
min_new=0,
)
with gr.Blocks(css="""
:root { --chat-w: 520px; }
html, body {
height: 100%;
overflow-y: hidden; /* no page scroll */
margin: 0;
}
/* Full-screen background image with a soft dark overlay */
body {
background:
linear-gradient(0deg, rgba(0,0,0,.28), rgba(0,0,0,.28)),
url("https://huggingface.co/spaces/JDhruv14/gita/resolve/main/bg.jpg") center / cover no-repeat fixed; /* <- change filename if needed */
}
/* Left-aligned, narrower chat panel */
.gradio-container {
max-width: var(--chat-w);
width: var(--chat-w);
margin-left: 16px;
margin-right: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
/* optional glass effect for readability */
background: rgba(0,0,0,.30);
border-radius: 16px;
backdrop-filter: blur(6px);
}
.chatbot {
height: 480px !important;
overflow-y: auto;
}
@media (max-width: 720px){
:root { --chat-w: 92vw; }
.gradio-container { margin-left: 4vw; }
}
""") as demo:
gr.Markdown(
"""
<div style='text-align: center; padding: 10px;'>
<h1 style='font-size: 2.0em; margin-bottom: 0.2em;'><span style='color: #4F46E5;'>🪷 Sarathi.AI</span></h1>
<p style='font-size: 1.0em; color: #bbb;'>Gita’s Eternal Teachings, Guided by AI 🕉️</p>
</div>
""",
elem_id="header"
)
gr.ChatInterface(
fn=gradio_fn,
examples=[
"Namaste!",
"What is my duty?",
"What is a Guna?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"),
type="messages",
)
if __name__ == "__main__":
demo.launch()