Skip to content

Commit

Permalink
Aug 11
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Aug 11, 2023
1 parent d6647da commit 6460043
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2,931 deletions.
22 changes: 11 additions & 11 deletions BirdSTEM/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class AdaSTEM(BaseEstimator):
def __init__(self,base_model,
task='hurdle',
ensemble_fold=1,
min_ensemble_require = 1,
min_ensemble_required = 1,
grid_len_lon_upper_threshold=25,
grid_len_lon_lower_threshold=5,
grid_len_lat_upper_threshold=25,
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(self,base_model,
warnings.warn('You have chosen HURDLE task. The goal is to first conduct classification, and then apply regression on points with *positive values*')

self.ensemble_fold = ensemble_fold
self.min_ensemble_require = min_ensemble_require
self.min_ensemble_required = min_ensemble_required
self.grid_len_lon_upper_threshold=grid_len_lon_upper_threshold
self.grid_len_lon_lower_threshold=grid_len_lon_lower_threshold
self.grid_len_lat_upper_threshold=grid_len_lat_upper_threshold
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(self,base_model,
# return {
# 'base_model':self.base_model,
# 'ensemble_fold':self.ensemble_fold,
# 'min_ensemble_require':self.min_ensemble_require,
# 'min_ensemble_required':self.min_ensemble_required,
# 'grid_len_lon_upper_threshold':self.grid_len_lon_upper_threshold,
# 'grid_len_lon_lower_threshold':self.grid_len_lon_lower_threshold,
# 'grid_len_lat_upper_threshold':self.grid_len_lat_upper_threshold,
Expand Down Expand Up @@ -391,10 +391,10 @@ def predict_proba(self,X_test,verbosity=0):
res_nan_count = res.isnull().sum(axis=1)
res_not_nan_count = len(round_res_list) - res_nan_count

pred_mean = np.where(res_not_nan_count.values < self.min_ensemble_require,
pred_mean = np.where(res_not_nan_count.values < self.min_ensemble_required,
np.nan,
res_mean.values)
pred_std = np.where(res_not_nan_count.values < self.min_ensemble_require,
pred_std = np.where(res_not_nan_count.values < self.min_ensemble_required,
np.nan,
res_std.values)

Expand Down Expand Up @@ -549,7 +549,7 @@ class AdaSTEMClassifier(AdaSTEM):
def __init__(self, base_model,
task='classification',
ensemble_fold=1,
min_ensemble_require=1,
min_ensemble_required=1,
grid_len_lon_upper_threshold=25,
grid_len_lon_lower_threshold=5,
grid_len_lat_upper_threshold=25,
Expand All @@ -568,7 +568,7 @@ def __init__(self, base_model,
super().__init__(base_model,
task,
ensemble_fold,
min_ensemble_require,
min_ensemble_required,
grid_len_lon_upper_threshold,
grid_len_lon_lower_threshold,
grid_len_lat_upper_threshold, grid_len_lat_lower_threshold,
Expand All @@ -592,7 +592,7 @@ class AdaSTEMRegressor(AdaSTEM):
def __init__(self, base_model,
task='regression',
ensemble_fold=1,
min_ensemble_require=1,
min_ensemble_required=1,
grid_len_lon_upper_threshold=25,
grid_len_lon_lower_threshold=5,
grid_len_lat_upper_threshold=25,
Expand All @@ -611,7 +611,7 @@ def __init__(self, base_model,
super().__init__(base_model,
task,
ensemble_fold,
min_ensemble_require,
min_ensemble_required,
grid_len_lon_upper_threshold,
grid_len_lon_lower_threshold,
grid_len_lat_upper_threshold, grid_len_lat_lower_threshold,
Expand All @@ -630,7 +630,7 @@ class AdaSTEMHurdle(AdaSTEM):
def __init__(self, base_model,
task='hurdle',
ensemble_fold=1,
min_ensemble_require=1,
min_ensemble_required=1,
grid_len_lon_upper_threshold=25,
grid_len_lon_lower_threshold=5,
grid_len_lat_upper_threshold=25,
Expand All @@ -649,7 +649,7 @@ def __init__(self, base_model,
super().__init__(base_model,
task,
ensemble_fold,
min_ensemble_require,
min_ensemble_required,
grid_len_lon_upper_threshold,
grid_len_lon_lower_threshold,
grid_len_lat_upper_threshold, grid_len_lat_lower_threshold,
Expand Down
Loading

0 comments on commit 6460043

Please sign in to comment.