mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-06-27 02:30:08 +00:00

### Description In use cases where an external system (such as code being run in a jupyter notebook) already has a running event loop, run the async code in a dedicated thread pool to not conflict with the existing event loop. This also has a variety of fixes that were found when putting together a demo leveraging the elasticsearch destination connector
134 lines
4.9 KiB
Python
Executable File
134 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from time import sleep, time
|
|
from typing import List
|
|
|
|
import click
|
|
from elasticsearch import Elasticsearch
|
|
from es_cluster_config import (
|
|
CLUSTER_URL,
|
|
INDEX_NAME,
|
|
PASSWORD,
|
|
USER,
|
|
)
|
|
|
|
from unstructured.embed.huggingface import HuggingFaceEmbeddingConfig, HuggingFaceEmbeddingEncoder
|
|
|
|
|
|
def embeddings_for_text(text: str) -> List[float]:
|
|
embedding_encoder = HuggingFaceEmbeddingEncoder(config=HuggingFaceEmbeddingConfig())
|
|
return embedding_encoder.embed_query(text)
|
|
|
|
|
|
def query(client: Elasticsearch, search_text: str):
|
|
# Query the index using the appropriate embedding vector for given query text
|
|
search_vector = embeddings_for_text(search_text)
|
|
# Constructing the search query
|
|
query = {
|
|
"query": {
|
|
"script_score": {
|
|
"query": {"match_all": {}},
|
|
"script": {
|
|
"source": "cosineSimilarity(params.query_vector, 'embeddings') + 1.0",
|
|
"params": {"query_vector": search_vector},
|
|
},
|
|
}
|
|
}
|
|
}
|
|
return client.search(index=INDEX_NAME, body=query)
|
|
|
|
|
|
def validate_count(client: Elasticsearch, num_elements: int):
|
|
print(f"Validating that the count of documents in index {INDEX_NAME} is {num_elements}")
|
|
count = int(client.cat.count(index=INDEX_NAME, format="json")[0]["count"])
|
|
consistent = False
|
|
consistent_count = 1
|
|
desired_consistent_count = 5
|
|
timeout = 60
|
|
sleep_interval = 1
|
|
start_time = time()
|
|
print(f"initial count returned: {count}")
|
|
while not consistent and time() - start_time < timeout:
|
|
new_count = int(client.cat.count(index=INDEX_NAME, format="json")[0]["count"])
|
|
print(f"latest count returned: {new_count}")
|
|
if new_count == count:
|
|
consistent_count += 1
|
|
else:
|
|
count = new_count
|
|
consistent_count = 1
|
|
sleep(sleep_interval)
|
|
if consistent_count >= desired_consistent_count:
|
|
consistent = True
|
|
if not consistent:
|
|
raise TimeoutError(f"failed to get consistent count after {timeout}s")
|
|
assert count == num_elements, (
|
|
f"Elasticsearch dest check failed: got {count} items in index, "
|
|
f"expected {num_elements} items in index."
|
|
)
|
|
print(f"Elasticsearch destination test was successful with {count} items being uploaded.")
|
|
|
|
|
|
def get_embeddings_len(client: Elasticsearch) -> int:
|
|
res = client.search(index=INDEX_NAME, size=1, query={"match_all": {}})
|
|
return len(res["hits"]["hits"][0]["_source"]["embeddings"])
|
|
|
|
|
|
def validate_embeddings(client: Elasticsearch, embeddings: list[float]):
|
|
# Query the index using the appropriate embedding vector for given query text
|
|
# Verify that the top 1 result matches the expected chunk by checking the start text
|
|
print("Testing query to the embedded index.")
|
|
es_embeddings_len = get_embeddings_len(client=client)
|
|
assert len(embeddings) == es_embeddings_len, (
|
|
f"length of embeddings ({len(embeddings)}) doesn't "
|
|
f"match what exists in Elasticsearch ({es_embeddings_len})"
|
|
)
|
|
query_string = {
|
|
"field": "embeddings",
|
|
"query_vector": embeddings,
|
|
"k": 10,
|
|
"num_candidates": 10,
|
|
}
|
|
query_response = client.search(index=INDEX_NAME, knn=query_string)
|
|
response_found = query_response["hits"]["hits"][0]["_source"]
|
|
assert response_found["embeddings"] == embeddings
|
|
print("Query to the embedded index was successful and returned the expected result.")
|
|
|
|
|
|
def validate(num_elements: int, embeddings: list[float]):
|
|
print(f"Checking contents of index" f"{INDEX_NAME} at {CLUSTER_URL}")
|
|
|
|
print("Connecting to the Elasticsearch cluster.")
|
|
client = Elasticsearch(CLUSTER_URL, basic_auth=(USER, PASSWORD), request_timeout=30)
|
|
print(client.info())
|
|
validate_count(client=client, num_elements=num_elements)
|
|
validate_embeddings(client=client, embeddings=embeddings)
|
|
|
|
|
|
def parse_embeddings(embeddings_str: str) -> list[float]:
|
|
if embeddings_str.startswith("["):
|
|
embeddings_str = embeddings_str[1:]
|
|
if embeddings_str.endswith("]"):
|
|
embeddings_str = embeddings_str[:-1]
|
|
embeddings_split = embeddings_str.split(",")
|
|
embeddings_split = [e.strip() for e in embeddings_split]
|
|
return [float(e) for e in embeddings_split]
|
|
|
|
|
|
@click.command()
|
|
@click.option(
|
|
"--num-elements", type=int, required=True, help="The expected number of elements to exist"
|
|
)
|
|
@click.option("--embeddings", type=str, required=True, help="List of embeddings to test")
|
|
def run_validation(num_elements: int, embeddings: str):
|
|
try:
|
|
parsed_embeddings = parse_embeddings(embeddings_str=embeddings)
|
|
except ValueError as e:
|
|
raise TypeError(
|
|
f"failed to parse embeddings string into list of float: {embeddings}"
|
|
) from e
|
|
validate(num_elements=num_elements, embeddings=parsed_embeddings)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_validation()
|