| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, validator, Field |
| from typing import List, Dict, Any, Union |
| import google.generativeai as genai |
| import os |
| from dotenv import load_dotenv |
| import logging |
| import time |
|
|
| load_dotenv() |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI(title="Language Agent (Gemini Pro - Generalized)") |
|
|
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
| GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") |
|
|
| if not GOOGLE_API_KEY: |
| logger.warning("GOOGLE_API_KEY not found.") |
| else: |
| try: |
| genai.configure(api_key=GOOGLE_API_KEY) |
| logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.") |
| except Exception as e: |
| logger.error(f"Failed to configure Google Generative AI: {e}") |
|
|
|
|
| class EarningsSummaryLLM(BaseModel): |
| ticker: str |
| surprise_pct: float |
|
|
|
|
| class AnalysisDataLLM(BaseModel): |
| target_label: str = "the portfolio" |
| current_allocation: float = 0.0 |
| yesterday_allocation: float = 0.0 |
| allocation_change_percentage_points: float = 0.0 |
|
|
| earnings_surprises: List[EarningsSummaryLLM] = Field( |
| default_factory=list, alias="earnings_surprises_for_target" |
| ) |
|
|
|
|
| class BriefRequest(BaseModel): |
| user_query: str |
| analysis: AnalysisDataLLM |
| retrieved_docs: List[str] = Field(default_factory=list) |
|
|
|
|
| def construct_gemini_prompt( |
| user_query: str, analysis_data: AnalysisDataLLM, docs_context: str |
| ) -> str: |
|
|
| alloc_change_str = "" |
| if analysis_data.allocation_change_percentage_points > 0.01: |
| alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." |
| elif analysis_data.allocation_change_percentage_points < -0.01: |
| alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." |
| else: |
| alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday." |
|
|
| analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n" |
|
|
| if analysis_data.earnings_surprises: |
| earnings_parts = [] |
| for e in analysis_data.earnings_surprises: |
| direction = ( |
| "beat estimates by" if e.surprise_pct >= 0 else "missed estimates by" |
| ) |
| earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%") |
| if earnings_parts: |
| analysis_summary_str += ( |
| "Key earnings updates: " + ", ".join(earnings_parts) + "." |
| ) |
| else: |
| analysis_summary_str += ( |
| "No specific earnings surprises to highlight for this segment." |
| ) |
| else: |
| analysis_summary_str += ( |
| "No notable earnings surprises reported for this segment." |
| ) |
|
|
| prompt = ( |
| f"You are a professional financial assistant. Based on the user's query and the provided data, " |
| f"deliver a concise, spoken-style morning market brief for a portfolio manager. " |
| f"The brief should start with 'Good morning.'\n\n" |
| f"User Query: {user_query}\n\n" |
| f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n" |
| f"Relevant Filings Context (if any):\n{docs_context}\n\n" |
| f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', " |
| f"you can state that specific data for that exact query aspect was not available in the analysis provided. " |
| f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')." |
| f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis." |
| ) |
| return prompt |
|
|
|
|
| generation_config = genai.types.GenerationConfig( |
| temperature=0.6, max_output_tokens=1024 |
| ) |
| safety_settings = [ |
| {"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} |
| for c in [ |
| "HARM_CATEGORY_HARASSMENT", |
| "HARM_CATEGORY_HATE_SPEECH", |
| "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| "HARM_CATEGORY_DANGEROUS_CONTENT", |
| ] |
| ] |
|
|
|
|
| @app.post("/generate_brief") |
| async def generate_brief(request: BriefRequest): |
| if not GOOGLE_API_KEY: |
| raise HTTPException(status_code=500, detail="Google API Key not configured.") |
| logger.info( |
| f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}" |
| ) |
|
|
| docs_context = ( |
| "\n".join(request.retrieved_docs[:2]) |
| if request.retrieved_docs |
| else "No relevant context from documents found." |
| ) |
|
|
| full_prompt = construct_gemini_prompt( |
| user_query=request.user_query, |
| analysis_data=request.analysis, |
| docs_context=docs_context, |
| ) |
| logger.debug(f"Full prompt for Gemini:\n{full_prompt}") |
|
|
| try: |
| model = genai.GenerativeModel( |
| model_name=GEMINI_MODEL_NAME, |
| generation_config=generation_config, |
| safety_settings=safety_settings, |
| ) |
| max_retries = 1 |
| retry_delay_seconds = 10 |
| for attempt in range(max_retries + 1): |
| try: |
| response = await model.generate_content_async(full_prompt) |
|
|
| if not response.parts: |
| if ( |
| response.prompt_feedback |
| and response.prompt_feedback.block_reason |
| ): |
| block_reason_message = ( |
| response.prompt_feedback.block_reason_message |
| or "Unknown safety block" |
| ) |
| logger.error( |
| f"Gemini content generation blocked. Reason: {block_reason_message}" |
| ) |
| raise HTTPException( |
| status_code=400, |
| detail=f"Content generation blocked: {block_reason_message}", |
| ) |
| else: |
| logger.error("Gemini response has no parts (empty content).") |
|
|
| if attempt == max_retries: |
| raise HTTPException( |
| status_code=500, |
| detail="Gemini returned empty content after retries.", |
| ) |
| else: |
| logger.warning( |
| f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..." |
| ) |
| await asyncio.sleep(retry_delay_seconds) |
| continue |
|
|
| brief_text = response.text |
| logger.info("Gemini content generated successfully.") |
| return {"brief": brief_text} |
|
|
| except ( |
| genai.types.generation_types.BlockedPromptException, |
| genai.types.generation_types.StopCandidateException, |
| ) as sce_bpe: |
| logger.error( |
| f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}" |
| ) |
| raise HTTPException( |
| status_code=400, detail=f"Gemini generation issue: {sce_bpe}" |
| ) |
| except Exception as e: |
| logger.error( |
| f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}" |
| ) |
| if ( |
| "rate limit" in str(e).lower() |
| or "quota" in str(e).lower() |
| or "429" in str(e) |
| or "resource_exhausted" in str(e).lower() |
| ): |
| if attempt < max_retries: |
| wait_time = retry_delay_seconds * (2**attempt) |
| logger.info(f"Rate limit likely. Retrying in {wait_time}s...") |
| await asyncio.sleep(wait_time) |
| continue |
| else: |
| logger.error("Max retries reached for rate limit.") |
| raise HTTPException( |
| status_code=429, |
| detail=f"Gemini API rate limit/quota exceeded: {e}", |
| ) |
| elif attempt < max_retries: |
| await asyncio.sleep(retry_delay_seconds) |
| continue |
| else: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Failed to generate brief with Gemini: {e}", |
| ) |
|
|
| raise HTTPException( |
| status_code=500, detail="Brief generation failed after all attempts." |
| ) |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Critical error in /generate_brief: {e}", exc_info=True) |
| raise HTTPException( |
| status_code=500, detail=f"Critical failure in brief generation: {e}" |
| ) |
|
|