mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-07-03 15:11:30 +00:00
151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
import json
|
||
|
import time
|
||
|
|
||
|
import click
|
||
|
from pymongo.mongo_client import MongoClient
|
||
|
from pymongo.operations import SearchIndexModel
|
||
|
|
||
|
|
||
|
def get_client(uri: str) -> MongoClient:
|
||
|
client = MongoClient(uri)
|
||
|
client.admin.command("ping")
|
||
|
print("Successfully connected to MongoDB")
|
||
|
return client
|
||
|
|
||
|
|
||
|
@click.group(name="mongo-ingest")
|
||
|
@click.option("--uri", type=str, required=True)
|
||
|
@click.option("--database", type=str, required=True)
|
||
|
@click.option("--collection", type=str, required=True)
|
||
|
@click.pass_context
|
||
|
def cli(ctx, uri: str, database: str, collection: str):
|
||
|
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
|
||
|
# by means other than the `if` block below)
|
||
|
ctx.ensure_object(dict)
|
||
|
|
||
|
ctx.obj["client"] = get_client(uri)
|
||
|
|
||
|
|
||
|
@cli.command()
|
||
|
@click.pass_context
|
||
|
def up(ctx):
|
||
|
client = ctx.obj["client"]
|
||
|
collection_name = ctx.parent.params["collection"]
|
||
|
db = client[ctx.parent.params["database"]]
|
||
|
print(f"creating collection {collection_name}")
|
||
|
collection = db.create_collection(name=collection_name)
|
||
|
print(f"successfully created collection: {collection_name}")
|
||
|
if "embeddings" in [c["name"] for c in collection.list_search_indexes()]:
|
||
|
print("search index already exists, skipping creation")
|
||
|
return
|
||
|
|
||
|
search_index_name = collection.create_search_index(
|
||
|
model=SearchIndexModel(
|
||
|
name="embeddings",
|
||
|
definition={
|
||
|
"mappings": {
|
||
|
"dynamic": True,
|
||
|
"fields": {
|
||
|
"embeddings": [
|
||
|
{"type": "knnVector", "dimensions": 384, "similarity": "euclidean"}
|
||
|
]
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
)
|
||
|
)
|
||
|
print(f"Added search index: {search_index_name}")
|
||
|
|
||
|
|
||
|
@cli.command()
|
||
|
@click.pass_context
|
||
|
def down(ctx):
|
||
|
collection_name = ctx.parent.params["collection"]
|
||
|
client = ctx.obj["client"]
|
||
|
db = client[ctx.parent.params["database"]]
|
||
|
if collection_name not in db.list_collection_names():
|
||
|
print(
|
||
|
"collection name {} does not exist amongst those in database: {}, "
|
||
|
"skipping deletion".format(collection_name, ", ".join(db.list_collection_names()))
|
||
|
)
|
||
|
return
|
||
|
print(f"deleting collection: {collection_name}")
|
||
|
collection = db[collection_name]
|
||
|
collection.drop()
|
||
|
print(f"successfully deleted collection: {collection}")
|
||
|
|
||
|
|
||
|
@cli.command()
|
||
|
@click.option("--expected-records", type=int, required=True)
|
||
|
@click.pass_context
|
||
|
def check(ctx, expected_records: int):
|
||
|
client = ctx.obj["client"]
|
||
|
db = client[ctx.parent.params["database"]]
|
||
|
collection = db[ctx.parent.params["collection"]]
|
||
|
count = collection.count_documents(filter={})
|
||
|
print(f"checking the count in the db ({count}) matches what's expected: {expected_records}")
|
||
|
assert (
|
||
|
count == expected_records
|
||
|
), f"expected count ({expected_records}) does not match how many records were found: {count}"
|
||
|
print("successfully checked that the expected number of records was found in the db!")
|
||
|
|
||
|
|
||
|
@cli.command()
|
||
|
@click.option("--output-json", type=click.File())
|
||
|
@click.pass_context
|
||
|
def check_vector(ctx, output_json):
|
||
|
"""
|
||
|
Checks the functionality of the vector search index by getting a score based on the
|
||
|
exact result of one of the embeddings. Makes sure that the search index itself has finished
|
||
|
indexing before running a query, then validated that the first item in the returned sorted
|
||
|
list has a score of 1.0 given that the exact embedding is used as a match, and all others
|
||
|
have a score less than 1.0.
|
||
|
"""
|
||
|
# Get the first embedding from the output file
|
||
|
json_content = json.load(output_json)
|
||
|
exact_embedding = json_content[0]["embeddings"]
|
||
|
client = ctx.obj["client"]
|
||
|
db = client[ctx.parent.params["database"]]
|
||
|
collection = db[ctx.parent.params["collection"]]
|
||
|
vector_index_name = "embeddings"
|
||
|
status = [ind for ind in collection.list_search_indexes() if ind["name"] == vector_index_name][
|
||
|
0
|
||
|
].get("status")
|
||
|
max_attempts = 30
|
||
|
attempts = 0
|
||
|
wait_seconds = 5
|
||
|
while status != "READY" and attempts < max_attempts:
|
||
|
print(
|
||
|
f"status of search index: {status}, waiting another {wait_seconds} "
|
||
|
f"seconds for it to be ready"
|
||
|
)
|
||
|
attempts += 1
|
||
|
time.sleep(wait_seconds)
|
||
|
status = [
|
||
|
ind for ind in collection.list_search_indexes() if ind["name"] == vector_index_name
|
||
|
][0].get("status")
|
||
|
print(f"search index is ready to go ({status}), checking vector content")
|
||
|
pipeline = [
|
||
|
{
|
||
|
"$vectorSearch": {
|
||
|
"index": "embeddings",
|
||
|
"path": "embeddings",
|
||
|
"queryVector": exact_embedding,
|
||
|
"numCandidates": 150,
|
||
|
"limit": 10,
|
||
|
},
|
||
|
},
|
||
|
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
|
||
|
]
|
||
|
result = list(collection.aggregate(pipeline=pipeline))
|
||
|
print(f"vector query result: {result}")
|
||
|
assert result[0]["score"] == 1.0, "score detected should be 1: {}".format(result[0]["score"])
|
||
|
for r in result[1:]:
|
||
|
assert r["score"] < 1.0, "score detected should be less than 1: {}".format(r["score"])
|
||
|
print("successfully validated vector content!")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
cli()
|