ruben3010 commited on
Commit
f15388e
·
verified ·
1 Parent(s): 640419b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -2
handler.py CHANGED
@@ -27,8 +27,11 @@ class EndpointHandler():
27
 
28
  # Load the Qwen2-VL model
29
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
30
- self.model_dir, torch_dtype="auto", device_map="auto"
31
- )
 
 
 
32
  self.processor = AutoProcessor.from_pretrained(self.model_dir)
33
 
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
27
 
28
  # Load the Qwen2-VL model
29
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
30
+ self.model_dir,
31
+ torch_dtype=torch.bfloat16,
32
+ attn_implementation="flash_attention_2",
33
+ device_map="auto",
34
+ )
35
  self.processor = AutoProcessor.from_pretrained(self.model_dir)
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: