lattmamb commited on
Commit
aa916fd
·
verified ·
1 Parent(s): 9a8d04a

Upload 229 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. 20240414161707_basejump-setup.sql +186 -0
  3. 20240414161947_basejump-accounts.sql +708 -0
  4. 20240414162100_basejump-invitations.sql +270 -0
  5. 20240414162131_basejump-billing.sql +236 -0
  6. 20250409211903_basejump-configure.sql +3 -0
  7. 20250409212058_initial.sql +189 -0
  8. 20250416133920_agentpress_schema.sql +382 -0
  9. 20250506000000_initial_setup.sql +85 -0
  10. 20250506000001_account_functions.sql +50 -0
  11. 20250506000002_project_functions.sql +105 -0
  12. 22dc0511fe69_add_toolsource_table.cpython-311.pyc +0 -0
  13. 2ea570019b8f_add_apikey_table.cpython-311.pyc +0 -0
  14. 2ea570019b8f_add_apikey_table.py +58 -0
  15. 4af13678b83c_add_toolsource_table.cpython-311.pyc +0 -0
  16. 4af13678b83c_add_toolsource_table.py +50 -0
  17. ActiveJobsProvider.py +57 -0
  18. AmazonProvider.py +191 -0
  19. ChatInterface.tsx +30 -0
  20. Dockerfile +19 -0
  21. Layout.tsx +41 -0
  22. LinkedinProvider.py +250 -0
  23. MANIFEST.in +17 -0
  24. README +1 -0
  25. README.md +36 -10
  26. RapidDataProviderBase.py +61 -0
  27. SettingsPanel.tsx +31 -0
  28. TwitterProvider.py +240 -0
  29. WorkflowEditor.tsx +52 -0
  30. YahooFinanceProvider.py +190 -0
  31. ZillowProvider.py +187 -0
  32. __init__.py +1 -0
  33. added_tokens.json +24 -0
  34. agent.py +41 -0
  35. alembic.ini +62 -0
  36. api.cpython-311.pyc +0 -0
  37. api.py +311 -0
  38. api.py.bak +156 -0
  39. api_keys.py +68 -0
  40. architecture_diagram.svg +0 -0
  41. auth_utils.py +177 -0
  42. base.py +33 -0
  43. billing.py +125 -0
  44. browser_api.py +2063 -0
  45. chat_template.json +3 -0
  46. cleanup.sh +34 -0
  47. compose-dev.yaml +12 -0
  48. computer_use_tool.py +624 -0
  49. config.cpython-311.pyc +0 -0
  50. config.json +495 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ diagram.png filter=lfs diff=lfs merge=lfs -text
