Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import pyarrow.parquet as pq | |
import pyarrow.dataset as ds | |
import time | |
import os | |
import plotly.graph_objects as go | |
import gc | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub import login | |
st.set_page_config(layout="wide") | |
hf_token = os.getenv('HF_TOKEN') | |
hf_repo = os.getenv('HF_REPO') | |
login(token=hf_token) | |
data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0) | |
if data_source == "Danbooru": | |
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE1'), repo_type="dataset") | |
elif data_source == "Gelbooru": | |
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE2'), repo_type="dataset") | |
elif data_source == "Rule 34": | |
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE3'), repo_type="dataset") | |
def load_parquet_metadata(parquet_file): | |
try: | |
parquet_dataset = pq.ParquetFile(parquet_file) | |
metadata = parquet_dataset.metadata | |
num_rows = metadata.num_rows | |
sample_df = next(parquet_dataset.iter_batches(batch_size=10)).to_pandas() | |
if 'post_id' in sample_df.columns: | |
try: | |
min_post_id = float('inf') | |
max_post_id = float('-inf') | |
for i in range(parquet_dataset.metadata.num_row_groups): | |
row_group = parquet_dataset.metadata.row_group(i) | |
for j in range(row_group.num_columns): | |
col = row_group.column(j) | |
if col.path_in_schema == 'post_id': | |
stats = col.statistics | |
if stats is not None: | |
min_post_id = min(min_post_id, stats.min) | |
max_post_id = max(max_post_id, stats.max) | |
if min_post_id == float('inf') or max_post_id == float('-inf'): | |
raise ValueError("Invalid post_id range") | |
except Exception as e: | |
st.warning(f"Unable to get post_id range from statistics: {str(e)}") | |
min_post_id = float('inf') | |
max_post_id = float('-inf') | |
with pq.ParquetReader(parquet_file) as reader: | |
first_batch = next(reader.iter_batches(batch_size=1000)) | |
first_df = first_batch.to_pandas() | |
batch_min = first_df['post_id'].min() | |
batch_max = first_df['post_id'].max() | |
min_post_id = min(min_post_id, batch_min) | |
max_post_id = max(max_post_id, batch_max) | |
num_row_groups = reader.num_row_groups | |
sample_indices = [0, num_row_groups//2, num_row_groups-1] | |
for idx in sample_indices: | |
if idx >= 0 and idx < num_row_groups: | |
batch = reader.read_row_group(idx).to_pandas() | |
batch_min = batch['post_id'].min() | |
batch_max = batch['post_id'].max() | |
min_post_id = min(min_post_id, batch_min) | |
max_post_id = max(max_post_id, batch_max) | |
else: | |
min_post_id = 0 | |
max_post_id = 100000 | |
available_ratings = [] | |
if 'rating' in sample_df.columns: | |
ratings_set = set() | |
for i in range(min(3, parquet_dataset.num_row_groups)): | |
sample = parquet_dataset.read_row_group(i, columns=['rating']).to_pandas() | |
ratings_set.update(sample['rating'].unique()) | |
available_ratings = sorted(list(ratings_set)) | |
else: | |
available_ratings = ['general'] | |
print(f"Metadata loaded: {num_rows} rows, post_id range: {min_post_id}-{max_post_id}") | |
return { | |
'num_rows': num_rows, | |
'min_post_id': int(min_post_id), | |
'max_post_id': int(max_post_id), | |
'available_ratings': available_ratings, | |
'columns': sample_df.columns.tolist() | |
} | |
except Exception as e: | |
st.error(f"Error loading Parquet metadata: {str(e)}") | |
return { | |
'num_rows': 0, | |
'min_post_id': 0, | |
'max_post_id': 100000, | |
'available_ratings': ['general'], | |
'columns': [] | |
} | |
def get_filtered_batch(parquet_file, filters, needed_columns, sort_option): | |
try: | |
dataset = ds.dataset(parquet_file, format='parquet') | |
pa_filters = [] | |
for col, op, val in filters: | |
if col in ['post_id', 'ava_score', 'aesthetic_score']: | |
if op == '>=': | |
pa_filters.append(ds.field(col) >= val) | |
elif op == '<=': | |
pa_filters.append(ds.field(col) <= val) | |
elif op == 'in' and len(val) > 0: | |
rating_filters = [ds.field(col) == r for r in val] | |
if rating_filters: | |
or_expr = rating_filters[0] | |
for rf in rating_filters[1:]: | |
or_expr = or_expr | rf | |
pa_filters.append(or_expr) | |
final_filter = None | |
if pa_filters: | |
final_filter = pa_filters[0] | |
for f in pa_filters[1:]: | |
final_filter = final_filter & f | |
scanner = dataset.scanner(columns=needed_columns, filter=final_filter) | |
df = scanner.to_table().to_pandas() | |
df.set_index('post_id', inplace=True) | |
if sort_option == "Post ID (Descending)": | |
df = df.sort_values(by=df.index.name, ascending=False) | |
elif sort_option == "Post ID (Ascending)": | |
df = df.sort_values(by=df.index.name, ascending=True) | |
elif sort_option == "AVA Score": | |
df = df.sort_values(by='ava_score', ascending=False) | |
elif sort_option == "Aesthetic Score": | |
df = df.sort_values(by='aesthetic_score', ascending=False) | |
return df | |
except Exception as e: | |
st.error(f"Error reading batch: {str(e)}") | |
return pd.DataFrame() | |
def process_tags_for_filtering(df, selected_tags, undesired_tags): | |
if not selected_tags and not undesired_tags: | |
return df | |
mask = np.ones(len(df), dtype=bool) | |
if selected_tags: | |
for i, tags_list in enumerate(df['tags']): | |
if mask[i]: | |
if isinstance(tags_list, list): | |
tags_set = set(tags_list) | |
elif isinstance(tags_list, (np.ndarray, np.generic)): | |
tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() | |
elif tags_list: | |
tags_set = {tags_list} | |
else: | |
tags_set = set() | |
if not selected_tags.issubset(tags_set): | |
mask[i] = False | |
if undesired_tags: | |
for i, tags_list in enumerate(df['tags']): | |
if mask[i]: | |
if isinstance(tags_list, list): | |
tags_set = set(tags_list) | |
elif isinstance(tags_list, (np.ndarray, np.generic)): | |
tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() | |
elif tags_list: | |
tags_set = {tags_list} | |
else: | |
tags_set = set() | |
if undesired_tags.intersection(tags_set): | |
mask[i] = False | |
return df[mask] | |
def get_filtered_data(parquet_file, filters_str, sort_option, selected_tags_str, undesired_tags_str, page_number, items_per_page): | |
filters = eval(filters_str) | |
selected_tags = set(eval(selected_tags_str)) | |
undesired_tags = set(eval(undesired_tags_str)) | |
needed_columns = ['post_id', 'tags', 'ava_score', 'aesthetic_score', 'rating', 'large_file_url'] | |
df = get_filtered_batch(parquet_file, filters, needed_columns, sort_option) | |
if selected_tags or undesired_tags: | |
df = process_tags_for_filtering(df, selected_tags, undesired_tags) | |
return df | |
st.title(f'{data_source} Images') | |
metadata = load_parquet_metadata(parquet_file) | |
score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.0, 10.0), step=0.1) | |
score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1) | |
min_post_id = metadata['min_post_id'] | |
max_post_id = metadata['max_post_id'] | |
post_id_range = st.sidebar.slider('Select Post ID range', | |
min_value=min_post_id, | |
max_value=max_post_id, | |
value=(min_post_id, max_post_id), | |
step=1000) | |
available_ratings = metadata['available_ratings'] | |
selected_ratings = st.sidebar.multiselect( | |
'Select ratings to include', | |
options=available_ratings, | |
default=[], | |
help='Filter images by their rating category' | |
) | |
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1) | |
items_per_page = 50 | |
sort_option = st.sidebar.selectbox('Sort by', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0) | |
user_input_tags = st.text_input('Enter tags (space-separated)', value='1girl scenery', help='Filter images based on tags. Use "-" to exclude tags.') | |
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')]) | |
undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')]) | |
filters = [ | |
('ava_score', '>=', score_range[0]), | |
('ava_score', '<=', score_range[1]), | |
('aesthetic_score', '>=', score_range_v2[0]), | |
('aesthetic_score', '<=', score_range_v2[1]), | |
('post_id', '>=', post_id_range[0]), | |
('post_id', '<=', post_id_range[1]), | |
] | |
if selected_ratings: | |
filters.append(('rating', 'in', selected_ratings)) | |
filters_str = repr(filters) | |
selected_tags_str = repr(list(selected_tags)) | |
undesired_tags_str = repr(list(undesired_tags)) | |
start_time = time.time() | |
current_batch = get_filtered_data( | |
parquet_file, filters_str, sort_option, | |
selected_tags_str, undesired_tags_str, | |
page_number, items_per_page | |
) | |
print(f"Data retrieved in {time.time() - start_time:.2f} seconds") | |
batch_start = (page_number - 1) * items_per_page | |
end_idx = min(batch_start + items_per_page, len(current_batch)) | |
current_data = current_batch.iloc[batch_start:end_idx] if batch_start < len(current_batch) else pd.DataFrame() | |
st.sidebar.write(f"Images on this page: {len(current_data)}") | |
st.sidebar.write(f"Total filtered sample: {len(current_batch)}") | |
columns_per_row = 5 | |
rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)] | |
for row in rows: | |
cols = st.columns(columns_per_row) | |
for col, (_, row_data) in zip(cols, row.iterrows()): | |
with col: | |
post_id = row_data.name | |
if data_source == "Danbooru": | |
link = f"https://danbooru.donmai.us/posts/{post_id}" | |
elif data_source == "Gelbooru": | |
link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}" | |
elif data_source == "Rule 34": | |
link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}" | |
st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA: {row_data['ava_score']:.2f}, Aesthetic: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True) | |
def histogram_slider(df, column1, column2): | |
if df.empty: | |
return | |
sample_size = min(5000, len(df)) | |
if len(df) > sample_size: | |
step = len(df) // sample_size | |
indices = np.arange(0, len(df), step)[:sample_size] | |
sample_data = df.iloc[indices] | |
else: | |
sample_data = df | |
hist1, bin_edges1 = np.histogram(sample_data[column1].dropna(), bins=30) | |
hist2, bin_edges2 = np.histogram(sample_data[column2].dropna(), bins=30) | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=(bin_edges1[:-1] + bin_edges1[1:])/2, | |
y=hist1, | |
name=column1, | |
opacity=0.75, | |
width=(bin_edges1[1]-bin_edges1[0]) | |
)) | |
fig.add_trace(go.Bar( | |
x=(bin_edges2[:-1] + bin_edges2[1:])/2, | |
y=hist2, | |
name=column2, | |
opacity=0.75, | |
width=(bin_edges2[1]-bin_edges2[0]) | |
)) | |
fig.update_layout( | |
barmode='overlay', | |
bargap=0.1, | |
height=200, | |
margin=dict(l=0, r=0, t=0, b=0), | |
legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5), | |
) | |
st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False}) | |
del sample_data, hist1, hist2, bin_edges1, bin_edges2 | |
gc.collect() | |
if not current_batch.empty: | |
start_time = time.time() | |
histogram_slider(current_batch, 'ava_score', 'aesthetic_score') | |
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds") |