Austin9 commited on
Commit
f0ba49a
·
verified ·
1 Parent(s): fb6f407

Delete test_run.ipynb

Browse files
Files changed (1) hide show
  1. test_run.ipynb +0 -237
test_run.ipynb DELETED
@@ -1,237 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": []
7
- },
8
- "kernelspec": {
9
- "name": "python3",
10
- "display_name": "Python 3"
11
- },
12
- "language_info": {
13
- "name": "python"
14
- }
15
- },
16
- "cells": [
17
- {
18
- "cell_type": "code",
19
- "execution_count": null,
20
- "metadata": {
21
- "id": "wDmLkBbZmJvB"
22
- },
23
- "outputs": [],
24
- "source": [
25
- "# ===============================\n",
26
- "# 1. 라이브러리 설치 (Google Colab)\n",
27
- "# ===============================\n",
28
- "!pip install unsloth xformers faiss-gpu-cu12 -U\n",
29
- "!pip install --no-deps --upgrade \"flash-attn>=2.6.3\"\n",
30
- "!pip install -U hf_transfer"
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "source": [
36
- "# ===============================\n",
37
- "# 2. 환경 설정\n",
38
- "# ===============================\n",
39
- "import os\n",
40
- "import torch\n",
41
- "import numpy as np\n",
42
- "import faiss\n",
43
- "import json\n",
44
- "import ast\n",
45
- "from transformers import TextStreamer\n",
46
- "from sentence_transformers import SentenceTransformer\n",
47
- "from unsloth import FastLanguageModel\n",
48
- "\n",
49
- "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
50
- ],
51
- "metadata": {
52
- "id": "OsEBB0aKmhBy"
53
- },
54
- "execution_count": null,
55
- "outputs": []
56
- },
57
- {
58
- "cell_type": "code",
59
- "source": [
60
- "# ===============================\n",
61
- "# 3. 모델 로드\n",
62
- "# ===============================\n",
63
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
64
- " model_name=\"Austin9/gemma-2-9b-it-Ko-RAG\",\n",
65
- " max_seq_length=8192,\n",
66
- " dtype=torch.float16,\n",
67
- " load_in_4bit=True\n",
68
- ")\n",
69
- "FastLanguageModel.for_inference(model)"
70
- ],
71
- "metadata": {
72
- "id": "ENT1FgZZmizd"
73
- },
74
- "execution_count": null,
75
- "outputs": []
76
- },
77
- {
78
- "cell_type": "code",
79
- "source": [
80
- "# ===============================\n",
81
- "# 4. FAISS 인덱스 로드\n",
82
- "# ===============================\n",
83
- "vector_db_path = \"/content/chunked_data_vectors.npz\"\n",
84
- "data = np.load(vector_db_path)\n",
85
- "vectors, texts, titles = data[\"vectors\"], data[\"texts\"], data[\"titles\"]\n",
86
- "\n",
87
- "gpu_resources = faiss.StandardGpuResources()\n",
88
- "faiss_index = faiss.GpuIndexFlatL2(gpu_resources, vectors.shape[1])\n",
89
- "faiss_index.add(vectors)"
90
- ],
91
- "metadata": {
92
- "id": "9H7Xcc9GmkQ8"
93
- },
94
- "execution_count": null,
95
- "outputs": []
96
- },
97
- {
98
- "cell_type": "code",
99
- "source": [
100
- "# ===============================\n",
101
- "# 5. 임베딩 모델 로드\n",
102
- "# ===============================\n",
103
- "embedding_model = SentenceTransformer(\"nlpai-lab/KURE-v1\", device=\"cuda\").to(torch.float16)"
104
- ],
105
- "metadata": {
106
- "id": "EwpMV0kXmpSX"
107
- },
108
- "execution_count": null,
109
- "outputs": []
110
- },
111
- {
112
- "cell_type": "code",
113
- "source": [
114
- "# ===============================\n",
115
- "# 6. JSON 파싱 함수\n",
116
- "# ===============================\n",
117
- "def robust_parse_json(response_text):\n",
118
- " response_text = response_text.strip().strip(\"'\").strip('\"').replace(\"'\", '\"')\n",
119
- " try:\n",
120
- " return json.loads(response_text)\n",
121
- " except:\n",
122
- " try:\n",
123
- " return ast.literal_eval(response_text)\n",
124
- " except:\n",
125
- " return {\"search\": \"\"}\n",
126
- "\n",
127
- "# ===============================\n",
128
- "# 7. 검색 쿼리 생성 (QCR 단계)\n",
129
- "# ===============================\n",
130
- "def generate_search_query(conversation_history, user_input):\n",
131
- " instruction = (\n",
132
- " \"다음은 대화 기록(Context)와 사용자의 질문(Input)입니다. \"\n",
133
- " \"사용자의 질문에 답을 제공하기 위해 필요한 단일 문자열 검색 쿼리를 생성하세요. \"\n",
134
- " \"검색이 필요하지 않거나 검색이 불필요한 경우(인사나, 겉치레, 농담) 빈 문자열을 반환하세요.\\n\\n\"\n",
135
- " \"최종 출력 형식은 {'search': '<검색 쿼리>'}입니다.\"\n",
136
- " )\n",
137
- " prompt = f\"\"\"\n",
138
- " # Query Rewriter\n",
139
- " ### Instruction:\n",
140
- " {instruction}\n",
141
- " ### Conversation:\n",
142
- " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
143
- " ### Input:\n",
144
- " {user_input}\n",
145
- " ### Response:\n",
146
- " \"\"\"\n",
147
- "\n",
148
- " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
149
- " output_tokens = model.generate(**inputs, max_new_tokens=300)\n",
150
- " response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()\n",
151
- " return robust_parse_json(response_text).get(\"search\", \"\")\n",
152
- "\n",
153
- "# ===============================\n",
154
- "# 8. FAISS 검색\n",
155
- "# ===============================\n",
156
- "def search_documents(query, k=3):\n",
157
- " if not query:\n",
158
- " return \"\"\n",
159
- " query_vector = embedding_model.encode([query])[0]\n",
160
- " _, indices = faiss_index.search(np.array([query_vector]), k)\n",
161
- " return \"\\n\\n\".join([f\"# Index [{i+1}]: {titles[idx]}\\n{texts[idx]}\" for i, idx in enumerate(indices[0])])\n",
162
- "\n",
163
- "# ===============================\n",
164
- "# 9. 답변 생성\n",
165
- "# ===============================\n",
166
- "def generate_response(conversation_history, context, user_input):\n",
167
- " instruction = (\n",
168
- " \"당신은 외부검색을 이용하여 사용자에게 도움을 주는 인공지능 조수입니다.\\n\"\n",
169
- " \"- Context는 외부검색을 통해 반환된 사용자 요청과 관련된 정보들입니다.\\n\"\n",
170
- " \"- Context를 활용할 때 문장 끝에 사용한 문서 조각의 [Index]를 붙이고 자연스러운 답변을 작성하세요. (e.g. [1])\\n\"\n",
171
- " \"- Context의 정보가 사용자 요청과 관련이 없거나 도움이 안될수도 있습니다. 관련있는 정보만 활용하고, 없는 정보를 절대 지어내지 마세요.\\n\"\n",
172
- " \"- 되도록이면 일반 지식으로 답변하지말고, 최대한 Context를 통해서 답변을 하려고 하세요. Context에 없을 경우에는 이 점을 언급하며 사죄하고 다른 주제나 질문을 추천해주세요.\\n\"\n",
173
- " \"- 사용자 요청에 알맞는 자연스러운 대화를 하세요.\\n\"\n",
174
- " \"- 항상 존댓말로 답변하세요.\"\n",
175
- " )\n",
176
- "\n",
177
- " prompt = f\"\"\"\n",
178
- " # Generator\n",
179
- " ### Instruction:\n",
180
- " {instruction}\n",
181
- " ### Conversation:\n",
182
- " {'\\n'.join([f'{role}: {msg}' for role, msg in conversation_history])}\n",
183
- " ### Context:\n",
184
- " {context}\n",
185
- " ### Input:\n",
186
- " {user_input}\n",
187
- " ### Response:\n",
188
- " \"\"\"\n",
189
- "\n",
190
- " inputs = tokenizer([prompt], return_tensors=\"pt\").to(\"cuda\")\n",
191
- " output_tokens = model.generate(**inputs, max_new_tokens=2500, do_sample=True)\n",
192
- " return tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(\"### Response:\")[-1].strip()"
193
- ],
194
- "metadata": {
195
- "id": "Nsv2Xp2kmp1S"
196
- },
197
- "execution_count": null,
198
- "outputs": []
199
- },
200
- {
201
- "cell_type": "code",
202
- "source": [
203
- "# ===============================\n",
204
- "# 10. 대화 루프\n",
205
- "# ===============================\n",
206
- "def chat_loop():\n",
207
- " conversation_history = []\n",
208
- " print(\"대화를 시작합니다. 'exit' 입력 시 종료.\")\n",
209
- "\n",
210
- " while True:\n",
211
- " user_input = input(\"\\nUser> \").strip()\n",
212
- " if user_input.lower() in [\"exit\", \"quit\"]:\n",
213
- " print(\"대화를 종료합니다.\")\n",
214
- " break\n",
215
- "\n",
216
- " print(\"\\n[검색 쿼리 생성 중...]\")\n",
217
- " search_query = generate_search_query(conversation_history, user_input)\n",
218
- " context = search_documents(search_query, k=5) if search_query else \"\"\n",
219
- "\n",
220
- " print(\"\\n[답변 생성 중...]\")\n",
221
- " response = generate_response(conversation_history, context, user_input)\n",
222
- "\n",
223
- " conversation_history.append((\"User\", user_input))\n",
224
- " conversation_history.append((\"Assistant\", response))\n",
225
- " print(f\"\\nAssistant> {response}\")\n",
226
- "\n",
227
- "if __name__ == \"__main__\":\n",
228
- " chat_loop()"
229
- ],
230
- "metadata": {
231
- "id": "4XD0UDZImsuE"
232
- },
233
- "execution_count": null,
234
- "outputs": []
235
- }
236
- ]
237
- }