Skip to content

Commit

Permalink
Second fix for regression
Browse files Browse the repository at this point in the history
Fix for regression where the response is numerical values
  • Loading branch information
Nabeel committed Jun 13, 2019
1 parent 3300511 commit 046a5ab
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 19 deletions.
14 changes: 2 additions & 12 deletions core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,7 @@ def _sklearn(request, context):
# Get labels for clustering
response = model.fit_transform(load_script=False)

# Set the correct data type for the response
if is_numeric_dtype(response):
dtypes = ["num"]
else:
dtypes = ["str"]
dtypes = ["str"]

elif function in (15, 17, 28):
if function == 15:
Expand All @@ -512,13 +508,7 @@ def _sklearn(request, context):
# Provide labels for clustering
response = model.fit_transform(load_script=True)

# Set the correct data type for the response
if is_numeric_dtype(response.iloc[:,2]):
dt = "num"
else:
dt = "str"

dtypes = ["str", "str", dt]
dtypes = ["str", "str", "str"]

elif function in (18, 22):
if function == 18:
Expand Down
8 changes: 1 addition & 7 deletions core/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,13 +1530,7 @@ def _send_table_description(self, variant):
elif variant == "predict":
self.table.fields.add(name="model_name")
self.table.fields.add(name="key")

# We return numerical predictions for regression and text for classification
if self.model.estimator_type == "regressor":
dt = 1
else:
dt = 0
self.table.fields.add(name="prediction", dataType=dt)
self.table.fields.add(name="prediction")
elif variant == "expression":
self.table.fields.add(name="result")
elif variant == "best_params":
Expand Down

0 comments on commit 046a5ab

Please sign in to comment.