LeoNguyen101120 commited on
Commit
2692e0d
·
1 Parent(s): c62df12

Refactor chat handling and model integration: Update .env.example to include new API keys, modify main.py to implement a lifespan context manager for resource management, and replace Message class with dictionary structures in chat_request.py and chat_service.py for improved flexibility. Remove unused message and response models to streamline codebase.

Browse files
.env.example CHANGED
@@ -1,3 +1,4 @@
1
  jina_api_key=
2
  brave_search_api_key=
3
- ai_url=
 
 
1
  jina_api_key=
2
  brave_search_api_key=
3
+ ai_url=
4
+ serp_api
src/main.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
  from fastapi import FastAPI, Request
 
3
  from fastapi.exceptions import RequestValidationError
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import JSONResponse
@@ -7,9 +9,26 @@ from fastapi.staticfiles import StaticFiles
7
  from constants.config import OUTPUT_DIR
8
  from models.responses.base_response import BaseResponse
9
  from routes import chat_routes, process_file_routes, vector_store_routes
 
10
  from utils.exception import CustomException
11
 
12
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  origins = ["*"]
15
  app.add_middleware(
@@ -20,34 +39,43 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
 
23
  @app.exception_handler(CustomException)
24
  async def custom_exception_handler(request: Request, exc: CustomException):
25
  return JSONResponse(
26
  status_code=exc.status_code,
27
- content=BaseResponse(status_code=exc.status_code, message=exc.message).model_dump(),
 
 
28
  )
29
 
 
30
  @app.exception_handler(Exception)
31
  async def global_exception_handler(request: Request, exc: Exception):
32
  # Mặc định cho các lỗi không được CustomException xử lý
33
  return JSONResponse(
34
  status_code=500,
35
- content=BaseResponse(status_code=500, message=str(exc)).model_dump()
36
  )
37
 
 
38
  @app.exception_handler(RequestValidationError)
39
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
40
  return JSONResponse(
41
  status_code=422,
42
- content=BaseResponse(status_code=422, message="Validation error").model_dump()
43
  )
44
 
 
45
  app.include_router(chat_routes.router, prefix="/api/v1")
46
  app.include_router(process_file_routes.router, prefix="/api/v1")
47
  app.include_router(vector_store_routes.router, prefix="/api/v1")
 
 
48
  @app.get("/")
49
  def read_root():
50
  return {"message": "Welcome to my API"}
51
 
 
52
  os.makedirs(OUTPUT_DIR, exist_ok=True)
53
- app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
 
1
  import os
2
+ from sre_parse import Tokenizer
3
  from fastapi import FastAPI, Request
4
+ from fastapi.concurrency import asynccontextmanager
5
  from fastapi.exceptions import RequestValidationError
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.responses import JSONResponse
 
9
  from constants.config import OUTPUT_DIR
10
  from models.responses.base_response import BaseResponse
11
  from routes import chat_routes, process_file_routes, vector_store_routes
12
+ from utils import image_pipeline, transformer_client
13
  from utils.exception import CustomException
14
 
15
+
16
+ @asynccontextmanager
17
+ async def lifespan(app: FastAPI):
18
+ # try:
19
+ # transformer_client.load_model()
20
+ # image_pipeline.load_pipeline()
21
+
22
+ # except Exception as e:
23
+ # print(f"Error during startup: {str(e)}")
24
+
25
+ yield
26
+
27
+ transformer_client.clear_resources()
28
+ image_pipeline.clear_resources()
29
+
30
+
31
+ app = FastAPI(lifespan=lifespan)
32
 
33
  origins = ["*"]
34
  app.add_middleware(
 
39
  allow_headers=["*"],
40
  )
41
 
42
+
43
  @app.exception_handler(CustomException)
44
  async def custom_exception_handler(request: Request, exc: CustomException):
45
  return JSONResponse(
46
  status_code=exc.status_code,
47
+ content=BaseResponse(
48
+ status_code=exc.status_code, message=exc.message
49
+ ).model_dump(),
50
  )
51
 
52
+
53
  @app.exception_handler(Exception)
54
  async def global_exception_handler(request: Request, exc: Exception):
55
  # Mặc định cho các lỗi không được CustomException xử lý
56
  return JSONResponse(
57
  status_code=500,
58
+ content=BaseResponse(status_code=500, message=str(exc)).model_dump(),
59
  )
60
 
61
+
62
  @app.exception_handler(RequestValidationError)
63
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
64
  return JSONResponse(
65
  status_code=422,
66
+ content=BaseResponse(status_code=422, message="Validation error").model_dump(),
67
  )
68
 
69
+
70
  app.include_router(chat_routes.router, prefix="/api/v1")
71
  app.include_router(process_file_routes.router, prefix="/api/v1")
72
  app.include_router(vector_store_routes.router, prefix="/api/v1")
73
+
74
+
75
  @app.get("/")
76
  def read_root():
77
  return {"message": "Welcome to my API"}
78
 
79
+
80
  os.makedirs(OUTPUT_DIR, exist_ok=True)
81
+ app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
src/models/others/message.py DELETED
@@ -1,24 +0,0 @@
1
- from enum import Enum
2
- from typing import List, Optional
3
-
4
- from pydantic import BaseModel
5
-
6
- from models.responses.tool_call_response import ToolCall
7
-
8
-
9
- class Role(str, Enum):
10
- assistant = "assistant"
11
- user = "user"
12
- system = "system"
13
- tool = "tool"
14
-
15
-
16
- class Message(BaseModel):
17
- role: Role
18
- content: Optional[str] = None
19
- tool_calls: Optional[List[ToolCall]] = None
20
-
21
- def to_map(self):
22
- data = self.model_dump(exclude_none=True)
23
- data["role"] = self.role.value
24
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/requests/chat_request.py CHANGED
@@ -2,11 +2,10 @@ from typing import List, Optional
2
  from pydantic import BaseModel
3
 
4
  from constants.config import LLM_MODEL_NAME
5
- from models.others.message import Role, Message
6
 
7
 
8
  class ChatRequest(BaseModel):
9
- messages: List[Message]
10
  has_file: bool = False
11
  chat_session_id: str | None = None
12
 
@@ -16,7 +15,7 @@ class ChatRequest(BaseModel):
16
  {
17
  "has_file": False,
18
  "chat_session_id": "123",
19
- "messages": [{"role": Role.user, "content": "hello"}],
20
  }
21
  ]
