Norod78's picture
Update app.py
eccb754 verified
import gradio as gr
import cv2
import torch
from PIL import Image
from pathlib import Path
from threading import Thread
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import spaces
import time
TITLE = " ืžื•ื“ืœ ืžื‘ื•ืกืก ื’ืžื” 3 ืœื™ืฆื™ืจืช ืฉื™ืจื™ื ืžื˜ื•ืคืฉื™ื ื‘ืขื‘ืจื™ืช "
DESCRIPTION= """
ื ื™ืชืŸ ืœื‘ืงืฉ ืฉื™ืจ ืขืœ ื‘ืกื™ืก ื˜ืงืกื˜, ืชืžื•ื ื” ื•ื•ื™ื“ืื•
ื‘ื›ืœ ืคืขื, ื™ื•ื•ืฆืจ ืฉื™ืจ ืฉื•ื ื”, ืื– ืื ืœื ืื”ื‘ืชื, ืืคืฉืจ ืœื ืกื•ืช ืฉื•ื‘ ืขื ืื•ืชื• ื”ืคืจื•ืžืคื˜
[ื”ืžื•ื“ืœ ื–ืžื™ืŸ ืœื”ื•ืจื“ื”](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune)
ื”ืžื•ื“ืœ ื›ึผื•ึผื™ึทึผื™ืœ ืขืดื™ [ื“ื•ืจื•ืŸ ืื“ืœืจ](https://linktr.ee/Norod78)
"""
# model config
model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune"
model_4b = Gemma3ForConditionalGeneration.from_pretrained(
model_4b_name,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
processor_4b = AutoProcessor.from_pretrained(model_4b_name)
# I will add timestamp later
def extract_video_frames(video_path, num_frames=8):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // num_frames, 1)
for i in range(num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
def format_message(content, files):
message_content = []
if content:
parts = content.split('<image>')
for i, part in enumerate(parts):
if part.strip():
message_content.append({"type": "text", "text": part.strip()})
if i < len(parts) - 1 and files:
img = Image.open(files.pop(0))
message_content.append({"type": "image", "image": img})
for file in files:
file_path = file if isinstance(file, str) else file.name
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
img = Image.open(file_path)
message_content.append({"type": "image", "image": img})
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
frames = extract_video_frames(file_path)
for frame in frames:
message_content.append({"type": "image", "image": frame})
return message_content
def format_conversation_history(chat_history):
messages = []
current_user_content = []
for item in chat_history:
role = item["role"]
content = item["content"]
if role == "user":
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list):
current_user_content.extend(content)
else:
current_user_content.append({"type": "text", "text": str(content)})
elif role == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
@spaces.GPU(duration=120)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
if isinstance(input_data, dict) and "text" in input_data:
text = input_data["text"]
files = input_data.get("files", [])
else:
text = str(input_data)
files = []
new_message_content = format_message(text, files)
new_message = {"role": "user", "content": new_message_content}
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history
if messages and messages[-1]["role"] == "user":
messages[-1]["content"].extend(new_message["content"])
else:
messages.append(new_message)
model = model_4b
processor = processor_4b
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"),
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
gr.Textbox(
label="System Prompt",
value="ืืชื” ืžืฉื•ืจืจ ื™ืฉืจืืœื™, ื›ื•ืชื‘ ืฉื™ืจื™ื ื‘ืขื‘ืจื™ืช",
lines=4,
placeholder="ืฉื ื” ืืช ื”ื”ื’ื“ืจื•ืช ืฉืœ ื”ืžื•ื“ืœ",
text_align = 'right', rtl = True
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1),
],
examples=[
[{"text": "ื›ืชื•ื‘ ืœื™ ื‘ื‘ืงืฉื” ืฉื™ืจ ื”ืžืชืืจ ืืช ื”ืชืžื•ื ื”", "files": ["examples/image1.jpg"]}],
[{"text": "ืชืคื•ื— ืื“ืžื” ืขื ื—ืจื“ื” ื—ื‘ืจืชื™ืช"}]
],
textbox=gr.MultimodalTextbox(
rtl=True,
label="ืงืœื˜",
file_types=["image", "video"],
file_count="multiple",
placeholder="ื‘ืงืฉื• ืฉื™ืจ ื•/ืื• ื”ืขืœื• ืชืžื•ื ื”",
),
cache_examples=False,
type="messages",
fill_height=True,
stop_btn="ื”ืคืกืง",
css_paths=["style.css"],
multimodal=True,
title=TITLE,
description=DESCRIPTION,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
chat_interface.queue(max_size=20).launch()