Coconet_Visual / app.py
Vedant Pungliya
Files
af19b86 unverified
import streamlit as st
import os
import json
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import plotly.graph_objects as go
from typing import List
import html
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
def get_available_cluster_files(layer_dir: str) -> List[str]:
"""Get list of available cluster files and extract their types and sizes."""
cluster_files = []
for file in os.listdir(layer_dir):
if file.startswith('clusters-') and file.endswith('.txt'):
# Parse files like 'clusters-agg-10.txt' or 'clusters-kmeans-500.txt'
parts = file.replace('.txt', '').split('-')
if len(parts) == 3 and parts[2].isdigit():
cluster_files.append(file)
for file in sorted(cluster_files):
st.sidebar.write(f"- {file}")
return sorted(cluster_files)
def parse_cluster_filename(filename: str) -> tuple:
"""Parse cluster filename to get algorithm and size."""
parts = filename.replace('.txt', '').split('-')
return parts[1], int(parts[2]) # returns (algorithm, size)
def load_cluster_sentences(model_dir: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str] = None, specific_cluster: str = None):
"""Load sentences and their cluster assignments for a given model and layer."""
sentence_file = os.path.join(model_dir, language, "input.in")
cluster_file_path = os.path.join(model_dir, language, f"layer{layer}", cluster_type, cluster_file)
# Load all sentences first
with open(sentence_file, 'r', encoding='utf-8') as f:
all_sentences = [line.strip() for line in f]
# Process cluster file to get sentence mappings
cluster_sentences = defaultdict(list)
cluster_tokens = defaultdict(set) # Track all tokens in each cluster
# First pass: collect all tokens for each cluster
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
cluster_id = parts[4].strip()
# Only collect tokens for the specific cluster if specified
if specific_cluster is None or cluster_id == specific_cluster:
cluster_tokens[cluster_id].add(token)
# Second pass: collect sentences
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
occurrence = 1
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
occurrence = 1
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
try:
occurrence = int(parts[1])
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
except ValueError:
continue
# Include sentences if:
# 1. This is the specific cluster we're looking for AND
# 2. Either no specific tokens requested OR token is in our search list
if ((specific_cluster is None or cluster_id == specific_cluster) and
(tokens is None or token in tokens)):
if 0 <= sentence_id < len(all_sentences):
sentence_tokens = all_sentences[sentence_id].split()
if 0 <= token_idx < len(sentence_tokens):
# Verify the token actually appears at the specified index
if sentence_tokens[token_idx] == token:
cluster_sentences[cluster_id].append({
"sentence": all_sentences[sentence_id],
"token": token,
"token_idx": token_idx,
"occurrence": occurrence,
"sentence_id": sentence_id,
"all_cluster_tokens": cluster_tokens[cluster_id]
})
return cluster_sentences
def create_sentence_html(sentence, sent_info, cluster_tokens=None):
"""Create HTML for sentence with highlighted tokens
Args:
sentence: The full sentence text
sent_info: Dictionary containing token and position info
cluster_tokens: Set of all unique tokens in this cluster (unused)
"""
# Remove the triple quotes and use single quotes to prevent HTML from being escaped
html = '<div style="font-family: monospace; padding: 10px; margin: 5px 0; background-color: #f5f5f5; border-radius: 5px;">'
html += '<div style="margin-bottom: 5px;">'
# Get token information
target_token = sent_info["token"]
target_idx = sent_info["token_idx"]
line_number = sent_info["sentence_id"]
# Split the tokenized sentence
tokens = sentence.split()
# Highlight only the target token in red
for i, token in enumerate(tokens):
if i == target_idx:
# Target token in red
html += f'<span style="color: red; font-weight: bold;">{token}</span> '
else:
# Regular tokens
html += f'{token} '
html += '</div>'
html += f'<div style="color: #666; font-size: 0.9em;">Token: <code>{target_token}</code> (Line: {line_number}, Index: {target_idx})</div>'
html += '</div>'
return html
def display_cluster_analysis(model_name: str, language: str, cluster_type: str, selected_layer: int, cluster_file: str, clustering_method: str):
"""Display cluster analysis for selected model and layer."""
# Load cluster data
cluster_sentences = load_cluster_sentences(model_name, language, cluster_type, selected_layer, cluster_file)
# Get clustering algorithm and size from filename
algorithm, size = parse_cluster_filename(cluster_file)
st.write(f"### Analyzing {algorithm.upper()} clustering with {size} clusters")
# Create cluster selection with navigation buttons
cluster_ids = sorted(cluster_sentences.keys(), key=lambda x: int(x)) # Sort numerically, no prefix removal needed
if not cluster_ids:
st.error("No clusters found in the data")
return
# Create a unique key for this combination of parameters
state_key = f"cluster_index_{selected_layer}_{algorithm}_{size}"
# Initialize or reset session state if needed
if state_key not in st.session_state:
st.session_state[state_key] = 0
# Ensure the index is valid for the current number of clusters
if st.session_state[state_key] >= len(cluster_ids):
st.session_state[state_key] = 0
# Create columns with adjusted widths for better spacing
col1, col2, col3, col4 = st.columns([3, 1, 1, 7]) # Adjusted column ratios
# Add some vertical space before the controls
st.write("")
with col1:
# Use a key that includes relevant parameters to force refresh
select_key = f"cluster_select_{selected_layer}_{algorithm}_{size}"
selected_cluster = st.selectbox(
"Select cluster",
range(len(cluster_ids)),
index=st.session_state[state_key],
format_func=lambda x: f"Cluster {cluster_ids[x]}",
label_visibility="collapsed", # Hides the label but keeps accessibility
key=select_key
)
# Update session state when dropdown changes
if selected_cluster != st.session_state[state_key]:
st.session_state[state_key] = selected_cluster
st.rerun()
# Previous cluster button with custom styling
with col2:
if st.button("◀", use_container_width=True):
st.session_state[state_key] = max(0, st.session_state[state_key] - 1)
st.rerun()
# Next cluster button with custom styling
with col3:
if st.button("▶", use_container_width=True):
st.session_state[state_key] = min(len(cluster_ids) - 1, st.session_state[state_key] + 1)
st.rerun()
# Add some vertical space after the controls
st.write("")
# Get the current cluster
cluster_id = cluster_ids[st.session_state[state_key]]
# Load cluster data with specific cluster ID
cluster_sentences = load_cluster_sentences(
model_name,
language,
cluster_type,
selected_layer,
cluster_file,
specific_cluster=cluster_id
)
sentences_data = cluster_sentences[cluster_id]
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
# Get all unique tokens in this cluster
cluster_tokens = sentences_data[0]["all_cluster_tokens"] if sentences_data else set()
# Display word cloud for this cluster
st.write("### Word Cloud")
wc = create_wordcloud(sentences_data)
if wc:
# Create a centered column for the word cloud
col1, col2, col3 = st.columns([1, 3, 1]) # Increased middle column width
with col2:
# Clear any existing matplotlib figures
plt.clf()
# Create new figure with larger size
fig, ax = plt.subplots(figsize=(10, 6)) # Increased from (5, 3) to (10, 6)
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
# Display the figure
st.pyplot(fig)
# Clean up
plt.close(fig)
# Display context sentences for this cluster
st.write("### Context Sentences")
# Create a dictionary to track sentences by their text
unique_sentences = {}
# First pass: collect all sentences and their token information
for sent_info in sentences_data:
# Escape any HTML special characters in the token and sentence
sentence_text = html.escape(sent_info["sentence"])
token = html.escape(sent_info["token"])
token_idx = sent_info["token_idx"]
if sentence_text not in unique_sentences:
unique_sentences[sentence_text] = {
"tokens": [(token, token_idx)],
"sentence_id": sent_info["sentence_id"]
}
else:
unique_sentences[sentence_text]["tokens"].append((token, token_idx))
# Second pass: display each unique sentence with all its tokens highlighted
for sentence_text, info in unique_sentences.items():
# Create HTML with multiple tokens highlighted
tokens = sentence_text.split()
html_content = """
<div style='font-family: monospace; padding: 10px; margin: 5px 0; background-color: #f5f5f5; border-radius: 5px;'>
<div style='margin-bottom: 5px;'>
"""
# Highlight all relevant tokens in the sentence
for i, token in enumerate(tokens):
if any(i == idx for _, idx in info["tokens"]):
# This is one of our target tokens - highlight it
html_content += f"<span style='color: red; font-weight: bold;'>{token}</span> "
else:
# Regular token
html_content += f"{token} "
# Add token information footer
html_content += f"""
</div>
<div style='color: #666; font-size: 0.9em;'>
Tokens: {", ".join([f"<code>{t}</code> (Index: {idx})" for t, idx in info["tokens"]])}
(Line: {info["sentence_id"]})
</div>
</div>
"""
st.markdown(html_content, unsafe_allow_html=True)
with col_chat:
# Add chat interface
chat_with_cluster(sentences_data, clustering_method, cluster_id)
def main():
# Set page to use full width
st.set_page_config(layout="wide")
st.title("Coconet Visual Analysis")
# Initialize session state for selections if they don't exist
if 'model_name' not in st.session_state:
st.session_state.model_name = None
if 'selected_language' not in st.session_state:
st.session_state.selected_language = None
if 'selected_layer' not in st.session_state:
st.session_state.selected_layer = None
if 'selected_cluster_type' not in st.session_state:
st.session_state.selected_cluster_type = None
if 'selected_cluster_file' not in st.session_state:
st.session_state.selected_cluster_file = None
if 'analysis_mode' not in st.session_state:
st.session_state.analysis_mode = "Individual Clusters"
# Get available models (directories in the current directory)
current_dir = os.path.dirname(os.path.abspath(__file__))
available_models = [d for d in os.listdir(current_dir)
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.') and not d == '__pycache__']
# Model selection
model_name = st.sidebar.selectbox(
"Select Data",
available_models,
key="model_select",
index=available_models.index(st.session_state.model_name) if st.session_state.model_name in available_models else 0
)
if model_name != st.session_state.model_name:
st.session_state.model_name = model_name
st.session_state.selected_language = None
st.session_state.selected_layer = None
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
if not model_name:
st.error("No models found")
return
# Get available languages for the selected model
model_dir = os.path.join(current_dir, model_name)
available_languages = [d for d in os.listdir(model_dir)
if os.path.isdir(os.path.join(model_dir, d))]
# Language selection
selected_language = st.sidebar.selectbox(
"Select Language",
available_languages,
key="language_select",
index=available_languages.index(st.session_state.selected_language) if st.session_state.selected_language in available_languages else 0
)
if selected_language != st.session_state.selected_language:
st.session_state.selected_language = selected_language
st.session_state.selected_layer = None
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
if not selected_language:
st.error("No languages found for selected model")
return
# Get available layers
language_dir = os.path.join(model_dir, selected_language)
layer_dirs = [d for d in os.listdir(language_dir)
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d))]
if not layer_dirs:
st.error("No layer directories found")
return
# Extract layer numbers
available_layers = sorted([int(d.replace('layer', '')) for d in layer_dirs])
# Layer selection
selected_layer = st.sidebar.selectbox(
"Select Layer",
available_layers,
key="layer_select",
index=available_layers.index(st.session_state.selected_layer) if st.session_state.selected_layer in available_layers else 0
)
if selected_layer != st.session_state.selected_layer:
st.session_state.selected_layer = selected_layer
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
# Get available clustering types
layer_dir = os.path.join(language_dir, f"layer{selected_layer}")
available_cluster_types = [d for d in os.listdir(layer_dir)
if os.path.isdir(os.path.join(layer_dir, d))]
# Clustering type selection
selected_cluster_type = st.sidebar.selectbox(
"Select Clustering Type",
available_cluster_types,
key="cluster_type_select",
index=available_cluster_types.index(st.session_state.selected_cluster_type) if st.session_state.selected_cluster_type in available_cluster_types else 0
)
if selected_cluster_type != st.session_state.selected_cluster_type:
st.session_state.selected_cluster_type = selected_cluster_type
st.session_state.selected_cluster_file = None
if not selected_cluster_type:
st.error("No clustering types found for selected layer")
return
# Get available cluster files
cluster_dir = os.path.join(layer_dir, selected_cluster_type)
available_cluster_files = get_available_cluster_files(cluster_dir)
if not available_cluster_files:
st.error("No cluster files found in the selected layer")
return
# Cluster file selection
selected_cluster_file = st.sidebar.selectbox(
"Select Clustering",
available_cluster_files,
key="cluster_file_select",
format_func=lambda x: f"{parse_cluster_filename(x)[0].upper()} (k={parse_cluster_filename(x)[1]})",
index=available_cluster_files.index(st.session_state.selected_cluster_file) if st.session_state.selected_cluster_file in available_cluster_files else 0
)
st.session_state.selected_cluster_file = selected_cluster_file
# Analysis mode selection
analysis_mode = st.sidebar.radio(
"Select Analysis Mode",
["Individual Clusters", "Search And Analysis", "Search By Line",
"Line Token Distribution", "Token Pairs", "View Input File"],
key="analysis_mode_select",
index=["Individual Clusters", "Search And Analysis", "Search By Line",
"Line Token Distribution", "Token Pairs", "View Input File"].index(st.session_state.analysis_mode)
)
st.session_state.analysis_mode = analysis_mode
# Call appropriate analysis function based on mode
if analysis_mode == "Individual Clusters":
display_cluster_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file, selected_cluster_type)
elif analysis_mode == "Search And Analysis":
handle_token_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Search By Line":
handle_line_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Line Token Distribution":
handle_line_token_distribution(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Token Pairs":
display_token_pair_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "View Input File":
display_input_file(model_name, selected_language)
def display_token_evolution(evolution_data: dict, tokens: List[str]):
"""Display evolution analysis for tokens"""
st.write(f"### Evolution Analysis for Token(s)")
# Create main evolution graph
fig = go.Figure()
# Colors for different types of lines
colors = {
'individual': ['#3498db', '#e74c3c', '#2ecc71'], # Blue, Red, Green
'exclusive': ['#9b59b6', '#f1c40f', '#1abc9c'], # Purple, Yellow, Turquoise
'combined': '#34495e' # Dark Gray
}
# Add individual count lines
for i, token in enumerate(tokens):
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['individual_counts'][token],
name=f"'{token}' (Total)",
mode='lines+markers',
line=dict(color=colors['individual'][i], width=2),
marker=dict(size=8)
))
# Add exclusive count lines only for multiple tokens
if len(tokens) > 1:
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['exclusive_counts'][token],
name=f"'{token}' (Exclusive)",
mode='lines+markers',
line=dict(color=colors['exclusive'][i], width=2, dash='dot'),
marker=dict(size=8)
))
# Add combined counts if multiple tokens
if len(tokens) > 1:
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['combined_counts'],
name='Co-occurring',
mode='lines+markers',
line=dict(color=colors['combined'], width=2),
marker=dict(size=8)
))
# Update layout
fig.update_layout(
title=dict(
text='Token Evolution Across Layers',
font=dict(size=20)
),
xaxis_title=dict(
text='Layer',
font=dict(size=14)
),
yaxis_title=dict(
text='Number of Clusters',
font=dict(size=14)
),
hovermode='x unified',
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
# Add gridlines
fig.update_xaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot')
fig.update_yaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot')
st.plotly_chart(fig, use_container_width=True)
def find_clusters_for_token(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, search_token: str) -> set:
"""Find cluster IDs containing the exact token"""
matching_clusters = set()
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
# Get token and cluster ID based on pipe count
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
cluster_id = parts[4].strip()
# Use exact token matching
if token == search_token:
matching_clusters.add(cluster_id)
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return set()
return matching_clusters
def handle_token_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle token search functionality"""
st.write("### Token Search")
# Initialize session state for search results if needed
if 'search_results_state' not in st.session_state:
st.session_state.search_results_state = {
'matching_tokens': [],
'matching_tokens2': [],
'last_search': None,
'last_search2': None,
'search_mode': 'single'
}
elif 'search_mode' not in st.session_state.search_results_state:
st.session_state.search_results_state['search_mode'] = 'single'
# Radio button for search mode
search_mode = st.radio(
"Search Mode",
["Single Token", "Token Pair"],
key="search_mode_radio",
index=0 if st.session_state.search_results_state['search_mode'] == 'single' else 1
)
# Update search mode in session state
st.session_state.search_results_state['search_mode'] = 'single' if search_mode == "Single Token" else 'pair'
if search_mode == "Single Token":
# Single token search interface
search_token = st.text_input("Search for token:")
if search_token:
# Find matching clusters
clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token
)
if clusters:
selected_token = search_token
if selected_token:
# Update state
st.session_state.search_results_state.update({
'matching_tokens': [selected_token],
'last_search': selected_token
})
# Display cluster details
display_cluster_details(
model_name,
language,
cluster_type,
selected_token,
cluster_file
)
else:
st.warning(f"No clusters found containing token: '{search_token}'")
else: # Token Pair search
col1, col2 = st.columns(2)
with col1:
search_token1 = st.text_input("Search for first token:")
with col2:
search_token2 = st.text_input("Search for second token:")
if search_token1 and search_token2:
# Find matching clusters for both tokens
clusters1 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token1
)
clusters2 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token2
)
# Find intersection of clusters
matching_clusters = clusters1 & clusters2
if matching_clusters:
# Update state
st.session_state.search_results_state.update({
'matching_tokens': [search_token1],
'matching_tokens2': [search_token2],
'last_search': search_token1,
'last_search2': search_token2
})
# Display cluster details
display_cluster_details(
model_name,
language,
cluster_type,
search_token1,
cluster_file,
second_token=search_token2
)
else:
st.warning(f"No clusters found containing both tokens: '{search_token1}' and '{search_token2}'")
def analyze_token_evolution(model_name: str, language: str, cluster_type: str, layer: int, tokens: List[str], cluster_file: str) -> dict:
"""Analyze token evolution across all available layers"""
# Get all available layers by checking directories
language_dir = os.path.join(model_name, language)
available_layers = []
for d in os.listdir(language_dir):
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)):
try:
layer_num = int(d.replace('layer', ''))
available_layers.append(layer_num)
except ValueError:
continue
available_layers.sort() # Sort layers numerically
evolution_data = {
'layers': available_layers,
'individual_counts': {token: [] for token in tokens},
'exclusive_counts': {token: [] for token in tokens}, # New: track exclusive counts
'combined_counts': [] if len(tokens) > 1 else None
}
# Extract cluster size from the filename (e.g., "clusters-agg-500.txt" -> "500")
cluster_size = cluster_file.split('-')[-1].replace('.txt', '')
# Handle shortened form for agglomerative clustering
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type
for current_layer in available_layers:
# Get clusters for each token
token_clusters = {}
cluster_file_path = os.path.join(
model_name,
language,
f"layer{current_layer}",
cluster_type,
f"clusters-{cluster_type_short}-{cluster_size}.txt"
)
# Skip layer if cluster file doesn't exist
if not os.path.exists(cluster_file_path):
continue
for token in tokens:
clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
current_layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
token
)
token_clusters[token] = set(clusters)
evolution_data['individual_counts'][token].append(len(clusters))
# Calculate exclusive and co-occurring clusters
if len(tokens) > 1:
# Calculate co-occurrences
cooccurring_clusters = set.intersection(*[token_clusters[token] for token in tokens])
evolution_data['combined_counts'].append(len(cooccurring_clusters))
# Calculate exclusive counts for each token
for token in tokens:
other_tokens = set(tokens) - {token}
other_clusters = set.union(*[token_clusters[t] for t in other_tokens]) if other_tokens else set()
exclusive_clusters = token_clusters[token] - other_clusters
evolution_data['exclusive_counts'][token].append(len(exclusive_clusters))
else:
# For single token, exclusive count is the same as individual count
evolution_data['exclusive_counts'][tokens[0]] = evolution_data['individual_counts'][tokens[0]]
return evolution_data
def find_clusters_with_multiple_tokens(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str]) -> dict:
"""Find clusters containing multiple specified tokens"""
clusters = defaultdict(lambda: {'matching_tokens': {token: set() for token in tokens}})
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|||')
if len(parts) == 5:
token = parts[0].strip()
cluster_id = parts[4].strip()
for search_token in tokens:
if token == search_token:
# Add all tokens from this cluster
with open(cluster_file_path, 'r', encoding='utf-8') as f2:
for line2 in f2:
parts2 = line2.strip().split('|||')
if len(parts2) == 5 and parts2[4].strip() == cluster_id:
clusters[cluster_id]['matching_tokens'][search_token].add(parts2[0].strip())
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return {}
# Filter to only keep clusters with all tokens
return {k: v for k, v in clusters.items() if all(v['matching_tokens'][token] for token in tokens)}
def handle_semantic_tag_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle semantic tag search functionality"""
st.write("### Semantic Tag Search")
st.info("This feature will be implemented soon.")
def display_token_pair_analysis(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Display analysis for predefined token pairs"""
st.write("### Token Pair Analysis")
# Get predefined token pairs
token_pairs = get_predefined_token_pairs()
# Create tabs for each category
tabs = st.tabs(list(token_pairs.keys()))
for tab, (category, data) in zip(tabs, token_pairs.items()):
with tab:
st.write(f"### {category}")
st.write(data["description"])
# Instead of using expanders, use a selectbox to choose the token pair
pairs = data["pairs"]
pair_labels = [f"{t1} vs {t2}" for t1, t2 in pairs]
selected_pair_idx = st.selectbox(
"Select token pair",
range(len(pairs)),
format_func=lambda i: pair_labels[i],
key=f"pair_select_{category}"
)
# Get the selected pair
token1, token2 = pairs[selected_pair_idx]
# Update state for token pair search
st.session_state.search_results_state = {
'matching_tokens': [token1],
'matching_tokens2': [token2],
'last_search': token1,
'last_search2': token2,
'search_mode': 'pair'
}
# Display results
st.write("### Search Results")
# Create tabs for different views
tab1, tab2 = st.tabs(["Evolution Analysis", "Co-occurring Clusters"])
with tab1:
evolution_data = analyze_token_evolution(
model_name,
language,
cluster_type,
layer,
[token1, token2],
cluster_file
)
if evolution_data:
display_token_evolution(evolution_data, [token1, token2])
with tab2:
display_cluster_details(
model_name,
language,
cluster_type,
token1,
cluster_file,
second_token=token2
)
def get_predefined_token_pairs():
"""Return predefined token pairs organized by categories"""
return {
"Control Flow": {
"description": "Different control flow constructs",
"pairs": [
("for", "while"),
("if", "switch"),
("break", "continue"),
("try", "catch")
]
},
"Access Modifiers": {
"description": "Access and modifier keywords",
"pairs": [
("public", "private"),
("static", "final"),
("abstract", "interface")
]
},
"Variable/Type": {
"description": "Variable and type-related tokens",
"pairs": [
("int", "Integer"),
("null", "Optional"),
("var", "String") # Example of var vs explicit type
]
},
"Collections": {
"description": "Collection-related tokens",
"pairs": [
("List", "Array"),
("ArrayList", "LinkedList"),
("HashMap", "TreeMap"),
("Set", "List")
]
},
"Threading": {
"description": "Threading and concurrency tokens",
"pairs": [
("synchronized", "volatile"),
("Runnable", "Callable"),
("wait", "sleep")
]
},
"Object-Oriented": {
"description": "Object-oriented programming tokens",
"pairs": [
("extends", "implements"),
("this", "super"),
("new", "clone")
]
}
}
def create_wordcloud(tokens, token1=None, token2=None):
"""Create and return a word cloud from tokens with frequencies"""
if not tokens:
return None
# Create frequency dict by counting occurrences
freq_dict = {}
# If tokens is a list of dictionaries (from cluster data)
if isinstance(tokens, list) and tokens and isinstance(tokens[0], dict):
# Count token occurrences in the cluster
for token_info in tokens:
token = token_info["token"]
freq_dict[token] = freq_dict.get(token, 0) + 1
else:
# If tokens is a set/list of strings, convert to frequency dict
tokens_list = list(tokens) if isinstance(tokens, set) else tokens
for token in tokens_list:
freq_dict[token] = freq_dict.get(token, 0) + 1
# Normalize frequencies with a base size
max_freq = max(freq_dict.values())
base_size = 30 # Base size for all tokens
normalized_freq = {token: base_size + ((count / max_freq) * 70) for token, count in freq_dict.items()}
# Boost frequency of searched tokens if provided
if token1:
normalized_freq[token1] = normalized_freq.get(token1, 0) + 5
if token2:
normalized_freq[token2] = normalized_freq.get(token2, 0) + 5
# Custom colormap with dark shades of brown, green, and blue
wc = WordCloud(
width=800, height=400,
background_color='white',
max_words=100,
prefer_horizontal=1.0, # Make all words horizontal
colormap='Dark2' # Dark colormap with browns, greens, blues
).generate_from_frequencies(normalized_freq)
return wc
def display_cluster_details(model_name: str, language: str, cluster_type: str, token: str, cluster_file: str, second_token: str = None):
"""Display detailed cluster information organized by layers"""
# Get all available layers
language_dir = os.path.join(model_name, language)
available_layers = []
for d in os.listdir(language_dir):
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)):
try:
layer_num = int(d.replace('layer', ''))
available_layers.append(layer_num)
except ValueError:
continue
available_layers.sort()
# Create tabs for each layer
layer_tabs = st.tabs([f"Layer {layer}" for layer in available_layers])
# Handle shortened form for agglomerative clustering
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type
cluster_size = cluster_file.split('-')[-1].replace('.txt', '')
for layer, tab in zip(available_layers, layer_tabs):
with tab:
# Find clusters containing the token(s)
matching_clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
token
)
if second_token:
matching_clusters2 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
second_token
)
# Find intersection of clusters containing both tokens
matching_clusters &= matching_clusters2
if matching_clusters:
# Sort cluster IDs numerically
cluster_ids = sorted(matching_clusters, key=lambda x: int(x))
# Create dropdown for cluster selection
selected_cluster = st.selectbox(
f"Select cluster from Layer {layer}",
cluster_ids,
format_func=lambda x: f"Cluster {x}",
key=f"cluster_select_{layer}_{token}_{second_token if second_token else ''}"
)
if selected_cluster:
# Load all cluster data for the selected cluster
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
specific_cluster=selected_cluster
)
if selected_cluster in cluster_data:
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
shown_sentences = set()
# Get all unique tokens in this cluster
cluster_tokens = cluster_data[selected_cluster][0]["all_cluster_tokens"] if cluster_data[selected_cluster] else set()
# Display word cloud for this cluster
st.write("### Word Cloud")
wc = create_wordcloud(cluster_data[selected_cluster], token1=token, token2=second_token)
if wc:
# Create a centered column for the word cloud
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
plt.clf()
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
st.pyplot(fig)
plt.close(fig)
# First show sentences containing searched tokens
if second_token:
st.write(f"#### Context for searched tokens: '{token}' and '{second_token}'")
else:
st.write(f"#### Context for searched token: '{token}'")
html_output = []
for sent_info in cluster_data[selected_cluster]:
if sent_info["token"] in [token, second_token] and sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
html_output.append("<hr>")
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Then show all other sentences in the cluster
st.write("#### All sentences in cluster")
html_output = []
for sent_info in cluster_data[selected_cluster]:
if sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
html_output.append("<hr>")
if not html_output:
st.info("No additional sentences in this cluster.")
else:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
with col_chat:
# Add chat interface with cluster ID
chat_with_cluster(cluster_data[selected_cluster], cluster_type, selected_cluster)
else:
st.info(f"No sentences found for cluster {selected_cluster}")
else:
if second_token:
st.info(f"No clusters containing both '{token}' and '{second_token}' found in Layer {layer}")
else:
st.info(f"No clusters containing '{token}' found in Layer {layer}")
def display_input_file(model_name: str, language: str):
"""Display the contents of input.in file"""
st.write("### Input File Contents")
input_file = os.path.join(model_name, language, "input.in")
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Add line numbers and display in a scrollable container
numbered_lines = [f"{i+1:4d} | {line}" for i, line in enumerate(lines)]
st.code('\n'.join(numbered_lines), language='text')
# Display some statistics
st.write(f"Total lines: {len(lines)}")
except Exception as e:
st.error(f"Error reading input file: {e}")
from dotenv import load_dotenv
def setup_gemini():
"""Setup Gemini model with API key and temperature setting"""
load_dotenv() # Load environment variables from .env file
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
st.error("Please set GEMINI_API_KEY in your .env file")
return False
genai.configure(api_key=api_key)
# Create generation config with temperature 0.4
generation_config = {
"temperature": 0.4,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
return generation_config
def get_cluster_context(cluster_sentences):
"""Format cluster data into a clear context for Gemini, focusing on searched tokens only"""
# Get unique searched tokens and their frequencies
token_counts = {}
for sent_info in cluster_sentences:
token = sent_info["token"]
token_counts[token] = token_counts.get(token, 0) + 1
# Format the context
context = "Here is the data for this cluster:\n\n"
# Add token frequency information
context += "Tokens being analyzed:\n"
for token, count in sorted(token_counts.items(), key=lambda x: (-x[1], x[0])):
context += f"- '{token}' (appears {count} times)\n"
# Add sentence examples with token highlighting
context += "\nExample sentences (analyzed tokens marked with *):\n"
unique_sentences = {}
for sent_info in cluster_sentences:
sentence = sent_info["sentence"]
if sentence not in unique_sentences:
tokens = sentence.split()
marked_tokens = tokens.copy()
marked_tokens[sent_info["token_idx"]] = f"*{tokens[sent_info['token_idx']]}*"
unique_sentences[sentence] = " ".join(marked_tokens)
for marked_sentence in unique_sentences.values():
context += f"- {marked_sentence}\n"
return context
def chat_with_cluster(cluster_sentences, clustering_method, cluster_id=None):
"""Create a chat interface for discussing the cluster with Gemini"""
# Initialize Gemini with temperature setting
generation_config = setup_gemini()
if not generation_config:
return
model = genai.GenerativeModel('gemini-2.0-flash', generation_config=generation_config)
# Create a unique key for this specific cluster's chat
cluster_tokens = {sent_info["token"] for sent_info in cluster_sentences}
# Include cluster_id in the key if provided
cluster_key = f"cluster_chat_{cluster_id}_{'-'.join(sorted(cluster_tokens))}" if cluster_id else f"cluster_chat_{'-'.join(sorted(cluster_tokens))}"
history_key = f"{cluster_key}_history"
# Reset chat and history on each page load
st.session_state[cluster_key] = model.start_chat(history=[])
st.session_state[history_key] = []
# Get cluster context and send initial message
context = get_cluster_context(cluster_sentences)
initial_message = f"""You are a helpful assistant analyzing Java code clusters. You will help users understand
patterns and relationships in the provided cluster data. Here is the context for this cluster:
{context}
Please analyze this data and help users understand what each token is doing in the cluster. Be concise and to the point.
"""
# Send initial message if history is empty
if not st.session_state[history_key]:
response = st.session_state[cluster_key].send_message(initial_message)
st.session_state[history_key].append(("assistant", response.text))
# Display chat interface
st.write("### Chat with Gemini about this Cluster")
st.write("Ask questions about patterns, relationships, or insights in this cluster.")
# Display chat history
for role, message in st.session_state[history_key]:
with st.chat_message(role):
st.write(message)
# Chat input with cluster-specific key
user_input = st.text_input("Your question:", key=f"input_{cluster_key}")
if user_input:
try:
# Display user message
with st.chat_message("user"):
st.write(user_input)
# Store user message in history
st.session_state[history_key].append(("user", user_input))
# Get and display Gemini's response
with st.chat_message("assistant"):
response = st.session_state[cluster_key].send_message(user_input)
st.write(response.text)
# Store assistant's response in history
st.session_state[history_key].append(("assistant", response.text))
except Exception as e:
st.error(f"Error communicating with Gemini: {e}")
def handle_line_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle line number search functionality"""
st.write("### Search by Line Number")
# Create two columns for search inputs
col1, col2 = st.columns([1, 1])
with col1:
# Input for line number
line_number = st.number_input("Enter line number:", min_value=1, step=1)
with col2:
# Optional token search
token_filter = st.text_input("Filter by token (optional):", "")
if st.button("Search"):
# Find clusters containing the line number
clusters = find_clusters_for_line(
model_name,
language,
cluster_type,
layer,
cluster_file,
line_number
)
if clusters:
if token_filter:
# Filter clusters to only those containing both the line and the token
filtered_clusters = set()
for cluster_id in clusters:
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
cluster_file,
specific_cluster=cluster_id
)
if any(sent_info["token"] == token_filter for sent_info in cluster_data[cluster_id]):
filtered_clusters.add(cluster_id)
clusters = filtered_clusters
if not clusters:
st.warning(f"No clusters found containing both line {line_number} and token '{token_filter}'")
return
st.success(f"Found {len(clusters)} clusters containing line {line_number} and token '{token_filter}'")
else:
st.success(f"Found {len(clusters)} clusters containing line {line_number}")
# Create tabs for each cluster
cluster_tabs = st.tabs([f"Cluster {cluster_id}" for cluster_id in sorted(clusters)])
for tab, cluster_id in zip(cluster_tabs, sorted(clusters)):
with tab:
# Load cluster data
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
cluster_file,
specific_cluster=cluster_id
)
if cluster_id in cluster_data:
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
# Display word cloud
st.write("### Word Cloud")
wc = create_wordcloud(cluster_data[cluster_id], token_filter if token_filter else None)
if wc:
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
plt.clf()
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
st.pyplot(fig)
plt.close(fig)
# Display sentences
st.write("### Sentences in Cluster")
shown_sentences = set()
# First show the searched line
st.write("#### Searched Line")
html_output = []
for sent_info in cluster_data[cluster_id]:
if sent_info["sentence_id"] == line_number - 1: # Adjust for 0-based indexing
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
break
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Then show sentences with the filtered token (if specified)
if token_filter:
st.write(f"#### Sentences containing '{token_filter}'")
html_output = []
for sent_info in cluster_data[cluster_id]:
if (sent_info["token"] == token_filter and
sent_info["sentence"] not in shown_sentences):
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Finally show other sentences
st.write("#### Other sentences in cluster")
html_output = []
for sent_info in cluster_data[cluster_id]:
if sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
with col_chat:
# Add chat interface with cluster ID
chat_with_cluster(cluster_data[cluster_id], cluster_type, cluster_id)
else:
st.warning(f"No clusters found containing line {line_number}")
def find_clusters_for_line(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, line_number: int) -> set:
"""Find cluster IDs containing the specified line number"""
matching_clusters = set()
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
parts = stripped_line.split('|||')
if len(parts) >= 3:
try:
sentence_id = int(parts[2])
if sentence_id == line_number - 1: # Adjust for 0-based indexing
cluster_id = parts[4].strip()
matching_clusters.add(cluster_id)
except (ValueError, IndexError):
continue
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return set()
return matching_clusters
# Add new function to handle line token distribution
def handle_line_token_distribution(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Display cluster distribution for each token in a specific line"""
st.write("### Line Token Distribution")
# Input for line number
line_number = st.number_input("Enter line number:", min_value=1, step=1)
if st.button("Analyze"):
# Load the input file to get the line content
input_file = os.path.join(model_name, language, "input.in")
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if line_number <= len(lines):
line_content = lines[line_number - 1].strip()
tokens = line_content.split()
# Create a dictionary to store cluster assignments for each token
token_clusters = {}
# Find clusters for each token in the line
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|||')
if len(parts) >= 5:
try:
sentence_id = int(parts[2])
token_idx = int(parts[3])
if sentence_id == line_number - 1: # Adjust for 0-based indexing
token = parts[0].strip()
cluster_id = parts[4].strip()
token_clusters[(token, token_idx)] = cluster_id
except (ValueError, IndexError):
continue
# Find clusters that contain different tokens
cluster_to_unique_tokens = {}
for (token, idx), cluster in token_clusters.items():
if cluster not in cluster_to_unique_tokens:
cluster_to_unique_tokens[cluster] = set()
cluster_to_unique_tokens[cluster].add(token)
# Filter for clusters with different tokens (more than one unique token)
clusters_with_different_tokens = {
cluster: tokens for cluster, tokens in cluster_to_unique_tokens.items()
if len(tokens) > 1
}
has_mixed_clusters = len(clusters_with_different_tokens) > 0
# Display results
col1, col2 = st.columns([3, 1])
with col1:
st.write("#### Line Content:")
st.code(line_content, language="java")
with col2:
st.write("#### Uniqueness Check:")
st.checkbox("All different tokens have unique clusters", value=not has_mixed_clusters, disabled=True)
# Show clusters with different tokens
if has_mixed_clusters:
st.write("##### Clusters with Different Tokens:")
for cluster, unique_tokens in clusters_with_different_tokens.items():
# Get all indices for each token in this cluster
token_positions = {}
for token in unique_tokens:
positions = [idx for (t, idx), c in token_clusters.items()
if t == token and c == cluster]
token_positions[token] = positions
# Format the display string
tokens_str = ", ".join(
f"'{token}' (idx: {', '.join(map(str, positions))})"
for token, positions in token_positions.items()
)
st.markdown(f"- Cluster **{cluster}**: {tokens_str}")
st.write("#### Token Distribution:")
# Create a table showing token distributions
data = []
for i, token in enumerate(tokens):
cluster = token_clusters.get((token, i), "N/A")
data.append({
"Token Index": i,
"Token": token,
"Cluster": cluster
})
df = pd.DataFrame(data)
st.table(df)
# Create a visualization of the distribution
st.write("#### Distribution Visualization:")
# Create a Sankey diagram
source = []
target = []
value = []
label = []
# Add tokens as source nodes
for i, token in enumerate(tokens):
source.append(i)
cluster = token_clusters.get((token, i), "N/A")
if cluster == "N/A":
target_idx = len(tokens)
else:
if cluster not in label[len(tokens):]:
label.append(cluster)
target_idx = len(tokens) + label[len(tokens):].index(cluster)
target.append(target_idx)
value.append(1)
if i == 0:
label.extend(tokens)
# Create and display the Sankey diagram
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=label,
color="blue"
),
link=dict(
source=source,
target=target,
value=value
)
)])
fig.update_layout(title_text="Token to Cluster Distribution", font_size=10)
st.plotly_chart(fig, use_container_width=True)
else:
st.error(f"Line number {line_number} is out of range. File has {len(lines)} lines.")
except Exception as e:
st.error(f"Error analyzing line: {e}")
if __name__ == "__main__":
main()