チラ裏備忘録

情報整理

PyTorchで線形回帰(二次関数)

はじめに

PyTorchで単純な線形回帰を行いたいと思います.
nn.Linear()を使ってやるのもいいのですが,今回は,重みとバイアス用のTensorを自分で定義する方針で行いました.

色々読み込み

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

データの準備

N = 100
x = np.linspace(-15, 15, N)
y = -2 * x**2 + 10*x + 32*np.random.randn(N)

x = x.astype('float32')
y = y.astype('float32')

x = torch.from_numpy(x)
y = torch.from_numpy(y)

w0 = torch.tensor(1.0, requires_grad=True)
w1 = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

今回は, y = -2 x^2 + 10 x + Noiseを線形回帰させます.
よって,なるべく w_0 = -2, w_1 = 10に近づくことが理想です.
また,バイアス項は,N(0, 1)に従う乱数によってバラつくので,0付近に近づくと思います.

グラフ

f:id:spookyboogie:20200926233557p:plain

モデル定義

def model(x):
  return w0*x**2 + w1*x + b

PyTorchのモデルというと,class MyModel(nn.Module): という定義の仕方が主ですが,
今回はnnモジュールを使った複雑な計算は必要としないので,関数として定義します.

損失関数と最適化手法の定義

criterion = nn.MSELoss()
optimizer = torch.optim.SGD([w0, w1, b], lr=1.0e-5)

損失関数は平均二乗誤差,最適化手法はSGDで学習率は0.00001としました.
また,勾配計算によって更新したいパラメータは,第一引数にリストで渡してやれば良いっぽいです.
(nn.Moduleによる定義では,モデル全体のパラメータを対象にしたいため,model.parameters()としていました)

学習

losses = []
epochs = 10000
es_count = 0
patience = 3

for epoch in range(epochs):
  optimizer.zero_grad()

  pred = model(x.view(-1, 1))

  loss = criterion(pred, y.view(-1, 1))
  loss.backward()
  optimizer.step()

  losses.append(loss.item())
  if epoch % 100 == 0:
    print("epoch {}, loss: {}".format(epoch+1, losses[epoch]))

  # Early Stopping
  if epoch > 0 and losses[epoch - 1] < losses[epoch]:
    es_count += 1
    if es_count >= patience:
      break
  else:
    es_count = 0

print('loss:', loss.item())
print('w0:', w0.item())
print('w1:', w1.item())
print('b:', b.item())

100epochごとに結果を表示します.

es_countというのは,EarlyStoppingという機能を実装するための変数です.
EarlyStoppingは,指定の回数以上損失が上がり続けた段階で学習をストップさせるというものです.
具体的には,1epoch前のlossと比較し,上回っていればカウント.3(=patience)回連続で続けば学習ストップ,としています.
(今回は10000epochまでlossが上昇することはなかったため,不要でしたが・・・)

結果

epoch 1, loss: 102432.0390625
epoch 101, loss: 5998.56396484375
epoch 201, loss: 4711.43408203125
epoch 301, loss: 3763.884765625
epoch 401, loss: 3066.32275390625
epoch 501, loss: 2552.794921875
epoch 601, loss: 2174.745361328125
epoch 701, loss: 1896.4339599609375
epoch 801, loss: 1691.5440673828125
epoch 901, loss: 1540.706298828125
epoch 1001, loss: 1429.658447265625

(中略)

epoch 9001, loss: 1118.526123046875
epoch 9101, loss: 1118.5107421875
epoch 9201, loss: 1118.49560546875
epoch 9301, loss: 1118.480224609375
epoch 9401, loss: 1118.4647216796875
epoch 9501, loss: 1118.4493408203125
epoch 9601, loss: 1118.434326171875
epoch 9701, loss: 1118.4189453125
epoch 9801, loss: 1118.403564453125
epoch 9901, loss: 1118.3885498046875
loss: 1118.3734130859375
w0: -2.0004830360412598
w1: 10.306112289428711
b: 0.38533857464790344

 w_0は-2に, w_1は10に近づいているので,それなりに近似できていることがわかります.

グラフ

f:id:spookyboogie:20200926233611p:plain
青が元の関数,オレンジが予測です.
見事に重なっていますね.

lossの推移(おまけ)

f:id:spookyboogie:20200926235032p:plain

断崖絶壁です.これを見る限り,1000epochもすれば十分そうですね.