File size: 2,010 Bytes
cb44543
0cdd743
65b4243
cb44543
0cdd743
cb44543
 
 
0cdd743
aff055d
0cdd743
 
 
5d74cf2
0cdd743
 
 
 
90aea8c
 
 
5d74cf2
7976834
90aea8c
7976834
c3e3be2
 
 
 
 
 
d5448bc
7976834
 
 
5d74cf2
7976834
0cdd743
cb44543
aff055d
cb44543
 
 
 
0cdd743
 
4a57f52
cb44543
aff055d
 
 
 
 
 
5d74cf2
 
 
 
 
 
 
cb44543
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

model = init_chat_model("gemini-2.0-flash", model_provider="google_genai")


def respond(
    user_input: str,
    dialog_history: list[dict],
    system_message: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
) -> str:
    """
    Respond to user input using the model.
    """
    # Set the model parameters
    model.temperature = temperature
    model.max_output_tokens = max_new_tokens
    model.top_p = top_p

    history_langchain_format = []
    # Add the dialog history to the history
    for msg in dialog_history:
        if msg['role'] == "user":
            history_langchain_format.append(
                HumanMessage(content=msg['content']))
        elif msg['role'] == "assistant":
            history_langchain_format.append(AIMessage(content=msg['content']))

    # Combine the system message, history, and user input into a single list
    model_input = [SystemMessage(content=system_message)] + \
        history_langchain_format + [HumanMessage(content=user_input)]

    response = model.invoke(model_input)
    return response.content


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    fn=respond,
    type="messages",
    # save_history=True,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.",
                   label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512,
                  step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7,
                  step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()