Step-by-step k-NN実装#

config = dict(
    date = "2022/07/09",
    author = "村上力",
)
config
{'date': '2022/07/09', 'author': '村上力'}

k-NNクラスを一から実装するのは少し大変です。このノートでは、いきなりk-NNをクラスとして実装することが難しいなぁ…と思っている人のために、一歩一歩k-NNの実装を行っていきます。

さて、k-NNを行うためには、

Important

k-NNクラスに必要な機能:

  1. 初期化: どんな条件のk-NNを行うのかを決める。

  2. 訓練: データとラベルを保存する

  3. 予測: 保存されたデータとラベルを使って未知データのクラスを予測する

の3ステップが必要でした。(これは、scikit-learnの機械学習モデルクラスの実装に倣った設計になっています。)

そこで、ここでは初期化、訓練、予測の3ステップに実装を分けて考えていきます。

データの準備#

実装したk-NNがちゃんと動いているのかを確かめるために、今回のコードではiris datasetを例として利用します。

import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
data1 = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data1.data, data1.target, test_size=0.3)
scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

初期化の実装#

最終的にはk-NNをクラスとして実装することが目的ですが、まずは訓練と予測をそれぞれ関数として実装します。ただし、関数の中ではできるだけ、「関数の引数」と「関数の中で定義した変数」のみが利用可能であることを意識して下さい。

さて、まずはk-NNの初期化ステップです。ここでは、機械学習の訓練をする前に、どのような設定でこの機械学習モデルを構築するのかを定義します。

関数として必要な機能を実装していくので、クラスは必要ないのですが、あくまでも「辞書型っぽいなにか」として利用するために、クラスを用意しておきます。

class knnに必要な情報をまとめたデータ構造:
    def __init__(self):
        ...

さて、k-NNを利用するときに、最低限確認したい情報は

  • kを何にするのか

