Spaces:
Running
Running
import os | |
import logging | |
import gradio as gr | |
from typing import Iterator | |
from gateway import request_generation | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
# Validate environment variables | |
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") | |
if not CLOUD_GATEWAY_API: | |
raise EnvironmentError("API_ENDPOINT is not set.") | |
MODEL_NAME: str = os.getenv("MODEL_NAME") | |
if not MODEL_NAME: | |
raise EnvironmentError("MODEL_NAME is not set.") | |
# Get API Key | |
API_KEY = os.getenv("API_KEY") | |
if not API_KEY: # simple check to validate API Key | |
raise Exception("API Key not valid.") | |
# Create a header, avoid declaring multiple times | |
HEADER = {"x-api-key": f"{API_KEY}"} | |
def generate( | |
message: str, | |
chat_history: list, | |
system_prompt: str, | |
temperature: float = 0.6, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
) -> Iterator[str]: | |
"""Send a request to backend, fetch the streaming responses and emit to the UI. | |
Args: | |
message (str): input message from the user | |
chat_history (list[tuple[str, str]]): entire chat history of the session | |
system_prompt (str): system prompt | |
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. | |
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities | |
that add up to top_p or higher are kept for generation. Defaults to 0.9. | |
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
Defaults to 50. | |
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. | |
Defaults to 1.2. | |
Yields: | |
Iterator[str]: Streaming responses to the UI | |
""" | |
# sample method to yield responses from the llm model | |
outputs = [] | |
for text in request_generation( | |
header=HEADER, | |
message=message, | |
system_prompt=system_prompt, | |
temperature=temperature, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
cloud_gateway_api=CLOUD_GATEWAY_API, | |
model_name=MODEL_NAME, | |
): | |
outputs.append(text) | |
yield "".join(outputs) | |
description = """ | |
This Space is an Alpha release that demonstrates the [Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) model running on AMD MI300 infrastructure. The space is built with Qwen 3 [License](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE). Feel free to play with it! | |
""" | |
demo = gr.ChatInterface( | |
fn=generate, | |
type="messages", | |
chatbot=gr.Chatbot( | |
type="messages", | |
scale=2, | |
allow_tags=True, | |
), | |
stop_btn=None, | |
additional_inputs=[ | |
gr.Textbox( | |
label="System prompt", | |
value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs. Keep answers well-structured and under 1200 tokens unless explicitly requested otherwise.", | |
lines=3, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.3, | |
), | |
gr.Slider( | |
label="Frequency penalty", | |
minimum=-2.0, | |
maximum=2.0, | |
step=0.1, | |
value=0.0, | |
), | |
gr.Slider( | |
label="Presence penalty", | |
minimum=-2.0, | |
maximum=2.0, | |
step=0.1, | |
value=0.0, | |
), | |
], | |
examples=[ | |
["Plan a three-day trip to Washington DC for Cherry Blossom Festival."], | |
[ | |
"Compose a short, joyful musical piece for kids celebrating spring sunshine and blossom." | |
], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'."], | |
], | |
cache_examples=False, | |
title="Qwen3-30B-A3B", | |
description=description, | |
) | |
if __name__ == "__main__": | |
demo.queue( | |
max_size=int(os.getenv("QUEUE")), | |
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")), | |
).launch() | |