Spaces:
Sleeping
Sleeping
import asyncio | |
import math | |
import random | |
from time import time | |
from types import SimpleNamespace | |
from typing import Dict, List, Literal, Tuple | |
from varco_arena_core.prompts import ComparisonPromptBase | |
from .match import Match | |
async def limited_coro(coro, semaphore): | |
async with semaphore: | |
return await coro | |
def log2_power_of_two(n): | |
# First, let's make sure n is indeed a power of 2 | |
if n & (n - 1) != 0 or n == 0: | |
raise ValueError("n must be a positive power of 2") | |
exponent = 0 | |
while n > 1: | |
n >>= 1 # Right shift is like dividing by 2, but faster | |
exponent += 1 | |
return exponent | |
class Tournament: | |
""" | |
Perform single elimination tournament of model outputs | |
""" | |
def __init__(self, participants, evaluation_model): | |
self.participants = participants | |
self.evaluation_model = evaluation_model | |
async def async_run( | |
self, | |
prompt_obj: ComparisonPromptBase = None, | |
semaphore: asyncio.Semaphore = None, | |
) -> List[Dict]: | |
random.shuffle(self.participants) | |
result = [] | |
next_higher_power_of_two = int( | |
math.pow(2, math.ceil(math.log2(len(self.participants)))) | |
) | |
winners_number_of_byes = next_higher_power_of_two - len( | |
self.participants | |
) # 부전승 | |
current_round_participant = self.participants + [None] * winners_number_of_byes | |
async def get_match_results( | |
a, | |
b, | |
match_order, | |
cur_round, # position_swap=False, | |
depth: int = None, | |
) -> Tuple[SimpleNamespace, List[Dict]]: | |
if a is None and b is None: | |
next_round_participant = None | |
result = None | |
elif a is not None and b is None: | |
next_round_participant = a | |
result = None | |
elif a is None and b is not None: | |
next_round_participant = b | |
result = None | |
else: | |
match = Match(A=a, B=b, eval_model=self.evaluation_model) | |
if self.evaluation_model == "debug": | |
winner, match_result = await match.async_dbg_eval() | |
else: | |
winner, match_result = await match.async_comp_eval( | |
comp_prompt=prompt_obj, | |
) | |
if winner is None: | |
next_round_participant = None | |
result = None | |
else: | |
if winner == "A": | |
next_round_participant = a | |
else: | |
next_round_participant = b | |
now_time = time() | |
# A vs B, B vs A | |
result = [ | |
{ # 여기서 *_a, *_b 는 prompt내의 position에 관한 것이다. Match.A, Match.B 가 아니다. | |
"task": a.task, # participant_pair[1].task | |
"model_a": a.model_id, | |
"model_b": b.model_id, | |
"winner": "A" | |
if match_result[0]["A"] > match_result[0]["B"] | |
else "B", | |
"prob_a": match_result[0]["A"], | |
"prob_b": match_result[0]["B"], | |
"evaluation_model": self.evaluation_model, | |
"instruction": a.instruction, # participant_pair[1].instruction, | |
"source": a.source, # participant_pair[1].source, | |
"generated_a": a.generated, | |
"generated_b": b.generated, | |
"tournament_idx": a.tournament_idx, | |
"round": cur_round, | |
"depth": depth, | |
"match_order_in_round": match_order, | |
"tstamp": now_time, | |
"api_call_kwargs": match_result[0]["api_call_kwargs"], | |
"actual_response_text": match_result[0]["actual_response_text"], | |
}, | |
] | |
try: | |
result[0]["source_language"] = a.source_language | |
result[0]["target_language"] = a.target_language | |
except: | |
pass | |
# next_round_participant = SimpleNamespace(A) if winner=="A" from Match.async_comp_eval() | |
return next_round_participant, result | |
while len(current_round_participant) > 1: | |
half_length = int(len(current_round_participant) / 2) | |
first = current_round_participant[:half_length] | |
last = current_round_participant[half_length:] | |
last.reverse() | |
# human-readable rounds | |
# key name and value format is just right for web view. do not remove or modify | |
if len(current_round_participant) == 2: | |
cur_round = "final" | |
elif len(current_round_participant) == 4: | |
cur_round = "semi-final" | |
elif len(current_round_participant) == 8: | |
cur_round = "quarter-final" | |
else: | |
cur_round = f"round-{len(current_round_participant)}" | |
# later used for bracket re-building | |
cur_depth = log2_power_of_two(len(current_round_participant)) - 1 | |
match_order = 0 | |
match_jobs = [] | |
for participant_0, participant_1, i in zip(first, last, range(len(last))): | |
participant_pair = (participant_0, participant_1) | |
if i % 2 == 1: | |
participant_pair = (participant_1, participant_0) | |
match_jobs.append( | |
limited_coro( | |
get_match_results( | |
participant_pair[0], | |
participant_pair[1], | |
match_order, | |
cur_round, | |
depth=cur_depth, | |
), | |
semaphore, | |
) | |
# get_match_results(participant_pair[0], participant_pair[1], match_order, cur_round) | |
) | |
match_order += 1 | |
pairs_of_winner_and_result = await asyncio.gather(*match_jobs) | |
next_round_participant, current_round_result = map( | |
list, zip(*pairs_of_winner_and_result) | |
) | |
current_round_result = [ | |
_result | |
for result_pair in current_round_result | |
if result_pair is not None | |
for _result in result_pair | |
] | |
result.extend(current_round_result) | |
current_round_participant = next_round_participant | |
return result | |