1D band calculation by plain wave basis set

Download script from pw1d.py


import sys
import numpy as np
from numpy import sqrt, exp, sin, cos, tan, pi
import numpy.linalg as LA
import csv
from matplotlib import pyplot as plt

"""
1D band calculation by plain wave basis set
"""


#===================================
# physical constants
#===================================
pi = 3.14159265358979323846
pi2 = 2.0 * pi
h = 6.6260755e-34 # Js";
hbar = 1.05459e-34 # "Js";
c = 2.99792458e8 # m/s";
e = 1.60218e-19 # C";
e0 = 8.854418782e-12; # C2N-1m-2";
kB = 1.380658e-23 # JK-1";
me = 9.1093897e-31 # kg";
R = 8.314462618 # J/K/mol
a0 = 5.29177e-11 # m";

KE = hbar * hbar / 2.0 / me / e * 1.0e20 # coefficient of kinetic energy
#print("KE = hbar*hbar/(2m)/e*1.0e20 = {}", KE)


#========================
# global configuration
#========================
mode = 'ft' # ft|band|wf

#========================
# Crystal definition
#========================
# Si
a = 5.4064 # angstrom, lattice parameter
na = 64 # division for FFT, must be 2^n

#========================
# Potential
#========================
pottype = 'rect'
#pottype = 'gauss'
bwidth = 0.5 # A, barrier width
bpot = 10.0 # eV, barrier height

#========================
# Band
#========================
nG = 3 # # of G points (# of basis functions)
kmin = -0.5 # in pi/a
kmax = 0.5 # in pi/a
nk = 21

# プロットするエネルギー範囲
Erange = [0.0, 10.0] # eV

#========================
# Wave function
#========================
xwmin = 0.0 # A
xwmax = 3.0 * a # A
nxw = 101
kw = 0.0
iLevel = 0

#===================================
# figure configuration
#===================================
figsize = (6, 8)
fontsize = 12
legend_fontsize = 12


#==============================================
# fundamental functions
#==============================================
# 実数値に変換できない文字列をfloat()で変換するとエラーになってプログラムが終了する
# この関数は、変換できなかったらNoneを返すが、プログラムは終了させない
def pfloat(str):
    try:
        return float(str)
    except:
        return None

# pfloat()のint版
def pint(str):
    try:
        return int(str)
    except:
        return None

# 起動時引数を取得するsys.argリスト変数は、範囲外のindexを渡すとエラーになってプログラムが終了する
# egtarg()では、範囲外のindexを渡したときは、defvalを返す
def getarg(position, defval = None):
    try:
        return sys.argv[position]
    except:
        return defval

# 起動時引数を実数に変換して返す
def getfloatarg(position, defval = None):
    return pfloat(getarg(position, defval))

# 起動時引数を整数値に変換して返す
def getintarg(position, defval = None):
    return pint(getarg(position, defval))

def usage():
    print("")
    print("Usage: Variables in () are optional")
    print(" python {}".format(sys.argv[0]))
    print(" python {} (ft a na pottype bwidth bpot)".format(sys.argv[0]))
    print(" python {} (band a na pottype bwidth bpot nG kmin kmax nk)".format(sys.argv[0]))
    print(" python {} (wf a na pottype bwidth bpot nG kw iLevel xwmin xwmax nxw)".format(sys.argv[0]))
    print(" pottype: rect|gauss")
    print(" ex: python {} {} {} {} {} {} {}"
            .format(sys.argv[0], 'ft', a, na, pottype, bwidth, bpot))
    print(" ex: python {} {} {} {} {} {} {} {} {} {} {}"
            .format(sys.argv[0], 'band', a, na, pottype, bwidth, bpot, nG, kmin, kmax, nk))
    print(" ex: python {} {} {} {} {} {} {} {} {} {} {} {} {}"
            .format(sys.argv[0], 'wf', a, na, pottype, bwidth, bpot, nG, kw, iLevel, xwmin, xwmax, nxw))

def terminate(message = None):
    print("")
    if message is not None:
        print("")
        print(message)
        print("")

    usage()
    print("")
    exit()

