oddadmix's picture
Update README.md
0f320df verified
metadata
base_model: unsloth/qwen2-vl-2b-instruct-unsloth-bnb-4bit
library_name: peft
model_name: Khanandeh-0.1-Persian-OCR-2B-Instruct
tags:
  - generated_from_trainer
  - unsloth
  - trl
  - sft
licence: license

Model Card for Khanandeh-0.1-Persian-OCR-2B-Instruct

This model is a fine-tuned version of unsloth/qwen2-vl-2b-instruct-unsloth-bnb-4bit. It has been trained using TRL.

You can load this model using the transformers and qwen_vl_utils library:

!pip install transformers qwen_vl_utils accelerate>=0.26.0 PEFT -U
!pip install -U bitsandbytes
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
import os
from qwen_vl_utils import process_vision_info



model_name = "oddadmix/Khanandeh-0.1-Persian-OCR-2B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype="auto",
                device_map="auto"
            )
processor = AutoProcessor.from_pretrained(model_name)
max_tokens = 2000

prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
image.save("image.png")

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": f"file://{src}"},
            {"type": "text", "text": prompt},
        ],
    }
]
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
os.remove(src)
print(output_text)