sonsus's picture
new feat: o4-mini supported
9e784df
import asyncio
import math
import random
from time import time
from types import SimpleNamespace
from typing import Dict, List, Literal, Tuple
from varco_arena_core.prompts import ComparisonPromptBase
from .match import Match
async def limited_coro(coro, semaphore):
async with semaphore:
return await coro
def log2_power_of_two(n):
# First, let's make sure n is indeed a power of 2
if n & (n - 1) != 0 or n == 0:
raise ValueError("n must be a positive power of 2")
exponent = 0
while n > 1:
n >>= 1 # Right shift is like dividing by 2, but faster
exponent += 1
return exponent
class Tournament:
"""
Perform single elimination tournament of model outputs
"""
def __init__(self, participants, evaluation_model):
self.participants = participants
self.evaluation_model = evaluation_model
async def async_run(
self,
prompt_obj: ComparisonPromptBase = None,
semaphore: asyncio.Semaphore = None,
) -> List[Dict]:
random.shuffle(self.participants)
result = []
next_higher_power_of_two = int(
math.pow(2, math.ceil(math.log2(len(self.participants))))
)
winners_number_of_byes = next_higher_power_of_two - len(
self.participants
) # 부전승
current_round_participant = self.participants + [None] * winners_number_of_byes
async def get_match_results(
a,
b,
match_order,
cur_round, # position_swap=False,
depth: int = None,
) -> Tuple[SimpleNamespace, List[Dict]]:
if a is None and b is None:
next_round_participant = None
result = None
elif a is not None and b is None:
next_round_participant = a
result = None
elif a is None and b is not None:
next_round_participant = b
result = None
else:
match = Match(A=a, B=b, eval_model=self.evaluation_model)
if self.evaluation_model == "debug":
winner, match_result = await match.async_dbg_eval()
else:
winner, match_result = await match.async_comp_eval(
comp_prompt=prompt_obj,
)
if winner is None:
next_round_participant = None
result = None
else:
if winner == "A":
next_round_participant = a
else:
next_round_participant = b
now_time = time()
# A vs B, B vs A
result = [
{ # 여기서 *_a, *_b 는 prompt내의 position에 관한 것이다. Match.A, Match.B 가 아니다.
"task": a.task, # participant_pair[1].task
"model_a": a.model_id,
"model_b": b.model_id,
"winner": "A"
if match_result[0]["A"] > match_result[0]["B"]
else "B",
"prob_a": match_result[0]["A"],
"prob_b": match_result[0]["B"],
"evaluation_model": self.evaluation_model,
"instruction": a.instruction, # participant_pair[1].instruction,
"source": a.source, # participant_pair[1].source,
"generated_a": a.generated,
"generated_b": b.generated,
"tournament_idx": a.tournament_idx,
"round": cur_round,
"depth": depth,
"match_order_in_round": match_order,
"tstamp": now_time,
"api_call_kwargs": match_result[0]["api_call_kwargs"],
"actual_response_text": match_result[0]["actual_response_text"],
},
]
try:
result[0]["source_language"] = a.source_language
result[0]["target_language"] = a.target_language
except:
pass
# next_round_participant = SimpleNamespace(A) if winner=="A" from Match.async_comp_eval()
return next_round_participant, result
while len(current_round_participant) > 1:
half_length = int(len(current_round_participant) / 2)
first = current_round_participant[:half_length]
last = current_round_participant[half_length:]
last.reverse()
# human-readable rounds
# key name and value format is just right for web view. do not remove or modify
if len(current_round_participant) == 2:
cur_round = "final"
elif len(current_round_participant) == 4:
cur_round = "semi-final"
elif len(current_round_participant) == 8:
cur_round = "quarter-final"
else:
cur_round = f"round-{len(current_round_participant)}"
# later used for bracket re-building
cur_depth = log2_power_of_two(len(current_round_participant)) - 1
match_order = 0
match_jobs = []
for participant_0, participant_1, i in zip(first, last, range(len(last))):
participant_pair = (participant_0, participant_1)
if i % 2 == 1:
participant_pair = (participant_1, participant_0)
match_jobs.append(
limited_coro(
get_match_results(
participant_pair[0],
participant_pair[1],
match_order,
cur_round,
depth=cur_depth,
),
semaphore,
)
# get_match_results(participant_pair[0], participant_pair[1], match_order, cur_round)
)
match_order += 1
pairs_of_winner_and_result = await asyncio.gather(*match_jobs)
next_round_participant, current_round_result = map(
list, zip(*pairs_of_winner_and_result)
)
current_round_result = [
_result
for result_pair in current_round_result
if result_pair is not None
for _result in result_pair
]
result.extend(current_round_result)
current_round_participant = next_round_participant
return result