Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for plugin analytics #12

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@
'tensorflow-macos; platform_system=="Darwin" and platform_machine=="arm64"',
"napari>=0.4.13",
"magicgui>=0.3.0",
"plausible-events",
],
)
112 changes: 112 additions & 0 deletions stardist_napari/_analytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import platform

import csbdeep
import magicgui
import napari
import numpy as np
import stardist
import tensorflow
from csbdeep.utils.tf import keras_import
from stardist.models import StarDist2D, StarDist3D

from . import __version__

keras = keras_import()


try:
from plausible_events import PlausibleEvents

PE = PlausibleEvents(domain="stardist-napari")
except:
PE = None


def consent_event(consent):
if PE is None:
return
PE.event("Consent", dict(value=consent))


def launch_event():
if PE is None:
return
# TODO: check for gpus?
# gpus = len(tensorflow.config.list_physical_devices('GPU'))
props = {
"platform": platform.platform().strip(),
"python": platform.python_version(),
"stardist-napari": __version__,
}
props.update(
{
p.__name__: p.__version__
for p in (
napari,
magicgui,
tensorflow,
keras,
csbdeep,
stardist,
)
}
)
PE.event("Launch", props)


def run_event(
model,
model_selected,
models_reg,
models_reg_public,
x_shape,
axes,
perc_low,
perc_high,
norm_image,
input_scale,
n_tiles,
prob_thresh,
nms_thresh,
cnn_output,
norm_axes,
output_type,
timelapse,
):
if PE is None:
return

def _model_name():
# only disclose model names of "public" registered/pre-trained models
model_type, model_name = model_selected
if model_type in models_reg:
return (
model_name
if (model_name in models_reg_public.get(model_type, {}))
else "Custom (registered)"
)
else:
return "Custom (folder)"

def _shape_pow2(shape, axes):
return tuple(
s if a == "C" else int(2 ** np.ceil(np.log2(s))) for s, a in zip(shape, axes)
)

props = {
"model": _model_name(),
"image_shape": _shape_pow2(x_shape, axes),
"image_axes": axes,
"image_norm": (perc_low, perc_high) if norm_image else False,
"image_scale": input_scale,
"image_tiles": n_tiles,
"thresh_prob": prob_thresh,
"thresh_nms": nms_thresh,
"output_type": output_type,
"output_cnn": cnn_output,
"norm_axes": norm_axes,
}
if "T" in axes:
props["timelapse"] = timelapse
name = {StarDist2D: "Run 2D", StarDist3D: "Run 3D"}[type(model)]
PE.event(name, props)
78 changes: 75 additions & 3 deletions stardist_napari/_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def surface_from_polys(polys):
def plugin_wrapper():
# delay imports until plugin is requested by user
# -> especially those importing tensorflow (csbdeep.models.*, stardist.models.*)
from csbdeep.models.pretrained import get_model_folder, get_registered_models
from csbdeep.models.pretrained import (
get_model_details,
get_model_folder,
get_registered_models,
)
from csbdeep.utils import (
_raise,
axes_check_and_normalize,
Expand Down Expand Up @@ -102,7 +106,7 @@ def wrapper(*args):
# -------------------------------------------------------------------------

_models, _aliases = {}, {}
models_reg = {}
models_reg, models_reg_public = {}, {}
for cls in (StarDist2D, StarDist3D):
# get available models for class
_models[cls], _aliases[cls] = get_registered_models(cls)
Expand All @@ -111,6 +115,15 @@ def wrapper(*args):
((_aliases[cls][m][0] if len(_aliases[cls][m]) > 0 else m), m)
for m in _models[cls]
]
# keys of registered (i.e. pre-trained) models that are "public",
# i.e. transmitting their names for analytics is safe (not personally identifiable)
models_reg_public[cls] = [
key
for name, key in models_reg[cls]
if get_model_details(cls, key)[2]["url"].startswith(
"https://github.com/stardist/stardist-models/releases/"
)
]

model_configs = dict()
model_threshs = dict()
Expand All @@ -136,6 +149,10 @@ def get_model(model_type, model):

# -------------------------------------------------------------------------

from ._analytics import consent_event, launch_event, run_event

# -------------------------------------------------------------------------