22
  }
 
2
  from pydantic import BaseModel
3
 
4
  from constants.config import LLM_MODEL_NAME
 
5
 
6
 
7
  class ChatRequest(BaseModel):
8
+ messages: List[dict]
9
  has_file: bool = False
10
  chat_session_id: str | None = None
11
 
 
15
  {
16
  "has_file": False,
17
  "chat_session_id": "123",
18
+ "messages": [{"role": 'user', "content": "hello"}],
19
  }
20
  ]
21
  }
src/models/responses/chat_response.py DELETED
@@ -1,92 +0,0 @@
1
- from typing import Any, List, Optional
2
- from click import argument
3
- from pydantic import BaseModel
4
- from models.others.message import Message, Role
5
- from models.responses.tool_call_response import ToolCall
6
-
7
-
8
- class Choice(BaseModel):
9
- message: Optional[Message] = None
10
- delta: Optional[Message] = None
11
- function_call: Optional[ToolCall] = None
12
-
13
-
14
- class ChatResponse(BaseModel):
15
- id: Optional[str] = None
16
- choices: Optional[List[Choice]] = None
17
-
18
- @classmethod
19
- def from_stream_chunk(cls, chunk: dict, last_role: Optional[Role] = None):
20
- choices = []
21
- updated_role = last_role # Default to last role
22
-
23
- for choice in chunk.get("choices", []):
24
- delta_data = choice.get("delta", {})
25
-
26
- # Skip chunks that contain neither content nor role
27
- if not delta_data.get("content") and not delta_data.get("role"):
28
- continue
29
-
30
- # Determine role
31
- if "role" in delta_data and delta_data["role"] is not None:
32
- try:
33
- updated_role = Role(delta_data["role"])
34
- except ValueError:
35
- # Skip or log invalid role values
36
- continue
37
-
38
- if not updated_role:
39
- # Still no role available, skip
40
- continue
41
-
42
- message = Message(
43
- role=updated_role,
44
- content=delta_data.get("content"),
45
- )
46
-
47
- choices.append(
48
- Choice(
49
- message=message,
50
- delta=message,
51
- )
52
- )
53
-
54
- return (
55
- cls(
56
- id=chunk.get("id"),
57
- choices=choices,
58
- ),
59
- updated_role,
60
- )
61
-
62
- @classmethod
63
- def from_llm_output(cls, output: dict) -> "ChatResponse":
64
- """
65
- Map the output dict from llm.create_chat_completion to a ChatResponse instance.
66
- """
67
- choices = []
68
- for choice in output.get("choices", []):
69
- message_data = choice.get("message", {})
70
- tool_calls_data = message_data.get("tool_calls")
71
- tool_calls = None
72
- if tool_calls_data:
73
- tool_calls = [ToolCall(**tc) for tc in tool_calls_data]
74
- message = Message(
75
- role=Role(message_data["role"]),
76
- content=message_data.get("content"),
77
- tool_calls=tool_calls,
78
- )
79
- # function_call is for OpenAI compatibility, may be None
80
- function_call = None
81
- if "function_call" in choice:
82
- function_call = ToolCall(**choice["function_call"])
83
- choices.append(
84
- Choice(
85
- message=message,
86
- function_call=function_call,
87
- )
88
- )
89
- return cls(
90
- id=output.get("id"),
91
- choices=choices,
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/responses/tool_call_response.py DELETED
@@ -1,13 +0,0 @@
1
- from typing import Optional
2
- from pydantic import BaseModel
3
-
4
-
5
- class FunctionOfToolCall(BaseModel):
6
- name: Optional[str]
7
- arguments: Optional[str]
8
-
9
-
10
- class ToolCall(BaseModel):
11
- id: Optional[str]
12
- type: Optional[str]
13
- function: Optional[FunctionOfToolCall]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/routes/chat_routes.py CHANGED
@@ -70,9 +70,7 @@ async def chat(request: ChatRequest):
70
  try:
71
  response = chat_service.chat_generate(request=request)
72
  return BaseResponse(
73
- data=json.loads(
74
- response.model_dump_json(),
75
- ),
76
  )
77
  except Exception as e:
78
  raise BaseExceptionResponse(message=str(e))
 
70
  try:
71
  response = chat_service.chat_generate(request=request)
72
  return BaseResponse(
73
+ data=response,
 
 
74
  )
75
  except Exception as e:
76
  raise BaseExceptionResponse(message=str(e))
src/services/chat_service.py CHANGED
@@ -6,13 +6,12 @@ from services import vector_store_service
6
  from utils import open_ai_client
7
  from utils.timing import measure_time
8
  from utils.tools import tools_helper, tools_define
9
- from models.others.message import Message, Role
10
  from utils.transformer_client import generate, generate_stream
11
 
12
 
13
- def build_context_prompt(request: ChatRequest) -> list[Message]:
14
  """Build system prompt with context if file is provided."""
15
- messages = [Message(role=Role.system, content=system_prompts.system_prompt)]
16
 
17
  if not request.has_file or not vector_store_service.check_if_collection_exists(
18
  request.chat_session_id
@@ -21,7 +20,7 @@ def build_context_prompt(request: ChatRequest) -> list[Message]:
21
 
22
  with measure_time("Get data from vector store"):
23
  vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
24
- query = request.messages[-1].content
25
  results = vectorstore.similarity_search(query=query or "", k=10)
26
 
27
  if not results:
@@ -42,7 +41,7 @@ def build_context_prompt(request: ChatRequest) -> list[Message]:
42
  f"CONTEXT: {context}\nQUESTION: {query}"
43
  )
44
 
45
- messages.append(Message(role=Role.system, content=embedded_prompt))
46
  return messages
47
 
48
 
@@ -62,21 +61,19 @@ def chat_generate_stream(
62
  final_tool_calls = {}
63
 
64
  for chunk in stream:
65
- if chunk.choices and len(chunk.choices) > 0:
66
- delta = chunk.choices[0].delta
67
- if getattr(delta, "tool_calls", None):
68
- final_tool_calls = tools_helper.final_tool_calls_handler(
69
- final_tool_calls, delta.tool_calls, is_stream=True
70
- )
71
- yield chunk
72
 
73
  if not final_tool_calls:
74
  return
75
 
76
  tool_call_result = tools_helper.process_tool_calls(final_tool_calls)
77
- tool_call_message = Message(
78
- role=Role.tool, content=tool_call_result.get("content", "")
79
- )
80
  messages.append(tool_call_message)
81
 
82
  # new_stream = open_ai_client.chat.completions.create(
@@ -94,39 +91,26 @@ def chat_generate(request: ChatRequest):
94
  messages = build_context_prompt(request)
95
  messages.extend(request.messages)
96
 
97
- # output = open_ai_client.open_ai_client.chat.completions.create(
98
- # messages=messages, model="my-model", tools=tools_define.tools
99
- # )
100
- output = generate(messages=messages)
101
-
102
- final_tool_calls = {}
103
 
104
- message = None
105
- if output.choices and len(output.choices) > 0:
106
- message = output.choices[0].message
107
- if (
108
- message is not None
109
- and getattr(message, "tool_calls", None)
110
- and message.tool_calls
111
- ):
112
- final_tool_calls = tools_helper.final_tool_calls_handler(
113
- final_tool_calls=final_tool_calls, tool_calls=message.tool_calls
114
- )
115
 
116
- if not final_tool_calls:
117
  return output
118
 
119
  tool_call_result = tools_helper.process_tool_calls(
120
- final_tool_calls=final_tool_calls
121
- )
122
- tool_call_message = Message(
123
- role=Role.tool, content=tool_call_result.get("content", "")
124
  )
 
125
  messages.append(tool_call_message)
126
 
127
- new_output = generate(messages=messages, has_tool_call=False)
128
- # new_output = open_ai_client.chat.completions.create(
129
- # messages=messages, model="my-model", tools=tools_define.tools
130
- # )
131
 
132
  return new_output
 
6
  from utils import open_ai_client
7
  from utils.timing import measure_time
8
  from utils.tools import tools_helper, tools_define
 
9
  from utils.transformer_client import generate, generate_stream
10
 
11
 
12
+ def build_context_prompt(request: ChatRequest) -> list[dict]:
13
  """Build system prompt with context if file is provided."""
14
+ messages = [{"role": "system", "content": system_prompts.system_prompt}]
15
 
16
  if not request.has_file or not vector_store_service.check_if_collection_exists(
17
  request.chat_session_id
 
20
 
21
  with measure_time("Get data from vector store"):
22
  vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
23
+ query = request.messages[-1].get("content")
24
  results = vectorstore.similarity_search(query=query or "", k=10)
25
 
26
  if not results:
 
41
  f"CONTEXT: {context}\nQUESTION: {query}"
42
  )
43
 
44
+ messages.append({"role": "system", "content": embedded_prompt})
45
  return messages
46
 
47
 
 
61
  final_tool_calls = {}
62
 
63
  for chunk in stream:
64
+ choices = chunk.get("choices", [])
65
+ if choices and choices[0].get("delta", {}).get("tool_calls"):
66
+ delta = choices[0]["delta"]
67
+ final_tool_calls = tools_helper.final_tool_calls_handler(
68
+ final_tool_calls, delta["tool_calls"], is_stream=True
69
+ )
70
+ yield chunk
71
 
72
  if not final_tool_calls:
73
  return
74
 
75
  tool_call_result = tools_helper.process_tool_calls(final_tool_calls)
76
+ tool_call_message = {"role": "tool", "content": tool_call_result.get("content", "")}
 
 
77
  messages.append(tool_call_message)
78
 
79
  # new_stream = open_ai_client.chat.completions.create(
 
91
  messages = build_context_prompt(request)
92
  messages.extend(request.messages)
93
 
94
+ output = open_ai_client.open_ai_client.chat.completions.create(
95
+ messages=messages, model="my-model", tools=tools_define.tools
96
+ ).model_dump()
97
+ # output = generate(messages=messages)
98
+ choices = output.get("choices", [])
 
99
 
100
+ tool_calls = choices[0].get("message").get("tool_calls")
 
 
 
 
 
 
 
 
 
 
101
 
102
+ if not tool_calls:
103
  return output
104
 
105
  tool_call_result = tools_helper.process_tool_calls(
106
+ tool_calls=tool_calls
 
 
 
107
  )
108
+ tool_call_message = {"role": "tool", "content": tool_call_result.get("content", "")}
109
  messages.append(tool_call_message)
110
 
111
+ # new_output = generate(messages=messages, has_tool_call=False)
112
+ new_output = open_ai_client.open_ai_client.chat.completions.create(
113
+ messages=messages, model="my-model"
114
+ ).model_dump()
115
 
116
  return new_output
src/utils/image_pipeline.py CHANGED
@@ -1,34 +1,27 @@
1
- # import torch
2
- # from diffusers import StableDiffusionPipeline
3
- # from constants.config import IMAGE_MODEL_ID_OR_LINK, TORCH_DEVICE
 
 
 
4
 
5
- # torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for performance on CUDA
6
 
7
- # _pipeline = None
8
 
9
 
10
- # def get_pipeline() -> StableDiffusionPipeline:
11
- # global _pipeline
12
- # if _pipeline is None:
13
- # try:
14
- # _pipeline = StableDiffusionPipeline.from_pretrained(
15
- # IMAGE_MODEL_ID_OR_LINK,
16
- # torch_dtype=torch.bfloat16,
17
- # variant="fp16",
18
- # # safety_checker=True,
19
- # use_safetensors=True,
20
- # )
21
- # # _pipeline = StableDiffusionPipeline.from_single_file(
22
- # # IMAGE_MODEL_ID_OR_LINK,
23
- # # torch_dtype=torch.bfloat16,
24
- # # variant="fp16",
25
- # # # safety_checker=True,
26
- # # use_safetensors=True,
27
- # # )
28
- # _pipeline.to(TORCH_DEVICE)
29
- # except Exception as e:
30
- # raise RuntimeError(f"Failed to load the model: {e}")
31
- # return _pipeline
32
 
33
 
34
- # pipeline = get_pipeline()
 
 
 
1
+ import torch
2
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
3
+ StableDiffusionPipeline,
4
+ )
5
+ from constants.config import IMAGE_MODEL_ID_OR_LINK, TORCH_DEVICE
6
+ from utils.timing import measure_time
7
 
8
+ torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for performance on CUDA
9
 
10
+ pipeline = None
11
 
12
 
13
+ def load_pipeline():
14
+ global pipeline
15
+ with measure_time("Load image pipeline"):
16
+ pipeline = StableDiffusionPipeline.from_pretrained(
17
+ IMAGE_MODEL_ID_OR_LINK,
18
+ torch_dtype=torch.bfloat16,
19
+ variant="fp16",
20
+ # safety_checker=True,
21
+ use_safetensors=True,
22
+ )
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
+ def clear_resources():
26
+ global pipeline
27
+ pipeline = None
src/utils/timing.py CHANGED
@@ -1,10 +1,12 @@
1
  import time
2
 
 
3
  class measure_time:
4
  def __init__(self, label="Operation"):
5
  self.label = label
6
 
7
  def __enter__(self):
 
8
  self.start = time.time()
9
  return self
10
 
 
1
  import time
2
 
3
+
4
  class measure_time:
5
  def __init__(self, label="Operation"):
6
  self.label = label
7
 
8
  def __enter__(self):
9
+ print(f"\nStart: {self.label}")
10
  self.start = time.time()
11
  return self
12
 
src/utils/tools/tools_helper.py CHANGED
@@ -1,6 +1,5 @@
1
  import json
2
  from typing import List
3
- from models.responses.tool_call_response import ToolCall
4
  from utils.tools.tools_define import ToolFunction
5
  from services import image_service, web_data_service
6
 
@@ -15,7 +14,7 @@ def extract_tool_args(tool_call):
15
  Returns:
16
  dict: The extracted arguments as a dictionary
17
  """
18
- return json.loads(tool_call.function.arguments)
19
 
20
 
21
  def handle_web_data_tool_call(tool_call):
@@ -66,7 +65,7 @@ def handle_search_web_tool_call(tool_call):
66
  return search_results
67
 
68
 
69
- def process_tool_calls(final_tool_calls):
70
  """
71
  Process all tool calls and execute them.
72
 
@@ -87,8 +86,8 @@ def process_tool_calls(final_tool_calls):
87
  ToolFunction.SEARCH_WEB.value: handle_search_web_tool_call,
88
  }
