MarkChen1214's picture
Update app.py
297c9ff
raw
history blame
8.64 kB
import os
import requests
import json
import time
from openai import OpenAI
import gradio as gr
from transformers import AutoTokenizer
import socket
hostname=socket.gethostname()
IPAddr=socket.gethostbyname(hostname)
print("Your Computer Name is:" + hostname)
print("Your Computer IP Address is:" + IPAddr)
DESCRIPTION = """
# Cloned from MediaTek Research Breeze-7B
MediaTek Research Breeze-7B (hereinafter referred to as Breeze-7B) is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use.
[Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v1_0) is the base model for the Breeze-7B series.
It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case.
[Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v1_0) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks.
This App is cloned from [Demo-MR-Breeze-7B](https://huggingface.co/spaces/MediaTek-Research/Demo-MR-Breeze-7B)
"""
LICENSE = """
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan."
API_URL = os.environ.get("API_URL")
TOKEN = os.environ.get("TOKEN")
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0"
MODEL_NAME = os.environ.get("MODEL_NAME")
HEADERS = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json",
"accept": "application/json"
}
MAX_SEC = 30
MAX_INPUT_LENGTH = 5000
client = OpenAI(
base_url=API_URL,
api_key=TOKEN
)
def chat_with_openai(client, model_name, system_message, user_message, temperature=0.5, max_tokens=1024, top_p=0.5):
chat_completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "system",
"content": system_message
},
{
"role": "user",
"content": user_message
}
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
for message in chat_completion:
yield message.choices[0].delta.content
def refusal_condition(query):
# 不要再問這些問題啦!
query_remove_space = query.replace(' ', '').lower()
is_including_tw = False
for x in ['台灣', '台湾', 'taiwan', 'tw', '中華民國', '中华民国']:
if x in query_remove_space:
is_including_tw = True
is_including_cn = False
for x in ['中國', '中国', 'cn', 'china', '大陸', '內地', '大陆', '内地', '中華人民共和國', '中华人民共和国']:
if x in query_remove_space:
is_including_cn = True
if is_including_tw and is_including_cn:
return True
for x in ['一個中國', '兩岸', '一中原則', '一中政策', '一个中国', '两岸', '一中原则']:
if x in query_remove_space:
return True
return False
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
system_prompt = gr.Textbox(label='System prompt',
value=DEFAULT_SYSTEM_PROMPT,
lines=1)
with gr.Accordion(label='Advanced options', open=False):
max_new_tokens = gr.Slider(
label='Max new tokens',
minimum=32,
maximum=2048,
step=1,
value=1024,
)
temperature = gr.Slider(
label='Temperature',
minimum=0.01,
maximum=0.5,
step=0.01,
value=0.01,
)
top_p = gr.Slider(
label='Top-p (nucleus sampling)',
minimum=0.01,
maximum=0.99,
step=0.01,
value=0.01,
)
chatbot = gr.Chatbot(show_copy_button=True, show_share_button=True, )
with gr.Row():
msg = gr.Textbox(
container=False,
show_label=False,
placeholder='Type a message...',
scale=10,
lines=6
)
submit_button = gr.Button('Submit',
variant='primary',
scale=1,
min_width=0)
with gr.Row():
retry_button = gr.Button('🔄 Retry', variant='secondary')
undo_button = gr.Button('↩️ Undo', variant='secondary')
clear = gr.Button('🗑️ Clear', variant='secondary')
saved_input = gr.State()
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, max_new_tokens, temperature, top_p, system_prompt):
chat_data = []
system_prompt = system_prompt.strip()
if system_prompt:
chat_data.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in history:
chat_data.append({"role": "user", "content": user_msg if user_msg is not None else ''})
chat_data.append({"role": "assistant", "content": assistant_msg if assistant_msg is not None else ''})
response = '[ERROR]'
if refusal_condition(history[-1][0]):
history = [['[安全拒答啟動]', '[安全拒答啟動] 請清除再開啟對話']]
response = '[REFUSAL]'
yield history
else:
r = chat_with_openai(
client,
MODEL_NAME,
system_prompt,
history[-1][0],
temperature,
max_new_tokens,
top_p)
if r is not None:
for delta in r:
if history[-1][1] is None:
history[-1][1] = ''
history[-1][1] += delta
yield history
if history[-1][1].endswith('</s>'):
history[-1][1] = history[-1][1][:-4]
yield history
response = history[-1][1]
if refusal_condition(history[-1][1]):
history[-1][1] = history[-1][1] + '\n\n**[免責聲明: 此模型並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**'
yield history
else:
del history[-1]
yield history
print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(history[-1][0]), response=repr(history[-1][1])))
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
system_prompt,
],
outputs=chatbot
)
submit_button.click(
user, [msg, chatbot], [msg, chatbot], queue=False
).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
system_prompt,
],
outputs=chatbot
)
def delete_prev_fn(
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ''
return history, message or ''
def display_input(message: str,
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
history.append((message, ''))
return history
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
system_prompt,
],
outputs=chatbot,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=msg,
api_name=False,
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown(LICENSE)
demo.queue(default_concurrency_limit=10)
demo.launch()