misinfo / src /preprocess /caption.py
gyigit's picture
history blame
4.62 kB
import os
from typing import Tuple
import pandas as pd
from tqdm import tqdm
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.utils.path_utils import get_project_root
# Initialize BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained(
PROJECT_ROOT = get_project_root()
RAW_DIR = PROJECT_ROOT / "data/raw/factify"
PROCESSED_DIR = PROJECT_ROOT / "data/preprocessed"
BATCH_SIZE = 20 # Number of rows to process per batch
def generate_caption(image_path: str) -> str:
"""Generates a caption for an image given its path."""
image = Image.open(f"{PROJECT_ROOT}/{image_path}").convert("RGB")
inputs = processor(image, return_tensors="pt")
output = model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return ""
def process_image_row(row: pd.Series) -> Tuple[str, str, str, str]:
"""Processes a single row to generate captions and enriched columns."""
claim_image_caption = generate_caption(row["claim_image"])
evidence_image_caption = generate_caption(row["evidence_image"])
claim_enriched = f"{row['claim']}. {claim_image_caption}"
evidence_enriched = f"{row['evidence']}. {evidence_image_caption}"
return (
def get_last_processed_index(df: pd.DataFrame) -> int:
Find the last processed row index by searching backwards from the end
until finding a row where evidence_image_caption is not NA.
Returns -1 if no processed rows are found.
for idx in range(len(df) - 1, -1, -1):
if pd.notna(df.loc[idx, "evidence_image_caption"]):
return idx
return -1
def process_csv(input_csv: str, output_csv: str) -> None:
"""Processes the CSV in chunks and writes results incrementally with efficient checkpointing."""
# Load input DataFrame
input_df = pd.read_csv(input_csv)
# Initialize or load output DataFrame
if os.path.exists(output_csv):
output_df = pd.read_csv(output_csv)
if len(output_df) != len(input_df):
"Mismatch in input and output CSV lengths. Reinitializing output CSV..."
output_df = input_df.copy()
for col in [
output_df[col] = pd.NA
# Find the last processed index
last_processed_idx = get_last_processed_index(output_df)
print(f"Resuming from index {last_processed_idx + 1}")
# Process remaining rows in batches
total_rows = len(input_df)
with tqdm(total=total_rows, initial=last_processed_idx + 1) as pbar:
for idx in range(last_processed_idx + 1, total_rows, BATCH_SIZE):
batch_end = min(idx + BATCH_SIZE, total_rows)
# Process each row in the batch
for row_idx in range(idx, batch_end):
row = input_df.iloc[row_idx]
# Skip if already processed
if pd.notna(output_df.at[row_idx, "evidence_image_caption"]):
# Process the row
claim_cap, evidence_cap, claim_enr, evidence_enr = process_image_row(
# Update the output DataFrame
output_df.loc[row_idx, "claim_image_caption"] = claim_cap
output_df.loc[row_idx, "evidence_image_caption"] = evidence_cap
output_df.loc[row_idx, "claim_enriched"] = claim_enr
output_df.loc[row_idx, "evidence_enriched"] = evidence_enr
# Save after each batch
output_df.to_csv(output_csv, index=False)
print(f"Saved progress at index {batch_end}")
if __name__ == "__main__":
for name in ["train", "test"]:
input_csv = f"{PROCESSED_DIR}/{name}.csv"
output_csv = f"{PROCESSED_DIR}/{name}_enriched.csv"
if not os.path.exists(input_csv):
raise FileNotFoundError(f"Input CSV file does not exist: {input_csv}")
process_csv(input_csv, output_csv)
print(f"Processing complete. Output saved to {output_csv}")