Chunhua Liao commited on
Commit
39c2104
·
1 Parent(s): e07720a

milestone v1: works so far

Browse files
Files changed (3) hide show
  1. config.yaml +6 -1
  2. graph_visualizer.py +35 -0
  3. proposal-gen-v1.py +30 -4
config.yaml CHANGED
@@ -13,7 +13,7 @@ openrouter_base_url: "https://openrouter.ai/api/v1"
13
  llm_model: "google/gemini-2.0-flash-001"
14
 
15
  # Number of hypotheses to generate
16
- num_hypotheses: 3
17
 
18
  # Elo K-factor
19
  elo_k_factor: 32
@@ -32,3 +32,8 @@ fastapi_host: "0.0.0.0"
32
 
33
  # FastAPI port
34
  fastapi_port: 8000
 
 
 
 
 
 
13
  llm_model: "google/gemini-2.0-flash-001"
14
 
15
  # Number of hypotheses to generate
16
+ num_hypotheses: 6
17
 
18
  # Elo K-factor
19
  elo_k_factor: 32
 
32
 
33
  # FastAPI port
34
  fastapi_port: 8000
35
+
36
+ # Temperature settings for each step
37
+ step_temperatures:
38
+ generation: 0.7
39
+ reflection: 0.5
graph_visualizer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+
4
+ def visualize_graph(adjacency_graph):
5
+ """Visualizes an adjacency graph using networkx and matplotlib."""
6
+
7
+ graph = nx.DiGraph() # Use DiGraph for directed graph
8
+
9
+ for node, edges in adjacency_graph.items():
10
+ for edge in edges:
11
+ graph.add_edge(node, edge['other_id'], weight=edge['similarity'])
12
+
13
+ pos = nx.spring_layout(graph) # Node positioning for visualization
14
+
15
+ # Get edge weights and normalize them
16
+ weights = [edge['weight'] for u, v, edge in graph.edges(data=True)]
17
+ normalized_weights = [(w - min(weights)) / (max(weights) - min(weights)) for w in weights]
18
+
19
+ # Create a color map
20
+ cmap = plt.cm.viridis
21
+
22
+ # Map normalized weights to colors
23
+ edge_colors = [cmap(w) for w in normalized_weights]
24
+
25
+ nx.draw(graph, pos, with_labels=True, node_size=1500, node_color="skyblue", font_size=10, edge_color=edge_colors)
26
+ edge_labels = nx.get_edge_attributes(graph, 'weight')
27
+ nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
28
+
29
+
30
+ plt.title("Adjacency Graph")
31
+ plt.show()
32
+
33
+ if __name__ == "__main__":
34
+ adjacency_graph = {"G3250":[{"other_id":"E2029","similarity":0.8886164931432184},{"other_id":"G4687","similarity":0.7799796722164661},{"other_id":"G2491","similarity":0.7358682896118928}],"E2029":[{"other_id":"G3250","similarity":0.589185898055919},{"other_id":"G4687","similarity":0.5547903202019775},{"other_id":"G2491","similarity":0.4763465778429552}],"G4687":[{"other_id":"G3250","similarity":0.8534847661087587},{"other_id":"E2029","similarity":0.382888810662511},{"other_id":"G2491","similarity":0.9591597530883424}],"G2491":[{"other_id":"G3250","similarity":0.11935305775711214},{"other_id":"E2029","similarity":0.3629634156202275},{"other_id":"G4687","similarity":0.810511185411589}]}
35
+ visualize_graph(adjacency_graph)
proposal-gen-v1.py CHANGED
@@ -38,6 +38,9 @@ def load_config(config_path: str) -> Dict:
38
  except AttributeError as e:
39
  print("Error: Invalid logging level in config file")
40
  exit(1)
 
 
 
41
 
42
 
43
  def setup_logger(log_filename):
@@ -57,10 +60,17 @@ def setup_logger(log_filename):
57
  # Load configuration at the start
58
  config = load_config("config.yaml")
59
 