class Output(Enum):
Labels = "Label Image"
Polys = "Polygons / Polyhedra"
Expand Down Expand Up @@ -281,6 +298,15 @@ class TimelapseLabels(Enum):
text="Set optimized postprocessing thresholds (for selected model)",
),
defaults_button=dict(widget_type="PushButton", text="Restore Defaults"),
label_analytics=dict(
widget_type="Label",
label="<b>Analytics:</b>",
),
analytics=dict(
widget_type="CheckBox",
text="Share anonymous usage data",
value=False,
),
progress_bar=dict(label=" ", min=0, max=0, visible=False),
layout="vertical",
persist=True,
Expand Down Expand Up @@ -312,6 +338,8 @@ def plugin(
cnn_output,
set_thresholds,
defaults_button,
label_analytics,
analytics,
progress_bar: mw.ProgressBar,
) -> List[napari.types.LayerDataTuple]:

Expand Down Expand Up @@ -347,7 +375,7 @@ def plugin(
) # relevant axes present in input image
assert len(axes_norm) > 0
# always jointly normalize channels for RGB images
if ("C" in axes and image.rgb == True) and ("C" not in axes_norm):
if ("C" in axes and image.rgb is True) and ("C" not in axes_norm):
axes_norm = axes_norm + "C"
warn("jointly normalizing channels of RGB input image")
ax = axes_dict(axes)
Expand All @@ -370,6 +398,31 @@ def plugin(
# _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],))
x = normalize(x, perc_low, perc_high, axis=_axis)

if analytics:
try:
run_event(
model,
model_selected,
models_reg,
models_reg_public,
x.shape,
axes,
perc_low,
perc_high,
norm_image,
input_scale,
n_tiles,
prob_thresh,
nms_thresh,
cnn_output,
norm_axes,
output_type={t.value: t.name for t in Output}[output_type],
timelapse={t.value: t.name for t in TimelapseLabels}[timelapse_opts],
)
except Exception as e:
if DEBUG:
raise e

# TODO: progress bar (labels) often don't show up. events not processed?
if "T" in axes:
app = use_app()
Expand Down Expand Up @@ -616,6 +669,7 @@ def progress(it, **kwargs):
plugin.n_tiles.value = DEFAULTS["n_tiles"]
plugin.input_scale.value = DEFAULTS["input_scale"]
plugin.label_head.value = '<small>Star-convex object detection for 2D and 3D images.<br>If you are using this in your research please <a href="https://github.com/stardist/stardist#how-to-cite" style="color:gray;">cite us</a>.</small><br><br><tt><a href="https://stardist.net" style="color:gray;">https://stardist.net</a></tt>'
plugin.label_analytics.value = '<small>Help improve this plugin by sharing usage data.<br></small><a href="https://github.com/stardist/stardist-napari/pull/12" style="color:gray;">What is shared and why.</a>'

# make labels prettier (https://doc.qt.io/qt-5/qsizepolicy.html#Policy-enum)
for w in (plugin.label_head, plugin.label_nn, plugin.label_nms, plugin.label_adv):
Expand Down Expand Up @@ -1179,6 +1233,16 @@ def restore_defaults():
for k, v in DEFAULTS.items():
getattr(plugin, k).value = v

# analytics
@change_handler(plugin.analytics, init=True)
def _analytics_change():
widgets_inactive(plugin.label_analytics, active=(not plugin.analytics.value))
try:
consent_event(plugin.analytics.value)
except Exception as e:
if DEBUG:
raise e

# -------------------------------------------------------------------------

# allow some widgets to shrink because their size depends on user input
Expand All @@ -1188,6 +1252,7 @@ def restore_defaults():
plugin.timelapse_opts.native.setMinimumWidth(240)

plugin.label_head.native.setOpenExternalLinks(True)
plugin.label_analytics.native.setOpenExternalLinks(True)
# make reset button smaller
# plugin.defaults_button.native.setMaximumWidth(150)

Expand All @@ -1207,4 +1272,11 @@ def restore_defaults():
# for i in range(layout.count()):
# print(i, layout.itemAt(i).widget())

if plugin.analytics.value:
try:
launch_event()
except Exception as e:
if DEBUG:
raise e

return plugin