Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,388 +1,292 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
-
|
| 7 |
from datasets import load_dataset
|
|
|
|
| 8 |
import evaluate
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
ROUGE = evaluate.load("rouge")
|
| 14 |
-
SACREBLEU = evaluate.load("sacrebleu")
|
| 15 |
-
|
| 16 |
-
# ---------- Small helpers: accuracy & F1 (macro) without scikit-learn ----------
|
| 17 |
-
def _accuracy_score(y_pred: List[str], y_true: List[str]) -> float:
|
| 18 |
-
paired = [(p, t) for p, t in zip(y_pred, y_true) if p is not None]
|
| 19 |
-
if not paired:
|
| 20 |
-
return 0.0
|
| 21 |
-
correct = sum(1 for p, t in paired if str(p) == str(t))
|
| 22 |
-
return correct / len(paired)
|
| 23 |
-
|
| 24 |
-
def _f1_macro_score(y_pred: List[str], y_true: List[str]) -> float:
|
| 25 |
-
paired = [(p, t) for p, t in zip(y_pred, y_true) if p is not None]
|
| 26 |
-
if not paired:
|
| 27 |
-
return 0.0
|
| 28 |
-
yp, yt = zip(*paired)
|
| 29 |
-
labels = sorted(set(yt))
|
| 30 |
-
def _f1_for(label: str) -> float:
|
| 31 |
-
tp = sum(1 for p, t in zip(yp, yt) if p == label and t == label)
|
| 32 |
-
fp = sum(1 for p, t in zip(yp, yt) if p == label and t != label)
|
| 33 |
-
fn = sum(1 for p, t in zip(yp, yt) if p != label and t == label)
|
| 34 |
-
if tp == 0 and (fp == 0 or fn == 0):
|
| 35 |
-
return 0.0
|
| 36 |
-
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 37 |
-
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 38 |
-
return (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0
|
| 39 |
-
scores = [ _f1_for(lbl) for lbl in labels ]
|
| 40 |
-
return sum(scores) / len(scores) if scores else 0.0
|
| 41 |
-
# -----------------------------------------------------------------------------
|
| 42 |
-
|
| 43 |
-
TASKS: Dict[str, Dict[str, str]] = {
|
| 44 |
-
"sentiment": {
|
| 45 |
-
"distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT SST-2",
|
| 46 |
-
"cardiffnlp/twitter-roberta-base-sentiment-latest": "RoBERTa Twitter Sentiment"
|
| 47 |
-
},
|
| 48 |
-
"zero-shot-classification": {
|
| 49 |
-
"facebook/bart-large-mnli": "BART MNLI",
|
| 50 |
-
"joeddav/xlm-roberta-large-xnli": "XLM-R XNLI"
|
| 51 |
-
},
|
| 52 |
-
"summarization": {
|
| 53 |
-
"facebook/bart-large-cnn": "BART CNN",
|
| 54 |
-
"google/pegasus-xsum": "Pegasus XSum"
|
| 55 |
-
},
|
| 56 |
-
"translation_en_fr": {
|
| 57 |
-
"Helsinki-NLP/opus-mt-en-fr": "Opus-MT EN to FR",
|
| 58 |
-
"facebook/m2m100_418M": "M2M100 418M"
|
| 59 |
-
}
|
| 60 |
-
}
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
"
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
-
def validate_token(
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return True, "Token format OK. We'll use it only for this session."
|
| 73 |
-
|
| 74 |
-
def load_hub_dataset(ds_id: str, config: Optional[str], split: Optional[str], sample_size: int) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
| 75 |
-
kwargs = {}
|
| 76 |
-
if config:
|
| 77 |
-
kwargs["name"] = config
|
| 78 |
-
if split:
|
| 79 |
-
kwargs["split"] = split
|
| 80 |
-
ds = load_dataset(ds_id, **kwargs)
|
| 81 |
-
if not split:
|
| 82 |
-
for sp in ["test", "validation", "train"]:
|
| 83 |
-
if sp in ds:
|
| 84 |
-
split = sp
|
| 85 |
-
break
|
| 86 |
-
d = ds[split].to_pandas()
|
| 87 |
-
if sample_size and sample_size < len(d):
|
| 88 |
-
d = d.sample(n=sample_size, random_state=42)
|
| 89 |
-
meta = {"dataset_id": ds_id, "config": config, "split": split}
|
| 90 |
-
return d, meta
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
for m in policy.get("metrics", []):
|
| 99 |
-
if m.endswith("_opt"):
|
| 100 |
-
continue
|
| 101 |
-
if m.endswith("_if_labels") and "label" not in mapped_cols:
|
| 102 |
-
continue
|
| 103 |
-
out.append(m)
|
| 104 |
-
return out
|
| 105 |
-
|
| 106 |
-
def normalize_cls_label(pred_label: str, label_names: Optional[List[str]]):
|
| 107 |
-
if label_names is None:
|
| 108 |
-
return pred_label
|
| 109 |
-
low = str(pred_label).lower()
|
| 110 |
-
for name in label_names:
|
| 111 |
-
if low == str(name).lower():
|
| 112 |
-
return name
|
| 113 |
-
if low.startswith("pos"):
|
| 114 |
-
for name in label_names:
|
| 115 |
-
if "pos" in str(name).lower():
|
| 116 |
-
return name
|
| 117 |
-
if low.startswith("neg"):
|
| 118 |
-
for name in label_names:
|
| 119 |
-
if "neg" in str(name).lower():
|
| 120 |
-
return name
|
| 121 |
-
return pred_label
|
| 122 |
|
| 123 |
-
def
|
| 124 |
-
|
| 125 |
client = InferenceClient(model=model_id, token=token)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
for t in texts:
|
| 130 |
try:
|
| 131 |
-
if task == "
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
elif task == "
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
elif task == "
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
elif task == "zero-shot-classification":
|
| 141 |
-
if not zs_labels:
|
| 142 |
-
outputs.append({"label": None, "score": None})
|
| 143 |
-
else:
|
| 144 |
-
out = client.zero_shot_classification(t, labels=zs_labels)
|
| 145 |
-
outputs.append(out)
|
| 146 |
else:
|
| 147 |
-
|
| 148 |
except Exception as e:
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
return
|
| 152 |
|
| 153 |
-
def compute_metrics(task: str,
|
|
|
|
| 154 |
metrics = {}
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
y_pred.append(normalize_cls_label(p[0]["label"], label_names))
|
| 162 |
-
elif isinstance(p, dict) and "label" in p:
|
| 163 |
-
y_pred.append(normalize_cls_label(p.get("label"), label_names))
|
| 164 |
else:
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
metrics["accuracy"] =
|
| 168 |
-
|
| 169 |
-
elif task == "summarization" and
|
| 170 |
-
|
| 171 |
-
for
|
| 172 |
-
if isinstance(
|
| 173 |
-
|
| 174 |
-
elif isinstance(p, list) and len(p) and isinstance(p[0], dict) and "summary_text" in p[0]:
|
| 175 |
-
preds_text.append(p[0]["summary_text"])
|
| 176 |
-
elif isinstance(p, str):
|
| 177 |
-
preds_text.append(p)
|
| 178 |
else:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
preds_text.append(p)
|
| 190 |
else:
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
return metrics
|
| 194 |
|
| 195 |
-
def
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
try:
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
has_tags = bool(data.get("tags"))
|
| 203 |
-
score = 0
|
| 204 |
-
score += 25 if pipeline_tag else 0
|
| 205 |
-
score += 25 if license_ else 0
|
| 206 |
-
score += 25 if has_tags else 0
|
| 207 |
-
score += 25
|
| 208 |
-
out["readiness"] = score
|
| 209 |
-
out["checks"].append({"pipeline_tag": pipeline_tag, "license": license_, "has_tags": has_tags})
|
| 210 |
-
except Exception as e:
|
| 211 |
-
out["checks"].append({"error": str(e)})
|
| 212 |
-
return out
|
| 213 |
-
|
| 214 |
-
def run_benchmark(hf_token: str, compute_mode: str, task: str, curated_models: List[str], custom_model: str,
|
| 215 |
-
ds_source: str, ds_id: str, ds_config: str, ds_split: str, csv_file, text_col: str,
|
| 216 |
-
label_col: str, ref_col: str, sample_size: int, zs_labels_csv: str,
|
| 217 |
-
max_new_tokens: int, temperature: float, batch_size: int, timeout_s: int):
|
| 218 |
-
models = []
|
| 219 |
-
if curated_models:
|
| 220 |
-
models.extend(curated_models)
|
| 221 |
-
if custom_model and custom_model.strip():
|
| 222 |
-
models.append(custom_model.strip())
|
| 223 |
-
models = list(dict.fromkeys(models))
|
| 224 |
-
if not models:
|
| 225 |
-
return {"error": "Pick at least one model"}
|
| 226 |
-
|
| 227 |
-
if ds_source == "hub":
|
| 228 |
-
df, meta = load_hub_dataset(ds_id, ds_config or None, ds_split or None, sample_size)
|
| 229 |
-
else:
|
| 230 |
-
if csv_file is None:
|
| 231 |
-
return {"error": "Upload a CSV"}
|
| 232 |
-
df = pd.read_csv(csv_file.name)
|
| 233 |
-
if sample_size and sample_size < len(df):
|
| 234 |
-
df = df.sample(n=sample_size, random_state=42)
|
| 235 |
-
meta = {"dataset_id": "uploaded_csv", "config": None, "split": None}
|
| 236 |
-
|
| 237 |
-
if text_col not in df.columns:
|
| 238 |
-
text_col = text_col or df.columns[0]
|
| 239 |
-
labels = df[label_col].tolist() if label_col and label_col in df.columns else None
|
| 240 |
-
refs = df[ref_col].tolist() if ref_col and ref_col in df.columns else None
|
| 241 |
-
|
| 242 |
-
zs_labels = [x.strip() for x in zs_labels_csv.split(',')] if (task == "zero-shot-classification" and zs_labels_csv) else None
|
| 243 |
-
|
| 244 |
-
all_preds = {}
|
| 245 |
-
metrics_table = []
|
| 246 |
-
|
| 247 |
-
for mid in models:
|
| 248 |
-
preds, avg_lat = run_remote_inference(
|
| 249 |
-
task=task,
|
| 250 |
-
model_id=mid,
|
| 251 |
-
token=hf_token,
|
| 252 |
-
texts=df[text_col].astype(str).tolist(),
|
| 253 |
-
zs_labels=zs_labels,
|
| 254 |
-
gen_params={"max_new_tokens": int(max_new_tokens), "temperature": float(temperature)},
|
| 255 |
-
timeout_s=int(timeout_s)
|
| 256 |
-
)
|
| 257 |
-
all_preds[mid] = preds
|
| 258 |
-
m = compute_metrics(task, preds, refs if task in ["summarization", "translation_en_fr"] else labels, label_names=None)
|
| 259 |
-
m["avg_latency_s"] = avg_lat
|
| 260 |
-
metrics_table.append({"model": mid, **m})
|
| 261 |
-
|
| 262 |
-
preview = pd.DataFrame({"text": df[text_col].astype(str).tolist()})
|
| 263 |
-
if labels is not None:
|
| 264 |
-
preview["label"] = labels
|
| 265 |
-
if refs is not None:
|
| 266 |
-
preview["reference"] = refs
|
| 267 |
-
for mid, preds in all_preds.items():
|
| 268 |
-
col = []
|
| 269 |
-
for p in preds:
|
| 270 |
-
if isinstance(p, dict):
|
| 271 |
-
col.append(p.get("summary_text") or p.get("translation_text") or p.get("label") or str(p))
|
| 272 |
-
elif isinstance(p, list) and len(p) and isinstance(p[0], dict):
|
| 273 |
-
col.append(p[0].get("summary_text") or p[0].get("translation_text") or p[0].get("label") or str(p[0]))
|
| 274 |
-
else:
|
| 275 |
-
col.append(str(p))
|
| 276 |
-
preview[mid] = col
|
| 277 |
-
|
| 278 |
-
csv_buf = io.StringIO()
|
| 279 |
-
preview.to_csv(csv_buf, index=False)
|
| 280 |
-
csv_bytes = io.BytesIO(csv_buf.getvalue().encode("utf-8"))
|
| 281 |
-
|
| 282 |
-
lints = [lint_model(m, hf_token) for m in models]
|
| 283 |
-
|
| 284 |
-
return {
|
| 285 |
-
"metrics": pd.DataFrame(metrics_table),
|
| 286 |
-
"preview": preview.head(20),
|
| 287 |
-
"download": ("predictions.csv", csv_bytes),
|
| 288 |
-
"lint": lints,
|
| 289 |
-
"session": {"task": task, "models": models, "dataset": meta, "sample_size": sample_size}
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
def build_ui():
|
| 293 |
-
# Use Interface instead of Blocks to avoid the JSON schema parsing bug
|
| 294 |
-
def benchmark_interface(hf_token, task, curated_models_text, custom_model, ds_id, ds_config, ds_split,
|
| 295 |
-
text_col, label_col, ref_col, sample_size, zs_labels_csv, max_new_tokens,
|
| 296 |
-
temperature, timeout_s, csv_file=None):
|
| 297 |
|
| 298 |
-
|
| 299 |
-
curated_models = [m.strip() for m in curated_models_text.split('\n') if m.strip()] if curated_models_text else []
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
return "Error: Please provide a valid HF token", "", "", ""
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
ds_split=ds_split,
|
| 316 |
-
csv_file=csv_file,
|
| 317 |
-
text_col=text_col,
|
| 318 |
-
label_col=label_col,
|
| 319 |
-
ref_col=ref_col,
|
| 320 |
-
sample_size=int(sample_size),
|
| 321 |
-
zs_labels_csv=zs_labels_csv,
|
| 322 |
-
max_new_tokens=int(max_new_tokens),
|
| 323 |
-
temperature=float(temperature),
|
| 324 |
-
batch_size=8,
|
| 325 |
-
timeout_s=int(timeout_s)
|
| 326 |
-
)
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
|
|
|
| 330 |
|
| 331 |
-
#
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
|
| 337 |
-
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
return demo
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
# Simple launch configuration for HF Spaces
|
| 379 |
if __name__ == "__main__":
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
app.launch()
|
| 383 |
-
else:
|
| 384 |
-
# Local development
|
| 385 |
-
try:
|
| 386 |
-
app.launch(share=True)
|
| 387 |
-
except:
|
| 388 |
-
app.launch(server_name="127.0.0.1", server_port=7860)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
from datasets import load_dataset
|
| 8 |
+
from huggingface_hub import InferenceClient
|
| 9 |
import evaluate
|
| 10 |
|
| 11 |
+
# Load evaluation metrics
|
| 12 |
+
rouge = evaluate.load("rouge")
|
| 13 |
+
sacrebleu = evaluate.load("sacrebleu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
# Model configurations
|
| 16 |
+
MODELS = {
|
| 17 |
+
"sentiment": [
|
| 18 |
+
"distilbert-base-uncased-finetuned-sst-2-english",
|
| 19 |
+
"cardiffnlp/twitter-roberta-base-sentiment-latest"
|
| 20 |
+
],
|
| 21 |
+
"summarization": [
|
| 22 |
+
"facebook/bart-large-cnn",
|
| 23 |
+
"google/pegasus-xsum"
|
| 24 |
+
],
|
| 25 |
+
"translation": [
|
| 26 |
+
"Helsinki-NLP/opus-mt-en-fr",
|
| 27 |
+
"facebook/m2m100_418M"
|
| 28 |
+
]
|
| 29 |
}
|
| 30 |
|
| 31 |
+
def validate_token(token: str) -> bool:
|
| 32 |
+
"""Validate HF token format"""
|
| 33 |
+
return token and token.strip().startswith("hf_")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def accuracy_score(predictions: List[str], labels: List[str]) -> float:
|
| 36 |
+
"""Calculate accuracy without sklearn"""
|
| 37 |
+
if len(predictions) != len(labels):
|
| 38 |
+
return 0.0
|
| 39 |
+
correct = sum(1 for p, l in zip(predictions, labels) if str(p).lower() == str(l).lower())
|
| 40 |
+
return correct / len(labels) if labels else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
def run_inference(model_id: str, texts: List[str], task: str, token: str) -> List[Dict]:
|
| 43 |
+
"""Run inference using HF Inference API"""
|
| 44 |
client = InferenceClient(model=model_id, token=token)
|
| 45 |
+
results = []
|
| 46 |
+
|
| 47 |
+
for text in texts:
|
|
|
|
| 48 |
try:
|
| 49 |
+
if task == "sentiment":
|
| 50 |
+
result = client.text_classification(text)
|
| 51 |
+
results.append(result[0] if isinstance(result, list) else result)
|
| 52 |
+
elif task == "summarization":
|
| 53 |
+
result = client.summarization(text, max_length=150)
|
| 54 |
+
results.append(result)
|
| 55 |
+
elif task == "translation":
|
| 56 |
+
result = client.translation(text, src_lang="en", tgt_lang="fr")
|
| 57 |
+
results.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
else:
|
| 59 |
+
results.append({"error": "Unsupported task"})
|
| 60 |
except Exception as e:
|
| 61 |
+
results.append({"error": str(e)})
|
| 62 |
+
|
| 63 |
+
return results
|
| 64 |
|
| 65 |
+
def compute_metrics(task: str, predictions: List[Dict], references: Optional[List[str]] = None) -> Dict[str, float]:
|
| 66 |
+
"""Compute task-specific metrics"""
|
| 67 |
metrics = {}
|
| 68 |
+
|
| 69 |
+
if task == "sentiment" and references:
|
| 70 |
+
pred_labels = []
|
| 71 |
+
for pred in predictions:
|
| 72 |
+
if isinstance(pred, dict) and "label" in pred:
|
| 73 |
+
pred_labels.append(pred["label"])
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
+
pred_labels.append("UNKNOWN")
|
| 76 |
+
|
| 77 |
+
metrics["accuracy"] = accuracy_score(pred_labels, references)
|
| 78 |
+
|
| 79 |
+
elif task == "summarization" and references:
|
| 80 |
+
pred_texts = []
|
| 81 |
+
for pred in predictions:
|
| 82 |
+
if isinstance(pred, dict) and "summary_text" in pred:
|
| 83 |
+
pred_texts.append(pred["summary_text"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
else:
|
| 85 |
+
pred_texts.append("")
|
| 86 |
+
|
| 87 |
+
rouge_scores = rouge.compute(predictions=pred_texts, references=references)
|
| 88 |
+
metrics.update(rouge_scores)
|
| 89 |
+
|
| 90 |
+
elif task == "translation" and references:
|
| 91 |
+
pred_texts = []
|
| 92 |
+
for pred in predictions:
|
| 93 |
+
if isinstance(pred, dict) and "translation_text" in pred:
|
| 94 |
+
pred_texts.append(pred["translation_text"])
|
|
|
|
| 95 |
else:
|
| 96 |
+
pred_texts.append("")
|
| 97 |
+
|
| 98 |
+
bleu_scores = sacrebleu.compute(predictions=pred_texts, references=[[ref] for ref in references])
|
| 99 |
+
metrics.update(bleu_scores)
|
| 100 |
+
|
| 101 |
return metrics
|
| 102 |
|
| 103 |
+
def benchmark_models(
|
| 104 |
+
hf_token: str,
|
| 105 |
+
task: str,
|
| 106 |
+
selected_models: List[str],
|
| 107 |
+
dataset_name: str,
|
| 108 |
+
sample_size: int,
|
| 109 |
+
text_column: str,
|
| 110 |
+
label_column: str
|
| 111 |
+
):
|
| 112 |
+
"""Main benchmarking function"""
|
| 113 |
+
|
| 114 |
+
# Validate token
|
| 115 |
+
if not validate_token(hf_token):
|
| 116 |
+
return "β Invalid HuggingFace token. Please provide a token starting with 'hf_'", "", ""
|
| 117 |
+
|
| 118 |
+
if not selected_models:
|
| 119 |
+
return "β Please select at least one model", "", ""
|
| 120 |
+
|
| 121 |
try:
|
| 122 |
+
# Load dataset
|
| 123 |
+
dataset = load_dataset(dataset_name, split="test")
|
| 124 |
+
if sample_size > 0:
|
| 125 |
+
dataset = dataset.select(range(min(sample_size, len(dataset))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
df = dataset.to_pandas()
|
|
|
|
| 128 |
|
| 129 |
+
if text_column not in df.columns:
|
| 130 |
+
return f"β Text column '{text_column}' not found in dataset", "", ""
|
|
|
|
| 131 |
|
| 132 |
+
texts = df[text_column].astype(str).tolist()
|
| 133 |
+
references = df[label_column].tolist() if label_column in df.columns else None
|
| 134 |
+
|
| 135 |
+
# Results storage
|
| 136 |
+
all_results = []
|
| 137 |
+
detailed_results = {"text": texts}
|
| 138 |
+
|
| 139 |
+
# Run benchmarks
|
| 140 |
+
for model_id in selected_models:
|
| 141 |
+
print(f"Running inference with {model_id}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
start_time = time.time()
|
| 144 |
+
predictions = run_inference(model_id, texts, task, hf_token)
|
| 145 |
+
inference_time = time.time() - start_time
|
| 146 |
|
| 147 |
+
# Compute metrics
|
| 148 |
+
metrics = compute_metrics(task, predictions, references)
|
| 149 |
+
metrics["model"] = model_id
|
| 150 |
+
metrics["inference_time"] = round(inference_time, 2)
|
| 151 |
+
metrics["samples"] = len(texts)
|
| 152 |
|
| 153 |
+
all_results.append(metrics)
|
| 154 |
|
| 155 |
+
# Store predictions for detailed view
|
| 156 |
+
pred_texts = []
|
| 157 |
+
for pred in predictions:
|
| 158 |
+
if isinstance(pred, dict):
|
| 159 |
+
if "label" in pred:
|
| 160 |
+
pred_texts.append(pred["label"])
|
| 161 |
+
elif "summary_text" in pred:
|
| 162 |
+
pred_texts.append(pred["summary_text"])
|
| 163 |
+
elif "translation_text" in pred:
|
| 164 |
+
pred_texts.append(pred["translation_text"])
|
| 165 |
+
else:
|
| 166 |
+
pred_texts.append(str(pred))
|
| 167 |
+
else:
|
| 168 |
+
pred_texts.append(str(pred))
|
| 169 |
+
|
| 170 |
+
detailed_results[model_id] = pred_texts
|
| 171 |
+
|
| 172 |
+
# Create results DataFrames
|
| 173 |
+
results_df = pd.DataFrame(all_results)
|
| 174 |
+
detailed_df = pd.DataFrame(detailed_results)
|
| 175 |
+
|
| 176 |
+
# Format results for display
|
| 177 |
+
results_str = "π **Benchmark Results:**\n\n"
|
| 178 |
+
results_str += results_df.to_string(index=False)
|
| 179 |
+
|
| 180 |
+
detailed_str = "π **Detailed Predictions (first 10 samples):**\n\n"
|
| 181 |
+
detailed_str += detailed_df.head(10).to_string(index=False)
|
| 182 |
+
|
| 183 |
+
# Create summary
|
| 184 |
+
summary = f"β
**Benchmark Complete!**\n\n"
|
| 185 |
+
summary += f"**Task:** {task}\n"
|
| 186 |
+
summary += f"**Dataset:** {dataset_name}\n"
|
| 187 |
+
summary += f"**Models tested:** {len(selected_models)}\n"
|
| 188 |
+
summary += f"**Samples processed:** {len(texts)}\n"
|
| 189 |
+
summary += f"**Total time:** {sum(r['inference_time'] for r in all_results):.2f}s\n"
|
| 190 |
+
|
| 191 |
+
return summary, results_str, detailed_str
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
return f"β Error: {str(e)}", "", ""
|
| 195 |
+
|
| 196 |
+
# Create Gradio interface
|
| 197 |
+
def create_interface():
|
| 198 |
+
with gr.Blocks(title="AI Model Benchmark Hub") as demo:
|
| 199 |
+
gr.Markdown("# π§ͺ AI Model Benchmark Hub")
|
| 200 |
+
gr.Markdown("Compare AI models on various tasks using HuggingFace Inference API")
|
| 201 |
+
|
| 202 |
+
with gr.Row():
|
| 203 |
+
with gr.Column():
|
| 204 |
+
gr.Markdown("### π Authentication")
|
| 205 |
+
hf_token = gr.Textbox(
|
| 206 |
+
label="HuggingFace Token",
|
| 207 |
+
type="password",
|
| 208 |
+
placeholder="hf_...",
|
| 209 |
+
info="Get your token from https://huggingface.co/settings/tokens"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
gr.Markdown("### π Task Selection")
|
| 213 |
+
task = gr.Dropdown(
|
| 214 |
+
choices=["sentiment", "summarization", "translation"],
|
| 215 |
+
label="Task",
|
| 216 |
+
value="sentiment"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
model_choices = gr.CheckboxGroup(
|
| 220 |
+
choices=MODELS["sentiment"],
|
| 221 |
+
label="Select Models",
|
| 222 |
+
value=[MODELS["sentiment"][0]]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def update_models(selected_task):
|
| 226 |
+
return gr.update(choices=MODELS[selected_task], value=[MODELS[selected_task][0]])
|
| 227 |
+
|
| 228 |
+
task.change(update_models, inputs=[task], outputs=[model_choices])
|
| 229 |
+
|
| 230 |
+
with gr.Column():
|
| 231 |
+
gr.Markdown("### π Dataset Configuration")
|
| 232 |
+
dataset_name = gr.Textbox(
|
| 233 |
+
label="Dataset Name",
|
| 234 |
+
value="imdb",
|
| 235 |
+
placeholder="e.g., imdb, amazon_reviews_multi"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
sample_size = gr.Slider(
|
| 239 |
+
minimum=10,
|
| 240 |
+
maximum=1000,
|
| 241 |
+
value=50,
|
| 242 |
+
step=10,
|
| 243 |
+
label="Sample Size"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
text_column = gr.Textbox(
|
| 247 |
+
label="Text Column Name",
|
| 248 |
+
value="text",
|
| 249 |
+
placeholder="e.g., text, review, sentence"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
label_column = gr.Textbox(
|
| 253 |
+
label="Label Column Name (optional)",
|
| 254 |
+
value="label",
|
| 255 |
+
placeholder="e.g., label, sentiment, rating"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
run_btn = gr.Button("π Run Benchmark", variant="primary", size="lg")
|
| 259 |
+
|
| 260 |
+
gr.Markdown("---")
|
| 261 |
+
|
| 262 |
+
with gr.Row():
|
| 263 |
+
with gr.Column():
|
| 264 |
+
summary_output = gr.Markdown(label="Summary")
|
| 265 |
+
|
| 266 |
+
with gr.Row():
|
| 267 |
+
with gr.Column():
|
| 268 |
+
results_output = gr.Markdown(label="Results")
|
| 269 |
+
with gr.Column():
|
| 270 |
+
detailed_output = gr.Markdown(label="Detailed Output")
|
| 271 |
+
|
| 272 |
+
# Connect the interface
|
| 273 |
+
run_btn.click(
|
| 274 |
+
benchmark_models,
|
| 275 |
+
inputs=[
|
| 276 |
+
hf_token,
|
| 277 |
+
task,
|
| 278 |
+
model_choices,
|
| 279 |
+
dataset_name,
|
| 280 |
+
sample_size,
|
| 281 |
+
text_column,
|
| 282 |
+
label_column
|
| 283 |
+
],
|
| 284 |
+
outputs=[summary_output, results_output, detailed_output]
|
| 285 |
+
)
|
| 286 |
|
| 287 |
return demo
|
| 288 |
|
| 289 |
+
# Launch the app
|
|
|
|
|
|
|
| 290 |
if __name__ == "__main__":
|
| 291 |
+
app = create_interface()
|
| 292 |
+
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|