MINOR - Add ES pagination with error handling (#17776)

* MINOR - Add ES pagination with error handling

* format

* format

* add nullable

* prepare API

* fix pagination

* format
This commit is contained in:
Pere Miquel Brull 2024-09-12 07:14:56 +02:00
parent 1b6d0fb915
commit 69fcd6a572
5 changed files with 202 additions and 3 deletions

View File

@ -16,12 +16,15 @@ To be used by OpenMetadata class
import functools
import json
import traceback
from typing import Generic, Iterable, List, Optional, Set, Type, TypeVar
from typing import Generic, Iterable, Iterator, List, Optional, Set, Type, TypeVar
from urllib.parse import quote_plus
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import Annotated
from metadata.generated.schema.entity.data.container import Container
from metadata.generated.schema.entity.data.query import Query
from metadata.ingestion.models.custom_pydantic import BaseModel
from metadata.ingestion.ometa.client import REST, APIError
from metadata.ingestion.ometa.utils import quote
from metadata.utils.elasticsearch import ES_INDEX_MAP
@ -32,6 +35,42 @@ logger = ometa_logger()
T = TypeVar("T", bound=BaseModel)
class TotalModel(BaseModel):
"""Elasticsearch total model"""
relation: str
value: int
class HitsModel(BaseModel):
"""Elasticsearch hits model"""
index: Annotated[str, Field(description="Index name", alias="_index")]
type: Annotated[str, Field(description="Type of the document", alias="_type")]
id: Annotated[str, Field(description="Document ID", alias="_id")]
score: Annotated[
Optional[float], Field(description="Score of the document", alias="_score")
]
source: Annotated[dict, Field(description="Document source", alias="_source")]
sort: Annotated[
List[str],
Field(description="Sort field. Used internally to get the next page FQN"),
]
class ESHits(BaseModel):
"""Elasticsearch hits model"""
total: Annotated[TotalModel, Field(description="Total matched elements")]
hits: Annotated[List[HitsModel], Field(description="List of matched elements")]
class ESResponse(BaseModel):
"""Elasticsearch response model"""
hits: ESHits
class ESMixin(Generic[T]):
"""
OpenMetadata API methods related to Elasticsearch.
@ -46,6 +85,12 @@ class ESMixin(Generic[T]):
"&size={size}&index={index}"
)
# sort_field needs to be unique for the pagination to work, so we can use the FQN
paginate_query = (
"/search/query?q=&size={size}&deleted=false{filter}&index={index}"
"&sort_field=fullyQualifiedName{after}"
)
@functools.lru_cache(maxsize=512)
def _search_es_entity(
self,
@ -252,3 +297,65 @@ class ESMixin(Generic[T]):
logger.debug(traceback.format_exc())
logger.warning(f"Unknown error extracting results from ES query [{err}]")
return None
def paginate_es(
self,
entity: Type[T],
query_filter: Optional[str] = None,
size: int = 100,
fields: Optional[List[str]] = None,
) -> Iterator[T]:
"""Paginate through the ES results, ignoring individual errors"""
after: Optional[str] = None
error_pages = 0
query = functools.partial(
self.paginate_query.format,
index=ES_INDEX_MAP[entity.__name__],
filter="&query_filter=" + quote_plus(query_filter) if query_filter else "",
size=size,
)
while True:
query_string = query(
after="&search_after=" + quote_plus(after) if after else ""
)
response = self._get_es_response(query_string)
# Allow 3 errors getting pages before getting out of the loop
if not response:
error_pages += 1
if error_pages < 3:
continue
else:
break
# Get the data
for hit in response.hits.hits:
try:
yield self.get_by_name(
entity=entity,
fqn=hit.source["fullyQualifiedName"],
fields=fields,
nullable=False, # Raise an error if we don't find the Entity
)
except Exception as exc:
logger.warning(
f"Error while getting {hit.source['fullyQualifiedName']} - {exc}"
)
# Get next page
last_hit = response.hits.hits[-1] if response.hits.hits else None
if not last_hit or not last_hit.sort:
logger.info("No more pages to fetch")
break
after = ",".join(last_hit.sort)
def _get_es_response(self, query_string: str) -> Optional[ESResponse]:
"""Get the Elasticsearch response"""
try:
response = self.client.get(query_string)
return ESResponse.model_validate(response)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error while getting ES response: {exc}")
return None

View File

@ -15,7 +15,9 @@ import logging
import time
import uuid
from unittest import TestCase
from unittest.mock import patch
import pytest
from requests.utils import quote
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
@ -46,10 +48,12 @@ from metadata.generated.schema.entity.services.databaseService import (
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.generated.schema.type.basic import SqlQuery
from metadata.generated.schema.type.basic import EntityName, SqlQuery
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils import fqn
from ..integration_base import get_create_entity
class OMetaESTest(TestCase):
"""
@ -295,3 +299,67 @@ class OMetaESTest(TestCase):
"""Check the payload from ES"""
res = self.metadata.es_get_queries_with_lineage(self.service.name.root)
self.assertIn(self.checksum, res)
def test_paginate_no_filter(self):
"""We can paginate all the data"""
# Since the test can run in parallel with other tables being there, we just
# want to check we are actually getting some results
for asset in self.metadata.paginate_es(entity=Table, size=2):
assert asset
break
def test_paginate_with_errors(self):
"""We don't want to stop the ES yields just because a single Entity has an error"""
# 1. First, prepare some tables
for name in [f"table_{i}" for i in range(10)]:
self.metadata.create_or_update(
data=get_create_entity(
entity=Table,
name=EntityName(name),
reference=self.create_schema_entity.fullyQualifiedName,
)
)
# 2. We'll fetch the entities, but we need to force a failure to ensure we can recover
error_name = fqn._build(
self.service_entity.name.root,
self.create_db_entity.name.root,
self.create_schema_entity.name.root,
"table_5",
)
ok_name = fqn._build(
self.service_entity.name.root,
self.create_db_entity.name.root,
self.create_schema_entity.name.root,
"table_6",
)
rest_client = self.metadata.client
original_get = rest_client.get
with patch.object(rest_client, "get", wraps=rest_client.get) as mock_get:
def side_effect(path: str, data=None):
# In case we pass filters as well, use `in path` rather than ==
if f"/tables/name/{error_name}" in path:
raise RuntimeError("Error")
return original_get(path, data)
mock_get.side_effect = side_effect
# Validate we are raising the error
with pytest.raises(RuntimeError):
self.metadata.get_by_name(entity=Table, fqn=error_name)
# This works
self.metadata.get_by_name(entity=Table, fqn=ok_name)
query_filter = (
'{"query":{"bool":{"must":[{"bool":{"should":[{"term":'
f'{{"service.displayName.keyword":"{self.service_entity.name.root}"}}}}]}}}}]}}}}}}'
)
assets = list(
self.metadata.paginate_es(
entity=Table, query_filter=query_filter, size=2
)
)
assert len(assets) == 10

View File

@ -127,6 +127,11 @@ public class SearchResource {
@DefaultValue("10")
@QueryParam("size")
int size,
@Parameter(
description =
"When paginating, specify the search_after values. Use it ass search_after=<val1>,<val2>,...")
@QueryParam("search_after")
String searchAfter,
@Parameter(
description =
"Sort the search results by field, available fields to "
@ -196,6 +201,7 @@ public class SearchResource {
.domains(domains)
.applyDomainFilter(
!subjectContext.isAdmin() && subjectContext.hasAnyRole(DOMAIN_ONLY_ACCESS_ROLE))
.searchAfter(searchAfter)
.build();
return searchRepository.search(request);
}

View File

@ -1,7 +1,10 @@
package org.openmetadata.service.search;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Getter;
import lombok.Setter;
import org.openmetadata.schema.type.EntityReference;
@ -25,6 +28,7 @@ public class SearchRequest {
private final boolean applyDomainFilter;
private final List<String> domains;
private final boolean getHierarchy;
private final Object[] searchAfter;
public SearchRequest(ElasticSearchRequestBuilder builder) {
this.query = builder.query;
@ -43,6 +47,7 @@ public class SearchRequest {
this.getHierarchy = builder.getHierarchy;
this.domains = builder.domains;
this.applyDomainFilter = builder.applyDomainFilter;
this.searchAfter = builder.searchAfter;
}
// Builder class for ElasticSearchRequest
@ -64,6 +69,7 @@ public class SearchRequest {
private boolean getHierarchy;
private boolean applyDomainFilter;
private List<String> domains;
private Object[] searchAfter;
public ElasticSearchRequestBuilder(String query, int size, String index) {
this.query = query;
@ -139,6 +145,14 @@ public class SearchRequest {
return this;
}
public ElasticSearchRequestBuilder searchAfter(String searchAfter) {
this.searchAfter = null;
if (!nullOrEmpty(searchAfter)) {
this.searchAfter = Stream.of(searchAfter.split(",")).toArray(Object[]::new);
}
return this;
}
public SearchRequest build() {
return new SearchRequest(this);
}

View File

@ -381,6 +381,10 @@ public class ElasticSearchClient implements SearchClient {
}
}
if (!nullOrEmpty(request.getSearchAfter())) {
searchSourceBuilder.searchAfter(request.getSearchAfter());
}
/* For backward-compatibility we continue supporting the deleted argument, this should be removed in future versions */
if (request
.getIndex()