Skip to content

Commit

Permalink
Rewritten Composite. Now should run more smoothly. also updated callb…
Browse files Browse the repository at this point in the history
…acks and tested that example runs
perara committed Oct 21, 2024
1 parent df55ecb commit 79cb1c9
Showing 5 changed files with 358 additions and 150 deletions.
4 changes: 2 additions & 2 deletions examples/composite/TMCompositeCIFAR10Demo.py
Original file line number Diff line number Diff line change
@@ -50,10 +50,10 @@ def main(args):

class TMCompositeCheckpointCallback(TMCompositeCallback):

def on_epoch_component_begin(self, component, epoch, logs=None):
def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs):
pass

def on_epoch_component_end(self, component, epoch, logs=None):
def on_epoch_component_end(self, component, epoch, logs=None, **kwargs):
component.save(component_path / f"{component}-{epoch}.pkl")

class TMCompositeEvaluationCallback(TMCompositeCallback):
34 changes: 23 additions & 11 deletions tmu/composite/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,49 @@
from dataclasses import dataclass
from enum import auto, Enum
from multiprocessing import Queue
from typing import Any, Dict


class TMCompositeCallback:
class CallbackMethod(Enum):
ON_TRAIN_COMPOSITE_BEGIN = auto()
ON_TRAIN_COMPOSITE_END = auto()
ON_EPOCH_COMPONENT_BEGIN = auto()
ON_EPOCH_COMPONENT_END = auto()
UPDATE_PROGRESS = auto()

@dataclass
class CallbackMessage:
method: CallbackMethod
kwargs: Dict[str, Any]

class TMCompositeCallback:
def __init__(self):
pass

def on_epoch_component_begin(self, component, epoch, logs=None):
def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs):
pass

def on_epoch_component_end(self, component, epoch, logs=None):
def on_epoch_component_end(self, component, epoch, logs=None, **kwargs):
pass

def on_train_composite_end(self, composite, logs=None):
def on_train_composite_end(self, composite, logs=None, **kwargs):
pass

def on_train_composite_begin(self, composite, logs=None):
def on_train_composite_begin(self, composite, logs=None, **kwargs):
pass


class TMCompositeCallbackProxy:

def __init__(self, queue: Queue):
self.queue = queue

