mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-10-28 17:23:28 +00:00
MINOR - Add ES pagination with error handling (#17776)
* MINOR - Add ES pagination with error handling * format * format * add nullable * prepare API * fix pagination * format
This commit is contained in:
parent
1b6d0fb915
commit
69fcd6a572
@ -16,12 +16,15 @@ To be used by OpenMetadata class
|
||||
import functools
|
||||
import json
|
||||
import traceback
|
||||
from typing import Generic, Iterable, List, Optional, Set, Type, TypeVar
|
||||
from typing import Generic, Iterable, Iterator, List, Optional, Set, Type, TypeVar
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from metadata.generated.schema.entity.data.container import Container
|
||||
from metadata.generated.schema.entity.data.query import Query
|
||||
from metadata.ingestion.models.custom_pydantic import BaseModel
|
||||
from metadata.ingestion.ometa.client import REST, APIError
|
||||
from metadata.ingestion.ometa.utils import quote
|
||||
from metadata.utils.elasticsearch import ES_INDEX_MAP
|
||||
@ -32,6 +35,42 @@ logger = ometa_logger()
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class TotalModel(BaseModel):
|
||||
"""Elasticsearch total model"""
|
||||
|
||||
relation: str
|
||||
value: int
|
||||
|
||||
|
||||
class HitsModel(BaseModel):
|
||||
"""Elasticsearch hits model"""
|
||||
|
||||
index: Annotated[str, Field(description="Index name", alias="_index")]
|
||||
type: Annotated[str, Field(description="Type of the document", alias="_type")]
|
||||
id: Annotated[str, Field(description="Document ID", alias="_id")]
|
||||
score: Annotated[
|
||||
Optional[float], Field(description="Score of the document", alias="_score")
|
||||
]
|
||||
source: Annotated[dict, Field(description="Document source", alias="_source")]
|
||||
sort: Annotated[
|
||||
List[str],
|
||||
Field(description="Sort field. Used internally to get the next page FQN"),
|
||||
]
|
||||
|
||||
|
||||
class ESHits(BaseModel):
|
||||
"""Elasticsearch hits model"""
|
||||
|
||||
total: Annotated[TotalModel, Field(description="Total matched elements")]
|
||||
hits: Annotated[List[HitsModel], Field(description="List of matched elements")]
|
||||
|
||||
|
||||
class ESResponse(BaseModel):
|
||||
"""Elasticsearch response model"""
|
||||
|
||||
hits: ESHits
|
||||
|
||||
|
||||
class ESMixin(Generic[T]):
|
||||
"""
|
||||
OpenMetadata API methods related to Elasticsearch.
|
||||
@ -46,6 +85,12 @@ class ESMixin(Generic[T]):
|
||||
"&size={size}&index={index}"
|
||||
)
|
||||
|
||||
# sort_field needs to be unique for the pagination to work, so we can use the FQN
|
||||
paginate_query = (
|
||||
"/search/query?q=&size={size}&deleted=false{filter}&index={index}"
|
||||
"&sort_field=fullyQualifiedName{after}"
|
||||
)
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def _search_es_entity(
|
||||
self,
|
||||
@ -252,3 +297,65 @@ class ESMixin(Generic[T]):
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Unknown error extracting results from ES query [{err}]")
|
||||
return None
|
||||
|
||||
def paginate_es(
|
||||
self,
|
||||
entity: Type[T],
|
||||
query_filter: Optional[str] = None,
|
||||
size: int = 100,
|
||||
fields: Optional[List[str]] = None,
|
||||
) -> Iterator[T]:
|
||||
"""Paginate through the ES results, ignoring individual errors"""
|
||||
after: Optional[str] = None
|
||||
error_pages = 0
|
||||
query = functools.partial(
|
||||
self.paginate_query.format,
|
||||
index=ES_INDEX_MAP[entity.__name__],
|
||||
filter="&query_filter=" + quote_plus(query_filter) if query_filter else "",
|
||||
size=size,
|
||||
)
|
||||
while True:
|
||||
query_string = query(
|
||||
after="&search_after=" + quote_plus(after) if after else ""
|
||||
)
|
||||
response = self._get_es_response(query_string)
|
||||
|
||||
# Allow 3 errors getting pages before getting out of the loop
|
||||
if not response:
|
||||
error_pages += 1
|
||||
if error_pages < 3:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# Get the data
|
||||
for hit in response.hits.hits:
|
||||
try:
|
||||
yield self.get_by_name(
|
||||
entity=entity,
|
||||
fqn=hit.source["fullyQualifiedName"],
|
||||
fields=fields,
|
||||
nullable=False, # Raise an error if we don't find the Entity
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Error while getting {hit.source['fullyQualifiedName']} - {exc}"
|
||||
)
|
||||
|
||||
# Get next page
|
||||
last_hit = response.hits.hits[-1] if response.hits.hits else None
|
||||
if not last_hit or not last_hit.sort:
|
||||
logger.info("No more pages to fetch")
|
||||
break
|
||||
|
||||
after = ",".join(last_hit.sort)
|
||||
|
||||
def _get_es_response(self, query_string: str) -> Optional[ESResponse]:
|
||||
"""Get the Elasticsearch response"""
|
||||
try:
|
||||
response = self.client.get(query_string)
|
||||
return ESResponse.model_validate(response)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error while getting ES response: {exc}")
|
||||
return None
|
||||
|
||||
@ -15,7 +15,9 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from requests.utils import quote
|
||||
|
||||
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
|
||||
@ -46,10 +48,12 @@ from metadata.generated.schema.entity.services.databaseService import (
|
||||
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
|
||||
OpenMetadataJWTClientConfig,
|
||||
)
|
||||
from metadata.generated.schema.type.basic import SqlQuery
|
||||
from metadata.generated.schema.type.basic import EntityName, SqlQuery
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils import fqn
|
||||
|
||||
from ..integration_base import get_create_entity
|
||||
|
||||
|
||||
class OMetaESTest(TestCase):
|
||||
"""
|
||||
@ -295,3 +299,67 @@ class OMetaESTest(TestCase):
|
||||
"""Check the payload from ES"""
|
||||
res = self.metadata.es_get_queries_with_lineage(self.service.name.root)
|
||||
self.assertIn(self.checksum, res)
|
||||
|
||||
def test_paginate_no_filter(self):
|
||||
"""We can paginate all the data"""
|
||||
# Since the test can run in parallel with other tables being there, we just
|
||||
# want to check we are actually getting some results
|
||||
for asset in self.metadata.paginate_es(entity=Table, size=2):
|
||||
assert asset
|
||||
break
|
||||
|
||||
def test_paginate_with_errors(self):
|
||||
"""We don't want to stop the ES yields just because a single Entity has an error"""
|
||||
# 1. First, prepare some tables
|
||||
for name in [f"table_{i}" for i in range(10)]:
|
||||
self.metadata.create_or_update(
|
||||
data=get_create_entity(
|
||||
entity=Table,
|
||||
name=EntityName(name),
|
||||
reference=self.create_schema_entity.fullyQualifiedName,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. We'll fetch the entities, but we need to force a failure to ensure we can recover
|
||||
error_name = fqn._build(
|
||||
self.service_entity.name.root,
|
||||
self.create_db_entity.name.root,
|
||||
self.create_schema_entity.name.root,
|
||||
"table_5",
|
||||
)
|
||||
ok_name = fqn._build(
|
||||
self.service_entity.name.root,
|
||||
self.create_db_entity.name.root,
|
||||
self.create_schema_entity.name.root,
|
||||
"table_6",
|
||||
)
|
||||
|
||||
rest_client = self.metadata.client
|
||||
original_get = rest_client.get
|
||||
with patch.object(rest_client, "get", wraps=rest_client.get) as mock_get:
|
||||
|
||||
def side_effect(path: str, data=None):
|
||||
# In case we pass filters as well, use `in path` rather than ==
|
||||
if f"/tables/name/{error_name}" in path:
|
||||
raise RuntimeError("Error")
|
||||
return original_get(path, data)
|
||||
|
||||
mock_get.side_effect = side_effect
|
||||
|
||||
# Validate we are raising the error
|
||||
with pytest.raises(RuntimeError):
|
||||
self.metadata.get_by_name(entity=Table, fqn=error_name)
|
||||
|
||||
# This works
|
||||
self.metadata.get_by_name(entity=Table, fqn=ok_name)
|
||||
|
||||
query_filter = (
|
||||
'{"query":{"bool":{"must":[{"bool":{"should":[{"term":'
|
||||
f'{{"service.displayName.keyword":"{self.service_entity.name.root}"}}}}]}}}}]}}}}}}'
|
||||
)
|
||||
assets = list(
|
||||
self.metadata.paginate_es(
|
||||
entity=Table, query_filter=query_filter, size=2
|
||||
)
|
||||
)
|
||||
assert len(assets) == 10
|
||||
|
||||
@ -127,6 +127,11 @@ public class SearchResource {
|
||||
@DefaultValue("10")
|
||||
@QueryParam("size")
|
||||
int size,
|
||||
@Parameter(
|
||||
description =
|
||||
"When paginating, specify the search_after values. Use it ass search_after=<val1>,<val2>,...")
|
||||
@QueryParam("search_after")
|
||||
String searchAfter,
|
||||
@Parameter(
|
||||
description =
|
||||
"Sort the search results by field, available fields to "
|
||||
@ -196,6 +201,7 @@ public class SearchResource {
|
||||
.domains(domains)
|
||||
.applyDomainFilter(
|
||||
!subjectContext.isAdmin() && subjectContext.hasAnyRole(DOMAIN_ONLY_ACCESS_ROLE))
|
||||
.searchAfter(searchAfter)
|
||||
.build();
|
||||
return searchRepository.search(request);
|
||||
}
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
package org.openmetadata.service.search;
|
||||
|
||||
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.openmetadata.schema.type.EntityReference;
|
||||
@ -25,6 +28,7 @@ public class SearchRequest {
|
||||
private final boolean applyDomainFilter;
|
||||
private final List<String> domains;
|
||||
private final boolean getHierarchy;
|
||||
private final Object[] searchAfter;
|
||||
|
||||
public SearchRequest(ElasticSearchRequestBuilder builder) {
|
||||
this.query = builder.query;
|
||||
@ -43,6 +47,7 @@ public class SearchRequest {
|
||||
this.getHierarchy = builder.getHierarchy;
|
||||
this.domains = builder.domains;
|
||||
this.applyDomainFilter = builder.applyDomainFilter;
|
||||
this.searchAfter = builder.searchAfter;
|
||||
}
|
||||
|
||||
// Builder class for ElasticSearchRequest
|
||||
@ -64,6 +69,7 @@ public class SearchRequest {
|
||||
private boolean getHierarchy;
|
||||
private boolean applyDomainFilter;
|
||||
private List<String> domains;
|
||||
private Object[] searchAfter;
|
||||
|
||||
public ElasticSearchRequestBuilder(String query, int size, String index) {
|
||||
this.query = query;
|
||||
@ -139,6 +145,14 @@ public class SearchRequest {
|
||||
return this;
|
||||
}
|
||||
|
||||
public ElasticSearchRequestBuilder searchAfter(String searchAfter) {
|
||||
this.searchAfter = null;
|
||||
if (!nullOrEmpty(searchAfter)) {
|
||||
this.searchAfter = Stream.of(searchAfter.split(",")).toArray(Object[]::new);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public SearchRequest build() {
|
||||
return new SearchRequest(this);
|
||||
}
|
||||
|
||||
@ -381,6 +381,10 @@ public class ElasticSearchClient implements SearchClient {
|
||||
}
|
||||
}
|
||||
|
||||
if (!nullOrEmpty(request.getSearchAfter())) {
|
||||
searchSourceBuilder.searchAfter(request.getSearchAfter());
|
||||
}
|
||||
|
||||
/* For backward-compatibility we continue supporting the deleted argument, this should be removed in future versions */
|
||||
if (request
|
||||
.getIndex()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user