Update handler.py
Browse files- handler.py +6 -15
handler.py
CHANGED
@@ -42,12 +42,6 @@ class EndpointHandler:
|
|
42 |
new_message["content"] += content["text"]
|
43 |
elif content["type"] == "image_url":
|
44 |
images.append(load_image(content["image_url"]["url"]))
|
45 |
-
logger.info(
|
46 |
-
"Loaded image using `transformers.image_utils.load_image`"
|
47 |
-
)
|
48 |
-
logger.info(
|
49 |
-
f"Current {new_message['content']} text if any contains {new_message['content'].count(IMAGE_TOKENS)} image tokens"
|
50 |
-
)
|
51 |
if new_message["content"].count(
|
52 |
f"{IMAGE_TOKENS}{SEPARATOR}"
|
53 |
) < len(images):
|
@@ -72,27 +66,24 @@ class EndpointHandler:
|
|
72 |
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
|
73 |
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
|
74 |
inputs = inputs.to("cuda").to(torch.bfloat16)
|
75 |
-
logger.info(f"Inputs contains {inputs=}")
|
76 |
|
77 |
generation_args = {
|
78 |
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
|
79 |
-
"temperature": data.get("temperature", 0.
|
80 |
-
"do_sample": False,
|
81 |
-
"use_cache":
|
82 |
"num_beams": 1,
|
83 |
}
|
84 |
-
logger.info(
|
|
|
|
|
85 |
|
86 |
with torch.inference_mode():
|
87 |
-
logger.info(f"Inputs contains {inputs['input_ids']=}")
|
88 |
generate_ids = self.model.generate(**inputs, **generation_args)
|
89 |
-
logger.info(f"Generate IDs contains {generate_ids=}")
|
90 |
|
91 |
-
logger.info(f"Generated {generate_ids=}")
|
92 |
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
93 |
response = self.processor.decode(
|
94 |
generate_ids[0], skip_special_tokens=True
|
95 |
).strip()
|
96 |
-
logger.info(f"Generated the {response=}")
|
97 |
|
98 |
return {"generated_text": response}
|
|
|
42 |
new_message["content"] += content["text"]
|
43 |
elif content["type"] == "image_url":
|
44 |
images.append(load_image(content["image_url"]["url"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
if new_message["content"].count(
|
46 |
f"{IMAGE_TOKENS}{SEPARATOR}"
|
47 |
) < len(images):
|
|
|
66 |
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
|
67 |
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
|
68 |
inputs = inputs.to("cuda").to(torch.bfloat16)
|
|
|
69 |
|
70 |
generation_args = {
|
71 |
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
|
72 |
+
"temperature": data.get("temperature", 0.0),
|
73 |
+
"do_sample": False, # temperature won't really work unless this is set to True
|
74 |
+
"use_cache": False, # disabled as otherwise the same prompt with different images won't download the image again
|
75 |
"num_beams": 1,
|
76 |
}
|
77 |
+
logger.info(
|
78 |
+
f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})"
|
79 |
+
)
|
80 |
|
81 |
with torch.inference_mode():
|
|
|
82 |
generate_ids = self.model.generate(**inputs, **generation_args)
|
|
|
83 |
|
|
|
84 |
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
85 |
response = self.processor.decode(
|
86 |
generate_ids[0], skip_special_tokens=True
|
87 |
).strip()
|
|
|
88 |
|
89 |
return {"generated_text": response}
|