-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sklearn plot saver #467
Conversation
Apply Sweep Rules to your PR?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're on the right track!
If there are any sklearn types that need to be dropped or not needed, let me know. Is |
If they all have a |
@skrawcz Python 3.7 build seems to be failing on being unable to import prediction error plot type |
The I think it would make sense to try/except for Python versions above 3.7 when importing this extension try:
assert sys.version_info >= (3, 7) # needs to be a tuple
except ImportError as e:
raise e("sklearn_plot_extension is not available for version below Python 3.8") Otherwise, the available display_classes = [
"ConfusionMatrixDisplay",
"DetCurveDisplay",
"PrecisionRecallDisplay",
"PredictionErrorDisplay",
"RocCurveDisplay",
]
SKLEARN_PLOT_TYPES = []
for class_name in display_classes:
# get the attribute via string from sklearn.metrics; if not found return None
if (class_ := getattr(sklearn.metrics, class_name, None):
SKLEARN_PLOT_TYPES.append(class_)
# Union will accept a tuple for dynamic definition but not a list
SKLEARN_PLOT_TYPES_ANNOTATION = Union[tuple(SKLEARN_PLOT_TYPES)] Finally, I think a few display functions are missing (for scikit-learn 1.3):
Since 1.2, there is (sorry for the long comment 😅 You are doing a great job!) |
Thanks a lot for clarifying! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked everything -- but this is looking good. Will add more comments later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's failing because of metrics.PredictionErrorDisplay
in the signature.
One way to get around this would be:
if hasattr(metrics, PredictionErrorDisplay):
PredictionErrorDisplay = metrics.PredictionErrorDisplay
else:
PredictionErrorDisplay = Type
and then use PredictionErrorDisplay
in place of metrics.PredictionErrorDisplay
as a type hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more things:
- need to add a variable to the module -- see my suggestion.
- need to add
sklearn_plot
to theplugins_modules
list in hamilton/function_modifiers/base.py.
Otherwise I ran the code -- it worked!
For reference I modified an example to output a confusion plot and this is how I saved things. to.png(
id="cf_matrix_test",
path="cf_matrix_test.png",
dependencies=["cm_test_display"]
), |
(oops sorry I hit the wrong button; didn't mean to close it) |
Done |
Hey @VPraharsha03 -- looks like all this needs is a rebase (to fix the list) -- otherwise this is great! I'll actually be using this in an upcoming blog post (I copied/pasted this locally, for now, but I'll swap it out when this is merged.) Will give you credit, of course! Nice work! |
yep just a rebase required and I'll merge. Thanks! |
@skrawcz @elijahbenizzy Done! |
Addresses issue #456, still WIP
sklearn plot types reference:
https://scikit-learn.org/stable/modules/classes.html#id7
I'm taking the help of matplotlib to save the plot as a .PNG
✅
I'm working on tests and would soon upload themList of display classes to be supported:
please do let me know about any changes required @skrawcz