xj commited on
Commit
f7a3ffd
·
1 Parent(s): a76d79d

feat: 添加了gradio 的界面.

Browse files
Files changed (4) hide show
  1. README.md +27 -3
  2. gradio_app.py +353 -0
  3. modeling_tio.py +12 -2
  4. utils_tio.py +56 -0
README.md CHANGED
@@ -9,6 +9,27 @@ language:
9
 
10
  TiO is an Interactive Visual Grounding Model for Disambiguation. (WIP)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ## Mini-Example
13
  ```python
14
  from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
@@ -21,18 +42,21 @@ import requests
21
  tokenizer = AutoTokenizer.from_pretrained("jxu124/TiO", use_fast=False)
22
  image_processor = AutoImageProcessor.from_pretrained("jxu124/TiO")
23
  model = AutoModel.from_pretrained("jxu124/TiO", trust_remote_code=True)
24
- model = model.to(torch.float16).cuda() # It will be faster when using float16.
25
 
26
  # Prepare example
27
  image = Image.open(BytesIO(requests.get("http://images.cocodataset.org/val2014/COCO_val2014_000000429913.jpg").content))
28
- text = " #instruction: guess what i want? \n #context: \"human: look that man in white! \""
 
 
 
29
 
30
  # Inference
31
  with torch.no_grad():
32
  pt_txt = tokenizer([text], return_tensors="pt").input_ids.cuda()
33
  pt_img = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).cuda()
34
  gen = model.generate(pt_txt, patch_images=pt_img, top_p=0.5, do_sample=True, no_repeat_ngram_size=3, max_length=256)
35
- print(tokenizer.batch_decode(gen, skip_special_tokens=True))
36
  # e.g. [' is he the one who just threw the ball?'] # Due to the generator, different results may be output
37
  ```
38
 
 
9
 
10
  TiO is an Interactive Visual Grounding Model for Disambiguation. (WIP)
11
 
12
+ ## Online / offline Demo
13
+
14
+ ```python
15
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
16
+
17
+ model_id = "jxu124/TiO"
18
+ model = AutoModel.from_pretrained(
19
+ model_id,
20
+ trust_remote_code=True,
21
+ torch_dtype=torch.float16,
22
+ device_map='cuda',
23
+ # load_in_4bit=True,
24
+ # bnb_4bit_compute_dtype=torch.float16,
25
+ )
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
27
+ image_processor = AutoImageProcessor.from_pretrained(model_id)
28
+ # setup gradio demo
29
+ model.get_gradio_demo(tokenizer, image_processor).\
30
+ queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)
31
+ ```
32
+
33
  ## Mini-Example
34
  ```python
35
  from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
 
42
  tokenizer = AutoTokenizer.from_pretrained("jxu124/TiO", use_fast=False)
43
  image_processor = AutoImageProcessor.from_pretrained("jxu124/TiO")
44
  model = AutoModel.from_pretrained("jxu124/TiO", trust_remote_code=True)
45
+ model = model.to(torch.float16).cuda() # It would be faster.
46
 
47
  # Prepare example
48
  image = Image.open(BytesIO(requests.get("http://images.cocodataset.org/val2014/COCO_val2014_000000429913.jpg").content))
49
+ text = """\
50
+ #instruction: can you specify which region the context describes?
51
+ #context:
52
+ human: look that man in white!"""
53
 
54
  # Inference
55
  with torch.no_grad():
56
  pt_txt = tokenizer([text], return_tensors="pt").input_ids.cuda()
57
  pt_img = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).cuda()
58
  gen = model.generate(pt_txt, patch_images=pt_img, top_p=0.5, do_sample=True, no_repeat_ngram_size=3, max_length=256)
59
+ print(tokenizer.batch_decode(gen, skip_special_tokens=True).replace("not yet.", ""))
60
  # e.g. [' is he the one who just threw the ball?'] # Due to the generator, different results may be output
61
  ```
62
 
