Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import traceback | |
import streamlit as st | |
from pathlib import Path | |
from typing import List, Annotated, Any | |
import chromadb | |
import operator | |
import tempfile | |
from tqdm import tqdm | |
from pydantic import BaseModel | |
from langchain.embeddings.cohere import CohereEmbeddings | |
from langchain_cohere import ChatCohere | |
from langchain.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
import cohere | |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage | |
from langgraph.graph import StateGraph, START, END, add_messages | |
from langgraph.constants import Send | |
from langgraph.checkpoint.memory import MemorySaver | |
chromadb.api.client.SharedSystemClient.clear_system_cache() | |
COHERE_API_KEY = os.environ["COHERE_API_KEY"] | |
co = cohere.Client(COHERE_API_KEY) | |
documents_path = Path(__file__).parent / "documents" | |
persist_dir = tempfile.mkdtemp() | |
def prepare_vectorstore(uploaded_files=None): | |
documents = [] | |
if uploaded_files and any(file.size > 0 for file in uploaded_files): | |
st.write("π Uploaded files:") | |
for file in uploaded_files: | |
st.write(f"β’ {file.name} ({file.size} bytes)") | |
file_path = Path(tempfile.gettempdir()) / file.name | |
try: | |
with open(file_path, "wb") as f: | |
f.write(file.getbuffer()) | |
st.write(f"β Saved to: {file_path}") | |
if file.name.endswith(".pdf"): | |
st.write(f"π Loading PDF: {file.name}") | |
loader = PyPDFLoader(str(file_path)) | |
elif file.name.endswith(".txt"): | |
st.write(f"π Loading TXT: {file.name}") | |
loader = TextLoader(str(file_path)) | |
else: | |
st.warning(f"Unsupported file type: {file.name}") | |
continue | |
loaded = loader.load() | |
st.write(f"Loaded {len(loaded)} pages from {file.name}") | |
documents.extend(loaded) | |
except Exception as e: | |
st.error(f"Error loading {file.name}:") | |
st.exception(e) | |
st.text(traceback.format_exc()) | |
else: | |
st.warning("No uploaded files found or all were empty.") | |
st.stop() | |
if not documents: | |
st.error("No content could be loaded from the uploaded files.") | |
st.stop() | |
st.write("Splitting documents into chunks...") | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) | |
docs = splitter.split_documents(documents) | |
st.write(f"Total chunks created: {len(docs)}") | |
if not docs: | |
st.error("No content found in the documents after splitting.") | |
st.stop() | |
st.write("Embedding documents...") | |
embedding = CohereEmbeddings( | |
model="embed-multilingual-light-v3.0", | |
cohere_api_key=COHERE_API_KEY, | |
user_agent="langgraph-app" | |
) | |
try: | |
vectorstore = Chroma.from_documents( | |
documents=tqdm(docs, desc="Embedding"), | |
embedding=embedding, | |
persist_directory=persist_dir | |
) | |
vectorstore.persist() | |
st.success("Document embedding complete.") | |
return vectorstore | |
except Exception as e: | |
st.error("Embedding failed:") | |
st.exception(e) | |
st.text(traceback.format_exc()) | |
st.stop() | |
class State(BaseModel): | |
state: List[str] = [] | |
messages: Annotated[list[AnyMessage], add_messages] | |
topic: List[str] = [] | |
context: List[str] = [] | |
sub_topic_list: List[str] = [] | |
sub_topics: Annotated[list[AnyMessage], add_messages] | |
stories: Annotated[list[AnyMessage], add_messages] | |
stories_lst: Annotated[list, operator.add] | |
class StoryState(BaseModel): | |
retrieved_docs: List[Any] = [] | |
stories: Annotated[list[AnyMessage], add_messages] | |
reranked_docs: List[str] = [] | |
story_topic: str = "" | |
stories_lst: Annotated[list, operator.add] | |
def extract_topics(messages): | |
topics = [] | |
for message in messages: | |
topics.extend(re.findall(r'- \*\*(.*?)\*\*', message.content)) | |
return topics | |
embedding_llm = CohereEmbeddings( | |
model="embed-multilingual-light-v3.0", | |
cohere_api_key=COHERE_API_KEY, | |
user_agent="langgraph-app" | |
) | |
llm = ChatCohere( | |
api_version="2024-02-15-preview", | |
temperature=0.7, | |
model="command-r-plus-08-2024", | |
cohere_api_key=COHERE_API_KEY | |
) | |
beginner_topic_sys_msg = SystemMessage(content="Suppose you are a middle grader who wants to learn constantly about new topics to get a good score in exams.") | |
middle_topic_sys_msg = SystemMessage(content="Suppose you are a college student who wants to learn constantly about new topics to get a good score in exams.") | |
advanced_topic_sys_msg = SystemMessage(content="Suppose you are a teacher who wants to learn constantly about new topics to teach your students.") | |
def retrieve_node(state): | |
topic = state.story_topic | |
query = f"information about {topic}" | |
retriever = Chroma(persist_directory=persist_dir, embedding_function=embedding_llm).as_retriever(search_kwargs={"k": 20}) | |
docs = retriever.get_relevant_documents(query) | |
return {"retrieved_docs": docs, "question": query} | |
def rerank_node(state): | |
topic = state.story_topic | |
query = f"Rerank documents based on how good they explain the topic {topic}" | |
docs = state.retrieved_docs | |
texts = [doc.page_content for doc in docs] | |
rerank_results = co.rerank(query=query, documents=texts, top_n=5, model="rerank-v3.5") | |
top_docs = [texts[result.index] for result in rerank_results.results] | |
return {"reranked_docs": top_docs, "question": query} | |
def generate_story_node(state): | |
context = "\n\n".join(state.reranked_docs) | |
topic = state.story_topic | |
system_message = """ | |
Suppose you're an amazing story writer and scientific thinker. | |
You've written hundreds of story books explaining scientific topics in a childlike manner. | |
You add a subtle humor to your stories to make them more engaging. | |
""" | |
prompt = f""" | |
Use the following context to generate a simple engaging story that explains {topic} in such a way a middle schooler can understand it. | |
Context: | |
{context} | |
Story: | |
""" | |
response = llm.invoke([SystemMessage(system_message), HumanMessage(prompt)]) | |
return {"stories": response} | |
def beginner_topic(state: State): | |
prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?" | |
sub_topics = [llm.invoke([beginner_topic_sys_msg] + [prompt])] | |
return {"message": sub_topics[0], "sub_topics": sub_topics[0]} | |
def middle_topic(state: State): | |
prompt = f"What are the middle-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}" | |
sub_topics = [llm.invoke([middle_topic_sys_msg] + [prompt])] | |
return {"message": sub_topics, "sub_topics": sub_topics} | |
def advanced_topic(state: State): | |
prompt = f"What are the advanced-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}" | |
sub_topics = [llm.invoke([advanced_topic_sys_msg] + [prompt])] | |
return {"message": sub_topics, "sub_topics": sub_topics} | |
def topic_extractor(state: State): | |
return {"sub_topic_list": extract_topics(state.sub_topics)} | |
def dynamic_topic_edges(state: State): | |
return [Send("story_generator", {"story_topic": topic}) for topic in state.sub_topic_list] | |
story_builder = StateGraph(StoryState) | |
story_builder.add_node("Retrieve", retrieve_node) | |
story_builder.add_node("Rerank", rerank_node) | |
story_builder.add_node("Generate", generate_story_node) | |
story_builder.set_entry_point("Retrieve") | |
story_builder.add_edge("Retrieve", "Rerank") | |
story_builder.add_edge("Rerank", "Generate") | |
story_builder.set_finish_point("Generate") | |
story_graph = story_builder.compile() | |
main_builder = StateGraph(State) | |
main_builder.add_node("beginner_topic", beginner_topic) | |
main_builder.add_node("middle_topic", middle_topic) | |
main_builder.add_node("advanced_topic", advanced_topic) | |
main_builder.add_node("topic_extractor", topic_extractor) | |
main_builder.add_node("story_generator", story_graph) | |
main_builder.add_edge(START, "beginner_topic") | |
main_builder.add_edge("beginner_topic", "middle_topic") | |
main_builder.add_edge("middle_topic", "advanced_topic") | |
main_builder.add_edge("advanced_topic", "topic_extractor") | |
main_builder.add_conditional_edges("topic_extractor", dynamic_topic_edges, ["story_generator"]) | |
main_builder.add_edge("story_generator", END) | |
memory = MemorySaver() | |
react_graph = main_builder.compile(checkpointer=memory, interrupt_after=["topic_extractor"]) | |
st.title("LangGraph Topic Story Generator") | |
uploaded_files = st.file_uploader( | |
"Upload .txt or .pdf files", | |
type=["txt", "pdf"], | |
accept_multiple_files=True, | |
key="file_uploader" | |
) | |
if uploaded_files: | |
st.session_state["files"] = uploaded_files | |
st.success(f"{len(uploaded_files)} file(s) uploaded:") | |
for file in uploaded_files: | |
st.write(f"β’ {file.name} ({file.size} bytes)") | |
elif "files" in st.session_state: | |
st.info("Using previously uploaded files:") | |
for file in st.session_state["files"]: | |
st.write(f"β’ {file.name} ({file.size} bytes)") | |
else: | |
st.info("No files uploaded yet.") | |
topic = st.text_input("Enter a topic", "Human Evolution") | |
context = st.text_input("Enter a context", "Science") | |
if st.button("Generate Stories"): | |
uploaded = st.session_state.get("files") | |
if not uploaded or all(file.size == 0 for file in uploaded): | |
st.warning("You uploaded files, but they appear to be empty.") | |
st.stop() | |
try: | |
prepare_vectorstore(uploaded) | |
thread = {"configurable": {"thread_id": "1"}} | |
react_graph.invoke({"topic": [topic], "context": [context]}, thread) | |
react_graph.update_state(thread, {"sub_topic_list": ['Early Hominins', 'Fossil Evidence', "Darwin's Theory of Evolution"]}) | |
result = react_graph.invoke(None, thread, stream_mode="values") | |
for story in result["stories"]: | |
st.markdown(story.content) | |
except Exception as e: | |
st.error("Something went wrong during story generation.") | |
st.exception(e) | |
st.text(traceback.format_exc()) | |