Skip to content

Commit

Permalink
refactor web-app, present settings and another test script
Browse files Browse the repository at this point in the history
  • Loading branch information
blokhin committed Feb 14, 2018
1 parent e571e42 commit 355a58e
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 37 deletions.
5 changes: 5 additions & 0 deletions data/settings.ini.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[mpds_ml_labs]
serve_ui = true
ml_models =
/path_to_models/model_one.pkl
/path_to_models/model_two.pkl
45 changes: 27 additions & 18 deletions mpds_ml_labs/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from flask import Flask, Blueprint, Response, request, send_from_directory

from cors import crossdomain
from struct_utils import detect_format, poscar_to_ase, symmetrize, get_formula
from cif_utils import cif_to_ase, ase_to_eq_cif
from prediction import ase_to_ml_model, get_legend, load_ml_model
from common import SERVE_UI, ML_MODELS


app_labs = Blueprint('app_labs', __name__)
ml_model = None
static_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../webassets'))
active_ml_model = None

def fmt_msg(msg, http_code=400):
return Response('{"error":"%s"}' % msg, content_type='application/json', status=http_code)
Expand All @@ -37,20 +38,23 @@ def html_formula(string):
if sub: html_formula += '</sub>'
return html_formula

@app_labs.route('/', methods=['GET'])
def index():
return send_from_directory(os.path.dirname(__file__), 'index.html')

@app_labs.route('/index.css', methods=['GET'])
def style():
return send_from_directory(os.path.dirname(__file__), 'index.css')

@app_labs.route('/player.html', methods=['GET'])
def player():
return send_from_directory(os.path.dirname(__file__), 'player.html')
if SERVE_UI:
@app_labs.route('/', methods=['GET'])
def index():
return send_from_directory(static_path, 'index.html')
@app_labs.route('/index.css', methods=['GET'])
def style():
return send_from_directory(static_path, 'index.css')
@app_labs.route('/player.html', methods=['GET'])
def player():
return send_from_directory(static_path, 'player.html')

@app_labs.after_request
def add_cors_header(response):
response.headers['Access-Control-Allow-Origin'] = '*'
return response

@app_labs.route("/predict", methods=['POST'])
@crossdomain(origin='*')
def predict():
if 'structure' not in request.values:
return fmt_msg('Invalid request')
Expand Down Expand Up @@ -81,7 +85,7 @@ def predict():
if error:
return fmt_msg(error)

prediction, error = ase_to_ml_model(ase_obj, ml_model)
prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
if error:
return fmt_msg(error)

Expand All @@ -105,10 +109,15 @@ def predict():

if __name__ == '__main__':
if sys.argv[1:]:
ml_model = load_ml_model(sys.argv[1:])
print("Loaded models: " + " ".join(sys.argv[1:]))
print("Models to load:\n" + "\n".join(sys.argv[1:]))
active_ml_model = load_ml_model(sys.argv[1:])

elif ML_MODELS:
print("Models to load:\n" + "\n".join(ML_MODELS))
active_ml_model = load_ml_model(ML_MODELS)

else:
print("No model loaded")
print("No models to load")

app = Flask(__name__)
app.debug = False
Expand Down
5 changes: 5 additions & 0 deletions mpds_ml_labs/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def cif_to_ase(cif_string):
else:
return None, 'Absent space group info in CIF'

disordered = False
try:
cellpar = (
float( parsed_cif['_cell_length_a'][0].split('(')[0] ),
Expand All @@ -61,9 +62,13 @@ def cif_to_ase(cif_string):
[ char.split('(')[0] for char in parsed_cif['_atom_site_fract_z'] ]
]).astype(np.float)
)
disordered = any([float(occ) != 1 for occ in parsed_cif.get('_atom_site_occupancy', [])])
except:
return None, 'Unexpected non-numerical values occured in CIF'

if disordered:
return None, 'This is disordered structure (not yet supported)'

symbols = parsed_cif.get('_atom_site_type_symbol')
if not symbols:
symbols = parsed_cif.get('_atom_site_label')
Expand Down
19 changes: 19 additions & 0 deletions mpds_ml_labs/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

import os
from ConfigParser import ConfigParser


DATA_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), '../data'))
config = ConfigParser()
config_path = os.path.join(DATA_PATH, 'settings.ini')

if os.path.exists(config_path):
config.read(config_path)
SERVE_UI = config.get('mpds_ml_labs', 'serve_ui')
ML_MODELS = [path.strip() for path in filter(
None,
config.get('mpds_ml_labs', 'ml_models').split()
)]
else:
SERVE_UI = True
ML_MODELS = []
54 changes: 35 additions & 19 deletions mpds_ml_labs/struct_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import re
import fractions
import cStringIO

from ase.atoms import Atoms
Expand Down Expand Up @@ -87,30 +88,45 @@ def symmetrize(ase_obj, accuracy=1E-03):
return None, 'Unrecognized sites or invalid site symmetry in structure'


