kv-press / app.py
ariG23498's picture
ariG23498 HF staff
Update app.py
7c5e5ba verified
import gradio as gr
import requests
from bs4 import BeautifulSoup
import torch
from transformers import pipeline
from kvpress import (
ExpectedAttentionPress,
KnormPress,
ObservedAttentionPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
TOVAPress,
)
import spaces
@spaces.GPU
def process_request(url, question, press_type, compression_ratio):
try:
# Fetch Wikipedia content
content = requests.get(url).content
soup = BeautifulSoup(content, "html.parser")
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n"
# Initialize the press
press_class = press_map.get(press_type)
if not press_class:
return "Invalid press type selected.", None
press = press_class(compression_ratio)
# Generate prediction
pred_answer = pipe(context, question=question, press=press)["answer"]
return pred_answer
except Exception as e:
return str(e)
# Load pipeline
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "sdpa"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
torch_dtype="auto",
model_kwargs={"attn_implementation": attn_implementation},
)
# Mapping of press types
press_map = {
"ExpectedAttentionPress": ExpectedAttentionPress,
"KnormPress": KnormPress,
"ObservedAttentionPress": ObservedAttentionPress,
"RandomPress": RandomPress,
"SnapKVPress": SnapKVPress,
"StreamingLLMPress": StreamingLLMPress,
"TOVAPress": TOVAPress,
}
# Gradio UI
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""# Wikipedia Article Question Answering with KV-Press
Enter a Wikipedia article URL, type a question, and select a press type with a compression ratio to get an answer.
""")
with gr.Row():
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here")
question_input = gr.Textbox(label="Question", placeholder="Type your question here")
with gr.Row():
press_selector = gr.Dropdown(
choices=list(press_map.keys()),
value="ExpectedAttentionPress",
label="Select Press Type",
)
compression_slider = gr.Slider(
minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Compression Ratio"
)
output = gr.Textbox(label="Output", lines=10)
submit_button = gr.Button("Submit")
submit_button.click(
process_request,
inputs=[url_input, question_input, press_selector, compression_slider],
outputs=[output],
)
return demo
demo = gradio_interface()
demo.launch()