perceptron.py

perceptron.py#

先ほどのノートブックのコードから学習と簡単な可視化だけ取り出したPython scriptを以下に示します.(このプログラムをperceptron.pyとして保存してください.)

# packageのimport
from typing import Any, Union, Callable, Type, TypeVar
import numpy as np 
import numpy.typing as npt
import pandas as pd 
import matplotlib.pyplot as plt 
import japanize_matplotlib
import argparse
plt.style.use("bmh")


def parse_args():
    parser = argparse.ArgumentParser(description="Perceptronで論理ゲートを再現する")
    parser.add_argument("--gate_type", type=str, default="or")
    parser.add_argument('--learning_rate', type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--save_path", type=str, default="./perceptron.png")
    return parser.parse_args()


def sign(x:float)->float:
    if x > 0:
        return 1.
    return -1

def perceptron(
        X:list[float,float], 
        W:list[float,float],
        b:float)->np.ndarray:
    h = 0.
    for _x, _w in zip(X,W):
        h += _x * _w
    h += b
    return sign(h)

def plot_perceptron(w,b, truth_table, gate_type="or", need_output=False):
    linear = lambda x1,w,b: -(w[0]*x1 + b)/w[1]
    x1_sample = np.linspace(-2,2,100)

    fig, ax = plt.subplots()
    ax.scatter(truth_table.x1, truth_table.x2, c=truth_table[gate_type], label="入力データ")
    ax.plot(x1_sample, linear(x1_sample, W,b), label="閾値, 識別境界")
    ax.set_title(f"$x_1$-$x_2$平面上の{gate_type}問題")
    ax.set_xlabel("$x_1$")
    ax.set_ylabel("$x_2$")
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    plt.legend()
    #ax.set_aspect('equal')
    if need_output:
        return fig
    
def update_perceptron(X,Y, lr=1.0, rng = np.random.default_rng(10)):
    W = rng.normal(0,1, [X.shape[-1]])
    b = float(rng.normal(0,1, 1))
    diff = np.inf
    while diff>0:
        diff = 0
        for (x,y) in zip(X,Y):
            y_hat = perceptron(x, W, b)
            W += lr*(y-y_hat)*x
            b += lr*(y-y_hat)
            diff += y != y_hat
    return W, b

if __name__ == "__main__":
    # 真偽値表の定義
    args = parse_args()
    truth_table = pd.DataFrame(
        np.array([[1,1,0,0],[1,0,1,0]]).T,
        columns=["x1","x2"]
    )
    truth_table["or"] = truth_table.x1 | truth_table.x2
    truth_table["and"] = truth_table.x1 & truth_table.x2
    truth_table["xor"] = np.logical_xor(truth_table.x1,truth_table.x2).astype(truth_table.x1.dtype)
    truth_table["nor"] = [0,0,0,1]

    truth_table[truth_table == 0] = -1

    W,b = update_perceptron(
        truth_table[["x1","x2"]].to_numpy(),
        truth_table[args.gate_type].to_numpy(),
        rng=np.random.default_rng(args.seed),
        lr=args.learning_rate
        )
    fig = plot_perceptron(W,b,truth_table=truth_table, gate_type=args.gate_type,need_output=True)

    Y_hat = []
    for x in truth_table[["x1","x2"]].to_numpy():
        y_hat = perceptron(x, W,b)
        Y_hat.append(y_hat)
    print("正解:",truth_table[args.gate_type].to_numpy(),)
    print("予測:",Y_hat)
    fig.savefig(args.save_path)

perceptron.pyは例えば以下のようにして,端末エミュレータから実行することができます.

python perceptron.py --gate_type=nor