davanstrien HF staff commited on
Commit
0b973e8
·
1 Parent(s): ac80dbe

refactor trending models fetching to improve parameter filtering and logging

Browse files
Files changed (1) hide show
  1. main.py +66 -21
main.py CHANGED
@@ -624,10 +624,15 @@ async def get_trending_models_with_summaries(
624
  and model.get("downloads", 0) >= min_downloads
625
  ]
626
 
627
- # Sort by trending score and limit
628
  trending_models = sorted(
629
  trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
630
- )[:limit]
 
 
 
 
 
631
 
632
  # Get model IDs
633
  model_ids = [model["modelId"] for model in trending_models]
@@ -640,27 +645,20 @@ async def get_trending_models_with_summaries(
640
  id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
641
  id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"]))
642
 
643
- # Combine data
644
- results = []
 
 
 
 
 
645
  for model in trending_models:
646
  if model["modelId"] in id_to_summary:
647
  metadata = id_to_metadata.get(model["modelId"], {})
648
  param_count = metadata.get("param_count", 0)
649
 
650
- # Apply parameter count filters
651
- using_param_filters = min_param_count > 0 or max_param_count is not None
652
-
653
- # Skip if param_count is 0 and we're using param filters
654
- if using_param_filters and param_count == 0:
655
- continue
656
-
657
- # Skip if param_count is less than min_param_count
658
- if min_param_count > 0 and param_count < min_param_count:
659
- continue
660
-
661
- # Skip if param_count is greater than max_param_count
662
- if max_param_count is not None and param_count > max_param_count:
663
- continue
664
 
665
  result = ModelQueryResult(
666
  model_id=model["modelId"],
@@ -670,9 +668,50 @@ async def get_trending_models_with_summaries(
670
  downloads=model.get("downloads", 0),
671
  param_count=param_count,
672
  )
673
- results.append(result)
674
 
675
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  except Exception as e:
678
  logger.error(f"Error fetching trending models: {str(e)}")
@@ -689,7 +728,7 @@ async def get_trending_models(
689
  min_param_count: int = Query(
690
  default=0,
691
  ge=0,
692
- description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
693
  ),
694
  max_param_count: Optional[int] = Query(
695
  default=None,
@@ -703,6 +742,10 @@ async def get_trending_models(
703
  - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
704
  - param_count=0 indicates missing/unknown parameter count in the dataset
705
  """
 
 
 
 
706
  results = await get_trending_models_with_summaries(
707
  limit=limit,
708
  min_likes=min_likes,
@@ -710,6 +753,8 @@ async def get_trending_models(
710
  min_param_count=min_param_count,
711
  max_param_count=max_param_count,
712
  )
 
 
713
  return ModelQueryResponse(results=results)
714
 
715
 
 
624
  and model.get("downloads", 0) >= min_downloads
625
  ]
626
 
627
+ # Sort by trending score
628
  trending_models = sorted(
629
  trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
630
+ )
631
+
632
+ # Fetch up to 3x the limit (buffer for filtering) or all available if fewer
633
+ # This ensures we have enough models to filter from
634
+ fetch_limit = min(len(trending_models), limit * 3)
635
+ trending_models = trending_models[:fetch_limit]
636
 
637
  # Get model IDs
638
  model_ids = [model["modelId"] for model in trending_models]
 
645
  id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
646
  id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"]))
647
 
648
+ # Log parameters for debugging
649
+ print(
650
+ f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}"
651
+ )
652
+
653
+ # Combine data - collect all results first
654
+ all_results = []
655
  for model in trending_models:
656
  if model["modelId"] in id_to_summary:
657
  metadata = id_to_metadata.get(model["modelId"], {})
658
  param_count = metadata.get("param_count", 0)
659
 
660
+ # Log model parameter counts
661
+ print(f"Model: {model['modelId']}, param_count: {param_count}")
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
  result = ModelQueryResult(
664
  model_id=model["modelId"],
 
668
  downloads=model.get("downloads", 0),
669
  param_count=param_count,
670
  )
671
+ all_results.append(result)
672
 
673
+ # Apply parameter filtering after collecting all results
674
+ filtered_results = all_results
675
+
676
+ # Check if any parameter filtering is being applied
677
+ using_param_filters = min_param_count > 0 or max_param_count is not None
678
+
679
+ # Only filter by params if we have specific parameter constraints
680
+ if using_param_filters:
681
+ filtered_results = []
682
+ for result in all_results:
683
+ should_include = True
684
+
685
+ # Always exclude models with param_count=0 when any parameter filtering is active
686
+ if result.param_count == 0:
687
+ print(
688
+ f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active"
689
+ )
690
+ should_include = False
691
+
692
+ # Apply min param filter if specified
693
+ elif min_param_count > 0 and result.param_count < min_param_count:
694
+ print(
695
+ f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}"
696
+ )
697
+ should_include = False
698
+
699
+ # Apply max param filter if specified
700
+ elif (
701
+ max_param_count is not None and result.param_count > max_param_count
702
+ ):
703
+ print(
704
+ f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}"
705
+ )
706
+ should_include = False
707
+
708
+ if should_include:
709
+ filtered_results.append(result)
710
+
711
+ print(f"After filtering: {len(filtered_results)} models remain")
712
+
713
+ # Finally limit to the requested number
714
+ return filtered_results[:limit]
715
 
716
  except Exception as e:
717
  logger.error(f"Error fetching trending models: {str(e)}")
 
728
  min_param_count: int = Query(
729
  default=0,
730
  ge=0,
731
+ description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)",
732
  ),
733
  max_param_count: Optional[int] = Query(
734
  default=None,
 
742
  - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
743
  - param_count=0 indicates missing/unknown parameter count in the dataset
744
  """
745
+ print(
746
+ f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}"
747
+ )
748
+
749
  results = await get_trending_models_with_summaries(
750
  limit=limit,
751
  min_likes=min_likes,
 
753
  min_param_count=min_param_count,
754
  max_param_count=max_param_count,
755
  )
756
+
757
+ print(f"Returning {len(results)} trending model results")
758
  return ModelQueryResponse(results=results)
759
 
760