Austin9 commited on
Commit
4720e70
ยท
verified ยท
1 Parent(s): f0ba49a

Upload test_run.ipynb

Browse files
Files changed (1) hide show
  1. test_run.ipynb +241 -0
test_run.ipynb ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "id": "wDmLkBbZmJvB"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "# ===============================\n",
26
+ "# 1. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ (Google Colab)\n",
27
+ "# ===============================\n",
28
+ "!pip install unsloth xformers faiss-gpu-cu12 -U\n",
29
+ "!pip install --no-deps --upgrade \"flash-attn>=2.6.3\"\n",
30
+ "!pip install -U hf_transfer"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "source": [
36
+ "# ===============================\n",
37
+ "# 2. ํ™˜๊ฒฝ ์„ค์ •\n",
38
+ "# ===============================\n",
39
+ "import os\n",
40
+ "import torch\n",
41
+ "import numpy as np\n",
42
+ "import faiss\n",
43
+ "import json\n",
44
+ "import ast\n",
45
+ "from transformers import TextStreamer\n",
46
+ "from sentence_transformers import SentenceTransformer\n",
47
+ "from unsloth import FastLanguageModel\n",
48
+ "from huggingface_hub import hf_hub_download\n",
49
+ "\n",
50
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
51
+ ],
52
+ "metadata": {
53
+ "id": "OsEBB0aKmhBy"
54
+ },
55
+ "execution_count": null,
56
+ "outputs": []
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "# ===============================\n",
62
+ "# 3. ๋ชจ๋ธ ๋กœ๋“œ\n",
63
+ "# ===============================\n",
64
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
65
+ " model_name=\"Austin9/gemma-2-9b-it-Ko-RAG\",\n",
66
+ " max_seq_length=8192,\n",
67
+ " dtype=torch.float16,\n",
68
+ " load_in_4bit=True\n",
69
+ ")\n",
70
+ "FastLanguageModel.for_inference(model)"
71
+ ],
72
+ "metadata": {
73
+ "id": "ENT1FgZZmizd"
74
+ },
75
+ "execution_count": null,
76
+ "outputs": []
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "source": [
81
+ "# ===============================\n",
82
+ "# 4. FAISS ์ธ๋ฑ์Šค ๋กœ๋“œ (Hugging Face Hub์—์„œ ์ง์ ‘ ๋‹ค์šด๋กœ๋“œ)\n",
83
+ "# ===============================\n",
84
+ "repo_id = \"Austin9/gemma-2-9b-it-Ko-RAG\" # ํ—ˆ๊น…ํŽ˜์ด์Šค ์ €์žฅ์†Œ ID\n",
85
+ "filename = \"chunked_data_vectors.npz\" # ์ €์žฅ๋œ npz ํŒŒ์ผ ์ด๋ฆ„\n",
86
+ "\n",
87
+ "vector_db_path = hf_hub_download(repo_id=repo_id, filename=filename)\n",
88
+ "data = np.load(vector_db_path)\n",
89
+ "vectors, texts, titles = data[\"vectors\"], data[\"texts\"], data[\"titles\"]\n",
90
+ "\n",
91
+ "gpu_resources = faiss.StandardGpuResources()\n",
92
+ "faiss_index = faiss.GpuIndexFlatL2(gpu_resources, vectors.shape[1])\n",
93
+ "faiss_index.add(vectors)"
94
+ ],
95
+ "metadata": {
96
+ "id": "9H7Xcc9GmkQ8"
97
+ },
98
+ "execution_count": null,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "source": [
104
+ "# ===============================\n",
105
+ "# 5. ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ\n",
106
+ "# ===============================\n",
107
+ "embedding_model = SentenceTransformer(\"nlpai-lab/KURE-v1\", device=\"cuda\").to(torch.float16)"
108
+ ],
109
+ "metadata": {
110
+ "id": "EwpMV0kXmpSX"
111
+ },
112
+ "execution_count": null,
113
+ "outputs": []
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "source": [
118
+ "# ===============================\n",
119
+ "# 6. JSON ํŒŒ์‹ฑ ํ•จ์ˆ˜\n",
120
+ "# ===============================\n",
121
+ "def robust_parse_json(response_text):\n",
122
+ " response_text = response_text.strip().strip(\"'\").strip('\"').replace(\"'\", '\"')\n",
123
+ " try:\n",
124
+ " return json.loads(response_text)\n",
125
+ " except:\n",
126
+ " try:\n",
127
+ " return ast.literal_eval(response_text)\n",
128
+ " except:\n",
129
+ " return {\"search\": \"\"}\n",
130
+ "\n",
131
+ "# ===============================\n",
132
+ "# 7. ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ์ƒ์„ฑ (QCR ๋‹จ๊ณ„)\n",
133
+ "# ===============================\n",
134
+ "def generate_search_query(conversation_history, user_input):\n",
135
+ " instruction = (\n",
136
+ " \"๋‹ค์Œ์€ ๋Œ€ํ™” ๊ธฐ๋ก(Context)์™€ ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ(Input)์ž…๋‹ˆ๋‹ค. \"\n",
137
+ " \"์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋‹ต์„ ์ œ๊ณตํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•œ ๋‹จ์ผ ๋ฌธ์ž์—ด ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”. \"\n",
138
+ " \"๊ฒ€์ƒ‰์ด ํ•„์š”ํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ๊ฒ€์ƒ‰์ด ๋ถˆํ•„์š”ํ•œ ๊ฒฝ์šฐ(์ธ์‚ฌ๋‚˜, ๊ฒ‰์น˜๋ ˆ, ๋†๋‹ด) ๋นˆ ๋ฌธ์ž์—ด์„ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”.\\n\\n\"\n",
139
+ " \"์ตœ์ข… ์ถœ๋ ฅ ํ˜•์‹์€ {'search': '<๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ>'}์ž…๋‹ˆ๋‹ค.\"\n",
140
+ " )\n",
141
+ " prompt = f\"\"\"\n",
142
+ " # Query Rewriter\n",
143
+ " ### Instruction:\n",
144
+ " {instruction}\n",
145
+ " ### Conversation:\n",
146
+ " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
147
+ " ### Input:\n",
148
+ " {user_input}\n",
149
+ " ### Response:\n",
150
+ " \"\"\"\n",
151
+ "\n",
152
+ " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
153
+ " output_tokens = model.generate(**inputs, max_new_tokens=300)\n",
154
+ " response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()\n",
155
+ " return robust_parse_json(response_text).get(\"search\", \"\")\n",
156
+ "\n",
157
+ "# ===============================\n",
158
+ "# 8. FAISS ๊ฒ€์ƒ‰\n",
159
+ "# ===============================\n",
160
+ "def search_documents(query, k=3):\n",
161
+ " if not query:\n",
162
+ " return \"\"\n",
163
+ " query_vector = embedding_model.encode([query])[0]\n",
164
+ " _, indices = faiss_index.search(np.array([query_vector]), k)\n",
165
+ " return \"\\n\\n\".join([f\"# Index [{i+1}]: {titles[idx]}\\n{texts[idx]}\" for i, idx in enumerate(indices[0])])\n",
166
+ "\n",
167
+ "# ===============================\n",
168
+ "# 9. ๋‹ต๋ณ€ ์ƒ์„ฑ\n",
169
+ "# ===============================\n",
170
+ "def generate_response(conversation_history, context, user_input):\n",
171
+ " instruction = (\n",
172
+ " \"๋‹น์‹ ์€ ์™ธ๋ถ€๊ฒ€์ƒ‰์„ ์ด์šฉํ•˜์—ฌ ์‚ฌ์šฉ์ž์—๊ฒŒ ๋„์›€์„ ์ฃผ๋Š” ์ธ๊ณต์ง€๋Šฅ ์กฐ์ˆ˜์ž…๋‹ˆ๋‹ค.\\n\"\n",
173
+ " \"- Context๋Š” ์™ธ๋ถ€๊ฒ€์ƒ‰์„ ํ†ตํ•ด ๋ฐ˜ํ™˜๋œ ์‚ฌ์šฉ์ž ์š”์ฒญ๊ณผ ๊ด€๋ จ๋œ ์ •๋ณด๋“ค์ž…๋‹ˆ๋‹ค.\\n\"\n",
174
+ " \"- Context๋ฅผ ํ™œ์šฉํ•  ๋•Œ ๋ฌธ์žฅ ๋์— ์‚ฌ์šฉํ•œ ๋ฌธ์„œ ์กฐ๊ฐ์˜ [Index]๋ฅผ ๋ถ™์ด๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”. (e.g. [1])\\n\"\n",
175
+ " \"- Context์˜ ์ •๋ณด๊ฐ€ ์‚ฌ์šฉ์ž ์š”์ฒญ๊ณผ ๊ด€๋ จ์ด ์—†๊ฑฐ๋‚˜ ๋„์›€์ด ์•ˆ๋ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ด€๋ จ์žˆ๋Š” ์ •๋ณด๋งŒ ํ™œ์šฉํ•˜๊ณ , ์—†๋Š” ์ •๋ณด๋ฅผ ์ ˆ๋Œ€ ์ง€์–ด๋‚ด์ง€ ๋งˆ์„ธ์š”.\\n\"\n",
176
+ " \"- ๋˜๋„๋ก์ด๋ฉด ์ผ๋ฐ˜ ์ง€์‹์œผ๋กœ ๋‹ต๋ณ€ํ•˜์ง€๋ง๊ณ , ์ตœ๋Œ€ํ•œ Context๋ฅผ ํ†ตํ•ด์„œ ๋‹ต๋ณ€์„ ํ•˜๋ ค๊ณ  ํ•˜์„ธ์š”. Context์— ์—†์„ ๊ฒฝ์šฐ์—๋Š” ์ด ์ ์„ ์–ธ๊ธ‰ํ•˜๋ฉฐ ์‚ฌ์ฃ„ํ•˜๊ณ  ๋‹ค๋ฅธ ์ฃผ์ œ๋‚˜ ์งˆ๋ฌธ์„ ์ถ”์ฒœํ•ด์ฃผ์„ธ์š”.\\n\"\n",
177
+ " \"- ์‚ฌ์šฉ์ž ์š”์ฒญ์— ์•Œ๋งž๋Š” ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€ํ™”๋ฅผ ํ•˜์„ธ์š”.\\n\"\n",
178
+ " \"- ํ•ญ์ƒ ์กด๋Œ“๋ง๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”.\"\n",
179
+ " )\n",
180
+ "\n",
181
+ " prompt = f\"\"\"\n",
182
+ " # Generator\n",
183
+ " ### Instruction:\n",
184
+ " {instruction}\n",
185
+ " ### Conversation:\n",
186
+ " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
187
+ " ### Context:\n",
188
+ " {context}\n",
189
+ " ### Input:\n",
190
+ " {user_input}\n",
191
+ " ### Response:\n",
192
+ " \"\"\"\n",
193
+ "\n",
194
+ " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
195
+ " output_tokens = model.generate(**inputs, max_new_tokens=2500, do_sample=True)\n",
196
+ " return tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()"
197
+ ],
198
+ "metadata": {
199
+ "id": "Nsv2Xp2kmp1S"
200
+ },
201
+ "execution_count": null,
202
+ "outputs": []
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "source": [
207
+ "# ===============================\n",
208
+ "# 10. ๋Œ€ํ™” ๋ฃจํ”„\n",
209
+ "# ===============================\n",
210
+ "def chat_loop():\n",
211
+ " conversation_history = []\n",
212
+ " print(\"๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. 'exit' ์ž…๋ ฅ ์‹œ ์ข…๋ฃŒ.\")\n",
213
+ "\n",
214
+ " while True:\n",
215
+ " user_input = input(\"\\nUser> \").strip()\n",
216
+ " if user_input.lower() in [\"exit\", \"quit\"]:\n",
217
+ " print(\"๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.\")\n",
218
+ " break\n",
219
+ "\n",
220
+ " print(\"\\n[๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ์ƒ์„ฑ ์ค‘...]\")\n",
221
+ " search_query = generate_search_query(conversation_history, user_input)\n",
222
+ " context = search_documents(search_query, k=5) if search_query else \"\"\n",
223
+ "\n",
224
+ " print(\"\\n[๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘...]\")\n",
225
+ " response = generate_response(conversation_history, context, user_input)\n",
226
+ "\n",
227
+ " conversation_history.append((\"User\", user_input))\n",
228
+ " conversation_history.append((\"Assistant\", response))\n",
229
+ " print(f\"\\nAssistant> {response}\")\n",
230
+ "\n",
231
+ "if __name__ == \"__main__\":\n",
232
+ " chat_loop()"
233
+ ],
234
+ "metadata": {
235
+ "id": "4XD0UDZImsuE"
236
+ },
237
+ "execution_count": null,
238
+ "outputs": []
239
+ }
240
+ ]
241
+ }