forked from kaist-amsg/imatgen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_cell_image.py
146 lines (125 loc) · 5.44 KB
/
get_cell_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import sys
import math
import os
import random
import json
from ase.io import read,write
from ase import Atom, Atoms
import argparse
import numpy
import numpy as np
from numpy.linalg import norm
from tqdm import tqdm
from joblib import Parallel, delayed
import multiprocessing
import pickle
def get_atomlist_atomindex():
# cod_atomlist = ['Ru', 'Re', 'Ra', 'Rb', 'Rn', 'Rh', 'Be', 'Ba', 'Bi', 'Bk', 'Br', 'H', 'P', 'Os', 'Ge', 'Gd', 'Ga', 'Pr', 'Pt', 'Pu', 'C', 'Pb', 'Pa', 'Pd', 'Cd', 'Po', 'Pm', 'Ho', 'Hf', 'Hg', 'He', 'Mg', 'K', 'Mn', 'O', 'S', 'W', 'Zn', 'Eu', 'Zr', 'Er', 'Ni', 'Na', 'Nb', 'Nd', 'Ne', 'Np', 'Fe', 'B', 'F', 'Sr', 'N', 'Kr', 'Si', 'Sn', 'Sm', 'V', 'Sc', 'Sb', 'Se', 'Co', 'Cm', 'Cl', 'Ca', 'Cf', 'Ce', 'Xe', 'Tm', 'Cs', 'Cr', 'Cu', 'La', 'Li', 'Tl', 'Lu', 'Th', 'Ti', 'Te', 'Tb', 'Tc', 'Ta', 'Yb', 'Dy', 'I', 'U', 'Y', 'Ac', 'Ag', 'Ir', 'Am', 'Al', 'As', 'Ar', 'Au', 'In', 'Mo'] # 96
# mp_atomlist = ['Ru', 'Re', 'Rb', 'Rh', 'Be', 'Ba', 'Bi', 'Br', 'H', 'P', 'Os', 'Ge', 'Gd', 'Ga', 'Pr', 'Pt', 'Pu', 'Mg', 'Pb','Pa', 'Pd', 'Cd', 'Pm', 'Ho', 'Hf', 'Hg', 'He', 'C', 'K', 'Mn', 'O', 'S', 'W', 'Zn', 'Eu', 'Zr', 'Er', 'Ni', 'Na','Nb', 'Nd', 'Ne', 'Np', 'Fe', 'B', 'F', 'Sr', 'N', 'Kr', 'Si', 'Sn', 'Sm', 'V', 'Sc', 'Sb', 'Se', 'Co', 'Cl', 'Ca','Ce', 'Xe', 'Tm', 'Cs', 'Cr', 'Cu', 'La', 'Li', 'Tl', 'Lu', 'Th', 'Ti', 'Te', 'Tb', 'Tc', 'Ta', 'Yb', 'Dy', 'I','U', 'Y', 'Ac', 'Ag', 'Ir', 'Al', 'As', 'Ar', 'Au', 'In', 'Mo'] #89
# all_atomlist = list(set(cod_atomlist+mp_atomlist))
# You can specify your own element lists (if you don't want to use the above list)
all_atomlist = ['V']
cod_atomindex = {}
for i,symbol in enumerate(all_atomlist):
cod_atomindex[symbol] = i
return cod_atomlist,cod_atomindex
def get_scale(sigma):
scale = 1.0/(2*sigma**2)
return scale
def get_image_one_atom(atom,fakeatoms_grid,nbins,scale):
grid_copy = fakeatoms_grid.copy()
ngrid = len(grid_copy)
image = numpy.zeros((1,nbins**3))
grid_copy.append(atom)
drijk = grid_copy.get_distances(-1,range(0,nbins**3),mic=True)
pijk = numpy.exp(-scale*drijk**2)
image[:,:] = pijk.flatten()
return image.reshape(nbins,nbins,nbins)
def get_fakeatoms_grid(atoms,nbins):
atomss = []
scaled_positions = []
ijks = []
grid = numpy.array([float(i)/float(nbins) for i in range(nbins)])
yv,xv,zv = numpy.meshgrid(grid,grid,grid)
pos = numpy.zeros((nbins**3,3))
pos[:,0] = xv.flatten()
pos[:,1] = yv.flatten()
pos[:,2] = zv.flatten()
atomss = Atoms('H'+str(nbins**3))
atomss.set_cell(atoms.get_cell())#making pseudo-crystal containing H positioned at pre-defined fractional coordinate
atomss.set_pbc(True)
atomss.set_scaled_positions(pos)
fakeatoms_grid = atomss
return fakeatoms_grid
def get_image_all_atoms(atoms,nbins,scale,norm,num_cores):
fakeatoms_grid = get_fakeatoms_grid(atoms,nbins)
cell = atoms.get_cell()
imageall_gen = Parallel(n_jobs=num_cores)(delayed(get_image_one_atom)(atom,fakeatoms_grid,nbins,scale) for atom in atoms)
imageall_list = list(imageall_gen)
cod_atomlist,cod_atomindex = get_atomlist_atomindex()
nchannel = len(cod_atomlist)
channellist = []
for i,atom in enumerate(atoms):
channel = cod_atomindex[atom.symbol]
channellist.append(channel)
channellist = list(set(channellist))
nc = len(channellist)
shape = (nbins,nbins,nbins,nc)
image = numpy.zeros(shape)
for i,atom in enumerate(atoms):
nnc = channellist.index(cod_atomindex[atom.symbol])
img_i = imageall_list[i]
image[:,:,:,nnc] += img_i * (img_i>=0.02)
return image,channellist
def image2pickle(image,channellist,savefilename):
dic = {'image':image.tolist()}
with open(savefilename,'w') as f:
json.dump(dic,f)
# You don't need to save channel_vector in case of making cell image
def extract_cell(atoms):
cell = atoms.cell
atoms_ = Atoms('V')
atoms_.cell = cell
atoms_.set_scaled_positions([0.5,0.5,0.5])
return atoms_
def file2image(args):
inputfiles = args.input_file
random.shuffle(inputfiles)
for inputfile in inputfiles:
tmp = inputfile.split('/')[-1].split('.')[0]; tmp2 = inputfile.split('.')[0]
touchfile = tmp2+'.touchtouch'
filename2 = './'+tmp+'_64.pickle'
savefilename = './'+tmp+'_32.npy'#'.pickle'
if os.path.isfile(filename2) or os.path.isfile(savefilename):
print('already made pickle')
pass
if os.path.isfile(touchfile):
pass
os.system('touch '+touchfile)
try:
atoms = read(inputfile,format = args.filetype)
except:
os.system('rm '+inputfile)
continue
scale = get_scale(sigma=0.26) # values for Gaussain width
num_cores = args.nproc
nbins = args.nbins # the number of grid for generated output image
image,channellist = get_image_all_atoms(extract_cell(atoms),nbins,scale,norm,num_cores)
image2pickle(image,channellist,savefilename)
os.system('rm '+touchfile)
return 1
def main():
parser = argparse.ArgumentParser(description='mapping POSCAR or cif structure into a box image')
parser.add_argument('--input_file', type=str,nargs='+',
help='a file path with the poscar or cif structure')
parser.add_argument('--filetype', type=str,default='cif',
help='filetype : cif,vasp')
parser.add_argument('--nbins', type=int,default=32,
help='number of bins in one dimension')
parser.add_argument('--nproc', type=int,default=1,
help='number of process')
args = parser.parse_args()
file2image(args)
return 1
if __name__=='__main__':
main()