Cartinoe5930 commited on
Commit
5b531fe
·
1 Parent(s): 1d16515

Update model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +2 -4
model_inference.py CHANGED
@@ -57,10 +57,7 @@ def generate_question(agents, question):
57
 
58
  return agent_contexts, content
59
 
60
- def Inference(model_list, question, API_KEY, cot):
61
- if len(model_list) != 3:
62
- raise ValueError("Please choose just '3' models! Neither more nor less!")
63
-
64
  openai.api_key = API_KEY
65
 
66
  prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
@@ -68,6 +65,7 @@ def Inference(model_list, question, API_KEY, cot):
68
  def generate_answer(model, formatted_prompt):
69
  API_URL = endpoint_dict[model]["API_URL"]
70
  headers = endpoint_dict[model]["headers"]
 
71
  payload = {
72
  "inputs": formatted_prompt,
73
  "parameters": {
 
57
 
58
  return agent_contexts, content
59
 
60
+ def Inference(model_list, question, API_KEY, cot, HF_TOKEN):
 
 
 
61
  openai.api_key = API_KEY
62
 
63
  prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
 
65
  def generate_answer(model, formatted_prompt):
66
  API_URL = endpoint_dict[model]["API_URL"]
67
  headers = endpoint_dict[model]["headers"]
68
+ headers["Authorization"] += HF_TOKEN
69
  payload = {
70
  "inputs": formatted_prompt,
71
  "parameters": {