71 lines
2.5 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
import sys
from typing import List
from elasticsearch import Elasticsearch
from es_cluster_config import (
CLUSTER_URL,
INDEX_NAME,
PASSWORD,
USER,
)
from unstructured.embed.huggingface import HuggingFaceEmbeddingConfig, HuggingFaceEmbeddingEncoder
N_ELEMENTS = 1404
def embeddings_for_text(text: str) -> List[float]:
embedding_encoder = HuggingFaceEmbeddingEncoder(config=HuggingFaceEmbeddingConfig())
return embedding_encoder.embed_query(text)
def query(client: Elasticsearch, search_text: str):
# Query the index using the appropriate embedding vector for given query text
search_vector = embeddings_for_text(search_text)
# Constructing the search query
query = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embeddings') + 1.0",
"params": {"query_vector": search_vector},
},
}
}
}
return client.search(index=INDEX_NAME, body=query)
if __name__ == "__main__":
print(f"Checking contents of index" f"{INDEX_NAME} at {CLUSTER_URL}")
print("Connecting to the Elasticsearch cluster.")
client = Elasticsearch(CLUSTER_URL, basic_auth=(USER, PASSWORD), request_timeout=30)
print(client.info())
count = int(client.cat.count(index=INDEX_NAME, format="json")[0]["count"])
try:
assert count == N_ELEMENTS
except AssertionError:
sys.exit(
"Elasticsearch dest check failed:"
f"got {count} items in index, expected {N_ELEMENTS} items in index."
)
print(f"Elasticsearch destination test was successful with {count} items being uploaded.")
# Query the index using the appropriate embedding vector for given query text
# Verify that the top 1 result matches the expected chunk by checking the start text
print("Testing query to the embedded index.")
query_text = (
"A gathering of Russian nobility and merchants in historic uniforms, "
"discussing the Emperor's manifesto with a mix of solemn anticipation "
"and everyday concerns, while Pierre, dressed in a tight nobleman's uniform, "
"ponders the French Revolution and social contracts amidst the crowd."
)
query_response = query(client, query_text)
assert query_response["hits"]["hits"][0]["_source"]["text"].startswith("CHAPTER XXII")
print("Query to the embedded index was successful and returned the expected result.")