pedrobento988 commited on
Commit
3feb691
Β·
verified Β·
1 Parent(s): 69b3513

ui-improvements (#11)

Browse files

- feat: updates dependencies (54bdb7fbbdc4bf60daeb470f6ed2796ba387f8e5)
- feat: removes annotations (e6cfbb2fd6fe8c0429994c1bead8eb0ad8674987)
- feat: updates dependencies (f20f6154b6d31972d9b64b8f180317ab32c426ed)
- feat: improves ui (87d04296fe3766c5759390c1a70689505d02b30e)

.pre-commit-config.yaml CHANGED
@@ -65,6 +65,7 @@ repos:
65
  "--format",
66
  "requirements-txt",
67
  "--no-hashes",
 
68
  "--no-dev",
69
  "-o",
70
  "requirements.txt",
@@ -82,6 +83,7 @@ repos:
82
  "--format",
83
  "requirements-txt",
84
  "--no-hashes",
 
85
  "--group",
86
  "dev",
87
  "--group",
@@ -91,7 +93,7 @@ repos:
91
  ]
92
  - id: mypy
93
  name: Running mypy
94
- stages: [commit]
95
  language: system
96
  entry: uv run mypy
97
  args: [--install-types, --non-interactive]
 
65
  "--format",
66
  "requirements-txt",
67
  "--no-hashes",
68
+ "--no-annotate",
69
  "--no-dev",
70
  "-o",
71
  "requirements.txt",
 
83
  "--format",
84
  "requirements-txt",
85
  "--no-hashes",
86
+ "--no-annotate",
87
  "--group",
88
  "dev",
89
  "--group",
 
93
  ]
94
  - id: mypy
95
  name: Running mypy
96
+ stages: [pre-commit]
97
  language: system
98
  entry: uv run mypy
99
  args: [--install-types, --non-interactive]
pyproject.toml CHANGED
@@ -16,8 +16,10 @@ dependencies = [
16
  "gradio[mcp]~=5.31",
17
  "huggingface-hub>=0.32.3",
18
  "langchain-aws>=0.2.24",
 
19
  "langchain-mcp-adapters>=0.1.1",
20
  "langgraph>=0.4.7",
 
21
  ]
22
 
23
  [project.scripts]
 
16
  "gradio[mcp]~=5.31",
17
  "huggingface-hub>=0.32.3",
18
  "langchain-aws>=0.2.24",
19
+ "langchain-huggingface>=0.2.0",
20
  "langchain-mcp-adapters>=0.1.1",
21
  "langgraph>=0.4.7",
22
+ "openai>=1.84.0",
23
  ]
24
 
25
  [project.scripts]
requirements-dev.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv export --format requirements-txt --no-hashes --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
@@ -18,6 +18,7 @@ coverage==7.8.2
18
  cyclonedx-python-lib==9.1.0
19
  defusedxml==0.7.1
20
  distlib==0.3.9
 
21
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
22
  fastapi==0.115.12
23
  ffmpy==0.6.0
@@ -36,11 +37,14 @@ identify==2.6.12
36
  idna==3.10
37
  iniconfig==2.1.0
38
  jinja2==3.1.6
 
39
  jmespath==1.0.1
 
40
  jsonpatch==1.33
41
  jsonpointer==3.0.0
42
  langchain-aws==0.2.24
43
  langchain-core==0.3.63
 
44
  langchain-mcp-adapters==0.1.1
45
  langgraph==0.4.7
46
  langgraph-checkpoint==2.0.26
@@ -52,12 +56,30 @@ markdown-it-py==3.0.0
52
  markupsafe==3.0.2
53
  mcp==1.9.0
54
  mdurl==0.1.2
 
55
  msgpack==1.1.0
56
  mypy==1.16.0
57
  mypy-extensions==1.1.0
 
 
58
  nodeenv==1.9.1
59
  numpy==1.26.4 ; python_full_version < '3.12'
60
  numpy==2.2.6 ; python_full_version >= '3.12'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  orjson==3.10.18
62
  ormsgpack==1.10.0
63
  packageurl-python==0.16.0
@@ -88,24 +110,36 @@ python-dotenv==1.1.0
88
  python-multipart==0.0.20
