FISHYA commited on
Commit
049b61f
·
verified ·
1 Parent(s): de2770d

Upload retry.py

Browse files
Files changed (1) hide show
  1. retry.py +408 -0
retry.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import functools
4
+ import requests
5
+ from abc import ABC, abstractmethod
6
+ from typing import Callable, Any, Dict, Optional, Type, Union, TypeVar, cast
7
+
8
+ # 导入配置模块
9
+ import config
10
+
11
+ # 类型变量定义
12
+ T = TypeVar('T')
13
+
14
+ class RetryStrategy(ABC):
15
+ """重试策略的抽象基类"""
16
+
17
+ @abstractmethod
18
+ def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool:
19
+ """
20
+ 判断是否应该重试
21
+
22
+ Args:
23
+ exception: 捕获的异常
24
+ retry_count: 当前重试次数
25
+ max_retries: 最大重试次数
26
+
27
+ Returns:
28
+ bool: 是否应该重试
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ def get_retry_delay(self, retry_count: int, base_delay: int) -> float:
34
+ """
35
+ 计算重试延迟时间
36
+
37
+ Args:
38
+ retry_count: 当前重试次数
39
+ base_delay: 基础延迟时间(秒)
40
+
41
+ Returns:
42
+ float: 重试延迟时间(秒)
43
+ """
44
+ pass
45
+
46
+ @abstractmethod
47
+ def log_retry_attempt(self, logger: logging.Logger, exception: Exception,
48
+ retry_count: int, max_retries: int, delay: float) -> None:
49
+ """
50
+ 记录重试尝试
51
+
52
+ Args:
53
+ logger: 日志记录器
54
+ exception: 捕获的异常
55
+ retry_count: 当前重试次数
56
+ max_retries: 最大重试次数
57
+ delay: 重试延迟时间
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def on_retry(self, exception: Exception, retry_count: int) -> None:
63
+ """
64
+ 重试前的回调函数,可以执行额外操作
65
+
66
+ Args:
67
+ exception: 捕获的异常
68
+ retry_count: 当前重试次数
69
+ """
70
+ pass
71
+
72
+
73
+ class ExponentialBackoffStrategy(RetryStrategy):
74
+ """指数退避重试策略,适用于连接错误"""
75
+
76
+ def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool:
77
+ return (isinstance(exception, requests.exceptions.ConnectionError) and
78
+ retry_count < max_retries)
79
+
80
+ def get_retry_delay(self, retry_count: int, base_delay: int) -> float:
81
+ # 指数退避: base_delay * 2^(retry_count)
82
+ return base_delay * (2 ** retry_count)
83
+
84
+ def log_retry_attempt(self, logger: logging.Logger, exception: Exception,
85
+ retry_count: int, max_retries: int, delay: float) -> None:
86
+ # 检查logger是否为函数对象(如client._log)
87
+ if callable(logger) and not isinstance(logger, logging.Logger):
88
+ # 如果是函数,直接调用它
89
+ logger(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING")
90
+ else:
91
+ # 如果是Logger对象,调用warning方法
92
+ logger.warning(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}")
93
+
94
+ def on_retry(self, exception: Exception, retry_count: int) -> None:
95
+ # 连接错误不需要额外操作
96
+ pass
97
+
98
+
99
+ class LinearBackoffStrategy(RetryStrategy):
100
+ """线性退避重试策略,适用于超时错误"""
101
+
102
+ def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool:
103
+ return (isinstance(exception, requests.exceptions.Timeout) and
104
+ retry_count < max_retries)
105
+
106
+ def get_retry_delay(self, retry_count: int, base_delay: int) -> float:
107
+ # 线性退避: base_delay * retry_count
108
+ return base_delay * retry_count
109
+
110
+ def log_retry_attempt(self, logger: logging.Logger, exception: Exception,
111
+ retry_count: int, max_retries: int, delay: float) -> None:
112
+ # 检查logger是否为函数对象(如client._log)
113
+ if callable(logger) and not isinstance(logger, logging.Logger):
114
+ # 如果是函数,直接调用它
115
+ logger(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING")
116
+ else:
117
+ # 如果是Logger对象,调用warning方法
118
+ logger.warning(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}")
119
+
120
+ def on_retry(self, exception: Exception, retry_count: int) -> None:
121
+ # 超时错误不需要额外操作
122
+ pass
123
+
124
+
125
+ class ServerErrorStrategy(RetryStrategy):
126
+ """服务器错误重试策略,适用于5xx错误"""
127
+
128
+ def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool:
129
+ if not isinstance(exception, requests.exceptions.HTTPError):
130
+ return False
131
+
132
+ response = getattr(exception, 'response', None)
133
+ if response is None:
134
+ return False
135
+
136
+ return (500 <= response.status_code < 600 and retry_count < max_retries)
137
+
138
+ def get_retry_delay(self, retry_count: int, base_delay: int) -> float:
139
+ # 线性退避: base_delay * retry_count
140
+ return base_delay * retry_count
141
+
142
+ def log_retry_attempt(self, logger: logging.Logger, exception: Exception,
143
+ retry_count: int, max_retries: int, delay: float) -> None:
144
+ response = getattr(exception, 'response', None)
145
+ status_code = response.status_code if response else 'unknown'
146
+ # 检查logger是否为函数对象(如client._log)
147
+ if callable(logger) and not isinstance(logger, logging.Logger):
148
+ # 如果是函数,直接调用它
149
+ logger(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})", "WARNING")
150
+ else:
151
+ # 如果是Logger对象,调用warning方法
152
+ logger.warning(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})")
153
+
154
+ def on_retry(self, exception: Exception, retry_count: int) -> None:
155
+ # 服务器错误不需要额外操作
156
+ pass
157
+
158
+
159
+ class RateLimitStrategy(RetryStrategy):
160
+ """速率限制重试策略,适用于429错误,包括账号切换逻辑和延迟重试"""
161
+
162
+ def __init__(self, client=None):
163
+ """
164
+ 初始化速率限制重试策略
165
+
166
+ Args:
167
+ client: API客户端实例,用于切换账号
168
+ """
169
+ self.client = client
170
+ self.consecutive_429_count = 0 # 连续429错误计数器
171
+
172
+ def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool:
173
+ if not isinstance(exception, requests.exceptions.HTTPError):
174
+ return False
175
+
176
+ response = getattr(exception, 'response', None)
177
+ if response is None:
178
+ return False
179
+
180
+ is_rate_limit = response.status_code == 429
181
+ if is_rate_limit:
182
+ self.consecutive_429_count += 1
183
+ else:
184
+ self.consecutive_429_count = 0 # 重置计数器
185
+
186
+ return is_rate_limit
187
+
188
+ def get_retry_delay(self, retry_count: int, base_delay: int) -> float:
189
+ # 根据用户反馈,429错误时不需要延迟,立即重试
190
+ return 0
191
+
192
+ def log_retry_attempt(self, logger: logging.Logger, exception: Exception,
193
+ retry_count: int, max_retries: int, delay: float) -> None:
194
+ # 检查logger是否为函数对象(如client._log)
195
+ message = ""
196
+ if self.consecutive_429_count > 1:
197
+ message = f"连续第{self.consecutive_429_count}次速率限制错误,尝试立即重试"
198
+ else:
199
+ message = "速率限制错误,尝试切换账号"
200
+
201
+ if callable(logger) and not isinstance(logger, logging.Logger):
202
+ # 如果是函数,直接调用它
203
+ logger(message, "WARNING")
204
+ else:
205
+ # 如果是Logger对象,调用warning方法
206
+ logger.warning(message)
207
+
208
+ def on_retry(self, exception: Exception, retry_count: int) -> None:
209
+ # 新增: 获取关联信息
210
+ user_identifier = getattr(self.client, '_associated_user_identifier', None)
211
+ request_ip = getattr(self.client, '_associated_request_ip', None) # request_ip 可能在某些情况下需要
212
+
213
+ # 只有在首次429错误或账号池中有多个账号时才切换账号
214
+ if self.consecutive_429_count == 1 or (self.consecutive_429_count > 0 and self.consecutive_429_count % 3 == 0):
215
+ if self.client and hasattr(self.client, 'email'):
216
+ # 记录当前账号进入冷却期
217
+ current_email = self.client.email # 这是切换前的 email
218
+ config.set_account_cooldown(current_email)
219
+
220
+ # 获取新账号
221
+ new_email, new_password = config.get_next_ondemand_account_details()
222
+ if new_email:
223
+ # 更新客户端信息
224
+ self.client.email = new_email # 这是切换后的 email
225
+ self.client.password = new_password
226
+ self.client.token = ""
227
+ self.client.refresh_token = ""
228
+ self.client.session_id = "" # 重置会话ID,确保创建新会话
229
+
230
+ # 尝试使用新账号登录并创建会话
231
+ try:
232
+ # 获取当前请求的上下文哈希,以便在切换账号后重新登录和创建会话时使用
233
+ current_context_hash = getattr(self.client, '_current_request_context_hash', None)
234
+
235
+ self.client.sign_in(context=current_context_hash)
236
+ if self.client.create_session(external_context=current_context_hash):
237
+ # 如果成功登录并创建会话,记录日志并设置标志位
238
+ if hasattr(self.client, '_log'):
239
+ self.client._log(f"成功切换到账号 {new_email} 并使用上下文哈希 '{current_context_hash}' 重新登录和创建新会话。", "INFO")
240
+ # 设置标志位,通知调用方下次需要发送完整历史
241
+ setattr(self.client, '_new_session_requires_full_history', True)
242
+ if hasattr(self.client, '_log'):
243
+ self.client._log(f"已设置 _new_session_requires_full_history = True,下次查询应发送完整历史。", "INFO")
244
+ else:
245
+ # 会话创建失败,记录错误
246
+ if hasattr(self.client, '_log'):
247
+ self.client._log(f"切换到账号 {new_email} 后,创建新会话失败。", "WARNING")
248
+ # 确保在这种情况下不设置需要完整历史的标志,因为会话本身就没成功
249
+ setattr(self.client, '_new_session_requires_full_history', False)
250
+
251
+
252
+ # --- 新增: 更新 client_sessions ---
253
+ if not user_identifier:
254
+ if hasattr(self.client, '_log'):
255
+ self.client._log("RateLimitStrategy: _associated_user_identifier not found on client. Cannot update client_sessions.", "ERROR")
256
+ # 即使没有 user_identifier,账号切换和会话创建也已发生,只是无法更新全局会话池
257
+ else:
258
+ old_email_in_strategy = current_email # 切换前的 email
259
+ new_email_in_strategy = self.client.email # 切换后的 email (即 new_email)
260
+
261
+ with config.config_instance.client_sessions_lock:
262
+ if user_identifier in config.config_instance.client_sessions:
263
+ user_specific_sessions = config.config_instance.client_sessions[user_identifier]
264
+
265
+ # 1. 移除旧 email 的条目 (如果存在)
266
+ # 我们只移除那些 client 实例确实是当前 self.client 的条目,
267
+ # 或者更简单地,如果旧 email 存在,就移除它,因为 user_identifier
268
+ # 现在应该通过 new_email 使用这个(已被修改的)client 实例。
269
+ if old_email_in_strategy in user_specific_sessions:
270
+ # 检查 client 实例是否匹配可能不可靠,因为 client 内部状态已变。
271
+ # 直接删除旧 email 的条目,因为这个 user_identifier + client 组合现在用新 email。
272
+ del user_specific_sessions[old_email_in_strategy]
273
+ if hasattr(self.client, '_log'):
274
+ self.client._log(f"RateLimitStrategy: Removed session for old email '{old_email_in_strategy}' for user '{user_identifier}'.", "INFO")
275
+
276
+ # 2. 添加/更新新 email 的条目
277
+ # 确保它指向当前这个已被修改的 self.client 实例
278
+ # 并重置 active_context_hash。
279
+ # IP 地址应来自 self.client._associated_request_ip 或 routes.py 中设置的值。
280
+ # 由于 routes.py 在创建/分配会话时已将 IP 存入 client_sessions,
281
+ # 这里我们主要关注 client 实例和 active_context_hash。
282
+ # 如果 request_ip 在 self.client 中可用,则使用它,否则尝试保留已有的。
283
+ ip_to_use = request_ip if request_ip else user_specific_sessions.get(new_email_in_strategy, {}).get("ip", "unknown_ip_in_retry_update")
284
+
285
+ # 需要导入 datetime
286
+ from datetime import datetime
287
+
288
+ # 从 client 实例获取原始请求的上下文哈希
289
+ # 这个哈希应该由 routes.py 在调用 send_query 之前设置到 client 实例上
290
+ active_hash_for_new_session = getattr(self.client, '_current_request_context_hash', None)
291
+
292
+ user_specific_sessions[new_email_in_strategy] = {
293
+ "client": self.client, # 关键: 指向当前更新了 email/session_id 的 client 实例
294
+ "active_context_hash": active_hash_for_new_session, # 使用来自 client 实例的哈希
295
+ "last_time": datetime.now(), # 更新时间
296
+ "ip": ip_to_use
297
+ }
298
+ log_message_hash_part = f"set to '{active_hash_for_new_session}' (from client instance's _current_request_context_hash)" if active_hash_for_new_session is not None else "set to None (_current_request_context_hash not found on client instance)"
299
+ if hasattr(self.client, '_log'):
300
+ self.client._log(f"RateLimitStrategy: Updated/added session for new email '{new_email_in_strategy}' for user '{user_identifier}'. active_context_hash {log_message_hash_part}.", "INFO")
301
+ else:
302
+ if hasattr(self.client, '_log'):
303
+ self.client._log(f"RateLimitStrategy: User '{user_identifier}' not found in client_sessions during update attempt.", "WARNING")
304
+ # --- 更新 client_sessions 结束 ---
305
+
306
+ except Exception as e:
307
+ # 登录或创建会话失败,记录错误但不抛出异常
308
+ # 让后续的重试机制处理
309
+ if hasattr(self.client, '_log'):
310
+ self.client._log(f"切换到账号 {new_email} 后登录或创建会话失败: {e}", "WARNING")
311
+ # 此处不应更新 client_sessions,因为新账号的会话未成功建立
312
+
313
+
314
+ class RetryHandler:
315
+ """重试处理器,管理多个重试策略"""
316
+
317
+ def __init__(self, client=None, logger=None):
318
+ """
319
+ 初始化重试处理器
320
+
321
+ Args:
322
+ client: API客户端实例,用于切换账号
323
+ logger: 日志记录器或日志函数
324
+ """
325
+ self.client = client
326
+ # 如果logger是None,使用默认logger
327
+ # 如果logger是函数或Logger对象,直接使用
328
+ self.logger = logger or logging.getLogger(__name__)
329
+ self.strategies = [
330
+ ExponentialBackoffStrategy(),
331
+ LinearBackoffStrategy(),
332
+ ServerErrorStrategy(),
333
+ RateLimitStrategy(client)
334
+ ]
335
+
336
+ def retry_operation(self, operation: Callable[..., T], *args, **kwargs) -> T:
337
+ """
338
+ 使用重试策略执行操作
339
+
340
+ Args:
341
+ operation: 要执行的操作
342
+ *args: 操作的位置参数
343
+ **kwargs: 操作的关键字参数
344
+
345
+ Returns:
346
+ 操作的结果
347
+
348
+ Raises:
349
+ Exception: 如果所有重试都失败,则抛出最后一个异常
350
+ """
351
+ max_retries = config.get_config_value('max_retries')
352
+ base_delay = config.get_config_value('retry_delay')
353
+ retry_count = 0
354
+ last_exception = None
355
+
356
+ while True:
357
+ try:
358
+ return operation(*args, **kwargs)
359
+ except Exception as e:
360
+ last_exception = e
361
+
362
+ # 查找适用的重试策略
363
+ strategy = next((s for s in self.strategies if s.should_retry(e, retry_count, max_retries)), None)
364
+
365
+ if strategy:
366
+ retry_count += 1
367
+ delay = strategy.get_retry_delay(retry_count, base_delay)
368
+ strategy.log_retry_attempt(self.logger, e, retry_count, max_retries, delay)
369
+ strategy.on_retry(e, retry_count)
370
+
371
+ if delay > 0:
372
+ time.sleep(delay)
373
+ else:
374
+ # 没有适用的重试策略,或者已达到最大重试次数
375
+ raise
376
+
377
+
378
+ def with_retry(max_retries: Optional[int] = None, retry_delay: Optional[int] = None):
379
+ """
380
+ 重试装饰器,用于装饰需要重试的方法
381
+
382
+ Args:
383
+ max_retries: 最大重试次数,如果为None则使用配置值
384
+ retry_delay: 基础重试延迟,如果为None则使用配置值
385
+
386
+ Returns:
387
+ 装饰后的函数
388
+ """
389
+ def decorator(func):
390
+ @functools.wraps(func)
391
+ def wrapper(self, *args, **kwargs):
392
+ # 获取配置值
393
+ _max_retries = max_retries or config.get_config_value('max_retries')
394
+ _retry_delay = retry_delay or config.get_config_value('retry_delay')
395
+
396
+ # 创建重试处理器
397
+ handler = RetryHandler(client=self, logger=getattr(self, '_log', None))
398
+
399
+ # 定义要重试的操作
400
+ def operation():
401
+ return func(self, *args, **kwargs)
402
+
403
+ # 执行操作并处理重试
404
+ return handler.retry_operation(operation)
405
+
406
+ return wrapper
407
+
408
+ return decorator