abooze commited on
Commit
da5aaca
·
verified ·
1 Parent(s): 360329c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
+
5
+ # Load model and tokenizer
6
+ model_name = "abooze/ft-deepseek-llm-7b-chat-dpo-pairs"
7
+ st.title("💬 RealMind AI")
8
+ st.markdown("Chat with RealMind AI!")
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype= torch.float32,
16
+ device_map="auto",
17
+ trust_remote_code=True
18
+ )
19
+ gen_config = GenerationConfig.from_pretrained(model_name)
20
+ gen_config.pad_token_id = gen_config.eos_token_id
21
+ return tokenizer, model, gen_config
22
+
23
+ tokenizer, model, gen_config = load_model()
24
+
25
+ # Session state to hold chat history
26
+ # if "messages" not in st.session_state:
27
+ # st.session_state.messages = [
28
+ # {"role": "system", "content": "You are a helpful assistant."}
29
+ # ]
30
+
31
+ # # Display chat history
32
+ # for msg in st.session_state.messages:
33
+ # if msg["role"] != "system":
34
+ # st.chat_message(msg["role"]).write(msg["content"])
35
+
36
+ # Chat input
37
+ user_input = st.chat_input("Ask something...")
38
+ if user_input:
39
+ # Add user input to message history
40
+ st.session_state.messages = [{"role": "user", "content": user_input}]
41
+ st.chat_message("user").write(user_input)
42
+
43
+ with st.spinner("Generating response..."):
44
+ input_ids = tokenizer.apply_chat_template(
45
+ st.session_state.messages,
46
+ return_tensors="pt",
47
+ add_generation_prompt=True
48
+ ).to(model.device)
49
+
50
+ outputs = model.generate(
51
+ input_ids=input_ids,
52
+ max_new_tokens=512,
53
+ temperature=0.7,
54
+ top_p=0.95,
55
+ do_sample=True,
56
+ pad_token_id=tokenizer.eos_token_id
57
+ )
58
+
59
+ # Decode and extract response
60
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+ # Extract only the assistant's last reply
62
+ assistant_reply = response.split("<|assistant|>\n")[-1].strip()
63
+
64
+ st.chat_message("assistant").write(assistant_reply)
65
+ st.session_state.messages.append({"role": "assistant", "content": assistant_reply})