Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sklearn plot saver #467

Merged
merged 19 commits into from
Oct 23, 2023
Merged

add sklearn plot saver #467

merged 19 commits into from
Oct 23, 2023

Conversation

VPraharsha03
Copy link
Contributor

@VPraharsha03 VPraharsha03 commented Oct 17, 2023

Addresses issue #456, still WIP

sklearn plot types reference:
https://scikit-learn.org/stable/modules/classes.html#id7

I'm taking the help of matplotlib to save the plot as a .PNG

I'm working on tests and would soon upload them

List of display classes to be supported:

  • 'CalibrationDisplay', <class 'sklearn.calibration.CalibrationDisplay'>
  • 'ConfusionMatrixDisplay', <class 'sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay'>
  • 'DecisionBoundaryDisplay', <class 'sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay'>
  • 'DetCurveDisplay', <class 'sklearn.metrics._plot.det_curve.DetCurveDisplay'>
  • 'LearningCurveDisplay', <class 'sklearn.model_selection._plot.LearningCurveDisplay'>
  • 'PartialDependenceDisplay', <class 'sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay'>
  • 'PrecisionRecallDisplay', <class 'sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay'>
  • 'PredictionErrorDisplay', <class 'sklearn.metrics._plot.regression.PredictionErrorDisplay'>
  • 'RocCurveDisplay', <class 'sklearn.metrics._plot.roc_curve.RocCurveDisplay'>
  • 'ValidationCurveDisplay', <class 'sklearn.model_selection._plot.ValidationCurveDisplay'>

please do let me know about any changes required @skrawcz

@sweep-ai
Copy link
Contributor

sweep-ai bot commented Oct 17, 2023

Apply Sweep Rules to your PR?

  • Apply: Leftover TODOs in the code should be handled.
  • Apply: All new business logic should have corresponding unit tests in the tests/ directory.
  • Apply: Any clearly inefficient or repeated code should be optimized or refactored.

Copy link
Collaborator

@skrawcz skrawcz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're on the right track!

hamilton/plugins/sklearn_plot_extensions.py Outdated Show resolved Hide resolved
hamilton/plugins/sklearn_plot_extensions.py Outdated Show resolved Hide resolved
hamilton/plugins/sklearn_plot_extensions.py Outdated Show resolved Hide resolved
hamilton/plugins/sklearn_plot_extensions.py Outdated Show resolved Hide resolved
@VPraharsha03
Copy link
Contributor Author

If there are any sklearn types that need to be dropped or not needed, let me know. Is CalibrationDisplay required ? @skrawcz

@skrawcz
Copy link
Collaborator

skrawcz commented Oct 18, 2023

If there are any sklearn types that need to be dropped or not needed, let me know. Is CalibrationDisplay required ? @skrawcz

If they all have a figure_ attribute then it should be supported. So we should validate that assumption for all the listed classes.

@VPraharsha03
Copy link
Contributor Author

@skrawcz Python 3.7 build seems to be failing on being unable to import prediction error plot type

@zilto
Copy link
Collaborator

zilto commented Oct 19, 2023

@skrawcz Python 3.7 build seems to be failing on being unable to import prediction error plot type

The Display classes didn't exist in scikit-learn 0.21, the last version supporting versions Python 3.7. (0.21 docs). Meanwhile, Python 3.8 is supported by the latest scikit-learn version (1.3).

I think it would make sense to try/except for Python versions above 3.7 when importing this extension

try:
    assert sys.version_info >= (3, 7)  # needs to be a tuple
except ImportError as e:
    raise e("sklearn_plot_extension is not available for version below Python 3.8")

Otherwise, the available Display API changed a lot (e.g., PredictionErrorDisplay isn't available before 1.2). To cover the edge cases we could:

display_classes = [
  "ConfusionMatrixDisplay",
  "DetCurveDisplay",
  "PrecisionRecallDisplay",
  "PredictionErrorDisplay",
  "RocCurveDisplay",
]
SKLEARN_PLOT_TYPES = []
for class_name in display_classes:
  # get the attribute via string from sklearn.metrics; if not found return None
  if (class_ := getattr(sklearn.metrics, class_name, None):
    SKLEARN_PLOT_TYPES.append(class_)

# Union will accept a tuple for dynamic definition but not a list
SKLEARN_PLOT_TYPES_ANNOTATION = Union[tuple(SKLEARN_PLOT_TYPES)]

Finally, I think a few display functions are missing (for scikit-learn 1.3):

  • DecisionBoundaryDisplay,
  • PartialDependenceDisplay,
  • CalibrationDisplay
  • LearningCurveDisplay
  • ValidationCurverDisplay

Since 1.2, there is sklearn.utils.discovery.all_displays() that you can use!

(sorry for the long comment 😅 You are doing a great job!)

@VPraharsha03
Copy link
Contributor Author

Thanks a lot for clarifying!

Copy link
Collaborator

@skrawcz skrawcz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked everything -- but this is looking good. Will add more comments later.

tests/plugins/test_sklearn_plot_extensions.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@skrawcz skrawcz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's failing because of metrics.PredictionErrorDisplay in the signature.

One way to get around this would be:

if hasattr(metrics, PredictionErrorDisplay):
    PredictionErrorDisplay = metrics.PredictionErrorDisplay
else:
    PredictionErrorDisplay = Type

and then use PredictionErrorDisplay in place of metrics.PredictionErrorDisplay as a type hint.

Copy link
Collaborator

@skrawcz skrawcz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more things:

  1. need to add a variable to the module -- see my suggestion.
  2. need to add sklearn_plot to the plugins_modules list in hamilton/function_modifiers/base.py.

Otherwise I ran the code -- it worked!

hamilton/plugins/sklearn_plot_extensions.py Show resolved Hide resolved
@skrawcz
Copy link
Collaborator

skrawcz commented Oct 21, 2023

For reference I modified an example to output a confusion plot and this is how I saved things.

to.png(
        id="cf_matrix_test",
        path="cf_matrix_test.png",
        dependencies=["cm_test_display"]
    ),

@skrawcz skrawcz closed this Oct 21, 2023
@skrawcz skrawcz reopened this Oct 21, 2023
@skrawcz
Copy link
Collaborator

skrawcz commented Oct 21, 2023

(oops sorry I hit the wrong button; didn't mean to close it)

@VPraharsha03
Copy link
Contributor Author

Two more things:

  1. need to add a variable to the module -- see my suggestion.
  2. need to add sklearn_plot to the plugins_modules list in hamilton/function_modifiers/base.py.

Otherwise I ran the code -- it worked!

Done

@elijahbenizzy
Copy link
Collaborator

Hey @VPraharsha03 -- looks like all this needs is a rebase (to fix the list) -- otherwise this is great!

I'll actually be using this in an upcoming blog post (I copied/pasted this locally, for now, but I'll swap it out when this is merged.) Will give you credit, of course! Nice work!

@skrawcz
Copy link
Collaborator

skrawcz commented Oct 23, 2023

yep just a rebase required and I'll merge. Thanks!

@VPraharsha03
Copy link
Contributor Author

@skrawcz @elijahbenizzy Done!

@skrawcz skrawcz self-requested a review October 23, 2023 04:40
@skrawcz skrawcz merged commit 2c81bc4 into DAGWorks-Inc:main Oct 23, 2023
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants