File size: 9,783 Bytes
fd69e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# services/openai_service.py (Added Generation Function)
import openai
import traceback
import json
import asyncio
from typing import Dict, Optional, Tuple, List, AsyncGenerator # Added List, AsyncGenerator
from langsmith import traceable

try:
    import config
    from utils import format_context_for_openai # Import new formatter
except ImportError:
    print("Error: Failed to import config or utils in openai_service.py")
    raise SystemExit("Failed imports in openai_service.py")

# --- Globals ---
openai_async_client: Optional[openai.AsyncOpenAI] = None
is_openai_ready: bool = False
openai_status_message: str = "OpenAI service not initialized."

# --- Initialization ---
def init_openai_client() -> Tuple[bool, str]:
    """Initializes the OpenAI async client."""
    global openai_async_client, is_openai_ready, openai_status_message
    if is_openai_ready: return True, openai_status_message
    if not config.OPENAI_API_KEY:
        openai_status_message = "Error: OPENAI_API_KEY not found in Secrets."
        is_openai_ready = False; return False, openai_status_message
    try:
        openai_async_client = openai.AsyncOpenAI(api_key=config.OPENAI_API_KEY)
        # Update status message to reflect dual use
        openai_status_message = f"OpenAI service ready (Validate: {config.OPENAI_VALIDATION_MODEL}, Generate: {config.OPENAI_GENERATION_MODEL})."
        is_openai_ready = True
        print("OpenAI Service: Async client initialized.")
        return True, openai_status_message
    except Exception as e:
        error_msg = f"Error initializing OpenAI async client: {type(e).__name__} - {e}"; print(error_msg); traceback.print_exc()
        openai_status_message = error_msg; is_openai_ready = False; openai_async_client = None
        return False, openai_status_message

def get_openai_status() -> Tuple[bool, str]:
    """Returns the current status of the OpenAI service."""
    if not is_openai_ready: init_openai_client()
    return is_openai_ready, openai_status_message

# --- Validation Function (Keep As Is) ---
@traceable(name="openai-validate-paragraph")
async def validate_relevance_openai(
    paragraph_data: Dict, user_question: str, paragraph_index: int
) -> Optional[Dict]:
    # ... (Keep the existing implementation of validate_relevance_openai) ...
    global openai_async_client; ready, msg = get_openai_status()
    if not ready or openai_async_client is None: print(f"OpenAI validation failed (Para {paragraph_index+1}): Client not ready - {msg}"); return None
    safe_paragraph_data = paragraph_data.copy() if isinstance(paragraph_data, dict) else {}
    if not paragraph_data or not isinstance(paragraph_data, dict):
        return {"validation": {"contains_relevant_info": False, "justification": "Input data empty/invalid."}, "paragraph_data": safe_paragraph_data}
    hebrew_text = paragraph_data.get('hebrew_text', '').strip(); english_text = paragraph_data.get('english_text', '').strip()
    if not hebrew_text and not english_text: return {"validation": {"contains_relevant_info": False, "justification": "Paragraph text empty."}, "paragraph_data": safe_paragraph_data}
    validation_model = config.OPENAI_VALIDATION_MODEL
    prompt_content = f"""User Question (Hebrew):\n"{user_question}"\n\nText Paragraph (Paragraph {paragraph_index+1}):\nHebrew:\n---\n{hebrew_text or "(No Hebrew)"}\n---\nEnglish:\n---\n{english_text or "(No English)"}\n---\n\nInstruction:\nAnalyze the Text Paragraph. Determine if it contains information that *directly* answers or significantly contributes to answering the User Question.\nRespond ONLY with valid JSON: {{"contains_relevant_info": boolean, "justification": "Brief Hebrew explanation"}}\nExample: {{"contains_relevant_info": true, "justification": "..."}} OR {{"contains_relevant_info": false, "justification": "..."}}\nOutput only the JSON object."""
    try:
        response = await openai_async_client.chat.completions.create(model=validation_model, messages=[{"role": "user", "content": prompt_content}], temperature=0.1, max_tokens=150, response_format={"type": "json_object"})
        json_string = response.choices[0].message.content
        try:
            validation_result = json.loads(json_string)
            if not isinstance(validation_result, dict) or 'contains_relevant_info' not in validation_result or 'justification' not in validation_result or not isinstance(validation_result['contains_relevant_info'], bool) or not isinstance(validation_result['justification'], str):
                print(f"Error (OpenAI Validate {paragraph_index+1}): Invalid JSON structure: {validation_result}")
                return {"validation": {"contains_relevant_info": False, "justification": "Error: Invalid response format."}, "paragraph_data": safe_paragraph_data}
            return {"validation": validation_result, "paragraph_data": safe_paragraph_data}
        except json.JSONDecodeError as json_err:
            print(f"Error (OpenAI Validate {paragraph_index+1}): Failed JSON decode: {json_err}. Response: {json_string}")
            return {"validation": {"contains_relevant_info": False, "justification": "Error: Failed to parse JSON response."}, "paragraph_data": safe_paragraph_data}
    except openai.RateLimitError as e: print(f"Error (OpenAI Validate {paragraph_index+1}): Rate Limit: {e}"); return {"validation": {"contains_relevant_info": False, "justification": "Error: Rate limit hit."}, "paragraph_data": safe_paragraph_data}
    except openai.APIConnectionError as e: print(f"Error (OpenAI Validate {paragraph_index+1}): Connection Error: {e}"); return None
    except openai.APIStatusError as e: print(f"Error (OpenAI Validate {paragraph_index+1}): API Status {e.status_code}: {e.response}"); return None
    except Exception as e: print(f"Error (OpenAI Validate {paragraph_index+1}): Unexpected: {type(e).__name__}"); traceback.print_exc(); return None


