Skip to content

Commit 1c0df99

Browse files
authored
Merge pull request #59 from lyle-morrisBV/main
Added dictionary and binary search lookups
2 parents f4707d9 + bd5434d commit 1c0df99

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

src/epss_api/epss.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import cached_property
55
from gzip import GzipFile
66
from urllib.request import urlopen
7+
from bisect import bisect_left
78

89

910
class Score(object):
@@ -17,7 +18,21 @@ def __init__(self, cve: str, epss: str, percentile: str):
1718

1819
class EPSS(object):
1920
def __init__(self) -> None:
20-
pass
21+
22+
url = 'https://epss.cyentia.com/epss_scores-current.csv.gz'
23+
24+
with urlopen(url) as res:
25+
dec = GzipFile(fileobj=res)
26+
epss_scores_str: str = dec.read().decode("utf-8")
27+
epss_scores_list = epss_scores_str.split('\n')
28+
29+
self._download = epss_scores_list
30+
31+
scores = [row for row in csv.DictReader(self._download[1:])]
32+
33+
self._byCVE = {row['cve'] : Score(row['cve'], row['epss'], row['percentile']) for row in scores}
34+
35+
self._sortedScores = sorted(self._byCVE.values(),key=lambda x:x.percentile)
2136

2237
def scores(self) -> list[Score]:
2338
"""Get all CVE's EPSS scores (downloaded data is cached in memory)
@@ -33,9 +48,7 @@ def scores(self) -> list[Score]:
3348
Returns:
3449
list[Score]: EPSS score's csv list
3550
"""
36-
scores = [row for row in csv.DictReader(self._download[1:])]
37-
return [Score(row['cve'], row['epss'], row['percentile'])
38-
for row in scores]
51+
return list(self._sortedScores)
3952

4053
def score(self, cve_id: str) -> Score:
4154
"""Get EPSS score and percentile
@@ -49,11 +62,8 @@ def score(self, cve_id: str) -> Score:
4962
Returns:
5063
Score | None: EPSS score percentile
5164
"""
52-
rows = self._filter_by_cve_id(cve_id)
53-
if len(rows) == 1:
54-
return rows[0]
55-
else:
56-
return None
65+
66+
return self._byCVE.get(cve_id,None)
5767

5868
def epss(self, cve_id: str) -> float:
5969
"""Get EPSS score
@@ -64,11 +74,12 @@ def epss(self, cve_id: str) -> float:
6474
Returns:
6575
float | None: EPSS score (0.0-1.0)
6676
"""
67-
rows = self._filter_by_cve_id(cve_id)
68-
if len(rows) == 1:
69-
return rows[0].epss
70-
else:
77+
78+
score = self._byCVE.get(cve_id,None)
79+
if score is None:
7180
return None
81+
else:
82+
return score.epss
7283

7384
def percentile(self, cve_id: str) -> float:
7485
"""Get EPSS percentile
@@ -79,11 +90,11 @@ def percentile(self, cve_id: str) -> float:
7990
Returns:
8091
float | None: EPSS percentile (0.0-1.0)
8192
"""
82-
rows = self._filter_by_cve_id(cve_id)
83-
if len(rows) == 1:
84-
return rows[0].percentile
85-
else:
93+
score = self._byCVE.get(cve_id,None)
94+
if score is None:
8695
return None
96+
else:
97+
return score.percentile
8798

8899
def epss_gt(self, max: float) -> list[Score]:
89100
"""Get CVEs with EPSS score greater or equal than the parameter
@@ -94,8 +105,9 @@ def epss_gt(self, max: float) -> list[Score]:
94105
Returns:
95106
list[Score] | None: EPSS score object list
96107
"""
97-
rows = [r for r in filter(lambda x: x.epss >= max, self.scores())]
98-
return rows
108+
i = bisect_left(self._sortedScores,min,key=lambda x:x.epss)
109+
110+
return list(self._sortedScores[i:])
99111

100112
def percentile_gt(self, max: float) -> list[Score]:
101113
"""Get CVEs with percentile greater or equal than the parameter
@@ -106,9 +118,9 @@ def percentile_gt(self, max: float) -> list[Score]:
106118
Returns:
107119
list[Score] | None: EPSS score object list
108120
"""
109-
rows = [r for r in
110-
filter(lambda x: x.percentile >= max, self.scores())]
111-
return rows
121+
i = bisect_left(self._sortedScores,min,key=lambda x:x.percentile)
122+
123+
return list(self._sortedScores[i:])
112124

113125
def epss_lt(self, min: float) -> list[Score]:
114126
"""Get CVEs with EPSS score lower or equal than the parameter
@@ -119,8 +131,9 @@ def epss_lt(self, min: float) -> list[Score]:
119131
Returns:
120132
list[Score] | None: EPSS score object list
121133
"""
122-
rows = [r for r in filter(lambda x: x.epss <= min, self.scores())]
123-
return rows
134+
i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.epss)
135+
136+
return list(self._sortedScores[:len(self.sortedScores)-i])
124137

125138
def percentile_lt(self, min: float) -> list[Score]:
126139
"""Get CVEs with percentile lower or equal than the parameter
@@ -131,9 +144,9 @@ def percentile_lt(self, min: float) -> list[Score]:
131144
Returns:
132145
list[Score] | None: EPSS score object list
133146
"""
134-
rows = [r for r in
135-
filter(lambda x: x.percentile <= min, self.scores())]
136-
return rows
147+
i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.percentile)
148+
149+
return list(self._sortedScores[:len(self.sortedScores)-i])
137150

138151
def csv(self) -> list[str]:
139152
"""Get csv data containing all epss scores.
@@ -150,18 +163,3 @@ def csv(self) -> list[str]:
150163
"""
151164
return self._download
152165

153-
def _filter_by_cve_id(self, cve_id: str):
154-
cve_filter = filter(lambda x: x.cve == cve_id, self.scores())
155-
rows = [row for row in cve_filter]
156-
return rows
157-
158-
@cached_property
159-
def _download(self):
160-
url = 'https://epss.cyentia.com/epss_scores-current.csv.gz'
161-
162-
with urlopen(url) as res:
163-
dec = GzipFile(fileobj=res)
164-
epss_scores_str: str = dec.read().decode("utf-8")
165-
epss_scores_list = epss_scores_str.split('\n')
166-
167-
return epss_scores_list

0 commit comments

Comments
 (0)