Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig | |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
# Constants | |
MODEL_NAME = "deepseek-ai/deepseek-math-7b-base" | |
SAVE_PATH = "finetuned_deepseek_math" | |
def load_model(): | |
# 4-bit quantization configuration (currently commented out) | |
# bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_quant_type="nf4", | |
# bnb_4bit_use_double_quant=True, | |
# bnb_4bit_compute_dtype=torch.bfloat16 | |
# ) | |
# Load tokenizer and model in 4-bit mode | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
offload_folder="offload" | |
# quantization_config=bnb_config | |
) | |
model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
model.generation_config.pad_token_id = model.generation_config.eos_token_id | |
# Prepare model for k-bit training and wrap with LoRA via PEFT | |
model = prepare_model_for_kbit_training(model) | |
lora_config = LoraConfig( | |
r=20, | |
lora_alpha=40, | |
target_modules=["q_proj", "v_proj"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
# Load the fine-tuned adapter weights | |
model.load_pretrained(SAVE_PATH) | |
model.eval() | |
return tokenizer, model | |
def generate_output(prompt, tokenizer, model): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=20, | |
generation_config=model.generation_config | |
) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return result | |
st.title("Deepseek Math Fine-Tuned Model Inference") | |
st.write("Enter your prompt below:") | |
# Cache the model in Streamlit's session state so it's loaded only once. | |
if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
st.session_state.tokenizer, st.session_state.model = load_model() | |
user_input = st.text_input("Prompt", "π + π + π + π = 20 β π =") | |
if st.button("Generate Output"): | |
with st.spinner("Generating answer..."): | |
# Use the cached model from session state | |
tokenizer = st.session_state.tokenizer | |
model = st.session_state.model | |
output = generate_output(user_input, tokenizer, model) | |
st.success("Output generated!") | |
st.write("**Input:**", user_input) | |
st.write("**Output:**", output) |