Skip to content

Conversation

@edwinjosechittilappilly
Copy link
Contributor

This pull request includes several significant changes to the quoptuna backend, focusing on enhancing model evaluation and reporting capabilities. The most important changes include the addition of new dependencies, improvements to data handling and serialization, and the integration of advanced reporting features using LangChain.

Dependencies:

  • Added new dependencies to pyproject.toml for LangChain and related libraries.

Data Handling and Serialization:

  • Introduced save_state and load_state methods in prepare.py and xai.py for class state serialization using pickle. [1] [2]

Model Evaluation Enhancements:

  • Added properties and methods in xai.py for handling test data and predictions, including x_test, y_test, predictions, and predictions_proba.
  • Enhanced get_plot method to support saving plots and returning base64-encoded images. [1] [2]
  • Added various methods for generating model evaluation metrics and plots, such as confusion matrix, classification report, ROC curve, and more.

Reporting with LangChain:

  • Implemented a comprehensive report generation method using LangChain, which includes detailed analysis and visualizations of model performance.

These changes collectively improve the functionality and usability of the quoptuna backend, providing more robust tools for model evaluation and reporting.

@dosubot dosubot bot added the size:L This PR changes 100-499 lines, ignoring generated files. label Feb 14, 2025
@dosubot dosubot bot added dependencies Pull requests that update a dependency file enhancement New feature or request labels Feb 14, 2025
@codecov
Copy link

codecov bot commented Feb 14, 2025

❌ 5 Tests Failed:

Tests completed Failed Passed Skipped
37 5 32 0
View the top 3 failed test(s) by shortest run time
tests/test_xai.py::test_get_heatmap_plot
Stack Traces | 0.539s run time
trained_model_sample = MLPClassifier()
load_data = {'x_train':         Age  Workclass  Education-Num  ...  Capital Loss  Hours per week  Country
22278  27.0          4  ...,
       1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0])}

    def test_get_heatmap_plot(trained_model_sample, load_data):
        # Test the get_heatmap_plot method
        xai = XAI(model=trained_model_sample, data=load_data, onsubset=True, subset_size=5)
        heatmap_plot = xai.get_heatmap_plot(class_index=0)
>       assert isinstance(heatmap_plot, plt.Figure)  # Check if a figure is returned
E       AssertionError: assert False
E        +  where False = isinstance('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAyAAAANSCAYAAACdmqqiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACiR0lEQVR4nOzdd3gU1f/28XuTkIQkJIQk9JLQe5XeQlVp0i2AICpSVEQFDCLtCwQQBekiXYoo0qS3hCYCUqTXEECkQxIgECCZ5w8e9seaDXWziZv367rmutyZM3M+s7a9OXPOmAzDMAQAAAAAduCU0gUAAAAASDsIIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAADshgACAAAAwG4IIAAAAAD...EAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaf4f+ZLc/83n+tUAAAAASUVORK5CYII=', <class 'matplotlib.figure.Figure'>)
E        +    where <class 'matplotlib.figure.Figure'> = plt.Figure

tests/test_xai.py:130: AssertionError
tests/test_xai.py::test_get_bar_plot
Stack Traces | 0.599s run time
trained_model_sample = MLPClassifier()
load_data = {'x_train':         Age  Workclass  Education-Num  ...  Capital Loss  Hours per week  Country
22278  27.0          4  ...,
       1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0])}

    def test_get_bar_plot(trained_model_sample, load_data):
        # Test the get_bar_plot method
        xai = XAI(model=trained_model_sample, data=load_data, onsubset=True, subset_size=5)
        bar_plot = xai.get_bar_plot(class_index=0)