89
  pytz==2025.2
90
  pyyaml==6.0.2
 
91
  requests==2.32.3
92
  requests-toolbelt==1.0.0
93
  rich==14.0.0
94
  ruff==0.11.12
95
  s3transfer==0.13.0
96
  safehttpx==0.1.6
 
 
 
97
  semantic-version==2.10.0
 
 
98
  shellingham==1.5.4 ; sys_platform != 'emscripten'
99
  six==1.17.0
100
  sniffio==1.3.1
101
  sortedcontainers==2.4.0
102
  sse-starlette==2.3.6
103
  starlette==0.46.2
 
104
  tenacity==9.1.2
 
 
105
  toml==0.10.2
106
  tomli==2.2.1 ; python_full_version <= '3.11'
107
  tomlkit==0.13.2
 
108
  tqdm==4.67.1
 
 
109
  typer==0.16.0 ; sys_platform != 'emscripten'
110
  typing-extensions==4.13.2
111
  typing-inspection==0.4.1
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes --no-annotate --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
18
  cyclonedx-python-lib==9.1.0
19
  defusedxml==0.7.1
20
  distlib==0.3.9
21
+ distro==1.9.0
22
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
23
  fastapi==0.115.12
24
  ffmpy==0.6.0
 
37
  idna==3.10
38
  iniconfig==2.1.0
39
  jinja2==3.1.6
40
+ jiter==0.10.0
41
  jmespath==1.0.1
42
+ joblib==1.5.1
43
  jsonpatch==1.33
44
  jsonpointer==3.0.0
45
  langchain-aws==0.2.24
46
  langchain-core==0.3.63
47
+ langchain-huggingface==0.2.0
48
  langchain-mcp-adapters==0.1.1
49
  langgraph==0.4.7
50
  langgraph-checkpoint==2.0.26
 
56
  markupsafe==3.0.2
57
  mcp==1.9.0
58
  mdurl==0.1.2
59
+ mpmath==1.3.0
60
  msgpack==1.1.0
61
  mypy==1.16.0
62
  mypy-extensions==1.1.0
63
+ networkx==3.4.2 ; python_full_version < '3.11'
64
+ networkx==3.5 ; python_full_version >= '3.11'
65
  nodeenv==1.9.1
66
  numpy==1.26.4 ; python_full_version < '3.12'
67
  numpy==2.2.6 ; python_full_version >= '3.12'
68
+ nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
69
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
70
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
71
+ nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
72
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
73
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
74
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
75
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
76
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
77
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
78
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
79
+ nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
80
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
81
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
82
+ openai==1.84.0
83
  orjson==3.10.18
84
  ormsgpack==1.10.0
85
  packageurl-python==0.16.0
 
110
  python-multipart==0.0.20
111
  pytz==2025.2
112
  pyyaml==6.0.2
113
+ regex==2024.11.6
114
  requests==2.32.3
115
  requests-toolbelt==1.0.0
116
  rich==14.0.0
117
  ruff==0.11.12
118
  s3transfer==0.13.0
119
  safehttpx==0.1.6
120
+ safetensors==0.5.3
121
+ scikit-learn==1.6.1
122
+ scipy==1.15.3
123
  semantic-version==2.10.0
124
+ sentence-transformers==4.1.0
125
+ setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
126
  shellingham==1.5.4 ; sys_platform != 'emscripten'
127
  six==1.17.0
128
  sniffio==1.3.1
129
  sortedcontainers==2.4.0
130
  sse-starlette==2.3.6
131
  starlette==0.46.2
132
+ sympy==1.14.0
133
  tenacity==9.1.2
134
+ threadpoolctl==3.6.0
135
+ tokenizers==0.21.1
136
  toml==0.10.2
137
  tomli==2.2.1 ; python_full_version <= '3.11'
138
  tomlkit==0.13.2
139
+ torch==2.7.1
140
  tqdm==4.67.1
141
+ transformers==4.52.4
142
+ triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
143
  typer==0.16.0 ; sys_platform != 'emscripten'
144
  typing-extensions==4.13.2
