File size: 10,482 Bytes
b13c344
 
 
c96d49b
ebdb8af
b13c344
 
8919ef1
b13c344
17b3852
b13c344
 
 
 
663038b
b13c344
 
 
 
 
 
 
 
cb98dd5
 
b13c344
 
369fca6
508b122
17b3852
b13c344
663038b
 
 
e09a0e3
9cfc436
663038b
e09a0e3
663038b
9cfc436
 
 
c96d49b
9cfc436
 
 
 
 
 
 
 
a7b0743
9cfc436
 
 
a7b0743
9cfc436
 
 
a7b0743
9cfc436
 
c96d49b
663038b
a7b0743
c96d49b
663038b
 
a7b0743
663038b
 
a7b0743
b13c344
 
a7b0743
663038b
 
a7b0743
663038b
 
a7b0743
b13c344
 
8edb2bc
 
b13c344
663038b
9cfc436
 
 
 
 
 
 
a7b0743
9cfc436
 
a7b0743
9cfc436
 
 
 
b13c344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8edb2bc
 
b13c344
c96d49b
b13c344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7b0743
 
 
b13c344
 
a7b0743
 
 
 
 
b13c344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96d49b
663038b
a7b0743
663038b
a7b0743
 
663038b
c96d49b
 
a7b0743
c96d49b
 
 
a7b0743
 
 
 
c96d49b
 
 
b13c344
 
ebdb8af
b13c344
a7b0743
 
 
c96d49b
 
9cfc436
a7b0743
9cfc436
 
 
 
 
 
 
a7b0743
9cfc436
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
import re
import json
import traceback
import streamlit as st
from pathlib import Path
from typing import List, Annotated, Any
import chromadb
import operator
import tempfile
from tqdm import tqdm
from pydantic import BaseModel
from langchain.embeddings.cohere import CohereEmbeddings
from langchain_cohere import ChatCohere
from langchain.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
import cohere
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
from langgraph.graph import StateGraph, START, END, add_messages
from langgraph.constants import Send
from langgraph.checkpoint.memory import MemorySaver

chromadb.api.client.SharedSystemClient.clear_system_cache()

COHERE_API_KEY = os.environ["COHERE_API_KEY"]
co = cohere.Client(COHERE_API_KEY)

documents_path = Path(__file__).parent / "documents"
persist_dir = tempfile.mkdtemp()

def prepare_vectorstore(uploaded_files=None):
    documents = []

    if uploaded_files and any(file.size > 0 for file in uploaded_files):
        st.write("πŸ“ Uploaded files:")
        for file in uploaded_files:
            st.write(f"β€’ {file.name} ({file.size} bytes)")
            file_path = Path(tempfile.gettempdir()) / file.name
            try:
                with open(file_path, "wb") as f:
                    f.write(file.getbuffer())
                st.write(f"βœ… Saved to: {file_path}")

                if file.name.endswith(".pdf"):
                    st.write(f"πŸ“„ Loading PDF: {file.name}")
                    loader = PyPDFLoader(str(file_path))
                elif file.name.endswith(".txt"):
                    st.write(f"πŸ“ƒ Loading TXT: {file.name}")
                    loader = TextLoader(str(file_path))
                else:
                    st.warning(f"Unsupported file type: {file.name}")
                    continue

                loaded = loader.load()
                st.write(f"Loaded {len(loaded)} pages from {file.name}")
                documents.extend(loaded)

            except Exception as e:
                st.error(f"Error loading {file.name}:")
                st.exception(e)
                st.text(traceback.format_exc())

    else:
        st.warning("No uploaded files found or all were empty.")
        st.stop()

    if not documents:
        st.error("No content could be loaded from the uploaded files.")
        st.stop()

    st.write("Splitting documents into chunks...")
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
    docs = splitter.split_documents(documents)
    st.write(f"Total chunks created: {len(docs)}")

    if not docs:
        st.error("No content found in the documents after splitting.")
        st.stop()

    st.write("Embedding documents...")
    embedding = CohereEmbeddings(
        model="embed-multilingual-light-v3.0",
        cohere_api_key=COHERE_API_KEY,
        user_agent="langgraph-app"
    )

    try:
        vectorstore = Chroma.from_documents(
            documents=tqdm(docs, desc="Embedding"),
            embedding=embedding,
            persist_directory=persist_dir
        )
        vectorstore.persist()
        st.success("Document embedding complete.")
        return vectorstore
    except Exception as e:
        st.error("Embedding failed:")
        st.exception(e)
        st.text(traceback.format_exc())
        st.stop()

class State(BaseModel):
    state: List[str] = []
    messages: Annotated[list[AnyMessage], add_messages]
    topic: List[str] = []
    context: List[str] = []
    sub_topic_list: List[str] = []
    sub_topics: Annotated[list[AnyMessage], add_messages]
    stories: Annotated[list[AnyMessage], add_messages]
    stories_lst: Annotated[list, operator.add]

class StoryState(BaseModel):
    retrieved_docs: List[Any] = []
    stories: Annotated[list[AnyMessage], add_messages]
    reranked_docs: List[str] = []
    story_topic: str = ""
    stories_lst: Annotated[list, operator.add]

def extract_topics(messages):
    topics = []
    for message in messages:
        topics.extend(re.findall(r'- \*\*(.*?)\*\*', message.content))
    return topics

embedding_llm = CohereEmbeddings(
    model="embed-multilingual-light-v3.0",
    cohere_api_key=COHERE_API_KEY,
    user_agent="langgraph-app"
)

llm = ChatCohere(
    api_version="2024-02-15-preview",
    temperature=0.7,
    model="command-r-plus-08-2024",
    cohere_api_key=COHERE_API_KEY
)

