File size: 4,617 Bytes
04ffb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from dotenv import load_dotenv
import google.genai as genai
from google.api_core import retry
from PIL import Image
from smolagents import ChatMessage
import logging
from image_utils import encode_image, decode_image, save_image

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Gemini API Retry Patch ---
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# Check if retry wrapper has already been applied
if not hasattr(genai.models.Models.generate_content, '__wrapped__'):
    genai.models.Models.generate_content = retry.Retry(
        predicate=is_retriable,
        initial=1.0,  # Initial delay in seconds
        maximum=60.0,  # Maximum delay
        multiplier=2.0,  # Multiplier for exponential backoff
        timeout=300.0,  # Total timeout in seconds
    )(genai.models.Models.generate_content)
    logger.info("Applied retry logic to Gemini API calls")
# --- End Patch ---

SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. 

YOUR FINAL ANSWER should be:
- A number OR
- As few words as possible OR
- A comma separated list of numbers and/or strings

Rules for formatting:
1. If asked for a number:
   - Don't use commas
   - Don't use units ($, %, etc.) unless specified
2. If asked for a string:
   - Don't use articles
   - Don't use abbreviations (e.g. for cities)
   - Write digits in plain text unless specified
3. If asked for a comma separated list:
   - Apply the above rules for each element
   - Separate elements with commas
   - No spaces after commas

Remember: There is only one correct answer. Be precise and concise."""

class GeminiLLM:
    def __init__(self, model="gemini-2.0-flash"):
        self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
        self.model_name = model
        # Generation settings
        self.generation_config = {
            "temperature": 0,  # Deterministic responses
            "top_p": 1,       # Use all tokens
            "top_k": 1,       # Choose only the most probable token
            "max_output_tokens": 2048,  # Maximum response length
        }

    def generate(self, prompt, image=None):
        try:
            # Add system prompt to request
            full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {prompt}"
            
            if image is not None:
                logger.debug(f"Image path: {image}")
                if isinstance(image, str):
                    image = Image.open(image)
                response = self.client.models.generate_content(
                    model=self.model_name,
                    contents=[full_prompt, image],
                    config=self.generation_config
                )
            else:
                response = self.client.models.generate_content(
                    model=self.model_name,
                    contents=[full_prompt],
                    config=self.generation_config
                )
            
            # Extract FINAL ANSWER from response
            content = response.text.strip()
            if "FINAL ANSWER:" in content:
                final_answer = content.split("FINAL ANSWER:")[-1].strip()
                return ChatMessage(role="assistant", content=final_answer)
            return ChatMessage(role="assistant", content=content)
        except genai.errors.APIError as e:
            if e.code in {429, 503}:
                logger.warning(f"Rate limit or server error (code {e.code}), retry logic will handle this")
            raise
        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            return ChatMessage(role="assistant", content=f"Error: {str(e)}")

class LLMClient:
    """Wrapper class for LLM to provide a unified interface"""
    
    def __init__(self):
        """Initialize LLM client"""
        self.llm = GeminiLLM()
    
    def generate_response(self, question: str, context: str = "", system_prompt: str = "") -> str:
        """Generate response using the LLM"""
        # Combine system prompt, context, and question
        if system_prompt:
            prompt = f"{system_prompt}\n\n"
        else:
            prompt = ""
        
        if context:
            prompt += f"Context:\n{context}\n\n"
        
        prompt += f"Question: {question}"
        
        # Generate response
        response = self.llm.generate(prompt)
        return response.content