fix: collection deletion for AstraDB test (#2869)

This PR:
- Fixes occasional collection deletion failures for AstraDB via putting
collection deletion statements inside a trap statement. It uses click
commands to do this.

Testing:
- Run ingest astradb destination test
This commit is contained in:
Ahmet Melek 2024-04-11 00:08:24 +01:00 committed by GitHub
parent 23edc4ad71
commit 6fd29ea77c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 10 deletions

View File

@ -30,6 +30,11 @@ source "$SCRIPT_DIR"/cleanup.sh
function cleanup() {
cleanup_dir "$OUTPUT_DIR"
cleanup_dir "$WORK_DIR"
python "$SCRIPT_DIR"/python/test-ingest-astra-output.py \
--token "$ASTRA_DB_TOKEN" \
--api-endpoint "$ASTRA_DB_ENDPOINT" \
--collection-name "$COLLECTION_NAME" down
}
trap cleanup EXIT
@ -52,4 +57,7 @@ PYTHONPATH=. ./unstructured/ingest/main.py \
--collection-name "$COLLECTION_NAME" \
--embedding-dimension "$EMBEDDING_DIMENSION"
python "$SCRIPT_DIR"/python/test-ingest-astra-output.py --token "$ASTRA_DB_TOKEN" --api-endpoint "$ASTRA_DB_ENDPOINT" --collection-name "$COLLECTION_NAME"
python "$SCRIPT_DIR"/python/test-ingest-astra-output.py \
--token "$ASTRA_DB_TOKEN" \
--api-endpoint "$ASTRA_DB_ENDPOINT" \
--collection-name "$COLLECTION_NAME" check

View File

@ -1,18 +1,35 @@
#!/usr/bin/env python
import click
from astrapy.db import AstraDB
@click.command()
def get_client(token, api_endpoint, collection_name) -> AstraDB:
# Initialize our vector db
astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
astra_db_collection = astra_db.collection(collection_name)
return astra_db, astra_db_collection
@click.group(name="astra-ingest")
@click.option("--token", type=str)
@click.option("--api-endpoint", type=str)
@click.option("--collection-name", type=str, default="collection_test")
@click.option("--embedding-dimension", type=int, default=384)
def run_check(token, api_endpoint, collection_name, embedding_dimension):
@click.pass_context
def cli(ctx, token: str, api_endpoint: str, collection_name: str, embedding_dimension: int):
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
ctx.ensure_object(dict)
ctx.obj["db"], ctx.obj["collection"] = get_client(token, api_endpoint, collection_name)
@cli.command()
@click.pass_context
def check(ctx):
collection_name = ctx.parent.params["collection_name"]
print(f"Checking contents of Astra DB collection: {collection_name}")
# Initialize our vector db
astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
astra_db_collection = astra_db.collection(collection_name)
astra_db_collection = ctx.obj["collection"]
# Tally up the embeddings
docs_count = astra_db_collection.count_documents()
@ -43,11 +60,16 @@ def run_check(token, api_endpoint, collection_name, embedding_dimension):
assert find_result[0]["content"] == random_text
print("Vector search complete.")
# Clean up the collection
astra_db.delete_collection(collection_name)
print("Table deletion complete")
@cli.command()
@click.pass_context
def down(ctx):
astra_db = ctx.obj["db"]
collection_name = ctx.parent.params["collection_name"]
print(f"deleting collection: {collection_name}")
astra_db.delete_collection(collection_name)
print(f"successfully deleted collection: {collection_name}")
if __name__ == "__main__":
run_check()
cli()