Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import cv2 | |
import torch | |
from PIL import Image | |
from pathlib import Path | |
from threading import Thread | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
import spaces | |
import time | |
TITLE = " ืืืื ืืืืกืก ืืื 3 ืืืฆืืจืช ืฉืืจืื ืืืืคืฉืื ืืขืืจืืช " | |
DESCRIPTION= """ | |
ื ืืชื ืืืงืฉ ืฉืืจ ืขื ืืกืืก ืืงืกื, ืชืืื ื ืืืืืื | |
ืืื ืคืขื, ืืืืฆืจ ืฉืืจ ืฉืื ื, ืื ืื ืื ืืืืชื, ืืคืฉืจ ืื ืกืืช ืฉืื ืขื ืืืชื ืืคืจืืืคื | |
[ืืืืื ืืืื ืืืืจืื](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune) | |
ืืืืื ืึผืึผืึทึผืื ืขืดื [ืืืจืื ืืืืจ](https://linktr.ee/Norod78) | |
""" | |
# model config | |
model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune" | |
model_4b = Gemma3ForConditionalGeneration.from_pretrained( | |
model_4b_name, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
).eval() | |
processor_4b = AutoProcessor.from_pretrained(model_4b_name) | |
# I will add timestamp later | |
def extract_video_frames(video_path, num_frames=8): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
step = max(total_frames // num_frames, 1) | |
for i in range(num_frames): | |
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) | |
ret, frame = cap.read() | |
if ret: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(Image.fromarray(frame)) | |
cap.release() | |
return frames | |
def format_message(content, files): | |
message_content = [] | |
if content: | |
parts = content.split('<image>') | |
for i, part in enumerate(parts): | |
if part.strip(): | |
message_content.append({"type": "text", "text": part.strip()}) | |
if i < len(parts) - 1 and files: | |
img = Image.open(files.pop(0)) | |
message_content.append({"type": "image", "image": img}) | |
for file in files: | |
file_path = file if isinstance(file, str) else file.name | |
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: | |
img = Image.open(file_path) | |
message_content.append({"type": "image", "image": img}) | |
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: | |
frames = extract_video_frames(file_path) | |
for frame in frames: | |
message_content.append({"type": "image", "image": frame}) | |
return message_content | |
def format_conversation_history(chat_history): | |
messages = [] | |
current_user_content = [] | |
for item in chat_history: | |
role = item["role"] | |
content = item["content"] | |
if role == "user": | |
if isinstance(content, str): | |
current_user_content.append({"type": "text", "text": content}) | |
elif isinstance(content, list): | |
current_user_content.extend(content) | |
else: | |
current_user_content.append({"type": "text", "text": str(content)}) | |
elif role == "assistant": | |
if current_user_content: | |
messages.append({"role": "user", "content": current_user_content}) | |
current_user_content = [] | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) | |
if current_user_content: | |
messages.append({"role": "user", "content": current_user_content}) | |
return messages | |
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): | |
if isinstance(input_data, dict) and "text" in input_data: | |
text = input_data["text"] | |
files = input_data.get("files", []) | |
else: | |
text = str(input_data) | |
files = [] | |
new_message_content = format_message(text, files) | |
new_message = {"role": "user", "content": new_message_content} | |
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] | |
processed_history = format_conversation_history(chat_history) | |
messages = system_message + processed_history | |
if messages and messages[-1]["role"] == "user": | |
messages[-1]["content"].extend(new_message["content"]) | |
else: | |
messages.append(new_message) | |
model = model_4b | |
processor = processor_4b | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_tensors="pt", | |
return_dict=True | |
).to(model.device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate_response, | |
chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"), | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), | |
gr.Textbox( | |
label="System Prompt", | |
value="ืืชื ืืฉืืจืจ ืืฉืจืืื, ืืืชื ืฉืืจืื ืืขืืจืืช", | |
lines=4, | |
placeholder="ืฉื ื ืืช ืืืืืจืืช ืฉื ืืืืื", | |
text_align = 'right', rtl = True | |
), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), | |
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), | |
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), | |
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), | |
], | |
examples=[ | |
[{"text": "ืืชืื ืื ืืืงืฉื ืฉืืจ ืืืชืืจ ืืช ืืชืืื ื", "files": ["examples/image1.jpg"]}], | |
[{"text": "ืชืคืื ืืืื ืขื ืืจืื ืืืจืชืืช"}] | |
], | |
textbox=gr.MultimodalTextbox( | |
rtl=True, | |
label="ืงืื", | |
file_types=["image", "video"], | |
file_count="multiple", | |
placeholder="ืืงืฉื ืฉืืจ ื/ืื ืืขืื ืชืืื ื", | |
), | |
cache_examples=False, | |
type="messages", | |
fill_height=True, | |
stop_btn="ืืคืกืง", | |
css_paths=["style.css"], | |
multimodal=True, | |
title=TITLE, | |
description=DESCRIPTION, | |
theme=gr.themes.Soft(), | |
) | |
if __name__ == "__main__": | |
chat_interface.queue(max_size=20).launch() | |