Song commited on
Commit
f540a2a
·
1 Parent(s): 47bec05
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +420 -237
  3. requirements.txt +2 -5
Dockerfile CHANGED
@@ -26,7 +26,7 @@ COPY . /app
26
 
27
  # ---- Healthcheck ----
28
  HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
29
- CMD curl -f http://localhost:${PORT:-7860}/health || exit 1 # 改為 /health 路由
30
 
31
  # ---- Port & CMD ----
32
  EXPOSE 7860
 
26
 
27
  # ---- Healthcheck ----
28
  HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
29
+ CMD curl -f http://localhost:${PORT:-7860}/health || exit 1
30
 
31
  # ---- Port & CMD ----
32
  EXPOSE 7860
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  DrugQA (ZH) — 優化版 FastAPI LINE Webhook
5
- 結合了 Intent Detection, Answer Validation, 和更穩健的 RAG 流程。
6
  """
7
 
8
  # ---------- 環境與快取設定 (應置於最前) ----------
@@ -32,7 +32,7 @@ import pandas as pd
32
  from fastapi import FastAPI, Request, Response, HTTPException, status
33
  import uvicorn
34
  import jieba
35
- from fuzzywuzzy import process
36
  from rank_bm25 import BM25Okapi
37
  from sentence_transformers import SentenceTransformer, CrossEncoder
38
  import faiss
@@ -40,35 +40,180 @@ import torch
40
  from openai import OpenAI
41
  from tenacity import retry, stop_after_attempt, wait_fixed
42
 
43
- # ---------- 應用程式設定 (集中管理) ----------
44
- class AppConfig:
45
- # 檔案路徑
46
- CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
47
- SENTENCES_PKL = os.getenv("SENTENCES_PKL", "/tmp/drug_sentences.pkl")
48
- META_PKL = os.getenv("META_PKL", "/tmp/drug_meta.pkl")
49
- FAISS_INDEX = os.getenv("FAISS_INDEX", "/tmp/drug_sentences.index")
50
- BM25_PKL = os.getenv("BM25_PKL", "/tmp/bm25.pkl")
51
-
52
- # LINE Bot 設定
53
- CHANNEL_ACCESS_TOKEN = os.getenv("CHANNEL_ACCESS_TOKEN")
54
- CHANNEL_SECRET = os.getenv("CHANNEL_SECRET")
55
-
56
- # LLM & RAG 模型設定
57
- LITELLM_API_KEY = os.getenv("LITELLM_API_KEY")
58
- LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL")
59
- LM_MODEL = os.getenv("LM_MODEL", "gemma-7b-it") # 預設模型
60
- EMBEDDING_MODEL_ID = "DMetaSoul/Dmeta-embedding-zh"
61
- RERANKER_MODEL_ID = "BAAI/bge-reranker-v2-m3"
62
-
63
- # RAG 搜尋參數
64
- FUZZY_MATCH_THRESHOLD = 85
65
- TOP_K_FAISS = 20
66
- TOP_K_BM25 = 20
67
- TOP_K_RERANK = 10
68
- MAX_CONTEXT_CHARS = 4000
69
-
70
- # 應用程式狀態
71
- STATE = type('state', (), {})()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # ---------- 日誌設定 ----------
74
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -80,85 +225,77 @@ app = FastAPI(
80
  description="提供基於RAG的台灣藥品資訊查詢服務",
81
  version="2.0.0"
82
  )
 
 
 
 
 
 
 
 
 
83
  CONFIG = AppConfig()
84
 
85
- # ---------- 核心 RAG 邏輯 (封裝成類別) ----------
86
  class RagPipeline:
87
  def __init__(self, config):
88
  self.config = config
89
- self.state = config.STATE
90
- self.llm_client = OpenAI(api_key=config.LITELLM_API_KEY, base_url=config.LITELLM_BASE_URL)
 
 
 
 
91
 
92
  def _load_data(self):
93
  """在啟動時載入所有必要的模型與資料"""
94
  log.info("開始載入資料與模型...")
95
  # 載入 CSV
96
- if os.path.exists(self.config.CSV_PATH):
97
- self.state.df_csv = pd.read_csv(self.config.CSV_PATH, dtype=str).fillna('')
98
- self.state.df_csv['drug_name_norm_normalized'] = self.state.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
99
- log.info(f"成功載入 CSV: {self.config.CSV_PATH} (rows={len(self.state.df_csv)})")
100
  else:
101
- log.error(f"錯誤: 找不到 CSV 檔案於 {self.config.CSV_PATH}")
102
- self.state.df_csv = None
103
 
104
  # 載入語料庫與模型
105
- self.state.sentences, self.state.meta = self._ensure_pkl_exists(self.config.SENTENCES_PKL, self.config.META_PKL)
106
- self.state.emb_model = SentenceTransformer(self.config.EMBEDDING_MODEL_ID)
107
- self.state.reranker_model = CrossEncoder(self.config.RERANKER_MODEL_ID)
108
- self.state.faiss_index = self._ensure_faiss_index()
109
  self.state.bm25 = self._ensure_bm25_index()
110
  log.info("所有模型與資料載入完成。")
111
 
112
- def _ensure_pkl_exists(self, sentences_path, meta_path):
113
- if os.path.exists(sentences_path) and os.path.exists(meta_path):
114
- with open(sentences_path, "rb") as f_sent, open(meta_path, "rb") as f_meta:
115
- return pickle.load(f_sent), pickle.load(f_meta)
116
- log.warning(f"PKL 檔案不存在,將從 CSV 重新建立。")
117
- if self.state.df_csv is None: return [], []
118
-
119
- sentences = self.state.df_csv["content"].tolist()
120
- meta = self.state.df_csv[["drug_name_zh", "section"]].to_dict(orient="records")
121
-
122
- with open(sentences_path, "wb") as f_sent, open(meta_path, "wb") as f_meta:
123
- pickle.dump(sentences, f_sent)
124
- pickle.dump(meta, f_meta)
125
- log.info(f"已建立並儲存 PKL 檔案於 {sentences_path} 與 {meta_path}")
126
- return sentences, meta
127
-
128
- def _ensure_faiss_index(self):
129
- if os.path.exists(self.config.FAISS_INDEX):
130
- log.info(f"正在從 {self.config.FAISS_INDEX} 載入 FAISS 索引...")
131
- return faiss.read_index(self.config.FAISS_INDEX)
132
-
133
- log.warning("FAISS 索引不存在,正在建立新的索引...")
134
- if not self.state.sentences: return None
135
- embeddings = self.state.emb_model.encode(self.state.sentences, convert_to_tensor=True, show_progress_bar=True)
136
- index = faiss.IndexFlatIP(embeddings.shape[1])
137
- index.add(embeddings.cpu().numpy())
138
- faiss.write_index(index, self.config.FAISS_INDEX)
139
- log.info(f"FAISS 索引已建立並儲存至 {self.config.FAISS_INDEX}")
140
- return index
141
 
142
  def _ensure_bm25_index(self):
143
- if os.path.exists(self.config.BM25_PKL):
144
- with open(self.config.BM25_PKL, "rb") as f:
145
  return pickle.load(f)
146
 
147
  log.warning("BM25 索引不存在,正在建立新的索引...")
148
  if not self.state.sentences: return None
149
- tokenized_corpus = [list(jieba.cut(s)) for s in self.state.sentences]
150
  bm25 = BM25Okapi(tokenized_corpus)
151
- with open(self.config.BM25_PKL, "wb") as f:
152
  pickle.dump(bm25, f)
153
- log.info(f"BM25 索引已建立並儲存至 {self.config.BM25_PKL}")
154
  return bm25
155
 
156
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
157
- def _llm_call(self, messages, temperature=0.2, max_tokens=1024):
158
  """帶有重試機制的 LLM API 呼叫"""
159
  try:
160
  response = self.llm_client.chat.completions.create(
161
- model=self.config.LM_MODEL,
162
  messages=messages,
163
  temperature=temperature,
164
  max_tokens=max_tokens,
@@ -168,165 +305,180 @@ class RagPipeline:
168
  log.error(f"LLM API 呼叫失敗: {e}")
169
  raise
170
 
171
- async def answer_question(self, user_question: str) -> str:
172
- """處理使用者問題的完整流程"""
173
- # 1. 意圖偵測
174
- intent = self._detect_intent(user_question)
175
- if intent != "drug_query":
176
- return "本服務僅提供藥品資訊查詢,無法提供醫療建議或回答一般性問題。若有身體不適,請立即就醫。"
177
-
178
- # 2. 藥品名稱模糊比對
179
- drug_name = self._find_drug_name(user_question)
180
- if not drug_name:
181
- return "抱歉,我無法從您的問題中識別出明確的藥品名稱。請試著提供更完整的藥品名稱,例如:'普拿疼'。"
182
-
183
- # 3. 混合檢索 (Hybrid Search)
184
- retrieved_indices = self._hybrid_search(user_question)
185
- if not retrieved_indices:
186
- return f"抱歉,在資料庫中找不到與 '{drug_name}' 相關的資訊。請確認藥品名稱是否正確。"
187
-
188
- # 4. 重排序 (Reranking)
189
- reranked_indices = self._rerank(user_question, retrieved_indices)
190
-
191
- # 5. 建立上下文 (Context)
192
- context = self._build_context(reranked_indices)
193
-
194
- # 6. 生成答案 (Generation)
195
- prompt = self._make_generation_prompt(user_question, context)
196
- generated_answer = self._llm_call([{"role": "user", "content": prompt}])
197
-
198
- # 7. 答案驗證 (Validation)
199
- is_valid, reason = self._validate_answer(user_question, generated_answer, context)
200
- log.info(f"答案驗證結果: {'有效' if is_valid else '無效'}, 原因: {reason}")
201
-
202
- if not is_valid:
203
- return "系統生成的答案可能不完全準確,為安全起見,不提供此回覆。請嘗試用其他方式提問,或諮詢專業藥師。"
204
-
205
- # 8. 格式化最終回覆
206
- final_answer = self._format_final_answer(generated_answer, drug_name)
207
- return final_answer
208
-
209
- def _detect_intent(self, query: str) -> str:
210
- prompt = f"""
211
- 請判斷以下使用者問題的意圖。意圖只能是以下三者之一:'drug_query', 'medical_advice', 'general_greeting'。
212
- - 如果問題在詢問某個具體藥品的資訊(如副作用、用法、成分),意圖是 'drug_query'。
213
- - 如果問題在尋求診斷、治療建議或詢問症狀,意圖是 'medical_advice'。
214
- - 如果只是打招呼或閒聊,意圖是 'general_greeting'。
215
-
216
- 使用者問題: "{query}"
217
- 意圖:
218
- """
219
- response = self._llm_call([{"role": "user", "content": prompt}], temperature=0.0, max_tokens=20)
220
- # 簡單解析回應
221
- if "drug_query" in response.lower(): return "drug_query"
222
- if "medical_advice" in response.lower(): return "medical_advice"
223
- return "general_greeting"
224
-
225
- def _find_drug_name(self, query: str) -> Optional[str]:
226
- unique_drug_names = self.state.df_csv['drug_name_norm'].unique()
227
- normalized_query = query.lower().replace(" ", "")
228
-
229
- best_match = process.extractOne(normalized_query, unique_drug_names, score_cutoff=self.config.FUZZY_MATCH_THRESHOLD)
230
- if best_match:
231
- log.info(f"模糊比對成功: '{query}' -> '{best_match[0]}' (分數: {best_match[1]})")
232
- return best_match[0]
233
- log.warning(f"模糊比對失敗: '{query}' 的分數低於 {self.config.FUZZY_MATCH_THRESHOLD}")
234
- return None
235
-
236
- def _hybrid_search(self, query: str) -> List[int]:
237
- # Vector Search (FAISS)
238
- query_embedding = self.state.emb_model.encode([query], convert_to_tensor=True).cpu().numpy()
239
- _, faiss_indices = self.state.faiss_index.search(query_embedding, self.config.TOP_K_FAISS)
240
- faiss_indices = faiss_indices[0].tolist()
241
-
242
- # Keyword Search (BM25)
243
- tokenized_query = list(jieba.cut(query))
244
- bm25_scores = self.state.bm25.get_scores(tokenized_query)
245
- bm25_indices = np.argsort(bm25_scores)[::-1][:self.config.TOP_K_BM25].tolist()
246
-
247
- # 合併並去重
248
- combined_indices = list(dict.fromkeys(faiss_indices + bm25_indices))
249
- log.info(f"混合檢索找到 {len(combined_indices)} 個不重複的候選文件。")
250
- return combined_indices
251
-
252
- def _rerank(self, query: str, indices: List[int]) -> List[int]:
253
- pairs = [(query, self.state.sentences[i]) for i in indices]
254
- scores = self.state.reranker_model.predict(pairs, show_progress_bar=False)
255
-
256
- scored_indices = sorted(zip(indices, scores), key=lambda x: x[1], reverse=True)
257
- reranked_indices = [idx for idx, score in scored_indices[:self.config.TOP_K_RERANK]]
258
- return reranked_indices
259
-
260
- def _build_context(self, indices: List[int]) -> str:
261
- context_parts = []
262
- char_count = 0
263
- for i in indices:
264
- sentence_text = self.state.sentences[i]
265
- if char_count + len(sentence_text) > self.config.MAX_CONTEXT_CHARS:
266
- break
267
- context_parts.append(f"來源[{i+1}]: {sentence_text}")
268
- char_count += len(sentence_text)
269
- return "\n\n".join(context_parts)
270
-
271
- def _make_generation_prompt(self, query: str, context: str) -> str:
272
- return f"""
273
- 你是一位專業且謹慎��台灣藥師。請根據以下提供的「參考資料」,用繁體中文簡潔且有條理地回答使用者的問題。
274
-
275
- **你的任務**:
276
- 1. **完全基於「參考資料」**:你的回答必須完全依據下方提供的資料,禁止使用任何外部知識或進行推測。
277
- 2. **格式化輸出**:使用 Markdown 的標題 (##) 和點列 (*) 來組織答案,使其清晰易讀。
278
- 3. **引用來源**:如果資料中有多個來源,你不需要在答案中標註來源編號。
279
- 4. **無法回答時**:如果參考資料無法回答問題,請直接回覆「根據提供的資料,無法回答此問題。」
280
-
281
- **參考資料**:
282
- ---
283
- {context}
284
- ---
285
-
286
- **使用者問題**:
287
- "{query}"
288
-
289
- **你的回答**:
290
- """
291
-
292
- def _validate_answer(self, query: str, answer: str, context: str) -> Tuple[bool, str]:
293
- prompt = f"""
294
- 請根據以下提供的「原始資料」,評估「生成答案」是否準確地回答了「使用者問題」。
295
-
296
- **評估標準**:
297
- 1. **事實一致性**:答案中的所有資訊是否都可以在原始資料中找到對應依據?
298
- 2. **無遺漏**:答案是否遺漏了與問題相關的關鍵警告或重要資訊?
299
- 3. **無捏造**:答案是否包含了原始資料中沒有的資訊?
300
-
301
- **原始資料**:
302
- ---
303
- {context}
304
- ---
305
-
306
- **使用者問題**:
307
- "{query}"
308
-
309
- **生成答案**:
310
- "{answer}"
311
-
312
- **評估結果**:
313
- 請以 JSON 格式回覆,包含 'is_valid' (布林值) 和 'reason' (字串) 兩個鍵。
314
- 範例: {{"is_valid": true, "reason": "答案忠於原文,且總結了關鍵資訊。"}}
315
- 範例: {{"is_valid": false, "reason": "答案中提到的'每日三次'在原文中找不到依據。"}}
316
- """
317
- response = self._llm_call([{"role": "user", "content": prompt}], temperature=0.0, max_tokens=256)
318
  try:
319
- result = json.loads(response)
320
- return result.get("is_valid", False), result.get("reason", "無效的JSON格式")
321
- except json.JSONDecodeError:
322
- log.error(f"答案驗證的JSON解析失敗: {response}")
323
- return False, "無法解析驗證模型的JSON回覆"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- def _format_final_answer(self, answer: str, drug_name: str) -> str:
326
- disclaimer = f"--- \n*免責聲明:本資訊僅供參考,無法取代專業醫療建議。用藥前請務必諮詢藥師或醫師。資料來源為藥品仿單。*"
327
- header = f"## 💊 關於「{drug_name}」\n\n"
328
- return header + answer + "\n\n" + disclaimer
329
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  # ---------- FastAPI 事件與路由 ----------
332
  @app.on_event("startup")
@@ -396,7 +548,38 @@ def line_reply(reply_token: str, text: str):
396
  except Exception as e:
397
  log.error(f"LINE API 回覆失敗: {e}")
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  # ---------- 執行 (用於本地測試) ----------
400
  if __name__ == "__main__":
401
- port = int(os.getenv("PORT", 8080))
402
  uvicorn.run(app, host="0.0.0.0", port=port)
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  DrugQA (ZH) — 優化版 FastAPI LINE Webhook
5
+ 整合 kaggle_rag.py RAG 邏輯,包括 LLM 意圖偵測、子查詢分解、Intent-aware 檢索 & Rerank。
6
  """
7
 
8
  # ---------- 環境與快取設定 (應置於最前) ----------
 
32
  from fastapi import FastAPI, Request, Response, HTTPException, status
33
  import uvicorn
34
  import jieba
35
+ from fuzzywuzzy import fuzz, process
36
  from rank_bm25 import BM25Okapi
37
  from sentence_transformers import SentenceTransformer, CrossEncoder
38
  import faiss
 
40
  from openai import OpenAI
41
  from tenacity import retry, stop_after_attempt, wait_fixed
42
 
43
+ # ---- 匯入 (從 kaggle_rag.py 整合) ----
44
+ import ast
45
+ from typing import List, Dict, Any
46
+
47
+ # ==== CONFIG ( kaggle_rag.py 整合並調整為環境變數) ====
48
+ CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
49
+ FAISS_INDEX = os.getenv("FAISS_INDEX", "/tmp/drug_sentences.index")
50
+ SENTENCES_PKL = os.getenv("SENTENCES_PKL", "/tmp/drug_sentences.pkl")
51
+ META_PKL = os.getenv("META_PKL", "/tmp/drug_meta.pkl")
52
+ BM25_PKL = os.getenv("BM25_PKL", "/tmp/bm25.pkl")
53
+
54
+ TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 30))
55
+ PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
56
+ MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 50))
57
+
58
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
59
+ RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
60
+
61
+ _SENT_SPLIT_RE = re.compile(r"[。!?\n]")
62
+ DRUG_STOPWORDS = {"藥", "劑", "錠", "膠囊", "糖漿", "乳膏", "貼片"}
63
+
64
+ SECTION_WEIGHTS = {
65
+ "用法及用量": 1.0,
66
+ "病人使用須知": 1.0,
67
+ "儲存條件": 1.0,
68
+ "警語及注意事項": 1.0,
69
+ "禁忌": 1.0,
70
+ "副作用": 1.0,
71
+ "藥物交互作用": 1.0,
72
+ "其他": 1.0,
73
+ }
74
+
75
+ RERANK_THRESHOLD = float(os.getenv("RERANK_THRESHOLD", 0.5))
76
+
77
+ DRUG_NAME_MAPPING = {
78
+ "fentanyl patch": "fentanyl",
79
+ "spiriva respimat": "spiriva",
80
+ "augmentin for syrup": "augmentin syrup",
81
+ "nitrostat": "nitroglycerin",
82
+ "ozempic": "ozempic",
83
+ "niflec": "niflec",
84
+ "fosamax": "fosamax",
85
+ "humira": "humira",
86
+ "premarin": "premarin",
87
+ "smecta": "smecta",
88
+ }
89
+
90
+ LLM_API_CONFIG = {
91
+ "base_url": os.getenv("LITELLM_BASE_URL"),
92
+ "api_key": os.getenv("LITELLM_API_KEY"),
93
+ "model": os.getenv("LM_MODEL")
94
+ }
95
+
96
+ LLM_MODEL_CONFIG = {
97
+ "max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 12000)),
98
+ "max_tokens": int(os.getenv("MAX_TOKENS", 2048)),
99
+ "temperature": float(os.getenv("TEMPERATURE", 0.0)),
100
+ "top_p": float(os.getenv("TOP_P", 0.95)),
101
+ "stop_tokens": ["==="],
102
+ }
103
+
104
+ # --- 修改: 意圖分類類別已更新為新的精細化分類 (從 kaggle_rag.py)
105
+ INTENT_CATEGORIES = [
106
+ "操作 (Administration)",
107
+ "保存/攜帶 (Storage & Handling)",
108
+ "副作用/異常 (Side Effects / Issues)",
109
+ "劑型相關 (Dosage Form Concerns)",
110
+ "時間/併用 (Timing & Interaction)",
111
+ "劑量調整 (Dosage Adjustment)",
112
+ "禁忌症/適應症 (Contraindications/Indications)"
113
+ ]
114
+
115
+ DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
116
+
117
+ # ---- 工具函式 (從 kaggle_rag.py 整合) ----
118
+ def ensure_csv_path(path: str) -> str:
119
+ if os.path.exists(path):
120
+ return path
121
+ log.warning(f"找不到輸入檔案:{path},嘗試自動搜尋...")
122
+ # 簡化搜尋邏輯,假設在工作目錄
123
+ return path # 如需擴展,可添加 os.walk
124
+
125
+ def pick_text_column(df: pd.DataFrame) -> str:
126
+ candidates = ["content", "text", "sentence", "chunk", "cleaned_text"]
127
+ for c in candidates:
128
+ if c in df.columns:
129
+ return c
130
+ raise RuntimeError(f"CSV 缺少文字欄位,至少需包含其中之一:{candidates}")
131
+
132
+ def split_sentences(text: str) -> list:
133
+ if not isinstance(text, str):
134
+ return []
135
+ paragraphs = text.split("\n")
136
+ sents = []
137
+ for p in paragraphs:
138
+ if re.match(r"^\d+\.", p):
139
+ sents.append(p.strip())
140
+ else:
141
+ para_sents = [s.strip() for s in _SENT_SPLIT_RE.split(p) if s.strip()]
142
+ combined = ""
143
+ for s in para_sents:
144
+ combined += s + "。"
145
+ if len(combined) >= 50:
146
+ sents.append(combined.strip())
147
+ combined = ""
148
+ if combined:
149
+ sents.append(combined.strip())
150
+ return [s for s in sents if len(s) > 6]
151
+
152
+ def load_embedding_model():
153
+ device = "cuda" if torch.cuda.is_available() else "cpu"
154
+ log.info(f"載入 embedding 模型:{EMBEDDING_MODEL} 至 {device}...")
155
+ try:
156
+ model = SentenceTransformer(EMBEDDING_MODEL, device=device)
157
+ except Exception as e:
158
+ log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
159
+ device = "cpu"
160
+ model = SentenceTransformer(EMBEDDING_MODEL, device=device)
161
+ return model
162
+
163
+ def load_reranker_model():
164
+ device = "cuda" if torch.cuda.is_available() else "cpu"
165
+ log.info(f"載入 reranker 模型:{RERANKER_MODEL} 至 {device}...")
166
+ try:
167
+ model = CrossEncoder(RERANKER_MODEL, device=device)
168
+ except Exception as e:
169
+ log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
170
+ device = "cpu"
171
+ model = CrossEncoder(RERANKER_MODEL, device=device)
172
+ return model
173
+
174
+ def build_sentence_index(csv_path, faiss_path, sentences_pkl, meta_pkl, embedding_model):
175
+ log.info("建立 chunk 向量索引...")
176
+ df = pd.read_csv(csv_path, dtype=str)
177
+ text_col = pick_text_column(df)
178
+ texts = df[text_col].fillna("").astype(str).tolist()
179
+
180
+ metas = []
181
+ for i, row in df.iterrows():
182
+ metas.append({
183
+ "row_idx": int(i),
184
+ "chunk_id": row.get("chunk_id"),
185
+ "source_file": row.get("source_file"),
186
+ "section": row.get("section"),
187
+ "drug_id": row.get("drug_id"),
188
+ "drug_name_norm": str(row.get("drug_name_norm", "")).lower(),
189
+ })
190
+
191
+ unique_texts_meta = {}
192
+ for t, m in zip(texts, metas):
193
+ t = t.strip()
194
+ if t and len(t) >= 6:
195
+ key = t[:100]
196
+ if key not in unique_texts_meta:
197
+ unique_texts_meta[key] = (t, m)
198
+ filtered_texts = [v[0] for v in unique_texts_meta.values()]
199
+ filtered_meta = [v[1] for v in unique_texts_meta.values()]
200
+
201
+ if not filtered_texts:
202
+ raise RuntimeError("沒有可用 chunk 建立索引")
203
+
204
+ emb = embedding_model.encode(filtered_texts, show_progress_bar=True, convert_to_numpy=True).astype("float32")
205
+ faiss.normalize_L2(emb)
206
+
207
+ index = faiss.IndexFlatIP(emb.shape[1])
208
+ index.add(emb)
209
+ faiss.write_index(index, faiss_path)
210
+
211
+ with open(sentences_pkl, "wb") as f:
212
+ pickle.dump(filtered_texts, f)
213
+ with open(meta_pkl, "wb") as f:
214
+ pickle.dump(filtered_meta, f)
215
+
216
+ return index, filtered_texts, filtered_meta
217
 
