|
from pathlib import Path |
|
import duckdb |
|
from datasets import load_dataset |
|
import os |
|
from config.settings import PARQUET_FILE |
|
|
|
|
|
class DatabaseService: |
|
def __init__(self): |
|
self.parquet_path = None |
|
self.conn = None |
|
|
|
os.makedirs('./data', exist_ok=True) |
|
|
|
parquet_file = self._find_parquet_file() |
|
if parquet_file and Path(parquet_file).exists(): |
|
self.parquet_path = Path(parquet_file) |
|
self._setup_duckdb() |
|
print(f"🗜️ Dataset : {self.parquet_path} \n\n") |
|
else: |
|
if self._download_from_huggingface(): |
|
parquet_file = self._find_parquet_file() |
|
if parquet_file: |
|
self.parquet_path = Path(parquet_file) |
|
self._setup_duckdb() |
|
print(f"✅ Dataset dowloaded : {self.parquet_path}") |
|
else: |
|
print("❌ Can't load dataset") |
|
raise Exception("No dataset") |
|
else: |
|
print("❌ Can't download dataset") |
|
raise Exception("Dataset download failed") |
|
|
|
@staticmethod |
|
def _download_from_huggingface(): |
|
""" |
|
Download dataset from Hugging Face |
|
""" |
|
try: |
|
print("🔄 Downloading dataset from Hugging Face...") |
|
|
|
|
|
dataset_name = "openfoodfacts/product-database" |
|
|
|
dataset = load_dataset(dataset_name, split="food") |
|
|
|
dataset.to_parquet("./data/food.parquet") |
|
|
|
print(f"✅ Dataset downloaded & saved in ./data/food.parquet") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"❌ Erreur lors du téléchargement depuis Hugging Face: {e}") |
|
|
|
@staticmethod |
|
def _find_parquet_file(): |
|
paths = [PARQUET_FILE, "./data/food.parquet", "./food.parquet", "../data/food.parquet"] |
|
for path in paths: |
|
if Path(path).exists(): |
|
return path |
|
return None |
|
|
|
def _setup_duckdb(self): |
|
try: |
|
self.conn = duckdb.connect() |
|
result = self.conn.execute(f"SELECT COUNT(*) FROM '{self.parquet_path}'").fetchone() |
|
total = result[0] if result else 0 |
|
print(f"🦆 DuckDB: {total:,} products \n\n") |
|
except Exception as e: |
|
print(f"❌ Error DuckDB: {e}") |
|
raise Exception(f"DuckDB can not be configured: {e}") |
|
|
|
@staticmethod |
|
def escape_sql_string(text): |
|
if not text: |
|
return "" |
|
text = text.replace("'", "''") |
|
text = text.replace("%", "%%") |
|
return text |
|
|
|
@staticmethod |
|
def _clean_tags(tags_raw): |
|
if not tags_raw: |
|
return [] |
|
|
|
if isinstance(tags_raw, str): |
|
if tags_raw.startswith('['): |
|
try: |
|
import ast |
|
tags_list = ast.literal_eval(tags_raw) |
|
if isinstance(tags_list, list): |
|
return [str(tag).replace('en:', '').replace('fr:', '') for tag in tags_list[:3]] |
|
except: |
|
pass |
|
return [tags_raw] |
|
|
|
if isinstance(tags_raw, list): |
|
return [str(tag).replace('en:', '').replace('fr:', '') for tag in tags_raw[:3]] |
|
|
|
return [str(tags_raw)] |
|
|
|
@staticmethod |
|
def clean_product_name(raw_name): |
|
if not raw_name or raw_name == 'N/A': |
|
return 'N/A' |
|
|
|
if raw_name.startswith('[') and 'text' in raw_name: |
|
try: |
|
import re |
|
match = re.search(r"'text':\s*'([^']*)'", raw_name) |
|
if match: |
|
return match.group(1) |
|
match = re.search(r'"text":\s*"([^"]*)"', raw_name) |
|
if match: |
|
return match.group(1) |
|
except Exception as e: |
|
print(f"⚠️ Error cleaning: {e}") |
|
|
|
clean = raw_name.replace('[', '').replace(']', '').replace('{', '').replace('}', '') |
|
|
|
if 'text' in clean: |
|
parts = clean.split('text') |
|
if len(parts) > 1: |
|
text_part = parts[-1] |
|
text_part = text_part.replace('"', '').replace("'", '').replace(':', '').replace(',', '') |
|
text_part = text_part.strip() |
|
if text_part and len(text_part) > 3: |
|
return text_part |
|
|
|
return clean[:100] |
|
|
|
def search_products(self, analysis, limit=1): |
|
""" |
|
Search products in dataset |
|
""" |
|
if not self.conn or not self.parquet_path: |
|
print("❌ DAtabase not initialized") |
|
return [] |
|
|
|
try: |
|
product = analysis.get("product", "").strip() |
|
brand = analysis.get("brand", "").strip() |
|
|
|
safe_product = self.escape_sql_string(product) |
|
safe_brand = self.escape_sql_string(brand) |
|
safe_query = self.escape_sql_string(product) |
|
|
|
conditions = [] |
|
scores = [] |
|
|
|
if product and brand: |
|
product_condition = f"LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('%{safe_product}%')" |
|
brand_condition = f"LOWER(CAST(brands AS VARCHAR)) LIKE LOWER('%{safe_brand}%')" |
|
conditions.append(f"({product_condition} AND {brand_condition})") |
|
scores.append(f"CASE WHEN {product_condition} AND {brand_condition} THEN 100 ELSE 0 END") |
|
scores.append( |
|
f"CASE WHEN LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('{safe_product}%') THEN 20 ELSE 0 END") |
|
scores.append(f"CASE WHEN LOWER(CAST(brands AS VARCHAR)) = LOWER('{safe_brand}') THEN 30 ELSE 0 END") |
|
|
|
elif product and not brand: |
|
product_condition = f"LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('%{safe_product}%')" |
|
conditions.append(product_condition) |
|
scores.append( |
|
f"CASE WHEN LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('{safe_product}%') THEN 80 ELSE 50 END") |
|
scores.append("CASE WHEN brands IS NOT NULL AND LENGTH(CAST(brands AS VARCHAR)) > 3 THEN 10 ELSE 0 END") |
|
|
|
elif brand and not product: |
|
brand_condition = f"LOWER(CAST(brands AS VARCHAR)) LIKE LOWER('%{safe_brand}%')" |
|
conditions.append(brand_condition) |
|
scores.append(f"CASE WHEN LOWER(CAST(brands AS VARCHAR)) = LOWER('{safe_brand}') THEN 90 ELSE 60 END") |
|
|
|
else: |
|
conditions.append( |
|
f"(LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('%{safe_query}%') OR LOWER(CAST(brands AS VARCHAR)) LIKE LOWER('%{safe_query}%'))") |
|
scores.append( |
|
f"CASE WHEN LOWER(CAST(product_name AS VARCHAR)) LIKE LOWER('%{safe_query}%') THEN 40 ELSE 20 END") |
|
|
|
conditions.append("LOWER(CAST(countries_tags AS VARCHAR)) LIKE '%france%'") |
|
|
|
where_clause = " AND ".join(conditions) |
|
score_calc = " + ".join(scores) if scores else "1" |
|
|
|
min_score_threshold = 30 |
|
where_clause = f"({where_clause}) AND (({score_calc}) >= {min_score_threshold})" |
|
|
|
sql = f""" |
|
SELECT DISTINCT |
|
product_name, -- 0 |
|
brands, -- 1 |
|
nutriscore_grade, -- 2 |
|
nova_group, -- 3 |
|
categories, -- 4 |
|
ingredients_n, -- 5 |
|
additives_n, -- 6 |
|
allergens_tags, -- 7 |
|
ingredients_text, -- 8 |
|
nutriments, -- 9 |
|
serving_size, -- 10 |
|
quantity, -- 11 |
|
({score_calc}) AS score -- 12 |
|
FROM '{self.parquet_path}' |
|
WHERE {where_clause} |
|
ORDER BY score DESC, product_name ASC |
|
LIMIT {limit} |
|
""" |
|
|
|
results = self.conn.execute(sql).fetchall() |
|
|
|
products = [] |
|
for row in results: |
|
raw_name = str(row[0]) if row[0] else 'N/A' |
|
clean_name = self.clean_product_name(raw_name) |
|
|
|
products.append({ |
|
'product_name': clean_name[:100], |
|
'brands': str(row[1])[:50] if row[1] else 'N/A', |
|
'nutriscore_grade': str(row[2]).lower() if row[2] else '', |
|
'nova_group': row[3] if row[3] else None, |
|
'categories': str(row[4])[:100] if row[4] else 'N/A', |
|
'ingredients_count': row[5] if row[5] else 0, |
|
'additives_count': row[6] if row[6] else 0, |
|
'allergens': self._clean_tags(row[7]), |
|
'ingredients_text': str(row[8])[:500] if row[8] else 'N/A', |
|
'nutriments': (row[9]) if row[9] else 'N/A', |
|
'serving_size': str(row[10])[:20] if row[10] else 'N/A', |
|
'quantity': str(row[11])[:30] if row[11] else 'N/A', |
|
'score': float(row[12]) if row[12] else 0.0, |
|
'mistral_analysis': analysis.get('explanation', '') |
|
}) |
|
|
|
return products |
|
|
|
except Exception as e: |
|
print(f"❌ Search error: {e}") |
|
return [] |