def on_epoch_component_begin(self, component, epoch, logs=None):
self.queue.put(('on_epoch_component_begin', component, epoch, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': component, 'epoch': epoch, 'logs': logs}))

def on_epoch_component_end(self, component, epoch, logs=None):
self.queue.put(('on_epoch_component_end', component, epoch, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': component, 'epoch': epoch, 'logs': logs}))

def on_train_composite_end(self, composite, logs=None):
self.queue.put(('on_train_composite_end', composite, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {'composite': composite, 'logs': logs}))

def on_train_composite_begin(self, composite, logs=None):
self.queue.put(('on_train_composite_begin', composite, logs))
self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {'composite': composite, 'logs': logs}))
33 changes: 20 additions & 13 deletions tmu/composite/components/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
import uuid

import numpy as np
from pathlib import Path
from typing import Union, Tuple
@@ -13,6 +15,7 @@ def __init__(self, model_cls, model_config, epochs=1, **kwargs) -> None:
self.model_cls = model_cls
self.model_config = model_config
self.epochs = epochs
self.uuid = uuid.uuid4()

# Warn about unused kwargs
if kwargs:
@@ -36,19 +39,23 @@ def preprocess(self, data: dict) -> dict:
return data

def fit(self, data: dict) -> None:
x_train, y_train = data["X"], data["Y"]

# Check if type is uint32
if x_train.dtype != np.uint32:
x_train = x_train.astype(np.uint32)

if y_train.dtype != np.uint32:
y_train = y_train.astype(np.uint32)

self.model_instance.fit(
x_train,
y_train,
)
try:
x_train, y_train = data["X"], data["Y"]

# Check if type is uint32
if x_train.dtype != np.uint32:
x_train = x_train.astype(np.uint32)

if y_train.dtype != np.uint32:
y_train = y_train.astype(np.uint32)

self.model_instance.fit(
x_train,
y_train,
)
except Exception as e:
print(f"Error: {e}")
raise e

def predict(self, data: dict) -> Tuple[np.array, np.array]:
X_test = data["X"]
47 changes: 35 additions & 12 deletions tmu/composite/components/color_thermometer_scoring.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,49 @@
import numpy as np
from typing import Dict, Any
from tmu.composite.components.base import TMComponent


class ColorThermometerComponent(TMComponent):

def __init__(self, model_cls, model_config, resolution=8, **kwargs) -> None:
super().__init__(model_cls=model_cls, model_config=model_config, **kwargs)
if resolution < 2 or resolution > 255:
raise ValueError("Resolution must be between 2 and 255")
self.resolution = resolution
self._thresholds = None

def _create_thresholds(self) -> None:
self._thresholds = np.linspace(0, 255, self.resolution + 1)[1:-1]

def preprocess(self, data: dict):
def preprocess(self, data: dict) -> Dict[str, Any]:
super().preprocess(data=data)
X_org = data.get("X")
Y = data.get("Y")

if X_org is None:
raise ValueError("Input data 'X' is missing")

if X_org.ndim != 4:
raise ValueError(f"Expected 4D input, got {X_org.ndim}D")

if X_org.shape[-1] != 3:
raise ValueError(f"Expected 3 color channels, got {X_org.shape[-1]}")

if self._thresholds is None:
self._create_thresholds()

X_org = data["X"]
Y = data["Y"]
# Use broadcasting for efficient computation
X = (X_org[:, :, :, :, np.newaxis] >= self._thresholds).astype(np.uint8)

X = np.empty((X_org.shape[0], X_org.shape[1], X_org.shape[2], X_org.shape[3], self.resolution), dtype=np.uint8)
for z in range(self.resolution):
X[:, :, :, :, z] = X_org[:, :, :, :] >= (z + 1) * 255 / (self.resolution + 1)
# Reshape correctly
batch_size, height, width, channels, _ = X.shape
X = X.transpose(0, 1, 2, 4, 3).reshape(batch_size, height, width, channels * (self.resolution - 1))

X = X.reshape((X_org.shape[0], X_org.shape[1], X_org.shape[2], 3 * self.resolution))
return {
"X": X,
"Y": Y
}

return dict(
X=X,
Y=Y,
)
def get_output_shape(self, input_shape: tuple) -> tuple:
if len(input_shape) != 4:
raise ValueError(f"Expected 4D input shape, got {len(input_shape)}D")
return (*input_shape[:-1], input_shape[-1] * (self.resolution - 1))
390 changes: 278 additions & 112 deletions tmu/composite/composite.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
import threading
from collections import defaultdict
from os import cpu_count
from typing import Optional, Type, Union, List
from pathlib import Path
from multiprocessing import Pool, Manager
import concurrent.futures
import multiprocessing
import traceback
import uuid
from functools import partial
from multiprocessing import Manager, Queue, cpu_count, Pool
import numpy as np
from tqdm import tqdm

from typing import Optional, List, Dict, Any, Tuple, Union
from dataclasses import dataclass
from pathlib import Path
import threading
from tmu.composite.callbacks.base import TMCompositeCallbackProxy, TMCompositeCallback
from tmu.composite.components.base import TMComponent
from tmu.composite.gating.base import BaseGate
from tmu.composite.gating.linear_gate import LinearGate

from tmu.composite.callbacks.base import CallbackMessage, CallbackMethod


@dataclass
class ComponentTask:
component: TMComponent
data: Any
epochs: int
progress: int = 0
result: Any = None
@property
def component_id(self) -> uuid.UUID:
return self.component.uuid

@dataclass
class FitResult:
component: TMComponent
success: bool
error: Optional[Exception] = None

class TMCompositeBase:

@@ -35,157 +55,302 @@ def _component_predict(self, component, data):
return votes


class TMCompositeMP(TMCompositeBase):
class TMCompositeMP:
def __init__(self, composite: 'TMComposite', **kwargs) -> None:
self.composite = composite
self.max_workers = min(cpu_count(), len(composite.components))
self.remove_data_after_preprocess = kwargs.get('remove_data_after_preprocess', False)
multiprocessing.set_start_method('spawn', force=True)

def _process_callbacks(self, callbacks: List[TMCompositeCallback], message: CallbackMessage) -> None:
method_name = message.method.name.lower()
for callback in callbacks:
try:
getattr(callback, method_name)(**message.kwargs, composite=self.composite)
except Exception as e:
print(f"Error in callback {callback.__class__.__name__}.{method_name}: {e}")
traceback.print_exc()

def _fit_component(self, task, callback_queue) -> FitResult:
try:
data_preprocessed = task.component.preprocess(task.data)
callbacks = []

# remove task.data
if self.remove_data_after_preprocess:
task.data = None

for epoch in range(task.epochs):

if callback_queue:
callbacks.append(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': task.component, 'epoch': epoch}))

task.component.fit(data=data_preprocessed)
task.progress += 1

if callback_queue:
callbacks.append(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': task.component, 'epoch': epoch}))

if callback_queue:
callback_queue.put(callbacks) # Send all callbacks at once
callbacks.clear()

return FitResult(component=task.component, success=True)
except Exception as e:
print(f"Error in _fit_component for {task.component.__class__.__name__}: {e}")
traceback.print_exc()
return FitResult(component=task.component, success=False, error=e)

def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None:
with Manager() as manager:
callback_queue: Optional[Queue] = manager.Queue() if callbacks else None
error_queue: Queue = manager.Queue()

def __init__(self, composite) -> None:
super().__init__(composite=composite)

def _listener(self, queue, callbacks):
while True:
item = queue.get()
if item == 'DONE':
break
method, *args = item
for callback in callbacks:
getattr(callback, method)(*args)

@staticmethod
def _mp_fit(args: tuple) -> None:
idx, component, data_preprocessed, proxy_callback = args

if proxy_callback:
proxy_callback.on_train_composite_begin(composite=component)

epochs = component.epochs
pbar = tqdm(total=epochs, position=idx)
pbar.set_description(f"Component {idx}: {type(component).__name__}")
for epoch in range(epochs):
if proxy_callback:
proxy_callback.on_epoch_component_begin(component=component, epoch=epoch)
component.fit(data=data_preprocessed)
pbar.update(1)
if proxy_callback:
proxy_callback.on_epoch_component_end(component=component, epoch=epoch)

if proxy_callback:
proxy_callback.on_train_composite_end(composite=component)
return component

def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None:
self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {}))

with Manager() as manager:
tasks = [ComponentTask(component=component, data=data, epochs=component.epochs)
for component in self.composite.components]

callback_thread = None
if callbacks:
callback_queue = manager.Queue() # Create a queue with the manager
callback_proxy = TMCompositeCallbackProxy(callback_queue)
callback_thread = self._start_callback_handler(callbacks, callback_queue, error_queue)

# Start listener thread
listener_thread = threading.Thread(target=self._listener, args=(callback_queue, callbacks))
listener_thread.start()
else:
callback_proxy = None

with Pool() as pool:
data_preprocessed = [component.preprocess(data) for component in self.composite.components]
self.composite.components = pool.map(TMCompositeMP._mp_fit,
((idx, component, data_preprocessed[idx], callback_proxy) for
idx, component in
enumerate(self.composite.components)))
results = self._execute_tasks(tasks, callback_queue)

if callbacks:
callback_queue.put('DONE') # Send done signal to listener
listener_thread.join() # Wait for listener to process all logs
self._process_results(results, error_queue)

def predict(self, data: dict, votes: dict, gating_mask: np.ndarray) -> np.array:
# Determine number of processes based on available CPU cores
n_processes = min(cpu_count(), len(self.composite.components))
with Pool(n_processes) as pool:
results = pool.starmap(self._component_predict, [
(component, data) for i, component in enumerate(self.composite.components)
])
self._cleanup(callback_queue, callback_thread)

# Aggregate results from each process
for i, result in enumerate(results):
for key, score in result.items():
self._check_errors(error_queue)

# Apply gating mask
masked_score = score * gating_mask[:, i]
self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {}))

