Spaces:
Running
Running
import uuid | |
import json | |
import requests | |
import time | |
import gradio as gr | |
# Data class of a session with ID and task history | |
class Session: | |
def __init__(self, session_id): | |
self.session_id = session_id | |
#tasks are a list of tuples with role and content | |
self.task_history = [] | |
def add_task(self, task): | |
self.task_history.append(task) | |
def get_history(self): | |
return self.task_history | |
# Global session tracking | |
sessions = {} | |
current_session_id = None | |
def new_session(): | |
""" Creates a new session, sets it as current, and returns the session ID. """ | |
global current_session_id | |
session_id = str(uuid.uuid4()) | |
# add a new Session object to the sessions dictionary | |
sessions[session_id] = Session(session_id) | |
current_session_id = session_id | |
return session_id | |
def get_current_session_id(): | |
""" Returns the current active session ID, or None if none exists. """ | |
return current_session_id | |
def list_sessions(): | |
""" Returns a list of all session IDs being tracked. """ | |
return list(sessions.keys()) | |
def change_session(session_id): | |
""" Changes the current session to the specified session ID. """ | |
global current_session_id | |
if session_id in sessions: | |
current_session_id = session_id | |
else: | |
raise ValueError(f"Session ID {session_id} not found.") | |
def get_task_history(session_id): | |
""" Retrieves the task history for a given session ID.""" | |
return sessions[session_id].get_history() | |
def send_message(server_url, content): | |
""" | |
Sends a message to the A2A server using the current session ID. | |
If no session exists, a new one is created automatically. | |
This method supports both direct responses in a Message object and | |
a task that must be polled for updates. | |
""" | |
global current_session_id | |
if not current_session_id: | |
new_session() | |
# Generate a unique ID for the message (and JSON-RPC request) | |
message_id = str(uuid.uuid4()) | |
# Log the user's message in the client-side session storage. | |
sessions[current_session_id].add_task({"role": "user", "content": content}) | |
# Construct the structured message object using the new protocol. | |
message_obj = { | |
"role": "user", | |
"parts": [{"kind": "text", "text": content}], | |
"messageId": message_id, | |
"kind": "message" | |
} | |
# Build the JSON-RPC payload | |
payload = { | |
"jsonrpc": "2.0", | |
"method": "message/send", | |
"params": { | |
"sessionId": current_session_id, | |
"message": message_obj, | |
"historyLength": 5 | |
}, | |
"id": message_id | |
} | |
json_endpoint = f"{server_url.rstrip('/')}/jsonrpc" | |
response = requests.post(json_endpoint, json=payload) | |
response.raise_for_status() | |
result = response.json() | |
# Check if response is a task or a message | |
if "result" not in result: | |
# If the response does not contain a result, it might be an error or a task ID. | |
if "error" in result: | |
error_message = result["error"].get("message", "Unknown error") | |
sessions[current_session_id].add_task({"role": "assistant", "content": f"Error: {error_message}"}) | |
return sessions[current_session_id].get_history() | |
return | |
# If it's a task, we need to poll for updates. | |
task_id = result.get("id") | |
if not task_id: | |
sessions[current_session_id].add_task({"role": "assistant", "content": "No task ID returned."}) | |
return sessions[current_session_id].get_history() | |
return | |
# TODO allow for multiple concurrent tasks in the client | |
# should store task IDs in the session and add a view into interface | |
# Poll for task updates | |
return poll_for_task_completion(server_url, task_id, current_session_id) | |
# If the response contains a result, it should be a direct message object | |
# Get the parts where the kind is "text", then get the text value | |
message_parts = result.get("result", {}).get("parts", []) | |
agent_reply = next((part["text"] for part in message_parts if part.get("kind") == "text"), "") | |
# convert agent_reply_object from a string of json into a Python object to get the text field | |
if agent_reply: | |
sessions[current_session_id].add_task({"role": "assistant", "content": agent_reply}) | |
history = sessions[current_session_id].get_history() | |
return history | |
# Poll for task updates | |
def poll_for_task_completion(server_url, task_id, current_session_id): | |
while True: | |
# Sleep for a second to avoid hammering the server | |
time.sleep(1) | |
task_status = get_task_status(server_url, task_id) | |
if task_status and "status" in task_status: | |
state = task_status["status"].get("state", "") | |
if state == "completed": | |
# TODO allow for returning other parts (data types) | |
# Extract the text part from the final message | |
message = task_status["status"].get("message", {}) | |
text_parts = [ | |
part["text"] for part in message.get("parts", []) | |
if part.get("kind") == "text" | |
] | |
final_text = text_parts[0] if text_parts else "Task completed, but no text content found." | |
# Add the final result to the session | |
sessions[current_session_id].add_task({"role": "assistant", "content": final_text}) | |
# Yield the final history to Gradio | |
yield sessions[current_session_id].get_history() | |
break # Exit loop once the task is completed | |
else: | |
break # Exit loop if task status is invalid | |
def send_message_and_subscribe(server_url, content): | |
""" | |
Sends a message to the A2A server using the current session ID. | |
If no session exists, a new one is created automatically. | |
This method supports both direct responses and tasks that stream updates. | |
For a streaming task, each step is added to the client-side database and the complete | |
conversation history is yielded after each update. When the task status is final ("final": True), | |
no further updates are yielded. | |
""" | |
global current_session_id | |
if not current_session_id: | |
new_session() | |
# Generate a unique ID for the message (and JSON-RPC request) | |
message_id = str(uuid.uuid4()) | |
# Log the user's message in the client-side session storage. | |
sessions[current_session_id].add_task({"role": "user", "content": content}) | |
# Construct the structured message object using the new protocol. | |
message_obj = { | |
"role": "user", | |
"parts": [{"kind": "text", "text": content}], | |
"messageId": message_id, | |
"kind": "message" | |
} | |
# Build the JSON-RPC payload using the updated method name. | |
payload = { | |
"jsonrpc": "2.0", | |
"method": "message/sendSubscribe", # updated from tasks/send | |
"params": { | |
"sessionId": current_session_id, | |
"message": message_obj, | |
"historyLength": 5 | |
}, | |
"id": message_id | |
} | |
json_endpoint = f"{server_url.rstrip('/')}/jsonrpc" | |
# Submit the request in streaming mode so that we can yield each step if a task is created. | |
response = requests.post(json_endpoint, json=payload, stream=True) | |
response.raise_for_status() | |
# Process line-by-line as defined by Server-Sent Events (SSE). | |
for line in response.iter_lines(decode_unicode=True): | |
if line and line.startswith("data:"): | |
# Remove the "data:" prefix and parse the JSON | |
json_data = line[len("data:"):].strip() | |
try: | |
update = json.loads(json_data) | |
except Exception: | |
continue # skip lines that can't be parsed | |
if "result" in update: | |
# If the update contains a task update, attempt to extract the text. | |
task_update_text = "" | |
# First, check if the result carries a structured task update. | |
if "task" in update["result"]: | |
# Assume structured task update according to the protocol. | |
parts = update["result"]["task"].get("message", {}).get("parts", []) | |
for part in parts: | |
if part.get("kind") == "text": | |
task_update_text = part.get("text", "") | |
break | |
else: | |
# Fallback: try to directly extract a 'content' field. | |
task_update_text = update["result"].get("content", "") | |
# Add the agent's update to the session (as an assistant message) | |
sessions[current_session_id].add_task({"role": "assistant", "content": task_update_text}) | |
# Yield the full conversation history so far | |
his = sessions[current_session_id].get_history() | |
yield his | |
# If the update indicates the final step, break out of the stream. | |
if update["result"].get("final", False): | |
break | |
# Fallback: if no streaming updates are detected, try processing a direct response. | |
try: | |
result = response.json() | |
if "error" in result: | |
error_message = result["error"].get("message", "Unknown error") | |
sessions[current_session_id].add_task({"role": "assistant", "content": f"Error: {error_message}"}) | |
yield sessions[current_session_id].get_history() | |
else: | |
agent_reply = result.get("result", {}).get("content", "") | |
if agent_reply: | |
sessions[current_session_id].add_task({"role": "assistant", "content": agent_reply}) | |
print(sessions[current_session_id].get_history()) | |
yield sessions[current_session_id].get_history() | |
except Exception: | |
pass | |
def get_task_status(server_url, task_id): | |
payload = { | |
"jsonrpc": "2.0", | |
"method": "tasks/get", | |
"params": { | |
"id": task_id | |
}, | |
"id": task_id | |
} | |
response = requests.post(server_url, json=payload) | |
return response.json() | |
def cancel_task(server_url, task_id): | |
payload = { | |
"jsonrpc": "2.0", | |
"method": "tasks/cancel", | |
"params": { | |
"id": task_id | |
}, | |
"id": task_id | |
} | |
response = requests.post(server_url, json=payload) | |
return response.json() | |