Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import json | |
import numpy as np | |
import pandas as pd | |
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
from wordcloud import WordCloud | |
import plotly.graph_objects as go | |
from typing import List | |
import html | |
import google.generativeai as genai | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
def get_available_cluster_files(layer_dir: str) -> List[str]: | |
"""Get list of available cluster files and extract their types and sizes.""" | |
cluster_files = [] | |
for file in os.listdir(layer_dir): | |
if file.startswith('clusters-') and file.endswith('.txt'): | |
# Parse files like 'clusters-agg-10.txt' or 'clusters-kmeans-500.txt' | |
parts = file.replace('.txt', '').split('-') | |
if len(parts) == 3 and parts[2].isdigit(): | |
cluster_files.append(file) | |
for file in sorted(cluster_files): | |
st.sidebar.write(f"- {file}") | |
return sorted(cluster_files) | |
def parse_cluster_filename(filename: str) -> tuple: | |
"""Parse cluster filename to get algorithm and size.""" | |
parts = filename.replace('.txt', '').split('-') | |
return parts[1], int(parts[2]) # returns (algorithm, size) | |
def load_cluster_sentences(model_dir: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str] = None, specific_cluster: str = None): | |
"""Load sentences and their cluster assignments for a given model and layer.""" | |
sentence_file = os.path.join(model_dir, language, "input.in") | |
cluster_file_path = os.path.join(model_dir, language, f"layer{layer}", cluster_type, cluster_file) | |
# Load all sentences first | |
with open(sentence_file, 'r', encoding='utf-8') as f: | |
all_sentences = [line.strip() for line in f] | |
# Process cluster file to get sentence mappings | |
cluster_sentences = defaultdict(list) | |
cluster_tokens = defaultdict(set) # Track all tokens in each cluster | |
# First pass: collect all tokens for each cluster | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
stripped_line = line.strip() | |
pipe_count = stripped_line.count('|') | |
if pipe_count == 13: | |
token = '|' | |
parts = stripped_line.split('|||') | |
cluster_id = parts[4].strip() | |
elif pipe_count == 14: | |
token = '||' | |
parts = stripped_line.split('|||') | |
cluster_id = parts[4].strip() | |
else: | |
parts = stripped_line.split('|||') | |
if len(parts) != 5: | |
continue | |
token = parts[0].strip() | |
cluster_id = parts[4].strip() | |
# Only collect tokens for the specific cluster if specified | |
if specific_cluster is None or cluster_id == specific_cluster: | |
cluster_tokens[cluster_id].add(token) | |
# Second pass: collect sentences | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
stripped_line = line.strip() | |
pipe_count = stripped_line.count('|') | |
if pipe_count == 13: | |
token = '|' | |
parts = stripped_line.split('|||') | |
occurrence = 1 | |
sentence_id = int(parts[2]) | |
token_idx = int(parts[3]) | |
cluster_id = parts[4].strip() | |
elif pipe_count == 14: | |
token = '||' | |
parts = stripped_line.split('|||') | |
occurrence = 1 | |
sentence_id = int(parts[2]) | |
token_idx = int(parts[3]) | |
cluster_id = parts[4].strip() | |
else: | |
parts = stripped_line.split('|||') | |
if len(parts) != 5: | |
continue | |
token = parts[0].strip() | |
try: | |
occurrence = int(parts[1]) | |
sentence_id = int(parts[2]) | |
token_idx = int(parts[3]) | |
cluster_id = parts[4].strip() | |
except ValueError: | |
continue | |
# Include sentences if: | |
# 1. This is the specific cluster we're looking for AND | |
# 2. Either no specific tokens requested OR token is in our search list | |
if ((specific_cluster is None or cluster_id == specific_cluster) and | |
(tokens is None or token in tokens)): | |
if 0 <= sentence_id < len(all_sentences): | |
sentence_tokens = all_sentences[sentence_id].split() | |
if 0 <= token_idx < len(sentence_tokens): | |
# Verify the token actually appears at the specified index | |
if sentence_tokens[token_idx] == token: | |
cluster_sentences[cluster_id].append({ | |
"sentence": all_sentences[sentence_id], | |
"token": token, | |
"token_idx": token_idx, | |
"occurrence": occurrence, | |
"sentence_id": sentence_id, | |
"all_cluster_tokens": cluster_tokens[cluster_id] | |
}) | |
return cluster_sentences | |
def create_sentence_html(sentence, sent_info, cluster_tokens=None): | |
"""Create HTML for sentence with highlighted tokens | |
Args: | |
sentence: The full sentence text | |
sent_info: Dictionary containing token and position info | |
cluster_tokens: Set of all unique tokens in this cluster (unused) | |
""" | |
# Remove the triple quotes and use single quotes to prevent HTML from being escaped | |
html = '<div style="font-family: monospace; padding: 10px; margin: 5px 0; background-color: #f5f5f5; border-radius: 5px;">' | |
html += '<div style="margin-bottom: 5px;">' | |
# Get token information | |
target_token = sent_info["token"] | |
target_idx = sent_info["token_idx"] | |
line_number = sent_info["sentence_id"] | |
# Split the tokenized sentence | |
tokens = sentence.split() | |
# Highlight only the target token in red | |
for i, token in enumerate(tokens): | |
if i == target_idx: | |
# Target token in red | |
html += f'<span style="color: red; font-weight: bold;">{token}</span> ' | |
else: | |
# Regular tokens | |
html += f'{token} ' | |
html += '</div>' | |
html += f'<div style="color: #666; font-size: 0.9em;">Token: <code>{target_token}</code> (Line: {line_number}, Index: {target_idx})</div>' | |
html += '</div>' | |
return html | |
def display_cluster_analysis(model_name: str, language: str, cluster_type: str, selected_layer: int, cluster_file: str, clustering_method: str): | |
"""Display cluster analysis for selected model and layer.""" | |
# Load cluster data | |
cluster_sentences = load_cluster_sentences(model_name, language, cluster_type, selected_layer, cluster_file) | |
# Get clustering algorithm and size from filename | |
algorithm, size = parse_cluster_filename(cluster_file) | |
st.write(f"### Analyzing {algorithm.upper()} clustering with {size} clusters") | |
# Create cluster selection with navigation buttons | |
cluster_ids = sorted(cluster_sentences.keys(), key=lambda x: int(x)) # Sort numerically, no prefix removal needed | |
if not cluster_ids: | |
st.error("No clusters found in the data") | |
return | |
# Create a unique key for this combination of parameters | |
state_key = f"cluster_index_{selected_layer}_{algorithm}_{size}" | |
# Initialize or reset session state if needed | |
if state_key not in st.session_state: | |
st.session_state[state_key] = 0 | |
# Ensure the index is valid for the current number of clusters | |
if st.session_state[state_key] >= len(cluster_ids): | |
st.session_state[state_key] = 0 | |
# Create columns with adjusted widths for better spacing | |
col1, col2, col3, col4 = st.columns([3, 1, 1, 7]) # Adjusted column ratios | |
# Add some vertical space before the controls | |
st.write("") | |
with col1: | |
# Use a key that includes relevant parameters to force refresh | |
select_key = f"cluster_select_{selected_layer}_{algorithm}_{size}" | |
selected_cluster = st.selectbox( | |
"Select cluster", | |
range(len(cluster_ids)), | |
index=st.session_state[state_key], | |
format_func=lambda x: f"Cluster {cluster_ids[x]}", | |
label_visibility="collapsed", # Hides the label but keeps accessibility | |
key=select_key | |
) | |
# Update session state when dropdown changes | |
if selected_cluster != st.session_state[state_key]: | |
st.session_state[state_key] = selected_cluster | |
st.rerun() | |
# Previous cluster button with custom styling | |
with col2: | |
if st.button("◀", use_container_width=True): | |
st.session_state[state_key] = max(0, st.session_state[state_key] - 1) | |
st.rerun() | |
# Next cluster button with custom styling | |
with col3: | |
if st.button("▶", use_container_width=True): | |
st.session_state[state_key] = min(len(cluster_ids) - 1, st.session_state[state_key] + 1) | |
st.rerun() | |
# Add some vertical space after the controls | |
st.write("") | |
# Get the current cluster | |
cluster_id = cluster_ids[st.session_state[state_key]] | |
# Load cluster data with specific cluster ID | |
cluster_sentences = load_cluster_sentences( | |
model_name, | |
language, | |
cluster_type, | |
selected_layer, | |
cluster_file, | |
specific_cluster=cluster_id | |
) | |
sentences_data = cluster_sentences[cluster_id] | |
# Create two columns for the main content | |
col_main, col_chat = st.columns([2, 1]) | |
with col_main: | |
# Get all unique tokens in this cluster | |
cluster_tokens = sentences_data[0]["all_cluster_tokens"] if sentences_data else set() | |
# Display word cloud for this cluster | |
st.write("### Word Cloud") | |
wc = create_wordcloud(sentences_data) | |
if wc: | |
# Create a centered column for the word cloud | |
col1, col2, col3 = st.columns([1, 3, 1]) # Increased middle column width | |
with col2: | |
# Clear any existing matplotlib figures | |
plt.clf() | |
# Create new figure with larger size | |
fig, ax = plt.subplots(figsize=(10, 6)) # Increased from (5, 3) to (10, 6) | |
ax.axis('off') | |
ax.imshow(wc, interpolation='bilinear') | |
# Display the figure | |
st.pyplot(fig) | |
# Clean up | |
plt.close(fig) | |
# Display context sentences for this cluster | |
st.write("### Context Sentences") | |
# Create a dictionary to track sentences by their text | |
unique_sentences = {} | |
# First pass: collect all sentences and their token information | |
for sent_info in sentences_data: | |
# Escape any HTML special characters in the token and sentence | |
sentence_text = html.escape(sent_info["sentence"]) | |
token = html.escape(sent_info["token"]) | |
token_idx = sent_info["token_idx"] | |
if sentence_text not in unique_sentences: | |
unique_sentences[sentence_text] = { | |
"tokens": [(token, token_idx)], | |
"sentence_id": sent_info["sentence_id"] | |
} | |
else: | |
unique_sentences[sentence_text]["tokens"].append((token, token_idx)) | |
# Second pass: display each unique sentence with all its tokens highlighted | |
for sentence_text, info in unique_sentences.items(): | |
# Create HTML with multiple tokens highlighted | |
tokens = sentence_text.split() | |
html_content = """ | |
<div style='font-family: monospace; padding: 10px; margin: 5px 0; background-color: #f5f5f5; border-radius: 5px;'> | |
<div style='margin-bottom: 5px;'> | |
""" | |
# Highlight all relevant tokens in the sentence | |
for i, token in enumerate(tokens): | |
if any(i == idx for _, idx in info["tokens"]): | |
# This is one of our target tokens - highlight it | |
html_content += f"<span style='color: red; font-weight: bold;'>{token}</span> " | |
else: | |
# Regular token | |
html_content += f"{token} " | |
# Add token information footer | |
html_content += f""" | |
</div> | |
<div style='color: #666; font-size: 0.9em;'> | |
Tokens: {", ".join([f"<code>{t}</code> (Index: {idx})" for t, idx in info["tokens"]])} | |
(Line: {info["sentence_id"]}) | |
</div> | |
</div> | |
""" | |
st.markdown(html_content, unsafe_allow_html=True) | |
with col_chat: | |
# Add chat interface | |
chat_with_cluster(sentences_data, clustering_method, cluster_id) | |
def main(): | |
# Set page to use full width | |
st.set_page_config(layout="wide") | |
st.title("Coconet Visual Analysis") | |
# Initialize session state for selections if they don't exist | |
if 'model_name' not in st.session_state: | |
st.session_state.model_name = None | |
if 'selected_language' not in st.session_state: | |
st.session_state.selected_language = None | |
if 'selected_layer' not in st.session_state: | |
st.session_state.selected_layer = None | |
if 'selected_cluster_type' not in st.session_state: | |
st.session_state.selected_cluster_type = None | |
if 'selected_cluster_file' not in st.session_state: | |
st.session_state.selected_cluster_file = None | |
if 'analysis_mode' not in st.session_state: | |
st.session_state.analysis_mode = "Individual Clusters" | |
# Get available models (directories in the current directory) | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
available_models = [d for d in os.listdir(current_dir) | |
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.') and not d == '__pycache__'] | |
# Model selection | |
model_name = st.sidebar.selectbox( | |
"Select Data", | |
available_models, | |
key="model_select", | |
index=available_models.index(st.session_state.model_name) if st.session_state.model_name in available_models else 0 | |
) | |
if model_name != st.session_state.model_name: | |
st.session_state.model_name = model_name | |
st.session_state.selected_language = None | |
st.session_state.selected_layer = None | |
st.session_state.selected_cluster_type = None | |
st.session_state.selected_cluster_file = None | |
if not model_name: | |
st.error("No models found") | |
return | |
# Get available languages for the selected model | |
model_dir = os.path.join(current_dir, model_name) | |
available_languages = [d for d in os.listdir(model_dir) | |
if os.path.isdir(os.path.join(model_dir, d))] | |
# Language selection | |
selected_language = st.sidebar.selectbox( | |
"Select Language", | |
available_languages, | |
key="language_select", | |
index=available_languages.index(st.session_state.selected_language) if st.session_state.selected_language in available_languages else 0 | |
) | |
if selected_language != st.session_state.selected_language: | |
st.session_state.selected_language = selected_language | |
st.session_state.selected_layer = None | |
st.session_state.selected_cluster_type = None | |
st.session_state.selected_cluster_file = None | |
if not selected_language: | |
st.error("No languages found for selected model") | |
return | |
# Get available layers | |
language_dir = os.path.join(model_dir, selected_language) | |
layer_dirs = [d for d in os.listdir(language_dir) | |
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d))] | |
if not layer_dirs: | |
st.error("No layer directories found") | |
return | |
# Extract layer numbers | |
available_layers = sorted([int(d.replace('layer', '')) for d in layer_dirs]) | |
# Layer selection | |
selected_layer = st.sidebar.selectbox( | |
"Select Layer", | |
available_layers, | |
key="layer_select", | |
index=available_layers.index(st.session_state.selected_layer) if st.session_state.selected_layer in available_layers else 0 | |
) | |
if selected_layer != st.session_state.selected_layer: | |
st.session_state.selected_layer = selected_layer | |
st.session_state.selected_cluster_type = None | |
st.session_state.selected_cluster_file = None | |
# Get available clustering types | |
layer_dir = os.path.join(language_dir, f"layer{selected_layer}") | |
available_cluster_types = [d for d in os.listdir(layer_dir) | |
if os.path.isdir(os.path.join(layer_dir, d))] | |
# Clustering type selection | |
selected_cluster_type = st.sidebar.selectbox( | |
"Select Clustering Type", | |
available_cluster_types, | |
key="cluster_type_select", | |
index=available_cluster_types.index(st.session_state.selected_cluster_type) if st.session_state.selected_cluster_type in available_cluster_types else 0 | |
) | |
if selected_cluster_type != st.session_state.selected_cluster_type: | |
st.session_state.selected_cluster_type = selected_cluster_type | |
st.session_state.selected_cluster_file = None | |
if not selected_cluster_type: | |
st.error("No clustering types found for selected layer") | |
return | |
# Get available cluster files | |
cluster_dir = os.path.join(layer_dir, selected_cluster_type) | |
available_cluster_files = get_available_cluster_files(cluster_dir) | |
if not available_cluster_files: | |
st.error("No cluster files found in the selected layer") | |
return | |
# Cluster file selection | |
selected_cluster_file = st.sidebar.selectbox( | |
"Select Clustering", | |
available_cluster_files, | |
key="cluster_file_select", | |
format_func=lambda x: f"{parse_cluster_filename(x)[0].upper()} (k={parse_cluster_filename(x)[1]})", | |
index=available_cluster_files.index(st.session_state.selected_cluster_file) if st.session_state.selected_cluster_file in available_cluster_files else 0 | |
) | |
st.session_state.selected_cluster_file = selected_cluster_file | |
# Analysis mode selection | |
analysis_mode = st.sidebar.radio( | |
"Select Analysis Mode", | |
["Individual Clusters", "Search And Analysis", "Search By Line", | |
"Line Token Distribution", "Token Pairs", "View Input File"], | |
key="analysis_mode_select", | |
index=["Individual Clusters", "Search And Analysis", "Search By Line", | |
"Line Token Distribution", "Token Pairs", "View Input File"].index(st.session_state.analysis_mode) | |
) | |
st.session_state.analysis_mode = analysis_mode | |
# Call appropriate analysis function based on mode | |
if analysis_mode == "Individual Clusters": | |
display_cluster_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file, selected_cluster_type) | |
elif analysis_mode == "Search And Analysis": | |
handle_token_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) | |
elif analysis_mode == "Search By Line": | |
handle_line_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) | |
elif analysis_mode == "Line Token Distribution": | |
handle_line_token_distribution(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) | |
elif analysis_mode == "Token Pairs": | |
display_token_pair_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) | |
elif analysis_mode == "View Input File": | |
display_input_file(model_name, selected_language) | |
def display_token_evolution(evolution_data: dict, tokens: List[str]): | |
"""Display evolution analysis for tokens""" | |
st.write(f"### Evolution Analysis for Token(s)") | |
# Create main evolution graph | |
fig = go.Figure() | |
# Colors for different types of lines | |
colors = { | |
'individual': ['#3498db', '#e74c3c', '#2ecc71'], # Blue, Red, Green | |
'exclusive': ['#9b59b6', '#f1c40f', '#1abc9c'], # Purple, Yellow, Turquoise | |
'combined': '#34495e' # Dark Gray | |
} | |
# Add individual count lines | |
for i, token in enumerate(tokens): | |
fig.add_trace(go.Scatter( | |
x=evolution_data['layers'], | |
y=evolution_data['individual_counts'][token], | |
name=f"'{token}' (Total)", | |
mode='lines+markers', | |
line=dict(color=colors['individual'][i], width=2), | |
marker=dict(size=8) | |
)) | |
# Add exclusive count lines only for multiple tokens | |
if len(tokens) > 1: | |
fig.add_trace(go.Scatter( | |
x=evolution_data['layers'], | |
y=evolution_data['exclusive_counts'][token], | |
name=f"'{token}' (Exclusive)", | |
mode='lines+markers', | |
line=dict(color=colors['exclusive'][i], width=2, dash='dot'), | |
marker=dict(size=8) | |
)) | |
# Add combined counts if multiple tokens | |
if len(tokens) > 1: | |
fig.add_trace(go.Scatter( | |
x=evolution_data['layers'], | |
y=evolution_data['combined_counts'], | |
name='Co-occurring', | |
mode='lines+markers', | |
line=dict(color=colors['combined'], width=2), | |
marker=dict(size=8) | |
)) | |
# Update layout | |
fig.update_layout( | |
title=dict( | |
text='Token Evolution Across Layers', | |
font=dict(size=20) | |
), | |
xaxis_title=dict( | |
text='Layer', | |
font=dict(size=14) | |
), | |
yaxis_title=dict( | |
text='Number of Clusters', | |
font=dict(size=14) | |
), | |
hovermode='x unified', | |
showlegend=True, | |
legend=dict( | |
yanchor="top", | |
y=0.99, | |
xanchor="left", | |
x=0.01 | |
) | |
) | |
# Add gridlines | |
fig.update_xaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot') | |
fig.update_yaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot') | |
st.plotly_chart(fig, use_container_width=True) | |
def find_clusters_for_token(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, search_token: str) -> set: | |
"""Find cluster IDs containing the exact token""" | |
matching_clusters = set() | |
try: | |
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
stripped_line = line.strip() | |
pipe_count = stripped_line.count('|') | |
# Get token and cluster ID based on pipe count | |
if pipe_count == 13: | |
token = '|' | |
parts = stripped_line.split('|||') | |
cluster_id = parts[4].strip() | |
elif pipe_count == 14: | |
token = '||' | |
parts = stripped_line.split('|||') | |
cluster_id = parts[4].strip() | |
else: | |
parts = stripped_line.split('|||') | |
if len(parts) != 5: | |
continue | |
token = parts[0].strip() | |
cluster_id = parts[4].strip() | |
# Use exact token matching | |
if token == search_token: | |
matching_clusters.add(cluster_id) | |
except Exception as e: | |
st.error(f"Error reading cluster file: {e}") | |
return set() | |
return matching_clusters | |
def handle_token_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): | |
"""Handle token search functionality""" | |
st.write("### Token Search") | |
# Initialize session state for search results if needed | |
if 'search_results_state' not in st.session_state: | |
st.session_state.search_results_state = { | |
'matching_tokens': [], | |
'matching_tokens2': [], | |
'last_search': None, | |
'last_search2': None, | |
'search_mode': 'single' | |
} | |
elif 'search_mode' not in st.session_state.search_results_state: | |
st.session_state.search_results_state['search_mode'] = 'single' | |
# Radio button for search mode | |
search_mode = st.radio( | |
"Search Mode", | |
["Single Token", "Token Pair"], | |
key="search_mode_radio", | |
index=0 if st.session_state.search_results_state['search_mode'] == 'single' else 1 | |
) | |
# Update search mode in session state | |
st.session_state.search_results_state['search_mode'] = 'single' if search_mode == "Single Token" else 'pair' | |
if search_mode == "Single Token": | |
# Single token search interface | |
search_token = st.text_input("Search for token:") | |
if search_token: | |
# Find matching clusters | |
clusters = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
search_token | |
) | |
if clusters: | |
selected_token = search_token | |
if selected_token: | |
# Update state | |
st.session_state.search_results_state.update({ | |
'matching_tokens': [selected_token], | |
'last_search': selected_token | |
}) | |
# Display cluster details | |
display_cluster_details( | |
model_name, | |
language, | |
cluster_type, | |
selected_token, | |
cluster_file | |
) | |
else: | |
st.warning(f"No clusters found containing token: '{search_token}'") | |
else: # Token Pair search | |
col1, col2 = st.columns(2) | |
with col1: | |
search_token1 = st.text_input("Search for first token:") | |
with col2: | |
search_token2 = st.text_input("Search for second token:") | |
if search_token1 and search_token2: | |
# Find matching clusters for both tokens | |
clusters1 = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
search_token1 | |
) | |
clusters2 = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
search_token2 | |
) | |
# Find intersection of clusters | |
matching_clusters = clusters1 & clusters2 | |
if matching_clusters: | |
# Update state | |
st.session_state.search_results_state.update({ | |
'matching_tokens': [search_token1], | |
'matching_tokens2': [search_token2], | |
'last_search': search_token1, | |
'last_search2': search_token2 | |
}) | |
# Display cluster details | |
display_cluster_details( | |
model_name, | |
language, | |
cluster_type, | |
search_token1, | |
cluster_file, | |
second_token=search_token2 | |
) | |
else: | |
st.warning(f"No clusters found containing both tokens: '{search_token1}' and '{search_token2}'") | |
def analyze_token_evolution(model_name: str, language: str, cluster_type: str, layer: int, tokens: List[str], cluster_file: str) -> dict: | |
"""Analyze token evolution across all available layers""" | |
# Get all available layers by checking directories | |
language_dir = os.path.join(model_name, language) | |
available_layers = [] | |
for d in os.listdir(language_dir): | |
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)): | |
try: | |
layer_num = int(d.replace('layer', '')) | |
available_layers.append(layer_num) | |
except ValueError: | |
continue | |
available_layers.sort() # Sort layers numerically | |
evolution_data = { | |
'layers': available_layers, | |
'individual_counts': {token: [] for token in tokens}, | |
'exclusive_counts': {token: [] for token in tokens}, # New: track exclusive counts | |
'combined_counts': [] if len(tokens) > 1 else None | |
} | |
# Extract cluster size from the filename (e.g., "clusters-agg-500.txt" -> "500") | |
cluster_size = cluster_file.split('-')[-1].replace('.txt', '') | |
# Handle shortened form for agglomerative clustering | |
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type | |
for current_layer in available_layers: | |
# Get clusters for each token | |
token_clusters = {} | |
cluster_file_path = os.path.join( | |
model_name, | |
language, | |
f"layer{current_layer}", | |
cluster_type, | |
f"clusters-{cluster_type_short}-{cluster_size}.txt" | |
) | |
# Skip layer if cluster file doesn't exist | |
if not os.path.exists(cluster_file_path): | |
continue | |
for token in tokens: | |
clusters = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
current_layer, | |
f"clusters-{cluster_type_short}-{cluster_size}.txt", | |
token | |
) | |
token_clusters[token] = set(clusters) | |
evolution_data['individual_counts'][token].append(len(clusters)) | |
# Calculate exclusive and co-occurring clusters | |
if len(tokens) > 1: | |
# Calculate co-occurrences | |
cooccurring_clusters = set.intersection(*[token_clusters[token] for token in tokens]) | |
evolution_data['combined_counts'].append(len(cooccurring_clusters)) | |
# Calculate exclusive counts for each token | |
for token in tokens: | |
other_tokens = set(tokens) - {token} | |
other_clusters = set.union(*[token_clusters[t] for t in other_tokens]) if other_tokens else set() | |
exclusive_clusters = token_clusters[token] - other_clusters | |
evolution_data['exclusive_counts'][token].append(len(exclusive_clusters)) | |
else: | |
# For single token, exclusive count is the same as individual count | |
evolution_data['exclusive_counts'][tokens[0]] = evolution_data['individual_counts'][tokens[0]] | |
return evolution_data | |
def find_clusters_with_multiple_tokens(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str]) -> dict: | |
"""Find clusters containing multiple specified tokens""" | |
clusters = defaultdict(lambda: {'matching_tokens': {token: set() for token in tokens}}) | |
try: | |
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
parts = line.strip().split('|||') | |
if len(parts) == 5: | |
token = parts[0].strip() | |
cluster_id = parts[4].strip() | |
for search_token in tokens: | |
if token == search_token: | |
# Add all tokens from this cluster | |
with open(cluster_file_path, 'r', encoding='utf-8') as f2: | |
for line2 in f2: | |
parts2 = line2.strip().split('|||') | |
if len(parts2) == 5 and parts2[4].strip() == cluster_id: | |
clusters[cluster_id]['matching_tokens'][search_token].add(parts2[0].strip()) | |
except Exception as e: | |
st.error(f"Error reading cluster file: {e}") | |
return {} | |
# Filter to only keep clusters with all tokens | |
return {k: v for k, v in clusters.items() if all(v['matching_tokens'][token] for token in tokens)} | |
def handle_semantic_tag_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): | |
"""Handle semantic tag search functionality""" | |
st.write("### Semantic Tag Search") | |
st.info("This feature will be implemented soon.") | |
def display_token_pair_analysis(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): | |
"""Display analysis for predefined token pairs""" | |
st.write("### Token Pair Analysis") | |
# Get predefined token pairs | |
token_pairs = get_predefined_token_pairs() | |
# Create tabs for each category | |
tabs = st.tabs(list(token_pairs.keys())) | |
for tab, (category, data) in zip(tabs, token_pairs.items()): | |
with tab: | |
st.write(f"### {category}") | |
st.write(data["description"]) | |
# Instead of using expanders, use a selectbox to choose the token pair | |
pairs = data["pairs"] | |
pair_labels = [f"{t1} vs {t2}" for t1, t2 in pairs] | |
selected_pair_idx = st.selectbox( | |
"Select token pair", | |
range(len(pairs)), | |
format_func=lambda i: pair_labels[i], | |
key=f"pair_select_{category}" | |
) | |
# Get the selected pair | |
token1, token2 = pairs[selected_pair_idx] | |
# Update state for token pair search | |
st.session_state.search_results_state = { | |
'matching_tokens': [token1], | |
'matching_tokens2': [token2], | |
'last_search': token1, | |
'last_search2': token2, | |
'search_mode': 'pair' | |
} | |
# Display results | |
st.write("### Search Results") | |
# Create tabs for different views | |
tab1, tab2 = st.tabs(["Evolution Analysis", "Co-occurring Clusters"]) | |
with tab1: | |
evolution_data = analyze_token_evolution( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
[token1, token2], | |
cluster_file | |
) | |
if evolution_data: | |
display_token_evolution(evolution_data, [token1, token2]) | |
with tab2: | |
display_cluster_details( | |
model_name, | |
language, | |
cluster_type, | |
token1, | |
cluster_file, | |
second_token=token2 | |
) | |
def get_predefined_token_pairs(): | |
"""Return predefined token pairs organized by categories""" | |
return { | |
"Control Flow": { | |
"description": "Different control flow constructs", | |
"pairs": [ | |
("for", "while"), | |
("if", "switch"), | |
("break", "continue"), | |
("try", "catch") | |
] | |
}, | |
"Access Modifiers": { | |
"description": "Access and modifier keywords", | |
"pairs": [ | |
("public", "private"), | |
("static", "final"), | |
("abstract", "interface") | |
] | |
}, | |
"Variable/Type": { | |
"description": "Variable and type-related tokens", | |
"pairs": [ | |
("int", "Integer"), | |
("null", "Optional"), | |
("var", "String") # Example of var vs explicit type | |
] | |
}, | |
"Collections": { | |
"description": "Collection-related tokens", | |
"pairs": [ | |
("List", "Array"), | |
("ArrayList", "LinkedList"), | |
("HashMap", "TreeMap"), | |
("Set", "List") | |
] | |
}, | |
"Threading": { | |
"description": "Threading and concurrency tokens", | |
"pairs": [ | |
("synchronized", "volatile"), | |
("Runnable", "Callable"), | |
("wait", "sleep") | |
] | |
}, | |
"Object-Oriented": { | |
"description": "Object-oriented programming tokens", | |
"pairs": [ | |
("extends", "implements"), | |
("this", "super"), | |
("new", "clone") | |
] | |
} | |
} | |
def create_wordcloud(tokens, token1=None, token2=None): | |
"""Create and return a word cloud from tokens with frequencies""" | |
if not tokens: | |
return None | |
# Create frequency dict by counting occurrences | |
freq_dict = {} | |
# If tokens is a list of dictionaries (from cluster data) | |
if isinstance(tokens, list) and tokens and isinstance(tokens[0], dict): | |
# Count token occurrences in the cluster | |
for token_info in tokens: | |
token = token_info["token"] | |
freq_dict[token] = freq_dict.get(token, 0) + 1 | |
else: | |
# If tokens is a set/list of strings, convert to frequency dict | |
tokens_list = list(tokens) if isinstance(tokens, set) else tokens | |
for token in tokens_list: | |
freq_dict[token] = freq_dict.get(token, 0) + 1 | |
# Normalize frequencies with a base size | |
max_freq = max(freq_dict.values()) | |
base_size = 30 # Base size for all tokens | |
normalized_freq = {token: base_size + ((count / max_freq) * 70) for token, count in freq_dict.items()} | |
# Boost frequency of searched tokens if provided | |
if token1: | |
normalized_freq[token1] = normalized_freq.get(token1, 0) + 5 | |
if token2: | |
normalized_freq[token2] = normalized_freq.get(token2, 0) + 5 | |
# Custom colormap with dark shades of brown, green, and blue | |
wc = WordCloud( | |
width=800, height=400, | |
background_color='white', | |
max_words=100, | |
prefer_horizontal=1.0, # Make all words horizontal | |
colormap='Dark2' # Dark colormap with browns, greens, blues | |
).generate_from_frequencies(normalized_freq) | |
return wc | |
def display_cluster_details(model_name: str, language: str, cluster_type: str, token: str, cluster_file: str, second_token: str = None): | |
"""Display detailed cluster information organized by layers""" | |
# Get all available layers | |
language_dir = os.path.join(model_name, language) | |
available_layers = [] | |
for d in os.listdir(language_dir): | |
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)): | |
try: | |
layer_num = int(d.replace('layer', '')) | |
available_layers.append(layer_num) | |
except ValueError: | |
continue | |
available_layers.sort() | |
# Create tabs for each layer | |
layer_tabs = st.tabs([f"Layer {layer}" for layer in available_layers]) | |
# Handle shortened form for agglomerative clustering | |
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type | |
cluster_size = cluster_file.split('-')[-1].replace('.txt', '') | |
for layer, tab in zip(available_layers, layer_tabs): | |
with tab: | |
# Find clusters containing the token(s) | |
matching_clusters = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
f"clusters-{cluster_type_short}-{cluster_size}.txt", | |
token | |
) | |
if second_token: | |
matching_clusters2 = find_clusters_for_token( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
f"clusters-{cluster_type_short}-{cluster_size}.txt", | |
second_token | |
) | |
# Find intersection of clusters containing both tokens | |
matching_clusters &= matching_clusters2 | |
if matching_clusters: | |
# Sort cluster IDs numerically | |
cluster_ids = sorted(matching_clusters, key=lambda x: int(x)) | |
# Create dropdown for cluster selection | |
selected_cluster = st.selectbox( | |
f"Select cluster from Layer {layer}", | |
cluster_ids, | |
format_func=lambda x: f"Cluster {x}", | |
key=f"cluster_select_{layer}_{token}_{second_token if second_token else ''}" | |
) | |
if selected_cluster: | |
# Load all cluster data for the selected cluster | |
cluster_data = load_cluster_sentences( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
f"clusters-{cluster_type_short}-{cluster_size}.txt", | |
specific_cluster=selected_cluster | |
) | |
if selected_cluster in cluster_data: | |
# Create two columns for the main content | |
col_main, col_chat = st.columns([2, 1]) | |
with col_main: | |
shown_sentences = set() | |
# Get all unique tokens in this cluster | |
cluster_tokens = cluster_data[selected_cluster][0]["all_cluster_tokens"] if cluster_data[selected_cluster] else set() | |
# Display word cloud for this cluster | |
st.write("### Word Cloud") | |
wc = create_wordcloud(cluster_data[selected_cluster], token1=token, token2=second_token) | |
if wc: | |
# Create a centered column for the word cloud | |
col1, col2, col3 = st.columns([1, 3, 1]) | |
with col2: | |
plt.clf() | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.axis('off') | |
ax.imshow(wc, interpolation='bilinear') | |
st.pyplot(fig) | |
plt.close(fig) | |
# First show sentences containing searched tokens | |
if second_token: | |
st.write(f"#### Context for searched tokens: '{token}' and '{second_token}'") | |
else: | |
st.write(f"#### Context for searched token: '{token}'") | |
html_output = [] | |
for sent_info in cluster_data[selected_cluster]: | |
if sent_info["token"] in [token, second_token] and sent_info["sentence"] not in shown_sentences: | |
html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) | |
shown_sentences.add(sent_info["sentence"]) | |
html_output.append("<hr>") | |
if html_output: | |
st.markdown("\n".join(html_output), unsafe_allow_html=True) | |
# Then show all other sentences in the cluster | |
st.write("#### All sentences in cluster") | |
html_output = [] | |
for sent_info in cluster_data[selected_cluster]: | |
if sent_info["sentence"] not in shown_sentences: | |
html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) | |
shown_sentences.add(sent_info["sentence"]) | |
html_output.append("<hr>") | |
if not html_output: | |
st.info("No additional sentences in this cluster.") | |
else: | |
st.markdown("\n".join(html_output), unsafe_allow_html=True) | |
with col_chat: | |
# Add chat interface with cluster ID | |
chat_with_cluster(cluster_data[selected_cluster], cluster_type, selected_cluster) | |
else: | |
st.info(f"No sentences found for cluster {selected_cluster}") | |
else: | |
if second_token: | |
st.info(f"No clusters containing both '{token}' and '{second_token}' found in Layer {layer}") | |
else: | |
st.info(f"No clusters containing '{token}' found in Layer {layer}") | |
def display_input_file(model_name: str, language: str): | |
"""Display the contents of input.in file""" | |
st.write("### Input File Contents") | |
input_file = os.path.join(model_name, language, "input.in") | |
try: | |
with open(input_file, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
# Add line numbers and display in a scrollable container | |
numbered_lines = [f"{i+1:4d} | {line}" for i, line in enumerate(lines)] | |
st.code('\n'.join(numbered_lines), language='text') | |
# Display some statistics | |
st.write(f"Total lines: {len(lines)}") | |
except Exception as e: | |
st.error(f"Error reading input file: {e}") | |
from dotenv import load_dotenv | |
def setup_gemini(): | |
"""Setup Gemini model with API key and temperature setting""" | |
load_dotenv() # Load environment variables from .env file | |
api_key = os.getenv("GEMINI_API_KEY") | |
if not api_key: | |
st.error("Please set GEMINI_API_KEY in your .env file") | |
return False | |
genai.configure(api_key=api_key) | |
# Create generation config with temperature 0.4 | |
generation_config = { | |
"temperature": 0.4, | |
"top_p": 1, | |
"top_k": 1, | |
"max_output_tokens": 2048, | |
} | |
return generation_config | |
def get_cluster_context(cluster_sentences): | |
"""Format cluster data into a clear context for Gemini, focusing on searched tokens only""" | |
# Get unique searched tokens and their frequencies | |
token_counts = {} | |
for sent_info in cluster_sentences: | |
token = sent_info["token"] | |
token_counts[token] = token_counts.get(token, 0) + 1 | |
# Format the context | |
context = "Here is the data for this cluster:\n\n" | |
# Add token frequency information | |
context += "Tokens being analyzed:\n" | |
for token, count in sorted(token_counts.items(), key=lambda x: (-x[1], x[0])): | |
context += f"- '{token}' (appears {count} times)\n" | |
# Add sentence examples with token highlighting | |
context += "\nExample sentences (analyzed tokens marked with *):\n" | |
unique_sentences = {} | |
for sent_info in cluster_sentences: | |
sentence = sent_info["sentence"] | |
if sentence not in unique_sentences: | |
tokens = sentence.split() | |
marked_tokens = tokens.copy() | |
marked_tokens[sent_info["token_idx"]] = f"*{tokens[sent_info['token_idx']]}*" | |
unique_sentences[sentence] = " ".join(marked_tokens) | |
for marked_sentence in unique_sentences.values(): | |
context += f"- {marked_sentence}\n" | |
return context | |
def chat_with_cluster(cluster_sentences, clustering_method, cluster_id=None): | |
"""Create a chat interface for discussing the cluster with Gemini""" | |
# Initialize Gemini with temperature setting | |
generation_config = setup_gemini() | |
if not generation_config: | |
return | |
model = genai.GenerativeModel('gemini-2.0-flash', generation_config=generation_config) | |
# Create a unique key for this specific cluster's chat | |
cluster_tokens = {sent_info["token"] for sent_info in cluster_sentences} | |
# Include cluster_id in the key if provided | |
cluster_key = f"cluster_chat_{cluster_id}_{'-'.join(sorted(cluster_tokens))}" if cluster_id else f"cluster_chat_{'-'.join(sorted(cluster_tokens))}" | |
history_key = f"{cluster_key}_history" | |
# Reset chat and history on each page load | |
st.session_state[cluster_key] = model.start_chat(history=[]) | |
st.session_state[history_key] = [] | |
# Get cluster context and send initial message | |
context = get_cluster_context(cluster_sentences) | |
initial_message = f"""You are a helpful assistant analyzing Java code clusters. You will help users understand | |
patterns and relationships in the provided cluster data. Here is the context for this cluster: | |
{context} | |
Please analyze this data and help users understand what each token is doing in the cluster. Be concise and to the point. | |
""" | |
# Send initial message if history is empty | |
if not st.session_state[history_key]: | |
response = st.session_state[cluster_key].send_message(initial_message) | |
st.session_state[history_key].append(("assistant", response.text)) | |
# Display chat interface | |
st.write("### Chat with Gemini about this Cluster") | |
st.write("Ask questions about patterns, relationships, or insights in this cluster.") | |
# Display chat history | |
for role, message in st.session_state[history_key]: | |
with st.chat_message(role): | |
st.write(message) | |
# Chat input with cluster-specific key | |
user_input = st.text_input("Your question:", key=f"input_{cluster_key}") | |
if user_input: | |
try: | |
# Display user message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Store user message in history | |
st.session_state[history_key].append(("user", user_input)) | |
# Get and display Gemini's response | |
with st.chat_message("assistant"): | |
response = st.session_state[cluster_key].send_message(user_input) | |
st.write(response.text) | |
# Store assistant's response in history | |
st.session_state[history_key].append(("assistant", response.text)) | |
except Exception as e: | |
st.error(f"Error communicating with Gemini: {e}") | |
def handle_line_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): | |
"""Handle line number search functionality""" | |
st.write("### Search by Line Number") | |
# Create two columns for search inputs | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
# Input for line number | |
line_number = st.number_input("Enter line number:", min_value=1, step=1) | |
with col2: | |
# Optional token search | |
token_filter = st.text_input("Filter by token (optional):", "") | |
if st.button("Search"): | |
# Find clusters containing the line number | |
clusters = find_clusters_for_line( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
line_number | |
) | |
if clusters: | |
if token_filter: | |
# Filter clusters to only those containing both the line and the token | |
filtered_clusters = set() | |
for cluster_id in clusters: | |
cluster_data = load_cluster_sentences( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
specific_cluster=cluster_id | |
) | |
if any(sent_info["token"] == token_filter for sent_info in cluster_data[cluster_id]): | |
filtered_clusters.add(cluster_id) | |
clusters = filtered_clusters | |
if not clusters: | |
st.warning(f"No clusters found containing both line {line_number} and token '{token_filter}'") | |
return | |
st.success(f"Found {len(clusters)} clusters containing line {line_number} and token '{token_filter}'") | |
else: | |
st.success(f"Found {len(clusters)} clusters containing line {line_number}") | |
# Create tabs for each cluster | |
cluster_tabs = st.tabs([f"Cluster {cluster_id}" for cluster_id in sorted(clusters)]) | |
for tab, cluster_id in zip(cluster_tabs, sorted(clusters)): | |
with tab: | |
# Load cluster data | |
cluster_data = load_cluster_sentences( | |
model_name, | |
language, | |
cluster_type, | |
layer, | |
cluster_file, | |
specific_cluster=cluster_id | |
) | |
if cluster_id in cluster_data: | |
# Create two columns for the main content | |
col_main, col_chat = st.columns([2, 1]) | |
with col_main: | |
# Display word cloud | |
st.write("### Word Cloud") | |
wc = create_wordcloud(cluster_data[cluster_id], token_filter if token_filter else None) | |
if wc: | |
col1, col2, col3 = st.columns([1, 3, 1]) | |
with col2: | |
plt.clf() | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.axis('off') | |
ax.imshow(wc, interpolation='bilinear') | |
st.pyplot(fig) | |
plt.close(fig) | |
# Display sentences | |
st.write("### Sentences in Cluster") | |
shown_sentences = set() | |
# First show the searched line | |
st.write("#### Searched Line") | |
html_output = [] | |
for sent_info in cluster_data[cluster_id]: | |
if sent_info["sentence_id"] == line_number - 1: # Adjust for 0-based indexing | |
html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) | |
shown_sentences.add(sent_info["sentence"]) | |
break | |
if html_output: | |
st.markdown("\n".join(html_output), unsafe_allow_html=True) | |
# Then show sentences with the filtered token (if specified) | |
if token_filter: | |
st.write(f"#### Sentences containing '{token_filter}'") | |
html_output = [] | |
for sent_info in cluster_data[cluster_id]: | |
if (sent_info["token"] == token_filter and | |
sent_info["sentence"] not in shown_sentences): | |
html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) | |
shown_sentences.add(sent_info["sentence"]) | |
if html_output: | |
st.markdown("\n".join(html_output), unsafe_allow_html=True) | |
# Finally show other sentences | |
st.write("#### Other sentences in cluster") | |
html_output = [] | |
for sent_info in cluster_data[cluster_id]: | |
if sent_info["sentence"] not in shown_sentences: | |
html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) | |
shown_sentences.add(sent_info["sentence"]) | |
if html_output: | |
st.markdown("\n".join(html_output), unsafe_allow_html=True) | |
with col_chat: | |
# Add chat interface with cluster ID | |
chat_with_cluster(cluster_data[cluster_id], cluster_type, cluster_id) | |
else: | |
st.warning(f"No clusters found containing line {line_number}") | |
def find_clusters_for_line(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, line_number: int) -> set: | |
"""Find cluster IDs containing the specified line number""" | |
matching_clusters = set() | |
try: | |
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
stripped_line = line.strip() | |
parts = stripped_line.split('|||') | |
if len(parts) >= 3: | |
try: | |
sentence_id = int(parts[2]) | |
if sentence_id == line_number - 1: # Adjust for 0-based indexing | |
cluster_id = parts[4].strip() | |
matching_clusters.add(cluster_id) | |
except (ValueError, IndexError): | |
continue | |
except Exception as e: | |
st.error(f"Error reading cluster file: {e}") | |
return set() | |
return matching_clusters | |
# Add new function to handle line token distribution | |
def handle_line_token_distribution(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): | |
"""Display cluster distribution for each token in a specific line""" | |
st.write("### Line Token Distribution") | |
# Input for line number | |
line_number = st.number_input("Enter line number:", min_value=1, step=1) | |
if st.button("Analyze"): | |
# Load the input file to get the line content | |
input_file = os.path.join(model_name, language, "input.in") | |
try: | |
with open(input_file, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
if line_number <= len(lines): | |
line_content = lines[line_number - 1].strip() | |
tokens = line_content.split() | |
# Create a dictionary to store cluster assignments for each token | |
token_clusters = {} | |
# Find clusters for each token in the line | |
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) | |
with open(cluster_file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
parts = line.strip().split('|||') | |
if len(parts) >= 5: | |
try: | |
sentence_id = int(parts[2]) | |
token_idx = int(parts[3]) | |
if sentence_id == line_number - 1: # Adjust for 0-based indexing | |
token = parts[0].strip() | |
cluster_id = parts[4].strip() | |
token_clusters[(token, token_idx)] = cluster_id | |
except (ValueError, IndexError): | |
continue | |
# Find clusters that contain different tokens | |
cluster_to_unique_tokens = {} | |
for (token, idx), cluster in token_clusters.items(): | |
if cluster not in cluster_to_unique_tokens: | |
cluster_to_unique_tokens[cluster] = set() | |
cluster_to_unique_tokens[cluster].add(token) | |
# Filter for clusters with different tokens (more than one unique token) | |
clusters_with_different_tokens = { | |
cluster: tokens for cluster, tokens in cluster_to_unique_tokens.items() | |
if len(tokens) > 1 | |
} | |
has_mixed_clusters = len(clusters_with_different_tokens) > 0 | |
# Display results | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.write("#### Line Content:") | |
st.code(line_content, language="java") | |
with col2: | |
st.write("#### Uniqueness Check:") | |
st.checkbox("All different tokens have unique clusters", value=not has_mixed_clusters, disabled=True) | |
# Show clusters with different tokens | |
if has_mixed_clusters: | |
st.write("##### Clusters with Different Tokens:") | |
for cluster, unique_tokens in clusters_with_different_tokens.items(): | |
# Get all indices for each token in this cluster | |
token_positions = {} | |
for token in unique_tokens: | |
positions = [idx for (t, idx), c in token_clusters.items() | |
if t == token and c == cluster] | |
token_positions[token] = positions | |
# Format the display string | |
tokens_str = ", ".join( | |
f"'{token}' (idx: {', '.join(map(str, positions))})" | |
for token, positions in token_positions.items() | |
) | |
st.markdown(f"- Cluster **{cluster}**: {tokens_str}") | |
st.write("#### Token Distribution:") | |
# Create a table showing token distributions | |
data = [] | |
for i, token in enumerate(tokens): | |
cluster = token_clusters.get((token, i), "N/A") | |
data.append({ | |
"Token Index": i, | |
"Token": token, | |
"Cluster": cluster | |
}) | |
df = pd.DataFrame(data) | |
st.table(df) | |
# Create a visualization of the distribution | |
st.write("#### Distribution Visualization:") | |
# Create a Sankey diagram | |
source = [] | |
target = [] | |
value = [] | |
label = [] | |
# Add tokens as source nodes | |
for i, token in enumerate(tokens): | |
source.append(i) | |
cluster = token_clusters.get((token, i), "N/A") | |
if cluster == "N/A": | |
target_idx = len(tokens) | |
else: | |
if cluster not in label[len(tokens):]: | |
label.append(cluster) | |
target_idx = len(tokens) + label[len(tokens):].index(cluster) | |
target.append(target_idx) | |
value.append(1) | |
if i == 0: | |
label.extend(tokens) | |
# Create and display the Sankey diagram | |
fig = go.Figure(data=[go.Sankey( | |
node=dict( | |
pad=15, | |
thickness=20, | |
line=dict(color="black", width=0.5), | |
label=label, | |
color="blue" | |
), | |
link=dict( | |
source=source, | |
target=target, | |
value=value | |
) | |
)]) | |
fig.update_layout(title_text="Token to Cluster Distribution", font_size=10) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.error(f"Line number {line_number} is out of range. File has {len(lines)} lines.") | |
except Exception as e: | |
st.error(f"Error analyzing line: {e}") | |
if __name__ == "__main__": | |
main() |