89
 
90
- for tool_call in final_tool_calls.values():
91
- handler = tool_handlers.get(tool_call.function.name)
92
  if handler:
93
  result = handler(tool_call)
94
  if isinstance(result, list):
@@ -98,14 +97,14 @@ def process_tool_calls(final_tool_calls):
98
 
99
  return {
100
  "role": "tool",
101
- "tool_call_id": tool_call.id,
102
- "tool_call_name": tool_call.function.name,
103
  "content": content,
104
  }
105
 
106
 
107
  def final_tool_calls_handler(
108
- final_tool_calls: dict, tool_calls: List[ToolCall], is_stream: bool = False
109
  ):
110
  """
111
  Handle and combine multiple tool calls.
@@ -120,13 +119,11 @@ def final_tool_calls_handler(
120
  for index, tool_call in enumerate(tool_calls):
121
  if index not in final_tool_calls:
122
  final_tool_calls[index] = tool_call
123
- elif tool_call.function is not None:
124
  if is_stream:
125
- final_tool_calls[
126
- index
127
- ].function.arguments += tool_call.function.arguments
128
  else:
129
- final_tool_calls[index].function.arguments = (
130
- tool_call.function.arguments
131
- )
132
  return final_tool_calls
 
1
  import json
2
  from typing import List
 
3
  from utils.tools.tools_define import ToolFunction
4
  from services import image_service, web_data_service
5
 
 
14
  Returns:
15
  dict: The extracted arguments as a dictionary
16
  """