>       assert isinstance(bar_plot, plt.Figure)  # Check if a figure is returned
E       AssertionError: assert False
E        +  where False = isinstance('data:image/png;base64,.../44efR4QFFxo/J+Pc954Ky99lrffeoPPq699rEYhmEIAAAAAEzglN8FAAAAACg4CCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBAAAAAApiGAAAAAADANAQQAAACAaQggAAAAAExDAAEAAABgGgIIAAAAANMQQAAAAACYhgACAAAAwDQEEAAAAACmIYAAAAAAMA0BBAAAAIBpCCAAAAAATEMAAQAAAGAaAggAAAAA0xBAAAAAAJiGAAIAAADANAQQAAAAAKYhgAAAAAAwDQEEAAAAgGkIIAAAAABMQwABAAAAYBoCCAAAAADTEEAAAAAAmIYAAgAAAMA0BBA...GRbrvtNn311VdasGCBbr/99goXfV+.../2VmTlzprZs2aL7779fd9xxh/z9/XX8+HG9/fbb2rFjh7p3765+/fq5HM/69et15MgRPf7445XWGTRokKZMmaKkpCTdddddLrdVkfj4eP3zn//U+PHj5e/vrwEDBjhtr4nzMXToUL344ov67W9/q+zsbDVo0EDr1q2r8Gntd911l6ZMmaIpU6aoQ4cOGjx4sEJCQnTs2DHt2LFDmZmZOnv2bJX6lpmZqaCgIHXv3r1K9QHgWkECAgA1oFevXtq+fbsSExO1fPlynThxQoGBgWrVqpXGjRundu3aSbr4jfu7776rCRMmaOnSpTp9+rQiIiK0dOlS7dq1q9oJyOjRo/XKK68oJSVF48ePv2L9hIQEpaam6qOPPtJ7772nvLw8+fj4KDw8XK+.../WPf/xDb775pry9vV1u71JxcXFq0KCB8vLyNGrUKNWrV89pe02cDz8/P2VmZmrcuHGaPn26fH19NXDgQC1fvtxxM4OyJk+erI4dO2rOnDmaNWuWTp8+rYYNGyoiIqLShyde6vTp00pLS9OYMWN4CjqA647NXO3tUgAA17Tf/e53Wr9+vfbt2+d42J508YndWVlZOnToUN0Fh6uyZMkSjRgxQgcPHlTz5s0d5aUPTPz666+rNNMFANcS1oAAwA1m6tSpOnnypBYvXlzXoaAWFBcXKzExURMnTiT5AHBd4hIsALjBNGzYUAUFBXUdBmqJt7e3jh07VtdhAIDLmAEBAAAAYBnWgAAAAACwDDMgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACxDAgIAAADAMiQgAAAAACzz/wFfTD0Li+B+sAAAAABJRU5ErkJggg==', <class 'matplotlib.figure.Figure'>)
E        +    where <class 'matplotlib.figure.Figure'> = plt.Figure

tests/test_xai.py:102: AssertionError
tests/test_xai.py::test_get_beeswarm_plot
Stack Traces | 0.67s run time
trained_model_sample = MLPClassifier()
load_data = {'x_train':         Age  Workclass  Education-Num  ...  Capital Loss  Hours per week  Country
22278  27.0          4  ...,
       1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0])}

    def test_get_beeswarm_plot(trained_model_sample, load_data):
        # Test the get_beeswarm_plot method
        xai = XAI(model=trained_model_sample, data=load_data, onsubset=True, subset_size=5)
        beeswarm_plot = xai.get_beeswarm_plot(class_index=0)
