Skip to content

Commit 9290243

Browse files
authored
Merge pull request #331 from yuxuan-z19/zyx-fix-torch
Make max_tasks_per_child configurable to fix CUDA memory leaks
2 parents dd776ce + c339f29 commit 9290243

File tree

4 files changed

+91
-144
lines changed

4 files changed

+91
-144
lines changed

openevolve/cli.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,11 @@ async def main_async() -> int:
7878
print(f"Error: Evaluation file '{args.evaluation_file}' not found")
7979
return 1
8080

81+
# Load base config from file or defaults
82+
config = load_config(args.config)
83+
8184
# Create config object with command-line overrides
82-
config = None
8385
if args.api_base or args.primary_model or args.secondary_model:
84-
# Load base config from file or defaults
85-
config = load_config(args.config)
86-
8786
# Apply command-line overrides
8887
if args.api_base:
8988
config.llm.api_base = args.api_base
@@ -110,7 +109,6 @@ async def main_async() -> int:
110109
initial_program_path=args.initial_program,
111110
evaluation_file=args.evaluation_file,
112111
config=config,
113-
config_path=args.config if config is None else None,
114112
output_dir=args.output,
115113
)
116114

openevolve/config.py

Lines changed: 17 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import os
6-
from dataclasses import dataclass, field
6+
from dataclasses import asdict, dataclass, field
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
99

@@ -41,7 +41,7 @@ class LLMModelConfig:
4141

4242
# Reproducibility
4343
random_seed: Optional[int] = None
44-
44+
4545
# Reasoning parameters
4646
reasoning_effort: Optional[str] = None
4747

@@ -75,7 +75,7 @@ class LLMConfig(LLMModelConfig):
7575
primary_model_weight: float = None
7676
secondary_model: str = None
7777
secondary_model_weight: float = None
78-
78+
7979
# Reasoning parameters (inherited from LLMModelConfig but can be overridden)
8080
reasoning_effort: Optional[str] = None
8181

@@ -146,7 +146,7 @@ def rebuild_models(self) -> None:
146146
# Clear existing models lists
147147
self.models = []
148148
self.evaluator_models = []
149-
149+
150150
# Re-run model generation logic from __post_init__
151151
if self.primary_model:
152152
# Create primary model
@@ -205,6 +205,7 @@ class PromptConfig:
205205
template_variations: Dict[str, List[str]] = field(default_factory=dict)
206206

207207
# Meta-prompting
208+
# Note: meta-prompting features not implemented
208209
use_meta_prompting: bool = False
209210
meta_prompt_weight: float = 0.1
210211

@@ -254,6 +255,7 @@ class DatabaseConfig:
254255
elite_selection_ratio: float = 0.1
255256
exploration_ratio: float = 0.2
256257
exploitation_ratio: float = 0.7
258+
# Note: diversity_metric fixed to "edit_distance"
257259
diversity_metric: str = "edit_distance" # Options: "edit_distance", "feature_based"
258260

259261
# Feature map dimensions for MAP-Elites
@@ -291,6 +293,7 @@ class DatabaseConfig:
291293
embedding_model: Optional[str] = None
292294
similarity_threshold: float = 0.99
293295

296+
294297
@dataclass
295298
class EvaluatorConfig:
296299
"""Configuration for program evaluation"""
@@ -300,6 +303,7 @@ class EvaluatorConfig:
300303
max_retries: int = 3
301304

302305
# Resource limits for evaluation
306+
# Note: resource limits not implemented
303307
memory_limit_mb: Optional[int] = None
304308
cpu_limit: Optional[float] = None
305309

@@ -309,6 +313,7 @@ class EvaluatorConfig:
309313

310314
# Parallel evaluation
311315
parallel_evaluations: int = 1
316+
# Note: distributed evaluation not implemented
312317
distributed: bool = False
313318

314319
# LLM-based feedback
@@ -323,7 +328,7 @@ class EvaluatorConfig:
323328
@dataclass
324329
class EvolutionTraceConfig:
325330
"""Configuration for evolution trace logging"""
326-
331+
327332
enabled: bool = False
328333
format: str = "jsonl" # Options: "jsonl", "json", "hdf5"
329334
include_code: bool = False
@@ -362,6 +367,9 @@ class Config:
362367
convergence_threshold: float = 0.001
363368
early_stopping_metric: str = "combined_score"
364369