#==============================================
# update default values by startup arguments
#==============================================
argv = sys.argv
#if len(argv) == 1:
#    terminate()

mode = getarg (1, mode)
a = getfloatarg(2, a)
na = getintarg (3, na)
pottype = getarg (4, pottype)
bwidth = getfloatarg(5, bwidth)
bpot = getfloatarg(6, bpot)
if mode == 'band':
    nG = getintarg ( 7, nG)
    kmin = getfloatarg( 8, kmin)
    kmax = getfloatarg( 9, kmax)
    nk = getintarg (10, nk)
elif mode == 'wf':
    nG = getintarg ( 7, nG)
    kw = getfloatarg( 8, kw)
    iLevel = getintarg ( 9, iLevel)
    xwmin = getfloatarg(10, xwmin)
    xwmax = getfloatarg(11, xwmax)
    nxw = getintarg (12, nxw)


def reduce(x, x0):
    n = int(x / x0)
    if x < 0.0:
        n += 1
    return x - x0 * n

# rectangular barrier potential
def pot(x):
    global a
    global pottype, bwidth, bpot

    xred = reduce(x, a)

    if pottype == 'rect':
        if 0.0 <= xred <= bwidth:
            return bpot
        return 0.0
    if pottype == 'gauss':
        xx = (xred - 0.5 * a) / bwidth
        return bpot * exp(-xx * xx)

def build_potential(xmin, xstep, n):
    xpot = np.empty(n)
    ypot = np.empty(n)
    for i in range(n):
        xx = xmin + i * xstep
        xpot[i] = xx
        ypot[i] = pot(xx)
    return xpot, ypot

# Glistのうち、与えられた逆格子の内部座標 dij に対応するインデックスを返す
def find_iG(dij, Glist):
    for iG in range(len(Vft)):
        if Glist[iG] == dij:
            return iG
    return None

# 与えられた逆格子の内部座標 dij に対応するVftを返す
def find_Vft(dij, Glist, Vft):
    for iG in range(len(Vft)):
        if Glist[iG] == dij:
            return Vft[iG]
    return 0.0

# ポテンシャルのフーリエ変換を取る
# xft0, yft0は FFT の計算結果で、G < 0の成分は i>nahalf にある
# i>nahalfをG<0に移動し、横軸を G に変換したのが xft, yft
def cal_fft(na, a, ypot):
    xftstep = 1.0 / a
    astep = a / na
    nahalf = int(na/2)
    xftmax = nahalf * xftstep

    xft0 =range(na)
    yft0 = np.fft.fft(ypot) * astep / a

# note: The fft result has periodicity
#       The second half at i >= nahalf data should be shifted to negative x
    xft = np.arange(-xftmax, xftmax, xftstep)
    yft = np.empty(na, dtype = complex)
    for i in range(nahalf):
        yft[i] = yft0[nahalf + i]
    for i in range(nahalf):
        yft[i+nahalf] = yft0[i]

    iGlist = []
    for i in range(na):
        if i <= nahalf:
            iG = i
        else:
            iG = -(na - i)
        iGlist.append(iG)

    return xft0, yft0, xft, yft, iGlist, nahalf, xftstep

# Extract G basis set for given nG
# G range will be -Gmax - +Gmax
def extract_basis(yft0, nG):
    Glist = np.empty(nG, dtype = int)
    Vftlist = np.empty(nG, dtype = complex)
    nGmax = int(nG / 2)
    if nG == 2:
        Glist[0] = 0
        Vftlist[0] = yft0[0]
        Glist[1] = 1
        Vftlist[1] = yft0[1]
    else:
        for i in range(nGmax + 1):
            Glist[i] = i
            Vftlist[i] = yft0[i]
            Glist[i + nGmax] = -i
            Vftlist[i + nGmax] = yft0[-i]
    return Glist, Vftlist, nGmax

# 自由電子の運動エネルギーをG点内部座標 iG から計算
def free_KE(k, iG):
    kK = KE * pow(2.0*pi/a, 2.0) # coefficient of kinetic energy in eV and angstrom
    return kK * (k + iG) * (k + iG)