145
  typing-inspection==0.4.1
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv export --format requirements-txt --no-hashes --no-dev -o requirements.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
@@ -12,6 +12,7 @@ charset-normalizer==3.4.2
12
  click==8.2.1 ; sys_platform != 'emscripten'
13
  colorama==0.4.6 ; sys_platform == 'win32'
14
  coverage==7.8.2
 
15
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
16
  fastapi==0.115.12
17
  ffmpy==0.6.0
@@ -29,11 +30,14 @@ huggingface-hub==0.32.3
29
  idna==3.10
30
  iniconfig==2.1.0
31
  jinja2==3.1.6
 
32
  jmespath==1.0.1
 
33
  jsonpatch==1.33
34
  jsonpointer==3.0.0
35
  langchain-aws==0.2.24
36
  langchain-core==0.3.63
 
37
  langchain-mcp-adapters==0.1.1
38
  langgraph==0.4.7
39
  langgraph-checkpoint==2.0.26
@@ -44,8 +48,26 @@ markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
44
  markupsafe==3.0.2
45
  mcp==1.9.0
46
  mdurl==0.1.2 ; sys_platform != 'emscripten'
 
 
 
47
  numpy==1.26.4 ; python_full_version < '3.12'
48
  numpy==2.2.6 ; python_full_version >= '3.12'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  orjson==3.10.18
50
  ormsgpack==1.10.0
51
  packaging==24.2
@@ -66,22 +88,34 @@ python-dotenv==1.1.0
66
  python-multipart==0.0.20
67
  pytz==2025.2
68
  pyyaml==6.0.2
 
69
  requests==2.32.3
70
  requests-toolbelt==1.0.0
71
  rich==14.0.0 ; sys_platform != 'emscripten'
72
  ruff==0.11.12 ; sys_platform != 'emscripten'
73
  s3transfer==0.13.0
74
  safehttpx==0.1.6
 
 
 
75
  semantic-version==2.10.0
 
 
76
  shellingham==1.5.4 ; sys_platform != 'emscripten'
77
  six==1.17.0
78
  sniffio==1.3.1
79
  sse-starlette==2.3.6
80
  starlette==0.46.2
 
81
  tenacity==9.1.2
 
 
82
  tomli==2.2.1 ; python_full_version <= '3.11'
83
  tomlkit==0.13.2
 
84
  tqdm==4.67.1
 
 
85
  typer==0.16.0 ; sys_platform != 'emscripten'
86
  typing-extensions==4.13.2
87
  typing-inspection==0.4.1
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes --no-annotate --no-dev -o requirements.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
12
  click==8.2.1 ; sys_platform != 'emscripten'
13
  colorama==0.4.6 ; sys_platform == 'win32'
14
  coverage==7.8.2
15
+ distro==1.9.0
16
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
17
  fastapi==0.115.12
18
  ffmpy==0.6.0
 
30
  idna==3.10
31
  iniconfig==2.1.0
32
  jinja2==3.1.6
33
+ jiter==0.10.0
34
  jmespath==1.0.1
35
+ joblib==1.5.1
36
  jsonpatch==1.33
37
  jsonpointer==3.0.0
38
  langchain-aws==0.2.24
39
  langchain-core==0.3.63
40
+ langchain-huggingface==0.2.0
41
  langchain-mcp-adapters==0.1.1
42
  langgraph==0.4.7
43
  langgraph-checkpoint==2.0.26
 
48
  markupsafe==3.0.2
49
  mcp==1.9.0
50
  mdurl==0.1.2 ; sys_platform != 'emscripten'
51
+ mpmath==1.3.0
52
+ networkx==3.4.2 ; python_full_version < '3.11'
53
+ networkx==3.5 ; python_full_version >= '3.11'
54
  numpy==1.26.4 ; python_full_version < '3.12'
55
  numpy==2.2.6 ; python_full_version >= '3.12'
56
+ nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
57
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
58
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
59
+ nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
60
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
61
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
62
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
63
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
64
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
65
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
66
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
67
+ nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
68
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
69
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
70
+ openai==1.84.0
71
  orjson==3.10.18
72
  ormsgpack==1.10.0
73
  packaging==24.2
 
88
  python-multipart==0.0.20
