mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-01 10:06:25 +00:00
[issue-1746] - Add dataSource to MlModel (#1833)
* Add dataSource to featureSource * Validate DataSource in MLFeatures * Generalise EntityRef check * Format
This commit is contained in:
parent
019a948392
commit
8dfb226fde
1
.idea/encodings.xml
generated
1
.idea/encodings.xml
generated
@ -5,5 +5,6 @@
|
||||
<file url="file://$PROJECT_DIR$/catalog-rest-service/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/catalog-rest-service/src/main/resources/json/data" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/common/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/openmetadata-ui/src/main/resources/ui/dist" charset="UTF-8" />
|
||||
</component>
|
||||
</project>
|
||||
@ -111,11 +111,14 @@ public class MlModelRepository extends EntityRepository<MlModel> {
|
||||
return dao.tagDAO().getTags(fqn);
|
||||
}
|
||||
|
||||
private void setMlFeatureSourcesFQN(String parentFQN, List<MlFeatureSource> mlSources) {
|
||||
private void setMlFeatureSourcesFQN(List<MlFeatureSource> mlSources) {
|
||||
mlSources.forEach(
|
||||
s -> {
|
||||
String sourceFqn = parentFQN + "." + s.getName();
|
||||
s.setFullyQualifiedName(sourceFqn);
|
||||
if (s.getDataSource() != null) {
|
||||
s.setFullyQualifiedName(s.getDataSource().getName() + "." + s.getName());
|
||||
} else {
|
||||
s.setFullyQualifiedName(s.getName());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -125,16 +128,34 @@ public class MlModelRepository extends EntityRepository<MlModel> {
|
||||
String featureFqn = parentFQN + "." + f.getName();
|
||||
f.setFullyQualifiedName(featureFqn);
|
||||
if (f.getFeatureSources() != null) {
|
||||
setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources());
|
||||
setMlFeatureSourcesFQN(f.getFeatureSources());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Make sure that all the MlFeatureSources are pointing to correct EntityReferences in tha Table DAO. */
|
||||
private void validateReferences(List<MlFeature> mlFeatures) throws IOException {
|
||||
for (MlFeature feature : mlFeatures) {
|
||||
if (feature.getFeatureSources() != null && !feature.getFeatureSources().isEmpty()) {
|
||||
for (MlFeatureSource source : feature.getFeatureSources()) {
|
||||
validateMlDataSource(source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void validateMlDataSource(MlFeatureSource source) throws IOException {
|
||||
if (source.getDataSource() != null) {
|
||||
Entity.getEntityReference(source.getDataSource().getType(), source.getDataSource().getId());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepare(MlModel mlModel) throws IOException {
|
||||
mlModel.setFullyQualifiedName(getFQN(mlModel));
|
||||
|
||||
if (mlModel.getMlFeatures() != null && !mlModel.getMlFeatures().isEmpty()) {
|
||||
validateReferences(mlModel.getMlFeatures());
|
||||
setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures());
|
||||
}
|
||||
|
||||
|
||||
@ -84,6 +84,10 @@
|
||||
"fullyQualifiedName": {
|
||||
"$ref": "#/definitions/fullyQualifiedFeatureSourceName"
|
||||
},
|
||||
"dataSource": {
|
||||
"description": "Description of the Data Source (e.g., a Table)",
|
||||
"$ref" : "../../type/entityReference.json"
|
||||
},
|
||||
"tags": {
|
||||
"description": "Tags associated with the feature source.",
|
||||
"type": "array",
|
||||
|
||||
@ -41,20 +41,29 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInfo;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
import org.openmetadata.catalog.Entity;
|
||||
import org.openmetadata.catalog.api.data.CreateDatabase;
|
||||
import org.openmetadata.catalog.api.data.CreateMlModel;
|
||||
import org.openmetadata.catalog.api.data.CreateTable;
|
||||
import org.openmetadata.catalog.api.services.CreateDashboardService;
|
||||
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
|
||||
import org.openmetadata.catalog.entity.data.Dashboard;
|
||||
import org.openmetadata.catalog.entity.data.Database;
|
||||
import org.openmetadata.catalog.entity.data.MlModel;
|
||||
import org.openmetadata.catalog.entity.data.Table;
|
||||
import org.openmetadata.catalog.entity.services.DashboardService;
|
||||
import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface;
|
||||
import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface;
|
||||
import org.openmetadata.catalog.jdbi3.MlModelRepository;
|
||||
import org.openmetadata.catalog.jdbi3.TableRepository.TableEntityInterface;
|
||||
import org.openmetadata.catalog.resources.EntityResourceTest;
|
||||
import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest;
|
||||
import org.openmetadata.catalog.resources.databases.DatabaseResourceTest;
|
||||
import org.openmetadata.catalog.resources.databases.TableResourceTest;
|
||||
import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList;
|
||||
import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest;
|
||||
import org.openmetadata.catalog.type.ChangeDescription;
|
||||
import org.openmetadata.catalog.type.Column;
|
||||
import org.openmetadata.catalog.type.ColumnDataType;
|
||||
import org.openmetadata.catalog.type.EntityReference;
|
||||
import org.openmetadata.catalog.type.FeatureSourceDataType;
|
||||
import org.openmetadata.catalog.type.FieldChange;
|
||||
@ -74,6 +83,10 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
|
||||
public static final String ALGORITHM = "regression";
|
||||
public static Dashboard DASHBOARD;
|
||||
public static EntityReference DASHBOARD_REFERENCE;
|
||||
public static Database DATABASE;
|
||||
public static List<Column> COLUMNS;
|
||||
public static Table TABLE;
|
||||
public static EntityReference TABLE_REFERENCE;
|
||||
|
||||
public static final URI SERVER = URI.create("http://localhost.com/mlModel");
|
||||
public static final MlStore ML_STORE =
|
||||
@ -124,6 +137,18 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
|
||||
dashboardResourceTest.createDashboard(
|
||||
dashboardResourceTest.create(test).withService(SUPERSET_REFERENCE), adminAuthHeaders());
|
||||
DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference();
|
||||
|
||||
DatabaseResourceTest databaseResourceTest = new DatabaseResourceTest();
|
||||
CreateDatabase create = databaseResourceTest.create(test).withService(SNOWFLAKE_REFERENCE);
|
||||
DATABASE = databaseResourceTest.createAndCheckEntity(create, adminAuthHeaders());
|
||||
|
||||
COLUMNS = Collections.singletonList(new Column().withName("age").withDataType(ColumnDataType.INT));
|
||||
|
||||
CreateTable createTable = new CreateTable().withName("myTable").withDatabase(DATABASE.getId()).withColumns(COLUMNS);
|
||||
|
||||
TableResourceTest tableResourceTest = new TableResourceTest();
|
||||
TABLE = tableResourceTest.createAndCheckEntity(createTable, adminAuthHeaders());
|
||||
TABLE_REFERENCE = new TableEntityInterface(TABLE).getEntityReference();
|
||||
}
|
||||
|
||||
public static MlModel createMlModel(CreateMlModel create, Map<String, String> authHeaders)
|
||||
@ -223,7 +248,7 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
|
||||
CreateMlModel request = create(test);
|
||||
// Create a made up dashboard reference by picking up a random UUID
|
||||
EntityReference dashboard = new EntityReference().withId(USER1.getId()).withType("dashboard");
|
||||
// MlModel model = createAndCheckEntity(request, adminAuthHeaders());
|
||||
|
||||
assertResponse(
|
||||
() -> createMlModel(request.withDashboard(dashboard), adminAuthHeaders()),
|
||||
Status.NOT_FOUND,
|
||||
@ -290,6 +315,55 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
|
||||
updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void put_MlModelWithDataSource_200(TestInfo test) throws IOException {
|
||||
CreateMlModel request = create(test);
|
||||
MlModel model = createAndCheckEntity(request, adminAuthHeaders());
|
||||
|
||||
MlFeature newMlFeature =
|
||||
new MlFeature()
|
||||
.withName("color")
|
||||
.withDataType(MlFeatureDataType.Categorical)
|
||||
.withFeatureSources(
|
||||
Collections.singletonList(
|
||||
new MlFeatureSource()
|
||||
.withName("age")
|
||||
.withDataType(FeatureSourceDataType.INTEGER)
|
||||
.withDataSource(TABLE_REFERENCE)));
|
||||
List<MlFeature> newFeatures = Collections.singletonList(newMlFeature);
|
||||
|
||||
ChangeDescription change = getChangeDescription(model.getVersion());
|
||||
change.getFieldsAdded().add(new FieldChange().withName("mlFeatures").withNewValue(newFeatures));
|
||||
change.getFieldsDeleted().add(new FieldChange().withName("mlFeatures").withOldValue(ML_FEATURES));
|
||||
|
||||
updateAndCheckEntity(request.withMlFeatures(newFeatures), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void put_MlModelWithInvalidDataSource_400(TestInfo test) throws IOException {
|
||||
CreateMlModel request = create(test);
|
||||
|
||||
// Create a made up table reference by picking up a random UUID
|
||||
EntityReference invalid_table = new EntityReference().withId(USER1.getId()).withType("table");
|
||||
|
||||
MlFeature newMlFeature =
|
||||
new MlFeature()
|
||||
.withName("color")
|
||||
.withDataType(MlFeatureDataType.Categorical)
|
||||
.withFeatureSources(
|
||||
Collections.singletonList(
|
||||
new MlFeatureSource()
|
||||
.withName("age")
|
||||
.withDataType(FeatureSourceDataType.INTEGER)
|
||||
.withDataSource(invalid_table)));
|
||||
List<MlFeature> newFeatures = Collections.singletonList(newMlFeature);
|
||||
|
||||
assertResponse(
|
||||
() -> createMlModel(request.withMlFeatures(newFeatures), adminAuthHeaders()),
|
||||
Status.NOT_FOUND,
|
||||
String.format("table instance for %s not found", USER1.getId()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void put_MlModelAddMlHyperParams_200(TestInfo test) throws IOException {
|
||||
CreateMlModel request = new CreateMlModel().withName(getEntityName(test)).withAlgorithm(ALGORITHM);
|
||||
@ -425,6 +499,7 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
|
||||
assertEquals(actual.getName(), expected.getName());
|
||||
assertEquals(actual.getDescription(), expected.getDescription());
|
||||
assertEquals(actual.getDataType(), expected.getDataType());
|
||||
assertEquals(actual.getDataSource(), expected.getDataSource());
|
||||
};
|
||||
|
||||
private void validateMlFeatureSources(List<MlFeature> expected, List<MlFeature> actual) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user