class MeanShift:
def __init__(self, points):
self.points = points # (10333, 3) columns: (weight, row coord, col coord)
def fit(self, random_starts = 100, bw=20):
random_starts = min(random_starts, self.points.shape[0])
centroids = self.points[np.random.choice(self.points.shape[0], random_starts, replace=False)][:,1:] # (100, 2)
optimize = True
while optimize:
diff_3d = centroids[:,None,:] - self.points[None,:,1:] # (100, 10333, 2) = (100, 1, 2) - (1, 10333, 2)
distance = np.sum( diff_3d**2 , axis=2) # (100, 10333)
in_bandwidth = distance < bw**2 # (100, 10333)
weights = self.points[:, 0].reshape(1,-1) # (1, 10333)
sum_weights = np.sum( weights * in_bandwidth , axis=1) # (100,)
# (100, 10333, 1) = (1, 10333, 1) * (100, 10333, 1)
weights_and_mask = weights.reshape((1, weights.shape[1], 1)) * in_bandwidth.reshape( (in_bandwidth.shape[0], -1, 1))
# (100, 10333, 2) = (1, 10333, 2) * (100, 10333, 1)
ij_sums = self.points[:, 1:].reshape((1, self.points.shape[0], 2)) * weights_and_mask
# (100, 2) = (100, 2) / (100, 1)
new_centroids = np.sum(ij_sums, axis=1) / sum_weights.reshape((sum_weights.shape[0], 1))
if np.min(centroids == new_centroids) == 1: # check if nothing changed
optimize = False
centroids = np.unique(new_centroids, axis=0) # build centroids set
return centroids