Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ MANIFEST
data/
.vscode/
.devcontainer/
.github/
.github/
pdm.lock
.pdm-python

# Misc
my_cmds.txt
tests/
pdm.lock
launch.json
7 changes: 5 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ COPY pyproject.toml README.md pdm.lock ./
ENV PATH="/workspace/.venv/bin:$PATH"
RUN pdm install --no-self
COPY examples ./examples
COPY robot_search ./robot_search
COPY funsearch ./funsearch

RUN pip install --no-deps .
RUN llm install llm-ollama
RUN pip install dm_control

RUN pip install mujoco==3.2.4
RUN pip install dm_control==1.0.24

# if running the container
RUN rm -r ./funsearch ./build
CMD /bin/bash

# if debugging
# RUN pip install debugpy
# CMD ["python", "-Xfrozen_modules=off", "-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client", "funsearch", "run", "examples/inv_pendulum_spec.py", "0.6", "--sandbox_type", "ExternalProcessSandbox"]
# CMD ["python", "-Xfrozen_modules=off", "-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client", "funsearch", "run", "examples/inv_pendulum_spec.py", "0.6", "--sandbox_type", "ExternalProcessSandbox"]
5 changes: 5 additions & 0 deletions examples/dm_control_ballcup_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

import numpy as np
import funsearch
import re
from dm_control import suite

