import sys
import os
import subprocess # For calling generate.py
import tempfile # For handling temporary image files
from typing import Optional
from PIL import Image as PILImage
import gradio as gr
import time # For timing

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM"
if NANOVLM_REPO_PATH not in sys.path:
    print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
    sys.path.insert(0, NANOVLM_REPO_PATH)

print(f"DEBUG: Python sys.path: {sys.path}")
print(f"DEBUG: Gradio version: {gr.__version__}") # Log Gradio version

GENERATE_SCRIPT_PATH = "/app/nanoVLM/generate.py"
MODEL_REPO_ID = "lusxvr/nanoVLM-222M"

print(f"DEBUG: Using generate.py script at: {GENERATE_SCRIPT_PATH}")
print(f"DEBUG: Using model repo ID: {MODEL_REPO_ID}")

# In app.py

# In app.py

def call_generate_script(image_path: str, prompt_text: str) -> str:
    print(f"\n--- DEBUG (call_generate_script) ---")
    print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Calling with image_path='{image_path}', prompt='{prompt_text}'")
    
    # Arguments for the provided nanoVLM's generate.py
    cmd_args = [
        "python", "-u", GENERATE_SCRIPT_PATH,
        "--hf_model", MODEL_REPO_ID,
        "--image", image_path,          # VERIFIED: script uses --image
        "--prompt", prompt_text,
        "--generations", "1",           # VERIFIED: script uses --generations
        "--max_new_tokens", "30"        # This was correct
        # No --device argument, as it's not in the provided generate.py
    ]
    
    print(f"Executing command: {' '.join(cmd_args)}")
    
    SCRIPT_TIMEOUT_SECONDS = 55 
    start_time = time.time()
    
    process_identifier = "generate.py_process" 
    try:
        process = subprocess.run(
            cmd_args,
            capture_output=True,
            text=True,
            check=False,
            timeout=SCRIPT_TIMEOUT_SECONDS
        )

        duration = time.time() - start_time
        print(f"Subprocess ({process_identifier}) finished in {duration:.2f} seconds.")
        print(f"generate.py RETURN CODE: {process.returncode}")
        
        stdout = process.stdout.strip() if process.stdout else "[No STDOUT from generate.py]"
        stderr = process.stderr.strip() if process.stderr else "[No STDERR from generate.py]"

        print(f"---------- generate.py STDOUT ({process_identifier}) START ----------\n{stdout}\n---------- generate.py STDOUT ({process_identifier}) END ----------")
        if stderr or process.returncode != 0:
            print(f"---------- generate.py STDERR ({process_identifier}) START ----------\n{stderr}\n---------- generate.py STDERR ({process_identifier}) END ----------")

        if process.returncode != 0:
            error_message = f"Error: Generation script failed (code {process.returncode})."
            if "unrecognized arguments" in stderr: # This shouldn't happen now
                error_message += " Argument mismatch with script."
            elif "syntax error" in stderr.lower():
                error_message += " Syntax error in script."
            print(error_message)
            return error_message + f" STDERR Snippet: {stderr[:300]}"

        # --- Parse the output from the provided nanoVLM's generate.py ---
        # The script prints:
        #   >> Generation {i+1}: {out}
        output_lines = stdout.splitlines()
        generated_text = "[No parsable output from generate.py]"
        
        found_output_line = False
        for line_idx, line in enumerate(output_lines):
            stripped_line = line.strip()
            # Looking for the specific output format "  >> Generation X: text"
            if stripped_line.startswith(">> Generation 1:"): # Assuming we only care about the first generation
                # Extract text after ">> Generation 1: " (note the space after colon)
                try:
                    generated_text = stripped_line.split(">> Generation 1:", 1)[1].strip()
                    found_output_line = True
                    print(f"Parsed generated text: '{generated_text}'")
                    break 
                except IndexError:
                    print(f"Could not split line for '>> Generation 1:': '{stripped_line}'")
                    generated_text = f"[Parsing failed] Malformed 'Generation 1' line: {stripped_line}"
                    break

        if not found_output_line:
            print(f"Could not find '>> Generation 1:' line in generate.py output. Raw STDOUT was:\n{stdout}")
            if stdout:
                generated_text = f"[Parsing failed] STDOUT: {stdout[:500]}"
            else:
                generated_text = "[Parsing failed, no STDOUT from script]"

        print(f"Returning parsed text: '{generated_text}'")
        return generated_text

    except subprocess.TimeoutExpired as e:
        duration = time.time() - start_time
        print(f"ERROR: generate.py ({process_identifier}) timed out after {duration:.2f} seconds (limit: {SCRIPT_TIMEOUT_SECONDS}s).")
        stdout_on_timeout = e.stdout.strip() if hasattr(e, 'stdout') and e.stdout else "[No STDOUT on timeout]"
        stderr_on_timeout = e.stderr.strip() if hasattr(e, 'stderr') and e.stderr else "[No STDERR on timeout]"
        print(f"STDOUT on timeout:\n{stdout_on_timeout}")
        print(f"STDERR on timeout:\n{stderr_on_timeout}")
        return f"Error: Generation script timed out after {SCRIPT_TIMEOUT_SECONDS}s. Model loading and generation may be too slow for CPU."
    except Exception as e:
        duration = time.time() - start_time
        print(f"ERROR: An unexpected error occurred ({process_identifier}) after {duration:.2f}s: {type(e).__name__} - {e}")
        import traceback; traceback.print_exc()
        return f"Unexpected error calling script: {str(e)}"
    finally:
        print(f"--- END (call_generate_script) ---")

