diff --git a/.idea/encodings.xml b/.idea/encodings.xml
index ed12c7eee38..c7d8cba3020 100644
--- a/.idea/encodings.xml
+++ b/.idea/encodings.xml
@@ -5,5 +5,6 @@
+
\ No newline at end of file
diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
index 543a9fcbac0..be7715e41e8 100644
--- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
+++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
@@ -111,11 +111,14 @@ public class MlModelRepository extends EntityRepository {
return dao.tagDAO().getTags(fqn);
}
- private void setMlFeatureSourcesFQN(String parentFQN, List mlSources) {
+ private void setMlFeatureSourcesFQN(List mlSources) {
mlSources.forEach(
s -> {
- String sourceFqn = parentFQN + "." + s.getName();
- s.setFullyQualifiedName(sourceFqn);
+ if (s.getDataSource() != null) {
+ s.setFullyQualifiedName(s.getDataSource().getName() + "." + s.getName());
+ } else {
+ s.setFullyQualifiedName(s.getName());
+ }
});
}
@@ -125,16 +128,34 @@ public class MlModelRepository extends EntityRepository {
String featureFqn = parentFQN + "." + f.getName();
f.setFullyQualifiedName(featureFqn);
if (f.getFeatureSources() != null) {
- setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources());
+ setMlFeatureSourcesFQN(f.getFeatureSources());
}
});
}
+ /** Make sure that all the MlFeatureSources are pointing to correct EntityReferences in tha Table DAO. */
+ private void validateReferences(List mlFeatures) throws IOException {
+ for (MlFeature feature : mlFeatures) {
+ if (feature.getFeatureSources() != null && !feature.getFeatureSources().isEmpty()) {
+ for (MlFeatureSource source : feature.getFeatureSources()) {
+ validateMlDataSource(source);
+ }
+ }
+ }
+ }
+
+ private void validateMlDataSource(MlFeatureSource source) throws IOException {
+ if (source.getDataSource() != null) {
+ Entity.getEntityReference(source.getDataSource().getType(), source.getDataSource().getId());
+ }
+ }
+
@Override
public void prepare(MlModel mlModel) throws IOException {
mlModel.setFullyQualifiedName(getFQN(mlModel));
if (mlModel.getMlFeatures() != null && !mlModel.getMlFeatures().isEmpty()) {
+ validateReferences(mlModel.getMlFeatures());
setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures());
}
diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
index 662025b5fb0..a25351e1b55 100644
--- a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
+++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
@@ -84,6 +84,10 @@
"fullyQualifiedName": {
"$ref": "#/definitions/fullyQualifiedFeatureSourceName"
},
+ "dataSource": {
+ "description": "Description of the Data Source (e.g., a Table)",
+ "$ref" : "../../type/entityReference.json"
+ },
"tags": {
"description": "Tags associated with the feature source.",
"type": "array",
diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
index 6623551cbd6..a614e0092a6 100644
--- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
+++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
@@ -41,20 +41,29 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.TestMethodOrder;
import org.openmetadata.catalog.Entity;
+import org.openmetadata.catalog.api.data.CreateDatabase;
import org.openmetadata.catalog.api.data.CreateMlModel;
+import org.openmetadata.catalog.api.data.CreateTable;
import org.openmetadata.catalog.api.services.CreateDashboardService;
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
import org.openmetadata.catalog.entity.data.Dashboard;
+import org.openmetadata.catalog.entity.data.Database;
import org.openmetadata.catalog.entity.data.MlModel;
+import org.openmetadata.catalog.entity.data.Table;
import org.openmetadata.catalog.entity.services.DashboardService;
import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface;
import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface;
import org.openmetadata.catalog.jdbi3.MlModelRepository;
+import org.openmetadata.catalog.jdbi3.TableRepository.TableEntityInterface;
import org.openmetadata.catalog.resources.EntityResourceTest;
import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest;
+import org.openmetadata.catalog.resources.databases.DatabaseResourceTest;
+import org.openmetadata.catalog.resources.databases.TableResourceTest;
import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList;
import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest;
import org.openmetadata.catalog.type.ChangeDescription;
+import org.openmetadata.catalog.type.Column;
+import org.openmetadata.catalog.type.ColumnDataType;
import org.openmetadata.catalog.type.EntityReference;
import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.FieldChange;
@@ -74,6 +83,10 @@ public class MlModelResourceTest extends EntityResourceTest {
public static final String ALGORITHM = "regression";
public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE;
+ public static Database DATABASE;
+ public static List COLUMNS;
+ public static Table TABLE;
+ public static EntityReference TABLE_REFERENCE;
public static final URI SERVER = URI.create("http://localhost.com/mlModel");
public static final MlStore ML_STORE =
@@ -124,6 +137,18 @@ public class MlModelResourceTest extends EntityResourceTest {
dashboardResourceTest.createDashboard(
dashboardResourceTest.create(test).withService(SUPERSET_REFERENCE), adminAuthHeaders());
DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference();
+
+ DatabaseResourceTest databaseResourceTest = new DatabaseResourceTest();
+ CreateDatabase create = databaseResourceTest.create(test).withService(SNOWFLAKE_REFERENCE);
+ DATABASE = databaseResourceTest.createAndCheckEntity(create, adminAuthHeaders());
+
+ COLUMNS = Collections.singletonList(new Column().withName("age").withDataType(ColumnDataType.INT));
+
+ CreateTable createTable = new CreateTable().withName("myTable").withDatabase(DATABASE.getId()).withColumns(COLUMNS);
+
+ TableResourceTest tableResourceTest = new TableResourceTest();
+ TABLE = tableResourceTest.createAndCheckEntity(createTable, adminAuthHeaders());
+ TABLE_REFERENCE = new TableEntityInterface(TABLE).getEntityReference();
}
public static MlModel createMlModel(CreateMlModel create, Map authHeaders)
@@ -223,7 +248,7 @@ public class MlModelResourceTest extends EntityResourceTest {
CreateMlModel request = create(test);
// Create a made up dashboard reference by picking up a random UUID
EntityReference dashboard = new EntityReference().withId(USER1.getId()).withType("dashboard");
- // MlModel model = createAndCheckEntity(request, adminAuthHeaders());
+
assertResponse(
() -> createMlModel(request.withDashboard(dashboard), adminAuthHeaders()),
Status.NOT_FOUND,
@@ -290,6 +315,55 @@ public class MlModelResourceTest extends EntityResourceTest {
updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
}
+ @Test
+ public void put_MlModelWithDataSource_200(TestInfo test) throws IOException {
+ CreateMlModel request = create(test);
+ MlModel model = createAndCheckEntity(request, adminAuthHeaders());
+
+ MlFeature newMlFeature =
+ new MlFeature()
+ .withName("color")
+ .withDataType(MlFeatureDataType.Categorical)
+ .withFeatureSources(
+ Collections.singletonList(
+ new MlFeatureSource()
+ .withName("age")
+ .withDataType(FeatureSourceDataType.INTEGER)
+ .withDataSource(TABLE_REFERENCE)));
+ List newFeatures = Collections.singletonList(newMlFeature);
+
+ ChangeDescription change = getChangeDescription(model.getVersion());
+ change.getFieldsAdded().add(new FieldChange().withName("mlFeatures").withNewValue(newFeatures));
+ change.getFieldsDeleted().add(new FieldChange().withName("mlFeatures").withOldValue(ML_FEATURES));
+
+ updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
+ }
+
+ @Test
+ public void put_MlModelWithInvalidDataSource_400(TestInfo test) throws IOException {
+ CreateMlModel request = create(test);
+
+ // Create a made up table reference by picking up a random UUID
+ EntityReference invalid_table = new EntityReference().withId(USER1.getId()).withType("table");
+
+ MlFeature newMlFeature =
+ new MlFeature()
+ .withName("color")
+ .withDataType(MlFeatureDataType.Categorical)
+ .withFeatureSources(
+ Collections.singletonList(
+ new MlFeatureSource()
+ .withName("age")
+ .withDataType(FeatureSourceDataType.INTEGER)
+ .withDataSource(invalid_table)));
+ List newFeatures = Collections.singletonList(newMlFeature);
+
+ assertResponse(
+ () -> createMlModel(request.withMlFeatures(newFeatures), adminAuthHeaders()),
+ Status.NOT_FOUND,
+ String.format("table instance for %s not found", USER1.getId()));
+ }
+
@Test
public void put_MlModelAddMlHyperParams_200(TestInfo test) throws IOException {
CreateMlModel request = new CreateMlModel().withName(getEntityName(test)).withAlgorithm(ALGORITHM);
@@ -425,6 +499,7 @@ public class MlModelResourceTest extends EntityResourceTest {
assertEquals(actual.getName(), expected.getName());
assertEquals(actual.getDescription(), expected.getDescription());
assertEquals(actual.getDataType(), expected.getDataType());
+ assertEquals(actual.getDataSource(), expected.getDataSource());
};
private void validateMlFeatureSources(List expected, List actual) {