import streamlit as st
import os
import json
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import plotly.graph_objects as go
from typing import List
import html
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
def get_available_cluster_files(layer_dir: str) -> List[str]:
"""Get list of available cluster files and extract their types and sizes."""
cluster_files = []
for file in os.listdir(layer_dir):
if file.startswith('clusters-') and file.endswith('.txt'):
# Parse files like 'clusters-agg-10.txt' or 'clusters-kmeans-500.txt'
parts = file.replace('.txt', '').split('-')
if len(parts) == 3 and parts[2].isdigit():
cluster_files.append(file)
for file in sorted(cluster_files):
st.sidebar.write(f"- {file}")
return sorted(cluster_files)
def parse_cluster_filename(filename: str) -> tuple:
"""Parse cluster filename to get algorithm and size."""
parts = filename.replace('.txt', '').split('-')
return parts[1], int(parts[2]) # returns (algorithm, size)
def load_cluster_sentences(model_dir: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str] = None, specific_cluster: str = None):
"""Load sentences and their cluster assignments for a given model and layer."""
sentence_file = os.path.join(model_dir, language, "input.in")
cluster_file_path = os.path.join(model_dir, language, f"layer{layer}", cluster_type, cluster_file)
# Load all sentences first
with open(sentence_file, 'r', encoding='utf-8') as f:
all_sentences = [line.strip() for line in f]
# Process cluster file to get sentence mappings
cluster_sentences = defaultdict(list)
cluster_tokens = defaultdict(set) # Track all tokens in each cluster
# First pass: collect all tokens for each cluster
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
cluster_id = parts[4].strip()
# Only collect tokens for the specific cluster if specified
if specific_cluster is None or cluster_id == specific_cluster:
cluster_tokens[cluster_id].add(token)
# Second pass: collect sentences
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
occurrence = 1
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
occurrence = 1
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
try:
occurrence = int(parts[1])
sentence_id = int(parts[2])
token_idx = int(parts[3])
cluster_id = parts[4].strip()
except ValueError:
continue
# Include sentences if:
# 1. This is the specific cluster we're looking for AND
# 2. Either no specific tokens requested OR token is in our search list
if ((specific_cluster is None or cluster_id == specific_cluster) and
(tokens is None or token in tokens)):
if 0 <= sentence_id < len(all_sentences):
sentence_tokens = all_sentences[sentence_id].split()
if 0 <= token_idx < len(sentence_tokens):
# Verify the token actually appears at the specified index
if sentence_tokens[token_idx] == token:
cluster_sentences[cluster_id].append({
"sentence": all_sentences[sentence_id],
"token": token,
"token_idx": token_idx,
"occurrence": occurrence,
"sentence_id": sentence_id,
"all_cluster_tokens": cluster_tokens[cluster_id]
})
return cluster_sentences
def create_sentence_html(sentence, sent_info, cluster_tokens=None):
"""Create HTML for sentence with highlighted tokens
Args:
sentence: The full sentence text
sent_info: Dictionary containing token and position info
cluster_tokens: Set of all unique tokens in this cluster (unused)
"""
# Remove the triple quotes and use single quotes to prevent HTML from being escaped
html = '
'
html += '
'
# Get token information
target_token = sent_info["token"]
target_idx = sent_info["token_idx"]
line_number = sent_info["sentence_id"]
# Split the tokenized sentence
tokens = sentence.split()
# Highlight only the target token in red
for i, token in enumerate(tokens):
if i == target_idx:
# Target token in red
html += f'{token} '
else:
# Regular tokens
html += f'{token} '
html += '
'
html += f'
Token: {target_token}
(Line: {line_number}, Index: {target_idx})
'
html += '
'
return html
def display_cluster_analysis(model_name: str, language: str, cluster_type: str, selected_layer: int, cluster_file: str, clustering_method: str):
"""Display cluster analysis for selected model and layer."""
# Load cluster data
cluster_sentences = load_cluster_sentences(model_name, language, cluster_type, selected_layer, cluster_file)
# Get clustering algorithm and size from filename
algorithm, size = parse_cluster_filename(cluster_file)
st.write(f"### Analyzing {algorithm.upper()} clustering with {size} clusters")
# Create cluster selection with navigation buttons
cluster_ids = sorted(cluster_sentences.keys(), key=lambda x: int(x)) # Sort numerically, no prefix removal needed
if not cluster_ids:
st.error("No clusters found in the data")
return
# Create a unique key for this combination of parameters
state_key = f"cluster_index_{selected_layer}_{algorithm}_{size}"
# Initialize or reset session state if needed
if state_key not in st.session_state:
st.session_state[state_key] = 0
# Ensure the index is valid for the current number of clusters
if st.session_state[state_key] >= len(cluster_ids):
st.session_state[state_key] = 0
# Create columns with adjusted widths for better spacing
col1, col2, col3, col4 = st.columns([3, 1, 1, 7]) # Adjusted column ratios
# Add some vertical space before the controls
st.write("")
with col1:
# Use a key that includes relevant parameters to force refresh
select_key = f"cluster_select_{selected_layer}_{algorithm}_{size}"
selected_cluster = st.selectbox(
"Select cluster",
range(len(cluster_ids)),
index=st.session_state[state_key],
format_func=lambda x: f"Cluster {cluster_ids[x]}",
label_visibility="collapsed", # Hides the label but keeps accessibility
key=select_key
)
# Update session state when dropdown changes
if selected_cluster != st.session_state[state_key]:
st.session_state[state_key] = selected_cluster
st.rerun()
# Previous cluster button with custom styling
with col2:
if st.button("◀", use_container_width=True):
st.session_state[state_key] = max(0, st.session_state[state_key] - 1)
st.rerun()
# Next cluster button with custom styling
with col3:
if st.button("▶", use_container_width=True):
st.session_state[state_key] = min(len(cluster_ids) - 1, st.session_state[state_key] + 1)
st.rerun()
# Add some vertical space after the controls
st.write("")
# Get the current cluster
cluster_id = cluster_ids[st.session_state[state_key]]
# Load cluster data with specific cluster ID
cluster_sentences = load_cluster_sentences(
model_name,
language,
cluster_type,
selected_layer,
cluster_file,
specific_cluster=cluster_id
)
sentences_data = cluster_sentences[cluster_id]
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
# Get all unique tokens in this cluster
cluster_tokens = sentences_data[0]["all_cluster_tokens"] if sentences_data else set()
# Display word cloud for this cluster
st.write("### Word Cloud")
wc = create_wordcloud(sentences_data)
if wc:
# Create a centered column for the word cloud
col1, col2, col3 = st.columns([1, 3, 1]) # Increased middle column width
with col2:
# Clear any existing matplotlib figures
plt.clf()
# Create new figure with larger size
fig, ax = plt.subplots(figsize=(10, 6)) # Increased from (5, 3) to (10, 6)
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
# Display the figure
st.pyplot(fig)
# Clean up
plt.close(fig)
# Display context sentences for this cluster
st.write("### Context Sentences")
# Create a dictionary to track sentences by their text
unique_sentences = {}
# First pass: collect all sentences and their token information
for sent_info in sentences_data:
# Escape any HTML special characters in the token and sentence
sentence_text = html.escape(sent_info["sentence"])
token = html.escape(sent_info["token"])
token_idx = sent_info["token_idx"]
if sentence_text not in unique_sentences:
unique_sentences[sentence_text] = {
"tokens": [(token, token_idx)],
"sentence_id": sent_info["sentence_id"]
}
else:
unique_sentences[sentence_text]["tokens"].append((token, token_idx))
# Second pass: display each unique sentence with all its tokens highlighted
for sentence_text, info in unique_sentences.items():
# Create HTML with multiple tokens highlighted
tokens = sentence_text.split()
html_content = """
"""
# Highlight all relevant tokens in the sentence
for i, token in enumerate(tokens):
if any(i == idx for _, idx in info["tokens"]):
# This is one of our target tokens - highlight it
html_content += f"{token} "
else:
# Regular token
html_content += f"{token} "
# Add token information footer
html_content += f"""
Tokens: {", ".join([f"{t}
(Index: {idx})" for t, idx in info["tokens"]])}
(Line: {info["sentence_id"]})
"""
st.markdown(html_content, unsafe_allow_html=True)
with col_chat:
# Add chat interface
chat_with_cluster(sentences_data, clustering_method, cluster_id)
def main():
# Set page to use full width
st.set_page_config(layout="wide")
st.title("Coconet Visual Analysis")
# Initialize session state for selections if they don't exist
if 'model_name' not in st.session_state:
st.session_state.model_name = None
if 'selected_language' not in st.session_state:
st.session_state.selected_language = None
if 'selected_layer' not in st.session_state:
st.session_state.selected_layer = None
if 'selected_cluster_type' not in st.session_state:
st.session_state.selected_cluster_type = None
if 'selected_cluster_file' not in st.session_state:
st.session_state.selected_cluster_file = None
if 'analysis_mode' not in st.session_state:
st.session_state.analysis_mode = "Individual Clusters"
# Get available models (directories in the current directory)
current_dir = os.path.dirname(os.path.abspath(__file__))
available_models = [d for d in os.listdir(current_dir)
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.') and not d == '__pycache__']
# Model selection
model_name = st.sidebar.selectbox(
"Select Data",
available_models,
key="model_select",
index=available_models.index(st.session_state.model_name) if st.session_state.model_name in available_models else 0
)
if model_name != st.session_state.model_name:
st.session_state.model_name = model_name
st.session_state.selected_language = None
st.session_state.selected_layer = None
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
if not model_name:
st.error("No models found")
return
# Get available languages for the selected model
model_dir = os.path.join(current_dir, model_name)
available_languages = [d for d in os.listdir(model_dir)
if os.path.isdir(os.path.join(model_dir, d))]
# Language selection
selected_language = st.sidebar.selectbox(
"Select Language",
available_languages,
key="language_select",
index=available_languages.index(st.session_state.selected_language) if st.session_state.selected_language in available_languages else 0
)
if selected_language != st.session_state.selected_language:
st.session_state.selected_language = selected_language
st.session_state.selected_layer = None
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
if not selected_language:
st.error("No languages found for selected model")
return
# Get available layers
language_dir = os.path.join(model_dir, selected_language)
layer_dirs = [d for d in os.listdir(language_dir)
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d))]
if not layer_dirs:
st.error("No layer directories found")
return
# Extract layer numbers
available_layers = sorted([int(d.replace('layer', '')) for d in layer_dirs])
# Layer selection
selected_layer = st.sidebar.selectbox(
"Select Layer",
available_layers,
key="layer_select",
index=available_layers.index(st.session_state.selected_layer) if st.session_state.selected_layer in available_layers else 0
)
if selected_layer != st.session_state.selected_layer:
st.session_state.selected_layer = selected_layer
st.session_state.selected_cluster_type = None
st.session_state.selected_cluster_file = None
# Get available clustering types
layer_dir = os.path.join(language_dir, f"layer{selected_layer}")
available_cluster_types = [d for d in os.listdir(layer_dir)
if os.path.isdir(os.path.join(layer_dir, d))]
# Clustering type selection
selected_cluster_type = st.sidebar.selectbox(
"Select Clustering Type",
available_cluster_types,
key="cluster_type_select",
index=available_cluster_types.index(st.session_state.selected_cluster_type) if st.session_state.selected_cluster_type in available_cluster_types else 0
)
if selected_cluster_type != st.session_state.selected_cluster_type:
st.session_state.selected_cluster_type = selected_cluster_type
st.session_state.selected_cluster_file = None
if not selected_cluster_type:
st.error("No clustering types found for selected layer")
return
# Get available cluster files
cluster_dir = os.path.join(layer_dir, selected_cluster_type)
available_cluster_files = get_available_cluster_files(cluster_dir)
if not available_cluster_files:
st.error("No cluster files found in the selected layer")
return
# Cluster file selection
selected_cluster_file = st.sidebar.selectbox(
"Select Clustering",
available_cluster_files,
key="cluster_file_select",
format_func=lambda x: f"{parse_cluster_filename(x)[0].upper()} (k={parse_cluster_filename(x)[1]})",
index=available_cluster_files.index(st.session_state.selected_cluster_file) if st.session_state.selected_cluster_file in available_cluster_files else 0
)
st.session_state.selected_cluster_file = selected_cluster_file
# Analysis mode selection
analysis_mode = st.sidebar.radio(
"Select Analysis Mode",
["Individual Clusters", "Search And Analysis", "Search By Line",
"Line Token Distribution", "Token Pairs", "View Input File"],
key="analysis_mode_select",
index=["Individual Clusters", "Search And Analysis", "Search By Line",
"Line Token Distribution", "Token Pairs", "View Input File"].index(st.session_state.analysis_mode)
)
st.session_state.analysis_mode = analysis_mode
# Call appropriate analysis function based on mode
if analysis_mode == "Individual Clusters":
display_cluster_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file, selected_cluster_type)
elif analysis_mode == "Search And Analysis":
handle_token_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Search By Line":
handle_line_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Line Token Distribution":
handle_line_token_distribution(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "Token Pairs":
display_token_pair_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file)
elif analysis_mode == "View Input File":
display_input_file(model_name, selected_language)
def display_token_evolution(evolution_data: dict, tokens: List[str]):
"""Display evolution analysis for tokens"""
st.write(f"### Evolution Analysis for Token(s)")
# Create main evolution graph
fig = go.Figure()
# Colors for different types of lines
colors = {
'individual': ['#3498db', '#e74c3c', '#2ecc71'], # Blue, Red, Green
'exclusive': ['#9b59b6', '#f1c40f', '#1abc9c'], # Purple, Yellow, Turquoise
'combined': '#34495e' # Dark Gray
}
# Add individual count lines
for i, token in enumerate(tokens):
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['individual_counts'][token],
name=f"'{token}' (Total)",
mode='lines+markers',
line=dict(color=colors['individual'][i], width=2),
marker=dict(size=8)
))
# Add exclusive count lines only for multiple tokens
if len(tokens) > 1:
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['exclusive_counts'][token],
name=f"'{token}' (Exclusive)",
mode='lines+markers',
line=dict(color=colors['exclusive'][i], width=2, dash='dot'),
marker=dict(size=8)
))
# Add combined counts if multiple tokens
if len(tokens) > 1:
fig.add_trace(go.Scatter(
x=evolution_data['layers'],
y=evolution_data['combined_counts'],
name='Co-occurring',
mode='lines+markers',
line=dict(color=colors['combined'], width=2),
marker=dict(size=8)
))
# Update layout
fig.update_layout(
title=dict(
text='Token Evolution Across Layers',
font=dict(size=20)
),
xaxis_title=dict(
text='Layer',
font=dict(size=14)
),
yaxis_title=dict(
text='Number of Clusters',
font=dict(size=14)
),
hovermode='x unified',
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
# Add gridlines
fig.update_xaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot')
fig.update_yaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot')
st.plotly_chart(fig, use_container_width=True)
def find_clusters_for_token(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, search_token: str) -> set:
"""Find cluster IDs containing the exact token"""
matching_clusters = set()
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
pipe_count = stripped_line.count('|')
# Get token and cluster ID based on pipe count
if pipe_count == 13:
token = '|'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
elif pipe_count == 14:
token = '||'
parts = stripped_line.split('|||')
cluster_id = parts[4].strip()
else:
parts = stripped_line.split('|||')
if len(parts) != 5:
continue
token = parts[0].strip()
cluster_id = parts[4].strip()
# Use exact token matching
if token == search_token:
matching_clusters.add(cluster_id)
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return set()
return matching_clusters
def handle_token_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle token search functionality"""
st.write("### Token Search")
# Initialize session state for search results if needed
if 'search_results_state' not in st.session_state:
st.session_state.search_results_state = {
'matching_tokens': [],
'matching_tokens2': [],
'last_search': None,
'last_search2': None,
'search_mode': 'single'
}
elif 'search_mode' not in st.session_state.search_results_state:
st.session_state.search_results_state['search_mode'] = 'single'
# Radio button for search mode
search_mode = st.radio(
"Search Mode",
["Single Token", "Token Pair"],
key="search_mode_radio",
index=0 if st.session_state.search_results_state['search_mode'] == 'single' else 1
)
# Update search mode in session state
st.session_state.search_results_state['search_mode'] = 'single' if search_mode == "Single Token" else 'pair'
if search_mode == "Single Token":
# Single token search interface
search_token = st.text_input("Search for token:")
if search_token:
# Find matching clusters
clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token
)
if clusters:
selected_token = search_token
if selected_token:
# Update state
st.session_state.search_results_state.update({
'matching_tokens': [selected_token],
'last_search': selected_token
})
# Display cluster details
display_cluster_details(
model_name,
language,
cluster_type,
selected_token,
cluster_file
)
else:
st.warning(f"No clusters found containing token: '{search_token}'")
else: # Token Pair search
col1, col2 = st.columns(2)
with col1:
search_token1 = st.text_input("Search for first token:")
with col2:
search_token2 = st.text_input("Search for second token:")
if search_token1 and search_token2:
# Find matching clusters for both tokens
clusters1 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token1
)
clusters2 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
cluster_file,
search_token2
)
# Find intersection of clusters
matching_clusters = clusters1 & clusters2
if matching_clusters:
# Update state
st.session_state.search_results_state.update({
'matching_tokens': [search_token1],
'matching_tokens2': [search_token2],
'last_search': search_token1,
'last_search2': search_token2
})
# Display cluster details
display_cluster_details(
model_name,
language,
cluster_type,
search_token1,
cluster_file,
second_token=search_token2
)
else:
st.warning(f"No clusters found containing both tokens: '{search_token1}' and '{search_token2}'")
def analyze_token_evolution(model_name: str, language: str, cluster_type: str, layer: int, tokens: List[str], cluster_file: str) -> dict:
"""Analyze token evolution across all available layers"""
# Get all available layers by checking directories
language_dir = os.path.join(model_name, language)
available_layers = []
for d in os.listdir(language_dir):
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)):
try:
layer_num = int(d.replace('layer', ''))
available_layers.append(layer_num)
except ValueError:
continue
available_layers.sort() # Sort layers numerically
evolution_data = {
'layers': available_layers,
'individual_counts': {token: [] for token in tokens},
'exclusive_counts': {token: [] for token in tokens}, # New: track exclusive counts
'combined_counts': [] if len(tokens) > 1 else None
}
# Extract cluster size from the filename (e.g., "clusters-agg-500.txt" -> "500")
cluster_size = cluster_file.split('-')[-1].replace('.txt', '')
# Handle shortened form for agglomerative clustering
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type
for current_layer in available_layers:
# Get clusters for each token
token_clusters = {}
cluster_file_path = os.path.join(
model_name,
language,
f"layer{current_layer}",
cluster_type,
f"clusters-{cluster_type_short}-{cluster_size}.txt"
)
# Skip layer if cluster file doesn't exist
if not os.path.exists(cluster_file_path):
continue
for token in tokens:
clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
current_layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
token
)
token_clusters[token] = set(clusters)
evolution_data['individual_counts'][token].append(len(clusters))
# Calculate exclusive and co-occurring clusters
if len(tokens) > 1:
# Calculate co-occurrences
cooccurring_clusters = set.intersection(*[token_clusters[token] for token in tokens])
evolution_data['combined_counts'].append(len(cooccurring_clusters))
# Calculate exclusive counts for each token
for token in tokens:
other_tokens = set(tokens) - {token}
other_clusters = set.union(*[token_clusters[t] for t in other_tokens]) if other_tokens else set()
exclusive_clusters = token_clusters[token] - other_clusters
evolution_data['exclusive_counts'][token].append(len(exclusive_clusters))
else:
# For single token, exclusive count is the same as individual count
evolution_data['exclusive_counts'][tokens[0]] = evolution_data['individual_counts'][tokens[0]]
return evolution_data
def find_clusters_with_multiple_tokens(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str]) -> dict:
"""Find clusters containing multiple specified tokens"""
clusters = defaultdict(lambda: {'matching_tokens': {token: set() for token in tokens}})
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|||')
if len(parts) == 5:
token = parts[0].strip()
cluster_id = parts[4].strip()
for search_token in tokens:
if token == search_token:
# Add all tokens from this cluster
with open(cluster_file_path, 'r', encoding='utf-8') as f2:
for line2 in f2:
parts2 = line2.strip().split('|||')
if len(parts2) == 5 and parts2[4].strip() == cluster_id:
clusters[cluster_id]['matching_tokens'][search_token].add(parts2[0].strip())
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return {}
# Filter to only keep clusters with all tokens
return {k: v for k, v in clusters.items() if all(v['matching_tokens'][token] for token in tokens)}
def handle_semantic_tag_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle semantic tag search functionality"""
st.write("### Semantic Tag Search")
st.info("This feature will be implemented soon.")
def display_token_pair_analysis(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Display analysis for predefined token pairs"""
st.write("### Token Pair Analysis")
# Get predefined token pairs
token_pairs = get_predefined_token_pairs()
# Create tabs for each category
tabs = st.tabs(list(token_pairs.keys()))
for tab, (category, data) in zip(tabs, token_pairs.items()):
with tab:
st.write(f"### {category}")
st.write(data["description"])
# Instead of using expanders, use a selectbox to choose the token pair
pairs = data["pairs"]
pair_labels = [f"{t1} vs {t2}" for t1, t2 in pairs]
selected_pair_idx = st.selectbox(
"Select token pair",
range(len(pairs)),
format_func=lambda i: pair_labels[i],
key=f"pair_select_{category}"
)
# Get the selected pair
token1, token2 = pairs[selected_pair_idx]
# Update state for token pair search
st.session_state.search_results_state = {
'matching_tokens': [token1],
'matching_tokens2': [token2],
'last_search': token1,
'last_search2': token2,
'search_mode': 'pair'
}
# Display results
st.write("### Search Results")
# Create tabs for different views
tab1, tab2 = st.tabs(["Evolution Analysis", "Co-occurring Clusters"])
with tab1:
evolution_data = analyze_token_evolution(
model_name,
language,
cluster_type,
layer,
[token1, token2],
cluster_file
)
if evolution_data:
display_token_evolution(evolution_data, [token1, token2])
with tab2:
display_cluster_details(
model_name,
language,
cluster_type,
token1,
cluster_file,
second_token=token2
)
def get_predefined_token_pairs():
"""Return predefined token pairs organized by categories"""
return {
"Control Flow": {
"description": "Different control flow constructs",
"pairs": [
("for", "while"),
("if", "switch"),
("break", "continue"),
("try", "catch")
]
},
"Access Modifiers": {
"description": "Access and modifier keywords",
"pairs": [
("public", "private"),
("static", "final"),
("abstract", "interface")
]
},
"Variable/Type": {
"description": "Variable and type-related tokens",
"pairs": [
("int", "Integer"),
("null", "Optional"),
("var", "String") # Example of var vs explicit type
]
},
"Collections": {
"description": "Collection-related tokens",
"pairs": [
("List", "Array"),
("ArrayList", "LinkedList"),
("HashMap", "TreeMap"),
("Set", "List")
]
},
"Threading": {
"description": "Threading and concurrency tokens",
"pairs": [
("synchronized", "volatile"),
("Runnable", "Callable"),
("wait", "sleep")
]
},
"Object-Oriented": {
"description": "Object-oriented programming tokens",
"pairs": [
("extends", "implements"),
("this", "super"),
("new", "clone")
]
}
}
def create_wordcloud(tokens, token1=None, token2=None):
"""Create and return a word cloud from tokens with frequencies"""
if not tokens:
return None
# Create frequency dict by counting occurrences
freq_dict = {}
# If tokens is a list of dictionaries (from cluster data)
if isinstance(tokens, list) and tokens and isinstance(tokens[0], dict):
# Count token occurrences in the cluster
for token_info in tokens:
token = token_info["token"]
freq_dict[token] = freq_dict.get(token, 0) + 1
else:
# If tokens is a set/list of strings, convert to frequency dict
tokens_list = list(tokens) if isinstance(tokens, set) else tokens
for token in tokens_list:
freq_dict[token] = freq_dict.get(token, 0) + 1
# Normalize frequencies with a base size
max_freq = max(freq_dict.values())
base_size = 30 # Base size for all tokens
normalized_freq = {token: base_size + ((count / max_freq) * 70) for token, count in freq_dict.items()}
# Boost frequency of searched tokens if provided
if token1:
normalized_freq[token1] = normalized_freq.get(token1, 0) + 5
if token2:
normalized_freq[token2] = normalized_freq.get(token2, 0) + 5
# Custom colormap with dark shades of brown, green, and blue
wc = WordCloud(
width=800, height=400,
background_color='white',
max_words=100,
prefer_horizontal=1.0, # Make all words horizontal
colormap='Dark2' # Dark colormap with browns, greens, blues
).generate_from_frequencies(normalized_freq)
return wc
def display_cluster_details(model_name: str, language: str, cluster_type: str, token: str, cluster_file: str, second_token: str = None):
"""Display detailed cluster information organized by layers"""
# Get all available layers
language_dir = os.path.join(model_name, language)
available_layers = []
for d in os.listdir(language_dir):
if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)):
try:
layer_num = int(d.replace('layer', ''))
available_layers.append(layer_num)
except ValueError:
continue
available_layers.sort()
# Create tabs for each layer
layer_tabs = st.tabs([f"Layer {layer}" for layer in available_layers])
# Handle shortened form for agglomerative clustering
cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type
cluster_size = cluster_file.split('-')[-1].replace('.txt', '')
for layer, tab in zip(available_layers, layer_tabs):
with tab:
# Find clusters containing the token(s)
matching_clusters = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
token
)
if second_token:
matching_clusters2 = find_clusters_for_token(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
second_token
)
# Find intersection of clusters containing both tokens
matching_clusters &= matching_clusters2
if matching_clusters:
# Sort cluster IDs numerically
cluster_ids = sorted(matching_clusters, key=lambda x: int(x))
# Create dropdown for cluster selection
selected_cluster = st.selectbox(
f"Select cluster from Layer {layer}",
cluster_ids,
format_func=lambda x: f"Cluster {x}",
key=f"cluster_select_{layer}_{token}_{second_token if second_token else ''}"
)
if selected_cluster:
# Load all cluster data for the selected cluster
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
f"clusters-{cluster_type_short}-{cluster_size}.txt",
specific_cluster=selected_cluster
)
if selected_cluster in cluster_data:
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
shown_sentences = set()
# Get all unique tokens in this cluster
cluster_tokens = cluster_data[selected_cluster][0]["all_cluster_tokens"] if cluster_data[selected_cluster] else set()
# Display word cloud for this cluster
st.write("### Word Cloud")
wc = create_wordcloud(cluster_data[selected_cluster], token1=token, token2=second_token)
if wc:
# Create a centered column for the word cloud
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
plt.clf()
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
st.pyplot(fig)
plt.close(fig)
# First show sentences containing searched tokens
if second_token:
st.write(f"#### Context for searched tokens: '{token}' and '{second_token}'")
else:
st.write(f"#### Context for searched token: '{token}'")
html_output = []
for sent_info in cluster_data[selected_cluster]:
if sent_info["token"] in [token, second_token] and sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
html_output.append("
")
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Then show all other sentences in the cluster
st.write("#### All sentences in cluster")
html_output = []
for sent_info in cluster_data[selected_cluster]:
if sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
html_output.append("
")
if not html_output:
st.info("No additional sentences in this cluster.")
else:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
with col_chat:
# Add chat interface with cluster ID
chat_with_cluster(cluster_data[selected_cluster], cluster_type, selected_cluster)
else:
st.info(f"No sentences found for cluster {selected_cluster}")
else:
if second_token:
st.info(f"No clusters containing both '{token}' and '{second_token}' found in Layer {layer}")
else:
st.info(f"No clusters containing '{token}' found in Layer {layer}")
def display_input_file(model_name: str, language: str):
"""Display the contents of input.in file"""
st.write("### Input File Contents")
input_file = os.path.join(model_name, language, "input.in")
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Add line numbers and display in a scrollable container
numbered_lines = [f"{i+1:4d} | {line}" for i, line in enumerate(lines)]
st.code('\n'.join(numbered_lines), language='text')
# Display some statistics
st.write(f"Total lines: {len(lines)}")
except Exception as e:
st.error(f"Error reading input file: {e}")
from dotenv import load_dotenv
def setup_gemini():
"""Setup Gemini model with API key and temperature setting"""
load_dotenv() # Load environment variables from .env file
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
st.error("Please set GEMINI_API_KEY in your .env file")
return False
genai.configure(api_key=api_key)
# Create generation config with temperature 0.4
generation_config = {
"temperature": 0.4,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
return generation_config
def get_cluster_context(cluster_sentences):
"""Format cluster data into a clear context for Gemini, focusing on searched tokens only"""
# Get unique searched tokens and their frequencies
token_counts = {}
for sent_info in cluster_sentences:
token = sent_info["token"]
token_counts[token] = token_counts.get(token, 0) + 1
# Format the context
context = "Here is the data for this cluster:\n\n"
# Add token frequency information
context += "Tokens being analyzed:\n"
for token, count in sorted(token_counts.items(), key=lambda x: (-x[1], x[0])):
context += f"- '{token}' (appears {count} times)\n"
# Add sentence examples with token highlighting
context += "\nExample sentences (analyzed tokens marked with *):\n"
unique_sentences = {}
for sent_info in cluster_sentences:
sentence = sent_info["sentence"]
if sentence not in unique_sentences:
tokens = sentence.split()
marked_tokens = tokens.copy()
marked_tokens[sent_info["token_idx"]] = f"*{tokens[sent_info['token_idx']]}*"
unique_sentences[sentence] = " ".join(marked_tokens)
for marked_sentence in unique_sentences.values():
context += f"- {marked_sentence}\n"
return context
def chat_with_cluster(cluster_sentences, clustering_method, cluster_id=None):
"""Create a chat interface for discussing the cluster with Gemini"""
# Initialize Gemini with temperature setting
generation_config = setup_gemini()
if not generation_config:
return
model = genai.GenerativeModel('gemini-2.0-flash', generation_config=generation_config)
# Create a unique key for this specific cluster's chat
cluster_tokens = {sent_info["token"] for sent_info in cluster_sentences}
# Include cluster_id in the key if provided
cluster_key = f"cluster_chat_{cluster_id}_{'-'.join(sorted(cluster_tokens))}" if cluster_id else f"cluster_chat_{'-'.join(sorted(cluster_tokens))}"
history_key = f"{cluster_key}_history"
# Reset chat and history on each page load
st.session_state[cluster_key] = model.start_chat(history=[])
st.session_state[history_key] = []
# Get cluster context and send initial message
context = get_cluster_context(cluster_sentences)
initial_message = f"""You are a helpful assistant analyzing Java code clusters. You will help users understand
patterns and relationships in the provided cluster data. Here is the context for this cluster:
{context}
Please analyze this data and help users understand what each token is doing in the cluster. Be concise and to the point.
"""
# Send initial message if history is empty
if not st.session_state[history_key]:
response = st.session_state[cluster_key].send_message(initial_message)
st.session_state[history_key].append(("assistant", response.text))
# Display chat interface
st.write("### Chat with Gemini about this Cluster")
st.write("Ask questions about patterns, relationships, or insights in this cluster.")
# Display chat history
for role, message in st.session_state[history_key]:
with st.chat_message(role):
st.write(message)
# Chat input with cluster-specific key
user_input = st.text_input("Your question:", key=f"input_{cluster_key}")
if user_input:
try:
# Display user message
with st.chat_message("user"):
st.write(user_input)
# Store user message in history
st.session_state[history_key].append(("user", user_input))
# Get and display Gemini's response
with st.chat_message("assistant"):
response = st.session_state[cluster_key].send_message(user_input)
st.write(response.text)
# Store assistant's response in history
st.session_state[history_key].append(("assistant", response.text))
except Exception as e:
st.error(f"Error communicating with Gemini: {e}")
def handle_line_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Handle line number search functionality"""
st.write("### Search by Line Number")
# Create two columns for search inputs
col1, col2 = st.columns([1, 1])
with col1:
# Input for line number
line_number = st.number_input("Enter line number:", min_value=1, step=1)
with col2:
# Optional token search
token_filter = st.text_input("Filter by token (optional):", "")
if st.button("Search"):
# Find clusters containing the line number
clusters = find_clusters_for_line(
model_name,
language,
cluster_type,
layer,
cluster_file,
line_number
)
if clusters:
if token_filter:
# Filter clusters to only those containing both the line and the token
filtered_clusters = set()
for cluster_id in clusters:
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
cluster_file,
specific_cluster=cluster_id
)
if any(sent_info["token"] == token_filter for sent_info in cluster_data[cluster_id]):
filtered_clusters.add(cluster_id)
clusters = filtered_clusters
if not clusters:
st.warning(f"No clusters found containing both line {line_number} and token '{token_filter}'")
return
st.success(f"Found {len(clusters)} clusters containing line {line_number} and token '{token_filter}'")
else:
st.success(f"Found {len(clusters)} clusters containing line {line_number}")
# Create tabs for each cluster
cluster_tabs = st.tabs([f"Cluster {cluster_id}" for cluster_id in sorted(clusters)])
for tab, cluster_id in zip(cluster_tabs, sorted(clusters)):
with tab:
# Load cluster data
cluster_data = load_cluster_sentences(
model_name,
language,
cluster_type,
layer,
cluster_file,
specific_cluster=cluster_id
)
if cluster_id in cluster_data:
# Create two columns for the main content
col_main, col_chat = st.columns([2, 1])
with col_main:
# Display word cloud
st.write("### Word Cloud")
wc = create_wordcloud(cluster_data[cluster_id], token_filter if token_filter else None)
if wc:
col1, col2, col3 = st.columns([1, 3, 1])
with col2:
plt.clf()
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('off')
ax.imshow(wc, interpolation='bilinear')
st.pyplot(fig)
plt.close(fig)
# Display sentences
st.write("### Sentences in Cluster")
shown_sentences = set()
# First show the searched line
st.write("#### Searched Line")
html_output = []
for sent_info in cluster_data[cluster_id]:
if sent_info["sentence_id"] == line_number - 1: # Adjust for 0-based indexing
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
break
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Then show sentences with the filtered token (if specified)
if token_filter:
st.write(f"#### Sentences containing '{token_filter}'")
html_output = []
for sent_info in cluster_data[cluster_id]:
if (sent_info["token"] == token_filter and
sent_info["sentence"] not in shown_sentences):
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
# Finally show other sentences
st.write("#### Other sentences in cluster")
html_output = []
for sent_info in cluster_data[cluster_id]:
if sent_info["sentence"] not in shown_sentences:
html_output.append(create_sentence_html(sent_info["sentence"], sent_info))
shown_sentences.add(sent_info["sentence"])
if html_output:
st.markdown("\n".join(html_output), unsafe_allow_html=True)
with col_chat:
# Add chat interface with cluster ID
chat_with_cluster(cluster_data[cluster_id], cluster_type, cluster_id)
else:
st.warning(f"No clusters found containing line {line_number}")
def find_clusters_for_line(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, line_number: int) -> set:
"""Find cluster IDs containing the specified line number"""
matching_clusters = set()
try:
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
stripped_line = line.strip()
parts = stripped_line.split('|||')
if len(parts) >= 3:
try:
sentence_id = int(parts[2])
if sentence_id == line_number - 1: # Adjust for 0-based indexing
cluster_id = parts[4].strip()
matching_clusters.add(cluster_id)
except (ValueError, IndexError):
continue
except Exception as e:
st.error(f"Error reading cluster file: {e}")
return set()
return matching_clusters
# Add new function to handle line token distribution
def handle_line_token_distribution(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str):
"""Display cluster distribution for each token in a specific line"""
st.write("### Line Token Distribution")
# Input for line number
line_number = st.number_input("Enter line number:", min_value=1, step=1)
if st.button("Analyze"):
# Load the input file to get the line content
input_file = os.path.join(model_name, language, "input.in")
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if line_number <= len(lines):
line_content = lines[line_number - 1].strip()
tokens = line_content.split()
# Create a dictionary to store cluster assignments for each token
token_clusters = {}
# Find clusters for each token in the line
cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file)
with open(cluster_file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|||')
if len(parts) >= 5:
try:
sentence_id = int(parts[2])
token_idx = int(parts[3])
if sentence_id == line_number - 1: # Adjust for 0-based indexing
token = parts[0].strip()
cluster_id = parts[4].strip()
token_clusters[(token, token_idx)] = cluster_id
except (ValueError, IndexError):
continue
# Find clusters that contain different tokens
cluster_to_unique_tokens = {}
for (token, idx), cluster in token_clusters.items():
if cluster not in cluster_to_unique_tokens:
cluster_to_unique_tokens[cluster] = set()
cluster_to_unique_tokens[cluster].add(token)
# Filter for clusters with different tokens (more than one unique token)
clusters_with_different_tokens = {
cluster: tokens for cluster, tokens in cluster_to_unique_tokens.items()
if len(tokens) > 1
}
has_mixed_clusters = len(clusters_with_different_tokens) > 0
# Display results
col1, col2 = st.columns([3, 1])
with col1:
st.write("#### Line Content:")
st.code(line_content, language="java")
with col2:
st.write("#### Uniqueness Check:")
st.checkbox("All different tokens have unique clusters", value=not has_mixed_clusters, disabled=True)
# Show clusters with different tokens
if has_mixed_clusters:
st.write("##### Clusters with Different Tokens:")
for cluster, unique_tokens in clusters_with_different_tokens.items():
# Get all indices for each token in this cluster
token_positions = {}
for token in unique_tokens:
positions = [idx for (t, idx), c in token_clusters.items()
if t == token and c == cluster]
token_positions[token] = positions
# Format the display string
tokens_str = ", ".join(
f"'{token}' (idx: {', '.join(map(str, positions))})"
for token, positions in token_positions.items()
)
st.markdown(f"- Cluster **{cluster}**: {tokens_str}")
st.write("#### Token Distribution:")
# Create a table showing token distributions
data = []
for i, token in enumerate(tokens):
cluster = token_clusters.get((token, i), "N/A")
data.append({
"Token Index": i,
"Token": token,
"Cluster": cluster
})
df = pd.DataFrame(data)
st.table(df)
# Create a visualization of the distribution
st.write("#### Distribution Visualization:")
# Create a Sankey diagram
source = []
target = []
value = []
label = []
# Add tokens as source nodes
for i, token in enumerate(tokens):
source.append(i)
cluster = token_clusters.get((token, i), "N/A")
if cluster == "N/A":
target_idx = len(tokens)
else:
if cluster not in label[len(tokens):]:
label.append(cluster)
target_idx = len(tokens) + label[len(tokens):].index(cluster)
target.append(target_idx)
value.append(1)
if i == 0:
label.extend(tokens)
# Create and display the Sankey diagram
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=label,
color="blue"
),
link=dict(
source=source,
target=target,
value=value
)
)])
fig.update_layout(title_text="Token to Cluster Distribution", font_size=10)
st.plotly_chart(fig, use_container_width=True)
else:
st.error(f"Line number {line_number} is out of range. File has {len(lines)} lines.")
except Exception as e:
st.error(f"Error analyzing line: {e}")
if __name__ == "__main__":
main()