218
  # ---------- 日誌設定 ----------
219
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
225
  description="提供基於RAG的台灣藥品資訊查詢服務",
226
  version="2.0.0"
227
  )
228
+
229
+ # ---------- 應用程式設定 (集中管理) ----------
230
+ class AppConfig:
231
+ # LINE Bot 設定
232
+ CHANNEL_ACCESS_TOKEN = os.getenv("CHANNEL_ACCESS_TOKEN")
233
+ CHANNEL_SECRET = os.getenv("CHANNEL_SECRET")
234
+
235
+ # 其他設定已在全局
236
+
237
  CONFIG = AppConfig()
238
 
239
+ # ---------- 核心 RAG 邏輯 (封裝成類別,整合 kaggle_rag.py) ----------
240
  class RagPipeline:
241
  def __init__(self, config):
242
  self.config = config
243
+ self.state = type('state', (), {})()
244
+ self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
245
+ self.embedding_model = load_embedding_model()
246
+ self.reranker = load_reranker_model()
247
+ self.csv_path = ensure_csv_path(CSV_PATH)
248
+ self.df_csv = pd.read_csv(self.csv_path, dtype=str)
249
 
250
  def _load_data(self):
251
  """在啟動時載入所有必要的模型與資料"""
252
  log.info("開始載入資料與模型...")
