Roman Isecke e0f4374386
Roman/bugfix conflicting event loop ingest (#3264)
### Description
In use cases where an external system (such as code being run in a
jupyter notebook) already has a running event loop, run the async code
in a dedicated thread pool to not conflict with the existing event loop.

This also has a variety of fixes that were found when putting together a
demo leveraging the elasticsearch destination connector
2024-06-24 18:47:37 +00:00

134 lines
4.9 KiB
Python
Executable File

#!/usr/bin/env python3
from time import sleep, time
from typing import List
import click
from elasticsearch import Elasticsearch
from es_cluster_config import (
CLUSTER_URL,
INDEX_NAME,
PASSWORD,
USER,
)
from unstructured.embed.huggingface import HuggingFaceEmbeddingConfig, HuggingFaceEmbeddingEncoder
def embeddings_for_text(text: str) -> List[float]:
embedding_encoder = HuggingFaceEmbeddingEncoder(config=HuggingFaceEmbeddingConfig())
return embedding_encoder.embed_query(text)
def query(client: Elasticsearch, search_text: str):
# Query the index using the appropriate embedding vector for given query text
search_vector = embeddings_for_text(search_text)
# Constructing the search query
query = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embeddings') + 1.0",
"params": {"query_vector": search_vector},
},
}
}
}
return client.search(index=INDEX_NAME, body=query)
def validate_count(client: Elasticsearch, num_elements: int):
print(f"Validating that the count of documents in index {INDEX_NAME} is {num_elements}")
count = int(client.cat.count(index=INDEX_NAME, format="json")[0]["count"])
consistent = False
consistent_count = 1
desired_consistent_count = 5
timeout = 60
sleep_interval = 1
start_time = time()
print(f"initial count returned: {count}")
while not consistent and time() - start_time < timeout:
new_count = int(client.cat.count(index=INDEX_NAME, format="json")[0]["count"])
print(f"latest count returned: {new_count}")
if new_count == count:
consistent_count += 1
else:
count = new_count
consistent_count = 1
sleep(sleep_interval)
if consistent_count >= desired_consistent_count:
consistent = True
if not consistent:
raise TimeoutError(f"failed to get consistent count after {timeout}s")
assert count == num_elements, (
f"Elasticsearch dest check failed: got {count} items in index, "
f"expected {num_elements} items in index."
)
print(f"Elasticsearch destination test was successful with {count} items being uploaded.")
def get_embeddings_len(client: Elasticsearch) -> int:
res = client.search(index=INDEX_NAME, size=1, query={"match_all": {}})
return len(res["hits"]["hits"][0]["_source"]["embeddings"])
def validate_embeddings(client: Elasticsearch, embeddings: list[float]):
# Query the index using the appropriate embedding vector for given query text
# Verify that the top 1 result matches the expected chunk by checking the start text
print("Testing query to the embedded index.")
es_embeddings_len = get_embeddings_len(client=client)
assert len(embeddings) == es_embeddings_len, (
f"length of embeddings ({len(embeddings)}) doesn't "
f"match what exists in Elasticsearch ({es_embeddings_len})"
)
query_string = {
"field": "embeddings",
"query_vector": embeddings,
"k": 10,
"num_candidates": 10,
}
query_response = client.search(index=INDEX_NAME, knn=query_string)
response_found = query_response["hits"]["hits"][0]["_source"]
assert response_found["embeddings"] == embeddings
print("Query to the embedded index was successful and returned the expected result.")
def validate(num_elements: int, embeddings: list[float]):
print(f"Checking contents of index" f"{INDEX_NAME} at {CLUSTER_URL}")
print("Connecting to the Elasticsearch cluster.")
client = Elasticsearch(CLUSTER_URL, basic_auth=(USER, PASSWORD), request_timeout=30)
print(client.info())
validate_count(client=client, num_elements=num_elements)
validate_embeddings(client=client, embeddings=embeddings)
def parse_embeddings(embeddings_str: str) -> list[float]:
if embeddings_str.startswith("["):
embeddings_str = embeddings_str[1:]
if embeddings_str.endswith("]"):
embeddings_str = embeddings_str[:-1]
embeddings_split = embeddings_str.split(",")
embeddings_split = [e.strip() for e in embeddings_split]
return [float(e) for e in embeddings_split]
@click.command()
@click.option(
"--num-elements", type=int, required=True, help="The expected number of elements to exist"
)
@click.option("--embeddings", type=str, required=True, help="List of embeddings to test")
def run_validation(num_elements: int, embeddings: str):
try:
parsed_embeddings = parse_embeddings(embeddings_str=embeddings)
except ValueError as e:
raise TypeError(
f"failed to parse embeddings string into list of float: {embeddings}"
) from e
validate(num_elements=num_elements, embeddings=parsed_embeddings)
if __name__ == "__main__":
run_validation()