gxy commited on
Commit
e58a035
·
1 Parent(s): 86d1668
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -15,7 +15,7 @@ model = BlipForConditionalGeneration.from_pretrained(
15
 
16
  def inference(raw_image, model_n, question, strategy):
17
  if model_n == 'Image Captioning':
18
- image = processor(raw_image).to(device, torch.float16)
19
  with torch.no_grad():
20
  if strategy == "Beam search":
21
  config = GenerationConfig(
@@ -24,7 +24,7 @@ def inference(raw_image, model_n, question, strategy):
24
  max_length=20,
25
  min_length=5,
26
  )
27
- captions = model.generate(image, generation_config=config)
28
  else:
29
  config = GenerationConfig(
30
  do_sample=True,
@@ -32,7 +32,7 @@ def inference(raw_image, model_n, question, strategy):
32
  max_length=20,
33
  min_length=5,
34
  )
35
- captions = model.generate(image, generation_config=config)
36
  caption = processor.decode(captions[0], skip_special_tokens=True)
37
  caption = caption.replace(' ', '')
38
  return 'caption: '+caption
 
15
 
16
  def inference(raw_image, model_n, question, strategy):
17
  if model_n == 'Image Captioning':
18
+ input = processor(raw_image).to(device, torch.float16)
19
  with torch.no_grad():
20
  if strategy == "Beam search":
21
  config = GenerationConfig(
 
24
  max_length=20,
25
  min_length=5,
26
  )
27
+ captions = model.generate(**input, generation_config=config)
28
  else:
29
  config = GenerationConfig(
30
  do_sample=True,
 
32
  max_length=20,
33
  min_length=5,
34
  )
35
+ captions = model.generate(**input, generation_config=config)
36
  caption = processor.decode(captions[0], skip_special_tokens=True)
37
  caption = caption.replace(' ', '')
38
  return 'caption: '+caption