253
  # 載入 CSV
254
+ if os.path.exists(self.csv_path):
255
+ self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
256
+ self.df_csv['drug_name_norm_normalized'] = self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
257
+ log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
258
  else:
259
+ log.error(f"錯誤: 找不到 CSV 檔案於 {self.csv_path}")
260
+ self.df_csv = None
261
 
262
  # 載入語料庫與模型
263
+ self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
 
 
 
264
  self.state.bm25 = self._ensure_bm25_index()
265
  log.info("所有模型與資料載入完成。")
266
 
267
+ def _load_or_build_sentence_index(self):
268
+ if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL) and os.path.exists(META_PKL):
269
+ log.info("載入已存在的索引...")
270
+ index = faiss.read_index(FAISS_INDEX)
271
+ with open(SENTENCES_PKL, "rb") as f:
272
+ sentences = pickle.load(f)
273
+ with open(META_PKL, "rb") as f:
274
+ meta = pickle.load(f)
275
+ return index, sentences, meta
276
+
277
+ return build_sentence_index(self.csv_path, FAISS_INDEX, SENTENCES_PKL, META_PKL, self.embedding_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def _ensure_bm25_index(self):
280
+ if os.path.exists(BM25_PKL):
281
+ with open(BM25_PKL, "rb") as f:
282
  return pickle.load(f)
283
 
284
  log.warning("BM25 索引不存在,正在建立新的索引...")
285
  if not self.state.sentences: return None
286
+ tokenized_corpus = [jieba.lcut(s) for s in self.state.sentences]
287
  bm25 = BM25Okapi(tokenized_corpus)
288
+ with open(BM25_PKL, "wb") as f:
289
  pickle.dump(bm25, f)
290
+ log.info(f"BM25 索引已建立並儲存至 {BM25_PKL}")
291
  return bm25
292
 
293
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
294
+ def _llm_call(self, messages, temperature=LLM_MODEL_CONFIG["temperature"], max_tokens=LLM_MODEL_CONFIG["max_tokens"]):
295
  """帶有重試機制的 LLM API 呼叫"""
296
  try:
297
  response = self.llm_client.chat.completions.create(
298
+ model=LLM_API_CONFIG["model"],
299
  messages=messages,
300
  temperature=temperature,
301
  max_tokens=max_tokens,
 
305
  log.error(f"LLM API 呼叫失敗: {e}")
306
  raise
307
 
308
+ async def answer_question(self, q_orig: str) -> str:
309
+ """處理使用者問題的完整流程 (整合 kaggle_rag.py main 邏輯)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  try:
311
+ drug_ids = self._find_drug_ids_from_name(q_orig, self.df_csv)
312
+ if not drug_ids:
313
+ return f"未在資料庫中找到該藥品,請檢查名稱或諮詢醫師/藥師。{DISCLAIMER}"
314
+
315
+ attempt = 0
316
+ max_attempts = 3
317
+ answer_is_good = False
318
+ answer, context = "", ""
319
+
320
+ while attempt < max_attempts and not answer_is_good:
321
+ current_prerank_k = PRE_RERANK_K + attempt * 20
322
+ current_rerank_threshold = RERANK_THRESHOLD if attempt == 0 else max(RERANK_THRESHOLD - 0.1, 0.3)
323
+ sub_queries = self._decompose_query(q_orig)
324
+ all_reranked_results = []
325
+ processed_chunk_ids = set()
326
+
327
+ relevant_indices = [i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids]
328
+ if not relevant_indices:
329
+ return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
330
+
331
+ relevant_sentences = [self.state.sentences[i] for i in relevant_indices]
332
+ relevant_meta = [self.state.meta[i] for i in relevant_indices]
333
+ relevant_bm25 = BM25Okapi([jieba.lcut(s) for s in relevant_sentences]) # 優化:僅對相關drug計算BM25
334
+
335
+ for sub_q in sub_queries:
336
+ intents = self._detect_intent(sub_q)
337
+ expanded_q = self._expand_query_with_llm(sub_q, intents)
338
+ weights = self._adjust_section_weights(intents)
339
+
340
+ # 語意搜尋 (FAISS)
341
+ sim_indices, sim_scores = self._semantic_search(self.state.index, expanded_q, current_prerank_k * 5, self.embedding_model)
342
+
343
+ tokenized_query = list(jieba.cut(expanded_q))
344
+
345
+ # BM25
346
+ bm25_scores_relevant = relevant_bm25.get_scores(tokenized_query) if len(tokenized_query) else np.zeros(len(relevant_sentences))
347
+ bm25_scores_all = np.zeros(len(self.state.sentences))
348
+ for rel_idx, global_idx in enumerate(relevant_indices):
349
+ bm25_scores_all[global_idx] = bm25_scores_relevant[rel_idx]
350
+
351
+ # 融合
352
+ candidate_dict = {}
353
+ if sim_indices:
354
+ for i, sem_score in zip(sim_indices, sim_scores):
355
+ if i in set(relevant_indices):
356
+ if i not in candidate_dict:
357
+ candidate_dict[i] = {"sem": sem_score, "bm": 0.0}
358
+
359
+ bm25_top_indices = np.argsort(bm25_scores_all)[::-1][:current_prerank_k * 5]
360
+ for i in bm25_top_indices:
361
+ if i in set(relevant_indices):
362
+ bm_score = bm25_scores_all[i]
363
+ if i in candidate_dict:
364
+ candidate_dict[i]["bm"] = bm_score
365
+ else:
366
+ candidate_dict[i] = {"sem": 0.0, "bm": bm_score}
367
+
368
+ candidates = []
369
+ for i, scores in candidate_dict.items():
370
+ section_name = self.state.meta[i].get("section", "其他")
371
+ section_weight = weights.get(section_name, 1.0)
372
+ fused_score = (scores["sem"] * 0.5 + scores["bm"] * 0.4) * section_weight
373
+ candidates.append((i, fused_score, scores["sem"], scores["bm"]))
374
+
375
+ candidates.sort(key=lambda x: x[1], reverse=True)
376
+
377
+ # Reranker
378
+ sub_reranked = self._rerank_with_crossencoder(q_orig, candidates, self.state.sentences, self.reranker, TOP_K_SENTENCES, self.state.meta, current_rerank_threshold)
379
+
380
+ for r in sub_reranked:
381
+ if r.get('chunk_id') and r['chunk_id'] not in processed_chunk_ids:
382
+ all_reranked_results.append(r)
383
+ processed_chunk_ids.add(r['chunk_id'])
384
+ elif not r.get('chunk_id') and r['idx'] not in {res['idx'] for res in all_reranked_results}:
385
+ all_reranked_results.append(r)
386
+
387
+ all_reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True)
388
+
389
+ context = self._build_context(all_reranked_results, LLM_MODEL_CONFIG["max_context_chars"])
390
+
391
+ prompt = self._make_prompt(q_orig, context)
392
+
393
+ answer = self._llm_call([{"role": "user", "content": prompt}])
394
+
395
+ validation = self._validate_answer(q_orig, answer, context)
396
+ if validation["score"] >= 75:
397
+ answer_is_good = True
398
+ else:
399
+ attempt += 1
400
+
401
+ final_answer_formatted = self._format_final_answer(answer, DISCLAIMER)
402
+ return final_answer_formatted
403
+ except Exception as e:
404
+ log.error(f"處理查詢 {q_orig} 時發生錯誤: {e}")
405
+ return f"處理時發生錯誤,請檢查日志。{DISCLAIMER}"
406
+
407
+ # ---- 以下為從 kaggle_rag.py 整合的輔助函式 ----
408
+ def _find_drug_ids_from_name(self, query: str, df: pd.DataFrame) -> List[str]:
409
+ candidates = extract_drug_candidates_from_query(query)
410
+ expanded = expand_aliases(candidates)
411
+ drug_ids = set()
412
+ for alias in expanded:
413
+ matched_rows = df[df['drug_name_norm'].str.lower().str.contains(alias.lower(), na=False)]
414
+ drug_ids.update(matched_rows['drug_id'].unique())
415
+ return list(drug_ids)
416
+
417
+ def _decompose_query(self, query: str) -> List[str]:
418
+ prompt = f"將以下問題分解成1-3個子問題:{query}"
419
+ response = self._llm_call([{"role": "user", "content": prompt}])
420
+ return ast.literal_eval(response) if response else [query] # 假設回應為列表字符串
421
+
422
+ def _detect_intent(self, query: str) -> List[str]:
423
+ prompt = f"偵測以下問題的意圖類別,從 {INTENT_CATEGORIES} 中選擇:{query}"
424
+ response = self._llm_call([{"role": "user", "content": prompt}])
425
+ return ast.literal_eval(response) if response else []
426
+
427
+ def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
428
+ prompt = f"基於意圖 {intents} 擴展查詢:{query}"
429
+ return self._llm_call([{"role": "user", "content": prompt}])
430
+
431
+ def _adjust_section_weights(self, intents: List[str]) -> Dict[str, float]:
432
+ weights = SECTION_WEIGHTS.copy()
433
+ # 根據 intents 調整權重邏輯 (從 kaggle_rag.py 簡化)
434
+ return weights
435
+
436
+ def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
437
+ q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
438
+ faiss.normalize_L2(q_emb)
439
+ distances, indices = index.search(q_emb, top_k)
440
+ return indices[0].tolist(), distances[0].tolist()
441
+
442
+ def _rerank_with_crossencoder(self, query: str, candidates: List[Tuple], sentences: List[str], reranker, top_k: int, meta: List[Dict], threshold: float) -> List[Dict]:
443
+ pairs = [(query, sentences[i]) for i, _, _, _ in candidates]
444
+ scores = reranker.predict(pairs)
445
+ reranked = []
446
+ for (i, _, sem, bm), score in zip(candidates, scores):
447
+ if score >= threshold:
448
+ reranked.append({
449
+ "idx": i,
450
+ "rerank_score": score,
451
+ "sem_score": sem,
452
+ "bm_score": bm,
453
+ "meta": meta[i],
454
+ "text": sentences[i]
455
+ })
456
+ reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
457
+ return reranked[:top_k]
458
+
459
+ def _build_context(self, reranked_results: List[Dict], max_chars: int) -> str:
460
+ context = ""
461
+ for res in reranked_results:
462
+ text = res['text']
463
+ if len(context) + len(text) > max_chars:
464
+ break
465
+ context += text + "\n\n"
466
+ return context.strip()
467
 
468
+ def _make_prompt(self, query: str, context: str) -> str:
469
+ return f"基於以下上下文回答問題:{context}\n問題:{query}"
 
 
470
 
471
+ def _validate_answer(self, query: str, answer: str, context: str) -> Dict:
472
+ prompt = f"驗證答案是否準確:問題 {query},答案 {answer},上下文 {context}"
473
+ response = self._llm_call([{"role": "user", "content": prompt}])
474
+ # 假設回應為 JSON 字符串
475
+ try:
476
+ return json.loads(response)
477
+ except:
478
+ return {"score": 0, "reason": "無法解析"}
479
+
480
+ def _format_final_answer(self, answer: str, disclaimer: str) -> str:
481
+ return f"{answer}\n\n{disclaimer}"
482
 
483
  # ---------- FastAPI 事件與路由 ----------
484
  @app.on_event("startup")
 
548
  except Exception as e:
549
  log.error(f"LINE API 回覆失敗: {e}")
550
 
551
+ # ---- 從 kaggle_rag.py 整合的額外工具函式 ----
552
+ def extract_drug_candidates_from_query(query: str) -> list:
553
+ query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query)
554
+ candidates = set()
555
+ parts = query.split(":", 1)
556
+ drug_part = parts[0]
557
+ for m in re.finditer(r"[a-zA-Z]{3,}", drug_part):
558
+ candidates.add(m.group(0))
559
+ for token in re.split(r"[\s,/()()]+", drug_part):
560
+ clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip()
561
+ if clean_token and clean_token.lower() not in DRUG_STOPWORDS:
562
+ candidates.add(clean_token)
563
+ if drug_part.strip():
564
+ candidates.add(drug_part.strip())
565
+ for query_name, dataset_name in DRUG_NAME_MAPPING.items():
566
+ if query_name in query.lower():
567
+ candidates.add(dataset_name)
568
+ return [c for c in candidates if len(c) > 1]
569
+
570
+ def expand_aliases(candidates: list) -> list:
571
+ out = set()
572
+ for c in candidates:
573
+ s = c.strip()
574
+ if not s:
575
+ continue
576
+ out.add(s)
577
+ out.add(re.sub(r"[^0-9A-Za-z\u4e00-\u9fff]+", "", s))
578
+ out.add(s.lower())
579
+ out.add(s.upper())
580
+ return [x for x in out if x]
581
+
582
  # ---------- 執行 (用於本地測試) ----------
583
  if __name__ == "__main__":
584
+ port = int(os.getenv("PORT", 7860))
585
  uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt CHANGED
@@ -4,9 +4,6 @@ uvicorn[standard]
4
  python-multipart
5
  requests # 用於 LINE reply
6
 
7
- # LINE Bot SDK (可選,此處直接使用 requests)
8
- # line-bot-sdk==3.11.0
9
-
10
  # NLP / RAG
11
  numpy
12
  pandas
@@ -24,5 +21,5 @@ torchaudio
24
 
25
  # LLM 連接與穩定性
26
  openai
27
- litellm # 修正 ModuleNotFoundError
28
- tenacity # 新增,用於 API 重試
 
4
  python-multipart
5
  requests # 用於 LINE reply
6
 
 
 
 
7
  # NLP / RAG
8
  numpy
9
  pandas
 
21
 
22
  # LLM 連接與穩定性
23
  openai
24
+ litellm
25
+ tenacity # 用於 API 重試