Update script.py
Browse files
script.py
CHANGED
@@ -15,11 +15,10 @@ from joblib import Parallel, delayed
|
|
15 |
import os
|
16 |
import json
|
17 |
import gc
|
18 |
-
|
19 |
from hoho2025.example_solutions import predict_wireframe
|
20 |
# check the https://github.com/s23dr/hoho2025/blob/main/hoho2025/example_solutions.py for the example solution
|
21 |
|
22 |
-
def empty_solution(
|
23 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
24 |
return np.zeros((2,3)), [(0, 1)]
|
25 |
|
@@ -105,7 +104,7 @@ if __name__ == "__main__":
|
|
105 |
try:
|
106 |
pred_vertices, pred_edges = predict_wireframe(sample)
|
107 |
except:
|
108 |
-
pred_vertices, pred_edges = empty_solution(
|
109 |
if i %10 == 0:
|
110 |
gc.collect()
|
111 |
return {
|
@@ -114,13 +113,12 @@ if __name__ == "__main__":
|
|
114 |
'wf_edges': pred_edges
|
115 |
}
|
116 |
num_cores = 4
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
solution.extend(results)
|
124 |
|
125 |
print('------------ Saving results ---------------')
|
126 |
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
|
|
|
15 |
import os
|
16 |
import json
|
17 |
import gc
|
|
|
18 |
from hoho2025.example_solutions import predict_wireframe
|
19 |
# check the https://github.com/s23dr/hoho2025/blob/main/hoho2025/example_solutions.py for the example solution
|
20 |
|
21 |
+
def empty_solution():
|
22 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
23 |
return np.zeros((2,3)), [(0, 1)]
|
24 |
|
|
|
104 |
try:
|
105 |
pred_vertices, pred_edges = predict_wireframe(sample)
|
106 |
except:
|
107 |
+
pred_vertices, pred_edges = empty_solution()
|
108 |
if i %10 == 0:
|
109 |
gc.collect()
|
110 |
return {
|
|
|
113 |
'wf_edges': pred_edges
|
114 |
}
|
115 |
num_cores = 4
|
116 |
+
|
117 |
+
for subset_name in dataset.keys():
|
118 |
+
print (f"Predicting {subset_name}")
|
119 |
+
for i, sample in enumerate(tqdm(dataset[subset_name])):
|
120 |
+
res = process_sample(sample, i)
|
121 |
+
solution.append(res)
|
|
|
122 |
|
123 |
print('------------ Saving results ---------------')
|
124 |
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
|