Upload test_run.ipynb
Browse files- test_run.ipynb +241 -0
test_run.ipynb
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": []
|
7 |
+
},
|
8 |
+
"kernelspec": {
|
9 |
+
"name": "python3",
|
10 |
+
"display_name": "Python 3"
|
11 |
+
},
|
12 |
+
"language_info": {
|
13 |
+
"name": "python"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"cells": [
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {
|
21 |
+
"id": "wDmLkBbZmJvB"
|
22 |
+
},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"# ===============================\n",
|
26 |
+
"# 1. ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น (Google Colab)\n",
|
27 |
+
"# ===============================\n",
|
28 |
+
"!pip install unsloth xformers faiss-gpu-cu12 -U\n",
|
29 |
+
"!pip install --no-deps --upgrade \"flash-attn>=2.6.3\"\n",
|
30 |
+
"!pip install -U hf_transfer"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"source": [
|
36 |
+
"# ===============================\n",
|
37 |
+
"# 2. ํ๊ฒฝ ์ค์ \n",
|
38 |
+
"# ===============================\n",
|
39 |
+
"import os\n",
|
40 |
+
"import torch\n",
|
41 |
+
"import numpy as np\n",
|
42 |
+
"import faiss\n",
|
43 |
+
"import json\n",
|
44 |
+
"import ast\n",
|
45 |
+
"from transformers import TextStreamer\n",
|
46 |
+
"from sentence_transformers import SentenceTransformer\n",
|
47 |
+
"from unsloth import FastLanguageModel\n",
|
48 |
+
"from huggingface_hub import hf_hub_download\n",
|
49 |
+
"\n",
|
50 |
+
"os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
|
51 |
+
],
|
52 |
+
"metadata": {
|
53 |
+
"id": "OsEBB0aKmhBy"
|
54 |
+
},
|
55 |
+
"execution_count": null,
|
56 |
+
"outputs": []
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"# ===============================\n",
|
62 |
+
"# 3. ๋ชจ๋ธ ๋ก๋\n",
|
63 |
+
"# ===============================\n",
|
64 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
65 |
+
" model_name=\"Austin9/gemma-2-9b-it-Ko-RAG\",\n",
|
66 |
+
" max_seq_length=8192,\n",
|
67 |
+
" dtype=torch.float16,\n",
|
68 |
+
" load_in_4bit=True\n",
|
69 |
+
")\n",
|
70 |
+
"FastLanguageModel.for_inference(model)"
|
71 |
+
],
|
72 |
+
"metadata": {
|
73 |
+
"id": "ENT1FgZZmizd"
|
74 |
+
},
|
75 |
+
"execution_count": null,
|
76 |
+
"outputs": []
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"source": [
|
81 |
+
"# ===============================\n",
|
82 |
+
"# 4. FAISS ์ธ๋ฑ์ค ๋ก๋ (Hugging Face Hub์์ ์ง์ ๋ค์ด๋ก๋)\n",
|
83 |
+
"# ===============================\n",
|
84 |
+
"repo_id = \"Austin9/gemma-2-9b-it-Ko-RAG\" # ํ๊น
ํ์ด์ค ์ ์ฅ์ ID\n",
|
85 |
+
"filename = \"chunked_data_vectors.npz\" # ์ ์ฅ๋ npz ํ์ผ ์ด๋ฆ\n",
|
86 |
+
"\n",
|
87 |
+
"vector_db_path = hf_hub_download(repo_id=repo_id, filename=filename)\n",
|
88 |
+
"data = np.load(vector_db_path)\n",
|
89 |
+
"vectors, texts, titles = data[\"vectors\"], data[\"texts\"], data[\"titles\"]\n",
|
90 |
+
"\n",
|
91 |
+
"gpu_resources = faiss.StandardGpuResources()\n",
|
92 |
+
"faiss_index = faiss.GpuIndexFlatL2(gpu_resources, vectors.shape[1])\n",
|
93 |
+
"faiss_index.add(vectors)"
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"id": "9H7Xcc9GmkQ8"
|
97 |
+
},
|
98 |
+
"execution_count": null,
|
99 |
+
"outputs": []
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"source": [
|
104 |
+
"# ===============================\n",
|
105 |
+
"# 5. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋\n",
|
106 |
+
"# ===============================\n",
|
107 |
+
"embedding_model = SentenceTransformer(\"nlpai-lab/KURE-v1\", device=\"cuda\").to(torch.float16)"
|
108 |
+
],
|
109 |
+
"metadata": {
|
110 |
+
"id": "EwpMV0kXmpSX"
|
111 |
+
},
|
112 |
+
"execution_count": null,
|
113 |
+
"outputs": []
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"source": [
|
118 |
+
"# ===============================\n",
|
119 |
+
"# 6. JSON ํ์ฑ ํจ์\n",
|
120 |
+
"# ===============================\n",
|
121 |
+
"def robust_parse_json(response_text):\n",
|
122 |
+
" response_text = response_text.strip().strip(\"'\").strip('\"').replace(\"'\", '\"')\n",
|
123 |
+
" try:\n",
|
124 |
+
" return json.loads(response_text)\n",
|
125 |
+
" except:\n",
|
126 |
+
" try:\n",
|
127 |
+
" return ast.literal_eval(response_text)\n",
|
128 |
+
" except:\n",
|
129 |
+
" return {\"search\": \"\"}\n",
|
130 |
+
"\n",
|
131 |
+
"# ===============================\n",
|
132 |
+
"# 7. ๊ฒ์ ์ฟผ๋ฆฌ ์์ฑ (QCR ๋จ๊ณ)\n",
|
133 |
+
"# ===============================\n",
|
134 |
+
"def generate_search_query(conversation_history, user_input):\n",
|
135 |
+
" instruction = (\n",
|
136 |
+
" \"๋ค์์ ๋ํ ๊ธฐ๋ก(Context)์ ์ฌ์ฉ์์ ์ง๋ฌธ(Input)์
๋๋ค. \"\n",
|
137 |
+
" \"์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ต์ ์ ๊ณตํ๊ธฐ ์ํด ํ์ํ ๋จ์ผ ๋ฌธ์์ด ๊ฒ์ ์ฟผ๋ฆฌ๋ฅผ ์์ฑํ์ธ์. \"\n",
|
138 |
+
" \"๊ฒ์์ด ํ์ํ์ง ์๊ฑฐ๋ ๊ฒ์์ด ๋ถํ์ํ ๊ฒฝ์ฐ(์ธ์ฌ๋, ๊ฒ์น๋ , ๋๋ด) ๋น ๋ฌธ์์ด์ ๋ฐํํ์ธ์.\\n\\n\"\n",
|
139 |
+
" \"์ต์ข
์ถ๋ ฅ ํ์์ {'search': '<๊ฒ์ ์ฟผ๋ฆฌ>'}์
๋๋ค.\"\n",
|
140 |
+
" )\n",
|
141 |
+
" prompt = f\"\"\"\n",
|
142 |
+
" # Query Rewriter\n",
|
143 |
+
" ### Instruction:\n",
|
144 |
+
" {instruction}\n",
|
145 |
+
" ### Conversation:\n",
|
146 |
+
" {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
|
147 |
+
" ### Input:\n",
|
148 |
+
" {user_input}\n",
|
149 |
+
" ### Response:\n",
|
150 |
+
" \"\"\"\n",
|
151 |
+
"\n",
|
152 |
+
" inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
|
153 |
+
" output_tokens = model.generate(**inputs, max_new_tokens=300)\n",
|
154 |
+
" response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()\n",
|
155 |
+
" return robust_parse_json(response_text).get(\"search\", \"\")\n",
|
156 |
+
"\n",
|
157 |
+
"# ===============================\n",
|
158 |
+
"# 8. FAISS ๊ฒ์\n",
|
159 |
+
"# ===============================\n",
|
160 |
+
"def search_documents(query, k=3):\n",
|
161 |
+
" if not query:\n",
|
162 |
+
" return \"\"\n",
|
163 |
+
" query_vector = embedding_model.encode([query])[0]\n",
|
164 |
+
" _, indices = faiss_index.search(np.array([query_vector]), k)\n",
|
165 |
+
" return \"\\n\\n\".join([f\"# Index [{i+1}]: {titles[idx]}\\n{texts[idx]}\" for i, idx in enumerate(indices[0])])\n",
|
166 |
+
"\n",
|
167 |
+
"# ===============================\n",
|
168 |
+
"# 9. ๋ต๋ณ ์์ฑ\n",
|
169 |
+
"# ===============================\n",
|
170 |
+
"def generate_response(conversation_history, context, user_input):\n",
|
171 |
+
" instruction = (\n",
|
172 |
+
" \"๋น์ ์ ์ธ๋ถ๊ฒ์์ ์ด์ฉํ์ฌ ์ฌ์ฉ์์๊ฒ ๋์์ ์ฃผ๋ ์ธ๊ณต์ง๋ฅ ์กฐ์์
๋๋ค.\\n\"\n",
|
173 |
+
" \"- Context๋ ์ธ๋ถ๊ฒ์์ ํตํด ๋ฐํ๋ ์ฌ์ฉ์ ์์ฒญ๊ณผ ๊ด๋ จ๋ ์ ๋ณด๋ค์
๋๋ค.\\n\"\n",
|
174 |
+
" \"- Context๋ฅผ ํ์ฉํ ๋ ๋ฌธ์ฅ ๋์ ์ฌ์ฉํ ๋ฌธ์ ์กฐ๊ฐ์ [Index]๋ฅผ ๋ถ์ด๊ณ ์์ฐ์ค๋ฌ์ด ๋ต๋ณ์ ์์ฑํ์ธ์. (e.g. [1])\\n\"\n",
|
175 |
+
" \"- Context์ ์ ๋ณด๊ฐ ์ฌ์ฉ์ ์์ฒญ๊ณผ ๊ด๋ จ์ด ์๊ฑฐ๋ ๋์์ด ์๋ ์๋ ์์ต๋๋ค. ๊ด๋ จ์๋ ์ ๋ณด๋ง ํ์ฉํ๊ณ , ์๋ ์ ๋ณด๋ฅผ ์ ๋ ์ง์ด๋ด์ง ๋ง์ธ์.\\n\"\n",
|
176 |
+
" \"- ๋๋๋ก์ด๋ฉด ์ผ๋ฐ ์ง์์ผ๋ก ๋ต๋ณํ์ง๋ง๊ณ , ์ต๋ํ Context๋ฅผ ํตํด์ ๋ต๋ณ์ ํ๋ ค๊ณ ํ์ธ์. Context์ ์์ ๊ฒฝ์ฐ์๋ ์ด ์ ์ ์ธ๊ธํ๋ฉฐ ์ฌ์ฃํ๊ณ ๋ค๋ฅธ ์ฃผ์ ๋ ์ง๋ฌธ์ ์ถ์ฒํด์ฃผ์ธ์.\\n\"\n",
|
177 |
+
" \"- ์ฌ์ฉ์ ์์ฒญ์ ์๋ง๋ ์์ฐ์ค๋ฌ์ด ๋ํ๋ฅผ ํ์ธ์.\\n\"\n",
|
178 |
+
" \"- ํญ์ ์กด๋๋ง๋ก ๋ต๋ณํ์ธ์.\"\n",
|
179 |
+
" )\n",
|
180 |
+
"\n",
|
181 |
+
" prompt = f\"\"\"\n",
|
182 |
+
" # Generator\n",
|
183 |
+
" ### Instruction:\n",
|
184 |
+
" {instruction}\n",
|
185 |
+
" ### Conversation:\n",
|
186 |
+
" {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
|
187 |
+
" ### Context:\n",
|
188 |
+
" {context}\n",
|
189 |
+
" ### Input:\n",
|
190 |
+
" {user_input}\n",
|
191 |
+
" ### Response:\n",
|
192 |
+
" \"\"\"\n",
|
193 |
+
"\n",
|
194 |
+
" inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
|
195 |
+
" output_tokens = model.generate(**inputs, max_new_tokens=2500, do_sample=True)\n",
|
196 |
+
" return tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()"
|
197 |
+
],
|
198 |
+
"metadata": {
|
199 |
+
"id": "Nsv2Xp2kmp1S"
|
200 |
+
},
|
201 |
+
"execution_count": null,
|
202 |
+
"outputs": []
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"source": [
|
207 |
+
"# ===============================\n",
|
208 |
+
"# 10. ๋ํ ๋ฃจํ\n",
|
209 |
+
"# ===============================\n",
|
210 |
+
"def chat_loop():\n",
|
211 |
+
" conversation_history = []\n",
|
212 |
+
" print(\"๋ํ๋ฅผ ์์ํฉ๋๋ค. 'exit' ์
๋ ฅ ์ ์ข
๋ฃ.\")\n",
|
213 |
+
"\n",
|
214 |
+
" while True:\n",
|
215 |
+
" user_input = input(\"\\nUser> \").strip()\n",
|
216 |
+
" if user_input.lower() in [\"exit\", \"quit\"]:\n",
|
217 |
+
" print(\"๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค.\")\n",
|
218 |
+
" break\n",
|
219 |
+
"\n",
|
220 |
+
" print(\"\\n[๊ฒ์ ์ฟผ๋ฆฌ ์์ฑ ์ค...]\")\n",
|
221 |
+
" search_query = generate_search_query(conversation_history, user_input)\n",
|
222 |
+
" context = search_documents(search_query, k=5) if search_query else \"\"\n",
|
223 |
+
"\n",
|
224 |
+
" print(\"\\n[๋ต๋ณ ์์ฑ ์ค...]\")\n",
|
225 |
+
" response = generate_response(conversation_history, context, user_input)\n",
|
226 |
+
"\n",
|
227 |
+
" conversation_history.append((\"User\", user_input))\n",
|
228 |
+
" conversation_history.append((\"Assistant\", response))\n",
|
229 |
+
" print(f\"\\nAssistant> {response}\")\n",
|
230 |
+
"\n",
|
231 |
+
"if __name__ == \"__main__\":\n",
|
232 |
+
" chat_loop()"
|
233 |
+
],
|
234 |
+
"metadata": {
|
235 |
+
"id": "4XD0UDZImsuE"
|
236 |
+
},
|
237 |
+
"execution_count": null,
|
238 |
+
"outputs": []
|
239 |
+
}
|
240 |
+
]
|
241 |
+
}
|