A2A-client / task_management.py
abrakjamson
conversion to message/send and new protocol version
2ddfb98
import uuid
import json
import requests
import time
import gradio as gr
# Data class of a session with ID and task history
class Session:
def __init__(self, session_id):
self.session_id = session_id
#tasks are a list of tuples with role and content
self.task_history = []
def add_task(self, task):
self.task_history.append(task)
def get_history(self):
return self.task_history
# Global session tracking
sessions = {}
current_session_id = None
def new_session():
""" Creates a new session, sets it as current, and returns the session ID. """
global current_session_id
session_id = str(uuid.uuid4())
# add a new Session object to the sessions dictionary
sessions[session_id] = Session(session_id)
current_session_id = session_id
return session_id
def get_current_session_id():
""" Returns the current active session ID, or None if none exists. """
return current_session_id
def list_sessions():
""" Returns a list of all session IDs being tracked. """
return list(sessions.keys())
def change_session(session_id):
""" Changes the current session to the specified session ID. """
global current_session_id
if session_id in sessions:
current_session_id = session_id
else:
raise ValueError(f"Session ID {session_id} not found.")
def get_task_history(session_id):
""" Retrieves the task history for a given session ID."""
return sessions[session_id].get_history()
def send_message(server_url, content):
"""
Sends a message to the A2A server using the current session ID.
If no session exists, a new one is created automatically.
This method supports both direct responses in a Message object and
a task that must be polled for updates.
"""
global current_session_id
if not current_session_id:
new_session()
# Generate a unique ID for the message (and JSON-RPC request)
message_id = str(uuid.uuid4())
# Log the user's message in the client-side session storage.
sessions[current_session_id].add_task({"role": "user", "content": content})
# Construct the structured message object using the new protocol.
message_obj = {
"role": "user",
"parts": [{"kind": "text", "text": content}],
"messageId": message_id,
"kind": "message"
}
# Build the JSON-RPC payload
payload = {
"jsonrpc": "2.0",
"method": "message/send",
"params": {
"sessionId": current_session_id,
"message": message_obj,
"historyLength": 5
},
"id": message_id
}
json_endpoint = f"{server_url.rstrip('/')}/jsonrpc"
response = requests.post(json_endpoint, json=payload)
response.raise_for_status()
result = response.json()
# Check if response is a task or a message
if "result" not in result:
# If the response does not contain a result, it might be an error or a task ID.
if "error" in result:
error_message = result["error"].get("message", "Unknown error")
sessions[current_session_id].add_task({"role": "assistant", "content": f"Error: {error_message}"})
return sessions[current_session_id].get_history()
return
# If it's a task, we need to poll for updates.
task_id = result.get("id")
if not task_id:
sessions[current_session_id].add_task({"role": "assistant", "content": "No task ID returned."})
return sessions[current_session_id].get_history()
return
# TODO allow for multiple concurrent tasks in the client
# should store task IDs in the session and add a view into interface
# Poll for task updates
return poll_for_task_completion(server_url, task_id, current_session_id)
# If the response contains a result, it should be a direct message object
# Get the parts where the kind is "text", then get the text value
message_parts = result.get("result", {}).get("parts", [])
agent_reply = next((part["text"] for part in message_parts if part.get("kind") == "text"), "")
# convert agent_reply_object from a string of json into a Python object to get the text field
if agent_reply:
sessions[current_session_id].add_task({"role": "assistant", "content": agent_reply})
history = sessions[current_session_id].get_history()
return history
# Poll for task updates
def poll_for_task_completion(server_url, task_id, current_session_id):
while True:
# Sleep for a second to avoid hammering the server
time.sleep(1)
task_status = get_task_status(server_url, task_id)
if task_status and "status" in task_status:
state = task_status["status"].get("state", "")
if state == "completed":
# TODO allow for returning other parts (data types)
# Extract the text part from the final message
message = task_status["status"].get("message", {})
text_parts = [
part["text"] for part in message.get("parts", [])
if part.get("kind") == "text"
]
final_text = text_parts[0] if text_parts else "Task completed, but no text content found."
# Add the final result to the session
sessions[current_session_id].add_task({"role": "assistant", "content": final_text})
# Yield the final history to Gradio
yield sessions[current_session_id].get_history()
break # Exit loop once the task is completed
else:
break # Exit loop if task status is invalid
def send_message_and_subscribe(server_url, content):
"""
Sends a message to the A2A server using the current session ID.
If no session exists, a new one is created automatically.
This method supports both direct responses and tasks that stream updates.
For a streaming task, each step is added to the client-side database and the complete
conversation history is yielded after each update. When the task status is final ("final": True),
no further updates are yielded.
"""
global current_session_id
if not current_session_id:
new_session()
# Generate a unique ID for the message (and JSON-RPC request)
message_id = str(uuid.uuid4())
# Log the user's message in the client-side session storage.
sessions[current_session_id].add_task({"role": "user", "content": content})
# Construct the structured message object using the new protocol.
message_obj = {
"role": "user",
"parts": [{"kind": "text", "text": content}],
"messageId": message_id,
"kind": "message"
}
# Build the JSON-RPC payload using the updated method name.
payload = {
"jsonrpc": "2.0",
"method": "message/sendSubscribe", # updated from tasks/send
"params": {
"sessionId": current_session_id,
"message": message_obj,
"historyLength": 5
},
"id": message_id
}
json_endpoint = f"{server_url.rstrip('/')}/jsonrpc"
# Submit the request in streaming mode so that we can yield each step if a task is created.
response = requests.post(json_endpoint, json=payload, stream=True)
response.raise_for_status()
# Process line-by-line as defined by Server-Sent Events (SSE).
for line in response.iter_lines(decode_unicode=True):
if line and line.startswith("data:"):
# Remove the "data:" prefix and parse the JSON
json_data = line[len("data:"):].strip()
try:
update = json.loads(json_data)
except Exception:
continue # skip lines that can't be parsed
if "result" in update:
# If the update contains a task update, attempt to extract the text.
task_update_text = ""
# First, check if the result carries a structured task update.
if "task" in update["result"]:
# Assume structured task update according to the protocol.
parts = update["result"]["task"].get("message", {}).get("parts", [])
for part in parts:
if part.get("kind") == "text":
task_update_text = part.get("text", "")
break
else:
# Fallback: try to directly extract a 'content' field.
task_update_text = update["result"].get("content", "")
# Add the agent's update to the session (as an assistant message)
sessions[current_session_id].add_task({"role": "assistant", "content": task_update_text})
# Yield the full conversation history so far
his = sessions[current_session_id].get_history()
yield his
# If the update indicates the final step, break out of the stream.
if update["result"].get("final", False):
break
# Fallback: if no streaming updates are detected, try processing a direct response.
try:
result = response.json()
if "error" in result:
error_message = result["error"].get("message", "Unknown error")
sessions[current_session_id].add_task({"role": "assistant", "content": f"Error: {error_message}"})
yield sessions[current_session_id].get_history()
else:
agent_reply = result.get("result", {}).get("content", "")
if agent_reply:
sessions[current_session_id].add_task({"role": "assistant", "content": agent_reply})
print(sessions[current_session_id].get_history())
yield sessions[current_session_id].get_history()
except Exception:
pass
def get_task_status(server_url, task_id):
payload = {
"jsonrpc": "2.0",
"method": "tasks/get",
"params": {
"id": task_id
},
"id": task_id
}
response = requests.post(server_url, json=payload)
return response.json()
def cancel_task(server_url, task_id):
payload = {
"jsonrpc": "2.0",
"method": "tasks/cancel",
"params": {
"id": task_id
},
"id": task_id
}
response = requests.post(server_url, json=payload)
return response.json()