Planet simulator: Solve simulataneous second order diffrential equations

Download script from .\planet.py


import sys
import csv
import numpy as np
from math import exp, sqrt, sin, cos, pi
import matplotlib.pyplot as plt


"""
  Planet simulator: Solve simulataneous second order diffrential equations
"""



#===================
# constants
#===================
G = 6.67259e-11 #Nm2/kg2
DayToSecond = 60 * 60 * 24 #s
SecondToDay = 1.0 / DayToSecond
AstronomicalUnit = 1.49597870e11 #m
AU = AstronomicalUnit
G1 = G * DayToSecond * DayToSecond / AU / AU / AU

#===================
# parameters
#===================
# algorism to solve differential equations: 'Euler', 'Verlet'
solver = 'Euler'
#solver = 'Verlet'
fplot = 1 # flag to plot graph: 0: not plot, 1: plot

# planet parameter database
dbfile = 'planet_db.csv'
# trajectries of planets
outfile = "diffeq2nd_Planet_{}.csv".format(solver)
# conservation law of total energy (U) and momenta (Px, Py, Pz)
outfile2 = "diffeq2nd_Planet_{}_conservation.csv".format(solver)

# step time to solve diff eq
dt = 0.1
# maximum steps to be calculated
nt = 20000

# display output control
iprint_interval = 100
nprint_planets = 4

# graph output control
iplotdata_interval = 50
iplot_interval = 500
# graph range
xgrange = (-5.0, 5.0)
ygrange = (-5.0, 5.0)

argv = sys.argv
n = len(argv)
if n >= 2:
    solver = argv[1]
if n >= 3:
    dt = float(argv[2])
if n >= 4:
    nt = int(argv[3])
if n >= 5:
    fplot = int(argv[4])


#===================
# functions
#===================
def readdb(dbfile):
    planets = []
    f = open(dbfile, "r");
    reader = csv.DictReader(f)
    for row in reader:
       planets.append(row)
    keys = list(planets[0].keys())
    for d in planets:
        for key in keys:
            if key != 'Name':
                d[key] = float(d[key])
    return planets

def sum(array):
    sum = 0.0
    for e in array:
        sum += e
    return sum

def Ptot(it, M, vx, vy, vz):
    Px, Py, Pz = 0.0, 0.0, 0.0
    Pmsm = 0.0
    np = len(M)
    for i in range(0, np):
        Pxi = M[i] * vx[i][it];
        Pyi = M[i] * vy[i][it];
        Pzi = M[i] * vz[i][it];
        Px += Pxi
        Py += Pyi
        Pz += Pzi
        Pmsm += Pxi*Pxi + Pyi*Pyi + Pzi*Pzi
    Pmsm = sqrt(Pmsm / 3.0 / np)
    return Px, Py, Pz, Pmsm

# Normalize total momentum to zero
def normalize_momentum(it, M, x, y, z, vx, vy, vz, fx, fy, fz):
    Mtot = sum(M)
    Px, Py, Pz, Pmsm = Ptot(it, M, vx, vy, vz);
    print("Pinitial = {}, {}, {}".format(Px, Py, Pz))
    for ip in range(0, len(M)):
        vx[ip][it] -= Px / Mtot;
        vy[ip][it] -= Py / Mtot;
        vz[ip][it] -= Pz / Mtot;
    Px, Py, Pz, Pmsm = Ptot(it, M, vx, vy, vz);
    print("Pnormalized = {}, {}, {}".format(Px, Py, Pz))
    print("")
    return Px, Py, Pz

# set initial normalized positions, velocities, forces
def initialize(planets, M, x, y, z, vx, vy, vz, fx, fy, fz):
    global AU
    global DayToSecond

    for i in range(0, len(planets)):
        M.append(planets[i]['Mass'])
        x.append([planets[i]['Revolution Radius'] / AU])
        y.append([0.0])
        z.append([0.0])
        vx.append([0.0])
        vy.append([planets[i]['Revolution Velocity'] * DayToSecond / AU])
        vz.append([0.0])

    for i in range(0, len(planets)):
        fxi, fyi, fzi = Fi(0, i, M, x, y, z)
        fx.append([fxi])
        fy.append([fyi])
        fz.append([fzi])

