Commit
·
231b875
1
Parent(s):
b81532f
Add vizdoom
Browse files
app.py
CHANGED
@@ -39,6 +39,23 @@ def get_user_models(hf_username, env_tag, lib_tag):
|
|
39 |
return user_model_ids
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def get_metadata(model_id):
|
43 |
"""
|
44 |
Get model metadata (contains evaluation data)
|
@@ -225,18 +242,22 @@ def certification(hf_username, first_name, last_name):
|
|
225 |
},
|
226 |
{
|
227 |
"unit": "Unit 8 PII",
|
228 |
-
"env": "
|
229 |
-
"library": "
|
230 |
-
"min_result":
|
231 |
"best_result": 0,
|
232 |
"best_model_id": "",
|
233 |
"passed_": False
|
234 |
},
|
235 |
]
|
236 |
for unit in results_certification:
|
|
|
237 |
# Get user model
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
240 |
# Calculate the best result and get the best_model_id
|
241 |
best_result, best_model_id = calculate_best_result(user_models)
|
242 |
|
|
|
39 |
return user_model_ids
|
40 |
|
41 |
|
42 |
+
def get_user_sf_models(hf_username, env_tag, lib_tag):
|
43 |
+
models_sf = []
|
44 |
+
models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])
|
45 |
+
|
46 |
+
user_model_ids = [x.modelId for x in models]
|
47 |
+
|
48 |
+
for model in user_model_ids:
|
49 |
+
meta = get_metadata(model)
|
50 |
+
if meta is None:
|
51 |
+
continue
|
52 |
+
result = meta["model-index"][0]["results"][0]["dataset"]["name"]
|
53 |
+
if result == env_tag:
|
54 |
+
models_sf.append(model)
|
55 |
+
|
56 |
+
return models_sf
|
57 |
+
|
58 |
+
|
59 |
def get_metadata(model_id):
|
60 |
"""
|
61 |
Get model metadata (contains evaluation data)
|
|
|
242 |
},
|
243 |
{
|
244 |
"unit": "Unit 8 PII",
|
245 |
+
"env": "doom_health_gathering_supreme",
|
246 |
+
"library": "sample-factory",
|
247 |
+
"min_result": 5,
|
248 |
"best_result": 0,
|
249 |
"best_model_id": "",
|
250 |
"passed_": False
|
251 |
},
|
252 |
]
|
253 |
for unit in results_certification:
|
254 |
+
if unit["unit"] != "Unit 8 PII":
|
255 |
# Get user model
|
256 |
+
user_models = get_user_models(hf_username, unit['env'], unit['library'])
|
257 |
+
# For sample factory vizdoom we don't have env tag for now
|
258 |
+
else:
|
259 |
+
user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
|
260 |
+
|
261 |
# Calculate the best result and get the best_model_id
|
262 |
best_result, best_model_id = calculate_best_result(user_models)
|
263 |
|