multiple-database-support-added (#3777)

* multiple-database-support-added

* FQDN_SEPARATOR-added

* code-formatted

* code-formatted
This commit is contained in:
codingwithabhi 2022-04-09 15:11:11 +05:30 committed by GitHub
parent 53e4403ccd
commit a92dff15e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 2 deletions

View File

@ -7,5 +7,5 @@ Provides metadata version information.
from incremental import Version
__version__ = Version("metadata", 0, 9, 0, dev=17)
__version__ = Version("metadata", 0, 9, 0, dev=18)
__all__ = ["__version__"]

View File

@ -9,10 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import namedtuple
from typing import Iterable
import psycopg2
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from metadata.config.common import FQDN_SEPARATOR
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
@ -24,11 +30,15 @@ from metadata.generated.schema.metadataIngestion.workflow import (
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.source import InvalidSourceException, SourceStatus
from metadata.ingestion.source.sql_source import SQLSource
from metadata.utils.engines import get_engine
TableKey = namedtuple("TableKey", ["schema", "table_name"])
logger: logging.Logger = logging.getLogger(__name__)
class PostgresSource(SQLSource):
def __init__(self, config, metadata_config):
@ -46,6 +56,34 @@ class PostgresSource(SQLSource):
return cls(config, metadata_config)
def get_databases(self) -> Iterable[Inspector]:
if self.config.database != None:
yield from super().get_databases()
else:
query = "select datname from pg_catalog.pg_database;"
results = self.connection.execute(query)
for res in results:
row = list(res)
try:
logger.info(f"Ingesting from database: {row[0]}")
self.config.database = row[0]
self.engine = get_engine(self.config)
self.connection = self.engine.connect()
yield inspect(self.engine)
except Exception as err:
logger.error(f"Failed to Connect: {row[0]} due to error {err}")
def _get_database(self, schema: str) -> Database:
return Database(
name=self.config.database + FQDN_SEPARATOR + schema,
service=EntityReference(id=self.service.id, type=self.config.service_type),
)
def get_status(self) -> SourceStatus:
return self.status
@ -61,7 +99,8 @@ class PostgresSource(SQLSource):
""",
(table_name, schema),
)
is_partition = cur.fetchone()[0]
obj = cur.fetchone()
is_partition = obj[0] if obj else False
return is_partition
def type_of_column_name(self, sa_type, table_name: str, column_name: str):