vikhyatk zesquirrelnator commited on
Commit
fa8398d
·
verified ·
1 Parent(s): d551111

Creating a handler.py file to support HF dedicated inference endpoints (#18)

Browse files

- Creating a handler.py file to support HF dedicated inference endpoints (aa98b0abeb8c874aa41f1bd96fe831f84dcd7e6a)
- Update handler.py (43b595da26e564d4f3a877609dbb95b38f367ecb)


Co-authored-by: georges casassovici <[email protected]>

Files changed (1) hide show
  1. handler.py +58 -0
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from PIL import Image
3
+ import torch
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_dir):
9
+ self.model_id = "vikhyatk/moondream2"
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True)
11
+ self.tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", trust_remote_code=True)
12
+
13
+ # Check if CUDA (GPU support) is available and then set the device to GPU or CPU
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model.to(self.device)
16
+
17
+ def preprocess_image(self, encoded_image):
18
+ """Decode and preprocess the input image."""
19
+ decoded_image = base64.b64decode(encoded_image)
20
+ img = Image.open(BytesIO(decoded_image)).convert("RGB")
21
+ return img
22
+
23
+ def __call__(self, data):
24
+ """Handle the incoming request."""
25
+ try:
26
+ # Extract the inputs from the data
27
+ inputs = data.pop("inputs", data)
28
+ input_image = inputs['image']
29
+ question = inputs.get('question', "move to the red ball")
30
+
31
+ # Preprocess the image
32
+ img = self.preprocess_image(input_image)
33
+
34
+ # Perform inference
35
+ enc_image = self.model.encode_image(img).to(self.device)
36
+ answer = self.model.answer_question(enc_image, question, self.tokenizer)
37
+
38
+ # If the output is a tensor, move it back to CPU and convert to list
39
+ if isinstance(answer, torch.Tensor):
40
+ answer = answer.cpu().numpy().tolist()
41
+
42
+ # Create the response
43
+ response = {
44
+ "statusCode": 200,
45
+ "body": {
46
+ "answer": answer
47
+ }
48
+ }
49
+ return response
50
+ except Exception as e:
51
+ # Handle any errors
52
+ response = {
53
+ "statusCode": 500,
54
+ "body": {
55
+ "error": str(e)
56
+ }
57
+ }
58
+ return response