Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wrapper for sklearn #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [

{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"cwd": "${fileDirname}",
"justMyCode": true,
"terminal.integrated.inheritEnv": true
},

{
"name": "Python: Debug Unit Tests",
"type": "python",
"program": "${file}",
"request": "launch",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false,
}

]
}
2 changes: 2 additions & 0 deletions batterylearn/elements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LinearMechanical,
LinearTimeInvariant,
Mechanical,
SkEstimator,
)

__author__ = "Xiaojun Li tonylee2016@gmail.com"
Expand All @@ -16,4 +17,5 @@
"Dynamical",
"Base",
"Container",
"SkEstimator",
]
4 changes: 4 additions & 0 deletions batterylearn/elements/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from abc import abstractmethod

import numpy as np
from sklearn import base

# from ahkab import circuit

class SkEstimator(base.BaseEstimator):
def __class_name(self):
return 'SkEstimator'

class Base:
def __init__(self, type, name=""):
Expand Down
4 changes: 2 additions & 2 deletions batterylearn/simulations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .learns import Learner
from .learns import Learner,LearnerSk
from .simulations import Current, Data, Simulator

__author__ = "Xiaojun Li tonylee2016@gmail.com"

__all__ = ["Simulator", "Data", "Current", "Learner"]
__all__ = ["Simulator", "Data", "Current", "Learner",'LearnerSk']
27 changes: 24 additions & 3 deletions batterylearn/simulations/learns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
minimize,
shgo,
)
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted

from .simulations import Simulator
from batterylearn.elements import SkEstimator


class Learner(Simulator, BaseEstimator):
class Learner(Simulator):
"""
a wrapper for model training
a wrapper for model training, using scipy.optimize

"""

Expand Down Expand Up @@ -171,3 +172,23 @@ def residuals(self, p0, names, config, x0, method, bounds):
print("rmse", res, len(meas_vt), len(sim_vt))
return res
return meas_vt - sim_vt

class LearnerSk(SkEstimator):
"""wrapper for model training, using sklearn

Args:
SkEstimator (_type_): _description_
"""

def __init__(self,parameter = 'demo_param'):
self.parameter = parameter

def fit(self,X,y):
X, y = self._validate_data(X, y, accept_sparse=False)
self.is_fitted_ = True
return self

def predict(self,X):
X = check_array(X, accept_sparse=False)
check_is_fitted(self, 'is_fitted_')
return np.ones(X.shape[0], dtype=np.int64)
1 change: 0 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pandas as pd
import os
from batterylearn.models import OCV, EcmCell
from batterylearn.utilities import ivp
from batterylearn.simulations import Simulator, Data, Current
import matplotlib.pyplot as plt
from scipy import optimize
Expand Down
10 changes: 10 additions & 0 deletions tests/test_new_estimator
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from batterylearn.models import OCV, EcmCell
from batterylearn.simulations import LearnerSk
import numpy as np
from sklearn.utils.estimator_checks import check_estimator
from sklearn.base import BaseEstimator


a = LearnerSk()
b = check_estimator(a)
pass