gradio_app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import Iterator
3
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, TextIteratorStreamer
4
+ from PIL import Image as PILImage
5
+ import tempfile
6
+ import torch
7
+ import gradio as gr
8
+
9
+
10
+ def get_gradio_demo(model, tokenizer, image_processor) -> gr.Interface:
11
+
12
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
13
+ system_prompt: str) -> str:
14
+ texts = [f'#instruction: {system_prompt}\n', '#context:\n']
15
+ texts += [f"human: {user_input.strip()}\nagent: {response.strip()}\n" for user_input, response in chat_history if isinstance(user_input, str)]
16
+ texts += [f'human: {message.strip()}']
17
+ return ''.join(texts)
18
+
19
+
20
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
21
+ prompt = get_prompt(message, chat_history, system_prompt)
22
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
23
+ return input_ids.shape[-1]
24
+
25
+
26
+ def run(image: PILImage.Image,
27
+ message: str,
28
+ chat_history: list[tuple[str, str]],
29
+ system_prompt: str,
30
+ max_new_tokens: int = 192,
31
+ temperature: float = 0.1,
32
+ top_p: float = 0.9,
33
+ top_k: int = 50) -> Iterator[str]:
34
+ prompt = get_prompt(message, chat_history, system_prompt)
35
+ patch_images = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).to('cuda')
36
+ inputs = tokenizer([prompt], return_tensors='pt').to('cuda')
37
+
38
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.) #
39
+ generate_kwargs = dict(
40
+ inputs,
41
+ patch_images=patch_images,
42
+ streamer=streamer,
43
+ max_length=max_new_tokens,
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ temperature=temperature,
48
+ num_beams=1,
49
+ )
50
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
51
+ t.start()
52
+
53
+ outputs = []
54
+ for text in streamer:
55
+ outputs.append(text)
56
+ yield ''.join(outputs).replace("not yet.", "").replace("<s>", "").replace("</s>", "").strip()
57
+
58
+ # -------
59
+
60
+ DEFAULT_SYSTEM_PROMPT = """can you specify which region the context describes?"""
61
+ MAX_MAX_NEW_TOKENS = 512
62
+ DEFAULT_MAX_NEW_TOKENS = 128
63
+ MAX_INPUT_TOKEN_LENGTH = 512
64
+
65
+ DESCRIPTION = """<h1 align="center">TiO Demo</h1>
66
+ <div align="center">https://huggingface.co/jxu124/TiO</div>
67
+ """
68
+
69
+ LICENSE = """
70
+ <p/>
71
+
72
+ ---
73
+ """
74
+
75
+ if not torch.cuda.is_available():
76
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
77
+
78
+
79
+ def upload_image(file_obj):
80
+ chatbot = [[(file_obj.name,), None]]
81
+ return (gr.update(visible=False), gr.update(interactive=True, placeholder='Type a message...',), chatbot)
82
+
83
+
84
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
85
+ return '', message
86
+
87
+
88
+ def display_input(message: str,
89
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
90
+ if len(history) == 0:
91
+ raise gr.Error(f'Upload an image first and try again.')
92
+ history.append((message, ''))
93
+ return history
94
+
95
+
96
+ def delete_prev_fn(
97
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
98
+ try:
99
+ message, _ = history.pop()
100
+ except IndexError:
101
+ message = ''
102
+ return history, message or ''
103
+
104
+
105
+ def generate(
106
+ message: str,
107
+ history_with_input: list[tuple[str, str]],
108
+ system_prompt: str,
109
+ max_new_tokens: int,
110
+ temperature: float,
111
+ top_p: float,
112
+ top_k: int,
113
+ ) -> Iterator[list[tuple[str, str]]]:
114
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
115
+ raise ValueError
116
+
117
+ image = PILImage.open(history_with_input[0][0][0])
118
+ history = history_with_input[:-1]
119
+ generator = run(image, message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
120
+ try:
121
+ first_response = next(generator)
122
+ yield history + [(message, first_response)]
123
+ except StopIteration:
124
+ yield history + [(message, '')]
125
+ for response in generator:
126
+ chatbot = history + [(message, response)]
127
+ if "region:" in response:
128
+ bboxes = model.utils.sbbox_to_bbox(response)
129
+ if len(bboxes):
130
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
131
+ model.utils.show_mask(image, bboxes).save(f)
132
+ chatbot += [(None, (f.name,))]
133
+ yield chatbot
134
+
135
+
136
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
137
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 192, 1, 0.95, 50)
138
+ for x in generator:
139
+ pass
140
+ return '', x
141
+
142
+
143
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
144
+ input_token_length = get_input_token_length(message, chat_history[:-1], system_prompt)
145
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
146
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
147
+
148
+
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown(DESCRIPTION)
151
+
152
+ with gr.Group():
153
+ chatbot = gr.Chatbot(label='Chatbot')
154
+ imagebox = gr.File(
155
+ file_types=["image"],
156
+ show_label=False,
157
+ )
158
+ with gr.Row():
159
+ textbox = gr.Textbox(
160
+ container=False,
161
+ show_label=False,
162
+ interactive=False,
163
+ placeholder='Upload an image...',
164
+ scale=10,
165
+ )
166
+ submit_button = gr.Button('Submit',
167
+ variant='primary',
168
+ scale=1,
169
+ min_width=0)
170
+ with gr.Row():
171
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
172
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
173
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
174
+
175
+ saved_input = gr.State()
176
+
177
+ with gr.Accordion(label='Advanced options', open=False):
178
+ system_prompt = gr.Textbox(label='System prompt',
179
+ value=DEFAULT_SYSTEM_PROMPT,
180
+ lines=6)
181
+ max_new_tokens = gr.Slider(
182
+ label='Max new tokens',
183
+ minimum=1,
184
+ maximum=MAX_MAX_NEW_TOKENS,
185
+ step=1,
186
+ value=DEFAULT_MAX_NEW_TOKENS,
187
+ )
188
+ temperature = gr.Slider(
189
+ label='Temperature',
190
+ minimum=0.1,
191
+ maximum=4.0,
192
+ step=0.1,
193
+ value=0.5,
194
+ )
195
+ top_p = gr.Slider(
196
+ label='Top-p (nucleus sampling)',
197
+ minimum=0.05,
198
+ maximum=1.0,
199
+ step=0.05,
200
+ value=0.9,
201
+ )
202
+ top_k = gr.Slider(
203
+ label='Top-k',
204
+ minimum=1,
205
+ maximum=1000,
206
+ step=1,
207
+ value=20,
208
+ )
209
+
210
+ gr.Markdown(LICENSE)
211
+ imagebox.upload(
212
+ fn=upload_image,
213
+ inputs=imagebox,
214
+ outputs=[imagebox, textbox, chatbot],
215
+ api_name=None,
216
+ queue=False,
217
+ )
218
+
219
+ textbox.submit(
220
+ fn=clear_and_save_textbox,
221
+ inputs=textbox,
222
+ outputs=[textbox, saved_input],
223
+ api_name=None,
224
+ queue=False,
225
+ ).then(
226
+ fn=display_input,
227
+ inputs=[saved_input, chatbot],
228
+ outputs=chatbot,
229
+ api_name=None,
230
+ queue=False,
231
+ ).then(
232
+ fn=check_input_token_length,
233
+ inputs=[saved_input, chatbot, system_prompt],
234
+ api_name=None,
235
+ queue=False,
236
+ ).success(
237
+ fn=generate,
238
+ inputs=[
239
+ saved_input,
240
+ chatbot,
241
+ system_prompt,
242
+ max_new_tokens,
243
+ temperature,
244
+ top_p,
245
+ top_k,
246
+ ],
247
+ outputs=chatbot,
248
+ api_name="generate",
249
+ )
250
+
251
+ button_event_preprocess = submit_button.click(
252
+ fn=clear_and_save_textbox,
253
+ inputs=textbox,
254
+ outputs=[textbox, saved_input],
255
+ api_name=None,
256
+ queue=False,
257
+ ).then(
258
+ fn=display_input,
259
+ inputs=[saved_input, chatbot],
260
+ outputs=chatbot,
261
+ api_name=None,
262
+ queue=False,
263
+ ).then(
264
+ fn=check_input_token_length,
265
+ inputs=[saved_input, chatbot, system_prompt],
266
+ api_name=None,
267
+ queue=False,
268
+ ).success(
269
+ fn=generate,
270
+ inputs=[
271
+ saved_input,
272
+ chatbot,
273
+ system_prompt,
274
+ max_new_tokens,
275
+ temperature,
276
+ top_p,
277
+ top_k,
278
+ ],
279
+ outputs=chatbot,
280
+ api_name=None,
281
+ )
282
+
283
+ retry_button.click(
284
+ fn=delete_prev_fn,
285
+ inputs=chatbot,
286
+ outputs=[chatbot, saved_input],
287
+ api_name=None,
288
+ queue=False,
289
+ ).then(
290
+ fn=display_input,
291
+ inputs=[saved_input, chatbot],
292
+ outputs=chatbot,
293
+ api_name=None,
294
+ queue=False,
295
+ ).then(
296
+ fn=generate,
297
+ inputs=[
298
+ saved_input,
299
+ chatbot,
300
+ system_prompt,
301
+ max_new_tokens,
302
+ temperature,
303
+ top_p,
304
+ top_k,
305
+ ],
306
+ outputs=chatbot,
307
+ api_name=None,
308
+ )
309
+
310
+ undo_button.click(
311
+ fn=delete_prev_fn,
312
+ inputs=chatbot,
313
+ outputs=[chatbot, saved_input],
314
+ api_name=None,
315
+ queue=False,
316
+ ).then(
317
+ fn=lambda x: x,
318
+ inputs=[saved_input],
319
+ outputs=textbox,
320
+ api_name=None,
321
+ queue=False,
322
+ )
323
+
324
+ clear_button.click(
325
+ fn=lambda: ([], '', gr.update(value=None, visible=True), gr.update(interactive=False, placeholder='Upload an image...',)),
326
+ outputs=[chatbot, saved_input, imagebox, textbox],
327
+ queue=False,
328
+ api_name=None,
329
+ )
330
+
331
+ return demo
332
+
333
+
334
+ def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None):
335
+ if torch.cuda.is_available():
336
+ model = AutoModel.from_pretrained(
337
+ model_id,
338
+ trust_remote_code=True,
339
+ torch_dtype=torch.float16,
340
+ device_map='cuda',
341
+ # load_in_4bit=True,
342
+ # bnb_4bit_compute_dtype=torch.float16,
343
+ )
344
+ else:
345
+ model = None
346
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
347
+ image_processor = AutoImageProcessor.from_pretrained(model_id)
348
+ model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port)
349
+
350
+
351
+ if __name__ == "__main__":
352
+ import fire
353
+ fire.Fire(main)
modeling_tio.py CHANGED
@@ -87,8 +87,12 @@ def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSI
87
  sign = torch.sign(relative_pos)
