|
import os |
|
import gc |
|
import tempfile |
|
import uuid |
|
import pandas as pd |
|
|
|
from gitingest import ingest |
|
from llama_index.core import Settings |
|
from llama_index.llms.sambanovasystems import SambaNovaCloud |
|
from llama_index.core import PromptTemplate |
|
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding |
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader |
|
from llama_index.core.node_parser import MarkdownNodeParser |
|
|
|
import streamlit as st |
|
|
|
|
|
SAMBANOVA_API_KEY = os.getenv("SAMBANOVA_API_KEY") |
|
MXBAI_API_KEY = os.getenv("MXBAI_API_KEY") |
|
|
|
|
|
if not SAMBANOVA_API_KEY: |
|
raise ValueError("SAMBANOVA_API_KEY is not set in the Hugging Face secrets.") |
|
if not MXBAI_API_KEY: |
|
raise ValueError("MXBAI_API_KEY is not set in the Hugging Face secrets.") |
|
|
|
if "id" not in st.session_state: |
|
st.session_state.id = uuid.uuid4() |
|
st.session_state.file_cache = {} |
|
|
|
session_id = st.session_state.id |
|
client = None |
|
|
|
@st.cache_resource |
|
def load_llm(): |
|
|
|
llm = SambaNovaCloud( |
|
model="DeepSeek-R1-Distill-Llama-70B", |
|
context_window=100000, |
|
max_tokens=1024, |
|
temperature=0.7, |
|
top_k=1, |
|
top_p=0.01, |
|
) |
|
return llm |
|
|
|
def reset_chat(): |
|
st.session_state.messages = [] |
|
st.session_state.context = None |
|
gc.collect() |
|
|
|
def process_with_gitingets(github_url): |
|
|
|
summary, tree, content = ingest(github_url) |
|
return summary, tree, content |
|
|
|
with st.sidebar: |
|
st.header(f"Add your GitHub repository!") |
|
|
|
github_url = st.text_input("Enter GitHub repository URL", placeholder="GitHub URL") |
|
load_repo = st.button("Load Repository") |
|
|
|
if github_url and load_repo: |
|
try: |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
st.write("Processing your repository...") |
|
repo_name = github_url.split('/')[-1] |
|
file_key = f"{session_id}-{repo_name}" |
|
|
|
if file_key not in st.session_state.get('file_cache', {}): |
|
|
|
if os.path.exists(temp_dir): |
|
summary, tree, content = process_with_gitingets(github_url) |
|
|
|
|
|
with open("content.md", "w", encoding="utf-8") as f: |
|
f.write(content) |
|
|
|
|
|
content_path = os.path.join(temp_dir, f"{repo_name}_content.md") |
|
with open(content_path, "w", encoding="utf-8") as f: |
|
f.write(content) |
|
loader = SimpleDirectoryReader( |
|
input_dir=temp_dir, |
|
) |
|
else: |
|
st.error('Could not find the file you uploaded, please check again...') |
|
st.stop() |
|
|
|
docs = loader.load_data() |
|
|
|
|
|
llm=load_llm() |
|
|
|
|
|
embed_model = MixedbreadAIEmbedding( |
|
api_key=MXBAI_API_KEY, |
|
model_name="mixedbread-ai/mxbai-embed-large-v1", |
|
) |
|
|
|
|
|
Settings.embed_model = embed_model |
|
node_parser = MarkdownNodeParser() |
|
index = VectorStoreIndex.from_documents(documents=docs, transformations=[node_parser], show_progress=True) |
|
|
|
|
|
Settings.llm = llm |
|
query_engine = index.as_query_engine(streaming=True) |
|
|
|
|
|
qa_prompt_tmpl_str = """ |
|
You are an AI assistant specialized in analyzing GitHub repositories. |
|
|
|
Repository structure: |
|
{tree} |
|
--------------------- |
|
|
|
Context information from the repository: |
|
{context_str} |
|
--------------------- |
|
|
|
Given the repository structure and context above, provide a clear and precise answer to the query. |
|
Focus on the repository's content, code structure, and implementation details. |
|
If the information is not available in the context, respond with 'I don't have enough information about that aspect of the repository.' |
|
|
|
Query: {query_str} |
|
Answer: """ |
|
qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str) |
|
|
|
query_engine.update_prompts( |
|
{"response_synthesizer:text_qa_template": qa_prompt_tmpl} |
|
) |
|
|
|
st.session_state.file_cache[file_key] = query_engine |
|
else: |
|
query_engine = st.session_state.file_cache[file_key] |
|
|
|
|
|
st.success("Ready to Chat!") |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
st.stop() |
|
|
|
col1, col2 = st.columns([6, 1]) |
|
|
|
with col1: |
|
st.header(f"Chat with GitHub using RAG </>") |
|
|
|
with col2: |
|
st.button("Clear ↺", on_click=reset_chat) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
reset_chat() |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("What's up?"): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
try: |
|
|
|
repo_name = github_url.split('/')[-1] |
|
file_key = f"{session_id}-{repo_name}" |
|
|
|
|
|
query_engine = st.session_state.file_cache.get(file_key) |
|
|
|
if query_engine is None: |
|
st.error("Please load a repository first!") |
|
st.stop() |
|
|
|
|
|
response = query_engine.query(prompt) |
|
|
|
|
|
if hasattr(response, 'response_gen'): |
|
for chunk in response.response_gen: |
|
if isinstance(chunk, str): |
|
full_response += chunk |
|
message_placeholder.markdown(full_response + "▌") |
|
else: |
|
|
|
full_response = str(response) |
|
message_placeholder.markdown(full_response) |
|
|
|
message_placeholder.markdown(full_response) |
|
except Exception as e: |
|
st.error(f"An error occurred while processing your query: {str(e)}") |
|
full_response = "Sorry, I encountered an error while processing your request." |
|
message_placeholder.markdown(full_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |