File size: 4,799 Bytes
382c900
75c7731
9f9ab63
1f0e494
9f9ab63
 
 
 
 
 
 
382c900
75c7731
 
3c85b1d
75c7731
9f9ab63
1f0e494
2a3e737
1f0e494
 
9f9ab63
1f0e494
9f9ab63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c85b1d
9f9ab63
 
 
 
 
3c85b1d
9f9ab63
 
 
 
 
 
 
 
 
3c85b1d
75c7731
 
 
 
 
 
 
 
9f9ab63
 
75c7731
9f9ab63
 
 
 
 
 
75c7731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f9ab63
75c7731
3c85b1d
9f9ab63
75c7731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import gradio as gr
import pandas as pd
from langchain_together import ChatTogether
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_experimental.tools import PythonAstREPLTool
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter

# Global variable to store QA history
qa_history = []

    
def load_model(api_key):
    return ChatTogether(
        api_key=api_key,
        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
        temperature=0
    )

def create_chain(df, llm):
    tool = PythonAstREPLTool(locals={"df": df})
    llm_with_tools = llm.bind_tools([tool], tool_choice=tool.name)
    parser = JsonOutputKeyToolsParser(key_name=tool.name, first_tool_only=True)
    
    system = f"""You have access to a pandas dataframe `df`. Here is the output of `df.head().to_markdown()`:
    ```
    {df.head().to_markdown()}
    ```
    Given a user question, write the Python code to answer it. Don't assume you have access to any libraries other than built-in Python ones and pandas.
    Respond directly to the question once you have enough information to answer it."""
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", system),
        ("human", "{question}"),
        MessagesPlaceholder("chat_history", optional=True),
    ])

    def _get_chat_history(x):
        ai_msg = x["ai_msg"]
        tool_call_id = x["ai_msg"].additional_kwargs["tool_calls"][0]["id"]
        tool_msg = ToolMessage(tool_call_id=tool_call_id, content=str(x["tool_output"]))
        return [ai_msg, tool_msg]

    chain = (
        RunnablePassthrough.assign(ai_msg=prompt | llm_with_tools)
        .assign(tool_output=itemgetter("ai_msg") | parser | tool)
        .assign(chat_history=_get_chat_history)
        .assign(response=prompt | llm | StrOutputParser())
        .pick(["tool_output", "response"])
    )
    
    return chain


def update_qa_history():
    # Convert QA history to DataFrame for display
    if not qa_history:
        return pd.DataFrame(columns=["CSV File", "Question", "Answer"]).to_markdown()
    return pd.DataFrame(qa_history, columns=["CSV File", "Question", "Answer"]).to_markdown()


def process_query(csv_file, api_key, query):
    if not api_key.strip():
        return "Please provide an API key", update_qa_history()
    
    try:
        df = pd.read_csv(csv_file.name)
        llm = load_model(api_key)
        chain = create_chain(df, llm)
        result = chain.invoke({"question": query})
        
        # Format the response
        response = f"Analysis Result:\n{result['response']}\n\nTechnical Details:\n{result['tool_output']}"
        
        # Extract just the filename without path
        filename = os.path.basename(csv_file.name)
        
        # Add to QA history
        qa_history.append([
            filename,  # Store only the filename
            query,
            result['response']  # Store just the human-readable response
        ])
        
        return response, update_qa_history()
    except Exception as e:
        return f"Error: {str(e)}", update_qa_history()

# Create Gradio interface
with gr.Blocks(title="CSV Analysis Assistant") as iface:
    gr.Markdown("# CSV Analysis Assistant")
    gr.Markdown("Upload a CSV file and ask questions about it using natural language.")
    
    # Top section: Split into left (inputs) and right (result)
    with gr.Row():
        # Left column for inputs
        with gr.Column(scale=1):
            file_input = gr.File(label="Upload CSV File")
            api_key = gr.Textbox(label="Together.ai API Key", type="password")
            query = gr.Textbox(label="Your Question")
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")
        
        # Right column for result
        with gr.Column(scale=1):
            output = gr.Textbox(label="Result", lines=10)
    
    # Bottom section: Full width for history table
    with gr.Row():
        history = gr.Markdown(value="### Question & Answer History\n" + update_qa_history())
    
    # Handle button events
    submit_btn.click(
        fn=process_query,
        inputs=[file_input, api_key, query],
        outputs=[output, history]
    )
    
    def clear_inputs():
        return [None, "", "", "", "### Question & Answer History\n" + update_qa_history()]
    
    clear_btn.click(
        fn=clear_inputs,
        inputs=[],
        outputs=[file_input, api_key, query, output, history]
    )
    
# For Hugging Face Spaces deployment
iface.launch()