Minibase commited on
Commit
84c0271
·
verified ·
1 Parent(s): e322de3

Upload deid_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. deid_inference.py +187 -0
deid_inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DeId-Small Inference Client
4
+
5
+ A Python client for running inference with the Minibase-DeId-Small model.
6
+ Handles text de-identification requests to the local llama.cpp server.
7
+ """
8
+
9
+ import requests
10
+ import json
11
+ from typing import Optional, Dict, Any, Tuple
12
+ import time
13
+
14
+
15
+ class DeIdClient:
16
+ """
17
+ Client for the DeId-Small de-identification model.
18
+
19
+ This client communicates with a local llama.cpp server running the
20
+ Minibase-DeId-Small model for text de-identification tasks.
21
+ """
22
+
23
+ def __init__(self, base_url: str = "http://127.0.0.1:8000", timeout: int = 30):
24
+ """
25
+ Initialize the DeId client.
26
+
27
+ Args:
28
+ base_url: Base URL of the llama.cpp server
29
+ timeout: Request timeout in seconds
30
+ """
31
+ self.base_url = base_url.rstrip('/')
32
+ self.timeout = timeout
33
+ self.default_instruction = "De-identify this text by replacing all personal information with placeholders."
34
+
35
+ def _make_request(self, prompt: str, max_tokens: int = 256,
36
+ temperature: float = 0.1) -> Tuple[str, float]:
37
+ """
38
+ Make a completion request to the model.
39
+
40
+ Args:
41
+ prompt: The input prompt
42
+ max_tokens: Maximum tokens to generate
43
+ temperature: Sampling temperature
44
+
45
+ Returns:
46
+ Tuple of (response_text, latency_ms)
47
+ """
48
+ payload = {
49
+ "prompt": prompt,
50
+ "max_tokens": max_tokens,
51
+ "temperature": temperature
52
+ }
53
+
54
+ headers = {'Content-Type': 'application/json'}
55
+
56
+ start_time = time.time()
57
+ try:
58
+ response = requests.post(
59
+ f"{self.base_url}/completion",
60
+ json=payload,
61
+ headers=headers,
62
+ timeout=self.timeout
63
+ )
64
+
65
+ latency = (time.time() - start_time) * 1000 # Convert to ms
66
+
67
+ if response.status_code == 200:
68
+ result = response.json()
69
+ return result.get('content', ''), latency
70
+ else:
71
+ return f"Error: Server returned status {response.status_code}", latency
72
+
73
+ except requests.exceptions.RequestException as e:
74
+ latency = (time.time() - start_time) * 1000
75
+ return f"Error: {e}", latency
76
+
77
+ def deidentify_text(self, text: str, instruction: Optional[str] = None,
78
+ max_tokens: int = 256, temperature: float = 0.1) -> str:
79
+ """
80
+ De-identify a text by removing personal identifiers.
81
+
82
+ Args:
83
+ text: The text to de-identify
84
+ instruction: Custom instruction (uses default if None)
85
+ max_tokens: Maximum tokens to generate
86
+ temperature: Sampling temperature (lower = more consistent)
87
+
88
+ Returns:
89
+ De-identified text with placeholders
90
+ """
91
+ if instruction is None:
92
+ instruction = self.default_instruction
93
+
94
+ prompt = f"Instruction: {instruction}\n\nInput: {text}\n\nResponse: "
95
+
96
+ response, _ = self._make_request(prompt, max_tokens, temperature)
97
+ return response
98
+
99
+ def deidentify_batch(self, texts: list, instruction: Optional[str] = None,
100
+ max_tokens: int = 256, temperature: float = 0.1) -> list:
101
+ """
102
+ De-identify multiple texts in batch.
103
+
104
+ Args:
105
+ texts: List of texts to de-identify
106
+ instruction: Custom instruction for all texts
107
+ max_tokens: Maximum tokens per response
108
+ temperature: Sampling temperature
109
+
110
+ Returns:
111
+ List of de-identified texts
112
+ """
113
+ results = []
114
+ for text in texts:
115
+ result = self.deidentify_text(text, instruction, max_tokens, temperature)
116
+ results.append(result)
117
+ return results
118
+
119
+ def health_check(self) -> bool:
120
+ """
121
+ Check if the model server is healthy and responding.
122
+
123
+ Returns:
124
+ True if server is healthy, False otherwise
125
+ """
126
+ try:
127
+ # Try completion endpoint first
128
+ response = requests.post(
129
+ f"{self.base_url}/completion",
130
+ json={"prompt": "Hello", "max_tokens": 1},
131
+ timeout=5
132
+ )
133
+ return response.status_code == 200
134
+ except:
135
+ return False
136
+
137
+ def get_server_info(self) -> Optional[Dict[str, Any]]:
138
+ """
139
+ Get server information if available.
140
+
141
+ Returns:
142
+ Server info dict or None if unavailable
143
+ """
144
+ try:
145
+ response = requests.get(f"{self.base_url}/props", timeout=5)
146
+ if response.status_code == 200:
147
+ return response.json()
148
+ except:
149
+ pass
150
+ return None
151
+
152
+
153
+ def main():
154
+ """Example usage of the DeId client."""
155
+ client = DeIdClient()
156
+
157
+ # Check server health
158
+ if not client.health_check():
159
+ print("❌ Error: DeId-Small server not responding. Please start the server first.")
160
+ print(" Run: ./Minibase-personal-id-masking-small.app/Contents/MacOS/run_server")
161
+ return
162
+
163
+ print("✅ DeId-Small server is running!")
164
+
165
+ # Example texts to de-identify
166
+ examples = [
167
+ "Patient John Smith, born 1985-03-15, lives at 123 Main Street, Boston MA.",
168
+ "Dr. Sarah Johnson called from (555) 123-4567 about the appointment.",
169
+ "Employee Jane Doe earns $75,000 annually at TechCorp Inc.",
170
+ "Customer Michael Brown reported issue with Order #CUST-12345."
171
+ ]
172
+
173
+ print("\n🔒 De-identification Examples:")
174
+ print("=" * 50)
175
+
176
+ for i, text in enumerate(examples, 1):
177
+ print(f"\n📝 Example {i}:")
178
+ print(f"Input: {text}")
179
+
180
+ clean_text = client.deidentify_text(text)
181
+ print(f"Output: {clean_text}")
182
+
183
+ print("\n✨ De-identification complete!")
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main()