Skip to content

Commit

Permalink
Additional Helper Functions (#15)
Browse files Browse the repository at this point in the history
* updating project metadata

* fixing CI yaml

* using venv

* trying again

* redoing CI

* fixing tests

* changing some settings

* updates

* fixing build

* trying to fix this

* fixing release

* bumping version

* better code organization

* updates

* adding initial cm impl

* adding unique, dispatching pattern to Py objects, renaming ext

* rustfmt

* cm dispatched

* rustfmt

* tests and benchmarks added

* bump version

* 100% test coverage

* updating readme

* Threading enabled (#9)

* bumping version

* major refactor leveraging macros

* bumping version and updating test

* adding executed notebook

* fixing performance w/ bool

* multiclass implemented and tested ready for 1.0.0

* shifting to u32 for numpy compatability

* bumping version

* changing to i64 for better compatability

* Major additions to Python API to include high-level helpers
  • Loading branch information
zachcoleman authored Jun 1, 2022
1 parent bc0573f commit a9b812b
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fast-stats"
version = "1.0.2"
version = "1.1.0"
edition = "2021"

[lib]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![License](https://img.shields.io/badge/license-Apache2.0-green)](./LICENSE)

# fast-stats
`fast-stats` is a fast and simple library for calculating basic statistics such as: precision, recall, and f1-score. The library also supports the calculation of confusion matrices. For examples, please look at the `benchmarks/` folder.
`fast-stats` is a fast and simple library for calculating basic statistics such as: precision, recall, and f1-score. The library also supports the calculation of confusion matrices. For examples, please look at the `examples/` folder.

The project was developed using the [maturin](https://maturin.rs) framework.

Expand Down
File renamed without changes.
144 changes: 144 additions & 0 deletions examples/stats.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import fast_stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Settings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"SIZE = (10, 512, 512)\n",
"NUM_CATS = 8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Binary Statistics"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"y_true = np.random.randint(0, 2, SIZE).astype(bool)\n",
"y_pred = np.random.randint(0, 2, SIZE).astype(bool)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': 0.49939381724124243,\n",
" 'recall': 0.4994250828781588,\n",
" 'f1-score': 0.4994094495703526}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fast_stats.binary_stats(y_true, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-class Statistics"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"y_true = np.random.randint(0, NUM_CATS, SIZE)\n",
"y_pred = np.random.randint(0, NUM_CATS, SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': array([0.1256168 , 0.12500038, 0.12486642, 0.1248673 , 0.12500914,\n",
" 0.12636344, 0.12454387, 0.12488666]),\n",
" 'recall': array([0.12568051, 0.12524121, 0.1245743 , 0.12500458, 0.12535152,\n",
" 0.12616009, 0.12444887, 0.12469365]),\n",
" 'f1-score': array([0.12564865, 0.12512068, 0.12472019, 0.1249359 , 0.1251801 ,\n",
" 0.12626169, 0.12449635, 0.12479008])}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fast_stats.stats(y_true, y_pred)"
]
}
],
"metadata": {
"interpreter": {
"hash": "a3a671d63c09fb4878d313d605bf6366336b9695c04e11736a5d015abf9b1e42"
},
"kernelspec": {
"display_name": "Python 3.9.11 ('.venv39': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
File renamed without changes.
10 changes: 8 additions & 2 deletions fast_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .binary import binary_f1_score, binary_precision, binary_recall
from .binary import (
binary_f1_score,
binary_precision,
binary_recall,
binary_stats,
binary_tp_fp_fn,
)
from .confusion_matrix import confusion_matrix
from .multiclass import f1_score, precision, recall
from .multiclass import f1_score, precision, recall, stats
73 changes: 72 additions & 1 deletion fast_stats/binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Union
from typing import Dict, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -128,3 +128,74 @@ def binary_f1_score(
return 0.0

return 2 * p * r / (p + r)


def binary_tp_fp_fn(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float]:
"""Binary calculations for TP, FP, and FN
Args:
y_true (np.ndarray): array of true values (must be bool or int types)
y_pred (np.ndarray): array of pred values (must be bool or int types)
Returns:
Tuple[int]: counts for TP, FP, and FN
"""
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert all(
[
isinstance(y_pred, np.ndarray),
isinstance(y_true, np.ndarray),
]
), "y_true and y_pred must be numpy arrays"

tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred)
fp, fn = tp_fp - tp, tp_fn - tp
return tp, fp, fn


def binary_stats(
y_true: np.ndarray,
y_pred: np.ndarray,
zero_division: ZeroDivision = ZeroDivision.NONE,
) -> Dict[str, Result]:
"""Binary calculations for precision, recall and f1-score
Args:
y_true (np.ndarray): array of true values (must be bool or int types)
y_pred (np.ndarray): array of pred values (must be bool or int types)
Returns:
Dict[str, Result]: stats for precision, recall and f1-score
"""
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert all(
[
isinstance(y_pred, np.ndarray),
isinstance(y_true, np.ndarray),
]
), "y_true and y_pred must be numpy arrays"
zero_division = ZeroDivision(zero_division)

tp, tp_fp, tp_fn = _binary_f1_score_reqs(y_true, y_pred)
p, r = _precision(tp, tp_fp, zero_division), _recall(tp, tp_fn, zero_division)
stats = dict({"precision": p, "recall": r})

# convert p and/or r to 0 if None
if p is None:
p = 0.0
if r is None:
r = 0.0

# handle 0 cases
if p + r == 0:
if zero_division == ZeroDivision.NONE:
f1 = None
elif zero_division == ZeroDivision.ZERO:
f1 = 0.0
else:
f1 = 2 * p * r / (p + r)

stats.update({"f1-score": f1})

return stats
96 changes: 95 additions & 1 deletion fast_stats/multiclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from functools import partial
from typing import List, Union
from typing import Dict, List, Union

import numpy as np

Expand Down Expand Up @@ -190,3 +190,97 @@ def f1_from_ext(x, y, z):
return zero_handle(f1_from_ext(x[:, 0].sum(), x[:, 1].sum(), x[:, 2].sum()))
elif average == AverageType.MACRO:
return np.nanmean(f1_from_ext(x[:, 0], x[:, 1], x[:, 2]))


def stats(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: Union[List, np.ndarray] = None,
zero_division: ZeroDivision = ZeroDivision.NONE,
average: AverageType = AverageType.NONE,
) -> Dict[str, Result]:
"""Multi-class calculation of f1 score
Args:
y_true (np.ndarray): array of true values (must be bool or int types)
y_pred (np.ndarray): array of pred values (must be bool or int types)
labels (optional | list or np.ndarray):
labels to calculate confusion matrix for (must be bool or int types)
zero_division (optional | str): strategy to handle division by 0
average (optional | str): strategy for averaging across classes
Returns:
Dict[str, Result]: dictionary of strings to 1D array or scalar values
depending on averaging
"""
assert y_true.shape == y_pred.shape, "y_true and y_pred must be same shape"
assert all(
[
isinstance(y_pred, np.ndarray),
isinstance(y_true, np.ndarray),
]
), "y_true and y_pred must be numpy arrays"
zero_division = ZeroDivision(zero_division)
average = AverageType(average)

if labels is None:
labels = np.array(
sorted(list(_unique(np.concatenate([y_true, y_pred])))), dtype=y_true.dtype
)
elif isinstance(labels, list):
labels = np.array(labels, dtype=y_true.dtype)

x = _f1_score(y_true, y_pred, labels)

if zero_division == ZeroDivision.NONE:
zero_handle = partial(
np.nan_to_num, copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan
)
elif zero_division == zero_division.ZERO:
zero_handle = partial(
np.nan_to_num, copy=False, nan=0.0, posinf=0.0, neginf=0.0
)

def f1_from_ext(x, y, z):
p, r = x / y, x / z
return 2 * p * r / (p + r)

stats = dict()

# precision
with np.errstate(divide="ignore", invalid="ignore"):
if average == AverageType.NONE:
stats.update({"precision": zero_handle(x[:, 0] / x[:, 1])})
elif average == AverageType.MICRO:
stats.update({"precision": zero_handle(x[:, 0].sum() / x[:, 1].sum())})
elif average == AverageType.MACRO:
stats.update({"precision": np.nanmean(zero_handle(x[:, 0] / x[:, 1]))})

# recall
with np.errstate(divide="ignore", invalid="ignore"):
if average == AverageType.NONE:
stats.update({"recall": zero_handle(x[:, 0] / x[:, 2])})
elif average == AverageType.MICRO:
stats.update({"recall": zero_handle(x[:, 0].sum() / x[:, 2].sum())})
elif average == AverageType.MACRO:
stats.update({"recall": np.nanmean(zero_handle(x[:, 0] / x[:, 2]))})

# f1-score
with np.errstate(divide="ignore", invalid="ignore"):
if average == AverageType.NONE:
stats.update(
{"f1-score": zero_handle(f1_from_ext(x[:, 0], x[:, 1], x[:, 2]))}
)
elif average == AverageType.MICRO:
stats.update(
{
"f1-score": zero_handle(
f1_from_ext(x[:, 0].sum(), x[:, 1].sum(), x[:, 2].sum())
)
}
)
elif average == AverageType.MACRO:
stats.update(
{"f1-score": np.nanmean(f1_from_ext(x[:, 0], x[:, 1], x[:, 2]))}
)

return stats
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fast-stats"
version = "1.0.2"
version = "1.1.0"
description = "A fast and simple library for calculating basic statistics"
readme = "README.md"
license = {text="Apache 2.0"}
Expand Down Expand Up @@ -29,6 +29,7 @@ repository = "https://github.com/zachcoleman/fast-stats"

[project.optional-dependencies]
test = [
"dictdiffer",
"pytest",
"pytest-cov[all]"
]
Expand Down
Loading

0 comments on commit a9b812b

Please sign in to comment.