Calculate 1D wave function by transfer matrix method

Download script from transfer_matrix.py


import sys
import numpy as np
from numpy import sqrt, exp, log, sin, cos, tan, cosh, sinh
import numpy.linalg as LA
import csv
from matplotlib import pyplot as plt



"""
Calculate 1D wave function by transfer matrix method
"""


#===================================
# physical constants
#===================================
pi = 3.14159265358979323846
pi2 = 2.0 * pi
h = 6.6260755e-34 # Js";
hbar = 1.05459e-34 # "Js";
c = 2.99792458e8 # m/s";
e = 1.60218e-19 # C";
e0 = 8.854418782e-12; # C2N-1m-2";
kB = 1.380658e-23 # JK-1";
me = 9.1093897e-31 # kg";
R = 8.314462618 # J/K/mol
a0 = 5.29177e-11 # m";


#========================
# global configuration
#========================
mode = 'wf' # wf|tr

#========================
# Range
#========================
zmin = -20.0 # zL, in angstrom
zmax = 200.0 # zR, in angstrom
nz = 201

#========================
# potential
#========================
pottype = 'mqw'
wellwidth = 20.0 # A
barrierwidth = 1.0 # A
barrierheight = 1.0 # eV
nbarriers = 10

#========================
# for wave functin plot
#========================
Ez0 = 0.1 # eV

#========================
# for transmission probability
#========================
Emin = 0.01 # eV
Emax = 1.0 # eV
nE = 1001


# for Si
#wellwidth = 5.4064 - 0.5 # A
#barrierwidth = 0.5 # A
#barrierheight = 10.0 # eV
#nbarriers = 10
#Emax = 9.5 # eV
#zmax = 70.0 # zR, in angstrom


#===================================
# figure configuration
#===================================
figsize = (8, 8)
fontsize = 12
legend_fontsize = 8


#==============================================
# fundamental functions
#==============================================
# 実数値に変換できない文字列をfloat()で変換するとエラーになってプログラムが終了する
# この関数は、変換できなかったらNoneを返すが、プログラムは終了させない
def pfloat(str):
    try:
        return float(str)
    except:
        return None

# pfloat()のint版
def pint(str):
    try:
        return int(str)
    except:
        return None

# 起動時引数を取得するsys.argリスト変数は、範囲外のindexを渡すとエラーになってプログラムが終了する
# egtarg()では、範囲外のindexを渡したときは、defvalを返す
def getarg(position, defval = None):
    try:
        return sys.argv[position]
    except:
        return defval

# 起動時引数を実数に変換して返す
def getfloatarg(position, defval = None):
    return pfloat(getarg(position, defval))

# 起動時引数を整数値に変換して返す
def getintarg(position, defval = None):
    return pint(getarg(position, defval))

def usage():
    global mode, Ez, nz

    print("")
    print("Usage: Variables in () are optional")
    print(" python {} (wf nz Ez0)".format(sys.argv[0]))
    print(" ex: python {} (wf {} {})".format(sys.argv[0], nz, Ez0))
    print(" python {} (tr nz Ez0 Emin Emax nE)".format(sys.argv[0]))
    print(" ex: python {} (tr {} {} {} {} {})".format(sys.argv[0], nz, Ez0, Emin, Emax, nE))

def terminate(message = None):
    print("")
    if message is not None:
        print("")
        print(message)
        print("")

    usage()
    print("")
    exit()


#==============================================
# update default values by startup arguments
#==============================================
argv = sys.argv
#if len(argv) == 1:
#    terminate()

mode = getarg (1, mode)
nz = getintarg (2, nz)
Ez0 = getfloatarg(3, Ez0)
if mode == 'wf':
    pass
elif mode == 'tr':
    Emin = getfloatarg(4, Emin)
    Emax = getfloatarg(5, Emax)
    nE = getintarg(6, nE)
else:
    terminate("Error: Invalide mode [{}]".format(mode))


def IsInBarrier(z):
    global wellwidth, barrierwidth, barrierheight, nbarriers

    w = wellwidth + barrierwidth
    n = int(z / w)
    if n < nbarriers and 0.0 <= z - n * w < barrierwidth:
        return 1
    else:
        return 0

def U(z): # eV
    global wellwidth, barrierwidth, barrierheight, nbarriers

    if IsInBarrier(z):
        return barrierheight
    else:
        return 0.0

def meff(z):
    mwell = 1.0
    mbarrier = 1.0
    if 0.0 <= z < 5.0:
        return mbarrier
    elif 25.0 <= z < 30.0:
        return mbarrier
    else:
        return mwell


