Update vision_llm.py
Browse files- vision_llm.py +6 -3
vision_llm.py
CHANGED
@@ -2,12 +2,15 @@
|
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
3 |
import torch
|
4 |
from PIL import Image
|
|
|
5 |
|
6 |
class VisionLLM:
|
7 |
-
def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224"): # Corrected
|
8 |
self.device = device
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
|
12 |
def describe_images(self, images, prompt="", max_length=128):
|
13 |
if isinstance(images, list):
|
|
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
+
import os
|
6 |
|
7 |
class VisionLLM:
|
8 |
+
def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224", use_auth_token=None): # Corrected __init__
|
9 |
self.device = device
|
10 |
+
if use_auth_token is None:
|
11 |
+
use_auth_token=os.environ.get("HF_TOKEN", None)
|
12 |
+
self.processor = AutoProcessor.from_pretrained(model_id, use_auth_token=use_auth_token) #Use the auth token
|
13 |
+
self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16, use_auth_token=use_auth_token).to(self.device) #Use the auth token
|
14 |
|
15 |
def describe_images(self, images, prompt="", max_length=128):
|
16 |
if isinstance(images, list):
|