import pandas as pd import os from google import genai from google.genai import types import json from tqdm import tqdm from typing import List, Dict import time def configure_genai(api_key: str): """Configure the Gemini API with the provided key.""" os.environ["GEMINI_API_KEY"] = api_key def load_predictions(task: str, layer: int) -> pd.DataFrame: """Load predictions from CSV file.""" predictions_path = os.path.join("src", "codebert", task, f"layer{layer}", f"predictions_layer_{layer}.csv") if os.path.exists(predictions_path): try: df = pd.read_csv(predictions_path, delimiter='\t') df['Token'] = df['Token'].astype(str) df['predicted_cluster'] = df['Top 1'].astype(str) return df except Exception as e: print(f"Error loading predictions: {str(e)}") return None return None def load_clusters(task: str, layer: int) -> Dict: """Load cluster data from clusters file.""" clusters_path = os.path.join("src", "codebert", task, f"layer{layer}", "clusters-350.txt") if not os.path.exists(clusters_path): return None clusters = {} try: with open(clusters_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue try: parts = [p.strip() for p in line.split('|||')] if len(parts) == 5: token, occurrence, line_num, col_num, cluster_id = parts cluster_id = cluster_id.split('|')[0].strip() if not cluster_id.isdigit(): continue cluster_id = str(int(cluster_id)) if cluster_id not in clusters: clusters[cluster_id] = [] clusters[cluster_id].append({ 'token': token, 'line_num': int(line_num), 'col_num': int(col_num) }) except Exception: continue except Exception as e: print(f"Error loading clusters: {str(e)}") return None return clusters def load_sentences(task: str, layer: int, file_name: str) -> List[str]: """Load sentences from specified file.""" file_path = os.path.join("src", "codebert", task, f"layer{layer}", file_name) if not os.path.exists(file_path): file_path = os.path.join("src", "codebert", task, file_name) try: with open(file_path, 'r', encoding='utf-8') as f: return f.readlines() except Exception as e: print(f"Error loading sentences from {file_path}: {str(e)}") return [] def get_gemini_explanation(sentence: str, highlighted_token: str, cluster_words: List[str]) -> str: """Get explanation from Gemini about the relationship between the token and cluster words.""" highlighted_sentence = sentence.replace(highlighted_token, f"[[{highlighted_token}]]") prompt = f"""Do you find any common semantic, structural, lexical and topical relation between the word highlighted in the sentence (enclosed in [[ ]]) and the following list of words? Give a more specific and concise summary about the most prominent relation among these words. Sentence: {highlighted_sentence} List of words: {', '.join(cluster_words)} Answer concisely and to the point.""" # Create the client client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) # Ensure this is correct model = "gemini-2.0-flash" contents = [ types.Content( role="user", parts=[ types.Part.from_text(text=prompt), # Ensure this is the correct usage ], ), ] generate_content_config = types.GenerateContentConfig( temperature=1.0, response_mime_type="text/plain", ) explanation = "" for chunk in client.models.generate_content_stream( model=model, contents=contents, config=generate_content_config, ): explanation += chunk.text return explanation.strip() def is_cls_token(token: str) -> bool: """Check if a token is a CLS token.""" return token.startswith('[CLS]') def get_gemini_explanation_for_cls(sentence: str, cluster_words: List[str], context_sentences: List[str]) -> str: """Get explanation from Gemini about the CLS token and its relationship with the cluster.""" # Include context sentences in the prompt context_text = "\n".join(context_sentences) if context_sentences else "No context sentences available." prompt = f"""[CLS] tokens represent the entire sentence. For this sentence, explain the semantic, structural, lexical, or topical meaning in relation to the list of words from similar contexts. What cohesive meaning does this sentence share with the contextual themes? Original Sentence: {sentence} List of cluster words: {', '.join(cluster_words)} Context Sentences of the list of cluster words: {context_text} Answer concisely and to the point about the semantic or topical meaning this sentence shares with the contexts.""" # Create the client client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) model = "gemini-2.0-flash" contents = [ types.Content( role="user", parts=[ types.Part.from_text(text=prompt), ], ), ] generate_content_config = types.GenerateContentConfig( temperature=1.0, response_mime_type="text/plain", ) explanation = "" for chunk in client.models.generate_content_stream( model=model, contents=contents, config=generate_content_config, ): explanation += chunk.text return explanation.strip() def get_gemini_explanation_with_retry(sentence: str, highlighted_token: str, cluster_words: List[str], max_retries: int = 3) -> str: """Get explanation from Gemini with retry logic.""" retry_count = 0 while retry_count < max_retries: try: return get_gemini_explanation(sentence, highlighted_token, cluster_words) except Exception as e: retry_count += 1 error_type = type(e).__name__ print(f"\nEncountered {error_type}: {str(e)}") if retry_count < max_retries: wait_time = 60 # Wait for 60 seconds before retrying print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...") time.sleep(wait_time) else: print(f"Max retries ({max_retries}) reached. Returning error message.") return f"Error generating explanation after {max_retries} attempts: {str(e)}" def get_gemini_explanation_for_cls_with_retry(sentence: str, cluster_words: List[str], context_sentences: List[str], max_retries: int = 3) -> str: """Get explanation for CLS tokens with retry logic.""" retry_count = 0 while retry_count < max_retries: try: return get_gemini_explanation_for_cls(sentence, cluster_words, context_sentences) except Exception as e: retry_count += 1 error_type = type(e).__name__ print(f"\nEncountered {error_type}: {str(e)}") if retry_count < max_retries: wait_time = 60 # Wait for 60 seconds before retrying print(f"Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...") time.sleep(wait_time) else: print(f"Max retries ({max_retries}) reached. Returning error message.") return f"Error generating explanation after {max_retries} attempts: {str(e)}" def process_tokens(task: str, layer: int, api_key: str): """Process the first 15 tokens for a given task and layer with API rate limiting and error handling.""" # Configure Gemini configure_genai(api_key) # Load necessary data predictions_df = load_predictions(task, layer) clusters = load_clusters(task, layer) dev_sentences = load_sentences(task, layer, "dev.in") input_sentences = load_sentences(task, layer, "input.in") if predictions_df is None or clusters is None: print("Failed to load required data") return # Limit to first 15 tokens predictions_df = predictions_df.head(15) print(f"Limited processing to first {len(predictions_df)} tokens") results = [] batch_size = 15 # API limit of 15 calls per minute call_count = 0 start_time = time.time() # Create output directory if it doesn't exist output_dir = os.path.join("src", "codebert", task, f"layer{layer}") os.makedirs(output_dir, exist_ok=True) # Check if there's an interim file to resume from interim_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_test15.json") if os.path.exists(interim_file): try: with open(interim_file, 'r', encoding='utf-8') as f: results = json.load(f) print(f"Resuming from {len(results)} previously processed tokens") # Skip tokens we've already processed processed_indices = {(result['line_idx'], result['position_idx']) for result in results} except Exception as e: print(f"Error loading interim file: {str(e)}") processed_indices = set() else: processed_indices = set() # Process limited number of tokens, showing progress with tqdm for idx, row in tqdm(predictions_df.iterrows(), total=len(predictions_df), desc="Processing tokens"): token = row['Token'] line_idx = row['line_idx'] position_idx = row['position_idx'] predicted_cluster = row['predicted_cluster'] # Skip if we've already processed this token if (line_idx, position_idx) in processed_indices: continue # Get original sentence if line_idx < len(dev_sentences): original_sentence = dev_sentences[line_idx].strip() else: continue # Get unique cluster words if predicted_cluster in clusters: cluster_words = list(set(token_info['token'] for token_info in clusters[predicted_cluster])) # Gather context sentences from the predicted cluster context_sentences = [] for token_info in clusters[predicted_cluster]: context_line_num = token_info['line_num'] if context_line_num < len(input_sentences): context_sentences.append(input_sentences[context_line_num].strip()) else: continue # Rate limiting: check if we've reached the batch limit call_count += 1 if call_count >= batch_size: elapsed = time.time() - start_time # If we've made batch_size calls in less than 60 seconds, wait until the minute is up if elapsed < 60: wait_time = 60 - elapsed print(f"\nReached API limit of {batch_size} calls. Waiting for {wait_time:.2f} seconds...") time.sleep(wait_time) # Reset counters call_count = 0 start_time = time.time() # Choose the right explanation function based on token type try: if is_cls_token(token): # Special handling for CLS tokens with retry explanation = get_gemini_explanation_for_cls_with_retry(original_sentence, cluster_words, context_sentences) else: # Standard handling for other tokens with retry explanation = get_gemini_explanation_with_retry(original_sentence, token, cluster_words) # Store results result = { 'token': token, 'is_cls_token': is_cls_token(token), 'line_idx': int(line_idx), 'position_idx': int(position_idx), 'predicted_cluster': predicted_cluster, 'original_sentence': original_sentence, 'cluster_words': cluster_words, 'context_sentences': context_sentences, 'explanation': explanation } results.append(result) # Add to processed indices processed_indices.add((line_idx, position_idx)) # Save after each token for this small test run with open(interim_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"\nSaved results to: {interim_file}") except Exception as e: print(f"\nUnexpected error processing token {token}: {str(e)}") # Save current results before potentially exiting with open(interim_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Emergency save to: {interim_file}") # Wait a minute before continuing print("Waiting 60 seconds before continuing...") time.sleep(60) # Reset batch counters call_count = 0 start_time = time.time() # Save final results with a different name to indicate it's the test run output_file = os.path.join(output_dir, f"token_explanations_layer_{layer}_first15.json") with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Results saved to: {output_file}") def main(): # Configuration API_KEY = "AIzaSyCUCwrqcDNTSaHsn5Ln_91A0L03W864iYU" # Replace with your API key TASK = "language_classification" # Replace with your task name LAYER = 11 # Replace with your layer number process_tokens(TASK, LAYER, API_KEY) if __name__ == "__main__": main()