import streamlit as st import os import json import numpy as np import pandas as pd from collections import defaultdict import matplotlib.pyplot as plt from wordcloud import WordCloud import plotly.graph_objects as go from typing import List import html import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold def get_available_cluster_files(layer_dir: str) -> List[str]: """Get list of available cluster files and extract their types and sizes.""" cluster_files = [] for file in os.listdir(layer_dir): if file.startswith('clusters-') and file.endswith('.txt'): # Parse files like 'clusters-agg-10.txt' or 'clusters-kmeans-500.txt' parts = file.replace('.txt', '').split('-') if len(parts) == 3 and parts[2].isdigit(): cluster_files.append(file) for file in sorted(cluster_files): st.sidebar.write(f"- {file}") return sorted(cluster_files) def parse_cluster_filename(filename: str) -> tuple: """Parse cluster filename to get algorithm and size.""" parts = filename.replace('.txt', '').split('-') return parts[1], int(parts[2]) # returns (algorithm, size) def load_cluster_sentences(model_dir: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str] = None, specific_cluster: str = None): """Load sentences and their cluster assignments for a given model and layer.""" sentence_file = os.path.join(model_dir, language, "input.in") cluster_file_path = os.path.join(model_dir, language, f"layer{layer}", cluster_type, cluster_file) # Load all sentences first with open(sentence_file, 'r', encoding='utf-8') as f: all_sentences = [line.strip() for line in f] # Process cluster file to get sentence mappings cluster_sentences = defaultdict(list) cluster_tokens = defaultdict(set) # Track all tokens in each cluster # First pass: collect all tokens for each cluster with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: stripped_line = line.strip() pipe_count = stripped_line.count('|') if pipe_count == 13: token = '|' parts = stripped_line.split('|||') cluster_id = parts[4].strip() elif pipe_count == 14: token = '||' parts = stripped_line.split('|||') cluster_id = parts[4].strip() else: parts = stripped_line.split('|||') if len(parts) != 5: continue token = parts[0].strip() cluster_id = parts[4].strip() # Only collect tokens for the specific cluster if specified if specific_cluster is None or cluster_id == specific_cluster: cluster_tokens[cluster_id].add(token) # Second pass: collect sentences with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: stripped_line = line.strip() pipe_count = stripped_line.count('|') if pipe_count == 13: token = '|' parts = stripped_line.split('|||') occurrence = 1 sentence_id = int(parts[2]) token_idx = int(parts[3]) cluster_id = parts[4].strip() elif pipe_count == 14: token = '||' parts = stripped_line.split('|||') occurrence = 1 sentence_id = int(parts[2]) token_idx = int(parts[3]) cluster_id = parts[4].strip() else: parts = stripped_line.split('|||') if len(parts) != 5: continue token = parts[0].strip() try: occurrence = int(parts[1]) sentence_id = int(parts[2]) token_idx = int(parts[3]) cluster_id = parts[4].strip() except ValueError: continue # Include sentences if: # 1. This is the specific cluster we're looking for AND # 2. Either no specific tokens requested OR token is in our search list if ((specific_cluster is None or cluster_id == specific_cluster) and (tokens is None or token in tokens)): if 0 <= sentence_id < len(all_sentences): sentence_tokens = all_sentences[sentence_id].split() if 0 <= token_idx < len(sentence_tokens): # Verify the token actually appears at the specified index if sentence_tokens[token_idx] == token: cluster_sentences[cluster_id].append({ "sentence": all_sentences[sentence_id], "token": token, "token_idx": token_idx, "occurrence": occurrence, "sentence_id": sentence_id, "all_cluster_tokens": cluster_tokens[cluster_id] }) return cluster_sentences def create_sentence_html(sentence, sent_info, cluster_tokens=None): """Create HTML for sentence with highlighted tokens Args: sentence: The full sentence text sent_info: Dictionary containing token and position info cluster_tokens: Set of all unique tokens in this cluster (unused) """ # Remove the triple quotes and use single quotes to prevent HTML from being escaped html = '
' html += '
' # Get token information target_token = sent_info["token"] target_idx = sent_info["token_idx"] line_number = sent_info["sentence_id"] # Split the tokenized sentence tokens = sentence.split() # Highlight only the target token in red for i, token in enumerate(tokens): if i == target_idx: # Target token in red html += f'{token} ' else: # Regular tokens html += f'{token} ' html += '
' html += f'
Token: {target_token} (Line: {line_number}, Index: {target_idx})
' html += '
' return html def display_cluster_analysis(model_name: str, language: str, cluster_type: str, selected_layer: int, cluster_file: str, clustering_method: str): """Display cluster analysis for selected model and layer.""" # Load cluster data cluster_sentences = load_cluster_sentences(model_name, language, cluster_type, selected_layer, cluster_file) # Get clustering algorithm and size from filename algorithm, size = parse_cluster_filename(cluster_file) st.write(f"### Analyzing {algorithm.upper()} clustering with {size} clusters") # Create cluster selection with navigation buttons cluster_ids = sorted(cluster_sentences.keys(), key=lambda x: int(x)) # Sort numerically, no prefix removal needed if not cluster_ids: st.error("No clusters found in the data") return # Create a unique key for this combination of parameters state_key = f"cluster_index_{selected_layer}_{algorithm}_{size}" # Initialize or reset session state if needed if state_key not in st.session_state: st.session_state[state_key] = 0 # Ensure the index is valid for the current number of clusters if st.session_state[state_key] >= len(cluster_ids): st.session_state[state_key] = 0 # Create columns with adjusted widths for better spacing col1, col2, col3, col4 = st.columns([3, 1, 1, 7]) # Adjusted column ratios # Add some vertical space before the controls st.write("") with col1: # Use a key that includes relevant parameters to force refresh select_key = f"cluster_select_{selected_layer}_{algorithm}_{size}" selected_cluster = st.selectbox( "Select cluster", range(len(cluster_ids)), index=st.session_state[state_key], format_func=lambda x: f"Cluster {cluster_ids[x]}", label_visibility="collapsed", # Hides the label but keeps accessibility key=select_key ) # Update session state when dropdown changes if selected_cluster != st.session_state[state_key]: st.session_state[state_key] = selected_cluster st.rerun() # Previous cluster button with custom styling with col2: if st.button("◀", use_container_width=True): st.session_state[state_key] = max(0, st.session_state[state_key] - 1) st.rerun() # Next cluster button with custom styling with col3: if st.button("▶", use_container_width=True): st.session_state[state_key] = min(len(cluster_ids) - 1, st.session_state[state_key] + 1) st.rerun() # Add some vertical space after the controls st.write("") # Get the current cluster cluster_id = cluster_ids[st.session_state[state_key]] # Load cluster data with specific cluster ID cluster_sentences = load_cluster_sentences( model_name, language, cluster_type, selected_layer, cluster_file, specific_cluster=cluster_id ) sentences_data = cluster_sentences[cluster_id] # Create two columns for the main content col_main, col_chat = st.columns([2, 1]) with col_main: # Get all unique tokens in this cluster cluster_tokens = sentences_data[0]["all_cluster_tokens"] if sentences_data else set() # Display word cloud for this cluster st.write("### Word Cloud") wc = create_wordcloud(sentences_data) if wc: # Create a centered column for the word cloud col1, col2, col3 = st.columns([1, 3, 1]) # Increased middle column width with col2: # Clear any existing matplotlib figures plt.clf() # Create new figure with larger size fig, ax = plt.subplots(figsize=(10, 6)) # Increased from (5, 3) to (10, 6) ax.axis('off') ax.imshow(wc, interpolation='bilinear') # Display the figure st.pyplot(fig) # Clean up plt.close(fig) # Display context sentences for this cluster st.write("### Context Sentences") # Create a dictionary to track sentences by their text unique_sentences = {} # First pass: collect all sentences and their token information for sent_info in sentences_data: # Escape any HTML special characters in the token and sentence sentence_text = html.escape(sent_info["sentence"]) token = html.escape(sent_info["token"]) token_idx = sent_info["token_idx"] if sentence_text not in unique_sentences: unique_sentences[sentence_text] = { "tokens": [(token, token_idx)], "sentence_id": sent_info["sentence_id"] } else: unique_sentences[sentence_text]["tokens"].append((token, token_idx)) # Second pass: display each unique sentence with all its tokens highlighted for sentence_text, info in unique_sentences.items(): # Create HTML with multiple tokens highlighted tokens = sentence_text.split() html_content = """
""" # Highlight all relevant tokens in the sentence for i, token in enumerate(tokens): if any(i == idx for _, idx in info["tokens"]): # This is one of our target tokens - highlight it html_content += f"{token} " else: # Regular token html_content += f"{token} " # Add token information footer html_content += f"""
Tokens: {", ".join([f"{t} (Index: {idx})" for t, idx in info["tokens"]])} (Line: {info["sentence_id"]})
""" st.markdown(html_content, unsafe_allow_html=True) with col_chat: # Add chat interface chat_with_cluster(sentences_data, clustering_method, cluster_id) def main(): # Set page to use full width st.set_page_config(layout="wide") st.title("Coconet Visual Analysis") # Initialize session state for selections if they don't exist if 'model_name' not in st.session_state: st.session_state.model_name = None if 'selected_language' not in st.session_state: st.session_state.selected_language = None if 'selected_layer' not in st.session_state: st.session_state.selected_layer = None if 'selected_cluster_type' not in st.session_state: st.session_state.selected_cluster_type = None if 'selected_cluster_file' not in st.session_state: st.session_state.selected_cluster_file = None if 'analysis_mode' not in st.session_state: st.session_state.analysis_mode = "Individual Clusters" # Get available models (directories in the current directory) current_dir = os.path.dirname(os.path.abspath(__file__)) available_models = [d for d in os.listdir(current_dir) if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.') and not d == '__pycache__'] # Model selection model_name = st.sidebar.selectbox( "Select Data", available_models, key="model_select", index=available_models.index(st.session_state.model_name) if st.session_state.model_name in available_models else 0 ) if model_name != st.session_state.model_name: st.session_state.model_name = model_name st.session_state.selected_language = None st.session_state.selected_layer = None st.session_state.selected_cluster_type = None st.session_state.selected_cluster_file = None if not model_name: st.error("No models found") return # Get available languages for the selected model model_dir = os.path.join(current_dir, model_name) available_languages = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] # Language selection selected_language = st.sidebar.selectbox( "Select Language", available_languages, key="language_select", index=available_languages.index(st.session_state.selected_language) if st.session_state.selected_language in available_languages else 0 ) if selected_language != st.session_state.selected_language: st.session_state.selected_language = selected_language st.session_state.selected_layer = None st.session_state.selected_cluster_type = None st.session_state.selected_cluster_file = None if not selected_language: st.error("No languages found for selected model") return # Get available layers language_dir = os.path.join(model_dir, selected_language) layer_dirs = [d for d in os.listdir(language_dir) if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d))] if not layer_dirs: st.error("No layer directories found") return # Extract layer numbers available_layers = sorted([int(d.replace('layer', '')) for d in layer_dirs]) # Layer selection selected_layer = st.sidebar.selectbox( "Select Layer", available_layers, key="layer_select", index=available_layers.index(st.session_state.selected_layer) if st.session_state.selected_layer in available_layers else 0 ) if selected_layer != st.session_state.selected_layer: st.session_state.selected_layer = selected_layer st.session_state.selected_cluster_type = None st.session_state.selected_cluster_file = None # Get available clustering types layer_dir = os.path.join(language_dir, f"layer{selected_layer}") available_cluster_types = [d for d in os.listdir(layer_dir) if os.path.isdir(os.path.join(layer_dir, d))] # Clustering type selection selected_cluster_type = st.sidebar.selectbox( "Select Clustering Type", available_cluster_types, key="cluster_type_select", index=available_cluster_types.index(st.session_state.selected_cluster_type) if st.session_state.selected_cluster_type in available_cluster_types else 0 ) if selected_cluster_type != st.session_state.selected_cluster_type: st.session_state.selected_cluster_type = selected_cluster_type st.session_state.selected_cluster_file = None if not selected_cluster_type: st.error("No clustering types found for selected layer") return # Get available cluster files cluster_dir = os.path.join(layer_dir, selected_cluster_type) available_cluster_files = get_available_cluster_files(cluster_dir) if not available_cluster_files: st.error("No cluster files found in the selected layer") return # Cluster file selection selected_cluster_file = st.sidebar.selectbox( "Select Clustering", available_cluster_files, key="cluster_file_select", format_func=lambda x: f"{parse_cluster_filename(x)[0].upper()} (k={parse_cluster_filename(x)[1]})", index=available_cluster_files.index(st.session_state.selected_cluster_file) if st.session_state.selected_cluster_file in available_cluster_files else 0 ) st.session_state.selected_cluster_file = selected_cluster_file # Analysis mode selection analysis_mode = st.sidebar.radio( "Select Analysis Mode", ["Individual Clusters", "Search And Analysis", "Search By Line", "Line Token Distribution", "Token Pairs", "View Input File"], key="analysis_mode_select", index=["Individual Clusters", "Search And Analysis", "Search By Line", "Line Token Distribution", "Token Pairs", "View Input File"].index(st.session_state.analysis_mode) ) st.session_state.analysis_mode = analysis_mode # Call appropriate analysis function based on mode if analysis_mode == "Individual Clusters": display_cluster_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file, selected_cluster_type) elif analysis_mode == "Search And Analysis": handle_token_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) elif analysis_mode == "Search By Line": handle_line_search(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) elif analysis_mode == "Line Token Distribution": handle_line_token_distribution(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) elif analysis_mode == "Token Pairs": display_token_pair_analysis(model_name, selected_language, selected_cluster_type, selected_layer, selected_cluster_file) elif analysis_mode == "View Input File": display_input_file(model_name, selected_language) def display_token_evolution(evolution_data: dict, tokens: List[str]): """Display evolution analysis for tokens""" st.write(f"### Evolution Analysis for Token(s)") # Create main evolution graph fig = go.Figure() # Colors for different types of lines colors = { 'individual': ['#3498db', '#e74c3c', '#2ecc71'], # Blue, Red, Green 'exclusive': ['#9b59b6', '#f1c40f', '#1abc9c'], # Purple, Yellow, Turquoise 'combined': '#34495e' # Dark Gray } # Add individual count lines for i, token in enumerate(tokens): fig.add_trace(go.Scatter( x=evolution_data['layers'], y=evolution_data['individual_counts'][token], name=f"'{token}' (Total)", mode='lines+markers', line=dict(color=colors['individual'][i], width=2), marker=dict(size=8) )) # Add exclusive count lines only for multiple tokens if len(tokens) > 1: fig.add_trace(go.Scatter( x=evolution_data['layers'], y=evolution_data['exclusive_counts'][token], name=f"'{token}' (Exclusive)", mode='lines+markers', line=dict(color=colors['exclusive'][i], width=2, dash='dot'), marker=dict(size=8) )) # Add combined counts if multiple tokens if len(tokens) > 1: fig.add_trace(go.Scatter( x=evolution_data['layers'], y=evolution_data['combined_counts'], name='Co-occurring', mode='lines+markers', line=dict(color=colors['combined'], width=2), marker=dict(size=8) )) # Update layout fig.update_layout( title=dict( text='Token Evolution Across Layers', font=dict(size=20) ), xaxis_title=dict( text='Layer', font=dict(size=14) ), yaxis_title=dict( text='Number of Clusters', font=dict(size=14) ), hovermode='x unified', showlegend=True, legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01 ) ) # Add gridlines fig.update_xaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot') fig.update_yaxes(gridcolor='LightGray', gridwidth=0.5, griddash='dot') st.plotly_chart(fig, use_container_width=True) def find_clusters_for_token(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, search_token: str) -> set: """Find cluster IDs containing the exact token""" matching_clusters = set() try: cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: stripped_line = line.strip() pipe_count = stripped_line.count('|') # Get token and cluster ID based on pipe count if pipe_count == 13: token = '|' parts = stripped_line.split('|||') cluster_id = parts[4].strip() elif pipe_count == 14: token = '||' parts = stripped_line.split('|||') cluster_id = parts[4].strip() else: parts = stripped_line.split('|||') if len(parts) != 5: continue token = parts[0].strip() cluster_id = parts[4].strip() # Use exact token matching if token == search_token: matching_clusters.add(cluster_id) except Exception as e: st.error(f"Error reading cluster file: {e}") return set() return matching_clusters def handle_token_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): """Handle token search functionality""" st.write("### Token Search") # Initialize session state for search results if needed if 'search_results_state' not in st.session_state: st.session_state.search_results_state = { 'matching_tokens': [], 'matching_tokens2': [], 'last_search': None, 'last_search2': None, 'search_mode': 'single' } elif 'search_mode' not in st.session_state.search_results_state: st.session_state.search_results_state['search_mode'] = 'single' # Radio button for search mode search_mode = st.radio( "Search Mode", ["Single Token", "Token Pair"], key="search_mode_radio", index=0 if st.session_state.search_results_state['search_mode'] == 'single' else 1 ) # Update search mode in session state st.session_state.search_results_state['search_mode'] = 'single' if search_mode == "Single Token" else 'pair' if search_mode == "Single Token": # Single token search interface search_token = st.text_input("Search for token:") if search_token: # Find matching clusters clusters = find_clusters_for_token( model_name, language, cluster_type, layer, cluster_file, search_token ) if clusters: selected_token = search_token if selected_token: # Update state st.session_state.search_results_state.update({ 'matching_tokens': [selected_token], 'last_search': selected_token }) # Display cluster details display_cluster_details( model_name, language, cluster_type, selected_token, cluster_file ) else: st.warning(f"No clusters found containing token: '{search_token}'") else: # Token Pair search col1, col2 = st.columns(2) with col1: search_token1 = st.text_input("Search for first token:") with col2: search_token2 = st.text_input("Search for second token:") if search_token1 and search_token2: # Find matching clusters for both tokens clusters1 = find_clusters_for_token( model_name, language, cluster_type, layer, cluster_file, search_token1 ) clusters2 = find_clusters_for_token( model_name, language, cluster_type, layer, cluster_file, search_token2 ) # Find intersection of clusters matching_clusters = clusters1 & clusters2 if matching_clusters: # Update state st.session_state.search_results_state.update({ 'matching_tokens': [search_token1], 'matching_tokens2': [search_token2], 'last_search': search_token1, 'last_search2': search_token2 }) # Display cluster details display_cluster_details( model_name, language, cluster_type, search_token1, cluster_file, second_token=search_token2 ) else: st.warning(f"No clusters found containing both tokens: '{search_token1}' and '{search_token2}'") def analyze_token_evolution(model_name: str, language: str, cluster_type: str, layer: int, tokens: List[str], cluster_file: str) -> dict: """Analyze token evolution across all available layers""" # Get all available layers by checking directories language_dir = os.path.join(model_name, language) available_layers = [] for d in os.listdir(language_dir): if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)): try: layer_num = int(d.replace('layer', '')) available_layers.append(layer_num) except ValueError: continue available_layers.sort() # Sort layers numerically evolution_data = { 'layers': available_layers, 'individual_counts': {token: [] for token in tokens}, 'exclusive_counts': {token: [] for token in tokens}, # New: track exclusive counts 'combined_counts': [] if len(tokens) > 1 else None } # Extract cluster size from the filename (e.g., "clusters-agg-500.txt" -> "500") cluster_size = cluster_file.split('-')[-1].replace('.txt', '') # Handle shortened form for agglomerative clustering cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type for current_layer in available_layers: # Get clusters for each token token_clusters = {} cluster_file_path = os.path.join( model_name, language, f"layer{current_layer}", cluster_type, f"clusters-{cluster_type_short}-{cluster_size}.txt" ) # Skip layer if cluster file doesn't exist if not os.path.exists(cluster_file_path): continue for token in tokens: clusters = find_clusters_for_token( model_name, language, cluster_type, current_layer, f"clusters-{cluster_type_short}-{cluster_size}.txt", token ) token_clusters[token] = set(clusters) evolution_data['individual_counts'][token].append(len(clusters)) # Calculate exclusive and co-occurring clusters if len(tokens) > 1: # Calculate co-occurrences cooccurring_clusters = set.intersection(*[token_clusters[token] for token in tokens]) evolution_data['combined_counts'].append(len(cooccurring_clusters)) # Calculate exclusive counts for each token for token in tokens: other_tokens = set(tokens) - {token} other_clusters = set.union(*[token_clusters[t] for t in other_tokens]) if other_tokens else set() exclusive_clusters = token_clusters[token] - other_clusters evolution_data['exclusive_counts'][token].append(len(exclusive_clusters)) else: # For single token, exclusive count is the same as individual count evolution_data['exclusive_counts'][tokens[0]] = evolution_data['individual_counts'][tokens[0]] return evolution_data def find_clusters_with_multiple_tokens(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, tokens: List[str]) -> dict: """Find clusters containing multiple specified tokens""" clusters = defaultdict(lambda: {'matching_tokens': {token: set() for token in tokens}}) try: cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split('|||') if len(parts) == 5: token = parts[0].strip() cluster_id = parts[4].strip() for search_token in tokens: if token == search_token: # Add all tokens from this cluster with open(cluster_file_path, 'r', encoding='utf-8') as f2: for line2 in f2: parts2 = line2.strip().split('|||') if len(parts2) == 5 and parts2[4].strip() == cluster_id: clusters[cluster_id]['matching_tokens'][search_token].add(parts2[0].strip()) except Exception as e: st.error(f"Error reading cluster file: {e}") return {} # Filter to only keep clusters with all tokens return {k: v for k, v in clusters.items() if all(v['matching_tokens'][token] for token in tokens)} def handle_semantic_tag_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): """Handle semantic tag search functionality""" st.write("### Semantic Tag Search") st.info("This feature will be implemented soon.") def display_token_pair_analysis(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): """Display analysis for predefined token pairs""" st.write("### Token Pair Analysis") # Get predefined token pairs token_pairs = get_predefined_token_pairs() # Create tabs for each category tabs = st.tabs(list(token_pairs.keys())) for tab, (category, data) in zip(tabs, token_pairs.items()): with tab: st.write(f"### {category}") st.write(data["description"]) # Instead of using expanders, use a selectbox to choose the token pair pairs = data["pairs"] pair_labels = [f"{t1} vs {t2}" for t1, t2 in pairs] selected_pair_idx = st.selectbox( "Select token pair", range(len(pairs)), format_func=lambda i: pair_labels[i], key=f"pair_select_{category}" ) # Get the selected pair token1, token2 = pairs[selected_pair_idx] # Update state for token pair search st.session_state.search_results_state = { 'matching_tokens': [token1], 'matching_tokens2': [token2], 'last_search': token1, 'last_search2': token2, 'search_mode': 'pair' } # Display results st.write("### Search Results") # Create tabs for different views tab1, tab2 = st.tabs(["Evolution Analysis", "Co-occurring Clusters"]) with tab1: evolution_data = analyze_token_evolution( model_name, language, cluster_type, layer, [token1, token2], cluster_file ) if evolution_data: display_token_evolution(evolution_data, [token1, token2]) with tab2: display_cluster_details( model_name, language, cluster_type, token1, cluster_file, second_token=token2 ) def get_predefined_token_pairs(): """Return predefined token pairs organized by categories""" return { "Control Flow": { "description": "Different control flow constructs", "pairs": [ ("for", "while"), ("if", "switch"), ("break", "continue"), ("try", "catch") ] }, "Access Modifiers": { "description": "Access and modifier keywords", "pairs": [ ("public", "private"), ("static", "final"), ("abstract", "interface") ] }, "Variable/Type": { "description": "Variable and type-related tokens", "pairs": [ ("int", "Integer"), ("null", "Optional"), ("var", "String") # Example of var vs explicit type ] }, "Collections": { "description": "Collection-related tokens", "pairs": [ ("List", "Array"), ("ArrayList", "LinkedList"), ("HashMap", "TreeMap"), ("Set", "List") ] }, "Threading": { "description": "Threading and concurrency tokens", "pairs": [ ("synchronized", "volatile"), ("Runnable", "Callable"), ("wait", "sleep") ] }, "Object-Oriented": { "description": "Object-oriented programming tokens", "pairs": [ ("extends", "implements"), ("this", "super"), ("new", "clone") ] } } def create_wordcloud(tokens, token1=None, token2=None): """Create and return a word cloud from tokens with frequencies""" if not tokens: return None # Create frequency dict by counting occurrences freq_dict = {} # If tokens is a list of dictionaries (from cluster data) if isinstance(tokens, list) and tokens and isinstance(tokens[0], dict): # Count token occurrences in the cluster for token_info in tokens: token = token_info["token"] freq_dict[token] = freq_dict.get(token, 0) + 1 else: # If tokens is a set/list of strings, convert to frequency dict tokens_list = list(tokens) if isinstance(tokens, set) else tokens for token in tokens_list: freq_dict[token] = freq_dict.get(token, 0) + 1 # Normalize frequencies with a base size max_freq = max(freq_dict.values()) base_size = 30 # Base size for all tokens normalized_freq = {token: base_size + ((count / max_freq) * 70) for token, count in freq_dict.items()} # Boost frequency of searched tokens if provided if token1: normalized_freq[token1] = normalized_freq.get(token1, 0) + 5 if token2: normalized_freq[token2] = normalized_freq.get(token2, 0) + 5 # Custom colormap with dark shades of brown, green, and blue wc = WordCloud( width=800, height=400, background_color='white', max_words=100, prefer_horizontal=1.0, # Make all words horizontal colormap='Dark2' # Dark colormap with browns, greens, blues ).generate_from_frequencies(normalized_freq) return wc def display_cluster_details(model_name: str, language: str, cluster_type: str, token: str, cluster_file: str, second_token: str = None): """Display detailed cluster information organized by layers""" # Get all available layers language_dir = os.path.join(model_name, language) available_layers = [] for d in os.listdir(language_dir): if d.startswith('layer') and os.path.isdir(os.path.join(language_dir, d)): try: layer_num = int(d.replace('layer', '')) available_layers.append(layer_num) except ValueError: continue available_layers.sort() # Create tabs for each layer layer_tabs = st.tabs([f"Layer {layer}" for layer in available_layers]) # Handle shortened form for agglomerative clustering cluster_type_short = "agg" if cluster_type == "agglomerative" else cluster_type cluster_size = cluster_file.split('-')[-1].replace('.txt', '') for layer, tab in zip(available_layers, layer_tabs): with tab: # Find clusters containing the token(s) matching_clusters = find_clusters_for_token( model_name, language, cluster_type, layer, f"clusters-{cluster_type_short}-{cluster_size}.txt", token ) if second_token: matching_clusters2 = find_clusters_for_token( model_name, language, cluster_type, layer, f"clusters-{cluster_type_short}-{cluster_size}.txt", second_token ) # Find intersection of clusters containing both tokens matching_clusters &= matching_clusters2 if matching_clusters: # Sort cluster IDs numerically cluster_ids = sorted(matching_clusters, key=lambda x: int(x)) # Create dropdown for cluster selection selected_cluster = st.selectbox( f"Select cluster from Layer {layer}", cluster_ids, format_func=lambda x: f"Cluster {x}", key=f"cluster_select_{layer}_{token}_{second_token if second_token else ''}" ) if selected_cluster: # Load all cluster data for the selected cluster cluster_data = load_cluster_sentences( model_name, language, cluster_type, layer, f"clusters-{cluster_type_short}-{cluster_size}.txt", specific_cluster=selected_cluster ) if selected_cluster in cluster_data: # Create two columns for the main content col_main, col_chat = st.columns([2, 1]) with col_main: shown_sentences = set() # Get all unique tokens in this cluster cluster_tokens = cluster_data[selected_cluster][0]["all_cluster_tokens"] if cluster_data[selected_cluster] else set() # Display word cloud for this cluster st.write("### Word Cloud") wc = create_wordcloud(cluster_data[selected_cluster], token1=token, token2=second_token) if wc: # Create a centered column for the word cloud col1, col2, col3 = st.columns([1, 3, 1]) with col2: plt.clf() fig, ax = plt.subplots(figsize=(10, 6)) ax.axis('off') ax.imshow(wc, interpolation='bilinear') st.pyplot(fig) plt.close(fig) # First show sentences containing searched tokens if second_token: st.write(f"#### Context for searched tokens: '{token}' and '{second_token}'") else: st.write(f"#### Context for searched token: '{token}'") html_output = [] for sent_info in cluster_data[selected_cluster]: if sent_info["token"] in [token, second_token] and sent_info["sentence"] not in shown_sentences: html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) shown_sentences.add(sent_info["sentence"]) html_output.append("
") if html_output: st.markdown("\n".join(html_output), unsafe_allow_html=True) # Then show all other sentences in the cluster st.write("#### All sentences in cluster") html_output = [] for sent_info in cluster_data[selected_cluster]: if sent_info["sentence"] not in shown_sentences: html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) shown_sentences.add(sent_info["sentence"]) html_output.append("
") if not html_output: st.info("No additional sentences in this cluster.") else: st.markdown("\n".join(html_output), unsafe_allow_html=True) with col_chat: # Add chat interface with cluster ID chat_with_cluster(cluster_data[selected_cluster], cluster_type, selected_cluster) else: st.info(f"No sentences found for cluster {selected_cluster}") else: if second_token: st.info(f"No clusters containing both '{token}' and '{second_token}' found in Layer {layer}") else: st.info(f"No clusters containing '{token}' found in Layer {layer}") def display_input_file(model_name: str, language: str): """Display the contents of input.in file""" st.write("### Input File Contents") input_file = os.path.join(model_name, language, "input.in") try: with open(input_file, 'r', encoding='utf-8') as f: lines = f.readlines() # Add line numbers and display in a scrollable container numbered_lines = [f"{i+1:4d} | {line}" for i, line in enumerate(lines)] st.code('\n'.join(numbered_lines), language='text') # Display some statistics st.write(f"Total lines: {len(lines)}") except Exception as e: st.error(f"Error reading input file: {e}") from dotenv import load_dotenv def setup_gemini(): """Setup Gemini model with API key and temperature setting""" load_dotenv() # Load environment variables from .env file api_key = os.getenv("GEMINI_API_KEY") if not api_key: st.error("Please set GEMINI_API_KEY in your .env file") return False genai.configure(api_key=api_key) # Create generation config with temperature 0.4 generation_config = { "temperature": 0.4, "top_p": 1, "top_k": 1, "max_output_tokens": 2048, } return generation_config def get_cluster_context(cluster_sentences): """Format cluster data into a clear context for Gemini, focusing on searched tokens only""" # Get unique searched tokens and their frequencies token_counts = {} for sent_info in cluster_sentences: token = sent_info["token"] token_counts[token] = token_counts.get(token, 0) + 1 # Format the context context = "Here is the data for this cluster:\n\n" # Add token frequency information context += "Tokens being analyzed:\n" for token, count in sorted(token_counts.items(), key=lambda x: (-x[1], x[0])): context += f"- '{token}' (appears {count} times)\n" # Add sentence examples with token highlighting context += "\nExample sentences (analyzed tokens marked with *):\n" unique_sentences = {} for sent_info in cluster_sentences: sentence = sent_info["sentence"] if sentence not in unique_sentences: tokens = sentence.split() marked_tokens = tokens.copy() marked_tokens[sent_info["token_idx"]] = f"*{tokens[sent_info['token_idx']]}*" unique_sentences[sentence] = " ".join(marked_tokens) for marked_sentence in unique_sentences.values(): context += f"- {marked_sentence}\n" return context def chat_with_cluster(cluster_sentences, clustering_method, cluster_id=None): """Create a chat interface for discussing the cluster with Gemini""" # Initialize Gemini with temperature setting generation_config = setup_gemini() if not generation_config: return model = genai.GenerativeModel('gemini-2.0-flash', generation_config=generation_config) # Create a unique key for this specific cluster's chat cluster_tokens = {sent_info["token"] for sent_info in cluster_sentences} # Include cluster_id in the key if provided cluster_key = f"cluster_chat_{cluster_id}_{'-'.join(sorted(cluster_tokens))}" if cluster_id else f"cluster_chat_{'-'.join(sorted(cluster_tokens))}" history_key = f"{cluster_key}_history" # Reset chat and history on each page load st.session_state[cluster_key] = model.start_chat(history=[]) st.session_state[history_key] = [] # Get cluster context and send initial message context = get_cluster_context(cluster_sentences) initial_message = f"""You are a helpful assistant analyzing Java code clusters. You will help users understand patterns and relationships in the provided cluster data. Here is the context for this cluster: {context} Please analyze this data and help users understand what each token is doing in the cluster. Be concise and to the point. """ # Send initial message if history is empty if not st.session_state[history_key]: response = st.session_state[cluster_key].send_message(initial_message) st.session_state[history_key].append(("assistant", response.text)) # Display chat interface st.write("### Chat with Gemini about this Cluster") st.write("Ask questions about patterns, relationships, or insights in this cluster.") # Display chat history for role, message in st.session_state[history_key]: with st.chat_message(role): st.write(message) # Chat input with cluster-specific key user_input = st.text_input("Your question:", key=f"input_{cluster_key}") if user_input: try: # Display user message with st.chat_message("user"): st.write(user_input) # Store user message in history st.session_state[history_key].append(("user", user_input)) # Get and display Gemini's response with st.chat_message("assistant"): response = st.session_state[cluster_key].send_message(user_input) st.write(response.text) # Store assistant's response in history st.session_state[history_key].append(("assistant", response.text)) except Exception as e: st.error(f"Error communicating with Gemini: {e}") def handle_line_search(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): """Handle line number search functionality""" st.write("### Search by Line Number") # Create two columns for search inputs col1, col2 = st.columns([1, 1]) with col1: # Input for line number line_number = st.number_input("Enter line number:", min_value=1, step=1) with col2: # Optional token search token_filter = st.text_input("Filter by token (optional):", "") if st.button("Search"): # Find clusters containing the line number clusters = find_clusters_for_line( model_name, language, cluster_type, layer, cluster_file, line_number ) if clusters: if token_filter: # Filter clusters to only those containing both the line and the token filtered_clusters = set() for cluster_id in clusters: cluster_data = load_cluster_sentences( model_name, language, cluster_type, layer, cluster_file, specific_cluster=cluster_id ) if any(sent_info["token"] == token_filter for sent_info in cluster_data[cluster_id]): filtered_clusters.add(cluster_id) clusters = filtered_clusters if not clusters: st.warning(f"No clusters found containing both line {line_number} and token '{token_filter}'") return st.success(f"Found {len(clusters)} clusters containing line {line_number} and token '{token_filter}'") else: st.success(f"Found {len(clusters)} clusters containing line {line_number}") # Create tabs for each cluster cluster_tabs = st.tabs([f"Cluster {cluster_id}" for cluster_id in sorted(clusters)]) for tab, cluster_id in zip(cluster_tabs, sorted(clusters)): with tab: # Load cluster data cluster_data = load_cluster_sentences( model_name, language, cluster_type, layer, cluster_file, specific_cluster=cluster_id ) if cluster_id in cluster_data: # Create two columns for the main content col_main, col_chat = st.columns([2, 1]) with col_main: # Display word cloud st.write("### Word Cloud") wc = create_wordcloud(cluster_data[cluster_id], token_filter if token_filter else None) if wc: col1, col2, col3 = st.columns([1, 3, 1]) with col2: plt.clf() fig, ax = plt.subplots(figsize=(10, 6)) ax.axis('off') ax.imshow(wc, interpolation='bilinear') st.pyplot(fig) plt.close(fig) # Display sentences st.write("### Sentences in Cluster") shown_sentences = set() # First show the searched line st.write("#### Searched Line") html_output = [] for sent_info in cluster_data[cluster_id]: if sent_info["sentence_id"] == line_number - 1: # Adjust for 0-based indexing html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) shown_sentences.add(sent_info["sentence"]) break if html_output: st.markdown("\n".join(html_output), unsafe_allow_html=True) # Then show sentences with the filtered token (if specified) if token_filter: st.write(f"#### Sentences containing '{token_filter}'") html_output = [] for sent_info in cluster_data[cluster_id]: if (sent_info["token"] == token_filter and sent_info["sentence"] not in shown_sentences): html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) shown_sentences.add(sent_info["sentence"]) if html_output: st.markdown("\n".join(html_output), unsafe_allow_html=True) # Finally show other sentences st.write("#### Other sentences in cluster") html_output = [] for sent_info in cluster_data[cluster_id]: if sent_info["sentence"] not in shown_sentences: html_output.append(create_sentence_html(sent_info["sentence"], sent_info)) shown_sentences.add(sent_info["sentence"]) if html_output: st.markdown("\n".join(html_output), unsafe_allow_html=True) with col_chat: # Add chat interface with cluster ID chat_with_cluster(cluster_data[cluster_id], cluster_type, cluster_id) else: st.warning(f"No clusters found containing line {line_number}") def find_clusters_for_line(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str, line_number: int) -> set: """Find cluster IDs containing the specified line number""" matching_clusters = set() try: cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: stripped_line = line.strip() parts = stripped_line.split('|||') if len(parts) >= 3: try: sentence_id = int(parts[2]) if sentence_id == line_number - 1: # Adjust for 0-based indexing cluster_id = parts[4].strip() matching_clusters.add(cluster_id) except (ValueError, IndexError): continue except Exception as e: st.error(f"Error reading cluster file: {e}") return set() return matching_clusters # Add new function to handle line token distribution def handle_line_token_distribution(model_name: str, language: str, cluster_type: str, layer: int, cluster_file: str): """Display cluster distribution for each token in a specific line""" st.write("### Line Token Distribution") # Input for line number line_number = st.number_input("Enter line number:", min_value=1, step=1) if st.button("Analyze"): # Load the input file to get the line content input_file = os.path.join(model_name, language, "input.in") try: with open(input_file, 'r', encoding='utf-8') as f: lines = f.readlines() if line_number <= len(lines): line_content = lines[line_number - 1].strip() tokens = line_content.split() # Create a dictionary to store cluster assignments for each token token_clusters = {} # Find clusters for each token in the line cluster_file_path = os.path.join(model_name, language, f"layer{layer}", cluster_type, cluster_file) with open(cluster_file_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split('|||') if len(parts) >= 5: try: sentence_id = int(parts[2]) token_idx = int(parts[3]) if sentence_id == line_number - 1: # Adjust for 0-based indexing token = parts[0].strip() cluster_id = parts[4].strip() token_clusters[(token, token_idx)] = cluster_id except (ValueError, IndexError): continue # Find clusters that contain different tokens cluster_to_unique_tokens = {} for (token, idx), cluster in token_clusters.items(): if cluster not in cluster_to_unique_tokens: cluster_to_unique_tokens[cluster] = set() cluster_to_unique_tokens[cluster].add(token) # Filter for clusters with different tokens (more than one unique token) clusters_with_different_tokens = { cluster: tokens for cluster, tokens in cluster_to_unique_tokens.items() if len(tokens) > 1 } has_mixed_clusters = len(clusters_with_different_tokens) > 0 # Display results col1, col2 = st.columns([3, 1]) with col1: st.write("#### Line Content:") st.code(line_content, language="java") with col2: st.write("#### Uniqueness Check:") st.checkbox("All different tokens have unique clusters", value=not has_mixed_clusters, disabled=True) # Show clusters with different tokens if has_mixed_clusters: st.write("##### Clusters with Different Tokens:") for cluster, unique_tokens in clusters_with_different_tokens.items(): # Get all indices for each token in this cluster token_positions = {} for token in unique_tokens: positions = [idx for (t, idx), c in token_clusters.items() if t == token and c == cluster] token_positions[token] = positions # Format the display string tokens_str = ", ".join( f"'{token}' (idx: {', '.join(map(str, positions))})" for token, positions in token_positions.items() ) st.markdown(f"- Cluster **{cluster}**: {tokens_str}") st.write("#### Token Distribution:") # Create a table showing token distributions data = [] for i, token in enumerate(tokens): cluster = token_clusters.get((token, i), "N/A") data.append({ "Token Index": i, "Token": token, "Cluster": cluster }) df = pd.DataFrame(data) st.table(df) # Create a visualization of the distribution st.write("#### Distribution Visualization:") # Create a Sankey diagram source = [] target = [] value = [] label = [] # Add tokens as source nodes for i, token in enumerate(tokens): source.append(i) cluster = token_clusters.get((token, i), "N/A") if cluster == "N/A": target_idx = len(tokens) else: if cluster not in label[len(tokens):]: label.append(cluster) target_idx = len(tokens) + label[len(tokens):].index(cluster) target.append(target_idx) value.append(1) if i == 0: label.extend(tokens) # Create and display the Sankey diagram fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=label, color="blue" ), link=dict( source=source, target=target, value=value ) )]) fig.update_layout(title_text="Token to Cluster Distribution", font_size=10) st.plotly_chart(fig, use_container_width=True) else: st.error(f"Line number {line_number} is out of range. File has {len(lines)} lines.") except Exception as e: st.error(f"Error analyzing line: {e}") if __name__ == "__main__": main()