def cal_wf(xz, yU, ym, Ez):
    nz = len(xz)
    ykz = np.empty(nz, dtype = complex)
    for i in range(nz):
        z = xz[i]
        kz = sqrt(2.0 * ym[i] * me / hbar / hbar * (Ez - yU[i])*e + 0.0j) * 1.0e-10
        ykz[i] = kz
#        print(z, kz)

    Ai = np.empty(nz, dtype = complex)
    Bi = np.empty(nz, dtype = complex)
    Psi = np.empty(nz, dtype = complex)
    Ai[0] = 1.0
    Bi[0] = 0.0
    for i in range(1, nz):
#        print(i, xz[i], Ez, ym[i], ykz[i])
        if ykz[i] == 0.0:
            ykz[i] = 1.0e-10
        mk = ym[i] / ym[i-1] * ykz[i-1] / ykz[i]
        ap = 0.5 * (1.0 + mk)
        am = 0.5 * (1.0 - mk)
        P = exp(1.0j * (ykz[i-1] - ykz[i]) * xz[i])
        Q = exp(1.0j * (ykz[i-1] + ykz[i]) * xz[i])
        Ai[i] = ap * P * Ai[i-1] + am / Q * Bi[i-1]
        Bi[i] = am * Q * Ai[i-1] + ap / P * Bi[i-1]
        Psi[i] = Ai[i] * exp(1.0j * ykz[i] * xz[i]) + Bi[i] * exp(-1.0j * ykz[i] * xz[i])
#        print(" ", Ai[i], Bi[i])
#        print(" ", yxz[i], ykz[i], Ai[i], Bi[i])

    T = ym[nz-1] / ym[0] * ykz[0] / ykz[nz-1] * Ai[0] / Ai[nz-1]
    T = pow(abs(T), 2)

    k = sqrt(pow(abs(Ai[nz-1]), 2) + pow(abs(Bi[nz-1]), 2))
    Ai[i] /= k
    Bi[i] /= k
    Psi[i] /= k

    return ykz, Ai, Bi, Psi, T

def analytical_check(Ez, a):
    print("")
    print("=== Analyatical results to check ===")
    kz = sqrt(Ez * e / (hbar * hbar / 2.0 / me)) * 1.0e-10 # A^-1
    print("Ez = {} eV".format(Ez))
    print(" kz = {} A^-1".format(kz*1.0e-10))

    print("=== Infinit potential well ===")
    print("Well width: {} A".format(a))
    k = 2.0 * pi / (2.0 * a * 1.0e-10)
    E = hbar * hbar / 2.0 / me * k * k / e # eV
    print(" E0={:12.6g} eV".format(E))

def tr():
    global mode
    global zmin, zmax, nz
    global Emin, Emax, nE
    global Ez0

    zstep = (zmax - zmin) / (nz - 1)
    Estep = (Emax - Emin) / (nE - 1)

    analytical_check(Ez0, 20.0)

    print("")
    print("=== Input parameterss ===")
    print("zmin(zL)=", zmin, "A")
    print("zmax(zL)=", zmax, "A")
    print("nz=", nz)
    print("zstep=", zstep)
    print("Ez0={} eV".format(Ez0))
    print("Erange: {} - {} eV, {} step, nE={}".format(Emin, Emax, Estep, nE))

    xz = [zmax - i * zstep for i in range(nz)]
    Elogstep = (log(Emax) - log(Emin)) / (nE - 1)
    xE = [exp(log(Emin) + i * Elogstep) for i in range(nE)]
    yU = [U(z) for z in xz]
    ym = [meff(z) for z in xz]

    print("")
    print("=== Wave function at Ez0 = {} eV ===".format(Ez0))
    ykz0, Ai0, Bi0, Psi0, T0 = cal_wf(xz, yU, ym, Ez0)
    print("Transmission probability at {} eV: {}".format(Ez0, T0))

    print("")
    print("=== Transmission probability vs Energy")
    yT = []
    print("{:10}\t{:10}".format("E (eV)", "T"))
    for E in xE:
        ykz, Ai, Bi, Psi, T = cal_wf(xz, yU, ym, E)
        yT.append(T)
        print("{:10.6g}\t{:14.6g}".format(E, T))


    print("")
    print("plot")
    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(2, 2, 1)
