leo-pasi commited on
Commit
3944997
·
1 Parent(s): 72f6fc2

updated main app

Browse files
Files changed (1) hide show
  1. scripts/app.py +104 -47
scripts/app.py CHANGED
@@ -1,5 +1,13 @@
 
 
 
1
  import gradio as gr
 
 
 
 
2
 
 
3
  from src.mythesis_chatbot.rag_setup import (
4
  SupportedRags,
5
  automerging_retrieval_setup,
@@ -7,68 +15,117 @@ from src.mythesis_chatbot.rag_setup import (
7
  sentence_window_retrieval_setup,
8
  )
9
 
10
- input_file = "./data/Master_Thesis.pdf"
11
- save_dir = "./data/indices/"
12
-
13
- automerging_engine = automerging_retrieval_setup(
14
- input_file=input_file,
15
- save_dir=save_dir,
16
- llm_openai_model="gpt-4o-mini",
17
- embed_model="BAAI/bge-small-en-v1.5",
18
- chunk_sizes=[2048, 512, 128],
19
- similarity_top_k=6,
20
- rerank_model="cross-encoder/ms-marco-MiniLM-L-2-v2",
21
- rerank_top_n=2,
22
- )
23
 
24
- sentence_window_engine = sentence_window_retrieval_setup(
25
- input_file=input_file,
26
- save_dir=save_dir,
27
- llm_openai_model="gpt-4o-mini",
28
- embed_model="BAAI/bge-small-en-v1.5",
29
- sentence_window_size=3,
30
- similarity_top_k=6,
31
- rerank_model="cross-encoder/ms-marco-MiniLM-L-2-v2",
32
- rerank_top_n=2,
33
- )
34
 
35
- basic_engine = basic_rag_setup(
36
- input_file=input_file,
37
- save_dir=save_dir,
38
- llm_openai_model="gpt-4o-mini",
39
- embed_model="BAAI/bge-small-en-v1.5",
40
- similarity_top_k=6,
41
- rerank_model="cross-encoder/ms-marco-MiniLM-L-2-v2",
42
- rerank_top_n=2,
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
45
 
46
- def chat_bot(query: str, rag_mode: SupportedRags) -> str:
47
- if rag_mode == "basic":
48
- return basic_engine.query(query).response
49
- if rag_mode == "auto-merging retrieval":
50
- return automerging_engine.query(query).response
51
- if rag_mode == "sentence window retrieval":
52
- return sentence_window_engine.query(query).response
53
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  default_message = (
56
- "Ask a about a topic that is discussed in my master thesis."
57
- " E.g., what is epistemic uncertainty?"
58
  )
59
 
 
 
 
60
  gradio_app = gr.Interface(
61
  fn=chat_bot,
62
  inputs=[
63
- gr.Textbox(placeholder=default_message),
64
  gr.Dropdown(
65
- choices=["basic", "sentence window retrieval", "auto-merging retrieval"],
66
  label="RAG mode",
67
- value="basic",
68
  ),
69
  ],
70
- outputs=["text"],
 
 
 
 
71
  )
72
 
73
- if __name__ == "__main__":
74
- gradio_app.launch()
 
1
+ import os
2
+ from pathlib import Path
3
+
4
  import gradio as gr
5
+ import nest_asyncio
6
+ import yaml
7
+ from trulens.core import TruSession
8
+ from trulens.dashboard import run_dashboard
9
 
10
+ from src.mythesis_chatbot.evaluation import get_prebuilt_trulens_recorder
11
  from src.mythesis_chatbot.rag_setup import (
12
  SupportedRags,
13
  automerging_retrieval_setup,
 
15
  sentence_window_retrieval_setup,
16
  )
17
 
18
+ input_file_dir = Path(__file__).parents[1] / "data/"
19
+ save_dir = Path(__file__).parents[1] / "data/indices/"
20
+ config_dir = Path(__file__).parents[1] / "configs/"
21
+ welcome_message_path = Path(__file__).parents[1] / "spaces/welcome_message.md"
 
 
 
 
 
 
 
 
 
22
 
23
+ # Enables running async code inside an existing event loop without crashing.
24
+ nest_asyncio.apply()
 
 
 
 
 
 
 
 
25
 
26
+ tru = TruSession(database_url=os.getenv("SUPABASE_CONNECTION_STRING"))
27
+ run_dashboard(tru)
28
+
29
+
30
+ class ChatBot:
31
+ def __init__(
32
+ self,
33
+ input_file_dir,
34
+ save_dir,
35
+ config_dir,
36
+ ):
37
+ self.recorder = None
38
+ self.previous_rag_mode = None
39
+ self.recorder = None
40
+
41
+ with open(os.path.join(config_dir, "basic.yaml")) as f:
42
+ self.basic_config = yaml.safe_load(f)
43
+ with open(os.path.join(config_dir, "auto_merging.yaml")) as f:
44
+ self.automerging_config = yaml.safe_load(f)
45
+ with open(os.path.join(config_dir, "sentence_window.yaml")) as f:
46
+ self.sentence_window_config = yaml.safe_load(f)
47
+
48
+ self.basic_engine = basic_rag_setup(
49
+ input_file=os.path.join(input_file_dir, self.basic_config["source_doc"]),
50
+ save_dir=save_dir,
51
+ **self.basic_config,
52
+ )
53
+ self.automerging_engine = automerging_retrieval_setup(
54
+ input_file=os.path.join(
55
+ input_file_dir, self.automerging_config["source_doc"]
56
+ ),
57
+ save_dir=save_dir,
58
+ **self.automerging_config,
59
+ )
60
+ self.sentence_window_engine = sentence_window_retrieval_setup(
61
+ input_file=os.path.join(
62
+ input_file_dir, self.sentence_window_config["source_doc"]
63
+ ),
64
+ save_dir=save_dir,
65
+ **self.sentence_window_config,
66
+ )
67
+
68
+ def __call__(self, query: str, rag_mode: SupportedRags):
69
 
70
+ match rag_mode:
71
+ case "classic retrieval":
72
 
73
+ if self.previous_rag_mode != rag_mode:
74
+ self.previous_rag_mode = rag_mode
75
+ self.recorder = get_prebuilt_trulens_recorder(
76
+ self.basic_engine, self.basic_config
77
+ )
 
 
78
 
79
+ with self.recorder as recording: # noqa: F841
80
+ response = self.basic_engine.query(query)
81
 
82
+ case "auto-merging retrieval":
83
+ if self.previous_rag_mode != rag_mode:
84
+ self.previous_rag_mode = rag_mode
85
+ self.recorder = get_prebuilt_trulens_recorder(
86
+ self.automerging_engine, self.automerging_config
87
+ )
88
+
89
+ with self.recorder as recording: # noqa: F841
90
+ response = self.automerging_engine.query(query)
91
+
92
+ case "sentence window retrieval":
93
+ if self.previous_rag_mode != rag_mode:
94
+ self.previous_rag_mode = rag_mode
95
+ self.recorder = get_prebuilt_trulens_recorder(
96
+ self.sentence_window_engine, self.sentence_window_config
97
+ )
98
+
99
+ with self.recorder as recording: # noqa: F841
100
+ response = self.sentence_window_engine.query(query)
101
+
102
+ return response.response
103
+
104
+
105
+ chat_bot = ChatBot(input_file_dir, save_dir, config_dir)
106
  default_message = (
107
+ "Ask about a topic that is discussed in my master thesis."
108
+ " E.g., what is this master thesis about? Or what is epistemic uncertainty?"
109
  )
110
 
111
+ with open(welcome_message_path, encoding="utf-8") as f:
112
+ description = f.read()
113
+
114
  gradio_app = gr.Interface(
115
  fn=chat_bot,
116
  inputs=[
117
+ gr.Textbox(placeholder=default_message, label="Query"),
118
  gr.Dropdown(
119
+ choices=SupportedRags.__args__,
120
  label="RAG mode",
121
+ value=SupportedRags.__args__[0],
122
  ),
123
  ],
124
+ outputs=[
125
+ gr.Textbox(label="Answer"),
126
+ ],
127
+ title="RAG powered chatbot",
128
+ description=description,
129
  )
130
 
131
+ gradio_app.launch()