File size: 6,171 Bytes
a02884f
e972751
a02884f
e972751
14beccb
 
a02884f
 
06cace1
98c2919
a02884f
 
98c2919
a02884f
 
 
 
 
2e4795a
e972751
a48a75f
 
 
 
 
1570c6c
6948a45
e972751
 
7b5322f
a02884f
 
 
0d226e8
a02884f
 
 
 
0d226e8
a02884f
6a06d7b
 
 
0d226e8
 
 
6a06d7b
e972751
 
 
 
 
 
 
65c2b1d
e972751
6a06d7b
 
 
 
 
 
 
 
0d226e8
e972751
57191a9
 
e972751
 
 
 
 
57191a9
e972751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
import pkgutil, sentence_transformers, bertopic, sys, json, os, uuid

# 1) Print versions & model-list
print("ST version:", sentence_transformers.__version__)
print("BERTopic version:", bertopic.__version__)
models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
print("ST models:", models)
sys.stdout.flush()

# 2) If StaticEmbedding is missing, alias Transformer β†’ StaticEmbedding
if "StaticEmbedding" not in models:
    from sentence_transformers.models import Transformer
    import sentence_transformers.models as _st_mod
    setattr(_st_mod, "StaticEmbedding", Transformer)
    print("πŸ”§ Shim applied: StaticEmbedding β†’ Transformer")
    sys.stdout.flush()
# ──────────────────────────────────────────────────────────────────────────────

# ── REST OF APP.PY ───────────────────────────────────────────────────────────
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer
from stop_words import get_stop_words

# 0) Quick env dump
print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
sys.stdout.flush()

# 1) Tidy numba cache
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
os.environ["NUMBA_DISABLE_CACHE"] = "1"

# 2) Config from ENV
# Read model name from env and normalize to lowercase to match HF repo ID
env_model = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
MODEL_NAME = env_model
MIN_TOPIC  = int(os.getenv("MIN_TOPIC_SIZE", "10"))
MAX_DOCS   = int(os.getenv("MAX_DOCS", "5000"))

# 3) Set HF cache envs to a writeable folder (once at startup) envs to a writeable folder (once at startup)
cache_dir = "/tmp/hfcache"
os.makedirs(cache_dir, exist_ok=True)
import stat
os.chmod(cache_dir, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir

# 4) Initialise embeddings once
# Use huggingface_hub to snapshot-download the model locally
from huggingface_hub import snapshot_download
print(f"Downloading model {MODEL_NAME} to {cache_dir}...")
sys.stdout.flush()
local_model_path = snapshot_download(repo_id=MODEL_NAME, cache_dir=cache_dir)

# Load SentenceTransformer from local path
embeddings = SentenceTransformer(local_model_path, cache_folder=cache_dir)

# Pre-initialize fallback global models for small-batch debugging
# Global UMAP: 2-neighbors, cosine space, random init
global_umap = UMAP(
    n_neighbors=2,
    metric="cosine",
    init="random",
    random_state=42
)

# Global HDBSCAN: min cluster size 2, min_samples 1, cosine metric
global_hdbscan = HDBSCAN(
    min_cluster_size=2,
    min_samples=1,
    metric="cosine",
    cluster_selection_method="eom",
    prediction_data=True
)
# Global Czech vectorizer: stopwords + bigrams
global_vectorizer = CountVectorizer(
    stop_words=get_stop_words("czech"),
    ngram_range=(1, 2)
)

# 5) FastAPI schemas and app
class Sentence(BaseModel):
    text: str
    start: float
    end: float
    speaker: str | None = None
    chunk_index: int | None = None

class Segment(BaseModel):
    topic_id: int
    label: str | None
    keywords: List[str]
    start: float
    end: float
    probability: float | None
    sentences: List[int]

class SegmentationResponse(BaseModel):
    run_id: str
    segments: List[Segment]

app = FastAPI(title="CZ Topic Segmenter", version="1.0")

@app.get("/")
async def root():
    return {"message": "CZ Topic Segmenter is running."}

@app.post("/segment", response_model=SegmentationResponse)
def segment(sentences: List[Sentence]):
    print(f"[segment] Received {len(sentences)} sentences, chunk_indices={[s.chunk_index for s in sentences]}")
    sys.stdout.flush()
    if len(sentences) > MAX_DOCS:
        raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")

    # sort by chunk_index
    sorted_sent = sorted(
        sentences,
        key=lambda s: s.chunk_index if s.chunk_index is not None else 0
    )
    docs = [s.text for s in sorted_sent]

    # Use global UMAP/HDBSCAN/vectorizer instances for debugging
    umap_model = global_umap
    hdbscan_model = global_hdbscan
    vectorizer_model = global_vectorizer

    # instantiate BERTopic per request with global components
    topic_model = BERTopic(
        embedding_model=embeddings,
        umap_model=umap_model,
        hdbscan_model=hdbscan_model,
        vectorizer_model=vectorizer_model,
        min_topic_size=2,
        calculate_probabilities=True
    )

    topics, probs = topic_model.fit_transform(docs)

    segments, cur = [], None
    for idx, (t_id, prob) in enumerate(zip(topics, probs)):
        orig_idx = sorted_sent[idx].chunk_index if sorted_sent[idx].chunk_index is not None else idx
        if cur is None or t_id != cur["topic_id"]:
            words = [w for w, _ in topic_model.get_topic(t_id)[:5]]
            cur = dict(
                topic_id=t_id,
                label=" ".join(words) if t_id != -1 else None,
                keywords=words,
                start=sorted_sent[idx].start,
                end=sorted_sent[idx].end,
                probability=float(prob or 0),
                sentences=[orig_idx],
            )
        else:
            cur["end"] = sorted_sent[idx].end
            cur["sentences"].append(orig_idx)
    if cur:
        segments.append(cur)

    print(f"[segment] Returning {len(segments)} segments: {segments}")
    sys.stdout.flush()
    return {"run_id": str(uuid.uuid4()), "segments": segments}