File size: 6,848 Bytes
38e4cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import json
import time
import re
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime

# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN", "")

# Load models
with open("models.json", "r") as f:
    models_data = json.load(f)

# Extract model IDs
model_ids = [model["id"] for model in models_data["data"]]

# Limit to first 20 models
model_ids = model_ids[:20]


def extract_svg(text):
    """Extract SVG content from model response"""
    # First, check for code blocks with different markers
    code_block_patterns = [
        r"```svg\s*(.*?)\s*```",
        r"```xml\s*(.*?)\s*```",
        r"```html\s*(.*?)\s*```",
        r"```\s*(.*?)\s*```",
    ]

    for pattern in code_block_patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            content = match.group(1)
            # Extract SVG from the code block content
            if "<svg" in content:
                svg_match = re.search(
                    r"<svg[^>]*>.*?</svg>", content, re.DOTALL | re.IGNORECASE
                )
                if svg_match:
                    return svg_match.group(0)

    # If no code blocks, look for SVG directly in the text
    # Handle cases where SVG might be in thinking tags or other wrappers
    svg_pattern = r"<svg[^>]*>.*?</svg>"
    svg_match = re.search(svg_pattern, text, re.DOTALL | re.IGNORECASE)
    if svg_match:
        return svg_match.group(0)

    return None


def test_model_with_temperature(model_id, temperature):
    """Test a single model with a specific temperature"""
    print(f"Testing {model_id} with temperature {temperature}...")

    result = {
        "model_id": model_id,
        "temperature": temperature,
        "timestamp": datetime.now().isoformat(),
        "success": False,
        "response_time": None,
        "svg_content": None,
        "error": None,
        "raw_response": None,
    }

    prompt = """Create a pelican riding a bicycle using SVG. Return only the SVG code without any explanation or markdown formatting. The SVG should be a complete, valid SVG document starting with <svg> and ending with </svg>."""

    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Content-Type": "application/json",
    }

    data = {
        "model": model_id,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 2000,
        "temperature": temperature,
    }

    try:
        start_time = time.time()

        response = requests.post(
            "https://router.huggingface.co/v1/chat/completions",
            headers=headers,
            json=data,
            timeout=60,
        )

        response_time = time.time() - start_time
        result["response_time"] = response_time

        if response.status_code == 200:
            response_data = response.json()
            if response_data.get("choices") and response_data["choices"][0].get(
                "message"
            ):
                response_text = response_data["choices"][0]["message"]["content"]
                result["raw_response"] = response_text

                # Extract SVG
                svg_content = extract_svg(response_text)
                if svg_content:
                    result["svg_content"] = svg_content
                    result["success"] = True
                else:
                    result["error"] = "No valid SVG found in response"
            else:
                result["error"] = "Empty response from model"
        else:
            result["error"] = f"HTTP {response.status_code}: {response.text}"

    except Exception as e:
        result["error"] = str(e)
        print(f"Error testing {model_id} with temperature {temperature}: {e}")

    return result


def main():
    temperatures = [0, 0.5, 1.0]
    print(f"Testing {len(model_ids)} models with {len(temperatures)} temperature settings...")
    results = []

    # Create test tasks for each model and temperature combination
    test_tasks = []
    for model_id in model_ids:
        for temp in temperatures:
            test_tasks.append((model_id, temp))

    # Use ThreadPoolExecutor for concurrent requests
    with ThreadPoolExecutor(max_workers=10) as executor:
        future_to_task = {
            executor.submit(test_model_with_temperature, task[0], task[1]): task 
            for task in test_tasks
        }

        for future in as_completed(future_to_task):
            task = future_to_task[future]
            model_id, temp = task
            try:
                result = future.result()
                results.append(result)
                print(
                    f"Completed {model_id} (temp={temp}): {'Success' if result['success'] else 'Failed'}"
                )
            except Exception as e:
                print(f"Exception for {model_id} (temp={temp}): {e}")
                results.append({
                    "model_id": model_id, 
                    "temperature": temp,
                    "success": False, 
                    "error": str(e)
                })

    # Save results
    with open("benchmark_results.json", "w") as f:
        json.dump(results, f, indent=2)

    # Generate statistics
    total_tests = len(results)
    successful_tests = sum(1 for r in results if r.get("success", False))
    
    # Group by model to count unique models with at least one success
    models_with_success = {}
    for r in results:
        if r.get("success", False):
            models_with_success[r["model_id"]] = True
    
    stats = {
        "total_models": len(model_ids),
        "temperatures_tested": temperatures,
        "total_tests": total_tests,
        "successful_tests": successful_tests,
        "failed_tests": total_tests - successful_tests,
        "models_with_at_least_one_success": len(models_with_success),
        "average_response_time": (
            sum(r.get("response_time", 0) for r in results if r.get("response_time"))
            / len([r for r in results if r.get("response_time")])
            if any(r.get("response_time") for r in results)
            else 0
        ),
        "successful_model_ids": list(models_with_success.keys()),
    }

    with open("benchmark_stats.json", "w") as f:
        json.dump(stats, f, indent=2)

    print("\nBenchmark complete!")
    print(f"Total models tested: {stats['total_models']}")
    print(f"Temperature settings: {stats['temperatures_tested']}")
    print(f"Total tests: {stats['total_tests']}")
    print(f"Successful tests: {stats['successful_tests']}")
    print(f"Failed tests: {stats['failed_tests']}")
    print(f"Models with at least one success: {stats['models_with_at_least_one_success']}")
    print(f"Average response time: {stats['average_response_time']:.2f}s")


if __name__ == "__main__":
    main()