Gopal2002 commited on
Commit
bf97bdc
·
verified ·
1 Parent(s): 4255893

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (FastAPI server to host the Jina Embedding model)
2
+ # Must be set before importing Hugging Face libs
3
+ import os
4
+ os.environ["HF_HOME"] = "/tmp/huggingface"
5
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
6
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
7
+ from fastapi import FastAPI
8
+ from pydantic import BaseModel
9
+ from typing import List, Optional
10
+ import torch
11
+ from transformers import AutoModel, AutoTokenizer
12
+
13
+ app = FastAPI()
14
+
15
+ # -----------------------------
16
+ # Load model once on startup
17
+ # -----------------------------
18
+ MODEL_NAME = "jinaai/jina-embeddings-v4"
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
22
+ model = AutoModel.from_pretrained(
23
+ MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
24
+ ).to(device)
25
+ model.eval()
26
+
27
+
28
+ # -----------------------------
29
+ # Request / Response Models
30
+ # -----------------------------
31
+ class EmbedRequest(BaseModel):
32
+ text: str
33
+ task: str = "retrieval" # "retrieval", "text-matching", "code", etc.
34
+ prompt_name: Optional[str] = None
35
+ return_token_embeddings: bool = True # False → for queries (pooled embedding)
36
+
37
+
38
+ class EmbedResponse(BaseModel):
39
+ embeddings: List[List[float]] # (num_tokens, hidden_dim) if token-level
40
+ # (1, hidden_dim) if pooled query
41
+
42
+
43
+ class TokenizeRequest(BaseModel):
44
+ text: str
45
+
46
+
47
+ class TokenizeResponse(BaseModel):
48
+ input_ids: List[int]
49
+
50
+
51
+ class DecodeRequest(BaseModel):
52
+ input_ids: List[int]
53
+
54
+
55
+ class DecodeResponse(BaseModel):
56
+ text: str
57
+
58
+
59
+ # -----------------------------
60
+ # Embedding Endpoint
61
+ # -----------------------------
62
+ @app.post("/embed", response_model=EmbedResponse)
63
+ def embed(req: EmbedRequest):
64
+ text = req.text
65
+
66
+ # -----------------------------
67
+ # Case 1: Query → directly pooled embedding
68
+ # -----------------------------
69
+ if not req.return_token_embeddings:
70
+ with torch.no_grad():
71
+ emb = model.encode_text(
72
+ texts=[text],
73
+ task=req.task,
74
+ prompt_name=req.prompt_name or "query",
75
+ return_multivector=False
76
+ )
77
+ return {"embeddings": emb.tolist()} # shape: (1, hidden_dim)
78
+
79
+ # -----------------------------
80
+ # Case 2: Long passages → sliding window token embeddings
81
+ # -----------------------------
82
+ enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
83
+ input_ids = enc["input_ids"].squeeze(0).to(device) # (total_tokens,)
84
+ total_tokens = input_ids.size(0)
85
+
86
+ max_len = model.config.max_position_embeddings # e.g., 32k for v4
87
+ stride = 50 # overlap for sliding window
88
+ embeddings = []
89
+ position = 0
90
+
91
+ while position < total_tokens:
92
+ end = min(position + max_len, total_tokens)
93
+ window_ids = input_ids[position:end].unsqueeze(0).to(device)
94
+
95
+ with torch.no_grad():
96
+ outputs = model.encode_text(
97
+ texts=[tokenizer.decode(window_ids[0])],
98
+ task=req.task,
99
+ prompt_name=req.prompt_name or "passage",
100
+ return_multivector=True,
101
+ )
102
+
103
+ window_embeds = outputs.squeeze(0).cpu() # (window_len, hidden_dim)
104
+
105
+ # Drop overlapping tokens except in first window
106
+ if position > 0:
107
+ window_embeds = window_embeds[stride:]
108
+
109
+ embeddings.append(window_embeds)
110
+
111
+ # Advance window
112
+ position += max_len - stride
113
+
114
+ full_embeddings = torch.cat(embeddings, dim=0) # (total_tokens, hidden_dim)
115
+ return {"embeddings": full_embeddings.tolist()}
116
+
117
+
118
+ # -----------------------------
119
+ # Tokenize Endpoint
120
+ # -----------------------------
121
+ @app.post("/tokenize", response_model=TokenizeResponse)
122
+ def tokenize(req: TokenizeRequest):
123
+ enc = tokenizer(req.text, add_special_tokens=False)
124
+ return {"input_ids": enc["input_ids"]}
125
+
126
+
127
+ # -----------------------------
128
+ # Decode Endpoint
129
+ # -----------------------------
130
+ @app.post("/decode", response_model=DecodeResponse)
131
+ def decode(req: DecodeRequest):
132
+ decoded = tokenizer.decode(req.input_ids)
133
+ return {"text": decoded}