17
+ return json.loads(tool_call.get("function", {}).get("arguments", "{}"))
18
 
19
 
20
  def handle_web_data_tool_call(tool_call):
 
65
  return search_results
66
 
67
 
68
+ def process_tool_calls(tool_calls):
69
  """
70
  Process all tool calls and execute them.
71
 
 
86
  ToolFunction.SEARCH_WEB.value: handle_search_web_tool_call,
87
  }
88
 
89
+ for tool_call in tool_calls:
90
+ handler = tool_handlers.get(tool_call.get("function").get("name"))
91
  if handler:
92
  result = handler(tool_call)
93
  if isinstance(result, list):
 
97
 
98
  return {
99
  "role": "tool",
100
+ "tool_call_id": tool_call.get("id"),
101
+ "tool_call_name": tool_call.get("function", {}).get("name"),
102
  "content": content,
103
  }
104
 
105
 
106
  def final_tool_calls_handler(
107
+ final_tool_calls: dict, tool_calls: List[dict], is_stream: bool = False
108
  ):
109
  """
110
  Handle and combine multiple tool calls.
 
119
  for index, tool_call in enumerate(tool_calls):
120
  if index not in final_tool_calls:
121
  final_tool_calls[index] = tool_call
122
+ elif tool_call.get("function") is not None:
123
  if is_stream:
