diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py index 946266bcf1d..52c0c693926 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/es_mixin.py @@ -29,7 +29,7 @@ from typing import ( ) from urllib.parse import quote_plus -from pydantic import Field +from pydantic import Field, field_validator from typing_extensions import Annotated from metadata.generated.schema.entity.data.container import Container @@ -78,6 +78,22 @@ class HitsModel(BaseModel): ), ] + @field_validator("sort", mode="before") + def normalize_sort(cls, sort_value: list[str] | None): + """ + Return sort as a list of strings, regardless of the actual type. + if sort_field is set to `_score`, sort is a list of the score and the sort value. + + Input examples from ES: + - ["metric"] + - [1.234, "metric"] + + """ + if sort_value is None: + return None + + return [str(v) for v in sort_value] + class ESHits(BaseModel): """Elasticsearch hits model""" @@ -109,7 +125,7 @@ class ESMixin(Generic[T]): # sort_field needs to be unique for the pagination to work, so we can use the FQN paginate_query = ( "/search/query?q=&size={size}&deleted=false{filter}&index={index}{include_fields}" - "&sort_field=fullyQualifiedName{after}" + "&sort_field={sort_field}&sort_order={sort_order}{after}" ) @functools.lru_cache(maxsize=512) @@ -337,8 +353,30 @@ class ESMixin(Generic[T]): query_filter: Optional[str] = None, size: int = 100, include_fields: Optional[List[str]] = None, + sort_field: str = "fullyQualifiedName", + sort_order: str = "desc", ) -> Iterator[ESResponse]: - """Paginate through the ES results, ignoring individual errors""" + """ + Paginate through the ES results, ignoring individual errors. + + Args: + entity: The entity type to paginate + query_filter: Optional ES query filter in JSON format + size: Number of results per page (default: 100) + include_fields: Optional list of fields to include in ES response (optimization) + sort_field: Field to sort by (default: "fullyQualifiedName"). + Special field "_score" is supported for relevance sorting. + sort_order: Sort order, either "asc" or "desc" (default: "desc") + + Yields: + ESResponse objects containing paginated results + + Raises: + ValueError: If sort_order is not "asc" or "desc" + """ + if sort_order not in ("asc", "desc"): + raise ValueError(f"sort_order must be 'asc' or 'desc', got '{sort_order}'") + after: Optional[str] = None error_pages = 0 query = functools.partial( @@ -347,6 +385,8 @@ class ESMixin(Generic[T]): filter="&query_filter=" + quote_plus(query_filter) if query_filter else "", size=size, include_fields=self._get_include_fields_query(include_fields), + sort_field=sort_field, + sort_order=sort_order, ) while True: query_string = query( @@ -377,8 +417,28 @@ class ESMixin(Generic[T]): query_filter: Optional[str] = None, size: int = 100, fields: Optional[List[str]] = None, + sort_field: str = "fullyQualifiedName", + sort_order: str = "desc", ) -> Iterator[T]: - for response in self._paginate_es_internal(entity, query_filter, size): + """ + Paginate through Elasticsearch results and fetch full entities from the API. + + Args: + entity: The entity type to paginate + query_filter: Optional ES query filter in JSON format + size: Number of results per page (default: 100) + fields: Optional list of fields to fetch from the API for each entity + sort_field: Field to sort by (default: "fullyQualifiedName"). + Must be an indexed ES field. Special field "_score" is supported + for relevance-based sorting. + sort_order: Sort order, either "asc" or "desc" (default: "desc") + + Yields: + Full entity objects fetched from the OpenMetadata API + """ + for response in self._paginate_es_internal( + entity, query_filter, size, sort_order=sort_order, sort_field=sort_field + ): yield from self._yield_hits_from_api( response=response, entity=entity, fields=fields ) diff --git a/ingestion/tests/integration/ometa/test_ometa_es_api.py b/ingestion/tests/integration/ometa/test_ometa_es_api.py index 6de671b686e..59bfd155b37 100644 --- a/ingestion/tests/integration/ometa/test_ometa_es_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_es_api.py @@ -11,6 +11,7 @@ """ OMeta ES Mixin integration tests. The API needs to be up """ +import json import logging import time import uuid @@ -393,3 +394,127 @@ class OMetaESTest(TestCase): self.metadata.paginate_es(entity=Table, query_filter=query_filter, size=2) ) assert len(assets) == 5 + + def test_paginate_with_sorting(self): + for name in [f"paginating_table_{i}" for i in range(5)]: + self.metadata.create_or_update( + data=get_create_entity( + entity=Table, + name=EntityName(name), + reference=self.create_schema_entity.fullyQualifiedName, + ) + ) + + query_filter_obj = { + "query": { + "bool": { + "must": [ + { + "bool": { + "must": [ + { + "term": { + "service.name.keyword": ( + self.service_entity.name.root + ) + } + }, + ] + } + } + ] + } + } + } + + query_filter = json.dumps(query_filter_obj) + # Default sorting with fullyQualifiedName and desc order. + assets = list( + self.metadata.paginate_es(entity=Table, query_filter=query_filter, size=2) + ) + returned_table_names = [ + asset.name.root + for asset in assets + if asset.name.root.startswith("paginating_table_") + ] + assert returned_table_names == [ + "paginating_table_4", + "paginating_table_3", + "paginating_table_2", + "paginating_table_1", + "paginating_table_0", + ] + + # Asc order with fullyQualifiedName + + assets = list( + self.metadata.paginate_es( + entity=Table, query_filter=query_filter, size=2, sort_order="asc" + ) + ) + returned_table_names = [ + asset.name.root + for asset in assets + if asset.name.root.startswith("paginating_table_") + ] + assert returned_table_names == [ + "paginating_table_0", + "paginating_table_1", + "paginating_table_2", + "paginating_table_3", + "paginating_table_4", + ] + + # Sorting by _score should be supported without deserialization + # errors. This tests the fix for the _score bug where ES returns + # [float_score, fqn_value] instead of [fqn_value], which caused + # the HitsModel.sort field to fail validation. + # Note: With a term filter (not a search query), all items have + # the same _score, so we verify the operation succeeds and returns + # all items, not score ordering. + assets = list( + self.metadata.paginate_es( + entity=Table, query_filter=query_filter, size=2, sort_field="_score" + ) + ) + returned_table_names = [ + asset.name.root + for asset in assets + if asset.name.root.startswith("paginating_table_") + ] + # Verify all 5 tables are returned (operation didn't crash) + assert len(returned_table_names) == 5 + + def test_paginate_invalid_sort_order(self): + """Test that invalid sort_order raises ValueError""" + query_filter_obj = { + "query": { + "bool": { + "must": [ + { + "bool": { + "must": [ + { + "term": { + "service.name.keyword": ( + self.service_entity.name.root + ) + } + }, + ] + } + } + ] + } + } + } + query_filter = json.dumps(query_filter_obj) + + with pytest.raises(ValueError, match="sort_order must be 'asc' or 'desc'"): + list( + self.metadata.paginate_es( + entity=Table, + query_filter=query_filter, + sort_order="invalid", + ) + )