Spaces:
Runtime error
Runtime error
File size: 2,032 Bytes
2561b63 73d1e6e 2561b63 73d1e6e 2561b63 73d1e6e b79c971 e1b962a 2561b63 e1b962a 2561b63 b79c971 05346b7 bcdca08 b79c971 e1b962a 2561b63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
#!/usr/bin/env python
from huggingface_hub import snapshot_download
from src.backend.envs import EVAL_REQUESTS_PATH_BACKEND
from src.backend.manage_requests import get_eval_requests
from src.backend.manage_requests import EvalRequest
from src.backend.run_eval_suite import run_evaluation
from src.backend.tasks.xsum.task import XSum
from lm_eval.tasks import initialize_tasks, include_task_folder
from lm_eval import tasks, evaluator, utils
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
from src.envs import QUEUE_REPO
def main():
# snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
PENDING_STATUS = "PENDING"
RUNNING_STATUS = "RUNNING"
FINISHED_STATUS = "FINISHED"
FAILED_STATUS = "FAILED"
status = [PENDING_STATUS, RUNNING_STATUS, FINISHED_STATUS, FAILED_STATUS]
# Get all eval request that are FINISHED, if you want to run other evals, change this parameter
eval_requests: list[EvalRequest] = get_eval_requests(job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND)
eval_request = [r for r in eval_requests if 'bloom-560m' in r.model][0]
# my_task = Task("memo-trap", "acc", "memo-trap", 0)
my_task = Task("xsum", "rougeLsum", "XSum", 2)
TASKS_HARNESS = [my_task]
# task_names = ['triviaqa']
# TASKS_HARNESS = [task.value for task in Tasks]
include_task_folder("src/backend/tasks/")
initialize_tasks('INFO')
# breakpoint()
print(tasks.ALL_TASKS)
for task in TASKS_HARNESS:
print(f"Selected Tasks: [{task}]")
results = evaluator.simple_evaluate(model="hf", model_args=eval_request.get_model_args(), tasks=[task.benchmark], num_fewshot=1,
batch_size=1, device="mps", use_cache=None, limit=10, write_out=True)
print('AAA', results["results"])
breakpoint()
if __name__ == "__main__":
main()
|