Spaces:
Runtime error
Runtime error
bugfix
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ model = BlipForConditionalGeneration.from_pretrained(
|
|
15 |
|
16 |
def inference(raw_image, model_n, question, strategy):
|
17 |
if model_n == 'Image Captioning':
|
18 |
-
|
19 |
with torch.no_grad():
|
20 |
if strategy == "Beam search":
|
21 |
config = GenerationConfig(
|
@@ -24,7 +24,7 @@ def inference(raw_image, model_n, question, strategy):
|
|
24 |
max_length=20,
|
25 |
min_length=5,
|
26 |
)
|
27 |
-
captions = model.generate(
|
28 |
else:
|
29 |
config = GenerationConfig(
|
30 |
do_sample=True,
|
@@ -32,7 +32,7 @@ def inference(raw_image, model_n, question, strategy):
|
|
32 |
max_length=20,
|
33 |
min_length=5,
|
34 |
)
|
35 |
-
captions = model.generate(
|
36 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
37 |
caption = caption.replace(' ', '')
|
38 |
return 'caption: '+caption
|
|
|
15 |
|
16 |
def inference(raw_image, model_n, question, strategy):
|
17 |
if model_n == 'Image Captioning':
|
18 |
+
input = processor(raw_image).to(device, torch.float16)
|
19 |
with torch.no_grad():
|
20 |
if strategy == "Beam search":
|
21 |
config = GenerationConfig(
|
|
|
24 |
max_length=20,
|
25 |
min_length=5,
|
26 |
)
|
27 |
+
captions = model.generate(**input, generation_config=config)
|
28 |
else:
|
29 |
config = GenerationConfig(
|
30 |
do_sample=True,
|
|
|
32 |
max_length=20,
|
33 |
min_length=5,
|
34 |
)
|
35 |
+
captions = model.generate(**input, generation_config=config)
|
36 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
37 |
caption = caption.replace(' ', '')
|
38 |
return 'caption: '+caption
|