diff --git a/.idea/encodings.xml b/.idea/encodings.xml index ed12c7eee38..c7d8cba3020 100644 --- a/.idea/encodings.xml +++ b/.idea/encodings.xml @@ -5,5 +5,6 @@ + \ No newline at end of file diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java index 543a9fcbac0..be7715e41e8 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java @@ -111,11 +111,14 @@ public class MlModelRepository extends EntityRepository { return dao.tagDAO().getTags(fqn); } - private void setMlFeatureSourcesFQN(String parentFQN, List mlSources) { + private void setMlFeatureSourcesFQN(List mlSources) { mlSources.forEach( s -> { - String sourceFqn = parentFQN + "." + s.getName(); - s.setFullyQualifiedName(sourceFqn); + if (s.getDataSource() != null) { + s.setFullyQualifiedName(s.getDataSource().getName() + "." + s.getName()); + } else { + s.setFullyQualifiedName(s.getName()); + } }); } @@ -125,16 +128,34 @@ public class MlModelRepository extends EntityRepository { String featureFqn = parentFQN + "." + f.getName(); f.setFullyQualifiedName(featureFqn); if (f.getFeatureSources() != null) { - setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources()); + setMlFeatureSourcesFQN(f.getFeatureSources()); } }); } + /** Make sure that all the MlFeatureSources are pointing to correct EntityReferences in tha Table DAO. */ + private void validateReferences(List mlFeatures) throws IOException { + for (MlFeature feature : mlFeatures) { + if (feature.getFeatureSources() != null && !feature.getFeatureSources().isEmpty()) { + for (MlFeatureSource source : feature.getFeatureSources()) { + validateMlDataSource(source); + } + } + } + } + + private void validateMlDataSource(MlFeatureSource source) throws IOException { + if (source.getDataSource() != null) { + Entity.getEntityReference(source.getDataSource().getType(), source.getDataSource().getId()); + } + } + @Override public void prepare(MlModel mlModel) throws IOException { mlModel.setFullyQualifiedName(getFQN(mlModel)); if (mlModel.getMlFeatures() != null && !mlModel.getMlFeatures().isEmpty()) { + validateReferences(mlModel.getMlFeatures()); setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures()); } diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json index 662025b5fb0..a25351e1b55 100644 --- a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json +++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json @@ -84,6 +84,10 @@ "fullyQualifiedName": { "$ref": "#/definitions/fullyQualifiedFeatureSourceName" }, + "dataSource": { + "description": "Description of the Data Source (e.g., a Table)", + "$ref" : "../../type/entityReference.json" + }, "tags": { "description": "Tags associated with the feature source.", "type": "array", diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java index 6623551cbd6..a614e0092a6 100644 --- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java +++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java @@ -41,20 +41,29 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.TestMethodOrder; import org.openmetadata.catalog.Entity; +import org.openmetadata.catalog.api.data.CreateDatabase; import org.openmetadata.catalog.api.data.CreateMlModel; +import org.openmetadata.catalog.api.data.CreateTable; import org.openmetadata.catalog.api.services.CreateDashboardService; import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType; import org.openmetadata.catalog.entity.data.Dashboard; +import org.openmetadata.catalog.entity.data.Database; import org.openmetadata.catalog.entity.data.MlModel; +import org.openmetadata.catalog.entity.data.Table; import org.openmetadata.catalog.entity.services.DashboardService; import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface; import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface; import org.openmetadata.catalog.jdbi3.MlModelRepository; +import org.openmetadata.catalog.jdbi3.TableRepository.TableEntityInterface; import org.openmetadata.catalog.resources.EntityResourceTest; import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest; +import org.openmetadata.catalog.resources.databases.DatabaseResourceTest; +import org.openmetadata.catalog.resources.databases.TableResourceTest; import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList; import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest; import org.openmetadata.catalog.type.ChangeDescription; +import org.openmetadata.catalog.type.Column; +import org.openmetadata.catalog.type.ColumnDataType; import org.openmetadata.catalog.type.EntityReference; import org.openmetadata.catalog.type.FeatureSourceDataType; import org.openmetadata.catalog.type.FieldChange; @@ -74,6 +83,10 @@ public class MlModelResourceTest extends EntityResourceTest { public static final String ALGORITHM = "regression"; public static Dashboard DASHBOARD; public static EntityReference DASHBOARD_REFERENCE; + public static Database DATABASE; + public static List COLUMNS; + public static Table TABLE; + public static EntityReference TABLE_REFERENCE; public static final URI SERVER = URI.create("http://localhost.com/mlModel"); public static final MlStore ML_STORE = @@ -124,6 +137,18 @@ public class MlModelResourceTest extends EntityResourceTest { dashboardResourceTest.createDashboard( dashboardResourceTest.create(test).withService(SUPERSET_REFERENCE), adminAuthHeaders()); DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference(); + + DatabaseResourceTest databaseResourceTest = new DatabaseResourceTest(); + CreateDatabase create = databaseResourceTest.create(test).withService(SNOWFLAKE_REFERENCE); + DATABASE = databaseResourceTest.createAndCheckEntity(create, adminAuthHeaders()); + + COLUMNS = Collections.singletonList(new Column().withName("age").withDataType(ColumnDataType.INT)); + + CreateTable createTable = new CreateTable().withName("myTable").withDatabase(DATABASE.getId()).withColumns(COLUMNS); + + TableResourceTest tableResourceTest = new TableResourceTest(); + TABLE = tableResourceTest.createAndCheckEntity(createTable, adminAuthHeaders()); + TABLE_REFERENCE = new TableEntityInterface(TABLE).getEntityReference(); } public static MlModel createMlModel(CreateMlModel create, Map authHeaders) @@ -223,7 +248,7 @@ public class MlModelResourceTest extends EntityResourceTest { CreateMlModel request = create(test); // Create a made up dashboard reference by picking up a random UUID EntityReference dashboard = new EntityReference().withId(USER1.getId()).withType("dashboard"); - // MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + assertResponse( () -> createMlModel(request.withDashboard(dashboard), adminAuthHeaders()), Status.NOT_FOUND, @@ -290,6 +315,55 @@ public class MlModelResourceTest extends EntityResourceTest { updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change); } + @Test + public void put_MlModelWithDataSource_200(TestInfo test) throws IOException { + CreateMlModel request = create(test); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + + MlFeature newMlFeature = + new MlFeature() + .withName("color") + .withDataType(MlFeatureDataType.Categorical) + .withFeatureSources( + Collections.singletonList( + new MlFeatureSource() + .withName("age") + .withDataType(FeatureSourceDataType.INTEGER) + .withDataSource(TABLE_REFERENCE))); + List newFeatures = Collections.singletonList(newMlFeature); + + ChangeDescription change = getChangeDescription(model.getVersion()); + change.getFieldsAdded().add(new FieldChange().withName("mlFeatures").withNewValue(newFeatures)); + change.getFieldsDeleted().add(new FieldChange().withName("mlFeatures").withOldValue(ML_FEATURES)); + + updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change); + } + + @Test + public void put_MlModelWithInvalidDataSource_400(TestInfo test) throws IOException { + CreateMlModel request = create(test); + + // Create a made up table reference by picking up a random UUID + EntityReference invalid_table = new EntityReference().withId(USER1.getId()).withType("table"); + + MlFeature newMlFeature = + new MlFeature() + .withName("color") + .withDataType(MlFeatureDataType.Categorical) + .withFeatureSources( + Collections.singletonList( + new MlFeatureSource() + .withName("age") + .withDataType(FeatureSourceDataType.INTEGER) + .withDataSource(invalid_table))); + List newFeatures = Collections.singletonList(newMlFeature); + + assertResponse( + () -> createMlModel(request.withMlFeatures(newFeatures), adminAuthHeaders()), + Status.NOT_FOUND, + String.format("table instance for %s not found", USER1.getId())); + } + @Test public void put_MlModelAddMlHyperParams_200(TestInfo test) throws IOException { CreateMlModel request = new CreateMlModel().withName(getEntityName(test)).withAlgorithm(ALGORITHM); @@ -425,6 +499,7 @@ public class MlModelResourceTest extends EntityResourceTest { assertEquals(actual.getName(), expected.getName()); assertEquals(actual.getDescription(), expected.getDescription()); assertEquals(actual.getDataType(), expected.getDataType()); + assertEquals(actual.getDataSource(), expected.getDataSource()); }; private void validateMlFeatureSources(List expected, List actual) {