import os from dotenv import load_dotenv import google.genai as genai from google.api_core import retry from PIL import Image from smolagents import ChatMessage import logging from image_utils import encode_image, decode_image, save_image # Load environment variables load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Gemini API Retry Patch --- is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503}) # Check if retry wrapper has already been applied if not hasattr(genai.models.Models.generate_content, '__wrapped__'): genai.models.Models.generate_content = retry.Retry( predicate=is_retriable, initial=1.0, # Initial delay in seconds maximum=60.0, # Maximum delay multiplier=2.0, # Multiplier for exponential backoff timeout=300.0, # Total timeout in seconds )(genai.models.Models.generate_content) logger.info("Applied retry logic to Gemini API calls") # --- End Patch --- SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be: - A number OR - As few words as possible OR - A comma separated list of numbers and/or strings Rules for formatting: 1. If asked for a number: - Don't use commas - Don't use units ($, %, etc.) unless specified 2. If asked for a string: - Don't use articles - Don't use abbreviations (e.g. for cities) - Write digits in plain text unless specified 3. If asked for a comma separated list: - Apply the above rules for each element - Separate elements with commas - No spaces after commas Remember: There is only one correct answer. Be precise and concise.""" class GeminiLLM: def __init__(self, model="gemini-2.0-flash"): self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) self.model_name = model # Generation settings self.generation_config = { "temperature": 0, # Deterministic responses "top_p": 1, # Use all tokens "top_k": 1, # Choose only the most probable token "max_output_tokens": 2048, # Maximum response length } def generate(self, prompt, image=None): try: # Add system prompt to request full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {prompt}" if image is not None: logger.debug(f"Image path: {image}") if isinstance(image, str): image = Image.open(image) response = self.client.models.generate_content( model=self.model_name, contents=[full_prompt, image], config=self.generation_config ) else: response = self.client.models.generate_content( model=self.model_name, contents=[full_prompt], config=self.generation_config ) # Extract FINAL ANSWER from response content = response.text.strip() if "FINAL ANSWER:" in content: final_answer = content.split("FINAL ANSWER:")[-1].strip() return ChatMessage(role="assistant", content=final_answer) return ChatMessage(role="assistant", content=content) except genai.errors.APIError as e: if e.code in {429, 503}: logger.warning(f"Rate limit or server error (code {e.code}), retry logic will handle this") raise except Exception as e: logger.error(f"Error generating response: {str(e)}") return ChatMessage(role="assistant", content=f"Error: {str(e)}") class LLMClient: """Wrapper class for LLM to provide a unified interface""" def __init__(self): """Initialize LLM client""" self.llm = GeminiLLM() def generate_response(self, question: str, context: str = "", system_prompt: str = "") -> str: """Generate response using the LLM""" # Combine system prompt, context, and question if system_prompt: prompt = f"{system_prompt}\n\n" else: prompt = "" if context: prompt += f"Context:\n{context}\n\n" prompt += f"Question: {question}" # Generate response response = self.llm.generate(prompt) return response.content