Commit
·
2692e0d
1
Parent(s):
c62df12
Refactor chat handling and model integration: Update .env.example to include new API keys, modify main.py to implement a lifespan context manager for resource management, and replace Message class with dictionary structures in chat_request.py and chat_service.py for improved flexibility. Remove unused message and response models to streamline codebase.
Browse files- .env.example +2 -1
- src/main.py +33 -5
- src/models/others/message.py +0 -24
- src/models/requests/chat_request.py +2 -3
- src/models/responses/chat_response.py +0 -92
- src/models/responses/tool_call_response.py +0 -13
- src/routes/chat_routes.py +1 -3
- src/services/chat_service.py +25 -41
- src/utils/image_pipeline.py +21 -28
- src/utils/timing.py +2 -0
- src/utils/tools/tools_helper.py +12 -15
- src/utils/transformer_client.py +131 -70
.env.example
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
jina_api_key=
|
2 |
brave_search_api_key=
|
3 |
-
ai_url=
|
|
|
|
1 |
jina_api_key=
|
2 |
brave_search_api_key=
|
3 |
+
ai_url=
|
4 |
+
serp_api
|
src/main.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
|
|
2 |
from fastapi import FastAPI, Request
|
|
|
3 |
from fastapi.exceptions import RequestValidationError
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
from fastapi.responses import JSONResponse
|
@@ -7,9 +9,26 @@ from fastapi.staticfiles import StaticFiles
|
|
7 |
from constants.config import OUTPUT_DIR
|
8 |
from models.responses.base_response import BaseResponse
|
9 |
from routes import chat_routes, process_file_routes, vector_store_routes
|
|
|
10 |
from utils.exception import CustomException
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
origins = ["*"]
|
15 |
app.add_middleware(
|
@@ -20,34 +39,43 @@ app.add_middleware(
|
|
20 |
allow_headers=["*"],
|
21 |
)
|
22 |
|
|
|
23 |
@app.exception_handler(CustomException)
|
24 |
async def custom_exception_handler(request: Request, exc: CustomException):
|
25 |
return JSONResponse(
|
26 |
status_code=exc.status_code,
|
27 |
-
content=BaseResponse(
|
|
|
|
|
28 |
)
|
29 |
|
|
|
30 |
@app.exception_handler(Exception)
|
31 |
async def global_exception_handler(request: Request, exc: Exception):
|
32 |
# Mặc định cho các lỗi không được CustomException xử lý
|
33 |
return JSONResponse(
|
34 |
status_code=500,
|
35 |
-
content=BaseResponse(status_code=500, message=str(exc)).model_dump()
|
36 |
)
|
37 |
|
|
|
38 |
@app.exception_handler(RequestValidationError)
|
39 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
40 |
return JSONResponse(
|
41 |
status_code=422,
|
42 |
-
content=BaseResponse(status_code=422, message="Validation error").model_dump()
|
43 |
)
|
44 |
|
|
|
45 |
app.include_router(chat_routes.router, prefix="/api/v1")
|
46 |
app.include_router(process_file_routes.router, prefix="/api/v1")
|
47 |
app.include_router(vector_store_routes.router, prefix="/api/v1")
|
|
|
|
|
48 |
@app.get("/")
|
49 |
def read_root():
|
50 |
return {"message": "Welcome to my API"}
|
51 |
|
|
|
52 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
53 |
-
app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
|
|
1 |
import os
|
2 |
+
from sre_parse import Tokenizer
|
3 |
from fastapi import FastAPI, Request
|
4 |
+
from fastapi.concurrency import asynccontextmanager
|
5 |
from fastapi.exceptions import RequestValidationError
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
from fastapi.responses import JSONResponse
|
|
|
9 |
from constants.config import OUTPUT_DIR
|
10 |
from models.responses.base_response import BaseResponse
|
11 |
from routes import chat_routes, process_file_routes, vector_store_routes
|
12 |
+
from utils import image_pipeline, transformer_client
|
13 |
from utils.exception import CustomException
|
14 |
|
15 |
+
|
16 |
+
@asynccontextmanager
|
17 |
+
async def lifespan(app: FastAPI):
|
18 |
+
# try:
|
19 |
+
# transformer_client.load_model()
|
20 |
+
# image_pipeline.load_pipeline()
|
21 |
+
|
22 |
+
# except Exception as e:
|
23 |
+
# print(f"Error during startup: {str(e)}")
|
24 |
+
|
25 |
+
yield
|
26 |
+
|
27 |
+
transformer_client.clear_resources()
|
28 |
+
image_pipeline.clear_resources()
|
29 |
+
|
30 |
+
|
31 |
+
app = FastAPI(lifespan=lifespan)
|
32 |
|
33 |
origins = ["*"]
|
34 |
app.add_middleware(
|
|
|
39 |
allow_headers=["*"],
|
40 |
)
|
41 |
|
42 |
+
|
43 |
@app.exception_handler(CustomException)
|
44 |
async def custom_exception_handler(request: Request, exc: CustomException):
|
45 |
return JSONResponse(
|
46 |
status_code=exc.status_code,
|
47 |
+
content=BaseResponse(
|
48 |
+
status_code=exc.status_code, message=exc.message
|
49 |
+
).model_dump(),
|
50 |
)
|
51 |
|
52 |
+
|
53 |
@app.exception_handler(Exception)
|
54 |
async def global_exception_handler(request: Request, exc: Exception):
|
55 |
# Mặc định cho các lỗi không được CustomException xử lý
|
56 |
return JSONResponse(
|
57 |
status_code=500,
|
58 |
+
content=BaseResponse(status_code=500, message=str(exc)).model_dump(),
|
59 |
)
|
60 |
|
61 |
+
|
62 |
@app.exception_handler(RequestValidationError)
|
63 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
64 |
return JSONResponse(
|
65 |
status_code=422,
|
66 |
+
content=BaseResponse(status_code=422, message="Validation error").model_dump(),
|
67 |
)
|
68 |
|
69 |
+
|
70 |
app.include_router(chat_routes.router, prefix="/api/v1")
|
71 |
app.include_router(process_file_routes.router, prefix="/api/v1")
|
72 |
app.include_router(vector_store_routes.router, prefix="/api/v1")
|
73 |
+
|
74 |
+
|
75 |
@app.get("/")
|
76 |
def read_root():
|
77 |
return {"message": "Welcome to my API"}
|
78 |
|
79 |
+
|
80 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
81 |
+
app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
src/models/others/message.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
from typing import List, Optional
|
3 |
-
|
4 |
-
from pydantic import BaseModel
|
5 |
-
|
6 |
-
from models.responses.tool_call_response import ToolCall
|
7 |
-
|
8 |
-
|
9 |
-
class Role(str, Enum):
|
10 |
-
assistant = "assistant"
|
11 |
-
user = "user"
|
12 |
-
system = "system"
|
13 |
-
tool = "tool"
|
14 |
-
|
15 |
-
|
16 |
-
class Message(BaseModel):
|
17 |
-
role: Role
|
18 |
-
content: Optional[str] = None
|
19 |
-
tool_calls: Optional[List[ToolCall]] = None
|
20 |
-
|
21 |
-
def to_map(self):
|
22 |
-
data = self.model_dump(exclude_none=True)
|
23 |
-
data["role"] = self.role.value
|
24 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/requests/chat_request.py
CHANGED
@@ -2,11 +2,10 @@ from typing import List, Optional
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
from constants.config import LLM_MODEL_NAME
|
5 |
-
from models.others.message import Role, Message
|
6 |
|
7 |
|
8 |
class ChatRequest(BaseModel):
|
9 |
-
messages: List[
|
10 |
has_file: bool = False
|
11 |
chat_session_id: str | None = None
|
12 |
|
@@ -16,7 +15,7 @@ class ChatRequest(BaseModel):
|
|
16 |
{
|
17 |
"has_file": False,
|
18 |
"chat_session_id": "123",
|
19 |
-
"messages": [{"role":
|
20 |
}
|
21 |
]
|
22 |
}
|
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
from constants.config import LLM_MODEL_NAME
|
|
|
5 |
|
6 |
|
7 |
class ChatRequest(BaseModel):
|
8 |
+
messages: List[dict]
|
9 |
has_file: bool = False
|
10 |
chat_session_id: str | None = None
|
11 |
|
|
|
15 |
{
|
16 |
"has_file": False,
|
17 |
"chat_session_id": "123",
|
18 |
+
"messages": [{"role": 'user', "content": "hello"}],
|
19 |
}
|
20 |
]
|
21 |
}
|
src/models/responses/chat_response.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
from typing import Any, List, Optional
|
2 |
-
from click import argument
|
3 |
-
from pydantic import BaseModel
|
4 |
-
from models.others.message import Message, Role
|
5 |
-
from models.responses.tool_call_response import ToolCall
|
6 |
-
|
7 |
-
|
8 |
-
class Choice(BaseModel):
|
9 |
-
message: Optional[Message] = None
|
10 |
-
delta: Optional[Message] = None
|
11 |
-
function_call: Optional[ToolCall] = None
|
12 |
-
|
13 |
-
|
14 |
-
class ChatResponse(BaseModel):
|
15 |
-
id: Optional[str] = None
|
16 |
-
choices: Optional[List[Choice]] = None
|
17 |
-
|
18 |
-
@classmethod
|
19 |
-
def from_stream_chunk(cls, chunk: dict, last_role: Optional[Role] = None):
|
20 |
-
choices = []
|
21 |
-
updated_role = last_role # Default to last role
|
22 |
-
|
23 |
-
for choice in chunk.get("choices", []):
|
24 |
-
delta_data = choice.get("delta", {})
|
25 |
-
|
26 |
-
# Skip chunks that contain neither content nor role
|
27 |
-
if not delta_data.get("content") and not delta_data.get("role"):
|
28 |
-
continue
|
29 |
-
|
30 |
-
# Determine role
|
31 |
-
if "role" in delta_data and delta_data["role"] is not None:
|
32 |
-
try:
|
33 |
-
updated_role = Role(delta_data["role"])
|
34 |
-
except ValueError:
|
35 |
-
# Skip or log invalid role values
|
36 |
-
continue
|
37 |
-
|
38 |
-
if not updated_role:
|
39 |
-
# Still no role available, skip
|
40 |
-
continue
|
41 |
-
|
42 |
-
message = Message(
|
43 |
-
role=updated_role,
|
44 |
-
content=delta_data.get("content"),
|
45 |
-
)
|
46 |
-
|
47 |
-
choices.append(
|
48 |
-
Choice(
|
49 |
-
message=message,
|
50 |
-
delta=message,
|
51 |
-
)
|
52 |
-
)
|
53 |
-
|
54 |
-
return (
|
55 |
-
cls(
|
56 |
-
id=chunk.get("id"),
|
57 |
-
choices=choices,
|
58 |
-
),
|
59 |
-
updated_role,
|
60 |
-
)
|
61 |
-
|
62 |
-
@classmethod
|
63 |
-
def from_llm_output(cls, output: dict) -> "ChatResponse":
|
64 |
-
"""
|
65 |
-
Map the output dict from llm.create_chat_completion to a ChatResponse instance.
|
66 |
-
"""
|
67 |
-
choices = []
|
68 |
-
for choice in output.get("choices", []):
|
69 |
-
message_data = choice.get("message", {})
|
70 |
-
tool_calls_data = message_data.get("tool_calls")
|
71 |
-
tool_calls = None
|
72 |
-
if tool_calls_data:
|
73 |
-
tool_calls = [ToolCall(**tc) for tc in tool_calls_data]
|
74 |
-
message = Message(
|
75 |
-
role=Role(message_data["role"]),
|
76 |
-
content=message_data.get("content"),
|
77 |
-
tool_calls=tool_calls,
|
78 |
-
)
|
79 |
-
# function_call is for OpenAI compatibility, may be None
|
80 |
-
function_call = None
|
81 |
-
if "function_call" in choice:
|
82 |
-
function_call = ToolCall(**choice["function_call"])
|
83 |
-
choices.append(
|
84 |
-
Choice(
|
85 |
-
message=message,
|
86 |
-
function_call=function_call,
|
87 |
-
)
|
88 |
-
)
|
89 |
-
return cls(
|
90 |
-
id=output.get("id"),
|
91 |
-
choices=choices,
|
92 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/responses/tool_call_response.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
from pydantic import BaseModel
|
3 |
-
|
4 |
-
|
5 |
-
class FunctionOfToolCall(BaseModel):
|
6 |
-
name: Optional[str]
|
7 |
-
arguments: Optional[str]
|
8 |
-
|
9 |
-
|
10 |
-
class ToolCall(BaseModel):
|
11 |
-
id: Optional[str]
|
12 |
-
type: Optional[str]
|
13 |
-
function: Optional[FunctionOfToolCall]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/routes/chat_routes.py
CHANGED
@@ -70,9 +70,7 @@ async def chat(request: ChatRequest):
|
|
70 |
try:
|
71 |
response = chat_service.chat_generate(request=request)
|
72 |
return BaseResponse(
|
73 |
-
data=
|
74 |
-
response.model_dump_json(),
|
75 |
-
),
|
76 |
)
|
77 |
except Exception as e:
|
78 |
raise BaseExceptionResponse(message=str(e))
|
|
|
70 |
try:
|
71 |
response = chat_service.chat_generate(request=request)
|
72 |
return BaseResponse(
|
73 |
+
data=response,
|
|
|
|
|
74 |
)
|
75 |
except Exception as e:
|
76 |
raise BaseExceptionResponse(message=str(e))
|
src/services/chat_service.py
CHANGED
@@ -6,13 +6,12 @@ from services import vector_store_service
|
|
6 |
from utils import open_ai_client
|
7 |
from utils.timing import measure_time
|
8 |
from utils.tools import tools_helper, tools_define
|
9 |
-
from models.others.message import Message, Role
|
10 |
from utils.transformer_client import generate, generate_stream
|
11 |
|
12 |
|
13 |
-
def build_context_prompt(request: ChatRequest) -> list[
|
14 |
"""Build system prompt with context if file is provided."""
|
15 |
-
messages = [
|
16 |
|
17 |
if not request.has_file or not vector_store_service.check_if_collection_exists(
|
18 |
request.chat_session_id
|
@@ -21,7 +20,7 @@ def build_context_prompt(request: ChatRequest) -> list[Message]:
|
|
21 |
|
22 |
with measure_time("Get data from vector store"):
|
23 |
vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
|
24 |
-
query = request.messages[-1].content
|
25 |
results = vectorstore.similarity_search(query=query or "", k=10)
|
26 |
|
27 |
if not results:
|
@@ -42,7 +41,7 @@ def build_context_prompt(request: ChatRequest) -> list[Message]:
|
|
42 |
f"CONTEXT: {context}\nQUESTION: {query}"
|
43 |
)
|
44 |
|
45 |
-
messages.append(
|
46 |
return messages
|
47 |
|
48 |
|
@@ -62,21 +61,19 @@ def chat_generate_stream(
|
|
62 |
final_tool_calls = {}
|
63 |
|
64 |
for chunk in stream:
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
if not final_tool_calls:
|
74 |
return
|
75 |
|
76 |
tool_call_result = tools_helper.process_tool_calls(final_tool_calls)
|
77 |
-
tool_call_message =
|
78 |
-
role=Role.tool, content=tool_call_result.get("content", "")
|
79 |
-
)
|
80 |
messages.append(tool_call_message)
|
81 |
|
82 |
# new_stream = open_ai_client.chat.completions.create(
|
@@ -94,39 +91,26 @@ def chat_generate(request: ChatRequest):
|
|
94 |
messages = build_context_prompt(request)
|
95 |
messages.extend(request.messages)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
output = generate(messages=messages)
|
101 |
-
|
102 |
-
final_tool_calls = {}
|
103 |
|
104 |
-
|
105 |
-
if output.choices and len(output.choices) > 0:
|
106 |
-
message = output.choices[0].message
|
107 |
-
if (
|
108 |
-
message is not None
|
109 |
-
and getattr(message, "tool_calls", None)
|
110 |
-
and message.tool_calls
|
111 |
-
):
|
112 |
-
final_tool_calls = tools_helper.final_tool_calls_handler(
|
113 |
-
final_tool_calls=final_tool_calls, tool_calls=message.tool_calls
|
114 |
-
)
|
115 |
|
116 |
-
if not
|
117 |
return output
|
118 |
|
119 |
tool_call_result = tools_helper.process_tool_calls(
|
120 |
-
|
121 |
-
)
|
122 |
-
tool_call_message = Message(
|
123 |
-
role=Role.tool, content=tool_call_result.get("content", "")
|
124 |
)
|
|
|
125 |
messages.append(tool_call_message)
|
126 |
|
127 |
-
new_output = generate(messages=messages, has_tool_call=False)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
|
132 |
return new_output
|
|
|
6 |
from utils import open_ai_client
|
7 |
from utils.timing import measure_time
|
8 |
from utils.tools import tools_helper, tools_define
|
|
|
9 |
from utils.transformer_client import generate, generate_stream
|
10 |
|
11 |
|
12 |
+
def build_context_prompt(request: ChatRequest) -> list[dict]:
|
13 |
"""Build system prompt with context if file is provided."""
|
14 |
+
messages = [{"role": "system", "content": system_prompts.system_prompt}]
|
15 |
|
16 |
if not request.has_file or not vector_store_service.check_if_collection_exists(
|
17 |
request.chat_session_id
|
|
|
20 |
|
21 |
with measure_time("Get data from vector store"):
|
22 |
vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
|
23 |
+
query = request.messages[-1].get("content")
|
24 |
results = vectorstore.similarity_search(query=query or "", k=10)
|
25 |
|
26 |
if not results:
|
|
|
41 |
f"CONTEXT: {context}\nQUESTION: {query}"
|
42 |
)
|
43 |
|
44 |
+
messages.append({"role": "system", "content": embedded_prompt})
|
45 |
return messages
|
46 |
|
47 |
|
|
|
61 |
final_tool_calls = {}
|
62 |
|
63 |
for chunk in stream:
|
64 |
+
choices = chunk.get("choices", [])
|
65 |
+
if choices and choices[0].get("delta", {}).get("tool_calls"):
|
66 |
+
delta = choices[0]["delta"]
|
67 |
+
final_tool_calls = tools_helper.final_tool_calls_handler(
|
68 |
+
final_tool_calls, delta["tool_calls"], is_stream=True
|
69 |
+
)
|
70 |
+
yield chunk
|
71 |
|
72 |
if not final_tool_calls:
|
73 |
return
|
74 |
|
75 |
tool_call_result = tools_helper.process_tool_calls(final_tool_calls)
|
76 |
+
tool_call_message = {"role": "tool", "content": tool_call_result.get("content", "")}
|
|
|
|
|
77 |
messages.append(tool_call_message)
|
78 |
|
79 |
# new_stream = open_ai_client.chat.completions.create(
|
|
|
91 |
messages = build_context_prompt(request)
|
92 |
messages.extend(request.messages)
|
93 |
|
94 |
+
output = open_ai_client.open_ai_client.chat.completions.create(
|
95 |
+
messages=messages, model="my-model", tools=tools_define.tools
|
96 |
+
).model_dump()
|
97 |
+
# output = generate(messages=messages)
|
98 |
+
choices = output.get("choices", [])
|
|
|
99 |
|
100 |
+
tool_calls = choices[0].get("message").get("tool_calls")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
if not tool_calls:
|
103 |
return output
|
104 |
|
105 |
tool_call_result = tools_helper.process_tool_calls(
|
106 |
+
tool_calls=tool_calls
|
|
|
|
|
|
|
107 |
)
|
108 |
+
tool_call_message = {"role": "tool", "content": tool_call_result.get("content", "")}
|
109 |
messages.append(tool_call_message)
|
110 |
|
111 |
+
# new_output = generate(messages=messages, has_tool_call=False)
|
112 |
+
new_output = open_ai_client.open_ai_client.chat.completions.create(
|
113 |
+
messages=messages, model="my-model"
|
114 |
+
).model_dump()
|
115 |
|
116 |
return new_output
|
src/utils/image_pipeline.py
CHANGED
@@ -1,34 +1,27 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
|
7 |
-
|
8 |
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
# )
|
21 |
-
# # _pipeline = StableDiffusionPipeline.from_single_file(
|
22 |
-
# # IMAGE_MODEL_ID_OR_LINK,
|
23 |
-
# # torch_dtype=torch.bfloat16,
|
24 |
-
# # variant="fp16",
|
25 |
-
# # # safety_checker=True,
|
26 |
-
# # use_safetensors=True,
|
27 |
-
# # )
|
28 |
-
# _pipeline.to(TORCH_DEVICE)
|
29 |
-
# except Exception as e:
|
30 |
-
# raise RuntimeError(f"Failed to load the model: {e}")
|
31 |
-
# return _pipeline
|
32 |
|
33 |
|
34 |
-
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
3 |
+
StableDiffusionPipeline,
|
4 |
+
)
|
5 |
+
from constants.config import IMAGE_MODEL_ID_OR_LINK, TORCH_DEVICE
|
6 |
+
from utils.timing import measure_time
|
7 |
|
8 |
+
torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for performance on CUDA
|
9 |
|
10 |
+
pipeline = None
|
11 |
|
12 |
|
13 |
+
def load_pipeline():
|
14 |
+
global pipeline
|
15 |
+
with measure_time("Load image pipeline"):
|
16 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
17 |
+
IMAGE_MODEL_ID_OR_LINK,
|
18 |
+
torch_dtype=torch.bfloat16,
|
19 |
+
variant="fp16",
|
20 |
+
# safety_checker=True,
|
21 |
+
use_safetensors=True,
|
22 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
+
def clear_resources():
|
26 |
+
global pipeline
|
27 |
+
pipeline = None
|
src/utils/timing.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import time
|
2 |
|
|
|
3 |
class measure_time:
|
4 |
def __init__(self, label="Operation"):
|
5 |
self.label = label
|
6 |
|
7 |
def __enter__(self):
|
|
|
8 |
self.start = time.time()
|
9 |
return self
|
10 |
|
|
|
1 |
import time
|
2 |
|
3 |
+
|
4 |
class measure_time:
|
5 |
def __init__(self, label="Operation"):
|
6 |
self.label = label
|
7 |
|
8 |
def __enter__(self):
|
9 |
+
print(f"\nStart: {self.label}")
|
10 |
self.start = time.time()
|
11 |
return self
|
12 |
|
src/utils/tools/tools_helper.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import json
|
2 |
from typing import List
|
3 |
-
from models.responses.tool_call_response import ToolCall
|
4 |
from utils.tools.tools_define import ToolFunction
|
5 |
from services import image_service, web_data_service
|
6 |
|
@@ -15,7 +14,7 @@ def extract_tool_args(tool_call):
|
|
15 |
Returns:
|
16 |
dict: The extracted arguments as a dictionary
|
17 |
"""
|
18 |
-
return json.loads(tool_call.function.arguments)
|
19 |
|
20 |
|
21 |
def handle_web_data_tool_call(tool_call):
|
@@ -66,7 +65,7 @@ def handle_search_web_tool_call(tool_call):
|
|
66 |
return search_results
|
67 |
|
68 |
|
69 |
-
def process_tool_calls(
|
70 |
"""
|
71 |
Process all tool calls and execute them.
|
72 |
|
@@ -87,8 +86,8 @@ def process_tool_calls(final_tool_calls):
|
|
87 |
ToolFunction.SEARCH_WEB.value: handle_search_web_tool_call,
|
88 |
}
|
89 |
|
90 |
-
for tool_call in
|
91 |
-
handler = tool_handlers.get(tool_call.function.name)
|
92 |
if handler:
|
93 |
result = handler(tool_call)
|
94 |
if isinstance(result, list):
|
@@ -98,14 +97,14 @@ def process_tool_calls(final_tool_calls):
|
|
98 |
|
99 |
return {
|
100 |
"role": "tool",
|
101 |
-
"tool_call_id": tool_call.id,
|
102 |
-
"tool_call_name": tool_call.function.name,
|
103 |
"content": content,
|
104 |
}
|
105 |
|
106 |
|
107 |
def final_tool_calls_handler(
|
108 |
-
final_tool_calls: dict, tool_calls: List[
|
109 |
):
|
110 |
"""
|
111 |
Handle and combine multiple tool calls.
|
@@ -120,13 +119,11 @@ def final_tool_calls_handler(
|
|
120 |
for index, tool_call in enumerate(tool_calls):
|
121 |
if index not in final_tool_calls:
|
122 |
final_tool_calls[index] = tool_call
|
123 |
-
elif tool_call.function is not None:
|
124 |
if is_stream:
|
125 |
-
final_tool_calls[
|
126 |
-
index
|
127 |
-
].function.arguments += tool_call.function.arguments
|
128 |
else:
|
129 |
-
final_tool_calls[index]
|
130 |
-
tool_call.function.arguments
|
131 |
-
)
|
132 |
return final_tool_calls
|
|
|
1 |
import json
|
2 |
from typing import List
|
|
|
3 |
from utils.tools.tools_define import ToolFunction
|
4 |
from services import image_service, web_data_service
|
5 |
|
|
|
14 |
Returns:
|
15 |
dict: The extracted arguments as a dictionary
|
16 |
"""
|
17 |
+
return json.loads(tool_call.get("function", {}).get("arguments", "{}"))
|
18 |
|
19 |
|
20 |
def handle_web_data_tool_call(tool_call):
|
|
|
65 |
return search_results
|
66 |
|
67 |
|
68 |
+
def process_tool_calls(tool_calls):
|
69 |
"""
|
70 |
Process all tool calls and execute them.
|
71 |
|
|
|
86 |
ToolFunction.SEARCH_WEB.value: handle_search_web_tool_call,
|
87 |
}
|
88 |
|
89 |
+
for tool_call in tool_calls:
|
90 |
+
handler = tool_handlers.get(tool_call.get("function").get("name"))
|
91 |
if handler:
|
92 |
result = handler(tool_call)
|
93 |
if isinstance(result, list):
|
|
|
97 |
|
98 |
return {
|
99 |
"role": "tool",
|
100 |
+
"tool_call_id": tool_call.get("id"),
|
101 |
+
"tool_call_name": tool_call.get("function", {}).get("name"),
|
102 |
"content": content,
|
103 |
}
|
104 |
|
105 |
|
106 |
def final_tool_calls_handler(
|
107 |
+
final_tool_calls: dict, tool_calls: List[dict], is_stream: bool = False
|
108 |
):
|
109 |
"""
|
110 |
Handle and combine multiple tool calls.
|
|
|
119 |
for index, tool_call in enumerate(tool_calls):
|
120 |
if index not in final_tool_calls:
|
121 |
final_tool_calls[index] = tool_call
|
122 |
+
elif tool_call.get("function") is not None:
|
123 |
if is_stream:
|
124 |
+
if "function" in final_tool_calls[index] and "arguments" in final_tool_calls[index]["function"]:
|
125 |
+
final_tool_calls[index]["function"]["arguments"] += tool_call.get("function", {}).get("arguments", "")
|
|
|
126 |
else:
|
127 |
+
if "function" in final_tool_calls[index]:
|
128 |
+
final_tool_calls[index]["function"]["arguments"] = tool_call.get("function", {}).get("arguments", "")
|
|
|
129 |
return final_tool_calls
|
src/utils/transformer_client.py
CHANGED
@@ -1,71 +1,122 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
from threading import Thread
|
3 |
from typing import Generator, List
|
4 |
-
import
|
|
|
5 |
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
6 |
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
7 |
from constants.config import (
|
8 |
-
GGUF_FILE_NAME,
|
9 |
LLM_MODEL_NAME,
|
10 |
-
GGUF_REPO_ID,
|
11 |
TORCH_DEVICE,
|
12 |
USE_QUANT,
|
13 |
MODEL_OPTIMIZATION,
|
14 |
-
IS_APPLE_SILICON,
|
15 |
)
|
16 |
-
from models.others.message import Message, Role
|
17 |
-
from models.responses.chat_response import ChatResponse
|
18 |
from transformers.generation.streamers import TextIteratorStreamer
|
19 |
from utils.timing import measure_time
|
20 |
from utils.tools import tools_define
|
21 |
from transformers.utils.quantization_config import BitsAndBytesConfig
|
22 |
|
23 |
-
# Configure model loading based on device
|
24 |
-
if USE_QUANT:
|
25 |
-
quantization_config = BitsAndBytesConfig(
|
26 |
-
load_in_4bit=True,
|
27 |
-
bnb_4bit_quant_type="nf4",
|
28 |
-
bnb_4bit_compute_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
29 |
-
bnb_4bit_use_double_quant=True,
|
30 |
-
)
|
31 |
-
model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
LLM_MODEL_NAME,
|
33 |
-
torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
34 |
-
device_map="auto",
|
35 |
-
quantization_config=quantization_config,
|
36 |
-
low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
|
37 |
-
use_cache=MODEL_OPTIMIZATION["use_cache"],
|
38 |
-
)
|
39 |
-
else:
|
40 |
-
model = AutoModelForCausalLM.from_pretrained(
|
41 |
-
LLM_MODEL_NAME,
|
42 |
-
torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
43 |
-
device_map="auto",
|
44 |
-
low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
|
45 |
-
use_cache=MODEL_OPTIMIZATION["use_cache"],
|
46 |
-
)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
def generate(messages: List[Message], has_tool_call: bool = True) -> ChatResponse:
|
60 |
# Convert messages to prompt
|
61 |
-
prompt =
|
62 |
|
63 |
# Prepare tools if enabled
|
64 |
tools = tools_define.tools if has_tool_call else None
|
65 |
tool_choice = "auto" if has_tool_call else "none"
|
66 |
|
67 |
# Apply chat template
|
68 |
-
formatted_prompt =
|
69 |
prompt,
|
70 |
tools=tools,
|
71 |
tool_choice=tool_choice,
|
@@ -73,61 +124,71 @@ def generate(messages: List[Message], has_tool_call: bool = True) -> ChatRespons
|
|
73 |
add_generation_prompt=True,
|
74 |
)
|
75 |
|
76 |
-
print("Starting create chat completion")
|
77 |
try:
|
78 |
-
with measure_time("
|
79 |
# Tokenize input with optimized settings
|
80 |
-
inputs =
|
81 |
formatted_prompt,
|
82 |
return_tensors="pt",
|
83 |
padding=True,
|
84 |
truncation=True,
|
85 |
-
max_length=
|
86 |
).to(TORCH_DEVICE)
|
87 |
|
88 |
# Generate response with optimized settings
|
89 |
-
output_ids =
|
90 |
**inputs,
|
91 |
max_new_tokens=4096,
|
92 |
do_sample=True,
|
93 |
temperature=0.7,
|
94 |
-
pad_token_id=
|
95 |
-
eos_token_id=
|
96 |
use_cache=True, # Enable KV cache for faster generation
|
97 |
num_beams=1, # Use greedy decoding for faster inference
|
98 |
)
|
99 |
|
100 |
# Decode response
|
101 |
-
output_text =
|
102 |
output_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
103 |
)
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
{
|
108 |
-
"choices": [
|
109 |
-
{
|
110 |
-
"message": {
|
111 |
-
"role": Role.assistant.value,
|
112 |
-
"content": output_text,
|
113 |
-
}
|
114 |
-
}
|
115 |
-
]
|
116 |
-
}
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
except Exception as e:
|
119 |
print(f"Error in create chat completion: {str(e)}")
|
120 |
raise
|
121 |
|
122 |
|
123 |
-
def generate_stream(messages: List[
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
# Convert messages to prompt
|
125 |
-
prompt =
|
126 |
# Prepare tools
|
127 |
tools = tools_define.tools
|
128 |
|
129 |
# Apply chat template
|
130 |
-
formatted_prompt =
|
131 |
prompt,
|
132 |
tools=tools,
|
133 |
tool_choice="auto",
|
@@ -137,7 +198,7 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
|
|
137 |
|
138 |
try:
|
139 |
# Tokenize input with optimized settings
|
140 |
-
inputs =
|
141 |
prompt,
|
142 |
return_tensors="pt",
|
143 |
padding=True,
|
@@ -147,7 +208,7 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
|
|
147 |
|
148 |
# Generate streaming output
|
149 |
streamer = TextIteratorStreamer(
|
150 |
-
|
151 |
skip_prompt=True,
|
152 |
skip_special_tokens=True,
|
153 |
)
|
@@ -158,17 +219,17 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
|
|
158 |
do_sample=True,
|
159 |
max_new_tokens=4096,
|
160 |
temperature=0.7,
|
161 |
-
pad_token_id=
|
162 |
-
eos_token_id=
|
163 |
use_cache=True, # Enable KV cache for faster generation
|
164 |
num_beams=1, # Use greedy decoding for faster inference
|
165 |
)
|
166 |
|
167 |
# Generate in background thread
|
168 |
-
thread = Thread(target=
|
169 |
thread.start()
|
170 |
|
171 |
-
last_role =
|
172 |
for new_text in streamer:
|
173 |
# Format the chunk to match the expected structure
|
174 |
chunk = {
|
|
|
1 |
+
from email import message
|
2 |
+
import json
|
3 |
+
import re
|
4 |
from threading import Thread
|
5 |
from typing import Generator, List
|
6 |
+
import uuid
|
7 |
+
from numpy import append
|
8 |
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
9 |
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
10 |
from constants.config import (
|
|
|
11 |
LLM_MODEL_NAME,
|
|
|
12 |
TORCH_DEVICE,
|
13 |
USE_QUANT,
|
14 |
MODEL_OPTIMIZATION,
|
|
|
15 |
)
|
|
|
|
|
16 |
from transformers.generation.streamers import TextIteratorStreamer
|
17 |
from utils.timing import measure_time
|
18 |
from utils.tools import tools_define
|
19 |
from transformers.utils.quantization_config import BitsAndBytesConfig
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
def load_model():
|
23 |
+
global _model, _tokenizer
|
24 |
+
|
25 |
+
# Configure model loading based on device
|
26 |
+
try:
|
27 |
+
with measure_time("Load model"):
|
28 |
+
if USE_QUANT:
|
29 |
+
quantization_config = BitsAndBytesConfig(
|
30 |
+
load_in_4bit=True,
|
31 |
+
bnb_4bit_quant_type="nf4",
|
32 |
+
bnb_4bit_compute_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
33 |
+
bnb_4bit_use_double_quant=True,
|
34 |
+
)
|
35 |
+
_model = AutoModelForCausalLM.from_pretrained(
|
36 |
+
LLM_MODEL_NAME,
|
37 |
+
torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
38 |
+
device_map="auto",
|
39 |
+
quantization_config=quantization_config,
|
40 |
+
low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
|
41 |
+
use_cache=MODEL_OPTIMIZATION["use_cache"],
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
_model = AutoModelForCausalLM.from_pretrained(
|
45 |
+
LLM_MODEL_NAME,
|
46 |
+
torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
|
47 |
+
device_map="auto",
|
48 |
+
low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
|
49 |
+
use_cache=MODEL_OPTIMIZATION["use_cache"],
|
50 |
+
)
|
51 |
+
|
52 |
+
# Configure tokenizer with appropriate settings
|
53 |
+
_tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
LLM_MODEL_NAME,
|
55 |
+
use_fast=True, # Use fast tokenizer for better performance
|
56 |
+
)
|
57 |
+
|
58 |
+
_model.eval()
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Failed to load model or tokenizer: {str(e)}")
|
61 |
+
_model = None
|
62 |
+
_tokenizer = None
|
63 |
+
raise
|
64 |
+
|
65 |
+
|
66 |
+
def clear_resources():
|
67 |
+
global _model, _tokenizer
|
68 |
+
_model = None
|
69 |
+
_tokenizer = None
|
70 |
+
|
71 |
+
|
72 |
+
def build_prompt(messages: List[dict]) -> str:
|
73 |
+
return "\n".join([f"{m.get('role')}: {m.get('content')}" for m in messages])
|
74 |
+
|
75 |
|
76 |
+
def extract_tool_calls_and_reupdate_output(text: str):
|
77 |
+
"""
|
78 |
+
Extracts all valid JSON objects found within <tool_call>{...}</tool_call> patterns.
|
79 |
+
"""
|
80 |
+
tool_calls = []
|
81 |
|
82 |
+
# Match any <tool_call> JSON-like structure (greedy to match full JSON block)
|
83 |
+
pattern = r"<tool_call>\s*(\{.*?\})\s*</?tool_call>?"
|
84 |
|
85 |
+
matches = list(re.finditer(pattern, text, re.DOTALL))
|
86 |
+
|
87 |
+
for match in matches:
|
88 |
+
try:
|
89 |
+
tool_call = {}
|
90 |
+
tool_call["id"] = str(uuid.uuid4())
|
91 |
+
tool_call["type"] = "function"
|
92 |
+
tool_call["function"] = {
|
93 |
+
"name": match.group(1),
|
94 |
+
"arguments": json.loads(match.group(1)),
|
95 |
+
}
|
96 |
+
tool_calls.append(tool_call)
|
97 |
+
except json.JSONDecodeError:
|
98 |
+
continue
|
99 |
+
|
100 |
+
text = re.sub(pattern, "", text, flags=re.DOTALL).strip()
|
101 |
+
return text.strip(), tool_calls if tool_calls else None
|
102 |
+
|
103 |
+
|
104 |
+
def generate(messages: List[dict], has_tool_call: bool = True) -> dict:
|
105 |
+
|
106 |
+
if _model is None or _tokenizer is None:
|
107 |
+
raise RuntimeError(
|
108 |
+
"Model or tokenizer not initialized. Ensure load_model was called successfully."
|
109 |
+
)
|
110 |
|
|
|
111 |
# Convert messages to prompt
|
112 |
+
prompt = build_prompt(messages)
|
113 |
|
114 |
# Prepare tools if enabled
|
115 |
tools = tools_define.tools if has_tool_call else None
|
116 |
tool_choice = "auto" if has_tool_call else "none"
|
117 |
|
118 |
# Apply chat template
|
119 |
+
formatted_prompt = _tokenizer.apply_chat_template(
|
120 |
prompt,
|
121 |
tools=tools,
|
122 |
tool_choice=tool_choice,
|
|
|
124 |
add_generation_prompt=True,
|
125 |
)
|
126 |
|
|
|
127 |
try:
|
128 |
+
with measure_time("Create chat completion"):
|
129 |
# Tokenize input with optimized settings
|
130 |
+
inputs = _tokenizer(
|
131 |
formatted_prompt,
|
132 |
return_tensors="pt",
|
133 |
padding=True,
|
134 |
truncation=True,
|
135 |
+
max_length=4096, # Adjust based on your needs
|
136 |
).to(TORCH_DEVICE)
|
137 |
|
138 |
# Generate response with optimized settings
|
139 |
+
output_ids = _model.generate(
|
140 |
**inputs,
|
141 |
max_new_tokens=4096,
|
142 |
do_sample=True,
|
143 |
temperature=0.7,
|
144 |
+
pad_token_id=_tokenizer.pad_token_id,
|
145 |
+
eos_token_id=_tokenizer.eos_token_id,
|
146 |
use_cache=True, # Enable KV cache for faster generation
|
147 |
num_beams=1, # Use greedy decoding for faster inference
|
148 |
)
|
149 |
|
150 |
# Decode response
|
151 |
+
output_text = _tokenizer.decode(
|
152 |
output_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
153 |
)
|
154 |
|
155 |
+
cleaned_output, tool_calls = extract_tool_calls_and_reupdate_output(
|
156 |
+
output_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
+
|
159 |
+
# Create ChatResponse using from_llm_output
|
160 |
+
return {
|
161 |
+
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
162 |
+
"choices": [
|
163 |
+
{
|
164 |
+
"message": {
|
165 |
+
"role": "assistant",
|
166 |
+
"content": cleaned_output,
|
167 |
+
"tool_calls": tool_calls,
|
168 |
+
},
|
169 |
+
}
|
170 |
+
],
|
171 |
+
}
|
172 |
+
|
173 |
except Exception as e:
|
174 |
print(f"Error in create chat completion: {str(e)}")
|
175 |
raise
|
176 |
|
177 |
|
178 |
+
def generate_stream(messages: List[dict]) -> Generator[dict, None, None]:
|
179 |
+
|
180 |
+
if _model is None or _tokenizer is None:
|
181 |
+
raise RuntimeError(
|
182 |
+
"Model or tokenizer not initialized. Ensure load_model was called successfully."
|
183 |
+
)
|
184 |
+
|
185 |
# Convert messages to prompt
|
186 |
+
prompt = build_prompt(messages)
|
187 |
# Prepare tools
|
188 |
tools = tools_define.tools
|
189 |
|
190 |
# Apply chat template
|
191 |
+
formatted_prompt = _tokenizer.apply_chat_template(
|
192 |
prompt,
|
193 |
tools=tools,
|
194 |
tool_choice="auto",
|
|
|
198 |
|
199 |
try:
|
200 |
# Tokenize input with optimized settings
|
201 |
+
inputs = _tokenizer(
|
202 |
prompt,
|
203 |
return_tensors="pt",
|
204 |
padding=True,
|
|
|
208 |
|
209 |
# Generate streaming output
|
210 |
streamer = TextIteratorStreamer(
|
211 |
+
_tokenizer,
|
212 |
skip_prompt=True,
|
213 |
skip_special_tokens=True,
|
214 |
)
|
|
|
219 |
do_sample=True,
|
220 |
max_new_tokens=4096,
|
221 |
temperature=0.7,
|
222 |
+
pad_token_id=_tokenizer.pad_token_id,
|
223 |
+
eos_token_id=_tokenizer.eos_token_id,
|
224 |
use_cache=True, # Enable KV cache for faster generation
|
225 |
num_beams=1, # Use greedy decoding for faster inference
|
226 |
)
|
227 |
|
228 |
# Generate in background thread
|
229 |
+
thread = Thread(target=_model.generate, kwargs=generation_kwargs)
|
230 |
thread.start()
|
231 |
|
232 |
+
last_role = "assistant"
|
233 |
for new_text in streamer:
|
234 |
# Format the chunk to match the expected structure
|
235 |
chunk = {
|