[issue-1746] - Add dataSource to MlModel (#1833)

* Add dataSource to featureSource

* Validate DataSource in MLFeatures

* Generalise EntityRef check

* Format
This commit is contained in:
Pere Miquel Brull 2021-12-20 18:57:16 +01:00 committed by GitHub
parent 019a948392
commit 8dfb226fde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 5 deletions

1
.idea/encodings.xml generated
View File

@ -5,5 +5,6 @@
<file url="file://$PROJECT_DIR$/catalog-rest-service/src/main/resources" charset="UTF-8" />
<file url="file://$PROJECT_DIR$/catalog-rest-service/src/main/resources/json/data" charset="UTF-8" />
<file url="file://$PROJECT_DIR$/common/src/main/java" charset="UTF-8" />
<file url="file://$PROJECT_DIR$/openmetadata-ui/src/main/resources/ui/dist" charset="UTF-8" />
</component>
</project>

View File

@ -111,11 +111,14 @@ public class MlModelRepository extends EntityRepository<MlModel> {
return dao.tagDAO().getTags(fqn);
}
private void setMlFeatureSourcesFQN(String parentFQN, List<MlFeatureSource> mlSources) {
private void setMlFeatureSourcesFQN(List<MlFeatureSource> mlSources) {
mlSources.forEach(
s -> {
String sourceFqn = parentFQN + "." + s.getName();
s.setFullyQualifiedName(sourceFqn);
if (s.getDataSource() != null) {
s.setFullyQualifiedName(s.getDataSource().getName() + "." + s.getName());
} else {
s.setFullyQualifiedName(s.getName());
}
});
}
@ -125,16 +128,34 @@ public class MlModelRepository extends EntityRepository<MlModel> {
String featureFqn = parentFQN + "." + f.getName();
f.setFullyQualifiedName(featureFqn);
if (f.getFeatureSources() != null) {
setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources());
setMlFeatureSourcesFQN(f.getFeatureSources());
}
});
}
/** Make sure that all the MlFeatureSources are pointing to correct EntityReferences in tha Table DAO. */
private void validateReferences(List<MlFeature> mlFeatures) throws IOException {
for (MlFeature feature : mlFeatures) {
if (feature.getFeatureSources() != null && !feature.getFeatureSources().isEmpty()) {
for (MlFeatureSource source : feature.getFeatureSources()) {
validateMlDataSource(source);
}
}
}
}
private void validateMlDataSource(MlFeatureSource source) throws IOException {
if (source.getDataSource() != null) {
Entity.getEntityReference(source.getDataSource().getType(), source.getDataSource().getId());
}
}
@Override
public void prepare(MlModel mlModel) throws IOException {
mlModel.setFullyQualifiedName(getFQN(mlModel));
if (mlModel.getMlFeatures() != null && !mlModel.getMlFeatures().isEmpty()) {
validateReferences(mlModel.getMlFeatures());
setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures());
}

View File