370+
# Parallel controller settings
371+
max_tasks_per_child: Optional[int] = None
372+
365373
@classmethod
366374
def from_yaml(cls, path: Union[str, Path]) -> "Config":
367375
"""Load configuration from a YAML file"""
@@ -377,7 +385,9 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
377385

378386
# Update top-level fields
379387
for key, value in config_dict.items():
380-
if key not in ["llm", "prompt", "database", "evaluator", "evolution_trace"] and hasattr(config, key):
388+
if key not in ["llm", "prompt", "database", "evaluator", "evolution_trace"] and hasattr(
389+
config, key
390+
):
381391
setattr(config, key, value)
382392

383393
# Update nested configs
@@ -406,87 +416,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
406416
return config
407417

408418
def to_dict(self) -> Dict[str, Any]:
409-
"""Convert configuration to a dictionary"""
410-
return {
411-
# General settings
412-
"max_iterations": self.max_iterations,
413-
"checkpoint_interval": self.checkpoint_interval,
414-
"log_level": self.log_level,
415-
"log_dir": self.log_dir,
416-
"random_seed": self.random_seed,
417-
# Component configurations
418-
"llm": {
419-
"models": self.llm.models,
420-
"evaluator_models": self.llm.evaluator_models,
421-
"api_base": self.llm.api_base,
422-
"temperature": self.llm.temperature,
423-
"top_p": self.llm.top_p,
424-
"max_tokens": self.llm.max_tokens,
425-
"timeout": self.llm.timeout,
426-
"retries": self.llm.retries,
427-
"retry_delay": self.llm.retry_delay,
428-
},
429-
"prompt": {
430-
"template_dir": self.prompt.template_dir,
431-
"system_message": self.prompt.system_message,
432-
"evaluator_system_message": self.prompt.evaluator_system_message,
433-
"num_top_programs": self.prompt.num_top_programs,
434-
"num_diverse_programs": self.prompt.num_diverse_programs,
435-
"use_template_stochasticity": self.prompt.use_template_stochasticity,
436-
"template_variations": self.prompt.template_variations,
437-
# Note: meta-prompting features not implemented
438-
# "use_meta_prompting": self.prompt.use_meta_prompting,
439-
# "meta_prompt_weight": self.prompt.meta_prompt_weight,
440-
},
441-
"database": {
442-
"db_path": self.database.db_path,
443-
"in_memory": self.database.in_memory,
444-
"population_size": self.database.population_size,
445-
"archive_size": self.database.archive_size,
446-
"num_islands": self.database.num_islands,
447-
"elite_selection_ratio": self.database.elite_selection_ratio,
448-
"exploration_ratio": self.database.exploration_ratio,
449-
"exploitation_ratio": self.database.exploitation_ratio,
450-
# Note: diversity_metric fixed to "edit_distance"
451-
# "diversity_metric": self.database.diversity_metric,
452-
"feature_dimensions": self.database.feature_dimensions,
453-
"feature_bins": self.database.feature_bins,
454-
"migration_interval": self.database.migration_interval,
455-
"migration_rate": self.database.migration_rate,
456-
"random_seed": self.database.random_seed,
457-
"log_prompts": self.database.log_prompts,
458-
},
459-
"evaluator": {
460-
"timeout": self.evaluator.timeout,
461-
"max_retries": self.evaluator.max_retries,
462-
# Note: resource limits not implemented
463-
# "memory_limit_mb": self.evaluator.memory_limit_mb,
464-
# "cpu_limit": self.evaluator.cpu_limit,
465-
"cascade_evaluation": self.evaluator.cascade_evaluation,
466-
"cascade_thresholds": self.evaluator.cascade_thresholds,
467-
"parallel_evaluations": self.evaluator.parallel_evaluations,
468-
# Note: distributed evaluation not implemented
469-
# "distributed": self.evaluator.distributed,
470-
"use_llm_feedback": self.evaluator.use_llm_feedback,
471-
"llm_feedback_weight": self.evaluator.llm_feedback_weight,
472-
},
473-
"evolution_trace": {
474-
"enabled": self.evolution_trace.enabled,
475-
"format": self.evolution_trace.format,
476-
"include_code": self.evolution_trace.include_code,
477-
"include_prompts": self.evolution_trace.include_prompts,
478-
"output_path": self.evolution_trace.output_path,
479-
"buffer_size": self.evolution_trace.buffer_size,
480-
"compress": self.evolution_trace.compress,
481-
},
482-
# Evolution settings
483-
"diff_based_evolution": self.diff_based_evolution,
484-
"max_code_length": self.max_code_length,
485-
# Early stopping settings
486-
"early_stopping_patience": self.early_stopping_patience,
487-
"convergence_threshold": self.convergence_threshold,
488-
"early_stopping_metric": self.early_stopping_metric,
489-
}
419+
return asdict(self)
490420

