my-api / app.py
wiklif's picture
naprawa literówki
53cab83
import os
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from queue import Queue, Empty
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model_id = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))
model = None
model_load_queue = Queue()
def load_model():
global model
try:
if model is None:
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=os.environ.get("MY_API_LLAMA_3_1"),
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
load_in_8bit=True
)
logger.info("Model loaded successfully")
model_load_queue.put(model)
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
model_load_queue.put(None)
@spaces.GPU(duration=120)
def generate_response(chat, kwargs):
global model
try:
if model is None:
logger.info("Starting model loading thread")
Thread(target=load_model).start()
model = model_load_queue.get(timeout=120)
if model is None:
return "Nie udało się załadować modelu. Proszę spróbować ponownie później."
logger.info("Preparing input for generation")
inputs = tokenizer(chat, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
if 'seed' in kwargs:
del kwargs['seed']
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
logger.info("Starting generation thread")
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
output = ""
try:
for new_text in streamer:
output += new_text
if output.endswith("</s>"):
output = output[:-4]
break
except Empty:
logger.warning("Timeout occurred during generation")
logger.info("Generation completed")
return output
except Exception as e:
logger.error(f"Error in generate_response: {str(e)}")
return f"Wystąpił błąd: {str(e)}"
def function(prompt, history=[]):
chat = "<s>"
for user_prompt, bot_response in history:
chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
chat += f"[INST] {prompt} [/INST]"
kwargs = dict(
max_new_tokens=4096,
do_sample=True,
temperature=0.5,
top_p=0.95,
repetition_penalty=1.0
)
return generate_response(chat, kwargs)
interface = gr.ChatInterface(
fn=function,
chatbot=gr.Chatbot(
avatar_images=None,
container=False,
show_copy_button=True,
layout='bubble',
render_markdown=True,
line_breaks=True
),
css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}',
autofocus=True,
fill_height=True,
analytics_enabled=False,
submit_btn='Chat',
stop_btn=None,
retry_btn=None,
undo_btn=None,
clear_btn=None
)
interface.launch(show_api=True, share=True)