whitphx HF Staff commited on
Commit
945bdba
·
1 Parent(s): fd2180b

Add batchsize param

Browse files
.claude/settings.local.json CHANGED
@@ -15,7 +15,12 @@
15
  "Bash(timeout 120 npm run bench:cli -- Xenova/distilbert-base-uncased feature-extraction --mode warm --repeats 2 --device webgpu --dtype q8)",
16
  "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device wasm --dtype fp32)",
17
  "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device webgpu --dtype fp32)",
18
- "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device wasm --dtype q8)"
 
 
 
 
 
19
  ],
20
  "deny": [],
21
  "ask": []
 
15
  "Bash(timeout 120 npm run bench:cli -- Xenova/distilbert-base-uncased feature-extraction --mode warm --repeats 2 --device webgpu --dtype q8)",
16
  "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device wasm --dtype fp32)",
17
  "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device webgpu --dtype fp32)",
18
+ "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device wasm --dtype q8)",
19
+ "Bash(timeout 180 npm run bench:cli -- Xenova/bert-base-uncased feature-extraction --mode warm --repeats 2 --device webgpu --dtype q8 --batch-size 8)",
20
+ "Bash(timeout 300 npm run bench:cli -- Xenova/gpt2 feature-extraction --mode warm --repeats 2 --device webgpu --dtype q8 --batch-size 1)",
21
+ "Bash(timeout 300 npm run bench -- Xenova/gpt2 feature-extraction --mode warm --repeats 2 --dtype q8 --batch-size 1 --cache-dir .bench-cache/warm)",
22
+ "Bash(timeout 300 npm run bench:cli -- Xenova/t5-small feature-extraction --mode warm --repeats 2 --device webgpu --dtype fp32 --batch-size 1)",
23
+ "Bash(timeout 300 npm run bench:cli -- Xenova/t5-small feature-extraction --mode warm --repeats 2 --device wasm --dtype fp32 --batch-size 1)"
24
  ],
25
  "deny": [],
26
  "ask": []
bench-node/src/index.ts CHANGED
@@ -18,6 +18,7 @@ const mode = (getArg("mode", "warm") as "warm" | "cold");
18
  const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
19
  const cacheDir = getArg("cache-dir", path.resolve(".bench-cache/default"))!;
20
  const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
 
21
 
22
  // Point library cache to a dedicated directory for controllable cold/warm behavior
23
  env.cacheDir = cacheDir;