124
+ if "function" in final_tool_calls[index] and "arguments" in final_tool_calls[index]["function"]:
125
+ final_tool_calls[index]["function"]["arguments"] += tool_call.get("function", {}).get("arguments", "")
 
126
  else:
127
+ if "function" in final_tool_calls[index]:
128
+ final_tool_calls[index]["function"]["arguments"] = tool_call.get("function", {}).get("arguments", "")
 
129
  return final_tool_calls
src/utils/transformer_client.py CHANGED
@@ -1,71 +1,122 @@
1
- import os
 
 
2
  from threading import Thread
3
  from typing import Generator, List
4
- import torch
 
5
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
6
  from transformers.models.auto.tokenization_auto import AutoTokenizer
7
  from constants.config import (
8
- GGUF_FILE_NAME,
9
  LLM_MODEL_NAME,
10
- GGUF_REPO_ID,
11
  TORCH_DEVICE,
12
  USE_QUANT,
13
  MODEL_OPTIMIZATION,
14
- IS_APPLE_SILICON,
15
  )
16
- from models.others.message import Message, Role
17
- from models.responses.chat_response import ChatResponse
18
  from transformers.generation.streamers import TextIteratorStreamer
19
  from utils.timing import measure_time
20
  from utils.tools import tools_define
21
  from transformers.utils.quantization_config import BitsAndBytesConfig
