Update chatbot.py
Browse files- chatbot.py +1 -7
chatbot.py
CHANGED
@@ -40,12 +40,6 @@ class Chatbot:
|
|
40 |
|
41 |
def load_models(self):
|
42 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
43 |
-
self.bert_model_name = "bert-base-uncased"
|
44 |
-
self.bert_model = BertModel.from_pretrained(self.bert_model_name)
|
45 |
-
self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
|
46 |
-
self.gpt2_model_name = "gpt2"
|
47 |
-
self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
|
48 |
-
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
|
49 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
50 |
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
51 |
|
@@ -81,7 +75,7 @@ class Chatbot:
|
|
81 |
docs.append(doc)
|
82 |
return docs
|
83 |
|
84 |
-
def get_results(self, query, embeddings, top_k=
|
85 |
query_embedding = self.model.encode([query])
|
86 |
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
|
87 |
top_results = torch.topk(cos_scores, k=top_k)
|
|
|
40 |
|
41 |
def load_models(self):
|
42 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
44 |
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
45 |
|
|
|
75 |
docs.append(doc)
|
76 |
return docs
|
77 |
|
78 |
+
def get_results(self, query, embeddings, top_k=5):
|
79 |
query_embedding = self.model.encode([query])
|
80 |
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
|
81 |
top_results = torch.topk(cos_scores, k=top_k)
|