alvarobartt HF Staff commited on
Commit
8284ee8
·
verified ·
1 Parent(s): f74a199

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -15
handler.py CHANGED
@@ -42,12 +42,6 @@ class EndpointHandler:
42
  new_message["content"] += content["text"]
43
  elif content["type"] == "image_url":
44
  images.append(load_image(content["image_url"]["url"]))
45
- logger.info(
46
- "Loaded image using `transformers.image_utils.load_image`"
47
- )
48
- logger.info(
49
- f"Current {new_message['content']} text if any contains {new_message['content'].count(IMAGE_TOKENS)} image tokens"
50
- )
51
  if new_message["content"].count(
52
  f"{IMAGE_TOKENS}{SEPARATOR}"
53
  ) < len(images):
@@ -72,27 +66,24 @@ class EndpointHandler:
72
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
73
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
74
  inputs = inputs.to("cuda").to(torch.bfloat16)
75
- logger.info(f"Inputs contains {inputs=}")
76
 
77
  generation_args = {
78
  "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
79
- "temperature": data.get("temperature", 0.7),
80
- "do_sample": False,
81
- "use_cache": True,
82
  "num_beams": 1,
83
  }
84
- logger.info(f"Running text generation with the following {generation_args=}")
 
 
85
 
86
  with torch.inference_mode():
87
- logger.info(f"Inputs contains {inputs['input_ids']=}")
88
  generate_ids = self.model.generate(**inputs, **generation_args)
89
- logger.info(f"Generate IDs contains {generate_ids=}")
90
 
91
- logger.info(f"Generated {generate_ids=}")
92
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
93
  response = self.processor.decode(
94
  generate_ids[0], skip_special_tokens=True
95
  ).strip()
96
- logger.info(f"Generated the {response=}")
97
 
98
  return {"generated_text": response}
 
42
  new_message["content"] += content["text"]
43
  elif content["type"] == "image_url":
44
  images.append(load_image(content["image_url"]["url"]))
 
 
 
 
 
 
45
  if new_message["content"].count(
46
  f"{IMAGE_TOKENS}{SEPARATOR}"
47
  ) < len(images):
 
66
  inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
67
  inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
68
  inputs = inputs.to("cuda").to(torch.bfloat16)
 
69
 
70
  generation_args = {
71
  "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
72
+ "temperature": data.get("temperature", 0.0),
73
+ "do_sample": False, # temperature won't really work unless this is set to True
74
+ "use_cache": False, # disabled as otherwise the same prompt with different images won't download the image again
75
  "num_beams": 1,
76
  }
77
+ logger.info(
78
+ f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})"
79
+ )
80
 
81
  with torch.inference_mode():
 
82
  generate_ids = self.model.generate(**inputs, **generation_args)
 
83
 
 
84
  generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
85
  response = self.processor.decode(
86
  generate_ids[0], skip_special_tokens=True
87
  ).strip()
 
88
 
89
  return {"generated_text": response}