dfa32412 commited on
Commit
f7d13af
·
verified ·
1 Parent(s): 1f45439

Upload 2 files

Browse files
Files changed (2) hide show
  1. augment2api_server.py +749 -0
  2. requirements.txt +5 -0
augment2api_server.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OpenAI to Augment API Adapter
4
+
5
+ 这个FastAPI应用程序将OpenAI API请求格式转换为Augment API格式,
6
+ 允许OpenAI客户端直接与Augment服务通信。
7
+ 所有配置参数都通过命令行参数提供,不依赖于环境变量或配置文件。
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import uuid
13
+ import time
14
+ import logging
15
+ import argparse
16
+ from typing import List, Optional, Dict, Any, Literal, Union
17
+ from datetime import datetime
18
+
19
+ import httpx
20
+ from fastapi import FastAPI, Header, HTTPException, Depends, Request
21
+ from fastapi.responses import StreamingResponse, JSONResponse
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ from pydantic import BaseModel, Field
24
+ import uvicorn
25
+
26
+ # 配置日志
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ #################################################
34
+ # 模型定义
35
+ #################################################
36
+
37
+ # OpenAI API 请求模型
38
+ class ChatMessage(BaseModel):
39
+ """表示OpenAI聊天API中的单条消息"""
40
+ role: Literal["system", "user", "assistant", "function"]
41
+ content: Optional[str] = None
42
+ name: Optional[str] = None
43
+
44
+ class ChatCompletionRequest(BaseModel):
45
+ """OpenAI聊天完成API请求模型"""
46
+ model: str
47
+ messages: List[ChatMessage]
48
+ temperature: Optional[float] = 1.0
49
+ top_p: Optional[float] = 1.0
50
+ n: Optional[int] = 1
51
+ stream: Optional[bool] = False
52
+ max_tokens: Optional[int] = None
53
+ presence_penalty: Optional[float] = 0
54
+ frequency_penalty: Optional[float] = 0
55
+ user: Optional[str] = None
56
+
57
+ # OpenAI API 响应模型
58
+ class ChatCompletionResponseChoice(BaseModel):
59
+ """OpenAI聊天完成API响应中的单个选择"""
60
+ index: int
61
+ message: ChatMessage
62
+ finish_reason: Optional[str] = None
63
+
64
+ class Usage(BaseModel):
65
+ """OpenAI API响应中的token使用信息"""
66
+ prompt_tokens: int
67
+ completion_tokens: int
68
+ total_tokens: int
69
+
70
+ class ChatCompletionResponse(BaseModel):
71
+ """OpenAI聊天完成API响应模型"""
72
+ id: str
73
+ object: str = "chat.completion"
74
+ created: int
75
+ model: str
76
+ choices: List[ChatCompletionResponseChoice]
77
+ usage: Usage
78
+
79
+ # OpenAI API 流式响应模型
80
+ class ChatCompletionStreamResponseChoice(BaseModel):
81
+ """OpenAI聊天完成流式API响应中的单个选择"""
82
+ index: int
83
+ delta: Dict[str, Any]
84
+ finish_reason: Optional[str] = None
85
+
86
+ class ChatCompletionStreamResponse(BaseModel):
87
+ """OpenAI聊天完成流式API响应模型"""
88
+ id: str
89
+ object: str = "chat.completion.chunk"
90
+ created: int
91
+ model: str
92
+ choices: List[ChatCompletionStreamResponseChoice]
93
+
94
+ # 模型信息响应
95
+ class ModelInfo(BaseModel):
96
+ """OpenAI模型信息"""
97
+ id: str
98
+ object: str = "model"
99
+ created: int
100
+ owned_by: str = "augment"
101
+
102
+ class ModelListResponse(BaseModel):
103
+ """OpenAI模型列表响应"""
104
+ object: str = "list"
105
+ data: List[ModelInfo]
106
+
107
+ # Augment API 请求相关模型
108
+ class AugmentResponseNode(BaseModel):
109
+ """Augment API响应节点"""
110
+ id: int
111
+ type: int
112
+ content: str
113
+ tool_use: Optional[Any] = None
114
+
115
+ class AugmentChatHistoryItem(BaseModel):
116
+ """Augment API聊天历史记录条目"""
117
+ request_message: str
118
+ response_text: str
119
+ request_id: Optional[str] = None
120
+ request_nodes: List[Any] = []
121
+ response_nodes: List[AugmentResponseNode] = []
122
+
123
+ class AugmentBlobs(BaseModel):
124
+ """Augment API Blobs对象"""
125
+ checkpoint_id: Optional[str] = None
126
+ added_blobs: List[Any] = []
127
+ deleted_blobs: List[Any] = []
128
+
129
+ class AugmentVcsChange(BaseModel):
130
+ """Augment API VCS更改"""
131
+ working_directory_changes: List[Any] = []
132
+
133
+ class AugmentFeatureFlags(BaseModel):
134
+ """Augment API功能标志"""
135
+ support_raw_output: bool = True
136
+
137
+ # 完整的Augment API请求模型
138
+ class AugmentChatRequest(BaseModel):
139
+ """Augment API聊天请求模型 - 基于抓包分析更新"""
140
+ model: Optional[str] = None
141
+ path: Optional[str] = None
142
+ prefix: Optional[str] = None
143
+ selected_code: Optional[str] = None
144
+ suffix: Optional[str] = None
145
+ message: str
146
+ chat_history: List[AugmentChatHistoryItem] = []
147
+ lang: Optional[str] = None
148
+ blobs: AugmentBlobs = AugmentBlobs()
149
+ user_guided_blobs: List[Any] = []
150
+ context_code_exchange_request_id: Optional[str] = None
151
+ vcs_change: AugmentVcsChange = AugmentVcsChange()
152
+ recency_info_recent_changes: List[Any] = []
153
+ external_source_ids: List[Any] = []
154
+ disable_auto_external_sources: Optional[bool] = None
155
+ user_guidelines: str = ""
156
+ workspace_guidelines: str = ""
157
+ feature_detection_flags: AugmentFeatureFlags = AugmentFeatureFlags()
158
+ tool_definitions: List[Any] = []
159
+ nodes: List[Any] = []
160
+ mode: str = "CHAT"
161
+ agent_memories: Optional[Any] = None
162
+ system_prompt: Optional[str] = None # 保留此字段以兼容之前的代码
163
+
164
+ # Augment API响应模型
165
+ class AugmentResponseChunk(BaseModel):
166
+ """Augment API响应块"""
167
+ text: str
168
+ unknown_blob_names: List[Any] = []
169
+ checkpoint_not_found: bool = False
170
+ workspace_file_chunks: List[Any] = []
171
+ incorporated_external_sources: List[Any] = []
172
+ nodes: List[AugmentResponseNode] = []
173
+
174
+ #################################################
175
+ # 辅助函数
176
+ #################################################
177
+
178
+ def generate_id():
179
+ """生成唯一ID,类似于OpenAI的格式"""
180
+ return str(uuid.uuid4()).replace("-", "")[:24]
181
+
182
+ def estimate_tokens(text):
183
+ """
184
+ 估计文本的token数量
185
+ 这是一个简单的估算,实际数量可能有所不同
186
+ """
187
+ if not text:
188
+ return 0
189
+ # 简单估算:假设每个单词约等于1.3个token
190
+ # 中文字符每个字约等于1个token
191
+ words = len(text.split()) if text else 0
192
+ chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') if text else 0
193
+ return int(words * 1.3 + chinese_chars)
194
+
195
+ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> AugmentChatRequest:
196
+ """
197
+ 将OpenAI API请求转换为Augment API请求
198
+
199
+ Args:
200
+ openai_request: OpenAI API请求对象
201
+
202
+ Returns:
203
+ 转换后的Augment API请求对象
204
+
205
+ Raises:
206
+ HTTPException: 如果请求格式无效
207
+ """
208
+ chat_history = []
209
+ system_message = None
210
+
211
+ # 处理消息历史记录
212
+ for i in range(len(openai_request.messages) - 1):
213
+ msg = openai_request.messages[i]
214
+ if msg.role == "system":
215
+ system_message = msg.content
216
+ elif msg.role == "user" and i + 1 < len(openai_request.messages) and openai_request.messages[i + 1].role == "assistant":
217
+ user_msg = msg.content
218
+ assistant_msg = openai_request.messages[i + 1].content
219
+
220
+ # 创建历史记录条目,格式符合Augment API
221
+ history_item = AugmentChatHistoryItem(
222
+ request_message=user_msg,
223
+ response_text=assistant_msg,
224
+ request_id=generate_id(),
225
+ response_nodes=[
226
+ AugmentResponseNode(
227
+ id=0,
228
+ type=0,
229
+ content=assistant_msg,
230
+ tool_use=None
231
+ )
232
+ ]
233
+ )
234
+ chat_history.append(history_item)
235
+
236
+ # 获取当前用户消息
237
+ current_message = None
238
+ for msg in reversed(openai_request.messages):
239
+ if msg.role == "user":
240
+ current_message = msg.content
241
+ break
242
+
243
+ # 如果没有用户消息,则返回错误
244
+ if current_message is None:
245
+ raise HTTPException(
246
+ status_code=400,
247
+ detail="At least one user message is required"
248
+ )
249
+
250
+ # 准备Augment请求体
251
+ augment_request = AugmentChatRequest(
252
+ message=current_message,
253
+ chat_history=chat_history,
254
+ mode="CHAT"
255
+ )
256
+
257
+ # 如果有系统消息,设置为用户指南
258
+ if system_message:
259
+ augment_request.user_guidelines = system_message
260
+
261
+ return augment_request
262
+
263
+ #################################################
264
+ # FastAPI应用
265
+ #################################################
266
+
267
+ def create_app(augment_base_url, chat_endpoint, timeout):
268
+ """
269
+ 创建并配置FastAPI应用
270
+
271
+ Args:
272
+ augment_base_url: Augment API基础URL
273
+ chat_endpoint: 聊天端点路径
274
+ timeout: 请求超时时间
275
+
276
+ Returns:
277
+ 配置好的FastAPI应用
278
+ """
279
+ app = FastAPI(
280
+ title="OpenAI to Augment API Adapter",
281
+ description="A FastAPI adapter that converts OpenAI API requests to Augment API format",
282
+ version="1.0.0"
283
+ )
284
+
285
+ # 添加CORS中间件
286
+ app.add_middleware(
287
+ CORSMiddleware,
288
+ allow_origins=["*"],
289
+ allow_credentials=True,
290
+ allow_methods=["*"],
291
+ allow_headers=["*"],
292
+ )
293
+
294
+ #################################################
295
+ # 中间件和依赖项
296
+ #################################################
297
+
298
+ @app.middleware("http")
299
+ async def catch_exceptions_middleware(request: Request, call_next):
300
+ """捕获所有未处理的异常,返回适当的错误响应"""
301
+ try:
302
+ return await call_next(request)
303
+ except Exception as e:
304
+ logger.exception("Unhandled exception")
305
+ return JSONResponse(
306
+ status_code=500,
307
+ content={
308
+ "error": {
309
+ "message": str(e),
310
+ "type": "internal_server_error",
311
+ "param": None,
312
+ "code": "internal_server_error"
313
+ }
314
+ }
315
+ )
316
+
317
+ async def verify_api_key(authorization: str = Header(...)):
318
+ """
319
+ 验证API密钥
320
+
321
+ Args:
322
+ authorization: Authorization头部值
323
+
324
+ Returns:
325
+ 提取的API密钥
326
+
327
+ Raises:
328
+ HTTPException: 如果API密钥格式无效或为空
329
+ """
330
+ if not authorization.startswith("Bearer "):
331
+ raise HTTPException(
332
+ status_code=401,
333
+ detail={
334
+ "error": {
335
+ "message": "Invalid API key format. Expected 'Bearer YOUR_API_KEY'",
336
+ "type": "invalid_request_error",
337
+ "param": "authorization",
338
+ "code": "invalid_api_key"
339
+ }
340
+ }
341
+ )
342
+ api_key = authorization.replace("Bearer ", "")
343
+ if not api_key:
344
+ raise HTTPException(
345
+ status_code=401,
346
+ detail={
347
+ "error": {
348
+ "message": "API key cannot be empty",
349
+ "type": "invalid_request_error",
350
+ "param": "authorization",
351
+ "code": "invalid_api_key"
352
+ }
353
+ }
354
+ )
355
+ return api_key
356
+
357
+ #################################################
358
+ # API端点
359
+ #################################################
360
+
361
+ @app.get("/health")
362
+ async def health_check():
363
+ """健康检查端点"""
364
+ return {"status": "ok", "timestamp": datetime.now().isoformat()}
365
+
366
+ @app.get("/v1/models")
367
+ async def list_models():
368
+ """列出支持的模型"""
369
+ # 返回一个虚拟的模型列表
370
+ models = [
371
+ ModelInfo(id="gpt-3.5-turbo", created=int(time.time())),
372
+ ModelInfo(id="gpt-4", created=int(time.time())),
373
+ ModelInfo(id="augment-default", created=int(time.time())),
374
+ ]
375
+ return ModelListResponse(data=models)
376
+
377
+ @app.get("/v1/models/{model_id}")
378
+ async def get_model(model_id: str):
379
+ """获取特定模型的信息"""
380
+ return ModelInfo(id=model_id, created=int(time.time()))
381
+
382
+ @app.post("/v1/chat/completions")
383
+ async def chat_completions(
384
+ request: ChatCompletionRequest,
385
+ api_key: str = Depends(verify_api_key)
386
+ ):
387
+ """
388
+ 聊天完成端点 - 将OpenAI API请求转换为Augment API请求
389
+
390
+ Args:
391
+ request: OpenAI格式的聊天完成请求
392
+ api_key: 通过验证的API密钥
393
+
394
+ Returns:
395
+ OpenAI格式的聊天完成响应或流式响应
396
+ """
397
+ try:
398
+ # 转换为Augment请求格式
399
+ augment_request = convert_to_augment_request(request)
400
+ logger.debug(f"Converted request: {augment_request.dict()}")
401
+
402
+
403
+ if ":" in api_key:
404
+ tenant_id, api_key = api_key.split(":")
405
+ augment_base_url = f"https://{tenant_id}.api.augmentcode.com/"
406
+
407
+ # 决定是否使用流式响应
408
+ if request.stream:
409
+ return StreamingResponse(
410
+ stream_augment_response(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout),
411
+ media_type="text/event-stream"
412
+ )
413
+ else:
414
+ # 同步请求处理
415
+ return await handle_sync_request(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout)
416
+
417
+ except httpx.TimeoutException:
418
+ logger.error("Request to Augment API timed out")
419
+ raise HTTPException(
420
+ status_code=504,
421
+ detail={
422
+ "error": {
423
+ "message": "Request to Augment API timed out",
424
+ "type": "timeout_error",
425
+ "param": None,
426
+ "code": "timeout"
427
+ }
428
+ }
429
+ )
430
+ except httpx.HTTPError as e:
431
+ logger.error(f"HTTP error: {str(e)}")
432
+ raise HTTPException(
433
+ status_code=502,
434
+ detail={
435
+ "error": {
436
+ "message": f"Error communicating with Augment API: {str(e)}",
437
+ "type": "api_error",
438
+ "param": None,
439
+ "code": "api_error"
440
+ }
441
+ }
442
+ )
443
+ except HTTPException:
444
+ # 重新抛出HTTPException,以保持原始状态码和详细信息
445
+ raise
446
+ except Exception as e:
447
+ logger.exception("Unexpected error")
448
+ raise HTTPException(
449
+ status_code=500,
450
+ detail={
451
+ "error": {
452
+ "message": f"Internal server error: {str(e)}",
453
+ "type": "internal_server_error",
454
+ "param": None,
455
+ "code": "internal_server_error"
456
+ }
457
+ }
458
+ )
459
+
460
+ return app
461
+
462
+ async def handle_sync_request(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
463
+ """
464
+ 处理同步请求
465
+
466
+ Args:
467
+ base_url: Augment API基础URL
468
+ api_key: API密钥
469
+ augment_request: Augment API请求对象
470
+ model_name: 模型名称
471
+ chat_endpoint: 聊天端点
472
+ timeout: 请求超时时间
473
+
474
+ Returns:
475
+ OpenAI格式的聊天完成响应
476
+ """
477
+ async with httpx.AsyncClient(timeout=timeout) as client:
478
+ response = await client.post(
479
+ f"{base_url.rstrip('/')}/{chat_endpoint}",
480
+ json=augment_request.dict(),
481
+ headers={
482
+ "Content-Type": "application/json",
483
+ "Authorization": f"Bearer {api_key}",
484
+ "User-Agent": "Augment.openai-adapter/1.0.0",
485
+ "Accept": "*/*"
486
+ }
487
+ )
488
+
489
+ if response.status_code != 200:
490
+ logger.error(f"Augment API error: {response.status_code} - {response.text}")
491
+ raise HTTPException(
492
+ status_code=response.status_code,
493
+ detail={
494
+ "error": {
495
+ "message": f"Augment API error: {response.text}",
496
+ "type": "api_error",
497
+ "param": None,
498
+ "code": "api_error"
499
+ }
500
+ }
501
+ )
502
+
503
+ # 处理流式响应,合并为完整响应
504
+ full_response = ""
505
+ for line in response.text.split("\n"):
506
+ if line.strip():
507
+ try:
508
+ data = json.loads(line)
509
+ if "text" in data and data["text"]:
510
+ full_response += data["text"]
511
+ except json.JSONDecodeError:
512
+ logger.warning(f"Failed to parse JSON: {line}")
513
+
514
+ # 估算token使用情况
515
+ prompt_tokens = estimate_tokens(augment_request.message)
516
+ completion_tokens = estimate_tokens(full_response)
517
+
518
+ # 构建OpenAI格式响应
519
+ return ChatCompletionResponse(
520
+ id=f"chatcmpl-{generate_id()}",
521
+ created=int(time.time()),
522
+ model=model_name,
523
+ choices=[
524
+ ChatCompletionResponseChoice(
525
+ index=0,
526
+ message=ChatMessage(
527
+ role="assistant",
528
+ content=full_response
529
+ ),
530
+ finish_reason="stop"
531
+ )
532
+ ],
533
+ usage=Usage(
534
+ prompt_tokens=prompt_tokens,
535
+ completion_tokens=completion_tokens,
536
+ total_tokens=prompt_tokens + completion_tokens
537
+ )
538
+ )
539
+
540
+ async def stream_augment_response(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
541
+ """
542
+ 处理流式响应
543
+
544
+ Args:
545
+ base_url: Augment API基础URL
546
+ api_key: API密钥
547
+ augment_request: Augment API请求对象
548
+ model_name: 模型名称
549
+ chat_endpoint: 聊天端点
550
+ timeout: 请求超时时间
551
+
552
+ Yields:
553
+ 流式响应的数据块
554
+ """
555
+ async with httpx.AsyncClient(timeout=timeout) as client:
556
+ try:
557
+ async with client.stream(
558
+ "POST",
559
+ f"{base_url.rstrip('/')}/{chat_endpoint}",
560
+ json=augment_request.dict(),
561
+ headers={
562
+ "Content-Type": "application/json",
563
+ "Authorization": f"Bearer {api_key}",
564
+ "User-Agent": "chrome",
565
+ "Accept": "*/*"
566
+ }
567
+ ) as response:
568
+
569
+ if response.status_code != 200:
570
+ error_detail = await response.aread()
571
+ logger.error(f"Augment API error: {response.status_code} - {error_detail}")
572
+ error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}"
573
+ yield f"data: {json.dumps({'error': error_message})}\n\n"
574
+ return
575
+
576
+ # 生成唯一ID
577
+ chat_id = f"chatcmpl-{generate_id()}"
578
+ created_time = int(time.time())
579
+
580
+ # 初始化响应
581
+ init_response = ChatCompletionStreamResponse(
582
+ id=chat_id,
583
+ created=created_time,
584
+ model=model_name,
585
+ choices=[
586
+ ChatCompletionStreamResponseChoice(
587
+ index=0,
588
+ delta={"role": "assistant"},
589
+ finish_reason=None
590
+ )
591
+ ]
592
+ )
593
+ init_data = json.dumps(init_response.dict())
594
+ yield f"data: {init_data}\n\n"
595
+
596
+ # 处理流式响应
597
+ buffer = ""
598
+ async for line in response.aiter_lines():
599
+ if not line.strip():
600
+ continue
601
+
602
+ try:
603
+ # 解析Augment响应格式
604
+ chunk = json.loads(line)
605
+ if "text" in chunk and chunk["text"]:
606
+ content = chunk["text"]
607
+
608
+ # 发送增量更新
609
+ stream_response = ChatCompletionStreamResponse(
610
+ id=chat_id,
611
+ created=created_time,
612
+ model=model_name,
613
+ choices=[
614
+ ChatCompletionStreamResponseChoice(
615
+ index=0,
616
+ delta={"content": content},
617
+ finish_reason=None
618
+ )
619
+ ]
620
+ )
621
+ response_data = json.dumps(stream_response.dict())
622
+ yield f"data: {response_data}\n\n"
623
+ except json.JSONDecodeError:
624
+ logger.warning(f"Failed to parse JSON: {line}")
625
+
626
+ # 发送完成信号
627
+ final_response = ChatCompletionStreamResponse(
628
+ id=chat_id,
629
+ created=created_time,
630
+ model=model_name,
631
+ choices=[
632
+ ChatCompletionStreamResponseChoice(
633
+ index=0,
634
+ delta={},
635
+ finish_reason="stop"
636
+ )
637
+ ]
638
+ )
639
+ final_data = json.dumps(final_response.dict())
640
+ yield f"data: {final_data}\n\n"
641
+
642
+ # 发送[DONE]标记
643
+ yield "data: [DONE]\n\n"
644
+
645
+ except httpx.TimeoutException:
646
+ logger.error("Request to Augment API timed out")
647
+ yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n"
648
+ except httpx.HTTPError as e:
649
+ logger.error(f"HTTP error: {str(e)}")
650
+ yield f"data: {json.dumps({'error': f'Error communicating with Augment API: {str(e)}'})}\n\n"
651
+ except Exception as e:
652
+ logger.exception("Unexpected error")
653
+ yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n"
654
+
655
+ def parse_args():
656
+ """解析命令行参数"""
657
+ parser = argparse.ArgumentParser(
658
+ description="OpenAI to Augment API Adapter",
659
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
660
+ )
661
+
662
+ parser.add_argument(
663
+ "--augment-url",
664
+ default="https://d6.api.augmentcode.com/",
665
+ help="Augment API基础URL"
666
+ )
667
+
668
+ parser.add_argument(
669
+ "--chat-endpoint",
670
+ default="chat-stream",
671
+ help="Augment聊天端点路径"
672
+ )
673
+
674
+ parser.add_argument(
675
+ "--host",
676
+ default="0.0.0.0",
677
+ help="服务器主机地址"
678
+ )
679
+
680
+ parser.add_argument(
681
+ "--port",
682
+ type=int,
683
+ default=8686,
684
+ help="服务器端口"
685
+ )
686
+
687
+ parser.add_argument(
688
+ "--timeout",
689
+ type=int,
690
+ default=120,
691
+ help="API请求超时时间(秒)"
692
+ )
693
+
694
+ parser.add_argument(
695
+ "--debug",
696
+ action="store_true",
697
+ help="启用调试模式"
698
+ )
699
+
700
+ parser.add_argument(
701
+ "--tenant-id",
702
+ default="d18",
703
+ help="Augment API租户ID (域名前缀)"
704
+ )
705
+
706
+ return parser.parse_args()
707
+
708
+ #################################################
709
+ # 主程序
710
+ #################################################
711
+
712
+ def main():
713
+ """主函数"""
714
+ args = parse_args()
715
+
716
+ # 配置日志级别
717
+ if args.debug:
718
+ logging.getLogger().setLevel(logging.DEBUG)
719
+
720
+ # 构建完整的Augment URL
721
+ if args.augment_url == "https://d18.api.augmentcode.com/":
722
+ # 如果使用默认URL,则应用tenant-id参数
723
+ augment_base_url = f"https://{args.tenant_id}.api.augmentcode.com/"
724
+ logger.info(f"Using tenant ID: {args.tenant_id}")
725
+ else:
726
+ # 否则使用提供的URL
727
+ augment_base_url = args.augment_url
728
+
729
+ # 创建应用
730
+ app = create_app(
731
+ augment_base_url=augment_base_url,
732
+ chat_endpoint=args.chat_endpoint,
733
+ timeout=args.timeout
734
+ )
735
+
736
+ # 启动应用
737
+ logger.info(f"Starting server on {args.host}:7860")
738
+ logger.info(f"Using Augment base URL: {augment_base_url}")
739
+ logger.info(f"Using Augment chat endpoint: {args.chat_endpoint}")
740
+
741
+ uvicorn.run(
742
+ app,
743
+ host=args.host,
744
+ port=7860,
745
+ log_level="info" if not args.debug else "debug"
746
+ )
747
+
748
+ if __name__ == "__main__":
749
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi==0.115.12
2
+ httpx==0.28.1
3
+ pydantic==1.10.19
4
+ Requests==2.32.3
5
+ uvicorn==0.34.0