|
|
|
|
|
""" |
|
|
NER-Small Inference Client |
|
|
|
|
|
A Python client for running inference with the Minibase-NER-Small model. |
|
|
Handles named entity recognition requests to the local llama.cpp server. |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import json |
|
|
from typing import Optional, Dict, Any, Tuple, List |
|
|
import time |
|
|
import re |
|
|
|
|
|
|
|
|
class NERClient: |
|
|
""" |
|
|
Client for the NER-Small named entity recognition model. |
|
|
|
|
|
This client communicates with a local llama.cpp server running the |
|
|
Minibase-NER-Small model for named entity recognition tasks. |
|
|
""" |
|
|
|
|
|
def __init__(self, base_url: str = "http://127.0.0.1:8000", timeout: int = 30): |
|
|
""" |
|
|
Initialize the NER client. |
|
|
|
|
|
Args: |
|
|
base_url: Base URL of the llama.cpp server |
|
|
timeout: Request timeout in seconds |
|
|
""" |
|
|
self.base_url = base_url.rstrip('/') |
|
|
self.timeout = timeout |
|
|
self.default_instruction = "Extract all named entities from the following text. List them as 1. Entity, 2. Entity, etc." |
|
|
|
|
|
def _make_request(self, prompt: str, max_tokens: int = 512, |
|
|
temperature: float = 0.1) -> Tuple[str, float]: |
|
|
""" |
|
|
Make a completion request to the model. |
|
|
|
|
|
Args: |
|
|
prompt: The input prompt |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Tuple of (response_text, latency_ms) |
|
|
""" |
|
|
payload = { |
|
|
"prompt": prompt, |
|
|
"max_tokens": max_tokens, |
|
|
"temperature": temperature |
|
|
} |
|
|
|
|
|
headers = {'Content-Type': 'application/json'} |
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
response = requests.post( |
|
|
f"{self.base_url}/completion", |
|
|
json=payload, |
|
|
headers=headers, |
|
|
timeout=self.timeout |
|
|
) |
|
|
|
|
|
latency = (time.time() - start_time) * 1000 |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
return result.get('content', ''), latency |
|
|
else: |
|
|
return f"Error: HTTP {response.status_code}", latency |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
latency = (time.time() - start_time) * 1000 |
|
|
return f"Error: {e}", latency |
|
|
|
|
|
def extract_entities(self, text: str, instruction: Optional[str] = None, |
|
|
max_tokens: int = 512, temperature: float = 0.1) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Extract named entities from text. |
|
|
|
|
|
Args: |
|
|
text: Input text to analyze |
|
|
instruction: Custom instruction (uses default if None) |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
List of entity dictionaries with text and metadata |
|
|
""" |
|
|
if instruction is None: |
|
|
instruction = self.default_instruction |
|
|
|
|
|
prompt = f"{instruction}\n\nInput: {text}\n\nResponse: " |
|
|
|
|
|
response_text, latency = self._make_request(prompt, max_tokens, temperature) |
|
|
|
|
|
if response_text.startswith("Error"): |
|
|
return [] |
|
|
|
|
|
|
|
|
entities = self._parse_entity_response(response_text) |
|
|
|
|
|
|
|
|
for entity in entities: |
|
|
entity.update({ |
|
|
'confidence': 1.0, |
|
|
'latency_ms': latency |
|
|
}) |
|
|
|
|
|
return entities |
|
|
|
|
|
def extract_entities_batch(self, texts: List[str], instruction: Optional[str] = None, |
|
|
max_tokens: int = 512, temperature: float = 0.1) -> List[List[Dict[str, Any]]]: |
|
|
""" |
|
|
Extract named entities from multiple texts. |
|
|
|
|
|
Args: |
|
|
texts: List of input texts to analyze |
|
|
instruction: Custom instruction (uses default if None) |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
List of entity lists, one per input text |
|
|
""" |
|
|
results = [] |
|
|
for text in texts: |
|
|
entities = self.extract_entities(text, instruction, max_tokens, temperature) |
|
|
results.append(entities) |
|
|
|
|
|
return results |
|
|
|
|
|
def _parse_entity_response(self, response_text: str) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Parse the model's numbered list response into structured entities. |
|
|
|
|
|
Args: |
|
|
response_text: Raw model response |
|
|
|
|
|
Returns: |
|
|
List of entity dictionaries |
|
|
""" |
|
|
entities = [] |
|
|
|
|
|
|
|
|
response_text = response_text.strip() |
|
|
|
|
|
|
|
|
lines = response_text.split('\n') |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
numbered_match = re.match(r'^\d+\.\s*(.+?)(?:\s*-\s*.+)?$', line) |
|
|
if numbered_match: |
|
|
entity_text = numbered_match.group(1).strip() |
|
|
|
|
|
entity_text = re.sub(r'[.,;:!?]$', '', entity_text).strip() |
|
|
|
|
|
if entity_text and len(entity_text) > 1 and not entity_text.lower() in ['the', 'and', 'or', 'but', 'for', 'with']: |
|
|
entities.append({ |
|
|
'text': entity_text, |
|
|
'type': 'ENTITY', |
|
|
'start': 0, |
|
|
'end': 0 |
|
|
}) |
|
|
|
|
|
return entities |
|
|
|
|
|
def health_check(self) -> bool: |
|
|
""" |
|
|
Check if the model server is healthy and responding. |
|
|
|
|
|
Returns: |
|
|
True if server is healthy, False otherwise |
|
|
""" |
|
|
try: |
|
|
response = requests.get(f"{self.base_url}/health", timeout=5) |
|
|
return response.status_code == 200 |
|
|
except: |
|
|
return False |
|
|
|
|
|
def get_model_info(self) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Get information about the loaded model. |
|
|
|
|
|
Returns: |
|
|
Model information dictionary or None if unavailable |
|
|
""" |
|
|
try: |
|
|
response = requests.get(f"{self.base_url}/v1/models", timeout=5) |
|
|
if response.status_code == 200: |
|
|
return response.json() |
|
|
except: |
|
|
pass |
|
|
return None |
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Command-line interface for NER inference. |
|
|
""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='NER-Small Inference Client') |
|
|
parser.add_argument('text', help='Text to analyze for named entities') |
|
|
parser.add_argument('--url', default='http://127.0.0.1:8000', |
|
|
help='Model server URL (default: http://127.0.0.1:8000)') |
|
|
parser.add_argument('--max-tokens', type=int, default=512, |
|
|
help='Maximum tokens to generate (default: 512)') |
|
|
parser.add_argument('--temperature', type=float, default=0.1, |
|
|
help='Sampling temperature (default: 0.1)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
client = NERClient(args.url) |
|
|
|
|
|
|
|
|
if not client.health_check(): |
|
|
print(f"❌ Error: Cannot connect to model server at {args.url}") |
|
|
print("Make sure the llama.cpp server is running with the NER-Small model.") |
|
|
return 1 |
|
|
|
|
|
|
|
|
entities = client.extract_entities( |
|
|
args.text, |
|
|
max_tokens=args.max_tokens, |
|
|
temperature=args.temperature |
|
|
) |
|
|
|
|
|
|
|
|
print(f"📝 Input Text: {args.text}") |
|
|
print(f"🎯 Found {len(entities)} entities:") |
|
|
print() |
|
|
|
|
|
if entities: |
|
|
for i, entity in enumerate(entities, 1): |
|
|
print(f"{i}. {entity['text']} (Type: {entity['type']})") |
|
|
else: |
|
|
print("No entities found.") |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|