File size: 1,782 Bytes
26e7e24
 
4c921ac
26e7e24
29a7c13
4c921ac
 
f2ceb0a
29a7c13
b580227
 
c6f760b
f2ceb0a
c6f760b
29a7c13
ec2ca3d
f2ceb0a
 
ec2ca3d
 
f2ceb0a
26e7e24
317ab43
ec2ca3d
 
f2ceb0a
26e7e24
29a7c13
317ab43
f2ceb0a
 
 
 
 
29a7c13
 
317ab43
 
c6f760b
 
26e7e24
 
f2ceb0a
ec2ca3d
f2ceb0a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# ๐Ÿ”ฅ ุชุญุฏูŠุฏ ุงู„ุฌู‡ุงุฒ ุงู„ู…ู†ุงุณุจ
device = "cuda" if torch.cuda.is_available() else "cpu"

# ๐Ÿ”ฅ ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ ุจุทุฑูŠู‚ุฉ ู…ุญุณู†ุฉ
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,  # โœ… ุถุจุท ุงู„ู€ device_map ุจุดูƒู„ ุฃุฏู‚
    trust_remote_code=True
).eval()

# ๐Ÿ”ฅ ุชุญู…ูŠู„ ุงู„ุชูˆูƒู†ูŠุฒุฑ
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token  

# ๐Ÿ”ฅ ุฅุนุฏุงุฏ ุงู„ุฏุงู„ุฉ ุงู„ุฎุงุตุฉ ุจุงู„ู€ Chatbot
def chatbot(user_input):
    if not user_input.strip():
        return "Please enter a message."

    inputs = tokenizer(user_input, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        output = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],  # โœ… ุฅุตู„ุงุญ ู…ุดูƒู„ุฉ ุงู„ู…ุงุณูƒ
            max_new_tokens=50,  # โœ… ุงุณุชุฎุฏุงู… max_new_tokens ุจุฏู„ุงู‹ ู…ู† max_length
            temperature=0.6,
            top_p=0.8,
            do_sample=True,
            early_stopping=True,
            pad_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

# ๐Ÿ”ฅ ุชุดุบูŠู„ ุงู„ูˆุงุฌู‡ุฉ ุจุฏูˆู† share=True
iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="Octagon 2.0 Chatbot")

# โœ… ุฅุตู„ุงุญ ู…ุดูƒู„ุฉ POST method not allowed
iface.launch(ssl=False, debug=True)