# total energy
def Utot(istep, M, x, y, z, vx, vy, vz):
    U = 0.0
    K = 0.0
    for i in range(0, len(M)):
        K += 0.5 * M[i] \
          * (vx[i][istep]*vx[i][istep] + vy[i][istep]*vy[i][istep] + vz[i][istep]*vz[i][istep])
        for j in range(i+1, len(M)):
            dx = x[j][istep] - x[i][istep]
            dy = y[j][istep] - y[i][istep]
            dz = z[j][istep] - z[i][istep]
            r2 = dx*dx + dy*dy + dz*dz
            r = sqrt(r2)
            U += G1 * M[i] * M[j] / r
    return U, K, U + K

# i - j interplanet normalized force devided by i-th planets mass
def Fij(istep, i, j, M, x, y, z):
    dx = x[j][istep] - x[i][istep]
    dy = y[j][istep] - y[i][istep]
    dz = z[j][istep] - z[i][istep]
    r2 = dx*dx + dy*dy + dz*dz
    r = sqrt(r2)
    g = G1 * M[j]
    f = g / r2
    fx = f * dx / r
    fy = f * dy / r
    fz = f * dz / r
    return fx, fy, fz

# normalized force on i-th planet devided by its mass
def Fi(istep, i, M, x, y, z):
    fxi = 0.0
    fyi = 0.0
    fzi = 0.0
    for j in range(0, len(M)):
        if i == j:
            continue
        fxj, fyj, fzj = Fij(istep, i, j, M, x, y, z)
        fxi += fxj
        fyi += fyj
        fzi += fzj
#    print("f={}, {}, {}".format(fxi, fyi, fzi))
    return fxi, fyi, fzi

#===================
# main routine
#===================
def main():
    global plt
    global nt
    global dt

    print("Planet simulator: Solve simulataneous second order diffrential equations by Euler method")
    print("G = {} Nm2/kg2".format(G))
    print("AU = {:e} m".format(AU))
    print("G1 = {}".format(G1))
    print("")

# read planet database
    print("Planets:")
    planets = readdb(dbfile)
    keys = list(planets[0].keys())
    for d in planets:
        print(" ", d['Name'])
        for key in keys:
            if key != 'Name':
                print(" {}: {}".format(key, d[key]))
    print("")

# create list variables and normalize
    M = []
    x = []
    y = []
    z = []
    xg = []
    yg = []
    zg = []
    vx = []
    vy = []
    vz = []
    fx = []
    fy = []
    fz = []
    initialize(planets, M, x, y, z, vx, vy, vz, fx, fy, fz)
    Px, Py, Pz = normalize_momentum(0, M, x, y, z, vx, vy, vz, fx, fy, fz)
    print("")

# make label list for display / csv output
    labellist = ['t']
    for i in range(0, len(planets)):
        labellist.append("x({})".format(planets[i]['Name']))
        labellist.append("y({})".format(planets[i]['Name']))

# open outfile to write a csv files
    print("Write to [{}]".format(outfile))
    f = open(outfile, 'w')
    fout = csv.writer(f, lineterminator='\n')
    fout.writerow(labellist)

    f2 = open(outfile2, 'w')
    fout2 = csv.writer(f2, lineterminator='\n')
    fout2.writerow(['t', 'U', 'K', 'E', 'Px', 'Py', 'Pz', 'Pmsm'])

    print("{:^5}".format('t'), end = '')
    for i in range(1, nprint_planets*2, 2):
        print(" {:^12} {:^12}".format(labellist[i], labellist[i+1]), end = '')
    print("")

# create figure object and axes list
    if fplot == 1:
        fig, ax = plt.subplots(1, 1)
        plots = []

