import duckdb | |
import yaml | |
import time | |
def extract_base_models(tags): | |
base_models = set() # Use a set to ensure uniqueness | |
for tag in tags: | |
if tag.startswith("base_model:"): | |
base_model = tag.split(":")[-1] | |
base_model = ''.join(c for c in base_model if c.isalnum() or c in '/_.-') | |
base_models.add(base_model) # Add to set | |
return list(base_models) # Convert back to list | |
# Create a DuckDB connection | |
con = duckdb.connect() | |
# Register the Python UDF with explicit return type | |
con.create_function("extract_base_models", extract_base_models, return_type="VARCHAR[]") | |
# Query to extract base models using the Python UDF | |
query = """ | |
SELECT | |
_id, | |
id, | |
extract_base_models(tags) AS base_models | |
FROM parquet_scan('public/models.parquet') | |
""" | |
start_time = time.time() | |
# Execute the query and create a view | |
con.execute(f"CREATE VIEW parent_models AS {query}") | |
# Write the view to a parquet file | |
con.execute("COPY parent_models TO 'public/parents.parquet' (FORMAT 'parquet')") | |
end_time = time.time() | |
execution_time = end_time - start_time | |
print(f"Query execution time: {execution_time:.2f} seconds") | |
# Filter rows with non-empty base_models | |
con.execute(""" | |
CREATE VIEW non_empty_base_models AS | |
SELECT * | |
FROM parent_models | |
WHERE ARRAY_LENGTH(base_models) > 0 | |
""") | |
# Write a random sample of 10 rows with non-empty base_models to yaml file for inspection | |
result = con.execute(""" | |
SELECT _id, id, base_models | |
FROM non_empty_base_models | |
ORDER BY RANDOM() | |
LIMIT 10 | |
""").fetchall() | |
with open("public/parents.example.yaml", "w") as f: | |
yaml.safe_dump(result, f, default_flow_style=False) | |