File size: 3,487 Bytes
f7fc778 7938810 f9490b8 f7fc778 8475fdd e3fd506 e234eec 9d9f151 f7fc778 8475fdd f7fc778 8475fdd f7fc778 8475fdd e234eec 53cab83 f28f9e0 e234eec 7938810 e234eec 7938810 8475fdd e234eec 8475fdd e234eec e3fd506 e234eec 8475fdd e234eec 8475fdd e234eec 8475fdd e234eec 7938810 f7fc778 7938810 b7844b5 8475fdd 7938810 e3fd506 7938810 e234eec 7938810 b7844b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from queue import Queue, Empty
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model_id = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))
model = None
model_load_queue = Queue()
def load_model():
global model
try:
if model is None:
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=os.environ.get("MY_API_LLAMA_3_1"),
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
load_in_8bit=True
)
logger.info("Model loaded successfully")
model_load_queue.put(model)
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
model_load_queue.put(None)
@spaces.GPU(duration=120)
def generate_response(chat, kwargs):
global model
try:
if model is None:
logger.info("Starting model loading thread")
Thread(target=load_model).start()
model = model_load_queue.get(timeout=120)
if model is None:
return "Nie udało się załadować modelu. Proszę spróbować ponownie później."
logger.info("Preparing input for generation")
inputs = tokenizer(chat, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
if 'seed' in kwargs:
del kwargs['seed']
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
logger.info("Starting generation thread")
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
output = ""
try:
for new_text in streamer:
output += new_text
if output.endswith("</s>"):
output = output[:-4]
break
except Empty:
logger.warning("Timeout occurred during generation")
logger.info("Generation completed")
return output
except Exception as e:
logger.error(f"Error in generate_response: {str(e)}")
return f"Wystąpił błąd: {str(e)}"
def function(prompt, history=[]):
chat = "<s>"
for user_prompt, bot_response in history:
chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
chat += f"[INST] {prompt} [/INST]"
kwargs = dict(
max_new_tokens=4096,
do_sample=True,
temperature=0.5,
top_p=0.95,
repetition_penalty=1.0
)
return generate_response(chat, kwargs)
interface = gr.ChatInterface(
fn=function,
chatbot=gr.Chatbot(
avatar_images=None,
container=False,
show_copy_button=True,
layout='bubble',
render_markdown=True,
line_breaks=True
),
css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}',
autofocus=True,
fill_height=True,
analytics_enabled=False,
submit_btn='Chat',
stop_btn=None,
retry_btn=None,
undo_btn=None,
clear_btn=None
)
interface.launch(show_api=True, share=True) |