This commit is contained in:
Gabriel Nieves 2025-03-26 16:25:08 +00:00
parent b7b2b562ce
commit be21d994c0

View File

@ -38,7 +38,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
_database_name: str _database_name: str
_container_name: str _container_name: str
_encoding: str _encoding: str
_no_id_prefixes: list[str] _no_id_prefixes: set[str]
def __init__( def __init__(
self, self,
@ -71,7 +71,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
if cosmosdb_account_url if cosmosdb_account_url
else None else None
) )
self._no_id_prefixes = [] self._no_id_prefixes = set()
log.info( log.info(
"creating cosmosdb storage with account: %s and database: %s and container: %s", "creating cosmosdb storage with account: %s and database: %s and container: %s",
self._cosmosdb_account_name, self._cosmosdb_account_name,
@ -198,6 +198,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
) -> Any: ) -> Any:
"""Fetch all items in a container that match the given key.""" """Fetch all items in a container that match the given key."""
try: try:
print(f"Looking in cosmos-db for key={key}")
if not self._database_client or not self._container_client: if not self._database_client or not self._container_client:
return None return None
if as_bytes: if as_bytes:
@ -223,8 +224,9 @@ class CosmosDBPipelineStorage(PipelineStorage):
return items_df.to_parquet() return items_df.to_parquet()
item = self._container_client.read_item(item=key, partition_key=key) item = self._container_client.read_item(item=key, partition_key=key)
print(f"Located item={item} for key={key}")
item_body = item.get("body") item_body = item.get("body")
return json.dumps(item_body) return item_body if (key == "graph.graphml") else json.dumps(item_body)
except Exception: except Exception:
log.exception("Error reading item %s", key) log.exception("Error reading item %s", key)
return None return None
@ -240,8 +242,10 @@ class CosmosDBPipelineStorage(PipelineStorage):
raise ValueError(msg) # noqa: TRY301 raise ValueError(msg) # noqa: TRY301
# value represents a parquet file # value represents a parquet file
if isinstance(value, bytes): if isinstance(value, bytes):
print(f"Setter function called for key={key}; self._no_id_prefixes={self._no_id_prefixes}")
prefix = self._get_prefix(key) prefix = self._get_prefix(key)
value_df = pd.read_parquet(BytesIO(value)) value_df = pd.read_parquet(BytesIO(value))
print(f"Setter function value_df.keys={value_df.keys()}; value_df={value_df.iloc[0]}")
value_json = value_df.to_json( value_json = value_df.to_json(
orient="records", lines=False, force_ascii=False orient="records", lines=False, force_ascii=False
) )
@ -249,21 +253,29 @@ class CosmosDBPipelineStorage(PipelineStorage):
log.exception("Error converting output %s to json", key) log.exception("Error converting output %s to json", key)
else: else:
cosmosdb_item_list = json.loads(value_json) cosmosdb_item_list = json.loads(value_json)
__i = 0
for index, cosmosdb_item in enumerate(cosmosdb_item_list): for index, cosmosdb_item in enumerate(cosmosdb_item_list):
# If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index # If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index
# TODO: Figure out optimal way to handle missing id keys in input dataframes # TODO: Figure out optimal way to handle missing id keys in input dataframes
if "id" not in cosmosdb_item: if "id" not in cosmosdb_item:
prefixed_id = f"{prefix}:{index}" prefixed_id = f"{prefix}:{index}"
self._no_id_prefixes.append(prefix) self._no_id_prefixes.add(prefix)
else: else:
if prefix in self._no_id_prefixes:
self._no_id_prefixes.remove(prefix)
prefixed_id = f"{prefix}:{cosmosdb_item['id']}" prefixed_id = f"{prefix}:{cosmosdb_item['id']}"
if __i==0:
print(f"index={index}; cosmosdb_item={cosmosdb_item}")
__i+=1
cosmosdb_item["id"] = prefixed_id cosmosdb_item["id"] = prefixed_id
self._container_client.upsert_item(body=cosmosdb_item) self._container_client.upsert_item(body=cosmosdb_item)
# value represents a cache output or stats.json # value represents a cache output or stats.json
else: else:
cosmosdb_item = { cosmosdb_item = {
"id": key, "id": key,
"body": json.loads(value), "body": value if (key == "graph.graphml") else json.loads(value),
} }
self._container_client.upsert_item(body=cosmosdb_item) self._container_client.upsert_item(body=cosmosdb_item)
except Exception: except Exception: