Yeetek commited on
Commit
65c2b1d
·
verified ·
1 Parent(s): 8fc16d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -32
app.py CHANGED
@@ -42,11 +42,8 @@ MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
42
 
43
  # 3) Initialise once
44
  embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
45
- topic_model = BERTopic(
46
- embedding_model=embeddings,
47
- min_topic_size=MIN_TOPIC,
48
- calculate_probabilities=True,
49
- )
50
 
51
  # 4) Schemas
52
  class Sentence(BaseModel):
@@ -74,40 +71,57 @@ app = FastAPI(title="CZ Topic Segmenter", version="1.0")
74
 
75
  @app.post("/segment", response_model=SegmentationResponse)
76
  def segment(sentences: List[Sentence]):
 
77
  if len(sentences) > MAX_DOCS:
78
  raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
79
 
80
- # If chunk_index is present, use it to preserve original ordering
81
- sorted_sent = sorted(
82
- sentences,
83
- key=lambda s: s.chunk_index if s.chunk_index is not None else 0
84
- )
85
- docs = [s.text for s in sorted_sent]
 
 
 
 
 
 
 
 
 
 
86
  topics, probs = topic_model.fit_transform(docs)
87
 
88
- segments, cur = [], None
 
 
89
  for idx, (t_id, prob) in enumerate(zip(topics, probs)):
90
- # compute the original index
91
- orig_idx = (
92
- sorted_sent[idx].chunk_index
93
- if sorted_sent[idx].chunk_index is not None
94
- else idx
95
- )
96
-
97
- if cur is None or t_id != cur["topic_id"]:
98
- # …
99
- cur = dict(
100
- topic_id=t_id,
101
- label=" ".join(words) if t_id != -1 else None,
102
- keywords=words,
103
- start=sorted_sent[idx].start,
104
- end=sorted_sent[idx].end,
105
- probability=float(prob or 0),
106
- sentences=[orig_idx], # use chunk_index here
107
  )
108
- else:
109
- cur["end"] = sorted_sent[idx].end
110
- cur["sentences"].append(orig_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if cur:
112
  segments.append(cur)
113
 
 
42
 
43
  # 3) Initialise once
44
  embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
45
+
46
+ # We will not build a global topic_model because we need sorting per-request
 
 
 
47
 
48
  # 4) Schemas
49
  class Sentence(BaseModel):
 
71
 
72
  @app.post("/segment", response_model=SegmentationResponse)
73
  def segment(sentences: List[Sentence]):
74
+ # Guardrail: avoid oversize requests
75
  if len(sentences) > MAX_DOCS:
76
  raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
77
 
78
+ # Sort by chunk_index if available, else maintain original order
79
+ sorted_sent = sorted(
80
+ sentences,
81
+ key=lambda s: s.chunk_index if s.chunk_index is not None else 0
82
+ )
83
+ docs = [s.text for s in sorted_sent]
84
+
85
+ # Build topic model per request to preserve order mapping
86
+ from bertopic import BERTopic
87
+ topic_model = BERTopic(
88
+ embedding_model=embeddings,
89
+ min_topic_size=MIN_TOPIC,
90
+ calculate_probabilities=True,
91
+ )
92
+
93
+ # Fit-transform
94
  topics, probs = topic_model.fit_transform(docs)
95
 
96
+ # Assemble segments
97
+ segments = []
98
+ cur = None
99
  for idx, (t_id, prob) in enumerate(zip(topics, probs)):
100
+ # Map back to original chunk_index or positional idx
101
+ orig_idx = (
102
+ sorted_sent[idx].chunk_index
103
+ if sorted_sent[idx].chunk_index is not None
104
+ else idx
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
+ # When topic changes, push previous segment
107
+ if cur is None or t_id != cur["topic_id"]:
108
+ if cur:
109
+ segments.append(cur)
110
+ # Top-5 keywords for this topic
111
+ words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
112
+ cur = {
113
+ "topic_id": t_id,
114
+ "label": " ".join(words) if t_id != -1 else None,
115
+ "keywords": words,
116
+ "start": sorted_sent[idx].start,
117
+ "end": sorted_sent[idx].end,
118
+ "probability": float(prob or 0),
119
+ "sentences": [orig_idx],
120
+ }
121
+ else:
122
+ cur["end"] = sorted_sent[idx].end
123
+ cur["sentences"].append(orig_idx)
124
+ # Append last
125
  if cur:
126
  segments.append(cur)
127