Spaces:
Sleeping
Sleeping
Commit
·
5b531fe
1
Parent(s):
1d16515
Update model_inference.py
Browse files- model_inference.py +2 -4
model_inference.py
CHANGED
@@ -57,10 +57,7 @@ def generate_question(agents, question):
|
|
57 |
|
58 |
return agent_contexts, content
|
59 |
|
60 |
-
def Inference(model_list, question, API_KEY, cot):
|
61 |
-
if len(model_list) != 3:
|
62 |
-
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
63 |
-
|
64 |
openai.api_key = API_KEY
|
65 |
|
66 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
@@ -68,6 +65,7 @@ def Inference(model_list, question, API_KEY, cot):
|
|
68 |
def generate_answer(model, formatted_prompt):
|
69 |
API_URL = endpoint_dict[model]["API_URL"]
|
70 |
headers = endpoint_dict[model]["headers"]
|
|
|
71 |
payload = {
|
72 |
"inputs": formatted_prompt,
|
73 |
"parameters": {
|
|
|
57 |
|
58 |
return agent_contexts, content
|
59 |
|
60 |
+
def Inference(model_list, question, API_KEY, cot, HF_TOKEN):
|
|
|
|
|
|
|
61 |
openai.api_key = API_KEY
|
62 |
|
63 |
prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
|
|
|
65 |
def generate_answer(model, formatted_prompt):
|
66 |
API_URL = endpoint_dict[model]["API_URL"]
|
67 |
headers = endpoint_dict[model]["headers"]
|
68 |
+
headers["Authorization"] += HF_TOKEN
|
69 |
payload = {
|
70 |
"inputs": formatted_prompt,
|
71 |
"parameters": {
|