Yeetek commited on
Commit
a48a75f
·
verified ·
1 Parent(s): 0655a6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -5
app.py CHANGED
@@ -1,7 +1,69 @@
1
- from fastapi import FastAPI
 
 
 
 
 
2
 
3
- app = FastAPI()
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, uuid
2
+ from typing import List
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from bertopic import BERTopic
6
+ from sentence_transformers import SentenceTransformer
7
 
8
+ MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
9
+ MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
10
+ MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
11
 
12
+ # --- init models once at container start ---
13
+ embeddings = SentenceTransformer(MODEL_NAME)
14
+ topic_model = BERTopic(
15
+ embedding_model = embeddings,
16
+ min_topic_size = MIN_TOPIC,
17
+ calculate_probabilities = True,
18
+ )
19
+
20
+ # -------- FastAPI schema ----------
21
+ class Sentence(BaseModel):
22
+ text: str
23
+ start: float
24
+ end: float
25
+ speaker: str | None = None
26
+
27
+ class Segment(BaseModel):
28
+ topic_id: int
29
+ label: str | None
30
+ keywords: List[str]
31
+ start: float
32
+ end: float
33
+ probability: float | None
34
+ sentences: List[int]
35
+
36
+ class SegmentationResponse(BaseModel):
37
+ run_id: str
38
+ segments: List[Segment]
39
+
40
+ app = FastAPI(title="CZ Topic Segmenter", version="1.0")
41
+
42
+ @app.post("/segment", response_model=SegmentationResponse)
43
+ def segment(sentences: List[Sentence]):
44
+ if len(sentences) > MAX_DOCS:
45
+ raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
46
+
47
+ docs = [s.text for s in sentences]
48
+ topics, probs = topic_model.fit_transform(docs)
49
+
50
+ segments, cur = [], None
51
+ for idx, (t_id, prob) in enumerate(zip(topics, probs)):
52
+ if cur is None or t_id != cur["topic_id"]:
53
+ if cur: segments.append(cur)
54
+ # top-5 keywords for the cluster
55
+ words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
56
+ cur = dict(topic_id=t_id,
57
+ label=" ".join(words) if t_id != -1 else None,
58
+ keywords=words,
59
+ start=sentences[idx].start,
60
+ end=sentences[idx].end,
61
+ probability=float(prob or 0),
62
+ sentences=[idx])
63
+ else:
64
+ cur["end"] = sentences[idx].end
65
+ cur["sentences"].append(idx)
66
+ if cur:
67
+ segments.append(cur)
68
+
69
+ return {"run_id": str(uuid.uuid4()), "segments": segments}