davanstrien HF staff commited on
Commit
42e0474
·
1 Parent(s): ada4842

add param_count to model metadata and update requirements for compatibility

Browse files
Files changed (3) hide show
  1. main.py +24 -5
  2. requirements.in +1 -1
  3. requirements.txt +26 -8
main.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import sys
5
  from contextlib import asynccontextmanager
6
  from datetime import datetime
7
- from typing import List
8
 
9
  import chromadb
10
  import dateutil.parser
@@ -194,7 +194,14 @@ def setup_database():
194
 
195
  if model_collection.count() < model_row_count:
196
  model_df = model_df.select(
197
- ["modelId", "summary", "likes", "downloads", "last_modified"]
 
 
 
 
 
 
 
198
  )
199
  model_df = model_df.collect()
200
  total_rows = len(model_df)
@@ -210,11 +217,15 @@ def setup_database():
210
  "likes": int(likes),
211
  "downloads": int(downloads),
212
  "last_modified": str(last_modified),
 
 
 
213
  }
214
- for likes, downloads, last_modified in zip(
215
  batch_df.select(["likes"]).to_series().to_list(),
216
  batch_df.select(["downloads"]).to_series().to_list(),
217
  batch_df.select(["last_modified"]).to_series().to_list(),
 
218
  )
219
  ],
220
  )
@@ -252,6 +263,7 @@ class ModelQueryResult(BaseModel):
252
  summary: str
253
  likes: int
254
  downloads: int
 
255
 
256
 
257
  class ModelQueryResponse(BaseModel):
@@ -471,6 +483,10 @@ async def process_search_results(results, id_field, k, sort_by, exclude_id=None)
471
  "downloads": results["metadatas"][0][i]["downloads"],
472
  }
473
 
 
 
 
 
474
  if id_field == "dataset":
475
  query_results.append(QueryResult(**result))
476
  else:
@@ -546,21 +562,24 @@ async def get_trending_models_with_summaries(
546
 
547
  # Fetch summaries from ChromaDB
548
  collection = client.get_collection("model_cards")
549
- summaries = collection.get(ids=model_ids, include=["documents"])
550
 
551
- # Create mapping of model_id to summary
552
  id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
 
553
 
554
  # Combine data
555
  results = []
556
  for model in trending_models:
557
  if model["modelId"] in id_to_summary:
 
558
  result = ModelQueryResult(
559
  model_id=model["modelId"],
560
  similarity=1.0, # Not applicable for trending
561
  summary=id_to_summary[model["modelId"]],
562
  likes=model.get("likes", 0),
563
  downloads=model.get("downloads", 0),
 
564
  )
565
  results.append(result)
566
 
 
4
  import sys
5
  from contextlib import asynccontextmanager
6
  from datetime import datetime
7
+ from typing import List, Optional
8
 
9
  import chromadb
10
  import dateutil.parser
 
194
 
195
  if model_collection.count() < model_row_count:
196
  model_df = model_df.select(
197
+ [
198
+ "modelId",
199
+ "summary",
200
+ "likes",
201
+ "downloads",
202
+ "last_modified",
203
+ "param_count",
204
+ ]
205
  )
206
  model_df = model_df.collect()
207
  total_rows = len(model_df)
 
217
  "likes": int(likes),
218
  "downloads": int(downloads),
219
  "last_modified": str(last_modified),
220
+ "param_count": int(param_count)
221
+ if param_count is not None
222
+ else 0,
223
  }
224
+ for likes, downloads, last_modified, param_count in zip(
225
  batch_df.select(["likes"]).to_series().to_list(),
226
  batch_df.select(["downloads"]).to_series().to_list(),
227
  batch_df.select(["last_modified"]).to_series().to_list(),
228
+ batch_df.select(["param_count"]).to_series().to_list(),
229
  )
230
  ],
231
  )
 
263
  summary: str
264
  likes: int
265
  downloads: int
266
+ param_count: Optional[int] = None
267
 
268
 
269
  class ModelQueryResponse(BaseModel):
 
483
  "downloads": results["metadatas"][0][i]["downloads"],
484
  }
485
 
486
+ # Add param_count for models if it exists in metadata
487
+ if id_field == "model" and "param_count" in results["metadatas"][0][i]:
488
+ result["param_count"] = results["metadatas"][0][i]["param_count"]
489
+
490
  if id_field == "dataset":
491
  query_results.append(QueryResult(**result))
492
  else:
 
562
 
563
  # Fetch summaries from ChromaDB
564
  collection = client.get_collection("model_cards")
565
+ summaries = collection.get(ids=model_ids, include=["documents", "metadatas"])
566
 
567
+ # Create mapping of model_id to summary and metadata
568
  id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
569
+ id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"]))
570
 
571
  # Combine data
572
  results = []
573
  for model in trending_models:
574
  if model["modelId"] in id_to_summary:
575
+ metadata = id_to_metadata.get(model["modelId"], {})
576
  result = ModelQueryResult(
577
  model_id=model["modelId"],
578
  similarity=1.0, # Not applicable for trending
579
  summary=id_to_summary[model["modelId"]],
580
  likes=model.get("likes", 0),
581
  downloads=model.get("downloads", 0),
582
+ param_count=metadata.get("param_count"),
583
  )
584
  results.append(result)
585
 
requirements.in CHANGED
@@ -1,6 +1,6 @@
1
  aiohttp
2
  cashews
3
- chromadb
4
  datasets
5
  einops
6
  fastapi
 
1
  aiohttp
2
  cashews
3
+ chromadb==1.0.0b0
4
  datasets
5
  einops
6
  fastapi
