forked from AnupamKhare/code-Repository
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVariable selection
53 lines (47 loc) · 1.95 KB
/
Variable selection
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
######### . select important variable using vif#############
#functions
def remove_by_vif(X,vif=5):
"""Remove columns from X whose VIF is greater than supplied 'vif'
Parameters:
X:array or dataframe containing data excluding target variable
vif: int or float of limiting value of VIF
Note:
This function changes X inplace
"""
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
removed=[]
for i in range(len(X.columns)):
l = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
s=pd.Series(index=X.columns,data=l).sort_values(ascending=False)
if s.iloc[0]>vif:
X.drop(s.index[0],axis=1,inplace=True)
print(s.index[0],', VIF: ',s.iloc[0])
removed.append(s.index[0])
else:
break
return removed
def plot_validation_curve(scores,param_range,param_name,scoring='r2'):
"""This function plot validation curve.
Parameters:
scores: scores obtained from validation_curve() method
param_range: list of range of parameters passed as 'param_range' in validation_curve() method
scoring: str
"""
n=len(param_range)
if scoring=='r2':
train_score=[scores[0][i].mean() for i in range (0,n)]
test_score=[scores[1][i].mean() for i in range (0,n)]
elif scoring=='neg_mean_squared_error':
train_score=[np.sqrt(-scores[0][i].mean()) for i in range (0,n)]
test_score=[np.sqrt(-scores[1][i].mean()) for i in range (0,n)]
fig=plt.figure(figsize=(8,6))
plt.plot(param_range,train_score,label='Train')
plt.plot(param_range,test_score,label='Test')
plt.xticks=param_range
plt.title("Validation curve of {}".format(param_name),size=12)
plt.legend()
#4. Check multicolinearity
removed_by_vif=remove_by_vif(X.copy())
# removed_by_vif=['p','j','c']
X=X.drop(removed_by_vif,axis=1)