Yeetek commited on
Commit
e972751
Β·
verified Β·
1 Parent(s): 7faf4df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -14
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
2
- import pkgutil, sentence_transformers, bertopic, sys, json
3
 
4
- # 1) Print versions & model‐list
5
  print("ST version:", sentence_transformers.__version__)
6
  print("BERTopic version:", bertopic.__version__)
7
  models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
@@ -17,8 +17,7 @@ if "StaticEmbedding" not in models:
17
  sys.stdout.flush()
18
  # ──────────────────────────────────────────────────────────────────────────────
19
 
20
- # ── REST OF YOUR APP.PY ──────────────────────────────────────────────────────
21
- import os, uuid
22
  from typing import List
23
  from fastapi import FastAPI, HTTPException
24
  from pydantic import BaseModel
@@ -26,6 +25,8 @@ from bertopic import BERTopic
26
  from sentence_transformers import SentenceTransformer
27
  from umap import UMAP
28
  from hdbscan import HDBSCAN
 
 
29
 
30
  # 0) Quick env dump
31
  print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
@@ -37,18 +38,122 @@ os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
37
  os.environ["NUMBA_DISABLE_CACHE"] = "1"
38
 
39
  # 2) Config from ENV
40
- MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
41
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
42
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
43
 
44
- # 3) Initialise embeddings once
45
- # Reuse the build-time cache folder which is writeable
46
- home_cache = "/tmp/hfcache"
47
- # Already created and permission-fixed at build time
48
- os.environ["HF_HOME"] = home_cache
49
- os.environ["TRANSFORMERS_CACHE"] = home_cache
50
- os.environ["SENTENCE_TRANSFORMERS_HOME"] = home_cache
 
51
 
52
- embeddings = SentenceTransformer(MODEL_NAME)
53
- (MODEL_NAME)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
2
+ import pkgutil, sentence_transformers, bertopic, sys, json, os, uuid
3
 
4
+ # 1) Print versions & model-list
5
  print("ST version:", sentence_transformers.__version__)
6
  print("BERTopic version:", bertopic.__version__)
7
  models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
 
17
  sys.stdout.flush()
18
  # ──────────────────────────────────────────────────────────────────────────────
19
 
20
+ # ── REST OF APP.PY ───────────────────────────────────────────────────────────
 
21
  from typing import List
22
  from fastapi import FastAPI, HTTPException
23
  from pydantic import BaseModel
 
25
  from sentence_transformers import SentenceTransformer
26
  from umap import UMAP
27
  from hdbscan import HDBSCAN
28
+ from sklearn.feature_extraction.text import CountVectorizer
29
+ from stop_words import get_stop_words
30
 
31
  # 0) Quick env dump
32
  print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
 
38
  os.environ["NUMBA_DISABLE_CACHE"] = "1"
39
 
40
  # 2) Config from ENV
41
+ MODEL_NAME = os.getenv("EMBED_MODEL", "seznam/simcse-small-e-czech")
42
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
43
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
44
 
45
+ # 3) Set HF cache envs to a writeable folder (once at startup)
46
+ cache_dir = "/tmp/hfcache"
47
+ os.makedirs(cache_dir, exist_ok=True)
48
+ import stat
49
+ os.chmod(cache_dir, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
50
+ os.environ["HF_HOME"] = cache_dir
51
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
52
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
53
 
54
+ # 4) Initialise embeddings once
55
+ embeddings = SentenceTransformer(MODEL_NAME, cache_folder=cache_dir)
56
 
57
+ # Pre-initialize fallback global models for small-batch debugging
58
+ # Global UMAP: 2-neighbors, cosine space, random init
59
+ global_umap = UMAP(
60
+ n_neighbors=2,
61
+ metric="cosine",
62
+ init="random",
63
+ random_state=42
64
+ )
65
+ # Global HDBSCAN: min cluster size 2, min_samples 1, cosine metric
66
+ global_hdbscan = HDBSCAN(
67
+ min_cluster_size=2,
68
+ min_samples=1,
69
+ metric="cosine",
70
+ cluster_selection_method="eom",
71
+ prediction_data=True
72
+ )
73
+ # Global Czech vectorizer: stopwords + bigrams
74
+ global_vectorizer = CountVectorizer(
75
+ stop_words=get_stop_words("czech"),
76
+ ngram_range=(1, 2)
77
+ )
78
+
79
+ # 5) FastAPI schemas and app
80
+ class Sentence(BaseModel):
81
+ text: str
82
+ start: float
83
+ end: float
84
+ speaker: str | None = None
85
+ chunk_index: int | None = None
86
+
87
+ class Segment(BaseModel):
88
+ topic_id: int
89
+ label: str | None
90
+ keywords: List[str]
91
+ start: float
92
+ end: float
93
+ probability: float | None
94
+ sentences: List[int]
95
+
96
+ class SegmentationResponse(BaseModel):
97
+ run_id: str
98
+ segments: List[Segment]
99
+
100
+ app = FastAPI(title="CZ Topic Segmenter", version="1.0")
101
+
102
+ @app.get("/")
103
+ async def root():
104
+ return {"message": "CZ Topic Segmenter is running."}
105
+
106
+ @app.post("/segment", response_model=SegmentationResponse)
107
+ def segment(sentences: List[Sentence]):
108
+ print(f"[segment] Received {len(sentences)} sentences, chunk_indices={[s.chunk_index for s in sentences]}")
109
+ sys.stdout.flush()
110
+ if len(sentences) > MAX_DOCS:
111
+ raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
112
+
113
+ # sort by chunk_index
114
+ sorted_sent = sorted(
115
+ sentences,
116
+ key=lambda s: s.chunk_index if s.chunk_index is not None else 0
117
+ )
118
+ docs = [s.text for s in sorted_sent]
119
+
120
+ # Use global UMAP/HDBSCAN/vectorizer instances for debugging
121
+ umap_model = global_umap
122
+ hdbscan_model = global_hdbscan
123
+ vectorizer_model = global_vectorizer
124
+
125
+ # instantiate BERTopic per request with global components
126
+ topic_model = BERTopic(
127
+ embedding_model=embeddings,
128
+ umap_model=umap_model,
129
+ hdbscan_model=hdbscan_model,
130
+ vectorizer_model=vectorizer_model,
131
+ min_topic_size=2,
132
+ calculate_probabilities=True
133
+ )
134
+
135
+ topics, probs = topic_model.fit_transform(docs)
136
+
137
+ segments, cur = [], None
138
+ for idx, (t_id, prob) in enumerate(zip(topics, probs)):
139
+ orig_idx = sorted_sent[idx].chunk_index if sorted_sent[idx].chunk_index is not None else idx
140
+ if cur is None or t_id != cur["topic_id"]:
141
+ words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
142
+ cur = dict(
143
+ topic_id=t_id,
144
+ label=" ".join(words) if t_id != -1 else None,
145
+ keywords=words,
146
+ start=sorted_sent[idx].start,
147
+ end=sorted_sent[idx].end,
148
+ probability=float(prob or 0),
149
+ sentences=[orig_idx],
150
+ )
151
+ else:
152
+ cur["end"] = sorted_sent[idx].end
153
+ cur["sentences"].append(orig_idx)
154
+ if cur:
155
+ segments.append(cur)
156
+
157
+ print(f"[segment] Returning {len(segments)} segments: {segments}")
158
+ sys.stdout.flush()
159
+ return {"run_id": str(uuid.uuid4()), "segments": segments}