import os import sys import json import pandas as pd import matplotlib.pyplot as plt from mp_api.client import MPRester from pymatgen.core.structure import Structure from emmet.core.electronic_structure import BSPathType from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter from pymatgen.electronic_structure.core import Spin from monty.json import MontyEncoder API_KEY = os.getenv('MP_APIKEY') if API_KEY is None: print("\nError: Can not get MP API Key from the environment var MP_APIKEY\n") exit() formula = None kpath = None def usage(): print() print(f"Usage: python {sys.argv[0]} formula\n") def update_vars(): global formula argv = sys.argv narg = len(argv) if narg <= 1: print("\nError: Chemical formula must be given as the first arg\n") input("Press ENTER to terminate>>\n") exit() formula = argv[1] if narg >= 3: kpath = argv[2] def get_MPReser(API_KEY = None): if not API_KEY: API_KEY = os.getenv('MP_APIKEY') if API_KEY is None: print("\nError: Can not get MP API Key from the environment var MP_APIKEY\n") return None mpr = MPRester(API_KEY) if mpr is None: print(f"\nError: Can not get MPRester using the given API_KEY [{API_KEY}]\n") return None return mpr def get_material_ids(formula, mpr): search_results = mpr.materials.search(formula = formula, fields = []) if not search_results: print(f"No data found for {formula}") return None mid_list = [] for res in search_results: mid_list.append(res.material_id) return mid_list def get_band_data(): mpr = get_MPReser() if not mpr: exit() mid_list = get_material_ids(formula, mpr) print("Material IDs:", mid_list) print() for mid in mid_list: material_data = mpr.materials.summary.search(material_ids = [mid], fields = []) if not material_data: print(f"No structure data found for material_id {material_id}") continue structure_dict = material_data[0].structure.as_dict() structure = Structure.from_dict(structure_dict) cformula = structure.reduced_formula if kpath == "Hinuma": # -- line-mode, Hinuma et al.: try: bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.hinuma) except: pass elif kpath == "Latimer-Munro": # -- line-mode, Latimer-Munro: try: bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.latimer_munro) except: pass elif kpath == "Setyawan-Curtarolo": # -- uniform: try: bs = mpr.get_bandstructure_by_material_id(mid, line_mode = False) except: pass else: # -- line-mode, Setyawan-Curtarolo (default): try: bs = mpr.get_bandstructure_by_material_id(mid) except: pass if bs is None: print(f" Band data is not available: Skip") continue try: dos = mpr.get_dos_by_material_id(mid) except: pass if dos is None: print(f" DOS data is not available: Skip") continue print("Saving JSON files...") with open(f'band_{cformula}_{mid}.json', 'w') as f: json.dump(bs.as_dict(), f, cls=MontyEncoder, indent = 2) with open(f'dos_{cformula}_{mid}.json', 'w') as f: json.dump(dos.as_dict(), f, cls=MontyEncoder, indent = 2) bs_data = [] for band in bs.bands[Spin.up]: for kpoint, energy in zip(bs.kpoints, band): bs_data.append([kpoint.frac_coords, energy]) print(" CSV files...") bs_df = pd.DataFrame(bs_data, columns=['K-Point', 'Energy']) bs_df.to_csv(f'band_{cformula}_{mid}.csv', index=False) dos_data = [] for energy, dos_value in zip(dos.energies, dos.densities[Spin.up]): dos_data.append([energy, dos_value]) dos_df = pd.DataFrame(dos_data, columns=['Energy', 'Density of States']) dos_df.to_csv(f'dos_{cformula}_{mid}.csv', index=False) print(" figure files...") bs_plotter = BSPlotter(bs) plt_bs = bs_plotter.get_plot() fig_bs = plt_bs.figure fig_bs.savefig(f'band_{cformula}_{mid}.png') dos_plotter = DosPlotter() dos_plotter.add_dos("Total DOS", dos) plt_dos = dos_plotter.get_plot() fig_dos = plt_dos.figure fig_dos.savefig(f'dos_{cformula}_{mid}.png') def main(): update_vars() get_band_data(app, cfg) usage() print() input("Press ENTER to terminate>>\n") exit() if __name__ == "__main__": main()