File size: 1,983 Bytes
8281a06
a683b22
71a777a
c8e8642
8281a06
a683b22
 
71a777a
135f6c4
8281a06
fe08abd
e5bb798
a683b22
71a777a
 
 
 
 
 
a683b22
71a777a
 
 
 
e5bb798
 
71a777a
e5bb798
a683b22
8281a06
a683b22
 
8281a06
e5bb798
 
8281a06
e5bb798
8281a06
 
bfdb29d
8281a06
 
 
 
 
 
 
 
 
 
e5bb798
 
 
 
 
fca20ed
a683b22
8281a06
a683b22
e5bb798
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces


model_name = "INSAIT-Institute/MamayLM-Gemma-2-9B-IT-v0.1"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

@spaces.GPU
def respond(message, chat_history, system_message, max_new_tokens, temperature, top_p):

    prompt = f"{system_message.strip()}\n"
    for user, bot in chat_history:
        prompt += f"User: {user}\nAssistant: {bot}\n"
    prompt += f"User: {message}\nAssistant:"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        eos_token_id=tokenizer.eos_token_id,
    )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)

    response = decoded.split("Assistant:")[-1].strip().split("User:")[0].strip()
    return response

def clear_fn():
    return None

chat = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Ти мовна модель, яка добре володіє українською.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["Привіт!"],
        ["Хто такий Пес Патрон?"],
    ],
    title="💬 Chat with MamayLM",
    description="A multi-turn chat interface for MamayLM-v0.1-9B with configurable parameters.",
    theme="soft"
)

chat.launch()