|
import gradio as gr |
|
|
|
from transformers import ( |
|
TrOCRProcessor, |
|
VisionEncoderDecoderModel, |
|
pipeline, |
|
AutoTokenizer, |
|
RobertaForSequenceClassification, |
|
AutoConfig |
|
) |
|
|
|
from PIL import Image |
|
import traceback |
|
import warnings |
|
import json |
|
import os |
|
import shutil |
|
|
|
|
|
TROCR_MODELS = { |
|
"Printed Text": "microsoft/trocr-large-printed", |
|
"Handwritten": "microsoft/trocr-large-handwritten", |
|
} |
|
DETECTOR_MODEL_ID = "SuperAnnotate/roberta-large-llm-content-detector" |
|
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}") |
|
|
|
|
|
print("Loading OCR models...") |
|
OCR_PIPELINES = {} |
|
for name, model_id in TROCR_MODELS.items(): |
|
try: |
|
proc = TrOCRProcessor.from_pretrained(model_id) |
|
mdl = VisionEncoderDecoderModel.from_pretrained(model_id) |
|
OCR_PIPELINES[name] = (proc, mdl) |
|
print(f"Loaded {name} OCR model.") |
|
except Exception as e: |
|
print(f"Error loading OCR model {name} ({model_id}): {e}") |
|
|
|
|
|
print(f"Loading AI detector components ({DETECTOR_MODEL_ID})...") |
|
DETECTOR_PIPELINE = None |
|
detector_tokenizer = None |
|
detector_model = None |
|
try: |
|
|
|
print("Loading detector config...") |
|
detector_config = AutoConfig.from_pretrained(DETECTOR_MODEL_ID) |
|
print(f"Loaded config. Expected hidden size: {detector_config.hidden_size}") |
|
|
|
|
|
if detector_config.hidden_size != 1024: |
|
raise ValueError(f"Loaded config specifies hidden size {detector_config.hidden_size}, but expected 1024 for roberta-large. Check cache for {DETECTOR_MODEL_ID}.") |
|
|
|
|
|
print("Loading detector tokenizer...") |
|
detector_tokenizer = AutoTokenizer.from_pretrained(DETECTOR_MODEL_ID) |
|
|
|
|
|
print("Loading detector model with loaded config...") |
|
detector_model = RobertaForSequenceClassification.from_pretrained( |
|
DETECTOR_MODEL_ID, |
|
config=detector_config |
|
) |
|
print("AI detector model and tokenizer loaded successfully.") |
|
|
|
|
|
print("Creating AI detector pipeline...") |
|
DETECTOR_PIPELINE = pipeline( |
|
"text-classification", |
|
model=detector_model, |
|
tokenizer=detector_tokenizer, |
|
top_k=None |
|
) |
|
print("Created AI detector pipeline.") |
|
|
|
|
|
if DETECTOR_PIPELINE: |
|
try: |
|
print("Testing detector pipeline labels...") |
|
sample_output = DETECTOR_PIPELINE("This is a reasonably long test sentence to check the model labels.", truncation=True) |
|
print(f"Sample detector output structure: {sample_output}") |
|
|
|
if sample_output and isinstance(sample_output, list) and len(sample_output) > 0: |
|
if isinstance(sample_output[0], list) and len(sample_output[0]) > 0: |
|
labels = [item.get('label', 'N/A') for item in sample_output[0] if isinstance(item, dict)] |
|
print(f"Detected labels from sample run: {labels}") |
|
elif isinstance(sample_output[0], dict): |
|
labels = [item.get('label', 'N/A') for item in sample_output if isinstance(item, dict)] |
|
print(f"Detected labels from sample run (non-nested): {labels}") |
|
if detector_model and detector_model.config and detector_model.config.id2label: |
|
print(f"Labels from model config: {detector_model.config.id2label}") |
|
except Exception as test_e: |
|
print(f"Could not perform detector label test: {test_e}") |
|
traceback.print_exc() |
|
|
|
|
|
except Exception as e: |
|
print(f"CRITICAL Error loading AI detector components ({DETECTOR_MODEL_ID}): {e}") |
|
traceback.print_exc() |
|
|
|
|
|
|
|
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) |
|
hub_cache_path = os.path.join(hf_home, "hub") |
|
|
|
print("\n--- TROUBLESHOOTING SUGGESTION ---") |
|
print(f"The model loading failed: {e}") |
|
print("\nThis *strongly* indicates a problem with the cached files for this model.") |
|
print("The most likely solution is to MANUALLY clear the cache for this model.") |
|
print(f"\n1. Stop this application.") |
|
print(f"2. Go to your Hugging Face hub cache directory (usually found under '{hub_cache_path}').") |
|
print(f" (If you've set HF_HOME environment variable, check there instead: '{hf_home}')") |
|
|
|
model_cache_folder_name = f"models--{DETECTOR_MODEL_ID.replace('/', '--')}" |
|
print(f"3. Delete the specific folder for this model: '{model_cache_folder_name}'") |
|
print(f" Full path example: {os.path.join(hub_cache_path, model_cache_folder_name)}") |
|
print(f"4. Restart the application. This will force a fresh download.") |
|
print("\nMake sure no other applications are using the cache while deleting.") |
|
print("--- END TROUBLESHOOTING ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_ai_and_human_scores(results): |
|
""" |
|
Processes detector results to get likelihood scores for both AI and Human classes. |
|
Handles various label formats including 'AI'/'Human', 'LABEL_0'/'LABEL_1', etc. |
|
Returns: |
|
tuple: (ai_display_string, human_display_string) |
|
""" |
|
ai_prob = 0.0 |
|
human_prob = 0.0 |
|
status_message = "Status: Initializing..." |
|
|
|
if not results: |
|
print("Warning: Received empty results for AI detection.") |
|
status_message = "Error: No results received" |
|
return status_message, "N/A" |
|
|
|
|
|
score_list = [] |
|
if isinstance(results, list) and len(results) > 0: |
|
if isinstance(results[0], list) and len(results[0]) > 0: |
|
score_list = results[0] |
|
elif isinstance(results[0], dict): |
|
score_list = results |
|
else: |
|
status_message = f"Error: Unexpected detector output format (inner list type: {type(results[0])})" |
|
print(f"Warning: {status_message}. Results[0]: {results[0]}") |
|
return status_message, "N/A" |
|
else: |
|
status_message = f"Error: Unexpected detector output format (outer type: {type(results)})" |
|
print(f"Warning: {status_message}. Results: {results}") |
|
return status_message, "N/A" |
|
|
|
|
|
lbl2score = {} |
|
parse_errors = [] |
|
for entry in score_list: |
|
if isinstance(entry, dict) and "label" in entry and "score" in entry: |
|
try: |
|
score = float(entry["score"]) |
|
lbl2score[entry["label"].upper()] = score |
|
except (ValueError, TypeError): |
|
parse_errors.append(f"Invalid score format: {entry}") |
|
else: |
|
parse_errors.append(f"Invalid entry format: {entry}") |
|
|
|
if parse_errors: |
|
print(f"Warning: Encountered parsing errors in score list: {parse_errors}") |
|
|
|
if not lbl2score: |
|
status_message = "Error: Could not parse any valid scores from detector output" |
|
print(f"Warning: {status_message}. Score list was: {score_list}") |
|
return status_message, "N/A" |
|
|
|
label_keys_found = ", ".join(lbl2score.keys()) |
|
found_pair = False |
|
inferred = False |
|
|
|
|
|
upper_keys = lbl2score.keys() |
|
|
|
|
|
if "AI" in upper_keys and "HUMAN" in upper_keys: |
|
ai_prob = lbl2score["AI"] |
|
human_prob = lbl2score["HUMAN"] |
|
found_pair = True |
|
status_message = "OK (Used AI/HUMAN labels)" |
|
|
|
elif "LABEL_1" in upper_keys and "LABEL_0" in upper_keys: |
|
ai_prob = lbl2score["LABEL_1"] |
|
human_prob = lbl2score["LABEL_0"] |
|
found_pair = True |
|
status_message = "OK (Used LABEL_1/LABEL_0 - Check Mapping)" |
|
print("Warning: Used fallback LABEL_1/LABEL_0. Config expects AI/HUMAN.") |
|
|
|
|
|
|
|
if not found_pair: |
|
if "AI" in upper_keys: |
|
ai_prob = lbl2score["AI"] |
|
human_prob = max(0.0, 1.0 - ai_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from AI label)" |
|
elif "HUMAN" in upper_keys: |
|
human_prob = lbl2score["HUMAN"] |
|
ai_prob = max(0.0, 1.0 - human_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from HUMAN label)" |
|
|
|
|
|
if not inferred: |
|
status_message = f"Error: Could not determine AI/Human pair from labels [{label_keys_found}]" |
|
print(f"Warning: {status_message}") |
|
|
|
|
|
ai_display_str = f"{ai_prob*100:.2f}%" |
|
human_display_str = f"{human_prob*100:.2f}%" |
|
|
|
if "Error:" in status_message: |
|
ai_display_str = status_message |
|
human_display_str = "N/A" |
|
|
|
print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") |
|
return ai_display_str, human_display_str |
|
|
|
|
|
def analyze_image(image: Image.Image, ocr_choice: str): |
|
"""Performs OCR and AI Content Detection, returns both AI and Human %.""" |
|
extracted = "" |
|
ai_result_str = "N/A" |
|
human_result_str = "N/A" |
|
status_update = "Awaiting input..." |
|
|
|
if image is None: |
|
status_update = "Please upload an image first." |
|
return extracted, ai_result_str, human_result_str, status_update |
|
if not ocr_choice or ocr_choice not in TROCR_MODELS: |
|
status_update = "Please select a valid OCR model." |
|
return extracted, ai_result_str, human_result_str, status_update |
|
if OCR_PIPELINES.get(ocr_choice) is None: |
|
return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' failed to load or is unavailable." |
|
if DETECTOR_PIPELINE is None: |
|
return "", "N/A", "N/A", f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?)." |
|
|
|
try: |
|
status_update = f"Processing with {ocr_choice} OCR..." |
|
print(status_update) |
|
proc, mdl = OCR_PIPELINES[ocr_choice] |
|
if image.mode != "RGB": image = image.convert("RGB") |
|
pix = proc(images=image, return_tensors="pt").pixel_values |
|
tokens = mdl.generate(pix, max_length=1024) |
|
extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0] |
|
extracted = extracted.strip() |
|
|
|
if not extracted: |
|
status_update = "OCR completed, but no text was extracted." |
|
print(status_update) |
|
return extracted, "N/A", "N/A", status_update |
|
|
|
status_update = f"Detecting AI/Human content in {len(extracted)} characters..." |
|
print(status_update) |
|
results = DETECTOR_PIPELINE(extracted) |
|
|
|
ai_result_str, human_result_str = get_ai_and_human_scores(results) |
|
|
|
if "Error:" in ai_result_str: |
|
status_update = ai_result_str |
|
else: |
|
status_update = "Analysis complete." |
|
print(f"Final Status: {status_update}") |
|
|
|
return extracted, ai_result_str, human_result_str, status_update |
|
|
|
except Exception as e: |
|
error_msg = f"Error during image analysis: {e}" |
|
print(error_msg) |
|
traceback.print_exc() |
|
status_update = error_msg |
|
return extracted, "Error", "Error", status_update |
|
|
|
|
|
|
|
def classify_text(text: str): |
|
"""Classifies provided text, returning both AI and Human %.""" |
|
ai_result_str = "N/A" |
|
human_result_str = "N/A" |
|
|
|
if DETECTOR_PIPELINE is None: |
|
return f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?).", "N/A" |
|
if not text or text.isspace(): |
|
return "Please enter some text.", "N/A" |
|
|
|
print("Classifying text...") |
|
try: |
|
results = DETECTOR_PIPELINE(text) |
|
|
|
ai_result_str, human_result_str = get_ai_and_human_scores(results) |
|
|
|
if "Error:" not in ai_result_str: |
|
print("Classification complete.") |
|
else: |
|
print(f"Classification completed with issues: {ai_result_str}") |
|
|
|
return ai_result_str, human_result_str |
|
|
|
except Exception as e: |
|
error_msg = f"Error during text classification: {e}" |
|
print(error_msg) |
|
traceback.print_exc() |
|
return error_msg, "Error" |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
f""" |
|
## OCR + AI/Human Content Detection |
|
Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it |
|
using an AI content detector (`{DETECTOR_MODEL_ID}`) |
|
to estimate the likelihood of it being AI-generated vs. Human-written. |
|
**Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods |
|
are estimates based on the model's training data and may not be definitive. |
|
Performance varies with text type, length, and AI generation methods. |
|
**Label Assumption:** Uses the model's configured labels (`AI`/`Human`). Fallbacks for other label formats are included but may be less reliable if the model deviates from its configuration. |
|
""" |
|
) |
|
with gr.Tab("Analyze Image"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"]) |
|
with gr.Column(scale=1): |
|
ocr_dd = gr.Dropdown( |
|
list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image." |
|
) |
|
run_btn = gr.Button("2. Analyze Image", variant="primary") |
|
status_img = gr.Label(value="Awaiting image analysis...", label="Status") |
|
|
|
with gr.Row(): |
|
text_out_img = gr.Textbox(label="Extracted Text", lines=10, interactive=False) |
|
with gr.Column(scale=1): |
|
ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False) |
|
with gr.Column(scale=1): |
|
human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False) |
|
|
|
run_btn.click( |
|
fn=analyze_image, |
|
inputs=[img_in, ocr_dd], |
|
outputs=[text_out_img, ai_out_img, human_out_img, status_img], |
|
queue=True |
|
) |
|
|
|
with gr.Tab("Classify Text"): |
|
with gr.Column(): |
|
text_in_classify = gr.Textbox(label="Paste or type text here", lines=10) |
|
classify_btn = gr.Button("Classify Text", variant="primary") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False) |
|
with gr.Column(scale=1): |
|
human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False) |
|
|
|
classify_btn.click( |
|
fn=classify_text, |
|
inputs=[text_in_classify], |
|
outputs=[ai_out_classify, human_out_classify], |
|
queue=True |
|
) |
|
|
|
gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>") |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting Gradio demo...") |
|
demo.launch(share=False, server_name="0.0.0.0") |