dmytromishkin commited on
Commit
f7be0c9
·
verified ·
1 Parent(s): b00eea2

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +8 -10
script.py CHANGED
@@ -15,11 +15,10 @@ from joblib import Parallel, delayed
15
  import os
16
  import json
17
  import gc
18
-
19
  from hoho2025.example_solutions import predict_wireframe
20
  # check the https://github.com/s23dr/hoho2025/blob/main/hoho2025/example_solutions.py for the example solution
21
 
22
- def empty_solution(sample):
23
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
24
  return np.zeros((2,3)), [(0, 1)]
25
 
@@ -105,7 +104,7 @@ if __name__ == "__main__":
105
  try:
106
  pred_vertices, pred_edges = predict_wireframe(sample)
107
  except:
108
- pred_vertices, pred_edges = empty_solution(sample)
109
  if i %10 == 0:
110
  gc.collect()
111
  return {
@@ -114,13 +113,12 @@ if __name__ == "__main__":
114
  'wf_edges': pred_edges
115
  }
116
  num_cores = 4
117
- for subset_name in dataset:
118
- samples = list(dataset[subset_name])
119
- # Process samples in parallel with simple tqdm progress tracking
120
- results = Parallel(n_jobs=num_cores)(
121
- delayed(process_sample)(sample, i) for sample in tqdm(samples, desc=f"Processing {subset_name}")
122
- )
123
- solution.extend(results)
124
 
125
  print('------------ Saving results ---------------')
126
  sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
 
15
  import os
16
  import json
17
  import gc
 
18
  from hoho2025.example_solutions import predict_wireframe
19
  # check the https://github.com/s23dr/hoho2025/blob/main/hoho2025/example_solutions.py for the example solution
20
 
21
+ def empty_solution():
22
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
23
  return np.zeros((2,3)), [(0, 1)]
24
 
 
104
  try:
105
  pred_vertices, pred_edges = predict_wireframe(sample)
106
  except:
107
+ pred_vertices, pred_edges = empty_solution()
108
  if i %10 == 0:
109
  gc.collect()
110
  return {
 
113
  'wf_edges': pred_edges
114
  }
115
  num_cores = 4
116
+
117
+ for subset_name in dataset.keys():
118
+ print (f"Predicting {subset_name}")
119
+ for i, sample in enumerate(tqdm(dataset[subset_name])):
120
+ res = process_sample(sample, i)
121
+ solution.append(res)
 
122
 
123
  print('------------ Saving results ---------------')
124
  sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])