Fixes 24332: Enable paginate_es to sort the response (#24375)

This commit is contained in:
Mohamed Daif 2025-11-28 12:30:36 +01:00 committed by GitHub
parent fccd717fbd
commit b50aa44082
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 189 additions and 4 deletions

View File

@ -29,7 +29,7 @@ from typing import (
)
from urllib.parse import quote_plus
from pydantic import Field
from pydantic import Field, field_validator
from typing_extensions import Annotated
from metadata.generated.schema.entity.data.container import Container
@ -78,6 +78,22 @@ class HitsModel(BaseModel):
),
]
@field_validator("sort", mode="before")
def normalize_sort(cls, sort_value: list[str] | None):
"""
Return sort as a list of strings, regardless of the actual type.
if sort_field is set to `_score`, sort is a list of the score and the sort value.
Input examples from ES:
- ["metric"]
- [1.234, "metric"]
"""
if sort_value is None:
return None
return [str(v) for v in sort_value]
class ESHits(BaseModel):
"""Elasticsearch hits model"""
@ -109,7 +125,7 @@ class ESMixin(Generic[T]):
# sort_field needs to be unique for the pagination to work, so we can use the FQN
paginate_query = (
"/search/query?q=&size={size}&deleted=false{filter}&index={index}{include_fields}"
"&sort_field=fullyQualifiedName{after}"
"&sort_field={sort_field}&sort_order={sort_order}{after}"
)
@functools.lru_cache(maxsize=512)
@ -337,8 +353,30 @@ class ESMixin(Generic[T]):
query_filter: Optional[str] = None,
size: int = 100,
include_fields: Optional[List[str]] = None,
sort_field: str = "fullyQualifiedName",
sort_order: str = "desc",
) -> Iterator[ESResponse]:
"""Paginate through the ES results, ignoring individual errors"""
"""
Paginate through the ES results, ignoring individual errors.
Args:
entity: The entity type to paginate
query_filter: Optional ES query filter in JSON format
size: Number of results per page (default: 100)
include_fields: Optional list of fields to include in ES response (optimization)
sort_field: Field to sort by (default: "fullyQualifiedName").
Special field "_score" is supported for relevance sorting.
sort_order: Sort order, either "asc" or "desc" (default: "desc")
Yields:
ESResponse objects containing paginated results
Raises:
ValueError: If sort_order is not "asc" or "desc"
"""
if sort_order not in ("asc", "desc"):
raise ValueError(f"sort_order must be 'asc' or 'desc', got '{sort_order}'")
after: Optional[str] = None
error_pages = 0
query = functools.partial(
@ -347,6 +385,8 @@ class ESMixin(Generic[T]):
filter="&query_filter=" + quote_plus(query_filter) if query_filter else "",
size=size,
include_fields=self._get_include_fields_query(include_fields),
sort_field=sort_field,
sort_order=sort_order,
)
while True:
query_string = query(
@ -377,8 +417,28 @@ class ESMixin(Generic[T]):
query_filter: Optional[str] = None,
size: int = 100,
fields: Optional[List[str]] = None,
sort_field: str = "fullyQualifiedName",
sort_order: str = "desc",
) -> Iterator[T]:
for response in self._paginate_es_internal(entity, query_filter, size):
"""
Paginate through Elasticsearch results and fetch full entities from the API.
Args:
entity: The entity type to paginate
query_filter: Optional ES query filter in JSON format
size: Number of results per page (default: 100)
fields: Optional list of fields to fetch from the API for each entity
sort_field: Field to sort by (default: "fullyQualifiedName").
Must be an indexed ES field. Special field "_score" is supported
for relevance-based sorting.
sort_order: Sort order, either "asc" or "desc" (default: "desc")
Yields:
Full entity objects fetched from the OpenMetadata API
"""
for response in self._paginate_es_internal(
entity, query_filter, size, sort_order=sort_order, sort_field=sort_field
):
yield from self._yield_hits_from_api(
response=response, entity=entity, fields=fields
)

View File

@ -11,6 +11,7 @@
"""
OMeta ES Mixin integration tests. The API needs to be up
"""
import json
import logging
import time
import uuid
@ -393,3 +394,127 @@ class OMetaESTest(TestCase):
self.metadata.paginate_es(entity=Table, query_filter=query_filter, size=2)
)
assert len(assets) == 5
def test_paginate_with_sorting(self):
for name in [f"paginating_table_{i}" for i in range(5)]:
self.metadata.create_or_update(
data=get_create_entity(
entity=Table,
name=EntityName(name),
reference=self.create_schema_entity.fullyQualifiedName,
)
)
query_filter_obj = {
"query": {
"bool": {
"must": [
{
"bool": {
"must": [
{
"term": {
"service.name.keyword": (
self.service_entity.name.root
)
}
},
]
}
}
]
}
}
}
query_filter = json.dumps(query_filter_obj)
# Default sorting with fullyQualifiedName and desc order.
assets = list(
self.metadata.paginate_es(entity=Table, query_filter=query_filter, size=2)
)
returned_table_names = [
asset.name.root
for asset in assets
if asset.name.root.startswith("paginating_table_")
]
assert returned_table_names == [
"paginating_table_4",
"paginating_table_3",
"paginating_table_2",
"paginating_table_1",
"paginating_table_0",
]
# Asc order with fullyQualifiedName
assets = list(
self.metadata.paginate_es(
entity=Table, query_filter=query_filter, size=2, sort_order="asc"
)
)
returned_table_names = [
asset.name.root
for asset in assets
if asset.name.root.startswith("paginating_table_")
]
assert returned_table_names == [
"paginating_table_0",
"paginating_table_1",
"paginating_table_2",
"paginating_table_3",
"paginating_table_4",
]
# Sorting by _score should be supported without deserialization
# errors. This tests the fix for the _score bug where ES returns
# [float_score, fqn_value] instead of [fqn_value], which caused
# the HitsModel.sort field to fail validation.
# Note: With a term filter (not a search query), all items have
# the same _score, so we verify the operation succeeds and returns
# all items, not score ordering.
assets = list(
self.metadata.paginate_es(
entity=Table, query_filter=query_filter, size=2, sort_field="_score"
)
)
returned_table_names = [
asset.name.root
for asset in assets
if asset.name.root.startswith("paginating_table_")
]
# Verify all 5 tables are returned (operation didn't crash)
assert len(returned_table_names) == 5
def test_paginate_invalid_sort_order(self):
"""Test that invalid sort_order raises ValueError"""
query_filter_obj = {
"query": {
"bool": {
"must": [
{
"bool": {
"must": [
{
"term": {
"service.name.keyword": (
self.service_entity.name.root
)
}
},
]
}
}
]
}
}
}
query_filter = json.dumps(query_filter_obj)
with pytest.raises(ValueError, match="sort_order must be 'asc' or 'desc'"):
list(
self.metadata.paginate_es(
entity=Table,
query_filter=query_filter,
sort_order="invalid",
)
)