Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
·
adb6ad5
1
Parent(s):
ff3952a
UPDATE: Add frontend
Browse files- .gitignore +4 -1
- Dockerfile +53 -2
- README.md +1 -1
- app.py +136 -61
- config.py +82 -0
- openai_test.py +0 -78
.gitignore
CHANGED
@@ -13,4 +13,7 @@ wheels/
|
|
13 |
|
14 |
*pth
|
15 |
*.pt
|
16 |
-
*.st
|
|
|
|
|
|
|
|
13 |
|
14 |
*pth
|
15 |
*.pt
|
16 |
+
*.st
|
17 |
+
*local*
|
18 |
+
|
19 |
+
dist-frontend/
|
Dockerfile
CHANGED
@@ -9,12 +9,23 @@ apt install --no-install-recommends -y \
|
|
9 |
apt clean && rm -rf /var/lib/apt/lists/*
|
10 |
EOF
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
13 |
|
14 |
COPY . .
|
15 |
|
16 |
RUN useradd -m -u 1000 user
|
17 |
-
# Switch to the "user" user
|
18 |
USER user
|
19 |
|
20 |
ENV HOME=/home/user \
|
@@ -23,7 +34,47 @@ ENV HOME=/home/user \
|
|
23 |
WORKDIR $HOME/app
|
24 |
|
25 |
COPY --chown=user . $HOME/app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
RUN uv sync --frozen --extra cu124
|
28 |
|
29 |
-
CMD ["uv","run","app.py",
|
|
|
9 |
apt clean && rm -rf /var/lib/apt/lists/*
|
10 |
EOF
|
11 |
|
12 |
+
# 安装Node.js和npm
|
13 |
+
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
14 |
+
apt-get install -y nodejs
|
15 |
+
|
16 |
+
# 安装pnpm
|
17 |
+
RUN npm install -g pnpm
|
18 |
+
|
19 |
+
# 克隆前端仓库并构建
|
20 |
+
RUN git clone https://github.com/SolomonLeon/web-rwkv-realweb.git /frontend
|
21 |
+
WORKDIR /frontend
|
22 |
+
RUN pnpm install && pnpm run build
|
23 |
+
|
24 |
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
25 |
|
26 |
COPY . .
|
27 |
|
28 |
RUN useradd -m -u 1000 user
|
|
|
29 |
USER user
|
30 |
|
31 |
ENV HOME=/home/user \
|
|
|
34 |
WORKDIR $HOME/app
|
35 |
|
36 |
COPY --chown=user . $HOME/app
|
37 |
+
COPY --chown=user /frontend/dist $HOME/app/dist-frontend
|
38 |
+
|
39 |
+
RUN cat > $HOME/app/config.local.yaml<<EOF
|
40 |
+
HOST: "0.0.0.0"
|
41 |
+
PORT: 7860
|
42 |
+
STRATEGY: "cuda fp16"
|
43 |
+
RWKV_CUDA_ON: False
|
44 |
+
CHUNK_LEN: 256
|
45 |
+
MODELS:
|
46 |
+
- SERVICE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096"
|
47 |
+
DOWNLOAD_MODEL_FILE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
|
48 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv-7-world"
|
49 |
+
DOWNLOAD_MODEL_DIR: "./"
|
50 |
+
REASONING: False
|
51 |
+
DEFAULT: True
|
52 |
+
DEFAULT_SAMPLER:
|
53 |
+
max_tokens: 512
|
54 |
+
temperature: 1.0
|
55 |
+
top_p: 0.3
|
56 |
+
presence_penalty: 0.5
|
57 |
+
count_penalty: 0.5
|
58 |
+
penalty_decay: 0.996
|
59 |
+
stop:
|
60 |
+
- "\n\n"
|
61 |
+
- SERVICE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k"
|
62 |
+
DOWNLOAD_MODEL_FILE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k.pth"
|
63 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/temp-latest-training-models"
|
64 |
+
DOWNLOAD_MODEL_DIR: "./"
|
65 |
+
REASONING: True
|
66 |
+
DEFAULT: True
|
67 |
+
DEFAULT_SAMPLER:
|
68 |
+
max_tokens: 4096
|
69 |
+
temperature: 1.0
|
70 |
+
top_p: 0.3
|
71 |
+
presence_penalty: 0.5
|
72 |
+
count_penalty: 0.5
|
73 |
+
penalty_decay: 0.996
|
74 |
+
stop:
|
75 |
+
- "\n\n"
|
76 |
+
EOF
|
77 |
|
78 |
RUN uv sync --frozen --extra cu124
|
79 |
|
80 |
+
CMD ["uv","run","app.py",]
|
README.md
CHANGED
@@ -25,7 +25,7 @@ python app.py --strategy "cuda fp16" --model_title "RWKV-x070-World-0.1B-v2.8-20
|
|
25 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
26 |
```
|
27 |
|
28 |
-
`RWKV7-G1-0.
|
29 |
|
30 |
```shell
|
31 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
|
|
25 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
26 |
```
|
27 |
|
28 |
+
`RWKV7-G1-0.4B-68%trained-20250303-ctx4k`
|
29 |
|
30 |
```shell
|
31 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
2 |
from huggingface_hub import hf_hub_download
|
3 |
from loguru import logger
|
@@ -6,32 +8,11 @@ from snowflake import SnowflakeGenerator
|
|
6 |
|
7 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
8 |
|
9 |
-
from typing import List, Optional, Union
|
10 |
-
from pydantic import BaseModel, Field
|
11 |
from pydantic_settings import BaseSettings
|
12 |
|
13 |
|
14 |
-
class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
|
15 |
-
HOST: str = Field("127.0.0.1", description="Host")
|
16 |
-
PORT: int = Field(8000, description="Port")
|
17 |
-
DEBUG: bool = Field(False, description="Debug mode")
|
18 |
-
STRATEGY: str = Field("cpu", description="Stratergy")
|
19 |
-
MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096")
|
20 |
-
DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world")
|
21 |
-
DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir")
|
22 |
-
MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path")
|
23 |
-
GEN_penalty_decay: float = Field(0.996, description="Default penalty decay")
|
24 |
-
CHUNK_LEN: int = Field(
|
25 |
-
256,
|
26 |
-
description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
|
27 |
-
)
|
28 |
-
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
29 |
-
RWKV_CUDA_ON:bool = Field(False, description="`True` to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!")
|
30 |
-
|
31 |
-
|
32 |
-
CONFIG = Config()
|
33 |
-
|
34 |
-
|
35 |
import numpy as np
|
36 |
import torch
|
37 |
|
@@ -58,9 +39,10 @@ os.environ["RWKV_CUDA_ON"] = (
|
|
58 |
from rwkv.model import RWKV
|
59 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
60 |
|
61 |
-
from fastapi import FastAPI
|
62 |
from fastapi.responses import StreamingResponse
|
63 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
64 |
|
65 |
from api_types import (
|
66 |
ChatMessage,
|
@@ -74,17 +56,50 @@ from api_types import (
|
|
74 |
from utils import cleanMessages, parse_think_response
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
)
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
class ChatCompletionRequest(BaseModel):
|
@@ -92,16 +107,33 @@ class ChatCompletionRequest(BaseModel):
|
|
92 |
default="rwkv-latest",
|
93 |
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
|
94 |
)
|
95 |
-
messages: List[ChatMessage]
|
96 |
prompt: Optional[str] = Field(default=None)
|
97 |
-
max_tokens: int = Field(default=
|
98 |
-
temperature: float = Field(default=
|
99 |
-
top_p: float = Field(default=
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
@@ -115,15 +147,19 @@ app.add_middleware(
|
|
115 |
)
|
116 |
|
117 |
|
118 |
-
async def runPrefill(
|
|
|
|
|
119 |
ctx = ctx.replace("\r\n", "\n")
|
120 |
|
121 |
-
tokens = pipeline.encode(ctx)
|
122 |
tokens = [int(x) for x in tokens]
|
123 |
model_tokens += tokens
|
124 |
|
125 |
while len(tokens) > 0:
|
126 |
-
out, model_state = model.forward(
|
|
|
|
|
127 |
tokens = tokens[CONFIG.CHUNK_LEN :]
|
128 |
await asyncio.sleep(0)
|
129 |
|
@@ -141,8 +177,8 @@ def generate(
|
|
141 |
args = PIPELINE_ARGS(
|
142 |
temperature=max(0.2, request.temperature),
|
143 |
top_p=request.top_p,
|
144 |
-
alpha_frequency=request.
|
145 |
-
alpha_presence=request.
|
146 |
token_ban=[], # ban the generation of some tokens
|
147 |
token_stop=[0],
|
148 |
) # stop generation whenever you see any token here
|
@@ -158,20 +194,22 @@ def generate(
|
|
158 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
159 |
out[0] -= 1e10 # disable END_OF_TEXT
|
160 |
|
161 |
-
token = pipeline.sample_logits(
|
162 |
out, temperature=args.temperature, top_p=args.top_p
|
163 |
)
|
164 |
|
165 |
-
out, model_state = model.forward(
|
|
|
|
|
166 |
model_tokens += [token]
|
167 |
|
168 |
out_tokens += [token]
|
169 |
|
170 |
for xxx in occurrence:
|
171 |
-
occurrence[xxx] *=
|
172 |
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
173 |
|
174 |
-
tmp: str = pipeline.decode(out_tokens[out_last:])
|
175 |
|
176 |
if "\ufffd" in tmp:
|
177 |
continue
|
@@ -210,19 +248,20 @@ def generate(
|
|
210 |
|
211 |
|
212 |
async def chatResponse(
|
213 |
-
request: ChatCompletionRequest,
|
|
|
|
|
|
|
214 |
) -> ChatCompletion:
|
215 |
createTimestamp = time.time()
|
216 |
|
217 |
-
enableReasoning = request.model.endswith(":thinking")
|
218 |
-
|
219 |
prompt = (
|
220 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
221 |
if request.prompt == None
|
222 |
else request.prompt.strip()
|
223 |
)
|
224 |
|
225 |
-
out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
|
226 |
|
227 |
prefillTime = time.time()
|
228 |
promptTokenCount = len(model_tokens)
|
@@ -291,19 +330,20 @@ async def chatResponse(
|
|
291 |
|
292 |
|
293 |
async def chatResponseStream(
|
294 |
-
request: ChatCompletionRequest,
|
|
|
|
|
|
|
295 |
):
|
296 |
createTimestamp = int(time.time())
|
297 |
|
298 |
-
enableReasoning = request.model.endswith(":thinking")
|
299 |
-
|
300 |
prompt = (
|
301 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
302 |
if request.prompt == None
|
303 |
else request.prompt.strip()
|
304 |
)
|
305 |
|
306 |
-
out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
|
307 |
|
308 |
prefillTime = time.time()
|
309 |
promptTokenCount = len(model_tokens)
|
@@ -343,7 +383,7 @@ async def chatResponseStream(
|
|
343 |
buffer = []
|
344 |
|
345 |
if enableReasoning:
|
346 |
-
buffer.append("
|
347 |
|
348 |
streamConfig = {
|
349 |
"isChecking": False,
|
@@ -532,6 +572,32 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
532 |
completionId = str(next(CompletionIdGenerator))
|
533 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
async def chatResponseStreamDisconnect():
|
536 |
if "cuda" in CONFIG.STRATEGY:
|
537 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
@@ -540,18 +606,27 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
540 |
)
|
541 |
|
542 |
model_state = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
|
544 |
if request.stream:
|
545 |
r = StreamingResponse(
|
546 |
-
chatResponseStream(
|
547 |
media_type="text/event-stream",
|
548 |
background=chatResponseStreamDisconnect,
|
549 |
)
|
550 |
else:
|
551 |
-
r = await chatResponse(
|
552 |
|
553 |
return r
|
554 |
|
|
|
555 |
|
556 |
if __name__ == "__main__":
|
557 |
import uvicorn
|
|
|
1 |
+
from config import CONFIG, ModelConfig
|
2 |
+
|
3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from loguru import logger
|
|
|
8 |
|
9 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
10 |
|
11 |
+
from typing import List, Optional, Union, Any, Dict
|
12 |
+
from pydantic import BaseModel, Field, model_validator
|
13 |
from pydantic_settings import BaseSettings
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import numpy as np
|
17 |
import torch
|
18 |
|
|
|
39 |
from rwkv.model import RWKV
|
40 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
41 |
|
42 |
+
from fastapi import FastAPI, HTTPException
|
43 |
from fastapi.responses import StreamingResponse
|
44 |
from fastapi.middleware.cors import CORSMiddleware
|
45 |
+
from fastapi.staticfiles import StaticFiles
|
46 |
|
47 |
from api_types import (
|
48 |
ChatMessage,
|
|
|
56 |
from utils import cleanMessages, parse_think_response
|
57 |
|
58 |
|
59 |
+
class ModelStorage:
|
60 |
+
MODEL_CONFIG: Optional[ModelConfig] = None
|
61 |
+
model: Optional[RWKV] = None
|
62 |
+
pipeline: Optional[PIPELINE] = None
|
63 |
+
|
64 |
+
|
65 |
+
MODEL_STORAGE: Dict[str, ModelStorage] = {}
|
66 |
+
|
67 |
+
DEFALUT_MODEL_NAME = None
|
68 |
+
DEFAULT_REASONING_MODEL_NAME = None
|
69 |
+
|
70 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
71 |
+
|
72 |
+
for model_config in CONFIG.MODELS:
|
73 |
+
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
74 |
+
|
75 |
+
if model_config.MODEL_FILE_PATH == None:
|
76 |
+
model_config.MODEL_FILE_PATH = hf_hub_download(
|
77 |
+
repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
|
78 |
+
filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
|
79 |
+
local_dir=model_config.DOWNLOAD_MODEL_DIR,
|
80 |
+
)
|
81 |
+
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
82 |
+
|
83 |
+
tmp_model = RWKV(
|
84 |
+
model=model_config.DOWNLOAD_MODEL_FILE_NAME.replace(".pth", ""),
|
85 |
+
strategy=CONFIG.STRATEGY,
|
86 |
)
|
87 |
+
tmp_pipeline = PIPELINE(tmp_model, model_config.VOCAB)
|
88 |
|
89 |
+
if model_config.DEFAULT:
|
90 |
+
if model_config.REASONING:
|
91 |
+
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
|
92 |
+
else:
|
93 |
+
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
94 |
+
|
95 |
+
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
96 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
97 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
98 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = tmp_pipeline
|
99 |
+
|
100 |
+
|
101 |
+
logger.info(f"DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
|
102 |
+
logger.info(f"DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`")
|
103 |
|
104 |
|
105 |
class ChatCompletionRequest(BaseModel):
|
|
|
107 |
default="rwkv-latest",
|
108 |
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
|
109 |
)
|
110 |
+
messages: Optional[List[ChatMessage]] = Field(default=None)
|
111 |
prompt: Optional[str] = Field(default=None)
|
112 |
+
max_tokens: Optional[int] = Field(default=None)
|
113 |
+
temperature: Optional[float] = Field(default=None)
|
114 |
+
top_p: Optional[float] = Field(default=None)
|
115 |
+
presence_penalty: Optional[float] = Field(default=None)
|
116 |
+
count_penalty: Optional[float] = Field(default=None)
|
117 |
+
penalty_decay: Optional[float] = Field(default=None)
|
118 |
+
stream: Optional[bool] = Field(default=False)
|
119 |
+
state_name: Optional[str] = Field(default=None)
|
120 |
+
include_usage: Optional[bool] = Field(default=False)
|
121 |
+
stop: Optional[list[str]] = Field(["\n\n"])
|
122 |
+
|
123 |
+
@model_validator(mode="before")
|
124 |
+
@classmethod
|
125 |
+
def validate_mutual_exclusivity(cls, data: Any) -> Any:
|
126 |
+
if not isinstance(data, dict):
|
127 |
+
return data
|
128 |
+
|
129 |
+
messages_provided = "messages" in data and data["messages"] != None
|
130 |
+
prompt_provided = "prompt" in data and data["prompt"] != None
|
131 |
+
|
132 |
+
if messages_provided and prompt_provided:
|
133 |
+
raise ValueError("messages and prompt cannot coexist. Choose one.")
|
134 |
+
if not messages_provided and not prompt_provided:
|
135 |
+
raise ValueError("Either messages or prompt must be provided.")
|
136 |
+
return data
|
137 |
|
138 |
|
139 |
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
|
|
147 |
)
|
148 |
|
149 |
|
150 |
+
async def runPrefill(
|
151 |
+
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
152 |
+
):
|
153 |
ctx = ctx.replace("\r\n", "\n")
|
154 |
|
155 |
+
tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
|
156 |
tokens = [int(x) for x in tokens]
|
157 |
model_tokens += tokens
|
158 |
|
159 |
while len(tokens) > 0:
|
160 |
+
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
161 |
+
tokens[: CONFIG.CHUNK_LEN], model_state
|
162 |
+
)
|
163 |
tokens = tokens[CONFIG.CHUNK_LEN :]
|
164 |
await asyncio.sleep(0)
|
165 |
|
|
|
177 |
args = PIPELINE_ARGS(
|
178 |
temperature=max(0.2, request.temperature),
|
179 |
top_p=request.top_p,
|
180 |
+
alpha_frequency=request.count_penalty,
|
181 |
+
alpha_presence=request.presence_penalty,
|
182 |
token_ban=[], # ban the generation of some tokens
|
183 |
token_stop=[0],
|
184 |
) # stop generation whenever you see any token here
|
|
|
194 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
195 |
out[0] -= 1e10 # disable END_OF_TEXT
|
196 |
|
197 |
+
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
198 |
out, temperature=args.temperature, top_p=args.top_p
|
199 |
)
|
200 |
|
201 |
+
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
202 |
+
[token], model_state
|
203 |
+
)
|
204 |
model_tokens += [token]
|
205 |
|
206 |
out_tokens += [token]
|
207 |
|
208 |
for xxx in occurrence:
|
209 |
+
occurrence[xxx] *= request.penalty_decay
|
210 |
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
211 |
|
212 |
+
tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
|
213 |
|
214 |
if "\ufffd" in tmp:
|
215 |
continue
|
|
|
248 |
|
249 |
|
250 |
async def chatResponse(
|
251 |
+
request: ChatCompletionRequest,
|
252 |
+
model_state: any,
|
253 |
+
completionId: str,
|
254 |
+
enableReasoning: bool,
|
255 |
) -> ChatCompletion:
|
256 |
createTimestamp = time.time()
|
257 |
|
|
|
|
|
258 |
prompt = (
|
259 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
260 |
if request.prompt == None
|
261 |
else request.prompt.strip()
|
262 |
)
|
263 |
|
264 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
265 |
|
266 |
prefillTime = time.time()
|
267 |
promptTokenCount = len(model_tokens)
|
|
|
330 |
|
331 |
|
332 |
async def chatResponseStream(
|
333 |
+
request: ChatCompletionRequest,
|
334 |
+
model_state: any,
|
335 |
+
completionId: str,
|
336 |
+
enableReasoning: bool,
|
337 |
):
|
338 |
createTimestamp = int(time.time())
|
339 |
|
|
|
|
|
340 |
prompt = (
|
341 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
342 |
if request.prompt == None
|
343 |
else request.prompt.strip()
|
344 |
)
|
345 |
|
346 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
347 |
|
348 |
prefillTime = time.time()
|
349 |
promptTokenCount = len(model_tokens)
|
|
|
383 |
buffer = []
|
384 |
|
385 |
if enableReasoning:
|
386 |
+
buffer.append("<think")
|
387 |
|
388 |
streamConfig = {
|
389 |
"isChecking": False,
|
|
|
572 |
completionId = str(next(CompletionIdGenerator))
|
573 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
574 |
|
575 |
+
modelName = request.model.split(":")[0]
|
576 |
+
enableReasoning = ":thinking" in request.model
|
577 |
+
|
578 |
+
if "rwkv-latest" in request.model:
|
579 |
+
if enableReasoning:
|
580 |
+
if DEFAULT_REASONING_MODEL_NAME == None:
|
581 |
+
raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set")
|
582 |
+
defaultSamplerConfig = MODEL_STORAGE[
|
583 |
+
DEFAULT_REASONING_MODEL_NAME
|
584 |
+
].MODEL_CONFIG.DEFAULT_SAMPLER
|
585 |
+
request.model = DEFAULT_REASONING_MODEL_NAME
|
586 |
+
|
587 |
+
else:
|
588 |
+
if DEFALUT_MODEL_NAME == None:
|
589 |
+
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
|
590 |
+
defaultSamplerConfig = MODEL_STORAGE[
|
591 |
+
DEFALUT_MODEL_NAME
|
592 |
+
].MODEL_CONFIG.DEFAULT_SAMPLER
|
593 |
+
request.model = DEFALUT_MODEL_NAME
|
594 |
+
|
595 |
+
elif modelName in MODEL_STORAGE:
|
596 |
+
defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
|
597 |
+
request.model = modelName
|
598 |
+
else:
|
599 |
+
raise f"Can not find `{modelName}`"
|
600 |
+
|
601 |
async def chatResponseStreamDisconnect():
|
602 |
if "cuda" in CONFIG.STRATEGY:
|
603 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
|
|
606 |
)
|
607 |
|
608 |
model_state = None
|
609 |
+
request_dict = request.model_dump()
|
610 |
+
|
611 |
+
for k, v in defaultSamplerConfig.model_dump().items():
|
612 |
+
if request_dict[k] == None:
|
613 |
+
request_dict[k] = v
|
614 |
+
realRequest = ChatCompletionRequest(**request_dict)
|
615 |
+
|
616 |
+
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
617 |
|
618 |
if request.stream:
|
619 |
r = StreamingResponse(
|
620 |
+
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
621 |
media_type="text/event-stream",
|
622 |
background=chatResponseStreamDisconnect,
|
623 |
)
|
624 |
else:
|
625 |
+
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
|
626 |
|
627 |
return r
|
628 |
|
629 |
+
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
630 |
|
631 |
if __name__ == "__main__":
|
632 |
import uvicorn
|
config.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import List, Optional
|
3 |
+
from typing import List, Optional, Union, Any
|
4 |
+
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
from pydantic_settings import BaseSettings
|
9 |
+
|
10 |
+
|
11 |
+
class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
|
12 |
+
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
|
13 |
+
|
14 |
+
|
15 |
+
CLI_CONFIG = CliConfig()
|
16 |
+
|
17 |
+
|
18 |
+
class SamplerConfig(BaseModel):
|
19 |
+
"""Default sampler configuration for each model."""
|
20 |
+
|
21 |
+
max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
|
22 |
+
temperature: float = Field(1.0, description="Sampling temperature.")
|
23 |
+
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
24 |
+
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
25 |
+
count_penalty: float = Field(0.5, description="Count penalty.")
|
26 |
+
penalty_decay: float = Field(0.5, description="Penalty decay factor.")
|
27 |
+
stop: List[str] = Field(0.996, description="List of stop sequences.")
|
28 |
+
|
29 |
+
|
30 |
+
class ModelConfig(BaseModel):
|
31 |
+
"""Configuration for each individual model."""
|
32 |
+
|
33 |
+
SERVICE_NAME: str = Field(..., description="Service name of the model.")
|
34 |
+
|
35 |
+
MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
|
36 |
+
|
37 |
+
DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
|
38 |
+
None, description="Model name, should end with .pth"
|
39 |
+
)
|
40 |
+
DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
|
41 |
+
None, description="Model repository ID on Hugging Face Hub."
|
42 |
+
)
|
43 |
+
DOWNLOAD_MODEL_DIR: Optional[str] = Field(
|
44 |
+
None, description="Directory to download the model to."
|
45 |
+
)
|
46 |
+
|
47 |
+
REASONING: bool = Field(
|
48 |
+
False, description="Whether reasoning is enabled for this model."
|
49 |
+
)
|
50 |
+
|
51 |
+
DEFAULT: bool = Field(False, description="Whether this model is the default model.")
|
52 |
+
DEFAULT_SAMPLER: SamplerConfig = Field(
|
53 |
+
SamplerConfig(), description="Default sampler configuration for this model."
|
54 |
+
)
|
55 |
+
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
56 |
+
|
57 |
+
|
58 |
+
class RootConfig(BaseModel):
|
59 |
+
"""Root configuration for the RWKV service."""
|
60 |
+
|
61 |
+
HOST: Optional[str] = Field(
|
62 |
+
"127.0.0.1", description="Host IP address to bind to."
|
63 |
+
) # 注释掉可选的HOST和PORT
|
64 |
+
PORT: Optional[int] = Field(
|
65 |
+
8000, description="Port number to listen on."
|
66 |
+
) # 因为YAML示例中被注释掉了
|
67 |
+
STRATEGY: str = Field(
|
68 |
+
"cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
|
69 |
+
)
|
70 |
+
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
|
71 |
+
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
|
72 |
+
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
|
73 |
+
|
74 |
+
|
75 |
+
import yaml
|
76 |
+
|
77 |
+
try:
|
78 |
+
with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
|
79 |
+
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Pydantic Model Validation Failed: {e}")
|
82 |
+
sys.exit(0)
|
openai_test.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
uv pip install openai
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
|
7 |
-
import logging
|
8 |
-
|
9 |
-
# logging.basicConfig(
|
10 |
-
# level=logging.DEBUG,
|
11 |
-
# )
|
12 |
-
|
13 |
-
os.environ["NO_PROXY"] = "127.0.0.1"
|
14 |
-
|
15 |
-
from openai import OpenAI
|
16 |
-
|
17 |
-
client = OpenAI(base_url="http://127.0.0.1:8000/api/v1", api_key="sk-test")
|
18 |
-
|
19 |
-
|
20 |
-
def completionStreamTest():
|
21 |
-
print("[*] Stream completion: ")
|
22 |
-
|
23 |
-
completion = client.chat.completions.create(
|
24 |
-
model="rwkv-latest",
|
25 |
-
messages=[
|
26 |
-
{
|
27 |
-
"role": "User",
|
28 |
-
"content": "请讲个关于一只灰猫和一个小女孩之间的简短故事。",
|
29 |
-
},
|
30 |
-
],
|
31 |
-
stream=True,
|
32 |
-
max_tokens=2048,
|
33 |
-
)
|
34 |
-
|
35 |
-
isReasoning = False
|
36 |
-
|
37 |
-
for chunk in completion:
|
38 |
-
if chunk.choices[0].delta.reasoning_content and not isReasoning:
|
39 |
-
print("<- Reasoning ->")
|
40 |
-
isReasoning = True
|
41 |
-
elif chunk.choices[0].delta.content and isReasoning:
|
42 |
-
isReasoning = False
|
43 |
-
print("<- Stop Reasoning ->")
|
44 |
-
|
45 |
-
if chunk.choices[0].delta.reasoning_content:
|
46 |
-
print(chunk.choices[0].delta.reasoning_content, end="", flush=True)
|
47 |
-
if chunk.choices[0].delta.content:
|
48 |
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
49 |
-
|
50 |
-
print("")
|
51 |
-
|
52 |
-
|
53 |
-
def completionTest():
|
54 |
-
completion = client.chat.completions.create(
|
55 |
-
model="rwkv-latest:thinking",
|
56 |
-
messages=[
|
57 |
-
{
|
58 |
-
"role": "User",
|
59 |
-
"content": "How many planets are there in our solar system?",
|
60 |
-
},
|
61 |
-
],
|
62 |
-
max_tokens=2048,
|
63 |
-
)
|
64 |
-
|
65 |
-
print("[*] Completion: ", completion)
|
66 |
-
|
67 |
-
|
68 |
-
if __name__ == "__main__":
|
69 |
-
try:
|
70 |
-
# completionTest()
|
71 |
-
|
72 |
-
testRounds = input("Test rounds (Default: 10) :")
|
73 |
-
|
74 |
-
for i in range(int(testRounds) if testRounds != "" else 10):
|
75 |
-
print("\n", "=" * 10, i + 1, "/", testRounds, "=" * 10)
|
76 |
-
completionStreamTest()
|
77 |
-
except KeyboardInterrupt:
|
78 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|