File size: 3,487 Bytes
f7fc778
7938810
f9490b8
f7fc778
8475fdd
 
e3fd506
e234eec
 
 
 
9d9f151
f7fc778
8475fdd
f7fc778
8475fdd
 
f7fc778
8475fdd
 
e234eec
 
 
 
 
 
 
 
53cab83
f28f9e0
e234eec
 
 
 
 
 
7938810
e234eec
7938810
8475fdd
e234eec
 
 
 
 
 
 
8475fdd
e234eec
 
 
e3fd506
e234eec
 
8475fdd
e234eec
8475fdd
e234eec
 
 
8475fdd
e234eec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7938810
 
 
 
 
f7fc778
7938810
b7844b5
8475fdd
 
7938810
e3fd506
7938810
 
e234eec
7938810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7844b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from queue import Queue, Empty
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

model_id = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))

model = None
model_load_queue = Queue()

def load_model():
    global model
    try:
        if model is None:
            logger.info("Loading model...")
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                token=os.environ.get("MY_API_LLAMA_3_1"),
                torch_dtype=torch.bfloat16,
                device_map="auto",
                low_cpu_mem_usage=True,
                load_in_8bit=True
            )
            logger.info("Model loaded successfully")
        model_load_queue.put(model)
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        model_load_queue.put(None)

@spaces.GPU(duration=120)
def generate_response(chat, kwargs):
    global model
    try:
        if model is None:
            logger.info("Starting model loading thread")
            Thread(target=load_model).start()
            model = model_load_queue.get(timeout=120)
            if model is None:
                return "Nie udało się załadować modelu. Proszę spróbować ponownie później."

        logger.info("Preparing input for generation")
        inputs = tokenizer(chat, return_tensors="pt").to(model.device)
        streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)

        if 'seed' in kwargs:
            del kwargs['seed']

        generation_kwargs = dict(inputs, streamer=streamer, **kwargs)

        logger.info("Starting generation thread")
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        output = ""
        try:
            for new_text in streamer:
                output += new_text
                if output.endswith("</s>"):
                    output = output[:-4]
                    break
        except Empty:
            logger.warning("Timeout occurred during generation")

        logger.info("Generation completed")
        return output
    except Exception as e:
        logger.error(f"Error in generate_response: {str(e)}")
        return f"Wystąpił błąd: {str(e)}"

def function(prompt, history=[]):
    chat = "<s>"
    for user_prompt, bot_response in history:
        chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
    chat += f"[INST] {prompt} [/INST]"
    kwargs = dict(
        max_new_tokens=4096,
        do_sample=True,
        temperature=0.5,
        top_p=0.95,
        repetition_penalty=1.0
    )

    return generate_response(chat, kwargs)

interface = gr.ChatInterface(
    fn=function,
    chatbot=gr.Chatbot(
        avatar_images=None,
        container=False,
        show_copy_button=True,
        layout='bubble',
        render_markdown=True,
        line_breaks=True
    ),
    css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}',
    autofocus=True,
    fill_height=True,
    analytics_enabled=False,
    submit_btn='Chat',
    stop_btn=None,
    retry_btn=None,
    undo_btn=None,
    clear_btn=None
)

interface.launch(show_api=True, share=True)