njbrake commited on
Commit
6fdc19a
ยท
1 Parent(s): 79c3db3
.github/workflows/hf_spaces.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://HF_USERNAME:[email protected]/spaces/mozilla-ai/surf-spot-finder main
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ software-properties-common \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY demo/requirements.txt ./
13
+ COPY demo/ ./demo/
14
+
15
+ RUN pip3 install -r requirements.txt
16
+
17
+ EXPOSE 8501
18
+
19
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
+
21
+ ENTRYPOINT ["streamlit", "run", "demo/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
demo/app.py DELETED
File without changes
demo/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_OPTIONS = [
2
+ # "huggingface/novita/deepseek-ai/DeepSeek-V3",
3
+ # "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
4
+ "gemini/gemini-2.0-flash-lite",
5
+ "gemini-2.0-flash",
6
+ # "huggingface/Qwen/Qwen3-32B", # right now throwing an internal error, but novita qwen isn't supporting tool calling
7
+ ]
8
+
9
+ # Novita was the only HF based provider that worked.
10
+
11
+ # Hugginface API Provider Error:
12
+ # Must alternate between assistant/user, which meant that the 'tool' role made it puke
demo/pickers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import json
3
+ import requests
4
+ import streamlit as st
5
+ from any_agent import AgentFramework
6
+ from any_agent.tracing.trace import _is_tracing_supported
7
+
8
+ from constants import MODEL_OPTIONS
9
+
10
+
11
+ @st.cache_resource
12
+ def get_area(area_name: str) -> dict:
13
+ """Get the area from Nominatim.
14
+
15
+ Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
16
+
17
+ Args:
18
+ area_name (str): The name of the area.
19
+
20
+ Returns:
21
+ dict: The area found.
22
+ """
23
+ response = requests.get(
24
+ f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
25
+ headers={"User-Agent": "Mozilla/5.0"},
26
+ timeout=5,
27
+ )
28
+ response.raise_for_status()
29
+ response_json = json.loads(response.content.decode())
30
+ return response_json
31
+
32
+
33
+ def get_user_inputs() -> dict:
34
+ st.title("Surf Spot Finder")
35
+ st.write(
36
+ "This app finds the best surf spots in your area based on the weather forecast."
37
+ )
38
+ default_val = "Los Angeles California, US"
39
+
40
+ location = st.text_input("Enter a location", value=default_val)
41
+ if location:
42
+ location_check = get_area(location)
43
+ if not location_check:
44
+ st.error("Invalid location. Please enter a valid location.")
45
+ return None
46
+ else:
47
+ # display a checkmark saying that the location is valid
48
+ st.success("Valid location")
49
+ max_driving_hours = st.number_input(
50
+ "Enter the maximum driving hours", min_value=1, value=2
51
+ )
52
+ date = st.date_input(
53
+ "Select a date in the future", value=datetime.now() + timedelta(days=1)
54
+ )
55
+
56
+ supported_frameworks = [
57
+ framework for framework in AgentFramework if _is_tracing_supported(framework)
58
+ ]
59
+
60
+ framework = st.selectbox(
61
+ "Select the agent framework to use",
62
+ supported_frameworks,
63
+ index=2,
64
+ format_func=lambda x: x.name,
65
+ )
66
+
67
+ model_id = st.selectbox(
68
+ "Select the model to use",
69
+ MODEL_OPTIONS,
70
+ index=0,
71
+ format_func=lambda x: "/".join(x.split("/")[-3:]),
72
+ )
73
+
74
+ return {
75
+ "location": location,
76
+ "max_driving_hours": max_driving_hours,
77
+ "date": date,
78
+ "framework": framework,
79
+ "model_id": model_id,
80
+ }
demo/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ openai-agents>=0.0.14
3
+ any-agent @ git+https://github.com/mozilla-ai/any-agent
4
+ surf-spot-finder @ git+https://github.com/mozilla-ai/surf-spot-finder
5
+ nest_asyncio
demo/streamlit_app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import asyncio
3
+ import nest_asyncio
4
+ from pickers import get_user_inputs
5
+ from utils import run_agent
6
+
7
+
8
+ nest_asyncio.apply()
9
+
10
+
11
+ with st.sidebar:
12
+ user_inputs = get_user_inputs()
13
+ is_valid = user_inputs is not None
14
+ run_button = st.button("Run", disabled=not is_valid)
15
+
16
+
17
+ async def main():
18
+ if run_button:
19
+ await run_agent(user_inputs)
20
+ else:
21
+ st.write("Click run to start!")
22
+
23
+
24
+ if __name__ == "__main__":
25
+ loop = asyncio.new_event_loop()
26
+ loop.run_until_complete(main())
demo/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from surf_spot_finder.tools import (
3
+ driving_hours_to_meters,
4
+ get_area_lat_lon,
5
+ get_surfing_spots,
6
+ get_wave_forecast,
7
+ get_wind_forecast,
8
+ )
9
+ from surf_spot_finder.config import Config
10
+ from any_agent import AgentConfig, AnyAgent, TracingConfig
11
+ from any_agent.evaluation import evaluate, TraceEvaluationResult
12
+
13
+
14
+ async def run_agent(user_inputs):
15
+ st.write("Running surf spot finder...")
16
+ if "huggingface" in user_inputs["model_id"]:
17
+ model_args = {
18
+ "extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
19
+ }
20
+ else:
21
+ model_args = {}
22
+ agent_config = AgentConfig(
23
+ model_id=user_inputs["model_id"],
24
+ model_args=model_args,
25
+ tools=[
26
+ get_wind_forecast,
27
+ get_wave_forecast,
28
+ get_area_lat_lon,
29
+ get_surfing_spots,
30
+ driving_hours_to_meters,
31
+ ],
32
+ )
33
+ config = Config(
34
+ location=user_inputs["location"],
35
+ max_driving_hours=user_inputs["max_driving_hours"],
36
+ date=user_inputs["date"],
37
+ framework=user_inputs["framework"],
38
+ main_agent=agent_config,
39
+ managed_agents=[],
40
+ evaluation_cases=None,
41
+ )
42
+
43
+ agent = await AnyAgent.create_async(
44
+ agent_framework=config.framework,
45
+ agent_config=config.main_agent,
46
+ managed_agents=config.managed_agents,
47
+ tracing=TracingConfig(console=True, cost_info=True),
48
+ )
49
+
50
+ query = config.input_prompt_template.format(
51
+ LOCATION=config.location,
52
+ MAX_DRIVING_HOURS=config.max_driving_hours,
53
+ DATE=config.date,
54
+ )
55
+ st.write("Running agent with query:\n", query)
56
+
57
+ with st.spinner("Running..."):
58
+ agent_trace = await agent.run_async(query)
59
+ agent.exit()
60
+
61
+ st.write("Final output from agent:\n", agent_trace.final_output)
62
+
63
+ # Display the agent trace
64
+ with st.expander("Agent Trace", expanded=True):
65
+ st.write(agent_trace.spans)
66
+
67
+ if config.evaluation_cases is not None:
68
+ results = []
69
+ st.write("Found evaluation cases, running trace evaluation")
70
+ for i, case in enumerate(config.evaluation_cases):
71
+ st.write("Evaluating case: ", case)
72
+ result: TraceEvaluationResult = evaluate(
73
+ evaluation_case=case,
74
+ trace=agent_trace,
75
+ agent_framework=config.framework,
76
+ )
77
+ for list_of_checkpoints in [
78
+ result.checkpoint_results,
79
+ result.direct_results,
80
+ result.hypothesis_answer_results,
81
+ ]:
82
+ for checkpoint in list_of_checkpoints:
83
+ msg = (
84
+ f"Checkpoint: {checkpoint.criteria}\n"
85
+ f"\tPassed: {checkpoint.passed}\n"
86
+ f"\tReason: {checkpoint.reason}\n"
87
+ f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
88
+ )
89
+ st.write(msg)
90
+ st.write("==========================")
91
+ st.write("Overall Score: %d%%", 100 * result.score)
92
+ st.write("==========================")
93
+ results.append(result)
94
+ st.write("Surf spot finder finished running.")