Hiridharan10's picture
Create app.py
f60807b verified
import gradio as gr
# --- Import necessary classes ---
from transformers import (
TrOCRProcessor,
VisionEncoderDecoderModel,
pipeline,
AutoTokenizer,
RobertaForSequenceClassification,
AutoConfig # <--- Import AutoConfig
)
# ---
from PIL import Image
import traceback
import warnings
import json
import os
import shutil # Not used directly, but keep for potential manual use
# --- Model IDs ---
TROCR_MODELS = {
"Printed Text": "microsoft/trocr-large-printed",
"Handwritten": "microsoft/trocr-large-handwritten",
}
DETECTOR_MODEL_ID = "SuperAnnotate/roberta-large-llm-content-detector"
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}")
# --- Load OCR Models (no changes here) ---
print("Loading OCR models...")
OCR_PIPELINES = {}
for name, model_id in TROCR_MODELS.items():
try:
proc = TrOCRProcessor.from_pretrained(model_id)
mdl = VisionEncoderDecoderModel.from_pretrained(model_id)
OCR_PIPELINES[name] = (proc, mdl)
print(f"Loaded {name} OCR model.")
except Exception as e:
print(f"Error loading OCR model {name} ({model_id}): {e}")
# --- Explicitly load config, tokenizer, and model ---
print(f"Loading AI detector components ({DETECTOR_MODEL_ID})...")
DETECTOR_PIPELINE = None
detector_tokenizer = None
detector_model = None
try:
# 1. Load Configuration FIRST
print("Loading detector config...")
detector_config = AutoConfig.from_pretrained(DETECTOR_MODEL_ID)
print(f"Loaded config. Expected hidden size: {detector_config.hidden_size}") # Should be 1024
# Add an assertion to halt if config is wrong (optional but helpful)
if detector_config.hidden_size != 1024:
raise ValueError(f"Loaded config specifies hidden size {detector_config.hidden_size}, but expected 1024 for roberta-large. Check cache for {DETECTOR_MODEL_ID}.")
# 2. Load Tokenizer
print("Loading detector tokenizer...")
detector_tokenizer = AutoTokenizer.from_pretrained(DETECTOR_MODEL_ID)
# 3. Load Model using the specific class AND the loaded config
print("Loading detector model with loaded config...")
detector_model = RobertaForSequenceClassification.from_pretrained(
DETECTOR_MODEL_ID,
config=detector_config # <--- Pass the loaded config
)
print("AI detector model and tokenizer loaded successfully.")
# 4. Create Pipeline
print("Creating AI detector pipeline...")
DETECTOR_PIPELINE = pipeline(
"text-classification",
model=detector_model,
tokenizer=detector_tokenizer,
top_k=None
)
print("Created AI detector pipeline.")
# --- Optional: Label Test (keep from previous version) ---
if DETECTOR_PIPELINE:
try:
print("Testing detector pipeline labels...")
sample_output = DETECTOR_PIPELINE("This is a reasonably long test sentence to check the model labels.", truncation=True)
print(f"Sample detector output structure: {sample_output}")
# ... (rest of label testing code) ...
if sample_output and isinstance(sample_output, list) and len(sample_output) > 0:
if isinstance(sample_output[0], list) and len(sample_output[0]) > 0:
labels = [item.get('label', 'N/A') for item in sample_output[0] if isinstance(item, dict)]
print(f"Detected labels from sample run: {labels}")
elif isinstance(sample_output[0], dict):
labels = [item.get('label', 'N/A') for item in sample_output if isinstance(item, dict)]
print(f"Detected labels from sample run (non-nested): {labels}")
if detector_model and detector_model.config and detector_model.config.id2label:
print(f"Labels from model config: {detector_model.config.id2label}") # Should show {0: 'Human', 1: 'AI'}
except Exception as test_e:
print(f"Could not perform detector label test: {test_e}")
traceback.print_exc()
except Exception as e:
print(f"CRITICAL Error loading AI detector components ({DETECTOR_MODEL_ID}): {e}")
traceback.print_exc()
# --- Simplified Cache Clearing Suggestion ---
# Get cache path using environment variable or default
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
hub_cache_path = os.path.join(hf_home, "hub") # Models are usually in the 'hub' subfolder
print("\n--- TROUBLESHOOTING SUGGESTION ---")
print(f"The model loading failed: {e}")
print("\nThis *strongly* indicates a problem with the cached files for this model.")
print("The most likely solution is to MANUALLY clear the cache for this model.")
print(f"\n1. Stop this application.")
print(f"2. Go to your Hugging Face hub cache directory (usually found under '{hub_cache_path}').")
print(f" (If you've set HF_HOME environment variable, check there instead: '{hf_home}')")
# Construct the model-specific cache folder name
model_cache_folder_name = f"models--{DETECTOR_MODEL_ID.replace('/', '--')}"
print(f"3. Delete the specific folder for this model: '{model_cache_folder_name}'")
print(f" Full path example: {os.path.join(hub_cache_path, model_cache_folder_name)}")
print(f"4. Restart the application. This will force a fresh download.")
print("\nMake sure no other applications are using the cache while deleting.")
print("--- END TROUBLESHOOTING ---")
# ---
# DETECTOR_PIPELINE remains None
# --- Functions get_ai_and_human_scores, analyze_image, classify_text remain the same ---
# (Ensure get_ai_and_human_scores correctly handles "AI" and "Human" based on config)
def get_ai_and_human_scores(results):
"""
Processes detector results to get likelihood scores for both AI and Human classes.
Handles various label formats including 'AI'/'Human', 'LABEL_0'/'LABEL_1', etc.
Returns:
tuple: (ai_display_string, human_display_string)
"""
ai_prob = 0.0
human_prob = 0.0
status_message = "Status: Initializing..." # Default status
if not results:
print("Warning: Received empty results for AI detection.")
status_message = "Error: No results received"
return status_message, "N/A"
# Handle potential nested list structure
score_list = []
if isinstance(results, list) and len(results) > 0:
if isinstance(results[0], list) and len(results[0]) > 0:
score_list = results[0]
elif isinstance(results[0], dict):
score_list = results
else:
status_message = f"Error: Unexpected detector output format (inner list type: {type(results[0])})"
print(f"Warning: {status_message}. Results[0]: {results[0]}")
return status_message, "N/A"
else:
status_message = f"Error: Unexpected detector output format (outer type: {type(results)})"
print(f"Warning: {status_message}. Results: {results}")
return status_message, "N/A"
# Build label→score map (uppercase labels for robust matching)
lbl2score = {}
parse_errors = []
for entry in score_list:
if isinstance(entry, dict) and "label" in entry and "score" in entry:
try:
score = float(entry["score"])
lbl2score[entry["label"].upper()] = score
except (ValueError, TypeError):
parse_errors.append(f"Invalid score format: {entry}")
else:
parse_errors.append(f"Invalid entry format: {entry}")
if parse_errors:
print(f"Warning: Encountered parsing errors in score list: {parse_errors}")
if not lbl2score:
status_message = "Error: Could not parse any valid scores from detector output"
print(f"Warning: {status_message}. Score list was: {score_list}")
return status_message, "N/A"
label_keys_found = ", ".join(lbl2score.keys())
found_pair = False
inferred = False
# --- Determine AI and Human probabilities based on labels ---
upper_keys = lbl2score.keys()
# Prioritize AI/HUMAN as per model config
if "AI" in upper_keys and "HUMAN" in upper_keys:
ai_prob = lbl2score["AI"]
human_prob = lbl2score["HUMAN"]
found_pair = True
status_message = "OK (Used AI/HUMAN labels)"
# Fallbacks
elif "LABEL_1" in upper_keys and "LABEL_0" in upper_keys:
ai_prob = lbl2score["LABEL_1"]
human_prob = lbl2score["LABEL_0"]
found_pair = True
status_message = "OK (Used LABEL_1/LABEL_0 - Check Mapping)"
print("Warning: Used fallback LABEL_1/LABEL_0. Config expects AI/HUMAN.")
# Add other fallbacks if necessary (FAKE/REAL, MACHINE/HUMAN)
# Inference logic
if not found_pair:
if "AI" in upper_keys:
ai_prob = lbl2score["AI"]
human_prob = max(0.0, 1.0 - ai_prob)
inferred = True
status_message = "OK (Inferred from AI label)"
elif "HUMAN" in upper_keys:
human_prob = lbl2score["HUMAN"]
ai_prob = max(0.0, 1.0 - human_prob)
inferred = True
status_message = "OK (Inferred from HUMAN label)"
# Add fallback inference if needed
if not inferred:
status_message = f"Error: Could not determine AI/Human pair from labels [{label_keys_found}]"
print(f"Warning: {status_message}")
# --- Format output strings ---
ai_display_str = f"{ai_prob*100:.2f}%"
human_display_str = f"{human_prob*100:.2f}%"
if "Error:" in status_message:
ai_display_str = status_message
human_display_str = "N/A"
print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}")
return ai_display_str, human_display_str
# --- analyze_image function (no changes needed) ---
def analyze_image(image: Image.Image, ocr_choice: str):
"""Performs OCR and AI Content Detection, returns both AI and Human %."""
extracted = ""
ai_result_str = "N/A"
human_result_str = "N/A"
status_update = "Awaiting input..."
if image is None:
status_update = "Please upload an image first."
return extracted, ai_result_str, human_result_str, status_update
if not ocr_choice or ocr_choice not in TROCR_MODELS:
status_update = "Please select a valid OCR model."
return extracted, ai_result_str, human_result_str, status_update
if OCR_PIPELINES.get(ocr_choice) is None:
return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' failed to load or is unavailable."
if DETECTOR_PIPELINE is None:
return "", "N/A", "N/A", f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?)."
try:
status_update = f"Processing with {ocr_choice} OCR..."
print(status_update)
proc, mdl = OCR_PIPELINES[ocr_choice]
if image.mode != "RGB": image = image.convert("RGB")
pix = proc(images=image, return_tensors="pt").pixel_values
tokens = mdl.generate(pix, max_length=1024)
extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0]
extracted = extracted.strip()
if not extracted:
status_update = "OCR completed, but no text was extracted."
print(status_update)
return extracted, "N/A", "N/A", status_update
status_update = f"Detecting AI/Human content in {len(extracted)} characters..."
print(status_update)
results = DETECTOR_PIPELINE(extracted)
ai_result_str, human_result_str = get_ai_and_human_scores(results)
if "Error:" in ai_result_str:
status_update = ai_result_str
else:
status_update = "Analysis complete."
print(f"Final Status: {status_update}")
return extracted, ai_result_str, human_result_str, status_update
except Exception as e:
error_msg = f"Error during image analysis: {e}"
print(error_msg)
traceback.print_exc()
status_update = error_msg
return extracted, "Error", "Error", status_update
# --- classify_text function (no changes needed) ---
def classify_text(text: str):
"""Classifies provided text, returning both AI and Human %."""
ai_result_str = "N/A"
human_result_str = "N/A"
if DETECTOR_PIPELINE is None:
return f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?).", "N/A"
if not text or text.isspace():
return "Please enter some text.", "N/A"
print("Classifying text...")
try:
results = DETECTOR_PIPELINE(text)
ai_result_str, human_result_str = get_ai_and_human_scores(results)
if "Error:" not in ai_result_str:
print("Classification complete.")
else:
print(f"Classification completed with issues: {ai_result_str}")
return ai_result_str, human_result_str
except Exception as e:
error_msg = f"Error during text classification: {e}"
print(error_msg)
traceback.print_exc()
return error_msg, "Error"
# --- Gradio Interface (no changes needed) ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
## OCR + AI/Human Content Detection
Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it
using an AI content detector (`{DETECTOR_MODEL_ID}`)
to estimate the likelihood of it being AI-generated vs. Human-written.
**Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods
are estimates based on the model's training data and may not be definitive.
Performance varies with text type, length, and AI generation methods.
**Label Assumption:** Uses the model's configured labels (`AI`/`Human`). Fallbacks for other label formats are included but may be less reliable if the model deviates from its configuration.
"""
)
with gr.Tab("Analyze Image"):
with gr.Row():
with gr.Column(scale=2):
img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
with gr.Column(scale=1):
ocr_dd = gr.Dropdown(
list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image."
)
run_btn = gr.Button("2. Analyze Image", variant="primary")
status_img = gr.Label(value="Awaiting image analysis...", label="Status")
with gr.Row():
text_out_img = gr.Textbox(label="Extracted Text", lines=10, interactive=False)
with gr.Column(scale=1):
ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False)
run_btn.click(
fn=analyze_image,
inputs=[img_in, ocr_dd],
outputs=[text_out_img, ai_out_img, human_out_img, status_img],
queue=True
)
with gr.Tab("Classify Text"):
with gr.Column():
text_in_classify = gr.Textbox(label="Paste or type text here", lines=10)
classify_btn = gr.Button("Classify Text", variant="primary")
with gr.Row():
with gr.Column(scale=1):
ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False)
classify_btn.click(
fn=classify_text,
inputs=[text_in_classify],
outputs=[ai_out_classify, human_out_classify],
queue=True
)
gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>")
if __name__ == "__main__":
print("Starting Gradio demo...")
demo.launch(share=False, server_name="0.0.0.0")