sparkleman commited on
Commit
adb6ad5
·
1 Parent(s): ff3952a

UPDATE: Add frontend

Browse files
Files changed (6) hide show
  1. .gitignore +4 -1
  2. Dockerfile +53 -2
  3. README.md +1 -1
  4. app.py +136 -61
  5. config.py +82 -0
  6. openai_test.py +0 -78
.gitignore CHANGED
@@ -13,4 +13,7 @@ wheels/
13
 
14
  *pth
15
  *.pt
16
- *.st
 
 
 
 
13
 
14
  *pth
15
  *.pt
16
+ *.st
17
+ *local*
18
+
19
+ dist-frontend/
Dockerfile CHANGED
@@ -9,12 +9,23 @@ apt install --no-install-recommends -y \
9
  apt clean && rm -rf /var/lib/apt/lists/*
10
  EOF
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
13
 
14
  COPY . .
15
 
16
  RUN useradd -m -u 1000 user
17
- # Switch to the "user" user
18
  USER user
19
 
20
  ENV HOME=/home/user \
@@ -23,7 +34,47 @@ ENV HOME=/home/user \
23
  WORKDIR $HOME/app
24
 
25
  COPY --chown=user . $HOME/app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  RUN uv sync --frozen --extra cu124
28
 
29
- CMD ["uv","run","app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world","--host","0.0.0.0","--port","7860","--RWKV_CUDA_ON","True"]
 
9
  apt clean && rm -rf /var/lib/apt/lists/*
10
  EOF
11
 
12
+ # 安装Node.js和npm
13
+ RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
14
+ apt-get install -y nodejs
15
+
16
+ # 安装pnpm
17
+ RUN npm install -g pnpm
18
+
19
+ # 克隆前端仓库并构建
20
+ RUN git clone https://github.com/SolomonLeon/web-rwkv-realweb.git /frontend
21
+ WORKDIR /frontend
22
+ RUN pnpm install && pnpm run build
23
+
24
  COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
25
 
26
  COPY . .
27
 
28
  RUN useradd -m -u 1000 user
 
29
  USER user
30
 
31
  ENV HOME=/home/user \
 
34
  WORKDIR $HOME/app
35
 
36
  COPY --chown=user . $HOME/app
37
+ COPY --chown=user /frontend/dist $HOME/app/dist-frontend
38
+
39
+ RUN cat > $HOME/app/config.local.yaml<<EOF
40
+ HOST: "0.0.0.0"
41
+ PORT: 7860
42
+ STRATEGY: "cuda fp16"
43
+ RWKV_CUDA_ON: False
44
+ CHUNK_LEN: 256
45
+ MODELS:
46
+ - SERVICE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096"
47
+ DOWNLOAD_MODEL_FILE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
48
+ DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv-7-world"
49
+ DOWNLOAD_MODEL_DIR: "./"
50
+ REASONING: False
51
+ DEFAULT: True
52
+ DEFAULT_SAMPLER:
53
+ max_tokens: 512
54
+ temperature: 1.0
55
+ top_p: 0.3
56
+ presence_penalty: 0.5
57
+ count_penalty: 0.5
58
+ penalty_decay: 0.996
59
+ stop:
60
+ - "\n\n"
61
+ - SERVICE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k"
62
+ DOWNLOAD_MODEL_FILE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k.pth"
63
+ DOWNLOAD_MODEL_REPO_ID: "BlinkDL/temp-latest-training-models"
64
+ DOWNLOAD_MODEL_DIR: "./"
65
+ REASONING: True
66
+ DEFAULT: True
67
+ DEFAULT_SAMPLER:
68
+ max_tokens: 4096
69
+ temperature: 1.0
70
+ top_p: 0.3
71
+ presence_penalty: 0.5
72
+ count_penalty: 0.5
73
+ penalty_decay: 0.996
74
+ stop:
75
+ - "\n\n"
76
+ EOF
77
 
78
  RUN uv sync --frozen --extra cu124
79
 
80
+ CMD ["uv","run","app.py",]
README.md CHANGED
@@ -25,7 +25,7 @@ python app.py --strategy "cuda fp16" --model_title "RWKV-x070-World-0.1B-v2.8-20
25
  python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
26
  ```
27
 
28
- `RWKV7-G1-0.1B-68%trained-20250303-ctx4k`
29
 
30
  ```shell
31
  python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
 
25
  python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
26
  ```
27
 
28
+ `RWKV7-G1-0.4B-68%trained-20250303-ctx4k`
29
 
30
  ```shell
31
  python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os, copy, types, gc, sys, re, time, collections, asyncio
2
  from huggingface_hub import hf_hub_download
3
  from loguru import logger
@@ -6,32 +8,11 @@ from snowflake import SnowflakeGenerator
6
 
7
  CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
8
 
9
- from typing import List, Optional, Union
10
- from pydantic import BaseModel, Field
11
  from pydantic_settings import BaseSettings
12
 
13
 
14
- class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
15
- HOST: str = Field("127.0.0.1", description="Host")
16
- PORT: int = Field(8000, description="Port")
17
- DEBUG: bool = Field(False, description="Debug mode")
18
- STRATEGY: str = Field("cpu", description="Stratergy")
19
- MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096")
20
- DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world")
21
- DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir")
22
- MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path")
23
- GEN_penalty_decay: float = Field(0.996, description="Default penalty decay")
24
- CHUNK_LEN: int = Field(
25
- 256,
26
- description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
27
- )
28
- VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
29
- RWKV_CUDA_ON:bool = Field(False, description="`True` to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!")
30
-
31
-
32
- CONFIG = Config()
33
-
34
-
35
  import numpy as np
36
  import torch
37
 
@@ -58,9 +39,10 @@ os.environ["RWKV_CUDA_ON"] = (
58
  from rwkv.model import RWKV
59
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
60
 
61
- from fastapi import FastAPI
62
  from fastapi.responses import StreamingResponse
63
  from fastapi.middleware.cors import CORSMiddleware
 
64
 
65
  from api_types import (
66
  ChatMessage,
@@ -74,17 +56,50 @@ from api_types import (
74
  from utils import cleanMessages, parse_think_response
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
77
  logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
78
- if CONFIG.MODEL_FILE_PATH == None:
79
- CONFIG.MODEL_FILE_PATH = hf_hub_download(
80
- repo_id=CONFIG.DOWNLOAD_REPO_ID,
81
- filename=f"{CONFIG.MODEL_TITLE}.pth",
82
- local_dir=CONFIG.DOWNLOAD_MODEL_DIR,
 
 
 
 
 
 
 
 
 
 
83
  )
 
84
 
85
- logger.info(f"Load Model - {CONFIG.MODEL_FILE_PATH}")
86
- model = RWKV(model=CONFIG.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY)
87
- pipeline = PIPELINE(model, CONFIG.VOCAB)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  class ChatCompletionRequest(BaseModel):
@@ -92,16 +107,33 @@ class ChatCompletionRequest(BaseModel):
92
  default="rwkv-latest",
93
  description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
94
  )
95
- messages: List[ChatMessage]
96
  prompt: Optional[str] = Field(default=None)
97
- max_tokens: int = Field(default=512)
98
- temperature: float = Field(default=1.0)
99
- top_p: float = Field(default=0.3)
100
- presencePenalty: float = Field(default=0.5)
101
- countPenalty: float = Field(default=0.5)
102
- stream: bool = Field(default=False)
103
- state_name: str = Field(default=None)
104
- include_usage: bool = Field(default=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  app = FastAPI(title="RWKV OpenAI-Compatible API")
@@ -115,15 +147,19 @@ app.add_middleware(
115
  )
116
 
117
 
118
- async def runPrefill(ctx: str, model_tokens: List[int], model_state):
 
 
119
  ctx = ctx.replace("\r\n", "\n")
120
 
121
- tokens = pipeline.encode(ctx)
122
  tokens = [int(x) for x in tokens]
123
  model_tokens += tokens
124
 
125
  while len(tokens) > 0:
126
- out, model_state = model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
 
 
127
  tokens = tokens[CONFIG.CHUNK_LEN :]
128
  await asyncio.sleep(0)
129
 
@@ -141,8 +177,8 @@ def generate(
141
  args = PIPELINE_ARGS(
142
  temperature=max(0.2, request.temperature),
143
  top_p=request.top_p,
144
- alpha_frequency=request.countPenalty,
145
- alpha_presence=request.presencePenalty,
146
  token_ban=[], # ban the generation of some tokens
147
  token_stop=[0],
148
  ) # stop generation whenever you see any token here
@@ -158,20 +194,22 @@ def generate(
158
  out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
159
  out[0] -= 1e10 # disable END_OF_TEXT
160
 
161
- token = pipeline.sample_logits(
162
  out, temperature=args.temperature, top_p=args.top_p
163
  )
164
 
165
- out, model_state = model.forward([token], model_state)
 
 
166
  model_tokens += [token]
167
 
168
  out_tokens += [token]
169
 
170
  for xxx in occurrence:
171
- occurrence[xxx] *= CONFIG.GEN_penalty_decay
172
  occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
173
 
174
- tmp: str = pipeline.decode(out_tokens[out_last:])
175
 
176
  if "\ufffd" in tmp:
177
  continue
@@ -210,19 +248,20 @@ def generate(
210
 
211
 
212
  async def chatResponse(
213
- request: ChatCompletionRequest, model_state: any, completionId: str
 
 
 
214
  ) -> ChatCompletion:
215
  createTimestamp = time.time()
216
 
217
- enableReasoning = request.model.endswith(":thinking")
218
-
219
  prompt = (
220
  f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
221
  if request.prompt == None
222
  else request.prompt.strip()
223
  )
224
 
225
- out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
226
 
227
  prefillTime = time.time()
228
  promptTokenCount = len(model_tokens)
@@ -291,19 +330,20 @@ async def chatResponse(
291
 
292
 
293
  async def chatResponseStream(
294
- request: ChatCompletionRequest, model_state: any, completionId: str
 
 
 
295
  ):
296
  createTimestamp = int(time.time())
297
 
298
- enableReasoning = request.model.endswith(":thinking")
299
-
300
  prompt = (
301
  f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
302
  if request.prompt == None
303
  else request.prompt.strip()
304
  )
305
 
306
- out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
307
 
308
  prefillTime = time.time()
309
  promptTokenCount = len(model_tokens)
@@ -343,7 +383,7 @@ async def chatResponseStream(
343
  buffer = []
344
 
345
  if enableReasoning:
346
- buffer.append(" <think")
347
 
348
  streamConfig = {
349
  "isChecking": False,
@@ -532,6 +572,32 @@ async def chat_completions(request: ChatCompletionRequest):
532
  completionId = str(next(CompletionIdGenerator))
533
  logger.info(f"[REQ] {completionId} - {request.model_dump()}")
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  async def chatResponseStreamDisconnect():
536
  if "cuda" in CONFIG.STRATEGY:
537
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
@@ -540,18 +606,27 @@ async def chat_completions(request: ChatCompletionRequest):
540
  )
541
 
542
  model_state = None
 
 
 
 
 
 
 
 
543
 
544
  if request.stream:
545
  r = StreamingResponse(
546
- chatResponseStream(request, model_state, completionId),
547
  media_type="text/event-stream",
548
  background=chatResponseStreamDisconnect,
549
  )
550
  else:
551
- r = await chatResponse(request, model_state, completionId)
552
 
553
  return r
554
 
 
555
 
556
  if __name__ == "__main__":
557
  import uvicorn
 
1
+ from config import CONFIG, ModelConfig
2
+
3
  import os, copy, types, gc, sys, re, time, collections, asyncio
4
  from huggingface_hub import hf_hub_download
5
  from loguru import logger
 
8
 
9
  CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
10
 
11
+ from typing import List, Optional, Union, Any, Dict
12
+ from pydantic import BaseModel, Field, model_validator
13
  from pydantic_settings import BaseSettings
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import numpy as np
17
  import torch
18
 
 
39
  from rwkv.model import RWKV
40
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
41
 
42
+ from fastapi import FastAPI, HTTPException
43
  from fastapi.responses import StreamingResponse
44
  from fastapi.middleware.cors import CORSMiddleware
45
+ from fastapi.staticfiles import StaticFiles
46
 
47
  from api_types import (
48
  ChatMessage,
 
56
  from utils import cleanMessages, parse_think_response
57
 
58
 
59
+ class ModelStorage:
60
+ MODEL_CONFIG: Optional[ModelConfig] = None
61
+ model: Optional[RWKV] = None
62
+ pipeline: Optional[PIPELINE] = None
63
+
64
+
65
+ MODEL_STORAGE: Dict[str, ModelStorage] = {}
66
+
67
+ DEFALUT_MODEL_NAME = None
68
+ DEFAULT_REASONING_MODEL_NAME = None
69
+
70
  logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
71
+
72
+ for model_config in CONFIG.MODELS:
73
+ logger.info(f"Load Model - {model_config.SERVICE_NAME}")
74
+
75
+ if model_config.MODEL_FILE_PATH == None:
76
+ model_config.MODEL_FILE_PATH = hf_hub_download(
77
+ repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
78
+ filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
79
+ local_dir=model_config.DOWNLOAD_MODEL_DIR,
80
+ )
81
+ logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
82
+
83
+ tmp_model = RWKV(
84
+ model=model_config.DOWNLOAD_MODEL_FILE_NAME.replace(".pth", ""),
85
+ strategy=CONFIG.STRATEGY,
86
  )
87
+ tmp_pipeline = PIPELINE(tmp_model, model_config.VOCAB)
88
 
89
+ if model_config.DEFAULT:
90
+ if model_config.REASONING:
91
+ DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
92
+ else:
93
+ DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
94
+
95
+ MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
96
+ MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
97
+ MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
98
+ MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = tmp_pipeline
99
+
100
+
101
+ logger.info(f"DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
102
+ logger.info(f"DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`")
103
 
104
 
105
  class ChatCompletionRequest(BaseModel):
 
107
  default="rwkv-latest",
108
  description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
109
  )
110
+ messages: Optional[List[ChatMessage]] = Field(default=None)
111
  prompt: Optional[str] = Field(default=None)
112
+ max_tokens: Optional[int] = Field(default=None)
113
+ temperature: Optional[float] = Field(default=None)
114
+ top_p: Optional[float] = Field(default=None)
115
+ presence_penalty: Optional[float] = Field(default=None)
116
+ count_penalty: Optional[float] = Field(default=None)
117
+ penalty_decay: Optional[float] = Field(default=None)
118
+ stream: Optional[bool] = Field(default=False)
119
+ state_name: Optional[str] = Field(default=None)
120
+ include_usage: Optional[bool] = Field(default=False)
121
+ stop: Optional[list[str]] = Field(["\n\n"])
122
+
123
+ @model_validator(mode="before")
124
+ @classmethod
125
+ def validate_mutual_exclusivity(cls, data: Any) -> Any:
126
+ if not isinstance(data, dict):
127
+ return data
128
+
129
+ messages_provided = "messages" in data and data["messages"] != None
130
+ prompt_provided = "prompt" in data and data["prompt"] != None
131
+
132
+ if messages_provided and prompt_provided:
133
+ raise ValueError("messages and prompt cannot coexist. Choose one.")
134
+ if not messages_provided and not prompt_provided:
135
+ raise ValueError("Either messages or prompt must be provided.")
136
+ return data
137
 
138
 
139
  app = FastAPI(title="RWKV OpenAI-Compatible API")
 
147
  )
148
 
149
 
150
+ async def runPrefill(
151
+ request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
152
+ ):
153
  ctx = ctx.replace("\r\n", "\n")
154
 
155
+ tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
156
  tokens = [int(x) for x in tokens]
157
  model_tokens += tokens
158
 
159
  while len(tokens) > 0:
160
+ out, model_state = MODEL_STORAGE[request.model].model.forward(
161
+ tokens[: CONFIG.CHUNK_LEN], model_state
162
+ )
163
  tokens = tokens[CONFIG.CHUNK_LEN :]
164
  await asyncio.sleep(0)
165
 
 
177
  args = PIPELINE_ARGS(
178
  temperature=max(0.2, request.temperature),
179
  top_p=request.top_p,
180
+ alpha_frequency=request.count_penalty,
181
+ alpha_presence=request.presence_penalty,
182
  token_ban=[], # ban the generation of some tokens
183
  token_stop=[0],
184
  ) # stop generation whenever you see any token here
 
194
  out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
195
  out[0] -= 1e10 # disable END_OF_TEXT
196
 
197
+ token = MODEL_STORAGE[request.model].pipeline.sample_logits(
198
  out, temperature=args.temperature, top_p=args.top_p
199
  )
200
 
201
+ out, model_state = MODEL_STORAGE[request.model].model.forward(
202
+ [token], model_state
203
+ )
204
  model_tokens += [token]
205
 
206
  out_tokens += [token]
207
 
208
  for xxx in occurrence:
209
+ occurrence[xxx] *= request.penalty_decay
210
  occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
211
 
212
+ tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
213
 
214
  if "\ufffd" in tmp:
215
  continue
 
248
 
249
 
250
  async def chatResponse(
251
+ request: ChatCompletionRequest,
252
+ model_state: any,
253
+ completionId: str,
254
+ enableReasoning: bool,
255
  ) -> ChatCompletion:
256
  createTimestamp = time.time()
257
 
 
 
258
  prompt = (
259
  f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
260
  if request.prompt == None
261
  else request.prompt.strip()
262
  )
263
 
264
+ out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
265
 
266
  prefillTime = time.time()
267
  promptTokenCount = len(model_tokens)
 
330
 
331
 
332
  async def chatResponseStream(
333
+ request: ChatCompletionRequest,
334
+ model_state: any,
335
+ completionId: str,
336
+ enableReasoning: bool,
337
  ):
338
  createTimestamp = int(time.time())
339
 
 
 
340
  prompt = (
341
  f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
342
  if request.prompt == None
343
  else request.prompt.strip()
344
  )
345
 
346
+ out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
347
 
348
  prefillTime = time.time()
349
  promptTokenCount = len(model_tokens)
 
383
  buffer = []
384
 
385
  if enableReasoning:
386
+ buffer.append("<think")
387
 
388
  streamConfig = {
389
  "isChecking": False,
 
572
  completionId = str(next(CompletionIdGenerator))
573
  logger.info(f"[REQ] {completionId} - {request.model_dump()}")
574
 
575
+ modelName = request.model.split(":")[0]
576
+ enableReasoning = ":thinking" in request.model
577
+
578
+ if "rwkv-latest" in request.model:
579
+ if enableReasoning:
580
+ if DEFAULT_REASONING_MODEL_NAME == None:
581
+ raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set")
582
+ defaultSamplerConfig = MODEL_STORAGE[
583
+ DEFAULT_REASONING_MODEL_NAME
584
+ ].MODEL_CONFIG.DEFAULT_SAMPLER
585
+ request.model = DEFAULT_REASONING_MODEL_NAME
586
+
587
+ else:
588
+ if DEFALUT_MODEL_NAME == None:
589
+ raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
590
+ defaultSamplerConfig = MODEL_STORAGE[
591
+ DEFALUT_MODEL_NAME
592
+ ].MODEL_CONFIG.DEFAULT_SAMPLER
593
+ request.model = DEFALUT_MODEL_NAME
594
+
595
+ elif modelName in MODEL_STORAGE:
596
+ defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
597
+ request.model = modelName
598
+ else:
599
+ raise f"Can not find `{modelName}`"
600
+
601
  async def chatResponseStreamDisconnect():
602
  if "cuda" in CONFIG.STRATEGY:
603
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
 
606
  )
607
 
608
  model_state = None
609
+ request_dict = request.model_dump()
610
+
611
+ for k, v in defaultSamplerConfig.model_dump().items():
612
+ if request_dict[k] == None:
613
+ request_dict[k] = v
614
+ realRequest = ChatCompletionRequest(**request_dict)
615
+
616
+ logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
617
 
618
  if request.stream:
619
  r = StreamingResponse(
620
+ chatResponseStream(realRequest, model_state, completionId, enableReasoning),
621
  media_type="text/event-stream",
622
  background=chatResponseStreamDisconnect,
623
  )
624
  else:
625
+ r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
626
 
627
  return r
628
 
629
+ app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
630
 
631
  if __name__ == "__main__":
632
  import uvicorn
config.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+ from typing import List, Optional, Union, Any
4
+
5
+ import sys
6
+
7
+
8
+ from pydantic_settings import BaseSettings
9
+
10
+
11
+ class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
12
+ CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
13
+
14
+
15
+ CLI_CONFIG = CliConfig()
16
+
17
+
18
+ class SamplerConfig(BaseModel):
19
+ """Default sampler configuration for each model."""
20
+
21
+ max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
22
+ temperature: float = Field(1.0, description="Sampling temperature.")
23
+ top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
+ presence_penalty: float = Field(0.5, description="Presence penalty.")
25
+ count_penalty: float = Field(0.5, description="Count penalty.")
26
+ penalty_decay: float = Field(0.5, description="Penalty decay factor.")
27
+ stop: List[str] = Field(0.996, description="List of stop sequences.")
28
+
29
+
30
+ class ModelConfig(BaseModel):
31
+ """Configuration for each individual model."""
32
+
33
+ SERVICE_NAME: str = Field(..., description="Service name of the model.")
34
+
35
+ MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
36
+
37
+ DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
38
+ None, description="Model name, should end with .pth"
39
+ )
40
+ DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
41
+ None, description="Model repository ID on Hugging Face Hub."
42
+ )
43
+ DOWNLOAD_MODEL_DIR: Optional[str] = Field(
44
+ None, description="Directory to download the model to."
45
+ )
46
+
47
+ REASONING: bool = Field(
48
+ False, description="Whether reasoning is enabled for this model."
49
+ )
50
+
51
+ DEFAULT: bool = Field(False, description="Whether this model is the default model.")
52
+ DEFAULT_SAMPLER: SamplerConfig = Field(
53
+ SamplerConfig(), description="Default sampler configuration for this model."
54
+ )
55
+ VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
56
+
57
+
58
+ class RootConfig(BaseModel):
59
+ """Root configuration for the RWKV service."""
60
+
61
+ HOST: Optional[str] = Field(
62
+ "127.0.0.1", description="Host IP address to bind to."
63
+ ) # 注释掉可选的HOST和PORT
64
+ PORT: Optional[int] = Field(
65
+ 8000, description="Port number to listen on."
66
+ ) # 因为YAML示例中被注释掉了
67
+ STRATEGY: str = Field(
68
+ "cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
69
+ )
70
+ RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
71
+ CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
72
+ MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
73
+
74
+
75
+ import yaml
76
+
77
+ try:
78
+ with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
79
+ CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
80
+ except Exception as e:
81
+ print(f"Pydantic Model Validation Failed: {e}")
82
+ sys.exit(0)
openai_test.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- uv pip install openai
3
- """
4
-
5
- import os
6
-
7
- import logging
8
-
9
- # logging.basicConfig(
10
- # level=logging.DEBUG,
11
- # )
12
-
13
- os.environ["NO_PROXY"] = "127.0.0.1"
14
-
15
- from openai import OpenAI
16
-
17
- client = OpenAI(base_url="http://127.0.0.1:8000/api/v1", api_key="sk-test")
18
-
19
-
20
- def completionStreamTest():
21
- print("[*] Stream completion: ")
22
-
23
- completion = client.chat.completions.create(
24
- model="rwkv-latest",
25
- messages=[
26
- {
27
- "role": "User",
28
- "content": "请讲个关于一只灰猫和一个小女孩之间的简短故事。",
29
- },
30
- ],
31
- stream=True,
32
- max_tokens=2048,
33
- )
34
-
35
- isReasoning = False
36
-
37
- for chunk in completion:
38
- if chunk.choices[0].delta.reasoning_content and not isReasoning:
39
- print("<- Reasoning ->")
40
- isReasoning = True
41
- elif chunk.choices[0].delta.content and isReasoning:
42
- isReasoning = False
43
- print("<- Stop Reasoning ->")
44
-
45
- if chunk.choices[0].delta.reasoning_content:
46
- print(chunk.choices[0].delta.reasoning_content, end="", flush=True)
47
- if chunk.choices[0].delta.content:
48
- print(chunk.choices[0].delta.content, end="", flush=True)
49
-
50
- print("")
51
-
52
-
53
- def completionTest():
54
- completion = client.chat.completions.create(
55
- model="rwkv-latest:thinking",
56
- messages=[
57
- {
58
- "role": "User",
59
- "content": "How many planets are there in our solar system?",
60
- },
61
- ],
62
- max_tokens=2048,
63
- )
64
-
65
- print("[*] Completion: ", completion)
66
-
67
-
68
- if __name__ == "__main__":
69
- try:
70
- # completionTest()
71
-
72
- testRounds = input("Test rounds (Default: 10) :")
73
-
74
- for i in range(int(testRounds) if testRounds != "" else 10):
75
- print("\n", "=" * 10, i + 1, "/", testRounds, "=" * 10)
76
- completionStreamTest()
77
- except KeyboardInterrupt:
78
- pass