Skip to content

Commit

Permalink
Update to ML 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed May 8, 2021
1 parent c330551 commit ec5d3b5
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 139 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ composer.lock
predictions.csv
progress.csv
report.json
*.model
*.rbx
*.old
.vscode
.vs
52 changes: 20 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ $ composer create-project rubix/housing
```

## Requirements
- [PHP](https://php.net) 7.2 or above
- [PHP](https://php.net) 7.4 or above

#### Recommended
- [Tensor extension](https://github.com/RubixML/Tensor) for faster training and inference
Expand Down Expand Up @@ -85,7 +85,7 @@ use Rubix\ML\Persisters\Filesystem;

$estimator = new PersistentModel(
new GradientBoost(new RegressionTree(4), 0.1),
new Filesystem('housing.model', true)
new Filesystem('housing.rbx', true)
);
```

Expand All @@ -110,27 +110,16 @@ $estimator->train($dataset);
```

### Validation Score and Loss
During training, the learner will record the validation score and the training loss at each iteration or *epoch*. The validation score is calculated using the default [RMSE](https://docs.rubixml.com/latest/cross-validation/metrics/rmse.html) metric on a hold out portion of the training set. Contrariwise, the training loss is the value of the cost function (in this case the L2 or *quadratic* loss) computed over the training data. We can visualize the training progress by plotting these metrics. To output the scores and losses you can call the additional `scores()` and `steps()` methods on the learner instance.
During training, the learner will record the validation score and the training loss at each iteration or *epoch*. The validation score is calculated using the default [RMSE](https://docs.rubixml.com/latest/cross-validation/metrics/rmse.html) metric on a hold out portion of the training set. Contrariwise, the training loss is the value of the cost function (in this case the L2 or *quadratic* loss) computed over the training data. We can visualize the training progress by plotting these metrics. To output the scores and losses you can call the additional `steps()` method on the learner instance. Then we can export the data to a CSV file by exporting the iterator returned by `steps()` to a CSV file.

```php
$scores = $estimator->scores();

$losses = $estimator->steps();
```

Then we can export the data to a CSV file using an [Unlabeled](https://docs.rubixml.com/latest/datasets/unlabeled.html) dataset object. The `array_transpose()` method takes a 2-dimensional array and changes the rows to columns and vice versa.

```php
use Rubix\ML\Unlabeled;
use function Rubix\ML\array_transpose;
use Rubix\ML\Extractors\CSV;

Unlabeled::build(array_transpose([$scores, $losses]))
->toCSV(['scores', 'losses'])
->write('progress.csv');
$extractor = new CSV('progress.csv', true);

$extractor->export($estimator->steps());
```


Here is an example of what the validation score and training loss look like when plotted. You can plot the values yourself by importing the `progress.csv` file into your favorite plotting software.

![R Squared Score](https://raw.githubusercontent.com/RubixML/Housing/master/docs/images/validation-score.png)
Expand Down Expand Up @@ -164,14 +153,6 @@ $dataset = Unlabeled::fromIterator(new CSV('unknown.csv', true))
->apply(new NumericStringConverter());
```

This time we will need the `Id` column from the housing dataset so we can just import the unknown samples as they are. The `Id` values, however, are useless and incompatible with our model so after we dump them to a separate array, we'll drop the column from the dataset. The Id numbers will be needed to submit to the Kaggle contest later on.

```php
$ids = $dataset->column(0);

$dataset->dropColumn(0);
```

### Load Model from Storage
Now, let's load the persisted Gradient Boost estimator into our script using the static `load()` method on the Persistent Model class by passing it a [Persister](https://docs.rubixml.com/latest/persisters/api.html) instance pointing to the model in storage.

Expand All @@ -189,15 +170,22 @@ To obtain the predictions from the model, call the `predict()` method with the d
$predictions = $estimator->predict($dataset);
```

Then we'll use another [Unlabeled](https://docs.rubixml.com/latest/datasets/unlabeled.html) dataset to write the IDs and predictions to a CSV file that we'll submit to the competition.
Then we'll use the CSV extractor to export the IDs and predictions to a file that we'll submit to the competition.

```php
use Rubix\ML\Datasets\Unlabeled;
use function Rubix\ML\array_transpose;
use Rubix\ML\Extractors\ColumnPicker;
use Rubix\ML\Extractors\CSV;

$extractor = new ColumnPicker(new CSV('dataset.csv', true), ['Id']);

$ids = array_column(iterator_to_array($extractor), 'Id');

array_unshift($ids, 'Id');
array_unshift($predictions, 'SalePrice');

$extractor = new CSV('predictions.csv');

Unlabeled::build(array_transpose([$ids, $predictions]))
->toCSV(['Id', 'SalePrice'])
->write('predictions.csv');
$extractor->export(array_transpose([$ids, $predictions]));
```

Now run the prediction script by calling it from the command line.
Expand All @@ -222,4 +210,4 @@ Have a look at the [Gradient Boost](https://docs.rubixml.com/latest/regressors/g
>- D. De Cock. (2011). Ames, Iowa: Alternative to the Boston Housing Data as an End of Semester Regression Project. Journal of Statistics Education, Volume 19, Number 3.
## License
The code is licensed [MIT](LICENSE.md) and the tutorial is licensed [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
The code is licensed [MIT](LICENSE.md) and the tutorial is licensed [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
4 changes: 2 additions & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
}
],
"require": {
"php": ">=7.2",
"rubix/ml": "^0.3.0"
"php": ">=7.4",
"rubix/ml": "^1.0"
},
"scripts": {
"predict": "@php predict.php",
Expand Down
Loading

0 comments on commit ec5d3b5

Please sign in to comment.