Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,091 Bytes
9fb02b5 b07260d 9fb02b5 b07260d 9fb02b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
from transformers import AutoProcessor, Llama4ForConditionalGeneration
from PIL import Image
import requests
import torch
import gradio as gr
import re
import spaces
model_id = "meta-llama/Llama-Guard-4-12B"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
labels_dict = {
"S1": "Violent Crimes",
"S2": "Non-Violent Crimes",
"S3": "Sex Crimes",
"S4": "Child Exploitation",
"S5": "Defamation",
"S6": "Specialized Advice",
"S7": "Privacy",
"S8": "Intellectual Property",
"S9": "Indiscriminate Weapons",
"S10": "Hate",
"S11": "Self-Harm",
"S12": "Sexual Content",
"S13": "Elections",
}
@spaces.GPU
def infer(image, text_input, model_output, exclude_categories):
if image is None and text_input:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text_input}]
},
]
elif image is not None and text_input:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text_input}]
}
]
messages[0]["content"].append({"type": "image", "url": image})
else:
return "Please provide at least text input."
if model_output:
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": model_output}]
}
)
print("messages", messages )
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
exclude_category_keys=exclude_categories,
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=False,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
if "unsafe" in response:
match = re.search(r'S(\d+)', response)
if match:
s_number = f"S{match.group(1)}"
category = labels_dict.get(s_number, "Unknown Category")
response = f"This content is unsafe under category {category}"
if "safe<|eot|>" in response:
response = "This conversation is safe."
return messages, response
with gr.Blocks() as demo:
gr.Markdown("## Llama Guard 4 for Multimodal Safety")
gr.Markdown(
"Llama Guard 4 is a safety moderation model for both large language and vision language models. "
"It can detect unsafe images and text inputs. To use it, simply input text or images along with text."
"You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation."
)
with gr.Column():
image = gr.Image(label="Image Input (Optional)", type="filepath")
text_input = gr.Textbox(label="Text Input")
model_output = gr.Textbox(label="Model Output")
with gr.Accordion("Exclude Safety Categories", open=False):
exclude_categories = gr.CheckboxGroup(
choices=[k for k in labels_dict.keys()],
label="Select categories to EXCLUDE from moderation",
info="Selected categories will NOT be flagged."
)
btn = gr.Button("Submit")
complete_conversation = gr.Textbox(label="Complete Conversation Provided by User")
outputs = [complete_conversation, gr.Textbox(label="Safety Category")]
btn.click(
fn=infer,
inputs=[image, text_input, model_output, exclude_categories],
outputs=outputs
)
gr.Examples(
examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]],
inputs=[image, text_input, model_output, exclude_categories],
outputs=outputs
)
demo.launch() |