File size: 1,919 Bytes
b39d091
072eb93
b283808
90c5c05
072eb93
 
 
 
 
b39d091
 
072eb93
b39d091
072eb93
 
 
b39d091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
072eb93
 
b39d091
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
import torch

# Load the tokenizer and model
repo_name = "nvidia/Hymba-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)

# Move the model to GPU with float16 precision for efficiency
model = model.to("cuda").to(torch.float16)

# Initialize the conversation history
messages = [
    {"role": "system", "content": "You are a helpful assistant."}
]

# Define stopping criteria
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings=["</s>"])])

# Chat function for Gradio interface
def chat_function(user_input):
    # Add user message to the conversation history
    messages.append({"role": "user", "content": user_input})

    # Tokenize the conversation
    tokenized_chat = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to("cuda")

    # Generate a response
    outputs = model.generate(
        tokenized_chat["input_ids"], 
        max_new_tokens=256,
        do_sample=False,
        temperature=0.7,
        use_cache=True,
        stopping_criteria=stopping_criteria
    )

    # Decode the output response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Add the assistant's response to the conversation history
    messages.append({"role": "assistant", "content": response})

    return response

# Set up Gradio interface with the chatbot template
iface = gr.Interface(
    fn=chat_function,
    inputs=gr.inputs.Textbox(label="Your message", placeholder="Enter your message here..."),
    outputs=gr.outputs.Chatbot(),
    live=True,
    title="Hymba Chatbot",
    description="Chat with the Hymba-1.5B-Instruct model!"
)

# Launch the Gradio interface
iface.launch()