import os import queue from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer ############################################################ # Model setup (modify as needed) ############################################################ DESCRIPTION = """\

Hi, I'm Gemma 2 (2B) 👋

This is a demo of google/gemma-2-2b-it fine-tuned for instruction following. For more details, please check the post. 👉 Looking for a larger version? Try the 27B in HuggingChat and the 9B in this Space. """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "google/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.sliding_window = 4096 model.eval() ############################################################ # Generator function (streaming approach) ############################################################ @spaces.GPU def generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: """Generate text from the model and stream tokens back to the UI.""" conversation = chat_history.copy() conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] try: for text in streamer: outputs.append(text) yield "".join(outputs) except queue.Empty: # End of stream; avoid traceback return ############################################################ # CREATE_INTERFACE function returning a gr.Blocks ############################################################ def create_interface() -> gr.Blocks: """ Build a custom Blocks interface containing: - A Chatbot with user/bot icons - A ChatInterface that uses the chatbot - Custom example suggestions with special styling """ gemma_css = """ :root { --gradient-start: #66AEEF; /* lighter top */ --gradient-end: #F0F8FF; /* very light at bottom */ } /* Overall page & container background gradient */ html, body, .gradio-container { margin: 0; padding: 0; background: linear-gradient(to bottom, var(--gradient-start), var(--gradient-end)); font-family: "Helvetica", sans-serif; color: #333; /* dark gray for better contrast */ } /* Make anchor (link) text a clearly visible dark blue */ a, a:visited { color: #02497A !important; text-decoration: underline; } /* Center the top headings in the description */ .gradio-container h1 { margin-top: 0.8em; margin-bottom: 0.5em; text-align: center; color: #fff; /* White text on top gradient for pop */ } /* Chat container background: a very light blue so it's distinct from the outer gradient */ .chatbot, .chatbot .wrap, .chat-interface, .chat-interface .wrap { background-color: #F8FDFF !important; } /* Remove harsh frames around chat messages */ .chatbot .message, .chat-message { border: none !important; position: relative; } /* Icons for user and bot messages (Chatbot) */ .chatbot .user .chat-avatar { background: url('user.png') center center no-repeat; background-size: cover; } .chatbot .bot .chat-avatar { background: url('gemma.png') center center no-repeat; background-size: cover; } /* Icons for user and bot messages (ChatInterface) */ .chat-message.user::before { content: ''; display: inline-block; background: url('user.png') center center no-repeat; background-size: cover; width: 24px; height: 24px; margin-right: 8px; vertical-align: middle; } .chat-message.bot::before { content: ''; display: inline-block; background: url('gemma.png') center center no-repeat; background-size: cover; width: 24px; height: 24px; margin-right: 8px; vertical-align: middle; } /* Chat bubbles (ChatInterface) */ .chat-message.user { background-color: #0284C7 !important; color: #FFFFFF !important; border-radius: 8px; padding: 8px 12px; margin: 6px 0; } .chat-message.bot { background-color: #EFF8FF !important; color: #333 !important; border-radius: 8px; padding: 8px 12px; margin: 6px 0; } /* Chat input area */ .chat-input textarea { background-color: #FFFFFF; color: #333; border: 1px solid #66AEEF; border-radius: 6px; padding: 8px; } /* Sliders & other controls */ form.sliders input[type="range"] { accent-color: #66AEEF; } form.sliders label { color: #333; } .gradio-button, .chat-send-btn { background-color: #0284C7 !important; color: #FFFFFF !important; border-radius: 5px; border: none; cursor: pointer; } .gradio-button:hover, .chat-send-btn:hover { background-color: #026FA6 !important; } /* Style the example "pill" buttons (ChatInterface) */ .gr-examples { display: flex !important; flex-wrap: wrap; gap: 16px; justify-content: center; margin-bottom: 1em !important; } .gr-examples button.example { background-color: #EFF8FF !important; border: 1px solid #66AEEF !important; border-radius: 8px !important; color: #333 !important; padding: 10px 16px !important; cursor: pointer !important; transition: background-color 0.2s !important; } .gr-examples button.example:hover { background-color: #E0F2FF !important; } /* Additional spacing / small tweaks */ #duplicate-button { margin: auto; background: #1565c0; border-radius: 100vh; color: #fff; } """ with gr.Blocks(css=gemma_css) as app: # A heading or custom markdown gr.Markdown(DESCRIPTION) # We can define a custom Gradio Chatbot (if you want both Chatbot and ChatInterface) chatbot = gr.Chatbot( label="Gemma Chat (Blocks-based)", avatar_images=("user.png", "gemma.png"), height=450, show_copy_button=True ) # Then define a ChatInterface that references your generate function # and optionally reuses the same "chatbot" component if you want. interface = gr.ChatInterface( fn=generate, chatbot=chatbot, # link the Chatbot to the ChatInterface css=gemma_css, # keep your custom CSS description="Gemma 2", additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, fill_height=True, ) return app ############################################################ # Main script entry ############################################################ def main(): demo = create_interface() # Launch the app with queue for concurrency/streaming demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, debug=True) if __name__ == "__main__": main()