requirements.txt CHANGED
@@ -19,10 +19,13 @@ anyio==4.8.0
19
  asgiref==3.8.1
20
  # via opentelemetry-instrumentation-asgi
21
  attrs==25.1.0
22
- # via aiohttp
 
 
 
23
  backoff==2.2.1
24
  # via posthog
25
- bcrypt==4.2.1
26
  # via chromadb
27
  build==1.2.2.post1
28
  # via chromadb
@@ -40,7 +43,7 @@ charset-normalizer==3.4.1
40
  # via requests
41
  chroma-hnswlib==0.7.6
42
  # via chromadb
43
- chromadb==0.6.3
44
  # via -r requirements.in
45
  click==8.1.8
46
  # via
@@ -59,11 +62,13 @@ dill==0.3.8
59
  # via
60
  # datasets
61
  # multiprocess
 
 
62
  durationpy==0.9
63
  # via kubernetes
64
  einops==0.8.1
65
  # via -r requirements.in
66
- fastapi==0.115.8
67
  # via
68
  # -r requirements.in
69
  # chromadb
@@ -135,6 +140,10 @@ jinja2==3.1.5
135
  # via torch
136
  joblib==1.4.2
137
  # via scikit-learn
 
 
 
 
138
  kubernetes==32.0.1
139
  # via chromadb
140
  markdown-it-py==3.0.0
@@ -228,9 +237,9 @@ pandas==2.2.3
228
  # via datasets
229
  pillow==11.1.0
230
  # via sentence-transformers
231
- polars==1.23.0
232
  # via -r requirements.in
233
- posthog==3.15.1
234
  # via chromadb
235
  propcache==0.3.0
236
  # via
@@ -281,6 +290,10 @@ pyyaml==6.0.2
281
  # kubernetes
282
  # transformers
283
  # uvicorn
 
 
 
 
284
  regex==2024.11.6
285
  # via transformers
286
  requests==2.32.3
@@ -297,6 +310,10 @@ rich==13.9.4
297
  # via
298
  # chromadb
299
  # typer
 
 
 
 
300
  rsa==4.9
301
  # via google-auth
302
  safetensors==0.5.3
@@ -309,7 +326,7 @@ scipy==1.15.2
309
  # sentence-transformers
310
  sentence-transformers==3.4.1
311
  # via -r requirements.in
312
- setuptools==75.8.1
313
  # via torch
314
  shellingham==1.5.4
315
  # via typer
@@ -350,7 +367,7 @@ tqdm==4.67.1
350
  # transformers
351
  transformers==4.49.0
352
  # via sentence-transformers
353
- typer==0.15.1
354
  # via chromadb
355
  typing-extensions==4.12.2
356
  # via
@@ -361,6 +378,7 @@ typing-extensions==4.12.2
361
  # opentelemetry-sdk
362
  # pydantic
363
  # pydantic-core
 
364
  # torch
365
  # typer
366
  tzdata==2025.1
 
19
  asgiref==3.8.1
20
  # via opentelemetry-instrumentation-asgi
21
  attrs==25.1.0
22
+ # via
23
+ # aiohttp
24
+ # jsonschema
25
+ # referencing
26
  backoff==2.2.1
27
  # via posthog
28
+ bcrypt==4.3.0
29
  # via chromadb
30
  build==1.2.2.post1
31
  # via chromadb
 
43
  # via requests
44
  chroma-hnswlib==0.7.6
45
  # via chromadb
46
+ chromadb==1.0.0b0
47
  # via -r requirements.in
48
  click==8.1.8
49
  # via
 
62
  # via
63
  # datasets
64
  # multiprocess
65
+ distro==1.9.0
66
+ # via posthog
67
  durationpy==0.9
68
  # via kubernetes
69
  einops==0.8.1
70
  # via -r requirements.in
71
+ fastapi==0.115.9
72
  # via
73
  # -r requirements.in
74
  # chromadb
 
140
  # via torch
141
  joblib==1.4.2
142
  # via scikit-learn
143
+ jsonschema==4.23.0
144
+ # via chromadb
145
+ jsonschema-specifications==2024.10.1
146
+ # via jsonschema
147
  kubernetes==32.0.1
148
  # via chromadb
149
  markdown-it-py==3.0.0
 
237
  # via datasets
238
  pillow==11.1.0
239
  # via sentence-transformers
240
+ polars==1.24.0
241
  # via -r requirements.in
242
+ posthog==3.18.0
243
  # via chromadb
244
  propcache==0.3.0
245
  # via
 
290
  # kubernetes
291
  # transformers
292
  # uvicorn
293
+ referencing==0.36.2
294
+ # via
295
+ # jsonschema
296
+ # jsonschema-specifications
297
  regex==2024.11.6
298
  # via transformers
299
  requests==2.32.3
 
310
  # via
311
  # chromadb
312
  # typer
313
+ rpds-py==0.23.1
314
+ # via
315
+ # jsonschema
316
+ # referencing
317
  rsa==4.9
318
  # via google-auth
319
  safetensors==0.5.3
 
326
  # sentence-transformers
327
  sentence-transformers==3.4.1
328
  # via -r requirements.in
329
+ setuptools==75.8.2
330
  # via torch
331
  shellingham==1.5.4
332
  # via typer
 
367
  # transformers
368
  transformers==4.49.0
369
  # via sentence-transformers
370
+ typer==0.15.2
371
  # via chromadb
372
  typing-extensions==4.12.2
373
  # via
 
378
  # opentelemetry-sdk
379
  # pydantic
380
  # pydantic-core
381
+ # referencing
382
  # torch
383
  # typer
384
  tzdata==2025.1