This commit is contained in:
Gabriel Nieves 2025-03-26 16:25:08 +00:00
parent b7b2b562ce
commit be21d994c0

View File

@ -38,7 +38,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
_database_name: str
_container_name: str
_encoding: str
_no_id_prefixes: list[str]
_no_id_prefixes: set[str]
def __init__(
self,
@ -71,7 +71,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
if cosmosdb_account_url
else None
)
self._no_id_prefixes = []
self._no_id_prefixes = set()
log.info(
"creating cosmosdb storage with account: %s and database: %s and container: %s",
self._cosmosdb_account_name,
@ -198,6 +198,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
) -> Any:
"""Fetch all items in a container that match the given key."""
try:
print(f"Looking in cosmos-db for key={key}")
if not self._database_client or not self._container_client:
return None
if as_bytes:
@ -223,8 +224,9 @@ class CosmosDBPipelineStorage(PipelineStorage):
return items_df.to_parquet()
item = self._container_client.read_item(item=key, partition_key=key)
print(f"Located item={item} for key={key}")
item_body = item.get("body")
return json.dumps(item_body)
return item_body if (key == "graph.graphml") else json.dumps(item_body)
except Exception:
log.exception("Error reading item %s", key)
return None
@ -240,8 +242,10 @@ class CosmosDBPipelineStorage(PipelineStorage):
raise ValueError(msg) # noqa: TRY301
# value represents a parquet file
if isinstance(value, bytes):
print(f"Setter function called for key={key}; self._no_id_prefixes={self._no_id_prefixes}")
prefix = self._get_prefix(key)
value_df = pd.read_parquet(BytesIO(value))
print(f"Setter function value_df.keys={value_df.keys()}; value_df={value_df.iloc[0]}")
value_json = value_df.to_json(
orient="records", lines=False, force_ascii=False
)
@ -249,21 +253,29 @@ class CosmosDBPipelineStorage(PipelineStorage):
log.exception("Error converting output %s to json", key)
else:
cosmosdb_item_list = json.loads(value_json)
__i = 0
for index, cosmosdb_item in enumerate(cosmosdb_item_list):
# If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index
# TODO: Figure out optimal way to handle missing id keys in input dataframes
if "id" not in cosmosdb_item:
prefixed_id = f"{prefix}:{index}"
self._no_id_prefixes.append(prefix)
self._no_id_prefixes.add(prefix)
else:
if prefix in self._no_id_prefixes:
self._no_id_prefixes.remove(prefix)
prefixed_id = f"{prefix}:{cosmosdb_item['id']}"
if __i==0:
print(f"index={index}; cosmosdb_item={cosmosdb_item}")
__i+=1
cosmosdb_item["id"] = prefixed_id
self._container_client.upsert_item(body=cosmosdb_item)
# value represents a cache output or stats.json
else:
cosmosdb_item = {
"id": key,
"body": json.loads(value),
"body": value if (key == "graph.graphml") else json.loads(value),
}
self._container_client.upsert_item(body=cosmosdb_item)
except Exception: