scipy.optimize.leastsqを使って多項式最小二乗の結果をグラフにプロット
フィッティング範囲を起動時引数で変える
import sys
import csv
from pprint import pprint
from math import sqrt
import numpy as np
from scipy import optimize
from matplotlib import pyplot as plt
"""
scipy.optimize.leastsqを使って多項式最小二乗の結果をグラフにプロット
フィッティング範囲を起動時引数で変える
"""
#=============================
# 大域変数の定義
#=============================
# CSVファイル
infile = 'data.csv'
# フィッティングパラメータ初期値。線形最小二乗の場合は適当
ai0 = [0, 0, 0]
fitrange = [-1e10, 1e10]
# グラフのフォントサイズ
fontsize = 24
# 起動時引数でフィッティング範囲を変更
argv = sys.argv
print("argv=", argv)
if len(argv) >= 2:
fitrange[0] = float(argv[1])
if len(argv) >= 3:
fitrange[1] = float(argv[2])
print("")
print("fitting range: ", fitrange)
#=============================
# 最小化する関数の定義
#=============================
def ycal(ai, x):
return ai[0] + ai[1] * x + ai[2] * x * x
def residual(ai, x, y):
res = []
for i in range(len(x)):
res.append(y[i] - ycal(ai, x[i]))
return res
#=============================
# csvファイルの読み込み
#=============================
i = 0
x = []
y = []
with open(infile, "r") as f:
reader = csv.reader(f)
for row in reader:
if i == 0:
header = row
else:
xi = float(row[0])
if fitrange[0] <= xi <= fitrange[1]:
x.append(xi)
y.append(float(row[1]))
i = i + 1
print("")
print("CSV data:")
print(" header:", header)
print(" x:", x)
print(" y:", y)
#=============================
# scipy.optimize()による最小化
#=============================
print("")
print("polynomial fit by scipy.optimize() start:")
# leastsqの戻り値は、最適化したパラメータのリストと、最適化の結果
ai, cov_x, inf, mesg, ier = optimize.leastsq(residual, ai0, args= (x, y), full_output = True)
print(" lsq result: ai=", ai)
res = sqrt(sum(inf['fvec']*inf['fvec']) / len(x))
print(" residual=", res)
print(" y = {} + {} * x + {} * x^2".format(ai[0], ai[1], ai[2]))
#=============================
# グラフの表示
#=============================
#表示データの作成
xmin = min(x)
xmax = max(x)
ncal = 100
xstep = (xmax - xmin) / (ncal - 1)
xc = []
yc = []
for i in range(ncal):
xi = xmin + i * xstep
yi = ycal(ai, xi)
xc.append(xi)
yc.append(yi)
#グラフの作成、表示
plt.plot(x, y, label='raw data', marker = 'o', linestyle = 'None')
plt.plot(xc, yc, label='fitted', linestyle = 'dashed')
plt.title(infile, fontsize = fontsize)
plt.xlabel(header[0], fontsize = fontsize)
plt.ylabel(header[1], fontsize = fontsize)
plt.legend(fontsize = fontsize)
plt.tick_params(labelsize = fontsize)
plt.tight_layout()
plt.show()