|
|
|
""" |
|
DeId-Small Inference Client |
|
|
|
A Python client for running inference with the Minibase-DeId-Small model. |
|
Handles text de-identification requests to the local llama.cpp server. |
|
""" |
|
|
|
import requests |
|
import json |
|
from typing import Optional, Dict, Any, Tuple |
|
import time |
|
|
|
|
|
class DeIdClient: |
|
""" |
|
Client for the DeId-Small de-identification model. |
|
|
|
This client communicates with a local llama.cpp server running the |
|
Minibase-DeId-Small model for text de-identification tasks. |
|
""" |
|
|
|
def __init__(self, base_url: str = "http://127.0.0.1:8000", timeout: int = 30): |
|
""" |
|
Initialize the DeId client. |
|
|
|
Args: |
|
base_url: Base URL of the llama.cpp server |
|
timeout: Request timeout in seconds |
|
""" |
|
self.base_url = base_url.rstrip('/') |
|
self.timeout = timeout |
|
self.default_instruction = "De-identify this text by replacing all personal information with placeholders." |
|
|
|
def _make_request(self, prompt: str, max_tokens: int = 256, |
|
temperature: float = 0.1) -> Tuple[str, float]: |
|
""" |
|
Make a completion request to the model. |
|
|
|
Args: |
|
prompt: The input prompt |
|
max_tokens: Maximum tokens to generate |
|
temperature: Sampling temperature |
|
|
|
Returns: |
|
Tuple of (response_text, latency_ms) |
|
""" |
|
payload = { |
|
"prompt": prompt, |
|
"max_tokens": max_tokens, |
|
"temperature": temperature |
|
} |
|
|
|
headers = {'Content-Type': 'application/json'} |
|
|
|
start_time = time.time() |
|
try: |
|
response = requests.post( |
|
f"{self.base_url}/completion", |
|
json=payload, |
|
headers=headers, |
|
timeout=self.timeout |
|
) |
|
|
|
latency = (time.time() - start_time) * 1000 |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
return result.get('content', ''), latency |
|
else: |
|
return f"Error: Server returned status {response.status_code}", latency |
|
|
|
except requests.exceptions.RequestException as e: |
|
latency = (time.time() - start_time) * 1000 |
|
return f"Error: {e}", latency |
|
|
|
def deidentify_text(self, text: str, instruction: Optional[str] = None, |
|
max_tokens: int = 256, temperature: float = 0.1) -> str: |
|
""" |
|
De-identify a text by removing personal identifiers. |
|
|
|
Args: |
|
text: The text to de-identify |
|
instruction: Custom instruction (uses default if None) |
|
max_tokens: Maximum tokens to generate |
|
temperature: Sampling temperature (lower = more consistent) |
|
|
|
Returns: |
|
De-identified text with placeholders |
|
""" |
|
if instruction is None: |
|
instruction = self.default_instruction |
|
|
|
prompt = f"Instruction: {instruction}\n\nInput: {text}\n\nResponse: " |
|
|
|
response, _ = self._make_request(prompt, max_tokens, temperature) |
|
return response |
|
|
|
def deidentify_batch(self, texts: list, instruction: Optional[str] = None, |
|
max_tokens: int = 256, temperature: float = 0.1) -> list: |
|
""" |
|
De-identify multiple texts in batch. |
|
|
|
Args: |
|
texts: List of texts to de-identify |
|
instruction: Custom instruction for all texts |
|
max_tokens: Maximum tokens per response |
|
temperature: Sampling temperature |
|
|
|
Returns: |
|
List of de-identified texts |
|
""" |
|
results = [] |
|
for text in texts: |
|
result = self.deidentify_text(text, instruction, max_tokens, temperature) |
|
results.append(result) |
|
return results |
|
|
|
def health_check(self) -> bool: |
|
""" |
|
Check if the model server is healthy and responding. |
|
|
|
Returns: |
|
True if server is healthy, False otherwise |
|
""" |
|
try: |
|
|
|
response = requests.post( |
|
f"{self.base_url}/completion", |
|
json={"prompt": "Hello", "max_tokens": 1}, |
|
timeout=5 |
|
) |
|
return response.status_code == 200 |
|
except: |
|
return False |
|
|
|
def get_server_info(self) -> Optional[Dict[str, Any]]: |
|
""" |
|
Get server information if available. |
|
|
|
Returns: |
|
Server info dict or None if unavailable |
|
""" |
|
try: |
|
response = requests.get(f"{self.base_url}/props", timeout=5) |
|
if response.status_code == 200: |
|
return response.json() |
|
except: |
|
pass |
|
return None |
|
|
|
|
|
def main(): |
|
"""Example usage of the DeId client.""" |
|
client = DeIdClient() |
|
|
|
|
|
if not client.health_check(): |
|
print("❌ Error: DeId-Small server not responding. Please start the server first.") |
|
print(" Run: ./Minibase-personal-id-masking-small.app/Contents/MacOS/run_server") |
|
return |
|
|
|
print("✅ DeId-Small server is running!") |
|
|
|
|
|
examples = [ |
|
"Patient John Smith, born 1985-03-15, lives at 123 Main Street, Boston MA.", |
|
"Dr. Sarah Johnson called from (555) 123-4567 about the appointment.", |
|
"Employee Jane Doe earns $75,000 annually at TechCorp Inc.", |
|
"Customer Michael Brown reported issue with Order #CUST-12345." |
|
] |
|
|
|
print("\n🔒 De-identification Examples:") |
|
print("=" * 50) |
|
|
|
for i, text in enumerate(examples, 1): |
|
print(f"\n📝 Example {i}:") |
|
print(f"Input: {text}") |
|
|
|
clean_text = client.deidentify_text(text) |
|
print(f"Output: {clean_text}") |
|
|
|
print("\n✨ De-identification complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|