colbert-xm-for-inference-api / test_endpoint.py
fdurant's picture
feat: add docstring to EndpointHandler.__call__ ; when multiple inputs are sent, the output now also contains a token_list k/v pair for easier human inspection
68b896e
raw
history blame
2.66 kB
import os
import pytest
import requests
URL = "http://localhost:4999/"
HEADERS = {"Content-Type": "application/json"}
def test_returns_200():
payload = {"inputs": "try me"}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
assert response.status_code == 200
def test_query_returns_expected_result():
query = "try me"
payload = {"inputs": query}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# print(response_data)
# Check structure and input
assert isinstance(response_data, list)
assert len(response_data) == 1
assert isinstance(response_data[0], dict)
assert response_data[0].get("input") == query
# Check query embedding (actually a list of embeddings, one per token in the query)
query_embedding = response_data[0].get("query_embedding")
assert isinstance(query_embedding, list)
assert len(query_embedding) == 32
# Check first of the token embeddings
first_token_embedding = query_embedding[0]
assert isinstance(first_token_embedding, list)
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)
def test_batch_returns_expected_result():
chunks = ["try me", "try me again and again and again"]
length_of_longest_chunk = 11 # Including special tokens and padding
doc_maxlen=512
payload = {"inputs": chunks}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# Check structure
assert isinstance(response_data, list)
assert len(response_data) == len(chunks)
for i, response_chunk in enumerate(response_data):
# Check input
assert response_chunk.get("input") == chunks[i]
# Check chunk embedding (actually a list of embeddings, one per token in the chunk)
chunk_embedding = response_chunk.get("chunk_embedding")
token_ids = response_chunk.get("token_ids")
assert isinstance(chunk_embedding, list)
assert len(chunk_embedding) == len(token_ids)
assert len(token_ids) == length_of_longest_chunk
assert len(token_ids) <= doc_maxlen
# Check first of the token embeddings
first_token_embedding = chunk_embedding[0]
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)
# Check token list
token_list = response_chunk.get("token_list")
assert len(token_ids) == len(token_list)
assert all(isinstance(token, str) for token in token_list)