ryota39 commited on
Commit
f193d86
·
verified ·
1 Parent(s): 8c3093b

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  vectorstore/ruri-large/index.faiss filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  vectorstore/ruri-large/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ vectorstore/static-embedding-japanese/index.faiss filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ from langchain.chains import RetrievalQA
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.llms import LlamaCpp
8
+
9
+
10
+ REPO_ID = "WariHima/sarashina2.2-1b-instruct-v0.1-Q4_K_M-GGUF"
11
+ FILENAME = "sarashina2.2-1b-instruct-v0.1-q4_k_m.gguf"
12
+
13
+
14
+ def get_model_path():
15
+ return hf_hub_download(
16
+ repo_id=REPO_ID,
17
+ filename=FILENAME,
18
+ repo_type="model",
19
+ )
20
+
21
+
22
+ GGUF_MODEL_PATH = get_model_path()
23
+ VECTOR_DB_PATH = "./vectorstore/static-embedding-japanese"
24
+ EMBEDDING_MODEL = "hotchpotch/static-embedding-japanese"
25
+
26
+
27
+ class RAGSystem:
28
+ def __init__(self):
29
+ self.vectorstore = None
30
+ self.qa_chain = None
31
+ self.setup_models()
32
+
33
+ def setup_models(self):
34
+ self.embeddings = HuggingFaceEmbeddings(
35
+ model_name=EMBEDDING_MODEL,
36
+ model_kwargs={"device": "cpu"},
37
+ )
38
+
39
+ try:
40
+ self.load_vectorstore()
41
+ except Exception as e:
42
+ print(f"ベクトルDBの読み込みに失敗しました: {str(e)}")
43
+
44
+ try:
45
+ self.llm = LlamaCpp(
46
+ model_path=GGUF_MODEL_PATH,
47
+ temperature=0.7,
48
+ max_tokens=512,
49
+ n_ctx=2048, # コンテキスト長
50
+ n_threads=8, # 使用するCPUスレッド数
51
+ n_gpu_layers=-1, # 可能であればGPUレイヤーを全て使用
52
+ verbose=False,
53
+ streaming=True,
54
+ model_kwargs={"f16_kv": True},
55
+ )
56
+
57
+ if self.vectorstore:
58
+ self.setup_qa_chain()
59
+ except Exception as e:
60
+ print(f"LLMの読み込みに失敗しました: {str(e)}")
61
+
62
+ def load_vectorstore(self):
63
+ if os.path.exists(VECTOR_DB_PATH):
64
+ self.vectorstore = FAISS.load_local(
65
+ VECTOR_DB_PATH,
66
+ self.embeddings,
67
+ allow_dangerous_deserialization=True,
68
+ )
69
+ if self.llm:
70
+ self.setup_qa_chain()
71
+ return True
72
+ return False
73
+
74
+ def setup_qa_chain(self):
75
+ if self.vectorstore and self.llm:
76
+ self.qa_chain = RetrievalQA.from_chain_type(
77
+ llm=self.llm,
78
+ chain_type="stuff",
79
+ retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}),
80
+ )
81
+ return True
82
+ return False
83
+
84
+ def answer_question_stream(self, question):
85
+ if not self.qa_chain:
86
+ if not self.vectorstore:
87
+ yield "ベクトルDBが読み込まれていません。"
88
+ return
89
+ if not self.llm:
90
+ yield "LLMモデルが読み込まれていません。"
91
+ return
92
+ yield "QAチェーンの初期化に失敗しました。"
93
+ return
94
+
95
+ try:
96
+ docs = self.vectorstore.similarity_search(question, k=3)
97
+ context = "\n\n".join([doc.page_content for doc in docs])
98
+
99
+ prompt = f"""与えられた文書を用いて、質問に対する適切な応答を書きなさい。
100
+ 文書: {context}
101
+ 質問: {question}
102
+ 応答: """
103
+
104
+ response = ""
105
+ for chunk in self.llm._stream(prompt):
106
+ if isinstance(chunk, str):
107
+ response += chunk
108
+ else:
109
+ response += chunk.text
110
+ yield response
111
+
112
+ except Exception as e:
113
+ yield f"回答生成中にエラーが発生しました: {str(e)}"
114
+
115
+ def get_system_status(self):
116
+ status = list()
117
+ if os.path.exists(GGUF_MODEL_PATH):
118
+ model_size = os.path.getsize(GGUF_MODEL_PATH) / (1024 * 1024 * 1024)
119
+ status.append(
120
+ f"✅ LLMモデル: {os.path.basename(GGUF_MODEL_PATH)} ({model_size:.2f} GB)"
121
+ )
122
+ else:
123
+ status.append(f"❌ LLMモデル: {GGUF_MODEL_PATH} が見つかりません")
124
+
125
+ if os.path.exists(VECTOR_DB_PATH):
126
+ status.append(f"✅ ベクトルDB: {VECTOR_DB_PATH}")
127
+ else:
128
+ status.append(f"❌ ベクトルDB: {VECTOR_DB_PATH} が見つかりません")
129
+
130
+ status.append(f"✅ 埋め込みモデル: {EMBEDDING_MODEL}")
131
+
132
+ if self.qa_chain:
133
+ status.append("✅ RAGシステム: 準備完了")
134
+ else:
135
+ status.append("❌ RAGシステム: 初期化されていません")
136
+
137
+ return "\n".join(status)
138
+
139
+
140
+ rag_system = RAGSystem()
141
+
142
+ with gr.Blocks(title="RAGデモアプリ") as demo:
143
+ gr.Markdown("# 🎇 Sake RAG デモアプリ")
144
+ gr.Markdown(
145
+ "醸造協会誌5年分のデータをベクトルDBとして保持した1B級の小型言語モデルです"
146
+ )
147
+
148
+ with gr.Row():
149
+ with gr.Column(scale=1):
150
+ refresh_button = gr.Button("システム状態を更新", variant="secondary")
151
+ status_output = gr.Textbox(
152
+ label="システム状態",
153
+ value=rag_system.get_system_status(),
154
+ interactive=False,
155
+ lines=5,
156
+ )
157
+
158
+ with gr.Column(scale=2):
159
+ question_input = gr.Textbox(
160
+ label="質問を入力してください",
161
+ placeholder="質問を入力してください",
162
+ lines=2,
163
+ )
164
+ submit_button = gr.Button("質問する", variant="primary")
165
+ answer_output = gr.Textbox(label="回答", interactive=False, lines=10)
166
+
167
+ refresh_button.click(
168
+ fn=rag_system.get_system_status,
169
+ inputs=[],
170
+ outputs=[status_output],
171
+ )
172
+
173
+ submit_button.click(
174
+ fn=rag_system.answer_question_stream,
175
+ inputs=[question_input],
176
+ outputs=[answer_output],
177
+ )
178
+
179
+
180
+ if __name__ == "__main__":
181
+ demo.launch()
vectorstore/static-embedding-japanese/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:598f40e8715d934190631a8298a58371a957fddba4ba4572b83f6ec9676e85af
3
+ size 7405613
vectorstore/static-embedding-japanese/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae8fff1ec1e0adda6565cea84064a604ae345e751c1dc522aca9185b23bc53d7
3
+ size 826595