Skip to content

Commit

Permalink
Fix typo and deprecate plt.get_cmap function (#1829)
Browse files Browse the repository at this point in the history
* Fix typo

* Fix colormap error
  • Loading branch information
giswqs authored Nov 13, 2023
1 parent d8ba370 commit dd0069c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
13 changes: 9 additions & 4 deletions geemap/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def get_palette(cmap_name=None, n_class=None, hashtag=False):
if cmap_name in ["ndvi", "ndwi", "dem", "dw", "esri_lulc"]:
colors = _palette_dict[cmap_name]
else:
cmap = plt.cm.get_cmap(cmap_name, n_class)
colors = [mpl.colors.rgb2hex(cmap(i))[1:] for i in range(cmap.N)]
cmap = mpl.colormaps[cmap_name] # Retrieve colormap
if n_class:
colors = [
mpl.colors.rgb2hex(cmap(i / (n_class - 1)))[1:] for i in range(n_class)
]
else:
colors = [mpl.colors.rgb2hex(cmap(i))[1:] for i in range(cmap.N)]
if hashtag:
colors = ["#" + i for i in colors]

Expand Down Expand Up @@ -174,7 +179,7 @@ def plot_colormap(
return_fig (bool, optional): Whether to return the figure. Defaults to False.
"""
fig, ax = plt.subplots(figsize=(width, height))
col_map = plt.get_cmap(cmap)
col_map = mpl.colormaps[cmap]

norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

Expand Down Expand Up @@ -210,7 +215,7 @@ def plot_colormaps(width=8.0, height=0.4):
gradient = np.vstack((gradient, gradient))

for ax, name in zip(axes, cmap_list):
ax.imshow(gradient, aspect="auto", cmap=plt.get_cmap(name))
ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name])
ax.set_axis_off()
pos = list(ax.get_position().bounds)
x_text = pos[0] - 0.01
Expand Down
2 changes: 1 addition & 1 deletion geemap/map_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
except ValueError as err:
raise ValueError("The provided min value must be scalar type.")

vmax = vis_params.get("max", kwargs.pop("mvax", 1))
vmax = vis_params.get("max", kwargs.pop("vmax", 1))
try:
vmax = float(vmax)
except ValueError as err:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_map_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_colorbar_min_max(self):
map_widgets.Colorbar(
vis_params={"palette": self.TEST_COLORS, "min": -1.5}, vmin=-1, vmax=2
)
self.normalize_class_mock.assert_called_with(vmin=-1.5, vmax=1)
self.normalize_class_mock.assert_called_with(vmin=-1.5, vmax=2)

def test_colorbar_invalid_min(self):
with self.assertRaisesRegex(ValueError, "min value must be scalar type"):
Expand Down Expand Up @@ -518,7 +518,7 @@ def _validate_row(self, row, name, checked, opacity):
def setUp(self):
self.fake_map = fake_map.FakeMap()
self.fake_map.layers = [
fake_map.FakeTileLayer(name='OpenStreetMap'), # Basemap
fake_map.FakeTileLayer(name="OpenStreetMap"), # Basemap
fake_map.FakeTileLayer(
name="GMaps", visible=False, opacity=0.5
), # Extra basemap
Expand Down

0 comments on commit dd0069c

Please sign in to comment.