@@ -41,15 +42,18 @@ async function benchOnce() {
41
  const pipe = await pipeline(task, modelId, options);
42
  const t1 = performance.now();
43
 
 
 
 
44
  const t2 = performance.now();
45
- await pipe("The quick brown fox jumps over the lazy dog.");
46
  const t3 = performance.now();
47
 
48
  // Run additional inferences to measure subsequent performance
49
  const subsequentTimes: number[] = [];
50
  for (let i = 0; i < 3; i++) {
51
  const t4 = performance.now();
52
- await pipe("The quick brown fox jumps over the lazy dog.");
53
  const t5 = performance.now();
54
  subsequentTimes.push(+(t5 - t4).toFixed(1));
55
  }
@@ -62,12 +66,13 @@ async function benchOnce() {
62
  }
63
 
64
  async function main() {
65
- console.log(`Model : ${modelId}`);
66
- console.log(`Task : ${task}`);
67
- console.log(`Mode : ${mode}`);
68
- console.log(`Repeats: ${repeats}`);
69
- console.log(`DType : ${dtype || 'auto'}`);
70
- console.log(`Cache : ${cacheDir}`);
 
71
 
72
  const loads: number[] = [];
73
  const firsts: number[] = [];
@@ -79,7 +84,8 @@ async function main() {
79
  const warmOptions: any = {};
80
  if (dtype) warmOptions.dtype = dtype;
81
  const warm = await pipeline(task, modelId, warmOptions);
82
- await warm("warmup");
 
83
 
84
  for (let i = 0; i < repeats; i++) {
85
  const r = await benchOnce();
@@ -105,6 +111,7 @@ async function main() {
105
  task,
106
  mode,
107
  repeats,
 
108
  cacheDir,
109
  metrics: {
110
  load_ms: { p50: +percentile(loads, 0.5).toFixed(1), p90: +percentile(loads, 0.9).toFixed(1), raw: loads },
 
18
  const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
19
  const cacheDir = getArg("cache-dir", path.resolve(".bench-cache/default"))!;
20
  const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
21
+ const batchSize = Math.max(1, parseInt(getArg("batch-size", "1") || "1", 10));
22
 
23
  // Point library cache to a dedicated directory for controllable cold/warm behavior
24
  env.cacheDir = cacheDir;
 
42
  const pipe = await pipeline(task, modelId, options);
43
  const t1 = performance.now();
44
 
45
+ // Prepare batch input
46
+ const inputs = Array(batchSize).fill("The quick brown fox jumps over the lazy dog.");
47
+
48
  const t2 = performance.now();
49
+ await pipe(inputs);
50
  const t3 = performance.now();
51
 
52
  // Run additional inferences to measure subsequent performance
53
  const subsequentTimes: number[] = [];
54
  for (let i = 0; i < 3; i++) {
55
  const t4 = performance.now();
56
+ await pipe(inputs);
57
  const t5 = performance.now();
58
  subsequentTimes.push(+(t5 - t4).toFixed(1));
59
  }
 
66
  }
67
 
68
  async function main() {
69
+ console.log(`Model : ${modelId}`);
70
+ console.log(`Task : ${task}`);
71
+ console.log(`Mode : ${mode}`);
72
+ console.log(`Repeats : ${repeats}`);
73
+ console.log(`DType : ${dtype || 'auto'}`);
74
+ console.log(`Batch Size: ${batchSize}`);
75
+ console.log(`Cache : ${cacheDir}`);
76
 
77
  const loads: number[] = [];
78
  const firsts: number[] = [];
 
84
  const warmOptions: any = {};
85
  if (dtype) warmOptions.dtype = dtype;
86
  const warm = await pipeline(task, modelId, warmOptions);
87
+ const warmupInputs = Array(batchSize).fill("warmup");
88
+ await warm(warmupInputs);
89
 
90
  for (let i = 0; i < repeats; i++) {
91
  const r = await benchOnce();
 
111
  task,
112
  mode,
113
  repeats,
114
+ batchSize,
115
  cacheDir,
116
  metrics: {
117
  load_ms: { p50: +percentile(loads, 0.5).toFixed(1), p90: +percentile(loads, 0.9).toFixed(1), raw: loads },
bench-web/src/cli.ts CHANGED
@@ -16,18 +16,20 @@ const mode = getArg("mode", "warm") as "warm" | "cold";
16
  const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
17
  const device = getArg("device", "webgpu") as "webgpu" | "wasm";
18
  const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
 
19
  const browserType = getArg("browser", "chromium") as "chromium" | "firefox" | "webkit";
20
  const headed = getArg("headed") === "true";
21
 
22
  async function main() {
23
- console.log(`Model : ${modelId}`);
24
- console.log(`Task : ${task}`);
25
- console.log(`Mode : ${mode}`);
26
- console.log(`Repeats : ${repeats}`);
27
- console.log(`Device : ${device}`);
28
- console.log(`DType : ${dtype || 'auto'}`);
29
- console.log(`Browser : ${browserType}`);
30
- console.log(`Headed : ${headed}`);
 
31
 
32
  // Start Vite dev server
33
  const server = await createServer({
@@ -88,23 +90,49 @@ async function main() {
88
 
89
  // Check WebGPU availability if using webgpu device
90
  if (device === "webgpu") {
91
- const gpuAvailable = await page.evaluate(() => {
92
- return 'gpu' in navigator;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  });
94
 
95
- if (!gpuAvailable) {
96
  console.error("\n❌ WebGPU is not available in this browser!");
97
  console.error("Make sure to use --enable-unsafe-webgpu flag for Chromium.");
 
98
  throw new Error("WebGPU not available");
99
  }
100
 
101
  console.log("✓ WebGPU is available");
 
 
 
 
102
  }
103
 
104
  // Use the exposed CLI function from main.ts
105
- const result = await page.evaluate(({ modelId, task, mode, repeats, device, dtype }) => {
106
- return (window as any).runBenchmarkCLI({ modelId, task, mode, repeats, device, dtype });
107
- }, { modelId, task, mode, repeats, device, dtype });
108
 
109
  console.log("\n" + JSON.stringify(result, null, 2));
110
 
 
16
  const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
17
  const device = getArg("device", "webgpu") as "webgpu" | "wasm";
18
  const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
19
+ const batchSize = Math.max(1, parseInt(getArg("batch-size", "1") || "1", 10));
20
  const browserType = getArg("browser", "chromium") as "chromium" | "firefox" | "webkit";
21
  const headed = getArg("headed") === "true";
22
 
23
  async function main() {
24
+ console.log(`Model : ${modelId}`);
25
+ console.log(`Task : ${task}`);
26
+ console.log(`Mode : ${mode}`);
27
+ console.log(`Repeats : ${repeats}`);
28
+ console.log(`Device : ${device}`);
29
+ console.log(`DType : ${dtype || 'auto'}`);
30
+ console.log(`Batch Size : ${batchSize}`);
31
+ console.log(`Browser : ${browserType}`);
32
+ console.log(`Headed : ${headed}`);
33
 
34
  // Start Vite dev server
35
  const server = await createServer({
 
90
 
91
  // Check WebGPU availability if using webgpu device
92
  if (device === "webgpu") {
93
+ const gpuInfo = await page.evaluate(async () => {
94
+ if (!('gpu' in navigator)) {
95
+ return { available: false, adapter: null, features: null };
96
+ }
97
+ try {
98
+ const adapter = await (navigator as any).gpu.requestAdapter();
99
+ if (!adapter) {
100
+ return { available: false, adapter: null, features: null };
101
+ }
102
+ const features = Array.from(adapter.features || []);
103
+ const limits = adapter.limits ? {
104
+ maxTextureDimension2D: adapter.limits.maxTextureDimension2D,
105
+ maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
106
+ } : null;
107
+ return {
108
+ available: true,
109
+ adapterInfo: adapter.info ? adapter.info.description : 'Unknown',
110
+ features,
111
+ limits
112
+ };
113
+ } catch (e) {
114
+ return { available: false, adapter: null, error: String(e) };
115
+ }
116
  });
117
 
118
+ if (!gpuInfo.available) {
119
  console.error("\n❌ WebGPU is not available in this browser!");
120
  console.error("Make sure to use --enable-unsafe-webgpu flag for Chromium.");
121
+ if (gpuInfo.error) console.error("Error:", gpuInfo.error);
122
  throw new Error("WebGPU not available");
123
  }
124
 
125
  console.log("✓ WebGPU is available");
126
+ console.log(` Adapter: ${gpuInfo.adapterInfo}`);
127
+ if (gpuInfo.features && gpuInfo.features.length > 0) {
128
+ console.log(` Features: ${gpuInfo.features.slice(0, 3).join(', ')}${gpuInfo.features.length > 3 ? '...' : ''}`);
129
+ }
130
  }
131
 
132
  // Use the exposed CLI function from main.ts
133
+ const result = await page.evaluate(({ modelId, task, mode, repeats, device, dtype, batchSize }) => {
134
+ return (window as any).runBenchmarkCLI({ modelId, task, mode, repeats, device, dtype, batchSize });
135
+ }, { modelId, task, mode, repeats, device, dtype, batchSize });
136
 
137
  console.log("\n" + JSON.stringify(result, null, 2));
138
 
bench-web/src/main.ts CHANGED
@@ -36,21 +36,25 @@ async function clearCaches({ clearSession = false }: { clearSession?: boolean }
36
  if (clearSession) sessionStorage.clear();
37
  } catch { }
38
  }
39
- async function benchOnce(modelId: string, task: string, device: string, dtype?: string) {
40
  const t0 = now();
41
  const options: any = { device };
42
  if (dtype) options.dtype = dtype;
43
  const pipe = await pipeline(task, modelId, options);
44
  const t1 = now();
 
 
 
 
45
  const t2 = now();
46
- await pipe("The quick brown fox jumps over the lazy dog.");
47
  const t3 = now();
48
 
49
  // Run additional inferences to measure subsequent performance
50
  const subsequentTimes: number[] = [];
51
  for (let i = 0; i < 3; i++) {
52
  const t4 = now();
53
- await pipe("The quick brown fox jumps over the lazy dog.");
54
  const t5 = now();
55
  subsequentTimes.push(+(t5 - t4).toFixed(1));
56
  }
@@ -61,12 +65,12 @@ async function benchOnce(modelId: string, task: string, device: string, dtype?:
61
  subsequent_infer_ms: subsequentTimes
62
  };
63
  }
64
- async function runMany(modelId: string, task: string, repeats: number, device: string, dtype?: string) {
65
  const loads: number[] = [];
66
  const firsts: number[] = [];
67
  const subsequents: number[] = [];
68
  for (let i = 0; i < repeats; i++) {
69
- const r = await benchOnce(modelId, task, device, dtype);
70
  loads.push(r.load_ms);
71
  firsts.push(r.first_infer_ms);
72
  subsequents.push(...r.subsequent_infer_ms);
@@ -77,16 +81,17 @@ async function runMany(modelId: string, task: string, repeats: number, device: s
77
  subsequent_infer_ms: { p50: +percentile(subsequents, 0.5).toFixed(1), p90: +percentile(subsequents, 0.9).toFixed(1), raw: subsequents },
78
  };
79
  }
80
- async function runCold(modelId: string, task: string, repeats: number, device: string, dtype?: string) {
81
  statusEl.textContent = "clearing caches (cold)...";
82
  await clearCaches();
83
  statusEl.textContent = "running (cold)...";
84
- const metrics = await runMany(modelId, task, repeats, device, dtype);
85
  const result: any = {
86
  platform: "browser",
87
  runtime: navigator.userAgent,
88
  mode: "cold",
89
  repeats,
 
90
  model: modelId,
91
  task,
92
  device,
@@ -96,19 +101,21 @@ async function runCold(modelId: string, task: string, repeats: number, device: s
96
  if (dtype) result.dtype = dtype;
97
  return result;
98
  }
99
- async function runWarmDirect(modelId: string, task: string, repeats: number, device: string, dtype?: string) {
100
  statusEl.textContent = "prefetching (warmup) ...";
101
  const options: any = { device };
102
  if (dtype) options.dtype = dtype;
103
  const p = await pipeline(task, modelId, options);
104
- await p("warmup");
 
105
  statusEl.textContent = "running (warm)...";
106
- const metrics = await runMany(modelId, task, repeats, device, dtype);
107
  const result: any = {
108
  platform: "browser",
109
  runtime: navigator.userAgent,
110
  mode: "warm",
111
  repeats,
 
112
  model: modelId,
113
  task,
114
  device,
@@ -117,20 +124,21 @@ async function runWarmDirect(modelId: string, task: string, repeats: number, dev
117
  if (dtype) result.dtype = dtype;
118
  return result;
119
  }
120
- async function runWarm(modelId: string, task: string, repeats: number, device: string, dtype?: string) {
121
  const flag = sessionStorage.getItem("__warm_ready__");
122
  if (!flag) {
123
  statusEl.textContent = "prefetching (warmup) ...";
124
  const options: any = { device };
125
  if (dtype) options.dtype = dtype;
126
  const p = await pipeline(task, modelId, options);
127
- await p("warmup");
128
- sessionStorage.setItem("__warm_ready__", JSON.stringify({ modelId, task, repeats, device, dtype }));
 
129
  location.reload();
130
  return null;
131
  } else {
132
  sessionStorage.removeItem("__warm_ready__");
133
- return await runWarmDirect(modelId, task, repeats, device, dtype);
134
  }
135
  }
136
  async function run() {
@@ -160,11 +168,12 @@ btn.addEventListener("click", () => {
160
  });
161
 
162
  // Expose for CLI use
163
- (window as any).runBenchmarkCLI = async function (params: { modelId: string, task: string, mode: string, repeats: number, device: string, dtype?: string }) {
 
164
  if (params.mode === "cold") {
165
- return await runCold(params.modelId, params.task, params.repeats, params.device, params.dtype);
166
  } else {
167
  // For warm, use the direct function that skips reload logic
168
- return await runWarmDirect(params.modelId, params.task, params.repeats, params.device, params.dtype);
169
  }
170
  };
 
36
  if (clearSession) sessionStorage.clear();
37
  } catch { }
38
  }
39
+ async function benchOnce(modelId: string, task: string, device: string, dtype?: string, batchSize: number = 1) {
40
  const t0 = now();
41
  const options: any = { device };
42
  if (dtype) options.dtype = dtype;
43
  const pipe = await pipeline(task, modelId, options);
44
  const t1 = now();
45
+
46
+ // Prepare batch input
47
+ const inputs = Array(batchSize).fill("The quick brown fox jumps over the lazy dog.");
48
+
49
  const t2 = now();
50
+ await pipe(inputs);
51
  const t3 = now();
52
 
53
  // Run additional inferences to measure subsequent performance
54
  const subsequentTimes: number[] = [];
55
  for (let i = 0; i < 3; i++) {
56
  const t4 = now();
57
+ await pipe(inputs);
58
  const t5 = now();
59
  subsequentTimes.push(+(t5 - t4).toFixed(1));
60
  }
 
65
  subsequent_infer_ms: subsequentTimes
66
  };
67
  }
68
+ async function runMany(modelId: string, task: string, repeats: number, device: string, dtype?: string, batchSize: number = 1) {
69
  const loads: number[] = [];
70
  const firsts: number[] = [];
71
  const subsequents: number[] = [];
72
  for (let i = 0; i < repeats; i++) {
73
+ const r = await benchOnce(modelId, task, device, dtype, batchSize);
74
  loads.push(r.load_ms);
75
  firsts.push(r.first_infer_ms);
76
  subsequents.push(...r.subsequent_infer_ms);
 
81
  subsequent_infer_ms: { p50: +percentile(subsequents, 0.5).toFixed(1), p90: +percentile(subsequents, 0.9).toFixed(1), raw: subsequents },
82
  };
83
  }
84
+ async function runCold(modelId: string, task: string, repeats: number, device: string, dtype?: string, batchSize: number = 1) {
85
  statusEl.textContent = "clearing caches (cold)...";
86
  await clearCaches();
87
  statusEl.textContent = "running (cold)...";
88
+ const metrics = await runMany(modelId, task, repeats, device, dtype, batchSize);
89
  const result: any = {
90
  platform: "browser",
91
  runtime: navigator.userAgent,
92
  mode: "cold",
93
  repeats,
94
+ batchSize,
95
  model: modelId,
96
  task,
97
  device,
 
101
  if (dtype) result.dtype = dtype;
102
  return result;
103
  }
104
+ async function runWarmDirect(modelId: string, task: string, repeats: number, device: string, dtype?: string, batchSize: number = 1) {
105
  statusEl.textContent = "prefetching (warmup) ...";
106
  const options: any = { device };
107
  if (dtype) options.dtype = dtype;
108
  const p = await pipeline(task, modelId, options);
109
+ const warmupInputs = Array(batchSize).fill("warmup");
110
+ await p(warmupInputs);
111
  statusEl.textContent = "running (warm)...";
112
+ const metrics = await runMany(modelId, task, repeats, device, dtype, batchSize);
113
  const result: any = {
114
  platform: "browser",
115
  runtime: navigator.userAgent,
116
  mode: "warm",
117
  repeats,
118
+ batchSize,
119
  model: modelId,
120
  task,
121
  device,
 
124
  if (dtype) result.dtype = dtype;
125
  return result;
126
  }
127
+ async function runWarm(modelId: string, task: string, repeats: number, device: string, dtype?: string, batchSize: number = 1) {
128
  const flag = sessionStorage.getItem("__warm_ready__");
129
  if (!flag) {
130
  statusEl.textContent = "prefetching (warmup) ...";
131
  const options: any = { device };
132
  if (dtype) options.dtype = dtype;
133
  const p = await pipeline(task, modelId, options);
134
+ const warmupInputs = Array(batchSize).fill("warmup");
135
+ await p(warmupInputs);
136
+ sessionStorage.setItem("__warm_ready__", JSON.stringify({ modelId, task, repeats, device, dtype, batchSize }));
137
  location.reload();
138
  return null;
139
  } else {
140
  sessionStorage.removeItem("__warm_ready__");
141
+ return await runWarmDirect(modelId, task, repeats, device, dtype, batchSize);
142
  }
143
  }
144
  async function run() {
 
168
  });
169
 
170
  // Expose for CLI use
171
+ (window as any).runBenchmarkCLI = async function (params: { modelId: string, task: string, mode: string, repeats: number, device: string, dtype?: string, batchSize?: number }) {
172
+ const batchSize = params.batchSize || 1;
173
  if (params.mode === "cold") {
174
+ return await runCold(params.modelId, params.task, params.repeats, params.device, params.dtype, batchSize);
175
  } else {
176
  // For warm, use the direct function that skips reload logic
177
+ return await runWarmDirect(params.modelId, params.task, params.repeats, params.device, params.dtype, batchSize);
178
  }
179
  };