mirror of
https://github.com/autistic-symposium/tensorflow-for-deep-learning-py.git
synced 2025-05-10 10:45:04 -04:00
23 lines
No EOL
676 B
Python
23 lines
No EOL
676 B
Python
import numpy as np
|
|
|
|
class NearestNeighbor(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def train(self, X, y):
|
|
self.Xtr = X
|
|
self.ytr = y
|
|
|
|
def predict(self, X):
|
|
num_test = X.shape[0]
|
|
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
|
|
|
|
# loop over all test rows
|
|
for i in xrange(num_test):
|
|
# find the nearest training image to the i'th test image
|
|
# using the L1 distance (sum of absolute value differences)
|
|
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
|
|
min_index = np.argmin(distances) # get the index with smallest distance
|
|
Ypred[i] = self.ytr[min_index] # predict the label of the nearest example
|
|
|
|
return Ypred |