|
import gradio as gr |
|
import requests |
|
from bs4 import BeautifulSoup |
|
import torch |
|
from transformers import pipeline |
|
from kvpress import ( |
|
ExpectedAttentionPress, |
|
KnormPress, |
|
ObservedAttentionPress, |
|
RandomPress, |
|
SnapKVPress, |
|
StreamingLLMPress, |
|
TOVAPress, |
|
) |
|
import spaces |
|
|
|
@spaces.GPU |
|
def process_request(url, question, press_type, compression_ratio): |
|
try: |
|
|
|
content = requests.get(url).content |
|
soup = BeautifulSoup(content, "html.parser") |
|
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n" |
|
|
|
|
|
press_class = press_map.get(press_type) |
|
if not press_class: |
|
return "Invalid press type selected.", None |
|
|
|
press = press_class(compression_ratio) |
|
|
|
|
|
pred_answer = pipe(context, question=question, press=press)["answer"] |
|
return pred_answer |
|
|
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
device = "cuda:0" |
|
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
attn_implementation = "sdpa" |
|
pipe = pipeline( |
|
"kv-press-text-generation", |
|
model=ckpt, |
|
device=device, |
|
torch_dtype="auto", |
|
model_kwargs={"attn_implementation": attn_implementation}, |
|
) |
|
|
|
|
|
press_map = { |
|
"ExpectedAttentionPress": ExpectedAttentionPress, |
|
"KnormPress": KnormPress, |
|
"ObservedAttentionPress": ObservedAttentionPress, |
|
"RandomPress": RandomPress, |
|
"SnapKVPress": SnapKVPress, |
|
"StreamingLLMPress": StreamingLLMPress, |
|
"TOVAPress": TOVAPress, |
|
} |
|
|
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("""# Wikipedia Article Question Answering with KV-Press |
|
Enter a Wikipedia article URL, type a question, and select a press type with a compression ratio to get an answer. |
|
""") |
|
|
|
with gr.Row(): |
|
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here") |
|
question_input = gr.Textbox(label="Question", placeholder="Type your question here") |
|
|
|
with gr.Row(): |
|
press_selector = gr.Dropdown( |
|
choices=list(press_map.keys()), |
|
value="ExpectedAttentionPress", |
|
label="Select Press Type", |
|
) |
|
compression_slider = gr.Slider( |
|
minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Compression Ratio" |
|
) |
|
|
|
output = gr.Textbox(label="Output", lines=10) |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
submit_button.click( |
|
process_request, |
|
inputs=[url_input, question_input, press_selector, compression_slider], |
|
outputs=[output], |
|
) |
|
|
|
return demo |
|
|
|
demo = gradio_interface() |
|
demo.launch() |
|
|