diff --git a/geemap/colormaps.py b/geemap/colormaps.py index d56efa01f5..895de93dbc 100644 --- a/geemap/colormaps.py +++ b/geemap/colormaps.py @@ -86,8 +86,13 @@ def get_palette(cmap_name=None, n_class=None, hashtag=False): if cmap_name in ["ndvi", "ndwi", "dem", "dw", "esri_lulc"]: colors = _palette_dict[cmap_name] else: - cmap = plt.cm.get_cmap(cmap_name, n_class) - colors = [mpl.colors.rgb2hex(cmap(i))[1:] for i in range(cmap.N)] + cmap = mpl.colormaps[cmap_name] # Retrieve colormap + if n_class: + colors = [ + mpl.colors.rgb2hex(cmap(i / (n_class - 1)))[1:] for i in range(n_class) + ] + else: + colors = [mpl.colors.rgb2hex(cmap(i))[1:] for i in range(cmap.N)] if hashtag: colors = ["#" + i for i in colors] @@ -174,7 +179,7 @@ def plot_colormap( return_fig (bool, optional): Whether to return the figure. Defaults to False. """ fig, ax = plt.subplots(figsize=(width, height)) - col_map = plt.get_cmap(cmap) + col_map = mpl.colormaps[cmap] norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) @@ -210,7 +215,7 @@ def plot_colormaps(width=8.0, height=0.4): gradient = np.vstack((gradient, gradient)) for ax, name in zip(axes, cmap_list): - ax.imshow(gradient, aspect="auto", cmap=plt.get_cmap(name)) + ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name]) ax.set_axis_off() pos = list(ax.get_position().bounds) x_text = pos[0] - 0.01 diff --git a/geemap/map_widgets.py b/geemap/map_widgets.py index ae3b3f292f..6e72d529c7 100644 --- a/geemap/map_widgets.py +++ b/geemap/map_widgets.py @@ -145,7 +145,7 @@ def __init__( except ValueError as err: raise ValueError("The provided min value must be scalar type.") - vmax = vis_params.get("max", kwargs.pop("mvax", 1)) + vmax = vis_params.get("max", kwargs.pop("vmax", 1)) try: vmax = float(vmax) except ValueError as err: diff --git a/tests/test_map_widgets.py b/tests/test_map_widgets.py index c93834fbb6..4613220e5a 100644 --- a/tests/test_map_widgets.py +++ b/tests/test_map_widgets.py @@ -194,7 +194,7 @@ def test_colorbar_min_max(self): map_widgets.Colorbar( vis_params={"palette": self.TEST_COLORS, "min": -1.5}, vmin=-1, vmax=2 ) - self.normalize_class_mock.assert_called_with(vmin=-1.5, vmax=1) + self.normalize_class_mock.assert_called_with(vmin=-1.5, vmax=2) def test_colorbar_invalid_min(self): with self.assertRaisesRegex(ValueError, "min value must be scalar type"): @@ -518,7 +518,7 @@ def _validate_row(self, row, name, checked, opacity): def setUp(self): self.fake_map = fake_map.FakeMap() self.fake_map.layers = [ - fake_map.FakeTileLayer(name='OpenStreetMap'), # Basemap + fake_map.FakeTileLayer(name="OpenStreetMap"), # Basemap fake_map.FakeTileLayer( name="GMaps", visible=False, opacity=0.5 ), # Extra basemap