# 固有方程式を解き、固有エネルギーと平面波の係数を返す
# 波動関数は規格化されていない
def solve_pw(Glist, Vftlist, nGmax, nG, k, IsPrint = 0):
    kK = KE * pow(2.0*pi/a, 2.0) # coefficient of kinetic energy in eV and angstrom

# build Fock matrix
    Hij = np.empty([nG, nG], dtype = complex)
    for i in range(nG):
        iG = Glist[i]
        Vij = find_Vft(0, Glist, Vftlist)
        Hij[i][i] = free_KE(k, iG) + Vij
        for j in range(i+1, nG):
            jG = Glist[j]
            dij = iG - jG
            Vij = find_Vft(dij, Glist, Vftlist)
            if IsPrint >= 1:
                print(" iG,jG,dij=", iG, jG, dij, " nG(base)=[{}, {}]".format(-nGmax, nGmax), " Vij=", Vij)
            Hij[i][j] = Vij
            Hij[j][i] = Hij[i][j].conjugate()

#        print(" Hij=\n", Hij)
    ei, ci = LA.eig(Hij)
    if IsPrint >= 1:
        print(" ei=", ei)
    if IsPrint >= 2:
        print(" ci=", ci)

    return ei, ci

# 平面波の係数から波動関数を計算
# i番目の基底の係数とG点内部座標を ci, Glistで渡す
# 波動関数は規格化されていない
def cal_wf(xwmin, xwmax, nxw, kw, ci, Glist):
    nG = len(ci)
    xwstep = (xwmax - xwmin) / (nxw - 1)
    xwf = np.empty(nxw, dtype = complex)
    ywf = np.empty(nxw, dtype = complex)
    ywf2 = np.empty(nxw, dtype = float)
    for i in range(nxw):
        x = xwmin + i * xwstep
        f = 0.0
        for iG in range(nG):
            G = Glist[iG]
            f += ci[iG] * exp(1.0j * (kw+G) * pi2/a * x)
#        print("f({})={}".format(x, f))
        xwf[i] = x
        ywf[i] = f
        charge = f * f.conjugate()
        ywf2[i] = charge.real
    return xwf, ywf, charge

def wf():
    global mode
    global a, na
    global kw, iG
    global xwmin, xwmax, nxw

    xwstep = (xwmax - xwmin) / (nxw - 1)

    print("")
    print("=== Input parameterss ===")
    print("mode:", mode)
    print("a=", a, "A")
    print(" na=", na)
    print("potential: {} w={} A h={} eV".format(pottype, bwidth, bpot))
    print("Wave function to be plotted: k = {} iLevel = {}".format(kw, iLevel))
    print("x range: {} - {} at {} step, {} points".format(xwmin, xwmax, xwstep, nxw))
    print("potential: {} w={} A h={} eV".format(pottype, bwidth, bpot))

    astep = a / na
    xpot, ypot = build_potential(0.0, astep, na)
    xplot, yplot = build_potential(xwmin, xwstep, nxw)

    xft0, yft0, xft, yft, iGlist, nahalf, xftstep = cal_fft(na, a, ypot)

    print("")
    print("=== FT result ===")
    print("{:4} {:4} {}".format("i", "iG", "ci"))
    for i in range(na):
        print("{:4d} {:4d} {}".format(i, iGlist[i], yft0[i]))

