Ryukijano commited on
Commit
68b7617
Β·
verified Β·
1 Parent(s): 3a54a7a

Update vision_llm.py

Browse files
Files changed (1) hide show
  1. vision_llm.py +6 -3
vision_llm.py CHANGED
@@ -2,12 +2,15 @@
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
 
5
 
6
  class VisionLLM:
7
- def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224"): # Corrected model ID
8
  self.device = device
9
- self.processor = AutoProcessor.from_pretrained(model_id)
10
- self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
 
 
11
 
12
  def describe_images(self, images, prompt="", max_length=128):
13
  if isinstance(images, list):
 
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
5
+ import os
6
 
7
  class VisionLLM:
8
+ def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224", use_auth_token=None): # Corrected __init__
9
  self.device = device
10
+ if use_auth_token is None:
11
+ use_auth_token=os.environ.get("HF_TOKEN", None)
12
+ self.processor = AutoProcessor.from_pretrained(model_id, use_auth_token=use_auth_token) #Use the auth token
13
+ self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16, use_auth_token=use_auth_token).to(self.device) #Use the auth token
14
 
15
  def describe_images(self, images, prompt="", max_length=128):
16
  if isinstance(images, list):