Skip to content

Commit

Permalink
resolve sklearn update issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nwoyecid committed Sep 26, 2022
1 parent d1eda62 commit 824c408
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ This takes an argument `num_class` which is default to `100`

The following function are possible with the `Recognition` class:


Name | Description
:--- | :---
update(`targets, predictions`)|takes in a (batch of) vector predictions and their corresponding groundtruth. vector size must match `num_class` in the class initialization.
Expand Down Expand Up @@ -138,7 +139,9 @@ This takes an argument `num_class` which is default to `100` and `num_tool` whic
The following function are possible with the `Detection` class:

Name | Description

:--- | :---

update(`targets, predictions, format`)|input: takes in a (batch of) list/dict predictions and their corresponding groundtruth. Each frame prediction/groundtruth can be either as a `list of list` or `list of dict`. (more details below).
video_end()|Call to make the end of one video sequence.
reset()|Reset current records. Useful during training and can be called at the begining of each epoch to avoid overlapping epoch performances.
Expand Down
8 changes: 8 additions & 0 deletions ivtmetrics/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(self, num_class=100, ignore_null=False):
self.num_class = num_class
self.ignore_null = ignore_null
self.reset_global()

def resolve_nan(self, classwise):
classwise[classwise==-0.0] = np.nan
return classwise

##%%%%%%%%%%%%%%%%%%% RESET OP #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
def reset(self):
Expand Down Expand Up @@ -114,6 +118,7 @@ def compute_AP(self, component="ivt", ignore_null=False):
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore') #, message='[info] triplet classes not represented in this test sample will be reported as nan values.')
classwise = average_precision_score(targets, predicts, average=None)
classwise = self.resolve_nan(classwise)
if (ignore_null and component=="ivt"): classwise = classwise[:-6]
mean = np.nanmean(classwise)
return {"AP":classwise, "mAP":mean}
Expand Down Expand Up @@ -147,6 +152,7 @@ def compute_global_AP(self, component="ivt", ignore_null=False):
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore') #, message='[info] triplet classes not represented in this test sample will be reported as nan values.')
classwise = average_precision_score(targets, predicts, average=None)
classwise = self.resolve_nan(classwise)
if (ignore_null and component=="ivt"): classwise = classwise[:-6]
mean = np.nanmean(classwise)
return {"AP":classwise, "mAP":mean}
Expand Down Expand Up @@ -180,6 +186,7 @@ def compute_video_AP(self, component="ivt", ignore_null=False):
else:
sys.exit("Function filtering {} not yet supported!".format(component))
classwise = average_precision_score(targets, predicts, average=None)
classwise = self.resolve_nan(classwise)
video_log.append( classwise.reshape([1,-1]) )
video_log = np.concatenate(video_log, axis=0)
videowise = np.nanmean(video_log, axis=0)
Expand Down Expand Up @@ -250,6 +257,7 @@ def topClass(self, k=10, component="ivt"):
else:
sys.exit("Function filtering {} not supported yet!".format(component))
classwise = average_precision_score(targets, predicts, average=None)
classwise = self.resolve_nan(classwise)
pd_idx = (-classwise).argsort()[:k]
output = {x:classwise[x] for x in pd_idx}
return output
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='ivtmetrics',
version='0.1.0',
version='0.1.1',
packages=['ivtmetrics'],
author='Chinedu Nwoye',
author_email='nwoye.chinedu@gmail.com',
Expand All @@ -15,10 +15,10 @@
long_description = long_description,
long_description_content_type ='text/x-rst',
url='https://github.com/CAMMA-public/ivtmetrics',
download_url = 'https://github.com/CAMMA-public/ivtmetrics/archive/refs/tags/v0.1.0.tar.gz',
download_url = 'https://github.com/CAMMA-public/ivtmetrics/archive/refs/tags/v0.1.1.tar.gz',
include_package_data=True,
license='BSD 2-clause', # Chose a license from here: https://help.github.com/articles/licensing-a-repository
install_requires=['scikit-learn',
install_requires=['scikit-learn>=1.0.2',
'numpy>=1.21',
],

Expand Down

0 comments on commit 824c408

Please sign in to comment.