# ax2をax1に関連させる
    ax2 = ax1.twinx()
    ax3 = fig.add_subplot(2, 2, 2)
    ax4 = fig.add_subplot(2, 2, 3)
    ax5 = fig.add_subplot(2, 2, 4)

    ax1.plot(xz, yU, label = 'U(z)', linewidth = 0.5, color = 'red')
    ax2.plot(xz, ym, label = 'm$_{eff}$(z)', linewidth = 0.5, color ='blue')
    ykzr = [ykz0[i].real for i in range(nz)]
    ykzi = [ykz0[i].imag for i in range(nz)]
    ykzr = [ykz0[i].real for i in range(nz)]
    ykzi = [ykz0[i].imag for i in range(nz)]
    ax3.plot(xz, ykzr, label = 'kz(real) (A$^{-1}$)', linewidth = 0.5, marker = 'o', markersize = 2)
    ax3.plot(xz, ykzi, label = 'kz(imag) (A$^{-1}$)', linewidth = 0.5)

    ax4.plot(xE, yT, label = 'T', linewidth = 0.5, color = 'red', marker = 'o', markersize = 0.5)
    yPsir = [Psi0[i].real for i in range(nz)]
    yPsii = [Psi0[i].imag for i in range(nz)]
    yPsia = [pow(abs(Psi0[i]), 2) for i in range(nz)]
    ax5.plot(xz, yPsir, label = '$\Psi$(real)', linewidth = 0.5)
    ax5.plot(xz, yPsii, label = '$\Psi$(imag)', linewidth = 0.5)
    ax5.plot(xz, yPsia, label = '$|\Psi|^2$', linewidth = 0.5)

    ax1.set_xlabel("z (A)", fontsize = fontsize)
    ax1.set_ylabel("U(z)", fontsize = fontsize)
    ax2.set_ylabel("m*(z)", fontsize = fontsize)
    ax3.set_xlabel("z (A)", fontsize = fontsize)
    ax3.set_ylabel("kz (A$^{-1}$)", fontsize = fontsize)
    ax4.set_xlabel("E (eV)", fontsize = fontsize)
    ax4.set_ylabel("Transmission probability", fontsize = fontsize)
    ax5.set_xlabel("z (A)", fontsize = fontsize)
    ax5.set_ylabel("$\Psi$(z)", fontsize = fontsize)

    ax1.set_xlim([zmin, zmax])
    ax2.set_xlim([zmin, zmax])
    ax1.set_ylim([min(yU), max(yU) * 1.1])
    ax2.set_ylim([min(ym), max(ym) * 1.3])
    ax3.set_xlim([zmin, zmax])
#    ax4.set_xlim([min(0.0, Emin), Emax])
    ax4.set_ylim([0.0, 1.1])
    ax5.set_xlim([zmin, zmax])

# 凡例をまとめて出力する
    handler1, label1 = ax1.get_legend_handles_labels()
    handler2, label2 = ax2.get_legend_handles_labels()
    ax1.legend(handler1 + handler2, label1 + label2, loc = 2, borderaxespad = 0.0, fontsize = legend_fontsize)
    ax3.legend(fontsize = legend_fontsize)
    ax4.legend(fontsize = legend_fontsize)
    ax5.legend(fontsize = legend_fontsize)

    ax1.tick_params(labelsize = fontsize)
    ax2.tick_params(labelsize = fontsize)
    ax3.tick_params(labelsize = fontsize)
    ax4.tick_params(labelsize = fontsize)
    ax5.tick_params(labelsize = fontsize)
    plt.tight_layout()

    plt.pause(0.1)
    print("")
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()


def wf():
    global mode
    global zmin, zmax, nz
    global Ez0

    analytical_check(Ez0, 20.0)

    zstep = (zmax - zmin) / (nz - 1)

    print("")
    print("=== Input parameterss ===")
    print("zmin(zL)=", zmin, "A")
    print("zmax(zL)=", zmax, "A")
    print("nz=", nz)
    print("zstep=", zstep)
    print("Ez0=", Ez0, "eV")

    xz = [zmax - i * zstep for i in range(nz)]
    yU = [U(z) for z in xz]
    ym = [meff(z) for z in xz]

    print("")
    print("=== Calculate wave function")
    ykz, Ai, Bi, Psi, T = cal_wf(xz, yU, ym, Ez0)

    print("Transmission probability from z = {} to {} A: {:14.6g}".format(xz[nz-1], xz[0], T))

    print("")
    print("plot")
    fig = plt.figure(figsize = figsize)
    ax1 = fig.add_subplot(2, 2, 1)
