diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py index b71f86f7..768861f7 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py @@ -212,7 +212,6 @@ class LinearCentroidsInitialisation(AbstractCentroidsInitialisation): def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): if tf.math.less_equal(number_of_clusters_for_interval, 0): - # Return an empty array of centroids return tf.constant([]) weight_min = tf.reduce_min(weight_interval) @@ -247,6 +246,10 @@ class RandomCentroidsInitialisation(AbstractCentroidsInitialisation): def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): + if tf.math.less_equal(number_of_clusters_for_interval, 0): + # Return an empty array of centroids + return tf.constant([]) + weight_min = tf.reduce_min(weight_interval) weight_max = tf.reduce_max(weight_interval) cluster_centroids = tf.random.uniform( diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py index 80d3d338..04e2c2fa 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py @@ -192,6 +192,7 @@ def testDensityBasedClusterCentroidsWithSparsityPreservation( @parameterized.parameters( ([0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], 5), ([0., 1., 2., 3., 3.1, 3.2, 3.3, 3.4, 3.5], 3), + ([1.0, 2.0, 3.0], 3), ([-3., -2., -1., 0., 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.], 6)) def testRandomClusterCentroidsWithSparsityPreservation( self, weights, number_of_clusters):