Spaces:
Running
Running
hf space
Browse files- .github/workflows/hf_spaces.yaml +20 -0
- Dockerfile +21 -0
- demo/app.py +0 -0
- demo/constants.py +12 -0
- demo/pickers.py +80 -0
- demo/requirements.txt +5 -0
- demo/streamlit_app.py +26 -0
- demo/utils.py +94 -0
.github/workflows/hf_spaces.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
lfs: true
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push https://HF_USERNAME:[email protected]/spaces/mozilla-ai/surf-spot-finder main
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
build-essential \
|
7 |
+
curl \
|
8 |
+
software-properties-common \
|
9 |
+
git \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
COPY demo/requirements.txt ./
|
13 |
+
COPY demo/ ./demo/
|
14 |
+
|
15 |
+
RUN pip3 install -r requirements.txt
|
16 |
+
|
17 |
+
EXPOSE 8501
|
18 |
+
|
19 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
20 |
+
|
21 |
+
ENTRYPOINT ["streamlit", "run", "demo/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
demo/app.py
DELETED
File without changes
|
demo/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_OPTIONS = [
|
2 |
+
# "huggingface/novita/deepseek-ai/DeepSeek-V3",
|
3 |
+
# "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
|
4 |
+
"gemini/gemini-2.0-flash-lite",
|
5 |
+
"gemini-2.0-flash",
|
6 |
+
# "huggingface/Qwen/Qwen3-32B", # right now throwing an internal error, but novita qwen isn't supporting tool calling
|
7 |
+
]
|
8 |
+
|
9 |
+
# Novita was the only HF based provider that worked.
|
10 |
+
|
11 |
+
# Hugginface API Provider Error:
|
12 |
+
# Must alternate between assistant/user, which meant that the 'tool' role made it puke
|
demo/pickers.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timedelta
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import streamlit as st
|
5 |
+
from any_agent import AgentFramework
|
6 |
+
from any_agent.tracing.trace import _is_tracing_supported
|
7 |
+
|
8 |
+
from constants import MODEL_OPTIONS
|
9 |
+
|
10 |
+
|
11 |
+
@st.cache_resource
|
12 |
+
def get_area(area_name: str) -> dict:
|
13 |
+
"""Get the area from Nominatim.
|
14 |
+
|
15 |
+
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
16 |
+
|
17 |
+
Args:
|
18 |
+
area_name (str): The name of the area.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
dict: The area found.
|
22 |
+
"""
|
23 |
+
response = requests.get(
|
24 |
+
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
|
25 |
+
headers={"User-Agent": "Mozilla/5.0"},
|
26 |
+
timeout=5,
|
27 |
+
)
|
28 |
+
response.raise_for_status()
|
29 |
+
response_json = json.loads(response.content.decode())
|
30 |
+
return response_json
|
31 |
+
|
32 |
+
|
33 |
+
def get_user_inputs() -> dict:
|
34 |
+
st.title("Surf Spot Finder")
|
35 |
+
st.write(
|
36 |
+
"This app finds the best surf spots in your area based on the weather forecast."
|
37 |
+
)
|
38 |
+
default_val = "Los Angeles California, US"
|
39 |
+
|
40 |
+
location = st.text_input("Enter a location", value=default_val)
|
41 |
+
if location:
|
42 |
+
location_check = get_area(location)
|
43 |
+
if not location_check:
|
44 |
+
st.error("Invalid location. Please enter a valid location.")
|
45 |
+
return None
|
46 |
+
else:
|
47 |
+
# display a checkmark saying that the location is valid
|
48 |
+
st.success("Valid location")
|
49 |
+
max_driving_hours = st.number_input(
|
50 |
+
"Enter the maximum driving hours", min_value=1, value=2
|
51 |
+
)
|
52 |
+
date = st.date_input(
|
53 |
+
"Select a date in the future", value=datetime.now() + timedelta(days=1)
|
54 |
+
)
|
55 |
+
|
56 |
+
supported_frameworks = [
|
57 |
+
framework for framework in AgentFramework if _is_tracing_supported(framework)
|
58 |
+
]
|
59 |
+
|
60 |
+
framework = st.selectbox(
|
61 |
+
"Select the agent framework to use",
|
62 |
+
supported_frameworks,
|
63 |
+
index=2,
|
64 |
+
format_func=lambda x: x.name,
|
65 |
+
)
|
66 |
+
|
67 |
+
model_id = st.selectbox(
|
68 |
+
"Select the model to use",
|
69 |
+
MODEL_OPTIONS,
|
70 |
+
index=0,
|
71 |
+
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
72 |
+
)
|
73 |
+
|
74 |
+
return {
|
75 |
+
"location": location,
|
76 |
+
"max_driving_hours": max_driving_hours,
|
77 |
+
"date": date,
|
78 |
+
"framework": framework,
|
79 |
+
"model_id": model_id,
|
80 |
+
}
|
demo/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
openai-agents>=0.0.14
|
3 |
+
any-agent @ git+https://github.com/mozilla-ai/any-agent
|
4 |
+
surf-spot-finder @ git+https://github.com/mozilla-ai/surf-spot-finder
|
5 |
+
nest_asyncio
|
demo/streamlit_app.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import asyncio
|
3 |
+
import nest_asyncio
|
4 |
+
from pickers import get_user_inputs
|
5 |
+
from utils import run_agent
|
6 |
+
|
7 |
+
|
8 |
+
nest_asyncio.apply()
|
9 |
+
|
10 |
+
|
11 |
+
with st.sidebar:
|
12 |
+
user_inputs = get_user_inputs()
|
13 |
+
is_valid = user_inputs is not None
|
14 |
+
run_button = st.button("Run", disabled=not is_valid)
|
15 |
+
|
16 |
+
|
17 |
+
async def main():
|
18 |
+
if run_button:
|
19 |
+
await run_agent(user_inputs)
|
20 |
+
else:
|
21 |
+
st.write("Click run to start!")
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
loop = asyncio.new_event_loop()
|
26 |
+
loop.run_until_complete(main())
|
demo/utils.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from surf_spot_finder.tools import (
|
3 |
+
driving_hours_to_meters,
|
4 |
+
get_area_lat_lon,
|
5 |
+
get_surfing_spots,
|
6 |
+
get_wave_forecast,
|
7 |
+
get_wind_forecast,
|
8 |
+
)
|
9 |
+
from surf_spot_finder.config import Config
|
10 |
+
from any_agent import AgentConfig, AnyAgent, TracingConfig
|
11 |
+
from any_agent.evaluation import evaluate, TraceEvaluationResult
|
12 |
+
|
13 |
+
|
14 |
+
async def run_agent(user_inputs):
|
15 |
+
st.write("Running surf spot finder...")
|
16 |
+
if "huggingface" in user_inputs["model_id"]:
|
17 |
+
model_args = {
|
18 |
+
"extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
|
19 |
+
}
|
20 |
+
else:
|
21 |
+
model_args = {}
|
22 |
+
agent_config = AgentConfig(
|
23 |
+
model_id=user_inputs["model_id"],
|
24 |
+
model_args=model_args,
|
25 |
+
tools=[
|
26 |
+
get_wind_forecast,
|
27 |
+
get_wave_forecast,
|
28 |
+
get_area_lat_lon,
|
29 |
+
get_surfing_spots,
|
30 |
+
driving_hours_to_meters,
|
31 |
+
],
|
32 |
+
)
|
33 |
+
config = Config(
|
34 |
+
location=user_inputs["location"],
|
35 |
+
max_driving_hours=user_inputs["max_driving_hours"],
|
36 |
+
date=user_inputs["date"],
|
37 |
+
framework=user_inputs["framework"],
|
38 |
+
main_agent=agent_config,
|
39 |
+
managed_agents=[],
|
40 |
+
evaluation_cases=None,
|
41 |
+
)
|
42 |
+
|
43 |
+
agent = await AnyAgent.create_async(
|
44 |
+
agent_framework=config.framework,
|
45 |
+
agent_config=config.main_agent,
|
46 |
+
managed_agents=config.managed_agents,
|
47 |
+
tracing=TracingConfig(console=True, cost_info=True),
|
48 |
+
)
|
49 |
+
|
50 |
+
query = config.input_prompt_template.format(
|
51 |
+
LOCATION=config.location,
|
52 |
+
MAX_DRIVING_HOURS=config.max_driving_hours,
|
53 |
+
DATE=config.date,
|
54 |
+
)
|
55 |
+
st.write("Running agent with query:\n", query)
|
56 |
+
|
57 |
+
with st.spinner("Running..."):
|
58 |
+
agent_trace = await agent.run_async(query)
|
59 |
+
agent.exit()
|
60 |
+
|
61 |
+
st.write("Final output from agent:\n", agent_trace.final_output)
|
62 |
+
|
63 |
+
# Display the agent trace
|
64 |
+
with st.expander("Agent Trace", expanded=True):
|
65 |
+
st.write(agent_trace.spans)
|
66 |
+
|
67 |
+
if config.evaluation_cases is not None:
|
68 |
+
results = []
|
69 |
+
st.write("Found evaluation cases, running trace evaluation")
|
70 |
+
for i, case in enumerate(config.evaluation_cases):
|
71 |
+
st.write("Evaluating case: ", case)
|
72 |
+
result: TraceEvaluationResult = evaluate(
|
73 |
+
evaluation_case=case,
|
74 |
+
trace=agent_trace,
|
75 |
+
agent_framework=config.framework,
|
76 |
+
)
|
77 |
+
for list_of_checkpoints in [
|
78 |
+
result.checkpoint_results,
|
79 |
+
result.direct_results,
|
80 |
+
result.hypothesis_answer_results,
|
81 |
+
]:
|
82 |
+
for checkpoint in list_of_checkpoints:
|
83 |
+
msg = (
|
84 |
+
f"Checkpoint: {checkpoint.criteria}\n"
|
85 |
+
f"\tPassed: {checkpoint.passed}\n"
|
86 |
+
f"\tReason: {checkpoint.reason}\n"
|
87 |
+
f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
|
88 |
+
)
|
89 |
+
st.write(msg)
|
90 |
+
st.write("==========================")
|
91 |
+
st.write("Overall Score: %d%%", 100 * result.score)
|
92 |
+
st.write("==========================")
|
93 |
+
results.append(result)
|
94 |
+
st.write("Surf spot finder finished running.")
|