Spaces:
Running
Running
File size: 6,848 Bytes
38e4cc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import json
import time
import re
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Load models
with open("models.json", "r") as f:
models_data = json.load(f)
# Extract model IDs
model_ids = [model["id"] for model in models_data["data"]]
# Limit to first 20 models
model_ids = model_ids[:20]
def extract_svg(text):
"""Extract SVG content from model response"""
# First, check for code blocks with different markers
code_block_patterns = [
r"```svg\s*(.*?)\s*```",
r"```xml\s*(.*?)\s*```",
r"```html\s*(.*?)\s*```",
r"```\s*(.*?)\s*```",
]
for pattern in code_block_patterns:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
content = match.group(1)
# Extract SVG from the code block content
if "<svg" in content:
svg_match = re.search(
r"<svg[^>]*>.*?</svg>", content, re.DOTALL | re.IGNORECASE
)
if svg_match:
return svg_match.group(0)
# If no code blocks, look for SVG directly in the text
# Handle cases where SVG might be in thinking tags or other wrappers
svg_pattern = r"<svg[^>]*>.*?</svg>"
svg_match = re.search(svg_pattern, text, re.DOTALL | re.IGNORECASE)
if svg_match:
return svg_match.group(0)
return None
def test_model_with_temperature(model_id, temperature):
"""Test a single model with a specific temperature"""
print(f"Testing {model_id} with temperature {temperature}...")
result = {
"model_id": model_id,
"temperature": temperature,
"timestamp": datetime.now().isoformat(),
"success": False,
"response_time": None,
"svg_content": None,
"error": None,
"raw_response": None,
}
prompt = """Create a pelican riding a bicycle using SVG. Return only the SVG code without any explanation or markdown formatting. The SVG should be a complete, valid SVG document starting with <svg> and ending with </svg>."""
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
}
data = {
"model": model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 2000,
"temperature": temperature,
}
try:
start_time = time.time()
response = requests.post(
"https://router.huggingface.co/v1/chat/completions",
headers=headers,
json=data,
timeout=60,
)
response_time = time.time() - start_time
result["response_time"] = response_time
if response.status_code == 200:
response_data = response.json()
if response_data.get("choices") and response_data["choices"][0].get(
"message"
):
response_text = response_data["choices"][0]["message"]["content"]
result["raw_response"] = response_text
# Extract SVG
svg_content = extract_svg(response_text)
if svg_content:
result["svg_content"] = svg_content
result["success"] = True
else:
result["error"] = "No valid SVG found in response"
else:
result["error"] = "Empty response from model"
else:
result["error"] = f"HTTP {response.status_code}: {response.text}"
except Exception as e:
result["error"] = str(e)
print(f"Error testing {model_id} with temperature {temperature}: {e}")
return result
def main():
temperatures = [0, 0.5, 1.0]
print(f"Testing {len(model_ids)} models with {len(temperatures)} temperature settings...")
results = []
# Create test tasks for each model and temperature combination
test_tasks = []
for model_id in model_ids:
for temp in temperatures:
test_tasks.append((model_id, temp))
# Use ThreadPoolExecutor for concurrent requests
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_task = {
executor.submit(test_model_with_temperature, task[0], task[1]): task
for task in test_tasks
}
for future in as_completed(future_to_task):
task = future_to_task[future]
model_id, temp = task
try:
result = future.result()
results.append(result)
print(
f"Completed {model_id} (temp={temp}): {'Success' if result['success'] else 'Failed'}"
)
except Exception as e:
print(f"Exception for {model_id} (temp={temp}): {e}")
results.append({
"model_id": model_id,
"temperature": temp,
"success": False,
"error": str(e)
})
# Save results
with open("benchmark_results.json", "w") as f:
json.dump(results, f, indent=2)
# Generate statistics
total_tests = len(results)
successful_tests = sum(1 for r in results if r.get("success", False))
# Group by model to count unique models with at least one success
models_with_success = {}
for r in results:
if r.get("success", False):
models_with_success[r["model_id"]] = True
stats = {
"total_models": len(model_ids),
"temperatures_tested": temperatures,
"total_tests": total_tests,
"successful_tests": successful_tests,
"failed_tests": total_tests - successful_tests,
"models_with_at_least_one_success": len(models_with_success),
"average_response_time": (
sum(r.get("response_time", 0) for r in results if r.get("response_time"))
/ len([r for r in results if r.get("response_time")])
if any(r.get("response_time") for r in results)
else 0
),
"successful_model_ids": list(models_with_success.keys()),
}
with open("benchmark_stats.json", "w") as f:
json.dump(stats, f, indent=2)
print("\nBenchmark complete!")
print(f"Total models tested: {stats['total_models']}")
print(f"Temperature settings: {stats['temperatures_tested']}")
print(f"Total tests: {stats['total_tests']}")
print(f"Successful tests: {stats['successful_tests']}")
print(f"Failed tests: {stats['failed_tests']}")
print(f"Models with at least one success: {stats['models_with_at_least_one_success']}")
print(f"Average response time: {stats['average_response_time']:.2f}s")
if __name__ == "__main__":
main()
|