if key not in votes:
votes[key] = masked_score
else:
votes[key] += masked_score
def _start_callback_handler(self, callbacks: List[TMCompositeCallback], callback_queue: Queue, error_queue: Queue):
def callback_handler() -> None:
while True:
try:
message = callback_queue.get()
if message == 'DONE':
break
if isinstance(message, list): # Handle batch of callbacks
for callback_message in message:
if isinstance(callback_message, CallbackMessage):
self._process_callbacks(callbacks, callback_message)
except Exception as e:
print(f"Error in callback handler: {e}")
traceback.print_exc()
error_queue.put(('callback_handler', e))

callback_thread = threading.Thread(target=callback_handler)
callback_thread.start()
return callback_thread

class TMCompositeSingleCPU(TMCompositeBase):

def __init__(self, composite) -> None:
super().__init__(composite=composite)

def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None:


def _execute_tasks(self, tasks: List[ComponentTask], callback_queue: Optional[Queue]) -> List[FitResult]:
results = []

# Create a partial function with the callback_queue
fit_component_partial = partial(self._fit_component, callback_queue=callback_queue)

with Pool(processes=self.max_workers) as pool:
# Map the tasks to the pool
async_results = [
pool.apply_async(fit_component_partial, (task,)) for task in tasks
]

# Collect results as they complete
for async_result, task in zip(async_results, tasks):
try:
result = async_result.get() # This will wait for the task to complete
results.append(result)
except Exception as e:
print(f"Exception when processing results for {task.component.__class__.__name__}: {e}")
traceback.print_exc()
results.append(
FitResult(component=task.component, success=False, error=e)
)