20240414161707_basejump-setup.sql ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ ____ _
3
+ | _ \ (_)
4
+ | |_) | __ _ ___ ___ _ _ _ _ __ ___ _ __
5
+ | _ < / _` / __|/ _ \ | | | | '_ ` _ \| '_ \
6
+ | |_) | (_| \__ \ __/ | |_| | | | | | | |_) |
7
+ |____/ \__,_|___/\___| |\__,_|_| |_| |_| .__/
8
+ _/ | | |
9
+ |__/ |_|
10
+
11
+ Basejump is a starter kit for building SaaS products on top of Supabase.
12
+ Learn more at https://usebasejump.com
13
+ */
14
+
15
+
16
+ /**
17
+ * -------------------------------------------------------
18
+ * Section - Basejump schema setup and utility functions
19
+ * -------------------------------------------------------
20
+ */
21
+
22
+ -- revoke execution by default from public
23
+ ALTER DEFAULT PRIVILEGES REVOKE EXECUTE ON FUNCTIONS FROM PUBLIC;
24
+ ALTER DEFAULT PRIVILEGES IN SCHEMA PUBLIC REVOKE EXECUTE ON FUNCTIONS FROM anon, authenticated;
25
+
26
+ -- Create basejump schema
27
+ CREATE SCHEMA IF NOT EXISTS basejump;
28
+ GRANT USAGE ON SCHEMA basejump to authenticated;
29
+ GRANT USAGE ON SCHEMA basejump to service_role;
30
+
31
+ /**
32
+ * -------------------------------------------------------
33
+ * Section - Enums
34
+ * -------------------------------------------------------
35
+ */
36
+
37
+ /**
38
+ * Invitation types are either email or link. Email invitations are sent to
39
+ * a single user and can only be claimed once. Link invitations can be used multiple times
40
+ * Both expire after 24 hours
41
+ */
42
+ DO
43
+ $$
44
+ BEGIN
45
+ -- check it account_role already exists on basejump schema
46
+ IF NOT EXISTS(SELECT 1
47
+ FROM pg_type t
48
+ JOIN pg_namespace n ON n.oid = t.typnamespace
49
+ WHERE t.typname = 'invitation_type'
50
+ AND n.nspname = 'basejump') THEN
51
+ CREATE TYPE basejump.invitation_type AS ENUM ('one_time', '24_hour');
52
+ end if;
53
+ end;
54
+ $$;
55
+
56
+ /**
57
+ * -------------------------------------------------------
58
+ * Section - Basejump settings
59
+ * -------------------------------------------------------
60
+ */
61
+
62
+ CREATE TABLE IF NOT EXISTS basejump.config
63
+ (
64
+ enable_team_accounts boolean default true,
65
+ enable_personal_account_billing boolean default true,
66
+ enable_team_account_billing boolean default true,
67
+ billing_provider text default 'stripe'
68
+ );
69
+
70
+ -- create config row
71
+ INSERT INTO basejump.config (enable_team_accounts, enable_personal_account_billing, enable_team_account_billing)
72
+ VALUES (true, true, true);
73
+
74
+ -- enable select on the config table
75
+ GRANT SELECT ON basejump.config TO authenticated, service_role;
76
+
77
+ -- enable RLS on config
78
+ ALTER TABLE basejump.config
79
+ ENABLE ROW LEVEL SECURITY;
80
+
81
+ create policy "Basejump settings can be read by authenticated users" on basejump.config
82
+ for select
83
+ to authenticated
84
+ using (
85
+ true
86
+ );
87
+
88
+ /**
89
+ * -------------------------------------------------------
90
+ * Section - Basejump utility functions
91
+ * -------------------------------------------------------
92
+ */
93
+
94
+ /**
95
+ basejump.get_config()
96
+ Get the full config object to check basejump settings
97
+ This is not accessible from the outside, so can only be used inside postgres functions
98
+ */
99
+ CREATE OR REPLACE FUNCTION basejump.get_config()
100
+ RETURNS json AS
101
+ $$
102
+ DECLARE
103
+ result RECORD;
104
+ BEGIN
105
+ SELECT * from basejump.config limit 1 into result;
106
+ return row_to_json(result);
107
+ END;
108
+ $$ LANGUAGE plpgsql;
109
+
110
+ grant execute on function basejump.get_config() to authenticated, service_role;
111
+
112
+
113
+ /**
114
+ basejump.is_set("field_name")
115
+ Check a specific boolean config value
116
+ */
117
+ CREATE OR REPLACE FUNCTION basejump.is_set(field_name text)
118
+ RETURNS boolean AS
119
+ $$
120
+ DECLARE
121
+ result BOOLEAN;
122
+ BEGIN
123
+ execute format('select %I from basejump.config limit 1', field_name) into result;
124
+ return result;
125
+ END;
126
+ $$ LANGUAGE plpgsql;
127
+
128
+ grant execute on function basejump.is_set(text) to authenticated;
129
+
130
+
131
+ /**
132
+ * Automatic handling for maintaining created_at and updated_at timestamps
133
+ * on tables
134
+ */
135
+ CREATE OR REPLACE FUNCTION basejump.trigger_set_timestamps()
136
+ RETURNS TRIGGER AS
137
+ $$
138
+ BEGIN
139
+ if TG_OP = 'INSERT' then
140
+ NEW.created_at = now();
141
+ NEW.updated_at = now();
142
+ else
143
+ NEW.updated_at = now();
144
+ NEW.created_at = OLD.created_at;
145
+ end if;
146
+ RETURN NEW;
147
+ END
148
+ $$ LANGUAGE plpgsql;
149
+
150
+
151
+ /**
152
+ * Automatic handling for maintaining created_by and updated_by timestamps
153
+ * on tables
154
+ */
155
+ CREATE OR REPLACE FUNCTION basejump.trigger_set_user_tracking()
156
+ RETURNS TRIGGER AS
157
+ $$
158
+ BEGIN
159
+ if TG_OP = 'INSERT' then
160
+ NEW.created_by = auth.uid();
161
+ NEW.updated_by = auth.uid();
162
+ else
163
+ NEW.updated_by = auth.uid();
164
+ NEW.created_by = OLD.created_by;
165
+ end if;
166
+ RETURN NEW;
167
+ END
168
+ $$ LANGUAGE plpgsql;
169
+
170
+ /**
171
+ basejump.generate_token(length)
172
+ Generates a secure token - used internally for invitation tokens
173
+ but could be used elsewhere. Check out the invitations table for more info on
174
+ how it's used
175
+ */
176
+ CREATE OR REPLACE FUNCTION basejump.generate_token(length int)
177
+ RETURNS text AS
178
+ $$
179
+ select regexp_replace(replace(
180
+ replace(replace(replace(encode(gen_random_bytes(length)::bytea, 'base64'), '/', ''), '+',
181
+ ''), '\', ''),
182
+ '=',
183
+ ''), E'[\\n\\r]+', '', 'g');
184
+ $$ LANGUAGE sql;
185
+
186
+ grant execute on function basejump.generate_token(int) to authenticated;
20240414161947_basejump-accounts.sql ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ ____ _
3
+ | _ \ (_)
4
+ | |_) | __ _ ___ ___ _ _ _ _ __ ___ _ __
5
+ | _ < / _` / __|/ _ \ | | | | '_ ` _ \| '_ \
6
+ | |_) | (_| \__ \ __/ | |_| | | | | | | |_) |
7
+ |____/ \__,_|___/\___| |\__,_|_| |_| |_| .__/
8
+ _/ | | |
9
+ |__/ |_|
10
+
11
+ Basejump is a starter kit for building SaaS products on top of Supabase.
12
+ Learn more at https://usebasejump.com
13
+ */
14
+
15
+ /**
16
+ * -------------------------------------------------------
17
+ * Section - Accounts
18
+ * -------------------------------------------------------
19
+ */
20
+
21
+ /**
22
+ * Account roles allow you to provide permission levels to users
23
+ * when they're acting on an account. By default, we provide
24
+ * "owner" and "member". The only distinction is that owners can
25
+ * also manage billing and invite/remove account members.
26
+ */
27
+ DO
28
+ $$
29
+ BEGIN
30
+ -- check it account_role already exists on basejump schema
31
+ IF NOT EXISTS(SELECT 1
32
+ FROM pg_type t
33
+ JOIN pg_namespace n ON n.oid = t.typnamespace
34
+ WHERE t.typname = 'account_role'
35
+ AND n.nspname = 'basejump') THEN
36
+ CREATE TYPE basejump.account_role AS ENUM ('owner', 'member');
37
+ end if;
38
+ end;
39
+ $$;
40
+
41
+ /**
42
+ * Accounts are the primary grouping for most objects within
43
+ * the system. They have many users, and all billing is connected to
44
+ * an account.
45
+ */
46
+ CREATE TABLE IF NOT EXISTS basejump.accounts
47
+ (
48
+ id uuid unique NOT NULL DEFAULT extensions.uuid_generate_v4(),
49
+ -- defaults to the user who creates the account
50
+ -- this user cannot be removed from an account without changing
51
+ -- the primary owner first
52
+ primary_owner_user_id uuid references auth.users not null default auth.uid(),
53
+ -- Account name
54
+ name text,
55
+ slug text unique,
56
+ personal_account boolean default false not null,
57
+ updated_at timestamp with time zone,
58
+ created_at timestamp with time zone,
59
+ created_by uuid references auth.users,
60
+ updated_by uuid references auth.users,
61
+ private_metadata jsonb default '{}'::jsonb,
62
+ public_metadata jsonb default '{}'::jsonb,
63
+ PRIMARY KEY (id)
64
+ );
65
+
66
+ -- constraint that conditionally allows nulls on the slug ONLY if personal_account is true
67
+ -- remove this if you want to ignore accounts slugs entirely
68
+ ALTER TABLE basejump.accounts
69
+ ADD CONSTRAINT basejump_accounts_slug_null_if_personal_account_true CHECK (
70
+ (personal_account = true AND slug is null)
71
+ OR (personal_account = false AND slug is not null)
72
+ );
73
+
74
+ -- Open up access to accounts
75
+ GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.accounts TO authenticated, service_role;
76
+
77
+ /**
78
+ * We want to protect some fields on accounts from being updated
79
+ * Specifically the primary owner user id and account id.
80
+ * primary_owner_user_id should be updated using the dedicated function
81
+ */
82
+ CREATE OR REPLACE FUNCTION basejump.protect_account_fields()
83
+ RETURNS TRIGGER AS
84
+ $$
85
+ BEGIN
86
+ IF current_user IN ('authenticated', 'anon') THEN
87
+ -- these are protected fields that users are not allowed to update themselves
88
+ -- platform admins should be VERY careful about updating them as well.
89
+ if NEW.id <> OLD.id
90
+ OR NEW.personal_account <> OLD.personal_account
91
+ OR NEW.primary_owner_user_id <> OLD.primary_owner_user_id
92
+ THEN
93
+ RAISE EXCEPTION 'You do not have permission to update this field';
94
+ end if;
95
+ end if;
96
+
97
+ RETURN NEW;
98
+ END
99
+ $$ LANGUAGE plpgsql;
100
+
101
+ -- trigger to protect account fields
102
+ CREATE TRIGGER basejump_protect_account_fields
103
+ BEFORE UPDATE
104
+ ON basejump.accounts
105
+ FOR EACH ROW
106
+ EXECUTE FUNCTION basejump.protect_account_fields();
107
+
108
+ -- convert any character in the slug that's not a letter, number, or dash to a dash on insert/update for accounts
109
+ CREATE OR REPLACE FUNCTION basejump.slugify_account_slug()
110
+ RETURNS TRIGGER AS
111
+ $$
112
+ BEGIN
113
+ if NEW.slug is not null then
114
+ NEW.slug = lower(regexp_replace(NEW.slug, '[^a-zA-Z0-9-]+', '-', 'g'));
115
+ end if;
116
+
117
+ RETURN NEW;
118
+ END
119
+ $$ LANGUAGE plpgsql;
120
+
121
+ -- trigger to slugify the account slug
122
+ CREATE TRIGGER basejump_slugify_account_slug
123
+ BEFORE INSERT OR UPDATE
124
+ ON basejump.accounts
125
+ FOR EACH ROW
126
+ EXECUTE FUNCTION basejump.slugify_account_slug();
127
+
128
+ -- enable RLS for accounts
129
+ alter table basejump.accounts
130
+ enable row level security;
131
+
132
+ -- protect the timestamps
133
+ CREATE TRIGGER basejump_set_accounts_timestamp
134
+ BEFORE INSERT OR UPDATE
135
+ ON basejump.accounts
136
+ FOR EACH ROW
137
+ EXECUTE PROCEDURE basejump.trigger_set_timestamps();
138
+
139
+ -- set the user tracking
140
+ CREATE TRIGGER basejump_set_accounts_user_tracking
141
+ BEFORE INSERT OR UPDATE
142
+ ON basejump.accounts
143
+ FOR EACH ROW
144
+ EXECUTE PROCEDURE basejump.trigger_set_user_tracking();
145
+
146
+ /**
147
+ * Account users are the users that are associated with an account.
148
+ * They can be invited to join the account, and can have different roles.
149
+ * The system does not enforce any permissions for roles, other than restricting
150
+ * billing and account membership to only owners
151
+ */
152
+ create table if not exists basejump.account_user
153
+ (
154
+ -- id of the user in the account
155
+ user_id uuid references auth.users on delete cascade not null,
156
+ -- id of the account the user is in
157
+ account_id uuid references basejump.accounts on delete cascade not null,
158
+ -- role of the user in the account
159
+ account_role basejump.account_role not null,
160
+ constraint account_user_pkey primary key (user_id, account_id)
161
+ );
162
+
163
+ GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.account_user TO authenticated, service_role;
164
+
165
+
166
+ -- enable RLS for account_user
167
+ alter table basejump.account_user
168
+ enable row level security;
169
+
170
+ /**
171
+ * When an account gets created, we want to insert the current user as the first
172
+ * owner
173
+ */
174
+ create or replace function basejump.add_current_user_to_new_account()
175
+ returns trigger
176
+ language plpgsql
177
+ security definer
178
+ set search_path = public
179
+ as
180
+ $$
181
+ begin
182
+ if new.primary_owner_user_id = auth.uid() then
183
+ insert into basejump.account_user (account_id, user_id, account_role)
184
+ values (NEW.id, auth.uid(), 'owner');
185
+ end if;
186
+ return NEW;
187
+ end;
188
+ $$;
189
+
190
+ -- trigger the function whenever a new account is created
191
+ CREATE TRIGGER basejump_add_current_user_to_new_account
192
+ AFTER INSERT
193
+ ON basejump.accounts
194
+ FOR EACH ROW
195
+ EXECUTE FUNCTION basejump.add_current_user_to_new_account();
196
+
197
+ /**
198
+ * When a user signs up, we need to create a personal account for them
199
+ * and add them to the account_user table so they can act on it
200
+ */
201
+ create or replace function basejump.run_new_user_setup()
202
+ returns trigger
203
+ language plpgsql
204
+ security definer
205
+ set search_path = public
206
+ as
207
+ $$
208
+ declare
209
+ first_account_id uuid;
210
+ generated_user_name text;
211
+ begin
212
+
213
+ -- first we setup the user profile
214
+ -- TODO: see if we can get the user's name from the auth.users table once we learn how oauth works
215
+ if new.email IS NOT NULL then
216
+ generated_user_name := split_part(new.email, '@', 1);
217
+ end if;
218
+ -- create the new users's personal account
219
+ insert into basejump.accounts (name, primary_owner_user_id, personal_account, id)
220
+ values (generated_user_name, NEW.id, true, NEW.id)
221
+ returning id into first_account_id;
222
+
223
+ -- add them to the account_user table so they can act on it
224
+ insert into basejump.account_user (account_id, user_id, account_role)
225
+ values (first_account_id, NEW.id, 'owner');
226
+
227
+ return NEW;
228
+ end;
229
+ $$;
230
+
231
+ -- trigger the function every time a user is created
232
+ create trigger on_auth_user_created
233
+ after insert
234
+ on auth.users
235
+ for each row
236
+ execute procedure basejump.run_new_user_setup();
237
+
238
+ /**
239
+ * -------------------------------------------------------
240
+ * Section - Account permission utility functions
241
+ * -------------------------------------------------------
242
+ * These functions are stored on the basejump schema, and useful for things like
243
+ * generating RLS policies
244
+ */
245
+
246
+ /**
247
+ * Returns true if the current user has the pass in role on the passed in account
248
+ * If no role is sent, will return true if the user is a member of the account
249
+ * NOTE: This is an inefficient function when used on large query sets. You should reach for the get_accounts_with_role and lookup
250
+ * the account ID in those cases.
251
+ */
252
+ create or replace function basejump.has_role_on_account(account_id uuid, account_role basejump.account_role default null)
253
+ returns boolean
254
+ language sql
255
+ security definer
256
+ set search_path = public
257
+ as
258
+ $$
259
+ select exists(
260
+ select 1
261
+ from basejump.account_user wu
262
+ where wu.user_id = auth.uid()
263
+ and wu.account_id = has_role_on_account.account_id
264
+ and (
265
+ wu.account_role = has_role_on_account.account_role
266
+ or has_role_on_account.account_role is null
267
+ )
268
+ );
269
+ $$;
270
+
271
+ grant execute on function basejump.has_role_on_account(uuid, basejump.account_role) to authenticated, anon, public, service_role;
272
+
273
+
274
+ /**
275
+ * Returns account_ids that the current user is a member of. If you pass in a role,
276
+ * it'll only return accounts that the user is a member of with that role.
277
+ */
278
+ create or replace function basejump.get_accounts_with_role(passed_in_role basejump.account_role default null)
279
+ returns setof uuid
280
+ language sql
281
+ security definer
282
+ set search_path = public
283
+ as
284
+ $$
285
+ select account_id
286
+ from basejump.account_user wu
287
+ where wu.user_id = auth.uid()
288
+ and (
289
+ wu.account_role = passed_in_role
290
+ or passed_in_role is null
291
+ );
292
+ $$;
293
+
294
+ grant execute on function basejump.get_accounts_with_role(basejump.account_role) to authenticated;
295
+
296
+ /**
297
+ * -------------------------
298
+ * Section - RLS Policies
299
+ * -------------------------
300
+ * This is where we define access to tables in the basejump schema
301
+ */
302
+
303
+ create policy "users can view their own account_users" on basejump.account_user
304
+ for select
305
+ to authenticated
306
+ using (
307
+ user_id = auth.uid()
308
+ );
309
+
310
+ create policy "users can view their teammates" on basejump.account_user
311
+ for select
312
+ to authenticated
313
+ using (
314
+ basejump.has_role_on_account(account_id) = true
315
+ );
316
+
317
+ create policy "Account users can be deleted by owners except primary account owner" on basejump.account_user
318
+ for delete
319
+ to authenticated
320
+ using (
321
+ (basejump.has_role_on_account(account_id, 'owner') = true)
322
+ AND
323
+ user_id != (select primary_owner_user_id
324
+ from basejump.accounts
325
+ where account_id = accounts.id)
326
+ );
327
+
328
+ create policy "Accounts are viewable by members" on basejump.accounts
329
+ for select
330
+ to authenticated
331
+ using (
332
+ basejump.has_role_on_account(id) = true
333
+ );
334
+
335
+ -- Primary owner should always have access to the account
336
+ create policy "Accounts are viewable by primary owner" on basejump.accounts
337
+ for select
338
+ to authenticated
339
+ using (
340
+ primary_owner_user_id = auth.uid()
341
+ );
342
+
343
+ create policy "Team accounts can be created by any user" on basejump.accounts
344
+ for insert
345
+ to authenticated
346
+ with check (
347
+ basejump.is_set('enable_team_accounts') = true
348
+ and personal_account = false
349
+ );
350
+
351
+
352
+ create policy "Accounts can be edited by owners" on basejump.accounts
353
+ for update
354
+ to authenticated
355
+ using (
356
+ basejump.has_role_on_account(id, 'owner') = true
357
+ );
358
+
359
+ /**
360
+ * -------------------------------------------------------
361
+ * Section - Public functions
362
+ * -------------------------------------------------------
363
+ * Each of these functions exists in the public name space because they are accessible
364
+ * via the API. it is the primary way developers can interact with Basejump accounts
365
+ */
366
+
367
+ /**
368
+ * Returns the account_id for a given account slug
369
+ */
370
+
371
+ create or replace function public.get_account_id(slug text)
372
+ returns uuid
373
+ language sql
374
+ as
375
+ $$
376
+ select id
377
+ from basejump.accounts
378
+ where slug = get_account_id.slug;
379
+ $$;
380
+
381
+ grant execute on function public.get_account_id(text) to authenticated, service_role;
382
+
383
+ /**
384
+ * Returns the current user's role within a given account_id
385
+ */
386
+ create or replace function public.current_user_account_role(account_id uuid)
387
+ returns jsonb
388
+ language plpgsql
389
+ as
390
+ $$
391
+ DECLARE
392
+ response jsonb;
393
+ BEGIN
394
+
395
+ select jsonb_build_object(
396
+ 'account_role', wu.account_role,
397
+ 'is_primary_owner', a.primary_owner_user_id = auth.uid(),
398
+ 'is_personal_account', a.personal_account
399
+ )
400
+ into response
401
+ from basejump.account_user wu
402
+ join basejump.accounts a on a.id = wu.account_id
403
+ where wu.user_id = auth.uid()
404
+ and wu.account_id = current_user_account_role.account_id;
405
+
406
+ -- if the user is not a member of the account, throw an error
407
+ if response ->> 'account_role' IS NULL then
408
+ raise exception 'Not found';
409
+ end if;
410
+
411
+ return response;
412
+ END
413
+ $$;
414
+
415
+ grant execute on function public.current_user_account_role(uuid) to authenticated;
416
+
417
+ /**
418
+ * Let's you update a users role within an account if you are an owner of that account
419
+ **/
420
+ create or replace function public.update_account_user_role(account_id uuid, user_id uuid,
421
+ new_account_role basejump.account_role,
422
+ make_primary_owner boolean default false)
423
+ returns void
424
+ security definer
425
+ set search_path = public
426
+ language plpgsql
427
+ as
428
+ $$
429
+ declare
430
+ is_account_owner boolean;
431
+ is_account_primary_owner boolean;
432
+ changing_primary_owner boolean;
433
+ begin
434
+ -- check if the user is an owner, and if they are, allow them to update the role
435
+ select basejump.has_role_on_account(update_account_user_role.account_id, 'owner') into is_account_owner;
436
+
437
+ if not is_account_owner then
438
+ raise exception 'You must be an owner of the account to update a users role';
439
+ end if;
440
+
441
+ -- check if the user being changed is the primary owner, if so its not allowed
442
+ select primary_owner_user_id = auth.uid(), primary_owner_user_id = update_account_user_role.user_id
443
+ into is_account_primary_owner, changing_primary_owner
444
+ from basejump.accounts
445
+ where id = update_account_user_role.account_id;
446
+
447
+ if changing_primary_owner = true and is_account_primary_owner = false then
448
+ raise exception 'You must be the primary owner of the account to change the primary owner';
449
+ end if;
450
+
451
+ update basejump.account_user au
452
+ set account_role = new_account_role
453
+ where au.account_id = update_account_user_role.account_id
454
+ and au.user_id = update_account_user_role.user_id;
455
+
456
+ if make_primary_owner = true then
457
+ -- first we see if the current user is the owner, only they can do this
458
+ if is_account_primary_owner = false then
459
+ raise exception 'You must be the primary owner of the account to change the primary owner';
460
+ end if;
461
+
462
+ update basejump.accounts
463
+ set primary_owner_user_id = update_account_user_role.user_id
464
+ where id = update_account_user_role.account_id;
465
+ end if;
466
+ end;
467
+ $$;
468
+
469
+ grant execute on function public.update_account_user_role(uuid, uuid, basejump.account_role, boolean) to authenticated;
470
+
471
+ /**
472
+ Returns the current user's accounts
473
+ */
474
+ create or replace function public.get_accounts()
475
+ returns json
476
+ language sql
477
+ as
478
+ $$
479
+ select coalesce(json_agg(
480
+ json_build_object(
481
+ 'account_id', wu.account_id,
482
+ 'account_role', wu.account_role,
483
+ 'is_primary_owner', a.primary_owner_user_id = auth.uid(),
484
+ 'name', a.name,
485
+ 'slug', a.slug,
486
+ 'personal_account', a.personal_account,
487
+ 'created_at', a.created_at,
488
+ 'updated_at', a.updated_at
489
+ )
490
+ ), '[]'::json)
491
+ from basejump.account_user wu
492
+ join basejump.accounts a on a.id = wu.account_id
493
+ where wu.user_id = auth.uid();
494
+ $$;
495
+
496
+ grant execute on function public.get_accounts() to authenticated;
497
+
498
+ /**
499
+ Returns a specific account that the current user has access to
500
+ */
501
+ create or replace function public.get_account(account_id uuid)
502
+ returns json
503
+ language plpgsql
504
+ as
505
+ $$
506
+ BEGIN
507
+ -- check if the user is a member of the account or a service_role user
508
+ if current_user IN ('anon', 'authenticated') and
509
+ (select current_user_account_role(get_account.account_id) ->> 'account_role' IS NULL) then
510
+ raise exception 'You must be a member of an account to access it';
511
+ end if;
512
+
513
+
514
+ return (select json_build_object(
515
+ 'account_id', a.id,
516
+ 'account_role', wu.account_role,
517
+ 'is_primary_owner', a.primary_owner_user_id = auth.uid(),
518
+ 'name', a.name,
519
+ 'slug', a.slug,
520
+ 'personal_account', a.personal_account,
521
+ 'billing_enabled', case
522
+ when a.personal_account = true then
523
+ config.enable_personal_account_billing
524
+ else
525
+ config.enable_team_account_billing
526
+ end,
527
+ 'billing_status', bs.status,
528
+ 'created_at', a.created_at,
529
+ 'updated_at', a.updated_at,
530
+ 'metadata', a.public_metadata
531
+ )
532
+ from basejump.accounts a
533
+ left join basejump.account_user wu on a.id = wu.account_id and wu.user_id = auth.uid()
534
+ join basejump.config config on true
535
+ left join (select bs.account_id, status
536
+ from basejump.billing_subscriptions bs
537
+ where bs.account_id = get_account.account_id
538
+ order by created desc
539
+ limit 1) bs on bs.account_id = a.id
540
+ where a.id = get_account.account_id);
541
+ END;
542
+ $$;
543
+
544
+ grant execute on function public.get_account(uuid) to authenticated, service_role;
545
+
546
+ /**
547
+ Returns a specific account that the current user has access to
548
+ */
549
+ create or replace function public.get_account_by_slug(slug text)
550
+ returns json
551
+ language plpgsql
552
+ as
553
+ $$
554
+ DECLARE
555
+ internal_account_id uuid;
556
+ BEGIN
557
+ select a.id
558
+ into internal_account_id
559
+ from basejump.accounts a
560
+ where a.slug IS NOT NULL
561
+ and a.slug = get_account_by_slug.slug;
562
+
563
+ return public.get_account(internal_account_id);
564
+ END;
565
+ $$;
566
+
567
+ grant execute on function public.get_account_by_slug(text) to authenticated;
568
+
569
+ /**
570
+ Returns the personal account for the current user
571
+ */
572
+ create or replace function public.get_personal_account()
573
+ returns json
574
+ language plpgsql
575
+ as
576
+ $$
577
+ BEGIN
578
+ return public.get_account(auth.uid());
579
+ END;
580
+ $$;
581
+
582
+ grant execute on function public.get_personal_account() to authenticated;
583
+
584
+ /**
585
+ * Create an account
586
+ */
587
+ create or replace function public.create_account(slug text default null, name text default null)
588
+ returns json
589
+ language plpgsql
590
+ as
591
+ $$
592
+ DECLARE
593
+ new_account_id uuid;
594
+ BEGIN
595
+ insert into basejump.accounts (slug, name)
596
+ values (create_account.slug, create_account.name)
597
+ returning id into new_account_id;
598
+
599
+ return public.get_account(new_account_id);
600
+ EXCEPTION
601
+ WHEN unique_violation THEN
602
+ raise exception 'An account with that unique ID already exists';
603
+ END;
604
+ $$;
605
+
606
+ grant execute on function public.create_account(slug text, name text) to authenticated;
607
+
608
+ /**
609
+ Update an account with passed in info. None of the info is required except for account ID.
610
+ If you don't pass in a value for a field, it will not be updated.
611
+ If you set replace_meta to true, the metadata will be replaced with the passed in metadata.
612
+ If you set replace_meta to false, the metadata will be merged with the passed in metadata.
613
+ */
614
+ create or replace function public.update_account(account_id uuid, slug text default null, name text default null,
615
+ public_metadata jsonb default null,
616
+ replace_metadata boolean default false)
617
+ returns json
618
+ language plpgsql
619
+ as
620
+ $$
621
+ BEGIN
622
+
623
+ -- check if postgres role is service_role
624
+ if current_user IN ('anon', 'authenticated') and
625
+ not (select current_user_account_role(update_account.account_id) ->> 'account_role' = 'owner') then
626
+ raise exception 'Only account owners can update an account';
627
+ end if;
628
+
629
+ update basejump.accounts accounts
630
+ set slug = coalesce(update_account.slug, accounts.slug),
631
+ name = coalesce(update_account.name, accounts.name),
632
+ public_metadata = case
633
+ when update_account.public_metadata is null then accounts.public_metadata -- do nothing
634
+ when accounts.public_metadata IS NULL then update_account.public_metadata -- set metadata
635
+ when update_account.replace_metadata
636
+ then update_account.public_metadata -- replace metadata
637
+ else accounts.public_metadata || update_account.public_metadata end -- merge metadata
638
+ where accounts.id = update_account.account_id;
639
+
640
+ return public.get_account(account_id);
641
+ END;
642
+ $$;
643
+
644
+ grant execute on function public.update_account(uuid, text, text, jsonb, boolean) to authenticated, service_role;
645
+
646
+ /**
647
+ Returns a list of current account members. Only account owners can access this function.
648
+ It's a security definer because it requries us to lookup personal_accounts for existing members so we can
649
+ get their names.
650
+ */
651
+ create or replace function public.get_account_members(account_id uuid, results_limit integer default 50,
652
+ results_offset integer default 0)
653
+ returns json
654
+ language plpgsql
655
+ security definer
656
+ set search_path = basejump
657
+ as
658
+ $$
659
+ BEGIN
660
+
661
+ -- only account owners can access this function
662
+ if (select public.current_user_account_role(get_account_members.account_id) ->> 'account_role' <> 'owner') then
663
+ raise exception 'Only account owners can access this function';
664
+ end if;
665
+
666
+ return (select json_agg(
667
+ json_build_object(
668
+ 'user_id', wu.user_id,
669
+ 'account_role', wu.account_role,
670
+ 'name', p.name,
671
+ 'email', u.email,
672
+ 'is_primary_owner', a.primary_owner_user_id = wu.user_id
673
+ )
674
+ )
675
+ from basejump.account_user wu
676
+ join basejump.accounts a on a.id = wu.account_id
677
+ join basejump.accounts p on p.primary_owner_user_id = wu.user_id and p.personal_account = true
678
+ join auth.users u on u.id = wu.user_id
679
+ where wu.account_id = get_account_members.account_id
680
+ limit coalesce(get_account_members.results_limit, 50) offset coalesce(get_account_members.results_offset, 0));
681
+ END;
682
+ $$;
683
+
684
+ grant execute on function public.get_account_members(uuid, integer, integer) to authenticated;
685
+
686
+ /**
687
+ Allows an owner of the account to remove any member other than the primary owner
688
+ */
689
+
690
+ create or replace function public.remove_account_member(account_id uuid, user_id uuid)
691
+ returns void
692
+ language plpgsql
693
+ as
694
+ $$
695
+ BEGIN
696
+ -- only account owners can access this function
697
+ if basejump.has_role_on_account(remove_account_member.account_id, 'owner') <> true then
698
+ raise exception 'Only account owners can access this function';
699
+ end if;
700
+
701
+ delete
702
+ from basejump.account_user wu
703
+ where wu.account_id = remove_account_member.account_id
704
+ and wu.user_id = remove_account_member.user_id;
705
+ END;
706
+ $$;
707
+
708
+ grant execute on function public.remove_account_member(uuid, uuid) to authenticated;
20240414162100_basejump-invitations.sql ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * -------------------------------------------------------
3
+ * Section - Invitations
4
+ * -------------------------------------------------------
5
+ */
6
+
7
+ /**
8
+ * Invitations are sent to users to join a account
9
+ * They pre-define the role the user should have once they join
10
+ */
11
+ create table if not exists basejump.invitations
12
+ (
13
+ -- the id of the invitation
14
+ id uuid unique not null default extensions.uuid_generate_v4(),
15
+ -- what role should invitation accepters be given in this account
16
+ account_role basejump.account_role not null,
17
+ -- the account the invitation is for
18
+ account_id uuid references basejump.accounts (id) on delete cascade not null,
19
+ -- unique token used to accept the invitation
20
+ token text unique not null default basejump.generate_token(30),
21
+ -- who created the invitation
22
+ invited_by_user_id uuid references auth.users not null,
23
+ -- account name. filled in by a trigger
24
+ account_name text,
25
+ -- when the invitation was last updated
26
+ updated_at timestamp with time zone,
27
+ -- when the invitation was created
28
+ created_at timestamp with time zone,
29
+ -- what type of invitation is this
30
+ invitation_type basejump.invitation_type not null,
31
+ primary key (id)
32
+ );
33
+
34
+ -- Open up access to invitations
35
+ GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.invitations TO authenticated, service_role;
36
+
37
+ -- manage timestamps
38
+ CREATE TRIGGER basejump_set_invitations_timestamp
39
+ BEFORE INSERT OR UPDATE
40
+ ON basejump.invitations
41
+ FOR EACH ROW
42
+ EXECUTE FUNCTION basejump.trigger_set_timestamps();
43
+
44
+ /**
45
+ * This funciton fills in account info and inviting user email
46
+ * so that the recipient can get more info about the invitation prior to
47
+ * accepting. It allows us to avoid complex permissions on accounts
48
+ */
49
+ CREATE OR REPLACE FUNCTION basejump.trigger_set_invitation_details()
50
+ RETURNS TRIGGER AS
51
+ $$
52
+ BEGIN
53
+ NEW.invited_by_user_id = auth.uid();
54
+ NEW.account_name = (select name from basejump.accounts where id = NEW.account_id);
55
+ RETURN NEW;
56
+ END
57
+ $$ LANGUAGE plpgsql;
58
+
59
+ CREATE TRIGGER basejump_trigger_set_invitation_details
60
+ BEFORE INSERT
61
+ ON basejump.invitations
62
+ FOR EACH ROW
63
+ EXECUTE FUNCTION basejump.trigger_set_invitation_details();
64
+
65
+ -- enable RLS on invitations
66
+ alter table basejump.invitations
67
+ enable row level security;
68
+
69
+ /**
70
+ * -------------------------
71
+ * Section - RLS Policies
72
+ * -------------------------
73
+ * This is where we define access to tables in the basejump schema
74
+ */
75
+
76
+ create policy "Invitations viewable by account owners" on basejump.invitations
77
+ for select
78
+ to authenticated
79
+ using (
80
+ created_at > (now() - interval '24 hours')
81
+ and
82
+ basejump.has_role_on_account(account_id, 'owner') = true
83
+ );
84
+
85
+
86
+ create policy "Invitations can be created by account owners" on basejump.invitations
87
+ for insert
88
+ to authenticated
89
+ with check (
90
+ -- team accounts should be enabled
91
+ basejump.is_set('enable_team_accounts') = true
92
+ -- this should not be a personal account
93
+ and (SELECT personal_account
94
+ FROM basejump.accounts
95
+ WHERE id = account_id) = false
96
+ -- the inserting user should be an owner of the account
97
+ and
98
+ (basejump.has_role_on_account(account_id, 'owner') = true)
99
+ );
100
+
101
+ create policy "Invitations can be deleted by account owners" on basejump.invitations
102
+ for delete
103
+ to authenticated
104
+ using (
105
+ basejump.has_role_on_account(account_id, 'owner') = true
106
+ );
107
+
108
+
109
+
110
+ /**
111
+ * -------------------------------------------------------
112
+ * Section - Public functions
113
+ * -------------------------------------------------------
114
+ * Each of these functions exists in the public name space because they are accessible
115
+ * via the API. it is the primary way developers can interact with Basejump accounts
116
+ */
117
+
118
+
119
+ /**
120
+ Returns a list of currently active invitations for a given account
121
+ */
122
+
123
+ create or replace function public.get_account_invitations(account_id uuid, results_limit integer default 25,
124
+ results_offset integer default 0)
125
+ returns json
126
+ language plpgsql
127
+ as
128
+ $$
129
+ BEGIN
130
+ -- only account owners can access this function
131
+ if (select public.current_user_account_role(get_account_invitations.account_id) ->> 'account_role' <> 'owner') then
132
+ raise exception 'Only account owners can access this function';
133
+ end if;
134
+
135
+ return (select json_agg(
136
+ json_build_object(
137
+ 'account_role', i.account_role,
138
+ 'created_at', i.created_at,
139
+ 'invitation_type', i.invitation_type,
140
+ 'invitation_id', i.id
141
+ )
142
+ )
143
+ from basejump.invitations i
144
+ where i.account_id = get_account_invitations.account_id
145
+ and i.created_at > now() - interval '24 hours'
146
+ limit coalesce(get_account_invitations.results_limit, 25) offset coalesce(get_account_invitations.results_offset, 0));
147
+ END;
148
+ $$;
149
+
150
+ grant execute on function public.get_account_invitations(uuid, integer, integer) to authenticated;
151
+
152
+
153
+ /**
154
+ * Allows a user to accept an existing invitation and join a account
155
+ * This one exists in the public schema because we want it to be called
156
+ * using the supabase rpc method
157
+ */
158
+ create or replace function public.accept_invitation(lookup_invitation_token text)
159
+ returns jsonb
160
+ language plpgsql
161
+ security definer set search_path = public, basejump
162
+ as
163
+ $$
164
+ declare
165
+ lookup_account_id uuid;
166
+ declare new_member_role basejump.account_role;
167
+ lookup_account_slug text;
168
+ begin
169
+ select i.account_id, i.account_role, a.slug
170
+ into lookup_account_id, new_member_role, lookup_account_slug
171
+ from basejump.invitations i
172
+ join basejump.accounts a on a.id = i.account_id
173
+ where i.token = lookup_invitation_token
174
+ and i.created_at > now() - interval '24 hours';
175
+
176
+ if lookup_account_id IS NULL then
177
+ raise exception 'Invitation not found';
178
+ end if;
179
+
180
+ if lookup_account_id is not null then
181
+ -- we've validated the token is real, so grant the user access
182
+ insert into basejump.account_user (account_id, user_id, account_role)
183
+ values (lookup_account_id, auth.uid(), new_member_role);
184
+ -- email types of invitations are only good for one usage
185
+ delete from basejump.invitations where token = lookup_invitation_token and invitation_type = 'one_time';
186
+ end if;
187
+ return json_build_object('account_id', lookup_account_id, 'account_role', new_member_role, 'slug',
188
+ lookup_account_slug);
189
+ EXCEPTION
190
+ WHEN unique_violation THEN
191
+ raise exception 'You are already a member of this account';
192
+ end;
193
+ $$;
194
+
195
+ grant execute on function public.accept_invitation(text) to authenticated;
196
+
197
+
198
+ /**
199
+ * Allows a user to lookup an existing invitation and join a account
200
+ * This one exists in the public schema because we want it to be called
201
+ * using the supabase rpc method
202
+ */
203
+ create or replace function public.lookup_invitation(lookup_invitation_token text)
204
+ returns json
205
+ language plpgsql
206
+ security definer set search_path = public, basejump
207
+ as
208
+ $$
209
+ declare
210
+ name text;
211
+ invitation_active boolean;
212
+ begin
213
+ select account_name,
214
+ case when id IS NOT NULL then true else false end as active
215
+ into name, invitation_active
216
+ from basejump.invitations
217
+ where token = lookup_invitation_token
218
+ and created_at > now() - interval '24 hours'
219
+ limit 1;
220
+ return json_build_object('active', coalesce(invitation_active, false), 'account_name', name);
221
+ end;
222
+ $$;
223
+
224
+ grant execute on function public.lookup_invitation(text) to authenticated;
225
+
226
+
227
+ /**
228
+ Allows a user to create a new invitation if they are an owner of an account
229
+ */
230
+ create or replace function public.create_invitation(account_id uuid, account_role basejump.account_role,
231
+ invitation_type basejump.invitation_type)
232
+ returns json
233
+ language plpgsql
234
+ as
235
+ $$
236
+ declare
237
+ new_invitation basejump.invitations;
238
+ begin
239
+ insert into basejump.invitations (account_id, account_role, invitation_type, invited_by_user_id)
240
+ values (account_id, account_role, invitation_type, auth.uid())
241
+ returning * into new_invitation;
242
+
243
+ return json_build_object('token', new_invitation.token);
244
+ end
245
+ $$;
246
+
247
+ grant execute on function public.create_invitation(uuid, basejump.account_role, basejump.invitation_type) to authenticated;
248
+
249
+ /**
250
+ Allows an owner to delete an existing invitation
251
+ */
252
+
253
+ create or replace function public.delete_invitation(invitation_id uuid)
254
+ returns void
255
+ language plpgsql
256
+ as
257
+ $$
258
+ begin
259
+ -- verify account owner for the invitation
260
+ if basejump.has_role_on_account(
261
+ (select account_id from basejump.invitations where id = delete_invitation.invitation_id), 'owner') <>
262
+ true then
263
+ raise exception 'Only account owners can delete invitations';
264
+ end if;
265
+
266
+ delete from basejump.invitations where id = delete_invitation.invitation_id;
267
+ end
268
+ $$;
269
+
270
+ grant execute on function public.delete_invitation(uuid) to authenticated;
20240414162131_basejump-billing.sql ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * -------------------------------------------------------
3
+ * Section - Billing
4
+ * -------------------------------------------------------
5
+ */
6
+
7
+ /**
8
+ * Subscription Status
9
+ * Tracks the current status of the account subscription
10
+ */
11
+ DO
12
+ $$
13
+ BEGIN
14
+ IF NOT EXISTS(SELECT 1
15
+ FROM pg_type t
16
+ JOIN pg_namespace n ON n.oid = t.typnamespace
17
+ WHERE t.typname = 'subscription_status'
18
+ AND n.nspname = 'basejump') THEN
19
+ create type basejump.subscription_status as enum (
20
+ 'trialing',
21
+ 'active',
22
+ 'canceled',
23
+ 'incomplete',
24
+ 'incomplete_expired',
25
+ 'past_due',
26
+ 'unpaid'
27
+ );
28
+ end if;
29
+ end;
30
+ $$;
31
+
32
+
33
+ /**
34
+ * Billing customer
35
+ * This is a private table that contains a mapping of user IDs to your billing providers IDs
36
+ */
37
+ create table if not exists basejump.billing_customers
38
+ (
39
+ -- UUID from auth.users
40
+ account_id uuid references basejump.accounts (id) on delete cascade not null,
41
+ -- The user's customer ID in Stripe. User must not be able to update this.
42
+ id text primary key,
43
+ -- The email address the customer wants to use for invoicing
44
+ email text,
45
+ -- The active status of a customer
46
+ active boolean,
47
+ -- The billing provider the customer is using
48
+ provider text
49
+ );
50
+
51
+ -- Open up access to billing_customers
52
+ GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.billing_customers TO service_role;
53
+ GRANT SELECT ON TABLE basejump.billing_customers TO authenticated;
54
+
55
+
56
+ -- enable RLS for billing_customers
57
+ alter table
58
+ basejump.billing_customers
59
+ enable row level security;
60
+
61
+ /**
62
+ * Billing subscriptions
63
+ * This is a private table that contains a mapping of account IDs to your billing providers subscription IDs
64
+ */
65
+ create table if not exists basejump.billing_subscriptions
66
+ (
67
+ -- Subscription ID from Stripe, e.g. sub_1234.
68
+ id text primary key,
69
+ account_id uuid references basejump.accounts (id) on delete cascade not null,
70
+ billing_customer_id text references basejump.billing_customers (id) on delete cascade not null,
71
+ -- The status of the subscription object, one of subscription_status type above.
72
+ status basejump.subscription_status,
73
+ -- Set of key-value pairs, used to store additional information about the object in a structured format.
74
+ metadata jsonb,
75
+ -- ID of the price that created this subscription.
76
+ price_id text,
77
+ plan_name text,
78
+ -- Quantity multiplied by the unit amount of the price creates the amount of the subscription. Can be used to charge multiple seats.
79
+ quantity integer,
80
+ -- If true the subscription has been canceled by the user and will be deleted at the end of the billing period.
81
+ cancel_at_period_end boolean,
82
+ -- Time at which the subscription was created.
83
+ created timestamp with time zone default timezone('utc' :: text, now()) not null,
84
+ -- Start of the current period that the subscription has been invoiced for.
85
+ current_period_start timestamp with time zone default timezone('utc' :: text, now()) not null,
86
+ -- End of the current period that the subscription has been invoiced for. At the end of this period, a new invoice will be created.
87
+ current_period_end timestamp with time zone default timezone('utc' :: text, now()) not null,
88
+ -- If the subscription has ended, the timestamp of the date the subscription ended.
89
+ ended_at timestamp with time zone default timezone('utc' :: text, now()),
90
+ -- A date in the future at which the subscription will automatically get canceled.
91
+ cancel_at timestamp with time zone default timezone('utc' :: text, now()),
92
+ -- If the subscription has been canceled, the date of that cancellation. If the subscription was canceled with `cancel_at_period_end`, `canceled_at` will still reflect the date of the initial cancellation request, not the end of the subscription period when the subscription is automatically moved to a canceled state.
93
+ canceled_at timestamp with time zone default timezone('utc' :: text, now()),
94
+ -- If the subscription has a trial, the beginning of that trial.
95
+ trial_start timestamp with time zone default timezone('utc' :: text, now()),
96
+ -- If the subscription has a trial, the end of that trial.
97
+ trial_end timestamp with time zone default timezone('utc' :: text, now()),
98
+ provider text
99
+ );
100
+
101
+ -- Open up access to billing_subscriptions
102
+ GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE basejump.billing_subscriptions TO service_role;
103
+ GRANT SELECT ON TABLE basejump.billing_subscriptions TO authenticated;
104
+
105
+ -- enable RLS for billing_subscriptions
106
+ alter table
107
+ basejump.billing_subscriptions
108
+ enable row level security;
109
+
110
+ /**
111
+ * -------------------------
112
+ * Section - RLS Policies
113
+ * -------------------------
114
+ * This is where we define access to tables in the basejump schema
115
+ */
116
+
117
+ create policy "Can only view own billing customer data." on basejump.billing_customers for
118
+ select
119
+ using (
120
+ basejump.has_role_on_account(account_id) = true
121
+ );
122
+
123
+
124
+ create policy "Can only view own billing subscription data." on basejump.billing_subscriptions for
125
+ select
126
+ using (
127
+ basejump.has_role_on_account(account_id) = true
128
+ );
129
+
130
+ /**
131
+ * -------------------------------------------------------
132
+ * Section - Public functions
133
+ * -------------------------------------------------------
134
+ * Each of these functions exists in the public name space because they are accessible
135
+ * via the API. it is the primary way developers can interact with Basejump accounts
136
+ */
137
+
138
+
139
+ /**
140
+ * Returns the current billing status for an account
141
+ */
142
+ CREATE OR REPLACE FUNCTION public.get_account_billing_status(account_id uuid)
143
+ RETURNS jsonb
144
+ security definer
145
+ set search_path = public, basejump
146
+ AS
147
+ $$
148
+ DECLARE
149
+ result jsonb;
150
+ role_result jsonb;
151
+ BEGIN
152
+ select public.current_user_account_role(get_account_billing_status.account_id) into role_result;
153
+
154
+ select jsonb_build_object(
155
+ 'account_id', get_account_billing_status.account_id,
156
+ 'billing_subscription_id', s.id,
157
+ 'billing_enabled', case
158
+ when a.personal_account = true then config.enable_personal_account_billing
159
+ else config.enable_team_account_billing end,
160
+ 'billing_status', s.status,
161
+ 'billing_customer_id', c.id,
162
+ 'billing_provider', config.billing_provider,
163
+ 'billing_email',
164
+ coalesce(c.email, u.email) -- if we don't have a customer email, use the user's email as a fallback
165
+ )
166
+ into result
167
+ from basejump.accounts a
168
+ join auth.users u on u.id = a.primary_owner_user_id
169
+ left join basejump.billing_subscriptions s on s.account_id = a.id
170
+ left join basejump.billing_customers c on c.account_id = coalesce(s.account_id, a.id)
171
+ join basejump.config config on true
172
+ where a.id = get_account_billing_status.account_id
173
+ order by s.created desc
174
+ limit 1;
175
+
176
+ return result || role_result;
177
+ END;
178
+ $$ LANGUAGE plpgsql;
179
+
180
+ grant execute on function public.get_account_billing_status(uuid) to authenticated;
181
+
182
+ /**
183
+ * Allow service accounts to upsert the billing data for an account
184
+ */
185
+ CREATE OR REPLACE FUNCTION public.service_role_upsert_customer_subscription(account_id uuid,
186
+ customer jsonb default null,
187
+ subscription jsonb default null)
188
+ RETURNS void AS
189
+ $$
190
+ BEGIN
191
+ -- if the customer is not null, upsert the data into billing_customers, only upsert fields that are present in the jsonb object
192
+ if customer is not null then
193
+ insert into basejump.billing_customers (id, account_id, email, provider)
194
+ values (customer ->> 'id', service_role_upsert_customer_subscription.account_id, customer ->> 'billing_email',
195
+ (customer ->> 'provider'))
196
+ on conflict (id) do update
197
+ set email = customer ->> 'billing_email';
198
+ end if;
199
+
200
+ -- if the subscription is not null, upsert the data into billing_subscriptions, only upsert fields that are present in the jsonb object
201
+ if subscription is not null then
202
+ insert into basejump.billing_subscriptions (id, account_id, billing_customer_id, status, metadata, price_id,
203
+ quantity, cancel_at_period_end, created, current_period_start,
204
+ current_period_end, ended_at, cancel_at, canceled_at, trial_start,
205
+ trial_end, plan_name, provider)
206
+ values (subscription ->> 'id', service_role_upsert_customer_subscription.account_id,
207
+ subscription ->> 'billing_customer_id', (subscription ->> 'status')::basejump.subscription_status,
208
+ subscription -> 'metadata',
209
+ subscription ->> 'price_id', (subscription ->> 'quantity')::int,
210
+ (subscription ->> 'cancel_at_period_end')::boolean,
211
+ (subscription ->> 'created')::timestamptz, (subscription ->> 'current_period_start')::timestamptz,
212
+ (subscription ->> 'current_period_end')::timestamptz, (subscription ->> 'ended_at')::timestamptz,
213
+ (subscription ->> 'cancel_at')::timestamptz,
214
+ (subscription ->> 'canceled_at')::timestamptz, (subscription ->> 'trial_start')::timestamptz,
215
+ (subscription ->> 'trial_end')::timestamptz,
216
+ subscription ->> 'plan_name', (subscription ->> 'provider'))
217
+ on conflict (id) do update
218
+ set billing_customer_id = subscription ->> 'billing_customer_id',
219
+ status = (subscription ->> 'status')::basejump.subscription_status,
220
+ metadata = subscription -> 'metadata',
221
+ price_id = subscription ->> 'price_id',
222
+ quantity = (subscription ->> 'quantity')::int,
223
+ cancel_at_period_end = (subscription ->> 'cancel_at_period_end')::boolean,
224
+ current_period_start = (subscription ->> 'current_period_start')::timestamptz,
225
+ current_period_end = (subscription ->> 'current_period_end')::timestamptz,
226
+ ended_at = (subscription ->> 'ended_at')::timestamptz,
227
+ cancel_at = (subscription ->> 'cancel_at')::timestamptz,
228
+ canceled_at = (subscription ->> 'canceled_at')::timestamptz,
229
+ trial_start = (subscription ->> 'trial_start')::timestamptz,
230
+ trial_end = (subscription ->> 'trial_end')::timestamptz,
231
+ plan_name = subscription ->> 'plan_name';
232
+ end if;
233
+ end;
234
+ $$ LANGUAGE plpgsql;
235
+
236
+ GRANT EXECUTE ON FUNCTION public.service_role_upsert_customer_subscription(uuid, jsonb, jsonb) TO service_role;
20250409211903_basejump-configure.sql ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ UPDATE basejump.config SET enable_team_accounts = TRUE;
2
+ UPDATE basejump.config SET enable_personal_account_billing = TRUE;
3
+ UPDATE basejump.config SET enable_team_account_billing = TRUE;
20250409212058_initial.sql ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Enable UUID extension
2
+ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
3
+
4
+ -- Create devices table first
5
+ CREATE TABLE public.devices (
6
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
7
+ account_id UUID NOT NULL,
8
+ name TEXT,
9
+ last_seen TIMESTAMP WITH TIME ZONE,
10
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
11
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
12
+ is_online BOOLEAN DEFAULT FALSE,
13
+ CONSTRAINT fk_account FOREIGN KEY (account_id) REFERENCES basejump.accounts(id) ON DELETE CASCADE
14
+ );
15
+
16
+ -- Create recordings table
17
+ CREATE TABLE public.recordings (
18
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
19
+ account_id UUID NOT NULL,
20
+ device_id UUID NOT NULL,
21
+ preprocessed_file_path TEXT,
22
+ meta JSONB,
23
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
24
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
25
+ name TEXT,
26
+ ui_annotated BOOLEAN DEFAULT FALSE,
27
+ a11y_file_path TEXT,
28
+ audio_file_path TEXT,
29
+ action_annotated BOOLEAN DEFAULT FALSE,
30
+ raw_data_file_path TEXT,
31
+ metadata_file_path TEXT,
32
+ action_training_file_path TEXT,
33
+ CONSTRAINT fk_account FOREIGN KEY (account_id) REFERENCES basejump.accounts(id) ON DELETE CASCADE,
34
+ CONSTRAINT fk_device FOREIGN KEY (device_id) REFERENCES public.devices(id) ON DELETE CASCADE
35
+ );
36
+
37
+ -- Create indexes for foreign keys
38
+ CREATE INDEX idx_recordings_account_id ON public.recordings(account_id);
39
+ CREATE INDEX idx_recordings_device_id ON public.recordings(device_id);
40
+ CREATE INDEX idx_devices_account_id ON public.devices(account_id);
41
+
42
+ -- Add RLS policies (optional, can be customized as needed)
43
+ ALTER TABLE public.recordings ENABLE ROW LEVEL SECURITY;
44
+ ALTER TABLE public.devices ENABLE ROW LEVEL SECURITY;
45
+
46
+ -- Create RLS policies for devices
47
+ CREATE POLICY "Account members can delete their own devices"
48
+ ON public.devices FOR DELETE
49
+ USING (basejump.has_role_on_account(account_id));
50
+
51
+ CREATE POLICY "Account members can insert their own devices"
52
+ ON public.devices FOR INSERT
53
+ WITH CHECK (basejump.has_role_on_account(account_id));
54
+
55
+ CREATE POLICY "Account members can only access their own devices"
56
+ ON public.devices FOR ALL
57
+ USING (basejump.has_role_on_account(account_id));
58
+
59
+ CREATE POLICY "Account members can update their own devices"
60
+ ON public.devices FOR UPDATE
61
+ USING (basejump.has_role_on_account(account_id));
62
+
63
+ CREATE POLICY "Account members can view their own devices"
64
+ ON public.devices FOR SELECT
65
+ USING (basejump.has_role_on_account(account_id));
66
+
67
+ -- Create RLS policies for recordings
68
+ CREATE POLICY "Account members can delete their own recordings"
69
+ ON public.recordings FOR DELETE
70
+ USING (basejump.has_role_on_account(account_id));
71
+
72
+ CREATE POLICY "Account members can insert their own recordings"
73
+ ON public.recordings FOR INSERT
74
+ WITH CHECK (basejump.has_role_on_account(account_id));
75
+
76
+ CREATE POLICY "Account members can only access their own recordings"
77
+ ON public.recordings FOR ALL
78
+ USING (basejump.has_role_on_account(account_id));
79
+
80
+ CREATE POLICY "Account members can update their own recordings"
81
+ ON public.recordings FOR UPDATE
82
+ USING (basejump.has_role_on_account(account_id));
83
+
84
+ CREATE POLICY "Account members can view their own recordings"
85
+ ON public.recordings FOR SELECT
86
+ USING (basejump.has_role_on_account(account_id));
87
+
88
+ -- Note: For threads and messages, you might want different RLS policies
89
+ -- depending on your application's requirements
90
+
91
+
92
+ -- Also drop the old function signature
93
+ DROP FUNCTION IF EXISTS transfer_device(UUID, UUID, TEXT);
94
+
95
+
96
+ CREATE OR REPLACE FUNCTION transfer_device(
97
+ device_id UUID, -- Parameter remains UUID
98
+ new_account_id UUID, -- Changed parameter name and implies new ownership target
99
+ device_name TEXT DEFAULT NULL
100
+ )
101
+ RETURNS SETOF devices AS $$
102
+ DECLARE
103
+ device_exists BOOLEAN;
104
+ updated_device devices;
105
+ BEGIN
106
+ -- Check if a device with the specified UUID exists
107
+ SELECT EXISTS (
108
+ SELECT 1 FROM devices WHERE id = device_id
109
+ ) INTO device_exists;
110
+
111
+ IF device_exists THEN
112
+ -- Device exists: update its account ownership and last_seen timestamp
113
+ UPDATE devices
114
+ SET
115
+ account_id = new_account_id, -- Update account_id instead of user_id
116
+ name = COALESCE(device_name, name),
117
+ last_seen = NOW()
118
+ WHERE id = device_id
119
+ RETURNING * INTO updated_device;
120
+
121
+ RETURN NEXT updated_device;
122
+ ELSE
123
+ -- Device doesn't exist; return nothing so the caller can handle creation
124
+ RETURN;
125
+ END IF;
126
+ END;
127
+ $$ LANGUAGE plpgsql SECURITY DEFINER;
128
+
129
+ -- Grant execute permission so that authenticated users can call this function
130
+ -- Updated function signature
131
+ GRANT EXECUTE ON FUNCTION transfer_device(UUID, UUID, TEXT) TO authenticated;
132
+
133
+
134
+
135
+
136
+ -- Create the ui_grounding bucket
137
+ INSERT INTO storage.buckets (id, name, public)
138
+ VALUES ('ui_grounding', 'ui_grounding', false)
139
+ ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
140
+
141
+ -- Create the ui_grounding_trajs bucket
142
+ INSERT INTO storage.buckets (id, name, public)
143
+ VALUES ('ui_grounding_trajs', 'ui_grounding_trajs', false)
144
+ ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
145
+
146
+ -- Create the recordings bucket
147
+ INSERT INTO storage.buckets (id, name, public, file_size_limit, allowed_mime_types)
148
+ VALUES ('recordings', 'recordings', false, null, null) -- Set file size limit and mime types as needed
149
+ ON CONFLICT (id) DO NOTHING; -- Avoid error if bucket already exists
150
+
151
+
152
+ -- RLS policies for the 'recordings' bucket
153
+ -- Allow members to view files in accounts they belong to
154
+ CREATE POLICY "Account members can select recording files"
155
+ ON storage.objects FOR SELECT
156
+ TO authenticated
157
+ USING (
158
+ bucket_id = 'recordings' AND
159
+ (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
160
+ );
161
+
162
+ -- Allow members to insert files into accounts they belong to
163
+ CREATE POLICY "Account members can insert recording files"
164
+ ON storage.objects FOR INSERT
165
+ TO authenticated
166
+ WITH CHECK (
167
+ bucket_id = 'recordings' AND
168
+ (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
169
+ );
170
+
171
+ -- Allow members to update files in accounts they belong to
172
+ CREATE POLICY "Account members can update recording files"
173
+ ON storage.objects FOR UPDATE
174
+ TO authenticated
175
+ USING (
176
+ bucket_id = 'recordings' AND
177
+ (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
178
+ );
179
+
180
+ -- Allow members to delete files from accounts they belong to
181
+ -- Consider restricting this further, e.g., to 'owner' role if needed:
182
+ -- (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role('owner'))
183
+ CREATE POLICY "Account members can delete recording files"
184
+ ON storage.objects FOR DELETE
185
+ TO authenticated
186
+ USING (
187
+ bucket_id = 'recordings' AND
188
+ (storage.foldername(name))[1]::uuid IN (SELECT basejump.get_accounts_with_role())
189
+ );
20250416133920_agentpress_schema.sql ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- AGENTPRESS SCHEMA:
2
+ -- Create projects table
3
+ CREATE TABLE projects (
4
+ project_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
5
+ name TEXT NOT NULL,
6
+ description TEXT,
7
+ account_id UUID NOT NULL REFERENCES basejump.accounts(id) ON DELETE CASCADE,
8
+ sandbox JSONB DEFAULT '{}'::jsonb,
9
+ is_public BOOLEAN DEFAULT FALSE,
10
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
11
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
12
+ );
13
+
14
+ -- Create threads table
15
+ CREATE TABLE threads (
16
+ thread_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
17
+ account_id UUID REFERENCES basejump.accounts(id) ON DELETE CASCADE,
18
+ project_id UUID REFERENCES projects(project_id) ON DELETE CASCADE,
19
+ is_public BOOLEAN DEFAULT FALSE,
20
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
21
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
22
+ );
23
+
24
+ -- Create messages table
25
+ CREATE TABLE messages (
26
+ message_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
27
+ thread_id UUID NOT NULL REFERENCES threads(thread_id) ON DELETE CASCADE,
28
+ type TEXT NOT NULL,
29
+ is_llm_message BOOLEAN NOT NULL DEFAULT TRUE,
30
+ content JSONB NOT NULL,
31
+ metadata JSONB DEFAULT '{}'::jsonb,
32
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
33
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
34
+ );
35
+
36
+ -- Create agent_runs table
37
+ CREATE TABLE agent_runs (
38
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
39
+ thread_id UUID NOT NULL REFERENCES threads(thread_id),
40
+ status TEXT NOT NULL DEFAULT 'running',
41
+ started_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
42
+ completed_at TIMESTAMP WITH TIME ZONE,
43
+ responses JSONB NOT NULL DEFAULT '[]'::jsonb, -- TO BE REMOVED, NOT USED
44
+ error TEXT,
45
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL,
46
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT TIMEZONE('utc'::text, NOW()) NOT NULL
47
+ );
48
+
49
+ -- Create updated_at trigger function
50
+ CREATE OR REPLACE FUNCTION update_updated_at_column()
51
+ RETURNS TRIGGER AS $$
52
+ BEGIN
53
+ NEW.updated_at = TIMEZONE('utc'::text, NOW());
54
+ RETURN NEW;
55
+ END;
56
+ $$ language 'plpgsql';
57
+
58
+ -- Create triggers for updated_at
59
+ CREATE TRIGGER update_threads_updated_at
60
+ BEFORE UPDATE ON threads
61
+ FOR EACH ROW
62
+ EXECUTE FUNCTION update_updated_at_column();
63
+
64
+ CREATE TRIGGER update_messages_updated_at
65
+ BEFORE UPDATE ON messages
66
+ FOR EACH ROW
67
+ EXECUTE FUNCTION update_updated_at_column();
68
+
69
+ CREATE TRIGGER update_agent_runs_updated_at
70
+ BEFORE UPDATE ON agent_runs
71
+ FOR EACH ROW
72
+ EXECUTE FUNCTION update_updated_at_column();
73
+
74
+ CREATE TRIGGER update_projects_updated_at
75
+ BEFORE UPDATE ON projects
76
+ FOR EACH ROW
77
+ EXECUTE FUNCTION update_updated_at_column();
78
+
79
+ -- Create indexes for better query performance
80
+ CREATE INDEX idx_threads_created_at ON threads(created_at);
81
+ CREATE INDEX idx_threads_account_id ON threads(account_id);
82
+ CREATE INDEX idx_threads_project_id ON threads(project_id);
83
+ CREATE INDEX idx_agent_runs_thread_id ON agent_runs(thread_id);
84
+ CREATE INDEX idx_agent_runs_status ON agent_runs(status);
85
+ CREATE INDEX idx_agent_runs_created_at ON agent_runs(created_at);
86
+ CREATE INDEX idx_projects_account_id ON projects(account_id);
87
+ CREATE INDEX idx_projects_created_at ON projects(created_at);
88
+ CREATE INDEX idx_messages_thread_id ON messages(thread_id);
89
+ CREATE INDEX idx_messages_created_at ON messages(created_at);
90
+
91
+ -- Enable Row Level Security
92
+ ALTER TABLE threads ENABLE ROW LEVEL SECURITY;
93
+ ALTER TABLE messages ENABLE ROW LEVEL SECURITY;
94
+ ALTER TABLE agent_runs ENABLE ROW LEVEL SECURITY;
95
+ ALTER TABLE projects ENABLE ROW LEVEL SECURITY;
96
+
97
+ -- Project policies
98
+ CREATE POLICY project_select_policy ON projects
99
+ FOR SELECT
100
+ USING (
101
+ is_public = TRUE OR
102
+ basejump.has_role_on_account(account_id) = true
103
+ );
104
+
105
+ CREATE POLICY project_insert_policy ON projects
106
+ FOR INSERT
107
+ WITH CHECK (basejump.has_role_on_account(account_id) = true);
108
+
109
+ CREATE POLICY project_update_policy ON projects
110
+ FOR UPDATE
111
+ USING (basejump.has_role_on_account(account_id) = true);
112
+
113
+ CREATE POLICY project_delete_policy ON projects
114
+ FOR DELETE
115
+ USING (basejump.has_role_on_account(account_id) = true);
116
+
117
+ -- Thread policies based on project and account ownership
118
+ CREATE POLICY thread_select_policy ON threads
119
+ FOR SELECT
120
+ USING (
121
+ basejump.has_role_on_account(account_id) = true OR
122
+ EXISTS (
123
+ SELECT 1 FROM projects
124
+ WHERE projects.project_id = threads.project_id
125
+ AND (
126
+ projects.is_public = TRUE OR
127
+ basejump.has_role_on_account(projects.account_id) = true
128
+ )
129
+ )
130
+ );
131
+
132
+ CREATE POLICY thread_insert_policy ON threads
133
+ FOR INSERT
134
+ WITH CHECK (
135
+ basejump.has_role_on_account(account_id) = true OR
136
+ EXISTS (
137
+ SELECT 1 FROM projects
138
+ WHERE projects.project_id = threads.project_id
139
+ AND basejump.has_role_on_account(projects.account_id) = true
140
+ )
141
+ );
142
+
143
+ CREATE POLICY thread_update_policy ON threads
144
+ FOR UPDATE
145
+ USING (
146
+ basejump.has_role_on_account(account_id) = true OR
147
+ EXISTS (
148
+ SELECT 1 FROM projects
149
+ WHERE projects.project_id = threads.project_id
150
+ AND basejump.has_role_on_account(projects.account_id) = true
151
+ )
152
+ );
153
+
154
+ CREATE POLICY thread_delete_policy ON threads
155
+ FOR DELETE
156
+ USING (
157
+ basejump.has_role_on_account(account_id) = true OR
158
+ EXISTS (
159
+ SELECT 1 FROM projects
160
+ WHERE projects.project_id = threads.project_id
161
+ AND basejump.has_role_on_account(projects.account_id) = true
162
+ )
163
+ );
164
+
165
+ -- Create policies for agent_runs based on thread ownership
166
+ CREATE POLICY agent_run_select_policy ON agent_runs
167
+ FOR SELECT
168
+ USING (
169
+ EXISTS (
170
+ SELECT 1 FROM threads
171
+ LEFT JOIN projects ON threads.project_id = projects.project_id
172
+ WHERE threads.thread_id = agent_runs.thread_id
173
+ AND (
174
+ projects.is_public = TRUE OR
175
+ basejump.has_role_on_account(threads.account_id) = true OR
176
+ basejump.has_role_on_account(projects.account_id) = true
177
+ )
178
+ )
179
+ );
180
+
181
+ CREATE POLICY agent_run_insert_policy ON agent_runs
182
+ FOR INSERT
183
+ WITH CHECK (
184
+ EXISTS (
185
+ SELECT 1 FROM threads
186
+ LEFT JOIN projects ON threads.project_id = projects.project_id
187
+ WHERE threads.thread_id = agent_runs.thread_id
188
+ AND (
189
+ basejump.has_role_on_account(threads.account_id) = true OR
190
+ basejump.has_role_on_account(projects.account_id) = true
191
+ )
192
+ )
193
+ );
194
+
195
+ CREATE POLICY agent_run_update_policy ON agent_runs
196
+ FOR UPDATE
197
+ USING (
198
+ EXISTS (
199
+ SELECT 1 FROM threads
200
+ LEFT JOIN projects ON threads.project_id = projects.project_id
201
+ WHERE threads.thread_id = agent_runs.thread_id
202
+ AND (
203
+ basejump.has_role_on_account(threads.account_id) = true OR
204
+ basejump.has_role_on_account(projects.account_id) = true
205
+ )
206
+ )
207
+ );
208
+
209
+ CREATE POLICY agent_run_delete_policy ON agent_runs
210
+ FOR DELETE
211
+ USING (
212
+ EXISTS (
213
+ SELECT 1 FROM threads
214
+ LEFT JOIN projects ON threads.project_id = projects.project_id
215
+ WHERE threads.thread_id = agent_runs.thread_id
216
+ AND (
217
+ basejump.has_role_on_account(threads.account_id) = true OR
218
+ basejump.has_role_on_account(projects.account_id) = true
219
+ )
220
+ )
221
+ );
222
+
223
+ -- Create message policies based on thread ownership
224
+ CREATE POLICY message_select_policy ON messages
225
+ FOR SELECT
226
+ USING (
227
+ EXISTS (
228
+ SELECT 1 FROM threads
229
+ LEFT JOIN projects ON threads.project_id = projects.project_id
230
+ WHERE threads.thread_id = messages.thread_id
231
+ AND (
232
+ projects.is_public = TRUE OR
233
+ basejump.has_role_on_account(threads.account_id) = true OR
234
+ basejump.has_role_on_account(projects.account_id) = true
235
+ )
236
+ )
237
+ );
238
+
239
+ CREATE POLICY message_insert_policy ON messages
240
+ FOR INSERT
241
+ WITH CHECK (
242
+ EXISTS (
243
+ SELECT 1 FROM threads
244
+ LEFT JOIN projects ON threads.project_id = projects.project_id
245
+ WHERE threads.thread_id = messages.thread_id
246
+ AND (
247
+ basejump.has_role_on_account(threads.account_id) = true OR
248
+ basejump.has_role_on_account(projects.account_id) = true
249
+ )
250
+ )
251
+ );
252
+
253
+ CREATE POLICY message_update_policy ON messages
254
+ FOR UPDATE
255
+ USING (
256
+ EXISTS (
257
+ SELECT 1 FROM threads
258
+ LEFT JOIN projects ON threads.project_id = projects.project_id
259
+ WHERE threads.thread_id = messages.thread_id
260
+ AND (
261
+ basejump.has_role_on_account(threads.account_id) = true OR
262
+ basejump.has_role_on_account(projects.account_id) = true
263
+ )
264
+ )
265
+ );
266
+
267
+ CREATE POLICY message_delete_policy ON messages
268
+ FOR DELETE
269
+ USING (
270
+ EXISTS (
271
+ SELECT 1 FROM threads
272
+ LEFT JOIN projects ON threads.project_id = projects.project_id
273
+ WHERE threads.thread_id = messages.thread_id
274
+ AND (
275
+ basejump.has_role_on_account(threads.account_id) = true OR
276
+ basejump.has_role_on_account(projects.account_id) = true
277
+ )
278
+ )
279
+ );
280
+
281
+ -- Grant permissions to roles
282
+ GRANT ALL PRIVILEGES ON TABLE projects TO authenticated, service_role;
283
+ GRANT SELECT ON TABLE projects TO anon;
284
+ GRANT SELECT ON TABLE threads TO authenticated, anon, service_role;
285
+ GRANT SELECT ON TABLE messages TO authenticated, anon, service_role;
286
+ GRANT ALL PRIVILEGES ON TABLE agent_runs TO authenticated, service_role;
287
+
288
+ -- Create a function that matches the Python get_messages behavior
289
+ CREATE OR REPLACE FUNCTION get_llm_formatted_messages(p_thread_id UUID)
290
+ RETURNS JSONB
291
+ SECURITY DEFINER -- Changed to SECURITY DEFINER to allow service role access
292
+ LANGUAGE plpgsql
293
+ AS $$
294
+ DECLARE
295
+ messages_array JSONB := '[]'::JSONB;
296
+ has_access BOOLEAN;
297
+ current_role TEXT;
298
+ latest_summary_id UUID;
299
+ latest_summary_time TIMESTAMP WITH TIME ZONE;
300
+ is_project_public BOOLEAN;
301
+ BEGIN
302
+ -- Get current role
303
+ SELECT current_user INTO current_role;
304
+
305
+ -- Check if associated project is public
306
+ SELECT p.is_public INTO is_project_public
307
+ FROM threads t
308
+ LEFT JOIN projects p ON t.project_id = p.project_id
309
+ WHERE t.thread_id = p_thread_id;
310
+
311
+ -- Skip access check for service_role or public projects
312
+ IF current_role = 'authenticated' AND NOT is_project_public THEN
313
+ -- Check if thread exists and user has access
314
+ SELECT EXISTS (
315
+ SELECT 1 FROM threads t
316
+ LEFT JOIN projects p ON t.project_id = p.project_id
317
+ WHERE t.thread_id = p_thread_id
318
+ AND (
319
+ basejump.has_role_on_account(t.account_id) = true OR
320
+ basejump.has_role_on_account(p.account_id) = true
321
+ )
322
+ ) INTO has_access;
323
+
324
+ IF NOT has_access THEN
325
+ RAISE EXCEPTION 'Thread not found or access denied';
326
+ END IF;
327
+ END IF;
328
+
329
+ -- Find the latest summary message if it exists
330
+ SELECT message_id, created_at
331
+ INTO latest_summary_id, latest_summary_time
332
+ FROM messages
333
+ WHERE thread_id = p_thread_id
334
+ AND type = 'summary'
335
+ AND is_llm_message = TRUE
336
+ ORDER BY created_at DESC
337
+ LIMIT 1;
338
+
339
+ -- Log whether a summary was found (helpful for debugging)
340
+ IF latest_summary_id IS NOT NULL THEN
341
+ RAISE NOTICE 'Found latest summary message: id=%, time=%', latest_summary_id, latest_summary_time;
342
+ ELSE
343
+ RAISE NOTICE 'No summary message found for thread %', p_thread_id;
344
+ END IF;
345
+
346
+ -- Parse content if it's stored as a string and return proper JSON objects
347
+ WITH parsed_messages AS (
348
+ SELECT
349
+ message_id,
350
+ CASE
351
+ WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
352
+ ELSE content
353
+ END AS parsed_content,
354
+ created_at,
355
+ type
356
+ FROM messages
357
+ WHERE thread_id = p_thread_id
358
+ AND is_llm_message = TRUE
359
+ AND (
360
+ -- Include the latest summary and all messages after it,
361
+ -- or all messages if no summary exists
362
+ latest_summary_id IS NULL
363
+ OR message_id = latest_summary_id
364
+ OR created_at > latest_summary_time
365
+ )
366
+ ORDER BY created_at
367
+ )
368
+ SELECT JSONB_AGG(parsed_content)
369
+ INTO messages_array
370
+ FROM parsed_messages;
371
+
372
+ -- Handle the case when no messages are found
373
+ IF messages_array IS NULL THEN
374
+ RETURN '[]'::JSONB;
375
+ END IF;
376
+
377
+ RETURN messages_array;
378
+ END;
379
+ $$;
380
+
381
+ -- Grant execute permissions
382
+ GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, anon, service_role;
20250506000000_initial_setup.sql ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Create required schemas
2
+ CREATE SCHEMA IF NOT EXISTS auth;
3
+ CREATE SCHEMA IF NOT EXISTS storage;
4
+ CREATE SCHEMA IF NOT EXISTS basejump;
5
+
6
+ -- Create basic roles
7
+ CREATE ROLE IF NOT EXISTS anon NOLOGIN;
8
+ GRANT USAGE ON SCHEMA public TO anon;
9
+ GRANT USAGE ON SCHEMA auth TO anon;
10
+ GRANT USAGE ON SCHEMA basejump TO anon;
11
+
12
+ -- Create a basic users table if it doesn't exist
13
+ CREATE TABLE IF NOT EXISTS auth.users (
14
+ id uuid PRIMARY KEY,
15
+ email text UNIQUE,
16
+ encrypted_password text,
17
+ created_at timestamp with time zone DEFAULT now(),
18
+ updated_at timestamp with time zone DEFAULT now()
19
+ );
20
+
21
+ -- Add Basejump configuration
22
+ CREATE TABLE IF NOT EXISTS basejump.config (
23
+ enable_team_accounts boolean DEFAULT true,
24
+ enable_personal_account_billing boolean DEFAULT true,
25
+ enable_team_account_billing boolean DEFAULT true
26
+ );
27
+
28
+ -- Insert default config if table is empty
29
+ INSERT INTO basejump.config (enable_team_accounts, enable_personal_account_billing, enable_team_account_billing)
30
+ SELECT true, true, true
31
+ WHERE NOT EXISTS (SELECT 1 FROM basejump.config);
32
+
33
+ -- Create accounts table for Suna
34
+ CREATE TABLE IF NOT EXISTS public.accounts (
35
+ id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
36
+ name text NOT NULL,
37
+ slug text UNIQUE NOT NULL,
38
+ created_at timestamptz DEFAULT now(),
39
+ updated_at timestamptz DEFAULT now()
40
+ );
41
+
42
+ -- Create projects table for Suna
43
+ CREATE TABLE IF NOT EXISTS public.projects (
44
+ project_id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
45
+ name text NOT NULL,
46
+ description text,
47
+ account_id uuid REFERENCES public.accounts(id) ON DELETE CASCADE,
48
+ sandbox jsonb DEFAULT NULL,
49
+ created_at timestamptz DEFAULT now(),
50
+ updated_at timestamptz DEFAULT now()
51
+ );
52
+
53
+ -- Create a function to create accounts
54
+ CREATE OR REPLACE FUNCTION create_account(
55
+ name TEXT,
56
+ slug TEXT
57
+ ) RETURNS json
58
+ LANGUAGE plpgsql SECURITY DEFINER
59
+ AS $$
60
+ DECLARE
61
+ account_id uuid;
62
+ existing_account_id uuid;
63
+ return_data json;
64
+ BEGIN
65
+ -- Check if slug is already taken
66
+ SELECT id INTO existing_account_id FROM public.accounts WHERE accounts.slug = create_account.slug;
67
+
68
+ IF existing_account_id IS NOT NULL THEN
69
+ RETURN json_build_object('error', 'Slug already taken');
70
+ END IF;
71
+
72
+ -- Insert account
73
+ INSERT INTO public.accounts (name, slug)
74
+ VALUES (create_account.name, create_account.slug)
75
+ RETURNING id INTO account_id;
76
+
77
+ return_data := json_build_object(
78
+ 'id', account_id,
79
+ 'name', name,
80
+ 'slug', slug
81
+ );
82
+
83
+ RETURN return_data;
84
+ END;
85
+ $$;
20250506000001_account_functions.sql ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Add account management functions
2
+
3
+ -- Function to update an account
4
+ CREATE OR REPLACE FUNCTION update_account(
5
+ name TEXT,
6
+ account_id UUID
7
+ ) RETURNS void
8
+ LANGUAGE plpgsql SECURITY DEFINER
9
+ AS $$
10
+ BEGIN
11
+ UPDATE public.accounts
12
+ SET
13
+ name = update_account.name,
14
+ updated_at = now()
15
+ WHERE id = update_account.account_id;
16
+ END;
17
+ $$;
18
+
19
+ -- Function to get all accounts for current user
20
+ CREATE OR REPLACE FUNCTION get_accounts()
21
+ RETURNS json
22
+ LANGUAGE plpgsql SECURITY DEFINER
23
+ AS $$
24
+ DECLARE
25
+ current_user_id uuid;
26
+ account_data json;
27
+ BEGIN
28
+ -- Get the current user's ID
29
+ current_user_id := auth.uid();
30
+
31
+ -- Query for accounts
32
+ SELECT json_agg(
33
+ json_build_object(
34
+ 'id', a.id,
35
+ 'name', a.name,
36
+ 'slug', a.slug,
37
+ 'personal_account', a.id = current_user_id
38
+ )
39
+ ) INTO account_data
40
+ FROM public.accounts a
41
+ WHERE a.id = current_user_id;
42
+
43
+ -- Return empty array if no results
44
+ IF account_data IS NULL THEN
45
+ RETURN '[]'::json;
46
+ END IF;
47
+
48
+ RETURN account_data;
49
+ END;
50
+ $$;
20250506000002_project_functions.sql ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Add project management functions
2
+
3
+ -- Function to create a project with account validation
4
+ CREATE OR REPLACE FUNCTION create_project(
5
+ name TEXT,
6
+ description TEXT,
7
+ account_id UUID
8
+ ) RETURNS json
9
+ LANGUAGE plpgsql SECURITY DEFINER
10
+ AS $$
11
+ DECLARE
12
+ new_project_id uuid;
13
+ project_data json;
14
+ BEGIN
15
+ -- Insert project
16
+ INSERT INTO public.projects (name, description, account_id)
17
+ VALUES (create_project.name, create_project.description, create_project.account_id)
18
+ RETURNING project_id INTO new_project_id;
19
+
20
+ -- Get the full project data
21
+ SELECT json_build_object(
22
+ 'project_id', p.project_id,
23
+ 'name', p.name,
24
+ 'description', p.description,
25
+ 'account_id', p.account_id,
26
+ 'sandbox', p.sandbox,
27
+ 'created_at', p.created_at,
28
+ 'updated_at', p.updated_at
29
+ ) INTO project_data
30
+ FROM public.projects p
31
+ WHERE p.project_id = new_project_id;
32
+
33
+ RETURN project_data;
34
+ END;
35
+ $$;
36
+
37
+ -- Function to update a project
38
+ CREATE OR REPLACE FUNCTION update_project(
39
+ project_id UUID,
40
+ name TEXT,
41
+ description TEXT
42
+ ) RETURNS json
43
+ LANGUAGE plpgsql SECURITY DEFINER
44
+ AS $$
45
+ DECLARE
46
+ updated_project_data json;
47
+ BEGIN
48
+ -- Update the project
49
+ UPDATE public.projects
50
+ SET
51
+ name = COALESCE(update_project.name, name),
52
+ description = COALESCE(update_project.description, description),
53
+ updated_at = now()
54
+ WHERE project_id = update_project.project_id;
55
+
56
+ -- Get the updated project data
57
+ SELECT json_build_object(
58
+ 'project_id', p.project_id,
59
+ 'name', p.name,
60
+ 'description', p.description,
61
+ 'account_id', p.account_id,
62
+ 'sandbox', p.sandbox,
63
+ 'created_at', p.created_at,
64
+ 'updated_at', p.updated_at
65
+ ) INTO updated_project_data
66
+ FROM public.projects p
67
+ WHERE p.project_id = update_project.project_id;
68
+
69
+ RETURN updated_project_data;
70
+ END;
71
+ $$;
72
+
73
+ -- Function to update a project's sandbox information
74
+ CREATE OR REPLACE FUNCTION update_project_sandbox(
75
+ project_id UUID,
76
+ sandbox_data jsonb
77
+ ) RETURNS json
78
+ LANGUAGE plpgsql SECURITY DEFINER
79
+ AS $$
80
+ DECLARE
81
+ updated_project_data json;
82
+ BEGIN
83
+ -- Update the project sandbox data
84
+ UPDATE public.projects
85
+ SET
86
+ sandbox = sandbox_data,
87
+ updated_at = now()
88
+ WHERE project_id = update_project_sandbox.project_id;
89
+
90
+ -- Get the updated project data
91
+ SELECT json_build_object(
92
+ 'project_id', p.project_id,
93
+ 'name', p.name,
94
+ 'description', p.description,
95
+ 'account_id', p.account_id,
96
+ 'sandbox', p.sandbox,
97
+ 'created_at', p.created_at,
98
+ 'updated_at', p.updated_at
99
+ ) INTO updated_project_data
100
+ FROM public.projects p
101
+ WHERE p.project_id = update_project_sandbox.project_id;
102
+
103
+ RETURN updated_project_data;
104
+ END;
105
+ $$;
22dc0511fe69_add_toolsource_table.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
2ea570019b8f_add_apikey_table.cpython-311.pyc ADDED
Binary file (4.54 kB). View file
 
2ea570019b8f_add_apikey_table.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Add ApiKey table
2
+
3
+ Revision ID: 2ea570019b8f
4
+ Revises: 4af13678b83c
5
+ Create Date: 2025-05-03 18:56:32.989446
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ from sqlalchemy.dialects import postgresql
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = '2ea570019b8f'
16
+ down_revision: Union[str, None] = '4af13678b83c'
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade schema."""
23
+ # ### commands auto generated by Alembic - adjusted ###
24
+ op.create_table('api_keys',
25
+ sa.Column('id', sa.UUID(), nullable=False),
26
+ sa.Column('user_id', sa.UUID(), nullable=False),
27
+ sa.Column('name', sa.String(), nullable=False),
28
+ sa.Column('description', sa.Text(), nullable=True),
29
+ sa.Column('key_prefix', sa.String(length=8), nullable=False),
30
+ sa.Column('hashed_key', sa.String(), nullable=False),
31
+ sa.Column('scopes', postgresql.ARRAY(sa.String()), nullable=True),
32
+ sa.Column('is_active', sa.Boolean(), nullable=False),
33
+ sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
34
+ sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
35
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_api_keys_user_id_users_id')),
36
+ sa.PrimaryKeyConstraint('id', name=op.f('pk_api_keys'))
37
+ )
38
+ with op.batch_alter_table('api_keys', schema=None) as batch_op:
39
+ batch_op.create_index(batch_op.f('ix_api_keys_hashed_key'), ['hashed_key'], unique=False)
40
+ batch_op.create_index(batch_op.f('ix_api_keys_key_prefix'), ['key_prefix'], unique=True)
41
+ batch_op.create_index(batch_op.f('ix_api_keys_user_id'), ['user_id'], unique=False)
42
+
43
+ # Removed incorrect drop/alter table commands for other tables
44
+ # ### end Alembic commands ###
45
+
46
+
47
+ def downgrade() -> None:
48
+ """Downgrade schema."""
49
+ # ### commands auto generated by Alembic - adjusted ###
50
+ with op.batch_alter_table('api_keys', schema=None) as batch_op:
51
+ batch_op.drop_index(batch_op.f('ix_api_keys_user_id'))
52
+ batch_op.drop_index(batch_op.f('ix_api_keys_key_prefix'))
53
+ batch_op.drop_index(batch_op.f('ix_api_keys_hashed_key'))
54
+
55
+ op.drop_table('api_keys')
56
+ # Removed incorrect create/alter table commands for other tables
57
+ # ### end Alembic commands ###
58
+
4af13678b83c_add_toolsource_table.cpython-311.pyc ADDED
Binary file (3.61 kB). View file
 
4af13678b83c_add_toolsource_table.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Add ToolSource table
2
+
3
+ Revision ID: 4af13678b83c
4
+ Revises: e2ca2546bf71
5
+ Create Date: 2025-05-03 18:51:11.601728
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ from sqlalchemy.dialects import postgresql
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = '4af13678b83c'
16
+ down_revision: Union[str, None] = 'e2ca2546bf71'
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade schema."""
23
+ # ### commands auto generated by Alembic - please adjust! ###
24
+ # Manually corrected: Remove incorrect drop commands and add create_table for tool_sources
25
+ op.create_table('tool_sources',
26
+ sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
27
+ sa.Column('github_url', sa.String(), nullable=False),
28
+ sa.Column('description', sa.Text(), nullable=True),
29
+ sa.Column('status', sa.String(), nullable=False, server_default='active'), # Match default from model
30
+ sa.Column('last_checked_at', sa.DateTime(timezone=True), nullable=True),
31
+ sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
32
+ sa.PrimaryKeyConstraint('id', name=op.f('pk_tool_sources'))
33
+ )
34
+ with op.batch_alter_table('tool_sources', schema=None) as batch_op:
35
+ batch_op.create_index(batch_op.f('ix_tool_sources_github_url'), ['github_url'], unique=True)
36
+ batch_op.create_index(batch_op.f('ix_tool_sources_status'), ['status'], unique=False)
37
+
38
+ # ### end Alembic commands ###
39
+
40
+
41
+ def downgrade() -> None:
42
+ """Downgrade schema."""
43
+ # ### commands auto generated by Alembic - please adjust! ###
44
+ with op.batch_alter_table('tool_sources', schema=None) as batch_op:
45
+ batch_op.drop_index(batch_op.f('ix_tool_sources_status'))
46
+ batch_op.drop_index(batch_op.f('ix_tool_sources_github_url'))
47
+
48
+ op.drop_table('tool_sources')
49
+ # ### end Alembic commands ###
50
+
ActiveJobsProvider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
4
+
5
+
6
+ class ActiveJobsProvider(RapidDataProviderBase):
7
+ def __init__(self):
8
+ endpoints: Dict[str, EndpointSchema] = {
9
+ "active_jobs": {
10
+ "route": "/active-ats-7d",
11
+ "method": "GET",
12
+ "name": "Active Jobs Search",
13
+ "description": "Get active job listings with various filter options.",
14
+ "payload": {
15
+ "limit": "Optional. Number of jobs per API call (10-100). Default is 100.",
16
+ "offset": "Optional. Offset for pagination. Default is 0.",
17
+ "title_filter": "Optional. Search terms for job title.",
18
+ "advanced_title_filter": "Optional. Advanced title filter with operators (can't be used with title_filter).",
19
+ "location_filter": "Optional. Filter by location(s). Use full names like 'United States' not 'US'.",
20
+ "description_filter": "Optional. Filter on job description content.",
21
+ "organization_filter": "Optional. Filter by company name(s).",
22
+ "description_type": "Optional. Return format for description: 'text' or 'html'. Leave empty to exclude descriptions.",
23
+ "source": "Optional. Filter by ATS source.",
24
+ "date_filter": "Optional. Filter by posting date (greater than).",
25
+ "ai_employment_type_filter": "Optional. Filter by employment type (FULL_TIME, PART_TIME, etc).",
26
+ "ai_work_arrangement_filter": "Optional. Filter by work arrangement (On-site, Hybrid, Remote OK, Remote Solely).",
27
+ "ai_experience_level_filter": "Optional. Filter by experience level (0-2, 2-5, 5-10, 10+).",
28
+ "li_organization_slug_filter": "Optional. Filter by LinkedIn company slug.",
29
+ "li_organization_slug_exclusion_filter": "Optional. Exclude LinkedIn company slugs.",
30
+ "li_industry_filter": "Optional. Filter by LinkedIn industry.",
31
+ "li_organization_specialties_filter": "Optional. Filter by LinkedIn company specialties.",
32
+ "li_organization_description_filter": "Optional. Filter by LinkedIn company description."
33
+ }
34
+ }
35
+ }
36
+
37
+ base_url = "https://active-jobs-db.p.rapidapi.com"
38
+ super().__init__(base_url, endpoints)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ from dotenv import load_dotenv
43
+ load_dotenv()
44
+ tool = ActiveJobsProvider()
45
+
46
+ # Example for searching active jobs
47
+ jobs = tool.call_endpoint(
48
+ route="active_jobs",
49
+ payload={
50
+ "limit": "10",
51
+ "offset": "0",
52
+ "title_filter": "\"Data Engineer\"",
53
+ "location_filter": "\"United States\" OR \"United Kingdom\"",
54
+ "description_type": "text"
55
+ }
56
+ )
57
+ print("Active Jobs:", jobs)
AmazonProvider.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
4
+
5
+
6
+ class AmazonProvider(RapidDataProviderBase):
7
+ def __init__(self):
8
+ endpoints: Dict[str, EndpointSchema] = {
9
+ "search": {
10
+ "route": "/search",
11
+ "method": "GET",
12
+ "name": "Amazon Product Search",
13
+ "description": "Search for products on Amazon with various filters and parameters.",
14
+ "payload": {
15
+ "query": "Search query (supports both free-form text queries or a product asin)",
16
+ "page": "Results page to return (default: 1)",
17
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
18
+ "sort_by": "Return the results in a specific sort order (RELEVANCE, LOWEST_PRICE, HIGHEST_PRICE, REVIEWS, NEWEST, BEST_SELLERS)",
19
+ "product_condition": "Return products in a specific condition (ALL, NEW, USED, RENEWED, COLLECTIBLE)",
20
+ "is_prime": "Only return prime products (boolean)",
21
+ "deals_and_discounts": "Return deals and discounts in a specific condition (NONE, ALL_DISCOUNTS, TODAYS_DEALS)",
22
+ "category_id": "Find products in a specific category / department (optional)",
23
+ "category": "Filter by specific numeric Amazon category (optional)",
24
+ "min_price": "Only return product offers with price greater than a certain value (optional)",
25
+ "max_price": "Only return product offers with price lower than a certain value (optional)",
26
+ "brand": "Find products with a specific brand (optional)",
27
+ "seller_id": "Find products sold by specific seller (optional)",
28
+ "four_stars_and_up": "Return product listings with ratings of 4 stars & up (optional)",
29
+ "additional_filters": "Any filters available on the Amazon page but not part of this endpoint's parameters (optional)"
30
+ }
31
+ },
32
+ "product-details": {
33
+ "route": "/product-details",
34
+ "method": "GET",
35
+ "name": "Amazon Product Details",
36
+ "description": "Get detailed information about specific Amazon products by ASIN.",
37
+ "payload": {
38
+ "asin": "Product ASIN for which to get details. Supports batching of up to 10 ASINs in a single request, separated by comma.",
39
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
40
+ "more_info_query": "A query to search and get more info about the product as part of Product Information, Customer Q&As, and Customer Reviews (optional)",
41
+ "fields": "A comma separated list of product fields to include in the response (field projection). By default all fields are returned. (optional)"
42
+ }
43
+ },
44
+ "products-by-category": {
45
+ "route": "/products-by-category",
46
+ "method": "GET",
47
+ "name": "Amazon Products by Category",
48
+ "description": "Get products from a specific Amazon category.",
49
+ "payload": {
50
+ "category_id": "The Amazon category for which to return results. Multiple category values can be separated by comma.",
51
+ "page": "Page to return (default: 1)",
52
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
53
+ "sort_by": "Return the results in a specific sort order (RELEVANCE, LOWEST_PRICE, HIGHEST_PRICE, REVIEWS, NEWEST, BEST_SELLERS)",
54
+ "min_price": "Only return product offers with price greater than a certain value (optional)",
55
+ "max_price": "Only return product offers with price lower than a certain value (optional)",
56
+ "product_condition": "Return products in a specific condition (ALL, NEW, USED, RENEWED, COLLECTIBLE)",
57
+ "brand": "Only return products of a specific brand. Multiple brands can be specified as a comma separated list (optional)",
58
+ "is_prime": "Only return prime products (boolean)",
59
+ "deals_and_discounts": "Return deals and discounts in a specific condition (NONE, ALL_DISCOUNTS, TODAYS_DEALS)",
60
+ "four_stars_and_up": "Return product listings with ratings of 4 stars & up (optional)",
61
+ "additional_filters": "Any filters available on the Amazon page but not part of this endpoint's parameters (optional)"
62
+ }
63
+ },
64
+ "product-reviews": {
65
+ "route": "/product-reviews",
66
+ "method": "GET",
67
+ "name": "Amazon Product Reviews",
68
+ "description": "Get customer reviews for a specific Amazon product by ASIN.",
69
+ "payload": {
70
+ "asin": "Product asin for which to get reviews.",
71
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
72
+ "page": "Results page to return (default: 1)",
73
+ "sort_by": "Return reviews in a specific sort order (TOP_REVIEWS, MOST_RECENT)",
74
+ "star_rating": "Only return reviews with a specific star rating (ALL, 5_STARS, 4_STARS, 3_STARS, 2_STARS, 1_STARS, POSITIVE, CRITICAL)",
75
+ "verified_purchases_only": "Only return reviews by reviewers who made a verified purchase (boolean)",
76
+ "images_or_videos_only": "Only return reviews containing images and / or videos (boolean)",
77
+ "current_format_only": "Only return reviews of the current format (product variant - e.g. Color) (boolean)"
78
+ }
79
+ },
80
+ "seller-profile": {
81
+ "route": "/seller-profile",
82
+ "method": "GET",
83
+ "name": "Amazon Seller Profile",
84
+ "description": "Get detailed information about a specific Amazon seller by Seller ID.",
85
+ "payload": {
86
+ "seller_id": "The Amazon Seller ID for which to get seller profile details",
87
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
88
+ "fields": "A comma separated list of seller profile fields to include in the response (field projection). By default all fields are returned. (optional)"
89
+ }
90
+ },
91
+ "seller-reviews": {
92
+ "route": "/seller-reviews",
93
+ "method": "GET",
94
+ "name": "Amazon Seller Reviews",
95
+ "description": "Get customer reviews for a specific Amazon seller by Seller ID.",
96
+ "payload": {
97
+ "seller_id": "The Amazon Seller ID for which to get seller reviews",
98
+ "country": "Sets the Amazon domain, marketplace country, language and currency (default: US)",
99
+ "star_rating": "Only return reviews with a specific star rating or positive / negative sentiment (ALL, 5_STARS, 4_STARS, 3_STARS, 2_STARS, 1_STARS, POSITIVE, CRITICAL)",
100
+ "page": "The page of seller feedback results to retrieve (default: 1)",
101
+ "fields": "A comma separated list of seller review fields to include in the response (field projection). By default all fields are returned. (optional)"
102
+ }
103
+ }
104
+ }
105
+ base_url = "https://real-time-amazon-data.p.rapidapi.com"
106
+ super().__init__(base_url, endpoints)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ from dotenv import load_dotenv
111
+ load_dotenv()
112
+ tool = AmazonProvider()
113
+
114
+ # Example for product search
115
+ search_result = tool.call_endpoint(
116
+ route="search",
117
+ payload={
118
+ "query": "Phone",
119
+ "page": 1,
120
+ "country": "US",
121
+ "sort_by": "RELEVANCE",
122
+ "product_condition": "ALL",
123
+ "is_prime": False,
124
+ "deals_and_discounts": "NONE"
125
+ }
126
+ )
127
+ print("Search Result:", search_result)
128
+
129
+ # Example for product details
130
+ details_result = tool.call_endpoint(
131
+ route="product-details",
132
+ payload={
133
+ "asin": "B07ZPKBL9V",
134
+ "country": "US"
135
+ }
136
+ )
137
+ print("Product Details:", details_result)
138
+
139
+ # Example for products by category
140
+ category_result = tool.call_endpoint(
141
+ route="products-by-category",
142
+ payload={
143
+ "category_id": "2478868012",
144
+ "page": 1,
145
+ "country": "US",
146
+ "sort_by": "RELEVANCE",
147
+ "product_condition": "ALL",
148
+ "is_prime": False,
149
+ "deals_and_discounts": "NONE"
150
+ }
151
+ )
152
+ print("Category Products:", category_result)
153
+
154
+ # Example for product reviews
155
+ reviews_result = tool.call_endpoint(
156
+ route="product-reviews",
157
+ payload={
158
+ "asin": "B07ZPKN6YR",
159
+ "country": "US",
160
+ "page": 1,
161
+ "sort_by": "TOP_REVIEWS",
162
+ "star_rating": "ALL",
163
+ "verified_purchases_only": False,
164
+ "images_or_videos_only": False,
165
+ "current_format_only": False
166
+ }
167
+ )
168
+ print("Product Reviews:", reviews_result)
169
+
170
+ # Example for seller profile
171
+ seller_result = tool.call_endpoint(
172
+ route="seller-profile",
173
+ payload={
174
+ "seller_id": "A02211013Q5HP3OMSZC7W",
175
+ "country": "US"
176
+ }
177
+ )
178
+ print("Seller Profile:", seller_result)
179
+
180
+ # Example for seller reviews
181
+ seller_reviews_result = tool.call_endpoint(
182
+ route="seller-reviews",
183
+ payload={
184
+ "seller_id": "A02211013Q5HP3OMSZC7W",
185
+ "country": "US",
186
+ "star_rating": "ALL",
187
+ "page": 1
188
+ }
189
+ )
190
+ print("Seller Reviews:", seller_reviews_result)
191
+
ChatInterface.tsx ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // /home/ubuntu/visionos-frontend/src/components/ChatInterface.tsx
2
+ import React from 'react';
3
+
4
+ const ChatInterface: React.FC = () => {
5
+ return (
6
+ <div className="flex flex-col h-full border rounded-lg shadow-md">
7
+ {/* Message Display Area */}
8
+ <div className="flex-grow p-4 overflow-y-auto bg-gray-50">
9
+ {/* Placeholder for messages */}
10
+ <p className="text-gray-500">Chat messages will appear here...</p>
11
+ </div>
12
+
13
+ {/* Input Area */}
14
+ <div className="p-4 border-t bg-white">
15
+ <input
16
+ type="text"
17
+ placeholder="Type your message..."
18
+ className="w-full p-2 border rounded"
19
+ />
20
+ {/* Placeholder for Send Button */}
21
+ <button className="mt-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
22
+ Send
23
+ </button>
24
+ </div>
25
+ </div>
26
+ );
27
+ };
28
+
29
+ export default ChatInterface;
30
+
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM node:18-alpine
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy package files and install dependencies
6
+ COPY package*.json ./
7
+ RUN npm install
8
+
9
+ # Copy the rest of the application code
10
+ COPY . .
11
+
12
+ # Build the application
13
+ RUN npm run build
14
+
15
+ # Expose the port
16
+ EXPOSE 3000
17
+
18
+ # Start the application
19
+ CMD ["npm", "start"]
Layout.tsx ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // /home/ubuntu/visionos-frontend/src/components/Layout.tsx
2
+ import React from 'react';
3
+ import Link from 'next/link'; // Import Link for navigation
4
+
5
+ interface LayoutProps {
6
+ children: React.ReactNode;
7
+ }
8
+
9
+ const Layout: React.FC<LayoutProps> = ({ children }) => {
10
+ return (
11
+ <div className="flex flex-col min-h-screen">
12
+ {/* Header/Navigation */}
13
+ <header className="bg-gray-800 text-white p-4">
14
+ <nav className="container mx-auto flex justify-between items-center">
15
+ <h1 className="text-xl font-bold">
16
+ <Link href="/">VisionOS UI</Link>
17
+ </h1>
18
+ <ul className="flex space-x-4">
19
+ <li><Link href="/" className="hover:text-gray-300">Chat</Link></li>
20
+ <li><Link href="/workflow" className="hover:text-gray-300">Workflow</Link></li>
21
+ <li><Link href="/settings" className="hover:text-gray-300">Settings</Link></li>
22
+ {/* Add more navigation links here */}
23
+ </ul>
24
+ </nav>
25
+ </header>
26
+
27
+ {/* Main Content Area */}
28
+ <main className="flex-grow p-4 container mx-auto">
29
+ {children}
30
+ </main>
31
+
32
+ {/* Footer */}
33
+ <footer className="bg-gray-200 p-4 text-center text-sm text-gray-600">
34
+ © 2025 VisionOS
35
+ </footer>
36
+ </div>
37
+ );
38
+ };
39
+
40
+ export default Layout;
41
+
LinkedinProvider.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
4
+
5
+
6
+ class LinkedinProvider(RapidDataProviderBase):
7
+ def __init__(self):
8
+ endpoints: Dict[str, EndpointSchema] = {
9
+ "person": {
10
+ "route": "/person",
11
+ "method": "POST",
12
+ "name": "Person Data",
13
+ "description": "Fetches any Linkedin profiles data including skills, certificates, experiences, qualifications and much more.",
14
+ "payload": {
15
+ "link": "LinkedIn Profile URL"
16
+ }
17
+ },
18
+ "person_urn": {
19
+ "route": "/person_urn",
20
+ "method": "POST",
21
+ "name": "Person Data (Using Urn)",
22
+ "description": "It takes profile urn instead of profile public identifier in input",
23
+ "payload": {
24
+ "link": "LinkedIn Profile URL or URN"
25
+ }
26
+ },
27
+ "person_deep": {
28
+ "route": "/person_deep",
29
+ "method": "POST",
30
+ "name": "Person Data (Deep)",
31
+ "description": "Fetches all experiences, educations, skills, languages, publications... related to a profile.",
32
+ "payload": {
33
+ "link": "LinkedIn Profile URL"
34
+ }
35
+ },
36
+ "profile_updates": {
37
+ "route": "/profile_updates",
38
+ "method": "GET",
39
+ "name": "Person Posts (WITH PAGINATION)",
40
+ "description": "Fetches posts of a linkedin profile alongwith reactions, comments, postLink and reposts data.",
41
+ "payload": {
42
+ "profile_url": "LinkedIn Profile URL",
43
+ "page": "Page number",
44
+ "reposts": "Include reposts (1 or 0)",
45
+ "comments": "Include comments (1 or 0)"
46
+ }
47
+ },
48
+ "profile_recent_comments": {
49
+ "route": "/profile_recent_comments",
50
+ "method": "POST",
51
+ "name": "Person Recent Activity (Comments on Posts)",
52
+ "description": "Fetches 20 most recent comments posted by a linkedin user (per page).",
53
+ "payload": {
54
+ "profile_url": "LinkedIn Profile URL",
55
+ "page": "Page number",
56
+ "paginationToken": "Token for pagination"
57
+ }
58
+ },
59
+ "comments_from_recent_activity": {
60
+ "route": "/comments_from_recent_activity",
61
+ "method": "GET",
62
+ "name": "Comments from recent activity",
63
+ "description": "Fetches recent comments posted by a person as per his recent activity tab.",
64
+ "payload": {
65
+ "profile_url": "LinkedIn Profile URL",
66
+ "page": "Page number"
67
+ }
68
+ },
69
+ "person_skills": {
70
+ "route": "/person_skills",
71
+ "method": "POST",
72
+ "name": "Person Skills",
73
+ "description": "Scraper all skills of a linkedin user",
74
+ "payload": {
75
+ "link": "LinkedIn Profile URL"
76
+ }
77
+ },
78
+ "email_to_linkedin_profile": {
79
+ "route": "/email_to_linkedin_profile",
80
+ "method": "POST",
81
+ "name": "Email to LinkedIn Profile",
82
+ "description": "Finds LinkedIn profile associated with an email address",
83
+ "payload": {
84
+ "email": "Email address to search"
85
+ }
86
+ },
87
+ "company": {
88
+ "route": "/company",
89
+ "method": "POST",
90
+ "name": "Company Data",
91
+ "description": "Fetches LinkedIn company profile data",
92
+ "payload": {
93
+ "link": "LinkedIn Company URL"
94
+ }
95
+ },
96
+ "web_domain": {
97
+ "route": "/web-domain",
98
+ "method": "POST",
99
+ "name": "Web Domain to Company",
100
+ "description": "Fetches LinkedIn company profile data from a web domain",
101
+ "payload": {
102
+ "link": "Website domain (e.g., huzzle.app)"
103
+ }
104
+ },
105
+ "similar_profiles": {
106
+ "route": "/similar_profiles",
107
+ "method": "GET",
108
+ "name": "Similar Profiles",
109
+ "description": "Fetches profiles similar to a given LinkedIn profile",
110
+ "payload": {
111
+ "profileUrl": "LinkedIn Profile URL"
112
+ }
113
+ },
114
+ "company_jobs": {
115
+ "route": "/company_jobs",
116
+ "method": "POST",
117
+ "name": "Company Jobs",
118
+ "description": "Fetches job listings from a LinkedIn company page",
119
+ "payload": {
120
+ "company_url": "LinkedIn Company URL",
121
+ "count": "Number of job listings to fetch"
122
+ }
123
+ },
124
+ "company_updates": {
125
+ "route": "/company_updates",
126
+ "method": "GET",
127
+ "name": "Company Posts",
128
+ "description": "Fetches posts from a LinkedIn company page",
129
+ "payload": {
130
+ "company_url": "LinkedIn Company URL",
131
+ "page": "Page number",
132
+ "reposts": "Include reposts (0, 1, or 2)",
133
+ "comments": "Include comments (0, 1, or 2)"
134
+ }
135
+ },
136
+ "company_employee": {
137
+ "route": "/company_employee",
138
+ "method": "GET",
139
+ "name": "Company Employees",
140
+ "description": "Fetches employees of a LinkedIn company using company ID",
141
+ "payload": {
142
+ "companyId": "LinkedIn Company ID",
143
+ "page": "Page number"
144
+ }
145
+ },
146
+ "company_updates_post": {
147
+ "route": "/company_updates",
148
+ "method": "POST",
149
+ "name": "Company Posts (POST)",
150
+ "description": "Fetches posts from a LinkedIn company page with specific count parameters",
151
+ "payload": {
152
+ "company_url": "LinkedIn Company URL",
153
+ "posts": "Number of posts to fetch",
154
+ "comments": "Number of comments to fetch per post",
155
+ "reposts": "Number of reposts to fetch"
156
+ }
157
+ },
158
+ "search_posts_with_filters": {
159
+ "route": "/search_posts_with_filters",
160
+ "method": "GET",
161
+ "name": "Search Posts With Filters",
162
+ "description": "Searches LinkedIn posts with various filtering options",
163
+ "payload": {
164
+ "query": "Keywords/Search terms (text you put in LinkedIn search bar)",
165
+ "page": "Page number (1-100, each page contains 20 results)",
166
+ "sort_by": "Sort method: 'relevance' (Top match) or 'date_posted' (Latest)",
167
+ "author_job_title": "Filter by job title of author (e.g., CEO)",
168
+ "content_type": "Type of content post contains (photos, videos, liveVideos, collaborativeArticles, documents)",
169
+ "from_member": "URN of person who posted (comma-separated for multiple)",
170
+ "from_organization": "ID of organization who posted (comma-separated for multiple)",
171
+ "author_company": "ID of company author works for (comma-separated for multiple)",
172
+ "author_industry": "URN of industry author is connected with (comma-separated for multiple)",
173
+ "mentions_member": "URN of person mentioned in post (comma-separated for multiple)",
174
+ "mentions_organization": "ID of organization mentioned in post (comma-separated for multiple)"
175
+ }
176
+ },
177
+ "search_jobs": {
178
+ "route": "/search_jobs",
179
+ "method": "GET",
180
+ "name": "Search Jobs",
181
+ "description": "Searches LinkedIn jobs with various filtering options",
182
+ "payload": {
183
+ "query": "Job search keywords (e.g., Software developer)",
184
+ "page": "Page number",
185
+ "searchLocationId": "Location ID for job search (get from Suggestion location endpoint)",
186
+ "easyApply": "Filter for easy apply jobs (true or false)",
187
+ "experience": "Experience level required (1=Internship, 2=Entry level, 3=Associate, 4=Mid senior, 5=Director, 6=Executive, comma-separated)",
188
+ "jobType": "Job type (F=Full time, P=Part time, C=Contract, T=Temporary, V=Volunteer, I=Internship, O=Other, comma-separated)",
189
+ "postedAgo": "Time jobs were posted in seconds (e.g., 3600 for past hour)",
190
+ "workplaceType": "Workplace type (1=On-Site, 2=Remote, 3=Hybrid, comma-separated)",
191
+ "sortBy": "Sort method (DD=most recent, R=most relevant)",
192
+ "companyIdsList": "List of company IDs, comma-separated",
193
+ "industryIdsList": "List of industry IDs, comma-separated",
194
+ "functionIdsList": "List of function IDs, comma-separated",
195
+ "titleIdsList": "List of job title IDs, comma-separated",
196
+ "locationIdsList": "List of location IDs within specified searchLocationId country, comma-separated"
197
+ }
198
+ },
199
+ "search_people_with_filters": {
200
+ "route": "/search_people_with_filters",
201
+ "method": "POST",
202
+ "name": "Search People With Filters",
203
+ "description": "Searches LinkedIn profiles with detailed filtering options",
204
+ "payload": {
205
+ "keyword": "General search keyword",
206
+ "page": "Page number",
207
+ "title_free_text": "Job title to filter by (e.g., CEO)",
208
+ "company_free_text": "Company name to filter by",
209
+ "first_name": "First name of person",
210
+ "last_name": "Last name of person",
211
+ "current_company_list": "List of current companies (comma-separated IDs)",
212
+ "past_company_list": "List of past companies (comma-separated IDs)",
213
+ "location_list": "List of locations (comma-separated IDs)",
214
+ "language_list": "List of languages (comma-separated)",
215
+ "service_catagory_list": "List of service categories (comma-separated)",
216
+ "school_free_text": "School name to filter by",
217
+ "industry_list": "List of industries (comma-separated IDs)",
218
+ "school_list": "List of schools (comma-separated IDs)"
219
+ }
220
+ },
221
+ "search_company_with_filters": {
222
+ "route": "/search_company_with_filters",
223
+ "method": "POST",
224
+ "name": "Search Company With Filters",
225
+ "description": "Searches LinkedIn companies with detailed filtering options",
226
+ "payload": {
227
+ "keyword": "General search keyword",
228
+ "page": "Page number",
229
+ "company_size_list": "List of company sizes (comma-separated, e.g., A,D)",
230
+ "hasJobs": "Filter companies with jobs (true or false)",
231
+ "location_list": "List of location IDs (comma-separated)",
232
+ "industry_list": "List of industry IDs (comma-separated)"
233
+ }
234
+ }
235
+ }
236
+ base_url = "https://linkedin-data-scraper.p.rapidapi.com"
237
+ super().__init__(base_url, endpoints)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ from dotenv import load_dotenv
242
+ load_dotenv()
243
+ tool = LinkedinProvider()
244
+
245
+ result = tool.call_endpoint(
246
+ route="comments_from_recent_activity",
247
+ payload={"profile_url": "https://www.linkedin.com/in/adamcohenhillel/", "page": 1}
248
+ )
249
+ print(result)
250
+
MANIFEST.in ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Include all Python files in agentpress directory
2
+ recursive-include agentpress *.py
3
+
4
+ # Include example files
5
+ recursive-include agentpress/examples *
6
+
7
+ # Include any other necessary files
8
+ include LICENSE
9
+ include README.md
10
+ include pyproject.toml
11
+
12
+ # Exclude unnecessary files
13
+ global-exclude *.pyc
14
+ global-exclude __pycache__
15
+ global-exclude .DS_Store
16
+ global-exclude *.pyo
17
+ global-exclude *.pyd
README ADDED
@@ -0,0 +1 @@
 
 
1
+ Generic single-database configuration.
README.md CHANGED
@@ -1,10 +1,36 @@
1
- ---
2
- license: mit
3
- title: visionosai
4
- sdk: streamlit
5
- emoji: 📊
6
- colorFrom: red
7
- colorTo: indigo
8
- pinned: true
9
- sdk_version: 1.45.0
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
2
+
3
+ ## Getting Started
4
+
5
+ First, run the development server:
6
+
7
+ ```bash
8
+ npm run dev
9
+ # or
10
+ yarn dev
11
+ # or
12
+ pnpm dev
13
+ # or
14
+ bun dev
15
+ ```
16
+
17
+ Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
18
+
19
+ You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
20
+
21
+ This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
22
+
23
+ ## Learn More
24
+
25
+ To learn more about Next.js, take a look at the following resources:
26
+
27
+ - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
28
+ - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
29
+
30
+ You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
31
+
32
+ ## Deploy on Vercel
33
+
34
+ The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
35
+
36
+ Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
RapidDataProviderBase.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Dict, Any, Optional, TypedDict, Literal
4
+
5
+
6
+ class EndpointSchema(TypedDict):
7
+ route: str
8
+ method: Literal['GET', 'POST']
9
+ name: str
10
+ description: str
11
+ payload: Dict[str, Any]
12
+
13
+
14
+ class RapidDataProviderBase:
15
+ def __init__(self, base_url: str, endpoints: Dict[str, EndpointSchema]):
16
+ self.base_url = base_url
17
+ self.endpoints = endpoints
18
+
19
+ def get_endpoints(self):
20
+ return self.endpoints
21
+
22
+ def call_endpoint(
23
+ self,
24
+ route: str,
25
+ payload: Optional[Dict[str, Any]] = None
26
+ ):
27
+ """
28
+ Call an API endpoint with the given parameters and data.
29
+
30
+ Args:
31
+ endpoint (EndpointSchema): The endpoint configuration dictionary
32
+ params (dict, optional): Query parameters for GET requests
33
+ payload (dict, optional): JSON payload for POST requests
34
+
35
+ Returns:
36
+ dict: The JSON response from the API
37
+ """
38
+ if route.startswith("/"):
39
+ route = route[1:]
40
+
41
+ endpoint = self.endpoints.get(route)
42
+ if not endpoint:
43
+ raise ValueError(f"Endpoint {route} not found")
44
+
45
+ url = f"{self.base_url}{endpoint['route']}"
46
+
47
+ headers = {
48
+ "x-rapidapi-key": os.getenv("RAPID_API_KEY"),
49
+ "x-rapidapi-host": url.split("//")[1].split("/")[0],
50
+ "Content-Type": "application/json"
51
+ }
52
+
53
+ method = endpoint.get('method', 'GET').upper()
54
+
55
+ if method == 'GET':
56
+ response = requests.get(url, params=payload, headers=headers)
57
+ elif method == 'POST':
58
+ response = requests.post(url, json=payload, headers=headers)
59
+ else:
60
+ raise ValueError(f"Unsupported HTTP method: {method}")
61
+ return response.json()
SettingsPanel.tsx ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // /home/ubuntu/visionos-frontend/src/components/SettingsPanel.tsx
2
+ import React from 'react';
3
+
4
+ const SettingsPanel: React.FC = () => {
5
+ return (
6
+ <div className="p-4 border rounded-lg shadow-md bg-white">
7
+ <h3 className="text-lg font-semibold mb-4">System Settings</h3>
8
+ {/* Placeholder for settings form */}
9
+ <div className="space-y-4">
10
+ <div>
11
+ <label htmlFor="setting1" className="block text-sm font-medium text-gray-700">Example Setting 1</label>
12
+ <input type="text" id="setting1" className="mt-1 block w-full p-2 border border-gray-300 rounded-md shadow-sm" placeholder="Value" />
13
+ </div>
14
+ <div>
15
+ <label htmlFor="setting2" className="block text-sm font-medium text-gray-700">Example Setting 2</label>
16
+ <select id="setting2" className="mt-1 block w-full p-2 border border-gray-300 rounded-md shadow-sm">
17
+ <option>Option A</option>
18
+ <option>Option B</option>
19
+ </select>
20
+ </div>
21
+ {/* Add more settings fields as needed */}
22
+ <button className="mt-4 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
23
+ Save Settings
24
+ </button>
25
+ </div>
26
+ </div>
27
+ );
28
+ };
29
+
30
+ export default SettingsPanel;
31
+
TwitterProvider.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
4
+
5
+
6
+ class TwitterProvider(RapidDataProviderBase):
7
+ def __init__(self):
8
+ endpoints: Dict[str, EndpointSchema] = {
9
+ "user_info": {
10
+ "route": "/screenname.php",
11
+ "method": "GET",
12
+ "name": "Twitter User Info",
13
+ "description": "Get information about a Twitter user by screenname or user ID.",
14
+ "payload": {
15
+ "screenname": "Twitter username without the @ symbol",
16
+ "rest_id": "Optional Twitter user's ID. If provided, overwrites screenname parameter."
17
+ }
18
+ },
19
+ "timeline": {
20
+ "route": "/timeline.php",
21
+ "method": "GET",
22
+ "name": "User Timeline",
23
+ "description": "Get tweets from a user's timeline.",
24
+ "payload": {
25
+ "screenname": "Twitter username without the @ symbol",
26
+ "rest_id": "Optional parameter that overwrites the screenname",
27
+ "cursor": "Optional pagination cursor"
28
+ }
29
+ },
30
+ "following": {
31
+ "route": "/following.php",
32
+ "method": "GET",
33
+ "name": "User Following",
34
+ "description": "Get users that a specific user follows.",
35
+ "payload": {
36
+ "screenname": "Twitter username without the @ symbol",
37
+ "rest_id": "Optional parameter that overwrites the screenname",
38
+ "cursor": "Optional pagination cursor"
39
+ }
40
+ },
41
+ "followers": {
42
+ "route": "/followers.php",
43
+ "method": "GET",
44
+ "name": "User Followers",
45
+ "description": "Get followers of a specific user.",
46
+ "payload": {
47
+ "screenname": "Twitter username without the @ symbol",
48
+ "cursor": "Optional pagination cursor"
49
+ }
50
+ },
51
+ "search": {
52
+ "route": "/search.php",
53
+ "method": "GET",
54
+ "name": "Twitter Search",
55
+ "description": "Search for tweets with a specific query.",
56
+ "payload": {
57
+ "query": "Search query string",
58
+ "cursor": "Optional pagination cursor",
59
+ "search_type": "Optional search type (e.g. 'Top')"
60
+ }
61
+ },
62
+ "replies": {
63
+ "route": "/replies.php",
64
+ "method": "GET",
65
+ "name": "User Replies",
66
+ "description": "Get replies made by a user.",
67
+ "payload": {
68
+ "screenname": "Twitter username without the @ symbol",
69
+ "cursor": "Optional pagination cursor"
70
+ }
71
+ },
72
+ "check_retweet": {
73
+ "route": "/checkretweet.php",
74
+ "method": "GET",
75
+ "name": "Check Retweet",
76
+ "description": "Check if a user has retweeted a specific tweet.",
77
+ "payload": {
78
+ "screenname": "Twitter username without the @ symbol",
79
+ "tweet_id": "ID of the tweet to check"
80
+ }
81
+ },
82
+ "tweet": {
83
+ "route": "/tweet.php",
84
+ "method": "GET",
85
+ "name": "Get Tweet",
86
+ "description": "Get details of a specific tweet by ID.",
87
+ "payload": {
88
+ "id": "ID of the tweet"
89
+ }
90
+ },
91
+ "tweet_thread": {
92
+ "route": "/tweet_thread.php",
93
+ "method": "GET",
94
+ "name": "Get Tweet Thread",
95
+ "description": "Get a thread of tweets starting from a specific tweet ID.",
96
+ "payload": {
97
+ "id": "ID of the tweet",
98
+ "cursor": "Optional pagination cursor"
99
+ }
100
+ },
101
+ "retweets": {
102
+ "route": "/retweets.php",
103
+ "method": "GET",
104
+ "name": "Get Retweets",
105
+ "description": "Get users who retweeted a specific tweet.",
106
+ "payload": {
107
+ "id": "ID of the tweet",
108
+ "cursor": "Optional pagination cursor"
109
+ }
110
+ },
111
+ "latest_replies": {
112
+ "route": "/latest_replies.php",
113
+ "method": "GET",
114
+ "name": "Get Latest Replies",
115
+ "description": "Get the latest replies to a specific tweet.",
116
+ "payload": {
117
+ "id": "ID of the tweet",
118
+ "cursor": "Optional pagination cursor"
119
+ }
120
+ }
121
+ }
122
+ base_url = "https://twitter-api45.p.rapidapi.com"
123
+ super().__init__(base_url, endpoints)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ from dotenv import load_dotenv
128
+ load_dotenv()
129
+ tool = TwitterProvider()
130
+
131
+ # Example for getting user info
132
+ user_info = tool.call_endpoint(
133
+ route="user_info",
134
+ payload={
135
+ "screenname": "elonmusk",
136
+ # "rest_id": "44196397" # Optional, uncomment to use user ID instead of screenname
137
+ }
138
+ )
139
+ print("User Info:", user_info)
140
+
141
+ # Example for getting user timeline
142
+ timeline = tool.call_endpoint(
143
+ route="timeline",
144
+ payload={
145
+ "screenname": "elonmusk",
146
+ # "cursor": "optional-cursor-value" # Optional for pagination
147
+ }
148
+ )
149
+ print("Timeline:", timeline)
150
+
151
+ # Example for getting user following
152
+ following = tool.call_endpoint(
153
+ route="following",
154
+ payload={
155
+ "screenname": "elonmusk",
156
+ # "cursor": "optional-cursor-value" # Optional for pagination
157
+ }
158
+ )
159
+ print("Following:", following)
160
+
161
+ # Example for getting user followers
162
+ followers = tool.call_endpoint(
163
+ route="followers",
164
+ payload={
165
+ "screenname": "elonmusk",
166
+ # "cursor": "optional-cursor-value" # Optional for pagination
167
+ }
168
+ )
169
+ print("Followers:", followers)
170
+
171
+ # Example for searching tweets
172
+ search_results = tool.call_endpoint(
173
+ route="search",
174
+ payload={
175
+ "query": "cybertruck",
176
+ "search_type": "Top" # Optional, defaults to Top
177
+ # "cursor": "optional-cursor-value" # Optional for pagination
178
+ }
179
+ )
180
+ print("Search Results:", search_results)
181
+
182
+ # Example for getting user replies
183
+ replies = tool.call_endpoint(
184
+ route="replies",
185
+ payload={
186
+ "screenname": "elonmusk",
187
+ # "cursor": "optional-cursor-value" # Optional for pagination
188
+ }
189
+ )
190
+ print("Replies:", replies)
191
+
192
+ # Example for checking if user retweeted a tweet
193
+ check_retweet = tool.call_endpoint(
194
+ route="check_retweet",
195
+ payload={
196
+ "screenname": "elonmusk",
197
+ "tweet_id": "1671370010743263233"
198
+ }
199
+ )
200
+ print("Check Retweet:", check_retweet)
201
+
202
+ # Example for getting tweet details
203
+ tweet = tool.call_endpoint(
204
+ route="tweet",
205
+ payload={
206
+ "id": "1671370010743263233"
207
+ }
208
+ )
209
+ print("Tweet:", tweet)
210
+
211
+ # Example for getting a tweet thread
212
+ tweet_thread = tool.call_endpoint(
213
+ route="tweet_thread",
214
+ payload={
215
+ "id": "1738106896777699464",
216
+ # "cursor": "optional-cursor-value" # Optional for pagination
217
+ }
218
+ )
219
+ print("Tweet Thread:", tweet_thread)
220
+
221
+ # Example for getting retweets of a tweet
222
+ retweets = tool.call_endpoint(
223
+ route="retweets",
224
+ payload={
225
+ "id": "1700199139470942473",
226
+ # "cursor": "optional-cursor-value" # Optional for pagination
227
+ }
228
+ )
229
+ print("Retweets:", retweets)
230
+
231
+ # Example for getting latest replies to a tweet
232
+ latest_replies = tool.call_endpoint(
233
+ route="latest_replies",
234
+ payload={
235
+ "id": "1738106896777699464",
236
+ # "cursor": "optional-cursor-value" # Optional for pagination
237
+ }
238
+ )
239
+ print("Latest Replies:", latest_replies)
240
+
WorkflowEditor.tsx ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // /home/ubuntu/visionos-frontend/src/components/WorkflowEditor.tsx
2
+ import React, { useCallback } from 'react';
3
+ import ReactFlow, {
4
+ MiniMap,
5
+ Controls,
6
+ Background,
7
+ useNodesState,
8
+ useEdgesState,
9
+ addEdge,
10
+ Connection,
11
+ Edge,
12
+ Node,
13
+ } from 'reactflow';
14
+
15
+ import 'reactflow/dist/style.css';
16
+
17
+ // Initial nodes and edges (example)
18
+ const initialNodes: Node[] = [
19
+ { id: '1', position: { x: 0, y: 0 }, data: { label: 'Start Node' }, type: 'input' },
20
+ { id: '2', position: { x: 0, y: 100 }, data: { label: 'Agent Task' } },
21
+ ];
22
+ const initialEdges: Edge[] = [{ id: 'e1-2', source: '1', target: '2' }];
23
+
24
+ const WorkflowEditor: React.FC = () => {
25
+ const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
26
+ const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
27
+
28
+ const onConnect = useCallback(
29
+ (params: Edge | Connection) => setEdges((eds) => addEdge(params, eds)),
30
+ [setEdges],
31
+ );
32
+
33
+ return (
34
+ <div style={{ width: '100%', height: '100%' }} className="border rounded-lg shadow-md">
35
+ <ReactFlow
36
+ nodes={nodes}
37
+ edges={edges}
38
+ onNodesChange={onNodesChange}
39
+ onEdgesChange={onEdgesChange}
40
+ onConnect={onConnect}
41
+ fitView
42
+ >
43
+ <Controls />
44
+ <MiniMap />
45
+ <Background variant="dots" gap={12} size={1} />
46
+ </ReactFlow>
47
+ </div>
48
+ );
49
+ };
50
+
51
+ export default WorkflowEditor;
52
+
YahooFinanceProvider.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
4
+
5
+
6
+ class YahooFinanceProvider(RapidDataProviderBase):
7
+ def __init__(self):
8
+ endpoints: Dict[str, EndpointSchema] = {
9
+ "get_tickers": {
10
+ "route": "/v2/markets/tickers",
11
+ "method": "GET",
12
+ "name": "Yahoo Finance Tickers",
13
+ "description": "Get financial tickers from Yahoo Finance with various filters and parameters.",
14
+ "payload": {
15
+ "page": "Page number for pagination (optional, default: 1)",
16
+ "type": "Asset class type (required): STOCKS, ETF, MUTUALFUNDS, or FUTURES",
17
+ }
18
+ },
19
+ "search": {
20
+ "route": "/v1/markets/search",
21
+ "method": "GET",
22
+ "name": "Yahoo Finance Search",
23
+ "description": "Search for financial instruments on Yahoo Finance",
24
+ "payload": {
25
+ "search": "Search term (required)",
26
+ }
27
+ },
28
+ "get_news": {
29
+ "route": "/v2/markets/news",
30
+ "method": "GET",
31
+ "name": "Yahoo Finance News",
32
+ "description": "Get news related to specific tickers from Yahoo Finance",
33
+ "payload": {
34
+ "tickers": "Stock symbol (optional, e.g., AAPL)",
35
+ "type": "News type (optional): ALL, VIDEO, or PRESS_RELEASE",
36
+ }
37
+ },
38
+ "get_stock_module": {
39
+ "route": "/v1/markets/stock/modules",
40
+ "method": "GET",
41
+ "name": "Yahoo Finance Stock Module",
42
+ "description": "Get detailed information about a specific stock module",
43
+ "payload": {
44
+ "ticker": "Company ticker symbol (required, e.g., AAPL)",
45
+ "module": "Module to retrieve (required): asset-profile, financial-data, earnings, etc.",
46
+ }
47
+ },
48
+ "get_sma": {
49
+ "route": "/v1/markets/indicators/sma",
50
+ "method": "GET",
51
+ "name": "Yahoo Finance SMA Indicator",
52
+ "description": "Get Simple Moving Average (SMA) indicator data for a stock",
53
+ "payload": {
54
+ "symbol": "Stock symbol (required, e.g., AAPL)",
55
+ "interval": "Time interval (required): 5m, 15m, 30m, 1h, 1d, 1wk, 1mo, 3mo",
56
+ "series_type": "Series type (required): open, close, high, low",
57
+ "time_period": "Number of data points used for calculation (required)",
58
+ "limit": "Limit the number of results (optional, default: 50)",
59
+ }
60
+ },
61
+ "get_rsi": {
62
+ "route": "/v1/markets/indicators/rsi",
63
+ "method": "GET",
64
+ "name": "Yahoo Finance RSI Indicator",
65
+ "description": "Get Relative Strength Index (RSI) indicator data for a stock",
66
+ "payload": {
67
+ "symbol": "Stock symbol (required, e.g., AAPL)",
68
+ "interval": "Time interval (required): 5m, 15m, 30m, 1h, 1d, 1wk, 1mo, 3mo",
69
+ "series_type": "Series type (required): open, close, high, low",
70
+ "time_period": "Number of data points used for calculation (required)",
71
+ "limit": "Limit the number of results (optional, default: 50)",
72
+ }
73
+ },
74
+ "get_earnings_calendar": {
75
+ "route": "/v1/markets/calendar/earnings",
76
+ "method": "GET",
77
+ "name": "Yahoo Finance Earnings Calendar",
78
+ "description": "Get earnings calendar data for a specific date",
79
+ "payload": {
80
+ "date": "Calendar date in yyyy-mm-dd format (optional, e.g., 2023-11-30)",
81
+ }
82
+ },
83
+ "get_insider_trades": {
84
+ "route": "/v1/markets/insider-trades",
85
+ "method": "GET",
86
+ "name": "Yahoo Finance Insider Trades",
87
+ "description": "Get recent insider trading activity",
88
+ "payload": {}
89
+ },
90
+ }
91
+ base_url = "https://yahoo-finance15.p.rapidapi.com/api"
92
+ super().__init__(base_url, endpoints)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ from dotenv import load_dotenv
97
+ load_dotenv()
98
+ tool = YahooFinanceProvider()
99
+
100
+ # Example for getting stock tickers
101
+ tickers_result = tool.call_endpoint(
102
+ route="get_tickers",
103
+ payload={
104
+ "page": 1,
105
+ "type": "STOCKS"
106
+ }
107
+ )
108
+ print("Tickers Result:", tickers_result)
109
+
110
+ # Example for searching financial instruments
111
+ search_result = tool.call_endpoint(
112
+ route="search",
113
+ payload={
114
+ "search": "AA"
115
+ }
116
+ )
117
+ print("Search Result:", search_result)
118
+
119
+ # Example for getting financial news
120
+ news_result = tool.call_endpoint(
121
+ route="get_news",
122
+ payload={
123
+ "tickers": "AAPL",
124
+ "type": "ALL"
125
+ }
126
+ )
127
+ print("News Result:", news_result)
128
+
129
+ # Example for getting stock asset profile module
130
+ stock_module_result = tool.call_endpoint(
131
+ route="get_stock_module",
132
+ payload={
133
+ "ticker": "AAPL",
134
+ "module": "asset-profile"
135
+ }
136
+ )
137
+ print("Asset Profile Result:", stock_module_result)
138
+
139
+ # Example for getting financial data module
140
+ financial_data_result = tool.call_endpoint(
141
+ route="get_stock_module",
142
+ payload={
143
+ "ticker": "AAPL",
144
+ "module": "financial-data"
145
+ }
146
+ )
147
+ print("Financial Data Result:", financial_data_result)
148
+
149
+ # Example for getting SMA indicator data
150
+ sma_result = tool.call_endpoint(
151
+ route="get_sma",
152
+ payload={
153
+ "symbol": "AAPL",
154
+ "interval": "5m",
155
+ "series_type": "close",
156
+ "time_period": "50",
157
+ "limit": "50"
158
+ }
159
+ )
160
+ print("SMA Result:", sma_result)
161
+
162
+ # Example for getting RSI indicator data
163
+ rsi_result = tool.call_endpoint(
164
+ route="get_rsi",
165
+ payload={
166
+ "symbol": "AAPL",
167
+ "interval": "5m",
168
+ "series_type": "close",
169
+ "time_period": "50",
170
+ "limit": "50"
171
+ }
172
+ )
173
+ print("RSI Result:", rsi_result)
174
+
175
+ # Example for getting earnings calendar data
176
+ earnings_calendar_result = tool.call_endpoint(
177
+ route="get_earnings_calendar",
178
+ payload={
179
+ "date": "2023-11-30"
180
+ }
181
+ )
182
+ print("Earnings Calendar Result:", earnings_calendar_result)
183
+
184
+ # Example for getting insider trades
185
+ insider_trades_result = tool.call_endpoint(
186
+ route="get_insider_trades",
187
+ payload={}
188
+ )
189
+ print("Insider Trades Result:", insider_trades_result)
190
+
ZillowProvider.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import logging
3
+
4
+ from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class ZillowProvider(RapidDataProviderBase):
10
+ def __init__(self):
11
+ endpoints: Dict[str, EndpointSchema] = {
12
+ "search": {
13
+ "route": "/search",
14
+ "method": "GET",
15
+ "name": "Zillow Property Search",
16
+ "description": "Search for properties by neighborhood, city, or ZIP code with various filters.",
17
+ "payload": {
18
+ "location": "Location can be an address, neighborhood, city, or ZIP code (required)",
19
+ "page": "Page number for pagination (optional, default: 0)",
20
+ "output": "Output format: json, csv, xlsx (optional, default: json)",
21
+ "status": "Status of properties: forSale, forRent, recentlySold (optional, default: forSale)",
22
+ "sortSelection": "Sorting criteria (optional, default: priorityscore)",
23
+ "listing_type": "Listing type: by_agent, by_owner_other (optional, default: by_agent)",
24
+ "doz": "Days on Zillow: any, 1, 7, 14, 30, 90, 6m, 12m, 24m, 36m (optional, default: any)",
25
+ "price_min": "Minimum price (optional)",
26
+ "price_max": "Maximum price (optional)",
27
+ "sqft_min": "Minimum square footage (optional)",
28
+ "sqft_max": "Maximum square footage (optional)",
29
+ "beds_min": "Minimum number of bedrooms (optional)",
30
+ "beds_max": "Maximum number of bedrooms (optional)",
31
+ "baths_min": "Minimum number of bathrooms (optional)",
32
+ "baths_max": "Maximum number of bathrooms (optional)",
33
+ "built_min": "Minimum year built (optional)",
34
+ "built_max": "Maximum year built (optional)",
35
+ "lotSize_min": "Minimum lot size in sqft (optional)",
36
+ "lotSize_max": "Maximum lot size in sqft (optional)",
37
+ "keywords": "Keywords to search for (optional)"
38
+ }
39
+ },
40
+ "search_address": {
41
+ "route": "/search_address",
42
+ "method": "GET",
43
+ "name": "Zillow Address Search",
44
+ "description": "Search for a specific property by its full address.",
45
+ "payload": {
46
+ "address": "Full property address (required)"
47
+ }
48
+ },
49
+ "propertyV2": {
50
+ "route": "/propertyV2",
51
+ "method": "GET",
52
+ "name": "Zillow Property Details",
53
+ "description": "Get detailed information about a specific property by zpid or URL.",
54
+ "payload": {
55
+ "zpid": "Zillow property ID (optional if URL is provided)",
56
+ "url": "Property details URL (optional if zpid is provided)"
57
+ }
58
+ },
59
+ "zestimate_history": {
60
+ "route": "/zestimate_history",
61
+ "method": "GET",
62
+ "name": "Zillow Zestimate History",
63
+ "description": "Get historical Zestimate values for a specific property.",
64
+ "payload": {
65
+ "zpid": "Zillow property ID (optional if URL is provided)",
66
+ "url": "Property details URL (optional if zpid is provided)"
67
+ }
68
+ },
69
+ "similar_properties": {
70
+ "route": "/similar_properties",
71
+ "method": "GET",
72
+ "name": "Zillow Similar Properties",
73
+ "description": "Find properties similar to a specific property.",
74
+ "payload": {
75
+ "zpid": "Zillow property ID (optional if URL or address is provided)",
76
+ "url": "Property details URL (optional if zpid or address is provided)",
77
+ "address": "Property address (optional if zpid or URL is provided)"
78
+ }
79
+ },
80
+ "mortgage_rates": {
81
+ "route": "/mortgage/rates",
82
+ "method": "GET",
83
+ "name": "Zillow Mortgage Rates",
84
+ "description": "Get current mortgage rates for different loan programs and conditions.",
85
+ "payload": {
86
+ "program": "Loan program (required): Fixed30Year, Fixed20Year, Fixed15Year, Fixed10Year, ARM3, ARM5, ARM7, etc.",
87
+ "state": "State abbreviation (optional, default: US)",
88
+ "refinance": "Whether this is for refinancing (optional, default: false)",
89
+ "loanType": "Type of loan: Conventional, etc. (optional)",
90
+ "loanAmount": "Loan amount category: Micro, SmallConforming, Conforming, SuperConforming, Jumbo (optional)",
91
+ "loanToValue": "Loan to value ratio: Normal, High, VeryHigh (optional)",
92
+ "creditScore": "Credit score category: Low, High, VeryHigh (optional)",
93
+ "duration": "Duration in days (optional, default: 30)"
94
+ }
95
+ },
96
+ }
97
+ base_url = "https://zillow56.p.rapidapi.com"
98
+ super().__init__(base_url, endpoints)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ from dotenv import load_dotenv
103
+ from time import sleep
104
+ load_dotenv()
105
+ tool = ZillowProvider()
106
+
107
+ # Example for searching properties in Houston
108
+ search_result = tool.call_endpoint(
109
+ route="search",
110
+ payload={
111
+ "location": "houston, tx",
112
+ "status": "forSale",
113
+ "sortSelection": "priorityscore",
114
+ "listing_type": "by_agent",
115
+ "doz": "any"
116
+ }
117
+ )
118
+ logger.debug("Search Result: %s", search_result)
119
+ logger.debug("***")
120
+ logger.debug("***")
121
+ logger.debug("***")
122
+ sleep(1)
123
+ # Example for searching by address
124
+ address_result = tool.call_endpoint(
125
+ route="search_address",
126
+ payload={
127
+ "address": "1161 Natchez Dr College Station Texas 77845"
128
+ }
129
+ )
130
+ logger.debug("Address Search Result: %s", address_result)
131
+ logger.debug("***")
132
+ logger.debug("***")
133
+ logger.debug("***")
134
+ sleep(1)
135
+ # Example for getting property details
136
+ property_result = tool.call_endpoint(
137
+ route="propertyV2",
138
+ payload={
139
+ "zpid": "7594920"
140
+ }
141
+ )
142
+ logger.debug("Property Details Result: %s", property_result)
143
+ sleep(1)
144
+ logger.debug("***")
145
+ logger.debug("***")
146
+ logger.debug("***")
147
+
148
+ # Example for getting zestimate history
149
+ zestimate_result = tool.call_endpoint(
150
+ route="zestimate_history",
151
+ payload={
152
+ "zpid": "20476226"
153
+ }
154
+ )
155
+ logger.debug("Zestimate History Result: %s", zestimate_result)
156
+ sleep(1)
157
+ logger.debug("***")
158
+ logger.debug("***")
159
+ logger.debug("***")
160
+ # Example for getting similar properties
161
+ similar_result = tool.call_endpoint(
162
+ route="similar_properties",
163
+ payload={
164
+ "zpid": "28253016"
165
+ }
166
+ )
167
+ logger.debug("Similar Properties Result: %s", similar_result)
168
+ sleep(1)
169
+ logger.debug("***")
170
+ logger.debug("***")
171
+ logger.debug("***")
172
+ # Example for getting mortgage rates
173
+ mortgage_result = tool.call_endpoint(
174
+ route="mortgage_rates",
175
+ payload={
176
+ "program": "Fixed30Year",
177
+ "state": "US",
178
+ "refinance": "false",
179
+ "loanType": "Conventional",
180
+ "loanAmount": "Conforming",
181
+ "loanToValue": "Normal",
182
+ "creditScore": "Low",
183
+ "duration": "30"
184
+ }
185
+ )
186
+ logger.debug("Mortgage Rates Result: %s", mortgage_result)
187
+
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions and constants for agent tools
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|AUDIO|>": 151646,
5
+ "<|IMAGE|>": 151655,
6
+ "<|VIDEO|>": 151656,
7
+ "<|audio_bos|>": 151647,
8
+ "<|audio_eos|>": 151648,
9
+ "<|box_end|>": 151649,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|quad_end|>": 151651,
19
+ "<|quad_start|>": 151650,
20
+ "<|repo_name|>": 151663,
21
+ "<|vision_bos|>": 151652,
22
+ "<|vision_eos|>": 151653,
23
+ "<|vision_pad|>": 151654
24
+ }
agent.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass
6
+ class Task:
7
+ id: str
8
+ status: str = 'pending'
9
+ input_data: dict = None
10
+ result_data: dict = None
11
+
12
+ class VisionOSAgent:
13
+ def __init__(self, model_name='distilbert-base-uncased'):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ self.model.to(self.device)
18
+ self.current_task: Task = None
19
+
20
+ def set_task(self, task: Task):
21
+ self.current_task = task
22
+
23
+ def execute_task(self):
24
+ if self.current_task is None or self.current_task.status != 'pending':
25
+ raise ValueError("No pending task to execute")
26
+ input_text = self.current_task.input_data.get('text', 'Default input')
27
+ inputs = self.tokenizer(input_text, return_tensors='pt').to(self.device)
28
+ with torch.no_grad():
29
+ outputs = self.model(**inputs)
30
+ prediction = torch.argmax(outputs.logits, dim=1).item()
31
+ self.current_task.result_data = {'prediction': prediction}
32
+ self.current_task.status = 'completed'
33
+ return self.current_task
34
+
35
+ # Example usage
36
+ if __name__ == "__main__":
37
+ agent = VisionOSAgent()
38
+ task = Task(id='task1', input_data={'text': 'Analyze vehicle status: operational'})
39
+ agent.set_task(task)
40
+ result = agent.execute_task()
41
+ print(f"Task ID: {result.id}, Status: {result.status}, Result: {result.result_data}")
alembic.ini ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # alembic.ini
2
+ # A generic configuration for the Alembic migration tool.
3
+
4
+ [alembic]
5
+ # path to migration scripts
6
+ script_location = alembic
7
+
8
+ # template used to generate migration files
9
+ # file_template = %%(rev)s_%%(slug)s
10
+
11
+ # timezone to use when rendering the date within the migration file
12
+ # as well as the filename.
13
+ # If specified, requires the python-dateutil library that is installable
14
+ # with pip install python-dateutil.
15
+ # Any required timezone name works, such as UTC, PST8PDT, Europe/London
16
+ # If None, the system default timezone is used.
17
+ # timezone =
18
+
19
+ # sys.path path, will be prepended to sys.path if present.
20
+ # defaults to the current working directory.
21
+ # prepend_sys_path = .
22
+
23
+ # Logging configuration
24
+ [loggers]
25
+ keys = root,sqlalchemy,alembic
26
+
27
+ [handlers]
28
+ keys = console
29
+
30
+ [formatters]
31
+ keys = generic
32
+
33
+ [logger_root]
34
+ level = WARN
35
+ handlers = console
36
+ qualname =
37
+
38
+ [logger_sqlalchemy]
39
+ level = WARN
40
+ handlers =
41
+ qualname = sqlalchemy.engine
42
+
43
+ [logger_alembic]
44
+ level = INFO
45
+ handlers =
46
+ qualname = alembic
47
+
48
+ [handler_console]
49
+ class = StreamHandler
50
+ args = (sys.stderr,)
51
+ level = NOTSET
52
+ formatter = generic
53
+
54
+ [formatter_generic]
55
+ format = %%(levelname)-5.5s [%%(name)s] %%(message)s
56
+ datefmt = %%H:%%M:%%S
57
+
58
+ # Database configuration
59
+ # Replace sqlalchemy.url with the actual database connection string from your settings
60
+ # This will be dynamically loaded in env.py
61
+ s sqlalchemy.url = driver://user:pass@localhost/dbname
62
+
api.cpython-311.pyc ADDED
Binary file (918 Bytes). View file
 
api.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, Form, Depends, Request
5
+ from fastapi.responses import Response, JSONResponse
6
+ from pydantic import BaseModel
7
+
8
+ from utils.logger import logger
9
+ from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
10
+ from sandbox.sandbox import get_or_start_sandbox
11
+ from services.supabase import DBConnection
12
+ from agent.api import get_or_create_project_sandbox
13
+
14
+
15
+ # Initialize shared resources
16
+ router = APIRouter(tags=["sandbox"])
17
+ db = None
18
+
19
+ def initialize(_db: DBConnection):
20
+ """Initialize the sandbox API with resources from the main API."""
21
+ global db
22
+ db = _db
23
+ logger.info("Initialized sandbox API with database connection")
24
+
25
+ class FileInfo(BaseModel):
26
+ """Model for file information"""
27
+ name: str
28
+ path: str
29
+ is_dir: bool
30
+ size: int
31
+ mod_time: str
32
+ permissions: Optional[str] = None
33
+
34
+ async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None):
35
+ """
36
+ Verify that a user has access to a specific sandbox based on account membership.
37
+
38
+ Args:
39
+ client: The Supabase client
40
+ sandbox_id: The sandbox ID to check access for
41
+ user_id: The user ID to check permissions for. Can be None for public resource access.
42
+
43
+ Returns:
44
+ dict: Project data containing sandbox information
45
+
46
+ Raises:
47
+ HTTPException: If the user doesn't have access to the sandbox or sandbox doesn't exist
48
+ """
49
+ # Find the project that owns this sandbox
50
+ project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute()
51
+
52
+ if not project_result.data or len(project_result.data) == 0:
53
+ raise HTTPException(status_code=404, detail="Sandbox not found")
54
+
55
+ project_data = project_result.data[0]
56
+
57
+ if project_data.get('is_public'):
58
+ return project_data
59
+
60
+ # For private projects, we must have a user_id
61
+ if not user_id:
62
+ raise HTTPException(status_code=401, detail="Authentication required for this resource")
63
+
64
+ account_id = project_data.get('account_id')
65
+
66
+ # Verify account membership
67
+ if account_id:
68
+ account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
69
+ if account_user_result.data and len(account_user_result.data) > 0:
70
+ return project_data
71
+
72
+ raise HTTPException(status_code=403, detail="Not authorized to access this sandbox")
73
+
74
+ async def get_sandbox_by_id_safely(client, sandbox_id: str):
75
+ """
76
+ Safely retrieve a sandbox object by its ID, using the project that owns it.
77
+
78
+ Args:
79
+ client: The Supabase client
80
+ sandbox_id: The sandbox ID to retrieve
81
+
82
+ Returns:
83
+ Sandbox: The sandbox object
84
+
85
+ Raises:
86
+ HTTPException: If the sandbox doesn't exist or can't be retrieved
87
+ """
88
+ # Find the project that owns this sandbox
89
+ project_result = await client.table('projects').select('project_id').filter('sandbox->>id', 'eq', sandbox_id).execute()
90
+
91
+ if not project_result.data or len(project_result.data) == 0:
92
+ logger.error(f"No project found for sandbox ID: {sandbox_id}")
93
+ raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox ID")
94
+
95
+ project_id = project_result.data[0]['project_id']
96
+ logger.debug(f"Found project {project_id} for sandbox {sandbox_id}")
97
+
98
+ try:
99
+ # Get the sandbox
100
+ sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
101
+
102
+ # Verify we got the right sandbox
103
+ if retrieved_sandbox_id != sandbox_id:
104
+ logger.warning(f"Retrieved sandbox ID {retrieved_sandbox_id} doesn't match requested ID {sandbox_id} for project {project_id}")
105
+ # Fall back to the direct method if IDs don't match (shouldn't happen but just in case)
106
+ sandbox = await get_or_start_sandbox(sandbox_id)
107
+
108
+ return sandbox
109
+ except Exception as e:
110
+ logger.error(f"Error retrieving sandbox {sandbox_id}: {str(e)}")
111
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve sandbox: {str(e)}")
112
+
113
+ @router.post("/sandboxes/{sandbox_id}/files")
114
+ async def create_file(
115
+ sandbox_id: str,
116
+ path: str = Form(...),
117
+ file: UploadFile = File(...),
118
+ request: Request = None,
119
+ user_id: Optional[str] = Depends(get_optional_user_id)
120
+ ):
121
+ """Create a file in the sandbox using direct file upload"""
122
+ logger.info(f"Received file upload request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
123
+ client = await db.client
124
+
125
+ # Verify the user has access to this sandbox
126
+ await verify_sandbox_access(client, sandbox_id, user_id)
127
+
128
+ try:
129
+ # Get sandbox using the safer method
130
+ sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
131
+
132
+ # Read file content directly from the uploaded file
133
+ content = await file.read()
134
+
135
+ # Create file using raw binary content
136
+ sandbox.fs.upload_file(path, content)
137
+ logger.info(f"File created at {path} in sandbox {sandbox_id}")
138
+
139
+ return {"status": "success", "created": True, "path": path}
140
+ except Exception as e:
141
+ logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}")
142
+ raise HTTPException(status_code=500, detail=str(e))
143
+
144
+ # For backward compatibility, keep the JSON version too
145
+ @router.post("/sandboxes/{sandbox_id}/files/json")
146
+ async def create_file_json(
147
+ sandbox_id: str,
148
+ file_request: dict,
149
+ request: Request = None,
150
+ user_id: Optional[str] = Depends(get_optional_user_id)
151
+ ):
152
+ """Create a file in the sandbox using JSON (legacy support)"""
153
+ logger.info(f"Received JSON file creation request for sandbox {sandbox_id}, user_id: {user_id}")
154
+ client = await db.client
155
+
156
+ # Verify the user has access to this sandbox
157
+ await verify_sandbox_access(client, sandbox_id, user_id)
158
+
159
+ try:
160
+ # Get sandbox using the safer method
161
+ sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
162
+
163
+ # Get file path and content
164
+ path = file_request.get("path")
165
+ content = file_request.get("content", "")
166
+
167
+ if not path:
168
+ logger.error(f"Missing file path in request for sandbox {sandbox_id}")
169
+ raise HTTPException(status_code=400, detail="File path is required")
170
+
171
+ # Convert string content to bytes
172
+ if isinstance(content, str):
173
+ content = content.encode('utf-8')
174
+
175
+ # Create file
176
+ sandbox.fs.upload_file(path, content)
177
+ logger.info(f"File created at {path} in sandbox {sandbox_id}")
178
+
179
+ return {"status": "success", "created": True, "path": path}
180
+ except Exception as e:
181
+ logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}")
182
+ raise HTTPException(status_code=500, detail=str(e))
183
+
184
+ @router.get("/sandboxes/{sandbox_id}/files")
185
+ async def list_files(
186
+ sandbox_id: str,
187
+ path: str,
188
+ request: Request = None,
189
+ user_id: Optional[str] = Depends(get_optional_user_id)
190
+ ):
191
+ """List files and directories at the specified path"""
192
+ logger.info(f"Received list files request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
193
+ client = await db.client
194
+
195
+ # Verify the user has access to this sandbox
196
+ await verify_sandbox_access(client, sandbox_id, user_id)
197
+
198
+ try:
199
+ # Get sandbox using the safer method
200
+ sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
201
+
202
+ # List files
203
+ files = sandbox.fs.list_files(path)
204
+ result = []
205
+
206
+ for file in files:
207
+ # Convert file information to our model
208
+ # Ensure forward slashes are used for paths, regardless of OS
209
+ full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}"
210
+ file_info = FileInfo(
211
+ name=file.name,
212
+ path=full_path, # Use the constructed path
213
+ is_dir=file.is_dir,
214
+ size=file.size,
215
+ mod_time=str(file.mod_time),
216
+ permissions=getattr(file, 'permissions', None)
217
+ )
218
+ result.append(file_info)
219
+
220
+ logger.info(f"Successfully listed {len(result)} files in sandbox {sandbox_id}")
221
+ return {"files": [file.dict() for file in result]}
222
+ except Exception as e:
223
+ logger.error(f"Error listing files in sandbox {sandbox_id}: {str(e)}")
224
+ raise HTTPException(status_code=500, detail=str(e))
225
+
226
+ @router.get("/sandboxes/{sandbox_id}/files/content")
227
+ async def read_file(
228
+ sandbox_id: str,
229
+ path: str,
230
+ request: Request = None,
231
+ user_id: Optional[str] = Depends(get_optional_user_id)
232
+ ):
233
+ """Read a file from the sandbox"""
234
+ logger.info(f"Received file read request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
235
+ client = await db.client
236
+
237
+ # Verify the user has access to this sandbox
238
+ await verify_sandbox_access(client, sandbox_id, user_id)
239
+
240
+ try:
241
+ # Get sandbox using the safer method
242
+ sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
243
+
244
+ # Read file
245
+ content = sandbox.fs.download_file(path)
246
+
247
+ # Return a Response object with the content directly
248
+ filename = os.path.basename(path)
249
+ logger.info(f"Successfully read file {filename} from sandbox {sandbox_id}")
250
+ return Response(
251
+ content=content,
252
+ media_type="application/octet-stream",
253
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
254
+ )
255
+ except Exception as e:
256
+ logger.error(f"Error reading file in sandbox {sandbox_id}: {str(e)}")
257
+ raise HTTPException(status_code=500, detail=str(e))
258
+
259
+ @router.post("/project/{project_id}/sandbox/ensure-active")
260
+ async def ensure_project_sandbox_active(
261
+ project_id: str,
262
+ request: Request = None,
263
+ user_id: Optional[str] = Depends(get_optional_user_id)
264
+ ):
265
+ """
266
+ Ensure that a project's sandbox is active and running.
267
+ Checks the sandbox status and starts it if it's not running.
268
+ """
269
+ logger.info(f"Received ensure sandbox active request for project {project_id}, user_id: {user_id}")
270
+ client = await db.client
271
+
272
+ # Find the project and sandbox information
273
+ project_result = await client.table('projects').select('*').eq('project_id', project_id).execute()
274
+
275
+ if not project_result.data or len(project_result.data) == 0:
276
+ logger.error(f"Project not found: {project_id}")
277
+ raise HTTPException(status_code=404, detail="Project not found")
278
+
279
+ project_data = project_result.data[0]
280
+
281
+ # For public projects, no authentication is needed
282
+ if not project_data.get('is_public'):
283
+ # For private projects, we must have a user_id
284
+ if not user_id:
285
+ logger.error(f"Authentication required for private project {project_id}")
286
+ raise HTTPException(status_code=401, detail="Authentication required for this resource")
287
+
288
+ account_id = project_data.get('account_id')
289
+
290
+ # Verify account membership
291
+ if account_id:
292
+ account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
293
+ if not (account_user_result.data and len(account_user_result.data) > 0):
294
+ logger.error(f"User {user_id} not authorized to access project {project_id}")
295
+ raise HTTPException(status_code=403, detail="Not authorized to access this project")
296
+
297
+ try:
298
+ # Get or create the sandbox
299
+ logger.info(f"Ensuring sandbox is active for project {project_id}")
300
+ sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
301
+
302
+ logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}")
303
+
304
+ return {
305
+ "status": "success",
306
+ "sandbox_id": sandbox_id,
307
+ "message": "Sandbox is active"
308
+ }
309
+ except Exception as e:
310
+ logger.error(f"Error ensuring sandbox is active for project {project_id}: {str(e)}")
311
+ raise HTTPException(status_code=500, detail=str(e))
api.py.bak ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from contextlib import asynccontextmanager
5
+ from agentpress.thread_manager import ThreadManager
6
+ from services.supabase import DBConnection
7
+ from datetime import datetime, timezone
8
+ from dotenv import load_dotenv
9
+ from utils.config import config, EnvMode
10
+ import asyncio
11
+ from utils.logger import logger
12
+ import uuid
13
+ import time
14
+ from collections import OrderedDict
15
+
16
+ # Import the agent API module
17
+ from agent import api as agent_api
18
+ from sandbox import api as sandbox_api
19
+
20
+ # Load environment variables (these will be available through config)
21
+ load_dotenv()
22
+
23
+ # Initialize managers
24
+ db = DBConnection()
25
+ thread_manager = None
26
+ instance_id = "single"
27
+
28
+ # Rate limiter state
29
+ ip_tracker = OrderedDict()
30
+ MAX_CONCURRENT_IPS = 25
31
+
32
+ @asynccontextmanager
33
+ async def lifespan(app: FastAPI):
34
+ # Startup
35
+ global thread_manager
36
+ logger.info(f"Starting up FastAPI application with instance ID: {instance_id} in {config.ENV_MODE.value} mode")
37
+
38
+ try:
39
+ # Initialize database
40
+ await db.initialize()
41
+ thread_manager = ThreadManager()
42
+
43
+ # Initialize the agent API with shared resources
44
+ agent_api.initialize(
45
+ thread_manager,
46
+ db,
47
+ instance_id
48
+ )
49
+
50
+ # Initialize the sandbox API with shared resources
51
+ sandbox_api.initialize(db)
52
+
53
+ # Initialize Redis connection
54
+ from services import redis
55
+ try:
56
+ await redis.initialize_async()
57
+ logger.info("Redis connection initialized successfully")
58
+ except Exception as e:
59
+ logger.error(f"Failed to initialize Redis connection: {e}")
60
+ # Continue without Redis - the application will handle Redis failures gracefully
61
+
62
+ # Start background tasks
63
+ asyncio.create_task(agent_api.restore_running_agent_runs())
64
+
65
+ yield
66
+
67
+ # Clean up agent resources
68
+ logger.info("Cleaning up agent resources")
69
+ await agent_api.cleanup()
70
+
71
+ # Clean up Redis connection
72
+ try:
73
+ logger.info("Closing Redis connection")
74
+ await redis.close()
75
+ logger.info("Redis connection closed successfully")
76
+ except Exception as e:
77
+ logger.error(f"Error closing Redis connection: {e}")
78
+
79
+ # Clean up database connection
80
+ logger.info("Disconnecting from database")
81
+ await db.disconnect()
82
+ except Exception as e:
83
+ logger.error(f"Error during application startup: {e}")
84
+ raise
85
+
86
+ app = FastAPI(lifespan=lifespan)
87
+
88
+ @app.middleware("http")
89
+ async def log_requests_middleware(request: Request, call_next):
90
+ start_time = time.time()
91
+ client_ip = request.client.host
92
+ method = request.method
93
+ url = str(request.url)
94
+ path = request.url.path
95
+ query_params = str(request.query_params)
96
+
97
+ # Log the incoming request
98
+ logger.info(f"Request started: {method} {path} from {client_ip} | Query: {query_params}")
99
+
100
+ try:
101
+ response = await call_next(request)
102
+ process_time = time.time() - start_time
103
+ logger.debug(f"Request completed: {method} {path} | Status: {response.status_code} | Time: {process_time:.2f}s")
104
+ return response
105
+ except Exception as e:
106
+ process_time = time.time() - start_time
107
+ logger.error(f"Request failed: {method} {path} | Error: {str(e)} | Time: {process_time:.2f}s")
108
+ raise
109
+
110
+ # Define allowed origins based on environment
111
+ allowed_origins = ["https://www.suna.so", "https://suna.so", "https://staging.suna.so", "http://localhost:3000"]
112
+
113
+ # Add staging-specific origins
114
+ if config.ENV_MODE == EnvMode.STAGING:
115
+ allowed_origins.append("http://localhost:3000")
116
+
117
+ # Add local-specific origins
118
+ if config.ENV_MODE == EnvMode.LOCAL:
119
+ allowed_origins.append("http://localhost:3000")
120
+
121
+ app.add_middleware(
122
+ CORSMiddleware,
123
+ allow_origins=allowed_origins,
124
+ allow_credentials=True,
125
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
126
+ allow_headers=["Content-Type", "Authorization"],
127
+ )
128
+
129
+ # Include the agent router with a prefix
130
+ app.include_router(agent_api.router, prefix="/api")
131
+
132
+ # Include the sandbox router with a prefix
133
+ app.include_router(sandbox_api.router, prefix="/api")
134
+
135
+ @app.get("/api/health")
136
+ async def health_check():
137
+ """Health check endpoint to verify API is working."""
138
+ logger.info("Health check endpoint called")
139
+ return {
140
+ "status": "ok",
141
+ "timestamp": datetime.now(timezone.utc).isoformat(),
142
+ "instance_id": instance_id
143
+ }
144
+
145
+ if __name__ == "__main__":
146
+ import uvicorn
147
+
148
+ workers = 2
149
+
150
+ logger.info(f"Starting server on 0.0.0.0:8000 with {workers} workers")
151
+ uvicorn.run(
152
+ "api:app",
153
+ host="0.0.0.0",
154
+ port=8000,
155
+ workers=workers
156
+ )
api_keys.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /home/ubuntu/visionos_farm/daedalus/api/v1/endpoints/api_keys.py
2
+
3
+ import uuid
4
+ from typing import List
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, status, Security
7
+ from sqlalchemy.orm import Session
8
+
9
+ from shared import schemas
10
+ from database import models, session
11
+ from crud import crud_api_key
12
+ from . import dependencies # Assuming dependencies.py for auth
13
+
14
+ router = APIRouter()
15
+
16
+ @router.post("/", response_model=schemas.ApiKeyCreateResponse, status_code=status.HTTP_201_CREATED)
17
+ def create_api_key(
18
+ api_key_in: schemas.ApiKeyCreate,
19
+ db: Session = Depends(session.get_db),
20
+ current_user: models.User = Depends(dependencies.get_current_active_user) # Protect key creation
21
+ ):
22
+ """Generate a new API key for the currently authenticated user."""
23
+ db_api_key, generated_key_string = crud_api_key.create_api_key(
24
+ db=db, api_key_in=api_key_in, user_id=current_user.id
25
+ )
26
+ # Important: The raw key is only returned ONCE upon creation.
27
+ return schemas.ApiKeyCreateResponse(
28
+ **db_api_key.__dict__, # Convert DB model to dict
29
+ key=generated_key_string # Add the raw key to the response
30
+ )
31
+
32
+ @router.get("/", response_model=List[schemas.ApiKey])
33
+ def read_api_keys(
34
+ skip: int = 0,
35
+ limit: int = 100,
36
+ db: Session = Depends(session.get_db),
37
+ current_user: models.User = Depends(dependencies.get_current_active_user)
38
+ ):
39
+ """Retrieve API keys for the currently authenticated user."""
40
+ api_keys = crud_api_key.get_api_keys_by_user(db=db, user_id=current_user.id, skip=skip, limit=limit)
41
+ return api_keys
42
+
43
+ @router.put("/{api_key_id}", response_model=schemas.ApiKey)
44
+ def update_api_key(
45
+ api_key_id: uuid.UUID,
46
+ api_key_in: schemas.ApiKeyUpdate,
47
+ db: Session = Depends(session.get_db),
48
+ current_user: models.User = Depends(dependencies.get_current_active_user)
49
+ ):
50
+ """Update an API key belonging to the current user."""
51
+ db_api_key = db.get(models.ApiKey, api_key_id)
52
+ if not db_api_key or db_api_key.user_id != current_user.id:
53
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API Key not found")
54
+ updated_key = crud_api_key.update_api_key(db=db, db_api_key=db_api_key, api_key_in=api_key_in)
55
+ return updated_key
56
+
57
+ @router.delete("/{api_key_id}", status_code=status.HTTP_204_NO_CONTENT)
58
+ def delete_api_key(
59
+ api_key_id: uuid.UUID,
60
+ db: Session = Depends(session.get_db),
61
+ current_user: models.User = Depends(dependencies.get_current_active_user)
62
+ ):
63
+ """Delete an API key belonging to the current user."""
64
+ deleted_key = crud_api_key.delete_api_key(db=db, api_key_id=api_key_id, user_id=current_user.id)
65
+ if not deleted_key:
66
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API Key not found")
67
+ return None # Return 204 No Content on success
68
+
architecture_diagram.svg ADDED
auth_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, Request, Depends
2
+ from typing import Optional, List, Dict, Any
3
+ import jwt
4
+ from jwt.exceptions import PyJWTError
5
+ from utils.logger import logger
6
+
7
+ # This function extracts the user ID from Supabase JWT
8
+ async def get_current_user_id(request: Request) -> str:
9
+ """
10
+ Extract and verify the user ID from the JWT in the Authorization header.
11
+
12
+ This function is used as a dependency in FastAPI routes to ensure the user
13
+ is authenticated and to provide the user ID for authorization checks.
14
+
15
+ Args:
16
+ request: The FastAPI request object
17
+
18
+ Returns:
19
+ str: The user ID extracted from the JWT
20
+
21
+ Raises:
22
+ HTTPException: If no valid token is found or if the token is invalid
23
+ """
24
+ auth_header = request.headers.get('Authorization')
25
+
26
+ if not auth_header or not auth_header.startswith('Bearer '):
27
+ raise HTTPException(
28
+ status_code=401,
29
+ detail="No valid authentication credentials found",
30
+ headers={"WWW-Authenticate": "Bearer"}
31
+ )
32
+
33
+ token = auth_header.split(' ')[1]
34
+
35
+ try:
36
+ # For Supabase JWT, we just need to decode and extract the user ID
37
+ # The actual validation is handled by Supabase's RLS
38
+ payload = jwt.decode(token, options={"verify_signature": False})
39
+
40
+ # Supabase stores the user ID in the 'sub' claim
41
+ user_id = payload.get('sub')
42
+
43
+ if not user_id:
44
+ raise HTTPException(
45
+ status_code=401,
46
+ detail="Invalid token payload",
47
+ headers={"WWW-Authenticate": "Bearer"}
48
+ )
49
+
50
+ return user_id
51
+
52
+ except PyJWTError:
53
+ raise HTTPException(
54
+ status_code=401,
55
+ detail="Invalid token",
56
+ headers={"WWW-Authenticate": "Bearer"}
57
+ )
58
+
59
+ async def get_user_id_from_stream_auth(
60
+ request: Request,
61
+ token: Optional[str] = None
62
+ ) -> str:
63
+ """
64
+ Extract and verify the user ID from either the Authorization header or query parameter token.
65
+ This function is specifically designed for streaming endpoints that need to support both
66
+ header-based and query parameter-based authentication (for EventSource compatibility).
67
+
68
+ Args:
69
+ request: The FastAPI request object
70
+ token: Optional token from query parameters
71
+
72
+ Returns:
73
+ str: The user ID extracted from the JWT
74
+
75
+ Raises:
76
+ HTTPException: If no valid token is found or if the token is invalid
77
+ """
78
+ # Try to get user_id from token in query param (for EventSource which can't set headers)
79
+ if token:
80
+ try:
81
+ # For Supabase JWT, we just need to decode and extract the user ID
82
+ payload = jwt.decode(token, options={"verify_signature": False})
83
+ user_id = payload.get('sub')
84
+ if user_id:
85
+ return user_id
86
+ except Exception:
87
+ pass
88
+
89
+ # If no valid token in query param, try to get it from the Authorization header
90
+ auth_header = request.headers.get('Authorization')
91
+ if auth_header and auth_header.startswith('Bearer '):
92
+ try:
93
+ # Extract token from header
94
+ header_token = auth_header.split(' ')[1]
95
+ payload = jwt.decode(header_token, options={"verify_signature": False})
96
+ user_id = payload.get('sub')
97
+ if user_id:
98
+ return user_id
99
+ except Exception:
100
+ pass
101
+
102
+ # If we still don't have a user_id, return authentication error
103
+ raise HTTPException(
104
+ status_code=401,
105
+ detail="No valid authentication credentials found",
106
+ headers={"WWW-Authenticate": "Bearer"}
107
+ )
108
+ async def verify_thread_access(client, thread_id: str, user_id: str):
109
+ """
110
+ Verify that a user has access to a specific thread based on account membership.
111
+
112
+ Args:
113
+ client: The Supabase client
114
+ thread_id: The thread ID to check access for
115
+ user_id: The user ID to check permissions for
116
+
117
+ Returns:
118
+ bool: True if the user has access
119
+
120
+ Raises:
121
+ HTTPException: If the user doesn't have access to the thread
122
+ """
123
+ # Query the thread to get account information
124
+ thread_result = await client.table('threads').select('*,project_id').eq('thread_id', thread_id).execute()
125
+
126
+ if not thread_result.data or len(thread_result.data) == 0:
127
+ raise HTTPException(status_code=404, detail="Thread not found")
128
+
129
+ thread_data = thread_result.data[0]
130
+
131
+ # Check if project is public
132
+ project_id = thread_data.get('project_id')
133
+ if project_id:
134
+ project_result = await client.table('projects').select('is_public').eq('project_id', project_id).execute()
135
+ if project_result.data and len(project_result.data) > 0:
136
+ if project_result.data[0].get('is_public'):
137
+ return True
138
+
139
+ account_id = thread_data.get('account_id')
140
+ # When using service role, we need to manually check account membership instead of using current_user_account_role
141
+ if account_id:
142
+ account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
143
+ if account_user_result.data and len(account_user_result.data) > 0:
144
+ return True
145
+ raise HTTPException(status_code=403, detail="Not authorized to access this thread")
146
+
147
+ async def get_optional_user_id(request: Request) -> Optional[str]:
148
+ """
149
+ Extract the user ID from the JWT in the Authorization header if present,
150
+ but don't require authentication. Returns None if no valid token is found.
151
+
152
+ This function is used for endpoints that support both authenticated and
153
+ unauthenticated access (like public projects).
154
+
155
+ Args:
156
+ request: The FastAPI request object
157
+
158
+ Returns:
159
+ Optional[str]: The user ID extracted from the JWT, or None if no valid token
160
+ """
161
+ auth_header = request.headers.get('Authorization')
162
+
163
+ if not auth_header or not auth_header.startswith('Bearer '):
164
+ return None
165
+
166
+ token = auth_header.split(' ')[1]
167
+
168
+ try:
169
+ # For Supabase JWT, we just need to decode and extract the user ID
170
+ payload = jwt.decode(token, options={"verify_signature": False})
171
+
172
+ # Supabase stores the user ID in the 'sub' claim
173
+ user_id = payload.get('sub')
174
+
175
+ return user_id
176
+ except PyJWTError:
177
+ return None
base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /home/ubuntu/visionos_farm/shared/tools/base.py
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, Any
4
+
5
+ class Tool(ABC):
6
+ """Abstract base class for all dynamically loaded tools."""
7
+
8
+ @property
9
+ @abstractmethod
10
+ def name(self) -> str:
11
+ """Unique name of the tool."""
12
+ pass
13
+
14
+ @property
15
+ @abstractmethod
16
+ def description(self) -> str:
17
+ """Description of what the tool does."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
22
+ """Executes the tool with the given parameters.
23
+
24
+ Args:
25
+ parameters: A dictionary of parameters required by the tool.
26
+
27
+ Returns:
28
+ A dictionary containing the result of the tool execution.
29
+ Should include a 'status' key ('SUCCESS' or 'FAILURE')
30
+ and potentially 'output' or 'error' keys.
31
+ """
32
+ pass
33
+
billing.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from typing import Dict, Optional, Tuple
3
+ from utils.logger import logger
4
+ from utils.config import config, EnvMode
5
+
6
+ # Define subscription tiers and their monthly limits (in minutes)
7
+ SUBSCRIPTION_TIERS = {
8
+ 'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8},
9
+ 'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300},
10
+ 'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400}
11
+ }
12
+
13
+ async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
14
+ """Get the current subscription for an account."""
15
+ result = await client.schema('basejump').from_('billing_subscriptions') \
16
+ .select('*') \
17
+ .eq('account_id', account_id) \
18
+ .eq('status', 'active') \
19
+ .order('created', desc=True) \
20
+ .limit(1) \
21
+ .execute()
22
+
23
+ if result.data and len(result.data) > 0:
24
+ return result.data[0]
25
+ return None
26
+
27
+ async def calculate_monthly_usage(client, account_id: str) -> float:
28
+ """Calculate total agent run minutes for the current month for an account."""
29
+ # Get start of current month in UTC
30
+ now = datetime.now(timezone.utc)
31
+ start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
32
+
33
+ # First get all threads for this account
34
+ threads_result = await client.table('threads') \
35
+ .select('thread_id') \
36
+ .eq('account_id', account_id) \
37
+ .execute()
38
+
39
+ if not threads_result.data:
40
+ return 0.0
41
+
42
+ thread_ids = [t['thread_id'] for t in threads_result.data]
43
+
44
+ # Then get all agent runs for these threads in current month
45
+ runs_result = await client.table('agent_runs') \
46
+ .select('started_at, completed_at') \
47
+ .in_('thread_id', thread_ids) \
48
+ .gte('started_at', start_of_month.isoformat()) \
49
+ .execute()
50
+
51
+ if not runs_result.data:
52
+ return 0.0
53
+
54
+ # Calculate total minutes
55
+ total_seconds = 0
56
+ now_ts = now.timestamp()
57
+
58
+ for run in runs_result.data:
59
+ start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
60
+ if run['completed_at']:
61
+ end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
62
+ else:
63
+ # For running jobs, use current time
64
+ end_time = now_ts
65
+
66
+ total_seconds += (end_time - start_time)
67
+
68
+ return total_seconds / 60 # Convert to minutes
69
+
70
+ async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
71
+ """
72
+ Check if an account can run agents based on their subscription and usage.
73
+
74
+ Returns:
75
+ Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
76
+ """
77
+ if config.ENV_MODE == EnvMode.LOCAL:
78
+ logger.info("Running in local development mode - billing checks are disabled")
79
+ return True, "Local development mode - billing disabled", {
80
+ "price_id": "local_dev",
81
+ "plan_name": "Local Development",
82
+ "minutes_limit": "no limit"
83
+ }
84
+
85
+ # For staging/production, check subscription status
86
+
87
+ # Get current subscription
88
+ subscription = await get_account_subscription(client, account_id)
89
+
90
+ # If no subscription, they can use free tier
91
+ if not subscription:
92
+ subscription = {
93
+ 'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier
94
+ 'plan_name': 'free'
95
+ }
96
+
97
+ # if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC':
98
+ # return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription
99
+
100
+ # Get tier info
101
+ tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
102
+ if not tier_info:
103
+ return False, "Invalid subscription tier", subscription
104
+
105
+ # Calculate current month's usage
106
+ current_usage = await calculate_monthly_usage(client, account_id)
107
+
108
+ # Check if within limits
109
+ if current_usage >= tier_info['minutes']:
110
+ return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
111
+
112
+ return True, "OK", subscription
113
+
114
+ # Helper function to get account ID from thread
115
+ async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
116
+ """Get the account ID associated with a thread."""
117
+ result = await client.table('threads') \
118
+ .select('account_id') \
119
+ .eq('thread_id', thread_id) \
120
+ .limit(1) \
121
+ .execute()
122
+
123
+ if result.data and len(result.data) > 0:
124
+ return result.data[0]['account_id']
125
+ return None
browser_api.py ADDED
@@ -0,0 +1,2063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, APIRouter, HTTPException, Body
2
+ from playwright.async_api import async_playwright, Browser, Page, ElementHandle
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict, Any, Union
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import re
9
+ import base64
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime
12
+ import os
13
+ import random
14
+ from functools import cached_property
15
+ import traceback
16
+ import pytesseract
17
+ from PIL import Image
18
+ import io
19
+
20
+ #######################################################
21
+ # Action model definitions
22
+ #######################################################
23
+
24
+ class Position(BaseModel):
25
+ x: int
26
+ y: int
27
+
28
+ class ClickElementAction(BaseModel):
29
+ index: int
30
+
31
+ class ClickCoordinatesAction(BaseModel):
32
+ x: int
33
+ y: int
34
+
35
+ class GoToUrlAction(BaseModel):
36
+ url: str
37
+
38
+ class InputTextAction(BaseModel):
39
+ index: int
40
+ text: str
41
+
42
+ class ScrollAction(BaseModel):
43
+ amount: Optional[int] = None
44
+
45
+ class SendKeysAction(BaseModel):
46
+ keys: str
47
+
48
+ class SearchGoogleAction(BaseModel):
49
+ query: str
50
+
51
+ class SwitchTabAction(BaseModel):
52
+ page_id: int
53
+
54
+ class OpenTabAction(BaseModel):
55
+ url: str
56
+
57
+ class CloseTabAction(BaseModel):
58
+ page_id: int
59
+
60
+ class NoParamsAction(BaseModel):
61
+ pass
62
+
63
+ class DragDropAction(BaseModel):
64
+ element_source: Optional[str] = None
65
+ element_target: Optional[str] = None
66
+ element_source_offset: Optional[Position] = None
67
+ element_target_offset: Optional[Position] = None
68
+ coord_source_x: Optional[int] = None
69
+ coord_source_y: Optional[int] = None
70
+ coord_target_x: Optional[int] = None
71
+ coord_target_y: Optional[int] = None
72
+ steps: Optional[int] = 10
73
+ delay_ms: Optional[int] = 5
74
+
75
+ class DoneAction(BaseModel):
76
+ success: bool = True
77
+ text: str = ""
78
+
79
+ #######################################################
80
+ # DOM Structure Models
81
+ #######################################################
82
+
83
+ @dataclass
84
+ class CoordinateSet:
85
+ x: int = 0
86
+ y: int = 0
87
+ width: int = 0
88
+ height: int = 0
89
+
90
+ @dataclass
91
+ class ViewportInfo:
92
+ width: int = 0
93
+ height: int = 0
94
+ scroll_x: int = 0
95
+ scroll_y: int = 0
96
+
97
+ @dataclass
98
+ class HashedDomElement:
99
+ tag_name: str
100
+ attributes: Dict[str, str]
101
+ is_visible: bool
102
+ page_coordinates: Optional[CoordinateSet] = None
103
+
104
+ @dataclass
105
+ class DOMBaseNode:
106
+ is_visible: bool
107
+ parent: Optional['DOMElementNode'] = None
108
+
109
+ @dataclass
110
+ class DOMTextNode(DOMBaseNode):
111
+ text: str = field(default="")
112
+ type: str = 'TEXT_NODE'
113
+
114
+ def has_parent_with_highlight_index(self) -> bool:
115
+ current = self.parent
116
+ while current is not None:
117
+ if current.highlight_index is not None:
118
+ return True
119
+ current = current.parent
120
+ return False
121
+
122
+ @dataclass
123
+ class DOMElementNode(DOMBaseNode):
124
+ tag_name: str = field(default="")
125
+ xpath: str = field(default="")
126
+ attributes: Dict[str, str] = field(default_factory=dict)
127
+ children: List['DOMBaseNode'] = field(default_factory=list)
128
+
129
+ is_interactive: bool = False
130
+ is_top_element: bool = False
131
+ is_in_viewport: bool = False
132
+ shadow_root: bool = False
133
+ highlight_index: Optional[int] = None
134
+ viewport_coordinates: Optional[CoordinateSet] = None
135
+ page_coordinates: Optional[CoordinateSet] = None
136
+ viewport_info: Optional[ViewportInfo] = None
137
+
138
+ def __repr__(self) -> str:
139
+ tag_str = f'<{self.tag_name}'
140
+ for key, value in self.attributes.items():
141
+ tag_str += f' {key}="{value}"'
142
+ tag_str += '>'
143
+
144
+ extras = []
145
+ if self.is_interactive:
146
+ extras.append('interactive')
147
+ if self.is_top_element:
148
+ extras.append('top')
149
+ if self.highlight_index is not None:
150
+ extras.append(f'highlight:{self.highlight_index}')
151
+
152
+ if extras:
153
+ tag_str += f' [{", ".join(extras)}]'
154
+
155
+ return tag_str
156
+
157
+ @cached_property
158
+ def hash(self) -> HashedDomElement:
159
+ return HashedDomElement(
160
+ tag_name=self.tag_name,
161
+ attributes=self.attributes,
162
+ is_visible=self.is_visible,
163
+ page_coordinates=self.page_coordinates
164
+ )
165
+
166
+ def get_all_text_till_next_clickable_element(self, max_depth: int = -1) -> str:
167
+ text_parts = []
168
+
169
+ def collect_text(node: DOMBaseNode, current_depth: int) -> None:
170
+ if max_depth != -1 and current_depth > max_depth:
171
+ return
172
+
173
+ if isinstance(node, DOMElementNode) and node != self and node.highlight_index is not None:
174
+ return
175
+
176
+ if isinstance(node, DOMTextNode):
177
+ text_parts.append(node.text)
178
+ elif isinstance(node, DOMElementNode):
179
+ for child in node.children:
180
+ collect_text(child, current_depth + 1)
181
+
182
+ collect_text(self, 0)
183
+ return '\n'.join(text_parts).strip()
184
+
185
+ def clickable_elements_to_string(self, include_attributes: list[str] | None = None) -> str:
186
+ """Convert the processed DOM content to HTML."""
187
+ formatted_text = []
188
+
189
+ def process_node(node: DOMBaseNode, depth: int) -> None:
190
+ if isinstance(node, DOMElementNode):
191
+ # Add element with highlight_index
192
+ if node.highlight_index is not None:
193
+ attributes_str = ''
194
+ text = node.get_all_text_till_next_clickable_element()
195
+
196
+ # Process attributes for display
197
+ display_attributes = []
198
+ if include_attributes:
199
+ for key, value in node.attributes.items():
200
+ if key in include_attributes and value and value != node.tag_name:
201
+ if text and value in text:
202
+ continue # Skip if attribute value is already in the text
203
+ display_attributes.append(str(value))
204
+
205
+ attributes_str = ';'.join(display_attributes)
206
+
207
+ # Build the element string
208
+ line = f'[{node.highlight_index}]<{node.tag_name}'
209
+
210
+ # Add important attributes for identification
211
+ for attr_name in ['id', 'href', 'name', 'value', 'type']:
212
+ if attr_name in node.attributes and node.attributes[attr_name]:
213
+ line += f' {attr_name}="{node.attributes[attr_name]}"'
214
+
215
+ # Add the text content if available
216
+ if text:
217
+ line += f'> {text}'
218
+ elif attributes_str:
219
+ line += f'> {attributes_str}'
220
+ else:
221
+ # If no text and no attributes, use the tag name
222
+ line += f'> {node.tag_name.upper()}'
223
+
224
+ line += ' </>'
225
+ formatted_text.append(line)
226
+
227
+ # Process children regardless
228
+ for child in node.children:
229
+ process_node(child, depth + 1)
230
+
231
+ elif isinstance(node, DOMTextNode):
232
+ # Add text only if it doesn't have a highlighted parent
233
+ if not node.has_parent_with_highlight_index() and node.is_visible:
234
+ if node.text and node.text.strip():
235
+ formatted_text.append(node.text)
236
+
237
+ process_node(self, 0)
238
+ result = '\n'.join(formatted_text)
239
+ return result if result.strip() else "No interactive elements found"
240
+
241
+ @dataclass
242
+ class DOMState:
243
+ element_tree: DOMElementNode
244
+ selector_map: Dict[int, DOMElementNode]
245
+ url: str = ""
246
+ title: str = ""
247
+ pixels_above: int = 0
248
+ pixels_below: int = 0
249
+
250
+ #######################################################
251
+ # Browser Action Result Model
252
+ #######################################################
253
+
254
+ class BrowserActionResult(BaseModel):
255
+ success: bool = True
256
+ message: str = ""
257
+ error: str = ""
258
+
259
+ # Extended state information
260
+ url: Optional[str] = None
261
+ title: Optional[str] = None
262
+ elements: Optional[str] = None # Formatted string of clickable elements
263
+ screenshot_base64: Optional[str] = None
264
+ pixels_above: int = 0
265
+ pixels_below: int = 0
266
+ content: Optional[str] = None
267
+ ocr_text: Optional[str] = None # Added field for OCR text
268
+
269
+ # Additional metadata
270
+ element_count: int = 0 # Number of interactive elements found
271
+ interactive_elements: Optional[List[Dict[str, Any]]] = None # Simplified list of interactive elements
272
+ viewport_width: Optional[int] = None
273
+ viewport_height: Optional[int] = None
274
+
275
+ class Config:
276
+ arbitrary_types_allowed = True
277
+
278
+ #######################################################
279
+ # Browser Automation Implementation
280
+ #######################################################
281
+
282
+ class BrowserAutomation:
283
+ def __init__(self):
284
+ self.router = APIRouter()
285
+ self.browser: Browser = None
286
+ self.pages: List[Page] = []
287
+ self.current_page_index: int = 0
288
+ self.logger = logging.getLogger("browser_automation")
289
+ self.include_attributes = ["id", "href", "src", "alt", "aria-label", "placeholder", "name", "role", "title", "value"]
290
+ self.screenshot_dir = os.path.join(os.getcwd(), "screenshots")
291
+ os.makedirs(self.screenshot_dir, exist_ok=True)
292
+
293
+ # Register routes
294
+ self.router.on_startup.append(self.startup)
295
+ self.router.on_shutdown.append(self.shutdown)
296
+
297
+ # Basic navigation
298
+ self.router.post("/automation/navigate_to")(self.navigate_to)
299
+ self.router.post("/automation/search_google")(self.search_google)
300
+ self.router.post("/automation/go_back")(self.go_back)
301
+ self.router.post("/automation/wait")(self.wait)
302
+
303
+ # Element interaction
304
+ self.router.post("/automation/click_element")(self.click_element)
305
+ self.router.post("/automation/click_coordinates")(self.click_coordinates)
306
+ self.router.post("/automation/input_text")(self.input_text)
307
+ self.router.post("/automation/send_keys")(self.send_keys)
308
+
309
+ # Tab management
310
+ self.router.post("/automation/switch_tab")(self.switch_tab)
311
+ self.router.post("/automation/open_tab")(self.open_tab)
312
+ self.router.post("/automation/close_tab")(self.close_tab)
313
+
314
+ # Content actions
315
+ self.router.post("/automation/extract_content")(self.extract_content)
316
+ self.router.post("/automation/save_pdf")(self.save_pdf)
317
+
318
+ # Scroll actions
319
+ self.router.post("/automation/scroll_down")(self.scroll_down)
320
+ self.router.post("/automation/scroll_up")(self.scroll_up)
321
+ self.router.post("/automation/scroll_to_text")(self.scroll_to_text)
322
+
323
+ # Dropdown actions
324
+ self.router.post("/automation/get_dropdown_options")(self.get_dropdown_options)
325
+ self.router.post("/automation/select_dropdown_option")(self.select_dropdown_option)
326
+
327
+ # Drag and drop
328
+ self.router.post("/automation/drag_drop")(self.drag_drop)
329
+
330
+ async def startup(self):
331
+ """Initialize the browser instance on startup"""
332
+ try:
333
+ print("Starting browser initialization...")
334
+ playwright = await async_playwright().start()
335
+ print("Playwright started, launching browser...")
336
+
337
+ # Use non-headless mode for testing with slower timeouts
338
+ launch_options = {
339
+ "headless": False,
340
+ "timeout": 60000
341
+ }
342
+
343
+ try:
344
+ self.browser = await playwright.chromium.launch(**launch_options)
345
+ print("Browser launched successfully")
346
+ except Exception as browser_error:
347
+ print(f"Failed to launch browser: {browser_error}")
348
+ # Try with minimal options
349
+ print("Retrying with minimal options...")
350
+ launch_options = {"timeout": 90000}
351
+ self.browser = await playwright.chromium.launch(**launch_options)
352
+ print("Browser launched with minimal options")
353
+
354
+ try:
355
+ await self.get_current_page()
356
+ print("Found existing page, using it")
357
+ self.current_page_index = 0
358
+ except Exception as page_error:
359
+ print(f"Error finding existing page, creating new one. ( {page_error})")
360
+ page = await self.browser.new_page()
361
+ print("New page created successfully")
362
+ self.pages.append(page)
363
+ self.current_page_index = 0
364
+ # Navigate to about:blank to ensure page is ready
365
+ # await page.goto("google.com", timeout=30000)
366
+ print("Navigated to google.com")
367
+
368
+ print("Browser initialization completed successfully")
369
+ except Exception as e:
370
+ print(f"Browser startup error: {str(e)}")
371
+ traceback.print_exc()
372
+ raise RuntimeError(f"Browser initialization failed: {str(e)}")
373
+
374
+ async def shutdown(self):
375
+ """Clean up browser instance on shutdown"""
376
+ if self.browser:
377
+ await self.browser.close()
378
+
379
+ async def get_current_page(self) -> Page:
380
+ """Get the current active page"""
381
+ if not self.pages:
382
+ raise HTTPException(status_code=500, detail="No browser pages available")
383
+ return self.pages[self.current_page_index]
384
+
385
+ async def get_selector_map(self) -> Dict[int, DOMElementNode]:
386
+ """Get a map of selectable elements on the page"""
387
+ page = await self.get_current_page()
388
+
389
+ # Create a selector map for interactive elements
390
+ selector_map = {}
391
+
392
+ try:
393
+ # More comprehensive JavaScript to find interactive elements
394
+ elements_js = """
395
+ (() => {
396
+ // Helper function to get all attributes as an object
397
+ function getAttributes(el) {
398
+ const attributes = {};
399
+ for (const attr of el.attributes) {
400
+ attributes[attr.name] = attr.value;
401
+ }
402
+ return attributes;
403
+ }
404
+
405
+ // Find all potentially interactive elements
406
+ const interactiveElements = Array.from(document.querySelectorAll(
407
+ 'a, button, input, select, textarea, [role="button"], [role="link"], [role="checkbox"], [role="radio"], [tabindex]:not([tabindex="-1"])'
408
+ ));
409
+
410
+ // Filter for visible elements
411
+ const visibleElements = interactiveElements.filter(el => {
412
+ const style = window.getComputedStyle(el);
413
+ const rect = el.getBoundingClientRect();
414
+ return style.display !== 'none' &&
415
+ style.visibility !== 'hidden' &&
416
+ style.opacity !== '0' &&
417
+ rect.width > 0 &&
418
+ rect.height > 0;
419
+ });
420
+
421
+ // Map to our expected structure
422
+ return visibleElements.map((el, index) => {
423
+ const rect = el.getBoundingClientRect();
424
+ const isInViewport = rect.top >= 0 &&
425
+ rect.left >= 0 &&
426
+ rect.bottom <= window.innerHeight &&
427
+ rect.right <= window.innerWidth;
428
+
429
+ return {
430
+ index: index + 1,
431
+ tagName: el.tagName.toLowerCase(),
432
+ text: el.innerText || el.value || '',
433
+ attributes: getAttributes(el),
434
+ isVisible: true,
435
+ isInteractive: true,
436
+ pageCoordinates: {
437
+ x: rect.left + window.scrollX,
438
+ y: rect.top + window.scrollY,
439
+ width: rect.width,
440
+ height: rect.height
441
+ },
442
+ viewportCoordinates: {
443
+ x: rect.left,
444
+ y: rect.top,
445
+ width: rect.width,
446
+ height: rect.height
447
+ },
448
+ isInViewport: isInViewport
449
+ };
450
+ });
451
+ })();
452
+ """
453
+
454
+ elements = await page.evaluate(elements_js)
455
+ print(f"Found {len(elements)} interactive elements in selector map")
456
+
457
+ # Create a root element for the tree
458
+ root = DOMElementNode(
459
+ is_visible=True,
460
+ tag_name="body",
461
+ is_interactive=False,
462
+ is_top_element=True
463
+ )
464
+
465
+ # Create element nodes for each element
466
+ for idx, el in enumerate(elements):
467
+ # Create coordinate sets
468
+ page_coordinates = None
469
+ viewport_coordinates = None
470
+
471
+ if 'pageCoordinates' in el:
472
+ coords = el['pageCoordinates']
473
+ page_coordinates = CoordinateSet(
474
+ x=coords.get('x', 0),
475
+ y=coords.get('y', 0),
476
+ width=coords.get('width', 0),
477
+ height=coords.get('height', 0)
478
+ )
479
+
480
+ if 'viewportCoordinates' in el:
481
+ coords = el['viewportCoordinates']
482
+ viewport_coordinates = CoordinateSet(
483
+ x=coords.get('x', 0),
484
+ y=coords.get('y', 0),
485
+ width=coords.get('width', 0),
486
+ height=coords.get('height', 0)
487
+ )
488
+
489
+ # Create the element node
490
+ element_node = DOMElementNode(
491
+ is_visible=el.get('isVisible', True),
492
+ tag_name=el.get('tagName', 'div'),
493
+ attributes=el.get('attributes', {}),
494
+ is_interactive=el.get('isInteractive', True),
495
+ is_in_viewport=el.get('isInViewport', False),
496
+ highlight_index=el.get('index', idx + 1),
497
+ page_coordinates=page_coordinates,
498
+ viewport_coordinates=viewport_coordinates
499
+ )
500
+
501
+ # Add a text node if there's text content
502
+ if el.get('text'):
503
+ text_node = DOMTextNode(is_visible=True, text=el.get('text', ''))
504
+ text_node.parent = element_node
505
+ element_node.children.append(text_node)
506
+
507
+ selector_map[el.get('index', idx + 1)] = element_node
508
+ root.children.append(element_node)
509
+ element_node.parent = root
510
+
511
+ except Exception as e:
512
+ print(f"Error getting selector map: {e}")
513
+ traceback.print_exc()
514
+ # Create a dummy element to avoid breaking tests
515
+ dummy = DOMElementNode(
516
+ is_visible=True,
517
+ tag_name="a",
518
+ attributes={'href': '#'},
519
+ is_interactive=True,
520
+ highlight_index=1
521
+ )
522
+ dummy_text = DOMTextNode(is_visible=True, text="Dummy Element")
523
+ dummy_text.parent = dummy
524
+ dummy.children.append(dummy_text)
525
+ selector_map[1] = dummy
526
+
527
+ return selector_map
528
+
529
+ async def get_current_dom_state(self) -> DOMState:
530
+ """Get the current DOM state including element tree and selector map"""
531
+ try:
532
+ page = await self.get_current_page()
533
+ selector_map = await self.get_selector_map()
534
+
535
+ # Create a root element
536
+ root = DOMElementNode(
537
+ is_visible=True,
538
+ tag_name="body",
539
+ is_interactive=False,
540
+ is_top_element=True
541
+ )
542
+
543
+ # Add all elements from selector map as children of root
544
+ for element in selector_map.values():
545
+ if element.parent is None:
546
+ element.parent = root
547
+ root.children.append(element)
548
+
549
+ # Get basic page info
550
+ url = page.url
551
+ try:
552
+ title = await page.title()
553
+ except:
554
+ title = "Unknown Title"
555
+
556
+ # Get more accurate scroll information - fix JavaScript syntax
557
+ try:
558
+ scroll_info = await page.evaluate("""
559
+ () => {
560
+ const body = document.body;
561
+ const html = document.documentElement;
562
+ const totalHeight = Math.max(
563
+ body.scrollHeight, body.offsetHeight,
564
+ html.clientHeight, html.scrollHeight, html.offsetHeight
565
+ );
566
+ const scrollY = window.scrollY || window.pageYOffset;
567
+ const windowHeight = window.innerHeight;
568
+
569
+ return {
570
+ pixelsAbove: scrollY,
571
+ pixelsBelow: Math.max(0, totalHeight - scrollY - windowHeight),
572
+ totalHeight: totalHeight,
573
+ viewportHeight: windowHeight
574
+ };
575
+ }
576
+ """)
577
+ pixels_above = scroll_info.get('pixelsAbove', 0)
578
+ pixels_below = scroll_info.get('pixelsBelow', 0)
579
+ except Exception as e:
580
+ print(f"Error getting scroll info: {e}")
581
+ pixels_above = 0
582
+ pixels_below = 0
583
+
584
+ return DOMState(
585
+ element_tree=root,
586
+ selector_map=selector_map,
587
+ url=url,
588
+ title=title,
589
+ pixels_above=pixels_above,
590
+ pixels_below=pixels_below
591
+ )
592
+ except Exception as e:
593
+ print(f"Error getting DOM state: {e}")
594
+ traceback.print_exc()
595
+ # Return a minimal valid state to avoid breaking tests
596
+ dummy_root = DOMElementNode(
597
+ is_visible=True,
598
+ tag_name="body",
599
+ is_interactive=False,
600
+ is_top_element=True
601
+ )
602
+ dummy_map = {1: dummy_root}
603
+ return DOMState(
604
+ element_tree=dummy_root,
605
+ selector_map=dummy_map,
606
+ url=page.url if 'page' in locals() else "about:blank",
607
+ title="Error page",
608
+ pixels_above=0,
609
+ pixels_below=0
610
+ )
611
+
612
+ async def take_screenshot(self) -> str:
613
+ """Take a screenshot and return as base64 encoded string"""
614
+ try:
615
+ page = await self.get_current_page()
616
+ screenshot_bytes = await page.screenshot(type='jpeg', quality=60, full_page=False)
617
+ return base64.b64encode(screenshot_bytes).decode('utf-8')
618
+ except Exception as e:
619
+ print(f"Error taking screenshot: {e}")
620
+ # Return an empty string rather than failing
621
+ return ""
622
+
623
+ async def save_screenshot_to_file(self) -> str:
624
+ """Take a screenshot and save to file, returning the path"""
625
+ try:
626
+ page = await self.get_current_page()
627
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
628
+ random_id = random.randint(1000, 9999)
629
+ filename = f"screenshot_{timestamp}_{random_id}.jpg"
630
+ filepath = os.path.join(self.screenshot_dir, filename)
631
+
632
+ await page.screenshot(path=filepath, type='jpeg', quality=60, full_page=False)
633
+ return filepath
634
+ except Exception as e:
635
+ print(f"Error saving screenshot: {e}")
636
+ return ""
637
+
638
+ async def extract_ocr_text_from_screenshot(self, screenshot_base64: str) -> str:
639
+ """Extract text from screenshot using OCR"""
640
+ if not screenshot_base64:
641
+ return ""
642
+
643
+ try:
644
+ # Decode base64 to image
645
+ image_bytes = base64.b64decode(screenshot_base64)
646
+ image = Image.open(io.BytesIO(image_bytes))
647
+
648
+ # Extract text using pytesseract
649
+ ocr_text = pytesseract.image_to_string(image)
650
+
651
+ # Clean up the text
652
+ ocr_text = ocr_text.strip()
653
+
654
+ return ocr_text
655
+ except Exception as e:
656
+ print(f"Error performing OCR: {e}")
657
+ traceback.print_exc()
658
+ return ""
659
+
660
+ async def get_updated_browser_state(self, action_name: str) -> tuple:
661
+ """Helper method to get updated browser state after any action
662
+ Returns a tuple of (dom_state, screenshot, elements, metadata)
663
+ """
664
+ try:
665
+ # Wait a moment for any potential async processes to settle
666
+ await asyncio.sleep(0.5)
667
+
668
+ # Get updated state
669
+ dom_state = await self.get_current_dom_state()
670
+ screenshot = await self.take_screenshot()
671
+
672
+ # Format elements for output
673
+ elements = dom_state.element_tree.clickable_elements_to_string(
674
+ include_attributes=self.include_attributes
675
+ )
676
+
677
+ # Collect additional metadata
678
+ page = await self.get_current_page()
679
+ metadata = {}
680
+
681
+ # Get element count
682
+ metadata['element_count'] = len(dom_state.selector_map)
683
+
684
+ # Create simplified interactive elements list
685
+ interactive_elements = []
686
+ for idx, element in dom_state.selector_map.items():
687
+ element_info = {
688
+ 'index': idx,
689
+ 'tag_name': element.tag_name,
690
+ 'text': element.get_all_text_till_next_clickable_element(),
691
+ 'is_in_viewport': element.is_in_viewport
692
+ }
693
+
694
+ # Add key attributes
695
+ for attr_name in ['id', 'href', 'src', 'alt', 'placeholder', 'name', 'role', 'title', 'type']:
696
+ if attr_name in element.attributes:
697
+ element_info[attr_name] = element.attributes[attr_name]
698
+
699
+ interactive_elements.append(element_info)
700
+
701
+ metadata['interactive_elements'] = interactive_elements
702
+
703
+ # Get viewport dimensions - Fix syntax error in JavaScript
704
+ try:
705
+ viewport = await page.evaluate("""
706
+ () => {
707
+ return {
708
+ width: window.innerWidth,
709
+ height: window.innerHeight
710
+ };
711
+ }
712
+ """)
713
+ metadata['viewport_width'] = viewport.get('width', 0)
714
+ metadata['viewport_height'] = viewport.get('height', 0)
715
+ except Exception as e:
716
+ print(f"Error getting viewport dimensions: {e}")
717
+ metadata['viewport_width'] = 0
718
+ metadata['viewport_height'] = 0
719
+
720
+ # Extract OCR text from screenshot if available
721
+ ocr_text = ""
722
+ if screenshot:
723
+ ocr_text = await self.extract_ocr_text_from_screenshot(screenshot)
724
+ metadata['ocr_text'] = ocr_text
725
+
726
+ print(f"Got updated state after {action_name}: {len(dom_state.selector_map)} elements")
727
+ return dom_state, screenshot, elements, metadata
728
+ except Exception as e:
729
+ print(f"Error getting updated state after {action_name}: {e}")
730
+ traceback.print_exc()
731
+ # Return empty values in case of error
732
+ return None, "", "", {}
733
+
734
+ def build_action_result(self, success: bool, message: str, dom_state, screenshot: str,
735
+ elements: str, metadata: dict, error: str = "", content: str = None,
736
+ fallback_url: str = None) -> BrowserActionResult:
737
+ """Helper method to build a consistent BrowserActionResult"""
738
+ # Ensure elements is never None to avoid display issues
739
+ if elements is None:
740
+ elements = ""
741
+
742
+ return BrowserActionResult(
743
+ success=success,
744
+ message=message,
745
+ error=error,
746
+ url=dom_state.url if dom_state else fallback_url or "",
747
+ title=dom_state.title if dom_state else "",
748
+ elements=elements,
749
+ screenshot_base64=screenshot,
750
+ pixels_above=dom_state.pixels_above if dom_state else 0,
751
+ pixels_below=dom_state.pixels_below if dom_state else 0,
752
+ content=content,
753
+ ocr_text=metadata.get('ocr_text', ""),
754
+ element_count=metadata.get('element_count', 0),
755
+ interactive_elements=metadata.get('interactive_elements', []),
756
+ viewport_width=metadata.get('viewport_width', 0),
757
+ viewport_height=metadata.get('viewport_height', 0)
758
+ )
759
+
760
+ # Basic Navigation Actions
761
+
762
+ async def navigate_to(self, action: GoToUrlAction = Body(...)):
763
+ """Navigate to a specified URL"""
764
+ try:
765
+ page = await self.get_current_page()
766
+ await page.goto(action.url, wait_until="domcontentloaded")
767
+ await page.wait_for_load_state("networkidle", timeout=10000)
768
+
769
+ # Get updated state after action
770
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"navigate_to({action.url})")
771
+
772
+ result = self.build_action_result(
773
+ True,
774
+ f"Navigated to {action.url}",
775
+ dom_state,
776
+ screenshot,
777
+ elements,
778
+ metadata,
779
+ error="",
780
+ content=None
781
+ )
782
+
783
+ print(f"Navigation result: success={result.success}, url={result.url}")
784
+ return result
785
+ except Exception as e:
786
+ print(f"Navigation error: {str(e)}")
787
+ traceback.print_exc()
788
+ # Try to get some state info even after error
789
+ try:
790
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("navigate_error_recovery")
791
+ return self.build_action_result(
792
+ False,
793
+ str(e),
794
+ dom_state,
795
+ screenshot,
796
+ elements,
797
+ metadata,
798
+ error=str(e),
799
+ content=None
800
+ )
801
+ except:
802
+ return self.build_action_result(
803
+ False,
804
+ str(e),
805
+ None,
806
+ "",
807
+ "",
808
+ {},
809
+ error=str(e),
810
+ content=None
811
+ )
812
+
813
+ async def search_google(self, action: SearchGoogleAction = Body(...)):
814
+ """Search Google with the provided query"""
815
+ try:
816
+ page = await self.get_current_page()
817
+ search_url = f"https://www.google.com/search?q={action.query}"
818
+ await page.goto(search_url)
819
+ await page.wait_for_load_state()
820
+
821
+ # Get updated state after action
822
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"search_google({action.query})")
823
+
824
+ return self.build_action_result(
825
+ True,
826
+ f"Searched for '{action.query}' in Google",
827
+ dom_state,
828
+ screenshot,
829
+ elements,
830
+ metadata,
831
+ error="",
832
+ content=None
833
+ )
834
+ except Exception as e:
835
+ print(f"Search error: {str(e)}")
836
+ traceback.print_exc()
837
+ # Try to get some state info even after error
838
+ try:
839
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("search_error_recovery")
840
+ return self.build_action_result(
841
+ False,
842
+ str(e),
843
+ dom_state,
844
+ screenshot,
845
+ elements,
846
+ metadata,
847
+ error=str(e),
848
+ content=None
849
+ )
850
+ except:
851
+ return self.build_action_result(
852
+ False,
853
+ str(e),
854
+ None,
855
+ "",
856
+ "",
857
+ {},
858
+ error=str(e),
859
+ content=None
860
+ )
861
+
862
+ async def go_back(self, _: NoParamsAction = Body(...)):
863
+ """Navigate back in browser history"""
864
+ try:
865
+ page = await self.get_current_page()
866
+ await page.go_back()
867
+ await page.wait_for_load_state()
868
+
869
+ # Get updated state after action
870
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("go_back")
871
+
872
+ return self.build_action_result(
873
+ True,
874
+ "Navigated back",
875
+ dom_state,
876
+ screenshot,
877
+ elements,
878
+ metadata,
879
+ error="",
880
+ content=None
881
+ )
882
+ except Exception as e:
883
+ return self.build_action_result(
884
+ False,
885
+ str(e),
886
+ None,
887
+ "",
888
+ "",
889
+ {},
890
+ error=str(e),
891
+ content=None
892
+ )
893
+
894
+ async def wait(self, seconds: int = Body(3)):
895
+ """Wait for the specified number of seconds"""
896
+ try:
897
+ await asyncio.sleep(seconds)
898
+
899
+ # Get updated state after waiting
900
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"wait({seconds} seconds)")
901
+
902
+ return self.build_action_result(
903
+ True,
904
+ f"Waited for {seconds} seconds",
905
+ dom_state,
906
+ screenshot,
907
+ elements,
908
+ metadata,
909
+ error="",
910
+ content=None
911
+ )
912
+ except Exception as e:
913
+ return self.build_action_result(
914
+ False,
915
+ str(e),
916
+ None,
917
+ "",
918
+ "",
919
+ {},
920
+ error=str(e),
921
+ content=None
922
+ )
923
+
924
+ # Element Interaction Actions
925
+
926
+ async def click_coordinates(self, action: ClickCoordinatesAction = Body(...)):
927
+ """Click at specific x,y coordinates on the page"""
928
+ try:
929
+ page = await self.get_current_page()
930
+
931
+ # Perform the click at the specified coordinates
932
+ await page.mouse.click(action.x, action.y)
933
+
934
+ # Give time for any navigation or DOM updates to occur
935
+ await page.wait_for_load_state("networkidle", timeout=5000)
936
+
937
+ # Get updated state after action
938
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_coordinates({action.x}, {action.y})")
939
+
940
+ return self.build_action_result(
941
+ True,
942
+ f"Clicked at coordinates ({action.x}, {action.y})",
943
+ dom_state,
944
+ screenshot,
945
+ elements,
946
+ metadata,
947
+ error="",
948
+ content=None
949
+ )
950
+ except Exception as e:
951
+ print(f"Error in click_coordinates: {e}")
952
+ traceback.print_exc()
953
+
954
+ # Try to get state even after error
955
+ try:
956
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("click_coordinates_error_recovery")
957
+ return self.build_action_result(
958
+ False,
959
+ str(e),
960
+ dom_state,
961
+ screenshot,
962
+ elements,
963
+ metadata,
964
+ error=str(e),
965
+ content=None
966
+ )
967
+ except:
968
+ return self.build_action_result(
969
+ False,
970
+ str(e),
971
+ None,
972
+ "",
973
+ "",
974
+ {},
975
+ error=str(e),
976
+ content=None
977
+ )
978
+
979
+ async def click_element(self, action: ClickElementAction = Body(...)):
980
+ """Click on an element by index"""
981
+ try:
982
+ page = await self.get_current_page()
983
+
984
+ # Get the current state and selector map *before* the click
985
+ initial_dom_state = await self.get_current_dom_state()
986
+ selector_map = initial_dom_state.selector_map
987
+
988
+ if action.index not in selector_map:
989
+ # Get updated state even if element not found initially
990
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_element_error (index {action.index} not found)")
991
+ return self.build_action_result(
992
+ False,
993
+ f"Element with index {action.index} not found",
994
+ dom_state, # Use the latest state
995
+ screenshot,
996
+ elements,
997
+ metadata,
998
+ error=f"Element with index {action.index} not found"
999
+ )
1000
+
1001
+ element_to_click = selector_map[action.index]
1002
+ print(f"Attempting to click element: {element_to_click}")
1003
+
1004
+ # Construct a more reliable selector using JavaScript evaluation
1005
+ # Find the element based on its properties captured in selector_map
1006
+ js_selector_script = """
1007
+ (targetElementInfo) => {
1008
+ const interactiveElements = Array.from(document.querySelectorAll(
1009
+ 'a, button, input, select, textarea, [role="button"], [role="link"], [role="checkbox"], [role="radio"], [tabindex]:not([tabindex="-1"])'
1010
+ ));
1011
+
1012
+ const visibleElements = interactiveElements.filter(el => {
1013
+ const style = window.getComputedStyle(el);
1014
+ const rect = el.getBoundingClientRect();
1015
+ return style.display !== 'none' && style.visibility !== 'hidden' && style.opacity !== '0' && rect.width > 0 && rect.height > 0;
1016
+ });
1017
+
1018
+ if (targetElementInfo.index > 0 && targetElementInfo.index <= visibleElements.length) {
1019
+ // Return the element at the specified index (1-based)
1020
+ return visibleElements[targetElementInfo.index - 1];
1021
+ }
1022
+ return null; // Element not found at the expected index
1023
+ }
1024
+ """
1025
+
1026
+ element_info = {'index': action.index} # Pass the target index to the script
1027
+
1028
+ target_element_handle = await page.evaluate_handle(js_selector_script, element_info)
1029
+
1030
+ click_success = False
1031
+ error_message = ""
1032
+
1033
+ if await target_element_handle.evaluate("node => node !== null"):
1034
+ try:
1035
+ # Use Playwright's recommended way: click the handle
1036
+ # Add timeout and wait for element to be stable
1037
+ await target_element_handle.click(timeout=5000)
1038
+ click_success = True
1039
+ print(f"Successfully clicked element handle for index {action.index}")
1040
+ except Exception as click_error:
1041
+ error_message = f"Error clicking element handle: {click_error}"
1042
+ print(error_message)
1043
+ # Optional: Add fallback methods here if needed
1044
+ # e.g., target_element_handle.dispatch_event('click')
1045
+ else:
1046
+ error_message = f"Could not locate the target element handle for index {action.index} using JS script."
1047
+ print(error_message)
1048
+
1049
+
1050
+ # Wait for potential page changes/network activity
1051
+ try:
1052
+ await page.wait_for_load_state("networkidle", timeout=5000)
1053
+ except Exception as wait_error:
1054
+ print(f"Timeout or error waiting for network idle after click: {wait_error}")
1055
+ await asyncio.sleep(1) # Fallback wait
1056
+
1057
+ # Get updated state after action
1058
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"click_element({action.index})")
1059
+
1060
+ return self.build_action_result(
1061
+ click_success,
1062
+ f"Clicked element with index {action.index}" if click_success else f"Attempted to click element {action.index} but failed. Error: {error_message}",
1063
+ dom_state,
1064
+ screenshot,
1065
+ elements,
1066
+ metadata,
1067
+ error=error_message if not click_success else "",
1068
+ content=None
1069
+ )
1070
+
1071
+ except Exception as e:
1072
+ print(f"Error in click_element: {e}")
1073
+ traceback.print_exc()
1074
+ # Try to get state even after error
1075
+ try:
1076
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("click_element_error_recovery")
1077
+ return self.build_action_result(
1078
+ False,
1079
+ str(e),
1080
+ dom_state,
1081
+ screenshot,
1082
+ elements,
1083
+ metadata,
1084
+ error=str(e),
1085
+ content=None
1086
+ )
1087
+ except:
1088
+ # Fallback if getting state also fails
1089
+ current_url = "unknown"
1090
+ try:
1091
+ current_url = page.url # Try to get at least the URL
1092
+ except:
1093
+ pass
1094
+ return self.build_action_result(
1095
+ False,
1096
+ str(e),
1097
+ None, # No DOM state available
1098
+ "", # No screenshot
1099
+ "", # No elements string
1100
+ {}, # Empty metadata
1101
+ error=str(e),
1102
+ content=None,
1103
+ fallback_url=current_url
1104
+ )
1105
+
1106
+ async def input_text(self, action: InputTextAction = Body(...)):
1107
+ """Input text into an element"""
1108
+ try:
1109
+ page = await self.get_current_page()
1110
+ selector_map = await self.get_selector_map()
1111
+
1112
+ if action.index not in selector_map:
1113
+ return self.build_action_result(
1114
+ False,
1115
+ f"Element with index {action.index} not found",
1116
+ None,
1117
+ "",
1118
+ "",
1119
+ {},
1120
+ error=f"Element with index {action.index} not found"
1121
+ )
1122
+
1123
+ # In a real implementation, we would use the selector map to get the element's
1124
+ # properties and use them to find and type into the element
1125
+ element = selector_map[action.index]
1126
+
1127
+ # Use CSS selector or XPath to locate and type into the element
1128
+ await page.wait_for_timeout(500) # Small delay before typing
1129
+
1130
+ # Demo implementation - would use proper selectors in production
1131
+ if element.attributes.get("id"):
1132
+ await page.fill(f"#{element.attributes['id']}", action.text)
1133
+ elif element.attributes.get("class"):
1134
+ class_selector = f".{element.attributes['class'].replace(' ', '.')}"
1135
+ await page.fill(class_selector, action.text)
1136
+ else:
1137
+ # Fallback to xpath
1138
+ await page.fill(f"//{element.tag_name}[{action.index}]", action.text)
1139
+
1140
+ # Get updated state after action
1141
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"input_text({action.index}, '{action.text}')")
1142
+
1143
+ return self.build_action_result(
1144
+ True,
1145
+ f"Input '{action.text}' into element with index {action.index}",
1146
+ dom_state,
1147
+ screenshot,
1148
+ elements,
1149
+ metadata,
1150
+ error="",
1151
+ content=None
1152
+ )
1153
+ except Exception as e:
1154
+ return self.build_action_result(
1155
+ False,
1156
+ str(e),
1157
+ None,
1158
+ "",
1159
+ "",
1160
+ {},
1161
+ error=str(e),
1162
+ content=None
1163
+ )
1164
+
1165
+ async def send_keys(self, action: SendKeysAction = Body(...)):
1166
+ """Send keyboard keys"""
1167
+ try:
1168
+ page = await self.get_current_page()
1169
+ await page.keyboard.press(action.keys)
1170
+
1171
+ # Get updated state after action
1172
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"send_keys({action.keys})")
1173
+
1174
+ return self.build_action_result(
1175
+ True,
1176
+ f"Sent keys: {action.keys}",
1177
+ dom_state,
1178
+ screenshot,
1179
+ elements,
1180
+ metadata,
1181
+ error="",
1182
+ content=None
1183
+ )
1184
+ except Exception as e:
1185
+ return self.build_action_result(
1186
+ False,
1187
+ str(e),
1188
+ None,
1189
+ "",
1190
+ "",
1191
+ {},
1192
+ error=str(e),
1193
+ content=None
1194
+ )
1195
+
1196
+ # Tab Management Actions
1197
+
1198
+ async def switch_tab(self, action: SwitchTabAction = Body(...)):
1199
+ """Switch to a different tab by index"""
1200
+ try:
1201
+ if 0 <= action.page_id < len(self.pages):
1202
+ self.current_page_index = action.page_id
1203
+ page = await self.get_current_page()
1204
+ await page.wait_for_load_state()
1205
+
1206
+ # Get updated state after action
1207
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"switch_tab({action.page_id})")
1208
+
1209
+ return self.build_action_result(
1210
+ True,
1211
+ f"Switched to tab {action.page_id}",
1212
+ dom_state,
1213
+ screenshot,
1214
+ elements,
1215
+ metadata,
1216
+ error="",
1217
+ content=None
1218
+ )
1219
+ else:
1220
+ return self.build_action_result(
1221
+ False,
1222
+ f"Tab {action.page_id} not found",
1223
+ None,
1224
+ "",
1225
+ "",
1226
+ {},
1227
+ error=f"Tab {action.page_id} not found"
1228
+ )
1229
+ except Exception as e:
1230
+ return self.build_action_result(
1231
+ False,
1232
+ str(e),
1233
+ None,
1234
+ "",
1235
+ "",
1236
+ {},
1237
+ error=str(e),
1238
+ content=None
1239
+ )
1240
+
1241
+ async def open_tab(self, action: OpenTabAction = Body(...)):
1242
+ """Open a new tab with the specified URL"""
1243
+ try:
1244
+ print(f"Attempting to open new tab with URL: {action.url}")
1245
+ # Create new page in same browser instance
1246
+ new_page = await self.browser.new_page()
1247
+ print(f"New page created successfully")
1248
+
1249
+ # Navigate to the URL
1250
+ await new_page.goto(action.url, wait_until="domcontentloaded")
1251
+ await new_page.wait_for_load_state("networkidle", timeout=10000)
1252
+ print(f"Navigated to URL in new tab: {action.url}")
1253
+
1254
+ # Add to page list and make it current
1255
+ self.pages.append(new_page)
1256
+ self.current_page_index = len(self.pages) - 1
1257
+ print(f"New tab added as index {self.current_page_index}")
1258
+
1259
+ # Get updated state after action
1260
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"open_tab({action.url})")
1261
+
1262
+ return self.build_action_result(
1263
+ True,
1264
+ f"Opened new tab with URL: {action.url}",
1265
+ dom_state,
1266
+ screenshot,
1267
+ elements,
1268
+ metadata,
1269
+ error="",
1270
+ content=None
1271
+ )
1272
+ except Exception as e:
1273
+ print("****"*10)
1274
+ print(f"Error opening tab: {e}")
1275
+ print(traceback.format_exc())
1276
+ print("****"*10)
1277
+ return self.build_action_result(
1278
+ False,
1279
+ str(e),
1280
+ None,
1281
+ "",
1282
+ "",
1283
+ {},
1284
+ error=str(e),
1285
+ content=None
1286
+ )
1287
+
1288
+ async def close_tab(self, action: CloseTabAction = Body(...)):
1289
+ """Close a tab by index"""
1290
+ try:
1291
+ if 0 <= action.page_id < len(self.pages):
1292
+ page = self.pages[action.page_id]
1293
+ url = page.url
1294
+ await page.close()
1295
+ self.pages.pop(action.page_id)
1296
+
1297
+ # Adjust current index if needed
1298
+ if self.current_page_index >= len(self.pages):
1299
+ self.current_page_index = max(0, len(self.pages) - 1)
1300
+ elif self.current_page_index >= action.page_id:
1301
+ self.current_page_index = max(0, self.current_page_index - 1)
1302
+
1303
+ # Get updated state after action
1304
+ page = await self.get_current_page()
1305
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"close_tab({action.page_id})")
1306
+
1307
+ return self.build_action_result(
1308
+ True,
1309
+ f"Closed tab {action.page_id} with URL: {url}",
1310
+ dom_state,
1311
+ screenshot,
1312
+ elements,
1313
+ metadata,
1314
+ error="",
1315
+ content=None
1316
+ )
1317
+ else:
1318
+ return self.build_action_result(
1319
+ False,
1320
+ f"Tab {action.page_id} not found",
1321
+ None,
1322
+ "",
1323
+ "",
1324
+ {},
1325
+ error=f"Tab {action.page_id} not found"
1326
+ )
1327
+ except Exception as e:
1328
+ return self.build_action_result(
1329
+ False,
1330
+ str(e),
1331
+ None,
1332
+ "",
1333
+ "",
1334
+ {},
1335
+ error=str(e),
1336
+ content=None
1337
+ )
1338
+
1339
+ # Content Actions
1340
+
1341
+ async def extract_content(self, goal: str = Body(...)):
1342
+ """Extract content from the current page based on the provided goal"""
1343
+ try:
1344
+ page = await self.get_current_page()
1345
+ content = await page.content()
1346
+
1347
+ # In a full implementation, we would use an LLM to extract specific content
1348
+ # based on the goal. For this example, we'll extract visible text.
1349
+ extracted_text = await page.evaluate("""
1350
+ Array.from(document.querySelectorAll('p, h1, h2, h3, h4, h5, h6, li, span, div'))
1351
+ .filter(el => {
1352
+ const style = window.getComputedStyle(el);
1353
+ return style.display !== 'none' &&
1354
+ style.visibility !== 'hidden' &&
1355
+ style.opacity !== '0' &&
1356
+ el.innerText &&
1357
+ el.innerText.trim().length > 0;
1358
+ })
1359
+ .map(el => el.innerText.trim())
1360
+ .join('\\n\\n');
1361
+ """)
1362
+
1363
+ # Get updated state
1364
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"extract_content({goal})")
1365
+
1366
+ return self.build_action_result(
1367
+ True,
1368
+ f"Content extracted based on goal: {goal}",
1369
+ dom_state,
1370
+ screenshot,
1371
+ elements,
1372
+ metadata,
1373
+ error="",
1374
+ content=extracted_text
1375
+ )
1376
+ except Exception as e:
1377
+ return self.build_action_result(
1378
+ False,
1379
+ str(e),
1380
+ None,
1381
+ "",
1382
+ "",
1383
+ {},
1384
+ error=str(e),
1385
+ content=None
1386
+ )
1387
+
1388
+ async def save_pdf(self):
1389
+ """Save the current page as a PDF"""
1390
+ try:
1391
+ page = await self.get_current_page()
1392
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
1393
+ random_id = random.randint(1000, 9999)
1394
+ filename = f"page_{timestamp}_{random_id}.pdf"
1395
+ filepath = os.path.join(self.screenshot_dir, filename)
1396
+
1397
+ await page.pdf(path=filepath)
1398
+
1399
+ # Get updated state
1400
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state("save_pdf")
1401
+
1402
+ return self.build_action_result(
1403
+ True,
1404
+ f"Saved page as PDF: {filepath}",
1405
+ dom_state,
1406
+ screenshot,
1407
+ elements,
1408
+ metadata,
1409
+ error="",
1410
+ content=None
1411
+ )
1412
+ except Exception as e:
1413
+ return self.build_action_result(
1414
+ False,
1415
+ str(e),
1416
+ None,
1417
+ "",
1418
+ "",
1419
+ {},
1420
+ error=str(e),
1421
+ content=None
1422
+ )
1423
+
1424
+ # Scroll Actions
1425
+
1426
+ async def scroll_down(self, action: ScrollAction = Body(...)):
1427
+ """Scroll down the page"""
1428
+ try:
1429
+ page = await self.get_current_page()
1430
+ if action.amount is not None:
1431
+ await page.evaluate(f"window.scrollBy(0, {action.amount});")
1432
+ amount_str = f"{action.amount} pixels"
1433
+ else:
1434
+ await page.evaluate("window.scrollBy(0, window.innerHeight);")
1435
+ amount_str = "one page"
1436
+
1437
+ await page.wait_for_timeout(500) # Wait for scroll to complete
1438
+
1439
+ # Get updated state after action
1440
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_down({amount_str})")
1441
+
1442
+ return self.build_action_result(
1443
+ True,
1444
+ f"Scrolled down by {amount_str}",
1445
+ dom_state,
1446
+ screenshot,
1447
+ elements,
1448
+ metadata,
1449
+ error="",
1450
+ content=None
1451
+ )
1452
+ except Exception as e:
1453
+ return self.build_action_result(
1454
+ False,
1455
+ str(e),
1456
+ None,
1457
+ "",
1458
+ "",
1459
+ {},
1460
+ error=str(e),
1461
+ content=None
1462
+ )
1463
+
1464
+ async def scroll_up(self, action: ScrollAction = Body(...)):
1465
+ """Scroll up the page"""
1466
+ try:
1467
+ page = await self.get_current_page()
1468
+ if action.amount is not None:
1469
+ await page.evaluate(f"window.scrollBy(0, -{action.amount});")
1470
+ amount_str = f"{action.amount} pixels"
1471
+ else:
1472
+ await page.evaluate("window.scrollBy(0, -window.innerHeight);")
1473
+ amount_str = "one page"
1474
+
1475
+ await page.wait_for_timeout(500) # Wait for scroll to complete
1476
+
1477
+ # Get updated state after action
1478
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_up({amount_str})")
1479
+
1480
+ return self.build_action_result(
1481
+ True,
1482
+ f"Scrolled up by {amount_str}",
1483
+ dom_state,
1484
+ screenshot,
1485
+ elements,
1486
+ metadata,
1487
+ error="",
1488
+ content=None
1489
+ )
1490
+ except Exception as e:
1491
+ return self.build_action_result(
1492
+ False,
1493
+ str(e),
1494
+ None,
1495
+ "",
1496
+ "",
1497
+ {},
1498
+ error=str(e),
1499
+ content=None
1500
+ )
1501
+
1502
+ async def scroll_to_text(self, text: str = Body(...)):
1503
+ """Scroll to text on the page"""
1504
+ try:
1505
+ page = await self.get_current_page()
1506
+ locators = [
1507
+ page.get_by_text(text, exact=False),
1508
+ page.locator(f"text={text}"),
1509
+ page.locator(f"//*[contains(text(), '{text}')]"),
1510
+ ]
1511
+
1512
+ found = False
1513
+ for locator in locators:
1514
+ try:
1515
+ if await locator.count() > 0 and await locator.first.is_visible():
1516
+ await locator.first.scroll_into_view_if_needed()
1517
+ await asyncio.sleep(0.5) # Wait for scroll to complete
1518
+ found = True
1519
+ break
1520
+ except Exception:
1521
+ continue
1522
+
1523
+ # Get updated state after action
1524
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"scroll_to_text({text})")
1525
+
1526
+ message = f"Scrolled to text: {text}" if found else f"Text '{text}' not found or not visible on page"
1527
+
1528
+ return self.build_action_result(
1529
+ found,
1530
+ message,
1531
+ dom_state,
1532
+ screenshot,
1533
+ elements,
1534
+ metadata,
1535
+ error="",
1536
+ content=None
1537
+ )
1538
+ except Exception as e:
1539
+ return self.build_action_result(
1540
+ False,
1541
+ str(e),
1542
+ None,
1543
+ "",
1544
+ "",
1545
+ {},
1546
+ error=str(e),
1547
+ content=None
1548
+ )
1549
+
1550
+ # Dropdown Actions
1551
+
1552
+ async def get_dropdown_options(self, index: int = Body(...)):
1553
+ """Get all options from a dropdown"""
1554
+ try:
1555
+ page = await self.get_current_page()
1556
+ selector_map = await self.get_selector_map()
1557
+
1558
+ if index not in selector_map:
1559
+ return self.build_action_result(
1560
+ False,
1561
+ f"Element with index {index} not found",
1562
+ None,
1563
+ "",
1564
+ "",
1565
+ {},
1566
+ error=f"Element with index {index} not found"
1567
+ )
1568
+
1569
+ element = selector_map[index]
1570
+ options = []
1571
+
1572
+ # Try to get the options - in a real implementation, we would use appropriate selectors
1573
+ try:
1574
+ if element.tag_name.lower() == 'select':
1575
+ # For <select> elements, get options using JavaScript
1576
+ options_js = f"""
1577
+ Array.from(document.querySelectorAll('select')[{index-1}].options)
1578
+ .map((option, index) => ({
1579
+ index: index,
1580
+ text: option.text,
1581
+ value: option.value
1582
+ }));
1583
+ """
1584
+ options = await page.evaluate(options_js)
1585
+ else:
1586
+ # For other dropdown types, try to get options using a more generic approach
1587
+ # Example for custom dropdowns - would need refinement in real implementation
1588
+ await page.click(f"#{element.attributes.get('id')}") if element.attributes.get('id') else None
1589
+ await page.wait_for_timeout(500)
1590
+
1591
+ options_js = """
1592
+ Array.from(document.querySelectorAll('.dropdown-item, [role="option"], li'))
1593
+ .filter(el => {
1594
+ const style = window.getComputedStyle(el);
1595
+ return style.display !== 'none' && style.visibility !== 'hidden';
1596
+ })
1597
+ .map((option, index) => ({
1598
+ index: index,
1599
+ text: option.innerText.trim(),
1600
+ value: option.getAttribute('value') || option.getAttribute('data-value') || option.innerText.trim()
1601
+ }));
1602
+ """
1603
+ options = await page.evaluate(options_js)
1604
+
1605
+ # Close dropdown to restore state
1606
+ await page.keyboard.press("Escape")
1607
+ except Exception as e:
1608
+ self.logger.error(f"Error getting dropdown options: {e}")
1609
+ # Fallback to dummy options if real ones cannot be retrieved
1610
+ options = [
1611
+ {"index": 0, "text": "Option 1", "value": "option1"},
1612
+ {"index": 1, "text": "Option 2", "value": "option2"},
1613
+ {"index": 2, "text": "Option 3", "value": "option3"},
1614
+ ]
1615
+
1616
+ # Get updated state
1617
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"get_dropdown_options({index})")
1618
+
1619
+ return self.build_action_result(
1620
+ True,
1621
+ f"Retrieved {len(options)} options from dropdown",
1622
+ dom_state,
1623
+ screenshot,
1624
+ elements,
1625
+ metadata,
1626
+ error="",
1627
+ content=json.dumps(options) # Include options in the content field
1628
+ )
1629
+ except Exception as e:
1630
+ return self.build_action_result(
1631
+ False,
1632
+ str(e),
1633
+ None,
1634
+ "",
1635
+ "",
1636
+ {},
1637
+ error=str(e),
1638
+ content=None
1639
+ )
1640
+
1641
+ async def select_dropdown_option(self, index: int = Body(...), option_text: str = Body(...)):
1642
+ """Select an option from a dropdown by text"""
1643
+ try:
1644
+ page = await self.get_current_page()
1645
+ selector_map = await self.get_selector_map()
1646
+
1647
+ if index not in selector_map:
1648
+ return self.build_action_result(
1649
+ False,
1650
+ f"Element with index {index} not found",
1651
+ None,
1652
+ "",
1653
+ "",
1654
+ {},
1655
+ error=f"Element with index {index} not found"
1656
+ )
1657
+
1658
+ element = selector_map[index]
1659
+
1660
+ # Try to select the option - implementation varies by dropdown type
1661
+ if element.tag_name.lower() == 'select':
1662
+ # For standard <select> elements
1663
+ selector = f"select option:has-text('{option_text}')"
1664
+ await page.select_option(
1665
+ f"#{element.attributes.get('id')}" if element.attributes.get('id') else f"//select[{index}]",
1666
+ label=option_text
1667
+ )
1668
+ else:
1669
+ # For custom dropdowns
1670
+ # First click to open the dropdown
1671
+ if element.attributes.get('id'):
1672
+ await page.click(f"#{element.attributes.get('id')}")
1673
+ else:
1674
+ await page.click(f"//{element.tag_name}[{index}]")
1675
+
1676
+ await page.wait_for_timeout(500)
1677
+
1678
+ # Then try to click the option
1679
+ await page.click(f"text={option_text}")
1680
+
1681
+ await page.wait_for_timeout(500)
1682
+
1683
+ # Get updated state after action
1684
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"select_dropdown_option({index}, '{option_text}')")
1685
+
1686
+ return self.build_action_result(
1687
+ True,
1688
+ f"Selected option '{option_text}' from dropdown with index {index}",
1689
+ dom_state,
1690
+ screenshot,
1691
+ elements,
1692
+ metadata,
1693
+ error="",
1694
+ content=None
1695
+ )
1696
+ except Exception as e:
1697
+ return self.build_action_result(
1698
+ False,
1699
+ str(e),
1700
+ None,
1701
+ "",
1702
+ "",
1703
+ {},
1704
+ error=str(e),
1705
+ content=None
1706
+ )
1707
+
1708
+ # Drag and Drop
1709
+
1710
+ async def drag_drop(self, action: DragDropAction = Body(...)):
1711
+ """Perform drag and drop operation"""
1712
+ try:
1713
+ page = await self.get_current_page()
1714
+
1715
+ # Element-based drag and drop
1716
+ if action.element_source and action.element_target:
1717
+ # In a real implementation, we would get the elements and perform the drag
1718
+ source_desc = action.element_source
1719
+ target_desc = action.element_target
1720
+
1721
+ # We would locate the elements using selectors and perform the drag
1722
+ # For this example, we'll use a simplified version
1723
+ await page.evaluate("""
1724
+ console.log("Simulating drag and drop between elements");
1725
+ """)
1726
+
1727
+ message = f"Dragged element '{source_desc}' to '{target_desc}'"
1728
+
1729
+ # Coordinate-based drag and drop
1730
+ elif all(coord is not None for coord in [
1731
+ action.coord_source_x, action.coord_source_y,
1732
+ action.coord_target_x, action.coord_target_y
1733
+ ]):
1734
+ source_x = action.coord_source_x
1735
+ source_y = action.coord_source_y
1736
+ target_x = action.coord_target_x
1737
+ target_y = action.coord_target_y
1738
+
1739
+ # Perform the drag
1740
+ await page.mouse.move(source_x, source_y)
1741
+ await page.mouse.down()
1742
+
1743
+ steps = max(1, action.steps or 10)
1744
+ delay_ms = max(0, action.delay_ms or 5)
1745
+
1746
+ for i in range(1, steps + 1):
1747
+ ratio = i / steps
1748
+ intermediate_x = int(source_x + (target_x - source_x) * ratio)
1749
+ intermediate_y = int(source_y + (target_y - source_y) * ratio)
1750
+ await page.mouse.move(intermediate_x, intermediate_y)
1751
+ if delay_ms > 0:
1752
+ await asyncio.sleep(delay_ms / 1000)
1753
+
1754
+ await page.mouse.move(target_x, target_y)
1755
+ await page.mouse.up()
1756
+
1757
+ message = f"Dragged from ({source_x}, {source_y}) to ({target_x}, {target_y})"
1758
+ else:
1759
+ return self.build_action_result(
1760
+ False,
1761
+ "Must provide either source/target selectors or coordinates",
1762
+ None,
1763
+ "",
1764
+ "",
1765
+ {},
1766
+ error="Must provide either source/target selectors or coordinates"
1767
+ )
1768
+
1769
+ # Get updated state after action
1770
+ dom_state, screenshot, elements, metadata = await self.get_updated_browser_state(f"drag_drop({action.element_source}, {action.element_target})")
1771
+
1772
+ return self.build_action_result(
1773
+ True,
1774
+ message,
1775
+ dom_state,
1776
+ screenshot,
1777
+ elements,
1778
+ metadata,
1779
+ error="",
1780
+ content=None
1781
+ )
1782
+ except Exception as e:
1783
+ return self.build_action_result(
1784
+ False,
1785
+ str(e),
1786
+ None,
1787
+ "",
1788
+ "",
1789
+ {},
1790
+ error=str(e),
1791
+ content=None
1792
+ )
1793
+
1794
+ # Create singleton instance
1795
+ automation_service = BrowserAutomation()
1796
+
1797
+ # Create API app
1798
+ api_app = FastAPI()
1799
+
1800
+ @api_app.get("/api")
1801
+ async def health_check():
1802
+ return {"status": "ok", "message": "API server is running"}
1803
+
1804
+ # Include automation service router with /api prefix
1805
+ api_app.include_router(automation_service.router, prefix="/api")
1806
+
1807
+ async def test_browser_api():
1808
+ """Test the browser automation API functionality"""
1809
+ try:
1810
+ # Initialize browser automation
1811
+ print("\n=== Starting Browser Automation Test ===")
1812
+ await automation_service.startup()
1813
+ print("✅ Browser started successfully")
1814
+
1815
+ # Navigate to a test page with interactive elements
1816
+ print("\n--- Testing Navigation ---")
1817
+ result = await automation_service.navigate_to(GoToUrlAction(url="https://www.youtube.com"))
1818
+ print(f"Navigation status: {'✅ Success' if result.success else '❌ Failed'}")
1819
+ if not result.success:
1820
+ print(f"Error: {result.error}")
1821
+ return
1822
+
1823
+ print(f"URL: {result.url}")
1824
+ print(f"Title: {result.title}")
1825
+
1826
+ # Check DOM state and elements
1827
+ print(f"\nFound {result.element_count} interactive elements")
1828
+ if result.elements and result.elements.strip():
1829
+ print("Elements:")
1830
+ print(result.elements)
1831
+ else:
1832
+ print("No formatted elements found, but DOM was processed")
1833
+
1834
+ # Display interactive elements as JSON
1835
+ if result.interactive_elements and len(result.interactive_elements) > 0:
1836
+ print("\nInteractive elements summary:")
1837
+ for el in result.interactive_elements:
1838
+ print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
1839
+
1840
+ # Screenshot info
1841
+ print(f"\nScreenshot captured: {'Yes' if result.screenshot_base64 else 'No'}")
1842
+ print(f"Viewport size: {result.viewport_width}x{result.viewport_height}")
1843
+
1844
+ # Test OCR extraction from screenshot
1845
+ print("\n--- Testing OCR Text Extraction ---")
1846
+ if result.ocr_text:
1847
+ print("OCR text extracted from screenshot:")
1848
+ print("=== OCR TEXT START ===")
1849
+ print(result.ocr_text)
1850
+ print("=== OCR TEXT END ===")
1851
+ print(f"OCR text length: {len(result.ocr_text)} characters")
1852
+ print(result.ocr_text)
1853
+ else:
1854
+ print("No OCR text extracted from screenshot")
1855
+
1856
+ await asyncio.sleep(2)
1857
+
1858
+ # Test search functionality
1859
+ print("\n--- Testing Search ---")
1860
+ result = await automation_service.search_google(SearchGoogleAction(query="browser automation"))
1861
+ print(f"Search status: {'✅ Success' if result.success else '❌ Failed'}")
1862
+ if not result.success:
1863
+ print(f"Error: {result.error}")
1864
+ else:
1865
+ print(f"Found {result.element_count} elements after search")
1866
+ print(f"Page title: {result.title}")
1867
+
1868
+ # Test OCR extraction from search results
1869
+ if result.ocr_text:
1870
+ print("\nOCR text from search results:")
1871
+ print("=== OCR TEXT START ===")
1872
+ print(result.ocr_text)
1873
+ print("=== OCR TEXT END ===")
1874
+ else:
1875
+ print("\nNo OCR text extracted from search results")
1876
+
1877
+ await asyncio.sleep(2)
1878
+
1879
+ # Test scrolling
1880
+ print("\n--- Testing Scrolling ---")
1881
+ result = await automation_service.scroll_down(ScrollAction(amount=300))
1882
+ print(f"Scroll status: {'✅ Success' if result.success else '❌ Failed'}")
1883
+ if result.success:
1884
+ print(f"Pixels above viewport: {result.pixels_above}")
1885
+ print(f"Pixels below viewport: {result.pixels_below}")
1886
+
1887
+ await asyncio.sleep(2)
1888
+
1889
+ # Test clicking on an element
1890
+ print("\n--- Testing Element Click ---")
1891
+ if result.element_count > 0:
1892
+ click_result = await automation_service.click_element(ClickElementAction(index=1))
1893
+ print(f"Click status: {'✅ Success' if click_result.success else '❌ Failed'}")
1894
+ print(f"Message: {click_result.message}")
1895
+ print(f"New URL after click: {click_result.url}")
1896
+ else:
1897
+ print("Skipping click test - no elements found")
1898
+
1899
+ await asyncio.sleep(2)
1900
+
1901
+ # Test clicking on coordinates
1902
+ print("\n--- Testing Click Coordinates ---")
1903
+ coord_click_result = await automation_service.click_coordinates(ClickCoordinatesAction(x=100, y=100))
1904
+ print(f"Coordinate click status: {'✅ Success' if coord_click_result.success else '❌ Failed'}")
1905
+ print(f"Message: {coord_click_result.message}")
1906
+ print(f"URL after coordinate click: {coord_click_result.url}")
1907
+
1908
+ await asyncio.sleep(2)
1909
+
1910
+ # Test extracting content
1911
+ print("\n--- Testing Content Extraction ---")
1912
+ content_result = await automation_service.extract_content("test goal")
1913
+ print(f"Content extraction status: {'✅ Success' if content_result.success else '❌ Failed'}")
1914
+ if content_result.content:
1915
+ content_preview = content_result.content[:100] + "..." if len(content_result.content) > 100 else content_result.content
1916
+ print(f"Content sample: {content_preview}")
1917
+ print(f"Total content length: {len(content_result.content)} chars")
1918
+ else:
1919
+ print("No content was extracted")
1920
+
1921
+ # Test tab management
1922
+ print("\n--- Testing Tab Management ---")
1923
+ tab_result = await automation_service.open_tab(OpenTabAction(url="https://www.example.org"))
1924
+ print(f"New tab status: {'✅ Success' if tab_result.success else '❌ Failed'}")
1925
+ if tab_result.success:
1926
+ print(f"New tab title: {tab_result.title}")
1927
+ print(f"Interactive elements: {tab_result.element_count}")
1928
+
1929
+ print("\n✅ All tests completed successfully!")
1930
+
1931
+ except Exception as e:
1932
+ print(f"\n❌ Test failed: {str(e)}")
1933
+ traceback.print_exc()
1934
+ finally:
1935
+ # Ensure browser is closed
1936
+ print("\n--- Cleaning up ---")
1937
+ await automation_service.shutdown()
1938
+ print("Browser closed")
1939
+
1940
+ async def test_browser_api_2():
1941
+ """Test the browser automation API functionality on the chess page"""
1942
+ try:
1943
+ # Initialize browser automation
1944
+ print("\n=== Starting Browser Automation Test 2 (Chess Page) ===")
1945
+ await automation_service.startup()
1946
+ print("✅ Browser started successfully")
1947
+
1948
+ # Navigate to the chess test page
1949
+ print("\n--- Testing Navigation to Chess Page ---")
1950
+ test_url = "https://dat-lequoc.github.io/chess-for-suna/chess.html"
1951
+ result = await automation_service.navigate_to(GoToUrlAction(url=test_url))
1952
+ print(f"Navigation status: {'✅ Success' if result.success else '❌ Failed'}")
1953
+ if not result.success:
1954
+ print(f"Error: {result.error}")
1955
+ return
1956
+
1957
+ print(f"URL: {result.url}")
1958
+ print(f"Title: {result.title}")
1959
+
1960
+ # Check DOM state and elements
1961
+ print(f"\nFound {result.element_count} interactive elements")
1962
+ if result.elements and result.elements.strip():
1963
+ print("Elements:")
1964
+ print(result.elements)
1965
+ else:
1966
+ print("No formatted elements found, but DOM was processed")
1967
+
1968
+ # Display interactive elements as JSON
1969
+ if result.interactive_elements and len(result.interactive_elements) > 0:
1970
+ print("\nInteractive elements summary:")
1971
+ for el in result.interactive_elements:
1972
+ print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
1973
+
1974
+ # Screenshot info
1975
+ print(f"\nScreenshot captured: {'Yes' if result.screenshot_base64 else 'No'}")
1976
+ print(f"Viewport size: {result.viewport_width}x{result.viewport_height}")
1977
+
1978
+ await asyncio.sleep(2)
1979
+
1980
+ # Test clicking on an element (e.g., a chess square)
1981
+ print("\n--- Testing Element Click (element 5) ---")
1982
+ if result.element_count > 4: # Ensure element 5 exists
1983
+ click_index = 5
1984
+ click_result = await automation_service.click_element(ClickElementAction(index=click_index))
1985
+ print(f"Click status for element {click_index}: {'✅ Success' if click_result.success else '❌ Failed'}")
1986
+ print(f"Message: {click_result.message}")
1987
+ print(f"URL after click: {click_result.url}")
1988
+
1989
+ # Retrieve and display elements again after click
1990
+ print(f"\n--- Retrieving elements after clicking element {click_index} ---")
1991
+ if click_result.elements and click_result.elements.strip():
1992
+ print("Updated Elements:")
1993
+ print(click_result.elements)
1994
+ else:
1995
+ print("No formatted elements found after click.")
1996
+
1997
+ if click_result.interactive_elements and len(click_result.interactive_elements) > 0:
1998
+ print("\nUpdated interactive elements summary:")
1999
+ for el in click_result.interactive_elements:
2000
+ print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
2001
+ else:
2002
+ print("No interactive elements found after click.")
2003
+
2004
+ # Test clicking element 1 after the first click
2005
+ print("\n--- Testing Element Click (element 1 after clicking 5) ---")
2006
+ if click_result.element_count > 0: # Check if there are still elements
2007
+ click_index_2 = 1
2008
+ click_result_2 = await automation_service.click_element(ClickElementAction(index=click_index_2))
2009
+ print(f"Click status for element {click_index_2}: {'✅ Success' if click_result_2.success else '❌ Failed'}")
2010
+ print(f"Message: {click_result_2.message}")
2011
+ print(f"URL after click: {click_result_2.url}")
2012
+
2013
+ # Retrieve and display elements again after the second click
2014
+ print(f"\n--- Retrieving elements after clicking element {click_index_2} ---")
2015
+ if click_result_2.elements and click_result_2.elements.strip():
2016
+ print("Elements after second click:")
2017
+ print(click_result_2.elements)
2018
+ else:
2019
+ print("No formatted elements found after second click.")
2020
+
2021
+ if click_result_2.interactive_elements and len(click_result_2.interactive_elements) > 0:
2022
+ print("\nInteractive elements summary after second click:")
2023
+ for el in click_result_2.interactive_elements:
2024
+ print(f" [{el['index']}] <{el['tag_name']}> {el.get('text', '')[:30]}")
2025
+ else:
2026
+ print("No interactive elements found after second click.")
2027
+ else:
2028
+ print("Skipping second element click test - no elements found after first click.")
2029
+
2030
+ else:
2031
+ print("Skipping element click test - fewer than 5 elements found.")
2032
+
2033
+ await asyncio.sleep(2)
2034
+
2035
+ print("\n✅ Chess Page Test Completed!")
2036
+ await asyncio.sleep(100)
2037
+
2038
+ except Exception as e:
2039
+ print(f"\n❌ Chess Page Test failed: {str(e)}")
2040
+ traceback.print_exc()
2041
+ finally:
2042
+ # Ensure browser is closed
2043
+ print("\n--- Cleaning up ---")
2044
+ await automation_service.shutdown()
2045
+ print("Browser closed")
2046
+
2047
+ if __name__ == '__main__':
2048
+ import uvicorn
2049
+ import sys
2050
+
2051
+ # Check command line arguments for test mode
2052
+ test_mode_1 = "--test" in sys.argv
2053
+ test_mode_2 = "--test2" in sys.argv
2054
+
2055
+ if test_mode_1:
2056
+ print("Running in test mode 1")
2057
+ asyncio.run(test_browser_api())
2058
+ elif test_mode_2:
2059
+ print("Running in test mode 2 (Chess Page)")
2060
+ asyncio.run(test_browser_api_2())
2061
+ else:
2062
+ print("Starting API server")
2063
+ uvicorn.run("browser_api:api_app", host="0.0.0.0", port=8002)
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set audio_count = namespace(value=0) %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_bos|><|IMAGE|><|vision_eos|>{% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_audio_id %}Audio {{ audio_count.value }}: {% endif %}<|audio_bos|><|AUDIO|><|audio_eos|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_bos|><|VIDEO|><|vision_eos|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
+ }
cleanup.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Print colored output
4
+ GREEN='\033[0;32m'
5
+ BLUE='\033[0;34m'
6
+ RED='\033[0;31m'
7
+ NC='\033[0m' # No Color
8
+
9
+ echo -e "${RED}Cleaning up all services...${NC}"
10
+
11
+ # Determine the script and project directories
12
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
13
+ PROJECT_ROOT="$SCRIPT_DIR"
14
+ if [[ "$SCRIPT_DIR" == */scripts ]]; then
15
+ PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
16
+ fi
17
+
18
+ # Stop all running background processes from previous runs
19
+ echo -e "${BLUE}Stopping background processes...${NC}"
20
+ pkill -f "uvicorn api:app"
21
+ pkill -f "npm run dev"
22
+
23
+ # Stop Redis container if running
24
+ echo -e "${BLUE}Stopping Redis container...${NC}"
25
+ docker stop agentpress-redis 2>/dev/null || true
26
+ docker rm agentpress-redis 2>/dev/null || true
27
+
28
+ # Stop Supabase
29
+ echo -e "${BLUE}Stopping Supabase...${NC}"
30
+ cd "$PROJECT_ROOT/backend/supabase"
31
+ supabase stop 2>/dev/null || true
32
+ cd "$SCRIPT_DIR"
33
+
34
+ echo -e "${GREEN}Cleanup complete. You can now start the services again.${NC}"
compose-dev.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ app:
3
+ entrypoint:
4
+ - sleep
5
+ - infinity
6
+ image: docker/dev-environments-javascript:stable-1
7
+ init: true
8
+ volumes:
9
+ - type: bind
10
+ source: /var/run/docker.sock
11
+ target: /var/run/docker.sock
12
+
computer_use_tool.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import base64
4
+ import aiohttp
5
+ import asyncio
6
+ import logging
7
+ from typing import Optional, Dict, Any, Union
8
+ from PIL import Image
9
+
10
+ from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
11
+ from sandbox.sandbox import SandboxToolsBase, Sandbox
12
+
13
+ KEYBOARD_KEYS = [
14
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
15
+ 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
16
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
17
+ 'enter', 'esc', 'backspace', 'tab', 'space', 'delete',
18
+ 'ctrl', 'alt', 'shift', 'win',
19
+ 'up', 'down', 'left', 'right',
20
+ 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12',
21
+ 'ctrl+c', 'ctrl+v', 'ctrl+x', 'ctrl+z', 'ctrl+a', 'ctrl+s',
22
+ 'alt+tab', 'alt+f4', 'ctrl+alt+delete'
23
+ ]
24
+
25
+ class ComputerUseTool(SandboxToolsBase):
26
+ """Computer automation tool for controlling the sandbox browser and GUI."""
27
+
28
+ def __init__(self, sandbox: Sandbox):
29
+ """Initialize automation tool with sandbox connection."""
30
+ super().__init__(sandbox)
31
+ self.session = None
32
+ self.mouse_x = 0 # Track current mouse position
33
+ self.mouse_y = 0
34
+ # Get automation service URL using port 8000
35
+ self.api_base_url = self.sandbox.get_preview_link(8000)
36
+ logging.info(f"Initialized Computer Use Tool with API URL: {self.api_base_url}")
37
+
38
+ async def _get_session(self) -> aiohttp.ClientSession:
39
+ """Get or create aiohttp session for API requests."""
40
+ if self.session is None or self.session.closed:
41
+ self.session = aiohttp.ClientSession()
42
+ return self.session
43
+
44
+ async def _api_request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict:
45
+ """Send request to automation service API."""
46
+ try:
47
+ session = await self._get_session()
48
+ url = f"{self.api_base_url}/api{endpoint}"
49
+
50
+ logging.debug(f"API request: {method} {url} {data}")
51
+
52
+ if method.upper() == "GET":
53
+ async with session.get(url) as response:
54
+ result = await response.json()
55
+ else: # POST
56
+ async with session.post(url, json=data) as response:
57
+ result = await response.json()
58
+
59
+ logging.debug(f"API response: {result}")
60
+ return result
61
+
62
+ except Exception as e:
63
+ logging.error(f"API request failed: {str(e)}")
64
+ return {"success": False, "error": str(e)}
65
+
66
+ async def cleanup(self):
67
+ """Clean up resources."""
68
+ if self.session and not self.session.closed:
69
+ await self.session.close()
70
+ self.session = None
71
+
72
+ @openapi_schema({
73
+ "type": "function",
74
+ "function": {
75
+ "name": "move_to",
76
+ "description": "Move cursor to specified position",
77
+ "parameters": {
78
+ "type": "object",
79
+ "properties": {
80
+ "x": {
81
+ "type": "number",
82
+ "description": "X coordinate"
83
+ },
84
+ "y": {
85
+ "type": "number",
86
+ "description": "Y coordinate"
87
+ }
88
+ },
89
+ "required": ["x", "y"]
90
+ }
91
+ }
92
+ })
93
+ @xml_schema(
94
+ tag_name="move-to",
95
+ mappings=[
96
+ {"param_name": "x", "node_type": "attribute", "path": "."},
97
+ {"param_name": "y", "node_type": "attribute", "path": "."}
98
+ ],
99
+ example='''
100
+ <move-to x="100" y="200">
101
+ </move-to>
102
+ '''
103
+ )
104
+ async def move_to(self, x: float, y: float) -> ToolResult:
105
+ """Move cursor to specified position."""
106
+ try:
107
+ x_int = int(round(float(x)))
108
+ y_int = int(round(float(y)))
109
+
110
+ result = await self._api_request("POST", "/automation/mouse/move", {
111
+ "x": x_int,
112
+ "y": y_int
113
+ })
114
+
115
+ if result.get("success", False):
116
+ self.mouse_x = x_int
117
+ self.mouse_y = y_int
118
+ return ToolResult(success=True, output=f"Moved to ({x_int}, {y_int})")
119
+ else:
120
+ return ToolResult(success=False, output=f"Failed to move: {result.get('error', 'Unknown error')}")
121
+
122
+ except Exception as e:
123
+ return ToolResult(success=False, output=f"Failed to move: {str(e)}")
124
+
125
+ @openapi_schema({
126
+ "type": "function",
127
+ "function": {
128
+ "name": "click",
129
+ "description": "Click at current or specified position",
130
+ "parameters": {
131
+ "type": "object",
132
+ "properties": {
133
+ "button": {
134
+ "type": "string",
135
+ "description": "Mouse button to click",
136
+ "enum": ["left", "right", "middle"],
137
+ "default": "left"
138
+ },
139
+ "x": {
140
+ "type": "number",
141
+ "description": "Optional X coordinate"
142
+ },
143
+ "y": {
144
+ "type": "number",
145
+ "description": "Optional Y coordinate"
146
+ },
147
+ "num_clicks": {
148
+ "type": "integer",
149
+ "description": "Number of clicks",
150
+ "enum": [1, 2, 3],
151
+ "default": 1
152
+ }
153
+ }
154
+ }
155
+ }
156
+ })
157
+ @xml_schema(
158
+ tag_name="click",
159
+ mappings=[
160
+ {"param_name": "x", "node_type": "attribute", "path": "x"},
161
+ {"param_name": "y", "node_type": "attribute", "path": "y"},
162
+ {"param_name": "button", "node_type": "attribute", "path": "button"},
163
+ {"param_name": "num_clicks", "node_type": "attribute", "path": "num_clicks"}
164
+ ],
165
+ example='''
166
+ <click x="100" y="200" button="left" num_clicks="1">
167
+ </click>
168
+ '''
169
+ )
170
+ async def click(self, x: Optional[float] = None, y: Optional[float] = None,
171
+ button: str = "left", num_clicks: int = 1) -> ToolResult:
172
+ """Click at current or specified position."""
173
+ try:
174
+ x_val = x if x is not None else self.mouse_x
175
+ y_val = y if y is not None else self.mouse_y
176
+
177
+ x_int = int(round(float(x_val)))
178
+ y_int = int(round(float(y_val)))
179
+ num_clicks = int(num_clicks)
180
+
181
+ result = await self._api_request("POST", "/automation/mouse/click", {
182
+ "x": x_int,
183
+ "y": y_int,
184
+ "clicks": num_clicks,
185
+ "button": button.lower()
186
+ })
187
+
188
+ if result.get("success", False):
189
+ self.mouse_x = x_int
190
+ self.mouse_y = y_int
191
+ return ToolResult(success=True,
192
+ output=f"{num_clicks} {button} click(s) performed at ({x_int}, {y_int})")
193
+ else:
194
+ return ToolResult(success=False, output=f"Failed to click: {result.get('error', 'Unknown error')}")
195
+ except Exception as e:
196
+ return ToolResult(success=False, output=f"Failed to click: {str(e)}")
197
+
198
+ @openapi_schema({
199
+ "type": "function",
200
+ "function": {
201
+ "name": "scroll",
202
+ "description": "Scroll the mouse wheel at current position",
203
+ "parameters": {
204
+ "type": "object",
205
+ "properties": {
206
+ "amount": {
207
+ "type": "integer",
208
+ "description": "Scroll amount (positive for up, negative for down)",
209
+ "minimum": -10,
210
+ "maximum": 10
211
+ }
212
+ },
213
+ "required": ["amount"]
214
+ }
215
+ }
216
+ })
217
+ @xml_schema(
218
+ tag_name="scroll",
219
+ mappings=[
220
+ {"param_name": "amount", "node_type": "attribute", "path": "amount"}
221
+ ],
222
+ example='''
223
+ <scroll amount="-3">
224
+ </scroll>
225
+ '''
226
+ )
227
+ async def scroll(self, amount: int) -> ToolResult:
228
+ """
229
+ Scroll the mouse wheel at current position.
230
+ Positive values scroll up, negative values scroll down.
231
+ """
232
+ try:
233
+ amount = int(float(amount))
234
+ amount = max(-10, min(10, amount))
235
+
236
+ result = await self._api_request("POST", "/automation/mouse/scroll", {
237
+ "clicks": amount,
238
+ "x": self.mouse_x,
239
+ "y": self.mouse_y
240
+ })
241
+
242
+ if result.get("success", False):
243
+ direction = "up" if amount > 0 else "down"
244
+ steps = abs(amount)
245
+ return ToolResult(success=True,
246
+ output=f"Scrolled {direction} {steps} step(s) at position ({self.mouse_x}, {self.mouse_y})")
247
+ else:
248
+ return ToolResult(success=False, output=f"Failed to scroll: {result.get('error', 'Unknown error')}")
249
+ except Exception as e:
250
+ return ToolResult(success=False, output=f"Failed to scroll: {str(e)}")
251
+
252
+ @openapi_schema({
253
+ "type": "function",
254
+ "function": {
255
+ "name": "typing",
256
+ "description": "Type specified text",
257
+ "parameters": {
258
+ "type": "object",
259
+ "properties": {
260
+ "text": {
261
+ "type": "string",
262
+ "description": "Text to type"
263
+ }
264
+ },
265
+ "required": ["text"]
266
+ }
267
+ }
268
+ })
269
+ @xml_schema(
270
+ tag_name="typing",
271
+ mappings=[
272
+ {"param_name": "text", "node_type": "content", "path": "text"}
273
+ ],
274
+ example='''
275
+ <typing>Hello World!</typing>
276
+ '''
277
+ )
278
+ async def typing(self, text: str) -> ToolResult:
279
+ """Type specified text."""
280
+ try:
281
+ text = str(text)
282
+
283
+ result = await self._api_request("POST", "/automation/keyboard/write", {
284
+ "message": text,
285
+ "interval": 0.01
286
+ })
287
+
288
+ if result.get("success", False):
289
+ return ToolResult(success=True, output=f"Typed: {text}")
290
+ else:
291
+ return ToolResult(success=False, output=f"Failed to type: {result.get('error', 'Unknown error')}")
292
+ except Exception as e:
293
+ return ToolResult(success=False, output=f"Failed to type: {str(e)}")
294
+
295
+ @openapi_schema({
296
+ "type": "function",
297
+ "function": {
298
+ "name": "press",
299
+ "description": "Press and release a key",
300
+ "parameters": {
301
+ "type": "object",
302
+ "properties": {
303
+ "key": {
304
+ "type": "string",
305
+ "description": "Key to press",
306
+ "enum": KEYBOARD_KEYS
307
+ }
308
+ },
309
+ "required": ["key"]
310
+ }
311
+ }
312
+ })
313
+ @xml_schema(
314
+ tag_name="press",
315
+ mappings=[
316
+ {"param_name": "key", "node_type": "attribute", "path": "key"}
317
+ ],
318
+ example='''
319
+ <press key="enter">
320
+ </press>
321
+ '''
322
+ )
323
+ async def press(self, key: str) -> ToolResult:
324
+ """Press and release a key."""
325
+ try:
326
+ key = str(key).lower()
327
+
328
+ result = await self._api_request("POST", "/automation/keyboard/press", {
329
+ "keys": key,
330
+ "presses": 1
331
+ })
332
+
333
+ if result.get("success", False):
334
+ return ToolResult(success=True, output=f"Pressed key: {key}")
335
+ else:
336
+ return ToolResult(success=False, output=f"Failed to press key: {result.get('error', 'Unknown error')}")
337
+ except Exception as e:
338
+ return ToolResult(success=False, output=f"Failed to press key: {str(e)}")
339
+
340
+ @openapi_schema({
341
+ "type": "function",
342
+ "function": {
343
+ "name": "wait",
344
+ "description": "Wait for specified duration",
345
+ "parameters": {
346
+ "type": "object",
347
+ "properties": {
348
+ "duration": {
349
+ "type": "number",
350
+ "description": "Duration in seconds",
351
+ "default": 0.5
352
+ }
353
+ }
354
+ }
355
+ }
356
+ })
357
+ @xml_schema(
358
+ tag_name="wait",
359
+ mappings=[
360
+ {"param_name": "duration", "node_type": "attribute", "path": "duration"}
361
+ ],
362
+ example='''
363
+ <wait duration="1.5">
364
+ </wait>
365
+ '''
366
+ )
367
+ async def wait(self, duration: float = 0.5) -> ToolResult:
368
+ """Wait for specified duration."""
369
+ try:
370
+ duration = float(duration)
371
+ duration = max(0, min(10, duration))
372
+ await asyncio.sleep(duration)
373
+ return ToolResult(success=True, output=f"Waited {duration} seconds")
374
+ except Exception as e:
375
+ return ToolResult(success=False, output=f"Failed to wait: {str(e)}")
376
+
377
+ @openapi_schema({
378
+ "type": "function",
379
+ "function": {
380
+ "name": "mouse_down",
381
+ "description": "Press a mouse button",
382
+ "parameters": {
383
+ "type": "object",
384
+ "properties": {
385
+ "button": {
386
+ "type": "string",
387
+ "description": "Mouse button to press",
388
+ "enum": ["left", "right", "middle"],
389
+ "default": "left"
390
+ }
391
+ }
392
+ }
393
+ }
394
+ })
395
+ @xml_schema(
396
+ tag_name="mouse-down",
397
+ mappings=[
398
+ {"param_name": "button", "node_type": "attribute", "path": "button"}
399
+ ],
400
+ example='''
401
+ <mouse-down button="left">
402
+ </mouse-down>
403
+ '''
404
+ )
405
+ async def mouse_down(self, button: str = "left", x: Optional[float] = None, y: Optional[float] = None) -> ToolResult:
406
+ """Press a mouse button at current or specified position."""
407
+ try:
408
+ x_val = x if x is not None else self.mouse_x
409
+ y_val = y if y is not None else self.mouse_y
410
+
411
+ x_int = int(round(float(x_val)))
412
+ y_int = int(round(float(y_val)))
413
+
414
+ result = await self._api_request("POST", "/automation/mouse/down", {
415
+ "x": x_int,
416
+ "y": y_int,
417
+ "button": button.lower()
418
+ })
419
+
420
+ if result.get("success", False):
421
+ self.mouse_x = x_int
422
+ self.mouse_y = y_int
423
+ return ToolResult(success=True, output=f"{button} button pressed at ({x_int}, {y_int})")
424
+ else:
425
+ return ToolResult(success=False, output=f"Failed to press button: {result.get('error', 'Unknown error')}")
426
+ except Exception as e:
427
+ return ToolResult(success=False, output=f"Failed to press button: {str(e)}")
428
+
429
+ @openapi_schema({
430
+ "type": "function",
431
+ "function": {
432
+ "name": "mouse_up",
433
+ "description": "Release a mouse button",
434
+ "parameters": {
435
+ "type": "object",
436
+ "properties": {
437
+ "button": {
438
+ "type": "string",
439
+ "description": "Mouse button to release",
440
+ "enum": ["left", "right", "middle"],
441
+ "default": "left"
442
+ }
443
+ }
444
+ }
445
+ }
446
+ })
447
+ @xml_schema(
448
+ tag_name="mouse-up",
449
+ mappings=[
450
+ {"param_name": "button", "node_type": "attribute", "path": "button"}
451
+ ],
452
+ example='''
453
+ <mouse-up button="left">
454
+ </mouse-up>
455
+ '''
456
+ )
457
+ async def mouse_up(self, button: str = "left", x: Optional[float] = None, y: Optional[float] = None) -> ToolResult:
458
+ """Release a mouse button at current or specified position."""
459
+ try:
460
+ x_val = x if x is not None else self.mouse_x
461
+ y_val = y if y is not None else self.mouse_y
462
+
463
+ x_int = int(round(float(x_val)))
464
+ y_int = int(round(float(y_val)))
465
+
466
+ result = await self._api_request("POST", "/automation/mouse/up", {
467
+ "x": x_int,
468
+ "y": y_int,
469
+ "button": button.lower()
470
+ })
471
+
472
+ if result.get("success", False):
473
+ self.mouse_x = x_int
474
+ self.mouse_y = y_int
475
+ return ToolResult(success=True, output=f"{button} button released at ({x_int}, {y_int})")
476
+ else:
477
+ return ToolResult(success=False, output=f"Failed to release button: {result.get('error', 'Unknown error')}")
478
+ except Exception as e:
479
+ return ToolResult(success=False, output=f"Failed to release button: {str(e)}")
480
+
481
+ @openapi_schema({
482
+ "type": "function",
483
+ "function": {
484
+ "name": "drag_to",
485
+ "description": "Drag cursor to specified position",
486
+ "parameters": {
487
+ "type": "object",
488
+ "properties": {
489
+ "x": {
490
+ "type": "number",
491
+ "description": "Target X coordinate"
492
+ },
493
+ "y": {
494
+ "type": "number",
495
+ "description": "Target Y coordinate"
496
+ }
497
+ },
498
+ "required": ["x", "y"]
499
+ }
500
+ }
501
+ })
502
+ @xml_schema(
503
+ tag_name="drag-to",
504
+ mappings=[
505
+ {"param_name": "x", "node_type": "attribute", "path": "x"},
506
+ {"param_name": "y", "node_type": "attribute", "path": "y"}
507
+ ],
508
+ example='''
509
+ <drag-to x="500" y="50">
510
+ </drag-to>
511
+ '''
512
+ )
513
+ async def drag_to(self, x: float, y: float) -> ToolResult:
514
+ """Click and drag from current position to target position."""
515
+ try:
516
+ target_x = int(round(float(x)))
517
+ target_y = int(round(float(y)))
518
+ start_x = self.mouse_x
519
+ start_y = self.mouse_y
520
+
521
+ result = await self._api_request("POST", "/automation/mouse/drag", {
522
+ "x": target_x,
523
+ "y": target_y,
524
+ "duration": 0.3,
525
+ "button": "left"
526
+ })
527
+
528
+ if result.get("success", False):
529
+ self.mouse_x = target_x
530
+ self.mouse_y = target_y
531
+ return ToolResult(success=True,
532
+ output=f"Dragged from ({start_x}, {start_y}) to ({target_x}, {target_y})")
533
+ else:
534
+ return ToolResult(success=False, output=f"Failed to drag: {result.get('error', 'Unknown error')}")
535
+ except Exception as e:
536
+ return ToolResult(success=False, output=f"Failed to drag: {str(e)}")
537
+
538
+ async def get_screenshot_base64(self) -> Optional[dict]:
539
+ """Capture screen and return as base64 encoded image."""
540
+ try:
541
+ result = await self._api_request("POST", "/automation/screenshot")
542
+
543
+ if "image" in result:
544
+ base64_str = result["image"]
545
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
546
+
547
+ # Save screenshot to file
548
+ screenshots_dir = "screenshots"
549
+ if not os.path.exists(screenshots_dir):
550
+ os.makedirs(screenshots_dir)
551
+
552
+ timestamped_filename = os.path.join(screenshots_dir, f"screenshot_{timestamp}.png")
553
+ latest_filename = "latest_screenshot.png"
554
+
555
+ # Decode base64 string and save to file
556
+ img_data = base64.b64decode(base64_str)
557
+ with open(timestamped_filename, 'wb') as f:
558
+ f.write(img_data)
559
+
560
+ # Save a copy as the latest screenshot
561
+ with open(latest_filename, 'wb') as f:
562
+ f.write(img_data)
563
+
564
+ return {
565
+ "content_type": "image/png",
566
+ "base64": base64_str,
567
+ "timestamp": timestamp,
568
+ "filename": timestamped_filename
569
+ }
570
+ else:
571
+ return None
572
+
573
+ except Exception as e:
574
+ print(f"[Screenshot] Error during screenshot process: {str(e)}")
575
+ return None
576
+
577
+ @openapi_schema({
578
+ "type": "function",
579
+ "function": {
580
+ "name": "hotkey",
581
+ "description": "Press a key combination",
582
+ "parameters": {
583
+ "type": "object",
584
+ "properties": {
585
+ "keys": {
586
+ "type": "string",
587
+ "description": "Key combination to press",
588
+ "enum": KEYBOARD_KEYS
589
+ }
590
+ },
591
+ "required": ["keys"]
592
+ }
593
+ }
594
+ })
595
+ @xml_schema(
596
+ tag_name="hotkey",
597
+ mappings=[
598
+ {"param_name": "keys", "node_type": "attribute", "path": "keys"}
599
+ ],
600
+ example='''
601
+ <hotkey keys="ctrl+a">
602
+ </hotkey>
603
+ '''
604
+ )
605
+ async def hotkey(self, keys: str) -> ToolResult:
606
+ """Press a key combination."""
607
+ try:
608
+ keys = str(keys).lower().strip()
609
+ key_sequence = keys.split('+')
610
+
611
+ result = await self._api_request("POST", "/automation/keyboard/hotkey", {
612
+ "keys": key_sequence,
613
+ "interval": 0.01
614
+ })
615
+
616
+ if result.get("success", False):
617
+ return ToolResult(success=True, output=f"Pressed key combination: {keys}")
618
+ else:
619
+ return ToolResult(success=False, output=f"Failed to press keys: {result.get('error', 'Unknown error')}")
620
+ except Exception as e:
621
+ return ToolResult(success=False, output=f"Failed to press keys: {str(e)}")
622
+
623
+ if __name__ == "__main__":
624
+ print("This module should be imported, not run directly.")
config.cpython-311.pyc ADDED
Binary file (3.84 kB). View file
 
config.json ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2_5OmniModel"
4
+ ],
5
+ "enable_audio_output": true,
6
+ "enable_talker": true,
7
+ "model_type": "qwen2_5_omni",
8
+ "talker_config": {
9
+ "_attn_implementation_autoset": true,
10
+ "_name_or_path": "Qwen2.5-Omni-3B/talker",
11
+ "architectures": [
12
+ "Qwen2OmniTalkerForConditionalGeneration"
13
+ ],
14
+ "attention_dropout": 0.0,
15
+ "audio_end_token_id": 151648,
16
+ "audio_start_token_id": 151647,
17
+ "audio_token_index": 151646,
18
+ "embedding_size": 2048,
19
+ "head_dim": 64,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 896,
22
+ "image_token_index": 151655,
23
+ "init_std": 0.02,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 4864,
26
+ "max_position_embeddings": 32768,
27
+ "max_window_layers": 28,
28
+ "model_type": "qwen2_5_omni_talker",
29
+ "num_attention_heads": 14,
30
+ "num_hidden_layers": 24,
31
+ "num_key_value_heads": 2,
32
+ "position_id_per_seconds": 25,
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_scaling": {
35
+ "mrope_section": [
36
+ 16,
37
+ 16,
38
+ 0
39
+ ],
40
+ "rope_type": "default",
41
+ "type": "default"
42
+ },
43
+ "rope_theta": 1000000.0,
44
+ "seconds_per_chunk": 2,
45
+ "sliding_window": 32768,
46
+ "spatial_merge_size": 2,
47
+ "torch_dtype": "bfloat16",
48
+ "tts_codec_end_token_id": 8294,
49
+ "tts_codec_mask_token_id": 8296,
50
+ "tts_codec_pad_token_id": 8292,
51
+ "tts_codec_start_token_id": 8293,
52
+ "tts_text_end_token_id": 151861,
53
+ "tts_text_pad_token_id": 151859,
54
+ "tts_text_start_token_id": 151860,
55
+ "use_cache": true,
56
+ "use_sliding_window": false,
57
+ "video_token_index": 151656,
58
+ "vision_end_token_id": 151653,
59
+ "vision_start_token_id": 151652,
60
+ "vocab_size": 8448
61
+ },
62
+ "thinker_config": {
63
+ "_attn_implementation_autoset": true,
64
+ "_name_or_path": "Qwen2.5-Omni-3B/thinker",
65
+ "architectures": [
66
+ "Qwen2OmniNaViTThinkerForConditionalGeneration"
67
+ ],
68
+ "audio_config": {
69
+ "_attn_implementation_autoset": true,
70
+ "_name_or_path": "",
71
+ "activation_dropout": 0.0,
72
+ "activation_function": "gelu",
73
+ "add_cross_attention": false,
74
+ "architectures": null,
75
+ "attention_dropout": 0.0,
76
+ "bad_words_ids": null,
77
+ "begin_suppress_tokens": null,
78
+ "bos_token_id": null,
79
+ "chunk_size_feed_forward": 0,
80
+ "cross_attention_hidden_size": null,
81
+ "d_model": 1280,
82
+ "decoder_start_token_id": null,
83
+ "diversity_penalty": 0.0,
84
+ "do_sample": false,
85
+ "dropout": 0.0,
86
+ "early_stopping": false,
87
+ "encoder_attention_heads": 20,
88
+ "encoder_ffn_dim": 5120,
89
+ "encoder_layerdrop": 0.0,
90
+ "encoder_layers": 32,
91
+ "encoder_no_repeat_ngram_size": 0,
92
+ "eos_token_id": null,
93
+ "exponential_decay_length_penalty": null,
94
+ "finetuning_task": null,
95
+ "forced_bos_token_id": null,
96
+ "forced_eos_token_id": null,
97
+ "id2label": {
98
+ "0": "LABEL_0",
99
+ "1": "LABEL_1"
100
+ },
101
+ "init_std": 0.02,
102
+ "is_decoder": false,
103
+ "is_encoder_decoder": false,
104
+ "label2id": {
105
+ "LABEL_0": 0,
106
+ "LABEL_1": 1
107
+ },
108
+ "length_penalty": 1.0,
109
+ "max_length": 20,
110
+ "max_source_positions": 1500,
111
+ "min_length": 0,
112
+ "model_type": "qwen2_5_omni_audio_encoder",
113
+ "n_window": 100,
114
+ "no_repeat_ngram_size": 0,
115
+ "num_beam_groups": 1,
116
+ "num_beams": 1,
117
+ "num_hidden_layers": 32,
118
+ "num_mel_bins": 128,
119
+ "num_return_sequences": 1,
120
+ "output_attentions": false,
121
+ "output_dim": 2048,
122
+ "output_hidden_states": false,
123
+ "output_scores": false,
124
+ "pad_token_id": null,
125
+ "prefix": null,
126
+ "problem_type": null,
127
+ "pruned_heads": {},
128
+ "remove_invalid_values": false,
129
+ "repetition_penalty": 1.0,
130
+ "return_dict": true,
131
+ "return_dict_in_generate": false,
132
+ "scale_embedding": false,
133
+ "sep_token_id": null,
134
+ "suppress_tokens": null,
135
+ "task_specific_params": null,
136
+ "temperature": 1.0,
137
+ "tf_legacy_loss": false,
138
+ "tie_encoder_decoder": false,
139
+ "tie_word_embeddings": true,
140
+ "tokenizer_class": null,
141
+ "top_k": 50,
142
+ "top_p": 1.0,
143
+ "torch_dtype": null,
144
+ "torchscript": false,
145
+ "typical_p": 1.0,
146
+ "use_bfloat16": false
147
+ },
148
+ "text_config": {
149
+ "model_type": "qwen2_5_omni_text",
150
+ "hidden_act": "silu",
151
+ "hidden_size": 2048,
152
+ "init_std": 0.02,
153
+ "intermediate_size": 11008,
154
+ "vocab_size": 151936,
155
+ "num_attention_heads": 16,
156
+ "num_hidden_layers": 36,
157
+ "num_key_value_heads": 2,
158
+ "max_position_embeddings": 32768,
159
+ "max_window_layers": 70,
160
+ "rms_norm_eps": 1e-06,
161
+ "rope_scaling": {
162
+ "mrope_section": [
163
+ 16,
164
+ 24,
165
+ 24
166
+ ],
167
+ "rope_type": "default",
168
+ "type": "default"
169
+ },
170
+ "use_cache": true,
171
+ "rope_theta": 1000000.0,
172
+ "use_sliding_window": false,
173
+ "sliding_window": 32768,
174
+ "attention_dropout": 0.0,
175
+ "tie_word_embeddings": false
176
+ },
177
+ "audio_end_token_id": 151648,
178
+ "audio_start_token_id": 151647,
179
+ "audio_token_index": 151646,
180
+ "bos_token_id": 151644,
181
+ "eos_token_id": 151645,
182
+ "ignore_index": -100,
183
+ "image_token_index": 151655,
184
+ "init_std": 0.02,
185
+ "model_type": "qwen2_5_omni_thinker",
186
+ "pad_token_id": 151643,
187
+ "position_id_per_seconds": 25,
188
+ "seconds_per_chunk": 2,
189
+ "torch_dtype": "bfloat16",
190
+ "user_token_id": 872,
191
+ "video_token_index": 151656,
192
+ "vision_config": {
193
+ "_attn_implementation_autoset": true,
194
+ "_name_or_path": "",
195
+ "add_cross_attention": false,
196
+ "architectures": null,
197
+ "bad_words_ids": null,
198
+ "begin_suppress_tokens": null,
199
+ "bos_token_id": null,
200
+ "chunk_size_feed_forward": 0,
201
+ "cross_attention_hidden_size": null,
202
+ "decoder_start_token_id": null,
203
+ "depth": 32,
204
+ "diversity_penalty": 0.0,
205
+ "do_sample": false,
206
+ "early_stopping": false,
207
+ "embed_dim": 1280,
208
+ "encoder_no_repeat_ngram_size": 0,
209
+ "eos_token_id": null,
210
+ "exponential_decay_length_penalty": null,
211
+ "finetuning_task": null,
212
+ "forced_bos_token_id": null,
213
+ "forced_eos_token_id": null,
214
+ "fullatt_block_indexes": [
215
+ 7,
216
+ 15,
217
+ 23,
218
+ 31
219
+ ],
220
+ "hidden_act": "silu",
221
+ "hidden_size": 1280,
222
+ "id2label": {
223
+ "0": "LABEL_0",
224
+ "1": "LABEL_1"
225
+ },
226
+ "in_channels": 3,
227
+ "in_chans": 3,
228
+ "init_std": 0.02,
229
+ "intermediate_size": 3420,
230
+ "is_decoder": false,
231
+ "is_encoder_decoder": false,
232
+ "label2id": {
233
+ "LABEL_0": 0,
234
+ "LABEL_1": 1
235
+ },
236
+ "length_penalty": 1.0,
237
+ "max_length": 20,
238
+ "min_length": 0,
239
+ "model_type": "qwen2_5_omni_vision_encoder",
240
+ "no_repeat_ngram_size": 0,
241
+ "num_beam_groups": 1,
242
+ "num_beams": 1,
243
+ "num_heads": 16,
244
+ "num_return_sequences": 1,
245
+ "out_hidden_size": 2048,
246
+ "output_attentions": false,
247
+ "output_hidden_states": false,
248
+ "output_scores": false,
249
+ "pad_token_id": null,
250
+ "patch_size": 14,
251
+ "prefix": null,
252
+ "problem_type": null,
253
+ "pruned_heads": {},
254
+ "remove_invalid_values": false,
255
+ "repetition_penalty": 1.0,
256
+ "return_dict": true,
257
+ "return_dict_in_generate": false,
258
+ "sep_token_id": null,
259
+ "spatial_merge_size": 2,
260
+ "spatial_patch_size": 14,
261
+ "suppress_tokens": null,
262
+ "task_specific_params": null,
263
+ "temperature": 1.0,
264
+ "temporal_patch_size": 2,
265
+ "tf_legacy_loss": false,
266
+ "tie_encoder_decoder": false,
267
+ "tie_word_embeddings": true,
268
+ "tokenizer_class": null,
269
+ "tokens_per_second": 25,
270
+ "top_k": 50,
271
+ "top_p": 1.0,
272
+ "torch_dtype": null,
273
+ "torchscript": false,
274
+ "typical_p": 1.0,
275
+ "use_bfloat16": false,
276
+ "window_size": 112
277
+ },
278
+ "vision_end_token_id": 151653,
279
+ "vision_start_token_id": 151652,
280
+ "vision_token_id": 151654
281
+ },
282
+ "token2wav_config": {
283
+ "_attn_implementation_autoset": true,
284
+ "bigvgan_config": {
285
+ "_attn_implementation_autoset": true,
286
+ "_name_or_path": "",
287
+ "add_cross_attention": false,
288
+ "architectures": null,
289
+ "bad_words_ids": null,
290
+ "begin_suppress_tokens": null,
291
+ "bos_token_id": null,
292
+ "chunk_size_feed_forward": 0,
293
+ "cross_attention_hidden_size": null,
294
+ "decoder_start_token_id": null,
295
+ "diversity_penalty": 0.0,
296
+ "do_sample": false,
297
+ "early_stopping": false,
298
+ "encoder_no_repeat_ngram_size": 0,
299
+ "eos_token_id": null,
300
+ "exponential_decay_length_penalty": null,
301
+ "finetuning_task": null,
302
+ "forced_bos_token_id": null,
303
+ "forced_eos_token_id": null,
304
+ "id2label": {
305
+ "0": "LABEL_0",
306
+ "1": "LABEL_1"
307
+ },
308
+ "is_decoder": false,
309
+ "is_encoder_decoder": false,
310
+ "label2id": {
311
+ "LABEL_0": 0,
312
+ "LABEL_1": 1
313
+ },
314
+ "length_penalty": 1.0,
315
+ "max_length": 20,
316
+ "mel_dim": 80,
317
+ "min_length": 0,
318
+ "model_type": "qwen2_5_omni_bigvgan",
319
+ "no_repeat_ngram_size": 0,
320
+ "num_beam_groups": 1,
321
+ "num_beams": 1,
322
+ "num_return_sequences": 1,
323
+ "output_attentions": false,
324
+ "output_hidden_states": false,
325
+ "output_scores": false,
326
+ "pad_token_id": null,
327
+ "prefix": null,
328
+ "problem_type": null,
329
+ "pruned_heads": {},
330
+ "remove_invalid_values": false,
331
+ "repetition_penalty": 1.0,
332
+ "resblock_dilation_sizes": [
333
+ [
334
+ 1,
335
+ 3,
336
+ 5
337
+ ],
338
+ [
339
+ 1,
340
+ 3,
341
+ 5
342
+ ],
343
+ [
344
+ 1,
345
+ 3,
346
+ 5
347
+ ]
348
+ ],
349
+ "resblock_kernel_sizes": [
350
+ 3,
351
+ 7,
352
+ 11
353
+ ],
354
+ "return_dict": true,
355
+ "return_dict_in_generate": false,
356
+ "sep_token_id": null,
357
+ "suppress_tokens": null,
358
+ "task_specific_params": null,
359
+ "temperature": 1.0,
360
+ "tf_legacy_loss": false,
361
+ "tie_encoder_decoder": false,
362
+ "tie_word_embeddings": true,
363
+ "tokenizer_class": null,
364
+ "top_k": 50,
365
+ "top_p": 1.0,
366
+ "torch_dtype": null,
367
+ "torchscript": false,
368
+ "typical_p": 1.0,
369
+ "upsample_initial_channel": 1536,
370
+ "upsample_kernel_sizes": [
371
+ 11,
372
+ 7,
373
+ 4,
374
+ 4,
375
+ 4,
376
+ 4
377
+ ],
378
+ "upsample_rates": [
379
+ 5,
380
+ 3,
381
+ 2,
382
+ 2,
383
+ 2,
384
+ 2
385
+ ],
386
+ "use_bfloat16": false,
387
+ "use_bias_at_final": false
388
+ },
389
+ "dit_config": {
390
+ "_attn_implementation_autoset": true,
391
+ "_name_or_path": "",
392
+ "add_cross_attention": false,
393
+ "architectures": null,
394
+ "bad_words_ids": null,
395
+ "begin_suppress_tokens": null,
396
+ "bos_token_id": null,
397
+ "chunk_size_feed_forward": 0,
398
+ "cross_attention_hidden_size": null,
399
+ "decoder_start_token_id": null,
400
+ "depth": 22,
401
+ "dim": 1024,
402
+ "diversity_penalty": 0.0,
403
+ "do_sample": false,
404
+ "dropout": 0.1,
405
+ "early_stopping": false,
406
+ "emb_dim": 512,
407
+ "enc_attention_channels": 64,
408
+ "enc_channels": [
409
+ 256,
410
+ 256,
411
+ 256,
412
+ 256,
413
+ 768
414
+ ],
415
+ "enc_dilations": [
416
+ 1,
417
+ 2,
418
+ 3,
419
+ 4,
420
+ 1
421
+ ],
422
+ "enc_dim": 128,
423
+ "enc_emb_dim": 192,
424
+ "enc_global_context": true,
425
+ "enc_kernel_sizes": [
426
+ 5,
427
+ 3,
428
+ 3,
429
+ 3,
430
+ 1
431
+ ],
432
+ "enc_lin_neurons": 192,
433
+ "enc_res2net_scale": 2,
434
+ "enc_se_channels": 64,
435
+ "encoder_no_repeat_ngram_size": 0,
436
+ "eos_token_id": null,
437
+ "exponential_decay_length_penalty": null,
438
+ "ff_mult": 2,
439
+ "finetuning_task": null,
440
+ "forced_bos_token_id": null,
441
+ "forced_eos_token_id": null,
442
+ "head_dim": 64,
443
+ "heads": 16,
444
+ "id2label": {
445
+ "0": "LABEL_0",
446
+ "1": "LABEL_1"
447
+ },
448
+ "is_decoder": false,
449
+ "is_encoder_decoder": false,
450
+ "label2id": {
451
+ "LABEL_0": 0,
452
+ "LABEL_1": 1
453
+ },
454
+ "length_penalty": 1.0,
455
+ "max_length": 20,
456
+ "mel_dim": 80,
457
+ "min_length": 0,
458
+ "model_type": "qwen2_5_omni_dit",
459
+ "no_repeat_ngram_size": 0,
460
+ "num_beam_groups": 1,
461
+ "num_beams": 1,
462
+ "num_embeds": 8193,
463
+ "num_return_sequences": 1,
464
+ "output_attentions": false,
465
+ "output_hidden_states": false,
466
+ "output_scores": false,
467
+ "pad_token_id": null,
468
+ "prefix": null,
469
+ "problem_type": null,
470
+ "pruned_heads": {},
471
+ "remove_invalid_values": false,
472
+ "repeats": 2,
473
+ "repetition_penalty": 1.0,
474
+ "return_dict": true,
475
+ "return_dict_in_generate": false,
476
+ "sep_token_id": null,
477
+ "suppress_tokens": null,
478
+ "task_specific_params": null,
479
+ "temperature": 1.0,
480
+ "tf_legacy_loss": false,
481
+ "tie_encoder_decoder": false,
482
+ "tie_word_embeddings": true,
483
+ "tokenizer_class": null,
484
+ "top_k": 50,
485
+ "top_p": 1.0,
486
+ "torch_dtype": "float32",
487
+ "torchscript": false,
488
+ "typical_p": 1.0,
489
+ "use_bfloat16": false
490
+ },
491
+ "model_type": "qwen2_5_omni_token2wav"
492
+ },
493
+ "torch_dtype": "bfloat16",
494
+ "transformers_version": "4.51.0.dev0"
495
+ }