MINOR - DQ incident severity classifier (#15149)

* feat: added severity classifier for DQ incidents

* feat: added severity classifier tests

* style: ran java linting
This commit is contained in:
Teddy 2024-02-13 07:02:09 +01:00 committed by GitHub
parent 3cc074ba5a
commit e6cacb19f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 262 additions and 1 deletions

View File

@ -370,3 +370,6 @@ web:
permission-policy:
enabled: ${WEB_CONF_PERMISSION_POLICY_ENABLED:-false}
option: ${WEB_CONF_PERMISSION_POLICY_OPTION:-""}
dataQualityConfiguration:
severityIncidentClassifier: ${DATA_QUALITY_SEVERITY_INCIDENT_CLASSIFIER:-"org.openmetadata.service.util.incidentSeverityClassifier.LogisticRegressionIncidentSeverityClassifier"}

View File

@ -113,6 +113,7 @@ import org.openmetadata.service.socket.OpenMetadataAssetServlet;
import org.openmetadata.service.socket.SocketAddressFilter;
import org.openmetadata.service.socket.WebSocketManager;
import org.openmetadata.service.util.MicrometerBundleSingleton;
import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface;
import org.openmetadata.service.util.jdbi.DatabaseAuthenticationProviderFactory;
import org.quartz.SchedulerException;
@ -136,6 +137,9 @@ public class OpenMetadataApplication extends Application<OpenMetadataApplication
NoSuchAlgorithmException {
validateConfiguration(catalogConfig);
// Instantiate incident severity classifier
IncidentSeverityClassifierInterface.createInstance(catalogConfig.getDataQualityConfiguration());
// init for dataSourceFactory
DatasourceConfig.initialize(catalogConfig.getDataSourceFactory().getDriverClass());

View File

@ -23,6 +23,7 @@ import javax.validation.constraints.NotNull;
import lombok.Getter;
import lombok.Setter;
import org.openmetadata.schema.api.configuration.apps.AppsPrivateConfiguration;
import org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration;
import org.openmetadata.schema.api.configuration.events.EventHandlerConfiguration;
import org.openmetadata.schema.api.configuration.pipelineServiceClient.PipelineServiceClientConfiguration;
import org.openmetadata.schema.api.fernet.FernetConfiguration;
@ -94,6 +95,9 @@ public class OpenMetadataApplicationConfig extends Configuration {
@JsonProperty("web")
private OMWebConfiguration webConfiguration = new OMWebConfiguration();
@JsonProperty("dataQualityConfiguration")
private DataQualityConfiguration dataQualityConfiguration;
@JsonProperty("applications")
private AppsPrivateConfiguration appsPrivateConfiguration;

View File

@ -13,12 +13,14 @@ import java.util.UUID;
import javax.json.JsonPatch;
import javax.ws.rs.core.Response;
import org.jdbi.v3.sqlobject.transaction.Transaction;
import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.feed.ResolveTask;
import org.openmetadata.schema.entity.feed.Thread;
import org.openmetadata.schema.entity.teams.User;
import org.openmetadata.schema.tests.TestCase;
import org.openmetadata.schema.tests.type.Assigned;
import org.openmetadata.schema.tests.type.Resolved;
import org.openmetadata.schema.tests.type.Severity;
import org.openmetadata.schema.tests.type.TestCaseResolutionStatus;
import org.openmetadata.schema.tests.type.TestCaseResolutionStatusTypes;
import org.openmetadata.schema.type.EntityReference;
@ -35,6 +37,7 @@ import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.JsonUtils;
import org.openmetadata.service.util.RestUtil;
import org.openmetadata.service.util.ResultList;
import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface;
public class TestCaseResolutionStatusRepository
extends EntityTimeSeriesRepository<TestCaseResolutionStatus> {
@ -168,6 +171,8 @@ public class TestCaseResolutionStatusRepository
: recordEntity.getSeverity());
}
inferIncidentSeverity(recordEntity);
switch (recordEntity.getTestCaseResolutionStatusType()) {
case New -> {
// If there is already an existing New incident we'll return it
@ -300,4 +305,29 @@ public class TestCaseResolutionStatusRepository
FeedRepository feedRepository = Entity.getFeedRepository();
feedRepository.patchThread(null, originalTask.getId(), user, patch);
}
public void inferIncidentSeverity(TestCaseResolutionStatus incident) {
if (incident.getSeverity() != null) {
// If the severity is already set, we don't need to infer it
return;
}
IncidentSeverityClassifierInterface incidentSeverityClassifier =
IncidentSeverityClassifierInterface.getInstance();
EntityReference testCaseReference = incident.getTestCaseReference();
TestCase testCase =
Entity.getEntityByName(
testCaseReference.getType(),
testCaseReference.getFullyQualifiedName(),
"",
Include.ALL);
MessageParser.EntityLink entityLink = MessageParser.EntityLink.parse(testCase.getEntityLink());
EntityInterface entity =
Entity.getEntityByName(
entityLink.getEntityType(),
entityLink.getEntityFQN(),
"followers,owner,tags,votes",
Include.ALL);
Severity severity = incidentSeverityClassifier.classifyIncidentSeverity(entity);
incident.setSeverity(severity);
}
}

View File

@ -0,0 +1,51 @@
package org.openmetadata.service.util.incidentSeverityClassifier;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration;
import org.openmetadata.schema.tests.type.Severity;
@Slf4j
public abstract class IncidentSeverityClassifierInterface {
protected static IncidentSeverityClassifierInterface instance;
public static IncidentSeverityClassifierInterface getInstance() {
if (instance == null) {
LOG.info(
"Incident severity classifier instance is null. Default to LogisticRegressionClassifier");
instance = new LogisticRegressionIncidentSeverityClassifier();
}
return instance;
}
public static void createInstance(DataQualityConfiguration dataQualityConfiguration) {
instance = getClassifierClass(dataQualityConfiguration.getSeverityIncidentClassifier());
}
private static IncidentSeverityClassifierInterface getClassifierClass(
String severityClassifierClassString) {
IncidentSeverityClassifierInterface incidentSeverityClassifier;
try {
Class severityClassifierClass = Class.forName(severityClassifierClassString);
Constructor severityClassifierConstructor = severityClassifierClass.getConstructor();
incidentSeverityClassifier =
(IncidentSeverityClassifierInterface) severityClassifierConstructor.newInstance();
} catch (ClassNotFoundException
| NoSuchMethodException
| IllegalAccessException
| InstantiationException
| InvocationTargetException e) {
LOG.error(
"Error occurred while initializing the incident severity classifier. Default to LogisticRegressionClassifier",
e);
// If we encounter an error while initializing the incident severity classifier, we default to
// the logistic regression classifier
incidentSeverityClassifier = new LogisticRegressionIncidentSeverityClassifier();
}
return incidentSeverityClassifier;
}
public abstract Severity classifyIncidentSeverity(EntityInterface entity);
}

View File

@ -0,0 +1,127 @@
package org.openmetadata.service.util.incidentSeverityClassifier;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.tests.type.Severity;
import org.openmetadata.schema.type.TagLabel;
@Slf4j
public class LogisticRegressionIncidentSeverityClassifier
extends IncidentSeverityClassifierInterface {
// coef. matrix represents the weights of the logistic regression model. It has shape
// (n_feature, n_class) where n_features are respectively:
// - row 0: 'Tier' (1, 2, 3, 4, 5) for an asset
// - row 1: 'HasOwner' 1 if the asset has an owner, 0 otherwise
// - row 2: 'Followers' number of followers of the asset
// - row 3: 'Votes' number of + votes of the asset.
// Coef. matrix was generated using synthetic data.
static final double[][] coefMatrix = {
new double[] {-39.7199427, -3.16664212, 7.52955733, 16.7600252, 18.5970022},
new double[] {65.6563864, 9.33015912, -3.11353307, -13.7841793, -58.0888332},
new double[] {0.0102508192, 0.00490356651, -0.00162766138, -0.00622724217, -0.0072994822},
new double[] {0.0784018717, -0.01140259, -0.00911123152, -0.0237962385, -0.0340918118},
};
@Override
public Severity classifyIncidentSeverity(EntityInterface entity) {
double[] vectorX = getVectorX(entity);
if (vectorX.length == 0) {
return null;
}
try {
double[] vectorZ = dotProduct(vectorX);
double[] softmaxVector = softmax(vectorZ);
int predictedClass = argmax(softmaxVector);
switch (predictedClass) {
case 0:
return Severity.Severity1;
case 1:
return Severity.Severity2;
case 2:
return Severity.Severity3;
case 3:
return Severity.Severity4;
case 4:
return Severity.Severity5;
}
} catch (Exception e) {
LOG.error("Error occurred while classifying incident severity", e);
}
return null;
}
private double[] dotProduct(double[] vectorX) {
// compute the dot product of the input vector and the coef. matrix
double[] result = new double[coefMatrix[0].length];
for (int i = 0; i < coefMatrix.length; i++) {
int sum = 0;
for (int j = 0; j < vectorX.length; j++) {
sum += vectorX[j] * coefMatrix[j][i];
}
result[i] = sum;
}
return result;
}
private double[] softmax(double[] vectorZ) {
// compute the softmax of the z vector
double expSum = Arrays.stream(vectorZ).map(Math::exp).sum();
double[] softmax = new double[vectorZ.length];
for (int i = 0; i < vectorZ.length; i++) {
softmax[i] = Math.exp(vectorZ[i]) / expSum;
}
return softmax;
}
private int argmax(double[] softmaxVector) {
// return the index of the max value in the softmax vector
// (i.e. the predicted class)
int maxIndex = 0;
double argmax = 0;
for (int i = 0; i < softmaxVector.length; i++) {
if (softmaxVector[i] > argmax) {
argmax = softmaxVector[i];
maxIndex = i;
}
}
return maxIndex;
}
private double[] getVectorX(EntityInterface entity) {
// get the input vector for the logistic regression model
double hasOwner = entity.getOwner() != null ? 1 : 0;
double followers = entity.getFollowers() != null ? entity.getFollowers().size() : 0;
double votes = entity.getVotes() != null ? entity.getVotes().getUpVotes() : 0;
double tier = entity.getTags() != null ? getTier(entity.getTags()) : 0;
if (tier == 0) {
// if we don't have a tier set we can't run the classifier
return new double[] {};
}
return new double[] {tier, hasOwner, followers, votes};
}
private double getTier(List<TagLabel> tags) {
// get the tier of the asset
for (TagLabel tag : tags) {
if (tag.getName().contains("Tier")) {
switch (tag.getName()) {
case "Tier1":
return 1;
case "Tier2":
return 2;
case "Tier3":
return 3;
case "Tier4":
return 4;
case "Tier5":
return 5;
}
}
}
return 0;
}
}

View File

@ -69,6 +69,7 @@ import org.openmetadata.schema.tests.type.TestSummary;
import org.openmetadata.schema.type.ChangeDescription;
import org.openmetadata.schema.type.Column;
import org.openmetadata.schema.type.ColumnDataType;
import org.openmetadata.schema.type.TagLabel;
import org.openmetadata.schema.type.TaskStatus;
import org.openmetadata.service.Entity;
import org.openmetadata.service.resources.EntityResourceTest;
@ -77,6 +78,7 @@ import org.openmetadata.service.resources.feeds.FeedResourceTest;
import org.openmetadata.service.util.JsonUtils;
import org.openmetadata.service.util.ResultList;
import org.openmetadata.service.util.TestUtils;
import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@Slf4j
@ -1565,6 +1567,24 @@ public class TestCaseResourceTest extends EntityResourceTest<TestCase, CreateTes
"Incident with status [Assigned] cannot be moved to [Ack]");
}
@Test
public void testInferSeverity(TestInfo test) {
IncidentSeverityClassifierInterface severityClassifier =
IncidentSeverityClassifierInterface.getInstance();
// TEST_TABLE1 has no tier information, hence severity should be null as the classifier won't be
// able to infer
Severity severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1);
assertNull(severity);
List<TagLabel> tags = new ArrayList<>();
tags.add(new TagLabel().withTagFQN("Tier.Tier1").withName("Tier1"));
TEST_TABLE1.setTags(tags);
// With tier set to Tier1, the severity should be inferred
severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1);
assertNotNull(severity);
}
public void deleteTestCaseResult(String fqn, Long timestamp, Map<String, String> authHeaders)
throws HttpResponseException {
WebTarget target = getCollection().path("/" + fqn + "/testCaseResult/" + timestamp);

View File

@ -222,4 +222,8 @@ email:
serverPort: ""
username: ""
password: ""
transportationStrategy: "SMTP_TLS"
transportationStrategy: "SMTP_TLS"
dataQualityConfiguration:
severityIncidentClassifier: "org.openmetadata.service.util.incidentSeverityClassifier.LogisticRegressionIncidentSeverityClassifier"

View File

@ -0,0 +1,18 @@
{
"$id": "https://open-metadata.org/schema/entity/configuration/dataQualityConfiguration.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "DataQualityConfiguration",
"description": "This schema defines the Data Quality Configuration",
"type": "object",
"javaType": "org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration",
"properties": {
"severityIncidentClassifier": {
"description": "Class Name for the severity incident classifier.",
"type": "string"
}
},
"required": [
"severityIncidentClassifier"
],
"additionalProperties": false
}