Cartinoe5930 commited on
Commit
c366528
1 Parent(s): bfead6e

Update model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +55 -40
model_inference.py CHANGED
@@ -13,32 +13,42 @@ def load_json(prompt_path, endpoint_path):
13
 
14
  return prompt_dict, endpoint_dict
15
 
16
- def construct_message(agents, instruction, idx):
17
- if len(agents) == 0:
18
- prompt = "Can you double check that your answer is correct. Please reiterate your answer, making sure to state your answer at the end of the response."
19
- return prompt
20
-
21
- contexts = [agents[0][idx]['content'], agents[1][idx]['content'], agents[2][idx]['content']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # system prompt & user prompt for gpt-3.5-turbo
24
- sys_prompt = f"I want you to act as a summarizer. You can look at multiple responses and summarize the main points of them so that the meaning is not lost. Multiple responses will be given, which are responses from several different models to a single question. And you should use your excellent summarizing skills to output the best summary."
25
- summarize_prompt = f"[Response 1]: {contexts[0]}\n[Response 2]: {contexts[1]}\nResponse 3: {contexts[2]}\n\nThese are response of each model to a certain question. Summarize comprehensively without compromising the meaning of each response."
26
 
27
- message = [
28
- {"role": "system", "content": sys_prompt},
29
- {"role": "user", "content": summarize_prompt},
30
- ]
31
 
32
- completion = openai.ChatCompletion.create(
33
- model="gpt-3.5-turbo-16k-0613",
34
- messages=message,
35
- max_tokens=256,
36
- n=1
37
- )
38
 
39
- prefix_string = f"This is the summarization of recent/updated opinions from other agents: {completion}"
40
- prefix_string = prefix_string + "\n\n Use this summarization carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response." + instruction
41
- return prefix_string
 
42
 
43
  def generate_question(agents, question):
44
  agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
@@ -47,7 +57,7 @@ def generate_question(agents, question):
47
 
48
  return agent_contexts, content
49
 
50
- def Inference(model_list, question, API_KEY, auth_token, round, cot):
51
  if len(model_list) != 3:
52
  raise ValueError("Please choose just '3' models! Neither more nor less!")
53
 
@@ -58,16 +68,21 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
58
  def generate_answer(model, formatted_prompt):
59
  API_URL = endpoint_dict[model]["API_URL"]
60
  headers = endpoint_dict[model]["headers"]
61
- payload = {"inputs": formatted_prompt}
 
 
 
 
 
62
  try:
63
  resp = requests.post(API_URL, json=payload, headers=headers)
64
  response = resp.json()
65
  except:
66
  print("retrying due to an error......")
67
  time.sleep(5)
68
- return generate_answer(API_URL, headers, payload)
69
 
70
- return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
71
 
72
  def prompt_formatting(model, instruction, cot):
73
  if model == "alpaca" or model == "orca":
@@ -77,37 +92,37 @@ def Inference(model_list, question, API_KEY, auth_token, round, cot):
77
 
78
  if cot:
79
  instruction += "Let's think step by step."
80
-
81
- return {"model": model, "content": prompt.format(instruction)}
82
-
83
  agents = len(model_list)
84
- rounds = round
85
 
86
- generated_description = []
87
 
88
- agent_contexts, content = generate_question(agents=model_list, question=question)
89
 
90
  # Debate
91
  for debate in range(rounds+1):
92
  # Refer to the summarized previous response
93
  if debate != 0:
94
- message = construct_message(agent_contexts, content, 2 * debate - 1)
95
- for i in range(agent_contexts):
96
  agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
97
 
98
  # Generate new response based on summarized response
99
  for agent_context in agent_contexts:
100
- completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"] if debate != 0 else prompt_formatting(agent_context[-1]["model"], agent_context[-1]["content"], args.cot)["content"])
101
  agent_context.append(completion)
102
 
103
  models_response = {
104
- f"{args.m1}": [agent_contexts[0][1]["content"], agent_contexts[0][3]["content"], agent_contexts[0][-1]["content"]],
105
- f"{args.m2}": [agent_contexts[1][1]["content"], agent_contexts[1][3]["content"], agent_contexts[1][-1]["content"]],
106
- f"{args.m3}": [agent_contexts[2][1]["content"], agent_contexts[2][3]["content"], agent_contexts[2][-1]["content"]]
107
  }
108
  response_summarization = [
109
- agent_contexts[0][2], agent_contexts[0][4]
110
  ]
111
- generated_description.append({"question": content, "agent_response": models_response, "summarization": response_summarization})
112
 
113
  return generated_description
 
13
 
14
  return prompt_dict, endpoint_dict
15
 
16
+ def construct_message(agent_context, instruction, idx):
17
+ prefix_string = "Here are a list of opinions from different agents: "
18
+
19
+ prefix_string = prefix_string + agent_context + "\n\n Write a summary of the different opinions from each of the individual agent."
20
+
21
+ message = [{"role": "user", "content": prefix_string}]
22
+
23
+ try:
24
+ completion = openai.ChatCompletion.create(
25
+ model="gpt-3.5-turbo-0613",
26
+ messages=message,
27
+ max_tokens=256,
28
+ n=1
29
+ )['choices'][0]['message']['content']
30
+ except:
31
+ print("retrying ChatGPT due to an error......")
32
+ time.sleep(5)
33
+ return construct_message(agent_context, instruction, idx)
34
+
35
+ prefix_string = f"Here is a summary of responses from other agents: {completion}"
36
+ prefix_string = prefix_string + "\n\n Use this summarization carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response." + instruction
37
+ return prefix_string
38
 
39
+ def summarize_message(agent_contexts, instruction, idx):
40
+ prefix_string = "Here are a list of opinions from different agents: "
 
41
 
42
+ for agent in agent_contexts:
43
+ agent_response = agent[-1]["content"]
44
+ response = "\n\n One agent response: ```{}```".format(agent_response)
 
45
 
46
+ prefix_string = prefix_string + response
 
 
 
 
 
47
 
48
+ prefix_string = prefix_string + "\n\n Write a summary of the different opinions from each of the individual agent."
49
+ completion = construct_message(prefix_string, instruction, idx)
50
+
51
+ return completion
52
 
53
  def generate_question(agents, question):
54
  agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
 
57
 
58
  return agent_contexts, content
59
 
60
+ def Inference(model_list, question, API_KEY, cot):
61
  if len(model_list) != 3:
62
  raise ValueError("Please choose just '3' models! Neither more nor less!")
63
 
 
68
  def generate_answer(model, formatted_prompt):
69
  API_URL = endpoint_dict[model]["API_URL"]
70
  headers = endpoint_dict[model]["headers"]
71
+ payload = {
72
+ "inputs": formatted_prompt,
73
+ "parameters": {
74
+ "max_new_tokens": 256
75
+ }
76
+ }
77
  try:
78
  resp = requests.post(API_URL, json=payload, headers=headers)
79
  response = resp.json()
80
  except:
81
  print("retrying due to an error......")
82
  time.sleep(5)
83
+ return generate_answer(model, formatted_prompt)
84
 
85
+ return {"model": model, "content": response[0]["generated_text"]}
86
 
87
  def prompt_formatting(model, instruction, cot):
88
  if model == "alpaca" or model == "orca":
 
92
 
93
  if cot:
94
  instruction += "Let's think step by step."
95
+
96
+ return {"model": model, "content": prompt.format(instruction=instruction)}
97
+
98
  agents = len(model_list)
99
+ rounds = 2
100
 
101
+ agent_contexts, content = generate_question(agents=model_list, question=args.question)
102
 
103
+ message = []
104
 
105
  # Debate
106
  for debate in range(rounds+1):
107
  # Refer to the summarized previous response
108
  if debate != 0:
109
+ message.append(summarize_message(agent_contexts, content, 2 * debate - 1))
110
+ for i in range(len(agent_contexts)):
111
  agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
112
 
113
  # Generate new response based on summarized response
114
  for agent_context in agent_contexts:
115
+ completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"])
116
  agent_context.append(completion)
117
 
118
  models_response = {
119
+ f"{model_list[0]}": [agent_contexts[0][1]["content"], agent_contexts[0][3]["content"], agent_contexts[0][-1]["content"]],
120
+ f"{model_list[1]}": [agent_contexts[1][1]["content"], agent_contexts[1][3]["content"], agent_contexts[1][-1]["content"]],
121
+ f"{model_list[2]}": [agent_contexts[2][1]["content"], agent_contexts[2][3]["content"], agent_contexts[2][-1]["content"]]
122
  }
123
  response_summarization = [
124
+ message[0], message[1]
125
  ]
126
+ generated_description = {"question": content, "agent_response": models_response, "summarization": response_summarization})
127
 
128
  return generated_description