Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import os | |
import json | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, util | |
# ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# ๋ฐ์ดํฐ์ ๋ก๋ | |
datasets = [ | |
("all-processed", "all-processed"), | |
("chatdoctor-icliniq", "chatdoctor-icliniq"), | |
("chatdoctor_healthcaremagic", "chatdoctor_healthcaremagic"), | |
] | |
all_datasets = {} | |
for dataset_name, config in datasets: | |
all_datasets[dataset_name] = load_dataset("lavita/medical-qa-datasets", config) | |
def find_most_similar_data(query): | |
query_embedding = model.encode(query, convert_to_tensor=True) | |
most_similar = None | |
highest_similarity = -1 | |
for dataset_name, dataset in all_datasets.items(): | |
for split in dataset.keys(): | |
for item in dataset[split]: | |
if 'question' in item and 'answer' in item: | |
item_text = f"์ง๋ฌธ: {item['question']} ๋ต๋ณ: {item['answer']}" | |
item_embedding = model.encode(item_text, convert_to_tensor=True) | |
similarity = util.pytorch_cos_sim(query_embedding, item_embedding).item() | |
if similarity > highest_similarity: | |
highest_similarity = similarity | |
most_similar = item_text | |
return most_similar | |
def respond_with_prefix(message, history, max_tokens=10000, temperature=0.7, top_p=0.95): | |
# ์ฌ๊ธฐ์ ํ๊ธ ๋ต๋ณ ๊ด๋ จ ํ๋ฆฌํฝ์ค ๋ก์ง ์ฝ์ | |
system_prefix = """ | |
์ฌ๊ธฐ์ ์๋ ์ฝ๋์ ์์คํ ํ๋ฆฌํฝ์ค๋ฅผ ์ฝ์ ํ์ธ์. | |
""" | |
modified_message = system_prefix + message # ์ฌ์ฉ์ ๋ฉ์์ง์ ํ๋ฆฌํฝ์ค ์ ์ฉ | |
# ๊ฐ์ฅ ์ ์ฌํ ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ดํฐ์ ์์ ์ฐพ๊ธฐ | |
similar_data = find_most_similar_data(message) | |
if similar_data: | |
modified_message += "\n\n" + similar_data # ์ ์ฌํ ๋ฐ์ดํฐ๋ฅผ ๋ฉ์์ง์ ์ถ๊ฐ | |
data = { | |
"model": "jinjavis:latest", | |
"prompt": modified_message, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p | |
} | |
# API ์์ฒญ | |
response = requests.post("http://hugpu.ai:7877/api/generate", json=data, stream=True) | |
partial_message = "" | |
for line in response.iter_lines(): | |
if line: | |
try: | |
result = json.loads(line) | |
if result.get("done", False): | |
break | |
new_text = result.get('response', '') | |
partial_message += new_text | |
yield partial_message | |
except json.JSONDecodeError as e: | |
print(f"Failed to decode JSON: {e}") | |
yield "An error occurred while processing your request." | |
demo = gr.ChatInterface( | |
fn=respond_with_prefix, | |
additional_inputs=[ | |
gr.Slider(minimum=1, maximum=120000, value=4000, label="Max Tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-P") # Corrected comma placement | |
], | |
theme="Nymbo/Nymbo_Theme" | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=4).launch() | |