Ley_Fill7 commited on
Commit
0fdded5
·
1 Parent(s): 809f031

Added the app file

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import modules and classes
2
+ from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
3
+ from llama_index.llms.nvidia import NVIDIA
4
+ from llama_index.embeddings.nvidia import NVIDIAEmbedding
5
+ from llama_index.core.llms import ChatMessage, MessageRole
6
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
7
+ from langchain_core.documents import Document as LangDocument
8
+ from llama_index.core import Document as LlamaDocument
9
+ from llama_index.core import Settings
10
+ from llama_parse import LlamaParse
11
+ import streamlit as st
12
+ import os
13
+
14
+ # Set environmental variables
15
+ nvidia_api_key = os.getenv("NVIDIA_KEY")
16
+ llamaparse_api_key = os.getenv("PARSE_KEY")
17
+
18
+ # Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
19
+ client = NVIDIA(
20
+ model="meta/llama-3.1-8b-instruct",
21
+ api_key=nvidia_api_key,
22
+ temperature=0.2,
23
+ top_p=0.7,
24
+ max_tokens=1024
25
+ )
26
+
27
+ embed_model = NVIDIAEmbedding(
28
+ model="nvidia/nv-embedqa-e5-v5",
29
+ api_key=nvidia_api_key,
30
+ truncate="NONE"
31
+ )
32
+
33
+ reranker = NVIDIARerank(
34
+ model="nvidia/nv-rerankqa-mistral-4b-v3",
35
+ api_key=nvidia_api_key,
36
+ )
37
+
38
+ # Set the NVIDIA models globally
39
+ Settings.embed_model = embed_model
40
+ Settings.llm = client
41
+
42
+ # Parse the local PDF document
43
+ parser = LlamaParse(
44
+ api_key=llamaparse_api_key,
45
+ result_type="markdown",
46
+ verbose=True
47
+ )
48
+
49
+ # Get the absolute path of the script's directory
50
+ script_dir = os.path.dirname(os.path.abspath(__file__))
51
+ data_file = os.path.join(script_dir, "PhilDataset.pdf")
52
+
53
+ # Load the PDF document using the relative path
54
+ documents = parser.load_data(data_file)
55
+ print("Document Parsed")
56
+
57
+ # Split parsed text into chunks for embedding model
58
+ def split_text(text, max_tokens=512):
59
+ words = text.split()
60
+ chunks = []
61
+ current_chunk = []
62
+ current_length = 0
63
+
64
+ for word in words:
65
+ word_length = len(word)
66
+ if current_length + word_length + 1 > max_tokens:
67
+ chunks.append(" ".join(current_chunk))
68
+ current_chunk = [word]
69
+ current_length = word_length + 1
70
+ else:
71
+ current_chunk.append(word)
72
+ current_length += word_length + 1
73
+
74
+ if current_chunk:
75
+ chunks.append(" ".join(current_chunk))
76
+
77
+ return chunks
78
+
79
+ # Generate embeddings for document chunks
80
+ all_embeddings = []
81
+ all_documents = []
82
+
83
+ for doc in documents:
84
+ text_chunks = split_text(doc.text)
85
+ for chunk in text_chunks:
86
+ embedding = embed_model.get_text_embedding(chunk)
87
+ all_embeddings.append(embedding)
88
+ all_documents.append(LlamaDocument(text=chunk))
89
+ print("Embeddings generated")
90
+
91
+ # Create and persist index with NVIDIAEmbeddings
92
+ index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
93
+ index.set_index_id("vector_index")
94
+ index.storage_context.persist("./storage")
95
+ print("Index created")
96
+
97
+ # Load index from storage
98
+ storage_context = StorageContext.from_defaults(persist_dir="storage")
99
+ index = load_index_from_storage(storage_context, index_id="vector_index")
100
+ print("Index loaded")
101
+
102
+ # Query the index and use output as LLM context
103
+ def query_model_with_context(question):
104
+
105
+ retriever = index.as_retriever(similarity_top_k=3)
106
+ nodes = retriever.retrieve(question)
107
+
108
+ for node in nodes:
109
+ print(node)
110
+
111
+ # Rerank the nodes
112
+ ranked_documents = reranker.compress_documents(
113
+ query=question,
114
+ documents = [LangDocument(page_content=node.text) for node in nodes]
115
+ )
116
+
117
+ # Print the most relevant and least relevant node
118
+ print(f"Most relevant node: {ranked_documents[0].page_content}")
119
+
120
+ # Use the most relevant node as context
121
+ context = ranked_documents[0].page_content
122
+
123
+ # Construct the messages using the ChatMessage class
124
+ messages = [
125
+ ChatMessage(role=MessageRole.SYSTEM, content=context),
126
+ ChatMessage(role=MessageRole.USER, content=str(question))
127
+ ]
128
+
129
+ completion = client.chat(messages)
130
+
131
+ # Process response - assuming completion is a single string or a tuple containing a string
132
+ response_text = ""
133
+
134
+ if isinstance(completion, (list, tuple)):
135
+ # Join elements of tuple/list if it's in such format
136
+ response_text = ' '.join(completion)
137
+ elif isinstance(completion, str):
138
+ # Directly assign if it's a string
139
+ response_text = completion
140
+ else:
141
+ # Fallback for unexpected types, convert to string
142
+ response_text = str(completion)
143
+
144
+ response_text = response_text.replace("assistant:", "Final Response:").strip()
145
+
146
+ return response_text
147
+
148
+
149
+ # Streamlit UI
150
+ st.title("Chat with this Rerank RAG App")
151
+ question = st.text_input("Enter a relevant question to chat with the attached PhilDataset PDF file:")
152
+
153
+ if st.button("Submit"):
154
+ if question:
155
+ st.write("**RAG Response:**")
156
+ response = query_model_with_context(question)
157
+ st.write(response)
158
+ else:
159
+ st.warning("Please enter a question.")