|
import os |
|
from dotenv import load_dotenv |
|
import google.genai as genai |
|
from google.api_core import retry |
|
from PIL import Image |
|
from smolagents import ChatMessage |
|
import logging |
|
from image_utils import encode_image, decode_image, save_image |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503}) |
|
|
|
|
|
if not hasattr(genai.models.Models.generate_content, '__wrapped__'): |
|
genai.models.Models.generate_content = retry.Retry( |
|
predicate=is_retriable, |
|
initial=1.0, |
|
maximum=60.0, |
|
multiplier=2.0, |
|
timeout=300.0, |
|
)(genai.models.Models.generate_content) |
|
logger.info("Applied retry logic to Gemini API calls") |
|
|
|
|
|
SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. |
|
|
|
YOUR FINAL ANSWER should be: |
|
- A number OR |
|
- As few words as possible OR |
|
- A comma separated list of numbers and/or strings |
|
|
|
Rules for formatting: |
|
1. If asked for a number: |
|
- Don't use commas |
|
- Don't use units ($, %, etc.) unless specified |
|
2. If asked for a string: |
|
- Don't use articles |
|
- Don't use abbreviations (e.g. for cities) |
|
- Write digits in plain text unless specified |
|
3. If asked for a comma separated list: |
|
- Apply the above rules for each element |
|
- Separate elements with commas |
|
- No spaces after commas |
|
|
|
Remember: There is only one correct answer. Be precise and concise.""" |
|
|
|
class GeminiLLM: |
|
def __init__(self, model="gemini-2.0-flash"): |
|
self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) |
|
self.model_name = model |
|
|
|
self.generation_config = { |
|
"temperature": 0, |
|
"top_p": 1, |
|
"top_k": 1, |
|
"max_output_tokens": 2048, |
|
} |
|
|
|
def generate(self, prompt, image=None): |
|
try: |
|
|
|
full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {prompt}" |
|
|
|
if image is not None: |
|
logger.debug(f"Image path: {image}") |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
response = self.client.models.generate_content( |
|
model=self.model_name, |
|
contents=[full_prompt, image], |
|
config=self.generation_config |
|
) |
|
else: |
|
response = self.client.models.generate_content( |
|
model=self.model_name, |
|
contents=[full_prompt], |
|
config=self.generation_config |
|
) |
|
|
|
|
|
content = response.text.strip() |
|
if "FINAL ANSWER:" in content: |
|
final_answer = content.split("FINAL ANSWER:")[-1].strip() |
|
return ChatMessage(role="assistant", content=final_answer) |
|
return ChatMessage(role="assistant", content=content) |
|
except genai.errors.APIError as e: |
|
if e.code in {429, 503}: |
|
logger.warning(f"Rate limit or server error (code {e.code}), retry logic will handle this") |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
return ChatMessage(role="assistant", content=f"Error: {str(e)}") |
|
|
|
class LLMClient: |
|
"""Wrapper class for LLM to provide a unified interface""" |
|
|
|
def __init__(self): |
|
"""Initialize LLM client""" |
|
self.llm = GeminiLLM() |
|
|
|
def generate_response(self, question: str, context: str = "", system_prompt: str = "") -> str: |
|
"""Generate response using the LLM""" |
|
|
|
if system_prompt: |
|
prompt = f"{system_prompt}\n\n" |
|
else: |
|
prompt = "" |
|
|
|
if context: |
|
prompt += f"Context:\n{context}\n\n" |
|
|
|
prompt += f"Question: {question}" |
|
|
|
|
|
response = self.llm.generate(prompt) |
|
return response.content |