Karthik1610 commited on
Commit
03ecb7b
Β·
verified Β·
1 Parent(s): ec47f09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -355
app.py CHANGED
@@ -1,388 +1,292 @@
1
- import os, time, json, io
2
- from typing import Dict, List, Any, Optional, Tuple
3
-
4
  import gradio as gr
5
  import pandas as pd
6
-
7
  from datasets import load_dataset
 
8
  import evaluate
9
 
10
- from huggingface_hub import InferenceClient, ModelCard
11
-
12
- # Keep evaluate for ROUGE and SacreBLEU only (no sklearn required)
13
- ROUGE = evaluate.load("rouge")
14
- SACREBLEU = evaluate.load("sacrebleu")
15
-
16
- # ---------- Small helpers: accuracy & F1 (macro) without scikit-learn ----------
17
- def _accuracy_score(y_pred: List[str], y_true: List[str]) -> float:
18
- paired = [(p, t) for p, t in zip(y_pred, y_true) if p is not None]
19
- if not paired:
20
- return 0.0
21
- correct = sum(1 for p, t in paired if str(p) == str(t))
22
- return correct / len(paired)
23
-
24
- def _f1_macro_score(y_pred: List[str], y_true: List[str]) -> float:
25
- paired = [(p, t) for p, t in zip(y_pred, y_true) if p is not None]
26
- if not paired:
27
- return 0.0
28
- yp, yt = zip(*paired)
29
- labels = sorted(set(yt))
30
- def _f1_for(label: str) -> float:
31
- tp = sum(1 for p, t in zip(yp, yt) if p == label and t == label)
32
- fp = sum(1 for p, t in zip(yp, yt) if p == label and t != label)
33
- fn = sum(1 for p, t in zip(yp, yt) if p != label and t == label)
34
- if tp == 0 and (fp == 0 or fn == 0):
35
- return 0.0
36
- prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
37
- rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
38
- return (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0
39
- scores = [ _f1_for(lbl) for lbl in labels ]
40
- return sum(scores) / len(scores) if scores else 0.0
41
- # -----------------------------------------------------------------------------
42
-
43
- TASKS: Dict[str, Dict[str, str]] = {
44
- "sentiment": {
45
- "distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT SST-2",
46
- "cardiffnlp/twitter-roberta-base-sentiment-latest": "RoBERTa Twitter Sentiment"
47
- },
48
- "zero-shot-classification": {
49
- "facebook/bart-large-mnli": "BART MNLI",
50
- "joeddav/xlm-roberta-large-xnli": "XLM-R XNLI"
51
- },
52
- "summarization": {
53
- "facebook/bart-large-cnn": "BART CNN",
54
- "google/pegasus-xsum": "Pegasus XSum"
55
- },
56
- "translation_en_fr": {
57
- "Helsinki-NLP/opus-mt-en-fr": "Opus-MT EN to FR",
58
- "facebook/m2m100_418M": "M2M100 418M"
59
- }
60
- }
61
 
62
- METRIC_POLICY = {
63
- "sentiment": {"requires": ["label"], "metrics": ["accuracy", "f1_macro"]},
64
- "zero-shot-classification": {"requires": [], "metrics": ["accuracy_if_labels", "f1_macro_if_labels"]},
65
- "summarization": {"requires": ["reference"], "metrics": ["rougeL", "rouge1_opt", "rouge2_opt"]},
66
- "translation_en_fr": {"requires": ["reference"], "metrics": ["sacrebleu", "chrf_opt"]},
 
 
 
 
 
 
 
 
 
67
  }
68
 
69
- def validate_token(hf_token: str) -> Tuple[bool, str]:
70
- if not hf_token or not hf_token.strip().startswith("hf_"):
71
- return False, "Paste a valid Hugging Face token starting with hf_"
72
- return True, "Token format OK. We'll use it only for this session."
73
-
74
- def load_hub_dataset(ds_id: str, config: Optional[str], split: Optional[str], sample_size: int) -> Tuple[pd.DataFrame, Dict[str, Any]]:
75
- kwargs = {}
76
- if config:
77
- kwargs["name"] = config
78
- if split:
79
- kwargs["split"] = split
80
- ds = load_dataset(ds_id, **kwargs)
81
- if not split:
82
- for sp in ["test", "validation", "train"]:
83
- if sp in ds:
84
- split = sp
85
- break
86
- d = ds[split].to_pandas()
87
- if sample_size and sample_size < len(d):
88
- d = d.sample(n=sample_size, random_state=42)
89
- meta = {"dataset_id": ds_id, "config": config, "split": split}
90
- return d, meta
91
 
92
- def decide_metrics(task: str, mapped_cols: List[str]) -> List[str]:
93
- policy = METRIC_POLICY.get(task, {})
94
- reqs = set(policy.get("requires", []))
95
- if not reqs.issubset(set(mapped_cols)):
96
- return ["latency_only"]
97
- out = []
98
- for m in policy.get("metrics", []):
99
- if m.endswith("_opt"):
100
- continue
101
- if m.endswith("_if_labels") and "label" not in mapped_cols:
102
- continue
103
- out.append(m)
104
- return out
105
-
106
- def normalize_cls_label(pred_label: str, label_names: Optional[List[str]]):
107
- if label_names is None:
108
- return pred_label
109
- low = str(pred_label).lower()
110
- for name in label_names:
111
- if low == str(name).lower():
112
- return name
113
- if low.startswith("pos"):
114
- for name in label_names:
115
- if "pos" in str(name).lower():
116
- return name
117
- if low.startswith("neg"):
118
- for name in label_names:
119
- if "neg" in str(name).lower():
120
- return name
121
- return pred_label
122
 
123
- def run_remote_inference(task: str, model_id: str, token: str, texts: List[str], zs_labels: Optional[List[str]] = None,
124
- gen_params: Optional[Dict[str, Any]] = None, timeout_s: int = 20) -> Tuple[List[Any], float]:
125
  client = InferenceClient(model=model_id, token=token)
126
- gen_params = gen_params or {}
127
- outputs = []
128
- t0 = time.perf_counter()
129
- for t in texts:
130
  try:
131
- if task == "summarization":
132
- out = client.summarization(t, **gen_params)
133
- outputs.append(out)
134
- elif task == "translation_en_fr":
135
- out = client.translation(t, src_lang="en", tgt_lang="fr", **gen_params)
136
- outputs.append(out)
137
- elif task == "sentiment":
138
- out = client.text_classification(t)
139
- outputs.append(out)
140
- elif task == "zero-shot-classification":
141
- if not zs_labels:
142
- outputs.append({"label": None, "score": None})
143
- else:
144
- out = client.zero_shot_classification(t, labels=zs_labels)
145
- outputs.append(out)
146
  else:
147
- outputs.append(None)
148
  except Exception as e:
149
- outputs.append({"error": str(e)})
150
- latency = (time.perf_counter() - t0) / max(1, len(texts))
151
- return outputs, latency
152
 
153
- def compute_metrics(task: str, preds: List[Any], refs: Optional[List[Any]], label_names: Optional[List[str]] = None) -> Dict[str, float]:
 
154
  metrics = {}
155
- if task in ["sentiment", "zero-shot-classification"] and refs is not None:
156
- if label_names and isinstance(refs[0], (int, float)):
157
- refs = [label_names[int(x)] for x in refs]
158
- y_pred = []
159
- for p in preds:
160
- if isinstance(p, list) and len(p) and isinstance(p[0], dict) and "label" in p[0]:
161
- y_pred.append(normalize_cls_label(p[0]["label"], label_names))
162
- elif isinstance(p, dict) and "label" in p:
163
- y_pred.append(normalize_cls_label(p.get("label"), label_names))
164
  else:
165
- y_pred.append(None)
166
- y_true = [str(x) for x in refs]
167
- metrics["accuracy"] = _accuracy_score(y_pred, y_true)
168
- metrics["f1_macro"] = _f1_macro_score(y_pred, y_true)
169
- elif task == "summarization" and refs is not None:
170
- preds_text = []
171
- for p in preds:
172
- if isinstance(p, dict) and "summary_text" in p:
173
- preds_text.append(p["summary_text"])
174
- elif isinstance(p, list) and len(p) and isinstance(p[0], dict) and "summary_text" in p[0]:
175
- preds_text.append(p[0]["summary_text"])
176
- elif isinstance(p, str):
177
- preds_text.append(p)
178
  else:
179
- preds_text.append("")
180
- metrics.update(ROUGE.compute(predictions=preds_text, references=refs))
181
- elif task == "translation_en_fr" and refs is not None:
182
- preds_text = []
183
- for p in preds:
184
- if isinstance(p, dict) and "translation_text" in p:
185
- preds_text.append(p["translation_text"])
186
- elif isinstance(p, list) and len(p) and isinstance(p[0], dict) and "translation_text" in p[0]:
187
- preds_text.append(p[0]["translation_text"])
188
- elif isinstance(p, str):
189
- preds_text.append(p)
190
  else:
191
- preds_text.append("")
192
- metrics.update(SACREBLEU.compute(predictions=preds_text, references=[[r] for r in refs]))
 
 
 
193
  return metrics
194
 
195
- def lint_model(model_id: str, token: Optional[str]) -> Dict[str, Any]:
196
- out = {"model": model_id, "readiness": 0, "checks": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  try:
198
- card = ModelCard.load(model_id, token=token)
199
- data = card.data.to_dict() if hasattr(card, 'data') else {}
200
- pipeline_tag = data.get("pipeline_tag")
201
- license_ = data.get("license")
202
- has_tags = bool(data.get("tags"))
203
- score = 0
204
- score += 25 if pipeline_tag else 0
205
- score += 25 if license_ else 0
206
- score += 25 if has_tags else 0
207
- score += 25
208
- out["readiness"] = score
209
- out["checks"].append({"pipeline_tag": pipeline_tag, "license": license_, "has_tags": has_tags})
210
- except Exception as e:
211
- out["checks"].append({"error": str(e)})
212
- return out
213
-
214
- def run_benchmark(hf_token: str, compute_mode: str, task: str, curated_models: List[str], custom_model: str,
215
- ds_source: str, ds_id: str, ds_config: str, ds_split: str, csv_file, text_col: str,
216
- label_col: str, ref_col: str, sample_size: int, zs_labels_csv: str,
217
- max_new_tokens: int, temperature: float, batch_size: int, timeout_s: int):
218
- models = []
219
- if curated_models:
220
- models.extend(curated_models)
221
- if custom_model and custom_model.strip():
222
- models.append(custom_model.strip())
223
- models = list(dict.fromkeys(models))
224
- if not models:
225
- return {"error": "Pick at least one model"}
226
-
227
- if ds_source == "hub":
228
- df, meta = load_hub_dataset(ds_id, ds_config or None, ds_split or None, sample_size)
229
- else:
230
- if csv_file is None:
231
- return {"error": "Upload a CSV"}
232
- df = pd.read_csv(csv_file.name)
233
- if sample_size and sample_size < len(df):
234
- df = df.sample(n=sample_size, random_state=42)
235
- meta = {"dataset_id": "uploaded_csv", "config": None, "split": None}
236
-
237
- if text_col not in df.columns:
238
- text_col = text_col or df.columns[0]
239
- labels = df[label_col].tolist() if label_col and label_col in df.columns else None
240
- refs = df[ref_col].tolist() if ref_col and ref_col in df.columns else None
241
-
242
- zs_labels = [x.strip() for x in zs_labels_csv.split(',')] if (task == "zero-shot-classification" and zs_labels_csv) else None
243
-
244
- all_preds = {}
245
- metrics_table = []
246
-
247
- for mid in models:
248
- preds, avg_lat = run_remote_inference(
249
- task=task,
250
- model_id=mid,
251
- token=hf_token,
252
- texts=df[text_col].astype(str).tolist(),
253
- zs_labels=zs_labels,
254
- gen_params={"max_new_tokens": int(max_new_tokens), "temperature": float(temperature)},
255
- timeout_s=int(timeout_s)
256
- )
257
- all_preds[mid] = preds
258
- m = compute_metrics(task, preds, refs if task in ["summarization", "translation_en_fr"] else labels, label_names=None)
259
- m["avg_latency_s"] = avg_lat
260
- metrics_table.append({"model": mid, **m})
261
-
262
- preview = pd.DataFrame({"text": df[text_col].astype(str).tolist()})
263
- if labels is not None:
264
- preview["label"] = labels
265
- if refs is not None:
266
- preview["reference"] = refs
267
- for mid, preds in all_preds.items():
268
- col = []
269
- for p in preds:
270
- if isinstance(p, dict):
271
- col.append(p.get("summary_text") or p.get("translation_text") or p.get("label") or str(p))
272
- elif isinstance(p, list) and len(p) and isinstance(p[0], dict):
273
- col.append(p[0].get("summary_text") or p[0].get("translation_text") or p[0].get("label") or str(p[0]))
274
- else:
275
- col.append(str(p))
276
- preview[mid] = col
277
-
278
- csv_buf = io.StringIO()
279
- preview.to_csv(csv_buf, index=False)
280
- csv_bytes = io.BytesIO(csv_buf.getvalue().encode("utf-8"))
281
-
282
- lints = [lint_model(m, hf_token) for m in models]
283
-
284
- return {
285
- "metrics": pd.DataFrame(metrics_table),
286
- "preview": preview.head(20),
287
- "download": ("predictions.csv", csv_bytes),
288
- "lint": lints,
289
- "session": {"task": task, "models": models, "dataset": meta, "sample_size": sample_size}
290
- }
291
-
292
- def build_ui():
293
- # Use Interface instead of Blocks to avoid the JSON schema parsing bug
294
- def benchmark_interface(hf_token, task, curated_models_text, custom_model, ds_id, ds_config, ds_split,
295
- text_col, label_col, ref_col, sample_size, zs_labels_csv, max_new_tokens,
296
- temperature, timeout_s, csv_file=None):
297
 
298
- # Parse curated models from text input
299
- curated_models = [m.strip() for m in curated_models_text.split('\n') if m.strip()] if curated_models_text else []
300
 
301
- # Validate token
302
- if not hf_token or not hf_token.strip().startswith("hf_"):
303
- return "Error: Please provide a valid HF token", "", "", ""
304
 
305
- try:
306
- out = run_benchmark(
307
- hf_token=hf_token,
308
- compute_mode="Remote (Inference API)",
309
- task=task,
310
- curated_models=curated_models,
311
- custom_model=custom_model,
312
- ds_source="hub" if csv_file is None else "csv",
313
- ds_id=ds_id,
314
- ds_config=ds_config,
315
- ds_split=ds_split,
316
- csv_file=csv_file,
317
- text_col=text_col,
318
- label_col=label_col,
319
- ref_col=ref_col,
320
- sample_size=int(sample_size),
321
- zs_labels_csv=zs_labels_csv,
322
- max_new_tokens=int(max_new_tokens),
323
- temperature=float(temperature),
324
- batch_size=8,
325
- timeout_s=int(timeout_s)
326
- )
327
 
328
- if isinstance(out, dict) and "error" in out:
329
- return f"Error: {out['error']}", "", "", ""
 
330
 
331
- # Format outputs as strings
332
- metrics_str = out["metrics"].to_string() if not out["metrics"].empty else "No metrics computed"
333
- preview_str = out["preview"].to_string() if not out["preview"].empty else "No preview available"
334
- lint_str = json.dumps(out["lint"], indent=2)
335
- session_str = json.dumps(out["session"], indent=2)
336
 
337
- return metrics_str, preview_str, lint_str, session_str
338
 
339
- except Exception as e:
340
- return f"Error: {str(e)}", "", "", ""
341
-
342
- # Create Interface instead of Blocks
343
- demo = gr.Interface(
344
- fn=benchmark_interface,
345
- inputs=[
346
- gr.Textbox(label="Hugging Face Token", type="password", placeholder="hf_..."),
347
- gr.Dropdown(choices=list(TASKS.keys()), label="Task", value="sentiment"),
348
- gr.Textbox(label="Curated Models (one per line)", lines=3,
349
- placeholder="distilbert-base-uncased-finetuned-sst-2-english\ncardiffnlp/twitter-roberta-base-sentiment-latest"),
350
- gr.Textbox(label="Custom Model ID (optional)", placeholder="username/my-finetune"),
351
- gr.Textbox(label="Dataset ID", value="imdb"),
352
- gr.Textbox(label="Config (optional)"),
353
- gr.Textbox(label="Split (optional)"),
354
- gr.Textbox(label="Text Column", value="text"),
355
- gr.Textbox(label="Label Column (optional)"),
356
- gr.Textbox(label="Reference Column (optional)"),
357
- gr.Slider(20, 500, value=100, step=10, label="Sample Size"),
358
- gr.Textbox(label="Zero-shot Labels (comma-separated)"),
359
- gr.Number(value=128, label="Max New Tokens"),
360
- gr.Number(value=0.7, label="Temperature"),
361
- gr.Number(value=20, label="Timeout (seconds)"),
362
- gr.File(file_types=[".csv"], label="CSV File (optional)")
363
- ],
364
- outputs=[
365
- gr.Textbox(label="Metrics", lines=10),
366
- gr.Textbox(label="Preview (first 20 rows)", lines=15),
367
- gr.Textbox(label="Model Lint Results", lines=8),
368
- gr.Textbox(label="Session Info", lines=5)
369
- ],
370
- title="AI Model Benchmark Hub",
371
- description="Compare AI models on various tasks using the Hugging Face Inference API"
372
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  return demo
375
 
376
- app = build_ui()
377
-
378
- # Simple launch configuration for HF Spaces
379
  if __name__ == "__main__":
380
- # Check if we're on HF Spaces
381
- if "SPACE_ID" in os.environ:
382
- app.launch()
383
- else:
384
- # Local development
385
- try:
386
- app.launch(share=True)
387
- except:
388
- app.launch(server_name="127.0.0.1", server_port=7860)
 
1
+ import os
2
+ import time
3
+ import json
4
  import gradio as gr
5
  import pandas as pd
6
+ from typing import List, Optional, Dict, Any
7
  from datasets import load_dataset
8
+ from huggingface_hub import InferenceClient
9
  import evaluate
10
 
11
+ # Load evaluation metrics
12
+ rouge = evaluate.load("rouge")
13
+ sacrebleu = evaluate.load("sacrebleu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Model configurations
16
+ MODELS = {
17
+ "sentiment": [
18
+ "distilbert-base-uncased-finetuned-sst-2-english",
19
+ "cardiffnlp/twitter-roberta-base-sentiment-latest"
20
+ ],
21
+ "summarization": [
22
+ "facebook/bart-large-cnn",
23
+ "google/pegasus-xsum"
24
+ ],
25
+ "translation": [
26
+ "Helsinki-NLP/opus-mt-en-fr",
27
+ "facebook/m2m100_418M"
28
+ ]
29
  }
30
 
31
+ def validate_token(token: str) -> bool:
32
+ """Validate HF token format"""
33
+ return token and token.strip().startswith("hf_")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ def accuracy_score(predictions: List[str], labels: List[str]) -> float:
36
+ """Calculate accuracy without sklearn"""
37
+ if len(predictions) != len(labels):
38
+ return 0.0
39
+ correct = sum(1 for p, l in zip(predictions, labels) if str(p).lower() == str(l).lower())
40
+ return correct / len(labels) if labels else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def run_inference(model_id: str, texts: List[str], task: str, token: str) -> List[Dict]:
43
+ """Run inference using HF Inference API"""
44
  client = InferenceClient(model=model_id, token=token)
45
+ results = []
46
+
47
+ for text in texts:
 
48
  try:
49
+ if task == "sentiment":
50
+ result = client.text_classification(text)
51
+ results.append(result[0] if isinstance(result, list) else result)
52
+ elif task == "summarization":
53
+ result = client.summarization(text, max_length=150)
54
+ results.append(result)
55
+ elif task == "translation":
56
+ result = client.translation(text, src_lang="en", tgt_lang="fr")
57
+ results.append(result)
 
 
 
 
 
 
58
  else:
59
+ results.append({"error": "Unsupported task"})
60
  except Exception as e:
61
+ results.append({"error": str(e)})
62
+
63
+ return results
64
 
65
+ def compute_metrics(task: str, predictions: List[Dict], references: Optional[List[str]] = None) -> Dict[str, float]:
66
+ """Compute task-specific metrics"""
67
  metrics = {}
68
+
69
+ if task == "sentiment" and references:
70
+ pred_labels = []
71
+ for pred in predictions:
72
+ if isinstance(pred, dict) and "label" in pred:
73
+ pred_labels.append(pred["label"])
 
 
 
74
  else:
75
+ pred_labels.append("UNKNOWN")
76
+
77
+ metrics["accuracy"] = accuracy_score(pred_labels, references)
78
+
79
+ elif task == "summarization" and references:
80
+ pred_texts = []
81
+ for pred in predictions:
82
+ if isinstance(pred, dict) and "summary_text" in pred:
83
+ pred_texts.append(pred["summary_text"])
 
 
 
 
84
  else:
85
+ pred_texts.append("")
86
+
87
+ rouge_scores = rouge.compute(predictions=pred_texts, references=references)
88
+ metrics.update(rouge_scores)
89
+
90
+ elif task == "translation" and references:
91
+ pred_texts = []
92
+ for pred in predictions:
93
+ if isinstance(pred, dict) and "translation_text" in pred:
94
+ pred_texts.append(pred["translation_text"])
 
95
  else:
96
+ pred_texts.append("")
97
+
98
+ bleu_scores = sacrebleu.compute(predictions=pred_texts, references=[[ref] for ref in references])
99
+ metrics.update(bleu_scores)
100
+
101
  return metrics
102
 
103
+ def benchmark_models(
104
+ hf_token: str,
105
+ task: str,
106
+ selected_models: List[str],
107
+ dataset_name: str,
108
+ sample_size: int,
109
+ text_column: str,
110
+ label_column: str
111
+ ):
112
+ """Main benchmarking function"""
113
+
114
+ # Validate token
115
+ if not validate_token(hf_token):
116
+ return "❌ Invalid HuggingFace token. Please provide a token starting with 'hf_'", "", ""
117
+
118
+ if not selected_models:
119
+ return "❌ Please select at least one model", "", ""
120
+
121
  try:
122
+ # Load dataset
123
+ dataset = load_dataset(dataset_name, split="test")
124
+ if sample_size > 0:
125
+ dataset = dataset.select(range(min(sample_size, len(dataset))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ df = dataset.to_pandas()
 
128
 
129
+ if text_column not in df.columns:
130
+ return f"❌ Text column '{text_column}' not found in dataset", "", ""
 
131
 
132
+ texts = df[text_column].astype(str).tolist()
133
+ references = df[label_column].tolist() if label_column in df.columns else None
134
+
135
+ # Results storage
136
+ all_results = []
137
+ detailed_results = {"text": texts}
138
+
139
+ # Run benchmarks
140
+ for model_id in selected_models:
141
+ print(f"Running inference with {model_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ start_time = time.time()
144
+ predictions = run_inference(model_id, texts, task, hf_token)
145
+ inference_time = time.time() - start_time
146
 
147
+ # Compute metrics
148
+ metrics = compute_metrics(task, predictions, references)
149
+ metrics["model"] = model_id
150
+ metrics["inference_time"] = round(inference_time, 2)
151
+ metrics["samples"] = len(texts)
152
 
153
+ all_results.append(metrics)
154
 
155
+ # Store predictions for detailed view
156
+ pred_texts = []
157
+ for pred in predictions:
158
+ if isinstance(pred, dict):
159
+ if "label" in pred:
160
+ pred_texts.append(pred["label"])
161
+ elif "summary_text" in pred:
162
+ pred_texts.append(pred["summary_text"])
163
+ elif "translation_text" in pred:
164
+ pred_texts.append(pred["translation_text"])
165
+ else:
166
+ pred_texts.append(str(pred))
167
+ else:
168
+ pred_texts.append(str(pred))
169
+
170
+ detailed_results[model_id] = pred_texts
171
+
172
+ # Create results DataFrames
173
+ results_df = pd.DataFrame(all_results)
174
+ detailed_df = pd.DataFrame(detailed_results)
175
+
176
+ # Format results for display
177
+ results_str = "πŸ“Š **Benchmark Results:**\n\n"
178
+ results_str += results_df.to_string(index=False)
179
+
180
+ detailed_str = "πŸ” **Detailed Predictions (first 10 samples):**\n\n"
181
+ detailed_str += detailed_df.head(10).to_string(index=False)
182
+
183
+ # Create summary
184
+ summary = f"βœ… **Benchmark Complete!**\n\n"
185
+ summary += f"**Task:** {task}\n"
186
+ summary += f"**Dataset:** {dataset_name}\n"
187
+ summary += f"**Models tested:** {len(selected_models)}\n"
188
+ summary += f"**Samples processed:** {len(texts)}\n"
189
+ summary += f"**Total time:** {sum(r['inference_time'] for r in all_results):.2f}s\n"
190
+
191
+ return summary, results_str, detailed_str
192
+
193
+ except Exception as e:
194
+ return f"❌ Error: {str(e)}", "", ""
195
+
196
+ # Create Gradio interface
197
+ def create_interface():
198
+ with gr.Blocks(title="AI Model Benchmark Hub") as demo:
199
+ gr.Markdown("# πŸ§ͺ AI Model Benchmark Hub")
200
+ gr.Markdown("Compare AI models on various tasks using HuggingFace Inference API")
201
+
202
+ with gr.Row():
203
+ with gr.Column():
204
+ gr.Markdown("### πŸ”‘ Authentication")
205
+ hf_token = gr.Textbox(
206
+ label="HuggingFace Token",
207
+ type="password",
208
+ placeholder="hf_...",
209
+ info="Get your token from https://huggingface.co/settings/tokens"
210
+ )
211
+
212
+ gr.Markdown("### πŸ“‹ Task Selection")
213
+ task = gr.Dropdown(
214
+ choices=["sentiment", "summarization", "translation"],
215
+ label="Task",
216
+ value="sentiment"
217
+ )
218
+
219
+ model_choices = gr.CheckboxGroup(
220
+ choices=MODELS["sentiment"],
221
+ label="Select Models",
222
+ value=[MODELS["sentiment"][0]]
223
+ )
224
+
225
+ def update_models(selected_task):
226
+ return gr.update(choices=MODELS[selected_task], value=[MODELS[selected_task][0]])
227
+
228
+ task.change(update_models, inputs=[task], outputs=[model_choices])
229
+
230
+ with gr.Column():
231
+ gr.Markdown("### πŸ“Š Dataset Configuration")
232
+ dataset_name = gr.Textbox(
233
+ label="Dataset Name",
234
+ value="imdb",
235
+ placeholder="e.g., imdb, amazon_reviews_multi"
236
+ )
237
+
238
+ sample_size = gr.Slider(
239
+ minimum=10,
240
+ maximum=1000,
241
+ value=50,
242
+ step=10,
243
+ label="Sample Size"
244
+ )
245
+
246
+ text_column = gr.Textbox(
247
+ label="Text Column Name",
248
+ value="text",
249
+ placeholder="e.g., text, review, sentence"
250
+ )
251
+
252
+ label_column = gr.Textbox(
253
+ label="Label Column Name (optional)",
254
+ value="label",
255
+ placeholder="e.g., label, sentiment, rating"
256
+ )
257
+
258
+ run_btn = gr.Button("πŸš€ Run Benchmark", variant="primary", size="lg")
259
+
260
+ gr.Markdown("---")
261
+
262
+ with gr.Row():
263
+ with gr.Column():
264
+ summary_output = gr.Markdown(label="Summary")
265
+
266
+ with gr.Row():
267
+ with gr.Column():
268
+ results_output = gr.Markdown(label="Results")
269
+ with gr.Column():
270
+ detailed_output = gr.Markdown(label="Detailed Output")
271
+
272
+ # Connect the interface
273
+ run_btn.click(
274
+ benchmark_models,
275
+ inputs=[
276
+ hf_token,
277
+ task,
278
+ model_choices,
279
+ dataset_name,
280
+ sample_size,
281
+ text_column,
282
+ label_column
283
+ ],
284
+ outputs=[summary_output, results_output, detailed_output]
285
+ )
286
 
287
  return demo
288
 
289
+ # Launch the app
 
 
290
  if __name__ == "__main__":
291
+ app = create_interface()
292
+ app.launch()