89
  pytz==2025.2
90
  pyyaml==6.0.2
91
+ regex==2024.11.6
92
  requests==2.32.3
93
  requests-toolbelt==1.0.0
94
  rich==14.0.0 ; sys_platform != 'emscripten'
95
  ruff==0.11.12 ; sys_platform != 'emscripten'
96
  s3transfer==0.13.0
97
  safehttpx==0.1.6
98
+ safetensors==0.5.3
99
+ scikit-learn==1.6.1
100
+ scipy==1.15.3
101
  semantic-version==2.10.0
102
+ sentence-transformers==4.1.0
103
+ setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
104
  shellingham==1.5.4 ; sys_platform != 'emscripten'
105
  six==1.17.0
106
  sniffio==1.3.1
107
  sse-starlette==2.3.6
108
  starlette==0.46.2
109
+ sympy==1.14.0
110
  tenacity==9.1.2
111
+ threadpoolctl==3.6.0
112
+ tokenizers==0.21.1
113
  tomli==2.2.1 ; python_full_version <= '3.11'
114
  tomlkit==0.13.2
115
+ torch==2.7.1
116
  tqdm==4.67.1
117
+ transformers==4.52.4
118
+ triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
119
  typer==0.16.0 ; sys_platform != 'emscripten'
120
  typing-extensions==4.13.2
121
  typing-inspection==0.4.1
tdagent/grchat.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
- from typing import TYPE_CHECKING
6
 
7
  import boto3
8
  import botocore
@@ -10,8 +10,11 @@ import botocore.exceptions
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
13
  from langchain_mcp_adapters.client import MultiServerMCPClient
14
  from langgraph.prebuilt import create_react_agent
 
 
15
 
16
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
17
 
@@ -48,6 +51,15 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
48
  },
49
  )
50
 
 
 
 
 
 
 
 
 
 
51
 
52
  #### Shared variables ####
53
 
@@ -56,12 +68,15 @@ llm_agent: CompiledGraph | None = None
56
  #### Utility functions ####
57
 
58
 
 
59
  def create_bedrock_llm(
60
  bedrock_model_id: str,
61
  aws_access_key: str,
62
  aws_secret_key: str,
63
  aws_session_token: str,
64
  aws_region: str,
 
 
65
  ) -> tuple[ChatBedrock | None, str]:
66
  """Create a LangGraph Bedrock agent."""