# Extract G basis set for given nG
    Glist, Vftlist, nGmax = extract_basis(yft0, nG)

    print("")
    print("=== G basis extracted ===")
    print("nG =", nG)
    print("{:4} {}".format("iG", "Vft"))
    for i in range(nG):
        print("{:4d} {}".format(Glist[i], Vftlist[i]))

    print("")
    print("=== Solve eigen equations ===")
    k = kw
    print("at k = {}".format(k))

    yE, ci = solve_pw(Glist, Vftlist, nGmax, nG, k, IsPrint = 1)
    print(" ei=", yE)
    print(" ci=", ci)

    print("")
    print("=== Calculate wave function ===")
    print("Energy levels:")
    for i in range(len(yE)):
        print(" {} {:12.6g} eV".format(i, yE[i].real))

    print("")
    print("at k = {}".format(k))
    print(" selected for {}-th energy level:".format(iLevel))
    print(" E = {:12.6g} eV".format(yE[iLevel].real))
    print(" Wave function coefficinets")
    ciLevel = [ci[i][iLevel] for i in range(nG)]
    for i in range(nG):
        print(" at iG={:3d}".format(Glist[i]), " ci={:8.4f}".format(ci[i][iLevel]))

    xwf, ywf, charge = cal_wf(xwmin, xwmax, nxw, kw, ciLevel, Glist)

    charge = [(f0 * f0.conjugate()).real for f0 in ywf]
    print("c")

    fig = plt.figure(figsize = (8, 8))
    ax2 = fig.add_subplot(1, 1, 1)
#    ax2 = fig.add_subplot(2, 1, 2)
# ax2をax1に関連させる
    ax1 = ax2.twinx()

    ax1.set_xlim([xwmin, xwmax])
    ax1.plot(xplot, yplot, linewidth = 0.5, label = '$V$($x$)')
    ax1.plot(ax1.get_xlim(), [0.0, 0.0], color = 'r', linestyle = 'dashed', linewidth = 0.5)
    ax2.set_xlim([xwmin, xwmax])
    ax2.plot(xplot, ywf.real, linewidth = 1.5, label = "$\Psi$(real)")
    ax2.plot(xplot, ywf.imag, linewidth = 1.5, label = "$\Psi$(imaginary)")
    ax2.plot(xplot, charge, linewidth = 0.5, label = "|$\Psi$|$^2$")
    ax2.plot(ax1.get_xlim(), [0.0, 0.0], color = 'r', linestyle = 'dashed', linewidth = 0.5)

#    ax1.set_xlabel("$x$ ($\AA$)", fontsize = fontsize)
    ax1.set_ylabel("$V$($x$)", fontsize = fontsize)
    ax2.set_xlabel("$x$ ($\AA$)", fontsize = fontsize)
    ax2.set_ylabel("$\Psi$", fontsize = fontsize)

# 凡例をまとめて出力する
    handler1, label1 = ax1.get_legend_handles_labels()
    handler2, label2 = ax2.get_legend_handles_labels()
    ax2.legend(handler1 + handler2, label1 + label2, loc = 2, borderaxespad = 0.0, fontsize = legend_fontsize)
#    ax2.legend(fontsize = legend_fontsize)
    ax1.tick_params(labelsize = fontsize)
    ax2.tick_params(labelsize = fontsize)
    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()

def band():
    global mode
    global a, na
    global nG
    global kmin, kmax, nk
    global pottype, bwidth, bpot

    kstep = (kmax - kmin) / (nk - 1)

    print("")
    print("=== Input parameterss ===")
    print("mode:", mode)
    print("a=", a, "A")
    print(" na=", na)
    print("potential: {} w={} A h={} eV".format(pottype, bwidth, bpot))
    print("Basis function: nG=", nG)
    if nG >= 4:
        nG = int(nG / 2) * 2 + 1 # convert to an odd number for nG >= 4
    elif nG <= 0 or 64 < nG:
        terminate("Error: nG must be between 1 and 64 [nG={}]".format(nG))
    print("k range: {} - {} at {} step, {} points".format(kmin, kmax, kstep, nk))
    print("potential: {} w={} A h={} eV".format(pottype, bwidth, bpot))

    astep = a / na
    xpot, ypot = build_potential(0.0, astep, na)

    xft0, yft0, xft, yft, iGlist, nahalf, xftstep = cal_fft(na, a, ypot)

    print("")
    print("=== FT result ===")
    print("{:4} {:4} {}".format("i", "iG", "ci"))
    for i in range(na):
        print("{:4d} {:4d} {}".format(i, iGlist[i], yft0[i]))

