diff --git a/streamlit/streamlit_app.py b/streamlit/streamlit_app.py index 679445d..7ebc3d8 100644 --- a/streamlit/streamlit_app.py +++ b/streamlit/streamlit_app.py @@ -1,32 +1,28 @@ -import mols2grid -from rdkit import Chem -from rdkit import DataStructs -import prolif as plf -import MDAnalysis as mda -import numpy as np -import streamlit as st +import os import tempfile -import streamlit.components.v1 as components -from moldrug import utils -from meeko import RDKitMolCreate, PDBQTMolecule +import time +import urllib.parse from io import StringIO -import seaborn as sns + import matplotlib.pyplot as plt +import MDAnalysis as mda +import mols2grid +import numpy as np import pandas as pd -from stmol import showmol -import os +import prolif as plf import pubchempy as pcp import requests -import urllib.parse +import seaborn as sns from bs4 import BeautifulSoup -import time +from meeko import PDBQTMolecule, RDKitMolCreate +from pandas.api.types import (is_categorical_dtype, is_datetime64_any_dtype, + is_numeric_dtype, is_object_dtype) +from rdkit import Chem, DataStructs +from stmol import showmol -from pandas.api.types import ( - is_categorical_dtype, - is_datetime64_any_dtype, - is_numeric_dtype, - is_object_dtype, -) +import streamlit as st +import streamlit.components.v1 as components +from moldrug import utils # TODO # add SyGma for metabolic prediction @@ -146,11 +142,14 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: def convert(number): if isinstance(number, np.floating): return float(number) - if isinstance(number, np.integer): + elif isinstance(number, np.integer): return int(number) + # It is not possible to convert, this fix some issue on MacOS + else: + return number -def plot_dist(individuals: list[utils.Individual], properties: list[str], every_gen: int = 1): +def plot_dist(individuals: list[utils.Individual], properties: list[str], every_gen: int = 1, figsize=(25, 25)): """Create the violin plot for the MolDrug run Parameters @@ -161,6 +160,8 @@ def plot_dist(individuals: list[utils.Individual], properties: list[str], every_ A list of the properties to be graph (must be attributes of the provided individuals) every_gen : int, optional Frequency to plot the distribution: every how many generations, by default 1 + fig_size : tuple(int), optional + The size of the graph, by default (25, 25) Returns ------- @@ -168,8 +169,12 @@ def plot_dist(individuals: list[utils.Individual], properties: list[str], every_ fig, axes """ # Set up the matplotlib figure + if len(properties) <= 1: + extra_plot_kwargs = dict(sharex=True, gridspec_kw={'hspace': 0.05}) + else: + extra_plot_kwargs = dict() sns.set_theme(style="whitegrid") - fig, axes = plt.subplots(nrows=len(properties), figsize=(25, 25)) + fig, axes = plt.subplots(nrows=len(properties), figsize=figsize, **extra_plot_kwargs) SawIndividuals = utils.to_dataframe(individuals).drop(['pdbqt'], axis=1).replace([np.inf, -np.inf], np.nan).dropna() SawIndividuals = SawIndividuals[SawIndividuals['kept_gens'].map(len) != 0].reset_index(drop=True) @@ -186,15 +191,19 @@ def plot_dist(individuals: list[utils.Individual], properties: list[str], every_ # Draw a violinplot with a narrow bandwidth than the default pops = pops.loc[pops['genID'].isin([gen for gen in range(0, NumGens + every_gen, every_gen)])] - if len(properties) <= 1: - sns.violinplot(hue='genID', y=properties[0], data=pops, palette="Set3", bw_adjust=.2, cut=0, linewidth=1, ax=axes, legend=False) - else: - for i, prop in enumerate(properties): + for i, prop in enumerate(properties): + if len(properties) <= 1: + sns.violinplot(hue='genID', x='genID', y=prop, data=pops, palette="Set3", bw_adjust=.2, cut=0, linewidth=1, ax=axes, legend=False) + else: sns.violinplot(hue='genID', x='genID', y=prop, data=pops, palette="Set3", bw_adjust=.2, cut=0, linewidth=1, ax=axes[i], legend=False) + # Remove x-axis labels and ticks from all but the bottom subplot + if i < len(properties) - 1: + axes[i].xaxis.set_visible(False) return fig, axes + @st.cache_data def ProtPdbBlockToProlifMol(protein_pdb_string): with tempfile.NamedTemporaryFile(prefix='.pro', suffix='.pdb', mode='w+') as tmp: @@ -611,6 +620,7 @@ def get_pubchem_dataframe(df: pd.DataFrame) -> pd.DataFrame: st.info('Nothing to show') with tab3: + st.info('🧪This feautre is still experimental. Sometimes the query just fail 🫤') PubChemCheck = st.checkbox("Explore PubChem") download_button = st.empty() if PubChemCheck: @@ -688,8 +698,13 @@ def get_pubchem_dataframe(df: pd.DataFrame) -> pd.DataFrame: with tab2: try: every_gen = st.number_input("Every how many generations:", min_value=1, max_value=moldrug_result.NumGens, value=10) + col1, col2 = st.columns(2) + + # Add widgets to each column + fig_size_x = col1.number_input("Size of the figure in x:", value=10) + fig_size_y = col2.number_input("Size of the figure in y:", value=10) properties_to_plot = [prop for prop in properties if prop not in ['genID']] - fig, axes = plot_dist(moldrug_result.SawIndividuals, properties=properties_to_plot, every_gen=every_gen) + fig, axes = plot_dist(moldrug_result.SawIndividuals, properties=properties_to_plot, every_gen=every_gen, figsize=(fig_size_x, fig_size_y)) st.pyplot(fig) except Exception: if is_GA: