versae commited on
Commit
e685151
·
verified ·
1 Parent(s): f25504b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -74,11 +74,11 @@ def download_img(identifier, url):
74
  def predict(image=None, text=None, sketch=None):
75
  if image is not None:
76
  input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy()
77
- topk = {"local": 100}
78
  else:
79
  if text:
80
  query = text
81
- topk = {text: 100}
82
  else:
83
  x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
84
  with torch.no_grad():
@@ -86,7 +86,7 @@ def predict(image=None, text=None, sketch=None):
86
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
87
  values, indices = torch.topk(probabilities, 5)
88
  query = LABELS[indices[0]]
89
- topk = {LABELS[i]: v.item() for i, v in zip(indices, values)}
90
  input_embeddings = compute_text_embeddings([query]).detach().numpy()
91
 
92
  n_results = 3
 
74
  def predict(image=None, text=None, sketch=None):
75
  if image is not None:
76
  input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy()
77
+ topk = {"local": 1}
78
  else:
79
  if text:
80
  query = text
81
+ topk = {text: 1}
82
  else:
83
  x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
84
  with torch.no_grad():
 
86
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
87
  values, indices = torch.topk(probabilities, 5)
88
  query = LABELS[indices[0]]
89
+ topk = {LABELS[i]: v.item() / 100.0 for i, v in zip(indices, values)}
90
  input_embeddings = compute_text_embeddings([query]).detach().numpy()
91
 
92
  n_results = 3