Yeetek commited on
Commit
a02884f
Β·
verified Β·
1 Parent(s): 98c2919

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -81
app.py CHANGED
@@ -1,83 +1,46 @@
1
- # ── DIAGNOSTICS ─────────────────────────────────────────────
2
- import pkgutil, sentence_transformers, bertopic, sys
 
 
3
  print("ST version:", sentence_transformers.__version__)
4
  print("BERTopic version:", bertopic.__version__)
5
- print("ST models:", [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)])
 
6
  sys.stdout.flush()
7
 
8
- # ── STATICEMBEDDING SHIM ────────────────────────────────────
9
- if "StaticEmbedding" not in [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]:
10
  from sentence_transformers.models import Transformer
11
- import bertopic.backend._sentencetransformers as st_back
12
- st_back.StaticEmbedding = Transformer
13
-
14
- # ── regular imports and monkey-patch if you choose ───────────────────────
15
- from bertopic import BERTopic
16
- # optionally:
17
- # import bertopic.backend._sentencetransformers as st_back
18
- # from sentence_transformers.models import Transformer
19
- # st_back.StaticEmbedding = Transformer
20
-
21
- from fastapi import FastAPI, HTTPException
22
-
23
- # ---------- BEGIN app.py ----------
24
- import os, sys, json, uuid, types
25
-
26
- # ── 0. Quick env print – delete later if you like ───────────────────────
27
- print("ENV-snapshot:", json.dumps(dict(list(os.environ.items())[:25])))
28
- sys.stdout.flush()
29
-
30
- # ── 1. Ensure a writable dir (good housekeeping) ────────────────────────
31
- os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
32
- os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
33
-
34
- # ── 2. FINAL numba cache kill-switch ────────────────────────────────────
35
- try:
36
- import importlib, numba, types
37
- from numba.core import dispatcher, caching
38
- import numba.np.ufunc.ufuncbuilder as ufuncbuilder
39
 
40
- # 2-a UMAP path: no-op dispatcher method
41
- dispatcher.Dispatcher.enable_caching = lambda self: None
42
 
43
- # 2-b Build a stub that pretends to be a FunctionCache
44
- class _NoCache(types.SimpleNamespace):
45
- def __init__(self, *_, **__): pass
46
- load_overload = lambda *_, **__: False
47
- save_overload = lambda *_, **__: None
48
- enable_caching = lambda *_, **__: None
49
-
50
- # 2-c Patch *every* place that still holds a reference
51
- caching.FunctionCache = _NoCache # core path
52
- ufuncbuilder.FunctionCache = _NoCache # PyNNDescent path
53
-
54
- # 2-d Extra belt-and-braces flag
55
- os.environ["NUMBA_DISABLE_CACHE"] = "1"
56
-
57
- except ImportError:
58
- # numba isn't installed yet during first pip install – harmless
59
- pass
60
- # ─────────────────────────────────────────────────────────────────────────
61
-
62
-
63
- # ── 3. Heavy imports (UMAP, BERTopic, FastAPI, …) ───────────────────────
64
  from typing import List
65
  from fastapi import FastAPI, HTTPException
66
  from pydantic import BaseModel
67
  from bertopic import BERTopic
68
  from sentence_transformers import SentenceTransformer
69
- # ---------- the rest of your file (config, model init, endpoint) stays unchanged ----------
70
 
71
- ...
72
- # ---------- rest of the file unchanged ----------
 
73
 
 
 
 
 
74
 
75
- # ── 4. Configuration via env vars ─────────────────────────────���───────────────
76
  MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
77
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
78
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
79
 
80
- # ── 5. Initialise models once at container start ─────────────────────────────
81
  embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