return results

def _process_results(self, results: List[FitResult], error_queue: Queue) -> None:
for result in results:
if result.success:
matching_component = next(
(c for c in self.composite.components if c.uuid == result.component.uuid),
None
)

if matching_component is not None:
idx = self.composite.components.index(matching_component)
self.composite.components[idx] = result.component
else:
error_message = f"Could not find a matching component for {result.component}"
print(error_message)
error_queue.put(('process_results', ValueError(error_message)))
else:
error_queue.put(('fit_component', result.error))

def _cleanup(self, callback_queue: Optional[Queue], callback_thread: Optional[threading.Thread]) -> None:
if callback_queue:
callback_queue.put('DONE')
if callback_thread:
callback_thread.join()


def _check_errors(self, error_queue: Queue) -> None:
if not error_queue.empty():
print("Errors occurred during fitting:")
while not error_queue.empty():
error_source, error = error_queue.get()
print(f"Error in {error_source}: {error}")

def predict(self, data: Dict[str, Any], votes: Dict[str, np.ndarray], gating_mask: np.ndarray) -> None:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
future_to_component = {
executor.submit(self._predict_component, component, data, gating_mask, i): component
for i, component in enumerate(self.composite.components)
}

for future in concurrent.futures.as_completed(future_to_component):
component = future_to_component[future]
try:
result = future.result()
for key, score in result.items():
if key not in votes:
votes[key] = score
else:
votes[key] += score
except Exception as e:
print(f"Exception when processing results for {component.__class__.__name__}: {e}")
traceback.print_exc()

def _predict_component(
self,
component: TMComponent,
data: Dict[str, Any],
gating_mask: np.ndarray,
component_idx: int
) -> Dict[str, np.ndarray]:
try:
# Preprocess data and get scores
_, scores = component.predict(component.preprocess(data))
scores = scores.reshape(scores.shape[0], -1) # Ensure 2D

# Normalize scores
denominator = np.maximum(np.ptp(scores, axis=1), 1e-8) # Avoid division by zero
normalized_scores = scores / denominator[:, np.newaxis]

# Apply gating mask
mask = gating_mask[:, component_idx].reshape(-1, 1) if gating_mask.ndim > 1 else gating_mask.reshape(-1, 1)
masked_scores = normalized_scores * mask

# Create and return votes
return {
"composite": masked_scores,
str(component): masked_scores
}
except Exception as e:
print(f"Error in predict_component for {component}: {str(e)}")
print(f"Shapes - scores: {scores.shape}, gating_mask: {gating_mask.shape}, mask: {mask.shape}")
traceback.print_exc()
return {}


class TMCompositeSingleCPU:
def __init__(self, composite, **kwargs) -> None:
self.composite = composite

def _component_predict(self, component, data):
data_preprocessed = component.preprocess(data)
_, scores = component.predict(data_preprocessed)

votes = dict()
votes["composite"] = np.zeros_like(scores, dtype=np.float32)
votes[str(component)] = np.zeros_like(scores, dtype=np.float32)

for i in range(scores.shape[0]):
denominator = np.max(scores[i]) - np.min(scores[i])
score = 1.0 * scores[i] / denominator if denominator != 0 else 0
votes["composite"][i] += score
votes[str(component)][i] += score

return votes

def _process_callbacks(self, callbacks: List[TMCompositeCallback], message: CallbackMessage) -> None:
method_name = message.method.name.lower()
for callback in callbacks:
try:
getattr(callback, method_name)(**message.kwargs)
except Exception as e:
print(f"Error in callback {callback.__class__.__name__}.{method_name}: {e}")
import traceback
traceback.print_exc()

def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None:
if callbacks is None:
callbacks = []

data_preprocessed = [component.preprocess(data) for component in self.composite.components]
epochs_left = [component.epochs for component in self.composite.components]
pbars = [tqdm(total=component.epochs) for component in self.composite.components]
for idx, (pbar, component) in enumerate(zip(pbars, self.composite.components)):
pbar.set_description(f"Component {idx}: {type(component).__name__}")

[callback.on_train_composite_begin(composite=self) for callback in callbacks]
self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {'composite': self.composite}))

