Update script.py
Browse files
script.py
CHANGED
@@ -11,6 +11,7 @@ import pandas as pd
|
|
11 |
import numpy as np
|
12 |
from datasets import load_dataset
|
13 |
from typing import Dict
|
|
|
14 |
|
15 |
def empty_solution(sample):
|
16 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
@@ -48,7 +49,6 @@ if __name__ == "__main__":
|
|
48 |
print(os.system('ls -lahtr /tmp/data/'))
|
49 |
print('/tmp/data/data')
|
50 |
print(os.system('ls -lahtrR /tmp/data/data'))
|
51 |
-
|
52 |
|
53 |
data_path_test_server = Path('/tmp/data')
|
54 |
data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/'
|
@@ -66,15 +66,7 @@ if __name__ == "__main__":
|
|
66 |
repo_type="dataset",
|
67 |
)
|
68 |
data_path = data_path_test_server
|
69 |
-
|
70 |
-
|
71 |
print(data_path)
|
72 |
-
|
73 |
-
# dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
|
74 |
-
# data_files = {
|
75 |
-
# "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
|
76 |
-
# "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
|
77 |
-
# }
|
78 |
data_files = {
|
79 |
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
|
80 |
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
|
@@ -86,31 +78,12 @@ if __name__ == "__main__":
|
|
86 |
trust_remote_code=True,
|
87 |
writer_batch_size=100
|
88 |
)
|
89 |
-
|
90 |
-
# if TEST_ENV:
|
91 |
-
# dataset = load_dataset(
|
92 |
-
# "webdataset",
|
93 |
-
# data_files=data_files,
|
94 |
-
# trust_remote_code=True,
|
95 |
-
# # streaming=True
|
96 |
-
# )
|
97 |
print('load with webdataset')
|
98 |
-
# else:
|
99 |
-
|
100 |
-
# dataset = load_dataset(
|
101 |
-
# "arrow",
|
102 |
-
# data_files=data_files,
|
103 |
-
# trust_remote_code=True,
|
104 |
-
# # streaming=True
|
105 |
-
# )
|
106 |
-
# print('load with arrow')
|
107 |
-
|
108 |
-
|
109 |
print(dataset, flush=True)
|
110 |
-
# dataset = load_dataset('webdataset', data_files={)
|
111 |
|
112 |
print('------------ Now you can do your solution ---------------')
|
113 |
solution = []
|
|
|
114 |
for subset_name in dataset:
|
115 |
for i, sample in enumerate(tqdm(dataset[subset_name])):
|
116 |
# replace this with your solution
|
@@ -121,15 +94,15 @@ if __name__ == "__main__":
|
|
121 |
except Exception as e:
|
122 |
print (f"Faile due to {e}")
|
123 |
pred_vertices, pred_edges = empty_solution(sample)
|
124 |
-
|
|
|
125 |
solution.append({
|
126 |
'order_id': sample['order_id'],
|
127 |
'wf_vertices': pred_vertices.tolist(),
|
128 |
'wf_edges': pred_edges
|
129 |
})
|
130 |
-
|
131 |
print('------------ Saving results ---------------')
|
132 |
-
print (len(solution))
|
133 |
from time import sleep
|
134 |
sleep(30)
|
135 |
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
|
|
|
11 |
import numpy as np
|
12 |
from datasets import load_dataset
|
13 |
from typing import Dict
|
14 |
+
from tqdm import tqdm
|
15 |
|
16 |
def empty_solution(sample):
|
17 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
|
|
49 |
print(os.system('ls -lahtr /tmp/data/'))
|
50 |
print('/tmp/data/data')
|
51 |
print(os.system('ls -lahtrR /tmp/data/data'))
|
|
|
52 |
|
53 |
data_path_test_server = Path('/tmp/data')
|
54 |
data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/'
|
|
|
66 |
repo_type="dataset",
|
67 |
)
|
68 |
data_path = data_path_test_server
|
|
|
|
|
69 |
print(data_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
data_files = {
|
71 |
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
|
72 |
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
|
|
|
78 |
trust_remote_code=True,
|
79 |
writer_batch_size=100
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
print('load with webdataset')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
print(dataset, flush=True)
|
|
|
83 |
|
84 |
print('------------ Now you can do your solution ---------------')
|
85 |
solution = []
|
86 |
+
num_fails = 0
|
87 |
for subset_name in dataset:
|
88 |
for i, sample in enumerate(tqdm(dataset[subset_name])):
|
89 |
# replace this with your solution
|
|
|
94 |
except Exception as e:
|
95 |
print (f"Faile due to {e}")
|
96 |
pred_vertices, pred_edges = empty_solution(sample)
|
97 |
+
num_fails+=1
|
98 |
+
#print (f'{pred_vertices=}, {pred_edges=}')
|
99 |
solution.append({
|
100 |
'order_id': sample['order_id'],
|
101 |
'wf_vertices': pred_vertices.tolist(),
|
102 |
'wf_edges': pred_edges
|
103 |
})
|
|
|
104 |
print('------------ Saving results ---------------')
|
105 |
+
print (f"Processed {len(solution)} entries, get {num_fails} fails")
|
106 |
from time import sleep
|
107 |
sleep(30)
|
108 |
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
|