Spaces:
Runtime error
Runtime error
Cartinoe5930
commited on
Commit
•
c366528
1
Parent(s):
bfead6e
Update model_inference.py
Browse files- model_inference.py +55 -40
model_inference.py
CHANGED
@@ -13,32 +13,42 @@ def load_json(prompt_path, endpoint_path):
|
|
13 |
|
14 |
return prompt_dict, endpoint_dict
|
15 |
|
16 |
-
def construct_message(
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
summarize_prompt = f"[Response 1]: {contexts[0]}\n[Response 2]: {contexts[1]}\nResponse 3: {contexts[2]}\n\nThese are response of each model to a certain question. Summarize comprehensively without compromising the meaning of each response."
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
]
|
31 |
|
32 |
-
|
33 |
-
model="gpt-3.5-turbo-16k-0613",
|
34 |
-
messages=message,
|
35 |
-
max_tokens=256,
|
36 |
-
n=1
|
37 |
-
)
|
38 |
|
39 |
-
prefix_string =
|
40 |
-
|
41 |
-
|
|
|
42 |
|
43 |
def generate_question(agents, question):
|
44 |
agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
|
@@ -47,7 +57,7 @@ def generate_question(agents, question):
|
|
47 |
|
48 |
return agent_contexts, content
|
49 |
|
50 |
-
def Inference(model_list, question, API_KEY,
|
51 |
if len(model_list) != 3:
|
52 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
53 |
|
@@ -58,16 +68,21 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
58 |
def generate_answer(model, formatted_prompt):
|
59 |
API_URL = endpoint_dict[model]["API_URL"]
|
60 |
headers = endpoint_dict[model]["headers"]
|
61 |
-
payload = {
|
|
|
|
|
|
|
|
|
|
|
62 |
try:
|
63 |
resp = requests.post(API_URL, json=payload, headers=headers)
|
64 |
response = resp.json()
|
65 |
except:
|
66 |
print("retrying due to an error......")
|
67 |
time.sleep(5)
|
68 |
-
return generate_answer(
|
69 |
|
70 |
-
return {"model": model, "content": response[0]["generated_text"]
|
71 |
|
72 |
def prompt_formatting(model, instruction, cot):
|
73 |
if model == "alpaca" or model == "orca":
|
@@ -77,37 +92,37 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
|
|
77 |
|
78 |
if cot:
|
79 |
instruction += "Let's think step by step."
|
80 |
-
|
81 |
-
return {"model": model, "content": prompt.format(instruction)}
|
82 |
-
|
83 |
agents = len(model_list)
|
84 |
-
rounds =
|
85 |
|
86 |
-
|
87 |
|
88 |
-
|
89 |
|
90 |
# Debate
|
91 |
for debate in range(rounds+1):
|
92 |
# Refer to the summarized previous response
|
93 |
if debate != 0:
|
94 |
-
message
|
95 |
-
for i in range(agent_contexts):
|
96 |
agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
|
97 |
|
98 |
# Generate new response based on summarized response
|
99 |
for agent_context in agent_contexts:
|
100 |
-
completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"]
|
101 |
agent_context.append(completion)
|
102 |
|
103 |
models_response = {
|
104 |
-
f"{
|
105 |
-
f"{
|
106 |
-
f"{
|
107 |
}
|
108 |
response_summarization = [
|
109 |
-
|
110 |
]
|
111 |
-
generated_description
|
112 |
|
113 |
return generated_description
|
|
|
13 |
|
14 |
return prompt_dict, endpoint_dict
|
15 |
|
16 |
+
def construct_message(agent_context, instruction, idx):
|
17 |
+
prefix_string = "Here are a list of opinions from different agents: "
|
18 |
+
|
19 |
+
prefix_string = prefix_string + agent_context + "\n\n Write a summary of the different opinions from each of the individual agent."
|
20 |
+
|
21 |
+
message = [{"role": "user", "content": prefix_string}]
|
22 |
+
|
23 |
+
try:
|
24 |
+
completion = openai.ChatCompletion.create(
|
25 |
+
model="gpt-3.5-turbo-0613",
|
26 |
+
messages=message,
|
27 |
+
max_tokens=256,
|
28 |
+
n=1
|
29 |
+
)['choices'][0]['message']['content']
|
30 |
+
except:
|
31 |
+
print("retrying ChatGPT due to an error......")
|
32 |
+
time.sleep(5)
|
33 |
+
return construct_message(agent_context, instruction, idx)
|
34 |
+
|
35 |
+
prefix_string = f"Here is a summary of responses from other agents: {completion}"
|
36 |
+
prefix_string = prefix_string + "\n\n Use this summarization carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response." + instruction
|
37 |
+
return prefix_string
|
38 |
|
39 |
+
def summarize_message(agent_contexts, instruction, idx):
|
40 |
+
prefix_string = "Here are a list of opinions from different agents: "
|
|
|
41 |
|
42 |
+
for agent in agent_contexts:
|
43 |
+
agent_response = agent[-1]["content"]
|
44 |
+
response = "\n\n One agent response: ```{}```".format(agent_response)
|
|
|
45 |
|
46 |
+
prefix_string = prefix_string + response
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
prefix_string = prefix_string + "\n\n Write a summary of the different opinions from each of the individual agent."
|
49 |
+
completion = construct_message(prefix_string, instruction, idx)
|
50 |
+
|
51 |
+
return completion
|
52 |
|
53 |
def generate_question(agents, question):
|
54 |
agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
|
|
|
57 |
|
58 |
return agent_contexts, content
|
59 |
|
60 |
+
def Inference(model_list, question, API_KEY, cot):
|
61 |
if len(model_list) != 3:
|
62 |
raise ValueError("Please choose just '3' models! Neither more nor less!")
|
63 |
|
|
|
68 |
def generate_answer(model, formatted_prompt):
|
69 |
API_URL = endpoint_dict[model]["API_URL"]
|
70 |
headers = endpoint_dict[model]["headers"]
|
71 |
+
payload = {
|
72 |
+
"inputs": formatted_prompt,
|
73 |
+
"parameters": {
|
74 |
+
"max_new_tokens": 256
|
75 |
+
}
|
76 |
+
}
|
77 |
try:
|
78 |
resp = requests.post(API_URL, json=payload, headers=headers)
|
79 |
response = resp.json()
|
80 |
except:
|
81 |
print("retrying due to an error......")
|
82 |
time.sleep(5)
|
83 |
+
return generate_answer(model, formatted_prompt)
|
84 |
|
85 |
+
return {"model": model, "content": response[0]["generated_text"]}
|
86 |
|
87 |
def prompt_formatting(model, instruction, cot):
|
88 |
if model == "alpaca" or model == "orca":
|
|
|
92 |
|
93 |
if cot:
|
94 |
instruction += "Let's think step by step."
|
95 |
+
|
96 |
+
return {"model": model, "content": prompt.format(instruction=instruction)}
|
97 |
+
|
98 |
agents = len(model_list)
|
99 |
+
rounds = 2
|
100 |
|
101 |
+
agent_contexts, content = generate_question(agents=model_list, question=args.question)
|
102 |
|
103 |
+
message = []
|
104 |
|
105 |
# Debate
|
106 |
for debate in range(rounds+1):
|
107 |
# Refer to the summarized previous response
|
108 |
if debate != 0:
|
109 |
+
message.append(summarize_message(agent_contexts, content, 2 * debate - 1))
|
110 |
+
for i in range(len(agent_contexts)):
|
111 |
agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
|
112 |
|
113 |
# Generate new response based on summarized response
|
114 |
for agent_context in agent_contexts:
|
115 |
+
completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"])
|
116 |
agent_context.append(completion)
|
117 |
|
118 |
models_response = {
|
119 |
+
f"{model_list[0]}": [agent_contexts[0][1]["content"], agent_contexts[0][3]["content"], agent_contexts[0][-1]["content"]],
|
120 |
+
f"{model_list[1]}": [agent_contexts[1][1]["content"], agent_contexts[1][3]["content"], agent_contexts[1][-1]["content"]],
|
121 |
+
f"{model_list[2]}": [agent_contexts[2][1]["content"], agent_contexts[2][3]["content"], agent_contexts[2][-1]["content"]]
|
122 |
}
|
123 |
response_summarization = [
|
124 |
+
message[0], message[1]
|
125 |
]
|
126 |
+
generated_description = {"question": content, "agent_response": models_response, "summarization": response_summarization})
|
127 |
|
128 |
return generated_description
|