import gradio as gr
import pytesseract
from PIL import Image
from transformers import pipeline
import re
from langdetect import detect
from deep_translator import GoogleTranslator

# Translator instance
translator = GoogleTranslator(source="auto", target="es")

# 1. Load separate keywords for SMiShing and Other Scam (assumed in English)
with open("smishing_keywords.txt", "r", encoding="utf-8") as f:
    SMISHING_KEYWORDS = [line.strip().lower() for line in f if line.strip()]

with open("other_scam_keywords.txt", "r", encoding="utf-8") as f:
    OTHER_SCAM_KEYWORDS = [line.strip().lower() for line in f if line.strip()]

# 2. Zero-Shot Classification Pipeline
model_name = "joeddav/xlm-roberta-large-xnli"
classifier = pipeline("zero-shot-classification", model=model_name)
CANDIDATE_LABELS = ["SMiShing", "Other Scam", "Legitimate"]


def get_keywords_by_language(text: str):
    """
    Detect language using `langdetect` and translate keywords if needed.
    """
    snippet = text[:200]  # Use a snippet for detection
    try:
        detected_lang = detect(snippet)
    except Exception:
        detected_lang = "en"  # Default to English if detection fails

    if detected_lang == "es":
        smishing_in_spanish = [
            translator.translate(kw).lower() for kw in SMISHING_KEYWORDS
        ]
        other_scam_in_spanish = [
            translator.translate(kw).lower() for kw in OTHER_SCAM_KEYWORDS
        ]
        return smishing_in_spanish, other_scam_in_spanish, "es"
    else:
        return SMISHING_KEYWORDS, OTHER_SCAM_KEYWORDS, "en"


def boost_probabilities(probabilities: dict, text: str):
    """
    Boost probabilities based on keyword matches and presence of URLs.
    """
    lower_text = text.lower()
    smishing_keywords, other_scam_keywords, detected_lang = get_keywords_by_language(text)

    smishing_count = sum(1 for kw in smishing_keywords if kw in lower_text)
    other_scam_count = sum(1 for kw in other_scam_keywords if kw in lower_text)

    # Example: 30% per found keyword
    smishing_boost = 0.30 * smishing_count
    other_scam_boost = 0.30 * other_scam_count

    found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
    if found_urls:
        # 35% boost for Smishing if there's a URL
        smishing_boost += 0.35

    p_smishing = probabilities.get("SMiShing", 0.0)
    p_other_scam = probabilities.get("Other Scam", 0.0)
    p_legit = probabilities.get("Legitimate", 1.0)

    p_smishing += smishing_boost
    p_other_scam += other_scam_boost
    p_legit -= (smishing_boost + other_scam_boost)

    # Clamp to 0
    p_smishing = max(p_smishing, 0.0)
    p_other_scam = max(p_other_scam, 0.0)
    p_legit = max(p_legit, 0.0)

    # Re-normalize
    total = p_smishing + p_other_scam + p_legit
    if total > 0:
        p_smishing /= total
        p_other_scam /= total
        p_legit /= total
    else:
        p_smishing, p_other_scam, p_legit = 0.0, 0.0, 1.0

    return {
        "SMiShing": p_smishing,
        "Other Scam": p_other_scam,
        "Legitimate": p_legit,
        "detected_lang": detected_lang
    }


def smishing_detector(input_type, text, image):
    """
    Main detection function:
      - If input_type == "Text": use `text` as the message
      - If input_type == "Screenshot": use OCR on `image` to get text
    """
    if input_type == "Text":
        # Use the pasted text
        combined_text = text.strip() if text else ""
    else:
        # input_type == "Screenshot"
        if image is not None:
            ocr_text = pytesseract.image_to_string(image, lang="spa+eng")
            combined_text = ocr_text.strip()
        else:
            combined_text = ""

    if not combined_text:
        return {
            "text_used_for_classification": "(none)",
            "label": "No text provided",
            "confidence": 0.0,
            "keywords_found": [],
            "urls_found": []
        }

    # Zero-shot classification
    result = classifier(
        sequences=combined_text,
        candidate_labels=CANDIDATE_LABELS,
        hypothesis_template="This message is {}."
    )
    original_probs = {k: float(v) for k, v in zip(result["labels"], result["scores"])}

    # Boost logic
    boosted = boost_probabilities(original_probs, combined_text)

    # Convert to float
    boosted = {k: float(v) for k, v in boosted.items() if isinstance(v, (int, float))}
    detected_lang = boosted.pop("detected_lang", "en")

    # Final classification
    final_label = max(boosted, key=boosted.get)
    final_confidence = round(boosted[final_label], 3)

    # For display
    lower_text = combined_text.lower()
    smishing_keys, scam_keys, _ = get_keywords_by_language(combined_text)

    found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
    found_smishing = [kw for kw in smishing_keys if kw in lower_text]
    found_other_scam = [kw for kw in scam_keys if kw in lower_text]

    return {
        "detected_language": detected_lang,
        "text_used_for_classification": combined_text,
        "original_probabilities": {k: round(v, 3) for k, v in original_probs.items()},
        "boosted_probabilities": {k: round(v, 3) for k, v in boosted.items()},
        "label": final_label,
        "confidence": final_confidence,
        "smishing_keywords_found": found_smishing,
        "other_scam_keywords_found": found_other_scam,
        "urls_found": found_urls,
    }


# Create a Radio for user choice + text input + image input
demo = gr.Interface(
    fn=smishing_detector,
    inputs=[
        gr.Radio(
            choices=["Text", "Screenshot"],
            label="Choose input type",
            value="Text",  # default
            info="Select 'Text' to paste a message, or 'Screenshot' to upload an image."
        ),
        gr.Textbox(
            lines=3,
            label="Paste Suspicious SMS Text",
            placeholder="Type or paste the message here..."
        ),
        gr.Image(
            type="pil",
            label="Upload a Screenshot",
        )
    ],
    outputs="json",
    title="SMiShing & Scam Detector",
    description="""
Select "Text" or "Screenshot" above. 
- If "Text", only use the textbox. 
- If "Screenshot", only upload an image. 
The app will classify the message as SMiShing, Other Scam, or Legitimate.
""",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()