# ax2をax1に関連させる
    ax2 = ax1.twinx()
    ax3 = fig.add_subplot(2, 2, 2)
    ax4 = fig.add_subplot(2, 2, 3)
    ax5 = fig.add_subplot(2, 2, 4)

    ax1.plot(xz, yU, label = 'U(z)', linewidth = 0.5, color = 'red')
    ax2.plot(xz, ym, label = 'm$_{eff}$(z)', linewidth = 0.5, color ='blue')
    ykzr = [ykz[i].real for i in range(nz)]
    ykzi = [ykz[i].imag for i in range(nz)]
    ax3.plot(xz, ykzr, label = 'kz(real) (A$^{-1}$)', linewidth = 0.5, marker = 'o', markersize = 2)
    ax3.plot(xz, ykzi, label = 'kz(imag) (A$^{-1}$)', linewidth = 0.5)
    yAir = [Ai[i].real for i in range(nz)]
    yAii = [Ai[i].imag for i in range(nz)]
    yAia = [abs(Ai[i]) for i in range(nz)]
    yBir = [Bi[i].real for i in range(nz)]
    yBii = [Bi[i].imag for i in range(nz)]
    yBia = [abs(Bi[i]) for i in range(nz)]
#    ax4.plot(xz, yAir, label = 'Ai(real)', linewidth = 0.5)
#    ax4.plot(xz, yAii, label = 'Ai(imag)', linewidth = 0.5)
    ax4.plot(xz, yAia, label = 'Ai(abs)', linewidth = 0.5)
#    ax4.plot(xz, yBii, label = 'Bi(imag)', linewidth = 0.5)
#    ax4.plot(xz, yBir, label = 'Bi(real)', linewidth = 0.5)
    ax4.plot(xz, yBii, label = 'Bi(abs)', linewidth = 0.5)
    yPsir = [Psi[i].real for i in range(nz)]
    yPsii = [Psi[i].imag for i in range(nz)]
    yPsia = [pow(abs(Psi[i]), 2) for i in range(nz)]
    ax5.plot(xz, yPsir, label = '$\Psi$(real)', linewidth = 0.5)
    ax5.plot(xz, yPsii, label = '$\Psi$(imag)', linewidth = 0.5)
    ax5.plot(xz, yPsia, label = '$|\Psi|^2$', linewidth = 0.5)

    ax1.set_xlabel("z (A)", fontsize = fontsize)
    ax1.set_ylabel("U(z)", fontsize = fontsize)
    ax2.set_ylabel("m*(z)", fontsize = fontsize)
    ax3.set_xlabel("z (A)", fontsize = fontsize)
    ax3.set_ylabel("kz (A$^{-1}$)", fontsize = fontsize)
    ax4.set_xlabel("z (A)", fontsize = fontsize)
    ax4.set_ylabel("Ai, Bi", fontsize = fontsize)
    ax5.set_xlabel("z (A)", fontsize = fontsize)
    ax5.set_ylabel("$\Psi$(z)", fontsize = fontsize)

    ax1.set_xlim([zmin, zmax])
    ax2.set_xlim([zmin, zmax])
    ax3.set_xlim([zmin, zmax])
    ax4.set_xlim([zmin, zmax])
    ax5.set_xlim([zmin, zmax])
    ax1.set_ylim([min(yU), max(yU) * 1.1])
    ax2.set_ylim([min(ym), max(ym) * 1.3])
    ax3.set_ylim([min(min(ykzr), min(ykzi) - 0.1), max(ykzr) * 1.1])
    ax3.set_ylim([min(min(ykzr), min(ykzi) - 0.1), max(ykzi) * 1.3])

# 凡例をまとめて出力する
    handler1, label1 = ax1.get_legend_handles_labels()
    handler2, label2 = ax2.get_legend_handles_labels()
    ax1.legend(handler1 + handler2, label1 + label2, loc = 2, borderaxespad = 0.0, fontsize = legend_fontsize)
    ax3.legend(fontsize = legend_fontsize)
    ax4.legend(fontsize = legend_fontsize)
    ax5.legend(fontsize = legend_fontsize)

    ax1.tick_params(labelsize = fontsize)
    plt.tight_layout()

    plt.pause(0.1)
    print("")
    print("Press ENTER to exit>>", end = '')
    input()

    terminate()


if __name__ == "__main__":
    if mode == 'wf':
        wf()
    elif mode == 'tr':
        tr()
    else:
        terminate("Error: Invalid mode [{}]".format(mode))