def get_formula(ase_obj):
formula_sequence = ['Fr','Cs','Rb','K','Na','Li', 'Be','Mg','Ca','Sr','Ba','Ra', 'Sc','Y','La','Ce','Pr','Nd','Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb', 'Ac','Th','Pa','U','Np','Pu', 'Ti','Zr','Hf', 'V','Nb','Ta', 'Cr','Mo','W', 'Fe','Ru','Os', 'Co','Rh','Ir', 'Mn','Tc','Re', 'Ni','Pd','Pt', 'Cu','Ag','Au', 'Zn','Cd','Hg', 'B','Al','Ga','In','Tl', 'Pb','Sn','Ge','Si','C', 'N','P','As','Sb','Bi', 'H', 'Po','Te','Se','S','O', 'At','I','Br','Cl','F', 'He','Ne','Ar','Kr','Xe','Rn']
FORMULA_SEQUENCE = ['Fr','Cs','Rb','K','Na','Li', 'Be','Mg','Ca','Sr','Ba','Ra', 'Sc','Y','La','Ce','Pr','Nd','Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb', 'Ac','Th','Pa','U','Np','Pu', 'Ti','Zr','Hf', 'V','Nb','Ta', 'Cr','Mo','W', 'Fe','Ru','Os', 'Co','Rh','Ir', 'Mn','Tc','Re', 'Ni','Pd','Pt', 'Cu','Ag','Au', 'Zn','Cd','Hg', 'B','Al','Ga','In','Tl', 'Pb','Sn','Ge','Si','C', 'N','P','As','Sb','Bi', 'H', 'Po','Te','Se','S','O', 'At','I','Br','Cl','F', 'He','Ne','Ar','Kr','Xe','Rn']

labels = {}
types = []
count = 0
def get_formula(ase_obj, find_gcd=True):
parsed_formula = {}

for k, label in enumerate(ase_obj.get_chemical_symbols()):
if label not in labels:
labels[label] = count
types.append([k+1])
count += 1
for label in ase_obj.get_chemical_symbols():
if label not in parsed_formula:
parsed_formula[label] = 1
else:
types[ labels[label] ].append(k+1)
parsed_formula[label] += 1

atoms = labels.keys()
atoms = [x for x in formula_sequence if x in atoms] + [x for x in atoms if x not in formula_sequence]
expanded = reduce(fractions.gcd, parsed_formula.values()) if find_gcd else 1
if expanded > 1:
parsed_formula = {el: int(content / float(expanded))
for el, content in parsed_formula.items()}

atoms = parsed_formula.keys()
atoms = [x for x in FORMULA_SEQUENCE if x in atoms] + [x for x in atoms if x not in FORMULA_SEQUENCE]
formula = ''
for atom in atoms:
n = len(types[labels[atom]])
if n == 1:
n = ''
else:
n = str(n)
formula += atom + n
index = parsed_formula[atom]
index = '' if index == 1 else str(index)
formula += atom + index

return formula


def sgn_to_crsystem(number):
if 195 <= number <= 230:
return 'cubic'
elif 168 <= number <= 194:
return 'hexagonal'
elif 143 <= number <= 167:
return 'trigonal'
elif 75 <= number <= 142:
return 'tetragonal'
elif 16 <= number <= 74:
return 'orthorhombic'
elif 3 <= number <= 15:
return 'monoclinic'
else:
return 'triclinic'
64 changes: 64 additions & 0 deletions mpds_ml_labs/test_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

import os, sys

from struct_utils import detect_format, poscar_to_ase, symmetrize
from cif_utils import cif_to_ase
from prediction import ase_to_ml_model, load_ml_model, human_names
from common import ML_MODELS, DATA_PATH


models, structures = [], []

if sys.argv[1:]:
inputs = [f for f in sys.argv[1:] if os.path.isfile(f)]
models, structures = \
[f for f in inputs if f.endswith('.pkl')], [f for f in inputs if not f.endswith('.pkl')]

if not models:
models = ML_MODELS

if not structures:
structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f))]

active_ml_model = load_ml_model(models)

for fname in structures:
print
print(fname)
structure = open(fname).read()

fmt = detect_format(structure)

if fmt == 'cif':
ase_obj, error = cif_to_ase(structure)
if error:
print(error)
continue

elif fmt == 'poscar':
ase_obj, error = poscar_to_ase(structure)
if error:
print(error)
continue

else:
print('Error: %s is not a crystal structure' % fname)
continue

ase_obj, error = symmetrize(ase_obj)
if error:
print(error)
continue

prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
if error:
print(error)
continue

for prop_id, pdata in prediction.items():
print("{0:40} = {1:6} (MAE = {2:4}), {3}".format(
human_names[prop_id]['name'],
pdata['value'],
pdata['mae'],
human_names[prop_id]['units']
))

0 comments on commit 355a58e

Please sign in to comment.