# --- NEW Generation Function ---
@traceable(name="openai-generate-stream")
async def generate_openai_stream(
    messages: List[Dict],
    context_documents: List[Dict],
) -> AsyncGenerator[str, None]:
    """
    Generates a response stream using OpenAI GPT model based on history and context.
    Yields text chunks or an error message string.
    """
    global openai_async_client
    ready, msg = get_openai_status()
    if not ready or openai_async_client is None:
        yield f"--- Error: OpenAI client not available for generation: {msg} ---"
        return

    try:
        # Validate context format
        if not isinstance(context_documents, list) or not all(isinstance(item, dict) for item in context_documents):
             yield f"--- Error: Invalid format for context_documents (expected List[Dict]). ---"
             return

        # Format context using the new utility function
        formatted_context = format_context_for_openai(context_documents)
        if not formatted_context or formatted_context.startswith("No"): # Check for empty or failed formatting
             yield f"--- Error: No valid context provided or formatted for OpenAI generator. ---"
             return

        # Find the latest user message from history
        last_user_msg_content = "User question not found."
        if messages and isinstance(messages, list):
            for msg_ in reversed(messages):
                if isinstance(msg_, dict) and msg_.get("role") == "user":
                    last_user_msg_content = str(msg_.get("content") or "")
                    break

        # Construct the final user prompt for the generation model
        user_prompt_content = f"Source Texts:\n{formatted_context}\n\nUser Question:\n{last_user_msg_content}\n\nAnswer (in Hebrew, based ONLY on the Source Texts provided):"

        # Prepare messages for the API call - System Prompt + User Prompt
        api_messages = [
            {"role": "system", "content": config.OPENAI_SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt_content}
        ]

        generation_model = config.OPENAI_GENERATION_MODEL
        print(f" -> Sending stream request to OpenAI (Model: {generation_model})...")
        print(f" -> User Prompt Content (start): {user_prompt_content[:300]}...") # Log start of prompt

        # Make the streaming API call
        stream = await openai_async_client.chat.completions.create(
            model=generation_model,
            messages=api_messages,
            temperature=0.5,  # Adjust temperature as needed
            max_tokens=3000, # Set a reasonable max token limit
            stream=True
        )

        print(f" -> OpenAI stream processing...")
        async for chunk in stream:
            content = chunk.choices[0].delta.content
            if content is not None:
                yield content # Yield the text chunk
            # Add a small sleep to avoid blocking the event loop entirely
            await asyncio.sleep(0.01)
        print(f" -> OpenAI stream finished.")

    # --- Exception Handling ---
    except openai.RateLimitError as e:
        error_msg = f"\n\n--- Error: OpenAI rate limit exceeded during generation: {e} ---"
        print(error_msg); traceback.print_exc(); yield error_msg
    except openai.APIConnectionError as e:
        error_msg = f"\n\n--- Error: OpenAI connection error during generation: {e} ---"
        print(error_msg); traceback.print_exc(); yield error_msg
    except openai.APIStatusError as e:
        error_msg = f"\n\n--- Error: OpenAI API status error ({e.status_code}) during generation: {e.response} ---"
        print(error_msg); traceback.print_exc(); yield error_msg
    except Exception as e:
        error_msg = f"\n\n--- Error: Unexpected error during OpenAI generation: {type(e).__name__} - {e} ---"
        print(error_msg); traceback.print_exc(); yield error_msg