PyTorchで線形回帰(二次関数)
はじめに
PyTorchで単純な線形回帰を行いたいと思います.
nn.Linear()を使ってやるのもいいのですが,今回は,重みとバイアス用のTensorを自分で定義する方針で行いました.
色々読み込み
import torch import torch.nn as nn import matplotlib.pyplot as plt import numpy as np
データの準備
N = 100 x = np.linspace(-15, 15, N) y = -2 * x**2 + 10*x + 32*np.random.randn(N) x = x.astype('float32') y = y.astype('float32') x = torch.from_numpy(x) y = torch.from_numpy(y) w0 = torch.tensor(1.0, requires_grad=True) w1 = torch.tensor(1.0, requires_grad=True) b = torch.tensor(0.0, requires_grad=True)
今回は,を線形回帰させます.
よって,なるべくに近づくことが理想です.
また,バイアス項は,N(0, 1)に従う乱数によってバラつくので,0付近に近づくと思います.
グラフ
モデル定義
def model(x): return w0*x**2 + w1*x + b
PyTorchのモデルというと,class MyModel(nn.Module): という定義の仕方が主ですが,
今回はnnモジュールを使った複雑な計算は必要としないので,関数として定義します.
損失関数と最適化手法の定義
criterion = nn.MSELoss()
optimizer = torch.optim.SGD([w0, w1, b], lr=1.0e-5)
損失関数は平均二乗誤差,最適化手法はSGDで学習率は0.00001としました.
また,勾配計算によって更新したいパラメータは,第一引数にリストで渡してやれば良いっぽいです.
(nn.Moduleによる定義では,モデル全体のパラメータを対象にしたいため,model.parameters()としていました)
学習
losses = [] epochs = 10000 es_count = 0 patience = 3 for epoch in range(epochs): optimizer.zero_grad() pred = model(x.view(-1, 1)) loss = criterion(pred, y.view(-1, 1)) loss.backward() optimizer.step() losses.append(loss.item()) if epoch % 100 == 0: print("epoch {}, loss: {}".format(epoch+1, losses[epoch])) # Early Stopping if epoch > 0 and losses[epoch - 1] < losses[epoch]: es_count += 1 if es_count >= patience: break else: es_count = 0 print('loss:', loss.item()) print('w0:', w0.item()) print('w1:', w1.item()) print('b:', b.item())
100epochごとに結果を表示します.
es_countというのは,EarlyStoppingという機能を実装するための変数です.
EarlyStoppingは,指定の回数以上損失が上がり続けた段階で学習をストップさせるというものです.
具体的には,1epoch前のlossと比較し,上回っていればカウント.3(=patience)回連続で続けば学習ストップ,としています.
(今回は10000epochまでlossが上昇することはなかったため,不要でしたが・・・)
結果
epoch 1, loss: 102432.0390625 epoch 101, loss: 5998.56396484375 epoch 201, loss: 4711.43408203125 epoch 301, loss: 3763.884765625 epoch 401, loss: 3066.32275390625 epoch 501, loss: 2552.794921875 epoch 601, loss: 2174.745361328125 epoch 701, loss: 1896.4339599609375 epoch 801, loss: 1691.5440673828125 epoch 901, loss: 1540.706298828125 epoch 1001, loss: 1429.658447265625 (中略) epoch 9001, loss: 1118.526123046875 epoch 9101, loss: 1118.5107421875 epoch 9201, loss: 1118.49560546875 epoch 9301, loss: 1118.480224609375 epoch 9401, loss: 1118.4647216796875 epoch 9501, loss: 1118.4493408203125 epoch 9601, loss: 1118.434326171875 epoch 9701, loss: 1118.4189453125 epoch 9801, loss: 1118.403564453125 epoch 9901, loss: 1118.3885498046875
loss: 1118.3734130859375 w0: -2.0004830360412598 w1: 10.306112289428711 b: 0.38533857464790344
は-2に,は10に近づいているので,それなりに近似できていることがわかります.
グラフ
青が元の関数,オレンジが予測です.
見事に重なっていますね.
lossの推移(おまけ)
断崖絶壁です.これを見る限り,1000epochもすれば十分そうですね.