Skip to content

Commit 706e6ea

Browse files
committed
Added numba jitting
1 parent df1b359 commit 706e6ea

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ matplotlib = "*"
1414
mypy = "*"
1515
sphinx = "*"
1616
pyitlib = "*"
17+
numba = "*"
1718

1819
[dev-packages]
1920

src/epimodels/discrete/models.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
__author__ = 'fccoelho'
99

1010
import numpy as np
11-
from scipy.stats.distributions import poisson, nbinom
12-
from numpy import inf, nan, nan_to_num
13-
import sys
14-
import logging
11+
# from scipy.stats.distributions import poisson, nbinom
12+
# from numpy import inf, nan, nan_to_num
13+
# import sys
14+
# import logging
1515
from collections import OrderedDict
16-
import cython
16+
# import cython
17+
from typing import Dict, List, Iterable, Any
18+
# import numba
19+
# from numba.experimental import jitclass
1720
from epimodels import BaseModel
1821

1922
model_types = {
@@ -76,7 +79,6 @@ def run(self, *args):
7679
raise NotImplementedError
7780

7881
def __call__(self, *args, **kwargs):
79-
# args = self.get_args_from_redis()
8082
res = self.run(*args)
8183
self.traces.update(res)
8284
# return res
@@ -511,7 +513,15 @@ def model(self, inits, trange, totpop, params):
511513

512514
return {'time': tspan, 'S': S, 'I': I, 'E': E, 'R': R}
513515

514-
516+
# from numba.types import unicode_type, pyobject
517+
# spec = [
518+
# ('model_type', unicode_type),
519+
# ('state_variables', pyobject),
520+
# ('parameters', pyobject),
521+
# ('run', pyobject)
522+
# ]
523+
#
524+
# @jitclass(spec)
515525
class SIRS(DiscreteModel):
516526
def __init__(self):
517527
super().__init__()
@@ -520,11 +530,16 @@ def __init__(self):
520530
self.parameters = {'beta': r'$\beta$', 'b': 'b', 'w': 'w'}
521531
self.run = self.model
522532

523-
def model(self, inits, trange, totpop, params):
533+
534+
# @numba.jit
535+
def model(self, inits: List, trange: List, totpop: int, params: Dict) -> Dict:
524536
"""
525537
calculates the model SIRS, and return its values (no demographics)
526-
- inits = (E,I,S)
527-
- theta = infectious individuals from neighbor sites
538+
:param inits: (E,I,S)
539+
:param trange:
540+
:param totpop:
541+
:param params:
542+
:return:
528543
"""
529544
S: np.ndarray = np.zeros(trange[1] - trange[0])
530545
I: np.ndarray = np.zeros(trange[1] - trange[0])
@@ -565,7 +580,7 @@ def __init__(self):
565580

566581
self.run = self.model
567582

568-
def model(self, inits, trange, totpop, params) -> list:
583+
def model(self, inits, trange, totpop, params) -> dict:
569584
S: np.ndarray = np.zeros(trange[1] - trange[0])
570585
E: np.ndarray = np.zeros(trange[1] - trange[0])
571586
I: np.ndarray = np.zeros(trange[1] - trange[0])

tests/test_continuous_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def test_SIR_with_t_eval():
2222
assert len(model.traces['S']) == 500
2323
# assert len(model.traces['time']) == 50
2424

25-
26-
2725
def test_SIS():
2826
model = SIS()
2927
model([1000, 1], [0, 50], 1001, {'beta': 2, 'gamma': .1})

0 commit comments

Comments
 (0)