import os from transformers import AutoProcessor, Llama4ForConditionalGeneration from PIL import Image import requests import torch import gradio as gr import re import spaces model_id = "meta-llama/Llama-Guard-4-12B" processor = AutoProcessor.from_pretrained(model_id) model = Llama4ForConditionalGeneration.from_pretrained( model_id, device_map="cuda:0", torch_dtype=torch.bfloat16, ) labels_dict = { "S1": "Violent Crimes", "S2": "Non-Violent Crimes", "S3": "Sex Crimes", "S4": "Child Exploitation", "S5": "Defamation", "S6": "Specialized Advice", "S7": "Privacy", "S8": "Intellectual Property", "S9": "Indiscriminate Weapons", "S10": "Hate", "S11": "Self-Harm", "S12": "Sexual Content", "S13": "Elections", } @spaces.GPU def infer(image, text_input, model_output, exclude_categories): if image is None and text_input: messages = [ { "role": "user", "content": [{"type": "text", "text": text_input}] }, ] elif image is not None and text_input: messages = [ { "role": "user", "content": [{"type": "text", "text": text_input}] } ] messages[0]["content"].append({"type": "image", "url": image}) else: return "Please provide at least text input." if model_output: messages.append( { "role": "assistant", "content": [{"type": "text", "text": model_output}] } ) print("messages", messages ) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", exclude_category_keys=exclude_categories, ).to(model.device) outputs = model.generate( **inputs, max_new_tokens=100, do_sample=False, ) response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0] if "unsafe" in response: match = re.search(r'S(\d+)', response) if match: s_number = f"S{match.group(1)}" category = labels_dict.get(s_number, "Unknown Category") response = f"This content is unsafe under category {category}" if "safe<|eot|>" in response: response = "This conversation is safe." return messages, response with gr.Blocks() as demo: gr.Markdown("## Llama Guard 4 for Multimodal Safety") gr.Markdown( "Llama Guard 4 is a safety moderation model for both large language and vision language models. " "It can detect unsafe images and text inputs. To use it, simply input text or images along with text." "You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation." ) with gr.Column(): image = gr.Image(label="Image Input (Optional)", type="filepath") text_input = gr.Textbox(label="Text Input") model_output = gr.Textbox(label="Model Output") with gr.Accordion("Exclude Safety Categories", open=False): exclude_categories = gr.CheckboxGroup( choices=[k for k in labels_dict.keys()], label="Select categories to EXCLUDE from moderation", info="Selected categories will NOT be flagged." ) btn = gr.Button("Submit") complete_conversation = gr.Textbox(label="Complete Conversation Provided by User") outputs = [complete_conversation, gr.Textbox(label="Safety Category")] btn.click( fn=infer, inputs=[image, text_input, model_output, exclude_categories], outputs=outputs ) gr.Examples( examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]], inputs=[image, text_input, model_output, exclude_categories], outputs=outputs ) demo.launch()