Skip to content

Commit

Permalink
allow specifying interpolation type when resizing images
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyu committed Jan 17, 2024
1 parent 30cae30 commit f1560fb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
11 changes: 9 additions & 2 deletions alf/environments/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ def __getattr__(self, name):

@alf.configurable
class FrameResize(BaseObservationWrapper):
def __init__(self, env, width=84, height=84, fields=None):
def __init__(self,
env,
width=84,
height=84,
fields=None,
interpolation=cv2.INTER_AREA):
"""Create a FrameResize instance
Args:
Expand All @@ -311,9 +316,11 @@ def __init__(self, env, width=84, height=84, fields=None):
height (int): resize height
fields (list[str]): fields to be resized, A field str is a multi-level
path denoted by "A.B.C". If None, then non-nested observation is resized
interpolation (int): cv2 interploation type
"""
self._width = width
self._height = height
self._interpolation = interpolation
super().__init__(env, fields=fields)

def transform_space(self, observation_space):
Expand All @@ -328,7 +335,7 @@ def transform_space(self, observation_space):
def transform_observation(self, observation):
obs = cv2.resize(
observation, (self._width, self._height),
interpolation=cv2.INTER_AREA)
interpolation=self._interpolation)
if len(obs.shape) != 3:
obs = obs[:, :, np.newaxis]
return obs
Expand Down
5 changes: 3 additions & 2 deletions alf/summary/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def data(self):
"""Return the image numpy array which is always RGB."""
return self._img

def resize(self, height=None, width=None):
def resize(self, height=None, width=None, interploation=cv2.INTER_NEAREST):
"""Resize the image in-place given the desired width and/or height.
Args:
Expand All @@ -108,6 +108,7 @@ def resize(self, height=None, width=None):
width (int): the desired output image width. If ``None``, this will
be scaled to keep the original aspect ratio if ``height`` is
provided.
interpolation (int): cv2 interpolation type
Returns:
Image: self after resizing
Expand All @@ -126,7 +127,7 @@ def resize(self, height=None, width=None):
dsize=(0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_LINEAR)
interpolation=interploation)
return self

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion alf/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def append_coordinate(im: torch.Tensor):
assert len(im.shape) == 4, "Image must have a shape of [B,C,H,W]!"
y = torch.arange(-1., 1., step=2. / im.shape[-2])
x = torch.arange(-1., 1., step=2. / im.shape[-1])
yy, xx = torch.meshgrid(y, x)
yy, xx = torch.meshgrid(y, x, indexing='ij')
# [H,W] -> [B,H,W]
yy = alf.utils.tensor_utils.tensor_extend_new_dim(yy, dim=0, n=im.shape[0])
xx = alf.utils.tensor_utils.tensor_extend_new_dim(xx, dim=0, n=im.shape[0])
Expand Down

0 comments on commit f1560fb

Please sign in to comment.