METHOD_MATCHER = re.compile(r"def policy_v\d\(.*?\) -> np.ndarray:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
METHOD_NAME_MATCHER = re.compile(r"policy_v\d+")
method_str = "def policy_v"


@funsearch.run
def solve(num_runs) -> float:
Expand Down
4 changes: 4 additions & 0 deletions examples/dm_control_swingup_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

import numpy as np
import funsearch
import re
from dm_control import suite

METHOD_MATCHER = re.compile(r"def policy_v\d\(.*?\) -> float:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
METHOD_NAME_MATCHER = re.compile(r"policy_v\d+")
method_str = "def policy_v"

@funsearch.run
def solve(num_runs) -> float:
Expand Down
12 changes: 9 additions & 3 deletions funsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
logging.basicConfig(level=LOGLEVEL)


def get_all_subclasses(cls):
all_subclasses = []

Expand Down Expand Up @@ -69,7 +68,7 @@ def main(ctx):
@click.argument('inputs')
# @click.option('--model_name', default="gpt-3.5-turbo-instruct", help='LLM model')
# @click.option('--model_name', default="deepseek-coder", help='LLM model')
@click.option('--model_name', default="starcoder2-control", help='LLM model')
@click.option('--model_name', default="starcoder-7b:latest", help='LLM model') # start with 7b
@click.option('--output_path', default="./data/", type=click.Path(file_okay=False), help='path for logs and data')
@click.option('--load_backup', default=None, type=click.File("rb"), help='Use existing program database')
@click.option('--iterations', default=-1, type=click.INT, help='Max iterations per sampler')
Expand Down Expand Up @@ -111,7 +110,11 @@ def run(spec_file, inputs, model_name, output_path, load_backup, iterations, sam
# model.key = model.get_key()
lm = sampler.LLM(2, model, log_path)

specification = spec_file.read()
specification = spec_file.read()
method_str = code_manipulation.extract_variable_value(specification, "method_str")
method_matcher = code_manipulation.extract_variable_value(specification, "METHOD_MATCHER")
method_name_matcher = code_manipulation.extract_variable_value(specification, "METHOD_NAME_MATCHER")

function_to_evolve, function_to_run = core._extract_function_names(specification)
template = code_manipulation.text_to_program(specification)

Expand All @@ -131,6 +134,9 @@ def run(spec_file, inputs, model_name, output_path, load_backup, iterations, sam
function_to_evolve,
function_to_run,
inputs,
method_str,
method_matcher,
method_name_matcher
) for _ in range(conf.num_evaluators)]

# We send the initial implementation to be analysed by one of the evaluators.
Expand Down
31 changes: 31 additions & 0 deletions funsearch/code_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import dataclasses
import io
import tokenize
import re

from absl import logging

Expand Down Expand Up @@ -251,3 +252,33 @@ def yield_decorated(code: str, module: str, name: str) -> Iterator[str]:
and attribute.value.id == module
and attribute.attr == name):
yield node.name

def extract_variable_value(code_str, var_name):
"""
Extracts the value assigned to a variable from a Python code string using AST.
This function supports extraction of:
- Direct string assignments (e.g., my_var = "hello")
- Values produced by a call to re.compile (e.g., my_regex = re.compile(r"..."))

If the variable is assigned via re.compile, this function returns the compiled regex.
Otherwise, it returns the string literal.
"""
tree = ast.parse(code_str)

for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == var_name:
value_node = node.value

# Direct string assignment:
if isinstance(value_node, ast.Constant) and isinstance(value_node.value, str):
return value_node.value

# Call to re.compile:
elif isinstance(value_node, ast.Call):
if hasattr(value_node.func, 'attr') and value_node.func.attr == 'compile':
if value_node.args and isinstance(value_node.args[0], ast.Constant) and isinstance(value_node.args[0].value, str):
pattern_str = value_node.args[0].value
return re.compile(pattern_str)
return None
30 changes: 20 additions & 10 deletions funsearch/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
from collections.abc import Sequence
import copy
from typing import Any, Tuple
from typing import Any, Tuple, Pattern

from funsearch import code_manipulation
from funsearch import programs_database
Expand All @@ -32,9 +32,9 @@
"""

# use this for pendulum swingup
METHOD_MATCHER = re.compile(r"def policy_v\d\(.*?\) -> float:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
METHOD_NAME_MATCHER = re.compile(r"policy_v\d+")
method_str = "def policy_v"
# METHOD_MATCHER = re.compile(r"def policy_v\d\(.*?\) -> float:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
# METHOD_NAME_MATCHER = re.compile(r"policy_v\d+")
# method_str = "def policy_v"

# use this for ball in cup
# METHOD_MATCHER = re.compile(r"def policy_v\d\(.*?\) -> np.ndarray:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
Expand All @@ -47,6 +47,7 @@ class _FunctionLineVisitor(ast.NodeVisitor):
def __init__(self, target_function_name: str) -> None:
self._target_function_name: str = target_function_name
self._function_end_line: int | None = None


def visit_FunctionDef(self, node: Any) -> None: # pylint: disable=invalid-name
"""Collects the end line number of the target function."""
Expand All @@ -61,7 +62,7 @@ def function_end_line(self) -> int:
return self._function_end_line


def _find_method_implementation(generated_code: str) -> Tuple[str, str]:
def _find_method_implementation(generated_code: str, METHOD_MATCHER: Pattern, METHOD_NAME_MATCHER: Pattern) -> Tuple[str, str]:
"""Find the last method specified in METHOD_MATCHER within generated code.

Return the code and the name of the method.
Expand All @@ -74,7 +75,7 @@ def _find_method_implementation(generated_code: str) -> Tuple[str, str]:
return last_match, name


def _trim_function_body(generated_code: str) -> str:
def _trim_function_body(generated_code: str, method_str: str, method_matcher: Pattern, method_name_matcher: Pattern) -> str:
"""Extracts the body of the generated function, trimming anything after it."""
if not generated_code:
return ''
Expand All @@ -84,7 +85,7 @@ def _trim_function_body(generated_code: str) -> str:
method_name = "fake_function_header"
# Check is the response only a continuation for our prompt or full method implementation with header
if method_str in generated_code:
code, method_name = _find_method_implementation(generated_code)
code, method_name = _find_method_implementation(generated_code, method_matcher, method_name_matcher)
else:
code = f'def {method_name}():\n{generated_code}'

Expand All @@ -111,9 +112,12 @@ def _sample_to_program(
version_generated: int | None,
template: code_manipulation.Program,
function_to_evolve: str,
method_str: str,
method_matcher: Pattern,
method_name_matcher: Pattern
) -> tuple[code_manipulation.Function, str]:
"""Returns the compiled generated function and the full runnable program."""
body = _trim_function_body(generated_code)
body = _trim_function_body(generated_code, method_str, method_matcher, method_name_matcher)
if version_generated is not None:
body = code_manipulation.rename_function_calls(
body,
Expand Down Expand Up @@ -150,7 +154,10 @@ def __init__(
function_to_evolve: str,
function_to_run: str,
inputs: Sequence[Any],
timeout_seconds: int = 30,
method_str: str,
method_matcher: Pattern,
method_name_matcher: Pattern,
timeout_seconds: int = 30
):
self._database = database
self._template = template
Expand All @@ -159,6 +166,9 @@ def __init__(
self._inputs = inputs
self._timeout_seconds = timeout_seconds
self._sandbox = sbox
self._method_matcher = method_matcher
self._method_str = method_str
self._method_name_matcher = method_name_matcher

def analyse(
self,
Expand All @@ -168,7 +178,7 @@ def analyse(
) -> None:
"""Compiles the sample into a program and executes it on test inputs."""
new_function, program = _sample_to_program(
sample, version_generated, self._template, self._function_to_evolve)
sample, version_generated, self._template, self._function_to_evolve, self._method_str, self._method_matcher, self._method_name_matcher)

scores_per_test = {}
for current_input in self._inputs: # runs the function on all inputs provided in the launch command
Expand Down
17 changes: 17 additions & 0 deletions models/Modelfile-7B
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM hf.co/QuantFactory/starcoder2-7b-instruct-GGUF:Q8_0
SYSTEM """You are an intelligent programming assistant. Most of your requests will be to complete the given function. Try not to use any additional formatting. Any explanations should be written as comments in the completed code."""

TEMPLATE """
{{ if .System }}{{ .System }}

{{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }}


{{ end }}### Response
{{ .Response }}<|endoftext|>
"""
PARAMETER temperature 1
PARAMETER top_p 0.95
PARAMETER repeat_last_n 15
PARAMETER num_predict 200
File renamed without changes.
22 changes: 19 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions robot_search/robogrammar_list_search_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""On every iteration, improve robot_v1 over the robot_vX from previous iterations.
Make only small changes. Try to make the code short.

Each function should return a list of integers. The integers should be between 0 and 19 (inclusive). Duplicates are allowed. Return only one completed function with no additional text, explanation, or formatting.

An example of a good design is: [0, 7, 1, 13, 1, 2, 16, 12, 13, 6, 4, 19, 4, 17, 5, 3, 2, 16, 4, 5, 18, 9, 8, 9, 9, 8]
An example of a bad design is: [0, 1, 2, 3]
"""
import funsearch
import re
from typing import List
import requests

# Ignore these 3 variables
METHOD_MATCHER = re.compile(r"def robot_v\d\(.*?\) -> List\[int\]:(?:\s*(?:[ \t]*(?!def|#|`|').*(?:\n|$)))+")
METHOD_NAME_MATCHER = re.compile(r"robot_v\d+")
method_str = "def robot_v"

@funsearch.run
def evaluate_robot(task="RidgedTerrainTask") -> float:
"""Returns the best reward managed by the robot design. Done via calling the robogrammar API
"""
# Type checking
robo_design = robot()

if not isinstance(robo_design, list) or not all([isinstance(x, int) for x in robo_design]):
return 0

url = "http://127.0.0.1:5555/simulate"
payload = {
"task": task,
"grammar_file": "data/designs/grammar_apr30.dot",
"rule_sequence": robo_design,
"jobs": 8,
"optim": True,
"episodes": 1, # using multiple episodes causes FCValueEstimator to crash apparently --- keep at 1
"episode_len": 30
}
response = requests.post(url, json=payload)
data = response.json()
optimization_result = data["distance_travelled"]
return optimization_result

@funsearch.evolve
def robot() -> List[int]:
"""Returns a list of numbers between 1 and 19 (inclusive) that represent the robot design.
"""
design = [0]

return design
Loading