>       assert isinstance(beeswarm_plot, plt.Figure)  # Check if a figure is returned
E       AssertionError: assert False
E        +  where False = isinstance('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAyAAAAJ2CAYAAACn0OY2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACPy0lEQVR4nOzdeXxM1//H8fdkFUkk9iWxr7VVia1IqV1qLaWttbUXVbRa2orSzbfVUlSotarW2tJSO6WotUpbqqglUktEEmSd+/vDzzDNxBozmeT1fDzm8XDPXc7nZoT5zOecc02GYRgCAAAAADtwcXQAAAAAALIOEhAAAAAAdkMCAgAAAMBuSEAAAAAA2A0JCAAAAAC7IQEBAAAAYDckIAAAAADshgQEAAAAgN2QgAAAAACwGxIQAAAAAHZDAgIAAADAbkhAAAAAANgNCQgAAAAAuyEBAQAAAGA3JCAAAAAA7IYEBAAAAIDdkIAAAAAAsBsSEAAAAAB2QwICAAAAwG5IQAAAAADYDQkIAAAAALshAQEAAABgNyQgAAAAAOyGBAQAAACA3ZCAAAAAALAbEhAAAAAAdkMCAgAAAMBuSEAAAAAA2A0JCAAAAAC7IQEBAAAAYDckIAAAAADshgQEAAAAgN2QgAAAAACwGxIQAAAAAHZDAgIAAADAbkhAAAAAANgNCQgAAAAAuyEBAQAAAGA3JCAAAAAA7IYEBAAAAIDdkIAAAAAAsBsSEAAAAAB2QwICAAAAwG5IQAAAAADYDQkIAAAAALshAQEAAABgNyQgAAAAAOyGBAQAAACA3ZCAAAAAALAbEhAAAAAAdkMCAgAAAMBuSEAAAAAA2A0JCAAAAAC7IQEBAAAAYDckIAAAAADshgQEAAAAgN2QgAAAAACwGxIQAAAAAHZDAgIAAADAbkhAAAAAANgNCQgAAAAAuyEBAQAAAGA3JCAAAAAA7IYEBAAAAIDdkIAAAAAAsBsSEAAAAAB2QwICAAAAwG5IQAAAAADYDQkIAAAAALshAQEAAMhCQkND5ePjc9d9J0+elMlk0pIlS+7r+g96HrION0cHAAAAgIynYMGC2rFjh8qUKePoUJDJkIAAAAAgFU9PT9WqVcvRYSATYggWAAAAUrE1lCoxMVGDBg1Srly55O/vrz59+mj+/PkymUw6efKk1fn...nV1ZXMzc1JS0uLzMzMyM/Pj9LS0ujRo0ei/BKJhJydnXstT77EpnwZ3okTJ5K2trbCcrhP+ueff8jY2JgcHByENPz/cqjPqqGhgbS1tQkAbdy4UWmeK1eu0JtvvkmmpqZkZGREU6ZMoTNnzihdLrS3JUSPHDlC48aNI11dXbKxsaEVK1ZQRUVFr0uIZmZmko+PDxkbG5Oenh5JJBIKCQmhvXv39uu85EvX5uTkiNKftgyvsiVFJRIJTZkyRSFdviRtTU2NkCZfxrS6upqCgoLI2NiYjIyMKCgoiKqqqhTK+OGHH8jR0ZH09PTI2tqaIiMjqampSWGpVbns7Gzy9vYmIyMjMjAwoDFjxlBsbKxoOdsjR46Qm5sb6enpEQClx97TzZs3acGCBWRpaUk6OjpkZ2dHq1evFi1b29s599VOPcmX4X1y6Vu53s67t+9Ubm4ueXl5kYGBARkYGJCXlxfl5eUprXfHjh3k4OBAurq6NGrUKNq6dauwXHPPY5FKpbRhwwZ69dVXSV9fn4yMjGjs2LG0ePFiKikpEfKpuuxxf9uZiKikpIS8vb1JT0+PzM3NKTIyklpaWpS20YULF8jb25sMDAwIgLCU7pPL52ZnZ5OLiwvp6uqSra0tffbZZ9Td3a1QryrX6dO+a6dOnSIAdPjw4X61DWOMPQ8aRAOcmckYY89JYGAgpFIpzp49q5b6/Pz8UFtbi9raWrXUx9jT1NbWws7ODuvXrxf1vqlDSEgIbt26hYsXL/5nFk9gjL38eA4IY+xfl5iYiPPnzw/o2Q2MsYEpLy9Hfn4+EhMTOfhgjKkVzwFhjP3rnJ2dX/jSpYwxMTc3N4VlpBljTB24B4QxxhhjjDGmNjwHhDHGGGOMMaY23APCGGOMMcYYUxsOQBhjjDHGGGNqwwEIY4wxxhhjTG04AGGMMcYYY4ypDQcgjDHGGGOMMbXhAIQxxhhjjDGmNhyAMMYYY4wxxtSGAxDGGGOMMcaY2nAAwhhjjDHGGFMbDkAYY4wxxhhjasMBCGOMMcYYY0xtOABhjDHGGGOMqQ0HIIwxxhhjjDG14QCEMcYYY4wxpjYcgDDGGGOMMcbUhgMQxhhjjDHGmNpwAMIYY4wxxhhTGw5AGGOMMcYYY2rDAQhjjDHGGGNMbTgAYYwxxhhjjKnN/wHa2WS+/MtcdQAAAABJRU5ErkJggg==', <class 'matplotlib.figure.Figure'>)
E        +    where <class 'matplotlib.figure.Figure'> = plt.Figure

tests/test_xai.py:109: AssertionError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Overview

This PR introduces optimizer improvements and backend refactorings to enhance model evaluation and reporting capabilities through added serialization support, refined plotting, and LangChain‐based report generation. Key changes include:

  • Addition of pickle-based save_state/load_state methods in both prepare.py and xai.py.
  • Enhancements to the get_plot method and new properties for managing test data, predictions, and various evaluation metrics in xai.py.
  • Integration of LangChain for comprehensive model evaluation report generation along with dependency updates in pyproject.toml.

Reviewed Changes

File Description
src/quoptuna/backend/utils/data_utils/prepare.py Added serialization methods for class state persistence using pickle.
src/quoptuna/backend/xai/xai.py Enhanced properties, plotting functions (returning base64 images), added evaluation metrics, and LangChain-based report generation.
pyproject.toml Introduced new dependencies for LangChain integration.

Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.

Comment on lines +263 to 265
return values
self._handle_plot_error(plot_type, e)

Copy link

Copilot AI Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'return values' statement in the exception block prevents the subsequent error handling call from executing. Consider removing or repositioning it to ensure errors are handled appropriately.

Suggested change
return values
self._handle_plot_error(plot_type, e)
self._handle_plot_error(plot_type, e)
return values

Copilot uses AI. Check for mistakes.
Comment on lines +348 to +351
def get_classification_report(self):
"""Get the classification report of the model."""
return classification_report(self.y_test, self.predictions)

Copy link

Copilot AI Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate definition of 'get_classification_report' detected. Remove one of the definitions to avoid unintended overrides.

Suggested change
def get_classification_report(self):
"""Get the classification report of the model."""
return classification_report(self.y_test, self.predictions)

Copilot uses AI. Check for mistakes.
"""Get the f1 score of the model."""
return f1_score(self.y_test, self.predictions)

def get_average_precision_score(self):
Copy link

Copilot AI Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate definition of 'get_average_precision_score' found. Consolidate the implementations to maintain consistent behavior.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file enhancement New feature or request size:L This PR changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants