Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
ac80dbe
1
Parent(s):
42e0474
add parameter count filters to model search and trending endpoints
Browse files
main.py
CHANGED
@@ -366,29 +366,64 @@ async def find_similar_datasets(
|
|
366 |
@cache(ttl=CACHE_TTL)
|
367 |
async def search_models(
|
368 |
query: str,
|
369 |
-
k: int = Query(default=5, ge=1, le=100),
|
370 |
sort_by: str = Query(
|
371 |
-
default="similarity",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
),
|
373 |
-
min_likes: int = Query(default=0, ge=0),
|
374 |
-
min_downloads: int = Query(default=0, ge=0),
|
375 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
try:
|
377 |
collection = client.get_collection(
|
378 |
name="model_cards", embedding_function=get_embedding_function()
|
379 |
)
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
results = collection.query(
|
382 |
query_texts=[f"search_query: {query}"],
|
383 |
n_results=k * 4 if sort_by != "similarity" else k,
|
384 |
-
where=
|
385 |
-
"$and": [
|
386 |
-
{"likes": {"$gte": min_likes}},
|
387 |
-
{"downloads": {"$gte": min_downloads}},
|
388 |
-
]
|
389 |
-
}
|
390 |
-
if min_likes > 0 or min_downloads > 0
|
391 |
-
else None,
|
392 |
)
|
393 |
|
394 |
query_results = await process_search_results(results, "model", k, sort_by)
|
@@ -404,13 +439,31 @@ async def search_models(
|
|
404 |
@cache(ttl=CACHE_TTL)
|
405 |
async def find_similar_models(
|
406 |
model_id: str,
|
407 |
-
k: int = Query(default=5, ge=1, le=100),
|
408 |
sort_by: str = Query(
|
409 |
-
default="similarity",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
),
|
411 |
-
min_likes: int = Query(default=0, ge=0),
|
412 |
-
min_downloads: int = Query(default=0, ge=0),
|
413 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
try:
|
415 |
collection = client.get_collection("model_cards")
|
416 |
|
@@ -421,17 +474,34 @@ async def find_similar_models(
|
|
421 |
status_code=404, detail=f"Model ID '{model_id}' not found"
|
422 |
)
|
423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
results = collection.query(
|
425 |
query_embeddings=[results["embeddings"][0]],
|
426 |
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
427 |
-
where=
|
428 |
-
"$and": [
|
429 |
-
{"likes": {"$gte": min_likes}},
|
430 |
-
{"downloads": {"$gte": min_downloads}},
|
431 |
-
]
|
432 |
-
}
|
433 |
-
if min_likes > 0 or min_downloads > 0
|
434 |
-
else None,
|
435 |
)
|
436 |
|
437 |
query_results = await process_search_results(
|
@@ -538,6 +608,8 @@ async def get_trending_models_with_summaries(
|
|
538 |
limit: int = 10,
|
539 |
min_likes: int = 0,
|
540 |
min_downloads: int = 0,
|
|
|
|
|
541 |
) -> List[ModelQueryResult]:
|
542 |
"""Fetch trending models and combine with summaries from database"""
|
543 |
try:
|
@@ -573,13 +645,30 @@ async def get_trending_models_with_summaries(
|
|
573 |
for model in trending_models:
|
574 |
if model["modelId"] in id_to_summary:
|
575 |
metadata = id_to_metadata.get(model["modelId"], {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
576 |
result = ModelQueryResult(
|
577 |
model_id=model["modelId"],
|
578 |
similarity=1.0, # Not applicable for trending
|
579 |
summary=id_to_summary[model["modelId"]],
|
580 |
likes=model.get("likes", 0),
|
581 |
downloads=model.get("downloads", 0),
|
582 |
-
param_count=
|
583 |
)
|
584 |
results.append(result)
|
585 |
|
@@ -592,13 +681,34 @@ async def get_trending_models_with_summaries(
|
|
592 |
|
593 |
@app.get("/trending/models", response_model=ModelQueryResponse)
|
594 |
async def get_trending_models(
|
595 |
-
limit: int = Query(
|
596 |
-
|
597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
):
|
599 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
600 |
results = await get_trending_models_with_summaries(
|
601 |
-
limit=limit,
|
|
|
|
|
|
|
|
|
602 |
)
|
603 |
return ModelQueryResponse(results=results)
|
604 |
|
|
|
366 |
@cache(ttl=CACHE_TTL)
|
367 |
async def search_models(
|
368 |
query: str,
|
369 |
+
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
|
370 |
sort_by: str = Query(
|
371 |
+
default="similarity",
|
372 |
+
enum=["similarity", "likes", "downloads", "trending"],
|
373 |
+
description="Sort method for results",
|
374 |
+
),
|
375 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
376 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
377 |
+
min_param_count: int = Query(
|
378 |
+
default=0,
|
379 |
+
ge=0,
|
380 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
381 |
+
),
|
382 |
+
max_param_count: Optional[int] = Query(
|
383 |
+
default=None,
|
384 |
+
ge=0,
|
385 |
+
description="Maximum parameter count (None means no upper limit)",
|
386 |
),
|
|
|
|
|
387 |
):
|
388 |
+
"""
|
389 |
+
Search for models based on a text query with optional filtering.
|
390 |
+
|
391 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
392 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
393 |
+
"""
|
394 |
try:
|
395 |
collection = client.get_collection(
|
396 |
name="model_cards", embedding_function=get_embedding_function()
|
397 |
)
|
398 |
|
399 |
+
where_conditions = []
|
400 |
+
if min_likes > 0:
|
401 |
+
where_conditions.append({"likes": {"$gte": min_likes}})
|
402 |
+
if min_downloads > 0:
|
403 |
+
where_conditions.append({"downloads": {"$gte": min_downloads}})
|
404 |
+
|
405 |
+
# Add parameter count filters
|
406 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
407 |
+
if using_param_filters:
|
408 |
+
# Always exclude zero param count when using any parameter filters
|
409 |
+
where_conditions.append({"param_count": {"$gt": 0}})
|
410 |
+
|
411 |
+
if min_param_count > 0:
|
412 |
+
where_conditions.append({"param_count": {"$gte": min_param_count}})
|
413 |
+
if max_param_count is not None:
|
414 |
+
where_conditions.append({"param_count": {"$lte": max_param_count}})
|
415 |
+
|
416 |
+
# Handle where clause creation based on number of conditions
|
417 |
+
where_clause = None
|
418 |
+
if len(where_conditions) > 1:
|
419 |
+
where_clause = {"$and": where_conditions}
|
420 |
+
elif len(where_conditions) == 1:
|
421 |
+
where_clause = where_conditions[0] # Single condition without $and
|
422 |
+
|
423 |
results = collection.query(
|
424 |
query_texts=[f"search_query: {query}"],
|
425 |
n_results=k * 4 if sort_by != "similarity" else k,
|
426 |
+
where=where_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
)
|
428 |
|
429 |
query_results = await process_search_results(results, "model", k, sort_by)
|
|
|
439 |
@cache(ttl=CACHE_TTL)
|
440 |
async def find_similar_models(
|
441 |
model_id: str,
|
442 |
+
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
|
443 |
sort_by: str = Query(
|
444 |
+
default="similarity",
|
445 |
+
enum=["similarity", "likes", "downloads", "trending"],
|
446 |
+
description="Sort method for results",
|
447 |
+
),
|
448 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
449 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
450 |
+
min_param_count: int = Query(
|
451 |
+
default=0,
|
452 |
+
ge=0,
|
453 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
454 |
+
),
|
455 |
+
max_param_count: Optional[int] = Query(
|
456 |
+
default=None,
|
457 |
+
ge=0,
|
458 |
+
description="Maximum parameter count (None means no upper limit)",
|
459 |
),
|
|
|
|
|
460 |
):
|
461 |
+
"""
|
462 |
+
Find similar models to a specified model with optional filtering.
|
463 |
+
|
464 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
465 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
466 |
+
"""
|
467 |
try:
|
468 |
collection = client.get_collection("model_cards")
|
469 |
|
|
|
474 |
status_code=404, detail=f"Model ID '{model_id}' not found"
|
475 |
)
|
476 |
|
477 |
+
where_conditions = []
|
478 |
+
if min_likes > 0:
|
479 |
+
where_conditions.append({"likes": {"$gte": min_likes}})
|
480 |
+
if min_downloads > 0:
|
481 |
+
where_conditions.append({"downloads": {"$gte": min_downloads}})
|
482 |
+
|
483 |
+
# Add parameter count filters
|
484 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
485 |
+
if using_param_filters:
|
486 |
+
# Always exclude zero param count when using any parameter filters
|
487 |
+
where_conditions.append({"param_count": {"$gt": 0}})
|
488 |
+
|
489 |
+
if min_param_count > 0:
|
490 |
+
where_conditions.append({"param_count": {"$gte": min_param_count}})
|
491 |
+
if max_param_count is not None:
|
492 |
+
where_conditions.append({"param_count": {"$lte": max_param_count}})
|
493 |
+
|
494 |
+
# Handle where clause creation based on number of conditions
|
495 |
+
where_clause = None
|
496 |
+
if len(where_conditions) > 1:
|
497 |
+
where_clause = {"$and": where_conditions}
|
498 |
+
elif len(where_conditions) == 1:
|
499 |
+
where_clause = where_conditions[0] # Single condition without $and
|
500 |
+
|
501 |
results = collection.query(
|
502 |
query_embeddings=[results["embeddings"][0]],
|
503 |
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
504 |
+
where=where_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
)
|
506 |
|
507 |
query_results = await process_search_results(
|
|
|
608 |
limit: int = 10,
|
609 |
min_likes: int = 0,
|
610 |
min_downloads: int = 0,
|
611 |
+
min_param_count: int = 0,
|
612 |
+
max_param_count: Optional[int] = None,
|
613 |
) -> List[ModelQueryResult]:
|
614 |
"""Fetch trending models and combine with summaries from database"""
|
615 |
try:
|
|
|
645 |
for model in trending_models:
|
646 |
if model["modelId"] in id_to_summary:
|
647 |
metadata = id_to_metadata.get(model["modelId"], {})
|
648 |
+
param_count = metadata.get("param_count", 0)
|
649 |
+
|
650 |
+
# Apply parameter count filters
|
651 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
652 |
+
|
653 |
+
# Skip if param_count is 0 and we're using param filters
|
654 |
+
if using_param_filters and param_count == 0:
|
655 |
+
continue
|
656 |
+
|
657 |
+
# Skip if param_count is less than min_param_count
|
658 |
+
if min_param_count > 0 and param_count < min_param_count:
|
659 |
+
continue
|
660 |
+
|
661 |
+
# Skip if param_count is greater than max_param_count
|
662 |
+
if max_param_count is not None and param_count > max_param_count:
|
663 |
+
continue
|
664 |
+
|
665 |
result = ModelQueryResult(
|
666 |
model_id=model["modelId"],
|
667 |
similarity=1.0, # Not applicable for trending
|
668 |
summary=id_to_summary[model["modelId"]],
|
669 |
likes=model.get("likes", 0),
|
670 |
downloads=model.get("downloads", 0),
|
671 |
+
param_count=param_count,
|
672 |
)
|
673 |
results.append(result)
|
674 |
|
|
|
681 |
|
682 |
@app.get("/trending/models", response_model=ModelQueryResponse)
|
683 |
async def get_trending_models(
|
684 |
+
limit: int = Query(
|
685 |
+
default=10, ge=1, le=100, description="Number of results to return"
|
686 |
+
),
|
687 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
688 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
689 |
+
min_param_count: int = Query(
|
690 |
+
default=0,
|
691 |
+
ge=0,
|
692 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
693 |
+
),
|
694 |
+
max_param_count: Optional[int] = Query(
|
695 |
+
default=None,
|
696 |
+
ge=0,
|
697 |
+
description="Maximum parameter count (None means no upper limit)",
|
698 |
+
),
|
699 |
):
|
700 |
+
"""
|
701 |
+
Get trending models with their summaries and optional filtering.
|
702 |
+
|
703 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
704 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
705 |
+
"""
|
706 |
results = await get_trending_models_with_summaries(
|
707 |
+
limit=limit,
|
708 |
+
min_likes=min_likes,
|
709 |
+
min_downloads=min_downloads,
|
710 |
+
min_param_count=min_param_count,
|
711 |
+
max_param_count=max_param_count,
|
712 |
)
|
713 |
return ModelQueryResponse(results=results)
|
714 |
|