だけです。 つまりこの下で用意する「knnに必要な情報をまとめたデータ構造」は、メンバ変数としてk(近傍の何個の点を見るのか」のみを持っていることになりそうです。 これを上のクラスに追加します。

class knnに必要な情報をまとめたデータ構造:
    def __init__(self, k:int):
        self.k = k

これで、これ以降に実装する関数には、このdata_containerを引数として渡すことで、kの値も参照できるようになりました。

data_container = knnに必要な情報をまとめたデータ構造(1) # ここではkを適当に1とします。
print(data_container.k)
1

訓練関数の実装#

さて次に、この必要情報と、訓練データと訓練ラベルを受け取って、訓練を行う関数を用意します。

ただし、knnが訓練ステップで行うのは、訓練というよりはデータペアをすべて保存するだけの「暗記」です。

# 「必要情報」という引数には、必ず「knnに必要な情報をまとめたデータ構造」クラスの
# インスタンス(つまりdata_container)を渡して下さい。
def 訓練(必要情報, 教師データ, 教師ラベル):
    必要情報.教師データ = 教師データ
    必要情報.教師ラベル = 教師ラベル
    return 必要情報

ここで定義した関数は、引数として受け取ったデータを必要情報(第一引数)の新しいメンバ変数として追加して、更新した第一引数のオブジェクトをそのまま帰すだけの仕事をします。

data_container = 訓練(data_container, X_train, y_train)

訓練を行ったあとのdata_containerの中身を確認してみましょう。

data_container.k
1

data_containerが持っている「教師データ」をみてみましょう.

data_container.教師データ
Hide code cell output
array([[ 1.1934322 ,  0.08286181,  0.63753499,  0.39303303],
       [-1.12928107, -1.60890016, -0.25799911, -0.25380816],
       [ 0.14821123, -2.09226072,  0.69350587,  0.39303303],
       [-0.78087408,  2.49966462, -1.26547496, -1.41812232],
       [-1.24541674,  0.80790265, -1.04159144, -1.28875408],
       [-1.12928107,  0.08286181, -1.26547496, -1.41812232],
       [ 1.1934322 ,  0.08286181,  0.74947675,  1.42797894],
       [ 1.77411052, -0.64217903,  1.30918556,  0.91050599],
       [ 0.14821123, -0.88385931,  0.74947675,  0.52240127],
       [ 1.1934322 ,  0.08286181,  0.91738939,  1.16924246],
       [ 0.96116088,  0.08286181,  0.52559322,  0.39303303],
       [ 0.61275388,  0.32454209,  0.41365146,  0.39303303],
       [-0.89700975,  1.5329435 , -1.26547496, -1.0300176 ],
       [-1.3615524 ,  0.32454209, -1.37741672, -1.28875408],
       [-1.70995939,  0.32454209, -1.37741672, -1.28875408],
       [ 0.14821123, -0.40049875,  0.41365146,  0.39303303],
       [ 0.03207556,  0.32454209,  0.5815641 ,  0.78113775],
       [ 0.38048256, -0.64217903,  0.5815641 ,  0.78113775],
       [-0.89700975,  1.77462378, -1.26547496, -1.15938584],
       [ 0.72888955, -0.15881847,  0.97336027,  0.78113775],
       [-0.89700975,  1.77462378, -1.04159144, -1.0300176 ],
       [-0.31633143, -0.88385931,  0.24573882,  0.13429655],
       [ 0.61275388, -0.64217903,  1.02933115,  1.2986107 ],
       [-0.43246709, -1.60890016, -0.03411558, -0.25380816],
       [ 0.61275388, -0.88385931,  0.86141851,  0.91050599],
       [-0.78087408,  1.04958293, -1.26547496, -1.28875408],
       [ 0.26434689, -0.40049875,  0.52559322,  0.26366479],
       [ 0.96116088,  0.56622237,  1.08530203,  1.16924246],
       [ 2.12251751, -0.15881847,  1.30918556,  1.42797894],
       [-0.43246709, -1.12553959,  0.35768058,  0.00492831],
       [-0.20019576, -0.64217903,  0.18976794,  0.13429655],
       [-1.24541674,  0.08286181, -1.20950408, -1.28875408],
       [ 1.30956787,  0.32454209,  0.52559322,  0.26366479],
       [-0.20019576,  3.22470546, -1.26547496, -1.0300176 ],
       [-0.20019576, -0.40049875,  0.24573882,  0.13429655],
       [ 1.07729654, -0.64217903,  0.5815641 ,  0.26366479],
       [-1.12928107, -0.15881847, -1.32144584, -1.28875408],
       [-1.47768807,  0.32454209, -1.32144584, -1.28875408],
       [-0.20019576, -0.15881847,  0.24573882,  0.00492831],
       [-0.54860275,  1.5329435 , -1.26547496, -1.28875408],
       [-0.89700975,  0.80790265, -1.26547496, -1.28875408],
       [-0.54860275,  0.80790265, -1.26547496, -1.0300176 ],
       [ 0.61275388, -0.64217903,  1.02933115,  1.16924246],
       [ 1.54183919,  0.32454209,  1.25321467,  0.78113775],
       [-0.89700975,  1.04958293, -1.32144584, -1.15938584],
       [-1.12928107,  0.08286181, -1.26547496, -1.28875408],
       [ 0.26434689, -0.64217903,  0.13379706,  0.13429655],
       [ 0.96116088, -0.15881847,  0.69350587,  0.65176951],
       [ 2.12251751,  1.77462378,  1.64501084,  1.2986107 ],
       [ 0.61275388,  0.32454209,  0.86141851,  1.42797894],
       [-1.47768807,  0.08286181, -1.26547496, -1.28875408],
       [-0.0840601 , -0.88385931,  0.74947675,  0.91050599],
       [-0.54860275,  2.01630406, -1.1535332 , -1.0300176 ],
       [-0.43246709, -1.36721988,  0.13379706,  0.13429655],
       [-0.20019576, -1.12553959, -0.14605735, -0.25380816],
       [-1.82609506, -0.15881847, -1.48935849, -1.41812232],
       [-0.20019576, -1.36721988,  0.69350587,  1.03987423],
       [-1.01314541,  1.04958293, -1.37741672, -1.15938584],
       [ 1.07729654,  0.32454209,  1.19724379,  1.42797894],
       [-1.01314541, -1.85058044, -0.25799911, -0.25380816],
       [-0.54860275,  0.80790265, -1.1535332 , -1.28875408],
       [ 2.00638185, -0.15881847,  1.58903996,  1.16924246],
       [ 1.65797486, -0.40049875,  1.42112732,  0.78113775],
       [ 2.12251751, -1.12553959,  1.7569526 ,  1.42797894],
       [ 0.49661822,  0.56622237,  0.52559322,  0.52240127],
       [ 0.49661822,  0.56622237,  1.25321467,  1.68671542],
       [ 0.84502521, -0.15881847,  0.35768058,  0.26366479],
       [-0.78087408,  0.80790265, -1.32144584, -1.28875408],
       [-0.20019576,  1.77462378, -1.1535332 , -1.15938584],
       [ 0.14821123, -2.09226072,  0.13379706, -0.25380816],
       [ 0.72888955, -0.15881847,  0.80544763,  1.03987423],
       [-1.70995939, -0.40049875, -1.32144584, -1.28875408],
       [ 0.61275388, -0.40049875,  0.3017097 ,  0.13429655],
       [-1.01314541,  0.32454209, -1.43338761, -1.28875408],
       [ 0.03207556, -0.15881847,  0.24573882,  0.39303303],
       [ 0.26434689, -0.64217903,  0.52559322,  0.00492831],
       [ 0.49661822, -1.36721988,  0.63753499,  0.39303303],
       [ 0.96116088,  0.56622237,  1.08530203,  1.68671542],
       [ 0.61275388,  0.08286181,  0.97336027,  0.78113775],
       [-1.24541674, -0.15881847, -1.32144584, -1.15938584],
       [-1.3615524 ,  0.32454209, -1.20950408, -1.28875408],
       [ 0.96116088, -0.15881847,  0.80544763,  1.42797894],
       [-0.54860275,  2.01630406, -1.37741672, -1.0300176 ],
       [ 1.07729654, -0.15881847,  0.97336027,  1.16924246],
       [-0.0840601 , -0.64217903,  0.74947675,  1.55734718],
       [ 0.72888955, -0.64217903,  0.46962234,  0.39303303],
       [-0.0840601 , -0.88385931,  0.18976794, -0.25380816],
       [ 0.49661822, -0.40049875,  1.02933115,  0.78113775],
       [-0.31633143, -0.40049875, -0.09008647,  0.13429655],
       [ 0.03207556, -0.15881847,  0.74947675,  0.78113775],
       [-1.12928107, -1.36721988,  0.41365146,  0.65176951],
       [ 1.42570353, -0.15881847,  1.19724379,  1.16924246],
       [ 0.84502521, -0.40049875,  0.46962234,  0.13429655],
       [-0.31633143, -0.15881847,  0.41365146,  0.39303303],
       [-1.70995939, -0.15881847, -1.37741672, -1.28875408],
       [ 0.38048256, -2.09226072,  0.41365146,  0.39303303],
       [ 2.35478884,  1.77462378,  1.4770982 ,  1.03987423],
       [-1.47768807,  1.29126322, -1.54532937, -1.28875408],
       [ 0.96116088,  0.08286181,  1.02933115,  1.55734718],
       [-1.01314541,  0.56622237, -1.32144584, -1.28875408],
       [ 0.14821123, -0.15881847,  0.5815641 ,  0.78113775],
       [ 0.49661822, -1.85058044,  0.35768058,  0.13429655],
       [-1.01314541,  1.29126322, -1.32144584, -1.28875408],
       [-1.01314541,  0.80790265, -1.26547496, -1.28875408],
       [-1.01314541, -0.15881847, -1.20950408, -1.28875408]])
data_container.教師ラベル
array([1, 1, 2, 0, 0, 0, 2, 2, 1, 2, 1, 1, 0, 0, 0, 1, 1, 2, 0, 2, 0, 1,
       2, 1, 2, 0, 1, 2, 2, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 2, 2,
       0, 0, 1, 1, 2, 2, 0, 2, 0, 1, 1, 0, 2, 0, 2, 1, 0, 2, 2, 2, 1, 2,
       1, 0, 0, 1, 2, 0, 1, 0, 1, 1, 1, 2, 2, 0, 0, 2, 0, 2, 2, 1, 1, 2,
       1, 2, 2, 2, 1, 1, 0, 1, 2, 0, 2, 0, 2, 1, 0, 0, 0])

予測関数の実装#

k-nnでは未知データと教師データすべてのとの距離を計算して、距離の近い順番にソートします。その後に教師データのラベルがそれぞれ何個有るのかを調べて、最も多いラベルを未知データのラベルとして採用するのでした。

実装を見る前に難しそうな関数を紹介しておきます。

np.argsort#

# np.argsortとは
身長 = [180,150,140,160]
print("np.sort:",np.sort(身長))
print("np.argsort:", np.argsort(身長))
np.sort: [140 150 160 180]
np.argsort: [2 1 3 0]

np.bincount#

# 配列内にある数字がそれぞれ何個あったのかを数えてくれる関数(最大値と同じ要素数のベクトルが返ってくる)
成績 = [4,2,1,3,3,3,5]
np.bincount(成績)
array([0, 1, 1, 3, 1, 1])

距離を計算する関数#

ではまず、距離を計算する関数を実装しましょう。

def 距離を計算する関数(データ点, 教師データ全部):
    距離 = ((データ点 - 教師データ全部) ** 2).sum(axis=1)
    return 距離

与えられたデータが属するクラスを予測する関数#

def 予測(必要情報, クラスを予測したいデータ):
    予測したクラスラベル = [] # リストとして初期化

    for (ループの回数, x) in enumerate(クラスを予測したいデータ):

        # データ点xと教師データすべてとの距離を計算し、distance_vectorに格納する。
        # distance_vectorの要素数は教師データの数と同じになっているはず。
        distance_vector = 距離を計算する関数(x, 必要情報.教師データ)

        # これを小さい順にソートして、データの番号を変数に保存しておく
        sorted_indexes = np.argsort(distance_vector)

        # 先頭からk個だけ取り出して、あとは捨てる
        ご近所さん = sorted_indexes[:必要情報.k]

        # k個のご近所さんの教師ラベルを変数に保存しておく
        ご近所さんのラベル = 必要情報.教師ラベル[ご近所さん]

        # ご近所さんの中で一番多いラベルを見つける
        近所で一番人気のラベル = np.bincount(ご近所さんのラベル).argmax()

        # xのご近所で一番多いのがこのクラスなら、きっとxもこのクラスなんだろうな…
        予測したクラスラベル.append(近所で一番人気のラベル)

        # あとはこれを「クラスを予測したいデータ」すべてに対して行えばすべての予測ができる。

    return np.array(予測したクラスラベル) # 返す時はnumpyの配列としておく(おそらくy_trainもそうだったでしょ?)
pred_labels = 予測(data_container, X_test)

# 正答率
(pred_labels == y_test).sum() / len(X_test)
0.9333333333333333

これでk-NNがとりあえず完成しました。

リファクタリング#

想定しない使われ方への対応#

しかし、訓練をしていないのに予測を走らせるとエラーが出そうです。そのために訓練済みかどうかを判別できるフラグも「必要情報」の一つになりそうですね。

ここまでで実装したコード⇓

class knnに必要な情報をまとめたデータ構造:
    def __init__(self, k:int):
        self.k = k

def 訓練(必要情報, 教師データ, 教師ラベル):
    必要情報.教師データ = 教師データ
    必要情報.教師ラベル = 教師ラベル
    return 必要情報

def 距離を計算する関数(データ点, 教師データ全部):
    距離 = ((データ点 - 教師データ全部) ** 2).sum(axis=1)
    return 距離

def 予測(必要情報, クラスを予測したいデータ):
    予測したクラスラベル = []

    for (ループの回数, x) in enumerate(クラスを予測したいデータ):
        distance_vector = 距離を計算する関数(x, 必要情報.教師データ)
        sorted_indexes = np.argsort(distance_vector)
        ご近所さん = sorted_indexes[:必要情報.k]
        ご近所さんのラベル = 必要情報.教師ラベル[ご近所さん]
        近所で一番人気のラベル = np.bincount(ご近所さんのラベル).argmax()
        予測したクラスラベル.append(近所で一番人気のラベル)
    return np.array(予測したクラスラベル)

訓練済みフラグを追加して、予測関数で訓練済みかを確認するように変更したコード⇓

class knnに必要な情報をまとめたデータ構造:
    def __init__(self, k:int):
        self.k = k
        self.is_fitted = False

def 訓練(必要情報, 教師データ, 教師ラベル):
    必要情報.教師データ = 教師データ
    必要情報.教師ラベル = 教師ラベル
    必要情報.is_fitted = True
    return 必要情報

def 距離を計算する関数(データ点, 教師データ全部):
    距離 = ((データ点 - 教師データ全部) ** 2).sum(axis=1)
    return 距離

def 予測(必要情報, クラスを予測したいデータ):
    assert 必要情報.is_fitted, "先に訓練してから予測して下さい"
    予測したクラスラベル = []

    for (ループの回数, x) in enumerate(クラスを予測したいデータ):
        distance_vector = 距離を計算する関数(x, 必要情報.教師データ)
        sorted_indexes = np.argsort(distance_vector)
        ご近所さん = sorted_indexes[:必要情報.k]
        ご近所さんのラベル = 必要情報.教師ラベル[ご近所さん]
        近所で一番人気のラベル = np.bincount(ご近所さんのラベル).argmax()
        予測したクラスラベル.append(近所で一番人気のラベル)
    return np.array(予測したクラスラベル)

訓練しないで予測しようとした場合:

# 訓練しないで予測しようとした場合

data_container = knnに必要な情報をまとめたデータ構造(3)
pred_labels = 予測(data_container, X_test)

このコードを実行すると次のようなエラーが表示されるはずです.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[31], line 2
      1 data_container = knnに必要な情報をまとめたデータ構造(3)
----> 2 pred_labels = 予測(data_container, X_test)

Cell In[21], line 17
     16 def 予測(必要情報, クラスを予測したいデータ):
---> 17     assert 必要情報.is_fitted, "先に訓練してから予測して下さい"
     18     予測したクラスラベル = []
     20     for (ループの回数, x) in enumerate(クラスを予測したいデータ):

AssertionError: 先に訓練してから予測して下さい

(自分でも試してみてください)

これに対して訓練後に予測をしようとした場合は問題なく動作するはずです.

# 訓練後に予測をしようとした場合

data_container = knnに必要な情報をまとめたデータ構造(3)
data_container = 訓練(data_container, X_train,y_train)
pred_labels = 予測(data_container, X_test)

これで例外処理ができそうです。

変数名などを英語に直していく#

さて、このコードでも十分動くのですが、変数名や関数名は英語で書くのが普通です。とりあえず、いくつか英語に直してみます。 また、この際に「必要情報」としていた部分を、全てselfに置き換えておきます。

… 教師データとそのラベルをX_train, y_train. (training dataなのでtrain)。テストデータとそのラベルをX_test, y_testとするのでした。これについてもここで直しておきます。

class knnに必要な情報をまとめたデータ構造:
    def __init__(self, k:int):
        self.k = k
        self.is_fitted = False

def fit(self, X_train, y_train):
    self.X_train = X_train
    self.y_train = y_train
    self.is_fitted = True
    return self

def compute_distance(data1, data2):
    distance = ((data1 - data2) ** 2).sum(axis=1)
    return distance

def predict(self, X_test):
    assert self.is_fitted, "先に訓練してから予測して下さい"
    pred_labels = []

    for (loop_counter, x) in enumerate(X_test):
        distance_vector = compute_distance(x, self.X_train)
        sorted_indexes = np.argsort(distance_vector)
        neighbors = sorted_indexes[:self.k]
        neighbors_label = self.y_train[neighbors]
        popular_label = np.bincount(neighbors_label).argmax()
        pred_labels.append(popular_label)
    return np.array(pred_labels)

よく見るプログラムらしくなってきましたね。

data_container = knnに必要な情報をまとめたデータ構造(3)
data_container = fit(data_container, X_train,y_train)
pred_labels = predict(data_container, X_test)

# 正答率
(pred_labels == y_test).sum() / len(X_test)
0.9333333333333333

関数からクラスへ#

さて、ここで、fit, predictは「 knnに必要な情報をまとめたデータ構造」のインスタンスを第一引数に取る関数でした。それ以外のデータ構造を渡しても、おそらくエラーが出てしまいそうです。

このような「あるデータ構造専用の関数」のことをメソッドと呼び、classの中で定義する事ができます。

def compute_distance(data1, data2):
        distance = ((data1 - data2) ** 2).sum(axis=1)
        return distance

class kNN:
    def __init__(self, k:int):
        self.k = k
        self.is_fitted = False

    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        self.is_fitted = True
        return self

    def predict(self, X_test):
        assert self.is_fitted, "先に訓練してから予測して下さい"
        pred_labels = []

        for (loop_counter, x) in enumerate(X_test):
            distance_vector = compute_distance(x, self.X_train)
            sorted_indexes = np.argsort(distance_vector)
            neighbors = sorted_indexes[:self.k]
            neighbors_label = self.y_train[neighbors]
            popular_label = np.bincount(neighbors_label).argmax()
            pred_labels.append(popular_label)
        return np.array(pred_labels)

この書き方をすることで、インスタンス.メソッド(self以外の引数)の形でメソッドの実行が可能です。

また、それぞれのメソッドは、別のメソッドからself.メソッド名で呼び出すことができます。 fitメソッドからpredictを呼び出す時は self.predict とすればいいのです。

selfはインスタンス自体を示しているので、これはインスタンス.メソッドの形になっています。

逆に、それぞれのメソッドの中で、クラス名.メソッドはエラーになります。(ただクラス名を書くだけだとインスタンスになっていませんよね。)

model = kNN(3)
model.fit(X_train,y_train)
pred_labels = model.predict(X_test)

# 正答率
(pred_labels == y_test).sum() / len(X_test)
0.9333333333333333

これで基本的なクラスの実装は完了しました。ただこうなると、仲間はずれのcompute_distanceが寂しそうです。

この関数はselfを受け取る必要のない関数なので、基本的にはクラスの外で定義しておいたほうがいいでしょう。

しかしもしも、「selfを第一引数に置く必要がない関数だけど、実はこのクラスでしか使わないだよなぁ…」とか「このクラスのメソッドとしてまとめておいたほうがわかりやすいんだよなぁ」という場合は、以下のように実装できます。

class kNN:
    def __init__(self, k:int):
        self.k = k
        self.is_fitted = False

    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        self.is_fitted = True
        return self

    def predict(self, X_test):
        assert self.is_fitted, "先に訓練してから予測して下さい"
        pred_labels = []

        for (loop_counter, x) in enumerate(X_test):
            distance_vector = kNN.compute_distance(x, self.X_train)
            sorted_indexes = np.argsort(distance_vector)
            neighbors = sorted_indexes[:self.k]
            neighbors_label = self.y_train[neighbors]
            popular_label = np.bincount(neighbors_label).argmax()
            pred_labels.append(popular_label)
        return np.array(pred_labels)

    @staticmethod
    def compute_distance(data1, data2):
        distance = ((data1 - data2) ** 2).sum(axis=1)
        return distance
model = kNN(3)
model.fit(X_train,y_train)
pred_labels = model.predict(X_test)

# 正答率
(pred_labels == y_test).sum() / len(X_test)
0.9333333333333333

ここで使った@staticmethodは、これがくっついた関数をselfを受け取らないメソッドにしてくれます。

詳しく説明すると長くなるので、ここではその程度のイメージで覚えておいて下さい。self.compute_distanceとしてもアクセスできますし、selfを使っていないので、kNN.cumpute_distanceとしてもアクセスできます。この場合の「インスタンス.」とか「クラス.」は名前空間を指定する程度のニュアンスでしかないんですね。

これで最低限のk-NNクラスの実装ができたはずです。これをもとに、課題に取り組んでみて下さい。