Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import os | |
import faiss | |
import torch | |
import numpy as np | |
from request import get_ft, get_topk | |
def load_model(): | |
"""Load DINOv2 model once and cache it""" | |
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') | |
model.eval() | |
model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
return model | |
def load_index(index_path): | |
"""Load FAISS index once and cache it""" | |
return faiss.read_index(index_path) | |
def distance_to_similarity(distances, temp=1e-4): | |
"""Convert distance to similarity""" | |
for ii in range(len(distances)): | |
contribs = distances[ii].max() - distances[ii] | |
contribs = contribs / temp | |
sum_contribs = np.exp(contribs).sum() | |
distances[ii] = np.exp(contribs) / sum_contribs | |
return distances | |
def calculate_rewards(subscription, num_generations, author_share, ro_share, num_users_k, similarities): | |
"""Calculate rewards based on user inputs and similarities""" | |
num_users = num_users_k * 1000 | |
paid_per_gen = subscription / num_generations | |
aro_share = paid_per_gen * (author_share + ro_share) / 100 * 100 # Convert to cents | |
rewards = [] | |
for sim in similarities[0]: | |
training_data_reward = aro_share * sim | |
author_month_reward = (training_data_reward / 100) * num_users / ((author_share + ro_share)/100) * (author_share/100) | |
ro_month_reward = (training_data_reward / 100) * num_users / ((author_share + ro_share)/100) * (ro_share/100) | |
rewards.append({ | |
'paid_per_month': f"{subscription:.0f}€", | |
'paid_per_gen': f"{paid_per_gen:.2f}€", | |
'aro_share': f"{aro_share:.2f}c€", | |
'attribution': f"{sim*100:.0f}%", | |
'training_data_reward': f"{training_data_reward:.2f}c€", | |
'author_month_reward': f"{author_month_reward:.0f}€", | |
'ro_month_reward': f"{ro_month_reward:.0f}€" | |
}) | |
return rewards | |
def main(): | |
st.title("Reward Simulator") | |
# Add introduction text and expandable details | |
st.markdown(""" | |
This simulator helps estimate potential rewards for authors and rights owners when their images are used | |
to train AI image generation models. | |
""") | |
with st.expander("Learn more about how it works"): | |
st.markdown(""" | |
### How it works | |
1. Select or upload an image that represents AI-generated content | |
2. The system finds similar images that might have influenced the generation in a database of 10M images (OpenImages) | |
3. Based on your parameters, it calculates potential rewards for: | |
- Original image authors | |
- Rights owners (e.g., stock photo companies, galleries) | |
### Key assumptions | |
- Attribution scores indicate the level of influence of training images | |
- Rewards are distributed based on subscription revenue | |
- Calculations use simplified models and are for demonstration purposes | |
### Use cases | |
- Explore fair compensation models for AI training data | |
- Simulate different revenue sharing scenarios | |
- Understand the relationship between model training and attribution | |
""") | |
# Load model and index | |
model = load_model() | |
index = load_index("data/openimages_index.bin") | |
with open("data/openimages_urls.txt", "r") as f: | |
urls = f.readlines() | |
# Sidebar controls for reward simulation | |
with st.sidebar: | |
st.subheader("Reward Simulation Parameters") | |
subscription = st.number_input( | |
"User Monthly Subscription (€/month)", | |
min_value=1.0, | |
max_value=100.0, | |
value=12.0, | |
help="Monthly subscription fee for users. Examples:\n" | |
"- Adobe Firefly: starts at $4.99\n" | |
"- Midjourney: starts at $10\n" | |
"- DALL·E 3: included with ChatGPT Plus ($20)\n" | |
"- Getty Edify: €45 for 25 generations (includes legal protection)" | |
) | |
num_generations = st.number_input( | |
"Number of Image Generations per Month", | |
min_value=1, | |
max_value=1000, | |
value=60, | |
help="Number of generations done by one user on average per month.\n" | |
"Examples:\n" | |
"- Adobe Firefly basic plan: 100 generations/month\n" | |
"- Midjourney basic plan: 200 generations/month" | |
) | |
author_share = st.number_input( | |
"Authors share (%)", | |
min_value=0.0, | |
max_value=100.0, | |
value=5.0, | |
help="Percentage of subscription allocated to authors. Typical examples:\n" | |
"- Printed books: 5-15%\n" | |
"- Music (performers + songwriters): 15-30% of net revenues\n" | |
"- Stock photography: 15-45% of revenues" | |
) | |
ro_share = st.number_input( | |
"Right Owners Share (%)", | |
min_value=0.0, | |
max_value=100.0, | |
value=10.0, | |
help="Percentage of subscription allocated to right owners. Examples:\n" | |
"- Music CMOs (ASCAP, BMI, SACEM): 2-8% of net revenues\n" | |
"- Image stocks: retain 55-85% of revenues" | |
) | |
num_users_k = st.number_input( | |
"Subscribers (in thousands)", | |
min_value=1, | |
max_value=10000, | |
value=500, | |
help="Number of paid users in thousands.\n" | |
"Note: Exact figures aren't public, but Midjourney is estimated\n" | |
"to have 2 to 5 million paying subscribers" | |
) | |
num_neighbors = st.slider( | |
"Number of similar images", | |
min_value=1, | |
max_value=20, | |
value=8, | |
help="Number of most similar images to display and calculate rewards for" | |
) | |
# Display default images in a row | |
st.subheader("Select an image to attribute rewards to similar images") | |
cols = st.columns(3) | |
default_images = { | |
"Image 1": "assets/1.jpg", | |
"Image 2": "assets/2.jpg", | |
"Image 3": "assets/3.jpg" | |
} | |
# Create session state for storing the selected image if it doesn't exist | |
if 'query_image' not in st.session_state: | |
st.session_state.query_image = Image.open("assets/1.jpg").convert('RGB') | |
# Display the three default images as buttons | |
for idx, (col, (name, path)) in enumerate(zip(cols, default_images.items())): | |
img = Image.open(path).convert('RGB') | |
col.image(img, caption=name, width=200, use_container_width=True) | |
# Center the button using columns | |
button_cols = col.columns([1, 2, 1]) # Create 3 sub-columns with 1:2:1 ratio | |
if button_cols[1].button(f"Select {name}", use_container_width=True): # Place button in middle column | |
st.session_state.query_image = img | |
# Upload option below | |
st.subheader("Or upload your own image:") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
st.session_state.query_image = Image.open(uploaded_file).convert('RGB') | |
# Display selected image | |
st.subheader("You query image:") | |
st.image(st.session_state.query_image, caption="Selected Image", width=300) | |
# Get features and search | |
features = get_ft(model, st.session_state.query_image) | |
distances, indices = get_topk(index, features, topk=num_neighbors) | |
similarities = distance_to_similarity(distances, temp=1e-5) | |
# Calculate rewards | |
rewards = calculate_rewards(subscription, num_generations, | |
author_share, ro_share, num_users_k, similarities) | |
# Display results in a table | |
st.subheader("Similar Images and Rewards") | |
for i in range(num_neighbors): | |
img_idx = indices[0][i] | |
reward = rewards[i] | |
cols = st.columns(4) | |
# Column 1: Similar image | |
cols[0].image(urls[img_idx], caption=f"Similar Image {i+1}", width=150) | |
# Column 2: Basic metrics | |
cols[1].markdown( | |
""" | |
<div style="margin-bottom: 8px;"> | |
<div class="metric-row" | |
title="Average payment per month and per paid user" | |
<span class="metric-label">Monthly Subscription:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
<div class="metric-row" | |
title="Payment per generated image = Monthly Subscription / Number of Generations" | |
<span class="metric-label">Paid per Generation:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
<div class="metric-row" | |
title="Share for Authors and Right Owners = Paid per Gen × (Authors + Right Owners share%)" | |
<span class="metric-label">Authors & RO Share:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
</div> | |
""".format( | |
reward['paid_per_month'], | |
reward['paid_per_gen'], | |
reward['aro_share'] | |
), | |
unsafe_allow_html=True | |
) | |
# Column 3: Attribution | |
cols[2].markdown( | |
""" | |
<div style="margin-bottom: 8px;"> | |
<div class="metric-row" | |
title="Percentage indicating how much this training data contributed to the generated image" | |
<span class="metric-label">Attribution:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
<div class="metric-row" | |
title="Reward for training data usage = Authors & RO Share × Attribution%" | |
<span class="metric-label">Training Data Reward:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
</div> | |
""".format( | |
reward['attribution'], | |
reward['training_data_reward'] | |
), | |
unsafe_allow_html=True | |
) | |
# Column 4: Monthly rewards | |
cols[3].markdown( | |
""" | |
<div style="margin-bottom: 8px;"> | |
<div class="metric-row" | |
title="Monthly reward = Training Data Reward × Number of Subscribers / (Authors + RO share%) × Author Share%" | |
<span class="metric-label">Author Monthly Reward:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
<div class="metric-row" | |
title="Monthly reward = Training Data Reward × Number of Subscribers / (Authors + RO share%) × RO Share%" | |
<span class="metric-label">Right Owner Monthly Reward:</span> | |
<span class="metric-value">{}</span> | |
</div> | |
</div> | |
""".format( | |
reward['author_month_reward'], | |
reward['ro_month_reward'] | |
), | |
unsafe_allow_html=True | |
) | |
# Add CSS for styling | |
st.markdown( | |
""" | |
<style> | |
.metric-row { | |
display: flex; | |
justify-content: space-between; | |
align-items: flex-start; | |
padding: 4px; | |
margin: 2px 0; | |
border-radius: 2px; | |
cursor: pointer; | |
} | |
.metric-row:hover { | |
background-color: #f0f2f6; | |
} | |
.metric-label { | |
font-family: monospace; | |
margin-right: 4px; | |
font-size: 12px; | |
} | |
.metric-value { | |
font-family: monospace; | |
font-size: 12px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |