File size: 4,091 Bytes
9fb02b5
 
 
 
 
 
 
b07260d
9fb02b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07260d
9fb02b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
from transformers import AutoProcessor, Llama4ForConditionalGeneration
from PIL import Image
import requests
import torch
import gradio as gr
import re
import spaces

model_id = "meta-llama/Llama-Guard-4-12B"

processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

labels_dict = {
    "S1": "Violent Crimes",
    "S2": "Non-Violent Crimes",
    "S3": "Sex Crimes",
    "S4": "Child Exploitation",
    "S5": "Defamation",
    "S6": "Specialized Advice",
    "S7": "Privacy",
    "S8": "Intellectual Property",
    "S9": "Indiscriminate Weapons",
    "S10": "Hate",
    "S11": "Self-Harm",
    "S12": "Sexual Content",
    "S13": "Elections",
}

@spaces.GPU
def infer(image, text_input, model_output, exclude_categories):
    if image is None and text_input: 
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": text_input}]
            },
        ]
    elif image is not None and text_input: 
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": text_input}]
            }
        ]
        messages[0]["content"].append({"type": "image", "url": image})
    else:
        return "Please provide at least text input."
    
    if model_output:
        messages.append(
            {
                "role": "assistant",
                "content": [{"type": "text", "text": model_output}]
            }
        )
    print("messages", messages )
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        exclude_category_keys=exclude_categories,  
    ).to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=False,
    )
    
    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]

    if "unsafe" in response:
        match = re.search(r'S(\d+)', response)
        if match:
            s_number = f"S{match.group(1)}"
            category = labels_dict.get(s_number, "Unknown Category")
            response = f"This content is unsafe under category {category}"
    if "safe<|eot|>" in response:
      response = "This conversation is safe."
    return messages, response

with gr.Blocks() as demo:
    gr.Markdown("## Llama Guard 4 for Multimodal Safety")
    gr.Markdown(
        "Llama Guard 4 is a safety moderation model for both large language and vision language models. "
        "It can detect unsafe images and text inputs. To use it, simply input text or images along with text."
        "You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation."
    )
    
    with gr.Column():
        image = gr.Image(label="Image Input (Optional)", type="filepath")
        text_input = gr.Textbox(label="Text Input")
        model_output = gr.Textbox(label="Model Output")
        
        with gr.Accordion("Exclude Safety Categories", open=False):
            exclude_categories = gr.CheckboxGroup(
                choices=[k for k in labels_dict.keys()],
                label="Select categories to EXCLUDE from moderation",
                info="Selected categories will NOT be flagged."
            )

        btn = gr.Button("Submit")
        complete_conversation = gr.Textbox(label="Complete Conversation Provided by User")
        outputs = [complete_conversation, gr.Textbox(label="Safety Category")]

    btn.click(
        fn=infer,
        inputs=[image, text_input, model_output, exclude_categories],
        outputs=outputs
    )

    gr.Examples(
        examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]],
        inputs=[image, text_input, model_output, exclude_categories],
        outputs=outputs
    )

demo.launch()