seawolf2357's picture
Update app.py
0e6d23a verified
raw
history blame
3.27 kB
import gradio as gr
import requests
import os
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
# ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
datasets = [
("all-processed", "all-processed"),
("chatdoctor-icliniq", "chatdoctor-icliniq"),
("chatdoctor_healthcaremagic", "chatdoctor_healthcaremagic"),
]
all_datasets = {}
for dataset_name, config in datasets:
all_datasets[dataset_name] = load_dataset("lavita/medical-qa-datasets", config)
def find_most_similar_data(query):
query_embedding = model.encode(query, convert_to_tensor=True)
most_similar = None
highest_similarity = -1
for dataset_name, dataset in all_datasets.items():
for split in dataset.keys():
for item in dataset[split]:
if 'question' in item and 'answer' in item:
item_text = f"์งˆ๋ฌธ: {item['question']} ๋‹ต๋ณ€: {item['answer']}"
item_embedding = model.encode(item_text, convert_to_tensor=True)
similarity = util.pytorch_cos_sim(query_embedding, item_embedding).item()
if similarity > highest_similarity:
highest_similarity = similarity
most_similar = item_text
return most_similar
def respond_with_prefix(message, history, max_tokens=10000, temperature=0.7, top_p=0.95):
# ์—ฌ๊ธฐ์— ํ•œ๊ธ€ ๋‹ต๋ณ€ ๊ด€๋ จ ํ”„๋ฆฌํ”ฝ์Šค ๋กœ์ง ์‚ฝ์ž…
system_prefix = """
์—ฌ๊ธฐ์— ์›๋ž˜ ์ฝ”๋“œ์˜ ์‹œ์Šคํ…œ ํ”„๋ฆฌํ”ฝ์Šค๋ฅผ ์‚ฝ์ž…ํ•˜์„ธ์š”.
"""
modified_message = system_prefix + message # ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€์— ํ”„๋ฆฌํ”ฝ์Šค ์ ์šฉ
# ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ฐพ๊ธฐ
similar_data = find_most_similar_data(message)
if similar_data:
modified_message += "\n\n" + similar_data # ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฉ”์‹œ์ง€์— ์ถ”๊ฐ€
data = {
"model": "jinjavis:latest",
"prompt": modified_message,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
# API ์š”์ฒญ
response = requests.post("http://hugpu.ai:7877/api/generate", json=data, stream=True)
partial_message = ""
for line in response.iter_lines():
if line:
try:
result = json.loads(line)
if result.get("done", False):
break
new_text = result.get('response', '')
partial_message += new_text
yield partial_message
except json.JSONDecodeError as e:
print(f"Failed to decode JSON: {e}")
yield "An error occurred while processing your request."
demo = gr.ChatInterface(
fn=respond_with_prefix,
additional_inputs=[
gr.Slider(minimum=1, maximum=120000, value=4000, label="Max Tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-P") # Corrected comma placement
],
theme="Nymbo/Nymbo_Theme"
)
if __name__ == "__main__":
demo.queue(max_size=4).launch()