Spaces:
Running
Running
Upload 229 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- 20240414161707_basejump-setup.sql +186 -0
- 20240414161947_basejump-accounts.sql +708 -0
- 20240414162100_basejump-invitations.sql +270 -0
- 20240414162131_basejump-billing.sql +236 -0
- 20250409211903_basejump-configure.sql +3 -0
- 20250409212058_initial.sql +189 -0
- 20250416133920_agentpress_schema.sql +382 -0
- 20250506000000_initial_setup.sql +85 -0
- 20250506000001_account_functions.sql +50 -0
- 20250506000002_project_functions.sql +105 -0
- 22dc0511fe69_add_toolsource_table.cpython-311.pyc +0 -0
- 2ea570019b8f_add_apikey_table.cpython-311.pyc +0 -0
- 2ea570019b8f_add_apikey_table.py +58 -0
- 4af13678b83c_add_toolsource_table.cpython-311.pyc +0 -0
- 4af13678b83c_add_toolsource_table.py +50 -0
- ActiveJobsProvider.py +57 -0
- AmazonProvider.py +191 -0
- ChatInterface.tsx +30 -0
- Dockerfile +19 -0
- Layout.tsx +41 -0
- LinkedinProvider.py +250 -0
- MANIFEST.in +17 -0
- README +1 -0
- README.md +36 -10
- RapidDataProviderBase.py +61 -0
- SettingsPanel.tsx +31 -0
- TwitterProvider.py +240 -0
- WorkflowEditor.tsx +52 -0
- YahooFinanceProvider.py +190 -0
- ZillowProvider.py +187 -0
- __init__.py +1 -0
- added_tokens.json +24 -0
- agent.py +41 -0
- alembic.ini +62 -0
- api.cpython-311.pyc +0 -0
- api.py +311 -0
- api.py.bak +156 -0
- api_keys.py +68 -0
- architecture_diagram.svg +0 -0
- auth_utils.py +177 -0
- base.py +33 -0
- billing.py +125 -0
- browser_api.py +2063 -0
- chat_template.json +3 -0
- cleanup.sh +34 -0
- compose-dev.yaml +12 -0
- computer_use_tool.py +624 -0
- config.cpython-311.pyc +0 -0
- config.json +495 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
diagram.png filter=lfs diff=lfs merge=lfs -text
|
20240414161707_basejump-setup.sql
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
____ _
|
3 |
+
| _ \ (_)
|
4 |
+
| |_) | __ _ ___ ___ _ _ _ _ __ ___ _ __
|
5 |
+
| _ < / _` / __|/ _ \ | | | | '_ ` _ \| '_ \
|
6 |
+
| |_) | (_| \__ \ __/ | |_| | | | | | | |_) |
|
7 |
+
|____/ \__,_|___/\___| |\__,_|_| |_| |_| .__/
|
8 |
+
_/ | | |
|
9 |
+
|__/ |_|
|
10 |
+
|
11 |
+
Basejump is a starter kit for building SaaS products on top of Supabase.
|
12 |
+
Learn more at https://usebasejump.com
|
13 |
+
*/
|
14 |
+
|
15 |
+
|
16 |
+
/**
|
17 |
+
* -------------------------------------------------------
|
18 |
+
* Section - Basejump schema setup and utility functions
|
19 |
+
* -------------------------------------------------------
|
20 |
+
*/
|
21 |
+
|
22 |
+
-- revoke execution by default from public
|
23 |
+
ALTER DEFAULT PRIVILEGES REVOKE EXECUTE ON FUNCTIONS FROM PUBLIC;
|
24 |
+
ALTER DEFAULT PRIVILEGES IN SCHEMA PUBLIC REVOKE EXECUTE ON FUNCTIONS FROM anon, authenticated;
|
25 |
+
|
26 |
+
-- Create basejump schema
|
27 |
+
CREATE SCHEMA IF NOT EXISTS basejump;
|
28 |
+
GRANT USAGE ON SCHEMA basejump to authenticated;
|
29 |
+
GRANT USAGE ON SCHEMA basejump to service_role;
|
30 |
+
|
31 |
+
/**
|
32 |
+
* -------------------------------------------------------
|
33 |
+
* Section - Enums
|
34 |
+
* -------------------------------------------------------
|
35 |
+
*/
|
36 |
+
|
37 |
+
/**
|
38 |
+
* Invitation types are either email or link. Email invitations are sent to
|
39 |
+
* a single user and can only be claimed once. Link invitations can be used multiple times
|
40 |
+
* Both expire after 24 hours
|
41 |
+
*/
|
42 |
+
DO
|
43 |
+
$$
|
44 |
+
BEGIN
|
45 |
+
-- check it account_role already exists on basejump schema
|
46 |
+
IF NOT EXISTS(SELECT 1
|
47 |
+
FROM pg_type t
|
48 |
+
JOIN pg_namespace n ON n.oid = t.typnamespace
|
49 |
+
WHERE t.typname = 'invitation_type'
|
50 |
+
AND n.nspname = 'basejump') THEN
|
51 |
+
CREATE TYPE basejump.invitation_type AS ENUM ('one_time', '24_hour');
|
52 |
+
end if;
|
53 |
+
end;
|
54 |
+
$$;
|
55 |
+
|
56 |
+
/**
|
57 |
+
* -------------------------------------------------------
|
58 |
+
* Section - Basejump settings
|
59 |
+
* -------------------------------------------------------
|
60 |
+
*/
|
61 |
+
|
62 |
+
CREATE TABLE IF NOT EXISTS basejump.config
|
63 |
+
(
|
64 |
+
enable_team_accounts boolean default true,
|
65 |
+
enable_personal_account_billing boolean default true,
|
66 |
+
enable_team_account_billing boolean default true,
|
67 |
+
billing_provider text default 'stripe'
|
68 |
+
);
|
69 |
+
|
70 |
+
-- create config row
|
71 |
+
INSERT INTO basejump.config (enable_team_accounts, enable_personal_account_billing, enable_team_account_billing)
|
72 |
+
VALUES (true, true, true);
|
73 |
+
|
74 |
+
-- enable select on the config table
|
75 |
+
GRANT SELECT ON basejump.config TO authenticated, service_role;
|
76 |
+
|
77 |
+
-- enable RLS on config
|
78 |
+
ALTER TABLE basejump.config
|
79 |
+
ENABLE ROW LEVEL SECURITY;
|
80 |
+
|
81 |
+
create policy "Basejump settings can be read by authenticated users" on basejump.config
|
82 |
+
for select
|
83 |
+
to authenticated
|
84 |
+
using (
|
85 |
+
true
|
86 |
+
);
|
87 |
+
|
88 |
+
/**
|
89 |
+
* -------------------------------------------------------
|
90 |
+
* Section - Basejump utility functions
|
91 |
+
* -------------------------------------------------------
|
92 |
+
*/
|
93 |
+
|
94 |
+
/**
|
95 |
+
basejump.get_config()
|
96 |
+
Get the full config object to check basejump settings
|
97 |
+
This is not accessible from the outside, so can only be used inside postgres functions
|
98 |
+
*/
|
99 |
+
CREATE OR REPLACE FUNCTION basejump.get_config()
|
100 |
+
RETURNS json AS
|
101 |
+
$$
|
102 |
+
DECLARE
|
103 |
+
result RECORD;
|
104 |
+
BEGIN
|
105 |
+
SELECT * from basejump.config limit 1 into result;
|
106 |
+
return row_to_json(result);
|
107 |
+
END;
|
108 |
+
$$ LANGUAGE plpgsql;
|
109 |
+
|
110 |
+
grant execute on function basejump.get_config() to authenticated, service_role;
|
111 |
+
|
112 |
+
|
113 |
+
/**
|
114 |
+
basejump.is_set("field_name")
|
115 |
+
Check a specific boolean config value
|
116 |
+
*/
|
117 |
+
CREATE OR REPLACE FUNCTION basejump.is_set(field_name text)
|
118 |
+
RETURNS boolean AS
|
119 |
+
$$
|
120 |
+
DECLARE
|
121 |
+
result BOOLEAN;
|
122 |
+
BEGIN
|
123 |
+
execute format('select %I from basejump.config limit 1', field_name) into result;
|
124 |
+
return result;
|
125 |
+
END;
|
126 |
+
$$ LANGUAGE plpgsql;
|
127 |
+
|
128 |
+
grant execute on function basejump.is_set(text) to authenticated;
|
129 |
+
|
130 |
+
|
131 |
+
/**
|
132 |
+
* Automatic handling for maintaining created_at and updated_at timestamps
|
133 |
+
* on tables
|
134 |
+
*/
|
135 |
+
CREATE OR REPLACE FUNCTION basejump.trigger_set_timestamps()
|
136 |
+
RETURNS TRIGGER AS
|
137 |
+
$$
|
138 |
+
BEGIN
|
139 |
+
if TG_OP = 'INSERT' then
|
140 |
+
NEW.created_at = now();
|
141 |
+
NEW.updated_at = now();
|
142 |
+
else
|
143 |
+
NEW.updated_at = now();
|
144 |
+
NEW.created_at = OLD.created_at;
|
145 |
+
end if;
|
146 |
+
RETURN NEW;
|
147 |
+
END
|
148 |
+
$$ LANGUAGE plpgsql;
|
149 |
+
|
150 |
+
|
151 |
+
/**
|
152 |
+
* Automatic handling for maintaining created_by and updated_by timestamps
|
153 |
+
* on tables
|
154 |
+
*/
|
155 |
+
CREATE OR REPLACE FUNCTION basejump.trigger_set_user_tracking()
|
156 |
+
RETURNS TRIGGER AS
|
157 |
+
$$
|
158 |
+
BEGIN
|
159 |
+
if TG_OP = 'INSERT' then
|
160 |
+
NEW.created_by = auth.uid();
|
161 |
+
NEW.updated_by = auth.uid();
|
162 |
+
else
|
163 |
+
NEW.updated_by = auth.uid();
|
164 |
+
NEW.created_by = OLD.created_by;
|
165 |
+
end if;
|
166 |
+
RETURN NEW;
|
167 |
+
END
|
168 |
+
$$ LANGUAGE plpgsql;
|
169 |
+
|
170 |
+
/**
|
171 |
+
basejump.generate_token(length)
|
172 |
+
Generates a secure token - used internally for invitation tokens
|
173 |
+
but could be used elsewhere. Check out the invitations table for more info on
|
174 |
+
how it's used
|
175 |
+
*/
|
176 |
+
CREATE OR REPLACE FUNCTION basejump.generate_token(length int)
|
177 |
+
RETURNS text AS
|
178 |
+
$$
|
179 |
+
select regexp_replace(replace(
|
180 |
+
replace(replace(replace(encode(gen_random_bytes(length)::bytea, 'base64'), '/', ''), '+',
|
181 |
+
''), '\', ''),
|
182 |
+
'=',
|
183 |
+
''), E'[\\n\\r]+', '', 'g');
|
184 |
+
$$ LANGUAGE sql;
|
185 |
+
|
186 |
+
grant execute on function basejump.generate_token(int) to authenticated;
|
20240414161947_basejump-accounts.sql
ADDED
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
____ _
|
3 |
+
| _ \ (_)
|
4 |
+
| |_) | __ _ ___ ___ _ _ _ _ __ ___ _ __
|
5 |
+
| _ < / _` / __|/ _ \ | | | | '_ ` _ \| '_ \
|
6 |
+
| |_) | (_| \__ \ __/ | |_| | | | | | | |_) |
|
7 |
+
|____/ \__,_|___/\___| |\__,_|_| |_| |_| .__/
|
8 |
+
_/ | | |
|
9 |
+
|__/ |_|
|
10 |
+
|
11 |
+
Basejump is a starter kit for building SaaS products on top of Supabase.
|
12 |
+
Learn more at https://usebasejump.com
|
13 |
+
*/
|
14 |
+
|
15 |
+
/**
|
16 |
+
* -------------------------------------------------------
|
17 |
+
* Section - Accounts
|
18 |
+
* -------------------------------------------------------
|
19 |
+
*/
|
20 |
+
|
21 |
+
/**
|
22 |
+
* Account roles allow you to provide permission levels to users
|
23 |
+
* when they're acting on an account. By default, we provide
|
24 |
+
* "owner" and "member". The only distinction is that owners can
|
25 |
+
* also manage billing and invite/remove account members.
|
26 |
+
*/
|
27 |
+
DO
|
28 |
+
$$
|
29 |
+
BEGIN
|
30 |
+
-- check it account_role already exists on basejump schema
|
31 |
+
IF NOT EXISTS(SELECT 1
|
32 |
+
FROM pg_type t
|
33 |
+
JOIN pg_namespace n ON n.oid = t.typnamespace
|
34 |
+
WHERE t.typname = 'account_role'
|
35 |
+
AND n.nspname = 'basejump') THEN
|
36 |
+
CREATE TYPE basejump.account_role AS ENUM ('owner', 'member');
|
37 |
+
end if;
|
38 |
+
end;
|
39 |
+
$$;
|
40 |
+
|
41 |
+
/**
|
42 |
+
* Accounts are the primary grouping for most objects within
|
43 |
+
* the system. They have many users, and all billing is connected to
|
44 |
+
* an account.
|
45 |
+
*/
|
46 |
+
CREATE TABLE IF NOT EXISTS basejump.accounts
|
47 |
+
(
|
48 |
+
id uuid unique NOT NULL DEFAULT extensions.uuid_generate_v4(),
|
49 |
+
-- defaults to the user who creates the account
|
50 |
+
-- this user cannot be removed from an account without changing
|
51 |
+
-- the primary owner first
|
52 |
+
primary_owner_user_id uuid references auth.users not null default auth.uid(),
|
53 |
+
-- Account name
|
54 |
+
name text,
|
55 |
+
slug text unique,
|
56 |
+
personal_account boolean default false not null,
|
57 |
+
updated_at timestamp with time zone,
|
58 |
+
created_at timestamp with time zone,
|
59 |
+
created_by uuid references auth.users,
|
60 |
+
updated_by uuid references auth.users,
|
61 |
+
private_metadata jsonb default '{}'::jsonb,
|
62 |
+
public_metadata jsonb default '{}'::jsonb,
|
63 |
+
PRIMARY KEY (id)
|
64 |
+
);
|
65 |
+
|
66 |
+
-- constraint that conditionally allows nulls on the slug ONLY if personal_account is true
|
67 |
+
-- remove this if you want to ignore accounts slugs entirely
|
68 |
+
ALTER TABLE basejump.accounts
|
69 |
+
ADD CONSTRAINT basejump_accounts_slug_null_if_personal_account_true CHECK (
|
70 |
+
(personal_account = true AND slug is null)
|
71 |
+
OR (personal_account = false AND slug is not null)
|
72 |
+
);
|
73 |
+
|
74 |
+
-- Open up access to accounts
|
75 |
+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.accounts TO authenticated, service_role;
|
76 |
+
|
77 |
+
/**
|
78 |
+
* We want to protect some fields on accounts from being updated
|
79 |
+
* Specifically the primary owner user id and account id.
|
80 |
+
* primary_owner_user_id should be updated using the dedicated function
|
81 |
+
*/
|
82 |
+
CREATE OR REPLACE FUNCTION basejump.protect_account_fields()
|
83 |
+
RETURNS TRIGGER AS
|
84 |
+
$$
|
85 |
+
BEGIN
|
86 |
+
IF current_user IN ('authenticated', 'anon') THEN
|
87 |
+
-- these are protected fields that users are not allowed to update themselves
|
88 |
+
-- platform admins should be VERY careful about updating them as well.
|
89 |
+
if NEW.id <> OLD.id
|
90 |
+
OR NEW.personal_account <> OLD.personal_account
|
91 |
+
OR NEW.primary_owner_user_id <> OLD.primary_owner_user_id
|
92 |
+
THEN
|
93 |
+
RAISE EXCEPTION 'You do not have permission to update this field';
|
94 |
+
end if;
|
95 |
+
end if;
|
96 |
+
|
97 |
+
RETURN NEW;
|
98 |
+
END
|
99 |
+
$$ LANGUAGE plpgsql;
|
100 |
+
|
101 |
+
-- trigger to protect account fields
|
102 |
+
CREATE TRIGGER basejump_protect_account_fields
|
103 |
+
BEFORE UPDATE
|
104 |
+
ON basejump.accounts
|
105 |
+
FOR EACH ROW
|
106 |
+
EXECUTE FUNCTION basejump.protect_account_fields();
|
107 |
+
|
108 |
+
-- convert any character in the slug that's not a letter, number, or dash to a dash on insert/update for accounts
|
109 |
+
CREATE OR REPLACE FUNCTION basejump.slugify_account_slug()
|
110 |
+
RETURNS TRIGGER AS
|
111 |
+
$$
|
112 |
+
BEGIN
|
113 |
+
if NEW.slug is not null then
|
114 |
+
NEW.slug = lower(regexp_replace(NEW.slug, '[^a-zA-Z0-9-]+', '-', 'g'));
|
115 |
+
end if;
|
116 |
+
|
117 |
+
RETURN NEW;
|
118 |
+
END
|
119 |
+
$$ LANGUAGE plpgsql;
|
120 |
+
|
121 |
+
-- trigger to slugify the account slug
|
122 |
+
CREATE TRIGGER basejump_slugify_account_slug
|
123 |
+
BEFORE INSERT OR UPDATE
|
124 |
+
ON basejump.accounts
|
125 |
+
FOR EACH ROW
|
126 |
+
EXECUTE FUNCTION basejump.slugify_account_slug();
|
127 |
+
|
128 |
+
-- enable RLS for accounts
|
129 |
+
alter table basejump.accounts
|
130 |
+
enable row level security;
|
131 |
+
|
132 |
+
-- protect the timestamps
|
133 |
+
CREATE TRIGGER basejump_set_accounts_timestamp
|
134 |
+
BEFORE INSERT OR UPDATE
|
135 |
+
ON basejump.accounts
|
136 |
+
FOR EACH ROW
|
137 |
+
EXECUTE PROCEDURE basejump.trigger_set_timestamps();
|
138 |
+
|
139 |
+
-- set the user tracking
|
140 |
+
CREATE TRIGGER basejump_set_accounts_user_tracking
|
141 |
+
BEFORE INSERT OR UPDATE
|
142 |
+
ON basejump.accounts
|
143 |
+
FOR EACH ROW
|
144 |
+
EXECUTE PROCEDURE basejump.trigger_set_user_tracking();
|
145 |
+
|
146 |
+
/**
|
147 |
+
* Account users are the users that are associated with an account.
|
148 |
+
* They can be invited to join the account, and can have different roles.
|
149 |
+
* The system does not enforce any permissions for roles, other than restricting
|
150 |
+
* billing and account membership to only owners
|
151 |
+
*/
|
152 |
+
create table if not exists basejump.account_user
|
153 |
+
(
|
154 |
+
-- id of the user in the account
|
155 |
+
user_id uuid references auth.users on delete cascade not null,
|
156 |
+
-- id of the account the user is in
|
157 |
+
account_id uuid references basejump.accounts on delete cascade not null,
|
158 |
+
-- role of the user in the account
|
159 |
+
account_role basejump.account_role not null,
|
160 |
+
constraint account_user_pkey primary key (user_id, account_id)
|
161 |
+
);
|
162 |
+
|
163 |
+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.account_user TO authenticated, service_role;
|
164 |
+
|
165 |
+
|
166 |
+
-- enable RLS for account_user
|
167 |
+
alter table basejump.account_user
|
168 |
+
enable row level security;
|
169 |
+
|
170 |
+
/**
|
171 |
+
* When an account gets created, we want to insert the current user as the first
|
172 |
+
* owner
|
173 |
+
*/
|
174 |
+
create or replace function basejump.add_current_user_to_new_account()
|
175 |
+
returns trigger
|
176 |
+
language plpgsql
|
177 |
+
security definer
|
178 |
+
set search_path = public
|
179 |
+
as
|
180 |
+
$$
|
181 |
+
begin
|
182 |
+
if new.primary_owner_user_id = auth.uid() then
|
183 |
+
insert into basejump.account_user (account_id, user_id, account_role)
|
184 |
+
values (NEW.id, auth.uid(), 'owner');
|
185 |
+
end if;
|
186 |
+
return NEW;
|
187 |
+
end;
|
188 |
+
$$;
|
189 |
+
|
190 |
+
-- trigger the function whenever a new account is created
|
191 |
+
CREATE TRIGGER basejump_add_current_user_to_new_account
|
192 |
+
AFTER INSERT
|
193 |
+
ON basejump.accounts
|
194 |
+
FOR EACH ROW
|
195 |
+
EXECUTE FUNCTION basejump.add_current_user_to_new_account();
|
196 |
+
|
197 |
+
/**
|
198 |
+
* When a user signs up, we need to create a personal account for them
|
199 |
+
* and add them to the account_user table so they can act on it
|
200 |
+
*/
|
201 |
+
create or replace function basejump.run_new_user_setup()
|
202 |
+
returns trigger
|
203 |
+
language plpgsql
|
204 |
+
security definer
|
205 |
+
set search_path = public
|
206 |
+
as
|
207 |
+
$$
|
208 |
+
declare
|
209 |
+
first_account_id uuid;
|
210 |
+
generated_user_name text;
|
211 |
+
begin
|
212 |
+
|
213 |
+
-- first we setup the user profile
|
214 |
+
-- TODO: see if we can get the user's name from the auth.users table once we learn how oauth works
|
215 |
+
if new.email IS NOT NULL then
|
216 |
+
generated_user_name := split_part(new.email, '@', 1);
|
217 |
+
end if;
|
218 |
+
-- create the new users's personal account
|
219 |
+
insert into basejump.accounts (name, primary_owner_user_id, personal_account, id)
|
220 |
+
values (generated_user_name, NEW.id, true, NEW.id)
|
221 |
+
returning id into first_account_id;
|
222 |
+
|
223 |
+
-- add them to the account_user table so they can act on it
|
224 |
+
insert into basejump.account_user (account_id, user_id, account_role)
|
225 |
+
values (first_account_id, NEW.id, 'owner');
|
226 |
+
|
227 |
+
return NEW;
|
228 |
+
end;
|
229 |
+
$$;
|
230 |
+
|
231 |
+
-- trigger the function every time a user is created
|
232 |
+
create trigger on_auth_user_created
|
233 |
+
after insert
|
234 |
+
on auth.users
|
235 |
+
for each row
|
236 |
+
execute procedure basejump.run_new_user_setup();
|
237 |
+
|
238 |
+
/**
|
239 |
+
* -------------------------------------------------------
|
240 |
+
* Section - Account permission utility functions
|
241 |
+
* -------------------------------------------------------
|
242 |
+
* These functions are stored on the basejump schema, and useful for things like
|
243 |
+
* generating RLS policies
|
244 |
+
*/
|
245 |
+
|
246 |
+
/**
|
247 |
+
* Returns true if the current user has the pass in role on the passed in account
|
248 |
+
* If no role is sent, will return true if the user is a member of the account
|
249 |
+
* NOTE: This is an inefficient function when used on large query sets. You should reach for the get_accounts_with_role and lookup
|
250 |
+
* the account ID in those cases.
|
251 |
+
*/
|
252 |
+
create or replace function basejump.has_role_on_account(account_id uuid, account_role basejump.account_role default null)
|
253 |
+
returns boolean
|
254 |
+
language sql
|
255 |
+
security definer
|
256 |
+
set search_path = public
|
257 |
+
as
|
258 |
+
$$
|
259 |
+
select exists(
|
260 |
+
select 1
|
261 |
+
from basejump.account_user wu
|
262 |
+
where wu.user_id = auth.uid()
|
263 |
+
and wu.account_id = has_role_on_account.account_id
|
264 |
+
and (
|
265 |
+
wu.account_role = has_role_on_account.account_role
|
266 |
+
or has_role_on_account.account_role is null
|
267 |
+
)
|
268 |
+
);
|
269 |
+
$$;
|
270 |
+
|
271 |
+
grant execute on function basejump.has_role_on_account(uuid, basejump.account_role) to authenticated, anon, public, service_role;
|
272 |
+
|
273 |
+
|
274 |
+
/**
|
275 |
+
* Returns account_ids that the current user is a member of. If you pass in a role,
|
276 |
+
* it'll only return accounts that the user is a member of with that role.
|
277 |
+
*/
|
278 |
+
create or replace function basejump.get_accounts_with_role(passed_in_role basejump.account_role default null)
|
279 |
+
returns setof uuid
|
280 |
+
language sql
|
281 |
+
security definer
|
282 |
+
set search_path = public
|
283 |
+
as
|
284 |
+
$$
|
285 |
+
select account_id
|
286 |
+
from basejump.account_user wu
|
287 |
+
where wu.user_id = auth.uid()
|
288 |
+
and (
|
289 |
+
wu.account_role = passed_in_role
|
290 |
+
or passed_in_role is null
|
291 |
+
);
|
292 |
+
$$;
|
293 |
+
|
294 |
+
grant execute on function basejump.get_accounts_with_role(basejump.account_role) to authenticated;
|
295 |
+
|
296 |
+
/**
|
297 |
+
* -------------------------
|
298 |
+
* Section - RLS Policies
|
299 |
+
* -------------------------
|
300 |
+
* This is where we define access to tables in the basejump schema
|
301 |
+
*/
|
302 |
+
|
303 |
+
create policy "users can view their own account_users" on basejump.account_user
|
304 |
+
for select
|
305 |
+
to authenticated
|
306 |
+
using (
|
307 |
+
user_id = auth.uid()
|
308 |
+
);
|
309 |
+
|
310 |
+
create policy "users can view their teammates" on basejump.account_user
|
311 |
+
for select
|
312 |
+
to authenticated
|
313 |
+
using (
|
314 |
+
basejump.has_role_on_account(account_id) = true
|
315 |
+
);
|
316 |
+
|
317 |
+
create policy "Account users can be deleted by owners except primary account owner" on basejump.account_user
|
318 |
+
for delete
|
319 |
+
to authenticated
|
320 |
+
using (
|
321 |
+
(basejump.has_role_on_account(account_id, 'owner') = true)
|
322 |
+
AND
|
323 |
+
user_id != (select primary_owner_user_id
|
324 |
+
from basejump.accounts
|
325 |
+
where account_id = accounts.id)
|
326 |
+
);
|
327 |
+
|
328 |
+
create policy "Accounts are viewable by members" on basejump.accounts
|
329 |
+
for select
|
330 |
+
to authenticated
|
331 |
+
using (
|
332 |
+
basejump.has_role_on_account(id) = true
|
333 |
+
);
|
334 |
+
|
335 |
+
-- Primary owner should always have access to the account
|
336 |
+
create policy "Accounts are viewable by primary owner" on basejump.accounts
|
337 |
+
for select
|
338 |
+
to authenticated
|
339 |
+
using (
|
340 |
+
primary_owner_user_id = auth.uid()
|
341 |
+
);
|
342 |
+
|
343 |
+
create policy "Team accounts can be created by any user" on basejump.accounts
|
344 |
+
for insert
|
345 |
+
to authenticated
|
346 |
+
with check (
|
347 |
+
basejump.is_set('enable_team_accounts') = true
|
348 |
+
and personal_account = false
|
349 |
+
);
|
350 |
+
|
351 |
+
|
352 |
+
create policy "Accounts can be edited by owners" on basejump.accounts
|
353 |
+
for update
|
354 |
+
to authenticated
|
355 |
+
using (
|
356 |
+
basejump.has_role_on_account(id, 'owner') = true
|
357 |
+
);
|
358 |
+
|
359 |
+
/**
|
360 |
+
* -------------------------------------------------------
|
361 |
+
* Section - Public functions
|
362 |
+
* -------------------------------------------------------
|
363 |
+
* Each of these functions exists in the public name space because they are accessible
|
364 |
+
* via the API. it is the primary way developers can interact with Basejump accounts
|
365 |
+
*/
|
366 |
+
|
367 |
+
/**
|
368 |
+
* Returns the account_id for a given account slug
|
369 |
+
*/
|
370 |
+
|
371 |
+
create or replace function public.get_account_id(slug text)
|
372 |
+
returns uuid
|
373 |
+
language sql
|
374 |
+
as
|
375 |
+
$$
|
376 |
+
select id
|
377 |
+
from basejump.accounts
|
378 |
+
where slug = get_account_id.slug;
|
379 |
+
$$;
|
380 |
+
|
381 |
+
grant execute on function public.get_account_id(text) to authenticated, service_role;
|
382 |
+
|
383 |
+
/**
|
384 |
+
* Returns the current user's role within a given account_id
|
385 |
+
*/
|
386 |
+
create or replace function public.current_user_account_role(account_id uuid)
|
387 |
+
returns jsonb
|
388 |
+
language plpgsql
|
389 |
+
as
|
390 |
+
$$
|
391 |
+
DECLARE
|
392 |
+
response jsonb;
|
393 |
+
BEGIN
|
394 |
+
|
395 |
+
select jsonb_build_object(
|
396 |
+
'account_role', wu.account_role,
|
397 |
+
'is_primary_owner', a.primary_owner_user_id = auth.uid(),
|
398 |
+
'is_personal_account', a.personal_account
|
399 |
+
)
|
400 |
+
into response
|
401 |
+
from basejump.account_user wu
|
402 |
+
join basejump.accounts a on a.id = wu.account_id
|
403 |
+
where wu.user_id = auth.uid()
|
404 |
+
and wu.account_id = current_user_account_role.account_id;
|
405 |
+
|
406 |
+
-- if the user is not a member of the account, throw an error
|
407 |
+
if response ->> 'account_role' IS NULL then
|
408 |
+
raise exception 'Not found';
|
409 |
+
end if;
|
410 |
+
|
411 |
+
return response;
|
412 |
+
END
|
413 |
+
$$;
|
414 |
+
|
415 |
+
grant execute on function public.current_user_account_role(uuid) to authenticated;
|
416 |
+
|
417 |
+
/**
|
418 |
+
* Let's you update a users role within an account if you are an owner of that account
|
419 |
+
**/
|
420 |
+
create or replace function public.update_account_user_role(account_id uuid, user_id uuid,
|
421 |
+
new_account_role basejump.account_role,
|
422 |
+
make_primary_owner boolean default false)
|
423 |
+
returns void
|
424 |
+
security definer
|
425 |
+
set search_path = public
|
426 |
+
language plpgsql
|
427 |
+
as
|
428 |
+
$$
|
429 |
+
declare
|
430 |
+
is_account_owner boolean;
|
431 |
+
is_account_primary_owner boolean;
|
432 |
+
changing_primary_owner boolean;
|
433 |
+
begin
|
434 |
+
-- check if the user is an owner, and if they are, allow them to update the role
|
435 |
+
select basejump.has_role_on_account(update_account_user_role.account_id, 'owner') into is_account_owner;
|
436 |
+
|
437 |
+
if not is_account_owner then
|
438 |
+
raise exception 'You must be an owner of the account to update a users role';
|
439 |
+
end if;
|
440 |
+
|
441 |
+
-- check if the user being changed is the primary owner, if so its not allowed
|
442 |
+
select primary_owner_user_id = auth.uid(), primary_owner_user_id = update_account_user_role.user_id
|
443 |
+
into is_account_primary_owner, changing_primary_owner
|
444 |
+
from basejump.accounts
|
445 |
+
where id = update_account_user_role.account_id;
|
446 |
+
|
447 |
+
if changing_primary_owner = true and is_account_primary_owner = false then
|
448 |
+
raise exception 'You must be the primary owner of the account to change the primary owner';
|
449 |
+
end if;
|
450 |
+
|
451 |
+
update basejump.account_user au
|
452 |
+
set account_role = new_account_role
|
453 |
+
where au.account_id = update_account_user_role.account_id
|
454 |
+
and au.user_id = update_account_user_role.user_id;
|
455 |
+
|
456 |
+
if make_primary_owner = true then
|
457 |
+
-- first we see if the current user is the owner, only they can do this
|
458 |
+
if is_account_primary_owner = false then
|
459 |
+
raise exception 'You must be the primary owner of the account to change the primary owner';
|
460 |
+
end if;
|
461 |
+
|
462 |
+
update basejump.accounts
|
463 |
+
set primary_owner_user_id = update_account_user_role.user_id
|
464 |
+
where id = update_account_user_role.account_id;
|
465 |
+
end if;
|
466 |
+
end;
|
467 |
+
$$;
|
468 |
+
|
469 |
+
grant execute on function public.update_account_user_role(uuid, uuid, basejump.account_role, boolean) to authenticated;
|
470 |
+
|
471 |
+
/**
|
472 |
+
Returns the current user's accounts
|
473 |
+
*/
|
474 |
+
create or replace function public.get_accounts()
|
475 |
+
returns json
|
476 |
+
language sql
|
477 |
+
as
|
478 |
+
$$
|
479 |
+
select coalesce(json_agg(
|
480 |
+
json_build_object(
|
481 |
+
'account_id', wu.account_id,
|
482 |
+
'account_role', wu.account_role,
|
483 |
+
'is_primary_owner', a.primary_owner_user_id = auth.uid(),
|
484 |
+
'name', a.name,
|
485 |
+
'slug', a.slug,
|
486 |
+
'personal_account', a.personal_account,
|
487 |
+
'created_at', a.created_at,
|
488 |
+
'updated_at', a.updated_at
|
489 |
+
)
|
490 |
+
), '[]'::json)
|
491 |
+
from basejump.account_user wu
|
492 |
+
join basejump.accounts a on a.id = wu.account_id
|
493 |
+
where wu.user_id = auth.uid();
|
494 |
+
$$;
|
495 |
+
|
496 |
+
grant execute on function public.get_accounts() to authenticated;
|
497 |
+
|
498 |
+
/**
|
499 |
+
Returns a specific account that the current user has access to
|
500 |
+
*/
|
501 |
+
create or replace function public.get_account(account_id uuid)
|
502 |
+
returns json
|
503 |
+
language plpgsql
|
504 |
+
as
|
505 |
+
$$
|
506 |
+
BEGIN
|
507 |
+
-- check if the user is a member of the account or a service_role user
|
508 |
+
if current_user IN ('anon', 'authenticated') and
|
509 |
+
(select current_user_account_role(get_account.account_id) ->> 'account_role' IS NULL) then
|
510 |
+
raise exception 'You must be a member of an account to access it';
|
511 |
+
end if;
|
512 |
+
|
513 |
+
|
514 |
+
return (select json_build_object(
|
515 |
+
'account_id', a.id,
|
516 |
+
'account_role', wu.account_role,
|
517 |
+
'is_primary_owner', a.primary_owner_user_id = auth.uid(),
|
518 |
+
'name', a.name,
|
519 |
+
'slug', a.slug,
|
520 |
+
'personal_account', a.personal_account,
|
521 |
+
'billing_enabled', case
|
522 |
+
when a.personal_account = true then
|
523 |
+
config.enable_personal_account_billing
|
524 |
+
else
|
525 |
+
config.enable_team_account_billing
|
526 |
+
end,
|
527 |
+
'billing_status', bs.status,
|
528 |
+
'created_at', a.created_at,
|
529 |
+
'updated_at', a.updated_at,
|
530 |
+
'metadata', a.public_metadata
|
531 |
+
)
|
532 |
+
from basejump.accounts a
|
533 |
+
left join basejump.account_user wu on a.id = wu.account_id and wu.user_id = auth.uid()
|
534 |
+
join basejump.config config on true
|
535 |
+
left join (select bs.account_id, status
|
536 |
+
from basejump.billing_subscriptions bs
|
537 |
+
where bs.account_id = get_account.account_id
|
538 |
+
order by created desc
|
539 |
+
limit 1) bs on bs.account_id = a.id
|
540 |
+
where a.id = get_account.account_id);
|
541 |
+
END;
|
542 |
+
$$;
|
543 |
+
|
544 |
+
grant execute on function public.get_account(uuid) to authenticated, service_role;
|
545 |
+
|
546 |
+
/**
|
547 |
+
Returns a specific account that the current user has access to
|
548 |
+
*/
|
549 |
+
create or replace function public.get_account_by_slug(slug text)
|
550 |
+
returns json
|
551 |
+
language plpgsql
|
552 |
+
as
|
553 |
+
$$
|
554 |
+
DECLARE
|
555 |
+
internal_account_id uuid;
|
556 |
+
BEGIN
|
557 |
+
select a.id
|
558 |
+
into internal_account_id
|
559 |
+
from basejump.accounts a
|
560 |
+
where a.slug IS NOT NULL
|
561 |
+
and a.slug = get_account_by_slug.slug;
|
562 |
+
|
563 |
+
return public.get_account(internal_account_id);
|
564 |
+
END;
|
565 |
+
$$;
|
566 |
+
|
567 |
+
grant execute on function public.get_account_by_slug(text) to authenticated;
|
568 |
+
|
569 |
+
/**
|
570 |
+
Returns the personal account for the current user
|
571 |
+
*/
|
572 |
+
create or replace function public.get_personal_account()
|
573 |
+
returns json
|
574 |
+
language plpgsql
|
575 |
+
as
|
576 |
+
$$
|
577 |
+
BEGIN
|
578 |
+
return public.get_account(auth.uid());
|
579 |
+
END;
|
580 |
+
$$;
|
581 |
+
|
582 |
+
grant execute on function public.get_personal_account() to authenticated;
|
583 |
+
|
584 |
+
/**
|
585 |
+
* Create an account
|
586 |
+
*/
|
587 |
+
create or replace function public.create_account(slug text default null, name text default null)
|
588 |
+
returns json
|
589 |
+
language plpgsql
|
590 |
+
as
|
591 |
+
$$
|
592 |
+
DECLARE
|
593 |
+
new_account_id uuid;
|
594 |
+
BEGIN
|
595 |
+
insert into basejump.accounts (slug, name)
|
596 |
+
values (create_account.slug, create_account.name)
|
597 |
+
returning id into new_account_id;
|
598 |
+
|
599 |
+
return public.get_account(new_account_id);
|
600 |
+
EXCEPTION
|
601 |
+
WHEN unique_violation THEN
|
602 |
+
raise exception 'An account with that unique ID already exists';
|
603 |
+
END;
|
604 |
+
$$;
|
605 |
+
|
606 |
+
grant execute on function public.create_account(slug text, name text) to authenticated;
|
607 |
+
|
608 |
+
/**
|
609 |
+
Update an account with passed in info. None of the info is required except for account ID.
|
610 |
+
If you don't pass in a value for a field, it will not be updated.
|
611 |
+
If you set replace_meta to true, the metadata will be replaced with the passed in metadata.
|
612 |
+
If you set replace_meta to false, the metadata will be merged with the passed in metadata.
|
613 |
+
*/
|
614 |
+
create or replace function public.update_account(account_id uuid, slug text default null, name text default null,
|
615 |
+
public_metadata jsonb default null,
|
616 |
+
replace_metadata boolean default false)
|
617 |
+
returns json
|
618 |
+
language plpgsql
|
619 |
+
as
|
620 |
+
$$
|
621 |
+
BEGIN
|
622 |
+
|
623 |
+
-- check if postgres role is service_role
|
624 |
+
if current_user IN ('anon', 'authenticated') and
|
625 |
+
not (select current_user_account_role(update_account.account_id) ->> 'account_role' = 'owner') then
|
626 |
+
raise exception 'Only account owners can update an account';
|
627 |
+
end if;
|
628 |
+
|
629 |
+
update basejump.accounts accounts
|
630 |
+
set slug = coalesce(update_account.slug, accounts.slug),
|
631 |
+
name = coalesce(update_account.name, accounts.name),
|
632 |
+
public_metadata = case
|
633 |
+
when update_account.public_metadata is null then accounts.public_metadata -- do nothing
|
634 |
+
when accounts.public_metadata IS NULL then update_account.public_metadata -- set metadata
|
635 |
+
when update_account.replace_metadata
|
636 |
+
then update_account.public_metadata -- replace metadata
|
637 |
+
else accounts.public_metadata || update_account.public_metadata end -- merge metadata
|
638 |
+
where accounts.id = update_account.account_id;
|
639 |
+
|
640 |
+
return public.get_account(account_id);
|
641 |
+
END;
|
642 |
+
$$;
|
643 |
+
|
644 |
+
grant execute on function public.update_account(uuid, text, text, jsonb, boolean) to authenticated, service_role;
|
645 |
+
|
646 |
+
/**
|
647 |
+
Returns a list of current account members. Only account owners can access this function.
|
648 |
+
It's a security definer because it requries us to lookup personal_accounts for existing members so we can
|
649 |
+
get their names.
|
650 |
+
*/
|
651 |
+
create or replace function public.get_account_members(account_id uuid, results_limit integer default 50,
|
652 |
+
results_offset integer default 0)
|
653 |
+
returns json
|
654 |
+
language plpgsql
|
655 |
+
security definer
|
656 |
+
set search_path = basejump
|
657 |
+
as
|
658 |
+
$$
|
659 |
+
BEGIN
|
660 |
+
|
661 |
+
-- only account owners can access this function
|
662 |
+
if (select public.current_user_account_role(get_account_members.account_id) ->> 'account_role' <> 'owner') then
|
663 |
+
raise exception 'Only account owners can access this function';
|
664 |
+
end if;
|
665 |
+
|
666 |
+
return (select json_agg(
|
667 |
+
json_build_object(
|
668 |
+
'user_id', wu.user_id,
|
669 |
+
'account_role', wu.account_role,
|
670 |
+
'name', p.name,
|
671 |
+
'email', u.email,
|
672 |
+
'is_primary_owner', a.primary_owner_user_id = wu.user_id
|
673 |
+
)
|
674 |
+
)
|
675 |
+
from basejump.account_user wu
|
676 |
+
join basejump.accounts a on a.id = wu.account_id
|
677 |
+
join basejump.accounts p on p.primary_owner_user_id = wu.user_id and p.personal_account = true
|
678 |
+
join auth.users u on u.id = wu.user_id
|
679 |
+
where wu.account_id = get_account_members.account_id
|
680 |
+
limit coalesce(get_account_members.results_limit, 50) offset coalesce(get_account_members.results_offset, 0));
|
681 |
+
END;
|
682 |
+
$$;
|
683 |
+
|
684 |
+
grant execute on function public.get_account_members(uuid, integer, integer) to authenticated;
|
685 |
+
|
686 |
+
/**
|
687 |
+
Allows an owner of the account to remove any member other than the primary owner
|
688 |
+
*/
|
689 |
+
|
690 |
+
create or replace function public.remove_account_member(account_id uuid, user_id uuid)
|
691 |
+
returns void
|
692 |
+
language plpgsql
|
693 |
+
as
|
694 |
+
$$
|
695 |
+
BEGIN
|
696 |
+
-- only account owners can access this function
|
697 |
+
if basejump.has_role_on_account(remove_account_member.account_id, 'owner') <> true then
|
698 |
+
raise exception 'Only account owners can access this function';
|
699 |
+
end if;
|
700 |
+
|
701 |
+
delete
|
702 |
+
from basejump.account_user wu
|
703 |
+
where wu.account_id = remove_account_member.account_id
|
704 |
+
and wu.user_id = remove_account_member.user_id;
|
705 |
+
END;
|
706 |
+
$$;
|
707 |
+
|
708 |
+
grant execute on function public.remove_account_member(uuid, uuid) to authenticated;
|
20240414162100_basejump-invitations.sql
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* -------------------------------------------------------
|
3 |
+
* Section - Invitations
|
4 |
+
* -------------------------------------------------------
|
5 |
+
*/
|
6 |
+
|
7 |
+
/**
|
8 |
+
* Invitations are sent to users to join a account
|
9 |
+
* They pre-define the role the user should have once they join
|
10 |
+
*/
|
11 |
+
create table if not exists basejump.invitations
|
12 |
+
(
|
13 |
+
-- the id of the invitation
|
14 |
+
id uuid unique not null default extensions.uuid_generate_v4(),
|
15 |
+
-- what role should invitation accepters be given in this account
|
16 |
+
account_role basejump.account_role not null,
|
17 |
+
-- the account the invitation is for
|
18 |
+
account_id uuid references basejump.accounts (id) on delete cascade not null,
|
19 |
+
-- unique token used to accept the invitation
|
20 |
+
token text unique not null default basejump.generate_token(30),
|
21 |
+
-- who created the invitation
|
22 |
+
invited_by_user_id uuid references auth.users not null,
|
23 |
+
-- account name. filled in by a trigger
|
24 |
+
account_name text,
|
25 |
+
-- when the invitation was last updated
|
26 |
+
updated_at timestamp with time zone,
|
27 |
+
-- when the invitation was created
|
28 |
+
created_at timestamp with time zone,
|
29 |
+
-- what type of invitation is this
|
30 |
+
invitation_type basejump.invitation_type not null,
|
31 |
+
primary key (id)
|
32 |
+
);
|
33 |
+
|
34 |
+
-- Open up access to invitations
|
35 |
+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.invitations TO authenticated, service_role;
|
36 |
+
|
37 |
+
-- manage timestamps
|
38 |
+
CREATE TRIGGER basejump_set_invitations_timestamp
|
39 |
+
BEFORE INSERT OR UPDATE
|
40 |
+
ON basejump.invitations
|
41 |
+
FOR EACH ROW
|
42 |
+
EXECUTE FUNCTION basejump.trigger_set_timestamps();
|
43 |
+
|
44 |
+
/**
|
45 |
+
* This funciton fills in account info and inviting user email
|
46 |
+
* so that the recipient can get more info about the invitation prior to
|
47 |
+
* accepting. It allows us to avoid complex permissions on accounts
|
48 |
+
*/
|
49 |
+
CREATE OR REPLACE FUNCTION basejump.trigger_set_invitation_details()
|
50 |
+
RETURNS TRIGGER AS
|
51 |
+
$$
|
52 |
+
BEGIN
|
53 |
+
NEW.invited_by_user_id = auth.uid();
|
54 |
+
NEW.account_name = (select name from basejump.accounts where id = NEW.account_id);
|
55 |
+
RETURN NEW;
|
56 |
+
END
|
57 |
+
$$ LANGUAGE plpgsql;
|
58 |
+
|
59 |
+
CREATE TRIGGER basejump_trigger_set_invitation_details
|
60 |
+
BEFORE INSERT
|
61 |
+
ON basejump.invitations
|
62 |
+
FOR EACH ROW
|
63 |
+
EXECUTE FUNCTION basejump.trigger_set_invitation_details();
|
64 |
+
|
65 |
+
-- enable RLS on invitations
|
66 |
+
alter table basejump.invitations
|
67 |
+
enable row level security;
|
68 |
+
|
69 |
+
/**
|
70 |
+
* -------------------------
|
71 |
+
* Section - RLS Policies
|
72 |
+
* -------------------------
|
73 |
+
* This is where we define access to tables in the basejump schema
|
74 |
+
*/
|
75 |
+
|
76 |
+
create policy "Invitations viewable by account owners" on basejump.invitations
|
77 |
+
for select
|
78 |
+
to authenticated
|
79 |
+
using (
|
80 |
+
created_at > (now() - interval '24 hours')
|
81 |
+
and
|
82 |
+
basejump.has_role_on_account(account_id, 'owner') = true
|
83 |
+
);
|
84 |
+
|
85 |
+
|
86 |
+
create policy "Invitations can be created by account owners" on basejump.invitations
|
87 |
+
for insert
|
88 |
+
to authenticated
|
89 |
+
with check (
|
90 |
+
-- team accounts should be enabled
|
91 |
+
basejump.is_set('enable_team_accounts') = true
|
92 |
+
-- this should not be a personal account
|
93 |
+
and (SELECT personal_account
|
94 |
+
FROM basejump.accounts
|
95 |
+
WHERE id = account_id) = false
|
96 |
+
-- the inserting user should be an owner of the account
|
97 |
+
and
|
98 |
+
(basejump.has_role_on_account(account_id, 'owner') = true)
|
99 |
+
);
|
100 |
+
|
101 |
+
create policy "Invitations can be deleted by account owners" on basejump.invitations
|
102 |
+
for delete
|
103 |
+
to authenticated
|
104 |
+
using (
|
105 |
+
basejump.has_role_on_account(account_id, 'owner') = true
|
106 |
+
);
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
/**
|
111 |
+
* -------------------------------------------------------
|
112 |
+
* Section - Public functions
|
113 |
+
* -------------------------------------------------------
|
114 |
+
* Each of these functions exists in the public name space because they are accessible
|
115 |
+
* via the API. it is the primary way developers can interact with Basejump accounts
|
116 |
+
*/
|
117 |
+
|
118 |
+
|
119 |
+
/**
|
120 |
+
Returns a list of currently active invitations for a given account
|
121 |
+
*/
|
122 |
+
|
123 |
+
create or replace function public.get_account_invitations(account_id uuid, results_limit integer default 25,
|
124 |
+
results_offset integer default 0)
|
125 |
+
returns json
|
126 |
+
language plpgsql
|
127 |
+
as
|
128 |
+
$$
|
129 |
+
BEGIN
|
130 |
+
-- only account owners can access this function
|
131 |
+
if (select public.current_user_account_role(get_account_invitations.account_id) ->> 'account_role' <> 'owner') then
|
132 |
+
raise exception 'Only account owners can access this function';
|
133 |
+
end if;
|
134 |
+
|
135 |
+
return (select json_agg(
|
136 |
+
json_build_object(
|
137 |
+
'account_role', i.account_role,
|
138 |
+
'created_at', i.created_at,
|
139 |
+
'invitation_type', i.invitation_type,
|
140 |
+
'invitation_id', i.id
|
141 |
+
)
|
142 |
+
)
|
143 |
+
from basejump.invitations i
|
144 |
+
where i.account_id = get_account_invitations.account_id
|
145 |
+
and i.created_at > now() - interval '24 hours'
|
146 |
+
limit coalesce(get_account_invitations.results_limit, 25) offset coalesce(get_account_invitations.results_offset, 0));
|
147 |
+
END;
|
148 |
+
$$;
|
149 |
+
|
150 |
+
grant execute on function public.get_account_invitations(uuid, integer, integer) to authenticated;
|
151 |
+
|
152 |
+
|
153 |
+
/**
|
154 |
+
* Allows a user to accept an existing invitation and join a account
|
155 |
+
* This one exists in the public schema because we want it to be called
|
156 |
+
* using the supabase rpc method
|
157 |
+
*/
|
158 |
+
create or replace function public.accept_invitation(lookup_invitation_token text)
|
159 |
+
returns jsonb
|
160 |
+
language plpgsql
|
161 |
+
security definer set search_path = public, basejump
|
162 |
+
as
|
163 |
+
$$
|
164 |
+
declare
|
165 |
+
lookup_account_id uuid;
|
166 |
+
declare new_member_role basejump.account_role;
|
167 |
+
lookup_account_slug text;
|
168 |
+
begin
|
169 |
+
select i.account_id, i.account_role, a.slug
|
170 |
+
into lookup_account_id, new_member_role, lookup_account_slug
|
171 |
+
from basejump.invitations i
|
172 |
+
join basejump.accounts a on a.id = i.account_id
|
173 |
+
where i.token = lookup_invitation_token
|
174 |
+
and i.created_at > now() - interval '24 hours';
|
175 |
+
|
176 |
+
if lookup_account_id IS NULL then
|
177 |
+
raise exception 'Invitation not found';
|
178 |
+
end if;
|
179 |
+
|
180 |
+
if lookup_account_id is not null then
|
181 |
+
-- we've validated the token is real, so grant the user access
|
182 |
+
insert into basejump.account_user (account_id, user_id, account_role)
|
183 |
+
values (lookup_account_id, auth.uid(), new_member_role);
|
184 |
+
-- email types of invitations are only good for one usage
|
185 |
+
delete from basejump.invitations where token = lookup_invitation_token and invitation_type = 'one_time';
|
186 |
+
end if;
|
187 |
+
return json_build_object('account_id', lookup_account_id, 'account_role', new_member_role, 'slug',
|
188 |
+
lookup_account_slug);
|
189 |
+
EXCEPTION
|
190 |
+
WHEN unique_violation THEN
|
191 |
+
raise exception 'You are already a member of this account';
|
192 |
+
end;
|
193 |
+
$$;
|
194 |
+
|
195 |
+
grant execute on function public.accept_invitation(text) to authenticated;
|
196 |
+
|
197 |
+
|
198 |
+
/**
|
199 |
+
* Allows a user to lookup an existing invitation and join a account
|
200 |
+
* This one exists in the public schema because we want it to be called
|
201 |
+
* using the supabase rpc method
|
202 |
+
*/
|
203 |
+
create or replace function public.lookup_invitation(lookup_invitation_token text)
|
204 |
+
returns json
|
205 |
+
language plpgsql
|
206 |
+
security definer set search_path = public, basejump
|
207 |
+
as
|
208 |
+
$$
|
209 |
+
declare
|
210 |
+
name text;
|
211 |
+
invitation_active boolean;
|
212 |
+
begin
|
213 |
+
select account_name,
|
214 |
+
case when id IS NOT NULL then true else false end as active
|
215 |
+
into name, invitation_active
|
216 |
+
from basejump.invitations
|
217 |
+
where token = lookup_invitation_token
|
218 |
+
and created_at > now() - interval '24 hours'
|
219 |
+
limit 1;
|
220 |
+
return json_build_object('active', coalesce(invitation_active, false), 'account_name', name);
|
221 |
+
end;
|
222 |
+
$$;
|
223 |
+
|
224 |
+
grant execute on function public.lookup_invitation(text) to authenticated;
|
225 |
+
|
226 |
+
|
227 |
+
/**
|
228 |
+
Allows a user to create a new invitation if they are an owner of an account
|
229 |
+
*/
|
230 |
+
create or replace function public.create_invitation(account_id uuid, account_role basejump.account_role,
|
231 |
+
invitation_type basejump.invitation_type)
|
232 |
+
returns json
|
233 |
+
language plpgsql
|
234 |
+
as
|
235 |
+
$$
|
236 |
+
declare
|
237 |
+
new_invitation basejump.invitations;
|
238 |
+
begin
|
239 |
+
insert into basejump.invitations (account_id, account_role, invitation_type, invited_by_user_id)
|
240 |
+
values (account_id, account_role, invitation_type, auth.uid())
|
241 |
+
returning * into new_invitation;
|
242 |
+
|
243 |
+
return json_build_object('token', new_invitation.token);
|
244 |
+
end
|
245 |
+
$$;
|
246 |
+
|
247 |
+
grant execute on function public.create_invitation(uuid, basejump.account_role, basejump.invitation_type) to authenticated;
|
248 |
+
|
249 |
+
/**
|
250 |
+
Allows an owner to delete an existing invitation
|
251 |
+
*/
|
252 |
+
|
253 |
+
create or replace function public.delete_invitation(invitation_id uuid)
|
254 |
+
returns void
|
255 |
+
language plpgsql
|
256 |
+
as
|
257 |
+
$$
|
258 |
+
begin
|
259 |
+
-- verify account owner for the invitation
|
260 |
+
if basejump.has_role_on_account(
|
261 |
+
(select account_id from basejump.invitations where id = delete_invitation.invitation_id), 'owner') <>
|
262 |
+
true then
|
263 |
+
raise exception 'Only account owners can delete invitations';
|
264 |
+
end if;
|
265 |
+
|
266 |
+
delete from basejump.invitations where id = delete_invitation.invitation_id;
|
267 |
+
end
|
268 |
+
$$;
|
269 |
+
|
270 |
+
grant execute on function public.delete_invitation(uuid) to authenticated;
|
20240414162131_basejump-billing.sql
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* -------------------------------------------------------
|
3 |
+
* Section - Billing
|
4 |
+
* -------------------------------------------------------
|
5 |
+
*/
|
6 |
+
|
7 |
+
/**
|
8 |
+
* Subscription Status
|
9 |
+
* Tracks the current status of the account subscription
|
10 |
+
*/
|
11 |
+
DO
|
12 |
+
$$
|
13 |
+
BEGIN
|
14 |
+
IF NOT EXISTS(SELECT 1
|
15 |
+
FROM pg_type t
|
16 |
+
JOIN pg_namespace n ON n.oid = t.typnamespace
|
17 |
+
WHERE t.typname = 'subscription_status'
|
18 |
+
AND n.nspname = 'basejump') THEN
|
19 |
+
create type basejump.subscription_status as enum (
|
20 |
+
'trialing',
|
21 |
+
'active',
|
22 |
+
'canceled',
|
23 |
+
'incomplete',
|
24 |
+
'incomplete_expired',
|
25 |
+
'past_due',
|
26 |
+
'unpaid'
|
27 |
+
);
|
28 |
+
end if;
|
29 |
+
end;
|
30 |
+
$$;
|
31 |
+
|
32 |
+
|
33 |
+
/**
|
34 |
+
* Billing customer
|
35 |
+
* This is a private table that contains a mapping of user IDs to your billing providers IDs
|
36 |
+
*/
|
37 |
+
create table if not exists basejump.billing_customers
|
38 |
+
(
|
39 |
+
-- UUID from auth.users
|
40 |
+
account_id uuid references basejump.accounts (id) on delete cascade not null,
|
41 |
+
-- The user's customer ID in Stripe. User must not be able to update this.
|
42 |
+
id text primary key,
|
43 |
+
-- The email address the customer wants to use for invoicing
|
44 |
+
email text,
|
45 |
+
-- The active status of a customer
|
46 |
+
active boolean,
|
47 |
+
-- The billing provider the customer is using
|
48 |
+
provider text
|
49 |
+
);
|
50 |
+
|
51 |
+
-- Open up access to billing_customers
|
52 |
+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.billing_customers TO service_role;
|
53 |
+
GRANT SELECT ON TABLE basejump.billing_customers TO authenticated;
|
54 |
+
|
55 |
+
|
56 |
+
-- enable RLS for billing_customers
|
57 |
+
alter table
|
58 |
+
basejump.billing_customers
|
59 |
+
enable row level security;
|
60 |
+
|
61 |
+
/**
|
62 |
+
* Billing subscriptions
|
63 |
+
* This is a private table that contains a mapping of account IDs to your billing providers subscription IDs
|
64 |
+
*/
|
65 |
+
create table if not exists basejump.billing_subscriptions
|
66 |
+
(
|
67 |
+
-- Subscription ID from Stripe, e.g. sub_1234.
|
68 |
+
id text primary key,
|
69 |
+
account_id uuid references basejump.accounts (id) on delete cascade not null,
|
70 |
+
billing_customer_id text references basejump.billing_customers (id) on delete cascade not null,
|
71 |
+
-- The status of the subscription object, one of subscription_status type above.
|
72 |
+
status basejump.subscription_status,
|
73 |
+
-- Set of key-value pairs, used to store additional information about the object in a structured format.
|
74 |
+
metadata jsonb,
|
75 |
+
-- ID of the price that created this subscription.
|
76 |
+
price_id text,
|
77 |
+
plan_name text,
|
78 |
+
-- Quantity multiplied by the unit amount of the price creates the amount of the subscription. Can be used to charge multiple seats.
|
79 |
+
quantity integer,
|
80 |
+
-- If true the subscription has been canceled by the user and will be deleted at the end of the billing period.
|
81 |
+
cancel_at_period_end boolean,
|
82 |
+
-- Time at which the subscription was created.
|
83 |
+
created timestamp with time zone default timezone('utc' :: text, now()) not null,
|
84 |
+
-- Start of the current period that the subscription has been invoiced for.
|
85 |
+
current_period_start timestamp with time zone default timezone('utc' :: text, now()) not null,
|
86 |
+
-- End of the current period that the subscription has been invoiced for. At the end of this period, a new invoice will be created.
|
87 |
+
current_period_end timestamp with time zone default timezone('utc' :: text, now()) not null,
|
88 |
+
-- If the subscription has ended, the timestamp of the date the subscription ended.
|
89 |
+
ended_at timestamp with time zone default timezone('utc' :: text, now()),
|
90 |
+
-- A date in the future at which the subscription will automatically get canceled.
|
91 |
+
cancel_at timestamp with time zone default timezone('utc' :: text, now()),
|
92 |
+
-- If the subscription has been canceled, the date of that cancellation. If the subscription was canceled with `cancel_at_period_end`, `canceled_at` will still reflect the date of the initial cancellation request, not the end of the subscription period when the subscription is automatically moved to a canceled state.
|
93 |
+
canceled_at timestamp with time zone default timezone('utc' :: text, now()),
|
94 |
+
-- If the subscription has a trial, the beginning of that trial.
|
95 |
+
trial_start timestamp with time zone default timezone('utc' :: text, now()),
|
96 |
+
-- If the subscription has a trial, the end of that trial.
|
97 |
+
trial_end timestamp with time zone default timezone('utc' :: text, now()),
|
98 |
+
provider text
|
99 |
+
);
|
100 |
+
|
101 |
+
-- Open up access to billing_subscriptions
|
102 |
+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.billing_subscriptions TO service_role;
|
103 |
+
GRANT SELECT ON TABLE basejump.billing_subscriptions TO authenticated;
|
104 |
+
|
105 |
+
-- enable RLS for billing_subscriptions
|
106 |
+
alter table
|
107 |
+
basejump.billing_subscriptions
|
108 |
+
enable row level security;
|
109 |
+
|
110 |
+
/**
|
111 |
+
* -------------------------
|
112 |
+
* Section - RLS Policies
|
113 |
+
* -------------------------
|
114 |
+
* This is where we define access to tables in the basejump schema
|
115 |
+
*/
|
116 |
+
|
117 |
+
create policy "Can only view own billing customer data." on basejump.billing_customers for
|
118 |
+
select
|
119 |
+
using (
|
120 |
+
basejump.has_role_on_account(account_id) = true
|
121 |
+
);
|
122 |
+
|
123 |
+
|
124 |
+
create policy "Can only view own billing subscription data." on basejump.billing_subscriptions for
|
125 |
+
select
|
126 |
+
using (
|
127 |
+
basejump.has_role_on_account(account_id) = true
|
128 |
+
);
|
129 |
+
|
130 |
+
/**
|
131 |
+
* -------------------------------------------------------
|
132 |
+
* Section - Public functions
|
133 |
+
* -------------------------------------------------------
|
134 |
+
* Each of these functions exists in the public name space because they are accessible
|
135 |
+
* via the API. it is the primary way developers can interact with Basejump accounts
|
136 |
+
*/
|
137 |
+
|
138 |
+
|
139 |
+
/**
|
140 |
+
* Returns the current billing status for an account
|
141 |
+
*/
|
142 |
+
CREATE OR REPLACE FUNCTION public.get_account_billing_status(account_id uuid)
|
143 |
+
RETURNS jsonb
|
144 |
+
security definer
|
145 |
+
set search_path = public, basejump
|
146 |
+
AS
|
147 |
+
$$
|
148 |
+
DECLARE
|
149 |
+
result jsonb;
|
150 |
+
role_result jsonb;
|
151 |
+
BEGIN
|
152 |
+
select public.current_user_account_role(get_account_billing_status.account_id) into role_result;
|
153 |
+
|
154 |
+
select jsonb_build_object(
|
155 |
+
'account_id', get_account_billing_status.account_id,
|
156 |
+
'billing_subscription_id', s.id,
|
157 |
+
'billing_enabled', case
|
158 |
+
when a.personal_account = true then config.enable_personal_account_billing
|
159 |
+
else config.enable_team_account_billing end,
|
160 |
+
'billing_status', s.status,
|
161 |
+
'billing_customer_id', c.id,
|
162 |
+
'billing_provider', config.billing_provider,
|
163 |
+
'billing_email',
|
164 |
+
coalesce(c.email, u.email) -- if we don't have a customer email, use the user's email as a fallback
|
165 |
+
)
|
166 |
+
into result
|
167 |
+
from basejump.accounts a
|
168 |
+
join auth.users u on u.id = a.primary_owner_user_id
|
169 |
+
left join basejump.billing_subscriptions s on s.account_id = a.id
|
170 |
+
left join basejump.billing_customers c on c.account_id = coalesce(s.account_id, a.id)
|
171 |
+
join basejump.config config on true
|
172 |
+
where a.id = get_account_billing_status.account_id
|
173 |
+
order by s.created desc
|
174 |
+
limit 1;
|
175 |
+
|
176 |
+
return result || role_result;
|
177 |
+
END;
|
178 |
+
$$ LANGUAGE plpgsql;
|
179 |
+
|
180 |
+
grant execute on function public.get_account_billing_status(uuid) to authenticated;
|
181 |
+
|
182 |
+
/**
|
183 |
+
* Allow service accounts to upsert the billing data for an account
|
184 |
+
*/
|
185 |
+
CREATE OR REPLACE FUNCTION public.service_role_upsert_customer_subscription(account_id uuid,
|
186 |
+
customer jsonb default null,
|
187 |
+
subscription jsonb default null)
|
188 |
+
RETURNS void AS
|
189 |
+
$$
|
190 |
+
BEGIN
|
191 |
+
-- if the customer is not null, upsert the data into billing_customers, only upsert fields that are present in the jsonb object
|
192 |
+
if customer is not null then
|
193 |
+
insert into basejump.billing_customers (id, account_id, email, provider)
|
194 |
+
values (customer ->> 'id', service_role_upsert_customer_subscription.account_id, customer ->> 'billing_email',
|
195 |
+
(customer ->> 'provider'))
|
196 |
+
on conflict (id) do update
|
197 |
+
set email = customer ->> 'billing_email';
|
198 |
+
end if;
|
199 |
+
|
200 |
+
-- if the subscription is not null, upsert the data into billing_subscriptions, only upsert fields that are present in the jsonb object
|
201 |
+
if subscription is not null then
|
202 |
+
insert into basejump.billing_subscriptions (id, account_id, billing_customer_id, status, metadata, price_id,
|
203 |
+
quantity, cancel_at_period_end, created, current_period_start,
|
204 |
+
current_period_end, ended_at, cancel_at, canceled_at, trial_start,
|
205 |
+
trial_end, plan_name, provider)
|
206 |
+
values (subscription ->> 'id', service_role_upsert_customer_subscription.account_id,
|
207 |
+
subscription ->> 'billing_customer_id', (subscription ->> 'status')::basejump.subscription_status,
|
208 |
+
subscription -> 'metadata',
|
209 |
+
subscription ->> 'price_id', (subscription ->> 'quantity')::int,
|
210 |
+
(subscription ->> 'cancel_at_period_end')::boolean,
|
211 |
+
(subscription ->> 'created')::timestamptz, (subscription ->> 'current_period_start')::timestamptz,
|
212 |
+
(subscription ->> 'current_period_end')::timestamptz, (subscription ->> 'ended_at')::timestamptz,
|
213 |
+
(subscription ->> 'cancel_at')::timestamptz,
|
214 |
+
(subscription ->> 'canceled_at')::timestamptz, (subscription ->> 'trial_start')::timestamptz,
|
215 |
+
(subscription ->> 'trial_end')::timestamptz,
|
216 |
+
subscription ->> 'plan_name', (subscription ->> 'provider'))
|
217 |
+
on conflict (id) do update
|
218 |
+
set billing_customer_id = subscription ->> 'billing_customer_id',
|
219 |
+
status = (subscription ->> 'status')::basejump.subscription_status,
|
220 |
+
metadata = subscription -> 'metadata',
|
221 |
+
price_id = subscription ->> 'price_id',
|
222 |
+
quantity = (subscription ->> 'quantity')::int,
|
223 |
+
cancel_at_period_end = (subscription ->> 'cancel_at_period_end')::boolean,
|
224 |
+
current_period_start = (subscription ->> 'current_period_start')::timestamptz,
|
225 |
+
current_period_end = (subscription ->> 'current_period_end')::timestamptz,
|
226 |
+
ended_at = (subscription ->> 'ended_at')::timestamptz,
|
227 |
+
cancel_at = (subscription ->> 'cancel_at')::timestamptz,
|
228 |
+
canceled_at = (subscription ->> 'canceled_at')::timestamptz,
|
229 |
+
trial_start = (subscription ->> 'trial_start')::timestamptz,
|
230 |
+
trial_end = (subscription ->> 'trial_end')::timestamptz,
|
231 |
+
plan_name = subscription ->> 'plan_name';
|
232 |
+
end if;
|
233 |
+
end;
|
234 |
+
$$ LANGUAGE plpgsql;
|
235 |
+
|
236 |
+
GRANT EXECUTE ON FUNCTION public.service_role_upsert_customer_subscription(uuid, jsonb, jsonb) TO service_role;
|
20250409211903_basejump-configure.sql
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
UPDATE basejump.config SET enable_team_accounts = TRUE;
|
2 |
+
UPDATE basejump.config SET enable_personal_account_billing = TRUE;
|
3 |
+
UPDATE basejump.config SET enable_team_account_billing = TRUE;
|
20250409212058_initial.sql
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Enable UUID extension
|
2 |
+
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
3 |
+
|
4 |
+
-- Create devices table first
|
5 |
+
CREATE TABLE public.devices (
|
6 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
7 |
+
account_id UUID NOT NULL,
|
8 |
+
name TEXT,
|
9 |
+
last_seen TIMESTAMP WITH TIME ZONE,
|
10 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
|
11 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
|
12 |
+
is_online BOOLEAN DEFAULT FALSE,
|
13 |
+
CONSTRAINT fk_account FOREIGN KEY (account_id) REFERENCES basejump.accounts(id) ON DELETE CASCADE
|
14 |
+
);
|
15 |
+
|
16 |
+
-- Create recordings table
|
17 |
+
CREATE TABLE public.recordings (
|
18 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
19 |
+
account_id UUID NOT NULL,
|
20 |
+
device_id UUID NOT NULL,
|
21 |
+
preprocessed_file_path TEXT,
|
22 |
+
meta JSONB,
|
23 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
|
24 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
|
25 |
+
name TEXT,
|
26 |
+
ui_annotated BOOLEAN DEFAULT FALSE,
|
27 |
+
a11y_file_path TEXT,
|
28 |
+
audio_file_path TEXT,
|
29 |
+
action_annotated BOOLEAN DEFAULT FALSE,
|
30 |
+
raw_data_file_path TEXT,
|
31 |
+
metadata_file_path TEXT,
|
32 |
+
action_training_file_path TEXT,
|
33 |
+
CONSTRAINT fk_account FOREIGN KEY (account_id) REFERENCES basejump.accounts(id) ON DELETE CASCADE,
|
34 |
+
CONSTRAINT fk_device FOREIGN KEY (device_id) REFERENCES public.devices(id) ON DELETE CASCADE
|
35 |
+
);
|
36 |
+
|
37 |
+
-- Create indexes for foreign keys
|
38 |
+
CREATE INDEX idx_recordings_account_id ON public.recordings(account_id);
|
39 |
+
CREATE INDEX idx_recordings_device_id ON public.recordings(device_id);
|
40 |
+
CREATE INDEX idx_devices_account_id ON public.devices(account_id);
|
41 |
+
|
42 |
+
-- Add RLS policies (optional, can be customized as needed)
|
43 |
+
ALTER TABLE public.recordings ENABLE ROW LEVEL SECURITY;
|
44 |
+
ALTER TABLE public.devices ENABLE ROW LEVEL SECURITY;
|
45 |
+
|
46 |
+
-- Create RLS policies for devices
|
47 |
+
CREATE POLICY "Account members can delete their own devices"
|
48 |
+
ON public.devices FOR DELETE
|
49 |
+
USING (basejump.has_role_on_account(account_id));
|
50 |
+
|
51 |
+
CREATE POLICY "Account members can insert their own devices"
|
52 |
+
ON public.devices FOR INSERT
|
53 |
+
WITH CHECK (basejump.has_role_on_account(account_id));
|
54 |
+
|
55 |
+
CREATE POLICY "Account members can only access their own devices"
|
56 |
+
ON public.devices FOR ALL
|
57 |
+
USING (basejump.has_role_on_account(account_id));
|
58 |
+
|
59 |
+
CREATE POLICY "Account members can update their own devices"
|
60 |
+
ON public.devices FOR UPDATE
|
61 |
+
USING (basejump.has_role_on_account(account_id));
|
62 |
+
|
63 |
+
CREATE POLICY "Account members can view their own devices"
|
64 |
+
ON public.devices FOR SELECT
|
65 |
+
USING (basejump.has_role_on_account(account_id));
|
66 |
+
|
67 |
+
-- Create RLS policies for recordings
|
68 |
+
CREATE POLICY "Account members can delete their own recordings"
|
69 |
+
ON public.recordings FOR DELETE
|
70 |
+
USING (basejump.has_role_on_account(account_id));
|
71 |
+
|
72 |
+
CREATE POLICY "Account members can insert their own recordings"
|
73 |
+
ON public.recordings FOR INSERT
|
74 |
+
WITH CHECK (basejump.has_role_on_account(account_id));
|
75 |
+
|
76 |
+
CREATE POLICY "Account members can only access their own recordings"
|
77 |
+
ON public.recordings FOR ALL
|
78 |
+
USING (basejump.has_role_on_account(account_id));
|
79 |
+
|
80 |
+
CREATE POLICY "Account members can update their own recordings"
|
81 |
+
ON public.recordings FOR UPDATE
|
82 |
+
USING (basejump.has_role_on_account(account_id));
|
83 |
+
|
84 |
+
CREATE POLICY "Account members can view their own recordings"
|
85 |
+
ON public.recordings FOR SELECT
|
86 |
+
USING (basejump.has_role_on_account(account_id));
|
87 |
+
|
88 |
+
-- Note: For threads and messages, you might want different RLS policies
|
89 |
+
-- depending on your application's requirements
|
90 |
+
|
91 |
+
|
92 |
+
-- Also drop the old function signature
|
93 |
+
DROP FUNCTION IF EXISTS transfer_device(UUID, UUID, TEXT);
|
94 |
+
|
95 |
+
|
96 |
+
CREATE OR REPLACE FUNCTION transfer_device(
|
97 |
+
device_id UUID, -- Parameter remains UUID
|
98 |
+
new_account_id UUID, -- Changed parameter name and implies new ownership target
|
99 |
+
device_name TEXT DEFAULT NULL
|
100 |
+
)
|
101 |
+
RETURNS SETOF devices AS $$
|
102 |
+
DECLARE
|
103 |
+
device_exists BOOLEAN;
|
104 |
+
updated_device devices;
|
105 |
+
BEGIN
|
106 |
+
-- Check if a device with the specified UUID exists
|
107 |
+
SELECT EXISTS (
|
108 |
+
SELECT 1 FROM devices WHERE id = device_id
|
109 |
+
) INTO device_exists;
|
110 |
+
|
111 |
+
IF device_exists THEN
|
112 |
+
-- Device exists: update its account ownership and last_seen timestamp
|
113 |
+
UPDATE devices
|
114 |
+
SET
|
115 |
+
account_id = new_account_id, -- Update account_id instead of user_id
|
116 |
+
name = COALESCE(device_name, name),
|
117 |
+
last_seen = NOW()
|
118 |
+
WHERE id = device_id
|
119 |
+
RETURNING * INTO updated_device;
|
120 |
+
|
121 |
+
RETURN NEXT updated_device;
|
122 |
+
ELSE
|
123 |
+
-- Device doesn't exist; return nothing so the caller can handle creation
|
124 |
+
RETURN;
|
125 |
+
END IF;
|
126 |
+
END;
|
127 |
+
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
128 |
+
|
129 |
+
-- Grant execute permission so that authenticated users can call this function
|
130 |
+
-- Updated function signature
|
131 |
+
GRANT EXECUTE ON FUNCTION transfer_device(UUID, UUID, TEXT) TO authenticated;
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
-- Create the ui_grounding bucket
|
137 |
+
INSERT INTO storage.buckets (id, name, public)
|
138 |
+
VALUES ('ui_grounding', 'ui_grounding', false)
|
139 |
+
ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
|
140 |
+
|
141 |
+
-- Create the ui_grounding_trajs bucket
|
142 |
+
INSERT INTO storage.buckets (id, name, public)
|
143 |
+
VALUES ('ui_grounding_trajs', 'ui_grounding_trajs', false)
|
144 |
+
ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
|
145 |
+
|
146 |
+
-- Create the recordings bucket
|
147 |
+
INSERT INTO storage.buckets (id, name, public, file_size_limit, allowed_mime_types)
|
148 |
+
VALUES ('recordings', 'recordings', false, null, null) -- Set file size limit and mime types as needed
|
149 |
+
ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
|
150 |
+
|
151 |
+
|
152 |
+
-- RLS policies for the 'recordings' bucket
|
153 |
+
-- Allow members to view files in accounts they belong to
|
154 |
+
CREATE POLICY "Account members can select recording files"
|
155 |
+
ON storage.objects FOR SELECT
|
156 |
+
TO authenticated
|
157 |
+
USING (
|
158 |
+
bucket_id = 'recordings' AND
|
159 |
+
(storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
|
160 |
+
);
|
161 |
+
|
162 |
+
-- Allow members to insert files into accounts they belong to
|
163 |
+
CREATE POLICY "Account members can insert recording files"
|
164 |
+
ON storage.objects FOR INSERT
|
165 |
+
TO authenticated
|
166 |
+
WITH CHECK (
|
167 |
+
bucket_id = 'recordings' AND
|
168 |
+
(storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
|
169 |
+
);
|
170 |
+
|
171 |
+
-- Allow members to update files in accounts they belong to
|
172 |
+
CREATE POLICY "Account members can update recording files"
|
173 |
+
ON storage.objects FOR UPDATE
|
174 |
+
TO authenticated
|
175 |
+
USING (
|
176 |
+
bucket_id = 'recordings' AND
|
177 |
+
(storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
|
178 |
+
);
|
179 |
+
|
180 |
+
-- Allow members to delete files from accounts they belong to
|
181 |
+
-- Consider restricting this further, e.g., to 'owner' role if needed:
|
182 |
+
-- (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role('owner'))
|
183 |
+
CREATE POLICY "Account members can delete recording files"
|
184 |
+
ON storage.objects FOR DELETE
|
185 |
+
TO authenticated
|
186 |
+
USING (
|
187 |
+
bucket_id = 'recordings' AND
|
188 |
+
(storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
|
189 |
+
);
|
20250416133920_agentpress_schema.sql
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- AGENTPRESS SCHEMA:
|
2 |
+
-- Create projects table
|
3 |
+
CREATE TABLE projects (
|
4 |
+
project_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
5 |
+
name TEXT NOT NULL,
|
6 |
+
description TEXT,
|
7 |
+
account_id UUID NOT NULL REFERENCES basejump.accounts(id) ON DELETE CASCADE,
|
8 |
+
sandbox JSONB DEFAULT '{}'::jsonb,
|
9 |
+
is_public BOOLEAN DEFAULT FALSE,
|
10 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
|
11 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
|
12 |
+
);
|
13 |
+
|
14 |
+
-- Create threads table
|
15 |
+
CREATE TABLE threads (
|
16 |
+
thread_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
17 |
+
account_id UUID REFERENCES basejump.accounts(id) ON DELETE CASCADE,
|
18 |
+
project_id UUID REFERENCES projects(project_id) ON DELETE CASCADE,
|
19 |
+
is_public BOOLEAN DEFAULT FALSE,
|
20 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
|
21 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
|
22 |
+
);
|
23 |
+
|
24 |
+
-- Create messages table
|
25 |
+
CREATE TABLE messages (
|
26 |
+
message_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
27 |
+
thread_id UUID NOT NULL REFERENCES threads(thread_id) ON DELETE CASCADE,
|
28 |
+
type TEXT NOT NULL,
|
29 |
+
is_llm_message BOOLEAN NOT NULL DEFAULT TRUE,
|
30 |
+
content JSONB NOT NULL,
|
31 |
+
metadata JSONB DEFAULT '{}'::jsonb,
|
32 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
|
33 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
|
34 |
+
);
|
35 |
+
|
36 |
+
-- Create agent_runs table
|
37 |
+
CREATE TABLE agent_runs (
|
38 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
39 |
+
thread_id UUID NOT NULL REFERENCES threads(thread_id),
|
40 |
+
status TEXT NOT NULL DEFAULT 'running',
|
41 |
+
started_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
|
42 |
+
completed_at TIMESTAMP WITH TIME ZONE,
|
43 |
+
responses JSONB NOT NULL DEFAULT '[]'::jsonb, -- TO BE REMOVED, NOT USED
|
44 |
+
error TEXT,
|
45 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
|
46 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
|
47 |
+
);
|
48 |
+
|
49 |
+
-- Create updated_at trigger function
|
50 |
+
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
51 |
+
RETURNS TRIGGER AS $$
|
52 |
+
BEGIN
|
53 |
+
NEW.updated_at = TIMEZONE('utc'::text, NOW());
|
54 |
+
RETURN NEW;
|
55 |
+
END;
|
56 |
+
$$ language 'plpgsql';
|
57 |
+
|
58 |
+
-- Create triggers for updated_at
|
59 |
+
CREATE TRIGGER update_threads_updated_at
|
60 |
+
BEFORE UPDATE ON threads
|
61 |
+
FOR EACH ROW
|
62 |
+
EXECUTE FUNCTION update_updated_at_column();
|
63 |
+
|
64 |
+
CREATE TRIGGER update_messages_updated_at
|
65 |
+
BEFORE UPDATE ON messages
|
66 |
+
FOR EACH ROW
|
67 |
+
EXECUTE FUNCTION update_updated_at_column();
|
68 |
+
|
69 |
+
CREATE TRIGGER update_agent_runs_updated_at
|
70 |
+
BEFORE UPDATE ON agent_runs
|
71 |
+
FOR EACH ROW
|
72 |
+
EXECUTE FUNCTION update_updated_at_column();
|
73 |
+
|
74 |
+
CREATE TRIGGER update_projects_updated_at
|
75 |
+
BEFORE UPDATE ON projects
|
76 |
+
FOR EACH ROW
|
77 |
+
EXECUTE FUNCTION update_updated_at_column();
|
78 |
+
|
79 |
+
-- Create indexes for better query performance
|
80 |
+
CREATE INDEX idx_threads_created_at ON threads(created_at);
|
81 |
+
CREATE INDEX idx_threads_account_id ON threads(account_id);
|
82 |
+
CREATE INDEX idx_threads_project_id ON threads(project_id);
|
83 |
+
CREATE INDEX idx_agent_runs_thread_id ON agent_runs(thread_id);
|
84 |
+
CREATE INDEX idx_agent_runs_status ON agent_runs(status);
|
85 |
+
CREATE INDEX idx_agent_runs_created_at ON agent_runs(created_at);
|
86 |
+
CREATE INDEX idx_projects_account_id ON projects(account_id);
|
87 |
+
CREATE INDEX idx_projects_created_at ON projects(created_at);
|
88 |
+
CREATE INDEX idx_messages_thread_id ON messages(thread_id);
|
89 |
+
CREATE INDEX idx_messages_created_at ON messages(created_at);
|
90 |
+
|
91 |
+
-- Enable Row Level Security
|
92 |
+
ALTER TABLE threads ENABLE ROW LEVEL SECURITY;
|
93 |
+
ALTER TABLE messages ENABLE ROW LEVEL SECURITY;
|
94 |
+
ALTER TABLE agent_runs ENABLE ROW LEVEL SECURITY;
|
95 |
+
ALTER TABLE projects ENABLE ROW LEVEL SECURITY;
|
96 |
+
|
97 |
+
-- Project policies
|
98 |
+
CREATE POLICY project_select_policy ON projects
|
99 |
+
FOR SELECT
|
100 |
+
USING (
|
101 |
+
is_public = TRUE OR
|
102 |
+
basejump.has_role_on_account(account_id) = true
|
103 |
+
);
|
104 |
+
|
105 |
+
CREATE POLICY project_insert_policy ON projects
|
106 |
+
FOR INSERT
|
107 |
+
WITH CHECK (basejump.has_role_on_account(account_id) = true);
|
108 |
+
|
109 |
+
CREATE POLICY project_update_policy ON projects
|
110 |
+
FOR UPDATE
|
111 |
+
USING (basejump.has_role_on_account(account_id) = true);
|
112 |
+
|
113 |
+
CREATE POLICY project_delete_policy ON projects
|
114 |
+
FOR DELETE
|
115 |
+
USING (basejump.has_role_on_account(account_id) = true);
|
116 |
+
|
117 |
+
-- Thread policies based on project and account ownership
|
118 |
+
CREATE POLICY thread_select_policy ON threads
|
119 |
+
FOR SELECT
|
120 |
+
USING (
|
121 |
+
basejump.has_role_on_account(account_id) = true OR
|
122 |
+
EXISTS (
|
123 |
+
SELECT 1 FROM projects
|
124 |
+
WHERE projects.project_id = threads.project_id
|
125 |
+
AND (
|
126 |
+
projects.is_public = TRUE OR
|
127 |
+
basejump.has_role_on_account(projects.account_id) = true
|
128 |
+
)
|
129 |
+
)
|
130 |
+
);
|
131 |
+
|
132 |
+
CREATE POLICY thread_insert_policy ON threads
|
133 |
+
FOR INSERT
|
134 |
+
WITH CHECK (
|
135 |
+
basejump.has_role_on_account(account_id) = true OR
|
136 |
+
EXISTS (
|
137 |
+
SELECT 1 FROM projects
|
138 |
+
WHERE projects.project_id = threads.project_id
|
139 |
+
AND basejump.has_role_on_account(projects.account_id) = true
|
140 |
+
)
|
141 |
+
);
|
142 |
+
|
143 |
+
CREATE POLICY thread_update_policy ON threads
|
144 |
+
FOR UPDATE
|
145 |
+
USING (
|
146 |
+
basejump.has_role_on_account(account_id) = true OR
|
147 |
+
EXISTS (
|
148 |
+
SELECT 1 FROM projects
|
149 |
+
WHERE projects.project_id = threads.project_id
|
150 |
+
AND basejump.has_role_on_account(projects.account_id) = true
|
151 |
+
)
|
152 |
+
);
|
153 |
+
|
154 |
+
CREATE POLICY thread_delete_policy ON threads
|
155 |
+
FOR DELETE
|
156 |
+
USING (
|
157 |
+
basejump.has_role_on_account(account_id) = true OR
|
158 |
+
EXISTS (
|
159 |
+
SELECT 1 FROM projects
|
160 |
+
WHERE projects.project_id = threads.project_id
|
161 |
+
AND basejump.has_role_on_account(projects.account_id) = true
|
162 |
+
)
|
163 |
+
);
|
164 |
+
|
165 |
+
-- Create policies for agent_runs based on thread ownership
|
166 |
+
CREATE POLICY agent_run_select_policy ON agent_runs
|
167 |
+
FOR SELECT
|
168 |
+
USING (
|
169 |
+
EXISTS (
|
170 |
+
SELECT 1 FROM threads
|
171 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
172 |
+
WHERE threads.thread_id = agent_runs.thread_id
|
173 |
+
AND (
|
174 |
+
projects.is_public = TRUE OR
|
175 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
176 |
+
basejump.has_role_on_account(projects.account_id) = true
|
177 |
+
)
|
178 |
+
)
|
179 |
+
);
|
180 |
+
|
181 |
+
CREATE POLICY agent_run_insert_policy ON agent_runs
|
182 |
+
FOR INSERT
|
183 |
+
WITH CHECK (
|
184 |
+
EXISTS (
|
185 |
+
SELECT 1 FROM threads
|
186 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
187 |
+
WHERE threads.thread_id = agent_runs.thread_id
|
188 |
+
AND (
|
189 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
190 |
+
basejump.has_role_on_account(projects.account_id) = true
|
191 |
+
)
|
192 |
+
)
|
193 |
+
);
|
194 |
+
|
195 |
+
CREATE POLICY agent_run_update_policy ON agent_runs
|
196 |
+
FOR UPDATE
|
197 |
+
USING (
|
198 |
+
EXISTS (
|
199 |
+
SELECT 1 FROM threads
|
200 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
201 |
+
WHERE threads.thread_id = agent_runs.thread_id
|
202 |
+
AND (
|
203 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
204 |
+
basejump.has_role_on_account(projects.account_id) = true
|
205 |
+
)
|
206 |
+
)
|
207 |
+
);
|
208 |
+
|
209 |
+
CREATE POLICY agent_run_delete_policy ON agent_runs
|
210 |
+
FOR DELETE
|
211 |
+
USING (
|
212 |
+
EXISTS (
|
213 |
+
SELECT 1 FROM threads
|
214 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
215 |
+
WHERE threads.thread_id = agent_runs.thread_id
|
216 |
+
AND (
|
217 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
218 |
+
basejump.has_role_on_account(projects.account_id) = true
|
219 |
+
)
|
220 |
+
)
|
221 |
+
);
|
222 |
+
|
223 |
+
-- Create message policies based on thread ownership
|
224 |
+
CREATE POLICY message_select_policy ON messages
|
225 |
+
FOR SELECT
|
226 |
+
USING (
|
227 |
+
EXISTS (
|
228 |
+
SELECT 1 FROM threads
|
229 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
230 |
+
WHERE threads.thread_id = messages.thread_id
|
231 |
+
AND (
|
232 |
+
projects.is_public = TRUE OR
|
233 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
234 |
+
basejump.has_role_on_account(projects.account_id) = true
|
235 |
+
)
|
236 |
+
)
|
237 |
+
);
|
238 |
+
|
239 |
+
CREATE POLICY message_insert_policy ON messages
|
240 |
+
FOR INSERT
|
241 |
+
WITH CHECK (
|
242 |
+
EXISTS (
|
243 |
+
SELECT 1 FROM threads
|
244 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
245 |
+
WHERE threads.thread_id = messages.thread_id
|
246 |
+
AND (
|
247 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
248 |
+
basejump.has_role_on_account(projects.account_id) = true
|
249 |
+
)
|
250 |
+
)
|
251 |
+
);
|
252 |
+
|
253 |
+
CREATE POLICY message_update_policy ON messages
|
254 |
+
FOR UPDATE
|
255 |
+
USING (
|
256 |
+
EXISTS (
|
257 |
+
SELECT 1 FROM threads
|
258 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
259 |
+
WHERE threads.thread_id = messages.thread_id
|
260 |
+
AND (
|
261 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
262 |
+
basejump.has_role_on_account(projects.account_id) = true
|
263 |
+
)
|
264 |
+
)
|
265 |
+
);
|
266 |
+
|
267 |
+
CREATE POLICY message_delete_policy ON messages
|
268 |
+
FOR DELETE
|
269 |
+
USING (
|
270 |
+
EXISTS (
|
271 |
+
SELECT 1 FROM threads
|
272 |
+
LEFT JOIN projects ON threads.project_id = projects.project_id
|
273 |
+
WHERE threads.thread_id = messages.thread_id
|
274 |
+
AND (
|
275 |
+
basejump.has_role_on_account(threads.account_id) = true OR
|
276 |
+
basejump.has_role_on_account(projects.account_id) = true
|
277 |
+
)
|
278 |
+
)
|
279 |
+
);
|
280 |
+
|
281 |
+
-- Grant permissions to roles
|
282 |
+
GRANT ALL PRIVILEGES ON TABLE projects TO authenticated, service_role;
|
283 |
+
GRANT SELECT ON TABLE projects TO anon;
|
284 |
+
GRANT SELECT ON TABLE threads TO authenticated, anon, service_role;
|
285 |
+
GRANT SELECT ON TABLE messages TO authenticated, anon, service_role;
|
286 |
+
GRANT ALL PRIVILEGES ON TABLE agent_runs TO authenticated, service_role;
|
287 |
+
|
288 |
+
-- Create a function that matches the Python get_messages behavior
|
289 |
+
CREATE OR REPLACE FUNCTION get_llm_formatted_messages(p_thread_id UUID)
|
290 |
+
RETURNS JSONB
|
291 |
+
SECURITY DEFINER -- Changed to SECURITY DEFINER to allow service role access
|
292 |
+
LANGUAGE plpgsql
|
293 |
+
AS $$
|
294 |
+
DECLARE
|
295 |
+
messages_array JSONB := '[]'::JSONB;
|
296 |
+
has_access BOOLEAN;
|
297 |
+
current_role TEXT;
|
298 |
+
latest_summary_id UUID;
|
299 |
+
latest_summary_time TIMESTAMP WITH TIME ZONE;
|
300 |
+
is_project_public BOOLEAN;
|
301 |
+
BEGIN
|
302 |
+
-- Get current role
|
303 |
+
SELECT current_user INTO current_role;
|
304 |
+
|
305 |
+
-- Check if associated project is public
|
306 |
+
SELECT p.is_public INTO is_project_public
|
307 |
+
FROM threads t
|
308 |
+
LEFT JOIN projects p ON t.project_id = p.project_id
|
309 |
+
WHERE t.thread_id = p_thread_id;
|
310 |
+
|
311 |
+
-- Skip access check for service_role or public projects
|
312 |
+
IF current_role = 'authenticated' AND NOT is_project_public THEN
|
313 |
+
-- Check if thread exists and user has access
|
314 |
+
SELECT EXISTS (
|
315 |
+
SELECT 1 FROM threads t
|
316 |
+
LEFT JOIN projects p ON t.project_id = p.project_id
|
317 |
+
WHERE t.thread_id = p_thread_id
|
318 |
+
AND (
|
319 |
+
basejump.has_role_on_account(t.account_id) = true OR
|
320 |
+
basejump.has_role_on_account(p.account_id) = true
|
321 |
+
)
|
322 |
+
) INTO has_access;
|
323 |
+
|
324 |
+
IF NOT has_access THEN
|
325 |
+
RAISE EXCEPTION 'Thread not found or access denied';
|
326 |
+
END IF;
|
327 |
+
END IF;
|
328 |
+
|
329 |
+
-- Find the latest summary message if it exists
|
330 |
+
SELECT message_id, created_at
|
331 |
+
INTO latest_summary_id, latest_summary_time
|
332 |
+
FROM messages
|
333 |
+
WHERE thread_id = p_thread_id
|
334 |
+
AND type = 'summary'
|
335 |
+
AND is_llm_message = TRUE
|
336 |
+
ORDER BY created_at DESC
|
337 |
+
LIMIT 1;
|
338 |
+
|
339 |
+
-- Log whether a summary was found (helpful for debugging)
|
340 |
+
IF latest_summary_id IS NOT NULL THEN
|
341 |
+
RAISE NOTICE 'Found latest summary message: id=%, time=%', latest_summary_id, latest_summary_time;
|
342 |
+
ELSE
|
343 |
+
RAISE NOTICE 'No summary message found for thread %', p_thread_id;
|
344 |
+
END IF;
|
345 |
+
|
346 |
+
-- Parse content if it's stored as a string and return proper JSON objects
|
347 |
+
WITH parsed_messages AS (
|
348 |
+
SELECT
|
349 |
+
message_id,
|
350 |
+
CASE
|
351 |
+
WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
|
352 |
+
ELSE content
|
353 |
+
END AS parsed_content,
|
354 |
+
created_at,
|
355 |
+
type
|
356 |
+
FROM messages
|
357 |
+
WHERE thread_id = p_thread_id
|
358 |
+
AND is_llm_message = TRUE
|
359 |
+
AND (
|
360 |
+
-- Include the latest summary and all messages after it,
|
361 |
+
-- or all messages if no summary exists
|
362 |
+
latest_summary_id IS NULL
|
363 |
+
OR message_id = latest_summary_id
|
364 |
+
OR created_at > latest_summary_time
|
365 |
+
)
|
366 |
+
ORDER BY created_at
|
367 |
+
)
|
368 |
+
SELECT JSONB_AGG(parsed_content)
|
369 |
+
INTO messages_array
|
370 |
+
FROM parsed_messages;
|
371 |
+
|
372 |
+
-- Handle the case when no messages are found
|
373 |
+
IF messages_array IS NULL THEN
|
374 |
+
RETURN '[]'::JSONB;
|
375 |
+
END IF;
|
376 |
+
|
377 |
+
RETURN messages_array;
|
378 |
+
END;
|
379 |
+
$$;
|
380 |
+
|
381 |
+
-- Grant execute permissions
|
382 |
+
GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, anon, service_role;
|
20250506000000_initial_setup.sql
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Create required schemas
|
2 |
+
CREATE SCHEMA IF NOT EXISTS auth;
|
3 |
+
CREATE SCHEMA IF NOT EXISTS storage;
|
4 |
+
CREATE SCHEMA IF NOT EXISTS basejump;
|
5 |
+
|
6 |
+
-- Create basic roles
|
7 |
+
CREATE ROLE IF NOT EXISTS anon NOLOGIN;
|
8 |
+
GRANT USAGE ON SCHEMA public TO anon;
|
9 |
+
GRANT USAGE ON SCHEMA auth TO anon;
|
10 |
+
GRANT USAGE ON SCHEMA basejump TO anon;
|
11 |
+
|
12 |
+
-- Create a basic users table if it doesn't exist
|
13 |
+
CREATE TABLE IF NOT EXISTS auth.users (
|
14 |
+
id uuid PRIMARY KEY,
|
15 |
+
email text UNIQUE,
|
16 |
+
encrypted_password text,
|
17 |
+
created_at timestamp with time zone DEFAULT now(),
|
18 |
+
updated_at timestamp with time zone DEFAULT now()
|
19 |
+
);
|
20 |
+
|
21 |
+
-- Add Basejump configuration
|
22 |
+
CREATE TABLE IF NOT EXISTS basejump.config (
|
23 |
+
enable_team_accounts boolean DEFAULT true,
|
24 |
+
enable_personal_account_billing boolean DEFAULT true,
|
25 |
+
enable_team_account_billing boolean DEFAULT true
|
26 |
+
);
|
27 |
+
|
28 |
+
-- Insert default config if table is empty
|
29 |
+
INSERT INTO basejump.config (enable_team_accounts, enable_personal_account_billing, enable_team_account_billing)
|
30 |
+
SELECT true, true, true
|
31 |
+
WHERE NOT EXISTS (SELECT 1 FROM basejump.config);
|
32 |
+
|
33 |
+
-- Create accounts table for Suna
|
34 |
+
CREATE TABLE IF NOT EXISTS public.accounts (
|
35 |
+
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
36 |
+
name text NOT NULL,
|
37 |
+
slug text UNIQUE NOT NULL,
|
38 |
+
created_at timestamptz DEFAULT now(),
|
39 |
+
updated_at timestamptz DEFAULT now()
|
40 |
+
);
|
41 |
+
|
42 |
+
-- Create projects table for Suna
|
43 |
+
CREATE TABLE IF NOT EXISTS public.projects (
|
44 |
+
project_id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
45 |
+
name text NOT NULL,
|
46 |
+
description text,
|
47 |
+
account_id uuid REFERENCES public.accounts(id) ON DELETE CASCADE,
|
48 |
+
sandbox jsonb DEFAULT NULL,
|
49 |
+
created_at timestamptz DEFAULT now(),
|
50 |
+
updated_at timestamptz DEFAULT now()
|
51 |
+
);
|
52 |
+
|
53 |
+
-- Create a function to create accounts
|
54 |
+
CREATE OR REPLACE FUNCTION create_account(
|
55 |
+
name TEXT,
|
56 |
+
slug TEXT
|
57 |
+
) RETURNS json
|
58 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
59 |
+
AS $$
|
60 |
+
DECLARE
|
61 |
+
account_id uuid;
|
62 |
+
existing_account_id uuid;
|
63 |
+
return_data json;
|
64 |
+
BEGIN
|
65 |
+
-- Check if slug is already taken
|
66 |
+
SELECT id INTO existing_account_id FROM public.accounts WHERE accounts.slug = create_account.slug;
|
67 |
+
|
68 |
+
IF existing_account_id IS NOT NULL THEN
|
69 |
+
RETURN json_build_object('error', 'Slug already taken');
|
70 |
+
END IF;
|
71 |
+
|
72 |
+
-- Insert account
|
73 |
+
INSERT INTO public.accounts (name, slug)
|
74 |
+
VALUES (create_account.name, create_account.slug)
|
75 |
+
RETURNING id INTO account_id;
|
76 |
+
|
77 |
+
return_data := json_build_object(
|
78 |
+
'id', account_id,
|
79 |
+
'name', name,
|
80 |
+
'slug', slug
|
81 |
+
);
|
82 |
+
|
83 |
+
RETURN return_data;
|
84 |
+
END;
|
85 |
+
$$;
|
20250506000001_account_functions.sql
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Add account management functions
|
2 |
+
|
3 |
+
-- Function to update an account
|
4 |
+
CREATE OR REPLACE FUNCTION update_account(
|
5 |
+
name TEXT,
|
6 |
+
account_id UUID
|
7 |
+
) RETURNS void
|
8 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
9 |
+
AS $$
|
10 |
+
BEGIN
|
11 |
+
UPDATE public.accounts
|
12 |
+
SET
|
13 |
+
name = update_account.name,
|
14 |
+
updated_at = now()
|
15 |
+
WHERE id = update_account.account_id;
|
16 |
+
END;
|
17 |
+
$$;
|
18 |
+
|
19 |
+
-- Function to get all accounts for current user
|
20 |
+
CREATE OR REPLACE FUNCTION get_accounts()
|
21 |
+
RETURNS json
|
22 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
23 |
+
AS $$
|
24 |
+
DECLARE
|
25 |
+
current_user_id uuid;
|
26 |
+
account_data json;
|
27 |
+
BEGIN
|
28 |
+
-- Get the current user's ID
|
29 |
+
current_user_id := auth.uid();
|
30 |
+
|
31 |
+
-- Query for accounts
|
32 |
+
SELECT json_agg(
|
33 |
+
json_build_object(
|
34 |
+
'id', a.id,
|
35 |
+
'name', a.name,
|
36 |
+
'slug', a.slug,
|
37 |
+
'personal_account', a.id = current_user_id
|
38 |
+
)
|
39 |
+
) INTO account_data
|
40 |
+
FROM public.accounts a
|
41 |
+
WHERE a.id = current_user_id;
|
42 |
+
|
43 |
+
-- Return empty array if no results
|
44 |
+
IF account_data IS NULL THEN
|
45 |
+
RETURN '[]'::json;
|
46 |
+
END IF;
|
47 |
+
|
48 |
+
RETURN account_data;
|
49 |
+
END;
|
50 |
+
$$;
|
20250506000002_project_functions.sql
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Add project management functions
|
2 |
+
|
3 |
+
-- Function to create a project with account validation
|
4 |
+
CREATE OR REPLACE FUNCTION create_project(
|
5 |
+
name TEXT,
|
6 |
+
description TEXT,
|
7 |
+
account_id UUID
|
8 |
+
) RETURNS json
|
9 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
10 |
+
AS $$
|
11 |
+
DECLARE
|
12 |
+
new_project_id uuid;
|
13 |
+
project_data json;
|
14 |
+
BEGIN
|
15 |
+
-- Insert project
|
16 |
+
INSERT INTO public.projects (name, description, account_id)
|
17 |
+
VALUES (create_project.name, create_project.description, create_project.account_id)
|
18 |
+
RETURNING project_id INTO new_project_id;
|
19 |
+
|
20 |
+
-- Get the full project data
|
21 |
+
SELECT json_build_object(
|
22 |
+
'project_id', p.project_id,
|
23 |
+
'name', p.name,
|
24 |
+
'description', p.description,
|
25 |
+
'account_id', p.account_id,
|
26 |
+
'sandbox', p.sandbox,
|
27 |
+
'created_at', p.created_at,
|
28 |
+
'updated_at', p.updated_at
|
29 |
+
) INTO project_data
|
30 |
+
FROM public.projects p
|
31 |
+
WHERE p.project_id = new_project_id;
|
32 |
+
|
33 |
+
RETURN project_data;
|
34 |
+
END;
|
35 |
+
$$;
|
36 |
+
|
37 |
+
-- Function to update a project
|
38 |
+
CREATE OR REPLACE FUNCTION update_project(
|
39 |
+
project_id UUID,
|
40 |
+
name TEXT,
|
41 |
+
description TEXT
|
42 |
+
) RETURNS json
|
43 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
44 |
+
AS $$
|
45 |
+
DECLARE
|
46 |
+
updated_project_data json;
|
47 |
+
BEGIN
|
48 |
+
-- Update the project
|
49 |
+
UPDATE public.projects
|
50 |
+
SET
|
51 |
+
name = COALESCE(update_project.name, name),
|
52 |
+
description = COALESCE(update_project.description, description),
|
53 |
+
updated_at = now()
|
54 |
+
WHERE project_id = update_project.project_id;
|
55 |
+
|
56 |
+
-- Get the updated project data
|
57 |
+
SELECT json_build_object(
|
58 |
+
'project_id', p.project_id,
|
59 |
+
'name', p.name,
|
60 |
+
'description', p.description,
|
61 |
+
'account_id', p.account_id,
|
62 |
+
'sandbox', p.sandbox,
|
63 |
+
'created_at', p.created_at,
|
64 |
+
'updated_at', p.updated_at
|
65 |
+
) INTO updated_project_data
|
66 |
+
FROM public.projects p
|
67 |
+
WHERE p.project_id = update_project.project_id;
|
68 |
+
|
69 |
+
RETURN updated_project_data;
|
70 |
+
END;
|
71 |
+
$$;
|
72 |
+
|
73 |
+
-- Function to update a project's sandbox information
|
74 |
+
CREATE OR REPLACE FUNCTION update_project_sandbox(
|
75 |
+
project_id UUID,
|
76 |
+
sandbox_data jsonb
|
77 |
+
) RETURNS json
|
78 |
+
LANGUAGE plpgsql SECURITY DEFINER
|
79 |
+
AS $$
|
80 |
+
DECLARE
|
81 |
+
updated_project_data json;
|
82 |
+
BEGIN
|
83 |
+
-- Update the project sandbox data
|
84 |
+
UPDATE public.projects
|
85 |
+
SET
|
86 |
+
sandbox = sandbox_data,
|
87 |
+
updated_at = now()
|
88 |
+
WHERE project_id = update_project_sandbox.project_id;
|
89 |
+
|
90 |
+
-- Get the updated project data
|
91 |
+
SELECT json_build_object(
|
92 |
+
'project_id', p.project_id,
|
93 |
+
'name', p.name,
|
94 |
+
'description', p.description,
|
95 |
+
'account_id', p.account_id,
|
96 |
+
'sandbox', p.sandbox,
|
97 |
+
'created_at', p.created_at,
|
98 |
+
'updated_at', p.updated_at
|
99 |
+
) INTO updated_project_data
|
100 |
+
FROM public.projects p
|
101 |
+
WHERE p.project_id = update_project_sandbox.project_id;
|
102 |
+
|
103 |
+
RETURN updated_project_data;
|
104 |
+
END;
|
105 |
+
$$;
|
22dc0511fe69_add_toolsource_table.cpython-311.pyc
ADDED
Binary file (1.33 kB). View file
|
|
2ea570019b8f_add_apikey_table.cpython-311.pyc
ADDED
Binary file (4.54 kB). View file
|
|
2ea570019b8f_add_apikey_table.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Add ApiKey table
|
2 |
+
|
3 |
+
Revision ID: 2ea570019b8f
|
4 |
+
Revises: 4af13678b83c
|
5 |
+
Create Date: 2025-05-03 18:56:32.989446
|
6 |
+
|
7 |
+
"""
|
8 |
+
from typing import Sequence, Union
|
9 |
+
|
10 |
+
from alembic import op
|
11 |
+
import sqlalchemy as sa
|
12 |
+
from sqlalchemy.dialects import postgresql
|
13 |
+
|
14 |
+
# revision identifiers, used by Alembic.
|
15 |
+
revision: str = '2ea570019b8f'
|
16 |
+
down_revision: Union[str, None] = '4af13678b83c'
|
17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
19 |
+
|
20 |
+
|
21 |
+
def upgrade() -> None:
|
22 |
+
"""Upgrade schema."""
|
23 |
+
# ### commands auto generated by Alembic - adjusted ###
|
24 |
+
op.create_table('api_keys',
|
25 |
+
sa.Column('id', sa.UUID(), nullable=False),
|
26 |
+
sa.Column('user_id', sa.UUID(), nullable=False),
|
27 |
+
sa.Column('name', sa.String(), nullable=False),
|
28 |
+
sa.Column('description', sa.Text(), nullable=True),
|
29 |
+
sa.Column('key_prefix', sa.String(length=8), nullable=False),
|
30 |
+
sa.Column('hashed_key', sa.String(), nullable=False),
|
31 |
+
sa.Column('scopes', postgresql.ARRAY(sa.String()), nullable=True),
|
32 |
+
sa.Column('is_active', sa.Boolean(), nullable=False),
|
33 |
+
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
34 |
+
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
35 |
+
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_api_keys_user_id_users_id')),
|
36 |
+
sa.PrimaryKeyConstraint('id', name=op.f('pk_api_keys'))
|
37 |
+
)
|
38 |
+
with op.batch_alter_table('api_keys', schema=None) as batch_op:
|
39 |
+
batch_op.create_index(batch_op.f('ix_api_keys_hashed_key'), ['hashed_key'], unique=False)
|
40 |
+
batch_op.create_index(batch_op.f('ix_api_keys_key_prefix'), ['key_prefix'], unique=True)
|
41 |
+
batch_op.create_index(batch_op.f('ix_api_keys_user_id'), ['user_id'], unique=False)
|
42 |
+
|
43 |
+
# Removed incorrect drop/alter table commands for other tables
|
44 |
+
# ### end Alembic commands ###
|
45 |
+
|
46 |
+
|
47 |
+
def downgrade() -> None:
|
48 |
+
"""Downgrade schema."""
|
49 |
+
# ### commands auto generated by Alembic - adjusted ###
|
50 |
+
with op.batch_alter_table('api_keys', schema=None) as batch_op:
|
51 |
+
batch_op.drop_index(batch_op.f('ix_api_keys_user_id'))
|
52 |
+
batch_op.drop_index(batch_op.f('ix_api_keys_key_prefix'))
|
53 |
+
batch_op.drop_index(batch_op.f('ix_api_keys_hashed_key'))
|
54 |
+
|
55 |
+
op.drop_table('api_keys')
|
56 |
+
# Removed incorrect create/alter table commands for other tables
|
57 |
+
# ### end Alembic commands ###
|
58 |
+
|
4af13678b83c_add_toolsource_table.cpython-311.pyc
ADDED
Binary file (3.61 kB). View file
|
|
4af13678b83c_add_toolsource_table.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Add ToolSource table
|
2 |
+
|
3 |
+
Revision ID: 4af13678b83c
|
4 |
+
Revises: e2ca2546bf71
|
5 |
+
Create Date: 2025-05-03 18:51:11.601728
|
6 |
+
|
7 |
+
"""
|
8 |
+
from typing import Sequence, Union
|
9 |
+
|
10 |
+
from alembic import op
|
11 |
+
import sqlalchemy as sa
|
12 |
+
from sqlalchemy.dialects import postgresql
|
13 |
+
|
14 |
+
# revision identifiers, used by Alembic.
|
15 |
+
revision: str = '4af13678b83c'
|
16 |
+
down_revision: Union[str, None] = 'e2ca2546bf71'
|
17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
19 |
+
|
20 |
+
|
21 |
+
def upgrade() -> None:
|
22 |
+
"""Upgrade schema."""
|
23 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
24 |
+
# Manually corrected: Remove incorrect drop commands and add create_table for tool_sources
|
25 |
+
op.create_table('tool_sources',
|
26 |
+
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
27 |
+
sa.Column('github_url', sa.String(), nullable=False),
|
28 |
+
sa.Column('description', sa.Text(), nullable=True),
|
29 |
+
sa.Column('status', sa.String(), nullable=False, server_default='active'), # Match default from model
|
30 |
+
sa.Column('last_checked_at', sa.DateTime(timezone=True), nullable=True),
|
31 |
+
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
32 |
+
sa.PrimaryKeyConstraint('id', name=op.f('pk_tool_sources'))
|
33 |
+
)
|
34 |
+
with op.batch_alter_table('tool_sources', schema=None) as batch_op:
|
35 |
+
batch_op.create_index(batch_op.f('ix_tool_sources_github_url'), ['github_url'], unique=True)
|
36 |
+
batch_op.create_index(batch_op.f('ix_tool_sources_status'), ['status'], unique=False)
|
37 |
+
|
38 |
+
# ### end Alembic commands ###
|
39 |
+
|
40 |
+
|
41 |
+
def downgrade() -> None:
|
42 |
+
"""Downgrade schema."""
|
43 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
44 |
+
with op.batch_alter_table('tool_sources', schema=None) as batch_op:
|
45 |
+
batch_op.drop_index(batch_op.f('ix_tool_sources_status'))
|
46 |
+
batch_op.drop_index(batch_op.f('ix_tool_sources_github_url'))
|
47 |
+
|
48 |
+
op.drop_table('tool_sources')
|
49 |
+
# ### end Alembic commands ###
|
50 |
+
|
ActiveJobsProvider.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
4 |
+
|
5 |
+
|
6 |
+
class ActiveJobsProvider(RapidDataProviderBase):
|
7 |
+
def __init__(self):
|
8 |
+
endpoints: Dict[str, EndpointSchema] = {
|
9 |
+
"active_jobs": {
|
10 |
+
"route": "/active-ats-7d",
|
11 |
+
"method": "GET",
|
12 |
+
"name": "Active Jobs Search",
|
13 |
+
"description": "Get active job listings with various filter options.",
|
14 |
+
"payload": {
|
15 |
+
"limit": "Optional. Number of jobs per API call (10-100). Default is 100.",
|
16 |
+
"offset": "Optional. Offset for pagination. Default is 0.",
|
17 |
+
"title_filter": "Optional. Search terms for job title.",
|
18 |
+
"advanced_title_filter": "Optional. Advanced title filter with operators (can't be used with title_filter).",
|
19 |
+
"location_filter": "Optional. Filter by location(s). Use full names like 'United States' not 'US'.",
|
20 |
+
"description_filter": "Optional. Filter on job description content.",
|
21 |
+
"organization_filter": "Optional. Filter by company name(s).",
|
22 |
+
"description_type": "Optional. Return format for description: 'text' or 'html'. Leave empty to exclude descriptions.",
|
23 |
+
"source": "Optional. Filter by ATS source.",
|
24 |
+
"date_filter": "Optional. Filter by posting date (greater than).",
|
25 |
+
"ai_employment_type_filter": "Optional. Filter by employment type (FULL_TIME, PART_TIME, etc).",
|
26 |
+
"ai_work_arrangement_filter": "Optional. Filter by work arrangement (On-site, Hybrid, Remote OK, Remote Solely).",
|
27 |
+
"ai_experience_level_filter": "Optional. Filter by experience level (0-2, 2-5, 5-10, 10+).",
|
28 |
+
"li_organization_slug_filter": "Optional. Filter by LinkedIn company slug.",
|
29 |
+
"li_organization_slug_exclusion_filter": "Optional. Exclude LinkedIn company slugs.",
|
30 |
+
"li_industry_filter": "Optional. Filter by LinkedIn industry.",
|
31 |
+
"li_organization_specialties_filter": "Optional. Filter by LinkedIn company specialties.",
|
32 |
+
"li_organization_description_filter": "Optional. Filter by LinkedIn company description."
|
33 |
+
}
|
34 |
+
}
|
35 |
+
}
|
36 |
+
|
37 |
+
base_url = "https://active-jobs-db.p.rapidapi.com"
|
38 |
+
super().__init__(base_url, endpoints)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
from dotenv import load_dotenv
|
43 |
+
load_dotenv()
|
44 |
+
tool = ActiveJobsProvider()
|
45 |
+
|
46 |
+
# Example for searching active jobs
|
47 |
+
jobs = tool.call_endpoint(
|
48 |
+
route="active_jobs",
|
49 |
+
payload={
|
50 |
+
"limit": "10",
|
51 |
+
"offset": "0",
|
52 |
+
"title_filter": "\"Data Engineer\"",
|
53 |
+
"location_filter": "\"United States\" OR \"United Kingdom\"",
|
54 |
+
"description_type": "text"
|
55 |
+
}
|
56 |
+
)
|
57 |
+
print("Active Jobs:", jobs)
|
AmazonProvider.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
|
3 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
4 |
+
|
5 |
+
|
6 |
+
class AmazonProvider(RapidDataProviderBase):
|
7 |
+
def __init__(self):
|
8 |
+
endpoints: Dict[str, EndpointSchema] = {
|
9 |
+
"search": {
|
10 |
+
"route": "/search",
|
11 |
+
"method": "GET",
|
12 |
+
"name": "Amazon Product Search",
|
13 |
+
"description": "Search for products on Amazon with various filters and parameters.",
|
14 |
+
"payload": {
|
15 |
+
"query": "Search query (supports both free-form text queries or a product asin)",
|
16 |
+
"page": "Results page to return (default: 1)",
|
17 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
18 |
+
"sort_by": "Return the results in a specific sort order (RELEVANCE, LOWEST_PRICE, HIGHEST_PRICE, REVIEWS, NEWEST, BEST_SELLERS)",
|
19 |
+
"product_condition": "Return products in a specific condition (ALL, NEW, USED, RENEWED, COLLECTIBLE)",
|
20 |
+
"is_prime": "Only return prime products (boolean)",
|
21 |
+
"deals_and_discounts": "Return deals and discounts in a specific condition (NONE, ALL_DISCOUNTS, TODAYS_DEALS)",
|
22 |
+
"category_id": "Find products in a specific category / department (optional)",
|
23 |
+
"category": "Filter by specific numeric Amazon category (optional)",
|
24 |
+
"min_price": "Only return product offers with price greater than a certain value (optional)",
|
25 |
+
"max_price": "Only return product offers with price lower than a certain value (optional)",
|
26 |
+
"brand": "Find products with a specific brand (optional)",
|
27 |
+
"seller_id": "Find products sold by specific seller (optional)",
|
28 |
+
"four_stars_and_up": "Return product listings with ratings of 4 stars & up (optional)",
|
29 |
+
"additional_filters": "Any filters available on the Amazon page but not part of this endpoint's parameters (optional)"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"product-details": {
|
33 |
+
"route": "/product-details",
|
34 |
+
"method": "GET",
|
35 |
+
"name": "Amazon Product Details",
|
36 |
+
"description": "Get detailed information about specific Amazon products by ASIN.",
|
37 |
+
"payload": {
|
38 |
+
"asin": "Product ASIN for which to get details. Supports batching of up to 10 ASINs in a single request, separated by comma.",
|
39 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
40 |
+
"more_info_query": "A query to search and get more info about the product as part of Product Information, Customer Q&As, and Customer Reviews (optional)",
|
41 |
+
"fields": "A comma separated list of product fields to include in the response (field projection). By default all fields are returned. (optional)"
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"products-by-category": {
|
45 |
+
"route": "/products-by-category",
|
46 |
+
"method": "GET",
|
47 |
+
"name": "Amazon Products by Category",
|
48 |
+
"description": "Get products from a specific Amazon category.",
|
49 |
+
"payload": {
|
50 |
+
"category_id": "The Amazon category for which to return results. Multiple category values can be separated by comma.",
|
51 |
+
"page": "Page to return (default: 1)",
|
52 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
53 |
+
"sort_by": "Return the results in a specific sort order (RELEVANCE, LOWEST_PRICE, HIGHEST_PRICE, REVIEWS, NEWEST, BEST_SELLERS)",
|
54 |
+
"min_price": "Only return product offers with price greater than a certain value (optional)",
|
55 |
+
"max_price": "Only return product offers with price lower than a certain value (optional)",
|
56 |
+
"product_condition": "Return products in a specific condition (ALL, NEW, USED, RENEWED, COLLECTIBLE)",
|
57 |
+
"brand": "Only return products of a specific brand. Multiple brands can be specified as a comma separated list (optional)",
|
58 |
+
"is_prime": "Only return prime products (boolean)",
|
59 |
+
"deals_and_discounts": "Return deals and discounts in a specific condition (NONE, ALL_DISCOUNTS, TODAYS_DEALS)",
|
60 |
+
"four_stars_and_up": "Return product listings with ratings of 4 stars & up (optional)",
|
61 |
+
"additional_filters": "Any filters available on the Amazon page but not part of this endpoint's parameters (optional)"
|
62 |
+
}
|
63 |
+
},
|
64 |
+
"product-reviews": {
|
65 |
+
"route": "/product-reviews",
|
66 |
+
"method": "GET",
|
67 |
+
"name": "Amazon Product Reviews",
|
68 |
+
"description": "Get customer reviews for a specific Amazon product by ASIN.",
|
69 |
+
"payload": {
|
70 |
+
"asin": "Product asin for which to get reviews.",
|
71 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
72 |
+
"page": "Results page to return (default: 1)",
|
73 |
+
"sort_by": "Return reviews in a specific sort order (TOP_REVIEWS, MOST_RECENT)",
|
74 |
+
"star_rating": "Only return reviews with a specific star rating (ALL, 5_STARS, 4_STARS, 3_STARS, 2_STARS, 1_STARS, POSITIVE, CRITICAL)",
|
75 |
+
"verified_purchases_only": "Only return reviews by reviewers who made a verified purchase (boolean)",
|
76 |
+
"images_or_videos_only": "Only return reviews containing images and / or videos (boolean)",
|
77 |
+
"current_format_only": "Only return reviews of the current format (product variant - e.g. Color) (boolean)"
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"seller-profile": {
|
81 |
+
"route": "/seller-profile",
|
82 |
+
"method": "GET",
|
83 |
+
"name": "Amazon Seller Profile",
|
84 |
+
"description": "Get detailed information about a specific Amazon seller by Seller ID.",
|
85 |
+
"payload": {
|
86 |
+
"seller_id": "The Amazon Seller ID for which to get seller profile details",
|
87 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
88 |
+
"fields": "A comma separated list of seller profile fields to include in the response (field projection). By default all fields are returned. (optional)"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
"seller-reviews": {
|
92 |
+
"route": "/seller-reviews",
|
93 |
+
"method": "GET",
|
94 |
+
"name": "Amazon Seller Reviews",
|
95 |
+
"description": "Get customer reviews for a specific Amazon seller by Seller ID.",
|
96 |
+
"payload": {
|
97 |
+
"seller_id": "The Amazon Seller ID for which to get seller reviews",
|
98 |
+
"country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
|
99 |
+
"star_rating": "Only return reviews with a specific star rating or positive / negative sentiment (ALL, 5_STARS, 4_STARS, 3_STARS, 2_STARS, 1_STARS, POSITIVE, CRITICAL)",
|
100 |
+
"page": "The page of seller feedback results to retrieve (default: 1)",
|
101 |
+
"fields": "A comma separated list of seller review fields to include in the response (field projection). By default all fields are returned. (optional)"
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
base_url = "https://real-time-amazon-data.p.rapidapi.com"
|
106 |
+
super().__init__(base_url, endpoints)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
from dotenv import load_dotenv
|
111 |
+
load_dotenv()
|
112 |
+
tool = AmazonProvider()
|
113 |
+
|
114 |
+
# Example for product search
|
115 |
+
search_result = tool.call_endpoint(
|
116 |
+
route="search",
|
117 |
+
payload={
|
118 |
+
"query": "Phone",
|
119 |
+
"page": 1,
|
120 |
+
"country": "US",
|
121 |
+
"sort_by": "RELEVANCE",
|
122 |
+
"product_condition": "ALL",
|
123 |
+
"is_prime": False,
|
124 |
+
"deals_and_discounts": "NONE"
|
125 |
+
}
|
126 |
+
)
|
127 |
+
print("Search Result:", search_result)
|
128 |
+
|
129 |
+
# Example for product details
|
130 |
+
details_result = tool.call_endpoint(
|
131 |
+
route="product-details",
|
132 |
+
payload={
|
133 |
+
"asin": "B07ZPKBL9V",
|
134 |
+
"country": "US"
|
135 |
+
}
|
136 |
+
)
|
137 |
+
print("Product Details:", details_result)
|
138 |
+
|
139 |
+
# Example for products by category
|
140 |
+
category_result = tool.call_endpoint(
|
141 |
+
route="products-by-category",
|
142 |
+
payload={
|
143 |
+
"category_id": "2478868012",
|
144 |
+
"page": 1,
|
145 |
+
"country": "US",
|
146 |
+
"sort_by": "RELEVANCE",
|
147 |
+
"product_condition": "ALL",
|
148 |
+
"is_prime": False,
|
149 |
+
"deals_and_discounts": "NONE"
|
150 |
+
}
|
151 |
+
)
|
152 |
+
print("Category Products:", category_result)
|
153 |
+
|
154 |
+
# Example for product reviews
|
155 |
+
reviews_result = tool.call_endpoint(
|
156 |
+
route="product-reviews",
|
157 |
+
payload={
|
158 |
+
"asin": "B07ZPKN6YR",
|
159 |
+
"country": "US",
|
160 |
+
"page": 1,
|
161 |
+
"sort_by": "TOP_REVIEWS",
|
162 |
+
"star_rating": "ALL",
|
163 |
+
"verified_purchases_only": False,
|
164 |
+
"images_or_videos_only": False,
|
165 |
+
"current_format_only": False
|
166 |
+
}
|
167 |
+
)
|
168 |
+
print("Product Reviews:", reviews_result)
|
169 |
+
|
170 |
+
# Example for seller profile
|
171 |
+
seller_result = tool.call_endpoint(
|
172 |
+
route="seller-profile",
|
173 |
+
payload={
|
174 |
+
"seller_id": "A02211013Q5HP3OMSZC7W",
|
175 |
+
"country": "US"
|
176 |
+
}
|
177 |
+
)
|
178 |
+
print("Seller Profile:", seller_result)
|
179 |
+
|
180 |
+
# Example for seller reviews
|
181 |
+
seller_reviews_result = tool.call_endpoint(
|
182 |
+
route="seller-reviews",
|
183 |
+
payload={
|
184 |
+
"seller_id": "A02211013Q5HP3OMSZC7W",
|
185 |
+
"country": "US",
|
186 |
+
"star_rating": "ALL",
|
187 |
+
"page": 1
|
188 |
+
}
|
189 |
+
)
|
190 |
+
print("Seller Reviews:", seller_reviews_result)
|
191 |
+
|
ChatInterface.tsx
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// /home/ubuntu/visionos-frontend/src/components/ChatInterface.tsx
|
2 |
+
import React from 'react';
|
3 |
+
|
4 |
+
const ChatInterface: React.FC = () => {
|
5 |
+
return (
|
6 |
+
<div className="flex flex-col h-full border rounded-lg shadow-md">
|
7 |
+
{/* Message Display Area */}
|
8 |
+
<div className="flex-grow p-4 overflow-y-auto bg-gray-50">
|
9 |
+
{/* Placeholder for messages */}
|
10 |
+
<p className="text-gray-500">Chat messages will appear here...</p>
|
11 |
+
</div>
|
12 |
+
|
13 |
+
{/* Input Area */}
|
14 |
+
<div className="p-4 border-t bg-white">
|
15 |
+
<input
|
16 |
+
type="text"
|
17 |
+
placeholder="Type your message..."
|
18 |
+
className="w-full p-2 border rounded"
|
19 |
+
/>
|
20 |
+
{/* Placeholder for Send Button */}
|
21 |
+
<button className="mt-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
|
22 |
+
Send
|
23 |
+
</button>
|
24 |
+
</div>
|
25 |
+
</div>
|
26 |
+
);
|
27 |
+
};
|
28 |
+
|
29 |
+
export default ChatInterface;
|
30 |
+
|
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM node:18-alpine
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Copy package files and install dependencies
|
6 |
+
COPY package*.json ./
|
7 |
+
RUN npm install
|
8 |
+
|
9 |
+
# Copy the rest of the application code
|
10 |
+
COPY . .
|
11 |
+
|
12 |
+
# Build the application
|
13 |
+
RUN npm run build
|
14 |
+
|
15 |
+
# Expose the port
|
16 |
+
EXPOSE 3000
|
17 |
+
|
18 |
+
# Start the application
|
19 |
+
CMD ["npm", "start"]
|
Layout.tsx
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// /home/ubuntu/visionos-frontend/src/components/Layout.tsx
|
2 |
+
import React from 'react';
|
3 |
+
import Link from 'next/link'; // Import Link for navigation
|
4 |
+
|
5 |
+
interface LayoutProps {
|
6 |
+
children: React.ReactNode;
|
7 |
+
}
|
8 |
+
|
9 |
+
const Layout: React.FC<LayoutProps> = ({ children }) => {
|
10 |
+
return (
|
11 |
+
<div className="flex flex-col min-h-screen">
|
12 |
+
{/* Header/Navigation */}
|
13 |
+
<header className="bg-gray-800 text-white p-4">
|
14 |
+
<nav className="container mx-auto flex justify-between items-center">
|
15 |
+
<h1 className="text-xl font-bold">
|
16 |
+
<Link href="/">VisionOS UI</Link>
|
17 |
+
</h1>
|
18 |
+
<ul className="flex space-x-4">
|
19 |
+
<li><Link href="/" className="hover:text-gray-300">Chat</Link></li>
|
20 |
+
<li><Link href="/workflow" className="hover:text-gray-300">Workflow</Link></li>
|
21 |
+
<li><Link href="/settings" className="hover:text-gray-300">Settings</Link></li>
|
22 |
+
{/* Add more navigation links here */}
|
23 |
+
</ul>
|
24 |
+
</nav>
|
25 |
+
</header>
|
26 |
+
|
27 |
+
{/* Main Content Area */}
|
28 |
+
<main className="flex-grow p-4 container mx-auto">
|
29 |
+
{children}
|
30 |
+
</main>
|
31 |
+
|
32 |
+
{/* Footer */}
|
33 |
+
<footer className="bg-gray-200 p-4 text-center text-sm text-gray-600">
|
34 |
+
© 2025 VisionOS
|
35 |
+
</footer>
|
36 |
+
</div>
|
37 |
+
);
|
38 |
+
};
|
39 |
+
|
40 |
+
export default Layout;
|
41 |
+
|
LinkedinProvider.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
4 |
+
|
5 |
+
|
6 |
+
class LinkedinProvider(RapidDataProviderBase):
|
7 |
+
def __init__(self):
|
8 |
+
endpoints: Dict[str, EndpointSchema] = {
|
9 |
+
"person": {
|
10 |
+
"route": "/person",
|
11 |
+
"method": "POST",
|
12 |
+
"name": "Person Data",
|
13 |
+
"description": "Fetches any Linkedin profiles data including skills, certificates, experiences, qualifications and much more.",
|
14 |
+
"payload": {
|
15 |
+
"link": "LinkedIn Profile URL"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"person_urn": {
|
19 |
+
"route": "/person_urn",
|
20 |
+
"method": "POST",
|
21 |
+
"name": "Person Data (Using Urn)",
|
22 |
+
"description": "It takes profile urn instead of profile public identifier in input",
|
23 |
+
"payload": {
|
24 |
+
"link": "LinkedIn Profile URL or URN"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"person_deep": {
|
28 |
+
"route": "/person_deep",
|
29 |
+
"method": "POST",
|
30 |
+
"name": "Person Data (Deep)",
|
31 |
+
"description": "Fetches all experiences, educations, skills, languages, publications... related to a profile.",
|
32 |
+
"payload": {
|
33 |
+
"link": "LinkedIn Profile URL"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"profile_updates": {
|
37 |
+
"route": "/profile_updates",
|
38 |
+
"method": "GET",
|
39 |
+
"name": "Person Posts (WITH PAGINATION)",
|
40 |
+
"description": "Fetches posts of a linkedin profile alongwith reactions, comments, postLink and reposts data.",
|
41 |
+
"payload": {
|
42 |
+
"profile_url": "LinkedIn Profile URL",
|
43 |
+
"page": "Page number",
|
44 |
+
"reposts": "Include reposts (1 or 0)",
|
45 |
+
"comments": "Include comments (1 or 0)"
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"profile_recent_comments": {
|
49 |
+
"route": "/profile_recent_comments",
|
50 |
+
"method": "POST",
|
51 |
+
"name": "Person Recent Activity (Comments on Posts)",
|
52 |
+
"description": "Fetches 20 most recent comments posted by a linkedin user (per page).",
|
53 |
+
"payload": {
|
54 |
+
"profile_url": "LinkedIn Profile URL",
|
55 |
+
"page": "Page number",
|
56 |
+
"paginationToken": "Token for pagination"
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"comments_from_recent_activity": {
|
60 |
+
"route": "/comments_from_recent_activity",
|
61 |
+
"method": "GET",
|
62 |
+
"name": "Comments from recent activity",
|
63 |
+
"description": "Fetches recent comments posted by a person as per his recent activity tab.",
|
64 |
+
"payload": {
|
65 |
+
"profile_url": "LinkedIn Profile URL",
|
66 |
+
"page": "Page number"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"person_skills": {
|
70 |
+
"route": "/person_skills",
|
71 |
+
"method": "POST",
|
72 |
+
"name": "Person Skills",
|
73 |
+
"description": "Scraper all skills of a linkedin user",
|
74 |
+
"payload": {
|
75 |
+
"link": "LinkedIn Profile URL"
|
76 |
+
}
|
77 |
+
},
|
78 |
+
"email_to_linkedin_profile": {
|
79 |
+
"route": "/email_to_linkedin_profile",
|
80 |
+
"method": "POST",
|
81 |
+
"name": "Email to LinkedIn Profile",
|
82 |
+
"description": "Finds LinkedIn profile associated with an email address",
|
83 |
+
"payload": {
|
84 |
+
"email": "Email address to search"
|
85 |
+
}
|
86 |
+
},
|
87 |
+
"company": {
|
88 |
+
"route": "/company",
|
89 |
+
"method": "POST",
|
90 |
+
"name": "Company Data",
|
91 |
+
"description": "Fetches LinkedIn company profile data",
|
92 |
+
"payload": {
|
93 |
+
"link": "LinkedIn Company URL"
|
94 |
+
}
|
95 |
+
},
|
96 |
+
"web_domain": {
|
97 |
+
"route": "/web-domain",
|
98 |
+
"method": "POST",
|
99 |
+
"name": "Web Domain to Company",
|
100 |
+
"description": "Fetches LinkedIn company profile data from a web domain",
|
101 |
+
"payload": {
|
102 |
+
"link": "Website domain (e.g., huzzle.app)"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"similar_profiles": {
|
106 |
+
"route": "/similar_profiles",
|
107 |
+
"method": "GET",
|
108 |
+
"name": "Similar Profiles",
|
109 |
+
"description": "Fetches profiles similar to a given LinkedIn profile",
|
110 |
+
"payload": {
|
111 |
+
"profileUrl": "LinkedIn Profile URL"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"company_jobs": {
|
115 |
+
"route": "/company_jobs",
|
116 |
+
"method": "POST",
|
117 |
+
"name": "Company Jobs",
|
118 |
+
"description": "Fetches job listings from a LinkedIn company page",
|
119 |
+
"payload": {
|
120 |
+
"company_url": "LinkedIn Company URL",
|
121 |
+
"count": "Number of job listings to fetch"
|
122 |
+
}
|
123 |
+
},
|
124 |
+
"company_updates": {
|
125 |
+
"route": "/company_updates",
|
126 |
+
"method": "GET",
|
127 |
+
"name": "Company Posts",
|
128 |
+
"description": "Fetches posts from a LinkedIn company page",
|
129 |
+
"payload": {
|
130 |
+
"company_url": "LinkedIn Company URL",
|
131 |
+
"page": "Page number",
|
132 |
+
"reposts": "Include reposts (0, 1, or 2)",
|
133 |
+
"comments": "Include comments (0, 1, or 2)"
|
134 |
+
}
|
135 |
+
},
|
136 |
+
"company_employee": {
|
137 |
+
"route": "/company_employee",
|
138 |
+
"method": "GET",
|
139 |
+
"name": "Company Employees",
|
140 |
+
"description": "Fetches employees of a LinkedIn company using company ID",
|
141 |
+
"payload": {
|
142 |
+
"companyId": "LinkedIn Company ID",
|
143 |
+
"page": "Page number"
|
144 |
+
}
|
145 |
+
},
|
146 |
+
"company_updates_post": {
|
147 |
+
"route": "/company_updates",
|
148 |
+
"method": "POST",
|
149 |
+
"name": "Company Posts (POST)",
|
150 |
+
"description": "Fetches posts from a LinkedIn company page with specific count parameters",
|
151 |
+
"payload": {
|
152 |
+
"company_url": "LinkedIn Company URL",
|
153 |
+
"posts": "Number of posts to fetch",
|
154 |
+
"comments": "Number of comments to fetch per post",
|
155 |
+
"reposts": "Number of reposts to fetch"
|
156 |
+
}
|
157 |
+
},
|
158 |
+
"search_posts_with_filters": {
|
159 |
+
"route": "/search_posts_with_filters",
|
160 |
+
"method": "GET",
|
161 |
+
"name": "Search Posts With Filters",
|
162 |
+
"description": "Searches LinkedIn posts with various filtering options",
|
163 |
+
"payload": {
|
164 |
+
"query": "Keywords/Search terms (text you put in LinkedIn search bar)",
|
165 |
+
"page": "Page number (1-100, each page contains 20 results)",
|
166 |
+
"sort_by": "Sort method: 'relevance' (Top match) or 'date_posted' (Latest)",
|
167 |
+
"author_job_title": "Filter by job title of author (e.g., CEO)",
|
168 |
+
"content_type": "Type of content post contains (photos, videos, liveVideos, collaborativeArticles, documents)",
|
169 |
+
"from_member": "URN of person who posted (comma-separated for multiple)",
|
170 |
+
"from_organization": "ID of organization who posted (comma-separated for multiple)",
|
171 |
+
"author_company": "ID of company author works for (comma-separated for multiple)",
|
172 |
+
"author_industry": "URN of industry author is connected with (comma-separated for multiple)",
|
173 |
+
"mentions_member": "URN of person mentioned in post (comma-separated for multiple)",
|
174 |
+
"mentions_organization": "ID of organization mentioned in post (comma-separated for multiple)"
|
175 |
+
}
|
176 |
+
},
|
177 |
+
"search_jobs": {
|
178 |
+
"route": "/search_jobs",
|
179 |
+
"method": "GET",
|
180 |
+
"name": "Search Jobs",
|
181 |
+
"description": "Searches LinkedIn jobs with various filtering options",
|
182 |
+
"payload": {
|
183 |
+
"query": "Job search keywords (e.g., Software developer)",
|
184 |
+
"page": "Page number",
|
185 |
+
"searchLocationId": "Location ID for job search (get from Suggestion location endpoint)",
|
186 |
+
"easyApply": "Filter for easy apply jobs (true or false)",
|
187 |
+
"experience": "Experience level required (1=Internship, 2=Entry level, 3=Associate, 4=Mid senior, 5=Director, 6=Executive, comma-separated)",
|
188 |
+
"jobType": "Job type (F=Full time, P=Part time, C=Contract, T=Temporary, V=Volunteer, I=Internship, O=Other, comma-separated)",
|
189 |
+
"postedAgo": "Time jobs were posted in seconds (e.g., 3600 for past hour)",
|
190 |
+
"workplaceType": "Workplace type (1=On-Site, 2=Remote, 3=Hybrid, comma-separated)",
|
191 |
+
"sortBy": "Sort method (DD=most recent, R=most relevant)",
|
192 |
+
"companyIdsList": "List of company IDs, comma-separated",
|
193 |
+
"industryIdsList": "List of industry IDs, comma-separated",
|
194 |
+
"functionIdsList": "List of function IDs, comma-separated",
|
195 |
+
"titleIdsList": "List of job title IDs, comma-separated",
|
196 |
+
"locationIdsList": "List of location IDs within specified searchLocationId country, comma-separated"
|
197 |
+
}
|
198 |
+
},
|
199 |
+
"search_people_with_filters": {
|
200 |
+
"route": "/search_people_with_filters",
|
201 |
+
"method": "POST",
|
202 |
+
"name": "Search People With Filters",
|
203 |
+
"description": "Searches LinkedIn profiles with detailed filtering options",
|
204 |
+
"payload": {
|
205 |
+
"keyword": "General search keyword",
|
206 |
+
"page": "Page number",
|
207 |
+
"title_free_text": "Job title to filter by (e.g., CEO)",
|
208 |
+
"company_free_text": "Company name to filter by",
|
209 |
+
"first_name": "First name of person",
|
210 |
+
"last_name": "Last name of person",
|
211 |
+
"current_company_list": "List of current companies (comma-separated IDs)",
|
212 |
+
"past_company_list": "List of past companies (comma-separated IDs)",
|
213 |
+
"location_list": "List of locations (comma-separated IDs)",
|
214 |
+
"language_list": "List of languages (comma-separated)",
|
215 |
+
"service_catagory_list": "List of service categories (comma-separated)",
|
216 |
+
"school_free_text": "School name to filter by",
|
217 |
+
"industry_list": "List of industries (comma-separated IDs)",
|
218 |
+
"school_list": "List of schools (comma-separated IDs)"
|
219 |
+
}
|
220 |
+
},
|
221 |
+
"search_company_with_filters": {
|
222 |
+
"route": "/search_company_with_filters",
|
223 |
+
"method": "POST",
|
224 |
+
"name": "Search Company With Filters",
|
225 |
+
"description": "Searches LinkedIn companies with detailed filtering options",
|
226 |
+
"payload": {
|
227 |
+
"keyword": "General search keyword",
|
228 |
+
"page": "Page number",
|
229 |
+
"company_size_list": "List of company sizes (comma-separated, e.g., A,D)",
|
230 |
+
"hasJobs": "Filter companies with jobs (true or false)",
|
231 |
+
"location_list": "List of location IDs (comma-separated)",
|
232 |
+
"industry_list": "List of industry IDs (comma-separated)"
|
233 |
+
}
|
234 |
+
}
|
235 |
+
}
|
236 |
+
base_url = "https://linkedin-data-scraper.p.rapidapi.com"
|
237 |
+
super().__init__(base_url, endpoints)
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == "__main__":
|
241 |
+
from dotenv import load_dotenv
|
242 |
+
load_dotenv()
|
243 |
+
tool = LinkedinProvider()
|
244 |
+
|
245 |
+
result = tool.call_endpoint(
|
246 |
+
route="comments_from_recent_activity",
|
247 |
+
payload={"profile_url": "https://www.linkedin.com/in/adamcohenhillel/", "page": 1}
|
248 |
+
)
|
249 |
+
print(result)
|
250 |
+
|
MANIFEST.in
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Include all Python files in agentpress directory
|
2 |
+
recursive-include agentpress *.py
|
3 |
+
|
4 |
+
# Include example files
|
5 |
+
recursive-include agentpress/examples *
|
6 |
+
|
7 |
+
# Include any other necessary files
|
8 |
+
include LICENSE
|
9 |
+
include README.md
|
10 |
+
include pyproject.toml
|
11 |
+
|
12 |
+
# Exclude unnecessary files
|
13 |
+
global-exclude *.pyc
|
14 |
+
global-exclude __pycache__
|
15 |
+
global-exclude .DS_Store
|
16 |
+
global-exclude *.pyo
|
17 |
+
global-exclude *.pyd
|
README
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Generic single-database configuration.
|
README.md
CHANGED
@@ -1,10 +1,36 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
|
2 |
+
|
3 |
+
## Getting Started
|
4 |
+
|
5 |
+
First, run the development server:
|
6 |
+
|
7 |
+
```bash
|
8 |
+
npm run dev
|
9 |
+
# or
|
10 |
+
yarn dev
|
11 |
+
# or
|
12 |
+
pnpm dev
|
13 |
+
# or
|
14 |
+
bun dev
|
15 |
+
```
|
16 |
+
|
17 |
+
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
18 |
+
|
19 |
+
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
|
20 |
+
|
21 |
+
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
|
22 |
+
|
23 |
+
## Learn More
|
24 |
+
|
25 |
+
To learn more about Next.js, take a look at the following resources:
|
26 |
+
|
27 |
+
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
28 |
+
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
29 |
+
|
30 |
+
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
|
31 |
+
|
32 |
+
## Deploy on Vercel
|
33 |
+
|
34 |
+
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
|
35 |
+
|
36 |
+
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
|
RapidDataProviderBase.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from typing import Dict, Any, Optional, TypedDict, Literal
|
4 |
+
|
5 |
+
|
6 |
+
class EndpointSchema(TypedDict):
|
7 |
+
route: str
|
8 |
+
method: Literal['GET', 'POST']
|
9 |
+
name: str
|
10 |
+
description: str
|
11 |
+
payload: Dict[str, Any]
|
12 |
+
|
13 |
+
|
14 |
+
class RapidDataProviderBase:
|
15 |
+
def __init__(self, base_url: str, endpoints: Dict[str, EndpointSchema]):
|
16 |
+
self.base_url = base_url
|
17 |
+
self.endpoints = endpoints
|
18 |
+
|
19 |
+
def get_endpoints(self):
|
20 |
+
return self.endpoints
|
21 |
+
|
22 |
+
def call_endpoint(
|
23 |
+
self,
|
24 |
+
route: str,
|
25 |
+
payload: Optional[Dict[str, Any]] = None
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Call an API endpoint with the given parameters and data.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
endpoint (EndpointSchema): The endpoint configuration dictionary
|
32 |
+
params (dict, optional): Query parameters for GET requests
|
33 |
+
payload (dict, optional): JSON payload for POST requests
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
dict: The JSON response from the API
|
37 |
+
"""
|
38 |
+
if route.startswith("/"):
|
39 |
+
route = route[1:]
|
40 |
+
|
41 |
+
endpoint = self.endpoints.get(route)
|
42 |
+
if not endpoint:
|
43 |
+
raise ValueError(f"Endpoint {route} not found")
|
44 |
+
|
45 |
+
url = f"{self.base_url}{endpoint['route']}"
|
46 |
+
|
47 |
+
headers = {
|
48 |
+
"x-rapidapi-key": os.getenv("RAPID_API_KEY"),
|
49 |
+
"x-rapidapi-host": url.split("//")[1].split("/")[0],
|
50 |
+
"Content-Type": "application/json"
|
51 |
+
}
|
52 |
+
|
53 |
+
method = endpoint.get('method', 'GET').upper()
|
54 |
+
|
55 |
+
if method == 'GET':
|
56 |
+
response = requests.get(url, params=payload, headers=headers)
|
57 |
+
elif method == 'POST':
|
58 |
+
response = requests.post(url, json=payload, headers=headers)
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
61 |
+
return response.json()
|
SettingsPanel.tsx
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// /home/ubuntu/visionos-frontend/src/components/SettingsPanel.tsx
|
2 |
+
import React from 'react';
|
3 |
+
|
4 |
+
const SettingsPanel: React.FC = () => {
|
5 |
+
return (
|
6 |
+
<div className="p-4 border rounded-lg shadow-md bg-white">
|
7 |
+
<h3 className="text-lg font-semibold mb-4">System Settings</h3>
|
8 |
+
{/* Placeholder for settings form */}
|
9 |
+
<div className="space-y-4">
|
10 |
+
<div>
|
11 |
+
<label htmlFor="setting1" className="block text-sm font-medium text-gray-700">Example Setting 1</label>
|
12 |
+
<input type="text" id="setting1" className="mt-1 block w-full p-2 border border-gray-300 rounded-md shadow-sm" placeholder="Value" />
|
13 |
+
</div>
|
14 |
+
<div>
|
15 |
+
<label htmlFor="setting2" className="block text-sm font-medium text-gray-700">Example Setting 2</label>
|
16 |
+
<select id="setting2" className="mt-1 block w-full p-2 border border-gray-300 rounded-md shadow-sm">
|
17 |
+
<option>Option A</option>
|
18 |
+
<option>Option B</option>
|
19 |
+
</select>
|
20 |
+
</div>
|
21 |
+
{/* Add more settings fields as needed */}
|
22 |
+
<button className="mt-4 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
|
23 |
+
Save Settings
|
24 |
+
</button>
|
25 |
+
</div>
|
26 |
+
</div>
|
27 |
+
);
|
28 |
+
};
|
29 |
+
|
30 |
+
export default SettingsPanel;
|
31 |
+
|
TwitterProvider.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
4 |
+
|
5 |
+
|
6 |
+
class TwitterProvider(RapidDataProviderBase):
|
7 |
+
def __init__(self):
|
8 |
+
endpoints: Dict[str, EndpointSchema] = {
|
9 |
+
"user_info": {
|
10 |
+
"route": "/screenname.php",
|
11 |
+
"method": "GET",
|
12 |
+
"name": "Twitter User Info",
|
13 |
+
"description": "Get information about a Twitter user by screenname or user ID.",
|
14 |
+
"payload": {
|
15 |
+
"screenname": "Twitter username without the @ symbol",
|
16 |
+
"rest_id": "Optional Twitter user's ID. If provided, overwrites screenname parameter."
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"timeline": {
|
20 |
+
"route": "/timeline.php",
|
21 |
+
"method": "GET",
|
22 |
+
"name": "User Timeline",
|
23 |
+
"description": "Get tweets from a user's timeline.",
|
24 |
+
"payload": {
|
25 |
+
"screenname": "Twitter username without the @ symbol",
|
26 |
+
"rest_id": "Optional parameter that overwrites the screenname",
|
27 |
+
"cursor": "Optional pagination cursor"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"following": {
|
31 |
+
"route": "/following.php",
|
32 |
+
"method": "GET",
|
33 |
+
"name": "User Following",
|
34 |
+
"description": "Get users that a specific user follows.",
|
35 |
+
"payload": {
|
36 |
+
"screenname": "Twitter username without the @ symbol",
|
37 |
+
"rest_id": "Optional parameter that overwrites the screenname",
|
38 |
+
"cursor": "Optional pagination cursor"
|
39 |
+
}
|
40 |
+
},
|
41 |
+
"followers": {
|
42 |
+
"route": "/followers.php",
|
43 |
+
"method": "GET",
|
44 |
+
"name": "User Followers",
|
45 |
+
"description": "Get followers of a specific user.",
|
46 |
+
"payload": {
|
47 |
+
"screenname": "Twitter username without the @ symbol",
|
48 |
+
"cursor": "Optional pagination cursor"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"search": {
|
52 |
+
"route": "/search.php",
|
53 |
+
"method": "GET",
|
54 |
+
"name": "Twitter Search",
|
55 |
+
"description": "Search for tweets with a specific query.",
|
56 |
+
"payload": {
|
57 |
+
"query": "Search query string",
|
58 |
+
"cursor": "Optional pagination cursor",
|
59 |
+
"search_type": "Optional search type (e.g. 'Top')"
|
60 |
+
}
|
61 |
+
},
|
62 |
+
"replies": {
|
63 |
+
"route": "/replies.php",
|
64 |
+
"method": "GET",
|
65 |
+
"name": "User Replies",
|
66 |
+
"description": "Get replies made by a user.",
|
67 |
+
"payload": {
|
68 |
+
"screenname": "Twitter username without the @ symbol",
|
69 |
+
"cursor": "Optional pagination cursor"
|
70 |
+
}
|
71 |
+
},
|
72 |
+
"check_retweet": {
|
73 |
+
"route": "/checkretweet.php",
|
74 |
+
"method": "GET",
|
75 |
+
"name": "Check Retweet",
|
76 |
+
"description": "Check if a user has retweeted a specific tweet.",
|
77 |
+
"payload": {
|
78 |
+
"screenname": "Twitter username without the @ symbol",
|
79 |
+
"tweet_id": "ID of the tweet to check"
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"tweet": {
|
83 |
+
"route": "/tweet.php",
|
84 |
+
"method": "GET",
|
85 |
+
"name": "Get Tweet",
|
86 |
+
"description": "Get details of a specific tweet by ID.",
|
87 |
+
"payload": {
|
88 |
+
"id": "ID of the tweet"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
"tweet_thread": {
|
92 |
+
"route": "/tweet_thread.php",
|
93 |
+
"method": "GET",
|
94 |
+
"name": "Get Tweet Thread",
|
95 |
+
"description": "Get a thread of tweets starting from a specific tweet ID.",
|
96 |
+
"payload": {
|
97 |
+
"id": "ID of the tweet",
|
98 |
+
"cursor": "Optional pagination cursor"
|
99 |
+
}
|
100 |
+
},
|
101 |
+
"retweets": {
|
102 |
+
"route": "/retweets.php",
|
103 |
+
"method": "GET",
|
104 |
+
"name": "Get Retweets",
|
105 |
+
"description": "Get users who retweeted a specific tweet.",
|
106 |
+
"payload": {
|
107 |
+
"id": "ID of the tweet",
|
108 |
+
"cursor": "Optional pagination cursor"
|
109 |
+
}
|
110 |
+
},
|
111 |
+
"latest_replies": {
|
112 |
+
"route": "/latest_replies.php",
|
113 |
+
"method": "GET",
|
114 |
+
"name": "Get Latest Replies",
|
115 |
+
"description": "Get the latest replies to a specific tweet.",
|
116 |
+
"payload": {
|
117 |
+
"id": "ID of the tweet",
|
118 |
+
"cursor": "Optional pagination cursor"
|
119 |
+
}
|
120 |
+
}
|
121 |
+
}
|
122 |
+
base_url = "https://twitter-api45.p.rapidapi.com"
|
123 |
+
super().__init__(base_url, endpoints)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
from dotenv import load_dotenv
|
128 |
+
load_dotenv()
|
129 |
+
tool = TwitterProvider()
|
130 |
+
|
131 |
+
# Example for getting user info
|
132 |
+
user_info = tool.call_endpoint(
|
133 |
+
route="user_info",
|
134 |
+
payload={
|
135 |
+
"screenname": "elonmusk",
|
136 |
+
# "rest_id": "44196397" # Optional, uncomment to use user ID instead of screenname
|
137 |
+
}
|
138 |
+
)
|
139 |
+
print("User Info:", user_info)
|
140 |
+
|
141 |
+
# Example for getting user timeline
|
142 |
+
timeline = tool.call_endpoint(
|
143 |
+
route="timeline",
|
144 |
+
payload={
|
145 |
+
"screenname": "elonmusk",
|
146 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
147 |
+
}
|
148 |
+
)
|
149 |
+
print("Timeline:", timeline)
|
150 |
+
|
151 |
+
# Example for getting user following
|
152 |
+
following = tool.call_endpoint(
|
153 |
+
route="following",
|
154 |
+
payload={
|
155 |
+
"screenname": "elonmusk",
|
156 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
157 |
+
}
|
158 |
+
)
|
159 |
+
print("Following:", following)
|
160 |
+
|
161 |
+
# Example for getting user followers
|
162 |
+
followers = tool.call_endpoint(
|
163 |
+
route="followers",
|
164 |
+
payload={
|
165 |
+
"screenname": "elonmusk",
|
166 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
167 |
+
}
|
168 |
+
)
|
169 |
+
print("Followers:", followers)
|
170 |
+
|
171 |
+
# Example for searching tweets
|
172 |
+
search_results = tool.call_endpoint(
|
173 |
+
route="search",
|
174 |
+
payload={
|
175 |
+
"query": "cybertruck",
|
176 |
+
"search_type": "Top" # Optional, defaults to Top
|
177 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
178 |
+
}
|
179 |
+
)
|
180 |
+
print("Search Results:", search_results)
|
181 |
+
|
182 |
+
# Example for getting user replies
|
183 |
+
replies = tool.call_endpoint(
|
184 |
+
route="replies",
|
185 |
+
payload={
|
186 |
+
"screenname": "elonmusk",
|
187 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
188 |
+
}
|
189 |
+
)
|
190 |
+
print("Replies:", replies)
|
191 |
+
|
192 |
+
# Example for checking if user retweeted a tweet
|
193 |
+
check_retweet = tool.call_endpoint(
|
194 |
+
route="check_retweet",
|
195 |
+
payload={
|
196 |
+
"screenname": "elonmusk",
|
197 |
+
"tweet_id": "1671370010743263233"
|
198 |
+
}
|
199 |
+
)
|
200 |
+
print("Check Retweet:", check_retweet)
|
201 |
+
|
202 |
+
# Example for getting tweet details
|
203 |
+
tweet = tool.call_endpoint(
|
204 |
+
route="tweet",
|
205 |
+
payload={
|
206 |
+
"id": "1671370010743263233"
|
207 |
+
}
|
208 |
+
)
|
209 |
+
print("Tweet:", tweet)
|
210 |
+
|
211 |
+
# Example for getting a tweet thread
|
212 |
+
tweet_thread = tool.call_endpoint(
|
213 |
+
route="tweet_thread",
|
214 |
+
payload={
|
215 |
+
"id": "1738106896777699464",
|
216 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
217 |
+
}
|
218 |
+
)
|
219 |
+
print("Tweet Thread:", tweet_thread)
|
220 |
+
|
221 |
+
# Example for getting retweets of a tweet
|
222 |
+
retweets = tool.call_endpoint(
|
223 |
+
route="retweets",
|
224 |
+
payload={
|
225 |
+
"id": "1700199139470942473",
|
226 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
227 |
+
}
|
228 |
+
)
|
229 |
+
print("Retweets:", retweets)
|
230 |
+
|
231 |
+
# Example for getting latest replies to a tweet
|
232 |
+
latest_replies = tool.call_endpoint(
|
233 |
+
route="latest_replies",
|
234 |
+
payload={
|
235 |
+
"id": "1738106896777699464",
|
236 |
+
# "cursor": "optional-cursor-value" # Optional for pagination
|
237 |
+
}
|
238 |
+
)
|
239 |
+
print("Latest Replies:", latest_replies)
|
240 |
+
|
WorkflowEditor.tsx
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// /home/ubuntu/visionos-frontend/src/components/WorkflowEditor.tsx
|
2 |
+
import React, { useCallback } from 'react';
|
3 |
+
import ReactFlow, {
|
4 |
+
MiniMap,
|
5 |
+
Controls,
|
6 |
+
Background,
|
7 |
+
useNodesState,
|
8 |
+
useEdgesState,
|
9 |
+
addEdge,
|
10 |
+
Connection,
|
11 |
+
Edge,
|
12 |
+
Node,
|
13 |
+
} from 'reactflow';
|
14 |
+
|
15 |
+
import 'reactflow/dist/style.css';
|
16 |
+
|
17 |
+
// Initial nodes and edges (example)
|
18 |
+
const initialNodes: Node[] = [
|
19 |
+
{ id: '1', position: { x: 0, y: 0 }, data: { label: 'Start Node' }, type: 'input' },
|
20 |
+
{ id: '2', position: { x: 0, y: 100 }, data: { label: 'Agent Task' } },
|
21 |
+
];
|
22 |
+
const initialEdges: Edge[] = [{ id: 'e1-2', source: '1', target: '2' }];
|
23 |
+
|
24 |
+
const WorkflowEditor: React.FC = () => {
|
25 |
+
const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
|
26 |
+
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
|
27 |
+
|
28 |
+
const onConnect = useCallback(
|
29 |
+
(params: Edge | Connection) => setEdges((eds) => addEdge(params, eds)),
|
30 |
+
[setEdges],
|
31 |
+
);
|
32 |
+
|
33 |
+
return (
|
34 |
+
<div style={{ width: '100%', height: '100%' }} className="border rounded-lg shadow-md">
|
35 |
+
<ReactFlow
|
36 |
+
nodes={nodes}
|
37 |
+
edges={edges}
|
38 |
+
onNodesChange={onNodesChange}
|
39 |
+
onEdgesChange={onEdgesChange}
|
40 |
+
onConnect={onConnect}
|
41 |
+
fitView
|
42 |
+
>
|
43 |
+
<Controls />
|
44 |
+
<MiniMap />
|
45 |
+
<Background variant="dots" gap={12} size={1} />
|
46 |
+
</ReactFlow>
|
47 |
+
</div>
|
48 |
+
);
|
49 |
+
};
|
50 |
+
|
51 |
+
export default WorkflowEditor;
|
52 |
+
|
YahooFinanceProvider.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
4 |
+
|
5 |
+
|
6 |
+
class YahooFinanceProvider(RapidDataProviderBase):
|
7 |
+
def __init__(self):
|
8 |
+
endpoints: Dict[str, EndpointSchema] = {
|
9 |
+
"get_tickers": {
|
10 |
+
"route": "/v2/markets/tickers",
|
11 |
+
"method": "GET",
|
12 |
+
"name": "Yahoo Finance Tickers",
|
13 |
+
"description": "Get financial tickers from Yahoo Finance with various filters and parameters.",
|
14 |
+
"payload": {
|
15 |
+
"page": "Page number for pagination (optional, default: 1)",
|
16 |
+
"type": "Asset class type (required): STOCKS, ETF, MUTUALFUNDS, or FUTURES",
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"search": {
|
20 |
+
"route": "/v1/markets/search",
|
21 |
+
"method": "GET",
|
22 |
+
"name": "Yahoo Finance Search",
|
23 |
+
"description": "Search for financial instruments on Yahoo Finance",
|
24 |
+
"payload": {
|
25 |
+
"search": "Search term (required)",
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"get_news": {
|
29 |
+
"route": "/v2/markets/news",
|
30 |
+
"method": "GET",
|
31 |
+
"name": "Yahoo Finance News",
|
32 |
+
"description": "Get news related to specific tickers from Yahoo Finance",
|
33 |
+
"payload": {
|
34 |
+
"tickers": "Stock symbol (optional, e.g., AAPL)",
|
35 |
+
"type": "News type (optional): ALL, VIDEO, or PRESS_RELEASE",
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"get_stock_module": {
|
39 |
+
"route": "/v1/markets/stock/modules",
|
40 |
+
"method": "GET",
|
41 |
+
"name": "Yahoo Finance Stock Module",
|
42 |
+
"description": "Get detailed information about a specific stock module",
|
43 |
+
"payload": {
|
44 |
+
"ticker": "Company ticker symbol (required, e.g., AAPL)",
|
45 |
+
"module": "Module to retrieve (required): asset-profile, financial-data, earnings, etc.",
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"get_sma": {
|
49 |
+
"route": "/v1/markets/indicators/sma",
|
50 |
+
"method": "GET",
|
51 |
+
"name": "Yahoo Finance SMA Indicator",
|
52 |
+
"description": "Get Simple Moving Average (SMA) indicator data for a stock",
|
53 |
+
"payload": {
|
54 |
+
"symbol": "Stock symbol (required, e.g., AAPL)",
|
55 |
+
"interval": "Time interval (required): 5m, 15m, 30m, 1h, 1d, 1wk, 1mo, 3mo",
|
56 |
+
"series_type": "Series type (required): open, close, high, low",
|
57 |
+
"time_period": "Number of data points used for calculation (required)",
|
58 |
+
"limit": "Limit the number of results (optional, default: 50)",
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"get_rsi": {
|
62 |
+
"route": "/v1/markets/indicators/rsi",
|
63 |
+
"method": "GET",
|
64 |
+
"name": "Yahoo Finance RSI Indicator",
|
65 |
+
"description": "Get Relative Strength Index (RSI) indicator data for a stock",
|
66 |
+
"payload": {
|
67 |
+
"symbol": "Stock symbol (required, e.g., AAPL)",
|
68 |
+
"interval": "Time interval (required): 5m, 15m, 30m, 1h, 1d, 1wk, 1mo, 3mo",
|
69 |
+
"series_type": "Series type (required): open, close, high, low",
|
70 |
+
"time_period": "Number of data points used for calculation (required)",
|
71 |
+
"limit": "Limit the number of results (optional, default: 50)",
|
72 |
+
}
|
73 |
+
},
|
74 |
+
"get_earnings_calendar": {
|
75 |
+
"route": "/v1/markets/calendar/earnings",
|
76 |
+
"method": "GET",
|
77 |
+
"name": "Yahoo Finance Earnings Calendar",
|
78 |
+
"description": "Get earnings calendar data for a specific date",
|
79 |
+
"payload": {
|
80 |
+
"date": "Calendar date in yyyy-mm-dd format (optional, e.g., 2023-11-30)",
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"get_insider_trades": {
|
84 |
+
"route": "/v1/markets/insider-trades",
|
85 |
+
"method": "GET",
|
86 |
+
"name": "Yahoo Finance Insider Trades",
|
87 |
+
"description": "Get recent insider trading activity",
|
88 |
+
"payload": {}
|
89 |
+
},
|
90 |
+
}
|
91 |
+
base_url = "https://yahoo-finance15.p.rapidapi.com/api"
|
92 |
+
super().__init__(base_url, endpoints)
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
from dotenv import load_dotenv
|
97 |
+
load_dotenv()
|
98 |
+
tool = YahooFinanceProvider()
|
99 |
+
|
100 |
+
# Example for getting stock tickers
|
101 |
+
tickers_result = tool.call_endpoint(
|
102 |
+
route="get_tickers",
|
103 |
+
payload={
|
104 |
+
"page": 1,
|
105 |
+
"type": "STOCKS"
|
106 |
+
}
|
107 |
+
)
|
108 |
+
print("Tickers Result:", tickers_result)
|
109 |
+
|
110 |
+
# Example for searching financial instruments
|
111 |
+
search_result = tool.call_endpoint(
|
112 |
+
route="search",
|
113 |
+
payload={
|
114 |
+
"search": "AA"
|
115 |
+
}
|
116 |
+
)
|
117 |
+
print("Search Result:", search_result)
|
118 |
+
|
119 |
+
# Example for getting financial news
|
120 |
+
news_result = tool.call_endpoint(
|
121 |
+
route="get_news",
|
122 |
+
payload={
|
123 |
+
"tickers": "AAPL",
|
124 |
+
"type": "ALL"
|
125 |
+
}
|
126 |
+
)
|
127 |
+
print("News Result:", news_result)
|
128 |
+
|
129 |
+
# Example for getting stock asset profile module
|
130 |
+
stock_module_result = tool.call_endpoint(
|
131 |
+
route="get_stock_module",
|
132 |
+
payload={
|
133 |
+
"ticker": "AAPL",
|
134 |
+
"module": "asset-profile"
|
135 |
+
}
|
136 |
+
)
|
137 |
+
print("Asset Profile Result:", stock_module_result)
|
138 |
+
|
139 |
+
# Example for getting financial data module
|
140 |
+
financial_data_result = tool.call_endpoint(
|
141 |
+
route="get_stock_module",
|
142 |
+
payload={
|
143 |
+
"ticker": "AAPL",
|
144 |
+
"module": "financial-data"
|
145 |
+
}
|
146 |
+
)
|
147 |
+
print("Financial Data Result:", financial_data_result)
|
148 |
+
|
149 |
+
# Example for getting SMA indicator data
|
150 |
+
sma_result = tool.call_endpoint(
|
151 |
+
route="get_sma",
|
152 |
+
payload={
|
153 |
+
"symbol": "AAPL",
|
154 |
+
"interval": "5m",
|
155 |
+
"series_type": "close",
|
156 |
+
"time_period": "50",
|
157 |
+
"limit": "50"
|
158 |
+
}
|
159 |
+
)
|
160 |
+
print("SMA Result:", sma_result)
|
161 |
+
|
162 |
+
# Example for getting RSI indicator data
|
163 |
+
rsi_result = tool.call_endpoint(
|
164 |
+
route="get_rsi",
|
165 |
+
payload={
|
166 |
+
"symbol": "AAPL",
|
167 |
+
"interval": "5m",
|
168 |
+
"series_type": "close",
|
169 |
+
"time_period": "50",
|
170 |
+
"limit": "50"
|
171 |
+
}
|
172 |
+
)
|
173 |
+
print("RSI Result:", rsi_result)
|
174 |
+
|
175 |
+
# Example for getting earnings calendar data
|
176 |
+
earnings_calendar_result = tool.call_endpoint(
|
177 |
+
route="get_earnings_calendar",
|
178 |
+
payload={
|
179 |
+
"date": "2023-11-30"
|
180 |
+
}
|
181 |
+
)
|
182 |
+
print("Earnings Calendar Result:", earnings_calendar_result)
|
183 |
+
|
184 |
+
# Example for getting insider trades
|
185 |
+
insider_trades_result = tool.call_endpoint(
|
186 |
+
route="get_insider_trades",
|
187 |
+
payload={}
|
188 |
+
)
|
189 |
+
print("Insider Trades Result:", insider_trades_result)
|
190 |
+
|
ZillowProvider.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class ZillowProvider(RapidDataProviderBase):
|
10 |
+
def __init__(self):
|
11 |
+
endpoints: Dict[str, EndpointSchema] = {
|
12 |
+
"search": {
|
13 |
+
"route": "/search",
|
14 |
+
"method": "GET",
|
15 |
+
"name": "Zillow Property Search",
|
16 |
+
"description": "Search for properties by neighborhood, city, or ZIP code with various filters.",
|
17 |
+
"payload": {
|
18 |
+
"location": "Location can be an address, neighborhood, city, or ZIP code (required)",
|
19 |
+
"page": "Page number for pagination (optional, default: 0)",
|
20 |
+
"output": "Output format: json, csv, xlsx (optional, default: json)",
|
21 |
+
"status": "Status of properties: forSale, forRent, recentlySold (optional, default: forSale)",
|
22 |
+
"sortSelection": "Sorting criteria (optional, default: priorityscore)",
|
23 |
+
"listing_type": "Listing type: by_agent, by_owner_other (optional, default: by_agent)",
|
24 |
+
"doz": "Days on Zillow: any, 1, 7, 14, 30, 90, 6m, 12m, 24m, 36m (optional, default: any)",
|
25 |
+
"price_min": "Minimum price (optional)",
|
26 |
+
"price_max": "Maximum price (optional)",
|
27 |
+
"sqft_min": "Minimum square footage (optional)",
|
28 |
+
"sqft_max": "Maximum square footage (optional)",
|
29 |
+
"beds_min": "Minimum number of bedrooms (optional)",
|
30 |
+
"beds_max": "Maximum number of bedrooms (optional)",
|
31 |
+
"baths_min": "Minimum number of bathrooms (optional)",
|
32 |
+
"baths_max": "Maximum number of bathrooms (optional)",
|
33 |
+
"built_min": "Minimum year built (optional)",
|
34 |
+
"built_max": "Maximum year built (optional)",
|
35 |
+
"lotSize_min": "Minimum lot size in sqft (optional)",
|
36 |
+
"lotSize_max": "Maximum lot size in sqft (optional)",
|
37 |
+
"keywords": "Keywords to search for (optional)"
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"search_address": {
|
41 |
+
"route": "/search_address",
|
42 |
+
"method": "GET",
|
43 |
+
"name": "Zillow Address Search",
|
44 |
+
"description": "Search for a specific property by its full address.",
|
45 |
+
"payload": {
|
46 |
+
"address": "Full property address (required)"
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"propertyV2": {
|
50 |
+
"route": "/propertyV2",
|
51 |
+
"method": "GET",
|
52 |
+
"name": "Zillow Property Details",
|
53 |
+
"description": "Get detailed information about a specific property by zpid or URL.",
|
54 |
+
"payload": {
|
55 |
+
"zpid": "Zillow property ID (optional if URL is provided)",
|
56 |
+
"url": "Property details URL (optional if zpid is provided)"
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"zestimate_history": {
|
60 |
+
"route": "/zestimate_history",
|
61 |
+
"method": "GET",
|
62 |
+
"name": "Zillow Zestimate History",
|
63 |
+
"description": "Get historical Zestimate values for a specific property.",
|
64 |
+
"payload": {
|
65 |
+
"zpid": "Zillow property ID (optional if URL is provided)",
|
66 |
+
"url": "Property details URL (optional if zpid is provided)"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"similar_properties": {
|
70 |
+
"route": "/similar_properties",
|
71 |
+
"method": "GET",
|
72 |
+
"name": "Zillow Similar Properties",
|
73 |
+
"description": "Find properties similar to a specific property.",
|
74 |
+
"payload": {
|
75 |
+
"zpid": "Zillow property ID (optional if URL or address is provided)",
|
76 |
+
"url": "Property details URL (optional if zpid or address is provided)",
|
77 |
+
"address": "Property address (optional if zpid or URL is provided)"
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"mortgage_rates": {
|
81 |
+
"route": "/mortgage/rates",
|
82 |
+
"method": "GET",
|
83 |
+
"name": "Zillow Mortgage Rates",
|
84 |
+
"description": "Get current mortgage rates for different loan programs and conditions.",
|
85 |
+
"payload": {
|
86 |
+
"program": "Loan program (required): Fixed30Year, Fixed20Year, Fixed15Year, Fixed10Year, ARM3, ARM5, ARM7, etc.",
|
87 |
+
"state": "State abbreviation (optional, default: US)",
|
88 |
+
"refinance": "Whether this is for refinancing (optional, default: false)",
|
89 |
+
"loanType": "Type of loan: Conventional, etc. (optional)",
|
90 |
+
"loanAmount": "Loan amount category: Micro, SmallConforming, Conforming, SuperConforming, Jumbo (optional)",
|
91 |
+
"loanToValue": "Loan to value ratio: Normal, High, VeryHigh (optional)",
|
92 |
+
"creditScore": "Credit score category: Low, High, VeryHigh (optional)",
|
93 |
+
"duration": "Duration in days (optional, default: 30)"
|
94 |
+
}
|
95 |
+
},
|
96 |
+
}
|
97 |
+
base_url = "https://zillow56.p.rapidapi.com"
|
98 |
+
super().__init__(base_url, endpoints)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
from dotenv import load_dotenv
|
103 |
+
from time import sleep
|
104 |
+
load_dotenv()
|
105 |
+
tool = ZillowProvider()
|
106 |
+
|
107 |
+
# Example for searching properties in Houston
|
108 |
+
search_result = tool.call_endpoint(
|
109 |
+
route="search",
|
110 |
+
payload={
|
111 |
+
"location": "houston, tx",
|
112 |
+
"status": "forSale",
|
113 |
+
"sortSelection": "priorityscore",
|
114 |
+
"listing_type": "by_agent",
|
115 |
+
"doz": "any"
|
116 |
+
}
|
117 |
+
)
|
118 |
+
logger.debug("Search Result: %s", search_result)
|
119 |
+
logger.debug("***")
|
120 |
+
logger.debug("***")
|
121 |
+
logger.debug("***")
|
122 |
+
sleep(1)
|
123 |
+
# Example for searching by address
|
124 |
+
address_result = tool.call_endpoint(
|
125 |
+
route="search_address",
|
126 |
+
payload={
|
127 |
+
"address": "1161 Natchez Dr College Station Texas 77845"
|
128 |
+
}
|
129 |
+
)
|
130 |
+
logger.debug("Address Search Result: %s", address_result)
|
131 |
+
logger.debug("***")
|
132 |
+
logger.debug("***")
|
133 |
+
logger.debug("***")
|
134 |
+
sleep(1)
|
135 |
+
# Example for getting property details
|
136 |
+
property_result = tool.call_endpoint(
|
137 |
+
route="propertyV2",
|
138 |
+
payload={
|
139 |
+
"zpid": "7594920"
|
140 |
+
}
|
141 |
+
)
|
142 |
+
logger.debug("Property Details Result: %s", property_result)
|
143 |
+
sleep(1)
|
144 |
+
logger.debug("***")
|
145 |
+
logger.debug("***")
|
146 |
+
logger.debug("***")
|
147 |
+
|
148 |
+
# Example for getting zestimate history
|
149 |
+
zestimate_result = tool.call_endpoint(
|
150 |
+
route="zestimate_history",
|
151 |
+
payload={
|
152 |
+
"zpid": "20476226"
|
153 |
+
}
|
154 |
+
)
|
155 |
+
logger.debug("Zestimate History Result: %s", zestimate_result)
|
156 |
+
sleep(1)
|
157 |
+
logger.debug("***")
|
158 |
+
logger.debug("***")
|
159 |
+
logger.debug("***")
|
160 |
+
# Example for getting similar properties
|
161 |
+
similar_result = tool.call_endpoint(
|
162 |
+
route="similar_properties",
|
163 |
+
payload={
|
164 |
+
"zpid": "28253016"
|
165 |
+
}
|
166 |
+
)
|
167 |
+
logger.debug("Similar Properties Result: %s", similar_result)
|
168 |
+
sleep(1)
|
169 |
+
logger.debug("***")
|
170 |
+
logger.debug("***")
|
171 |
+
logger.debug("***")
|
172 |
+
# Example for getting mortgage rates
|
173 |
+
mortgage_result = tool.call_endpoint(
|
174 |
+
route="mortgage_rates",
|
175 |
+
payload={
|
176 |
+
"program": "Fixed30Year",
|
177 |
+
"state": "US",
|
178 |
+
"refinance": "false",
|
179 |
+
"loanType": "Conventional",
|
180 |
+
"loanAmount": "Conforming",
|
181 |
+
"loanToValue": "Normal",
|
182 |
+
"creditScore": "Low",
|
183 |
+
"duration": "30"
|
184 |
+
}
|
185 |
+
)
|
186 |
+
logger.debug("Mortgage Rates Result: %s", mortgage_result)
|
187 |
+
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Utility functions and constants for agent tools
|
added_tokens.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</tool_call>": 151658,
|
3 |
+
"<tool_call>": 151657,
|
4 |
+
"<|AUDIO|>": 151646,
|
5 |
+
"<|IMAGE|>": 151655,
|
6 |
+
"<|VIDEO|>": 151656,
|
7 |
+
"<|audio_bos|>": 151647,
|
8 |
+
"<|audio_eos|>": 151648,
|
9 |
+
"<|box_end|>": 151649,
|
10 |
+
"<|endoftext|>": 151643,
|
11 |
+
"<|file_sep|>": 151664,
|
12 |
+
"<|fim_middle|>": 151660,
|
13 |
+
"<|fim_pad|>": 151662,
|
14 |
+
"<|fim_prefix|>": 151659,
|
15 |
+
"<|fim_suffix|>": 151661,
|
16 |
+
"<|im_end|>": 151645,
|
17 |
+
"<|im_start|>": 151644,
|
18 |
+
"<|quad_end|>": 151651,
|
19 |
+
"<|quad_start|>": 151650,
|
20 |
+
"<|repo_name|>": 151663,
|
21 |
+
"<|vision_bos|>": 151652,
|
22 |
+
"<|vision_eos|>": 151653,
|
23 |
+
"<|vision_pad|>": 151654
|
24 |
+
}
|
agent.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Task:
|
7 |
+
id: str
|
8 |
+
status: str = 'pending'
|
9 |
+
input_data: dict = None
|
10 |
+
result_data: dict = None
|
11 |
+
|
12 |
+
class VisionOSAgent:
|
13 |
+
def __init__(self, model_name='distilbert-base-uncased'):
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
16 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
self.model.to(self.device)
|
18 |
+
self.current_task: Task = None
|
19 |
+
|
20 |
+
def set_task(self, task: Task):
|
21 |
+
self.current_task = task
|
22 |
+
|
23 |
+
def execute_task(self):
|
24 |
+
if self.current_task is None or self.current_task.status != 'pending':
|
25 |
+
raise ValueError("No pending task to execute")
|
26 |
+
input_text = self.current_task.input_data.get('text', 'Default input')
|
27 |
+
inputs = self.tokenizer(input_text, return_tensors='pt').to(self.device)
|
28 |
+
with torch.no_grad():
|
29 |
+
outputs = self.model(**inputs)
|
30 |
+
prediction = torch.argmax(outputs.logits, dim=1).item()
|
31 |
+
self.current_task.result_data = {'prediction': prediction}
|
32 |
+
self.current_task.status = 'completed'
|
33 |
+
return self.current_task
|
34 |
+
|
35 |
+
# Example usage
|
36 |
+
if __name__ == "__main__":
|
37 |
+
agent = VisionOSAgent()
|
38 |
+
task = Task(id='task1', input_data={'text': 'Analyze vehicle status: operational'})
|
39 |
+
agent.set_task(task)
|
40 |
+
result = agent.execute_task()
|
41 |
+
print(f"Task ID: {result.id}, Status: {result.status}, Result: {result.result_data}")
|
alembic.ini
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# alembic.ini
|
2 |
+
# A generic configuration for the Alembic migration tool.
|
3 |
+
|
4 |
+
[alembic]
|
5 |
+
# path to migration scripts
|
6 |
+
script_location = alembic
|
7 |
+
|
8 |
+
# template used to generate migration files
|
9 |
+
# file_template = %%(rev)s_%%(slug)s
|
10 |
+
|
11 |
+
# timezone to use when rendering the date within the migration file
|
12 |
+
# as well as the filename.
|
13 |
+
# If specified, requires the python-dateutil library that is installable
|
14 |
+
# with pip install python-dateutil.
|
15 |
+
# Any required timezone name works, such as UTC, PST8PDT, Europe/London
|
16 |
+
# If None, the system default timezone is used.
|
17 |
+
# timezone =
|
18 |
+
|
19 |
+
# sys.path path, will be prepended to sys.path if present.
|
20 |
+
# defaults to the current working directory.
|
21 |
+
# prepend_sys_path = .
|
22 |
+
|
23 |
+
# Logging configuration
|
24 |
+
[loggers]
|
25 |
+
keys = root,sqlalchemy,alembic
|
26 |
+
|
27 |
+
[handlers]
|
28 |
+
keys = console
|
29 |
+
|
30 |
+
[formatters]
|
31 |
+
keys = generic
|
32 |
+
|
33 |
+
[logger_root]
|
34 |
+
level = WARN
|
35 |
+
handlers = console
|
36 |
+
qualname =
|
37 |
+
|
38 |
+
[logger_sqlalchemy]
|
39 |
+
level = WARN
|
40 |
+
handlers =
|
41 |
+
qualname = sqlalchemy.engine
|
42 |
+
|
43 |
+
[logger_alembic]
|
44 |
+
level = INFO
|
45 |
+
handlers =
|
46 |
+
qualname = alembic
|
47 |
+
|
48 |
+
[handler_console]
|
49 |
+
class = StreamHandler
|
50 |
+
args = (sys.stderr,)
|
51 |
+
level = NOTSET
|
52 |
+
formatter = generic
|
53 |
+
|
54 |
+
[formatter_generic]
|
55 |
+
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
56 |
+
datefmt = %%H:%%M:%%S
|
57 |
+
|
58 |
+
# Database configuration
|
59 |
+
# Replace sqlalchemy.url with the actual database connection string from your settings
|
60 |
+
# This will be dynamically loaded in env.py
|
61 |
+
s sqlalchemy.url = driver://user:pass@localhost/dbname
|
62 |
+
|
api.cpython-311.pyc
ADDED
Binary file (918 Bytes). View file
|
|
api.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, Form, Depends, Request
|
5 |
+
from fastapi.responses import Response, JSONResponse
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
from utils.logger import logger
|
9 |
+
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
|
10 |
+
from sandbox.sandbox import get_or_start_sandbox
|
11 |
+
from services.supabase import DBConnection
|
12 |
+
from agent.api import get_or_create_project_sandbox
|
13 |
+
|
14 |
+
|
15 |
+
# Initialize shared resources
|
16 |
+
router = APIRouter(tags=["sandbox"])
|
17 |
+
db = None
|
18 |
+
|
19 |
+
def initialize(_db: DBConnection):
|
20 |
+
"""Initialize the sandbox API with resources from the main API."""
|
21 |
+
global db
|
22 |
+
db = _db
|
23 |
+
logger.info("Initialized sandbox API with database connection")
|
24 |
+
|
25 |
+
class FileInfo(BaseModel):
|
26 |
+
"""Model for file information"""
|
27 |
+
name: str
|
28 |
+
path: str
|
29 |
+
is_dir: bool
|
30 |
+
size: int
|
31 |
+
mod_time: str
|
32 |
+
permissions: Optional[str] = None
|
33 |
+
|
34 |
+
async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None):
|
35 |
+
"""
|
36 |
+
Verify that a user has access to a specific sandbox based on account membership.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
client: The Supabase client
|
40 |
+
sandbox_id: The sandbox ID to check access for
|
41 |
+
user_id: The user ID to check permissions for. Can be None for public resource access.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
dict: Project data containing sandbox information
|
45 |
+
|
46 |
+
Raises:
|
47 |
+
HTTPException: If the user doesn't have access to the sandbox or sandbox doesn't exist
|
48 |
+
"""
|
49 |
+
# Find the project that owns this sandbox
|
50 |
+
project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute()
|
51 |
+
|
52 |
+
if not project_result.data or len(project_result.data) == 0:
|
53 |
+
raise HTTPException(status_code=404, detail="Sandbox not found")
|
54 |
+
|
55 |
+
project_data = project_result.data[0]
|
56 |
+
|
57 |
+
if project_data.get('is_public'):
|
58 |
+
return project_data
|
59 |
+
|
60 |
+
# For private projects, we must have a user_id
|
61 |
+
if not user_id:
|
62 |
+
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
63 |
+
|
64 |
+
account_id = project_data.get('account_id')
|
65 |
+
|
66 |
+
# Verify account membership
|
67 |
+
if account_id:
|
68 |
+
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
69 |
+
if account_user_result.data and len(account_user_result.data) > 0:
|
70 |
+
return project_data
|
71 |
+
|
72 |
+
raise HTTPException(status_code=403, detail="Not authorized to access this sandbox")
|
73 |
+
|
74 |
+
async def get_sandbox_by_id_safely(client, sandbox_id: str):
|
75 |
+
"""
|
76 |
+
Safely retrieve a sandbox object by its ID, using the project that owns it.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
client: The Supabase client
|
80 |
+
sandbox_id: The sandbox ID to retrieve
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Sandbox: The sandbox object
|
84 |
+
|
85 |
+
Raises:
|
86 |
+
HTTPException: If the sandbox doesn't exist or can't be retrieved
|
87 |
+
"""
|
88 |
+
# Find the project that owns this sandbox
|
89 |
+
project_result = await client.table('projects').select('project_id').filter('sandbox->>id', 'eq', sandbox_id).execute()
|
90 |
+
|
91 |
+
if not project_result.data or len(project_result.data) == 0:
|
92 |
+
logger.error(f"No project found for sandbox ID: {sandbox_id}")
|
93 |
+
raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox ID")
|
94 |
+
|
95 |
+
project_id = project_result.data[0]['project_id']
|
96 |
+
logger.debug(f"Found project {project_id} for sandbox {sandbox_id}")
|
97 |
+
|
98 |
+
try:
|
99 |
+
# Get the sandbox
|
100 |
+
sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
|
101 |
+
|
102 |
+
# Verify we got the right sandbox
|
103 |
+
if retrieved_sandbox_id != sandbox_id:
|
104 |
+
logger.warning(f"Retrieved sandbox ID {retrieved_sandbox_id} doesn't match requested ID {sandbox_id} for project {project_id}")
|
105 |
+
# Fall back to the direct method if IDs don't match (shouldn't happen but just in case)
|
106 |
+
sandbox = await get_or_start_sandbox(sandbox_id)
|
107 |
+
|
108 |
+
return sandbox
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error retrieving sandbox {sandbox_id}: {str(e)}")
|
111 |
+
raise HTTPException(status_code=500, detail=f"Failed to retrieve sandbox: {str(e)}")
|
112 |
+
|
113 |
+
@router.post("/sandboxes/{sandbox_id}/files")
|
114 |
+
async def create_file(
|
115 |
+
sandbox_id: str,
|
116 |
+
path: str = Form(...),
|
117 |
+
file: UploadFile = File(...),
|
118 |
+
request: Request = None,
|
119 |
+
user_id: Optional[str] = Depends(get_optional_user_id)
|
120 |
+
):
|
121 |
+
"""Create a file in the sandbox using direct file upload"""
|
122 |
+
logger.info(f"Received file upload request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
|
123 |
+
client = await db.client
|
124 |
+
|
125 |
+
# Verify the user has access to this sandbox
|
126 |
+
await verify_sandbox_access(client, sandbox_id, user_id)
|
127 |
+
|
128 |
+
try:
|
129 |
+
# Get sandbox using the safer method
|
130 |
+
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
|
131 |
+
|
132 |
+
# Read file content directly from the uploaded file
|
133 |
+
content = await file.read()
|
134 |
+
|
135 |
+
# Create file using raw binary content
|
136 |
+
sandbox.fs.upload_file(path, content)
|
137 |
+
logger.info(f"File created at {path} in sandbox {sandbox_id}")
|
138 |
+
|
139 |
+
return {"status": "success", "created": True, "path": path}
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}")
|
142 |
+
raise HTTPException(status_code=500, detail=str(e))
|
143 |
+
|
144 |
+
# For backward compatibility, keep the JSON version too
|
145 |
+
@router.post("/sandboxes/{sandbox_id}/files/json")
|
146 |
+
async def create_file_json(
|
147 |
+
sandbox_id: str,
|
148 |
+
file_request: dict,
|
149 |
+
request: Request = None,
|
150 |
+
user_id: Optional[str] = Depends(get_optional_user_id)
|
151 |
+
):
|
152 |
+
"""Create a file in the sandbox using JSON (legacy support)"""
|
153 |
+
logger.info(f"Received JSON file creation request for sandbox {sandbox_id}, user_id: {user_id}")
|
154 |
+
client = await db.client
|
155 |
+
|
156 |
+
# Verify the user has access to this sandbox
|
157 |
+
await verify_sandbox_access(client, sandbox_id, user_id)
|
158 |
+
|
159 |
+
try:
|
160 |
+
# Get sandbox using the safer method
|
161 |
+
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
|
162 |
+
|
163 |
+
# Get file path and content
|
164 |
+
path = file_request.get("path")
|
165 |
+
content = file_request.get("content", "")
|
166 |
+
|
167 |
+
if not path:
|
168 |
+
logger.error(f"Missing file path in request for sandbox {sandbox_id}")
|
169 |
+
raise HTTPException(status_code=400, detail="File path is required")
|
170 |
+
|
171 |
+
# Convert string content to bytes
|
172 |
+
if isinstance(content, str):
|
173 |
+
content = content.encode('utf-8')
|
174 |
+
|
175 |
+
# Create file
|
176 |
+
sandbox.fs.upload_file(path, content)
|
177 |
+
logger.info(f"File created at {path} in sandbox {sandbox_id}")
|
178 |
+
|
179 |
+
return {"status": "success", "created": True, "path": path}
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}")
|
182 |
+
raise HTTPException(status_code=500, detail=str(e))
|
183 |
+
|
184 |
+
@router.get("/sandboxes/{sandbox_id}/files")
|
185 |
+
async def list_files(
|
186 |
+
sandbox_id: str,
|
187 |
+
path: str,
|
188 |
+
request: Request = None,
|
189 |
+
user_id: Optional[str] = Depends(get_optional_user_id)
|
190 |
+
):
|
191 |
+
"""List files and directories at the specified path"""
|
192 |
+
logger.info(f"Received list files request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
|
193 |
+
client = await db.client
|
194 |
+
|
195 |
+
# Verify the user has access to this sandbox
|
196 |
+
await verify_sandbox_access(client, sandbox_id, user_id)
|
197 |
+
|
198 |
+
try:
|
199 |
+
# Get sandbox using the safer method
|
200 |
+
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
|
201 |
+
|
202 |
+
# List files
|
203 |
+
files = sandbox.fs.list_files(path)
|
204 |
+
result = []
|
205 |
+
|
206 |
+
for file in files:
|
207 |
+
# Convert file information to our model
|
208 |
+
# Ensure forward slashes are used for paths, regardless of OS
|
209 |
+
full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}"
|
210 |
+
file_info = FileInfo(
|
211 |
+
name=file.name,
|
212 |
+
path=full_path, # Use the constructed path
|
213 |
+
is_dir=file.is_dir,
|
214 |
+
size=file.size,
|
215 |
+
mod_time=str(file.mod_time),
|
216 |
+
permissions=getattr(file, 'permissions', None)
|
217 |
+
)
|
218 |
+
result.append(file_info)
|
219 |
+
|
220 |
+
logger.info(f"Successfully listed {len(result)} files in sandbox {sandbox_id}")
|
221 |
+
return {"files": [file.dict() for file in result]}
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"Error listing files in sandbox {sandbox_id}: {str(e)}")
|
224 |
+
raise HTTPException(status_code=500, detail=str(e))
|
225 |
+
|
226 |
+
@router.get("/sandboxes/{sandbox_id}/files/content")
|
227 |
+
async def read_file(
|
228 |
+
sandbox_id: str,
|
229 |
+
path: str,
|
230 |
+
request: Request = None,
|
231 |
+
user_id: Optional[str] = Depends(get_optional_user_id)
|
232 |
+
):
|
233 |
+
"""Read a file from the sandbox"""
|
234 |
+
logger.info(f"Received file read request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
|
235 |
+
client = await db.client
|
236 |
+
|
237 |
+
# Verify the user has access to this sandbox
|
238 |
+
await verify_sandbox_access(client, sandbox_id, user_id)
|
239 |
+
|
240 |
+
try:
|
241 |
+
# Get sandbox using the safer method
|
242 |
+
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
|
243 |
+
|
244 |
+
# Read file
|
245 |
+
content = sandbox.fs.download_file(path)
|
246 |
+
|
247 |
+
# Return a Response object with the content directly
|
248 |
+
filename = os.path.basename(path)
|
249 |
+
logger.info(f"Successfully read file {filename} from sandbox {sandbox_id}")
|
250 |
+
return Response(
|
251 |
+
content=content,
|
252 |
+
media_type="application/octet-stream",
|
253 |
+
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
254 |
+
)
|
255 |
+
except Exception as e:
|
256 |
+
logger.error(f"Error reading file in sandbox {sandbox_id}: {str(e)}")
|
257 |
+
raise HTTPException(status_code=500, detail=str(e))
|
258 |
+
|
259 |
+
@router.post("/project/{project_id}/sandbox/ensure-active")
|
260 |
+
async def ensure_project_sandbox_active(
|
261 |
+
project_id: str,
|
262 |
+
request: Request = None,
|
263 |
+
user_id: Optional[str] = Depends(get_optional_user_id)
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Ensure that a project's sandbox is active and running.
|
267 |
+
Checks the sandbox status and starts it if it's not running.
|
268 |
+
"""
|
269 |
+
logger.info(f"Received ensure sandbox active request for project {project_id}, user_id: {user_id}")
|
270 |
+
client = await db.client
|
271 |
+
|
272 |
+
# Find the project and sandbox information
|
273 |
+
project_result = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
274 |
+
|
275 |
+
if not project_result.data or len(project_result.data) == 0:
|
276 |
+
logger.error(f"Project not found: {project_id}")
|
277 |
+
raise HTTPException(status_code=404, detail="Project not found")
|
278 |
+
|
279 |
+
project_data = project_result.data[0]
|
280 |
+
|
281 |
+
# For public projects, no authentication is needed
|
282 |
+
if not project_data.get('is_public'):
|
283 |
+
# For private projects, we must have a user_id
|
284 |
+
if not user_id:
|
285 |
+
logger.error(f"Authentication required for private project {project_id}")
|
286 |
+
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
287 |
+
|
288 |
+
account_id = project_data.get('account_id')
|
289 |
+
|
290 |
+
# Verify account membership
|
291 |
+
if account_id:
|
292 |
+
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
293 |
+
if not (account_user_result.data and len(account_user_result.data) > 0):
|
294 |
+
logger.error(f"User {user_id} not authorized to access project {project_id}")
|
295 |
+
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
296 |
+
|
297 |
+
try:
|
298 |
+
# Get or create the sandbox
|
299 |
+
logger.info(f"Ensuring sandbox is active for project {project_id}")
|
300 |
+
sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
|
301 |
+
|
302 |
+
logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}")
|
303 |
+
|
304 |
+
return {
|
305 |
+
"status": "success",
|
306 |
+
"sandbox_id": sandbox_id,
|
307 |
+
"message": "Sandbox is active"
|
308 |
+
}
|
309 |
+
except Exception as e:
|
310 |
+
logger.error(f"Error ensuring sandbox is active for project {project_id}: {str(e)}")
|
311 |
+
raise HTTPException(status_code=500, detail=str(e))
|
api.py.bak
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from fastapi.responses import JSONResponse
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from agentpress.thread_manager import ThreadManager
|
6 |
+
from services.supabase import DBConnection
|
7 |
+
from datetime import datetime, timezone
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from utils.config import config, EnvMode
|
10 |
+
import asyncio
|
11 |
+
from utils.logger import logger
|
12 |
+
import uuid
|
13 |
+
import time
|
14 |
+
from collections import OrderedDict
|
15 |
+
|
16 |
+
# Import the agent API module
|
17 |
+
from agent import api as agent_api
|
18 |
+
from sandbox import api as sandbox_api
|
19 |
+
|
20 |
+
# Load environment variables (these will be available through config)
|
21 |
+
load_dotenv()
|
22 |
+
|
23 |
+
# Initialize managers
|
24 |
+
db = DBConnection()
|
25 |
+
thread_manager = None
|
26 |
+
instance_id = "single"
|
27 |
+
|
28 |
+
# Rate limiter state
|
29 |
+
ip_tracker = OrderedDict()
|
30 |
+
MAX_CONCURRENT_IPS = 25
|
31 |
+
|
32 |
+
@asynccontextmanager
|
33 |
+
async def lifespan(app: FastAPI):
|
34 |
+
# Startup
|
35 |
+
global thread_manager
|
36 |
+
logger.info(f"Starting up FastAPI application with instance ID: {instance_id} in {config.ENV_MODE.value} mode")
|
37 |
+
|
38 |
+
try:
|
39 |
+
# Initialize database
|
40 |
+
await db.initialize()
|
41 |
+
thread_manager = ThreadManager()
|
42 |
+
|
43 |
+
# Initialize the agent API with shared resources
|
44 |
+
agent_api.initialize(
|
45 |
+
thread_manager,
|
46 |
+
db,
|
47 |
+
instance_id
|
48 |
+
)
|
49 |
+
|
50 |
+
# Initialize the sandbox API with shared resources
|
51 |
+
sandbox_api.initialize(db)
|
52 |
+
|
53 |
+
# Initialize Redis connection
|
54 |
+
from services import redis
|
55 |
+
try:
|
56 |
+
await redis.initialize_async()
|
57 |
+
logger.info("Redis connection initialized successfully")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Failed to initialize Redis connection: {e}")
|
60 |
+
# Continue without Redis - the application will handle Redis failures gracefully
|
61 |
+
|
62 |
+
# Start background tasks
|
63 |
+
asyncio.create_task(agent_api.restore_running_agent_runs())
|
64 |
+
|
65 |
+
yield
|
66 |
+
|
67 |
+
# Clean up agent resources
|
68 |
+
logger.info("Cleaning up agent resources")
|
69 |
+
await agent_api.cleanup()
|
70 |
+
|
71 |
+
# Clean up Redis connection
|
72 |
+
try:
|
73 |
+
logger.info("Closing Redis connection")
|
74 |
+
await redis.close()
|
75 |
+
logger.info("Redis connection closed successfully")
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error closing Redis connection: {e}")
|
78 |
+
|
79 |
+
# Clean up database connection
|
80 |
+
logger.info("Disconnecting from database")
|
81 |
+
await db.disconnect()
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error during application startup: {e}")
|
84 |
+
raise
|
85 |
+
|
86 |
+
app = FastAPI(lifespan=lifespan)
|
87 |
+
|
88 |
+
@app.middleware("http")
|
89 |
+
async def log_requests_middleware(request: Request, call_next):
|
90 |
+
start_time = time.time()
|
91 |
+
client_ip = request.client.host
|
92 |
+
method = request.method
|
93 |
+
url = str(request.url)
|
94 |
+
path = request.url.path
|
95 |
+
query_params = str(request.query_params)
|
96 |
+
|
97 |
+
# Log the incoming request
|
98 |
+
logger.info(f"Request started: {method} {path} from {client_ip} | Query: {query_params}")
|
99 |
+
|
100 |
+
try:
|
101 |
+
response = await call_next(request)
|
102 |
+
process_time = time.time() - start_time
|
103 |
+
logger.debug(f"Request completed: {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}s")
|
104 |
+
return response
|
105 |
+
except Exception as e:
|
106 |
+
process_time = time.time() - start_time
|
107 |
+
logger.error(f"Request failed: {method} {path} | Error: {str(e)} | Time: {process_time:.2f}s")
|
108 |
+
raise
|
109 |
+
|
110 |
+
# Define allowed origins based on environment
|
111 |
+
allowed_origins = ["https://www.suna.so", "https://suna.so", "https://staging.suna.so", "http://localhost:3000"]
|
112 |
+
|
113 |
+
# Add staging-specific origins
|
114 |
+
if config.ENV_MODE == EnvMode.STAGING:
|
115 |
+
allowed_origins.append("http://localhost:3000")
|
116 |
+
|
117 |
+
# Add local-specific origins
|
118 |
+
if config.ENV_MODE == EnvMode.LOCAL:
|
119 |
+
allowed_origins.append("http://localhost:3000")
|
120 |
+
|
121 |
+
app.add_middleware(
|
122 |
+
CORSMiddleware,
|
123 |
+
allow_origins=allowed_origins,
|
124 |
+
allow_credentials=True,
|
125 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
126 |
+
allow_headers=["Content-Type", "Authorization"],
|
127 |
+
)
|
128 |
+
|
129 |
+
# Include the agent router with a prefix
|
130 |
+
app.include_router(agent_api.router, prefix="/api")
|
131 |
+
|
132 |
+
# Include the sandbox router with a prefix
|
133 |
+
app.include_router(sandbox_api.router, prefix="/api")
|
134 |
+
|
135 |
+
@app.get("/api/health")
|
136 |
+
async def health_check():
|
137 |
+
"""Health check endpoint to verify API is working."""
|
138 |
+
logger.info("Health check endpoint called")
|
139 |
+
return {
|
140 |
+
"status": "ok",
|
141 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
142 |
+
"instance_id": instance_id
|
143 |
+
}
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
import uvicorn
|
147 |
+
|
148 |
+
workers = 2
|
149 |
+
|
150 |
+
logger.info(f"Starting server on 0.0.0.0:8000 with {workers} workers")
|
151 |
+
uvicorn.run(
|
152 |
+
"api:app",
|
153 |
+
host="0.0.0.0",
|
154 |
+
port=8000,
|
155 |
+
workers=workers
|
156 |
+
)
|
api_keys.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /home/ubuntu/visionos_farm/daedalus/api/v1/endpoints/api_keys.py
|
2 |
+
|
3 |
+
import uuid
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
7 |
+
from sqlalchemy.orm import Session
|
8 |
+
|
9 |
+
from shared import schemas
|
10 |
+
from database import models, session
|
11 |
+
from crud import crud_api_key
|
12 |
+
from . import dependencies # Assuming dependencies.py for auth
|
13 |
+
|
14 |
+
router = APIRouter()
|
15 |
+
|
16 |
+
@router.post("/", response_model=schemas.ApiKeyCreateResponse, status_code=status.HTTP_201_CREATED)
|
17 |
+
def create_api_key(
|
18 |
+
api_key_in: schemas.ApiKeyCreate,
|
19 |
+
db: Session = Depends(session.get_db),
|
20 |
+
current_user: models.User = Depends(dependencies.get_current_active_user) # Protect key creation
|
21 |
+
):
|
22 |
+
"""Generate a new API key for the currently authenticated user."""
|
23 |
+
db_api_key, generated_key_string = crud_api_key.create_api_key(
|
24 |
+
db=db, api_key_in=api_key_in, user_id=current_user.id
|
25 |
+
)
|
26 |
+
# Important: The raw key is only returned ONCE upon creation.
|
27 |
+
return schemas.ApiKeyCreateResponse(
|
28 |
+
**db_api_key.__dict__, # Convert DB model to dict
|
29 |
+
key=generated_key_string # Add the raw key to the response
|
30 |
+
)
|
31 |
+
|
32 |
+
@router.get("/", response_model=List[schemas.ApiKey])
|
33 |
+
def read_api_keys(
|
34 |
+
skip: int = 0,
|
35 |
+
limit: int = 100,
|
36 |
+
db: Session = Depends(session.get_db),
|
37 |
+
current_user: models.User = Depends(dependencies.get_current_active_user)
|
38 |
+
):
|
39 |
+
"""Retrieve API keys for the currently authenticated user."""
|
40 |
+
api_keys = crud_api_key.get_api_keys_by_user(db=db, user_id=current_user.id, skip=skip, limit=limit)
|
41 |
+
return api_keys
|
42 |
+
|
43 |
+
@router.put("/{api_key_id}", response_model=schemas.ApiKey)
|
44 |
+
def update_api_key(
|
45 |
+
api_key_id: uuid.UUID,
|
46 |
+
api_key_in: schemas.ApiKeyUpdate,
|
47 |
+
db: Session = Depends(session.get_db),
|
48 |
+
current_user: models.User = Depends(dependencies.get_current_active_user)
|
49 |
+
):
|
50 |
+
"""Update an API key belonging to the current user."""
|
51 |
+
db_api_key = db.get(models.ApiKey, api_key_id)
|
52 |
+
if not db_api_key or db_api_key.user_id != current_user.id:
|
53 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API Key not found")
|
54 |
+
updated_key = crud_api_key.update_api_key(db=db, db_api_key=db_api_key, api_key_in=api_key_in)
|
55 |
+
return updated_key
|
56 |
+
|
57 |
+
@router.delete("/{api_key_id}", status_code=status.HTTP_204_NO_CONTENT)
|
58 |
+
def delete_api_key(
|
59 |
+
api_key_id: uuid.UUID,
|
60 |
+
db: Session = Depends(session.get_db),
|
61 |
+
current_user: models.User = Depends(dependencies.get_current_active_user)
|
62 |
+
):
|
63 |
+
"""Delete an API key belonging to the current user."""
|
64 |
+
deleted_key = crud_api_key.delete_api_key(db=db, api_key_id=api_key_id, user_id=current_user.id)
|
65 |
+
if not deleted_key:
|
66 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API Key not found")
|
67 |
+
return None # Return 204 No Content on success
|
68 |
+
|
architecture_diagram.svg
ADDED
|
auth_utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException, Request, Depends
|
2 |
+
from typing import Optional, List, Dict, Any
|
3 |
+
import jwt
|
4 |
+
from jwt.exceptions import PyJWTError
|
5 |
+
from utils.logger import logger
|
6 |
+
|
7 |
+
# This function extracts the user ID from Supabase JWT
|
8 |
+
async def get_current_user_id(request: Request) -> str:
|
9 |
+
"""
|
10 |
+
Extract and verify the user ID from the JWT in the Authorization header.
|
11 |
+
|
12 |
+
This function is used as a dependency in FastAPI routes to ensure the user
|
13 |
+
is authenticated and to provide the user ID for authorization checks.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
request: The FastAPI request object
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
str: The user ID extracted from the JWT
|
20 |
+
|
21 |
+
Raises:
|
22 |
+
HTTPException: If no valid token is found or if the token is invalid
|
23 |
+
"""
|
24 |
+
auth_header = request.headers.get('Authorization')
|
25 |
+
|
26 |
+
if not auth_header or not auth_header.startswith('Bearer '):
|
27 |
+
raise HTTPException(
|
28 |
+
status_code=401,
|
29 |
+
detail="No valid authentication credentials found",
|
30 |
+
headers={"WWW-Authenticate": "Bearer"}
|
31 |
+
)
|
32 |
+
|
33 |
+
token = auth_header.split(' ')[1]
|
34 |
+
|
35 |
+
try:
|
36 |
+
# For Supabase JWT, we just need to decode and extract the user ID
|
37 |
+
# The actual validation is handled by Supabase's RLS
|
38 |
+
payload = jwt.decode(token, options={"verify_signature": False})
|
39 |
+
|
40 |
+
# Supabase stores the user ID in the 'sub' claim
|
41 |
+
user_id = payload.get('sub')
|
42 |
+
|
43 |
+
if not user_id:
|
44 |
+
raise HTTPException(
|
45 |
+
status_code=401,
|
46 |
+
detail="Invalid token payload",
|
47 |
+
headers={"WWW-Authenticate": "Bearer"}
|
48 |
+
)
|
49 |
+
|
50 |
+
return user_id
|
51 |
+
|
52 |
+
except PyJWTError:
|
53 |
+
raise HTTPException(
|
54 |
+
status_code=401,
|
55 |
+
detail="Invalid token",
|
56 |
+
headers={"WWW-Authenticate": "Bearer"}
|
57 |
+
)
|
58 |
+
|
59 |
+
async def get_user_id_from_stream_auth(
|
60 |
+
request: Request,
|
61 |
+
token: Optional[str] = None
|
62 |
+
) -> str:
|
63 |
+
"""
|
64 |
+
Extract and verify the user ID from either the Authorization header or query parameter token.
|
65 |
+
This function is specifically designed for streaming endpoints that need to support both
|
66 |
+
header-based and query parameter-based authentication (for EventSource compatibility).
|
67 |
+
|
68 |
+
Args:
|
69 |
+
request: The FastAPI request object
|
70 |
+
token: Optional token from query parameters
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
str: The user ID extracted from the JWT
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
HTTPException: If no valid token is found or if the token is invalid
|
77 |
+
"""
|
78 |
+
# Try to get user_id from token in query param (for EventSource which can't set headers)
|
79 |
+
if token:
|
80 |
+
try:
|
81 |
+
# For Supabase JWT, we just need to decode and extract the user ID
|
82 |
+
payload = jwt.decode(token, options={"verify_signature": False})
|
83 |
+
user_id = payload.get('sub')
|
84 |
+
if user_id:
|
85 |
+
return user_id
|
86 |
+
except Exception:
|
87 |
+
pass
|
88 |
+
|
89 |
+
# If no valid token in query param, try to get it from the Authorization header
|
90 |
+
auth_header = request.headers.get('Authorization')
|
91 |
+
if auth_header and auth_header.startswith('Bearer '):
|
92 |
+
try:
|
93 |
+
# Extract token from header
|
94 |
+
header_token = auth_header.split(' ')[1]
|
95 |
+
payload = jwt.decode(header_token, options={"verify_signature": False})
|
96 |
+
user_id = payload.get('sub')
|
97 |
+
if user_id:
|
98 |
+
return user_id
|
99 |
+
except Exception:
|
100 |
+
pass
|
101 |
+
|
102 |
+
# If we still don't have a user_id, return authentication error
|
103 |
+
raise HTTPException(
|
104 |
+
status_code=401,
|
105 |
+
detail="No valid authentication credentials found",
|
106 |
+
headers={"WWW-Authenticate": "Bearer"}
|
107 |
+
)
|
108 |
+
async def verify_thread_access(client, thread_id: str, user_id: str):
|
109 |
+
"""
|
110 |
+
Verify that a user has access to a specific thread based on account membership.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
client: The Supabase client
|
114 |
+
thread_id: The thread ID to check access for
|
115 |
+
user_id: The user ID to check permissions for
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
bool: True if the user has access
|
119 |
+
|
120 |
+
Raises:
|
121 |
+
HTTPException: If the user doesn't have access to the thread
|
122 |
+
"""
|
123 |
+
# Query the thread to get account information
|
124 |
+
thread_result = await client.table('threads').select('*,project_id').eq('thread_id', thread_id).execute()
|
125 |
+
|
126 |
+
if not thread_result.data or len(thread_result.data) == 0:
|
127 |
+
raise HTTPException(status_code=404, detail="Thread not found")
|
128 |
+
|
129 |
+
thread_data = thread_result.data[0]
|
130 |
+
|
131 |
+
# Check if project is public
|
132 |
+
project_id = thread_data.get('project_id')
|
133 |
+
if project_id:
|
134 |
+
project_result = await client.table('projects').select('is_public').eq('project_id', project_id).execute()
|
135 |
+
if project_result.data and len(project_result.data) > 0:
|
136 |
+
if project_result.data[0].get('is_public'):
|
137 |
+
return True
|
138 |
+
|
139 |
+
account_id = thread_data.get('account_id')
|
140 |
+
# When using service role, we need to manually check account membership instead of using current_user_account_role
|
141 |
+
if account_id:
|
142 |
+
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
143 |
+
if account_user_result.data and len(account_user_result.data) > 0:
|
144 |
+
return True
|
145 |
+
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
|
146 |
+
|
147 |
+
async def get_optional_user_id(request: Request) -> Optional[str]:
|
148 |
+
"""
|
149 |
+
Extract the user ID from the JWT in the Authorization header if present,
|
150 |
+
but don't require authentication. Returns None if no valid token is found.
|
151 |
+
|
152 |
+
This function is used for endpoints that support both authenticated and
|
153 |
+
unauthenticated access (like public projects).
|
154 |
+
|
155 |
+
Args:
|
156 |
+
request: The FastAPI request object
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
Optional[str]: The user ID extracted from the JWT, or None if no valid token
|
160 |
+
"""
|
161 |
+
auth_header = request.headers.get('Authorization')
|
162 |
+
|
163 |
+
if not auth_header or not auth_header.startswith('Bearer '):
|
164 |
+
return None
|
165 |
+
|
166 |
+
token = auth_header.split(' ')[1]
|
167 |
+
|
168 |
+
try:
|
169 |
+
# For Supabase JWT, we just need to decode and extract the user ID
|
170 |
+
payload = jwt.decode(token, options={"verify_signature": False})
|
171 |
+
|
172 |
+
# Supabase stores the user ID in the 'sub' claim
|
173 |
+
user_id = payload.get('sub')
|
174 |
+
|
175 |
+
return user_id
|
176 |
+
except PyJWTError:
|
177 |
+
return None
|
base.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /home/ubuntu/visionos_farm/shared/tools/base.py
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import Dict, Any
|
4 |
+
|
5 |
+
class Tool(ABC):
|
6 |
+
"""Abstract base class for all dynamically loaded tools."""
|
7 |
+
|
8 |
+
@property
|
9 |
+
@abstractmethod
|
10 |
+
def name(self) -> str:
|
11 |
+
"""Unique name of the tool."""
|
12 |
+
pass
|
13 |
+
|
14 |
+
@property
|
15 |
+
@abstractmethod
|
16 |
+
def description(self) -> str:
|
17 |
+
"""Description of what the tool does."""
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
22 |
+
"""Executes the tool with the given parameters.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
parameters: A dictionary of parameters required by the tool.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A dictionary containing the result of the tool execution.
|
29 |
+
Should include a 'status' key ('SUCCESS' or 'FAILURE')
|
30 |
+
and potentially 'output' or 'error' keys.
|
31 |
+
"""
|
32 |
+
pass
|
33 |
+
|
billing.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timezone
|
2 |
+
from typing import Dict, Optional, Tuple
|
3 |
+
from utils.logger import logger
|
4 |
+
from utils.config import config, EnvMode
|
5 |
+
|
6 |
+
# Define subscription tiers and their monthly limits (in minutes)
|
7 |
+
SUBSCRIPTION_TIERS = {
|
8 |
+
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8},
|
9 |
+
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300},
|
10 |
+
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400}
|
11 |
+
}
|
12 |
+
|
13 |
+
async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
|
14 |
+
"""Get the current subscription for an account."""
|
15 |
+
result = await client.schema('basejump').from_('billing_subscriptions') \
|
16 |
+
.select('*') \
|
17 |
+
.eq('account_id', account_id) \
|
18 |
+
.eq('status', 'active') \
|
19 |
+
.order('created', desc=True) \
|
20 |
+
.limit(1) \
|
21 |
+
.execute()
|
22 |
+
|
23 |
+
if result.data and len(result.data) > 0:
|
24 |
+
return result.data[0]
|
25 |
+
return None
|
26 |
+
|
27 |
+
async def calculate_monthly_usage(client, account_id: str) -> float:
|
28 |
+
"""Calculate total agent run minutes for the current month for an account."""
|
29 |
+
# Get start of current month in UTC
|
30 |
+
now = datetime.now(timezone.utc)
|
31 |
+
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
32 |
+
|
33 |
+
# First get all threads for this account
|
34 |
+
threads_result = await client.table('threads') \
|
35 |
+
.select('thread_id') \
|
36 |
+
.eq('account_id', account_id) \
|
37 |
+
.execute()
|
38 |
+
|
39 |
+
if not threads_result.data:
|
40 |
+
return 0.0
|
41 |
+
|
42 |
+
thread_ids = [t['thread_id'] for t in threads_result.data]
|
43 |
+
|
44 |
+
# Then get all agent runs for these threads in current month
|
45 |
+
runs_result = await client.table('agent_runs') \
|
46 |
+
.select('started_at, completed_at') \
|
47 |
+
.in_('thread_id', thread_ids) \
|
48 |
+
.gte('started_at', start_of_month.isoformat()) \
|
49 |
+
.execute()
|
50 |
+
|
51 |
+
if not runs_result.data:
|
52 |
+
return 0.0
|
53 |
+
|
54 |
+
# Calculate total minutes
|
55 |
+
total_seconds = 0
|
56 |
+
now_ts = now.timestamp()
|
57 |
+
|
58 |
+
for run in runs_result.data:
|
59 |
+
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
|
60 |
+
if run['completed_at']:
|
61 |
+
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
|
62 |
+
else:
|
63 |
+
# For running jobs, use current time
|
64 |
+
end_time = now_ts
|
65 |
+
|
66 |
+
total_seconds += (end_time - start_time)
|
67 |
+
|
68 |
+
return total_seconds / 60 # Convert to minutes
|
69 |
+
|
70 |
+
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
|
71 |
+
"""
|
72 |
+
Check if an account can run agents based on their subscription and usage.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
|
76 |
+
"""
|
77 |
+
if config.ENV_MODE == EnvMode.LOCAL:
|
78 |
+
logger.info("Running in local development mode - billing checks are disabled")
|
79 |
+
return True, "Local development mode - billing disabled", {
|
80 |
+
"price_id": "local_dev",
|
81 |
+
"plan_name": "Local Development",
|
82 |
+
"minutes_limit": "no limit"
|
83 |
+
}
|
84 |
+
|
85 |
+
# For staging/production, check subscription status
|
86 |
+
|
87 |
+
# Get current subscription
|
88 |
+
subscription = await get_account_subscription(client, account_id)
|
89 |
+
|
90 |
+
# If no subscription, they can use free tier
|
91 |
+
if not subscription:
|
92 |
+
subscription = {
|
93 |
+
'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier
|
94 |
+
'plan_name': 'free'
|
95 |
+
}
|
96 |
+
|
97 |
+
# if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC':
|
98 |
+
# return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription
|
99 |
+
|
100 |
+
# Get tier info
|
101 |
+
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
|
102 |
+
if not tier_info:
|
103 |
+
return False, "Invalid subscription tier", subscription
|
104 |
+
|
105 |
+
# Calculate current month's usage
|
106 |
+
current_usage = await calculate_monthly_usage(client, account_id)
|
107 |
+
|
108 |
+
# Check if within limits
|
109 |
+
if current_usage >= tier_info['minutes']:
|
110 |
+
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
|
111 |
+
|
112 |
+
return True, "OK", subscription
|
113 |
+
|
114 |
+
# Helper function to get account ID from thread
|
115 |
+
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
|
116 |
+
"""Get the account ID associated with a thread."""
|
117 |
+
result = await client.table('threads') \
|
118 |
+
.select('account_id') \
|
119 |
+
.eq('thread_id', thread_id) \
|
120 |
+
.limit(1) \
|
121 |
+
.execute()
|
122 |
+
|
123 |
+
if result.data and len(result.data) > 0:
|
124 |
+
return result.data[0]['account_id']
|
125 |
+
return None
|
browser_api.py
ADDED
@@ -0,0 +1,2063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, APIRouter, HTTPException, Body
|
2 |
+
from playwright.async_api import async_playwright, Browser, Page, ElementHandle
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import Optional, List, Dict, Any, Union
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import re
|
9 |
+
import base64
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from datetime import datetime
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
from functools import cached_property
|
15 |
+
import traceback
|
16 |
+
import pytesseract
|
17 |
+
from PIL import Image
|
18 |
+
import io
|
19 |
+
|
20 |
+
#######################################################
|
21 |
+
# Action model definitions
|
22 |
+
#######################################################
|
23 |
+
|
24 |
+
class Position(BaseModel):
|
25 |
+
x: int
|
26 |
+
y: int
|
27 |
+
|
28 |
+
class ClickElementAction(BaseModel):
|
29 |
+
index: int
|
30 |
+
|
31 |
+
class ClickCoordinatesAction(BaseModel):
|
32 |
+
x: int
|
33 |
+
y: int
|
34 |
+
|
35 |
+
class GoToUrlAction(BaseModel):
|
36 |
+
url: str
|
37 |
+
|
38 |
+
class InputTextAction(BaseModel):
|
39 |
+
index: int
|
40 |
+
text: str
|
41 |
+
|
42 |
+
class ScrollAction(BaseModel):
|
43 |
+
amount: Optional[int] = None
|
44 |
+
|
45 |
+
class SendKeysAction(BaseModel):
|
46 |
+
keys: str
|
47 |
+
|
48 |
+
class SearchGoogleAction(BaseModel):
|
49 |
+
query: str
|
50 |
+
|
51 |
+
class SwitchTabAction(BaseModel):
|
52 |
+
page_id: int
|
53 |
+
|
54 |
+
class OpenTabAction(BaseModel):
|
55 |
+
url: str
|
56 |
+
|
57 |
+
class CloseTabAction(BaseModel):
|
58 |
+
page_id: int
|
59 |
+
|
60 |
+
class NoParamsAction(BaseModel):
|
61 |
+
pass
|
62 |
+
|
63 |
+
class DragDropAction(BaseModel):
|
64 |
+
element_source: Optional[str] = None
|
65 |
+
element_target: Optional[str] = None
|
66 |
+
element_source_offset: Optional[Position] = None
|
67 |
+
element_target_offset: Optional[Position] = None
|
68 |
+
coord_source_x: Optional[int] = None
|
69 |
+
coord_source_y: Optional[int] = None
|
70 |
+
coord_target_x: Optional[int] = None
|
71 |
+
coord_target_y: Optional[int] = None
|
72 |
+
steps: Optional[int] = 10
|
73 |
+
delay_ms: Optional[int] = 5
|
74 |
+
|
75 |
+
class DoneAction(BaseModel):
|
76 |
+
success: bool = True
|
77 |
+
text: str = ""
|
78 |
+
|
79 |
+
#######################################################
|
80 |
+
# DOM Structure Models
|
81 |
+
#######################################################
|
82 |
+
|
83 |
+
@dataclass
|
84 |
+
class CoordinateSet:
|
85 |
+
x: int = 0
|
86 |
+
y: int = 0
|
87 |
+
width: int = 0
|
88 |
+
height: int = 0
|
89 |
+
|
90 |
+
@dataclass
|
91 |
+
class ViewportInfo:
|
92 |
+
width: int = 0
|
93 |
+
height: int = 0
|
94 |
+
scroll_x: int = 0
|
95 |
+
scroll_y: int = 0
|
96 |
+
|
97 |
+
@dataclass
|
98 |
+
class HashedDomElement:
|
99 |
+
tag_name: str
|
100 |
+
attributes: Dict[str, str]
|
101 |
+
is_visible: bool
|
102 |
+
page_coordinates: Optional[CoordinateSet] = None
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class DOMBaseNode:
|
106 |
+
is_visible: bool
|
107 |
+
parent: Optional['DOMElementNode'] = None
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class DOMTextNode(DOMBaseNode):
|
111 |
+
text: str = field(default="")
|
112 |
+
type: str = 'TEXT_NODE'
|
113 |
+
|
114 |
+
def has_parent_with_highlight_index(self) -> bool:
|
115 |
+
current = self.parent
|
116 |
+
while current is not None:
|
117 |
+
if current.highlight_index is not None:
|
118 |
+
return True
|
119 |
+
current = current.parent
|
120 |
+
return False
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class DOMElementNode(DOMBaseNode):
|
124 |
+
tag_name: str = field(default="")
|
125 |
+
xpath: str = field(default="")
|
126 |
+
attributes: Dict[str, str] = field(default_factory=dict)
|
127 |
+
children: List['DOMBaseNode'] = field(default_factory=list)
|
128 |
+
|
129 |
+
is_interactive: bool = False
|
130 |
+
is_top_element: bool = False
|
131 |
+
is_in_viewport: bool = False
|
132 |
+
shadow_root: bool = False
|
133 |
+
highlight_index: Optional[int] = None
|
134 |
+
viewport_coordinates: Optional[CoordinateSet] = None
|
135 |
+
page_coordinates: Optional[CoordinateSet] = None
|
136 |
+
viewport_info: Optional[ViewportInfo] = None
|
137 |
+
|
138 |
+
def __repr__(self) -> str:
|
139 |
+
tag_str = f'<{self.tag_name}'
|
140 |
+
for key, value in self.attributes.items():
|
141 |
+
tag_str += f' {key}="{value}"'
|
142 |
+
tag_str += '>'
|
143 |
+
|
144 |
+
extras = []
|
145 |
+
if self.is_interactive:
|
146 |
+
extras.append('interactive')
|
147 |
+
if self.is_top_element:
|
148 |
+
extras.append('top')
|
149 |
+
if self.highlight_index is not None:
|
150 |
+
extras.append(f'highlight:{self.highlight_index}')
|
151 |
+
|
152 |
+
if extras:
|
153 |
+
tag_str += f' [{", ".join(extras)}]'
|
154 |
+
|
155 |
+
return tag_str
|
156 |
+
|
157 |
+
@cached_property
|
158 |
+
def hash(self) -> HashedDomElement:
|
159 |
+
return HashedDomElement(
|
160 |
+
tag_name=self.tag_name,
|
161 |
+
attributes=self.attributes,
|
162 |
+
is_visible=self.is_visible,
|
163 |
+
page_coordinates=self.page_coordinates
|
164 |
+
)
|
165 |
+
|
166 |
+
def get_all_text_till_next_clickable_element(self, max_depth: int = -1) -> str:
|
167 |
+
text_parts = []
|
168 |
+
|
169 |
+
def collect_text(node: DOMBaseNode, current_depth: int) -> None:
|
170 |
+
if max_depth != -1 and current_depth > max_depth:
|
171 |
+
return
|
172 |
+
|
173 |
+
if isinstance(node, DOMElementNode) and node != self and node.highlight_index is not None:
|
174 |
+
return
|
175 |
+
|
176 |
+
if isinstance(node, DOMTextNode):
|
177 |
+
text_parts.append(node.text)
|
178 |
+
elif isinstance(node, DOMElementNode):
|
179 |
+
for child in node.children:
|
180 |
+
collect_text(child, current_depth + 1)
|
181 |
+
|
182 |
+
collect_text(self, 0)
|
183 |
+
return '\n'.join(text_parts).strip()
|
184 |
+
|
185 |
+
def clickable_elements_to_string(self, include_attributes: list[str] | None = None) -> str:
|
186 |
+
"""Convert the processed DOM content to HTML."""
|
187 |
+
formatted_text = []
|
188 |
+
|
189 |
+
def process_node(node: DOMBaseNode, depth: int) -> None:
|
190 |
+
if isinstance(node, DOMElementNode):
|
191 |
+
# Add element with highlight_index
|
192 |
+
if node.highlight_index is not None:
|
193 |
+
attributes_str = ''
|
194 |
+
text = node.get_all_text_till_next_clickable_element()
|
195 |
+
|
196 |
+
# Process attributes for display
|
197 |
+
display_attributes = []
|
198 |
+
if include_attributes:
|
199 |
+
for key, value in node.attributes.items():
|
200 |
+
if key in include_attributes and value and value != node.tag_name:
|
201 |
+
if text and value in text:
|
202 |
+
continue # Skip if attribute value is already in the text
|
203 |
+
display_attributes.append(str(value))
|
204 |
+
|
205 |
+
attributes_str = ';'.join(display_attributes)
|
206 |
+
|
207 |
+
# Build the element string
|
208 |
+
line = f'[{node.highlight_index}]<{node.tag_name}'
|
209 |
+
|
210 |
+
# Add important attributes for identification
|
211 |
+
for attr_name in ['id', 'href', 'name', 'value', 'type']:
|
212 |
+
if attr_name in node.attributes and node.attributes[attr_name]:
|
213 |
+
line += f' {attr_name}="{node.attributes[attr_name]}"'
|
214 |
+
|
215 |
+
# Add the text content if available
|
216 |
+
if text:
|
217 |
+
line += f'> {text}'
|
218 |
+
elif attributes_str:
|
219 |
+
line += f'> {attributes_str}'
|
220 |
+
else:
|
221 |
+
# If no text and no attributes, use the tag name
|
222 |
+
line += f'> {node.tag_name.upper()}'
|
223 |
+
|
224 |
+
line += ' </>'
|
225 |
+
formatted_text.append(line)
|
226 |
+
|
227 |
+
# Process children regardless
|
228 |
+
for child in node.children:
|
229 |
+
process_node(child, depth + 1)
|
230 |
+
|
231 |
+
elif isinstance(node, DOMTextNode):
|
232 |
+
# Add text only if it doesn't have a highlighted parent
|
233 |
+
if not node.has_parent_with_highlight_index() and node.is_visible:
|
234 |
+
if node.text and node.text.strip():
|
235 |
+
formatted_text.append(node.text)
|
236 |
+
|
237 |
+
process_node(self, 0)
|
238 |
+
result = '\n'.join(formatted_text)
|
239 |
+
return result if result.strip() else "No interactive elements found"
|
240 |
+
|
241 |
+
@dataclass
|
242 |
+
class DOMState:
|
243 |
+
element_tree: DOMElementNode
|
244 |
+
selector_map: Dict[int, DOMElementNode]
|
245 |
+
url: str = ""
|
246 |
+
title: str = ""
|
247 |
+
pixels_above: int = 0
|
248 |
+
pixels_below: int = 0
|
249 |
+
|
250 |
+
#######################################################
|
251 |
+
# Browser Action Result Model
|
252 |
+
#######################################################
|
253 |
+
|
254 |
+
class BrowserActionResult(BaseModel):
|
255 |
+
success: bool = True
|
256 |
+
message: str = ""
|
257 |
+
error: str = ""
|
258 |
+
|
259 |
+
# Extended state information
|
260 |
+
url: Optional[str] = None
|
261 |
+
title: Optional[str] = None
|
262 |
+
elements: Optional[str] = None # Formatted string of clickable elements
|
263 |
+
screenshot_base64: Optional[str] = None
|
264 |
+
pixels_above: int = 0
|
265 |
+
pixels_below: int = 0
|
266 |
+
content: Optional[str] = None
|
267 |
+
ocr_text: Optional[str] = None # Added field for OCR text
|
268 |
+
|
269 |
+
# Additional metadata
|
270 |
+
element_count: int = 0 # Number of interactive elements found
|
271 |
+
interactive_elements: Optional[List[Dict[str, Any]]] = None # Simplified list of interactive elements
|
272 |
+
viewport_width: Optional[int] = None
|
273 |
+
viewport_height: Optional[int] = None
|
274 |
+
|
275 |
+
class Config:
|
276 |
+
arbitrary_types_allowed = True
|
277 |
+
|
278 |
+
#######################################################
|
279 |
+
# Browser Automation Implementation
|
280 |
+
#######################################################
|
281 |
+
|
282 |
+
class BrowserAutomation:
|
283 |
+
def __init__(self):
|
284 |
+
self.router = APIRouter()
|
285 |
+
self.browser: Browser = None
|
286 |
+
self.pages: List[Page] = []
|
287 |
+
self.current_page_index: int = 0
|
288 |
+
self.logger = logging.getLogger("browser_automation")
|
289 |
+
self.include_attributes = ["id", "href", "src", "alt", "aria-label", "placeholder", "name", "role", "title", "value"]
|
290 |
+
self.screenshot_dir = os.path.join(os.getcwd(), "screenshots")
|
291 |
+
os.makedirs(self.screenshot_dir, exist_ok=True)
|
292 |
+
|
293 |
+
# Register routes
|
294 |
+
self.router.on_startup.append(self.startup)
|
295 |
+
self.router.on_shutdown.append(self.shutdown)
|
296 |
+
|
297 |
+
# Basic navigation
|
298 |
+
self.router.post("/automation/navigate_to")(self.navigate_to)
|
299 |
+
self.router.post("/automation/search_google")(self.search_google)
|
300 |
+
self.router.post("/automation/go_back")(self.go_back)
|
301 |
+
self.router.post("/automation/wait")(self.wait)
|
302 |
+
|
303 |
+
# Element interaction
|
304 |
+
self.router.post("/automation/click_element")(self.click_element)
|
305 |
+
self.router.post("/automation/click_coordinates")(self.click_coordinates)
|
306 |
+
self.router.post("/automation/input_text")(self.input_text)
|
307 |
+
self.router.post("/automation/send_keys")(self.send_keys)
|
308 |
+
|
309 |
+
# Tab management
|
310 |
+
self.router.post("/automation/switch_tab")(self.switch_tab)
|
311 |
+
self.router.post("/automation/open_tab")(self.open_tab)
|
312 |
+
self.router.post("/automation/close_tab")(self.close_tab)
|
313 |
+
|
314 |
+
# Content actions
|
315 |
+
self.router.post("/automation/extract_content")(self.extract_content)
|
316 |
+
self.router.post("/automation/save_pdf")(self.save_pdf)
|
317 |
+
|
318 |
+
# Scroll actions
|
319 |
+
self.router.post("/automation/scroll_down")(self.scroll_down)
|
320 |
+
self.router.post("/automation/scroll_up")(self.scroll_up)
|
321 |
+
self.router.post("/automation/scroll_to_text")(self.scroll_to_text)
|
322 |
+
|
323 |
+
# Dropdown actions
|
324 |
+
self.router.post("/automation/get_dropdown_options")(self.get_dropdown_options)
|
325 |
+
self.router.post("/automation/select_dropdown_option")(self.select_dropdown_option)
|
326 |
+
|
327 |
+
# Drag and drop
|
328 |
+
self.router.post("/automation/drag_drop")(self.drag_drop)
|
329 |
+
|
330 |
+
async def startup(self):
|
331 |
+
"""Initialize the browser instance on startup"""
|
332 |
+
try:
|
333 |
+
print("Starting browser initialization...")
|
334 |
+
playwright = await async_playwright().start()
|
335 |
+
print("Playwright started, launching browser...")
|
336 |
+
|
337 |
+
# Use non-headless mode for testing with slower timeouts
|
338 |
+
launch_options = {
|
339 |
+
"headless": False,
|
340 |
+
"timeout": 60000
|
341 |
+
}
|
342 |
+
|
343 |
+
try:
|
344 |
+
self.browser = await playwright.chromium.launch(**launch_options)
|
345 |
+
print("Browser launched successfully")
|
346 |
+
except Exception as browser_error:
|
347 |
+
print(f"Failed to launch browser: {browser_error}")
|
348 |
+
# Try with minimal options
|
349 |
+
print("Retrying with minimal options...")
|
350 |
+
launch_options = {"timeout": 90000}
|
351 |
+
self.browser = await playwright.chromium.launch(**launch_options)
|
352 |
+
print("Browser launched with minimal options")
|
353 |
+
|
354 |
+
try:
|
355 |
+
await self.get_current_page()
|
356 |
+
print("Found existing page, using it")
|
357 |
+
self.current_page_index = 0
|
358 |
+
except Exception as page_error:
|
359 |
+
print(f"Error finding existing page, creating new one. ( {page_error})")
|
360 |
+
page = await self.browser.new_page()
|
361 |
+
print("New page created successfully")
|
362 |
+
self.pages.append(page)
|
363 |
+
self.current_page_index = 0
|
364 |
+
# Navigate to about:blank to ensure page is ready
|
365 |
+
# await page.goto("google.com", timeout=30000)
|
366 |
+
print("Navigated to google.com")
|
367 |
+
|
368 |
+
print("Browser initialization completed successfully")
|
369 |
+
except Exception as e:
|
370 |
+
print(f"Browser startup error: {str(e)}")
|
371 |
+
traceback.print_exc()
|
372 |
+
raise RuntimeError(f"Browser initialization failed: {str(e)}")
|
373 |
+
|
374 |
+
async def shutdown(self):
|
375 |
+
"""Clean up browser instance on shutdown"""
|
376 |
+
if self.browser:
|
377 |
+
await self.browser.close()
|
378 |
+
|
379 |
+
async def get_current_page(self) -> Page:
|
380 |
+
"""Get the current active page"""
|
381 |
+
if not self.pages:
|
382 |
+
raise HTTPException(status_code=500, detail="No browser pages available")
|
383 |
+
return self.pages[self.current_page_index]
|
384 |
+
|
385 |
+
async def get_selector_map(self) -> Dict[int, DOMElementNode]:
|
386 |
+
"""Get a map of selectable elements on the page"""
|
387 |
+
page = await self.get_current_page()
|
388 |
+
|
389 |
+
# Create a selector map for interactive elements
|
390 |
+
selector_map = {}
|
391 |
+
|
392 |
+
try:
|
393 |
+
# More comprehensive JavaScript to find interactive elements
|
394 |
+
elements_js = """
|
395 |
+
(() => {
|
396 |
+
// Helper function to get all attributes as an object
|
397 |
+
function getAttributes(el) {
|
398 |
+
const attributes = {};
|
399 |
+
for (const attr of el.attributes) {
|
400 |
+
attributes[attr.name] = attr.value;
|
401 |
+
}
|
402 |
+
return attributes;
|
403 |
+
}
|
404 |
+
|
405 |
+
// Find all potentially interactive elements
|
406 |
+
const interactiveElements = Array.from(document.querySelectorAll(
|
407 |
+
'a, button, input, select, textarea, [role="button"], [role="link"], [role="checkbox"], [role="radio"], [tabindex]:not([tabindex="-1"])'
|
408 |
+
));
|
409 |
+
|
410 |
+
// Filter for visible elements
|
411 |
+
const visibleElements = interactiveElements.filter(el => {
|
412 |
+
const style = window.getComputedStyle(el);
|
413 |
+
const rect = el.getBoundingClientRect();
|
414 |
+
return style.display !== 'none' &&
|
415 |
+
style.visibility !== 'hidden' &&
|
416 |
+
style.opacity !== '0' &&
|
417 |
+
rect.width > 0 &&
|
418 |
+
rect.height > 0;
|
419 |
+
});
|
420 |
+
|
421 |
+
// Map to our expected structure
|
422 |
+
return visibleElements.map((el, index) => {
|
423 |
+
const rect = el.getBoundingClientRect();
|
424 |
+
const isInViewport = rect.top >= 0 &&
|
425 |
+
rect.left >= 0 &&
|
426 |
+
rect.bottom <= window.innerHeight &&
|
427 |
+
rect.right <= window.innerWidth;
|
428 |
+
|
429 |
+
return {
|
430 |
+
index: index + 1,
|
431 |
+
tagName: el.tagName.toLowerCase(),
|
432 |
+
text: el.innerText || el.value || '',
|
433 |
+
attributes: getAttributes(el),
|
434 |
+
isVisible: true,
|
435 |
+
isInteractive: true,
|
436 |
+
pageCoordinates: {
|
437 |
+
x: rect.left + window.scrollX,
|
438 |
+
y: rect.top + window.scrollY,
|
439 |
+
width: rect.width,
|
440 |
+
height: rect.height
|
441 |
+
},
|
442 |
+
viewportCoordinates: {
|
443 |
+
x: rect.left,
|
444 |
+
y: rect.top,
|
445 |
+
width: rect.width,
|
446 |
+
height: rect.height
|
447 |
+
},
|
448 |
+
isInViewport: isInViewport
|
449 |
+
};
|
450 |
+
});
|
451 |
+
})();
|
452 |
+
"""
|
453 |
+
|
454 |
+
elements = await page.evaluate(elements_js)
|
455 |
+
print(f"Found {len(elements)} interactive elements in selector map")
|
456 |
+
|
457 |
+
# Create a root element for the tree
|
458 |
+
root = DOMElementNode(
|
459 |
+
is_visible=True,
|
460 |
+
tag_name="body",
|
461 |
+
is_interactive=False,
|
462 |
+
is_top_element=True
|
463 |
+
)
|
464 |
+
|
465 |
+
# Create element nodes for each element
|
466 |
+
for idx, el in enumerate(elements):
|
467 |
+
# Create coordinate sets
|
468 |
+
page_coordinates = None
|
469 |
+
viewport_coordinates = None
|
470 |
+
|
471 |
+
if 'pageCoordinates' in el:
|
472 |
+
coords = el['pageCoordinates']
|
473 |
+
page_coordinates = CoordinateSet(
|
474 |
+
x=coords.get('x', 0),
|
475 |
+
y=coords.get('y', 0),
|
476 |
+
width=coords.get('width', 0),
|
477 |
+
height=coords.get('height', 0)
|
478 |
+
)
|
479 |
+
|
480 |
+
if 'viewportCoordinates' in el:
|
481 |
+
coords = el['viewportCoordinates']
|
482 |
+
viewport_coordinates = CoordinateSet(
|
483 |
+
x=coords.get('x', 0),
|
484 |
+
y=coords.get('y', 0),
|
485 |
+
width=coords.get('width', 0),
|
486 |
+
height=coords.get('height', 0)
|
487 |
+
)
|
488 |
+
|
489 |
+
# Create the element node
|
490 |
+
element_node = DOMElementNode(
|
491 |
+
is_visible=el.get('isVisible', True),
|
492 |
+
tag_name=el.get('tagName', 'div'),
|
493 |
+
attributes=el.get('attributes', {}),
|
494 |
+
is_interactive=el.get('isInteractive', True),
|
495 |
+
is_in_viewport=el.get('isInViewport', False),
|
496 |
+
highlight_index=el.get('index', idx + 1),
|
497 |
+
page_coordinates=page_coordinates,
|
498 |
+
viewport_coordinates=viewport_coordinates
|
499 |
+
)
|
500 |
+
|
501 |
+
# Add a text node if there's text content
|
502 |
+
if el.get('text'):
|
503 |
+
text_node = DOMTextNode(is_visible=True, text=el.get('text', ''))
|
504 |
+
text_node.parent = element_node
|
505 |
+
element_node.children.append(text_node)
|
506 |
+
|
507 |
+
selector_map[el.get('index', idx + 1)] = element_node
|
508 |
+
root.children.append(element_node)
|
509 |
+
element_node.parent = root
|
510 |
+
|
511 |
+
except Exception as e:
|
512 |
+
print(f"Error getting selector map: {e}")
|
513 |
+
traceback.print_exc()
|
514 |
+
# Create a dummy element to avoid breaking tests
|
515 |
+
dummy = DOMElementNode(
|
516 |
+
is_visible=True,
|
517 |
+
tag_name="a",
|
518 |
+
attributes={'href': '#'},
|
519 |
+
is_interactive=True,
|
520 |
+
highlight_index=1
|
521 |
+
)
|
522 |
+
dummy_text = DOMTextNode(is_visible=True, text="Dummy Element")
|
523 |
+
dummy_text.parent = dummy
|
524 |
+
dummy.children.append(dummy_text)
|
525 |
+
selector_map[1] = dummy
|
526 |
+
|
527 |
+
return selector_map
|
528 |
+
|
529 |
+
async def get_current_dom_state(self) -> DOMState:
|
530 |
+
"""Get the current DOM state including element tree and selector map"""
|
531 |
+
try:
|
532 |
+
page = await self.get_current_page()
|
533 |
+
selector_map = await self.get_selector_map()
|
534 |
+
|
535 |
+
# Create a root element
|
536 |
+
root = DOMElementNode(
|
537 |
+
is_visible=True,
|
538 |
+
tag_name="body",
|
539 |
+
is_interactive=False,
|
540 |
+
is_top_element=True
|
541 |
+
)
|
542 |
+
|
543 |
+
# Add all elements from selector map as children of root
|
544 |
+
for element in selector_map.values():
|
545 |
+
if element.parent is None:
|
546 |
+
element.parent = root
|
547 |
+
root.children.append(element)
|
548 |
+
|
549 |
+
# Get basic page info
|
550 |
+
url = page.url
|
551 |
+
try:
|
552 |
+
title = await page.title()
|
553 |
+
except:
|
554 |
+
title = "Unknown Title"
|
555 |
+
|
556 |
+
# Get more accurate scroll information - fix JavaScript syntax
|
557 |
+
try:
|
558 |
+
scroll_info = await page.evaluate("""
|
559 |
+
() => {
|
560 |
+
const body = document.body;
|
561 |
+
const html = document.documentElement;
|
562 |
+
const totalHeight = Math.max(
|
563 |
+
body.scrollHeight, body.offsetHeight,
|
564 |
+
html.clientHeight, html.scrollHeight, html.offsetHeight
|
565 |
+
);
|
566 |
+
const scrollY = window.scrollY || window.pageYOffset;
|
567 |
+
const windowHeight = window.innerHeight;
|
568 |
+
|
569 |
+
return {
|
570 |
+
pixelsAbove: scrollY,
|
571 |
+
pixelsBelow: Math.max(0, totalHeight - scrollY - windowHeight),
|
572 |
+
totalHeight: totalHeight,
|
573 |
+
viewportHeight: windowHeight
|
574 |
+
};
|
575 |
+
}
|
576 |
+
""")
|
577 |
+
pixels_above = scroll_info.get('pixelsAbove', 0)
|
578 |
+
pixels_below = scroll_info.get('pixelsBelow', 0)
|
579 |
+
except Exception as e:
|
580 |
+
print(f"Error getting scroll info: {e}")
|
581 |
+
pixels_above = 0
|
582 |
+
pixels_below = 0
|
583 |
+
|
584 |
+
return DOMState(
|
585 |
+
element_tree=root,
|
586 |
+
selector_map=selector_map,
|
587 |
+
url=url,
|
588 |
+
title=title,
|
589 |
+
pixels_above=pixels_above,
|
590 |
+
pixels_below=pixels_below
|
591 |
+
)
|
592 |
+
except Exception as e:
|
593 |
+
print(f"Error getting DOM state: {e}")
|
594 |
+
traceback.print_exc()
|
595 |
+
# Return a minimal valid state to avoid breaking tests
|
596 |
+
dummy_root = DOMElementNode(
|
597 |
+
is_visible=True,
|
598 |
+
tag_name="body",
|
599 |
+
is_interactive=False,
|
600 |
+
is_top_element=True
|
601 |
+
)
|
602 |
+
dummy_map = {1: dummy_root}
|
603 |
+
return DOMState(
|
604 |
+
element_tree=dummy_root,
|
605 |
+
selector_map=dummy_map,
|
606 |
+
url=page.url if 'page' in locals() else "about:blank",
|
607 |
+
title="Error page",
|
608 |
+
pixels_above=0,
|
609 |
+
pixels_below=0
|
610 |
+
)
|
611 |
+
|
612 |
+
async def take_screenshot(self) -> str:
|
613 |
+
"""Take a screenshot and return as base64 encoded string"""
|
614 |
+
try:
|
615 |
+
page = await self.get_current_page()
|
616 |
+
screenshot_bytes = await page.screenshot(type='jpeg', quality=60, full_page=False)
|
617 |
+
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
618 |
+
except Exception as e:
|
619 |
+
print(f"Error taking screenshot: {e}")
|
620 |
+
# Return an empty string rather than failing
|
621 |
+
return ""
|
622 |
+
|
623 |
+
async def save_screenshot_to_file(self) -> str:
|
624 |
+
"""Take a screenshot and save to file, returning the path"""
|
625 |
+
try:
|
626 |
+
page = await self.get_current_page()
|
627 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
628 |
+
random_id = random.randint(1000, 9999)
|
629 |
+
filename = f"screenshot_{timestamp}_{random_id}.jpg"
|
630 |
+
filepath = os.path.join(self.screenshot_dir, filename)
|
631 |
+
|
632 |
+
await page.screenshot(path=filepath, type='jpeg', quality=60, full_page=False)
|
633 |
+
return filepath
|
634 |
+
except Exception as e:
|
635 |
+
print(f"Error saving screenshot: {e}")
|
636 |
+
return ""
|
637 |
+
|
638 |
+
async def extract_ocr_text_from_screenshot(self, screenshot_base64: str) -> str:
|
639 |
+
"""Extract text from screenshot using OCR"""
|
640 |
+
if not screenshot_base64:
|
641 |
+
return ""
|
642 |
+
|
643 |
+
try:
|
644 |
+
# Decode base64 to image
|
645 |
+
image_bytes = base64.b64decode(screenshot_base64)
|
646 |
+
image = Image.open(io.BytesIO(image_bytes))
|
647 |
+
|
648 |
+
# Extract text using pytesseract
|
649 |
+
ocr_text = pytesseract.image_to_string(image)
|
650 |
+
|
651 |
+
# Clean up the text
|
652 |
+
ocr_text = ocr_text.strip()
|
653 |
+
|
654 |
+
return ocr_text
|
655 |
+
except Exception as e:
|
656 |
+
print(f"Error performing OCR: {e}")
|
657 |
+
traceback.print_exc()
|
658 |
+
return ""
|
659 |
+
|
660 |
+
async def get_updated_browser_state(self, action_name: str) -> tuple:
|
661 |
+
"""Helper method to get updated browser state after any action
|
662 |
+
Returns a tuple of (dom_state, screenshot, elements, metadata)
|
663 |
+
"""
|
664 |
+
try:
|
665 |
+
# Wait a moment for any potential async processes to settle
|
666 |
+
await asyncio.sleep(0.5)
|
667 |
+
|
668 |
+
# Get updated state
|
669 |
+
dom_state = await self.get_current_dom_state()
|
670 |
+
screenshot = await self.take_screenshot()
|
671 |
+
|
672 |
+
# Format elements for output
|
673 |
+
elements = dom_state.element_tree.clickable_elements_to_string(
|
674 |
+
include_attributes=self.include_attributes
|
675 |
+
)
|
676 |
+
|
677 |
+
# Collect additional metadata
|
678 |
+
page = await self.get_current_page()
|
679 |
+
metadata = {}
|
680 |
+
|
681 |
+
# Get element count
|
682 |
+
metadata['element_count'] = len(dom_state.selector_map)
|
683 |
+
|
684 |
+
# Create simplified interactive elements list
|
685 |
+
interactive_elements = []
|
686 |
+
for idx, element in dom_state.selector_map.items():
|
687 |
+
element_info = {
|
688 |
+
'index': idx,
|
689 |
+
'tag_name': element.tag_name,
|
690 |
+
'text': element.get_all_text_till_next_clickable_element(),
|
691 |
+
'is_in_viewport': element.is_in_viewport
|
692 |
+
}
|
693 |
+
|
694 |
+
# Add key attributes
|
695 |
+
for attr_name in ['id', 'href', 'src', 'alt', 'placeholder', 'name', 'role', 'title', 'type']:
|
696 |
+
if attr_name in element.attributes:
|
697 |
+
element_info[attr_name] = element.attributes[attr_name]
|
698 |
+
|
699 |
+
interactive_elements.append(element_info)
|
700 |
+
|
701 |
+
metadata['interactive_elements'] = interactive_elements
|
702 |
+
|
703 |
+
# Get viewport dimensions - Fix syntax error in JavaScript
|
704 |
+
try:
|
705 |
+
viewport = await page.evaluate("""
|
706 |
+
() => {
|
707 |
+
return {
|
708 |
+
width: window.innerWidth,
|
709 |
+
height: window.innerHeight
|
710 |
+
};
|
711 |
+
}
|
712 |
+
""")
|
713 |
+
metadata['viewport_width'] = viewport.get('width', 0)
|
714 |
+
metadata['viewport_height'] = viewport.get('height', 0)
|
715 |
+
except Exception as e:
|
716 |
+
print(f"Error getting viewport dimensions: {e}")
|
717 |
+
metadata['viewport_width'] = 0
|
718 |
+
metadata['viewport_height'] = 0
|
719 |
+
|
720 |
+
# Extract OCR text from screenshot if available
|
721 |
+
ocr_text = ""
|
722 |
+
if screenshot:
|
723 |
+
ocr_text = await self.extract_ocr_text_from_screenshot(screenshot)
|
724 |
+
metadata['ocr_text'] = ocr_text
|
725 |
+
|
726 |
+
print(f"Got updated state after {action_name}: {len(dom_state.selector_map)} elements")
|
727 |
+
return dom_state, screenshot, elements, metadata
|
728 |
+
except Exception as e:
|
729 |
+
print(f"Error getting updated state after {action_name}: {e}")
|
730 |
+
traceback.print_exc()
|
731 |
+
# Return empty values in case of error
|
732 |
+
return None, "", "", {}
|
733 |
+
|
734 |
+
def build_action_result(self, success: bool, message: str, dom_state, screenshot: str,
|
735 |
+
elements: str, metadata: dict, error: str = "", content: str = None,
|
736 |
+
fallback_url: str = None) -> BrowserActionResult:
|
737 |
+
"""Helper method to build a consistent BrowserActionResult"""
|
738 |
+
# Ensure elements is never None to avoid display issues
|
739 |
+
if elements is None:
|
740 |
+
elements = ""
|
741 |
+
|
742 |
+
return BrowserActionResult(
|
743 |
+
success=success,
|
744 |
+
message=message,
|
745 |
+
error=error,
|
746 |
+
url=dom_state.url if dom_state else fallback_url or "",
|
747 |
+
title=dom_state.title if dom_state else "",
|
748 |
+
elements=elements,
|
749 |
+
screenshot_base64=screenshot,
|
750 |
+
pixels_above=dom_state.pixels_above if dom_state else 0,
|
751 |
+
pixels_below=dom_state.pixels_below if dom_state else 0,
|
752 |
+
content=content,
|
753 |
+
ocr_text=metadata.get('ocr_text', ""),
|
754 |
+
element_count=metadata.get('element_count', 0),
|
755 |
+
interactive_elements=metadata.get('interactive_elements', []),
|
756 |
+
viewport_width=metadata.get('viewport_width', 0),
|
757 |
+
viewport_height=metadata.get('viewport_height', 0)
|
758 |
+
)
|
759 |
+
|
760 |
+
# Basic Navigation Actions
|
761 |
+
|
762 |
+
async def navigate_to(self, action: GoToUrlAction = Body(...)):
|
763 |
+
"""Navigate to a specified URL"""
|
764 |
+
try:
|
765 |
+
page = await self.get_current_page()
|
766 |
+
await page.goto(action.url, wait_until="domcontentloaded")
|
767 |
+
await page.wait_for_load_state("networkidle", timeout=10000)
|
768 |
+
|
769 |
+
# Get updated state after action
|
770 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"navigate_to({action.url})")
|
771 |
+
|
772 |
+
result = self.build_action_result(
|
773 |
+
True,
|
774 |
+
f"Navigated to {action.url}",
|
775 |
+
dom_state,
|
776 |
+
screenshot,
|
777 |
+
elements,
|
778 |
+
metadata,
|
779 |
+
error="",
|
780 |
+
content=None
|
781 |
+
)
|
782 |
+
|
783 |
+
print(f"Navigation result: success={result.success}, url={result.url}")
|
784 |
+
return result
|
785 |
+
except Exception as e:
|
786 |
+
print(f"Navigation error: {str(e)}")
|
787 |
+
traceback.print_exc()
|
788 |
+
# Try to get some state info even after error
|
789 |
+
try:
|
790 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("navigate_error_recovery")
|
791 |
+
return self.build_action_result(
|
792 |
+
False,
|
793 |
+
str(e),
|
794 |
+
dom_state,
|
795 |
+
screenshot,
|
796 |
+
elements,
|
797 |
+
metadata,
|
798 |
+
error=str(e),
|
799 |
+
content=None
|
800 |
+
)
|
801 |
+
except:
|
802 |
+
return self.build_action_result(
|
803 |
+
False,
|
804 |
+
str(e),
|
805 |
+
None,
|
806 |
+
"",
|
807 |
+
"",
|
808 |
+
{},
|
809 |
+
error=str(e),
|
810 |
+
content=None
|
811 |
+
)
|
812 |
+
|
813 |
+
async def search_google(self, action: SearchGoogleAction = Body(...)):
|
814 |
+
"""Search Google with the provided query"""
|
815 |
+
try:
|
816 |
+
page = await self.get_current_page()
|
817 |
+
search_url = f"https://www.google.com/search?q={action.query}"
|
818 |
+
await page.goto(search_url)
|
819 |
+
await page.wait_for_load_state()
|
820 |
+
|
821 |
+
# Get updated state after action
|
822 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"search_google({action.query})")
|
823 |
+
|
824 |
+
return self.build_action_result(
|
825 |
+
True,
|
826 |
+
f"Searched for '{action.query}' in Google",
|
827 |
+
dom_state,
|
828 |
+
screenshot,
|
829 |
+
elements,
|
830 |
+
metadata,
|
831 |
+
error="",
|
832 |
+
content=None
|
833 |
+
)
|
834 |
+
except Exception as e:
|
835 |
+
print(f"Search error: {str(e)}")
|
836 |
+
traceback.print_exc()
|
837 |
+
# Try to get some state info even after error
|
838 |
+
try:
|
839 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("search_error_recovery")
|
840 |
+
return self.build_action_result(
|
841 |
+
False,
|
842 |
+
str(e),
|
843 |
+
dom_state,
|
844 |
+
screenshot,
|
845 |
+
elements,
|
846 |
+
metadata,
|
847 |
+
error=str(e),
|
848 |
+
content=None
|
849 |
+
)
|
850 |
+
except:
|
851 |
+
return self.build_action_result(
|
852 |
+
False,
|
853 |
+
str(e),
|
854 |
+
None,
|
855 |
+
"",
|
856 |
+
"",
|
857 |
+
{},
|
858 |
+
error=str(e),
|
859 |
+
content=None
|
860 |
+
)
|
861 |
+
|
862 |
+
async def go_back(self, _: NoParamsAction = Body(...)):
|
863 |
+
"""Navigate back in browser history"""
|
864 |
+
try:
|
865 |
+
page = await self.get_current_page()
|
866 |
+
await page.go_back()
|
867 |
+
await page.wait_for_load_state()
|
868 |
+
|
869 |
+
# Get updated state after action
|
870 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("go_back")
|
871 |
+
|
872 |
+
return self.build_action_result(
|
873 |
+
True,
|
874 |
+
"Navigated back",
|
875 |
+
dom_state,
|
876 |
+
screenshot,
|
877 |
+
elements,
|
878 |
+
metadata,
|
879 |
+
error="",
|
880 |
+
content=None
|
881 |
+
)
|
882 |
+
except Exception as e:
|
883 |
+
return self.build_action_result(
|
884 |
+
False,
|
885 |
+
str(e),
|
886 |
+
None,
|
887 |
+
"",
|
888 |
+
"",
|
889 |
+
{},
|
890 |
+
error=str(e),
|
891 |
+
content=None
|
892 |
+
)
|
893 |
+
|
894 |
+
async def wait(self, seconds: int = Body(3)):
|
895 |
+
"""Wait for the specified number of seconds"""
|
896 |
+
try:
|
897 |
+
await asyncio.sleep(seconds)
|
898 |
+
|
899 |
+
# Get updated state after waiting
|
900 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"wait({seconds} seconds)")
|
901 |
+
|
902 |
+
return self.build_action_result(
|
903 |
+
True,
|
904 |
+
f"Waited for {seconds} seconds",
|
905 |
+
dom_state,
|
906 |
+
screenshot,
|
907 |
+
elements,
|
908 |
+
metadata,
|
909 |
+
error="",
|
910 |
+
content=None
|
911 |
+
)
|
912 |
+
except Exception as e:
|
913 |
+
return self.build_action_result(
|
914 |
+
False,
|
915 |
+
str(e),
|
916 |
+
None,
|
917 |
+
"",
|
918 |
+
"",
|
919 |
+
{},
|
920 |
+
error=str(e),
|
921 |
+
content=None
|
922 |
+
)
|
923 |
+
|
924 |
+
# Element Interaction Actions
|
925 |
+
|
926 |
+
async def click_coordinates(self, action: ClickCoordinatesAction = Body(...)):
|
927 |
+
"""Click at specific x,y coordinates on the page"""
|
928 |
+
try:
|
929 |
+
page = await self.get_current_page()
|
930 |
+
|
931 |
+
# Perform the click at the specified coordinates
|
932 |
+
await page.mouse.click(action.x, action.y)
|
933 |
+
|
934 |
+
# Give time for any navigation or DOM updates to occur
|
935 |
+
await page.wait_for_load_state("networkidle", timeout=5000)
|
936 |
+
|
937 |
+
# Get updated state after action
|
938 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_coordinates({action.x}, {action.y})")
|
939 |
+
|
940 |
+
return self.build_action_result(
|
941 |
+
True,
|
942 |
+
f"Clicked at coordinates ({action.x}, {action.y})",
|
943 |
+
dom_state,
|
944 |
+
screenshot,
|
945 |
+
elements,
|
946 |
+
metadata,
|
947 |
+
error="",
|
948 |
+
content=None
|
949 |
+
)
|
950 |
+
except Exception as e:
|
951 |
+
print(f"Error in click_coordinates: {e}")
|
952 |
+
traceback.print_exc()
|
953 |
+
|
954 |
+
# Try to get state even after error
|
955 |
+
try:
|
956 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("click_coordinates_error_recovery")
|
957 |
+
return self.build_action_result(
|
958 |
+
False,
|
959 |
+
str(e),
|
960 |
+
dom_state,
|
961 |
+
screenshot,
|
962 |
+
elements,
|
963 |
+
metadata,
|
964 |
+
error=str(e),
|
965 |
+
content=None
|
966 |
+
)
|
967 |
+
except:
|
968 |
+
return self.build_action_result(
|
969 |
+
False,
|
970 |
+
str(e),
|
971 |
+
None,
|
972 |
+
"",
|
973 |
+
"",
|
974 |
+
{},
|
975 |
+
error=str(e),
|
976 |
+
content=None
|
977 |
+
)
|
978 |
+
|
979 |
+
async def click_element(self, action: ClickElementAction = Body(...)):
|
980 |
+
"""Click on an element by index"""
|
981 |
+
try:
|
982 |
+
page = await self.get_current_page()
|
983 |
+
|
984 |
+
# Get the current state and selector map *before* the click
|
985 |
+
initial_dom_state = await self.get_current_dom_state()
|
986 |
+
selector_map = initial_dom_state.selector_map
|
987 |
+
|
988 |
+
if action.index not in selector_map:
|
989 |
+
# Get updated state even if element not found initially
|
990 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_element_error (index {action.index} not found)")
|
991 |
+
return self.build_action_result(
|
992 |
+
False,
|
993 |
+
f"Element with index {action.index} not found",
|
994 |
+
dom_state, # Use the latest state
|
995 |
+
screenshot,
|
996 |
+
elements,
|
997 |
+
metadata,
|
998 |
+
error=f"Element with index {action.index} not found"
|
999 |
+
)
|
1000 |
+
|
1001 |
+
element_to_click = selector_map[action.index]
|
1002 |
+
print(f"Attempting to click element: {element_to_click}")
|
1003 |
+
|
1004 |
+
# Construct a more reliable selector using JavaScript evaluation
|
1005 |
+
# Find the element based on its properties captured in selector_map
|
1006 |
+
js_selector_script = """
|
1007 |
+
(targetElementInfo) => {
|
1008 |
+
const interactiveElements = Array.from(document.querySelectorAll(
|
1009 |
+
'a, button, input, select, textarea, [role="button"], [role="link"], [role="checkbox"], [role="radio"], [tabindex]:not([tabindex="-1"])'
|
1010 |
+
));
|
1011 |
+
|
1012 |
+
const visibleElements = interactiveElements.filter(el => {
|
1013 |
+
const style = window.getComputedStyle(el);
|
1014 |
+
const rect = el.getBoundingClientRect();
|
1015 |
+
return style.display !== 'none' && style.visibility !== 'hidden' && style.opacity !== '0' && rect.width > 0 && rect.height > 0;
|
1016 |
+
});
|
1017 |
+
|
1018 |
+
if (targetElementInfo.index > 0 && targetElementInfo.index <= visibleElements.length) {
|
1019 |
+
// Return the element at the specified index (1-based)
|
1020 |
+
return visibleElements[targetElementInfo.index - 1];
|
1021 |
+
}
|
1022 |
+
return null; // Element not found at the expected index
|
1023 |
+
}
|
1024 |
+
"""
|
1025 |
+
|
1026 |
+
element_info = {'index': action.index} # Pass the target index to the script
|
1027 |
+
|
1028 |
+
target_element_handle = await page.evaluate_handle(js_selector_script, element_info)
|
1029 |
+
|
1030 |
+
click_success = False
|
1031 |
+
error_message = ""
|
1032 |
+
|
1033 |
+
if await target_element_handle.evaluate("node => node !== null"):
|
1034 |
+
try:
|
1035 |
+
# Use Playwright's recommended way: click the handle
|
1036 |
+
# Add timeout and wait for element to be stable
|
1037 |
+
await target_element_handle.click(timeout=5000)
|
1038 |
+
click_success = True
|
1039 |
+
print(f"Successfully clicked element handle for index {action.index}")
|
1040 |
+
except Exception as click_error:
|
1041 |
+
error_message = f"Error clicking element handle: {click_error}"
|
1042 |
+
print(error_message)
|
1043 |
+
# Optional: Add fallback methods here if needed
|
1044 |
+
# e.g., target_element_handle.dispatch_event('click')
|
1045 |
+
else:
|
1046 |
+
error_message = f"Could not locate the target element handle for index {action.index} using JS script."
|
1047 |
+
print(error_message)
|
1048 |
+
|
1049 |
+
|
1050 |
+
# Wait for potential page changes/network activity
|
1051 |
+
try:
|
1052 |
+
await page.wait_for_load_state("networkidle", timeout=5000)
|
1053 |
+
except Exception as wait_error:
|
1054 |
+
print(f"Timeout or error waiting for network idle after click: {wait_error}")
|
1055 |
+
await asyncio.sleep(1) # Fallback wait
|
1056 |
+
|
1057 |
+
# Get updated state after action
|
1058 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_element({action.index})")
|
1059 |
+
|
1060 |
+
return self.build_action_result(
|
1061 |
+
click_success,
|
1062 |
+
f"Clicked element with index {action.index}" if click_success else f"Attempted to click element {action.index} but failed. Error: {error_message}",
|
1063 |
+
dom_state,
|
1064 |
+
screenshot,
|
1065 |
+
elements,
|
1066 |
+
metadata,
|
1067 |
+
error=error_message if not click_success else "",
|
1068 |
+
content=None
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
except Exception as e:
|
1072 |
+
print(f"Error in click_element: {e}")
|
1073 |
+
traceback.print_exc()
|
1074 |
+
# Try to get state even after error
|
1075 |
+
try:
|
1076 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("click_element_error_recovery")
|
1077 |
+
return self.build_action_result(
|
1078 |
+
False,
|
1079 |
+
str(e),
|
1080 |
+
dom_state,
|
1081 |
+
screenshot,
|
1082 |
+
elements,
|
1083 |
+
metadata,
|
1084 |
+
error=str(e),
|
1085 |
+
content=None
|
1086 |
+
)
|
1087 |
+
except:
|
1088 |
+
# Fallback if getting state also fails
|
1089 |
+
current_url = "unknown"
|
1090 |
+
try:
|
1091 |
+
current_url = page.url # Try to get at least the URL
|
1092 |
+
except:
|
1093 |
+
pass
|
1094 |
+
return self.build_action_result(
|
1095 |
+
False,
|
1096 |
+
str(e),
|
1097 |
+
None, # No DOM state available
|
1098 |
+
"", # No screenshot
|
1099 |
+
"", # No elements string
|
1100 |
+
{}, # Empty metadata
|
1101 |
+
error=str(e),
|
1102 |
+
content=None,
|
1103 |
+
fallback_url=current_url
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
async def input_text(self, action: InputTextAction = Body(...)):
|
1107 |
+
"""Input text into an element"""
|
1108 |
+
try:
|
1109 |
+
page = await self.get_current_page()
|
1110 |
+
selector_map = await self.get_selector_map()
|
1111 |
+
|
1112 |
+
if action.index not in selector_map:
|
1113 |
+
return self.build_action_result(
|
1114 |
+
False,
|
1115 |
+
f"Element with index {action.index} not found",
|
1116 |
+
None,
|
1117 |
+
"",
|
1118 |
+
"",
|
1119 |
+
{},
|
1120 |
+
error=f"Element with index {action.index} not found"
|
1121 |
+
)
|
1122 |
+
|
1123 |
+
# In a real implementation, we would use the selector map to get the element's
|
1124 |
+
# properties and use them to find and type into the element
|
1125 |
+
element = selector_map[action.index]
|
1126 |
+
|
1127 |
+
# Use CSS selector or XPath to locate and type into the element
|
1128 |
+
await page.wait_for_timeout(500) # Small delay before typing
|
1129 |
+
|
1130 |
+
# Demo implementation - would use proper selectors in production
|
1131 |
+
if element.attributes.get("id"):
|
1132 |
+
await page.fill(f"#{element.attributes['id']}", action.text)
|
1133 |
+
elif element.attributes.get("class"):
|
1134 |
+
class_selector = f".{element.attributes['class'].replace(' ', '.')}"
|
1135 |
+
await page.fill(class_selector, action.text)
|
1136 |
+
else:
|
1137 |
+
# Fallback to xpath
|
1138 |
+
await page.fill(f"//{element.tag_name}[{action.index}]", action.text)
|
1139 |
+
|
1140 |
+
# Get updated state after action
|
1141 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"input_text({action.index}, '{action.text}')")
|
1142 |
+
|
1143 |
+
return self.build_action_result(
|
1144 |
+
True,
|
1145 |
+
f"Input '{action.text}' into element with index {action.index}",
|
1146 |
+
dom_state,
|
1147 |
+
screenshot,
|
1148 |
+
elements,
|
1149 |
+
metadata,
|
1150 |
+
error="",
|
1151 |
+
content=None
|
1152 |
+
)
|
1153 |
+
except Exception as e:
|
1154 |
+
return self.build_action_result(
|
1155 |
+
False,
|
1156 |
+
str(e),
|
1157 |
+
None,
|
1158 |
+
"",
|
1159 |
+
"",
|
1160 |
+
{},
|
1161 |
+
error=str(e),
|
1162 |
+
content=None
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
async def send_keys(self, action: SendKeysAction = Body(...)):
|
1166 |
+
"""Send keyboard keys"""
|
1167 |
+
try:
|
1168 |
+
page = await self.get_current_page()
|
1169 |
+
await page.keyboard.press(action.keys)
|
1170 |
+
|
1171 |
+
# Get updated state after action
|
1172 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"send_keys({action.keys})")
|
1173 |
+
|
1174 |
+
return self.build_action_result(
|
1175 |
+
True,
|
1176 |
+
f"Sent keys: {action.keys}",
|
1177 |
+
dom_state,
|
1178 |
+
screenshot,
|
1179 |
+
elements,
|
1180 |
+
metadata,
|
1181 |
+
error="",
|
1182 |
+
content=None
|
1183 |
+
)
|
1184 |
+
except Exception as e:
|
1185 |
+
return self.build_action_result(
|
1186 |
+
False,
|
1187 |
+
str(e),
|
1188 |
+
None,
|
1189 |
+
"",
|
1190 |
+
"",
|
1191 |
+
{},
|
1192 |
+
error=str(e),
|
1193 |
+
content=None
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
# Tab Management Actions
|
1197 |
+
|
1198 |
+
async def switch_tab(self, action: SwitchTabAction = Body(...)):
|
1199 |
+
"""Switch to a different tab by index"""
|
1200 |
+
try:
|
1201 |
+
if 0 <= action.page_id < len(self.pages):
|
1202 |
+
self.current_page_index = action.page_id
|
1203 |
+
page = await self.get_current_page()
|
1204 |
+
await page.wait_for_load_state()
|
1205 |
+
|
1206 |
+
# Get updated state after action
|
1207 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"switch_tab({action.page_id})")
|
1208 |
+
|
1209 |
+
return self.build_action_result(
|
1210 |
+
True,
|
1211 |
+
f"Switched to tab {action.page_id}",
|
1212 |
+
dom_state,
|
1213 |
+
screenshot,
|
1214 |
+
elements,
|
1215 |
+
metadata,
|
1216 |
+
error="",
|
1217 |
+
content=None
|
1218 |
+
)
|
1219 |
+
else:
|
1220 |
+
return self.build_action_result(
|
1221 |
+
False,
|
1222 |
+
f"Tab {action.page_id} not found",
|
1223 |
+
None,
|
1224 |
+
"",
|
1225 |
+
"",
|
1226 |
+
{},
|
1227 |
+
error=f"Tab {action.page_id} not found"
|
1228 |
+
)
|
1229 |
+
except Exception as e:
|
1230 |
+
return self.build_action_result(
|
1231 |
+
False,
|
1232 |
+
str(e),
|
1233 |
+
None,
|
1234 |
+
"",
|
1235 |
+
"",
|
1236 |
+
{},
|
1237 |
+
error=str(e),
|
1238 |
+
content=None
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
async def open_tab(self, action: OpenTabAction = Body(...)):
|
1242 |
+
"""Open a new tab with the specified URL"""
|
1243 |
+
try:
|
1244 |
+
print(f"Attempting to open new tab with URL: {action.url}")
|
1245 |
+
# Create new page in same browser instance
|
1246 |
+
new_page = await self.browser.new_page()
|
1247 |
+
print(f"New page created successfully")
|
1248 |
+
|
1249 |
+
# Navigate to the URL
|
1250 |
+
await new_page.goto(action.url, wait_until="domcontentloaded")
|
1251 |
+
await new_page.wait_for_load_state("networkidle", timeout=10000)
|
1252 |
+
print(f"Navigated to URL in new tab: {action.url}")
|
1253 |
+
|
1254 |
+
# Add to page list and make it current
|
1255 |
+
self.pages.append(new_page)
|
1256 |
+
self.current_page_index = len(self.pages) - 1
|
1257 |
+
print(f"New tab added as index {self.current_page_index}")
|
1258 |
+
|
1259 |
+
# Get updated state after action
|
1260 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"open_tab({action.url})")
|
1261 |
+
|
1262 |
+
return self.build_action_result(
|
1263 |
+
True,
|
1264 |
+
f"Opened new tab with URL: {action.url}",
|
1265 |
+
dom_state,
|
1266 |
+
screenshot,
|
1267 |
+
elements,
|
1268 |
+
metadata,
|
1269 |
+
error="",
|
1270 |
+
content=None
|
1271 |
+
)
|
1272 |
+
except Exception as e:
|
1273 |
+
print("****"*10)
|
1274 |
+
print(f"Error opening tab: {e}")
|
1275 |
+
print(traceback.format_exc())
|
1276 |
+
print("****"*10)
|
1277 |
+
return self.build_action_result(
|
1278 |
+
False,
|
1279 |
+
str(e),
|
1280 |
+
None,
|
1281 |
+
"",
|
1282 |
+
"",
|
1283 |
+
{},
|
1284 |
+
error=str(e),
|
1285 |
+
content=None
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
async def close_tab(self, action: CloseTabAction = Body(...)):
|
1289 |
+
"""Close a tab by index"""
|
1290 |
+
try:
|
1291 |
+
if 0 <= action.page_id < len(self.pages):
|
1292 |
+
page = self.pages[action.page_id]
|
1293 |
+
url = page.url
|
1294 |
+
await page.close()
|
1295 |
+
self.pages.pop(action.page_id)
|
1296 |
+
|
1297 |
+
# Adjust current index if needed
|
1298 |
+
if self.current_page_index >= len(self.pages):
|
1299 |
+
self.current_page_index = max(0, len(self.pages) - 1)
|
1300 |
+
elif self.current_page_index >= action.page_id:
|
1301 |
+
self.current_page_index = max(0, self.current_page_index - 1)
|
1302 |
+
|
1303 |
+
# Get updated state after action
|
1304 |
+
page = await self.get_current_page()
|
1305 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"close_tab({action.page_id})")
|
1306 |
+
|
1307 |
+
return self.build_action_result(
|
1308 |
+
True,
|
1309 |
+
f"Closed tab {action.page_id} with URL: {url}",
|
1310 |
+
dom_state,
|
1311 |
+
screenshot,
|
1312 |
+
elements,
|
1313 |
+
metadata,
|
1314 |
+
error="",
|
1315 |
+
content=None
|
1316 |
+
)
|
1317 |
+
else:
|
1318 |
+
return self.build_action_result(
|
1319 |
+
False,
|
1320 |
+
f"Tab {action.page_id} not found",
|
1321 |
+
None,
|
1322 |
+
"",
|
1323 |
+
"",
|
1324 |
+
{},
|
1325 |
+
error=f"Tab {action.page_id} not found"
|
1326 |
+
)
|
1327 |
+
except Exception as e:
|
1328 |
+
return self.build_action_result(
|
1329 |
+
False,
|
1330 |
+
str(e),
|
1331 |
+
None,
|
1332 |
+
"",
|
1333 |
+
"",
|
1334 |
+
{},
|
1335 |
+
error=str(e),
|
1336 |
+
content=None
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
# Content Actions
|
1340 |
+
|
1341 |
+
async def extract_content(self, goal: str = Body(...)):
|
1342 |
+
"""Extract content from the current page based on the provided goal"""
|
1343 |
+
try:
|
1344 |
+
page = await self.get_current_page()
|
1345 |
+
content = await page.content()
|
1346 |
+
|
1347 |
+
# In a full implementation, we would use an LLM to extract specific content
|
1348 |
+
# based on the goal. For this example, we'll extract visible text.
|
1349 |
+
extracted_text = await page.evaluate("""
|
1350 |
+
Array.from(document.querySelectorAll('p, h1, h2, h3, h4, h5, h6, li, span, div'))
|
1351 |
+
.filter(el => {
|
1352 |
+
const style = window.getComputedStyle(el);
|
1353 |
+
return style.display !== 'none' &&
|
1354 |
+
style.visibility !== 'hidden' &&
|
1355 |
+
style.opacity !== '0' &&
|
1356 |
+
el.innerText &&
|
1357 |
+
el.innerText.trim().length > 0;
|
1358 |
+
})
|
1359 |
+
.map(el => el.innerText.trim())
|
1360 |
+
.join('\\n\\n');
|
1361 |
+
""")
|
1362 |
+
|
1363 |
+
# Get updated state
|
1364 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"extract_content({goal})")
|
1365 |
+
|
1366 |
+
return self.build_action_result(
|
1367 |
+
True,
|
1368 |
+
f"Content extracted based on goal: {goal}",
|
1369 |
+
dom_state,
|
1370 |
+
screenshot,
|
1371 |
+
elements,
|
1372 |
+
metadata,
|
1373 |
+
error="",
|
1374 |
+
content=extracted_text
|
1375 |
+
)
|
1376 |
+
except Exception as e:
|
1377 |
+
return self.build_action_result(
|
1378 |
+
False,
|
1379 |
+
str(e),
|
1380 |
+
None,
|
1381 |
+
"",
|
1382 |
+
"",
|
1383 |
+
{},
|
1384 |
+
error=str(e),
|
1385 |
+
content=None
|
1386 |
+
)
|
1387 |
+
|
1388 |
+
async def save_pdf(self):
|
1389 |
+
"""Save the current page as a PDF"""
|
1390 |
+
try:
|
1391 |
+
page = await self.get_current_page()
|
1392 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
1393 |
+
random_id = random.randint(1000, 9999)
|
1394 |
+
filename = f"page_{timestamp}_{random_id}.pdf"
|
1395 |
+
filepath = os.path.join(self.screenshot_dir, filename)
|
1396 |
+
|
1397 |
+
await page.pdf(path=filepath)
|
1398 |
+
|
1399 |
+
# Get updated state
|
1400 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("save_pdf")
|
1401 |
+
|
1402 |
+
return self.build_action_result(
|
1403 |
+
True,
|
1404 |
+
f"Saved page as PDF: {filepath}",
|
1405 |
+
dom_state,
|
1406 |
+
screenshot,
|
1407 |
+
elements,
|
1408 |
+
metadata,
|
1409 |
+
error="",
|
1410 |
+
content=None
|
1411 |
+
)
|
1412 |
+
except Exception as e:
|
1413 |
+
return self.build_action_result(
|
1414 |
+
False,
|
1415 |
+
str(e),
|
1416 |
+
None,
|
1417 |
+
"",
|
1418 |
+
"",
|
1419 |
+
{},
|
1420 |
+
error=str(e),
|
1421 |
+
content=None
|
1422 |
+
)
|
1423 |
+
|
1424 |
+
# Scroll Actions
|
1425 |
+
|
1426 |
+
async def scroll_down(self, action: ScrollAction = Body(...)):
|
1427 |
+
"""Scroll down the page"""
|
1428 |
+
try:
|
1429 |
+
page = await self.get_current_page()
|
1430 |
+
if action.amount is not None:
|
1431 |
+
await page.evaluate(f"window.scrollBy(0, {action.amount});")
|
1432 |
+
amount_str = f"{action.amount} pixels"
|
1433 |
+
else:
|
1434 |
+
await page.evaluate("window.scrollBy(0, window.innerHeight);")
|
1435 |
+
amount_str = "one page"
|
1436 |
+
|
1437 |
+
await page.wait_for_timeout(500) # Wait for scroll to complete
|
1438 |
+
|
1439 |
+
# Get updated state after action
|
1440 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_down({amount_str})")
|
1441 |
+
|
1442 |
+
return self.build_action_result(
|
1443 |
+
True,
|
1444 |
+
f"Scrolled down by {amount_str}",
|
1445 |
+
dom_state,
|
1446 |
+
screenshot,
|
1447 |
+
elements,
|
1448 |
+
metadata,
|
1449 |
+
error="",
|
1450 |
+
content=None
|
1451 |
+
)
|
1452 |
+
except Exception as e:
|
1453 |
+
return self.build_action_result(
|
1454 |
+
False,
|
1455 |
+
str(e),
|
1456 |
+
None,
|
1457 |
+
"",
|
1458 |
+
"",
|
1459 |
+
{},
|
1460 |
+
error=str(e),
|
1461 |
+
content=None
|
1462 |
+
)
|
1463 |
+
|
1464 |
+
async def scroll_up(self, action: ScrollAction = Body(...)):
|
1465 |
+
"""Scroll up the page"""
|
1466 |
+
try:
|
1467 |
+
page = await self.get_current_page()
|
1468 |
+
if action.amount is not None:
|
1469 |
+
await page.evaluate(f"window.scrollBy(0, -{action.amount});")
|
1470 |
+
amount_str = f"{action.amount} pixels"
|
1471 |
+
else:
|
1472 |
+
await page.evaluate("window.scrollBy(0, -window.innerHeight);")
|
1473 |
+
amount_str = "one page"
|
1474 |
+
|
1475 |
+
await page.wait_for_timeout(500) # Wait for scroll to complete
|
1476 |
+
|
1477 |
+
# Get updated state after action
|
1478 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_up({amount_str})")
|
1479 |
+
|
1480 |
+
return self.build_action_result(
|
1481 |
+
True,
|
1482 |
+
f"Scrolled up by {amount_str}",
|
1483 |
+
dom_state,
|
1484 |
+
screenshot,
|
1485 |
+
elements,
|
1486 |
+
metadata,
|
1487 |
+
error="",
|
1488 |
+
content=None
|
1489 |
+
)
|
1490 |
+
except Exception as e:
|
1491 |
+
return self.build_action_result(
|
1492 |
+
False,
|
1493 |
+
str(e),
|
1494 |
+
None,
|
1495 |
+
"",
|
1496 |
+
"",
|
1497 |
+
{},
|
1498 |
+
error=str(e),
|
1499 |
+
content=None
|
1500 |
+
)
|
1501 |
+
|
1502 |
+
async def scroll_to_text(self, text: str = Body(...)):
|
1503 |
+
"""Scroll to text on the page"""
|
1504 |
+
try:
|
1505 |
+
page = await self.get_current_page()
|
1506 |
+
locators = [
|
1507 |
+
page.get_by_text(text, exact=False),
|
1508 |
+
page.locator(f"text={text}"),
|
1509 |
+
page.locator(f"//*[contains(text(), '{text}')]"),
|
1510 |
+
]
|
1511 |
+
|
1512 |
+
found = False
|
1513 |
+
for locator in locators:
|
1514 |
+
try:
|
1515 |
+
if await locator.count() > 0 and await locator.first.is_visible():
|
1516 |
+
await locator.first.scroll_into_view_if_needed()
|
1517 |
+
await asyncio.sleep(0.5) # Wait for scroll to complete
|
1518 |
+
found = True
|
1519 |
+
break
|
1520 |
+
except Exception:
|
1521 |
+
continue
|
1522 |
+
|
1523 |
+
# Get updated state after action
|
1524 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_to_text({text})")
|
1525 |
+
|
1526 |
+
message = f"Scrolled to text: {text}" if found else f"Text '{text}' not found or not visible on page"
|
1527 |
+
|
1528 |
+
return self.build_action_result(
|
1529 |
+
found,
|
1530 |
+
message,
|
1531 |
+
dom_state,
|
1532 |
+
screenshot,
|
1533 |
+
elements,
|
1534 |
+
metadata,
|
1535 |
+
error="",
|
1536 |
+
content=None
|
1537 |
+
)
|
1538 |
+
except Exception as e:
|
1539 |
+
return self.build_action_result(
|
1540 |
+
False,
|
1541 |
+
str(e),
|
1542 |
+
None,
|
1543 |
+
"",
|
1544 |
+
"",
|
1545 |
+
{},
|
1546 |
+
error=str(e),
|
1547 |
+
content=None
|
1548 |
+
)
|
1549 |
+
|
1550 |
+
# Dropdown Actions
|
1551 |
+
|
1552 |
+
async def get_dropdown_options(self, index: int = Body(...)):
|
1553 |
+
"""Get all options from a dropdown"""
|
1554 |
+
try:
|
1555 |
+
page = await self.get_current_page()
|
1556 |
+
selector_map = await self.get_selector_map()
|
1557 |
+
|
1558 |
+
if index not in selector_map:
|
1559 |
+
return self.build_action_result(
|
1560 |
+
False,
|
1561 |
+
f"Element with index {index} not found",
|
1562 |
+
None,
|
1563 |
+
"",
|
1564 |
+
"",
|
1565 |
+
{},
|
1566 |
+
error=f"Element with index {index} not found"
|
1567 |
+
)
|
1568 |
+
|
1569 |
+
element = selector_map[index]
|
1570 |
+
options = []
|
1571 |
+
|
1572 |
+
# Try to get the options - in a real implementation, we would use appropriate selectors
|
1573 |
+
try:
|
1574 |
+
if element.tag_name.lower() == 'select':
|
1575 |
+
# For <select> elements, get options using JavaScript
|
1576 |
+
options_js = f"""
|
1577 |
+
Array.from(document.querySelectorAll('select')[{index-1}].options)
|
1578 |
+
.map((option, index) => ({
|
1579 |
+
index: index,
|
1580 |
+
text: option.text,
|
1581 |
+
value: option.value
|
1582 |
+
}));
|
1583 |
+
"""
|
1584 |
+
options = await page.evaluate(options_js)
|
1585 |
+
else:
|
1586 |
+
# For other dropdown types, try to get options using a more generic approach
|
1587 |
+
# Example for custom dropdowns - would need refinement in real implementation
|
1588 |
+
await page.click(f"#{element.attributes.get('id')}") if element.attributes.get('id') else None
|
1589 |
+
await page.wait_for_timeout(500)
|
1590 |
+
|
1591 |
+
options_js = """
|
1592 |
+
Array.from(document.querySelectorAll('.dropdown-item, [role="option"], li'))
|
1593 |
+
.filter(el => {
|
1594 |
+
const style = window.getComputedStyle(el);
|
1595 |
+
return style.display !== 'none' && style.visibility !== 'hidden';
|
1596 |
+
})
|
1597 |
+
.map((option, index) => ({
|
1598 |
+
index: index,
|
1599 |
+
text: option.innerText.trim(),
|
1600 |
+
value: option.getAttribute('value') || option.getAttribute('data-value') || option.innerText.trim()
|
1601 |
+
}));
|
1602 |
+
"""
|
1603 |
+
options = await page.evaluate(options_js)
|
1604 |
+
|
1605 |
+
# Close dropdown to restore state
|
1606 |
+
await page.keyboard.press("Escape")
|
1607 |
+
except Exception as e:
|
1608 |
+
self.logger.error(f"Error getting dropdown options: {e}")
|
1609 |
+
# Fallback to dummy options if real ones cannot be retrieved
|
1610 |
+
options = [
|
1611 |
+
{"index": 0, "text": "Option 1", "value": "option1"},
|
1612 |
+
{"index": 1, "text": "Option 2", "value": "option2"},
|
1613 |
+
{"index": 2, "text": "Option 3", "value": "option3"},
|
1614 |
+
]
|
1615 |
+
|
1616 |
+
# Get updated state
|
1617 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"get_dropdown_options({index})")
|
1618 |
+
|
1619 |
+
return self.build_action_result(
|
1620 |
+
True,
|
1621 |
+
f"Retrieved {len(options)} options from dropdown",
|
1622 |
+
dom_state,
|
1623 |
+
screenshot,
|
1624 |
+
elements,
|
1625 |
+
metadata,
|
1626 |
+
error="",
|
1627 |
+
content=json.dumps(options) # Include options in the content field
|
1628 |
+
)
|
1629 |
+
except Exception as e:
|
1630 |
+
return self.build_action_result(
|
1631 |
+
False,
|
1632 |
+
str(e),
|
1633 |
+
None,
|
1634 |
+
"",
|
1635 |
+
"",
|
1636 |
+
{},
|
1637 |
+
error=str(e),
|
1638 |
+
content=None
|
1639 |
+
)
|
1640 |
+
|
1641 |
+
async def select_dropdown_option(self, index: int = Body(...), option_text: str = Body(...)):
|
1642 |
+
"""Select an option from a dropdown by text"""
|
1643 |
+
try:
|
1644 |
+
page = await self.get_current_page()
|
1645 |
+
selector_map = await self.get_selector_map()
|
1646 |
+
|
1647 |
+
if index not in selector_map:
|
1648 |
+
return self.build_action_result(
|
1649 |
+
False,
|
1650 |
+
f"Element with index {index} not found",
|
1651 |
+
None,
|
1652 |
+
"",
|
1653 |
+
"",
|
1654 |
+
{},
|
1655 |
+
error=f"Element with index {index} not found"
|
1656 |
+
)
|
1657 |
+
|
1658 |
+
element = selector_map[index]
|
1659 |
+
|
1660 |
+
# Try to select the option - implementation varies by dropdown type
|
1661 |
+
if element.tag_name.lower() == 'select':
|
1662 |
+
# For standard <select> elements
|
1663 |
+
selector = f"select option:has-text('{option_text}')"
|
1664 |
+
await page.select_option(
|
1665 |
+
f"#{element.attributes.get('id')}" if element.attributes.get('id') else f"//select[{index}]",
|
1666 |
+
label=option_text
|
1667 |
+
)
|
1668 |
+
else:
|
1669 |
+
# For custom dropdowns
|
1670 |
+
# First click to open the dropdown
|
1671 |
+
if element.attributes.get('id'):
|
1672 |
+
await page.click(f"#{element.attributes.get('id')}")
|
1673 |
+
else:
|
1674 |
+
await page.click(f"//{element.tag_name}[{index}]")
|
1675 |
+
|
1676 |
+
await page.wait_for_timeout(500)
|
1677 |
+
|
1678 |
+
# Then try to click the option
|
1679 |
+
await page.click(f"text={option_text}")
|
1680 |
+
|
1681 |
+
await page.wait_for_timeout(500)
|
1682 |
+
|
1683 |
+
# Get updated state after action
|
1684 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"select_dropdown_option({index}, '{option_text}')")
|
1685 |
+
|
1686 |
+
return self.build_action_result(
|
1687 |
+
True,
|
1688 |
+
f"Selected option '{option_text}' from dropdown with index {index}",
|
1689 |
+
dom_state,
|
1690 |
+
screenshot,
|
1691 |
+
elements,
|
1692 |
+
metadata,
|
1693 |
+
error="",
|
1694 |
+
content=None
|
1695 |
+
)
|
1696 |
+
except Exception as e:
|
1697 |
+
return self.build_action_result(
|
1698 |
+
False,
|
1699 |
+
str(e),
|
1700 |
+
None,
|
1701 |
+
"",
|
1702 |
+
"",
|
1703 |
+
{},
|
1704 |
+
error=str(e),
|
1705 |
+
content=None
|
1706 |
+
)
|
1707 |
+
|
1708 |
+
# Drag and Drop
|
1709 |
+
|
1710 |
+
async def drag_drop(self, action: DragDropAction = Body(...)):
|
1711 |
+
"""Perform drag and drop operation"""
|
1712 |
+
try:
|
1713 |
+
page = await self.get_current_page()
|
1714 |
+
|
1715 |
+
# Element-based drag and drop
|
1716 |
+
if action.element_source and action.element_target:
|
1717 |
+
# In a real implementation, we would get the elements and perform the drag
|
1718 |
+
source_desc = action.element_source
|
1719 |
+
target_desc = action.element_target
|
1720 |
+
|
1721 |
+
# We would locate the elements using selectors and perform the drag
|
1722 |
+
# For this example, we'll use a simplified version
|
1723 |
+
await page.evaluate("""
|
1724 |
+
console.log("Simulating drag and drop between elements");
|
1725 |
+
""")
|
1726 |
+
|
1727 |
+
message = f"Dragged element '{source_desc}' to '{target_desc}'"
|
1728 |
+
|
1729 |
+
# Coordinate-based drag and drop
|
1730 |
+
elif all(coord is not None for coord in [
|
1731 |
+
action.coord_source_x, action.coord_source_y,
|
1732 |
+
action.coord_target_x, action.coord_target_y
|
1733 |
+
]):
|
1734 |
+
source_x = action.coord_source_x
|
1735 |
+
source_y = action.coord_source_y
|
1736 |
+
target_x = action.coord_target_x
|
1737 |
+
target_y = action.coord_target_y
|
1738 |
+
|
1739 |
+
# Perform the drag
|
1740 |
+
await page.mouse.move(source_x, source_y)
|
1741 |
+
await page.mouse.down()
|
1742 |
+
|
1743 |
+
steps = max(1, action.steps or 10)
|
1744 |
+
delay_ms = max(0, action.delay_ms or 5)
|
1745 |
+
|
1746 |
+
for i in range(1, steps + 1):
|
1747 |
+
ratio = i / steps
|
1748 |
+
intermediate_x = int(source_x + (target_x - source_x) * ratio)
|
1749 |
+
intermediate_y = int(source_y + (target_y - source_y) * ratio)
|
1750 |
+
await page.mouse.move(intermediate_x, intermediate_y)
|
1751 |
+
if delay_ms > 0:
|
1752 |
+
await asyncio.sleep(delay_ms / 1000)
|
1753 |
+
|
1754 |
+
await page.mouse.move(target_x, target_y)
|
1755 |
+
await page.mouse.up()
|
1756 |
+
|
1757 |
+
message = f"Dragged from ({source_x}, {source_y}) to ({target_x}, {target_y})"
|
1758 |
+
else:
|
1759 |
+
return self.build_action_result(
|
1760 |
+
False,
|
1761 |
+
"Must provide either source/target selectors or coordinates",
|
1762 |
+
None,
|
1763 |
+
"",
|
1764 |
+
"",
|
1765 |
+
{},
|
1766 |
+
error="Must provide either source/target selectors or coordinates"
|
1767 |
+
)
|
1768 |
+
|
1769 |
+
# Get updated state after action
|
1770 |
+
dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"drag_drop({action.element_source}, {action.element_target})")
|
1771 |
+
|
1772 |
+
return self.build_action_result(
|
1773 |
+
True,
|
1774 |
+
message,
|
1775 |
+
dom_state,
|
1776 |
+
screenshot,
|
1777 |
+
elements,
|
1778 |
+
metadata,
|
1779 |
+
error="",
|
1780 |
+
content=None
|
1781 |
+
)
|
1782 |
+
except Exception as e:
|
1783 |
+
return self.build_action_result(
|
1784 |
+
False,
|
1785 |
+
str(e),
|
1786 |
+
None,
|
1787 |
+
"",
|
1788 |
+
"",
|
1789 |
+
{},
|
1790 |
+
error=str(e),
|
1791 |
+
content=None
|
1792 |
+
)
|
1793 |
+
|
1794 |
+
# Create singleton instance
|
1795 |
+
automation_service = BrowserAutomation()
|
1796 |
+
|
1797 |
+
# Create API app
|
1798 |
+
api_app = FastAPI()
|
1799 |
+
|
1800 |
+
@api_app.get("/api")
|
1801 |
+
async def health_check():
|
1802 |
+
return {"status": "ok", "message": "API server is running"}
|
1803 |
+
|
1804 |
+
# Include automation service router with /api prefix
|
1805 |
+
api_app.include_router(automation_service.router, prefix="/api")
|
1806 |
+
|
1807 |
+
async def test_browser_api():
|
1808 |
+
"""Test the browser automation API functionality"""
|
1809 |
+
try:
|
1810 |
+
# Initialize browser automation
|
1811 |
+
print("\n=== Starting Browser Automation Test ===")
|
1812 |
+
await automation_service.startup()
|
1813 |
+
print("✅ Browser started successfully")
|
1814 |
+
|
1815 |
+
# Navigate to a test page with interactive elements
|
1816 |
+
print("\n--- Testing Navigation ---")
|
1817 |
+
result = await automation_service.navigate_to(GoToUrlAction(url="https://www.youtube.com"))
|
1818 |
+
print(f"Navigation status: {'✅ Success' if result.success else '❌ Failed'}")
|
1819 |
+
if not result.success:
|
1820 |
+
print(f"Error: {result.error}")
|
1821 |
+
return
|
1822 |
+
|
1823 |
+
print(f"URL: {result.url}")
|
1824 |
+
print(f"Title: {result.title}")
|
1825 |
+
|
1826 |
+
# Check DOM state and elements
|
1827 |
+
print(f"\nFound {result.element_count} interactive elements")
|
1828 |
+
if result.elements and result.elements.strip():
|
1829 |
+
print("Elements:")
|
1830 |
+
print(result.elements)
|
1831 |
+
else:
|
1832 |
+
print("No formatted elements found, but DOM was processed")
|
1833 |
+
|
1834 |
+
# Display interactive elements as JSON
|
1835 |
+
if result.interactive_elements and len(result.interactive_elements) > 0:
|
1836 |
+
print("\nInteractive elements summary:")
|
1837 |
+
for el in result.interactive_elements:
|
1838 |
+
print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
|
1839 |
+
|
1840 |
+
# Screenshot info
|
1841 |
+
print(f"\nScreenshot captured: {'Yes' if result.screenshot_base64 else 'No'}")
|
1842 |
+
print(f"Viewport size: {result.viewport_width}x{result.viewport_height}")
|
1843 |
+
|
1844 |
+
# Test OCR extraction from screenshot
|
1845 |
+
print("\n--- Testing OCR Text Extraction ---")
|
1846 |
+
if result.ocr_text:
|
1847 |
+
print("OCR text extracted from screenshot:")
|
1848 |
+
print("=== OCR TEXT START ===")
|
1849 |
+
print(result.ocr_text)
|
1850 |
+
print("=== OCR TEXT END ===")
|
1851 |
+
print(f"OCR text length: {len(result.ocr_text)} characters")
|
1852 |
+
print(result.ocr_text)
|
1853 |
+
else:
|
1854 |
+
print("No OCR text extracted from screenshot")
|
1855 |
+
|
1856 |
+
await asyncio.sleep(2)
|
1857 |
+
|
1858 |
+
# Test search functionality
|
1859 |
+
print("\n--- Testing Search ---")
|
1860 |
+
result = await automation_service.search_google(SearchGoogleAction(query="browser automation"))
|
1861 |
+
print(f"Search status: {'✅ Success' if result.success else '❌ Failed'}")
|
1862 |
+
if not result.success:
|
1863 |
+
print(f"Error: {result.error}")
|
1864 |
+
else:
|
1865 |
+
print(f"Found {result.element_count} elements after search")
|
1866 |
+
print(f"Page title: {result.title}")
|
1867 |
+
|
1868 |
+
# Test OCR extraction from search results
|
1869 |
+
if result.ocr_text:
|
1870 |
+
print("\nOCR text from search results:")
|
1871 |
+
print("=== OCR TEXT START ===")
|
1872 |
+
print(result.ocr_text)
|
1873 |
+
print("=== OCR TEXT END ===")
|
1874 |
+
else:
|
1875 |
+
print("\nNo OCR text extracted from search results")
|
1876 |
+
|
1877 |
+
await asyncio.sleep(2)
|
1878 |
+
|
1879 |
+
# Test scrolling
|
1880 |
+
print("\n--- Testing Scrolling ---")
|
1881 |
+
result = await automation_service.scroll_down(ScrollAction(amount=300))
|
1882 |
+
print(f"Scroll status: {'✅ Success' if result.success else '❌ Failed'}")
|
1883 |
+
if result.success:
|
1884 |
+
print(f"Pixels above viewport: {result.pixels_above}")
|
1885 |
+
print(f"Pixels below viewport: {result.pixels_below}")
|
1886 |
+
|
1887 |
+
await asyncio.sleep(2)
|
1888 |
+
|
1889 |
+
# Test clicking on an element
|
1890 |
+
print("\n--- Testing Element Click ---")
|
1891 |
+
if result.element_count > 0:
|
1892 |
+
click_result = await automation_service.click_element(ClickElementAction(index=1))
|
1893 |
+
print(f"Click status: {'✅ Success' if click_result.success else '❌ Failed'}")
|
1894 |
+
print(f"Message: {click_result.message}")
|
1895 |
+
print(f"New URL after click: {click_result.url}")
|
1896 |
+
else:
|
1897 |
+
print("Skipping click test - no elements found")
|
1898 |
+
|
1899 |
+
await asyncio.sleep(2)
|
1900 |
+
|
1901 |
+
# Test clicking on coordinates
|
1902 |
+
print("\n--- Testing Click Coordinates ---")
|
1903 |
+
coord_click_result = await automation_service.click_coordinates(ClickCoordinatesAction(x=100, y=100))
|
1904 |
+
print(f"Coordinate click status: {'✅ Success' if coord_click_result.success else '❌ Failed'}")
|
1905 |
+
print(f"Message: {coord_click_result.message}")
|
1906 |
+
print(f"URL after coordinate click: {coord_click_result.url}")
|
1907 |
+
|
1908 |
+
await asyncio.sleep(2)
|
1909 |
+
|
1910 |
+
# Test extracting content
|
1911 |
+
print("\n--- Testing Content Extraction ---")
|
1912 |
+
content_result = await automation_service.extract_content("test goal")
|
1913 |
+
print(f"Content extraction status: {'✅ Success' if content_result.success else '❌ Failed'}")
|
1914 |
+
if content_result.content:
|
1915 |
+
content_preview = content_result.content[:100] + "..." if len(content_result.content) > 100 else content_result.content
|
1916 |
+
print(f"Content sample: {content_preview}")
|
1917 |
+
print(f"Total content length: {len(content_result.content)} chars")
|
1918 |
+
else:
|
1919 |
+
print("No content was extracted")
|
1920 |
+
|
1921 |
+
# Test tab management
|
1922 |
+
print("\n--- Testing Tab Management ---")
|
1923 |
+
tab_result = await automation_service.open_tab(OpenTabAction(url="https://www.example.org"))
|
1924 |
+
print(f"New tab status: {'✅ Success' if tab_result.success else '❌ Failed'}")
|
1925 |
+
if tab_result.success:
|
1926 |
+
print(f"New tab title: {tab_result.title}")
|
1927 |
+
print(f"Interactive elements: {tab_result.element_count}")
|
1928 |
+
|
1929 |
+
print("\n✅ All tests completed successfully!")
|
1930 |
+
|
1931 |
+
except Exception as e:
|
1932 |
+
print(f"\n❌ Test failed: {str(e)}")
|
1933 |
+
traceback.print_exc()
|
1934 |
+
finally:
|
1935 |
+
# Ensure browser is closed
|
1936 |
+
print("\n--- Cleaning up ---")
|
1937 |
+
await automation_service.shutdown()
|
1938 |
+
print("Browser closed")
|
1939 |
+
|
1940 |
+
async def test_browser_api_2():
|
1941 |
+
"""Test the browser automation API functionality on the chess page"""
|
1942 |
+
try:
|
1943 |
+
# Initialize browser automation
|
1944 |
+
print("\n=== Starting Browser Automation Test 2 (Chess Page) ===")
|
1945 |
+
await automation_service.startup()
|
1946 |
+
print("✅ Browser started successfully")
|
1947 |
+
|
1948 |
+
# Navigate to the chess test page
|
1949 |
+
print("\n--- Testing Navigation to Chess Page ---")
|
1950 |
+
test_url = "https://dat-lequoc.github.io/chess-for-suna/chess.html"
|
1951 |
+
result = await automation_service.navigate_to(GoToUrlAction(url=test_url))
|
1952 |
+
print(f"Navigation status: {'✅ Success' if result.success else '❌ Failed'}")
|
1953 |
+
if not result.success:
|
1954 |
+
print(f"Error: {result.error}")
|
1955 |
+
return
|
1956 |
+
|
1957 |
+
print(f"URL: {result.url}")
|
1958 |
+
print(f"Title: {result.title}")
|
1959 |
+
|
1960 |
+
# Check DOM state and elements
|
1961 |
+
print(f"\nFound {result.element_count} interactive elements")
|
1962 |
+
if result.elements and result.elements.strip():
|
1963 |
+
print("Elements:")
|
1964 |
+
print(result.elements)
|
1965 |
+
else:
|
1966 |
+
print("No formatted elements found, but DOM was processed")
|
1967 |
+
|
1968 |
+
# Display interactive elements as JSON
|
1969 |
+
if result.interactive_elements and len(result.interactive_elements) > 0:
|
1970 |
+
print("\nInteractive elements summary:")
|
1971 |
+
for el in result.interactive_elements:
|
1972 |
+
print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
|
1973 |
+
|
1974 |
+
# Screenshot info
|
1975 |
+
print(f"\nScreenshot captured: {'Yes' if result.screenshot_base64 else 'No'}")
|
1976 |
+
print(f"Viewport size: {result.viewport_width}x{result.viewport_height}")
|
1977 |
+
|
1978 |
+
await asyncio.sleep(2)
|
1979 |
+
|
1980 |
+
# Test clicking on an element (e.g., a chess square)
|
1981 |
+
print("\n--- Testing Element Click (element 5) ---")
|
1982 |
+
if result.element_count > 4: # Ensure element 5 exists
|
1983 |
+
click_index = 5
|
1984 |
+
click_result = await automation_service.click_element(ClickElementAction(index=click_index))
|
1985 |
+
print(f"Click status for element {click_index}: {'✅ Success' if click_result.success else '❌ Failed'}")
|
1986 |
+
print(f"Message: {click_result.message}")
|
1987 |
+
print(f"URL after click: {click_result.url}")
|
1988 |
+
|
1989 |
+
# Retrieve and display elements again after click
|
1990 |
+
print(f"\n--- Retrieving elements after clicking element {click_index} ---")
|
1991 |
+
if click_result.elements and click_result.elements.strip():
|
1992 |
+
print("Updated Elements:")
|
1993 |
+
print(click_result.elements)
|
1994 |
+
else:
|
1995 |
+
print("No formatted elements found after click.")
|
1996 |
+
|
1997 |
+
if click_result.interactive_elements and len(click_result.interactive_elements) > 0:
|
1998 |
+
print("\nUpdated interactive elements summary:")
|
1999 |
+
for el in click_result.interactive_elements:
|
2000 |
+
print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
|
2001 |
+
else:
|
2002 |
+
print("No interactive elements found after click.")
|
2003 |
+
|
2004 |
+
# Test clicking element 1 after the first click
|
2005 |
+
print("\n--- Testing Element Click (element 1 after clicking 5) ---")
|
2006 |
+
if click_result.element_count > 0: # Check if there are still elements
|
2007 |
+
click_index_2 = 1
|
2008 |
+
click_result_2 = await automation_service.click_element(ClickElementAction(index=click_index_2))
|
2009 |
+
print(f"Click status for element {click_index_2}: {'✅ Success' if click_result_2.success else '❌ Failed'}")
|
2010 |
+
print(f"Message: {click_result_2.message}")
|
2011 |
+
print(f"URL after click: {click_result_2.url}")
|
2012 |
+
|
2013 |
+
# Retrieve and display elements again after the second click
|
2014 |
+
print(f"\n--- Retrieving elements after clicking element {click_index_2} ---")
|
2015 |
+
if click_result_2.elements and click_result_2.elements.strip():
|
2016 |
+
print("Elements after second click:")
|
2017 |
+
print(click_result_2.elements)
|
2018 |
+
else:
|
2019 |
+
print("No formatted elements found after second click.")
|
2020 |
+
|
2021 |
+
if click_result_2.interactive_elements and len(click_result_2.interactive_elements) > 0:
|
2022 |
+
print("\nInteractive elements summary after second click:")
|
2023 |
+
for el in click_result_2.interactive_elements:
|
2024 |
+
print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
|
2025 |
+
else:
|
2026 |
+
print("No interactive elements found after second click.")
|
2027 |
+
else:
|
2028 |
+
print("Skipping second element click test - no elements found after first click.")
|
2029 |
+
|
2030 |
+
else:
|
2031 |
+
print("Skipping element click test - fewer than 5 elements found.")
|
2032 |
+
|
2033 |
+
await asyncio.sleep(2)
|
2034 |
+
|
2035 |
+
print("\n✅ Chess Page Test Completed!")
|
2036 |
+
await asyncio.sleep(100)
|
2037 |
+
|
2038 |
+
except Exception as e:
|
2039 |
+
print(f"\n❌ Chess Page Test failed: {str(e)}")
|
2040 |
+
traceback.print_exc()
|
2041 |
+
finally:
|
2042 |
+
# Ensure browser is closed
|
2043 |
+
print("\n--- Cleaning up ---")
|
2044 |
+
await automation_service.shutdown()
|
2045 |
+
print("Browser closed")
|
2046 |
+
|
2047 |
+
if __name__ == '__main__':
|
2048 |
+
import uvicorn
|
2049 |
+
import sys
|
2050 |
+
|
2051 |
+
# Check command line arguments for test mode
|
2052 |
+
test_mode_1 = "--test" in sys.argv
|
2053 |
+
test_mode_2 = "--test2" in sys.argv
|
2054 |
+
|
2055 |
+
if test_mode_1:
|
2056 |
+
print("Running in test mode 1")
|
2057 |
+
asyncio.run(test_browser_api())
|
2058 |
+
elif test_mode_2:
|
2059 |
+
print("Running in test mode 2 (Chess Page)")
|
2060 |
+
asyncio.run(test_browser_api_2())
|
2061 |
+
else:
|
2062 |
+
print("Starting API server")
|
2063 |
+
uvicorn.run("browser_api:api_app", host="0.0.0.0", port=8002)
|
chat_template.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"chat_template": "{% set audio_count = namespace(value=0) %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_bos|><|IMAGE|><|vision_eos|>{% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_audio_id %}Audio {{ audio_count.value }}: {% endif %}<|audio_bos|><|AUDIO|><|audio_eos|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_bos|><|VIDEO|><|vision_eos|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
3 |
+
}
|
cleanup.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Print colored output
|
4 |
+
GREEN='\033[0;32m'
|
5 |
+
BLUE='\033[0;34m'
|
6 |
+
RED='\033[0;31m'
|
7 |
+
NC='\033[0m' # No Color
|
8 |
+
|
9 |
+
echo -e "${RED}Cleaning up all services...${NC}"
|
10 |
+
|
11 |
+
# Determine the script and project directories
|
12 |
+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
13 |
+
PROJECT_ROOT="$SCRIPT_DIR"
|
14 |
+
if [[ "$SCRIPT_DIR" == */scripts ]]; then
|
15 |
+
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
16 |
+
fi
|
17 |
+
|
18 |
+
# Stop all running background processes from previous runs
|
19 |
+
echo -e "${BLUE}Stopping background processes...${NC}"
|
20 |
+
pkill -f "uvicorn api:app"
|
21 |
+
pkill -f "npm run dev"
|
22 |
+
|
23 |
+
# Stop Redis container if running
|
24 |
+
echo -e "${BLUE}Stopping Redis container...${NC}"
|
25 |
+
docker stop agentpress-redis 2>/dev/null || true
|
26 |
+
docker rm agentpress-redis 2>/dev/null || true
|
27 |
+
|
28 |
+
# Stop Supabase
|
29 |
+
echo -e "${BLUE}Stopping Supabase...${NC}"
|
30 |
+
cd "$PROJECT_ROOT/backend/supabase"
|
31 |
+
supabase stop 2>/dev/null || true
|
32 |
+
cd "$SCRIPT_DIR"
|
33 |
+
|
34 |
+
echo -e "${GREEN}Cleanup complete. You can now start the services again.${NC}"
|
compose-dev.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
app:
|
3 |
+
entrypoint:
|
4 |
+
- sleep
|
5 |
+
- infinity
|
6 |
+
image: docker/dev-environments-javascript:stable-1
|
7 |
+
init: true
|
8 |
+
volumes:
|
9 |
+
- type: bind
|
10 |
+
source: /var/run/docker.sock
|
11 |
+
target: /var/run/docker.sock
|
12 |
+
|
computer_use_tool.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import base64
|
4 |
+
import aiohttp
|
5 |
+
import asyncio
|
6 |
+
import logging
|
7 |
+
from typing import Optional, Dict, Any, Union
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
11 |
+
from sandbox.sandbox import SandboxToolsBase, Sandbox
|
12 |
+
|
13 |
+
KEYBOARD_KEYS = [
|
14 |
+
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
|
15 |
+
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
|
16 |
+
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
|
17 |
+
'enter', 'esc', 'backspace', 'tab', 'space', 'delete',
|
18 |
+
'ctrl', 'alt', 'shift', 'win',
|
19 |
+
'up', 'down', 'left', 'right',
|
20 |
+
'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12',
|
21 |
+
'ctrl+c', 'ctrl+v', 'ctrl+x', 'ctrl+z', 'ctrl+a', 'ctrl+s',
|
22 |
+
'alt+tab', 'alt+f4', 'ctrl+alt+delete'
|
23 |
+
]
|
24 |
+
|
25 |
+
class ComputerUseTool(SandboxToolsBase):
|
26 |
+
"""Computer automation tool for controlling the sandbox browser and GUI."""
|
27 |
+
|
28 |
+
def __init__(self, sandbox: Sandbox):
|
29 |
+
"""Initialize automation tool with sandbox connection."""
|
30 |
+
super().__init__(sandbox)
|
31 |
+
self.session = None
|
32 |
+
self.mouse_x = 0 # Track current mouse position
|
33 |
+
self.mouse_y = 0
|
34 |
+
# Get automation service URL using port 8000
|
35 |
+
self.api_base_url = self.sandbox.get_preview_link(8000)
|
36 |
+
logging.info(f"Initialized Computer Use Tool with API URL: {self.api_base_url}")
|
37 |
+
|
38 |
+
async def _get_session(self) -> aiohttp.ClientSession:
|
39 |
+
"""Get or create aiohttp session for API requests."""
|
40 |
+
if self.session is None or self.session.closed:
|
41 |
+
self.session = aiohttp.ClientSession()
|
42 |
+
return self.session
|
43 |
+
|
44 |
+
async def _api_request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict:
|
45 |
+
"""Send request to automation service API."""
|
46 |
+
try:
|
47 |
+
session = await self._get_session()
|
48 |
+
url = f"{self.api_base_url}/api{endpoint}"
|
49 |
+
|
50 |
+
logging.debug(f"API request: {method} {url} {data}")
|
51 |
+
|
52 |
+
if method.upper() == "GET":
|
53 |
+
async with session.get(url) as response:
|
54 |
+
result = await response.json()
|
55 |
+
else: # POST
|
56 |
+
async with session.post(url, json=data) as response:
|
57 |
+
result = await response.json()
|
58 |
+
|
59 |
+
logging.debug(f"API response: {result}")
|
60 |
+
return result
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
logging.error(f"API request failed: {str(e)}")
|
64 |
+
return {"success": False, "error": str(e)}
|
65 |
+
|
66 |
+
async def cleanup(self):
|
67 |
+
"""Clean up resources."""
|
68 |
+
if self.session and not self.session.closed:
|
69 |
+
await self.session.close()
|
70 |
+
self.session = None
|
71 |
+
|
72 |
+
@openapi_schema({
|
73 |
+
"type": "function",
|
74 |
+
"function": {
|
75 |
+
"name": "move_to",
|
76 |
+
"description": "Move cursor to specified position",
|
77 |
+
"parameters": {
|
78 |
+
"type": "object",
|
79 |
+
"properties": {
|
80 |
+
"x": {
|
81 |
+
"type": "number",
|
82 |
+
"description": "X coordinate"
|
83 |
+
},
|
84 |
+
"y": {
|
85 |
+
"type": "number",
|
86 |
+
"description": "Y coordinate"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"required": ["x", "y"]
|
90 |
+
}
|
91 |
+
}
|
92 |
+
})
|
93 |
+
@xml_schema(
|
94 |
+
tag_name="move-to",
|
95 |
+
mappings=[
|
96 |
+
{"param_name": "x", "node_type": "attribute", "path": "."},
|
97 |
+
{"param_name": "y", "node_type": "attribute", "path": "."}
|
98 |
+
],
|
99 |
+
example='''
|
100 |
+
<move-to x="100" y="200">
|
101 |
+
</move-to>
|
102 |
+
'''
|
103 |
+
)
|
104 |
+
async def move_to(self, x: float, y: float) -> ToolResult:
|
105 |
+
"""Move cursor to specified position."""
|
106 |
+
try:
|
107 |
+
x_int = int(round(float(x)))
|
108 |
+
y_int = int(round(float(y)))
|
109 |
+
|
110 |
+
result = await self._api_request("POST", "/automation/mouse/move", {
|
111 |
+
"x": x_int,
|
112 |
+
"y": y_int
|
113 |
+
})
|
114 |
+
|
115 |
+
if result.get("success", False):
|
116 |
+
self.mouse_x = x_int
|
117 |
+
self.mouse_y = y_int
|
118 |
+
return ToolResult(success=True, output=f"Moved to ({x_int}, {y_int})")
|
119 |
+
else:
|
120 |
+
return ToolResult(success=False, output=f"Failed to move: {result.get('error', 'Unknown error')}")
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
return ToolResult(success=False, output=f"Failed to move: {str(e)}")
|
124 |
+
|
125 |
+
@openapi_schema({
|
126 |
+
"type": "function",
|
127 |
+
"function": {
|
128 |
+
"name": "click",
|
129 |
+
"description": "Click at current or specified position",
|
130 |
+
"parameters": {
|
131 |
+
"type": "object",
|
132 |
+
"properties": {
|
133 |
+
"button": {
|
134 |
+
"type": "string",
|
135 |
+
"description": "Mouse button to click",
|
136 |
+
"enum": ["left", "right", "middle"],
|
137 |
+
"default": "left"
|
138 |
+
},
|
139 |
+
"x": {
|
140 |
+
"type": "number",
|
141 |
+
"description": "Optional X coordinate"
|
142 |
+
},
|
143 |
+
"y": {
|
144 |
+
"type": "number",
|
145 |
+
"description": "Optional Y coordinate"
|
146 |
+
},
|
147 |
+
"num_clicks": {
|
148 |
+
"type": "integer",
|
149 |
+
"description": "Number of clicks",
|
150 |
+
"enum": [1, 2, 3],
|
151 |
+
"default": 1
|
152 |
+
}
|
153 |
+
}
|
154 |
+
}
|
155 |
+
}
|
156 |
+
})
|
157 |
+
@xml_schema(
|
158 |
+
tag_name="click",
|
159 |
+
mappings=[
|
160 |
+
{"param_name": "x", "node_type": "attribute", "path": "x"},
|
161 |
+
{"param_name": "y", "node_type": "attribute", "path": "y"},
|
162 |
+
{"param_name": "button", "node_type": "attribute", "path": "button"},
|
163 |
+
{"param_name": "num_clicks", "node_type": "attribute", "path": "num_clicks"}
|
164 |
+
],
|
165 |
+
example='''
|
166 |
+
<click x="100" y="200" button="left" num_clicks="1">
|
167 |
+
</click>
|
168 |
+
'''
|
169 |
+
)
|
170 |
+
async def click(self, x: Optional[float] = None, y: Optional[float] = None,
|
171 |
+
button: str = "left", num_clicks: int = 1) -> ToolResult:
|
172 |
+
"""Click at current or specified position."""
|
173 |
+
try:
|
174 |
+
x_val = x if x is not None else self.mouse_x
|
175 |
+
y_val = y if y is not None else self.mouse_y
|
176 |
+
|
177 |
+
x_int = int(round(float(x_val)))
|
178 |
+
y_int = int(round(float(y_val)))
|
179 |
+
num_clicks = int(num_clicks)
|
180 |
+
|
181 |
+
result = await self._api_request("POST", "/automation/mouse/click", {
|
182 |
+
"x": x_int,
|
183 |
+
"y": y_int,
|
184 |
+
"clicks": num_clicks,
|
185 |
+
"button": button.lower()
|
186 |
+
})
|
187 |
+
|
188 |
+
if result.get("success", False):
|
189 |
+
self.mouse_x = x_int
|
190 |
+
self.mouse_y = y_int
|
191 |
+
return ToolResult(success=True,
|
192 |
+
output=f"{num_clicks} {button} click(s) performed at ({x_int}, {y_int})")
|
193 |
+
else:
|
194 |
+
return ToolResult(success=False, output=f"Failed to click: {result.get('error', 'Unknown error')}")
|
195 |
+
except Exception as e:
|
196 |
+
return ToolResult(success=False, output=f"Failed to click: {str(e)}")
|
197 |
+
|
198 |
+
@openapi_schema({
|
199 |
+
"type": "function",
|
200 |
+
"function": {
|
201 |
+
"name": "scroll",
|
202 |
+
"description": "Scroll the mouse wheel at current position",
|
203 |
+
"parameters": {
|
204 |
+
"type": "object",
|
205 |
+
"properties": {
|
206 |
+
"amount": {
|
207 |
+
"type": "integer",
|
208 |
+
"description": "Scroll amount (positive for up, negative for down)",
|
209 |
+
"minimum": -10,
|
210 |
+
"maximum": 10
|
211 |
+
}
|
212 |
+
},
|
213 |
+
"required": ["amount"]
|
214 |
+
}
|
215 |
+
}
|
216 |
+
})
|
217 |
+
@xml_schema(
|
218 |
+
tag_name="scroll",
|
219 |
+
mappings=[
|
220 |
+
{"param_name": "amount", "node_type": "attribute", "path": "amount"}
|
221 |
+
],
|
222 |
+
example='''
|
223 |
+
<scroll amount="-3">
|
224 |
+
</scroll>
|
225 |
+
'''
|
226 |
+
)
|
227 |
+
async def scroll(self, amount: int) -> ToolResult:
|
228 |
+
"""
|
229 |
+
Scroll the mouse wheel at current position.
|
230 |
+
Positive values scroll up, negative values scroll down.
|
231 |
+
"""
|
232 |
+
try:
|
233 |
+
amount = int(float(amount))
|
234 |
+
amount = max(-10, min(10, amount))
|
235 |
+
|
236 |
+
result = await self._api_request("POST", "/automation/mouse/scroll", {
|
237 |
+
"clicks": amount,
|
238 |
+
"x": self.mouse_x,
|
239 |
+
"y": self.mouse_y
|
240 |
+
})
|
241 |
+
|
242 |
+
if result.get("success", False):
|
243 |
+
direction = "up" if amount > 0 else "down"
|
244 |
+
steps = abs(amount)
|
245 |
+
return ToolResult(success=True,
|
246 |
+
output=f"Scrolled {direction} {steps} step(s) at position ({self.mouse_x}, {self.mouse_y})")
|
247 |
+
else:
|
248 |
+
return ToolResult(success=False, output=f"Failed to scroll: {result.get('error', 'Unknown error')}")
|
249 |
+
except Exception as e:
|
250 |
+
return ToolResult(success=False, output=f"Failed to scroll: {str(e)}")
|
251 |
+
|
252 |
+
@openapi_schema({
|
253 |
+
"type": "function",
|
254 |
+
"function": {
|
255 |
+
"name": "typing",
|
256 |
+
"description": "Type specified text",
|
257 |
+
"parameters": {
|
258 |
+
"type": "object",
|
259 |
+
"properties": {
|
260 |
+
"text": {
|
261 |
+
"type": "string",
|
262 |
+
"description": "Text to type"
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"required": ["text"]
|
266 |
+
}
|
267 |
+
}
|
268 |
+
})
|
269 |
+
@xml_schema(
|
270 |
+
tag_name="typing",
|
271 |
+
mappings=[
|
272 |
+
{"param_name": "text", "node_type": "content", "path": "text"}
|
273 |
+
],
|
274 |
+
example='''
|
275 |
+
<typing>Hello World!</typing>
|
276 |
+
'''
|
277 |
+
)
|
278 |
+
async def typing(self, text: str) -> ToolResult:
|
279 |
+
"""Type specified text."""
|
280 |
+
try:
|
281 |
+
text = str(text)
|
282 |
+
|
283 |
+
result = await self._api_request("POST", "/automation/keyboard/write", {
|
284 |
+
"message": text,
|
285 |
+
"interval": 0.01
|
286 |
+
})
|
287 |
+
|
288 |
+
if result.get("success", False):
|
289 |
+
return ToolResult(success=True, output=f"Typed: {text}")
|
290 |
+
else:
|
291 |
+
return ToolResult(success=False, output=f"Failed to type: {result.get('error', 'Unknown error')}")
|
292 |
+
except Exception as e:
|
293 |
+
return ToolResult(success=False, output=f"Failed to type: {str(e)}")
|
294 |
+
|
295 |
+
@openapi_schema({
|
296 |
+
"type": "function",
|
297 |
+
"function": {
|
298 |
+
"name": "press",
|
299 |
+
"description": "Press and release a key",
|
300 |
+
"parameters": {
|
301 |
+
"type": "object",
|
302 |
+
"properties": {
|
303 |
+
"key": {
|
304 |
+
"type": "string",
|
305 |
+
"description": "Key to press",
|
306 |
+
"enum": KEYBOARD_KEYS
|
307 |
+
}
|
308 |
+
},
|
309 |
+
"required": ["key"]
|
310 |
+
}
|
311 |
+
}
|
312 |
+
})
|
313 |
+
@xml_schema(
|
314 |
+
tag_name="press",
|
315 |
+
mappings=[
|
316 |
+
{"param_name": "key", "node_type": "attribute", "path": "key"}
|
317 |
+
],
|
318 |
+
example='''
|
319 |
+
<press key="enter">
|
320 |
+
</press>
|
321 |
+
'''
|
322 |
+
)
|
323 |
+
async def press(self, key: str) -> ToolResult:
|
324 |
+
"""Press and release a key."""
|
325 |
+
try:
|
326 |
+
key = str(key).lower()
|
327 |
+
|
328 |
+
result = await self._api_request("POST", "/automation/keyboard/press", {
|
329 |
+
"keys": key,
|
330 |
+
"presses": 1
|
331 |
+
})
|
332 |
+
|
333 |
+
if result.get("success", False):
|
334 |
+
return ToolResult(success=True, output=f"Pressed key: {key}")
|
335 |
+
else:
|
336 |
+
return ToolResult(success=False, output=f"Failed to press key: {result.get('error', 'Unknown error')}")
|
337 |
+
except Exception as e:
|
338 |
+
return ToolResult(success=False, output=f"Failed to press key: {str(e)}")
|
339 |
+
|
340 |
+
@openapi_schema({
|
341 |
+
"type": "function",
|
342 |
+
"function": {
|
343 |
+
"name": "wait",
|
344 |
+
"description": "Wait for specified duration",
|
345 |
+
"parameters": {
|
346 |
+
"type": "object",
|
347 |
+
"properties": {
|
348 |
+
"duration": {
|
349 |
+
"type": "number",
|
350 |
+
"description": "Duration in seconds",
|
351 |
+
"default": 0.5
|
352 |
+
}
|
353 |
+
}
|
354 |
+
}
|
355 |
+
}
|
356 |
+
})
|
357 |
+
@xml_schema(
|
358 |
+
tag_name="wait",
|
359 |
+
mappings=[
|
360 |
+
{"param_name": "duration", "node_type": "attribute", "path": "duration"}
|
361 |
+
],
|
362 |
+
example='''
|
363 |
+
<wait duration="1.5">
|
364 |
+
</wait>
|
365 |
+
'''
|
366 |
+
)
|
367 |
+
async def wait(self, duration: float = 0.5) -> ToolResult:
|
368 |
+
"""Wait for specified duration."""
|
369 |
+
try:
|
370 |
+
duration = float(duration)
|
371 |
+
duration = max(0, min(10, duration))
|
372 |
+
await asyncio.sleep(duration)
|
373 |
+
return ToolResult(success=True, output=f"Waited {duration} seconds")
|
374 |
+
except Exception as e:
|
375 |
+
return ToolResult(success=False, output=f"Failed to wait: {str(e)}")
|
376 |
+
|
377 |
+
@openapi_schema({
|
378 |
+
"type": "function",
|
379 |
+
"function": {
|
380 |
+
"name": "mouse_down",
|
381 |
+
"description": "Press a mouse button",
|
382 |
+
"parameters": {
|
383 |
+
"type": "object",
|
384 |
+
"properties": {
|
385 |
+
"button": {
|
386 |
+
"type": "string",
|
387 |
+
"description": "Mouse button to press",
|
388 |
+
"enum": ["left", "right", "middle"],
|
389 |
+
"default": "left"
|
390 |
+
}
|
391 |
+
}
|
392 |
+
}
|
393 |
+
}
|
394 |
+
})
|
395 |
+
@xml_schema(
|
396 |
+
tag_name="mouse-down",
|
397 |
+
mappings=[
|
398 |
+
{"param_name": "button", "node_type": "attribute", "path": "button"}
|
399 |
+
],
|
400 |
+
example='''
|
401 |
+
<mouse-down button="left">
|
402 |
+
</mouse-down>
|
403 |
+
'''
|
404 |
+
)
|
405 |
+
async def mouse_down(self, button: str = "left", x: Optional[float] = None, y: Optional[float] = None) -> ToolResult:
|
406 |
+
"""Press a mouse button at current or specified position."""
|
407 |
+
try:
|
408 |
+
x_val = x if x is not None else self.mouse_x
|
409 |
+
y_val = y if y is not None else self.mouse_y
|
410 |
+
|
411 |
+
x_int = int(round(float(x_val)))
|
412 |
+
y_int = int(round(float(y_val)))
|
413 |
+
|
414 |
+
result = await self._api_request("POST", "/automation/mouse/down", {
|
415 |
+
"x": x_int,
|
416 |
+
"y": y_int,
|
417 |
+
"button": button.lower()
|
418 |
+
})
|
419 |
+
|
420 |
+
if result.get("success", False):
|
421 |
+
self.mouse_x = x_int
|
422 |
+
self.mouse_y = y_int
|
423 |
+
return ToolResult(success=True, output=f"{button} button pressed at ({x_int}, {y_int})")
|
424 |
+
else:
|
425 |
+
return ToolResult(success=False, output=f"Failed to press button: {result.get('error', 'Unknown error')}")
|
426 |
+
except Exception as e:
|
427 |
+
return ToolResult(success=False, output=f"Failed to press button: {str(e)}")
|
428 |
+
|
429 |
+
@openapi_schema({
|
430 |
+
"type": "function",
|
431 |
+
"function": {
|
432 |
+
"name": "mouse_up",
|
433 |
+
"description": "Release a mouse button",
|
434 |
+
"parameters": {
|
435 |
+
"type": "object",
|
436 |
+
"properties": {
|
437 |
+
"button": {
|
438 |
+
"type": "string",
|
439 |
+
"description": "Mouse button to release",
|
440 |
+
"enum": ["left", "right", "middle"],
|
441 |
+
"default": "left"
|
442 |
+
}
|
443 |
+
}
|
444 |
+
}
|
445 |
+
}
|
446 |
+
})
|
447 |
+
@xml_schema(
|
448 |
+
tag_name="mouse-up",
|
449 |
+
mappings=[
|
450 |
+
{"param_name": "button", "node_type": "attribute", "path": "button"}
|
451 |
+
],
|
452 |
+
example='''
|
453 |
+
<mouse-up button="left">
|
454 |
+
</mouse-up>
|
455 |
+
'''
|
456 |
+
)
|
457 |
+
async def mouse_up(self, button: str = "left", x: Optional[float] = None, y: Optional[float] = None) -> ToolResult:
|
458 |
+
"""Release a mouse button at current or specified position."""
|
459 |
+
try:
|
460 |
+
x_val = x if x is not None else self.mouse_x
|
461 |
+
y_val = y if y is not None else self.mouse_y
|
462 |
+
|
463 |
+
x_int = int(round(float(x_val)))
|
464 |
+
y_int = int(round(float(y_val)))
|
465 |
+
|
466 |
+
result = await self._api_request("POST", "/automation/mouse/up", {
|
467 |
+
"x": x_int,
|
468 |
+
"y": y_int,
|
469 |
+
"button": button.lower()
|
470 |
+
})
|
471 |
+
|
472 |
+
if result.get("success", False):
|
473 |
+
self.mouse_x = x_int
|
474 |
+
self.mouse_y = y_int
|
475 |
+
return ToolResult(success=True, output=f"{button} button released at ({x_int}, {y_int})")
|
476 |
+
else:
|
477 |
+
return ToolResult(success=False, output=f"Failed to release button: {result.get('error', 'Unknown error')}")
|
478 |
+
except Exception as e:
|
479 |
+
return ToolResult(success=False, output=f"Failed to release button: {str(e)}")
|
480 |
+
|
481 |
+
@openapi_schema({
|
482 |
+
"type": "function",
|
483 |
+
"function": {
|
484 |
+
"name": "drag_to",
|
485 |
+
"description": "Drag cursor to specified position",
|
486 |
+
"parameters": {
|
487 |
+
"type": "object",
|
488 |
+
"properties": {
|
489 |
+
"x": {
|
490 |
+
"type": "number",
|
491 |
+
"description": "Target X coordinate"
|
492 |
+
},
|
493 |
+
"y": {
|
494 |
+
"type": "number",
|
495 |
+
"description": "Target Y coordinate"
|
496 |
+
}
|
497 |
+
},
|
498 |
+
"required": ["x", "y"]
|
499 |
+
}
|
500 |
+
}
|
501 |
+
})
|
502 |
+
@xml_schema(
|
503 |
+
tag_name="drag-to",
|
504 |
+
mappings=[
|
505 |
+
{"param_name": "x", "node_type": "attribute", "path": "x"},
|
506 |
+
{"param_name": "y", "node_type": "attribute", "path": "y"}
|
507 |
+
],
|
508 |
+
example='''
|
509 |
+
<drag-to x="500" y="50">
|
510 |
+
</drag-to>
|
511 |
+
'''
|
512 |
+
)
|
513 |
+
async def drag_to(self, x: float, y: float) -> ToolResult:
|
514 |
+
"""Click and drag from current position to target position."""
|
515 |
+
try:
|
516 |
+
target_x = int(round(float(x)))
|
517 |
+
target_y = int(round(float(y)))
|
518 |
+
start_x = self.mouse_x
|
519 |
+
start_y = self.mouse_y
|
520 |
+
|
521 |
+
result = await self._api_request("POST", "/automation/mouse/drag", {
|
522 |
+
"x": target_x,
|
523 |
+
"y": target_y,
|
524 |
+
"duration": 0.3,
|
525 |
+
"button": "left"
|
526 |
+
})
|
527 |
+
|
528 |
+
if result.get("success", False):
|
529 |
+
self.mouse_x = target_x
|
530 |
+
self.mouse_y = target_y
|
531 |
+
return ToolResult(success=True,
|
532 |
+
output=f"Dragged from ({start_x}, {start_y}) to ({target_x}, {target_y})")
|
533 |
+
else:
|
534 |
+
return ToolResult(success=False, output=f"Failed to drag: {result.get('error', 'Unknown error')}")
|
535 |
+
except Exception as e:
|
536 |
+
return ToolResult(success=False, output=f"Failed to drag: {str(e)}")
|
537 |
+
|
538 |
+
async def get_screenshot_base64(self) -> Optional[dict]:
|
539 |
+
"""Capture screen and return as base64 encoded image."""
|
540 |
+
try:
|
541 |
+
result = await self._api_request("POST", "/automation/screenshot")
|
542 |
+
|
543 |
+
if "image" in result:
|
544 |
+
base64_str = result["image"]
|
545 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
546 |
+
|
547 |
+
# Save screenshot to file
|
548 |
+
screenshots_dir = "screenshots"
|
549 |
+
if not os.path.exists(screenshots_dir):
|
550 |
+
os.makedirs(screenshots_dir)
|
551 |
+
|
552 |
+
timestamped_filename = os.path.join(screenshots_dir, f"screenshot_{timestamp}.png")
|
553 |
+
latest_filename = "latest_screenshot.png"
|
554 |
+
|
555 |
+
# Decode base64 string and save to file
|
556 |
+
img_data = base64.b64decode(base64_str)
|
557 |
+
with open(timestamped_filename, 'wb') as f:
|
558 |
+
f.write(img_data)
|
559 |
+
|
560 |
+
# Save a copy as the latest screenshot
|
561 |
+
with open(latest_filename, 'wb') as f:
|
562 |
+
f.write(img_data)
|
563 |
+
|
564 |
+
return {
|
565 |
+
"content_type": "image/png",
|
566 |
+
"base64": base64_str,
|
567 |
+
"timestamp": timestamp,
|
568 |
+
"filename": timestamped_filename
|
569 |
+
}
|
570 |
+
else:
|
571 |
+
return None
|
572 |
+
|
573 |
+
except Exception as e:
|
574 |
+
print(f"[Screenshot] Error during screenshot process: {str(e)}")
|
575 |
+
return None
|
576 |
+
|
577 |
+
@openapi_schema({
|
578 |
+
"type": "function",
|
579 |
+
"function": {
|
580 |
+
"name": "hotkey",
|
581 |
+
"description": "Press a key combination",
|
582 |
+
"parameters": {
|
583 |
+
"type": "object",
|
584 |
+
"properties": {
|
585 |
+
"keys": {
|
586 |
+
"type": "string",
|
587 |
+
"description": "Key combination to press",
|
588 |
+
"enum": KEYBOARD_KEYS
|
589 |
+
}
|
590 |
+
},
|
591 |
+
"required": ["keys"]
|
592 |
+
}
|
593 |
+
}
|
594 |
+
})
|
595 |
+
@xml_schema(
|
596 |
+
tag_name="hotkey",
|
597 |
+
mappings=[
|
598 |
+
{"param_name": "keys", "node_type": "attribute", "path": "keys"}
|
599 |
+
],
|
600 |
+
example='''
|
601 |
+
<hotkey keys="ctrl+a">
|
602 |
+
</hotkey>
|
603 |
+
'''
|
604 |
+
)
|
605 |
+
async def hotkey(self, keys: str) -> ToolResult:
|
606 |
+
"""Press a key combination."""
|
607 |
+
try:
|
608 |
+
keys = str(keys).lower().strip()
|
609 |
+
key_sequence = keys.split('+')
|
610 |
+
|
611 |
+
result = await self._api_request("POST", "/automation/keyboard/hotkey", {
|
612 |
+
"keys": key_sequence,
|
613 |
+
"interval": 0.01
|
614 |
+
})
|
615 |
+
|
616 |
+
if result.get("success", False):
|
617 |
+
return ToolResult(success=True, output=f"Pressed key combination: {keys}")
|
618 |
+
else:
|
619 |
+
return ToolResult(success=False, output=f"Failed to press keys: {result.get('error', 'Unknown error')}")
|
620 |
+
except Exception as e:
|
621 |
+
return ToolResult(success=False, output=f"Failed to press keys: {str(e)}")
|
622 |
+
|
623 |
+
if __name__ == "__main__":
|
624 |
+
print("This module should be imported, not run directly.")
|
config.cpython-311.pyc
ADDED
Binary file (3.84 kB). View file
|
|
config.json
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen2_5OmniModel"
|
4 |
+
],
|
5 |
+
"enable_audio_output": true,
|
6 |
+
"enable_talker": true,
|
7 |
+
"model_type": "qwen2_5_omni",
|
8 |
+
"talker_config": {
|
9 |
+
"_attn_implementation_autoset": true,
|
10 |
+
"_name_or_path": "Qwen2.5-Omni-3B/talker",
|
11 |
+
"architectures": [
|
12 |
+
"Qwen2OmniTalkerForConditionalGeneration"
|
13 |
+
],
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"audio_end_token_id": 151648,
|
16 |
+
"audio_start_token_id": 151647,
|
17 |
+
"audio_token_index": 151646,
|
18 |
+
"embedding_size": 2048,
|
19 |
+
"head_dim": 64,
|
20 |
+
"hidden_act": "silu",
|
21 |
+
"hidden_size": 896,
|
22 |
+
"image_token_index": 151655,
|
23 |
+
"init_std": 0.02,
|
24 |
+
"initializer_range": 0.02,
|
25 |
+
"intermediate_size": 4864,
|
26 |
+
"max_position_embeddings": 32768,
|
27 |
+
"max_window_layers": 28,
|
28 |
+
"model_type": "qwen2_5_omni_talker",
|
29 |
+
"num_attention_heads": 14,
|
30 |
+
"num_hidden_layers": 24,
|
31 |
+
"num_key_value_heads": 2,
|
32 |
+
"position_id_per_seconds": 25,
|
33 |
+
"rms_norm_eps": 1e-06,
|
34 |
+
"rope_scaling": {
|
35 |
+
"mrope_section": [
|
36 |
+
16,
|
37 |
+
16,
|
38 |
+
0
|
39 |
+
],
|
40 |
+
"rope_type": "default",
|
41 |
+
"type": "default"
|
42 |
+
},
|
43 |
+
"rope_theta": 1000000.0,
|
44 |
+
"seconds_per_chunk": 2,
|
45 |
+
"sliding_window": 32768,
|
46 |
+
"spatial_merge_size": 2,
|
47 |
+
"torch_dtype": "bfloat16",
|
48 |
+
"tts_codec_end_token_id": 8294,
|
49 |
+
"tts_codec_mask_token_id": 8296,
|
50 |
+
"tts_codec_pad_token_id": 8292,
|
51 |
+
"tts_codec_start_token_id": 8293,
|
52 |
+
"tts_text_end_token_id": 151861,
|
53 |
+
"tts_text_pad_token_id": 151859,
|
54 |
+
"tts_text_start_token_id": 151860,
|
55 |
+
"use_cache": true,
|
56 |
+
"use_sliding_window": false,
|
57 |
+
"video_token_index": 151656,
|
58 |
+
"vision_end_token_id": 151653,
|
59 |
+
"vision_start_token_id": 151652,
|
60 |
+
"vocab_size": 8448
|
61 |
+
},
|
62 |
+
"thinker_config": {
|
63 |
+
"_attn_implementation_autoset": true,
|
64 |
+
"_name_or_path": "Qwen2.5-Omni-3B/thinker",
|
65 |
+
"architectures": [
|
66 |
+
"Qwen2OmniNaViTThinkerForConditionalGeneration"
|
67 |
+
],
|
68 |
+
"audio_config": {
|
69 |
+
"_attn_implementation_autoset": true,
|
70 |
+
"_name_or_path": "",
|
71 |
+
"activation_dropout": 0.0,
|
72 |
+
"activation_function": "gelu",
|
73 |
+
"add_cross_attention": false,
|
74 |
+
"architectures": null,
|
75 |
+
"attention_dropout": 0.0,
|
76 |
+
"bad_words_ids": null,
|
77 |
+
"begin_suppress_tokens": null,
|
78 |
+
"bos_token_id": null,
|
79 |
+
"chunk_size_feed_forward": 0,
|
80 |
+
"cross_attention_hidden_size": null,
|
81 |
+
"d_model": 1280,
|
82 |
+
"decoder_start_token_id": null,
|
83 |
+
"diversity_penalty": 0.0,
|
84 |
+
"do_sample": false,
|
85 |
+
"dropout": 0.0,
|
86 |
+
"early_stopping": false,
|
87 |
+
"encoder_attention_heads": 20,
|
88 |
+
"encoder_ffn_dim": 5120,
|
89 |
+
"encoder_layerdrop": 0.0,
|
90 |
+
"encoder_layers": 32,
|
91 |
+
"encoder_no_repeat_ngram_size": 0,
|
92 |
+
"eos_token_id": null,
|
93 |
+
"exponential_decay_length_penalty": null,
|
94 |
+
"finetuning_task": null,
|
95 |
+
"forced_bos_token_id": null,
|
96 |
+
"forced_eos_token_id": null,
|
97 |
+
"id2label": {
|
98 |
+
"0": "LABEL_0",
|
99 |
+
"1": "LABEL_1"
|
100 |
+
},
|
101 |
+
"init_std": 0.02,
|
102 |
+
"is_decoder": false,
|
103 |
+
"is_encoder_decoder": false,
|
104 |
+
"label2id": {
|
105 |
+
"LABEL_0": 0,
|
106 |
+
"LABEL_1": 1
|
107 |
+
},
|
108 |
+
"length_penalty": 1.0,
|
109 |
+
"max_length": 20,
|
110 |
+
"max_source_positions": 1500,
|
111 |
+
"min_length": 0,
|
112 |
+
"model_type": "qwen2_5_omni_audio_encoder",
|
113 |
+
"n_window": 100,
|
114 |
+
"no_repeat_ngram_size": 0,
|
115 |
+
"num_beam_groups": 1,
|
116 |
+
"num_beams": 1,
|
117 |
+
"num_hidden_layers": 32,
|
118 |
+
"num_mel_bins": 128,
|
119 |
+
"num_return_sequences": 1,
|
120 |
+
"output_attentions": false,
|
121 |
+
"output_dim": 2048,
|
122 |
+
"output_hidden_states": false,
|
123 |
+
"output_scores": false,
|
124 |
+
"pad_token_id": null,
|
125 |
+
"prefix": null,
|
126 |
+
"problem_type": null,
|
127 |
+
"pruned_heads": {},
|
128 |
+
"remove_invalid_values": false,
|
129 |
+
"repetition_penalty": 1.0,
|
130 |
+
"return_dict": true,
|
131 |
+
"return_dict_in_generate": false,
|
132 |
+
"scale_embedding": false,
|
133 |
+
"sep_token_id": null,
|
134 |
+
"suppress_tokens": null,
|
135 |
+
"task_specific_params": null,
|
136 |
+
"temperature": 1.0,
|
137 |
+
"tf_legacy_loss": false,
|
138 |
+
"tie_encoder_decoder": false,
|
139 |
+
"tie_word_embeddings": true,
|
140 |
+
"tokenizer_class": null,
|
141 |
+
"top_k": 50,
|
142 |
+
"top_p": 1.0,
|
143 |
+
"torch_dtype": null,
|
144 |
+
"torchscript": false,
|
145 |
+
"typical_p": 1.0,
|
146 |
+
"use_bfloat16": false
|
147 |
+
},
|
148 |
+
"text_config": {
|
149 |
+
"model_type": "qwen2_5_omni_text",
|
150 |
+
"hidden_act": "silu",
|
151 |
+
"hidden_size": 2048,
|
152 |
+
"init_std": 0.02,
|
153 |
+
"intermediate_size": 11008,
|
154 |
+
"vocab_size": 151936,
|
155 |
+
"num_attention_heads": 16,
|
156 |
+
"num_hidden_layers": 36,
|
157 |
+
"num_key_value_heads": 2,
|
158 |
+
"max_position_embeddings": 32768,
|
159 |
+
"max_window_layers": 70,
|
160 |
+
"rms_norm_eps": 1e-06,
|
161 |
+
"rope_scaling": {
|
162 |
+
"mrope_section": [
|
163 |
+
16,
|
164 |
+
24,
|
165 |
+
24
|
166 |
+
],
|
167 |
+
"rope_type": "default",
|
168 |
+
"type": "default"
|
169 |
+
},
|
170 |
+
"use_cache": true,
|
171 |
+
"rope_theta": 1000000.0,
|
172 |
+
"use_sliding_window": false,
|
173 |
+
"sliding_window": 32768,
|
174 |
+
"attention_dropout": 0.0,
|
175 |
+
"tie_word_embeddings": false
|
176 |
+
},
|
177 |
+
"audio_end_token_id": 151648,
|
178 |
+
"audio_start_token_id": 151647,
|
179 |
+
"audio_token_index": 151646,
|
180 |
+
"bos_token_id": 151644,
|
181 |
+
"eos_token_id": 151645,
|
182 |
+
"ignore_index": -100,
|
183 |
+
"image_token_index": 151655,
|
184 |
+
"init_std": 0.02,
|
185 |
+
"model_type": "qwen2_5_omni_thinker",
|
186 |
+
"pad_token_id": 151643,
|
187 |
+
"position_id_per_seconds": 25,
|
188 |
+
"seconds_per_chunk": 2,
|
189 |
+
"torch_dtype": "bfloat16",
|
190 |
+
"user_token_id": 872,
|
191 |
+
"video_token_index": 151656,
|
192 |
+
"vision_config": {
|
193 |
+
"_attn_implementation_autoset": true,
|
194 |
+
"_name_or_path": "",
|
195 |
+
"add_cross_attention": false,
|
196 |
+
"architectures": null,
|
197 |
+
"bad_words_ids": null,
|
198 |
+
"begin_suppress_tokens": null,
|
199 |
+
"bos_token_id": null,
|
200 |
+
"chunk_size_feed_forward": 0,
|
201 |
+
"cross_attention_hidden_size": null,
|
202 |
+
"decoder_start_token_id": null,
|
203 |
+
"depth": 32,
|
204 |
+
"diversity_penalty": 0.0,
|
205 |
+
"do_sample": false,
|
206 |
+
"early_stopping": false,
|
207 |
+
"embed_dim": 1280,
|
208 |
+
"encoder_no_repeat_ngram_size": 0,
|
209 |
+
"eos_token_id": null,
|
210 |
+
"exponential_decay_length_penalty": null,
|
211 |
+
"finetuning_task": null,
|
212 |
+
"forced_bos_token_id": null,
|
213 |
+
"forced_eos_token_id": null,
|
214 |
+
"fullatt_block_indexes": [
|
215 |
+
7,
|
216 |
+
15,
|
217 |
+
23,
|
218 |
+
31
|
219 |
+
],
|
220 |
+
"hidden_act": "silu",
|
221 |
+
"hidden_size": 1280,
|
222 |
+
"id2label": {
|
223 |
+
"0": "LABEL_0",
|
224 |
+
"1": "LABEL_1"
|
225 |
+
},
|
226 |
+
"in_channels": 3,
|
227 |
+
"in_chans": 3,
|
228 |
+
"init_std": 0.02,
|
229 |
+
"intermediate_size": 3420,
|
230 |
+
"is_decoder": false,
|
231 |
+
"is_encoder_decoder": false,
|
232 |
+
"label2id": {
|
233 |
+
"LABEL_0": 0,
|
234 |
+
"LABEL_1": 1
|
235 |
+
},
|
236 |
+
"length_penalty": 1.0,
|
237 |
+
"max_length": 20,
|
238 |
+
"min_length": 0,
|
239 |
+
"model_type": "qwen2_5_omni_vision_encoder",
|
240 |
+
"no_repeat_ngram_size": 0,
|
241 |
+
"num_beam_groups": 1,
|
242 |
+
"num_beams": 1,
|
243 |
+
"num_heads": 16,
|
244 |
+
"num_return_sequences": 1,
|
245 |
+
"out_hidden_size": 2048,
|
246 |
+
"output_attentions": false,
|
247 |
+
"output_hidden_states": false,
|
248 |
+
"output_scores": false,
|
249 |
+
"pad_token_id": null,
|
250 |
+
"patch_size": 14,
|
251 |
+
"prefix": null,
|
252 |
+
"problem_type": null,
|
253 |
+
"pruned_heads": {},
|
254 |
+
"remove_invalid_values": false,
|
255 |
+
"repetition_penalty": 1.0,
|
256 |
+
"return_dict": true,
|
257 |
+
"return_dict_in_generate": false,
|
258 |
+
"sep_token_id": null,
|
259 |
+
"spatial_merge_size": 2,
|
260 |
+
"spatial_patch_size": 14,
|
261 |
+
"suppress_tokens": null,
|
262 |
+
"task_specific_params": null,
|
263 |
+
"temperature": 1.0,
|
264 |
+
"temporal_patch_size": 2,
|
265 |
+
"tf_legacy_loss": false,
|
266 |
+
"tie_encoder_decoder": false,
|
267 |
+
"tie_word_embeddings": true,
|
268 |
+
"tokenizer_class": null,
|
269 |
+
"tokens_per_second": 25,
|
270 |
+
"top_k": 50,
|
271 |
+
"top_p": 1.0,
|
272 |
+
"torch_dtype": null,
|
273 |
+
"torchscript": false,
|
274 |
+
"typical_p": 1.0,
|
275 |
+
"use_bfloat16": false,
|
276 |
+
"window_size": 112
|
277 |
+
},
|
278 |
+
"vision_end_token_id": 151653,
|
279 |
+
"vision_start_token_id": 151652,
|
280 |
+
"vision_token_id": 151654
|
281 |
+
},
|
282 |
+
"token2wav_config": {
|
283 |
+
"_attn_implementation_autoset": true,
|
284 |
+
"bigvgan_config": {
|
285 |
+
"_attn_implementation_autoset": true,
|
286 |
+
"_name_or_path": "",
|
287 |
+
"add_cross_attention": false,
|
288 |
+
"architectures": null,
|
289 |
+
"bad_words_ids": null,
|
290 |
+
"begin_suppress_tokens": null,
|
291 |
+
"bos_token_id": null,
|
292 |
+
"chunk_size_feed_forward": 0,
|
293 |
+
"cross_attention_hidden_size": null,
|
294 |
+
"decoder_start_token_id": null,
|
295 |
+
"diversity_penalty": 0.0,
|
296 |
+
"do_sample": false,
|
297 |
+
"early_stopping": false,
|
298 |
+
"encoder_no_repeat_ngram_size": 0,
|
299 |
+
"eos_token_id": null,
|
300 |
+
"exponential_decay_length_penalty": null,
|
301 |
+
"finetuning_task": null,
|
302 |
+
"forced_bos_token_id": null,
|
303 |
+
"forced_eos_token_id": null,
|
304 |
+
"id2label": {
|
305 |
+
"0": "LABEL_0",
|
306 |
+
"1": "LABEL_1"
|
307 |
+
},
|
308 |
+
"is_decoder": false,
|
309 |
+
"is_encoder_decoder": false,
|
310 |
+
"label2id": {
|
311 |
+
"LABEL_0": 0,
|
312 |
+
"LABEL_1": 1
|
313 |
+
},
|
314 |
+
"length_penalty": 1.0,
|
315 |
+
"max_length": 20,
|
316 |
+
"mel_dim": 80,
|
317 |
+
"min_length": 0,
|
318 |
+
"model_type": "qwen2_5_omni_bigvgan",
|
319 |
+
"no_repeat_ngram_size": 0,
|
320 |
+
"num_beam_groups": 1,
|
321 |
+
"num_beams": 1,
|
322 |
+
"num_return_sequences": 1,
|
323 |
+
"output_attentions": false,
|
324 |
+
"output_hidden_states": false,
|
325 |
+
"output_scores": false,
|
326 |
+
"pad_token_id": null,
|
327 |
+
"prefix": null,
|
328 |
+
"problem_type": null,
|
329 |
+
"pruned_heads": {},
|
330 |
+
"remove_invalid_values": false,
|
331 |
+
"repetition_penalty": 1.0,
|
332 |
+
"resblock_dilation_sizes": [
|
333 |
+
[
|
334 |
+
1,
|
335 |
+
3,
|
336 |
+
5
|
337 |
+
],
|
338 |
+
[
|
339 |
+
1,
|
340 |
+
3,
|
341 |
+
5
|
342 |
+
],
|
343 |
+
[
|
344 |
+
1,
|
345 |
+
3,
|
346 |
+
5
|
347 |
+
]
|
348 |
+
],
|
349 |
+
"resblock_kernel_sizes": [
|
350 |
+
3,
|
351 |
+
7,
|
352 |
+
11
|
353 |
+
],
|
354 |
+
"return_dict": true,
|
355 |
+
"return_dict_in_generate": false,
|
356 |
+
"sep_token_id": null,
|
357 |
+
"suppress_tokens": null,
|
358 |
+
"task_specific_params": null,
|
359 |
+
"temperature": 1.0,
|
360 |
+
"tf_legacy_loss": false,
|
361 |
+
"tie_encoder_decoder": false,
|
362 |
+
"tie_word_embeddings": true,
|
363 |
+
"tokenizer_class": null,
|
364 |
+
"top_k": 50,
|
365 |
+
"top_p": 1.0,
|
366 |
+
"torch_dtype": null,
|
367 |
+
"torchscript": false,
|
368 |
+
"typical_p": 1.0,
|
369 |
+
"upsample_initial_channel": 1536,
|
370 |
+
"upsample_kernel_sizes": [
|
371 |
+
11,
|
372 |
+
7,
|
373 |
+
4,
|
374 |
+
4,
|
375 |
+
4,
|
376 |
+
4
|
377 |
+
],
|
378 |
+
"upsample_rates": [
|
379 |
+
5,
|
380 |
+
3,
|
381 |
+
2,
|
382 |
+
2,
|
383 |
+
2,
|
384 |
+
2
|
385 |
+
],
|
386 |
+
"use_bfloat16": false,
|
387 |
+
"use_bias_at_final": false
|
388 |
+
},
|
389 |
+
"dit_config": {
|
390 |
+
"_attn_implementation_autoset": true,
|
391 |
+
"_name_or_path": "",
|
392 |
+
"add_cross_attention": false,
|
393 |
+
"architectures": null,
|
394 |
+
"bad_words_ids": null,
|
395 |
+
"begin_suppress_tokens": null,
|
396 |
+
"bos_token_id": null,
|
397 |
+
"chunk_size_feed_forward": 0,
|
398 |
+
"cross_attention_hidden_size": null,
|
399 |
+
"decoder_start_token_id": null,
|
400 |
+
"depth": 22,
|
401 |
+
"dim": 1024,
|
402 |
+
"diversity_penalty": 0.0,
|
403 |
+
"do_sample": false,
|
404 |
+
"dropout": 0.1,
|
405 |
+
"early_stopping": false,
|
406 |
+
"emb_dim": 512,
|
407 |
+
"enc_attention_channels": 64,
|
408 |
+
"enc_channels": [
|
409 |
+
256,
|
410 |
+
256,
|
411 |
+
256,
|
412 |
+
256,
|
413 |
+
768
|
414 |
+
],
|
415 |
+
"enc_dilations": [
|
416 |
+
1,
|
417 |
+
2,
|
418 |
+
3,
|
419 |
+
4,
|
420 |
+
1
|
421 |
+
],
|
422 |
+
"enc_dim": 128,
|
423 |
+
"enc_emb_dim": 192,
|
424 |
+
"enc_global_context": true,
|
425 |
+
"enc_kernel_sizes": [
|
426 |
+
5,
|
427 |
+
3,
|
428 |
+
3,
|
429 |
+
3,
|
430 |
+
1
|
431 |
+
],
|
432 |
+
"enc_lin_neurons": 192,
|
433 |
+
"enc_res2net_scale": 2,
|
434 |
+
"enc_se_channels": 64,
|
435 |
+
"encoder_no_repeat_ngram_size": 0,
|
436 |
+
"eos_token_id": null,
|
437 |
+
"exponential_decay_length_penalty": null,
|
438 |
+
"ff_mult": 2,
|
439 |
+
"finetuning_task": null,
|
440 |
+
"forced_bos_token_id": null,
|
441 |
+
"forced_eos_token_id": null,
|
442 |
+
"head_dim": 64,
|
443 |
+
"heads": 16,
|
444 |
+
"id2label": {
|
445 |
+
"0": "LABEL_0",
|
446 |
+
"1": "LABEL_1"
|
447 |
+
},
|
448 |
+
"is_decoder": false,
|
449 |
+
"is_encoder_decoder": false,
|
450 |
+
"label2id": {
|
451 |
+
"LABEL_0": 0,
|
452 |
+
"LABEL_1": 1
|
453 |
+
},
|
454 |
+
"length_penalty": 1.0,
|
455 |
+
"max_length": 20,
|
456 |
+
"mel_dim": 80,
|
457 |
+
"min_length": 0,
|
458 |
+
"model_type": "qwen2_5_omni_dit",
|
459 |
+
"no_repeat_ngram_size": 0,
|
460 |
+
"num_beam_groups": 1,
|
461 |
+
"num_beams": 1,
|
462 |
+
"num_embeds": 8193,
|
463 |
+
"num_return_sequences": 1,
|
464 |
+
"output_attentions": false,
|
465 |
+
"output_hidden_states": false,
|
466 |
+
"output_scores": false,
|
467 |
+
"pad_token_id": null,
|
468 |
+
"prefix": null,
|
469 |
+
"problem_type": null,
|
470 |
+
"pruned_heads": {},
|
471 |
+
"remove_invalid_values": false,
|
472 |
+
"repeats": 2,
|
473 |
+
"repetition_penalty": 1.0,
|
474 |
+
"return_dict": true,
|
475 |
+
"return_dict_in_generate": false,
|
476 |
+
"sep_token_id": null,
|
477 |
+
"suppress_tokens": null,
|
478 |
+
"task_specific_params": null,
|
479 |
+
"temperature": 1.0,
|
480 |
+
"tf_legacy_loss": false,
|
481 |
+
"tie_encoder_decoder": false,
|
482 |
+
"tie_word_embeddings": true,
|
483 |
+
"tokenizer_class": null,
|
484 |
+
"top_k": 50,
|
485 |
+
"top_p": 1.0,
|
486 |
+
"torch_dtype": "float32",
|
487 |
+
"torchscript": false,
|
488 |
+
"typical_p": 1.0,
|
489 |
+
"use_bfloat16": false
|
490 |
+
},
|
491 |
+
"model_type": "qwen2_5_omni_token2wav"
|
492 |
+
},
|
493 |
+
"torch_dtype": "bfloat16",
|
494 |
+
"transformers_version": "4.51.0.dev0"
|
495 |
+
}
|