militarybearz commited on
Commit
69928c1
·
verified ·
1 Parent(s): 30abb01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -117
app.py CHANGED
@@ -1,119 +1,71 @@
1
- import torch
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- from typing import List, Optional, Union
5
- # !!! ИСПРАВЛЕНИЕ: Используем правильный класс для Seq2Seq моделей, как T5/FRIDA !!!
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
7
-
8
- # --- 1. Конфигурация и загрузка модели ---
9
-
10
- # Имя модели на Hugging Face
11
- MODEL_NAME = "ai-forever/FRIDA"
12
- # Имя, которое будет возвращаться в ответах API (может быть любым)
13
- MODEL_ALIAS = "frida-v1"
14
-
15
- print("Starting model loading process...")
16
-
17
- # Конфигурация квантизации для экономии памяти (ОБЯЗАТЕЛЬНО для бесплатного тарифа)
18
- quantization_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_compute_dtype=torch.float16
21
  )
22
 
23
- # Загружаем токенизатор
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
25
-
26
- # !!! ИСПРАВЛЕНИЕ: Загружаем модель с помощью AutoModelForSeq2SeqLM !!!
27
- model = AutoModelForSeq2SeqLM.from_pretrained(
28
- MODEL_NAME,
29
- quantization_config=quantization_config,
30
- device_map="auto", # Автоматически распределит модель по доступным ресурсам
31
- trust_remote_code=True
32
- )
33
- model.eval() # Переводим модель в режим инференса
34
-
35
- print("Model and tokenizer loaded successfully.")
36
-
37
- # --- 2. Определение Pydantic-моделей для имитации OpenAI API ---
38
-
39
- class ChatMessage(BaseModel):
40
- role: str
41
- content: str
42
-
43
- class ChatCompletionRequest(BaseModel):
44
- model: str # Будет проигнорировано, но нужно для совместимости
45
- messages: List[ChatMessage]
46
- temperature: Optional[float] = 0.7
47
- max_tokens: Optional[int] = 1024
48
-
49
- class ChatCompletionChoice(BaseModel):
50
- index: int
51
- message: ChatMessage
52
- finish_reason: str = "stop"
53
-
54
- class UsageInfo(BaseModel):
55
- prompt_tokens: int
56
- completion_tokens: int
57
- total_tokens: int
58
-
59
- class ChatCompletionResponse(BaseModel):
60
- id: str = "chatcmpl-mock"
61
- object: str = "chat.completion"
62
- created: int = 0
63
- model: str = MODEL_ALIAS
64
- choices: List[ChatCompletionChoice]
65
- usage: UsageInfo
66
-
67
- # --- 3. Создание FastAPI приложения ---
68
- app = FastAPI()
69
-
70
- # --- 4. Реализация эндпоинта /v1/chat/completions ---
71
-
72
- @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
73
- async def create_chat_completion(request: ChatCompletionRequest):
74
- """
75
- Основная функция, которая принимает запрос и генерирует ответ.
76
- """
77
- print(f"Received request: {request.dict()}")
78
-
79
- # Преобразуем сообщения из формата OpenAI в единую строку для T5 модели
80
- prompt_text = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
81
-
82
- print(f"Formatted prompt for FRIDA:\n{prompt_text}")
83
-
84
- # Кодируем текст в токены
85
- inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
86
- prompt_tokens_count = inputs["input_ids"].shape[1]
87
-
88
- # Генерируем ответ от модели
89
- outputs = model.generate(
90
- **inputs,
91
- max_new_tokens=request.max_tokens,
92
- temperature=request.temperature,
93
- do_sample=True,
94
- eos_token_id=tokenizer.eos_token_id,
95
- pad_token_id=tokenizer.pad_token_id
96
- )
97
-
98
- # Декодируем сгенерированные токены обратно в текст
99
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
-
101
- # Считаем токены ответа
102
- completion_tokens_count = outputs[0].shape[0]
103
-
104
- print(f"Generated response: {response_text}")
105
-
106
- # Формируем ответ в формате OpenAI
107
- response_message = ChatMessage(role="assistant", content=response_text)
108
- choice = ChatCompletionChoice(index=0, message=response_message)
109
- usage = UsageInfo(
110
- prompt_tokens=prompt_tokens_count,
111
- completion_tokens=completion_tokens_count,
112
- total_tokens=prompt_tokens_count + completion_tokens_count
113
- )
114
-
115
- return ChatCompletionResponse(choices=[choice], usage=usage)
116
-
117
- @app.get("/")
118
- def health_check():
119
- return {"status": "ok", "model_name": MODEL_ALIAS}
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Optional
5
+ import numpy as np
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ app = FastAPI(title="FRIDA Embedding API", version="1.0")
10
+
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
 
 
 
 
17
  )
18
 
19
+ MODEL_NAME = "ai-forever/FRIDA"
20
+ model = SentenceTransformer(MODEL_NAME)
21
+ EMBED_DIM = model.get_sentence_embedding_dimension()
22
+
23
+ SUPPORTED_PROMPTS = [
24
+ "search_query",
25
+ "search_document",
26
+ "paraphrase",
27
+ "categorize",
28
+ "categorize_sentiment",
29
+ "categorize_topic",
30
+ "categorize_entailment",
31
+ ]
32
+
33
+ class EmbedRequest(BaseModel):
34
+ texts: List[str] = Field(..., description="Список текстов")
35
+ prompt_name: Optional[str] = Field("search_document", description="FRIDA prompt_name")
36
+
37
+ class EmbedResponse(BaseModel):
38
+ embeddings: List[List[float]]
39
+ dim: int
40
+
41
+ @app.get("/health")
42
+ def health():
43
+ return {"status": "ok"}
44
+
45
+ @app.get("/metadata")
46
+ def metadata():
47
+ return {
48
+ "model": MODEL_NAME,
49
+ "embedding_dim": EMBED_DIM,
50
+ "pooling": "cls",
51
+ "prompts_supported": SUPPORTED_PROMPTS,
52
+ }
53
+
54
+ @app.post("/embed", response_model=EmbedResponse)
55
+ def embed(req: EmbedRequest):
56
+ if not req.texts:
57
+ raise HTTPException(status_code=400, detail="texts must be non-empty")
58
+ prompt = req.prompt_name or "search_document"
59
+ if prompt not in SUPPORTED_PROMPTS:
60
+ raise HTTPException(status_code=400, detail=f"Unsupported prompt_name: {prompt}")
61
+
62
+ vectors = model.encode(
63
+ req.texts,
64
+ convert_to_numpy=True,
65
+ prompt_name=prompt,
66
+ normalize_embeddings=True,
67
+ batch_size=min(16, max(1, len(req.texts))),
68
+ show_progress_bar=False,
69
+ ).astype(np.float32)
70
+
71
+ return {"embeddings": vectors.tolist(), "dim": int(vectors.shape[1])}