misinfo / src /preprocess /caption.py
gyigit's picture
update
54e8a79
raw
history blame
4.62 kB
import os
from typing import Tuple
import pandas as pd
from tqdm import tqdm
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.utils.path_utils import get_project_root
# Initialize BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
PROJECT_ROOT = get_project_root()
RAW_DIR = PROJECT_ROOT / "data/raw/factify"
PROCESSED_DIR = PROJECT_ROOT / "data/preprocessed"
BATCH_SIZE = 20 # Number of rows to process per batch
def generate_caption(image_path: str) -> str:
"""Generates a caption for an image given its path."""
try:
image = Image.open(f"{PROJECT_ROOT}/{image_path}").convert("RGB")
inputs = processor(image, return_tensors="pt")
output = model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return ""
def process_image_row(row: pd.Series) -> Tuple[str, str, str, str]:
"""Processes a single row to generate captions and enriched columns."""
claim_image_caption = generate_caption(row["claim_image"])
evidence_image_caption = generate_caption(row["evidence_image"])
claim_enriched = f"{row['claim']}. {claim_image_caption}"
evidence_enriched = f"{row['evidence']}. {evidence_image_caption}"
return (
claim_image_caption,
evidence_image_caption,
claim_enriched,
evidence_enriched,
)
def get_last_processed_index(df: pd.DataFrame) -> int:
"""
Find the last processed row index by searching backwards from the end
until finding a row where evidence_image_caption is not NA.
Returns -1 if no processed rows are found.
"""
for idx in range(len(df) - 1, -1, -1):
if pd.notna(df.loc[idx, "evidence_image_caption"]):
return idx
return -1
def process_csv(input_csv: str, output_csv: str) -> None:
"""Processes the CSV in chunks and writes results incrementally with efficient checkpointing."""
# Load input DataFrame
input_df = pd.read_csv(input_csv)
# Initialize or load output DataFrame
if os.path.exists(output_csv):
output_df = pd.read_csv(output_csv)
if len(output_df) != len(input_df):
print(
"Mismatch in input and output CSV lengths. Reinitializing output CSV..."
)
else:
output_df = input_df.copy()
for col in [
"claim_image_caption",
"evidence_image_caption",
"claim_enriched",
"evidence_enriched",
]:
output_df[col] = pd.NA
# Find the last processed index
last_processed_idx = get_last_processed_index(output_df)
print(f"Resuming from index {last_processed_idx + 1}")
# Process remaining rows in batches
total_rows = len(input_df)
with tqdm(total=total_rows, initial=last_processed_idx + 1) as pbar:
for idx in range(last_processed_idx + 1, total_rows, BATCH_SIZE):
batch_end = min(idx + BATCH_SIZE, total_rows)
# Process each row in the batch
for row_idx in range(idx, batch_end):
row = input_df.iloc[row_idx]
# Skip if already processed
if pd.notna(output_df.at[row_idx, "evidence_image_caption"]):
continue
# Process the row
claim_cap, evidence_cap, claim_enr, evidence_enr = process_image_row(
row
)
# Update the output DataFrame
output_df.loc[row_idx, "claim_image_caption"] = claim_cap
output_df.loc[row_idx, "evidence_image_caption"] = evidence_cap
output_df.loc[row_idx, "claim_enriched"] = claim_enr
output_df.loc[row_idx, "evidence_enriched"] = evidence_enr
pbar.update(1)
# Save after each batch
output_df.to_csv(output_csv, index=False)
print(f"Saved progress at index {batch_end}")
if __name__ == "__main__":
for name in ["train", "test"]:
input_csv = f"{PROCESSED_DIR}/{name}.csv"
output_csv = f"{PROCESSED_DIR}/{name}_enriched.csv"
if not os.path.exists(input_csv):
raise FileNotFoundError(f"Input CSV file does not exist: {input_csv}")
process_csv(input_csv, output_csv)
print(f"Processing complete. Output saved to {output_csv}")