ToDoAgent / tools.py
siyuwang541's picture
mvp
95bd630 verified
#!/usr/bin/python3
# -*- coding:utf-8 -*-
import os
import datetime
import re
import time
import traceback
import math
from urllib.parse import urlparse
from urllib3 import encode_multipart_formdata
from wsgiref.handlers import format_date_time
from time import mktime
import hashlib
import base64
import hmac
from urllib.parse import urlencode
import json
import requests
import azure.cognitiveservices.speech as speechsdk
# 常量定义
LFASR_HOST = "http://upload-ost-api.xfyun.cn/file" # 文件上传Host
API_INIT = "/mpupload/init" # 初始化接口
API_UPLOAD = "/upload" # 上传接口
API_CUT = "/mpupload/upload" # 分片上传接口
API_CUT_COMPLETE = "/mpupload/complete" # 分片完成接口
API_CUT_CANCEL = "/mpupload/cancel" # 分片取消接口
FILE_PIECE_SIZE = 5242880 # 文件分片大小5M
PRO_CREATE_URI = "/v2/ost/pro_create"
QUERY_URI = "/v2/ost/query"
# 文件上传类
class FileUploader:
def __init__(self, app_id, api_key, api_secret, upload_file_path):
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.upload_file_path = upload_file_path
def get_request_id(self):
"""生成请求ID"""
return time.strftime("%Y%m%d%H%M")
def hashlib_256(self, data):
"""计算 SHA256 哈希"""
m = hashlib.sha256(bytes(data.encode(encoding="utf-8"))).digest()
digest = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
return digest
def assemble_auth_header(self, request_url, file_data_type, method="", body=""):
"""组装鉴权头部"""
u = urlparse(request_url)
host = u.hostname
path = u.path
now = datetime.datetime.now()
date = format_date_time(mktime(now.timetuple()))
digest = "SHA256=" + self.hashlib_256("")
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1\ndigest: {}".format(
host, date, method, path, digest
)
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
self.api_key,
"hmac-sha256",
"host date request-line digest",
signature_sha,
)
headers = {
"host": host,
"date": date,
"authorization": authorization,
"digest": digest,
"content-type": file_data_type,
}
return headers
def call_api(self, url, file_data, file_data_type):
"""调用POST API接口"""
headers = self.assemble_auth_header(
url, file_data_type, method="POST", body=file_data
)
try:
resp = requests.post(url, headers=headers, data=file_data, timeout=8)
print("上传状态:", resp.status_code, resp.text)
return resp.json()
except Exception as e:
print("上传失败!Exception :%s" % e)
return None
def upload_cut_complete(self, upload_id):
"""分块上传完成"""
body_dict = {
"app_id": self.app_id,
"request_id": self.get_request_id(),
"upload_id": upload_id,
}
file_data_type = "application/json"
url = LFASR_HOST + API_CUT_COMPLETE
response = self.call_api(url, json.dumps(body_dict), file_data_type)
if response and "data" in response and "url" in response["data"]:
file_url = response["data"]["url"]
print("任务上传结束")
return file_url
else:
print("分片上传完成失败", response)
return None
def upload_file(self):
"""上传文件,根据文件大小选择分片或普通上传"""
file_total_size = os.path.getsize(self.upload_file_path)
if file_total_size < 31457280: # 30MB
print("-----不使用分块上传-----")
return self.simple_upload()
else:
print("-----使用分块上传-----")
return self.multipart_upload()
def simple_upload(self):
"""简单上传文件"""
try:
with open(self.upload_file_path, mode="rb") as f:
file = {
"data": (self.upload_file_path, f.read()),
"app_id": self.app_id,
"request_id": self.get_request_id(),
}
encode_data = encode_multipart_formdata(file)
file_data = encode_data[0]
file_data_type = encode_data[1]
url = LFASR_HOST + API_UPLOAD
response = self.call_api(url, file_data, file_data_type)
if response and "data" in response and "url" in response["data"]:
return response["data"]["url"]
else:
print("简单上传失败", response)
return None
except FileNotFoundError:
print("文件未找到:", self.upload_file_path)
return None
def multipart_upload(self):
"""分片上传文件"""
upload_id = self.prepare_upload()
if not upload_id:
return None
if not self.do_upload(upload_id):
return None
file_url = self.upload_cut_complete(upload_id)
print("分片上传地址:", file_url)
return file_url
def prepare_upload(self):
"""预处理,获取upload_id"""
body_dict = {
"app_id": self.app_id,
"request_id": self.get_request_id(),
}
url = LFASR_HOST + API_INIT
file_data_type = "application/json"
response = self.call_api(url, json.dumps(body_dict), file_data_type)
if response and "data" in response and "upload_id" in response["data"]:
return response["data"]["upload_id"]
else:
print("预处理失败", response)
return None
def do_upload(self, upload_id):
"""执行分片上传"""
file_total_size = os.path.getsize(self.upload_file_path)
chunk_size = FILE_PIECE_SIZE
chunks = math.ceil(file_total_size / chunk_size)
request_id = self.get_request_id()
slice_id = 1
print(
"文件:",
self.upload_file_path,
" 文件大小:",
file_total_size,
" 分块大小:",
chunk_size,
" 分块数:",
chunks,
)
with open(self.upload_file_path, mode="rb") as content:
while slice_id <= chunks:
current_size = min(
chunk_size, file_total_size - (slice_id - 1) * chunk_size
)
file = {
"data": (self.upload_file_path, content.read(current_size)),
"app_id": self.app_id,
"request_id": request_id,
"upload_id": upload_id,
"slice_id": slice_id,
}
encode_data = encode_multipart_formdata(file)
file_data = encode_data[0]
file_data_type = encode_data[1]
url = LFASR_HOST + API_CUT
resp = self.call_api(url, file_data, file_data_type)
count = 0
while not resp and (count < 3):
print("上传重试")
resp = self.call_api(url, file_data, file_data_type)
count = count + 1
time.sleep(1)
if not resp:
print("分片上传失败")
return False
slice_id += 1
return True
class ResultExtractor:
def __init__(self, appid, apikey, apisecret):
# POST 请求相关参数
self.Host = "ost-api.xfyun.cn"
self.RequestUriCreate = PRO_CREATE_URI
self.RequestUriQuery = QUERY_URI
# 设置 URL
if re.match(r"^\d", self.Host):
self.urlCreate = "http://" + self.Host + self.RequestUriCreate
self.urlQuery = "http://" + self.Host + self.RequestUriQuery
else:
self.urlCreate = "https://" + self.Host + self.RequestUriCreate
self.urlQuery = "https://" + self.Host + self.RequestUriQuery
self.HttpMethod = "POST"
self.APPID = appid
self.Algorithm = "hmac-sha256"
self.HttpProto = "HTTP/1.1"
self.UserName = apikey
self.Secret = apisecret
# 设置当前时间
cur_time_utc = datetime.datetime.now(datetime.timezone.utc)
self.Date = self.httpdate(cur_time_utc)
# 设置测试音频文件参数
self.BusinessArgsCreate = {
"language": "zh_cn",
"accent": "mandarin",
"domain": "pro_ost_ed",
}
def img_read(self, path):
with open(path, "rb") as fo:
return fo.read()
def hashlib_256(self, res):
m = hashlib.sha256(bytes(res.encode(encoding="utf-8"))).digest()
result = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
return result
def httpdate(self, dt):
weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()]
month = [
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
][dt.month - 1]
return "%s, %02d %s %04d %02d:%02d:%02d GMT" % (
weekday,
dt.day,
month,
dt.year,
dt.hour,
dt.minute,
dt.second,
)
def generateSignature(self, digest, uri):
signature_str = "host: " + self.Host + "\n"
signature_str += "date: " + self.Date + "\n"
signature_str += self.HttpMethod + " " + uri + " " + self.HttpProto + "\n"
signature_str += "digest: " + digest
signature = hmac.new(
bytes(self.Secret.encode("utf-8")),
bytes(signature_str.encode("utf-8")),
digestmod=hashlib.sha256,
).digest()
result = base64.b64encode(signature)
return result.decode(encoding="utf-8")
def init_header(self, data, uri):
digest = self.hashlib_256(data)
sign = self.generateSignature(digest, uri)
auth_header = (
'api_key="%s",algorithm="%s", '
'headers="host date request-line digest", '
'signature="%s"' % (self.UserName, self.Algorithm, sign)
)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Method": "POST",
"Host": self.Host,
"Date": self.Date,
"Digest": digest,
"Authorization": auth_header,
}
return headers
def get_create_body(self, fileurl):
post_data = {
"common": {"app_id": self.APPID},
"business": self.BusinessArgsCreate,
"data": {"audio_src": "http", "audio_url": fileurl, "encoding": "raw"},
}
body = json.dumps(post_data)
return body
def get_query_body(self, task_id):
post_data = {
"common": {"app_id": self.APPID},
"business": {
"task_id": task_id,
},
}
body = json.dumps(post_data)
return body
def call(self, url, body, headers):
try:
response = requests.post(url, data=body, headers=headers, timeout=8)
status_code = response.status_code
if status_code != 200:
info = response.content
return info
else:
try:
return json.loads(response.text)
except json.JSONDecodeError:
return response.text
except Exception as e:
print("Exception :%s" % e)
return None
def task_create(self, fileurl):
body = self.get_create_body(fileurl)
headers_create = self.init_header(body, self.RequestUriCreate)
return self.call(self.urlCreate, body, headers_create)
def task_query(self, task_id):
query_body = self.get_query_body(task_id)
headers_query = self.init_header(query_body, self.RequestUriQuery)
return self.call(self.urlQuery, query_body, headers_query)
def extract_text(self, result):
"""
从API响应中提取文本内容
支持多种结果格式,增强错误处理
"""
# 调试输出:打印原始结果类型
print(f"\n[DEBUG] extract_text 输入类型: {type(result)}")
# 如果是字符串,尝试解析为JSON
if isinstance(result, str):
print(f"[DEBUG] 字符串内容 (前200字符): {result[:200]}")
try:
result = json.loads(result)
print("[DEBUG] 成功解析字符串为JSON对象")
except json.JSONDecodeError:
print("[DEBUG] 无法解析为JSON,返回原始字符串")
return result
# 处理字典类型的结果
if isinstance(result, dict):
print("[DEBUG] 处理字典类型结果")
# 1. 检查错误信息
if "code" in result and result["code"] != 0:
error_msg = result.get("message", "未知错误")
print(
f"[ERROR] API返回错误: code={result['code']}, message={error_msg}"
)
return f"错误: {error_msg}"
# 2. 检查直接包含文本结果的情况
if "result" in result and isinstance(result["result"], str):
print("[DEBUG] 找到直接结果字段")
return result["result"]
# 3. 检查lattice结构(详细结果)
if "lattice" in result and isinstance(result["lattice"], list):
print("[DEBUG] 解析lattice结构")
text_parts = []
for lattice in result["lattice"]:
if not isinstance(lattice, dict):
continue
# 获取json_1best内容
json_1best = lattice.get("json_1best", {})
if not json_1best or not isinstance(json_1best, dict):
continue
# 处理st字段 - 修正:st可能是字典或列表
st_content = json_1best.get("st")
st_list = []
if isinstance(st_content, dict):
st_list = [st_content] # 转为列表统一处理
elif isinstance(st_content, list):
st_list = st_content
for st in st_list:
if isinstance(st, str):
# 直接是字符串结果
text_parts.append(st)
elif isinstance(st, dict):
# 处理字典结构的st
rt = st.get("rt", [])
if not isinstance(rt, list):
continue
for item in rt:
if isinstance(item, dict):
ws_list = item.get("ws", [])
if isinstance(ws_list, list):
for ws in ws_list:
if isinstance(ws, dict):
cw_list = ws.get("cw", [])
if isinstance(cw_list, list):
for cw in cw_list:
if isinstance(cw, dict):
w = cw.get("w", "")
if w:
text_parts.append(w)
return "".join(text_parts)
# 4. 检查简化结构(直接包含st)
if "st" in result and isinstance(result["st"], list):
print("[DEBUG] 解析st结构")
text_parts = []
for st in result["st"]:
if isinstance(st, str):
text_parts.append(st)
elif isinstance(st, dict):
rt = st.get("rt", [])
if isinstance(rt, list):
for item in rt:
if isinstance(item, dict):
ws_list = item.get("ws", [])
if isinstance(ws_list, list):
for ws in ws_list:
if isinstance(ws, dict):
cw_list = ws.get("cw", [])
if isinstance(cw_list, list):
for cw in cw_list:
if isinstance(cw, dict):
w = cw.get("w", "")
if w:
text_parts.append(w)
return "".join(text_parts)
# 5. 其他未知结构
print("[WARNING] 无法识别的结果结构")
return json.dumps(result, indent=2, ensure_ascii=False)
# 6. 非字典类型结果
print(f"[WARNING] 非字典类型结果: {type(result)}")
return str(result)
def audio_to_str(appid, apikey, apisecret, file_path):
"""
调用讯飞开放平台接口,获取音频文件的转写结果。
参数:
appid (str): 讯飞开放平台的appid。
apikey (str): 讯飞开放平台的apikey。
apisecret (str): 讯飞开放平台的apisecret。
file_path (str): 音频文件路径。
返回值:
str: 转写结果文本,如果发生错误则返回None。
"""
# 检查文件是否存在
if not os.path.exists(file_path):
print(f"错误:文件 {file_path} 不存在")
return None
try:
# 1. 文件上传
file_uploader = FileUploader(
app_id=appid,
api_key=apikey,
api_secret=apisecret,
upload_file_path=file_path,
)
fileurl = file_uploader.upload_file()
if not fileurl:
print("文件上传失败")
return None
print("文件上传成功,fileurl:", fileurl)
# 2. 创建任务并查询结果
result_extractor = ResultExtractor(appid, apikey, apisecret)
print("\n------ 创建任务 -------")
create_response = result_extractor.task_create(fileurl)
# 调试输出创建响应
print(
f"[DEBUG] 创建任务响应: {json.dumps(create_response, indent=2, ensure_ascii=False)}"
)
if not isinstance(create_response, dict) or "data" not in create_response:
print("创建任务失败:", create_response)
return None
task_id = create_response["data"]["task_id"]
print(f"任务ID: {task_id}")
# 查询任务
print("\n------ 查询任务 -------")
print("任务转写中······")
max_attempts = 30
attempt = 0
while attempt < max_attempts:
result = result_extractor.task_query(task_id)
# 调试输出查询响应
print(f"\n[QUERY {attempt + 1}] 响应类型: {type(result)}")
if isinstance(result, dict):
print(
f"[QUERY {attempt + 1}] 响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}"
)
else:
print(
f"[QUERY {attempt + 1}] 响应内容 (前200字符): {str(result)[:200]}"
)
# 检查响应是否有效
if not isinstance(result, dict):
print(f"无效响应类型: {type(result)}")
return None
# 检查API错误码
if "code" in result and result["code"] != 0:
error_msg = result.get("message", "未知错误")
print(f"API错误: code={result['code']}, message={error_msg}")
return None
# 获取任务状态
task_data = result.get("data", {})
task_status = task_data.get("task_status")
if not task_status:
print("响应中缺少任务状态字段")
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
return None
# 处理不同状态
if task_status in ["3", "4"]: # 任务已完成或回调完成
print("转写完成···")
# 提取结果
result_content = task_data.get("result")
if result_content is not None:
try:
result_text = result_extractor.extract_text(result_content)
print("\n转写结果:\n", result_text)
return result_text
except Exception as e:
print(f"\n提取文本时出错: {str(e)}")
print(f"错误详情:\n{traceback.format_exc()}")
print(
"原始结果内容:",
json.dumps(result_content, indent=2, ensure_ascii=False),
)
return None
else:
print("\n响应中缺少结果字段")
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
return None
elif task_status in ["1", "2"]: # 任务待处理或处理中
print(
f"任务状态:{task_status},等待中... (尝试 {attempt + 1}/{max_attempts})"
)
time.sleep(5)
attempt += 1
else:
print(f"未知任务状态:{task_status}")
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
return None
else:
print(f"超过最大查询次数({max_attempts}),任务可能仍在处理中")
return None
except Exception as e:
print(f"发生异常: {str(e)}")
print(f"错误详情:\n{traceback.format_exc()}")
return None
"""
1、通用文字识别,图像数据base64编码后大小不得超过10M
2、appid、apiSecret、apiKey请到讯飞开放平台控制台获取并填写到此demo中
3、支持中英文,支持手写和印刷文字。
4、在倾斜文字上效果有提升,同时支持部分生僻字的识别
"""
# 图像识别接口地址
URL = "https://api.xf-yun.com/v1/private/sf8e6aca1"
class AssembleHeaderException(Exception):
def __init__(self, msg):
self.message = msg
class Url:
def __init__(self, host, path, schema):
self.host = host
self.path = path
self.schema = schema
pass
# calculate sha256 and encode to base64
def sha256base64(data):
sha256 = hashlib.sha256()
sha256.update(data)
digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
return digest
def parse_url(requset_url):
stidx = requset_url.index("://")
host = requset_url[stidx + 3 :]
schema = requset_url[: stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + requset_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u
# build websocket auth request url
def assemble_ws_auth_url(requset_url, method="POST", api_key="", api_secret=""):
u = parse_url(requset_url)
host = u.host
path = u.path
now = datetime.datetime.now()
date = format_date_time(mktime(now.timetuple()))
# print(date) # 可选:打印Date值
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
host, date, method, path
)
# print(signature_origin) # 可选:打印签名原文
signature_sha = hmac.new(
api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = (
'api_key="%s", algorithm="%s", headers="%s", signature="%s"'
% (api_key, "hmac-sha256", "host date request-line", signature_sha)
)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
# print(authorization_origin) # 可选:打印鉴权原文
values = {"host": host, "date": date, "authorization": authorization}
return requset_url + "?" + urlencode(values)
def image_to_str(endpoint=None, key=None, unused_param=None, file_path=None):
"""
调用Azure Computer Vision API识别图片中的文字。
参数:
endpoint (str): Azure Computer Vision endpoint URL。
key (str): Azure Computer Vision API key。
unused_param (str): 未使用的参数,保持兼容性。
file_path (str): 图片文件路径。
返回值:
str: 图片中的文字识别结果,如果发生错误则返回None。
"""
# 默认配置
if endpoint is None:
endpoint = "https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/"
if key is None:
key = "45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ"
try:
# 读取图片文件
with open(file_path, "rb") as f:
image_data = f.read()
# 构造请求URL
analyze_url = endpoint.rstrip('/') + "/vision/v3.2/read/analyze"
# 设置请求头
headers = {
'Ocp-Apim-Subscription-Key': key,
'Content-Type': 'application/octet-stream'
}
# 发送POST请求开始分析
response = requests.post(analyze_url, headers=headers, data=image_data)
if response.status_code != 202:
print(f"分析请求失败: {response.status_code}, {response.text}")
return None
# 获取操作位置
operation_url = response.headers["Operation-Location"]
# 轮询结果
import time
while True:
result_response = requests.get(operation_url, headers={'Ocp-Apim-Subscription-Key': key})
result = result_response.json()
if result["status"] == "succeeded":
# 提取文字
text_results = []
if "analyzeResult" in result and "readResults" in result["analyzeResult"]:
for read_result in result["analyzeResult"]["readResults"]:
for line in read_result["lines"]:
text_results.append(line["text"])
return " ".join(text_results) if text_results else ""
elif result["status"] == "failed":
print(f"文字识别失败: {result}")
return None
# 等待1秒后重试
time.sleep(1)
except Exception as e:
print(f"发生异常: {e}")
return None
if __name__ == "__main__":
# 输入讯飞开放平台的 appid,secret、key 和文件路径
appid = "33c1b63d"
apikey = "40bf7cd82e31ace30a9cfb76309a43a3"
apisecret = "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4"
audio_path = r"audio_sample_little.wav" # 确保文件路径正确
image_path = r"1.png" # 确保文件路径正确
# 音频转文字
audio_text = audio_to_str(appid, apikey, apisecret, audio_path)
# 图片转文字
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_path)
print("-"* 20)
print("\n音频转文字结果:", audio_text)
print("\n图片转文字结果:", image_text)
def azure_speech_to_text(speech_key, speech_region, audio_file_path):
"""
使用Azure Speech服务将音频文件转换为文本。
参数:
speech_key (str): Azure Speech服务的API密钥。
speech_region (str): Azure Speech服务的区域。
audio_file_path (str): 音频文件路径。
返回值:
str: 转换后的文本,如果发生错误则返回None。
"""
try:
# 设置语音配置
speech_config = speechsdk.SpeechConfig(subscription=speech_key, region=speech_region)
speech_config.speech_recognition_language = "zh-CN" # 设置为中文
# 设置音频配置
audio_config = speechsdk.audio.AudioConfig(filename=audio_file_path)
# 创建语音识别器
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
# 执行语音识别
result = speech_recognizer.recognize_once()
# 检查识别结果
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
print(f"Azure Speech识别成功: {result.text}")
return result.text
elif result.reason == speechsdk.ResultReason.NoMatch:
print("Azure Speech未识别到语音")
return None
elif result.reason == speechsdk.ResultReason.Canceled:
cancellation_details = result.cancellation_details
print(f"Azure Speech识别被取消: {cancellation_details.reason}")
if cancellation_details.reason == speechsdk.CancellationReason.Error:
print(f"错误详情: {cancellation_details.error_details}")
return None
except Exception as e:
print(f"Azure Speech识别出错: {str(e)}")
return None