Luigi commited on
Commit
56a78b6
·
1 Parent(s): bfdf2ce

type handling

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -81,7 +81,21 @@ def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
81
  tokenize=True,
82
  return_dict=True,
83
  return_tensors='pt'
84
- ).to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
86
 
87
  # Inference
 
81
  tokenize=True,
82
  return_dict=True,
83
  return_tensors='pt'
84
+ )
85
+ # Move inputs to correct device and dtype (matching model parameters)
86
+ param_dtype = next(model.parameters()).dtype
87
+ cast_inputs = {}
88
+ for k, v in inputs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ if v.dtype.is_floating_point:
91
+ # cast floating-point tensors to model's parameter dtype
92
+ cast_inputs[k] = v.to(device=model.device, dtype=param_dtype)
93
+ else:
94
+ # move integer/mask tensors without changing dtype
95
+ cast_inputs[k] = v.to(device=model.device)
96
+ else:
97
+ cast_inputs[k] = v
98
+ inputs = cast_inputs
99
  debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
100
 
101
  # Inference