import streamlit as st import pandas as pd import pyarrow.parquet as pq import pyarrow.dataset as ds import time import os import plotly.graph_objects as go import gc import numpy as np from huggingface_hub import hf_hub_download from huggingface_hub import login st.set_page_config(layout="wide") hf_token = os.getenv('HF_TOKEN') hf_repo = os.getenv('HF_REPO') login(token=hf_token) data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0) if data_source == "Danbooru": parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE1'), repo_type="dataset") elif data_source == "Gelbooru": parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE2'), repo_type="dataset") elif data_source == "Rule 34": parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE3'), repo_type="dataset") @st.cache_resource def load_parquet_metadata(parquet_file): try: parquet_dataset = pq.ParquetFile(parquet_file) metadata = parquet_dataset.metadata num_rows = metadata.num_rows sample_df = next(parquet_dataset.iter_batches(batch_size=10)).to_pandas() if 'post_id' in sample_df.columns: try: min_post_id = float('inf') max_post_id = float('-inf') for i in range(parquet_dataset.metadata.num_row_groups): row_group = parquet_dataset.metadata.row_group(i) for j in range(row_group.num_columns): col = row_group.column(j) if col.path_in_schema == 'post_id': stats = col.statistics if stats is not None: min_post_id = min(min_post_id, stats.min) max_post_id = max(max_post_id, stats.max) if min_post_id == float('inf') or max_post_id == float('-inf'): raise ValueError("Invalid post_id range") except Exception as e: st.warning(f"Unable to get post_id range from statistics: {str(e)}") min_post_id = float('inf') max_post_id = float('-inf') with pq.ParquetReader(parquet_file) as reader: first_batch = next(reader.iter_batches(batch_size=1000)) first_df = first_batch.to_pandas() batch_min = first_df['post_id'].min() batch_max = first_df['post_id'].max() min_post_id = min(min_post_id, batch_min) max_post_id = max(max_post_id, batch_max) num_row_groups = reader.num_row_groups sample_indices = [0, num_row_groups//2, num_row_groups-1] for idx in sample_indices: if idx >= 0 and idx < num_row_groups: batch = reader.read_row_group(idx).to_pandas() batch_min = batch['post_id'].min() batch_max = batch['post_id'].max() min_post_id = min(min_post_id, batch_min) max_post_id = max(max_post_id, batch_max) else: min_post_id = 0 max_post_id = 100000 available_ratings = [] if 'rating' in sample_df.columns: ratings_set = set() for i in range(min(3, parquet_dataset.num_row_groups)): sample = parquet_dataset.read_row_group(i, columns=['rating']).to_pandas() ratings_set.update(sample['rating'].unique()) available_ratings = sorted(list(ratings_set)) else: available_ratings = ['general'] print(f"Metadata loaded: {num_rows} rows, post_id range: {min_post_id}-{max_post_id}") return { 'num_rows': num_rows, 'min_post_id': int(min_post_id), 'max_post_id': int(max_post_id), 'available_ratings': available_ratings, 'columns': sample_df.columns.tolist() } except Exception as e: st.error(f"Error loading Parquet metadata: {str(e)}") return { 'num_rows': 0, 'min_post_id': 0, 'max_post_id': 100000, 'available_ratings': ['general'], 'columns': [] } def get_filtered_batch(parquet_file, filters, needed_columns, sort_option): try: dataset = ds.dataset(parquet_file, format='parquet') pa_filters = [] for col, op, val in filters: if col in ['post_id', 'ava_score', 'aesthetic_score']: if op == '>=': pa_filters.append(ds.field(col) >= val) elif op == '<=': pa_filters.append(ds.field(col) <= val) elif op == 'in' and len(val) > 0: rating_filters = [ds.field(col) == r for r in val] if rating_filters: or_expr = rating_filters[0] for rf in rating_filters[1:]: or_expr = or_expr | rf pa_filters.append(or_expr) final_filter = None if pa_filters: final_filter = pa_filters[0] for f in pa_filters[1:]: final_filter = final_filter & f scanner = dataset.scanner(columns=needed_columns, filter=final_filter) df = scanner.to_table().to_pandas() df.set_index('post_id', inplace=True) if sort_option == "Post ID (Descending)": df = df.sort_values(by=df.index.name, ascending=False) elif sort_option == "Post ID (Ascending)": df = df.sort_values(by=df.index.name, ascending=True) elif sort_option == "AVA Score": df = df.sort_values(by='ava_score', ascending=False) elif sort_option == "Aesthetic Score": df = df.sort_values(by='aesthetic_score', ascending=False) return df except Exception as e: st.error(f"Error reading batch: {str(e)}") return pd.DataFrame() def process_tags_for_filtering(df, selected_tags, undesired_tags): if not selected_tags and not undesired_tags: return df mask = np.ones(len(df), dtype=bool) if selected_tags: for i, tags_list in enumerate(df['tags']): if mask[i]: if isinstance(tags_list, list): tags_set = set(tags_list) elif isinstance(tags_list, (np.ndarray, np.generic)): tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() elif tags_list: tags_set = {tags_list} else: tags_set = set() if not selected_tags.issubset(tags_set): mask[i] = False if undesired_tags: for i, tags_list in enumerate(df['tags']): if mask[i]: if isinstance(tags_list, list): tags_set = set(tags_list) elif isinstance(tags_list, (np.ndarray, np.generic)): tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() elif tags_list: tags_set = {tags_list} else: tags_set = set() if undesired_tags.intersection(tags_set): mask[i] = False return df[mask] @st.cache_data(ttl=600) def get_filtered_data(parquet_file, filters_str, sort_option, selected_tags_str, undesired_tags_str, page_number, items_per_page): filters = eval(filters_str) selected_tags = set(eval(selected_tags_str)) undesired_tags = set(eval(undesired_tags_str)) needed_columns = ['post_id', 'tags', 'ava_score', 'aesthetic_score', 'rating', 'large_file_url'] df = get_filtered_batch(parquet_file, filters, needed_columns, sort_option) if selected_tags or undesired_tags: df = process_tags_for_filtering(df, selected_tags, undesired_tags) return df st.title(f'{data_source} Images') metadata = load_parquet_metadata(parquet_file) score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.0, 10.0), step=0.1) score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1) min_post_id = metadata['min_post_id'] max_post_id = metadata['max_post_id'] post_id_range = st.sidebar.slider('Select Post ID range', min_value=min_post_id, max_value=max_post_id, value=(min_post_id, max_post_id), step=1000) available_ratings = metadata['available_ratings'] selected_ratings = st.sidebar.multiselect( 'Select ratings to include', options=available_ratings, default=[], help='Filter images by their rating category' ) page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1) items_per_page = 50 sort_option = st.sidebar.selectbox('Sort by', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0) user_input_tags = st.text_input('Enter tags (space-separated)', value='1girl scenery', help='Filter images based on tags. Use "-" to exclude tags.') selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')]) undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')]) filters = [ ('ava_score', '>=', score_range[0]), ('ava_score', '<=', score_range[1]), ('aesthetic_score', '>=', score_range_v2[0]), ('aesthetic_score', '<=', score_range_v2[1]), ('post_id', '>=', post_id_range[0]), ('post_id', '<=', post_id_range[1]), ] if selected_ratings: filters.append(('rating', 'in', selected_ratings)) filters_str = repr(filters) selected_tags_str = repr(list(selected_tags)) undesired_tags_str = repr(list(undesired_tags)) start_time = time.time() current_batch = get_filtered_data( parquet_file, filters_str, sort_option, selected_tags_str, undesired_tags_str, page_number, items_per_page ) print(f"Data retrieved in {time.time() - start_time:.2f} seconds") batch_start = (page_number - 1) * items_per_page end_idx = min(batch_start + items_per_page, len(current_batch)) current_data = current_batch.iloc[batch_start:end_idx] if batch_start < len(current_batch) else pd.DataFrame() st.sidebar.write(f"Images on this page: {len(current_data)}") st.sidebar.write(f"Total filtered sample: {len(current_batch)}") columns_per_row = 5 rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)] for row in rows: cols = st.columns(columns_per_row) for col, (_, row_data) in zip(cols, row.iterrows()): with col: post_id = row_data.name if data_source == "Danbooru": link = f"https://danbooru.donmai.us/posts/{post_id}" elif data_source == "Gelbooru": link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}" elif data_source == "Rule 34": link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}" st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA: {row_data['ava_score']:.2f}, Aesthetic: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True) def histogram_slider(df, column1, column2): if df.empty: return sample_size = min(5000, len(df)) if len(df) > sample_size: step = len(df) // sample_size indices = np.arange(0, len(df), step)[:sample_size] sample_data = df.iloc[indices] else: sample_data = df hist1, bin_edges1 = np.histogram(sample_data[column1].dropna(), bins=30) hist2, bin_edges2 = np.histogram(sample_data[column2].dropna(), bins=30) fig = go.Figure() fig.add_trace(go.Bar( x=(bin_edges1[:-1] + bin_edges1[1:])/2, y=hist1, name=column1, opacity=0.75, width=(bin_edges1[1]-bin_edges1[0]) )) fig.add_trace(go.Bar( x=(bin_edges2[:-1] + bin_edges2[1:])/2, y=hist2, name=column2, opacity=0.75, width=(bin_edges2[1]-bin_edges2[0]) )) fig.update_layout( barmode='overlay', bargap=0.1, height=200, margin=dict(l=0, r=0, t=0, b=0), legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5), ) st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False}) del sample_data, hist1, hist2, bin_edges1, bin_edges2 gc.collect() if not current_batch.empty: start_time = time.time() histogram_slider(current_batch, 'ava_score', 'aesthetic_score') print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")