beginner_topic_sys_msg = SystemMessage(content="Suppose you are a middle grader who wants to learn constantly about new topics to get a good score in exams.")
middle_topic_sys_msg = SystemMessage(content="Suppose you are a college student who wants to learn constantly about new topics to get a good score in exams.")
advanced_topic_sys_msg = SystemMessage(content="Suppose you are a teacher who wants to learn constantly about new topics to teach your students.")

def retrieve_node(state):
    topic = state.story_topic
    query = f"information about {topic}"
    retriever = Chroma(persist_directory=persist_dir, embedding_function=embedding_llm).as_retriever(search_kwargs={"k": 20})
    docs = retriever.get_relevant_documents(query)
    return {"retrieved_docs": docs, "question": query}

def rerank_node(state):
    topic = state.story_topic
    query = f"Rerank documents based on how good they explain the topic {topic}"
    docs = state.retrieved_docs
    texts = [doc.page_content for doc in docs]
    rerank_results = co.rerank(query=query, documents=texts, top_n=5, model="rerank-v3.5")
    top_docs = [texts[result.index] for result in rerank_results.results]
    return {"reranked_docs": top_docs, "question": query}

def generate_story_node(state):
    context = "\n\n".join(state.reranked_docs)
    topic = state.story_topic
    system_message = """
    Suppose you're an amazing story writer and scientific thinker. 
    You've written hundreds of story books explaining scientific topics in a childlike manner.
    You add a subtle humor to your stories to make them more engaging.
    """
    prompt = f"""
    Use the following context to generate a simple engaging story that explains {topic} in such a way a middle schooler can understand it.

    Context:
    {context}

    Story:
    """
    response = llm.invoke([SystemMessage(system_message), HumanMessage(prompt)])
    return {"stories": response}

def beginner_topic(state: State):
    prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?"
    sub_topics = [llm.invoke([beginner_topic_sys_msg] + [prompt])]
    return {"message": sub_topics[0], "sub_topics": sub_topics[0]}

def middle_topic(state: State):
    prompt = f"What are the middle-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}"
    sub_topics = [llm.invoke([middle_topic_sys_msg] + [prompt])]
    return {"message": sub_topics, "sub_topics": sub_topics}

def advanced_topic(state: State):
    prompt = f"What are the advanced-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}"
    sub_topics = [llm.invoke([advanced_topic_sys_msg] + [prompt])]
    return {"message": sub_topics, "sub_topics": sub_topics}

def topic_extractor(state: State):
    return {"sub_topic_list": extract_topics(state.sub_topics)}

def dynamic_topic_edges(state: State):
    return [Send("story_generator", {"story_topic": topic}) for topic in state.sub_topic_list]

story_builder = StateGraph(StoryState)
story_builder.add_node("Retrieve", retrieve_node)
story_builder.add_node("Rerank", rerank_node)
story_builder.add_node("Generate", generate_story_node)
story_builder.set_entry_point("Retrieve")
story_builder.add_edge("Retrieve", "Rerank")
story_builder.add_edge("Rerank", "Generate")
story_builder.set_finish_point("Generate")

story_graph = story_builder.compile()

main_builder = StateGraph(State)
main_builder.add_node("beginner_topic", beginner_topic)
main_builder.add_node("middle_topic", middle_topic)
main_builder.add_node("advanced_topic", advanced_topic)
main_builder.add_node("topic_extractor", topic_extractor)
main_builder.add_node("story_generator", story_graph)
main_builder.add_edge(START, "beginner_topic")
main_builder.add_edge("beginner_topic", "middle_topic")
main_builder.add_edge("middle_topic", "advanced_topic")
main_builder.add_edge("advanced_topic", "topic_extractor")
main_builder.add_conditional_edges("topic_extractor", dynamic_topic_edges, ["story_generator"])
main_builder.add_edge("story_generator", END)

memory = MemorySaver()
react_graph = main_builder.compile(checkpointer=memory, interrupt_after=["topic_extractor"])

st.title("LangGraph Topic Story Generator")

uploaded_files = st.file_uploader(
    "Upload .txt or .pdf files",
    type=["txt", "pdf"],
    accept_multiple_files=True,
    key="file_uploader"
)

if uploaded_files:
    st.session_state["files"] = uploaded_files
    st.success(f"{len(uploaded_files)} file(s) uploaded:")
    for file in uploaded_files:
        st.write(f"β€’ {file.name} ({file.size} bytes)")
elif "files" in st.session_state:
    st.info("Using previously uploaded files:")
    for file in st.session_state["files"]:
        st.write(f"β€’ {file.name} ({file.size} bytes)")
else:
    st.info("No files uploaded yet.")

topic = st.text_input("Enter a topic", "Human Evolution")
context = st.text_input("Enter a context", "Science")

if st.button("Generate Stories"):
    uploaded = st.session_state.get("files")
    if not uploaded or all(file.size == 0 for file in uploaded):
        st.warning("You uploaded files, but they appear to be empty.")
        st.stop()

    try:
        prepare_vectorstore(uploaded)
        thread = {"configurable": {"thread_id": "1"}}
        react_graph.invoke({"topic": [topic], "context": [context]}, thread)
        react_graph.update_state(thread, {"sub_topic_list": ['Early Hominins', 'Fossil Evidence', "Darwin's Theory of Evolution"]})
        result = react_graph.invoke(None, thread, stream_mode="values")
        for story in result["stories"]:
            st.markdown(story.content)
    except Exception as e:
        st.error("Something went wrong during story generation.")
        st.exception(e)
        st.text(traceback.format_exc())