# Extract G basis set for given nG
    Glist, Vftlist, nGmax = extract_basis(yft0, nG)

    print("")
    print("=== FTed potential ===")
    print("nG =", nG)
    print("{:4} {}".format("iG", "Vft"))
    for i in range(nG):
        print("{:4d} {}".format(Glist[i], Vftlist[i]))

    print("")
    print("=== Solve eigen equations ===")
    xk = [kmin + i * kstep for i in range(nk)]
    yE = np.empty([nG, nk])
    yEfree = np.empty([nG, nk])
    for ik in range(nk):
        k = kmin + ik * kstep
        print("at k = {}".format(k))

        ei,ci = solve_pw(Glist, Vftlist, nGmax, nG, k, IsPrint = 1)

        for i in range(nG):
            yE[i][ik] = ei[i]
            iG = Glist[i]
            print("k=", k, iG, k+iG)
            yEfree[i][ik] = free_KE(k, iG)

    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(1, 1, 1)

    ax1.set_xlim([-0.5, 0.5])
    ax1.set_ylim(Erange)
    for iG in range(nG):
        ax1.plot(xk, yE[iG], linestyle = '', marker = 'o', markersize = 2.0,
                        markerfacecolor = 'none', markeredgecolor = 'black', markeredgewidth = 0.5) #, label = 'with V(x)')
        ax1.plot(xk, yEfree[iG], linestyle = '-', linewidth = 0.3, color = 'red')# , label = 'free e')

    ax1.set_xlabel("k ($\pi$/a)", fontsize = fontsize)
    ax1.set_ylabel("E (eV)", fontsize = fontsize)
    ax1.legend(fontsize = legend_fontsize)
    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()

def ft():
    global mode
    global a, na
    global pottype, bwidth, bpot

    astep = a / na

    print("")
    print("=== Input parameterss ===")
    print("mode:", mode)
    print("a=", a, "A")
    print(" na=", na)
    print(" astep=", astep)
    print("potential: {} w={} A h={} eV".format(pottype, bwidth, bpot))
    print("")

    xpot, ypot = build_potential(0.0, astep, na)
    xplot, yplot = build_potential(0.0, astep, 3 * na)

    xft0, yft0, xft, yft, iG, nahalf, xftstep = cal_fft(na, a, ypot)

    fig = plt.figure(figsize = (8, 8))
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 3)
    ax3 = fig.add_subplot(2, 2, 4)

    ax1.plot(xplot, yplot, linewidth = 0.5)
    ax2.plot(xft0, yft0.real, label = "real", linewidth = 0.5)
    ax2.plot(xft0, yft0.imag, label = "imaginary", linewidth = 0.5)
    ax3.plot(xft, yft.real, marker = 'o', markersize = 1.0, linewidth = 0.5, label = "real")
    ax3.plot(xft, yft.imag, marker = 'o', markersize = 1.0, linewidth = 0.5, label = "imaginary")
    ax3.plot(xft, abs(yft), marker = 'o', markersize = 1.0, linewidth = 0.5, label = "absolute")
    ax1.set_xlabel("x ($\AA$)", fontsize = fontsize)
    ax1.set_ylabel("pot (x)", fontsize = fontsize)
    ax2.set_xlabel("i", fontsize = fontsize)
    ax2.set_ylabel("FTed pot, raw x", fontsize = fontsize)
    ax3.set_xlabel("1/x, normalized (1/A)", fontsize = fontsize)
    ax3.set_ylabel("FTed pot")
    ax3.set_xlim([-10.0 * xftstep, 10.0 * xftstep])
#    ax1.legend(fontsize = legend_fontsize)
    ax2.legend(fontsize = legend_fontsize)
    ax3.legend(fontsize = legend_fontsize)
    ax1.tick_params(labelsize = fontsize)
    ax2.tick_params(labelsize = fontsize)
    ax3.tick_params(labelsize = fontsize)
    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()

def main():
    global mode

    if mode == 'ft':
        ft()
    elif mode == 'band':
        band()
    elif mode == 'wf':
        wf()
    else:
        terminate("Error: Invalid mode [{}]".format(mode))


if __name__ == "__main__":
    main()