Skip to content

Commit

Permalink
structural constants & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ohjeah committed Jul 31, 2017
1 parent 2f4846b commit dcf1ac9
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
27 changes: 25 additions & 2 deletions cartesian/cgp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import copy
import sys
import re
from operator import attrgetter
from collections import namedtuple

Expand All @@ -22,6 +23,7 @@ class Terminal(Primitive):
def __init__(self, name):
self.name = name


class Constant(Terminal):
pass

Expand All @@ -30,6 +32,21 @@ class Ephemeral(Primitive):
def __init__(self, name, function):
super().__init__(name, function, 0)


class Structual(Primitive):
def __init__(self, name, function, arity):
self.name = name
self._function = function
self.arity = arity

def function(self, *args):
return self._function(*map(self.get_len, args))

@staticmethod
def get_len(expr, tokens=("(,")):
regex = "|".join("\\{}".format(t) for t in tokens)
return len(re.split(regex, expr))

# class PrimitiveSet:
# def __init__(self, primitives):
# self.operators = [p for p in primitives if p.arity > 0]
Expand Down Expand Up @@ -111,6 +128,9 @@ def __copy__(self):
def clone(self):
return copy.copy(self)

def format(self, x):
return "{}".format(x)

def fit(self, x, y=None, **fit_params):
self._transform = compile(self)
self.fit_params = fit_params
Expand Down Expand Up @@ -184,17 +204,20 @@ def h(g):
gene = make_it(c[g])
primitive = primitives[next(gene)]

# refactor to primitive.format() ? side-effects?
if primitive.arity == 0:
if isinstance(primitive, Terminal):
used_arguments.add(primitive)

elif isinstance(primitive, Ephemeral):
if g not in c.memory:
c.memory[g] = "{0:.2f}".format(primitive.function())
c.memory[g] = c.format(primitive.function())
return c.memory[g]

return primitive.name

elif isinstance(primitive, Structual):
return c.format(primitive.function(*[h(a) for a, _ in zip(gene, range(primitive.arity))]))

else:
return "{}({})".format(primitive.name,
", ".join(h(a) for a, _ in zip(gene, range(primitive.arity))))
Expand Down
32 changes: 32 additions & 0 deletions examples/structural_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from sklearn.utils.validation import check_random_state

from cartesian.algorithm import oneplus
from cartesian.cgp import *

rng = check_random_state(1337)

primitives = [
Primitive("add", np.add, 2),
Primitive("mul", np.multiply, 2),
Terminal("x_0"),
Terminal("x_1"),
Structual("SC", (lambda x, y: min(x, y)/max(x, y)), 2),
]

pset = create_pset(primitives)

x = rng.normal(size=(100, 2))
y = x[:, 1] * x[:, 0] + 0.3
y += 0.05 * rng.normal(size=y.shape)


def func(individual):
f = compile(individual)
yhat = f(*x.T)
return np.sqrt(np.mean((y - yhat)**2))/(y.max() - y.min())


MyCartesian = Cartesian("MyCartesian", pset, n_rows=3, n_columns=4, n_out=1, n_back=1)
res = oneplus(func, cls=MyCartesian, f_tol=0.01, random_state=rng, max_nfev=50000, n_jobs=1)
print(res)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ def individual(request):
code = [[[3, 1]]]
outputs = [3]
return MyCartesian(code, outputs)

@pytest.fixture
def sc():
s = Structual("SC", (lambda x, y: x/y), 2)
return s
12 changes: 12 additions & 0 deletions tests/test_cgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,15 @@ def test_ephemeral_constant():
assert s1 != s2
ind3 = point_mutation(ind1)
assert not ind3.memory # empty dict


def test_structural_constant_cls(sc):
assert 0.5 == sc.function("x", "f(x)")

def test_structural_constant_to_polish(sc):
primitives = [Terminal("x_0"), sc]
pset = create_pset(primitives)

MyClass = Cartesian("MyClass", pset)
ind = MyClass([[[1, 0, 0]], [[1, 0, 0]], [[1, 0, 0]]], [2])
assert to_polish(ind, return_args=False) == ["1.0"]

0 comments on commit dcf1ac9

Please sign in to comment.