60
- def call_llm(prompt: str) -> str:
61
  """
62
  Calls an LLM via the OpenRouter API and returns the response.
63
 
 
 
 
 
 
 
 
64
  Args:
65
  prompt (str): The input prompt for the LLM.
66
 
@@ -76,6 +86,7 @@ def call_llm(prompt: str) -> str:
76
  completion = client.chat.completions.create(
77
  model=config["llm_model"],
78
  messages=[{"role": "user", "content": prompt}],
 
79
  )
80
  except Exception as e:
81
  retries = config.get("max_retries", 3)
@@ -96,6 +107,7 @@ def call_llm(prompt: str) -> str:
96
  completion = client.chat.completions.create(
97
  model=config["llm_model"],
98
  messages=[{"role": "user", "content": prompt}],
 
99
  )
100
  if completion.choices and len(completion.choices) > 0:
101
  return completion.choices[0].message.content
@@ -291,7 +303,8 @@ def call_llm_for_generation(prompt: str, num_hypotheses: int = 3) -> List[Dict]:
291
  # Modify the prompt to request JSON output
292
  prompt += "\n\nPlease return the response as a JSON array of objects, where each object has a 'title' and 'text' key."
293
 
294
- response = call_llm(prompt)
 
295
  logger.info("LLM response: %s", response)
296
 
297
  if "API call failed" in response:
@@ -341,7 +354,8 @@ def call_llm_for_reflection(hypothesis_text: str) -> Dict:
341
  f"Return the response as a JSON object with the following keys: 'novelty_review', 'feasibility_review', 'comment', 'references'."
342
 
343
  )
344
- response = call_llm(prompt)
 
345
  logger.info("LLM reflection for hypothesis: %s, response: %s", hypothesis_text, response)
346
 
347
  if "API call failed" in response:
@@ -422,7 +436,19 @@ def run_pairwise_debate(hypoA: Hypothesis, hypoB: Hypothesis) -> Hypothesis:
422
  the novelty and feasibility scores.
423
  """
424
  mapping = {"HIGH": 3, "MEDIUM": 2, "LOW": 1, None: 0}
425
- return mapping.get(h.novelty_review, 0) + mapping.get(h.feasibility_review, 0)
 
 
 
 
 
 
 
 
 
 
 
 
426
  scoreA = score(hypoA)
427
  scoreB = score(hypoB)
428
  winner = hypoA if scoreA > scoreB else hypoB if scoreB > scoreA else random.choice([hypoA, hypoB])
 
38
  except AttributeError as e:
39
  print("Error: Invalid logging level in config file")
40
  exit(1)
41
+ except KeyError as e:
42
+ print(f"Error: Missing key in config file: {e}")
43
+ exit(1)
44
 
45
 
46
  def setup_logger(log_filename):
 
60
  # Load configuration at the start
61
  config = load_config("config.yaml")
62
 
63
+ def call_llm(prompt: str, temperature: float = 0.7) -> str:
64
  """
65
  Calls an LLM via the OpenRouter API and returns the response.
66
 
67
+ Args:
68
+ prompt (str): The input prompt for the LLM.
69
+ temperature (float, optional): The temperature setting for the LLM. Defaults to 0.7.
70
+
71
+ Returns:
72
+ str: The LLM's response.
73
+
74
  Args:
75
  prompt (str): The input prompt for the LLM.
76
 
 
86
  completion = client.chat.completions.create(
87
  model=config["llm_model"],
88
  messages=[{"role": "user", "content": prompt}],
89
+ temperature=temperature, # Pass temperature to the API call
90
  )
91
  except Exception as e:
92
  retries = config.get("max_retries", 3)
 
107
  completion = client.chat.completions.create(
108
  model=config["llm_model"],
109
  messages=[{"role": "user", "content": prompt}],
110
+ temperature=temperature, # Pass temperature to the API call
111
  )
112
  if completion.choices and len(completion.choices) > 0:
113
  return completion.choices[0].message.content
 
303
  # Modify the prompt to request JSON output
304
  prompt += "\n\nPlease return the response as a JSON array of objects, where each object has a 'title' and 'text' key."
305
 
306
+ # Call LLM with the appropriate temperature
307
+ response = call_llm(prompt, temperature=config["step_temperatures"]["generation"])
308
  logger.info("LLM response: %s", response)
309
 
310
  if "API call failed" in response:
 
354
  f"Return the response as a JSON object with the following keys: 'novelty_review', 'feasibility_review', 'comment', 'references'."
355
 
356
  )
357
+ # Call LLM with the appropriate temperature
358
+ response = call_llm(prompt, temperature=config["step_temperatures"]["reflection"])
359
  logger.info("LLM reflection for hypothesis: %s, response: %s", hypothesis_text, response)
360
 
361
  if "API call failed" in response:
 
436
  the novelty and feasibility scores.
437
  """
438
  mapping = {"HIGH": 3, "MEDIUM": 2, "LOW": 1, None: 0}
439
+ score_novelty = 0
440
+ if isinstance(h.novelty_review, str):
441
+ score_novelty = mapping.get(h.novelty_review, 0)
442
+ else:
443
+ logger.error(f"Invalid novelty_review type: {type(h.novelty_review)}, value: {h.novelty_review}")
444
+
445
+ score_feasibility = 0
446
+ if isinstance(h.feasibility_review, str):
447
+ score_feasibility = mapping.get(h.feasibility_review, 0)
448
+ else:
449
+ logger.error(f"Invalid feasibility_review type: {type(h.feasibility_review)}, value: {h.feasibility_review}")
450
+
451
+ return score_novelty + score_feasibility
452
  scoreA = score(hypoA)
453
  scoreB = score(hypoB)
454
  winner = hypoA if scoreA > scoreB else hypoB if scoreB > scoreA else random.choice([hypoA, hypoB])