epoch = 0
while any(epochs_left):
for idx, component in enumerate(self.composite.components):
if epochs_left[idx] > 0:
[callback.on_epoch_component_begin(component=component, epoch=epoch) for callback in callbacks]
self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': component, 'epoch': epoch}))

component.fit(data=data_preprocessed[idx])
[callback.on_epoch_component_end(component=component, epoch=epoch) for callback in callbacks]
pbars[idx].update(1)

self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': component, 'epoch': epoch}))

epochs_left[idx] -= 1

epoch += 1

[callback.on_train_composite_end(composite=self) for callback in callbacks]
self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {'composite': self.composite}))

def predict(self, data: dict, votes: dict, gating_mask: np.ndarray):
pbar = tqdm(total=len(self.composite.components))
def predict(self, data: Dict[str, Any], votes: Dict[str, np.ndarray], gating_mask: np.ndarray) -> None:
for i, component in enumerate(self.composite.components):
pbar.set_description(f"Component {i}: {type(component).__name__}")
component_votes = self._component_predict(component, data)
for key, score in component_votes.items():

# Apply gating mask
masked_score = score * gating_mask[:, i]

if key not in votes:
votes[key] = masked_score
else:
votes[key] += masked_score
pbar.update(1)


class TMComposite:

def __init__(
self,
components: Optional[list[TMComponent]] = None,
gate_function: Optional[Type[BaseGate]] = None,
gate_function_params: Optional[dict] = None,
use_multiprocessing: bool = False
components: Optional[List[TMComponent]] = None,
gate_function: Optional[type[BaseGate]] = None,
gate_function_params: Optional[Dict[str, Any]] = None,
use_multiprocessing: bool = False,
**kwargs
) -> None:
self.components: List[TMComponent] = components or []
self.use_multiprocessing = use_multiprocessing

if gate_function_params is None:
gate_function_params = dict()
gate_function_params = {}

self.gate_function_instance = gate_function(self, **gate_function_params) if gate_function else LinearGate(self,
**gate_function_params)
self.gate_function_instance: BaseGate = gate_function(self, **gate_function_params) if gate_function else LinearGate(self, **gate_function_params)

self.logic = TMCompositeSingleCPU(composite=self) if not use_multiprocessing else TMCompositeMP(composite=self)
self.logic: Union[TMCompositeSingleCPU, TMCompositeMP] = TMCompositeMP(composite=self, **kwargs) if use_multiprocessing else TMCompositeSingleCPU(composite=self, **kwargs)

def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None:
def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None:
self.logic.fit(data, callbacks)

def predict(self, data: dict) -> np.array:
votes = dict()
def predict(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]:
votes: Dict[str, np.ndarray] = {}

# Gating Mechanism
gating_mask: np.ndarray = self.gate_function_instance.predict(data)
@@ -196,7 +361,7 @@ def predict(self, data: dict) -> np.array:

return {k: v.argmax(axis=1) for k, v in votes.items()}

def save_model(self, path: Union[Path, str], format="pkl") -> None:
def save_model(self, path: Union[Path, str], format: str = "pkl") -> None:
path = Path(path) if isinstance(path, str) else path

if format == "pkl":
@@ -206,7 +371,7 @@ def save_model(self, path: Union[Path, str], format="pkl") -> None:
else:
raise NotImplementedError(f"Format {format} not supported")

def load_model(self, path: Union[Path, str], format="pkl") -> None:
def load_model(self, path: Union[Path, str], format: str = "pkl") -> None:
path = Path(path) if isinstance(path, str) else path

if format == "pkl":
@@ -217,7 +382,7 @@ def load_model(self, path: Union[Path, str], format="pkl") -> None:
else:
raise NotImplementedError(f"Format {format} not supported")

def load_model_from_components(self, path: Union[Path, str], format="pkl") -> None:
def load_model_from_components(self, path: Union[Path, str], format: str = "pkl") -> None:
path = Path(path) if isinstance(path, str) else path

if not path.is_dir():
@@ -227,7 +392,8 @@ def load_model_from_components(self, path: Union[Path, str], format="pkl") -> No
files = [f for f in path.iterdir() if f.is_file() and f.suffix == f".{format}"]

# Group files by component details
component_groups = defaultdict(list)
from collections import defaultdict
component_groups: Dict[str, List[Tuple[int, Path]]] = defaultdict(list)
for file in files:
parts = file.stem.split('-')
epoch = int(parts[-1])

0 comments on commit 79cb1c9

Please sign in to comment.