67
  boto3_config = {
@@ -70,7 +85,6 @@ def create_bedrock_llm(
70
  "aws_session_token": aws_session_token if aws_session_token else None,
71
  "region_name": aws_region,
72
  }
73
-
74
  # Verify credentials
75
  try:
76
  sts = boto3.client("sts", **boto3_config)
@@ -83,7 +97,7 @@ def create_bedrock_llm(
83
  llm = ChatBedrock(
84
  model_id=bedrock_model_id,
85
  client=bedrock_client,
86
- model_kwargs={"temperature": 0.8},
87
  )
88
  except Exception as e: # noqa: BLE001
89
  return None, str(e)
@@ -91,20 +105,59 @@ def create_bedrock_llm(
91
  return llm, ""
92
 
93
 
94
- #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
- async def gr_connect_to_bedrock(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  model_id: str,
99
  access_key: str,
100
  secret_key: str,
101
  session_token: str,
102
  region: str,
103
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
104
  ) -> str:
105
  """Initialize Bedrock agent."""
106
  global llm_agent # noqa: PLW0603
107
-
108
  if not access_key or not secret_key:
109
  return "❌ Please provide both Access Key ID and Secret Access Key"
110
 
@@ -114,6 +167,8 @@ async def gr_connect_to_bedrock(
114
  secret_key.strip(),
115
  session_token.strip(),
116
  region,
 
 
117
  )
118
 
119
  if llm is None:
@@ -128,7 +183,6 @@ async def gr_connect_to_bedrock(
128
  # }
129
  # )
130
  # tools = await client.get_tools()
131
-
132
  if mcp_servers:
133
  client = MultiServerMCPClient(
134
  {
@@ -142,7 +196,6 @@ async def gr_connect_to_bedrock(
142
  tools = await client.get_tools()
143
  else:
144
  tools = []
145
-
146
  llm_agent = create_react_agent(
147
  model=llm,
148
  tools=tools,
@@ -152,6 +205,73 @@ async def gr_connect_to_bedrock(
152
  return "βœ… Successfully connected to AWS Bedrock!"
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  async def gr_chat_function( # noqa: D103
156
  message: str,
157
  history: list[Mapping[str, str]],
@@ -178,49 +298,110 @@ async def gr_chat_function( # noqa: D103
178
 
179
  ## UI components ##
180
 
181
- with gr.Blocks() as gr_app:
182
- gr.Markdown("# πŸ” Secure Bedrock Chatbot")
183
-
184
- ### MCP Servers ###
185
- with gr.Accordion():
186
- mcp_list = MutableCheckBoxGroup(
187
- values=[
188
- MutableCheckBoxGroupEntry(
189
- name="TDAgent tools",
190
- value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
191
- ),
192
- ],
193
- label="MCP Servers",
194
- )
195
 
196
- # Credentials section (collapsible)
197
- with gr.Accordion("πŸ”‘ Bedrock Configuration", open=True):
198
- gr.Markdown(
199
- "**Note**: Credentials are only stored in memory during your session.",
200
- )
201
- with gr.Row():
202
- bedrock_model_id_textbox = gr.Textbox(
203
- label="Bedrock Model Id",
204
- value="eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
- with gr.Row():
207
  aws_access_key_textbox = gr.Textbox(
208
  label="AWS Access Key ID",
209
  type="password",
210
  placeholder="Enter your AWS Access Key ID",
 
211
  )
212
  aws_secret_key_textbox = gr.Textbox(
213
  label="AWS Secret Access Key",
214
  type="password",
215
  placeholder="Enter your AWS Secret Access Key",
 
216
  )
217
- with gr.Row():
218
- aws_session_token_textbox = gr.Textbox(
219
- label="AWS Session Token",
220
- type="password",
221
- placeholder="Enter your AWS session token",
222
- )
223
- with gr.Row():
224
  aws_region_dropdown = gr.Dropdown(
225
  label="AWS Region",
226
  choices=[
@@ -231,31 +412,83 @@ with gr.Blocks() as gr_app:
231
  "ap-southeast-1",
232
  ],
233
  value="eu-west-1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
- connect_btn = gr.Button("πŸ”Œ Connect to Bedrock", variant="primary")
236
 
 
237
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
238
 
239
  connect_btn.click(
240
- gr_connect_to_bedrock,
241
  inputs=[
242
- bedrock_model_id_textbox,
 
 
243
  aws_access_key_textbox,
244
  aws_secret_key_textbox,
245
  aws_session_token_textbox,
246
  aws_region_dropdown,
247
- mcp_list.state,
 
 
248
  ],
249
  outputs=[status_textbox],
250
  )
251
 
252
- chat_interface = gr.ChatInterface(
253
- fn=gr_chat_function,
254
- type="messages",
255
- examples=[],
256
- title="Agent with MCP Tools",
257
- description="This is a simple agent that uses MCP tools.",
258
- )
 
259
 
260
 
261
  if __name__ == "__main__":
 
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
+ from typing import TYPE_CHECKING, Any
6
 
7
  import boto3
8
  import botocore
 
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
+ from openai import OpenAI
17
+ from openai.types.chat import ChatCompletion
18
 
19
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
20
 
 
51
  },
52
  )
53
 
54
+ MODEL_OPTIONS = {
55
+ "AWS Bedrock": {
56
+ "Anthropic Claude 3.5 Sonnet": "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
57
+ # "Anthropic Claude 3.7 Sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
58
+ },
59
+ "HuggingFace": {
60
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct",
61
+ },
62
+ }
63
 
64
  #### Shared variables ####
65
 
 
68
  #### Utility functions ####
69
 
70
 
71
+ ## Bedrock LLM creation ##
72
  def create_bedrock_llm(
73
  bedrock_model_id: str,
74
  aws_access_key: str,
75
  aws_secret_key: str,
76
  aws_session_token: str,
77
  aws_region: str,
78
+ temperature: float = 0.8,
79
+ max_tokens: int = 512,
80
  ) -> tuple[ChatBedrock | None, str]:
81
  """Create a LangGraph Bedrock agent."""
82
  boto3_config = {
 
85
  "aws_session_token": aws_session_token if aws_session_token else None,
86
  "region_name": aws_region,
87
  }
 
88
  # Verify credentials
89
  try:
90
  sts = boto3.client("sts", **boto3_config)
 
97
  llm = ChatBedrock(
98
  model_id=bedrock_model_id,
99
  client=bedrock_client,
100
+ model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
101
  )
102
  except Exception as e: # noqa: BLE001
103
  return None, str(e)
 
105
  return llm, ""
106
 
107
 
108
+ ## Hugging Face LLM creation ##
109
+ def create_hf_llm(
110
+ hf_model_id: str,
111
+ huggingfacehub_api_token: str | None = None,
112
+ ) -> tuple[HuggingFaceEndpoint | None, str]:
113
+ """Create a LangGraph Hugging Face agent."""
114
+ try:
115
+ llm = HuggingFaceEndpoint(
116
+ model=hf_model_id,
117
+ huggingfacehub_api_token=huggingfacehub_api_token,
118
+ temperature=0.8,
119
+ )
120
+ except Exception as e: # noqa: BLE001
121
+ return None, str(e)
122
+
123
+ return llm, ""
124
 
125
 
126
+ ## OpenAI LLM creation ##
127
+ def create_openai_llm(
128
+ model_id: str,
129
+ token_id: str,
130
+ ) -> tuple[ChatCompletion | None, str]:
131
+ """Create a LangGraph OpenAI agent."""
132
+ try:
133
+ client = OpenAI(
134
+ base_url="https://api.studio.nebius.com/v1/",
135
+ api_key=token_id,
136
+ )
137
+ llm = client.chat.completions.create(
138
+ messages=[], # needs to be fixed
139
+ model=model_id,
140
+ max_tokens=512,
141
+ temperature=0.8,
142
+ )
143
+ except Exception as e: # noqa: BLE001
144
+ return None, str(e)
145
+ return llm, ""
146
+
147
+
148
+ #### UI functionality ####
149
+ async def gr_connect_to_bedrock( # noqa: PLR0913
150
  model_id: str,
151
  access_key: str,
152
  secret_key: str,
153
  session_token: str,
154
  region: str,
155
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
156
+ temperature: float = 0.8,
157
+ max_tokens: int = 512,
158
  ) -> str:
159
  """Initialize Bedrock agent."""
160
  global llm_agent # noqa: PLW0603
 
161
  if not access_key or not secret_key:
162
  return "❌ Please provide both Access Key ID and Secret Access Key"
163
 
 
167
  secret_key.strip(),
168
  session_token.strip(),
169
  region,
170
+ temperature=temperature,
171
+ max_tokens=max_tokens,
172
  )
173
 
174
  if llm is None:
 
183
  # }
184
  # )
185
  # tools = await client.get_tools()
 
186
  if mcp_servers:
187
  client = MultiServerMCPClient(
188
  {
 
196
  tools = await client.get_tools()
197
  else:
198
  tools = []
 
199
  llm_agent = create_react_agent(
200
  model=llm,
201
  tools=tools,
 
205
  return "βœ… Successfully connected to AWS Bedrock!"
206
 
207
 
208
+ async def gr_connect_to_hf(
209
+ model_id: str,
210
+ hf_access_token_textbox: str | None,
211
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
212
+ ) -> str:
213
+ """Initialize Hugging Face agent."""
214
+ global llm_agent # noqa: PLW0603
215
+
216
+ llm, error = create_hf_llm(model_id, hf_access_token_textbox)
217
+
218
+ if llm is None:
219
+ return f"❌ Connection failed: {error}"
220
+ tools = []
221
+ if mcp_servers:
222
+ client = MultiServerMCPClient(
223
+ {
224
+ server.name.replace(" ", "-"): {
225
+ "url": server.value,
226
+ "transport": "sse",
227
+ }
228
+ for server in mcp_servers
229
+ },
230
+ )
231
+ tools = await client.get_tools()
232
+
233
+ llm_agent = create_react_agent(
234
+ model=llm,
235
+ tools=tools,
236
+ prompt=SYSTEM_MESSAGE,
237
+ )
238
+
239
+ return "βœ… Successfully connected to Hugging Face!"
240
+
241
+
242
+ async def gr_connect_to_nebius(
243
+ model_id: str,
244
+ nebius_access_token_textbox: str,
245
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
246
+ ) -> str:
247
+ """Initialize Hugging Face agent."""
248
+ global llm_agent # noqa: PLW0603
249
+
250
+ llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
251
+
252
+ if llm is None:
253
+ return f"❌ Connection failed: {error}"
254
+ tools = []
255
+ if mcp_servers:
256
+ client = MultiServerMCPClient(
257
+ {
258
+ server.name.replace(" ", "-"): {
259
+ "url": server.value,
260
+ "transport": "sse",
261
+ }
262
+ for server in mcp_servers
263
+ },
264
+ )
265
+ tools = await client.get_tools()
266
+
267
+ llm_agent = create_react_agent(
268
+ model=str(llm),
269
+ tools=tools,
270
+ prompt=SYSTEM_MESSAGE,
271
+ )
272
+ return "βœ… Successfully connected to nebius!"
273
+
274
+
275
  async def gr_chat_function( # noqa: D103
276
  message: str,
277
  history: list[Mapping[str, str]],
 
298
 
299
  ## UI components ##
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # Function to toggle visibility and set model IDs
303
+ def toggle_model_fields(
304
+ provider: str,
305
+ ) -> tuple[
306
+ dict[str, Any],
307
+ dict[str, Any],
308
+ dict[str, Any],
309
+ dict[str, Any],
310
+ dict[str, Any],
311
+ dict[str, Any],
312
+ ]: # ignore: F821
313
+ """Toggle visibility of model fields based on the selected provider."""
314
+ # Update model choices based on the selected provider
315
+ if provider in MODEL_OPTIONS:
316
+ model_choices = list(MODEL_OPTIONS[provider].keys())
317
+ model_pretty = gr.update(choices=model_choices, visible=True, interactive=True)
318
+ else:
319
+ model_pretty = gr.update(choices=[], visible=False)
320
+
321
+ # Visibility settings for fields specific to each provider
322
+ is_aws = provider == "AWS Bedrock"
323
+ is_hf = provider == "HuggingFace"
324
+ return (
325
+ model_pretty,
326
+ gr.update(visible=is_aws, interactive=is_aws),
327
+ gr.update(visible=is_aws, interactive=is_aws),
328
+ gr.update(visible=is_aws, interactive=is_aws),
329
+ gr.update(visible=is_aws, interactive=is_aws),
330
+ gr.update(visible=is_hf, interactive=is_hf),
331
+ )
332
+
333
+
334
+ async def update_connection_status( # noqa: PLR0913
335
+ provider: str,
336
+ pretty_model: str,
337
+ mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
338
+ aws_access_key_textbox: str,
339
+ aws_secret_key_textbox: str,
340
+ aws_session_token_textbox: str,
341
+ aws_region_dropdown: str,
342
+ hf_token: str,
343
+ temperature: float,
344
+ max_tokens: int,
345
+ ) -> str:
346
+ """Update the connection status based on the selected provider and model."""
347
+ if not provider or not pretty_model:
348
+ return "❌ Please select a provider and model."
349
+ model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
350
+ if model_id:
351
+ if provider == "AWS Bedrock":
352
+ connection = await gr_connect_to_bedrock(
353
+ model_id,
354
+ aws_access_key_textbox,
355
+ aws_secret_key_textbox,
356
+ aws_session_token_textbox,
357
+ aws_region_dropdown,
358
+ mcp_list_state,
359
+ temperature,
360
+ max_tokens,
361
+ )
362
+ elif provider == "HuggingFace":
363
+ connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
364
+ elif provider == "Nebius":
365
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
366
+ else:
367
+ return "❌ Invalid provider"
368
+ return connection if connection else "❌ Invalid provider"
369
+
370
+
371
+ with gr.Blocks(
372
+ theme=gr.themes.Origin(primary_hue="teal", spacing_size="sm", font="sans-serif"),
373
+ title="TDAgent",
374
+ ) as gr_app, gr.Row():
375
+ with gr.Column(scale=1):
376
+ with gr.Accordion("πŸ”Œ MCP Servers", open=False):
377
+ mcp_list = MutableCheckBoxGroup(
378
+ values=[
379
+ MutableCheckBoxGroupEntry(
380
+ name="TDAgent tools",
381
+ value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
382
+ ),
383
+ ],
384
+ label="MCP Servers",
385
+ )
386
+
387
+ with gr.Accordion("βš™οΈ Provider Configuration", open=True):
388
+ model_provider = gr.Dropdown(
389
+ choices=list(MODEL_OPTIONS.keys()),
390
+ value=None,
391
+ label="Select Model Provider",
392
  )
 
393
  aws_access_key_textbox = gr.Textbox(
394
  label="AWS Access Key ID",
395
  type="password",
396
  placeholder="Enter your AWS Access Key ID",
397
+ visible=False,
398
  )
399
  aws_secret_key_textbox = gr.Textbox(
400
  label="AWS Secret Access Key",
401
  type="password",
402
  placeholder="Enter your AWS Secret Access Key",
403
+ visible=False,
404
  )
 
 
 
 
 
 
 
405
  aws_region_dropdown = gr.Dropdown(
406
  label="AWS Region",
407
  choices=[
 
412
  "ap-southeast-1",
413
  ],
414
  value="eu-west-1",
415
+ visible=False,
416
+ )
417
+ aws_session_token_textbox = gr.Textbox(
418
+ label="AWS Session Token",
419
+ type="password",
420
+ placeholder="Enter your AWS session token",
421
+ visible=False,
422
+ )
423
+ hf_token = gr.Textbox(
424
+ label="HuggingFace Token",
425
+ type="password",
426
+ placeholder="Enter your Hugging Face Access Token",
427
+ visible=False,
428
+ )
429
+
430
+ with gr.Accordion("🧠 Model Configuration", open=True):
431
+ model_display_id = gr.Dropdown(
432
+ label="Select Model ID",
433
+ choices=[],
434
+ visible=False,
435
+ )
436
+ model_provider.change(
437
+ toggle_model_fields,
438
+ inputs=[model_provider],
439
+ outputs=[
440
+ model_display_id,
441
+ aws_access_key_textbox,
442
+ aws_secret_key_textbox,
443
+ aws_session_token_textbox,
444
+ aws_region_dropdown,
445
+ hf_token,
446
+ ],
447
+ )
448
+ # Initialize the temperature and max tokens based on model specifications
449
+ temperature = gr.Slider(
450
+ label="Temperature",
451
+ minimum=0.0,
452
+ maximum=1.0,
453
+ value=0.8,
454
+ step=0.1,
455
+ )
456
+ max_tokens = gr.Slider(
457
+ label="Max Tokens",
458
+ minimum=64,
459
+ maximum=4096,
460
+ value=512,
461
+ step=64,
462
  )
 
463
 
464
+ connect_btn = gr.Button("πŸ”Œ Connect to Model", variant="primary")
465
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
466
 
467
  connect_btn.click(
468
+ update_connection_status,
469
  inputs=[
470
+ model_provider,
471
+ model_display_id,
472
+ mcp_list.state,
473
  aws_access_key_textbox,
474
  aws_secret_key_textbox,
475
  aws_session_token_textbox,
476
  aws_region_dropdown,
477
+ hf_token,
478
+ temperature,
479
+ max_tokens,
480
  ],
481
  outputs=[status_textbox],
482
  )
483
 
484
+ with gr.Column(scale=2):
485
+ chat_interface = gr.ChatInterface(
486
+ fn=gr_chat_function,
487
+ type="messages",
488
+ examples=[], # Add examples if needed
489
+ title="πŸ‘©β€πŸ’» TDAgent",
490
+ description="This is a simple agent that uses MCP tools.",
491
+ )
492
 
493
 
494
  if __name__ == "__main__":
uv.lock CHANGED
The diff for this file is too large to render. See raw diff