491421
def to_yaml(self, path: Union[str, Path]) -> None:
492422
"""Save configuration to a YAML file"""

openevolve/controller.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,10 @@
1616
from openevolve.evaluator import Evaluator
1717
from openevolve.evolution_trace import EvolutionTracer
1818
from openevolve.llm.ensemble import LLMEnsemble
19-
from openevolve.prompt.sampler import PromptSampler
2019
from openevolve.process_parallel import ProcessParallelController
21-
from openevolve.utils.code_utils import (
22-
extract_code_language,
23-
)
24-
from openevolve.utils.format_utils import (
25-
format_metrics_safe,
26-
format_improvement_safe,
27-
)
20+
from openevolve.prompt.sampler import PromptSampler
21+
from openevolve.utils.code_utils import extract_code_language
22+
from openevolve.utils.format_utils import format_improvement_safe, format_metrics_safe
2823

2924
logger = logging.getLogger(__name__)
3025

@@ -75,17 +70,11 @@ def __init__(
7570
self,
7671
initial_program_path: str,
7772
evaluation_file: str,
78-
config_path: Optional[str] = None,
79-
config: Optional[Config] = None,
73+
config: Config,
8074
output_dir: Optional[str] = None,
8175
):
82-
# Load configuration
83-
if config is not None:
84-
# Use provided Config object directly
85-
self.config = config
86-
else:
87-
# Load from file or use defaults
88-
self.config = load_config(config_path)
76+
# Load configuration (loaded in main_async)
77+
self.config = config
8978

9079
# Set up output directory
9180
self.output_dir = output_dir or os.path.join(
@@ -98,9 +87,10 @@ def __init__(
9887

9988
# Set random seed for reproducibility if specified
10089
if self.config.random_seed is not None:
90+
import hashlib
10191
import random
92+
10293
import numpy as np
103-
import hashlib
10494

10595
# Set global random seeds
10696
random.seed(self.config.random_seed)
@@ -139,7 +129,7 @@ def __init__(
139129
self.file_extension = f".{self.file_extension}"
140130

141131
# Set the file_suffix in config (can be overridden in YAML)
142-
if not hasattr(self.config, 'file_suffix') or self.config.file_suffix == ".py":
132+
if not hasattr(self.config, "file_suffix") or self.config.file_suffix == ".py":
143133
self.config.file_suffix = self.file_extension
144134

145135
# Initialize components
@@ -175,18 +165,17 @@ def __init__(
175165
if not trace_output_path:
176166
# Default to output_dir/evolution_trace.{format}
177167
trace_output_path = os.path.join(
178-
self.output_dir,
179-
f"evolution_trace.{self.config.evolution_trace.format}"
168+
self.output_dir, f"evolution_trace.{self.config.evolution_trace.format}"
180169
)
181-
170+
182171
self.evolution_tracer = EvolutionTracer(
183172
output_path=trace_output_path,
184173
format=self.config.evolution_trace.format,
185174
include_code=self.config.evolution_trace.include_code,
186175
include_prompts=self.config.evolution_trace.include_prompts,
187176
enabled=True,
188177
buffer_size=self.config.evolution_trace.buffer_size,
189-
compress=self.config.evolution_trace.compress
178+
compress=self.config.evolution_trace.compress,
190179
)
191180
logger.info(f"Evolution tracing enabled: {trace_output_path}")
192181
else:
@@ -305,8 +294,11 @@ async def run(
305294
# Initialize improved parallel processing
306295
try:
307296
self.parallel_controller = ProcessParallelController(
308-
self.config, self.evaluation_file, self.database, self.evolution_tracer,
309-
file_suffix=self.config.file_suffix
297+
self.config,
298+
self.evaluation_file,
299+
self.database,
300+
self.evolution_tracer,
301+
file_suffix=self.config.file_suffix,
310302
)
311303

312304
# Set up signal handlers for graceful shutdown
@@ -349,7 +341,7 @@ def force_exit_handler(signum, frame):
349341
if self.parallel_controller:
350342
self.parallel_controller.stop()
351343
self.parallel_controller = None
352-
344+
353345
# Close evolution tracer
354346
if self.evolution_tracer:
355347
self.evolution_tracer.close()

0 commit comments

Comments
 (0)