82
  topic_model = BERTopic(
83
  embedding_model=embeddings,
@@ -85,7 +48,7 @@ topic_model = BERTopic(
85
  calculate_probabilities=True,
86
  )
87
 
88
- # ── 6. Pydantic schemas ──────────────────────────────────────────────────────
89
  class Sentence(BaseModel):
90
  text: str
91
  start: float
@@ -105,17 +68,13 @@ class SegmentationResponse(BaseModel):
105
  run_id: str
106
  segments: List[Segment]
107
 
108
- # ── 7. FastAPI app and endpoint ──────────────────────────────────────────────
109
  app = FastAPI(title="CZ Topic Segmenter", version="1.0")
110
 
111
  @app.post("/segment", response_model=SegmentationResponse)
112
  def segment(sentences: List[Sentence]):
113
- # Guardrail: avoid oversize requests
114
  if len(sentences) > MAX_DOCS:
115
- raise HTTPException(
116
- status_code=413,
117
- detail=f"Too many sentences ({len(sentences)} > {MAX_DOCS})"
118
- )
119
 
120
  docs = [s.text for s in sentences]
121
  topics, probs = topic_model.fit_transform(docs)
@@ -125,25 +84,21 @@ def segment(sentences: List[Sentence]):
125
  if cur is None or t_id != cur["topic_id"]:
126
  if cur:
127
  segments.append(cur)
128
-
129
- # Top-5 keywords for this topic
130
- words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
131
-
132
  cur = dict(
133
- topic_id=t_id,
134
- label=" ".join(words) if t_id != -1 else None, # βœ“ fixed β€˜=’
135
- keywords=words,
136
- start=sentences[idx].start,
137
- end=sentences[idx].end,
138
- probability=float(prob or 0),
139
- sentences=[idx],
140
  )
141
  else:
142
  cur["end"] = sentences[idx].end
143
  cur["sentences"].append(idx)
144
-
145
  if cur:
146
  segments.append(cur)
147
 
148
  return {"run_id": str(uuid.uuid4()), "segments": segments}
149
- # ---------- END app.py ----------
 
1
+ # ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
2
+ import pkgutil, sentence_transformers, bertopic, sys, json
3
+
4
+ # 1) Print versions & model‐list
5
  print("ST version:", sentence_transformers.__version__)
6
  print("BERTopic version:", bertopic.__version__)
7
+ models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
8
+ print("ST models:", models)
9
  sys.stdout.flush()
10
 
11
+ # 2) If StaticEmbedding is missing, alias Transformer β†’ StaticEmbedding
12
+ if "StaticEmbedding" not in models:
13
  from sentence_transformers.models import Transformer
14
+ import sentence_transformers.models as _st_mod
15
+ setattr(_st_mod, "StaticEmbedding", Transformer)
16
+ print("πŸ”§ Shim applied: StaticEmbedding β†’ Transformer")
17
+ sys.stdout.flush()
18
+ # ──────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
20
 
21
+ # ── REST OF YOUR APP.PY ──────────────────────────────────────────────────────
22
+ import os, uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  from typing import List
24
  from fastapi import FastAPI, HTTPException
25
  from pydantic import BaseModel
26
  from bertopic import BERTopic
27
  from sentence_transformers import SentenceTransformer
 
28
 
29
+ # 0) Quick env dump
30
+ print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
31
+ sys.stdout.flush()
32
 
33
+ # 1) Tidy numba cache
34
+ os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
35
+ os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
36
+ os.environ["NUMBA_DISABLE_CACHE"] = "1"
37
 
38
+ # 2) Config from ENV
39
  MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
40
  MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
41
  MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
42
 
43
+ # 3) Initialise once
44
  embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
45
  topic_model = BERTopic(
46
  embedding_model=embeddings,
 
48
  calculate_probabilities=True,
49
  )
50
 
51
+ # 4) Schemas
52
  class Sentence(BaseModel):
53
  text: str
54
  start: float
 
68
  run_id: str
69
  segments: List[Segment]
70
 
71
+ # 5) FastAPI
72
  app = FastAPI(title="CZ Topic Segmenter", version="1.0")
73
 
74
  @app.post("/segment", response_model=SegmentationResponse)
75
  def segment(sentences: List[Sentence]):
 
76
  if len(sentences) > MAX_DOCS:
77
+ raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
 
 
 
78
 
79
  docs = [s.text for s in sentences]
80
  topics, probs = topic_model.fit_transform(docs)
 
84
  if cur is None or t_id != cur["topic_id"]:
85
  if cur:
86
  segments.append(cur)
87
+ words = [w for w,_ in topic_model.get_topic(t_id)[:5]]
 
 
 
88
  cur = dict(
89
+ topic_id = t_id,
90
+ label = None if t_id == -1 else " ".join(words),
91
+ keywords = words,
92
+ start = sentences[idx].start,
93
+ end = sentences[idx].end,
94
+ probability= float(prob or 0),
95
+ sentences = [idx],
96
  )
97
  else:
98
  cur["end"] = sentences[idx].end
99
  cur["sentences"].append(idx)
 
100
  if cur:
101
  segments.append(cur)
102
 
103
  return {"run_id": str(uuid.uuid4()), "segments": segments}
104
+ # ──────────────────────────────────────────────────────────────────────────────