# Solve the 1st data by Euler or Heun method
    datalist = [0.0]
    print("{:^5}".format(0.0), end = '')
    for i in range(0, len(planets)):
        fx0, fy0, fz0 = Fi(0, i, M, x, y, z)
        vx1 = vx[i][0] + dt * fx0
        vy1 = vy[i][0] + dt * fy0
        vz1 = vz[i][0] + dt * fz0
        x1 = x[i][0] + dt * vx[i][0]
        y1 = y[i][0] + dt * vy[i][0]
        z1 = z[i][0] + dt * vz[i][0]

        datalist.append(x[i][0])
        datalist.append(y[i][0])
        x[i].append(x1)
        y[i].append(y1)
        z[i].append(z1)
        vx[i].append(vx1)
        vy[i].append(vy1)
        vz[i].append(vz1)
        if fplot == 1:
            xg.append([x1])
            yg.append([y1])
            zg.append([z1])
            lines, = ax.plot(x[i], y[i], linewidth = 0.3)
            plots.append(lines)
    for i in range(1, nprint_planets*2, 2):
        print(" {:>12.4f} {:>12.4f}".format(x[i][0], y[i][0]), end = '')
    print("")
    fout.writerow(datalist)
    U, K, E = Utot(0, M, x, y, z, vx, vy, vz)
    Px, Py, Pz, Pmsm = Ptot(0, M, vx, vy, vz)
    fout2.writerow([0.0, U, K, E, Px, Py, Pz, Pmsm])

# Solve the 2nd and later steps
    for it in range(1, nt+1):
        t = it * dt
#        print("it={} t={}".format(it, t))
        datalist = [t]
        if it % iprint_interval == 0:
            print("{:^5}".format(t), end = '')
        xmin = 0.0
        xmax = 0.0
        ymin = 0.0
        ymax = 0.0
        for i in range(0, len(planets)):
            fx0, fy0, fz0 = Fi(it, i, M, x, y, z)
            if solver == 'Euler':
                vx1 = vx[i][it] + dt * fx0
                vy1 = vy[i][it] + dt * fy0
                vz1 = vz[i][it] + dt * fz0
                x1 = x[i][it] + dt * vx[i][it]
                y1 = y[i][it] + dt * vy[i][it]
                z1 = z[i][it] + dt * vz[i][it]
            elif solver == 'Verlet':
                x1 = 2.0 * x[i][it] - x[i][it-1] + dt*dt * fx0
                y1 = 2.0 * y[i][it] - y[i][it-1] + dt*dt * fy0
                z1 = 2.0 * z[i][it] - z[i][it-1] + dt*dt * fz0
                vx1 = (x1 - x[i][it-1]) / 2.0 / dt
                vy1 = (y1 - y[i][it-1]) / 2.0 / dt
                vz1 = (z1 - z[i][it-1]) / 2.0 / dt

            datalist.append(x[i][it])
            datalist.append(y[i][it])
            x[i].append(x1)
            y[i].append(y1)
            z[i].append(z1)
            vx[i].append(vx1)
            vy[i].append(vy1)
            vz[i].append(vz1)
            if fplot and (it % iplotdata_interval == 0):
                xg[i].append(x1)
                yg[i].append(y1)
                zg[i].append(z1)
# add trajectry data (x[i], y[i]) to the axes object plaots[i]
# get x- and y-ranges to be displayed in the graph
            if fplot and i <= 6:
                plots[i].set_data(xg[i], yg[i])
                xmin = min([xmin] + x[i])
                xmax = max([xmax] + x[i])
                ymin = min([ymin] + y[i])
                ymax = max([ymax] + y[i])
# display output every iprint_interval steps
        if it % iprint_interval == 0:
            for i in range(1, nprint_planets*2, 2):
                print(" {:>12.4g} {:>12.4g}".format(x[i][it], y[i][it]), end = '')
            print("")
# write to trajectory csv file
        fout.writerow(datalist)
# write to conservation csv file
        U, K, E = Utot(it, M, x, y, z, vx, vy, vz)
        Px, Py, Pz, Pmsm = Ptot(it, M, vx, vy, vz)
        fout2.writerow([t, U, K, E, Px, Py, Pz, Pmsm])
# update the graph every iplot_interval steps
        if fplot and it % iplot_interval == 0:
            ax.set_xlim(xgrange)
            ax.set_ylim(ygrange)
#            ax.set_xlim((xmin, xmax))
#            ax.set_ylim((ymin, ymax))
            plt.pause(1.e-10)

    f.close()

    print("Press ENTER to exit>>", end = '')
    input()

    exit()


if __name__ == '__main__':
    main()