davanstrien HF staff commited on
Commit
ac80dbe
·
1 Parent(s): 42e0474

add parameter count filters to model search and trending endpoints

Browse files
Files changed (1) hide show
  1. main.py +140 -30
main.py CHANGED
@@ -366,29 +366,64 @@ async def find_similar_datasets(
366
  @cache(ttl=CACHE_TTL)
367
  async def search_models(
368
  query: str,
369
- k: int = Query(default=5, ge=1, le=100),
370
  sort_by: str = Query(
371
- default="similarity", enum=["similarity", "likes", "downloads", "trending"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  ),
373
- min_likes: int = Query(default=0, ge=0),
374
- min_downloads: int = Query(default=0, ge=0),
375
  ):
 
 
 
 
 
 
376
  try:
377
  collection = client.get_collection(
378
  name="model_cards", embedding_function=get_embedding_function()
379
  )
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  results = collection.query(
382
  query_texts=[f"search_query: {query}"],
383
  n_results=k * 4 if sort_by != "similarity" else k,
384
- where={
385
- "$and": [
386
- {"likes": {"$gte": min_likes}},
387
- {"downloads": {"$gte": min_downloads}},
388
- ]
389
- }
390
- if min_likes > 0 or min_downloads > 0
391
- else None,
392
  )
393
 
394
  query_results = await process_search_results(results, "model", k, sort_by)
@@ -404,13 +439,31 @@ async def search_models(
404
  @cache(ttl=CACHE_TTL)
405
  async def find_similar_models(
406
  model_id: str,
407
- k: int = Query(default=5, ge=1, le=100),
408
  sort_by: str = Query(
409
- default="similarity", enum=["similarity", "likes", "downloads", "trending"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  ),
411
- min_likes: int = Query(default=0, ge=0),
412
- min_downloads: int = Query(default=0, ge=0),
413
  ):
 
 
 
 
 
 
414
  try:
415
  collection = client.get_collection("model_cards")
416
 
@@ -421,17 +474,34 @@ async def find_similar_models(
421
  status_code=404, detail=f"Model ID '{model_id}' not found"
422
  )
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  results = collection.query(
425
  query_embeddings=[results["embeddings"][0]],
426
  n_results=k * 4 if sort_by != "similarity" else k + 1,
427
- where={
428
- "$and": [
429
- {"likes": {"$gte": min_likes}},
430
- {"downloads": {"$gte": min_downloads}},
431
- ]
432
- }
433
- if min_likes > 0 or min_downloads > 0
434
- else None,
435
  )
436
 
437
  query_results = await process_search_results(
@@ -538,6 +608,8 @@ async def get_trending_models_with_summaries(
538
  limit: int = 10,
539
  min_likes: int = 0,
540
  min_downloads: int = 0,
 
 
541
  ) -> List[ModelQueryResult]:
542
  """Fetch trending models and combine with summaries from database"""
543
  try:
@@ -573,13 +645,30 @@ async def get_trending_models_with_summaries(
573
  for model in trending_models:
574
  if model["modelId"] in id_to_summary:
575
  metadata = id_to_metadata.get(model["modelId"], {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  result = ModelQueryResult(
577
  model_id=model["modelId"],
578
  similarity=1.0, # Not applicable for trending
579
  summary=id_to_summary[model["modelId"]],
580
  likes=model.get("likes", 0),
581
  downloads=model.get("downloads", 0),
582
- param_count=metadata.get("param_count"),
583
  )
584
  results.append(result)
585
 
@@ -592,13 +681,34 @@ async def get_trending_models_with_summaries(
592
 
593
  @app.get("/trending/models", response_model=ModelQueryResponse)
594
  async def get_trending_models(
595
- limit: int = Query(default=10, ge=1, le=100),
596
- min_likes: int = Query(default=0, ge=0),
597
- min_downloads: int = Query(default=0, ge=0),
 
 
 
 
 
 
 
 
 
 
 
 
598
  ):
599
- """Get trending models with their summaries"""
 
 
 
 
 
600
  results = await get_trending_models_with_summaries(
601
- limit=limit, min_likes=min_likes, min_downloads=min_downloads
 
 
 
 
602
  )
603
  return ModelQueryResponse(results=results)
604
 
 
366
  @cache(ttl=CACHE_TTL)
367
  async def search_models(
368
  query: str,
369
+ k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
370
  sort_by: str = Query(
371
+ default="similarity",
372
+ enum=["similarity", "likes", "downloads", "trending"],
373
+ description="Sort method for results",
374
+ ),
375
+ min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
376
+ min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
377
+ min_param_count: int = Query(
378
+ default=0,
379
+ ge=0,
380
+ description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
381
+ ),
382
+ max_param_count: Optional[int] = Query(
383
+ default=None,
384
+ ge=0,
385
+ description="Maximum parameter count (None means no upper limit)",
386
  ),
 
 
387
  ):
388
+ """
389
+ Search for models based on a text query with optional filtering.
390
+
391
+ - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
392
+ - param_count=0 indicates missing/unknown parameter count in the dataset
393
+ """
394
  try:
395
  collection = client.get_collection(
396
  name="model_cards", embedding_function=get_embedding_function()
397
  )
398
 
399
+ where_conditions = []
400
+ if min_likes > 0:
401
+ where_conditions.append({"likes": {"$gte": min_likes}})
402
+ if min_downloads > 0:
403
+ where_conditions.append({"downloads": {"$gte": min_downloads}})
404
+
405
+ # Add parameter count filters
406
+ using_param_filters = min_param_count > 0 or max_param_count is not None
407
+ if using_param_filters:
408
+ # Always exclude zero param count when using any parameter filters
409
+ where_conditions.append({"param_count": {"$gt": 0}})
410
+
411
+ if min_param_count > 0:
412
+ where_conditions.append({"param_count": {"$gte": min_param_count}})
413
+ if max_param_count is not None:
414
+ where_conditions.append({"param_count": {"$lte": max_param_count}})
415
+
416
+ # Handle where clause creation based on number of conditions
417
+ where_clause = None
418
+ if len(where_conditions) > 1:
419
+ where_clause = {"$and": where_conditions}
420
+ elif len(where_conditions) == 1:
421
+ where_clause = where_conditions[0] # Single condition without $and
422
+
423
  results = collection.query(
424
  query_texts=[f"search_query: {query}"],
425
  n_results=k * 4 if sort_by != "similarity" else k,
426
+ where=where_clause,
 
 
 
 
 
 
 
427
  )
428
 
429
  query_results = await process_search_results(results, "model", k, sort_by)
 
439
  @cache(ttl=CACHE_TTL)
440
  async def find_similar_models(
441
  model_id: str,
442
+ k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
443
  sort_by: str = Query(
444
+ default="similarity",
445
+ enum=["similarity", "likes", "downloads", "trending"],
446
+ description="Sort method for results",
447
+ ),
448
+ min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
449
+ min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
450
+ min_param_count: int = Query(
451
+ default=0,
452
+ ge=0,
453
+ description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
454
+ ),
455
+ max_param_count: Optional[int] = Query(
456
+ default=None,
457
+ ge=0,
458
+ description="Maximum parameter count (None means no upper limit)",
459
  ),
 
 
460
  ):
461
+ """
462
+ Find similar models to a specified model with optional filtering.
463
+
464
+ - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
465
+ - param_count=0 indicates missing/unknown parameter count in the dataset
466
+ """
467
  try:
468
  collection = client.get_collection("model_cards")
469
 
 
474
  status_code=404, detail=f"Model ID '{model_id}' not found"
475
  )
476
 
477
+ where_conditions = []
478
+ if min_likes > 0:
479
+ where_conditions.append({"likes": {"$gte": min_likes}})
480
+ if min_downloads > 0:
481
+ where_conditions.append({"downloads": {"$gte": min_downloads}})
482
+
483
+ # Add parameter count filters
484
+ using_param_filters = min_param_count > 0 or max_param_count is not None
485
+ if using_param_filters:
486
+ # Always exclude zero param count when using any parameter filters
487
+ where_conditions.append({"param_count": {"$gt": 0}})
488
+
489
+ if min_param_count > 0:
490
+ where_conditions.append({"param_count": {"$gte": min_param_count}})
491
+ if max_param_count is not None:
492
+ where_conditions.append({"param_count": {"$lte": max_param_count}})
493
+
494
+ # Handle where clause creation based on number of conditions
495
+ where_clause = None
496
+ if len(where_conditions) > 1:
497
+ where_clause = {"$and": where_conditions}
498
+ elif len(where_conditions) == 1:
499
+ where_clause = where_conditions[0] # Single condition without $and
500
+
501
  results = collection.query(
502
  query_embeddings=[results["embeddings"][0]],
503
  n_results=k * 4 if sort_by != "similarity" else k + 1,
504
+ where=where_clause,
 
 
 
 
 
 
 
505
  )
506
 
507
  query_results = await process_search_results(
 
608
  limit: int = 10,
609
  min_likes: int = 0,
610
  min_downloads: int = 0,
611
+ min_param_count: int = 0,
612
+ max_param_count: Optional[int] = None,
613
  ) -> List[ModelQueryResult]:
614
  """Fetch trending models and combine with summaries from database"""
615
  try:
 
645
  for model in trending_models:
646
  if model["modelId"] in id_to_summary:
647
  metadata = id_to_metadata.get(model["modelId"], {})
648
+ param_count = metadata.get("param_count", 0)
649
+
650
+ # Apply parameter count filters
651
+ using_param_filters = min_param_count > 0 or max_param_count is not None
652
+
653
+ # Skip if param_count is 0 and we're using param filters
654
+ if using_param_filters and param_count == 0:
655
+ continue
656
+
657
+ # Skip if param_count is less than min_param_count
658
+ if min_param_count > 0 and param_count < min_param_count:
659
+ continue
660
+
661
+ # Skip if param_count is greater than max_param_count
662
+ if max_param_count is not None and param_count > max_param_count:
663
+ continue
664
+
665
  result = ModelQueryResult(
666
  model_id=model["modelId"],
667
  similarity=1.0, # Not applicable for trending
668
  summary=id_to_summary[model["modelId"]],
669
  likes=model.get("likes", 0),
670
  downloads=model.get("downloads", 0),
671
+ param_count=param_count,
672
  )
673
  results.append(result)
674
 
 
681
 
682
  @app.get("/trending/models", response_model=ModelQueryResponse)
683
  async def get_trending_models(
684
+ limit: int = Query(
685
+ default=10, ge=1, le=100, description="Number of results to return"
686
+ ),
687
+ min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
688
+ min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
689
+ min_param_count: int = Query(
690
+ default=0,
691
+ ge=0,
692
+ description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
693
+ ),
694
+ max_param_count: Optional[int] = Query(
695
+ default=None,
696
+ ge=0,
697
+ description="Maximum parameter count (None means no upper limit)",
698
+ ),
699
  ):
700
+ """
701
+ Get trending models with their summaries and optional filtering.
702
+
703
+ - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
704
+ - param_count=0 indicates missing/unknown parameter count in the dataset
705
+ """
706
  results = await get_trending_models_with_summaries(
707
+ limit=limit,
708
+ min_likes=min_likes,
709
+ min_downloads=min_downloads,
710
+ min_param_count=min_param_count,
711
+ max_param_count=max_param_count,
712
  )
713
  return ModelQueryResponse(results=results)
714