# The rest of your app.py (gradio_interface_fn, Gradio Interface Definition, __main__ block)
# should remain the same.

def gradio_interface_fn(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
    print(f"\nDEBUG (gradio_interface_fn): Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Received prompt: '{prompt_input_str}', Image type: {type(image_input_pil)}")

    if image_input_pil is None:
        return "Please upload an image."
    
    cleaned_prompt = prompt_input_str.strip() if prompt_input_str else ""
    if not cleaned_prompt:
        return "Please provide a non-empty prompt."

    tmp_image_path = None 
    try:
        if image_input_pil.mode != "RGB":
            print(f"Converting image from {image_input_pil.mode} to RGB.")
            image_input_pil = image_input_pil.convert("RGB")

        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_image_file:
            image_input_pil.save(tmp_image_file, format="JPEG")
            tmp_image_path = tmp_image_file.name
        
        print(f"Temporary image saved to: {tmp_image_path}")
        
        result_text = call_generate_script(tmp_image_path, cleaned_prompt)
        
        print(f"Result from call_generate_script: '{result_text}'")
        return result_text
        
    except Exception as e:
        print(f"ERROR (gradio_interface_fn): Error processing image or calling script: {type(e).__name__} - {e}")
        import traceback; traceback.print_exc()
        return f"An error occurred in Gradio interface function: {str(e)}"
    finally:
        if tmp_image_path and os.path.exists(tmp_image_path):
            try:
                os.remove(tmp_image_path)
                print(f"Temporary image {tmp_image_path} removed.")
            except Exception as e_remove:
                print(f"WARN: Could not remove temporary image {tmp_image_path}: {e_remove}")
        print(f"DEBUG (gradio_interface_fn): Exiting.")


# --- Gradio Interface Definition ---
description_md = """
## nanoVLM-222M Interactive Demo (via generate.py)
Upload an image and type a prompt. This interface calls the `generate.py` script from
`huggingface/nanoVLM` under the hood to perform inference.
**Note:** Each request re-loads the model via the script, so it might be slow on CPU.
"""

print("DEBUG: Defining Gradio interface...")
iface = None
try:
    iface = gr.Interface(
        fn=gradio_interface_fn,
        inputs=[
            gr.Image(type="pil", label="Upload Image"),
            gr.Textbox(label="Your Prompt / Question", info="e.g., 'describe this image in detail'")
        ],
        outputs=gr.Textbox(label="Generated Text", show_copy_button=True, lines=5),
        title="nanoVLM-222M Demo (via Script)",
        description=description_md,
        allow_flagging="never"
    )
    print("DEBUG: Gradio interface defined successfully.")
except Exception as e:
    print(f"CRITICAL ERROR defining Gradio interface: {e}")
    import traceback; traceback.print_exc()

# --- Launch Gradio App ---
if __name__ == "__main__":
    print("DEBUG: Entered __main__ block for Gradio launch.")
    if not os.path.exists(GENERATE_SCRIPT_PATH):
        print(f"CRITICAL ERROR: The script {GENERATE_SCRIPT_PATH} was not found. Cannot launch app.")
        iface = None 

    if iface is not None:
        print("DEBUG: Attempting to launch Gradio interface...")
        try:
            iface.launch(server_name="0.0.0.0", server_port=7860)
            print("DEBUG: Gradio launch command issued. UI should be accessible.")
        except Exception as e:
            print(f"CRITICAL ERROR launching Gradio interface: {e}")
            import traceback; traceback.print_exc()
    else:
        print("CRITICAL ERROR: Gradio interface (iface) is None or not defined. Cannot launch.")