Spaces:
Sleeping
Sleeping
Chunhua Liao
commited on
Commit
·
39c2104
1
Parent(s):
e07720a
milestone v1: works so far
Browse files- config.yaml +6 -1
- graph_visualizer.py +35 -0
- proposal-gen-v1.py +30 -4
config.yaml
CHANGED
@@ -13,7 +13,7 @@ openrouter_base_url: "https://openrouter.ai/api/v1"
|
|
13 |
llm_model: "google/gemini-2.0-flash-001"
|
14 |
|
15 |
# Number of hypotheses to generate
|
16 |
-
num_hypotheses:
|
17 |
|
18 |
# Elo K-factor
|
19 |
elo_k_factor: 32
|
@@ -32,3 +32,8 @@ fastapi_host: "0.0.0.0"
|
|
32 |
|
33 |
# FastAPI port
|
34 |
fastapi_port: 8000
|
|
|
|
|
|
|
|
|
|
|
|
13 |
llm_model: "google/gemini-2.0-flash-001"
|
14 |
|
15 |
# Number of hypotheses to generate
|
16 |
+
num_hypotheses: 6
|
17 |
|
18 |
# Elo K-factor
|
19 |
elo_k_factor: 32
|
|
|
32 |
|
33 |
# FastAPI port
|
34 |
fastapi_port: 8000
|
35 |
+
|
36 |
+
# Temperature settings for each step
|
37 |
+
step_temperatures:
|
38 |
+
generation: 0.7
|
39 |
+
reflection: 0.5
|
graph_visualizer.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
def visualize_graph(adjacency_graph):
|
5 |
+
"""Visualizes an adjacency graph using networkx and matplotlib."""
|
6 |
+
|
7 |
+
graph = nx.DiGraph() # Use DiGraph for directed graph
|
8 |
+
|
9 |
+
for node, edges in adjacency_graph.items():
|
10 |
+
for edge in edges:
|
11 |
+
graph.add_edge(node, edge['other_id'], weight=edge['similarity'])
|
12 |
+
|
13 |
+
pos = nx.spring_layout(graph) # Node positioning for visualization
|
14 |
+
|
15 |
+
# Get edge weights and normalize them
|
16 |
+
weights = [edge['weight'] for u, v, edge in graph.edges(data=True)]
|
17 |
+
normalized_weights = [(w - min(weights)) / (max(weights) - min(weights)) for w in weights]
|
18 |
+
|
19 |
+
# Create a color map
|
20 |
+
cmap = plt.cm.viridis
|
21 |
+
|
22 |
+
# Map normalized weights to colors
|
23 |
+
edge_colors = [cmap(w) for w in normalized_weights]
|
24 |
+
|
25 |
+
nx.draw(graph, pos, with_labels=True, node_size=1500, node_color="skyblue", font_size=10, edge_color=edge_colors)
|
26 |
+
edge_labels = nx.get_edge_attributes(graph, 'weight')
|
27 |
+
nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
|
28 |
+
|
29 |
+
|
30 |
+
plt.title("Adjacency Graph")
|
31 |
+
plt.show()
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
adjacency_graph = {"G3250":[{"other_id":"E2029","similarity":0.8886164931432184},{"other_id":"G4687","similarity":0.7799796722164661},{"other_id":"G2491","similarity":0.7358682896118928}],"E2029":[{"other_id":"G3250","similarity":0.589185898055919},{"other_id":"G4687","similarity":0.5547903202019775},{"other_id":"G2491","similarity":0.4763465778429552}],"G4687":[{"other_id":"G3250","similarity":0.8534847661087587},{"other_id":"E2029","similarity":0.382888810662511},{"other_id":"G2491","similarity":0.9591597530883424}],"G2491":[{"other_id":"G3250","similarity":0.11935305775711214},{"other_id":"E2029","similarity":0.3629634156202275},{"other_id":"G4687","similarity":0.810511185411589}]}
|
35 |
+
visualize_graph(adjacency_graph)
|
proposal-gen-v1.py
CHANGED
@@ -38,6 +38,9 @@ def load_config(config_path: str) -> Dict:
|
|
38 |
except AttributeError as e:
|
39 |
print("Error: Invalid logging level in config file")
|
40 |
exit(1)
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
def setup_logger(log_filename):
|
@@ -57,10 +60,17 @@ def setup_logger(log_filename):
|
|
57 |
# Load configuration at the start
|
58 |
config = load_config("config.yaml")
|
59 |
|
60 |
-
def call_llm(prompt: str) -> str:
|
61 |
"""
|
62 |
Calls an LLM via the OpenRouter API and returns the response.
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
Args:
|
65 |
prompt (str): The input prompt for the LLM.
|
66 |
|
@@ -76,6 +86,7 @@ def call_llm(prompt: str) -> str:
|
|
76 |
completion = client.chat.completions.create(
|
77 |
model=config["llm_model"],
|
78 |
messages=[{"role": "user", "content": prompt}],
|
|
|
79 |
)
|
80 |
except Exception as e:
|
81 |
retries = config.get("max_retries", 3)
|
@@ -96,6 +107,7 @@ def call_llm(prompt: str) -> str:
|
|
96 |
completion = client.chat.completions.create(
|
97 |
model=config["llm_model"],
|
98 |
messages=[{"role": "user", "content": prompt}],
|
|
|
99 |
)
|
100 |
if completion.choices and len(completion.choices) > 0:
|
101 |
return completion.choices[0].message.content
|
@@ -291,7 +303,8 @@ def call_llm_for_generation(prompt: str, num_hypotheses: int = 3) -> List[Dict]:
|
|
291 |
# Modify the prompt to request JSON output
|
292 |
prompt += "\n\nPlease return the response as a JSON array of objects, where each object has a 'title' and 'text' key."
|
293 |
|
294 |
-
|
|
|
295 |
logger.info("LLM response: %s", response)
|
296 |
|
297 |
if "API call failed" in response:
|
@@ -341,7 +354,8 @@ def call_llm_for_reflection(hypothesis_text: str) -> Dict:
|
|
341 |
f"Return the response as a JSON object with the following keys: 'novelty_review', 'feasibility_review', 'comment', 'references'."
|
342 |
|
343 |
)
|
344 |
-
|
|
|
345 |
logger.info("LLM reflection for hypothesis: %s, response: %s", hypothesis_text, response)
|
346 |
|
347 |
if "API call failed" in response:
|
@@ -422,7 +436,19 @@ def run_pairwise_debate(hypoA: Hypothesis, hypoB: Hypothesis) -> Hypothesis:
|
|
422 |
the novelty and feasibility scores.
|
423 |
"""
|
424 |
mapping = {"HIGH": 3, "MEDIUM": 2, "LOW": 1, None: 0}
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
scoreA = score(hypoA)
|
427 |
scoreB = score(hypoB)
|
428 |
winner = hypoA if scoreA > scoreB else hypoB if scoreB > scoreA else random.choice([hypoA, hypoB])
|
|
|
38 |
except AttributeError as e:
|
39 |
print("Error: Invalid logging level in config file")
|
40 |
exit(1)
|
41 |
+
except KeyError as e:
|
42 |
+
print(f"Error: Missing key in config file: {e}")
|
43 |
+
exit(1)
|
44 |
|
45 |
|
46 |
def setup_logger(log_filename):
|
|
|
60 |
# Load configuration at the start
|
61 |
config = load_config("config.yaml")
|
62 |
|
63 |
+
def call_llm(prompt: str, temperature: float = 0.7) -> str:
|
64 |
"""
|
65 |
Calls an LLM via the OpenRouter API and returns the response.
|
66 |
|
67 |
+
Args:
|
68 |
+
prompt (str): The input prompt for the LLM.
|
69 |
+
temperature (float, optional): The temperature setting for the LLM. Defaults to 0.7.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
str: The LLM's response.
|
73 |
+
|
74 |
Args:
|
75 |
prompt (str): The input prompt for the LLM.
|
76 |
|
|
|
86 |
completion = client.chat.completions.create(
|
87 |
model=config["llm_model"],
|
88 |
messages=[{"role": "user", "content": prompt}],
|
89 |
+
temperature=temperature, # Pass temperature to the API call
|
90 |
)
|
91 |
except Exception as e:
|
92 |
retries = config.get("max_retries", 3)
|
|
|
107 |
completion = client.chat.completions.create(
|
108 |
model=config["llm_model"],
|
109 |
messages=[{"role": "user", "content": prompt}],
|
110 |
+
temperature=temperature, # Pass temperature to the API call
|
111 |
)
|
112 |
if completion.choices and len(completion.choices) > 0:
|
113 |
return completion.choices[0].message.content
|
|
|
303 |
# Modify the prompt to request JSON output
|
304 |
prompt += "\n\nPlease return the response as a JSON array of objects, where each object has a 'title' and 'text' key."
|
305 |
|
306 |
+
# Call LLM with the appropriate temperature
|
307 |
+
response = call_llm(prompt, temperature=config["step_temperatures"]["generation"])
|
308 |
logger.info("LLM response: %s", response)
|
309 |
|
310 |
if "API call failed" in response:
|
|
|
354 |
f"Return the response as a JSON object with the following keys: 'novelty_review', 'feasibility_review', 'comment', 'references'."
|
355 |
|
356 |
)
|
357 |
+
# Call LLM with the appropriate temperature
|
358 |
+
response = call_llm(prompt, temperature=config["step_temperatures"]["reflection"])
|
359 |
logger.info("LLM reflection for hypothesis: %s, response: %s", hypothesis_text, response)
|
360 |
|
361 |
if "API call failed" in response:
|
|
|
436 |
the novelty and feasibility scores.
|
437 |
"""
|
438 |
mapping = {"HIGH": 3, "MEDIUM": 2, "LOW": 1, None: 0}
|
439 |
+
score_novelty = 0
|
440 |
+
if isinstance(h.novelty_review, str):
|
441 |
+
score_novelty = mapping.get(h.novelty_review, 0)
|
442 |
+
else:
|
443 |
+
logger.error(f"Invalid novelty_review type: {type(h.novelty_review)}, value: {h.novelty_review}")
|
444 |
+
|
445 |
+
score_feasibility = 0
|
446 |
+
if isinstance(h.feasibility_review, str):
|
447 |
+
score_feasibility = mapping.get(h.feasibility_review, 0)
|
448 |
+
else:
|
449 |
+
logger.error(f"Invalid feasibility_review type: {type(h.feasibility_review)}, value: {h.feasibility_review}")
|
450 |
+
|
451 |
+
return score_novelty + score_feasibility
|
452 |
scoreA = score(hypoA)
|
453 |
scoreB = score(hypoB)
|
454 |
winner = hypoA if scoreA > scoreB else hypoB if scoreB > scoreA else random.choice([hypoA, hypoB])
|