Short examples to show how to extract or plot the band structure from a VASP calculation using pymagen.
# Uncomment the subsequent lines in this cell to install dependencies for Google Colab.
# !pip install pymatgen==2022.7.19
import os
import matplotlib.pyplot as plt
from pymatgen.electronic_structure.core import Spin
from pymatgen.electronic_structure.plotter import BSDOSPlotter, BSPlotter, DosPlotter
from pymatgen.io.vasp.outputs import BSVasprun, Vasprun
%matplotlib inline
I work in a directory that contains the calculations of the NiO band structure. Look at this example to learn how to obtain a band structure from VASP calculations.
os.listdir()
BandStructureNiO.ipynb KPOINTS WAVECAR CHG OSZICAR XDATCAR CHGCAR OUTCAR jNiO CONTCAR PCDAT slurm-348507.out DOSCAR POSCAR vasprun.xml EIGENVAL POTCAR INCAR PROCAR
run = BSVasprun("vasprun.xml", parse_projected_eigen=True)
bs = run.get_band_structure("KPOINTS")
You can obtain some information about the band structure :
print("number of bands", bs.nb_bands)
number of bands 14
print("number of kpoints", len(bs.kpoints))
number of kpoints 200
bs.is_metal()
False
bs.is_spin_polarized
True
The bands
attribute of the BaandStructure
object is a dictionnary of arrays that contains all the bands. The shape is the following :
{Spin.up: np.array((nb_bands, nb_kpoints)), Spin.down: np.array((nb_bands, nb_kpoints))}
bs.bands
{<Spin.up: 1>: array([[-12.6235, -12.6259, -12.6329, ..., -14.3847, -14.3914, -14.3936], [ -3.5768, -3.5725, -3.5597, ..., -1.3942, -1.3809, -1.3765], [ -2.9295, -2.9282, -2.9243, ..., -1.3815, -1.3777, -1.3765], ..., [ 20.4623, 20.4656, 20.4757, ..., 23.0714, 23.0629, 23.0559], [ 20.4623, 20.4656, 20.4757, ..., 28.6536, 28.9814, 29.1847], [ 26.6233, 26.6231, 26.6232, ..., 29.26 , 29.2519, 29.3045]]), <Spin.down: -1>: array([[-12.2547, -12.2571, -12.2642, ..., -14.0377, -14.0444, -14.0466], [ -3.0758, -3.0713, -3.0576, ..., 0.7154, 0.723 , 0.7256], [ 0.5472, 0.5477, 0.5494, ..., 0.7227, 0.7249, 0.7256], ..., [ 20.845 , 20.8494, 20.8627, ..., 23.4954, 23.5054, 23.5058], [ 20.845 , 20.8494, 20.8627, ..., 29.7286, 29.9021, 29.9973], [ 27.0945, 27.0942, 27.094 , ..., 29.9072, 29.9491, 29.9978]])}
bs.bands[Spin.up].shape
(14, 200)
The 9th bands of spin down is extracted by :
bs.bands[Spin.down][9, :]
array([ 12.1302, 12.131 , 12.1334, 12.137 , 12.1416, 12.1469, 12.1521, 12.1568, 12.1601, 12.1613, 12.1594, 12.1535, 12.1428, 12.1264, 12.1034, 12.073 , 12.0347, 11.988 , 11.9326, 11.8683, 11.7951, 11.7132, 11.6228, 11.5244, 11.4186, 11.3058, 11.1867, 11.062 , 10.9324, 10.7987, 10.6617, 10.522 , 10.3805, 10.238 , 10.0954, 9.9536, 9.8135, 9.6761, 9.5424, 9.4137, 9.2911, 9.176 , 9.0696, 8.9734, 8.8888, 8.8172, 8.76 , 8.7182, 8.6928, 8.6842, 8.6842, 8.6957, 8.7301, 8.7877, 8.8689, 8.9736, 9.1009, 9.2495, 9.4174, 9.602 , 9.8008, 10.011 , 10.2301, 10.4558, 10.6859, 10.9185, 11.1519, 11.3844, 11.6146, 11.8408, 12.0616, 12.2756, 12.4812, 12.677 , 12.8614, 13.0327, 13.1893, 13.3294, 13.4513, 13.5535, 13.6345, 13.6936, 13.7303, 13.7454, 13.7402, 13.7171, 13.6793, 13.6301, 13.5731, 13.5118, 13.4493, 13.3882, 13.3306, 13.278 , 13.232 , 13.1933, 13.1626, 13.1405, 13.1271, 13.1226, 13.1226, 13.1235, 13.1262, 13.1307, 13.1371, 13.1452, 13.1551, 13.1667, 13.1802, 13.1953, 13.2121, 13.2307, 13.2509, 13.2727, 13.2961, 13.3211, 13.3475, 13.3756, 13.405 , 13.4359, 13.4681, 13.5016, 13.5365, 13.5725, 13.6098, 13.6482, 13.6876, 13.7282, 13.7697, 13.8122, 13.8555, 13.8997, 13.9447, 13.9904, 14.0368, 14.0839, 14.1315, 14.1797, 14.2283, 14.2774, 14.3268, 14.3766, 14.4267, 14.477 , 14.5274, 14.578 , 14.6285, 14.6791, 14.7295, 14.7798, 14.7798, 14.9285, 15.0712, 15.2025, 15.313 , 15.3886, 15.4144, 15.3871, 15.3166, 15.2167, 15.097 , 14.9637, 14.8204, 14.6694, 14.5122, 14.35 , 14.1833, 14.0129, 13.839 , 13.662 , 13.4821, 13.2995, 13.1142, 12.9266, 12.7366, 12.5442, 12.3498, 12.1532, 11.9547, 11.7545, 11.5529, 11.3502, 11.1468, 10.9433, 10.7401, 10.5382, 10.3384, 10.1418, 9.9497, 9.7636, 9.5851, 9.4164, 9.2595, 9.1167, 8.9907, 8.8837, 8.7979, 8.7352, 8.6971, 8.6842])
In order to print a band and the corresponding k-points :
n = 0
for kpoints, e in zip(bs.kpoints, bs.bands[Spin.down][9, :]):
n += 1
if n == 11:
print("...")
if 10 < n < 190:
continue
print(
"kx = %5.3f ky = %5.3f kz = %5.3f eps(k) = %8.4f"
% (tuple(kpoints.frac_coords) + (e,))
)
kx = 0.500 ky = 0.500 kz = 0.500 eps(k) = 12.1302 kx = 0.490 ky = 0.490 kz = 0.490 eps(k) = 12.1310 kx = 0.480 ky = 0.480 kz = 0.480 eps(k) = 12.1334 kx = 0.469 ky = 0.469 kz = 0.469 eps(k) = 12.1370 kx = 0.459 ky = 0.459 kz = 0.459 eps(k) = 12.1416 kx = 0.449 ky = 0.449 kz = 0.449 eps(k) = 12.1469 kx = 0.439 ky = 0.439 kz = 0.439 eps(k) = 12.1521 kx = 0.429 ky = 0.429 kz = 0.429 eps(k) = 12.1568 kx = 0.418 ky = 0.418 kz = 0.418 eps(k) = 12.1601 kx = 0.408 ky = 0.408 kz = 0.408 eps(k) = 12.1613 ... kx = 0.077 ky = 0.153 kz = 0.077 eps(k) = 9.7636 kx = 0.069 ky = 0.138 kz = 0.069 eps(k) = 9.5851 kx = 0.061 ky = 0.122 kz = 0.061 eps(k) = 9.4164 kx = 0.054 ky = 0.107 kz = 0.054 eps(k) = 9.2595 kx = 0.046 ky = 0.092 kz = 0.046 eps(k) = 9.1167 kx = 0.038 ky = 0.077 kz = 0.038 eps(k) = 8.9907 kx = 0.031 ky = 0.061 kz = 0.031 eps(k) = 8.8837 kx = 0.023 ky = 0.046 kz = 0.023 eps(k) = 8.7979 kx = 0.015 ky = 0.031 kz = 0.015 eps(k) = 8.7352 kx = 0.008 ky = 0.015 kz = 0.008 eps(k) = 8.6971 kx = 0.000 ky = 0.000 kz = 0.000 eps(k) = 8.6842
bsplot = BSPlotter(bs)
# get the plot
bsplot.get_plot(ylim=(-20, 10), zero_to_efermi=True)
print(bs.efermi)
# add some features
ax = plt.gca()
ax.set_title("NiO Band Structure", fontsize=20)
xlim = ax.get_xlim()
ax.hlines(0, xlim[0], xlim[1], linestyles="dashed", color="black")
# add legend
ax.plot((), (), "b-", label="spin up")
ax.plot((), (), "r--", label="spin down")
ax.legend(fontsize=16, loc="upper left")
5.35857687
<matplotlib.legend.Legend at 0x10eeae470>
You can get data from the plot and in particular the (x, y) coordinates of each band.
data = bsplot.bs_plot_data()
data.keys()
dict_keys(['ticks', 'distances', 'energy', 'vbm', 'cbm', 'lattice', 'zero_energy', 'is_metal', 'band_gap'])
For example, here, you print the abscissa and the energy of the 9th band. Keep in mind that here the data are the ones used to do the plot. Thus the zero to fermi translation is already done according to the BSPlotter
class.
ibands = 9 # band number from 0 --> number of bands
spin = str(Spin.up)
for xpath, epath in zip(data["distances"], data["energy"]):
print(20 * "-")
for x, bands in zip(xpath, epath[spin][ibands]):
print(f"{x:8.4f} {bands:8.4f}")
-------------------- 0.0000 6.8926 0.0273 6.8929 0.0546 6.8941 0.0819 6.8956 0.1092 6.8974 0.1365 6.8990 0.1638 6.8999 0.1911 6.8994 0.2184 6.8970 0.2457 6.8919 0.2730 6.8832 0.3003 6.8702 0.3276 6.8520 0.3549 6.8278 0.3822 6.7970 0.4095 6.7589 0.4368 6.7131 0.4641 6.6591 0.4914 6.5967 0.5187 6.5259 0.5460 6.4468 0.5733 6.3595 0.6006 6.2643 0.6279 6.1617 0.6552 6.0522 0.6825 5.9364 0.7097 5.8148 0.7370 5.6881 0.7643 5.5569 0.7916 5.4220 0.8189 5.2841 0.8462 5.1440 0.8735 5.0024 0.9008 4.8600 0.9281 4.7178 0.9554 4.5765 0.9827 4.4371 1.0100 4.3006 1.0373 4.1680 1.0646 4.0403 1.0919 3.9189 1.1192 3.8049 1.1465 3.6996 1.1738 3.6045 1.2011 3.5209 1.2284 3.4503 1.2557 3.3938 1.2830 3.3526 1.3103 3.3275 1.3376 3.3191 -------------------- 1.3376 3.3191 1.3691 3.3303 1.4006 3.3638 1.4322 3.4189 1.4637 3.4946 1.4952 3.5897 1.5267 3.7025 1.5583 3.8313 1.5898 3.9745 1.6213 4.1303 1.6528 4.2969 1.6843 4.4727 1.7159 4.6560 1.7474 4.8454 1.7789 5.0394 1.8104 5.2363 1.8419 5.4352 1.8735 5.6347 1.9050 5.8335 1.9365 6.0304 1.9680 6.2241 1.9995 6.4135 2.0311 6.5974 2.0626 6.7746 2.0941 6.9436 2.1256 7.1033 2.1571 7.2523 2.1887 7.3891 2.2202 7.5125 2.2517 7.6211 2.2832 7.7139 2.3148 7.7899 2.3463 7.8485 2.3778 7.8898 2.4093 7.9143 2.4408 7.9231 2.4724 7.9179 2.5039 7.9007 2.5354 7.8740 2.5669 7.8405 2.5984 7.8025 2.6300 7.7623 2.6615 7.7225 2.6930 7.6841 2.7245 7.6493 2.7560 7.6192 2.7876 7.5948 2.8191 7.5769 2.8506 7.5659 2.8821 7.5622 -------------------- 2.8821 7.5622 2.8933 7.5631 2.9044 7.5660 2.9156 7.5708 2.9267 7.5775 2.9379 7.5862 2.9490 7.5966 2.9601 7.6090 2.9713 7.6232 2.9824 7.6392 2.9936 7.6570 3.0047 7.6765 3.0159 7.6978 3.0270 7.7208 3.0382 7.7454 3.0493 7.7716 3.0604 7.7994 3.0716 7.8288 3.0827 7.8596 3.0939 7.8918 3.1050 7.9254 3.1162 7.9604 3.1273 7.9966 3.1385 8.0340 3.1496 8.0727 3.1607 8.1124 3.1719 8.1533 3.1830 8.1951 3.1942 8.2379 3.2053 8.2817 3.2165 8.3263 3.2276 8.3718 3.2388 8.4180 3.2499 8.4649 3.2610 8.5125 3.2722 8.5608 3.2833 8.6096 3.2945 8.6590 3.3056 8.7088 3.3168 8.7591 3.3279 8.8099 3.3391 8.8609 3.3502 8.9124 3.3613 8.9641 3.3725 9.0160 3.3836 9.0682 3.3948 9.1205 3.4059 9.1730 3.4171 9.2255 3.4282 9.2780 -------------------- 3.4282 9.2780 3.4616 9.4353 3.4951 9.5903 3.5285 9.7395 3.5619 9.8760 3.5954 9.9843 3.6288 10.0379 3.6622 10.0174 3.6957 9.9370 3.7291 9.8218 3.7625 9.6868 3.7960 9.5394 3.8294 9.3836 3.8628 9.2214 3.8963 9.0544 3.9297 8.8833 3.9631 8.7089 3.9966 8.5317 4.0300 8.3521 4.0634 8.1703 4.0969 7.9866 4.1303 7.8011 4.1637 7.6139 4.1972 7.4253 4.2306 7.2352 4.2640 7.0438 4.2975 6.8511 4.3309 6.6573 4.3643 6.4624 4.3978 6.2667 4.4312 6.0704 4.4646 5.8738 4.4981 5.6773 4.5315 5.4813 4.5649 5.2862 4.5984 5.0929 4.6318 4.9021 4.6652 4.7147 4.6987 4.5319 4.7321 4.3549 4.7655 4.1852 4.7990 4.0246 4.8324 3.8750 4.8658 3.7385 4.8993 3.6174 4.9327 3.5139 4.9661 3.4305 4.9996 3.3692 5.0330 3.3317 5.0664 3.3191
Same as above but in a plot.
ibands = 9 # band number from 0 --> number of bands
spin = str(Spin.up)
for xpath, epath in zip(data["distances"], data["energy"]):
plt.plot(xpath, epath[spin][ibands])
The same again but merging the slices of the band.
ibands = 1 # band number from 0 --> number of bands
spin = str(Spin.up)
x = list()
y = list()
for xpath, epath in zip(data["distances"], data["energy"]):
x += xpath
y += epath[spin][ibands]
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x10f259160>]
Read the DOS from another calculations.
dosrun = Vasprun("../DOS_SMEAR/vasprun.xml", parse_dos=True)
dos = dosrun.complete_dos
print(dosrun.efermi)
print(dos.efermi)
5.24546925 5.24546925
dosplot = DosPlotter(sigma=0.1)
dosplot.add_dos("Total DOS", dos)
dosplot.add_dos_dict(dos.get_element_dos())
plt = dosplot.get_plot()
plt.grid()
bs = run.get_band_structure("KPOINTS", efermi=dos.efermi)
bsdosplot = BSDOSPlotter(
bs_projection="elements",
dos_projection="elements",
vb_energy_range=22,
egrid_interval=2.5,
)
plt = bsdosplot.get_plot(bs, dos=dos)