22
 
23
- # Configure model loading based on device
24
- if USE_QUANT:
25
- quantization_config = BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_quant_type="nf4",
28
- bnb_4bit_compute_dtype=MODEL_OPTIMIZATION["torch_dtype"],
29
- bnb_4bit_use_double_quant=True,
30
- )
31
- model = AutoModelForCausalLM.from_pretrained(
32
- LLM_MODEL_NAME,
33
- torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
34
- device_map="auto",
35
- quantization_config=quantization_config,
36
- low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
37
- use_cache=MODEL_OPTIMIZATION["use_cache"],
38
- )
39
- else:
40
- model = AutoModelForCausalLM.from_pretrained(
41
- LLM_MODEL_NAME,
42
- torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
43
- device_map="auto",
44
- low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
45
- use_cache=MODEL_OPTIMIZATION["use_cache"],
46
- )
47
 
48
- # Configure tokenizer with appropriate settings
49
- tokenizer = AutoTokenizer.from_pretrained(
50
- LLM_MODEL_NAME,
51
- use_fast=True, # Use fast tokenizer for better performance
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
54
 
55
- def build_prompt(messages: List[Message]) -> str:
56
- return "\n".join([f"{m.role.value}: {m.content}" for m in messages])
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def generate(messages: List[Message], has_tool_call: bool = True) -> ChatResponse:
60
  # Convert messages to prompt
61
- prompt = [message.to_map() for message in messages]
62
 
63
  # Prepare tools if enabled
64
  tools = tools_define.tools if has_tool_call else None
65
  tool_choice = "auto" if has_tool_call else "none"
66
 
67
  # Apply chat template
68
- formatted_prompt = tokenizer.apply_chat_template(
69
  prompt,
70
  tools=tools,
71
  tool_choice=tool_choice,
@@ -73,61 +124,71 @@ def generate(messages: List[Message], has_tool_call: bool = True) -> ChatRespons
73
  add_generation_prompt=True,
74
  )
75
 
76
- print("Starting create chat completion")
77
  try:
78
- with measure_time("Starting create chat completion"):
79
  # Tokenize input with optimized settings
80
- inputs = tokenizer(
81
  formatted_prompt,
82
  return_tensors="pt",
83
  padding=True,
84
  truncation=True,
85
- max_length=2048, # Adjust based on your needs
86
  ).to(TORCH_DEVICE)
87
 
88
  # Generate response with optimized settings
89
- output_ids = model.generate(
90
  **inputs,
91
  max_new_tokens=4096,
92
  do_sample=True,
93
  temperature=0.7,
94
- pad_token_id=tokenizer.pad_token_id,
95
- eos_token_id=tokenizer.eos_token_id,
96
  use_cache=True, # Enable KV cache for faster generation
97
  num_beams=1, # Use greedy decoding for faster inference
98
  )
99
 
100
  # Decode response
101
- output_text = tokenizer.decode(
102
  output_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
103
  )
104
 
105
- # Create ChatResponse using from_llm_output
106
- return ChatResponse.from_llm_output(
107
- {
108
- "choices": [
109
- {
110
- "message": {
111
- "role": Role.assistant.value,
112
- "content": output_text,
113
- }
114
- }
115
- ]
116
- }
117
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
  print(f"Error in create chat completion: {str(e)}")
120
  raise
121
 
122
 
123
- def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, None]:
 
 
 
 
 
 
124
  # Convert messages to prompt
125
- prompt = [message.to_map() for message in messages]
126
  # Prepare tools
127
  tools = tools_define.tools
128
 
129
  # Apply chat template
130
- formatted_prompt = tokenizer.apply_chat_template(
131
  prompt,
132
  tools=tools,
133
  tool_choice="auto",
@@ -137,7 +198,7 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
137
 
138
  try:
139
  # Tokenize input with optimized settings
140
- inputs = tokenizer(
141
  prompt,
142
  return_tensors="pt",
143
  padding=True,
@@ -147,7 +208,7 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
147
 
148
  # Generate streaming output
149
  streamer = TextIteratorStreamer(
150
- tokenizer,
151
  skip_prompt=True,
152
  skip_special_tokens=True,
153
  )
@@ -158,17 +219,17 @@ def generate_stream(messages: List[Message]) -> Generator[ChatResponse, None, No
158
  do_sample=True,
159
  max_new_tokens=4096,
160
  temperature=0.7,
161
- pad_token_id=tokenizer.pad_token_id,
162
- eos_token_id=tokenizer.eos_token_id,
163
  use_cache=True, # Enable KV cache for faster generation
164
  num_beams=1, # Use greedy decoding for faster inference
165
  )
166
 
167
  # Generate in background thread
168
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
169
  thread.start()
170
 
171
- last_role = Role.assistant
172
  for new_text in streamer:
173
  # Format the chunk to match the expected structure
174
  chunk = {
 
1
+ from email import message
2
+ import json
3
+ import re
4
  from threading import Thread
5
  from typing import Generator, List
6
+ import uuid
7
+ from numpy import append
8
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
9
  from transformers.models.auto.tokenization_auto import AutoTokenizer
10
  from constants.config import (
 
11
  LLM_MODEL_NAME,
 
12
  TORCH_DEVICE,
13
  USE_QUANT,
14
  MODEL_OPTIMIZATION,
 
15
  )
 
 
16
  from transformers.generation.streamers import TextIteratorStreamer
17
  from utils.timing import measure_time
18
  from utils.tools import tools_define
19
  from transformers.utils.quantization_config import BitsAndBytesConfig
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def load_model():
23
+ global _model, _tokenizer
24
+
25
+ # Configure model loading based on device
26
+ try:
27
+ with measure_time("Load model"):
28
+ if USE_QUANT:
29
+ quantization_config = BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ bnb_4bit_quant_type="nf4",
32
+ bnb_4bit_compute_dtype=MODEL_OPTIMIZATION["torch_dtype"],
33
+ bnb_4bit_use_double_quant=True,
34
+ )
35
+ _model = AutoModelForCausalLM.from_pretrained(
36
+ LLM_MODEL_NAME,
37
+ torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
38
+ device_map="auto",
39
+ quantization_config=quantization_config,
40
+ low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
41
+ use_cache=MODEL_OPTIMIZATION["use_cache"],
42
+ )
43
+ else:
44
+ _model = AutoModelForCausalLM.from_pretrained(
45
+ LLM_MODEL_NAME,
46
+ torch_dtype=MODEL_OPTIMIZATION["torch_dtype"],
47
+ device_map="auto",
48
+ low_cpu_mem_usage=MODEL_OPTIMIZATION["low_cpu_mem_usage"],
49
+ use_cache=MODEL_OPTIMIZATION["use_cache"],
50
+ )
51
+
52
+ # Configure tokenizer with appropriate settings
53
+ _tokenizer = AutoTokenizer.from_pretrained(
54
+ LLM_MODEL_NAME,
55
+ use_fast=True, # Use fast tokenizer for better performance
56
+ )
57
+
58
+ _model.eval()
59
+ except Exception as e:
60
+ print(f"Failed to load model or tokenizer: {str(e)}")
61
+ _model = None
62
+ _tokenizer = None
63
+ raise
64
+
65
+
66
+ def clear_resources():
67
+ global _model, _tokenizer
68
+ _model = None
69
+ _tokenizer = None
70
+
71
+
72
+ def build_prompt(messages: List[dict]) -> str:
73
+ return "\n".join([f"{m.get('role')}: {m.get('content')}" for m in messages])
74
+
75
 
76
+ def extract_tool_calls_and_reupdate_output(text: str):
77
+ """
78
+ Extracts all valid JSON objects found within <tool_call>{...}</tool_call> patterns.
79
+ """
80
+ tool_calls = []
81
 
82
+ # Match any <tool_call> JSON-like structure (greedy to match full JSON block)
83
+ pattern = r"<tool_call>\s*(\{.*?\})\s*</?tool_call>?"
84
 
85
+ matches = list(re.finditer(pattern, text, re.DOTALL))
86
+
87
+ for match in matches:
88
+ try:
89
+ tool_call = {}
90
+ tool_call["id"] = str(uuid.uuid4())
91
+ tool_call["type"] = "function"
92
+ tool_call["function"] = {
93
+ "name": match.group(1),
94
+ "arguments": json.loads(match.group(1)),
95
+ }
96
+ tool_calls.append(tool_call)
97
+ except json.JSONDecodeError:
98
+ continue
99
+
100
+ text = re.sub(pattern, "", text, flags=re.DOTALL).strip()
101
+ return text.strip(), tool_calls if tool_calls else None
102
+
103
+
104
+ def generate(messages: List[dict], has_tool_call: bool = True) -> dict:
105
+
106
+ if _model is None or _tokenizer is None:
107
+ raise RuntimeError(
108
+ "Model or tokenizer not initialized. Ensure load_model was called successfully."
109
+ )
110
 
 
111
  # Convert messages to prompt
112
+ prompt = build_prompt(messages)
113
 
114
  # Prepare tools if enabled
115
  tools = tools_define.tools if has_tool_call else None
116
  tool_choice = "auto" if has_tool_call else "none"
117
 
118
  # Apply chat template
119
+ formatted_prompt = _tokenizer.apply_chat_template(
120
  prompt,
121
  tools=tools,
122
  tool_choice=tool_choice,
 
124
  add_generation_prompt=True,
125
  )
126
 
 
127
  try:
128
+ with measure_time("Create chat completion"):
129
  # Tokenize input with optimized settings
130
+ inputs = _tokenizer(
131
  formatted_prompt,
132
  return_tensors="pt",
133
  padding=True,
134
  truncation=True,
135
+ max_length=4096, # Adjust based on your needs
136
  ).to(TORCH_DEVICE)
137
 
138
  # Generate response with optimized settings
139
+ output_ids = _model.generate(
140
  **inputs,
141
  max_new_tokens=4096,
142
  do_sample=True,
143
  temperature=0.7,
144
+ pad_token_id=_tokenizer.pad_token_id,
145
+ eos_token_id=_tokenizer.eos_token_id,
146
  use_cache=True, # Enable KV cache for faster generation
147
  num_beams=1, # Use greedy decoding for faster inference
148
  )
149
 
150
  # Decode response
151
+ output_text = _tokenizer.decode(
152
  output_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
153
  )
154
 
155
+ cleaned_output, tool_calls = extract_tool_calls_and_reupdate_output(
156
+ output_text
 
 
 
 
 
 
 
 
 
 
157
  )
158
+
159
+ # Create ChatResponse using from_llm_output
160
+ return {
161
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
162
+ "choices": [
163
+ {
164
+ "message": {
165
+ "role": "assistant",
166
+ "content": cleaned_output,
167
+ "tool_calls": tool_calls,
168
+ },
169
+ }
170
+ ],
171
+ }
172
+
173
  except Exception as e:
174
  print(f"Error in create chat completion: {str(e)}")
175
  raise
176
 
177
 
178
+ def generate_stream(messages: List[dict]) -> Generator[dict, None, None]:
179
+
180
+ if _model is None or _tokenizer is None:
181
+ raise RuntimeError(
182
+ "Model or tokenizer not initialized. Ensure load_model was called successfully."
183
+ )
184
+
185
  # Convert messages to prompt
186
+ prompt = build_prompt(messages)
187
  # Prepare tools
188
  tools = tools_define.tools
189
 
190
  # Apply chat template
191
+ formatted_prompt = _tokenizer.apply_chat_template(
192
  prompt,
193
  tools=tools,
194
  tool_choice="auto",
 
198
 
199
  try:
200
  # Tokenize input with optimized settings
201
+ inputs = _tokenizer(
202
  prompt,
203
  return_tensors="pt",
204
  padding=True,
 
208
 
209
  # Generate streaming output
210
  streamer = TextIteratorStreamer(
211
+ _tokenizer,
212
  skip_prompt=True,
213
  skip_special_tokens=True,
214
  )
 
219
  do_sample=True,
220
  max_new_tokens=4096,
221
  temperature=0.7,
222
+ pad_token_id=_tokenizer.pad_token_id,
223
+ eos_token_id=_tokenizer.eos_token_id,
224
  use_cache=True, # Enable KV cache for faster generation
225
  num_beams=1, # Use greedy decoding for faster inference
226
  )
227
 
228
  # Generate in background thread
229
+ thread = Thread(target=_model.generate, kwargs=generation_kwargs)
230
  thread.start()
231
 
232
+ last_role = "assistant"
233
  for new_text in streamer:
234
  # Format the chunk to match the expected structure
235
  chunk = {