Spaces:
Running
Running
import os | |
import json | |
import time | |
import re | |
import requests | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from datetime import datetime | |
# Get HF token | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
# Load models | |
with open("models.json", "r") as f: | |
models_data = json.load(f) | |
# Extract model IDs | |
model_ids = [model["id"] for model in models_data["data"]] | |
# Limit to first 20 models | |
model_ids = model_ids[:20] | |
def extract_svg(text): | |
"""Extract SVG content from model response""" | |
# First, check for code blocks with different markers | |
code_block_patterns = [ | |
r"```svg\s*(.*?)\s*```", | |
r"```xml\s*(.*?)\s*```", | |
r"```html\s*(.*?)\s*```", | |
r"```\s*(.*?)\s*```", | |
] | |
for pattern in code_block_patterns: | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
if match: | |
content = match.group(1) | |
# Extract SVG from the code block content | |
if "<svg" in content: | |
svg_match = re.search( | |
r"<svg[^>]*>.*?</svg>", content, re.DOTALL | re.IGNORECASE | |
) | |
if svg_match: | |
return svg_match.group(0) | |
# If no code blocks, look for SVG directly in the text | |
# Handle cases where SVG might be in thinking tags or other wrappers | |
svg_pattern = r"<svg[^>]*>.*?</svg>" | |
svg_match = re.search(svg_pattern, text, re.DOTALL | re.IGNORECASE) | |
if svg_match: | |
return svg_match.group(0) | |
return None | |
def test_model_with_temperature(model_id, temperature): | |
"""Test a single model with a specific temperature""" | |
print(f"Testing {model_id} with temperature {temperature}...") | |
result = { | |
"model_id": model_id, | |
"temperature": temperature, | |
"timestamp": datetime.now().isoformat(), | |
"success": False, | |
"response_time": None, | |
"svg_content": None, | |
"error": None, | |
"raw_response": None, | |
} | |
prompt = """Create a pelican riding a bicycle using SVG. Return only the SVG code without any explanation or markdown formatting. The SVG should be a complete, valid SVG document starting with <svg> and ending with </svg>.""" | |
headers = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
"Content-Type": "application/json", | |
} | |
data = { | |
"model": model_id, | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": 2000, | |
"temperature": temperature, | |
} | |
try: | |
start_time = time.time() | |
response = requests.post( | |
"https://router.huggingface.co/v1/chat/completions", | |
headers=headers, | |
json=data, | |
timeout=60, | |
) | |
response_time = time.time() - start_time | |
result["response_time"] = response_time | |
if response.status_code == 200: | |
response_data = response.json() | |
if response_data.get("choices") and response_data["choices"][0].get( | |
"message" | |
): | |
response_text = response_data["choices"][0]["message"]["content"] | |
result["raw_response"] = response_text | |
# Extract SVG | |
svg_content = extract_svg(response_text) | |
if svg_content: | |
result["svg_content"] = svg_content | |
result["success"] = True | |
else: | |
result["error"] = "No valid SVG found in response" | |
else: | |
result["error"] = "Empty response from model" | |
else: | |
result["error"] = f"HTTP {response.status_code}: {response.text}" | |
except Exception as e: | |
result["error"] = str(e) | |
print(f"Error testing {model_id} with temperature {temperature}: {e}") | |
return result | |
def main(): | |
temperatures = [0, 0.5, 1.0] | |
print(f"Testing {len(model_ids)} models with {len(temperatures)} temperature settings...") | |
results = [] | |
# Create test tasks for each model and temperature combination | |
test_tasks = [] | |
for model_id in model_ids: | |
for temp in temperatures: | |
test_tasks.append((model_id, temp)) | |
# Use ThreadPoolExecutor for concurrent requests | |
with ThreadPoolExecutor(max_workers=10) as executor: | |
future_to_task = { | |
executor.submit(test_model_with_temperature, task[0], task[1]): task | |
for task in test_tasks | |
} | |
for future in as_completed(future_to_task): | |
task = future_to_task[future] | |
model_id, temp = task | |
try: | |
result = future.result() | |
results.append(result) | |
print( | |
f"Completed {model_id} (temp={temp}): {'Success' if result['success'] else 'Failed'}" | |
) | |
except Exception as e: | |
print(f"Exception for {model_id} (temp={temp}): {e}") | |
results.append({ | |
"model_id": model_id, | |
"temperature": temp, | |
"success": False, | |
"error": str(e) | |
}) | |
# Save results | |
with open("benchmark_results.json", "w") as f: | |
json.dump(results, f, indent=2) | |
# Generate statistics | |
total_tests = len(results) | |
successful_tests = sum(1 for r in results if r.get("success", False)) | |
# Group by model to count unique models with at least one success | |
models_with_success = {} | |
for r in results: | |
if r.get("success", False): | |
models_with_success[r["model_id"]] = True | |
stats = { | |
"total_models": len(model_ids), | |
"temperatures_tested": temperatures, | |
"total_tests": total_tests, | |
"successful_tests": successful_tests, | |
"failed_tests": total_tests - successful_tests, | |
"models_with_at_least_one_success": len(models_with_success), | |
"average_response_time": ( | |
sum(r.get("response_time", 0) for r in results if r.get("response_time")) | |
/ len([r for r in results if r.get("response_time")]) | |
if any(r.get("response_time") for r in results) | |
else 0 | |
), | |
"successful_model_ids": list(models_with_success.keys()), | |
} | |
with open("benchmark_stats.json", "w") as f: | |
json.dump(stats, f, indent=2) | |
print("\nBenchmark complete!") | |
print(f"Total models tested: {stats['total_models']}") | |
print(f"Temperature settings: {stats['temperatures_tested']}") | |
print(f"Total tests: {stats['total_tests']}") | |
print(f"Successful tests: {stats['successful_tests']}") | |
print(f"Failed tests: {stats['failed_tests']}") | |
print(f"Models with at least one success: {stats['models_with_at_least_one_success']}") | |
print(f"Average response time: {stats['average_response_time']:.2f}s") | |
if __name__ == "__main__": | |
main() | |