@ -84,6 +84,10 @@
"fullyQualifiedName": {
"$ref": "#/definitions/fullyQualifiedFeatureSourceName"
},
"dataSource": {
"description": "Description of the Data Source (e.g., a Table)",
"$ref" : "../../type/entityReference.json"
},
"tags": {
"description": "Tags associated with the feature source.",
"type": "array",

View File

@ -41,20 +41,29 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.TestMethodOrder;
import org.openmetadata.catalog.Entity;
import org.openmetadata.catalog.api.data.CreateDatabase;
import org.openmetadata.catalog.api.data.CreateMlModel;
import org.openmetadata.catalog.api.data.CreateTable;
import org.openmetadata.catalog.api.services.CreateDashboardService;
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
import org.openmetadata.catalog.entity.data.Dashboard;
import org.openmetadata.catalog.entity.data.Database;
import org.openmetadata.catalog.entity.data.MlModel;
import org.openmetadata.catalog.entity.data.Table;
import org.openmetadata.catalog.entity.services.DashboardService;
import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface;
import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface;
import org.openmetadata.catalog.jdbi3.MlModelRepository;
import org.openmetadata.catalog.jdbi3.TableRepository.TableEntityInterface;
import org.openmetadata.catalog.resources.EntityResourceTest;
import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest;
import org.openmetadata.catalog.resources.databases.DatabaseResourceTest;
import org.openmetadata.catalog.resources.databases.TableResourceTest;
import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList;
import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest;
import org.openmetadata.catalog.type.ChangeDescription;
import org.openmetadata.catalog.type.Column;
import org.openmetadata.catalog.type.ColumnDataType;
import org.openmetadata.catalog.type.EntityReference;
import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.FieldChange;
@ -74,6 +83,10 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
public static final String ALGORITHM = "regression";
public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE;
public static Database DATABASE;
public static List<Column> COLUMNS;
public static Table TABLE;
public static EntityReference TABLE_REFERENCE;
public static final URI SERVER = URI.create("http://localhost.com/mlModel");
public static final MlStore ML_STORE =
@ -124,6 +137,18 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
dashboardResourceTest.createDashboard(
dashboardResourceTest.create(test).withService(SUPERSET_REFERENCE), adminAuthHeaders());
DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference();
DatabaseResourceTest databaseResourceTest = new DatabaseResourceTest();
CreateDatabase create = databaseResourceTest.create(test).withService(SNOWFLAKE_REFERENCE);
DATABASE = databaseResourceTest.createAndCheckEntity(create, adminAuthHeaders());
COLUMNS = Collections.singletonList(new Column().withName("age").withDataType(ColumnDataType.INT));
CreateTable createTable = new CreateTable().withName("myTable").withDatabase(DATABASE.getId()).withColumns(COLUMNS);
TableResourceTest tableResourceTest = new TableResourceTest();
TABLE = tableResourceTest.createAndCheckEntity(createTable, adminAuthHeaders());
TABLE_REFERENCE = new TableEntityInterface(TABLE).getEntityReference();
}
public static MlModel createMlModel(CreateMlModel create, Map<String, String> authHeaders)
@ -223,7 +248,7 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
CreateMlModel request = create(test);
// Create a made up dashboard reference by picking up a random UUID
EntityReference dashboard = new EntityReference().withId(USER1.getId()).withType("dashboard");
// MlModel model = createAndCheckEntity(request, adminAuthHeaders());
assertResponse(
() -> createMlModel(request.withDashboard(dashboard), adminAuthHeaders()),
Status.NOT_FOUND,
@ -290,6 +315,55 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
}
@Test
public void put_MlModelWithDataSource_200(TestInfo test) throws IOException {
CreateMlModel request = create(test);
MlModel model = createAndCheckEntity(request, adminAuthHeaders());
MlFeature newMlFeature =
new MlFeature()
.withName("color")
.withDataType(MlFeatureDataType.Categorical)
.withFeatureSources(
Collections.singletonList(
new MlFeatureSource()
.withName("age")
.withDataType(FeatureSourceDataType.INTEGER)
.withDataSource(TABLE_REFERENCE)));
List<MlFeature> newFeatures = Collections.singletonList(newMlFeature);
ChangeDescription change = getChangeDescription(model.getVersion());
change.getFieldsAdded().add(new FieldChange().withName("mlFeatures").withNewValue(newFeatures));
change.getFieldsDeleted().add(new FieldChange().withName("mlFeatures").withOldValue(ML_FEATURES));
updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
}
@Test
public void put_MlModelWithInvalidDataSource_400(TestInfo test) throws IOException {
CreateMlModel request = create(test);
// Create a made up table reference by picking up a random UUID
EntityReference invalid_table = new EntityReference().withId(USER1.getId()).withType("table");
MlFeature newMlFeature =
new MlFeature()
.withName("color")
.withDataType(MlFeatureDataType.Categorical)
.withFeatureSources(
Collections.singletonList(
new MlFeatureSource()
.withName("age")
.withDataType(FeatureSourceDataType.INTEGER)
.withDataSource(invalid_table)));
List<MlFeature> newFeatures = Collections.singletonList(newMlFeature);
assertResponse(
() -> createMlModel(request.withMlFeatures(newFeatures), adminAuthHeaders()),
Status.NOT_FOUND,
String.format("table instance for %s not found", USER1.getId()));
}
@Test
public void put_MlModelAddMlHyperParams_200(TestInfo test) throws IOException {
CreateMlModel request = new CreateMlModel().withName(getEntityName(test)).withAlgorithm(ALGORITHM);
@ -425,6 +499,7 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
assertEquals(actual.getName(), expected.getName());
assertEquals(actual.getDescription(), expected.getDescription());
assertEquals(actual.getDataType(), expected.getDataType());
assertEquals(actual.getDataSource(), expected.getDataSource());
};
private void validateMlFeatureSources(List<MlFeature> expected, List<MlFeature> actual) {