Yeetek commited on
Commit
6948a45
·
verified ·
1 Parent(s): 1570c6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -25,6 +25,7 @@ from pydantic import BaseModel
25
  from bertopic import BERTopic
26
  from sentence_transformers import SentenceTransformer
27
  from umap import UMAP
 
28
 
29
  # 0) Quick env dump
30
  print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
@@ -40,8 +41,8 @@ MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
40
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
41
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
42
 
43
- # 3) Initialise once
44
- embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
45
 
46
  # 4) Schemas
47
  class Sentence(BaseModel):
@@ -69,21 +70,20 @@ app = FastAPI(title="CZ Topic Segmenter", version="1.0")
69
 
70
  @app.post("/segment", response_model=SegmentationResponse)
71
  def segment(sentences: List[Sentence]):
72
- # Guardrail: avoid oversize requests
73
  if len(sentences) > MAX_DOCS:
74
  raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
75
 
76
- # Sort by chunk_index if available, else maintain original order
77
  sorted_sent = sorted(
78
  sentences,
79
  key=lambda s: s.chunk_index if s.chunk_index is not None else 0
80
  )
81
  docs = [s.text for s in sorted_sent]
82
 
83
- # Choose dynamic n_neighbors <= n_samples-1
84
  n_samples = len(docs)
85
  n_neighbors = min(15, max(2, n_samples - 1))
86
- # UMAP with cosine and random init to avoid spectral errors on tiny N
87
  umap_model = UMAP(
88
  n_neighbors=n_neighbors,
89
  metric="cosine",
@@ -91,12 +91,21 @@ def segment(sentences: List[Sentence]):
91
  random_state=42
92
  )
93
 
94
- # Build BERTopic per request with dynamic UMAP
 
 
 
 
 
 
 
 
 
95
  topic_model = BERTopic(
96
  embedding_model=embeddings,
97
  umap_model=umap_model,
98
- min_topic_size=MIN_TOPIC,
99
- calculate_probabilities=True,
100
  )
101
 
102
  # Fit-transform
@@ -106,7 +115,6 @@ def segment(sentences: List[Sentence]):
106
  segments = []
107
  cur = None
108
  for idx, (t_id, prob) in enumerate(zip(topics, probs)):
109
- # Map back to original chunk_index or positional idx
110
  orig_idx = (
111
  sorted_sent[idx].chunk_index
112
  if sorted_sent[idx].chunk_index is not None
@@ -115,7 +123,6 @@ def segment(sentences: List[Sentence]):
115
  if cur is None or t_id != cur["topic_id"]:
116
  if cur:
117
  segments.append(cur)
118
- # Top-5 keywords
119
  words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
120
  cur = {
121
  "topic_id": t_id,
 
25
  from bertopic import BERTopic
26
  from sentence_transformers import SentenceTransformer
27
  from umap import UMAP
28
+ from hdbscan import HDBSCAN
29
 
30
  # 0) Quick env dump
31
  print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
 
41
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
42
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
43
 
44
+ # 3) Initialise embeddings once
45
+ en embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
46
 
47
  # 4) Schemas
48
  class Sentence(BaseModel):
 
70
 
71
  @app.post("/segment", response_model=SegmentationResponse)
72
  def segment(sentences: List[Sentence]):
73
+ # Guardrail
74
  if len(sentences) > MAX_DOCS:
75
  raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
76
 
77
+ # Sort by chunk_index if present
78
  sorted_sent = sorted(
79
  sentences,
80
  key=lambda s: s.chunk_index if s.chunk_index is not None else 0
81
  )
82
  docs = [s.text for s in sorted_sent]
83
 
84
+ # UMAP with cosine, init random, dynamic neighbors
85
  n_samples = len(docs)
86
  n_neighbors = min(15, max(2, n_samples - 1))
 
87
  umap_model = UMAP(
88
  n_neighbors=n_neighbors,
89
  metric="cosine",
 
91
  random_state=42
92
  )
93
 
94
+ # HDBSCAN with dynamic cluster sizes
95
+ cluster_size = min(MIN_TOPIC, n_samples) if n_samples >= 2 else 2
96
+ hdbscan_model = HDBSCAN(
97
+ min_cluster_size=cluster_size,
98
+ min_samples=min(cluster_size, n_samples),
99
+ metric="euclidean",
100
+ cluster_selection_method="eom"
101
+ )
102
+
103
+ # Build BERTopic per request
104
  topic_model = BERTopic(
105
  embedding_model=embeddings,
106
  umap_model=umap_model,
107
+ hdbscan_model=hdbscan_model,
108
+ calculate_probabilities=True
109
  )
110
 
111
  # Fit-transform
 
115
  segments = []
116
  cur = None
117
  for idx, (t_id, prob) in enumerate(zip(topics, probs)):
 
118
  orig_idx = (
119
  sorted_sent[idx].chunk_index
120
  if sorted_sent[idx].chunk_index is not None
 
123
  if cur is None or t_id != cur["topic_id"]:
124
  if cur:
125
  segments.append(cur)
 
126
  words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
127
  cur = {
128
  "topic_id": t_id,