llama-guard-4 / app.py
merve's picture
merve HF Staff
Update app.py
b07260d verified
raw
history blame contribute delete
4.09 kB
import os
from transformers import AutoProcessor, Llama4ForConditionalGeneration
from PIL import Image
import requests
import torch
import gradio as gr
import re
import spaces
model_id = "meta-llama/Llama-Guard-4-12B"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
labels_dict = {
"S1": "Violent Crimes",
"S2": "Non-Violent Crimes",
"S3": "Sex Crimes",
"S4": "Child Exploitation",
"S5": "Defamation",
"S6": "Specialized Advice",
"S7": "Privacy",
"S8": "Intellectual Property",
"S9": "Indiscriminate Weapons",
"S10": "Hate",
"S11": "Self-Harm",
"S12": "Sexual Content",
"S13": "Elections",
}
@spaces.GPU
def infer(image, text_input, model_output, exclude_categories):
if image is None and text_input:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text_input}]
},
]
elif image is not None and text_input:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text_input}]
}
]
messages[0]["content"].append({"type": "image", "url": image})
else:
return "Please provide at least text input."
if model_output:
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": model_output}]
}
)
print("messages", messages )
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
exclude_category_keys=exclude_categories,
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=False,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
if "unsafe" in response:
match = re.search(r'S(\d+)', response)
if match:
s_number = f"S{match.group(1)}"
category = labels_dict.get(s_number, "Unknown Category")
response = f"This content is unsafe under category {category}"
if "safe<|eot|>" in response:
response = "This conversation is safe."
return messages, response
with gr.Blocks() as demo:
gr.Markdown("## Llama Guard 4 for Multimodal Safety")
gr.Markdown(
"Llama Guard 4 is a safety moderation model for both large language and vision language models. "
"It can detect unsafe images and text inputs. To use it, simply input text or images along with text."
"You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation."
)
with gr.Column():
image = gr.Image(label="Image Input (Optional)", type="filepath")
text_input = gr.Textbox(label="Text Input")
model_output = gr.Textbox(label="Model Output")
with gr.Accordion("Exclude Safety Categories", open=False):
exclude_categories = gr.CheckboxGroup(
choices=[k for k in labels_dict.keys()],
label="Select categories to EXCLUDE from moderation",
info="Selected categories will NOT be flagged."
)
btn = gr.Button("Submit")
complete_conversation = gr.Textbox(label="Complete Conversation Provided by User")
outputs = [complete_conversation, gr.Textbox(label="Safety Category")]
btn.click(
fn=infer,
inputs=[image, text_input, model_output, exclude_categories],
outputs=outputs
)
gr.Examples(
examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]],
inputs=[image, text_input, model_output, exclude_categories],
outputs=outputs
)
demo.launch()