diff --git a/src/main/java/fr/irit/smac/amas4dc/cluster/CalinskiHarabazIndex.java b/src/main/java/fr/irit/smac/amas4dc/cluster/CalinskiHarabazIndex.java new file mode 100644 index 0000000000000000000000000000000000000000..dc343f9313dbaeefc31153e58177162823efbab8 --- /dev/null +++ b/src/main/java/fr/irit/smac/amas4dc/cluster/CalinskiHarabazIndex.java @@ -0,0 +1,59 @@ +package fr.irit.smac.amas4dc.cluster; + +import fr.irit.smac.amas4dc.amas.MASSettings; + +import java.util.List; + +public class CalinskiHarabazIndex<T extends DataPoint> { + + public double compute(MASSettings<T> masSettings, List<ExtendedCluster<T>> clusters) { + var ssb = computeBetweenClusterVariance(clusters, masSettings); + var ssw = computeWithinClusterVariance(clusters, masSettings); + + if (ssw == 0) { + // Avoid division by zero + return Double.POSITIVE_INFINITY; + } + + return ssb / ssw; + } + + private double computeBetweenClusterVariance(List<ExtendedCluster<T>> clusters, MASSettings<T> masSettings) { + var centroid = computeGlobalCentroid(clusters, masSettings); + var ssb = 0.0; + + for (var cluster : clusters) { + var centroidDist = 1.0/masSettings.similarityScoreMethod().apply(cluster.getRepresentative(), centroid); + ssb += centroidDist * centroidDist * cluster.getContent().size(); + } + + return ssb; + } + + private double computeWithinClusterVariance(List<ExtendedCluster<T>> clusters, MASSettings<T> masSettings) { + var ssw = 0.0; + + for (var cluster : clusters) { + var clusterCentroid = cluster.getRepresentative(); + var clusterSumDist = 0.0; + + for (var point : cluster.getContent()) { + var dist = 1.0/masSettings.similarityScoreMethod().apply(point, clusterCentroid); + clusterSumDist += dist * dist; + } + + ssw += clusterSumDist; + } + + return ssw; + } + + private T computeGlobalCentroid(List<ExtendedCluster<T>> clusters, MASSettings<T> masSettings) { + var fused = clusters.get(0).getRepresentative(); + for (int i = 1; i < clusters.size(); i++) { + fused = masSettings.dataPointFuser().apply(fused, clusters.get(i).getRepresentative()); + } + return fused; + } +} +