2024-04-11 00:08:24 +01:00
|
|
|
#!/usr/bin/env python
|
2024-02-23 12:50:50 -08:00
|
|
|
import click
|
|
|
|
from astrapy.db import AstraDB
|
|
|
|
|
|
|
|
|
2024-04-11 00:08:24 +01:00
|
|
|
def get_client(token, api_endpoint, collection_name) -> AstraDB:
|
|
|
|
# Initialize our vector db
|
|
|
|
astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
|
|
|
|
astra_db_collection = astra_db.collection(collection_name)
|
|
|
|
return astra_db, astra_db_collection
|
|
|
|
|
|
|
|
|
|
|
|
@click.group(name="astra-ingest")
|
2024-02-23 12:50:50 -08:00
|
|
|
@click.option("--token", type=str)
|
|
|
|
@click.option("--api-endpoint", type=str)
|
|
|
|
@click.option("--collection-name", type=str, default="collection_test")
|
|
|
|
@click.option("--embedding-dimension", type=int, default=384)
|
2024-04-11 00:08:24 +01:00
|
|
|
@click.pass_context
|
|
|
|
def cli(ctx, token: str, api_endpoint: str, collection_name: str, embedding_dimension: int):
|
|
|
|
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
|
|
|
|
ctx.ensure_object(dict)
|
|
|
|
|
|
|
|
ctx.obj["db"], ctx.obj["collection"] = get_client(token, api_endpoint, collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
@cli.command()
|
|
|
|
@click.pass_context
|
|
|
|
def check(ctx):
|
|
|
|
collection_name = ctx.parent.params["collection_name"]
|
2024-02-23 12:50:50 -08:00
|
|
|
print(f"Checking contents of Astra DB collection: {collection_name}")
|
|
|
|
|
2024-04-11 00:08:24 +01:00
|
|
|
astra_db_collection = ctx.obj["collection"]
|
2024-02-23 12:50:50 -08:00
|
|
|
|
|
|
|
# Tally up the embeddings
|
|
|
|
docs_count = astra_db_collection.count_documents()
|
|
|
|
number_of_embeddings = docs_count["status"]["count"]
|
|
|
|
|
|
|
|
# Print the results
|
|
|
|
expected_embeddings = 3
|
|
|
|
print(
|
|
|
|
f"# of embeddings in collection vs expected: {number_of_embeddings}/{expected_embeddings}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check that the assertion is true
|
|
|
|
assert number_of_embeddings == expected_embeddings, (
|
|
|
|
f"Number of rows in generated table ({number_of_embeddings})"
|
|
|
|
f"doesn't match expected value: {expected_embeddings}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Grab an embedding from the collection and search against itself
|
|
|
|
# Should get the same document back as the most similar
|
|
|
|
find_one = astra_db_collection.find_one()
|
|
|
|
random_vector = find_one["data"]["document"]["$vector"]
|
|
|
|
random_text = find_one["data"]["document"]["content"]
|
|
|
|
|
|
|
|
# Perform a similarity search
|
|
|
|
find_result = astra_db_collection.vector_find(random_vector, limit=1)
|
|
|
|
|
|
|
|
# Check that we retrieved the coded cleats copy data
|
|
|
|
assert find_result[0]["content"] == random_text
|
|
|
|
print("Vector search complete.")
|
|
|
|
|
|
|
|
|
2024-04-11 00:08:24 +01:00
|
|
|
@cli.command()
|
|
|
|
@click.pass_context
|
|
|
|
def down(ctx):
|
|
|
|
astra_db = ctx.obj["db"]
|
|
|
|
collection_name = ctx.parent.params["collection_name"]
|
|
|
|
print(f"deleting collection: {collection_name}")
|
|
|
|
astra_db.delete_collection(collection_name)
|
|
|
|
print(f"successfully deleted collection: {collection_name}")
|
2024-02-23 12:50:50 -08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-04-11 00:08:24 +01:00
|
|
|
cli()
|