88
  mid = bucket_size // 2
89
  abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
90
- log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
91
- log_pos = log_pos.int()
 
 
 
 
92
  bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
93
  return bucket_pos + bucket_size - 1
94
 
@@ -2013,3 +2017,9 @@ class TiOModel(TiOPreTrainedModel):
2013
  )
2014
  model_kwargs["encoder_outputs"] = encoder_outputs
2015
  return input_ids, model_kwargs
 
 
 
 
 
 
 
87
  sign = torch.sign(relative_pos)
88
  mid = bucket_size // 2
89
  abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
90
+ # import pdb; pdb.set_trace()
91
+ # log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
92
+ # log_pos = log_pos.int()
93
+ import numpy as np
94
+ log_pos = np.ceil(np.log(abs_pos.numpy() / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
95
+ log_pos = torch.tensor(log_pos.astype('int64'))
96
  bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
97
  return bucket_pos + bucket_size - 1
98
 
 
2017
  )
2018
  model_kwargs["encoder_outputs"] = encoder_outputs
2019
  return input_ids, model_kwargs
2020
+
2021
+
2022
+ from .utils_tio import Utils
2023
+ from .gradio_app import get_gradio_demo
2024
+ TiOModel.utils = Utils
2025
+ TiOModel.get_gradio_demo = get_gradio_demo
utils_tio.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image as PILImage
2
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
3
+ import re
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ class Utils():
9
+ def xywh2xyxy(b):
10
+ b[..., 2:] += b[..., :2]
11
+ return b
12
+
13
+ def bbox_to_sbbox(bbox):
14
+ # xyxy in [0, 1]
15
+ assert len(bbox) == 4
16
+ bbox = (np.asarray(bbox) * 1000).astype(np.int16)
17
+ bbox = np.clip(bbox, 0, 999)
18
+ bbox = " ".join([f"<bin_{i}>" for i in bbox])
19
+ return bbox
20
+
21
+ def sbbox_to_bbox(sbbox):
22
+ sbbox = [re.findall(r"<bin_(\d+)>", s)[:4] for s in sbbox.split("region:")]
23
+ bbox = np.asarray([s for s in sbbox if len(s) >= 4], dtype=int)
24
+ bbox = np.clip(bbox / 1000, 1e-3, 1 - 1e-3)
25
+ return bbox.reshape(-1, 4)
26
+
27
+ def make_dialog_context(dialog: list, text_human: str = None) -> str:
28
+ # dialog: [("pass me an apple.", "which apple do you want?"), ...]
29
+ context = "".join([f"human: {d[0]}\nagent: {d[1]}\n" for d in dialog])
30
+ if text_human is not None:
31
+ context += f"human: {text_human}"
32
+ return context
33
+
34
+ def show_mask(image: PILImage.Image, bboxes=None, masks=None, show_id=False, text_size=1) -> PILImage.Image:
35
+ """ 给图片画上mask: 只更改被mask标记部分的rgb值. """
36
+ import colorsys
37
+ colors = [tuple(int(c * 255) for c in colorsys.hsv_to_rgb(i * 1.0 / 36, 1, 1)) for i in range(36)]
38
+ size = image.size
39
+ image = np.asarray(image)
40
+ if bboxes is not None:
41
+ bboxes = np.array(bboxes).reshape(-1, 4)
42
+ for k, bbox in enumerate(bboxes):
43
+ bbox = (np.asarray(bbox) * np.asarray([*size, *size])).astype(int)
44
+ image = cv2.rectangle(image, tuple(bbox[:2]), tuple(bbox[2:]), tuple(colors[k]), thickness=2)
45
+ if show_id:
46
+ for k, bbox in enumerate(bboxes):
47
+ bbox = (np.asarray(bbox) * np.asarray([*size, *size])).astype(int)
48
+ image = cv2.putText(image, str(k), tuple(bbox[:2] + np.array([2, 28 * text_size])), cv2.FONT_HERSHEY_SIMPLEX, text_size, (255, 255, 255), 2, cv2.LINE_AA)
49
+ image = cv2.putText(image, str(k), tuple(bbox[:2] + np.array([2, 28 * text_size])), cv2.FONT_HERSHEY_SIMPLEX, text_size, tuple(colors[k%len(colors)]), 1, cv2.LINE_AA)
50
+
51
+ if masks is not None:
52
+ for k, mask in enumerate(masks):
53
+ mask_color = (mask[..., None] * colors[k%len(colors)][:3]).astype(np.uint8)
54
+ image_mask = cv2.addWeighted(mask_color, 0.5, image * mask[..., None], 0.5, 0)
55
+ image = cv2.add(image * ~mask[..., None], image_mask)
56
+ return PILImage.fromarray(image)