Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new input format as numpy array representing the image #106

Merged
merged 9 commits into from
Sep 10, 2024
79 changes: 61 additions & 18 deletions DECIMER/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from PIL import ImageEnhance
from pillow_heif import register_heif_opener
from typing import Union

import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
import DECIMER.Transformer_decoder as Transformer_decoder
Expand Down Expand Up @@ -95,26 +96,68 @@ def HEIF_to_pillow(image_path: str):
return heif_file


def remove_transparent(image_path: str):
def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image:
"""
Removes the transparent layer from a PNG image with an alpha channel
Args: image_path (str): path of input image
Returns: PIL.Image
Removes the transparent layer from a PNG image with an alpha channel.

Args:
image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image.

Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
def process_image(png: Image.Image) -> Image.Image:
"""
Helper function to remove transparency from a single image.

Args:
png (PIL.Image.Image): The input PIL image with transparency.

Returns:
PIL.Image.Image: The image with transparency removed.
"""
background = Image.new("RGBA", png.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, png)
return alpha_composite

background = Image.new("RGBA", png.size, (255, 255, 255))
def handle_image_path(image_path: str) -> Image.Image:
"""
Helper function to handle image paths.

Args:
image_path (str): The path to the input image.

Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
return process_image(png)

def handle_numpy_array(array: np.ndarray) -> Image.Image:
"""
Helper function to handle a numpy array.

Args:
array (np.ndarray): The numpy array representing the image.

Returns:
PIL.Image.Image: The image with transparency removed.
"""
png = Image.fromarray(array).convert("RGBA")
return process_image(png)

alpha_composite = Image.alpha_composite(background, png)
# Check if input is a numpy array
if isinstance(image, np.ndarray):
return handle_numpy_array(array=image)

return alpha_composite
return handle_image_path(image_path=image)


def get_bnw_image(image):
Expand Down Expand Up @@ -185,12 +228,12 @@ def increase_brightness(image):
return image


def decode_image(image_path: str):
def decode_image(image_path: Union[str, np.ndarray]):
"""Loads an image and preprocesses the input image in several steps to get
the image ready for DECIMER input.

Args:
image_path (str): path of input image
image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image.

Returns:
Processed image
Expand Down Expand Up @@ -237,7 +280,7 @@ def initialize_encoder_config(
backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder
image_shape (int): Shape of the input image
do_permute (bool, optional): . Defaults to False.
pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None.
pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None.
"""
self.encoder_config = dict(
image_embedding_dim=image_embedding_dim,
Expand Down
13 changes: 7 additions & 6 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from typing import Tuple

import numpy as np
import pystow
import tensorflow as tf

Expand Down Expand Up @@ -122,19 +123,19 @@ def detokenize_output_add_confidence(


def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.

Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image.
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction.
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn.

Returns:
str: SMILES representation of the molecule in the input image, optionally with confidence values
str: SMILES representation of the molecule in the input image, optionally with confidence values.
"""
chemical_structure = config.decode_image(image_path)
chemical_structure = config.decode_image(image_input)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))
Expand Down
Loading
Loading