import os
from typing import Dict, List, Optional, NamedTuple, Tuple
from datetime import datetime

from datasets import load_dataset

# Environment variables and constants
EXAM_DATASET_ID = os.getenv(
    "EXAM_DATASET_ID", "agents-course/unit_1_quiz_student_responses"
)
CERTIFICATE_MODELS_DIR = os.getenv("CERTIFICATE_MODELS_DIR", "./certificate_models")
CERTIFICATE_PATH = os.path.join(CERTIFICATE_MODELS_DIR, "certificate.png")

PASSING_THRESHOLD = float(os.getenv("PASSING_THRESHOLD", "0.8"))  # 80%
MIN_QUESTIONS = int(os.getenv("MIN_QUESTIONS", "1"))


class CertificateResult(NamedTuple):
    """Stores the result of a certificate check"""

    message: str
    certificate_path: Optional[str]
    pass_percentage: float
    passed: bool
    results_df: Optional[object] = None


def get_user_results(username: str) -> List[Dict]:
    """
    Get user's quiz results from the dataset.

    Args:
        username: The Hugging Face username to check

    Returns:
        List of user's quiz results
    """
    try:
        sanitized_username = username.replace("-", "000")
        user_results = load_dataset(
            path=EXAM_DATASET_ID,
            streaming=True,
            data_files=f"data/{sanitized_username}-00000-of-00001.parquet",
        )
        results = list(user_results["train"])
        print(f"Found {len(results)} results for user {sanitized_username}")
        return results

    except Exception as e:
        print(f"Error in get_user_results: {str(e)}")
        raise


def calculate_pass_percentage(results: List[Dict]) -> Tuple[float, int]:
    """
    Calculate the user's pass percentage and number of questions from their results.

    The dataset structure has:
    - is_correct: bool indicating if answer was correct
    - grade: float64 indicating overall grade
    - datetime: string of attempt timestamp

    Args:
        results: List of quiz results

    Returns:
        Tuple of (highest grade achieved, number of questions answered)
    """
    try:
        if not results:
            return 0.0, 0

        # Group results by datetime to get distinct attempts
        attempts = {}
        for result in results:
            timestamp = result["datetime"]
            if timestamp not in attempts:
                attempts[timestamp] = {
                    "correct": 0,
                    "total": 0,
                    "grade": result.get("grade", 0.0),
                }

            attempts[timestamp]["total"] += 1
            if result["is_correct"]:
                attempts[timestamp]["correct"] += 1

        # Find the best attempt
        best_attempt = max(
            attempts.values(),
            key=lambda x: (
                x["grade"]
                if x["grade"] is not None
                else (x["correct"] / x["total"] if x["total"] > 0 else 0)
            ),
        )

        # If grade is available, use it; otherwise calculate from correct/total
        if best_attempt["grade"] is not None and best_attempt["grade"] > 0:
            pass_percentage = float(best_attempt["grade"])
        else:
            pass_percentage = (
                best_attempt["correct"] / best_attempt["total"]
                if best_attempt["total"] > 0
                else 0.0
            )

        return pass_percentage, best_attempt["total"]

    except Exception as e:
        print(f"Error in calculate_pass_percentage: {str(e)}")
        raise


def has_passed(pass_percentage: float, num_questions: int) -> bool:
    """
    Check if user has passed based on percentage and minimum questions.

    Args:
        pass_percentage: User's highest quiz score
        num_questions: Number of questions answered

    Returns:
        Boolean indicating if user passed
    """
    return pass_percentage >= PASSING_THRESHOLD and num_questions >= MIN_QUESTIONS


def get_certificate_result(
    pass_percentage: float, num_questions: int
) -> CertificateResult:
    """
    Determine if user passed and create appropriate message.

    Args:
        pass_percentage: User's highest quiz score
        num_questions: Number of questions answered

    Returns:
        CertificateResult with pass status and details
    """
    passed = has_passed(pass_percentage, num_questions)

    if passed:
        message = """
            Congratulations, you successfully completed the course! 🎉 \n
            You can download your certificate below ⬇️ 
            """
        return CertificateResult(
            message=message,
            certificate_path=CERTIFICATE_PATH,
            pass_percentage=pass_percentage,
            passed=True,
        )
    else:
        return CertificateResult(
            message="""
            You haven't completed all the requirements yet. \n
            Keep trying! 💪
            """,
            certificate_path=None,
            pass_percentage=pass_percentage,
            passed=False,
        )


def check_certification(username: str) -> CertificateResult:
    """
    Check if a user has completed the certification requirements.

    Args:
        username: The Hugging Face username to check

    Returns:
        CertificateResult containing pass status and details
    """
    try:
        # Get user's quiz results
        results = get_user_results(username)
        if not results:
            return CertificateResult(
                message="No quiz results found. Please complete the quiz first.",
                certificate_path=None,
                pass_percentage=0.0,
                passed=False,
            )

        # Calculate pass percentage and get appropriate certificate result
        pass_percentage, num_questions = calculate_pass_percentage(results)
        return get_certificate_result(pass_percentage, num_questions)

    except Exception as e:
        error_msg = """
        There was an error checking your certification status.
        Please try again later or contact support if the issue persists.
        """
        return CertificateResult(
            message=error_msg,
            certificate_path=None,
            pass_percentage=0.0,
            passed=False,
        )