alakxender's picture
Update README.md
3cf6a9d verified
---
language: dv
tags:
- vision
- image-to-text
- OCR
- Dhivehi
- PaliGemma2
- thaana
license: apache-2.0
datasets:
- alakxender/dhivehi-vrd-images
metrics:
- accuracy
base_model:
- google/paligemma2-3b-pt-224
library_name: transformers
---
# PaliGemma2 VRD Dhivehi OCR Model
## Model Description
This is a fine-tuned version of the PaliGemma2 model specifically optimized for Optical Character Recognition (OCR) of Dhivehi text in images. (ck-162810)
The model is based on the `google/paligemma2-3b-pt-224` architecture and has been fine-tuned for improved performance in reading and transcribing Dhivehi text from images.
## Model Details
- **Model type:** Vision-Language Model
- **Base model:** google/paligemma2-3b-pt-224
- **Fine-tuning approach:** QLoRA
- **Input format:** Images with text
- **Output format:** Text transcription
- **Supported languages:** Primarily Dhivehi
## How to Use
### Option 1: Direct Loading
```python
from transformers.image_utils import load_image
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
# Print GPU information
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"Current GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")
model_id = "alakxender/paligemma2-qlora-dhivehi-ocr-224-sl-14k"
print("Loading model...")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)
print("Loading image...")
image = load_image("ocr1.png")
print("Processing image...")
prompt = "What text is written in this image?"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to("cuda")
input_len = model_inputs["input_ids"].shape[-1]
print("Model inputs device:", model_inputs["input_ids"].device)
print("Model device:", model.device)
print("Generating output...")
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
print("Done!")
```
### Option 2: Memory-Efficient PEFT Loading
```python
from transformers.image_utils import load_image
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from peft import PeftModel, PeftConfig
# Define model ID
model_id = "alakxender/paligemma2-qlora-dhivehi-ocr-224-sl-14k"
# Load the PEFT configuration to get the base model path
print("Loading PEFT configuration...")
peft_config = PeftConfig.from_pretrained(model_id)
# Load the base model
print(f"Loading base model: {peft_config.base_model_name_or_path}...")
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path,
device_map="auto",
torch_dtype=torch.bfloat16
)
# Load the adapter on top of the base model
print(f"Loading PEFT adapter: {model_id}...")
model = PeftModel.from_pretrained(base_model, model_id)
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path)
print("Loading image...")
image = load_image("ocr1.png")
print("Processing image...")
prompt = "What text is written in this image?"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16)
# Move inputs to the same device as the model
if hasattr(model, 'device'):
device = model.device
else:
# If device isn't directly accessible, infer from model parameters
device = next(model.parameters()).device
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
input_len = model_inputs["input_ids"].shape[-1]
# Process without printing device information
print("Generating output...")
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
print("Done!")
```
## Training Details
- **Base Model:** google/paligemma2-3b-pt-224
- **Dataset:** alakxender/dhivehi-vrd-mix-img-questions
- **Training Configuration:**
- Batch size: 2 per device
- Gradient accumulation steps: 8
- Effective batch size: 16
- Learning rate: 2e-5
- Weight decay: 1e-6
- Adam β2: 0.999
- Warmup steps: 2
- Training steps: 20,000
- Epochs: 1
- Mixed precision: bfloat16
- **QLoRA Configuration:**
- Quantization: 4-bit NF4
- LoRA rank (r): 8
- Target modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- Task type: CAUSAL_LM
- Optimizer: paged_adamw_8bit
- **Data Processing:**
- Image resize method: LANCZOS
- Input format: RGB images
- Text prompt format: "answer [question]"
- **Training Metrics:**
- Initial loss: ~15
- Final loss: ~2
- Learning rate: Decreasing from 1.5e-5 to 5e-6
- Gradient norm: Stabilized around 20-60
- Model checkpointing: Every 1000 steps
- Logging frequency: Every 100 steps
## Performance
The model showed consistent improvement during training:
- Loss decreased significantly in the first 5k steps and stabilized afterwards
- Gradient norms remained in a healthy range throughout training
- Learning rate was automatically adjusted following a linear decay schedule
- Training completed successfully with convergence in loss metrics
- Training progress was monitored using Weights & Biases
## Model Architecture
This model uses Parameter-Efficient Fine-Tuning (PEFT) with QLoRA:
- **Quantization:** 4-bit quantization for memory efficiency
- **LoRA Adaptation:** Low-rank adaptation of key transformer components
- **Memory Optimization:** Uses paged optimizer for efficient memory usage
- **Mixed Precision:** bfloat16 for training stability and speed
## Limitations
- Primarily optimized for Dhivehi text
- Performance may vary with different image qualities and text styles
- May or may not perform optimally on handwritten text
## Dataset
- **Source Dataset:** alakxender/dhivehi-vrd-images (mixed)
- **Processed Dataset:** alakxender/dhivehi-vrd-b1-img-questions
- **Dataset Size:**
- Total: 1M+ samples
- **Question Types:**
The dataset uses a variety of question prompts for OCR tasks, including:
```
- "What text is written in this image?"
- "Can you read and transcribe the Dhivehi text shown in this image?"
- "What is the Dhivehi text visible in this image?"
- "Please read out the text content from this image"
- "What Dhivehi text can you see in this image?"
- "Is there any text visible in this image? If so, what does it say?"
- "Could you transcribe the Dhivehi text shown in this image?"
- "What does the text in this image say?"
- "Can you read the Dhivehi text in this image? What does it say?"
- "Please identify and transcribe any text visible in this image"
- "What Dhivehi text is present in this image?"
```
- **Dataset Format:**
- Features:
- `image`: Image containing Dhivehi text
- `question`: Randomly selected question from the question pool
- `answer`: Ground truth Dhivehi text transcription
- Processing: Memory-efficient chunked processing (10,000 samples per chunk)