Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from transformers import AutoProcessor, Llama4ForConditionalGeneration | |
from PIL import Image | |
import requests | |
import torch | |
import gradio as gr | |
import re | |
import spaces | |
model_id = "meta-llama/Llama-Guard-4-12B" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = Llama4ForConditionalGeneration.from_pretrained( | |
model_id, | |
device_map="cuda:0", | |
torch_dtype=torch.bfloat16, | |
) | |
labels_dict = { | |
"S1": "Violent Crimes", | |
"S2": "Non-Violent Crimes", | |
"S3": "Sex Crimes", | |
"S4": "Child Exploitation", | |
"S5": "Defamation", | |
"S6": "Specialized Advice", | |
"S7": "Privacy", | |
"S8": "Intellectual Property", | |
"S9": "Indiscriminate Weapons", | |
"S10": "Hate", | |
"S11": "Self-Harm", | |
"S12": "Sexual Content", | |
"S13": "Elections", | |
} | |
def infer(image, text_input, model_output, exclude_categories): | |
if image is None and text_input: | |
messages = [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": text_input}] | |
}, | |
] | |
elif image is not None and text_input: | |
messages = [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": text_input}] | |
} | |
] | |
messages[0]["content"].append({"type": "image", "url": image}) | |
else: | |
return "Please provide at least text input." | |
if model_output: | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": model_output}] | |
} | |
) | |
print("messages", messages ) | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
exclude_category_keys=exclude_categories, | |
).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
do_sample=False, | |
) | |
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0] | |
if "unsafe" in response: | |
match = re.search(r'S(\d+)', response) | |
if match: | |
s_number = f"S{match.group(1)}" | |
category = labels_dict.get(s_number, "Unknown Category") | |
response = f"This content is unsafe under category {category}" | |
if "safe<|eot|>" in response: | |
response = "This conversation is safe." | |
return messages, response | |
with gr.Blocks() as demo: | |
gr.Markdown("## Llama Guard 4 for Multimodal Safety") | |
gr.Markdown( | |
"Llama Guard 4 is a safety moderation model for both large language and vision language models. " | |
"It can detect unsafe images and text inputs. To use it, simply input text or images along with text." | |
"You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation." | |
) | |
with gr.Column(): | |
image = gr.Image(label="Image Input (Optional)", type="filepath") | |
text_input = gr.Textbox(label="Text Input") | |
model_output = gr.Textbox(label="Model Output") | |
with gr.Accordion("Exclude Safety Categories", open=False): | |
exclude_categories = gr.CheckboxGroup( | |
choices=[k for k in labels_dict.keys()], | |
label="Select categories to EXCLUDE from moderation", | |
info="Selected categories will NOT be flagged." | |
) | |
btn = gr.Button("Submit") | |
complete_conversation = gr.Textbox(label="Complete Conversation Provided by User") | |
outputs = [complete_conversation, gr.Textbox(label="Safety Category")] | |
btn.click( | |
fn=infer, | |
inputs=[image, text_input, model_output, exclude_categories], | |
outputs=outputs | |
) | |
gr.Examples( | |
examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]], | |
inputs=[image, text_input, model_output, exclude_categories], | |
outputs=outputs | |
) | |
demo.launch() |