From e823fcf4facf0b4627c9a2ccec8c6f218412d4d7 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 13 Mar 2024 18:47:59 +0100 Subject: [PATCH 01/44] enable registering implementations for custom ops (PR2435) --- thunder/executors/passes.py | 6 +++++- thunder/tests/test_extend.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index ab6b35abba..144aa30974 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -57,12 +57,13 @@ def preserve_bsym(bsym: BoundSymbol) -> Any: # If the executor has an execution transform, it's called and True is returned # If no executor can execute the BoundSymbol, False is returned def visit_helper_(bsym: BoundSymbol) -> None | bool: - if bsym.sym.executor is not None or bsym.sym.python_impl is not None: + if bsym.sym.python_impl is not None: return None ex: Executor for ex in executors_list: # TODO Consider allowing operator executors to claim portions of operations + # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) ): @@ -87,6 +88,9 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: safe_map_flat(update_swapmap, bsym.output, out) return True + if bsym.sym.executor is not None: + return None + return False def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 9e77f8f496..00dfe71fd7 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -131,3 +131,39 @@ def test_get_all_executors_includes_all_native_executors(): if torch.cuda.is_available(): expected.update({"nvfuser"}) assert actual == expected + + +def test_register_implementation_custom_op(): + myex = OperatorExecutor("myex", version="0.1") + register_executor(myex) + + def _myadd(a, b): + return a + b + + myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd) + myadd2 = myex.register_operator("myadd2", like=_myadd, fn=_myadd) + + def fn(a, b): + return myadd1(a, b) + + cfn = thunder.jit(fn, executors=[myex]) + + a = torch.randn(2, 2) + b = torch.randn(2, 2) + + res = cfn(a, b) + + assert "myadd1" in str(thunder.last_traces(cfn)[-1]) + + def myadd_trafo(a, b): + return myadd2(a, b) + + myex.register_implementation(myadd1, execution_transform=myadd_trafo) + + cfn = thunder.jit(fn, executors=[myex]) + res = cfn(a, b) + + s = str(thunder.last_traces(cfn)[-1]) + assert "myadd2" in s and "myadd1" not in s + + deregister_executor(myex) From 7c916c13675bb05b1a5522a9c797b33e997e4f19 Mon Sep 17 00:00:00 2001 From: nikitaved Date: Wed, 13 Mar 2024 21:13:02 +0100 Subject: [PATCH 02/44] proxy rename: make sure prologue/epilogue contain proper re-named proxies (PR2432) --- thunder/core/jit_ext.py | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e1c1f92826..3a6f17a2f4 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -917,9 +917,15 @@ def _general_jit_const_callback(value: Any) -> WrappedValue: # TODO(nikitaved): maybe call it upon Frame creation def _maybe_update_proxy_name(orig_value: Any, name: str): + # Names that we do not re-name proxies into as these are reserved + proxy_rename_ignore_names = { + "fn", # For example, `fn = globals()['__function_obj']` in prologue + "obj", # For example, `obj = fn.forward` in prologue + } + uvalue = unwrap(orig_value) - if isinstance(uvalue, Proxy) and is_proxy_name_available(name): + if isinstance(uvalue, Proxy) and (name not in proxy_rename_ignore_names) and is_proxy_name_available(name): uvalue_var = variableify(uvalue) rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap if uvalue_var not in rename_proxy_swapmap: @@ -927,8 +933,11 @@ def _maybe_update_proxy_name(orig_value: Any, name: str): rename_proxy_swapmap[uvalue_var] = uvalue_renamed -def _apply_trace_proxy_rename(trace: TraceCtx, name: None | str = None) -> TraceCtx: - rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap +def _apply_trace_proxy_rename( + trace: TraceCtx, rename_proxy_swapmap: None | dict[Variable, Proxy] = None, name: str | None = None +) -> TraceCtx: + if rename_proxy_swapmap is None: + rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap new_trace = from_trace(trace) @@ -1424,11 +1433,24 @@ def thunder_general_jit( if epilogue_trace: bind_inputs("epilogue", epilogue_trace, pro_to_epi + comp_to_epi, pro_to_epi_proxies + comp_to_epi_proxies) - with general_jit_ctx(ctx): - # TODO(nikitaved): update prologue/epilogue as well - computation_trace = _apply_trace_proxy_rename(computation_trace, "computation") - if epilogue_trace: - # TODO: is it safe to use current swapdict here? - epilogue_trace = _apply_trace_proxy_rename(epilogue_trace, "epilogue") + # Returns a new swapmap dictionary which has the keys (ctx._proxy_swapmap.key() & variableify(proxies)) + def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]: + proxy_swapmap = ctx._proxy_swapmap + proxy_vars = {variableify(p) for p in proxies} + common_vars = proxy_swapmap.keys() & proxy_vars + restricted_proxy_swapmap = {v: proxy_swapmap[v] for v in common_vars} + return restricted_proxy_swapmap + + # Update prologue trace by renaming proxies which are passed from prologue to the computation trace + prologue_trace = _apply_trace_proxy_rename(prologue_trace, restrict_proxy_swapmap(pro_to_comp_proxies)) + + # Update computation trace by renaming proxies which are in the ctx._proxy_swapmap + computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation") + + # Update epilogue trace by renaming proxies which are passed to the epilogue trace from prologue and computation traces + if epilogue_trace: + epilogue_trace = _apply_trace_proxy_rename( + epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue" + ) return prologue_trace, computation_trace, epilogue_trace From 9dbb7053480ea1c9f37bb715988f41285345bb2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 13 Mar 2024 21:13:37 +0100 Subject: [PATCH 03/44] LitGPT benchmarking script fixes (PR2436) --- examples/lit-gpt/test_parametrized.py | 35 ++++++++++++++------------ thunder/benchmarks/benchmark_litgpt.py | 22 +++++++--------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py index 20ddaa9278..bca55173fa 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/examples/lit-gpt/test_parametrized.py @@ -14,13 +14,10 @@ from absl.testing import parameterized from absl.testing import absltest import os -import pickle import subprocess -import warnings import json import pandas as pd from datetime import datetime -import threading class Runner: ''' @@ -65,11 +62,11 @@ def complete_dataframe(self, is_teardown): if self.output_format not in ('none', 'print'): output_ext = {'xlsx': '.xlsx', }[self.output_format] if not is_teardown: - filename = '/scratch/lightning-thunder/examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) + filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) else: current_time = datetime.now().strftime('%Y-%m-%d_%H-%M') filename = f"{current_time}_litgpt_benchmark" + str(output_ext) - filename = '/scratch/lightning-thunder/examples/lit-gpt/' + str(filename) + filename = 'examples/lit-gpt/' + str(filename) with pd.ExcelWriter(filename, engine='xlsxwriter') as writer: self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)') @@ -87,19 +84,24 @@ def complete_dataframe(self, is_teardown): print(self.memory_used_GB_df) def run_benchmark(self, kwargs): - # benchmark_file = '/scratch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py' + # benchmark_file = 'thunder/benchmarks/benchmark_litgpt.py' command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) if kwargs['distributed_mode'] != 'none': - subprocess_cmd = ["torchrun", "--nproc_per_node=8", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + nproc_per_node = torch.cuda.device_count() + subprocess_cmd = ["torchrun", f"--nproc_per_node={nproc_per_node}", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] subprocess_cmd.extend(command_list) else: - subprocess_cmd = ["python", "{}".format(benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = ["python", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) + if proc_output.returncode: + print(proc_output.stdout) + print(proc_output.stderr) + proc_output.check_returncode() with open(self.json_file_path, 'r') as file: self.perf_metrics_dict = json.load(file) @@ -116,12 +118,13 @@ def run_benchmark(self, kwargs): pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success." return True, pass_str else: + print(proc_output.stdout) + print(proc_output.stderr) fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure." return False, fail_str else: return True, "Test passed successfully." - # print(proc_output.stdout) - # print(proc_output.stderr) + class Test(parameterized.TestCase): @@ -152,12 +155,12 @@ def tearDownClass(cls): # dict(distributed_mode = "none", shard_mode = "none")), # (dict(model_name = 'Llama-2-7b-hf', micro_batch_size=1), # dict(model_name = 'Llama-2-7b-hf', micro_batch_size=2), - # dict(model_name = 'Llama-2-13b{}-hf', micro_batch_size=1), - # dict(model_name = 'Llama-2-13b{}-hf', micro_batch_size=2), + # dict(model_name = 'Llama-2-13b-hf', micro_batch_size=1), + # dict(model_name = 'Llama-2-13b-hf', micro_batch_size=2), # dict(model_name = 'stablecode-completion-alpha-3b', micro_batch_size=1), # dict(model_name = 'stablecode-completion-alpha-3b', micro_batch_size=2), - # dict(model_name = 'Mistral-7B-{}v0.1', micro_batch_size=1), - # dict(model_name = 'Mistral-7B-{}v0.1', micro_batch_size=2), + # dict(model_name = 'Mistral-7B-v0.1', micro_batch_size=1), + # dict(model_name = 'Mistral-7B-v0.1', micro_batch_size=2), # dict(model_name = 'open_llama_3b', micro_batch_size=1), # dict(model_name = 'open_llama_3b', micro_batch_size=2), # dict(model_name = 'open_llama_3b', micro_batch_size=4), @@ -178,8 +181,8 @@ def tearDownClass(cls): # dict(model_name = 'pythia-6.9b', micro_batch_size=2), # dict(model_name = 'pythia-12b', micro_batch_size=1), # dict(model_name = 'pythia-12b', micro_batch_size=2), - # dict(model_name = 'falcon-7b{}', micro_batch_size=1), - # dict(model_name = 'falcon-7b{}', micro_batch_size=2)), + # dict(model_name = 'falcon-7b', micro_batch_size=1), + # dict(model_name = 'falcon-7b', micro_batch_size=2)), # compile = ("eager", "inductor", "thunder", "thunder_inductor",) # ) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 3bad00892d..40db2db5de 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -1,7 +1,5 @@ import os -import copy import time -import pprint import torch import functools @@ -19,18 +17,14 @@ except: LIGHTNING_AVAILABLE = False -world_size, local_rank, global_rank = None, None, None -if "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ: +world_size = int(os.environ.get("WORLD_SIZE", 1)) +local_rank = int(os.environ.get("LOCAL_RANK", 0)) +global_rank = int(os.environ.get("RANK", 0)) +if world_size > 1: torch_dist.init_process_group(backend="nccl") - world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ["LOCAL_RANK"]) - global_rank = int(os.environ["RANK"]) pg = torch_dist.distributed_c10d._get_default_group() - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - use_ddp = True -else: - device = torch.device("cuda", 0) +device = torch.device("cuda", local_rank) +torch.cuda.set_device(device) def configure_optimizers(model, weight_decay, learning_rate, betas, device_type): @@ -38,7 +32,9 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type) fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device_type == "cuda" - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, fused=use_fused) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas, fused=use_fused + ) return optimizer From 033c6fb3af8179e58927ae13cbfacc47ff411eab Mon Sep 17 00:00:00 2001 From: apaz Date: Wed, 13 Mar 2024 19:07:39 -0500 Subject: [PATCH 04/44] Add `torch.t()` and `x.T` (PR2441) --- thunder/core/proxies.py | 5 +++ thunder/tests/opinfos.py | 66 +++++++++++++++++++++++++++++++++++++++ thunder/torch/__init__.py | 28 +++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index e4c0ba01e6..f8016f5ea3 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1461,6 +1461,11 @@ def __rmatmul__(self, other): # Transposes # + @property + def T(self): + method = resolve_method("T", self) + return method(self) + @property def mT(self): method = resolve_method("mT", self) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 2809c9af8a..9a3a7687e0 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3871,6 +3871,72 @@ def torch_permute_reference(a, *dims): shape_ops.append(permute_opinfo) +def t_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape + cases = ( + (), + (1), + (4), + (4, 5), + ) + + for shape in cases: + yield SampleInput(make(shape)) + + +def t_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + + # shape, error type, error message + cases = ( + ((4, 5, 6), RuntimeError, r"t\(\) expects a tensor with <= 2 dimensions, but self is 3D"), + ( + (4, 5, 6, 7), + RuntimeError, + r"t\(\) expects a tensor with <= 2 dimensions, but self is 4D", + ), + ) + + for shape, err_type, err_msg in cases: + yield SampleInput(make(shape)), err_type, err_msg + + +t_opinfo = OpInfo( + ltorch.t, + sample_input_generator=t_sample_generator, + error_input_generator=t_error_generator, + torch_reference=lambda x: torch.Tensor.t(x), +) +shape_ops.append(t_opinfo) + + +def reverse_dims_T_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape + cases = ( + (), + (1), + (4), + (4, 5), + (4, 5, 6), + (4, 5, 6, 7), + ) + + for shape in cases: + yield SampleInput(make(shape)) + + +reverse_dims_T_opinfo = OpInfo( + ltorch.reverse_dims_T, + sample_input_generator=reverse_dims_T_sample_generator, + torch_reference=lambda x: x.T, +) +shape_ops.append(reverse_dims_T_opinfo) + + def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a4388d9903..055df74f1e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -29,6 +29,7 @@ from thunder.core.symbol import Symbol from thunder.core.transforms import register_grad, put_grads from thunder.core.prims import get_grad, put_grad +from thunder.core.baseutils import run_once __all__ = [ "is_available", @@ -884,6 +885,33 @@ def squeeze(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorL return clang.squeeze(a, dims) +@torchsymbol(torch.t, is_method=True) +def t(a: TensorLike, /) -> TensorLike: + utils.check( + a.ndim <= 2, + lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D", + RuntimeError, + ) + return prims.transpose(a, (1, 0)) if a.ndim == 2 else a + + +@run_once +def warn_ndim_not_2(): + warnings.warn( + "The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and will throw an error in a future release." + "Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor." + ) + + +def reverse_dims_T(a: TensorLike, /) -> TensorLike: + if a.ndim != 2: + warn_ndim_not_2() + return a if a.ndim < 2 else prims.transpose(a, tuple(reversed(range(a.ndim)))) + + +register_method("T", reverse_dims_T) + + # TODO Add type annotations # See https://pytorch.org/docs/master/generated/torch.tensor_split.html @torchsymbol(torch.tensor_split, is_method=True) From 3b68934ace8d8a56f9dae63c1e897d413f1bd5da Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 14 Mar 2024 01:08:25 +0100 Subject: [PATCH 05/44] disable preprocessing in thunder.compile (PR2442) --- thunder/__init__.py | 67 +++++++++++++++++++--------- thunder/common.py | 21 +++------ thunder/core/transforms.py | 16 +++++++ thunder/tests/test_cudnn_executor.py | 4 +- thunder/tests/test_grad.py | 16 +++---- thunder/tests/test_nvfuser.py | 14 +++--- 6 files changed, 85 insertions(+), 53 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index df817a6e43..ae252b9836 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -1,6 +1,6 @@ from functools import wraps from typing import Any -from collections import defaultdict +from collections import defaultdict, namedtuple from collections.abc import Callable from collections.abc import Sequence from contextlib import contextmanager @@ -277,6 +277,21 @@ def _recursive_jit_call_warning() -> None: ) +CacheEntry = namedtuple( + "CacheEntry", + [ + "prologue_fn", + "prologue_traces", + "computation_fn", + "computation_traces", + "epilogue_fn", + "epilogue_traces", + "backward_fn", + "backward_traces", + ], +) + + # This function will replace compile() (below) before RC1 # TODO RC1 Consider adding a debug_log parameter to control debug printing # TODO RC1 Consider renaming compile_options to additional_compile_options @@ -386,9 +401,17 @@ def get_computation_and_inputs(*args, **kwargs): # Checks cache cs.last_trace_cache_start = time.time_ns() if (cd.cache_option is CACHE_OPTIONS.CONSTANT_VALUES) or (cd.cache_option is CACHE_OPTIONS.SYMBOLIC_VALUES): - for pro, pro_traces, comp, comp_traces, epilogue, epilogue_traces, backward_fn, backward_traces in reversed( - cs.interpreter_cache - ): + for cache_entry in reversed(cs.interpreter_cache): + ( + pro, + pro_traces, + comp, + comp_traces, + epilogue, + epilogue_traces, + backward_fn, + backward_traces, + ) = cache_entry try: cs.last_prologue_execution_start = time.time_ns() if epilogue: @@ -415,10 +438,11 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_computation_transformation_start = 0 cs.last_computation_transformation_stop = 0 - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi if cd.cache_option is CACHE_OPTIONS.SAME_INPUT: if len(cs.interpreter_cache): + cache_entry = cs.interpreter_cache[0] ( pro, pro_traces, @@ -428,7 +452,7 @@ def get_computation_and_inputs(*args, **kwargs): epilogue_traces, backward_fn, backward_traces, - ) = cs.interpreter_cache[0] + ) = cache_entry cs.last_prologue_execution_start = time.time_ns() if epilogue: @@ -449,7 +473,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_traces = pro_traces cs.last_prologue = pro - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi cs.cache_misses += 1 cs.last_trace_cache_stop = time.time_ns() @@ -553,17 +577,20 @@ def get_computation_and_inputs(*args, **kwargs): backward_traces = [] # TODO RC1 Update the cache + cache_entry = CacheEntry( + pro, protraces, comp, extraces, epilogue, epilogue_traces, backward_fn, backward_traces + ) if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: - cs.interpreter_cache.append( - (pro, protraces, comp, extraces, epilogue, epilogue_traces, backward_fn, backward_traces) - ) + cs.interpreter_cache.append(cache_entry) cs.last_computation_transformation_stop = time.time_ns() cs.last_traces = [computation_trc] + extraces cs.last_prologue_traces = [prologue_trc] + protraces cs.last_prologue = pro - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi + + cd.get_computation_and_inputs = get_computation_and_inputs @wraps(fn) def fn_(*args, **kwargs) -> Any: @@ -575,18 +602,18 @@ def fn_(*args, **kwargs) -> Any: cs.last_trace_host_start = time.time_ns() cs.calls += 1 - inps, pro_to_epi, comp, epilogue, backward_fn = get_computation_and_inputs(*args, **kwargs) + cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) cs.last_trace_host_execution_start = time.time_ns() - result = comp(*inps) + result = cache_entry.computation_fn(*inps) - if backward_fn: + if cache_entry.backward_fn: # Run the compiled forward function data_for_autograd, (saved_tensors, saved_other) = result # Connect produced tensors with PyTorch's autograd graph ThunderFunction.apply( - backward_fn, + cache_entry.backward_fn, saved_tensors, saved_other, data_for_autograd["flat_output"], @@ -594,21 +621,17 @@ def fn_(*args, **kwargs) -> Any: ) result = data_for_autograd["output"] - if epilogue: + if cache_entry.epilogue_fn: result, comp_to_epi = result - epilogue(*pro_to_epi, *comp_to_epi) + cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) cs.last_trace_host_execution_stop = time.time_ns() cs.last_computation_execution_stop = cs.last_trace_host_execution_stop - cs.last_executed = comp + cs.last_executed = cache_entry.computation_fn cs.last_trace_cache_stop = time.time_ns() cs.last_trace_host_stop = time.time_ns() - # Updates statistics - cs.last_executed = comp - cs.last_trace_host_stop = time.time_ns() - return result if isinstance(fn, pytorch.nn.Module): diff --git a/thunder/common.py b/thunder/common.py index fbdfaa0db7..74ca9f6473 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -189,6 +189,7 @@ def __init__( use_rematerialization: bool = False, debug_log: None | StringIO = None, compile_options: dict[str, Any] = {}, + get_computation_and_inputs: Callable | None = None, ): # Records whether we're using the thunder.jit() entrypoint or not # The thunder.jit() entrypoint introduces important architectural updates, @@ -196,6 +197,9 @@ def __init__( # and are being temporarily maintained to facilitate their development. self.using_jit = using_jit + # runs prologues to get the compute/backward/epilogue function and inputs + self.get_computation_and_inputs = get_computation_and_inputs + # Resolves cache option self.cache_option = resolve_cache_option(cache_option) @@ -262,20 +266,9 @@ def __init__( self.num_constant_args = 0 self._processed_function: Callable - if disable_preprocessing: - self._processed_function = fn - else: - warnings.warn( - "please use thunder.jit if possible and upgrade and use thunder.jit if it is not yet possible" - ) - self._processed_function = preprocess(fn, is_module=self.is_module) - - # TODO Revisit assuming parameters are const - if self.is_module: - self.additional_param_names = self.processed_function._additional_param_names - self.additional_param_values = self.processed_function._additional_param_values - self.additional_return_names = self.processed_function._additional_return_names - self.num_constant_args = len(self.additional_param_values) + + assert disable_preprocessing, "please use thunder.jit if you need preprocessing" + self._processed_function = fn # Disallows overwriting processed_function @property diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 9922fede58..42dcddced4 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -546,6 +546,22 @@ def _flatten(bsym: BoundSymbol): def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, args=None, kwargs=None) -> None: idx: int = 0 from thunder.common import ThunderOptimizedModule + from thunder import ThunderModule, compile_data + + if isinstance(tom, ThunderModule) or thunder.compile_data(tom).using_jit: + assert args is not None, "populate grad needs args (and possibly kwargs) to work with ThunderModules" + if kwargs is None: + kwargs = {} + _, computation_inputs, _ = compile_data(tom).get_computation_and_inputs(*args, **kwargs) + for p in computation_inputs: + if isinstance(p, torch.Tensor) and p.requires_grad: + # Supports grad accumulation (like when weight tying) + if p.grad is not None: + p.grad += grads[idx] + else: + p.grad = grads[idx] + idx += 1 + return if tom is not None and isinstance(tom, ThunderOptimizedModule) and tom._additional_param_values is not None: for p in tom._additional_param_values: diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index f90d2a231c..5ab26bd5af 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -127,7 +127,7 @@ def test(query, key, value, is_causal=False, attn_mask=None): query, key, value, is_causal=is_causal, attn_mask=attn_mask ) - ctest = thunder.compile(test, executors_list=[cudnn_ex]) + ctest = thunder.jit(test, executors=[cudnn_ex]) actual = ctest(query, key, value, is_causal=is_causal, attn_mask=attn_mask) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=1e-2) last_trace = thunder.last_traces(ctest)[-1] @@ -184,7 +184,7 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): pytest.xfail("Only interleaved layout is supported pre 8.9.2.") for sample in op.reference_inputs(device, dtype, requires_grad=False): - cfn = thunder.compile(op_name_to_fn[op.name], executors_list=[cudnn_ex, cudnn_layernorm_ex]) + cfn = thunder.jit(op_name_to_fn[op.name], executors=[cudnn_ex, cudnn_layernorm_ex]) result = run_snippet( snippet_torch_consistency, diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 931c365fb9..659dcb6523 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1250,11 +1250,11 @@ def test_populate_grads_mlp(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=(x,)) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-3, rtol=1e-5) @@ -1277,11 +1277,11 @@ def test_populate_grads_csa(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) @@ -1304,11 +1304,11 @@ def test_populate_grads_block(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) @@ -1340,7 +1340,7 @@ def test_populate_grads_nanogpt(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) def grad_specifier(out) -> None: logits, loss = out @@ -1349,7 +1349,7 @@ def grad_specifier(out) -> None: tom_grad = grad(tom, grad_specifier=grad_specifier) thunder_grads = tom_grad(x, targets) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x, targets]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 9ec09989bd..166c4ae3c6 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -338,10 +338,10 @@ def test_cse_rematerialization(executor, device, _): x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - compiled_func = thunder.compile( + compiled_func = thunder.jit( model.eval(), - disable_torch_autograd_support=True, - executors_list=executor.executors_list(), + disable_torch_autograd=True, + executors=executor.executors_list(), nv_enable_bookend=False, ) compiled_func(x, y) @@ -357,10 +357,10 @@ def test_cse_rematerialization(executor, device, _): # fusion groups 1 and 7 correspond with the apply_rotary_emb function # Nvfuser with recomputation should use precomputed cos and sin values. assert len(fusion_bsyms[1].args) == len(fusion_bsyms[7].args) - assert fusion_bsyms[1].args[0].name == "freqs_cos" - assert fusion_bsyms[1].args[1].name == "freqs_sin" - assert fusion_bsyms[7].args[0].name == "freqs_cos" - assert fusion_bsyms[7].args[1].name == "freqs_sin" + assert fusion_bsyms[1].subsymbols[0].output.name == "freqs_cos" + assert fusion_bsyms[1].subsymbols[1].output.name == "freqs_sin" + assert fusion_bsyms[7].subsymbols[0].output.name == "freqs_cos" + assert fusion_bsyms[7].subsymbols[1].output.name == "freqs_sin" # Tests that two separated nvFuser regions can be merged when they don't depend From b7cc7ac9861fe4f47a022e03000afc7033747913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 14 Mar 2024 01:09:49 +0100 Subject: [PATCH 06/44] Use single quotes for zsh support (PR2440) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9593c3aea8..b119d4860b 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The main goal for Lightning Thunder is to allow optimizing user programs in the Install the nvFuser nightly, which will also install the matching PyTorch nightly: ```bash -pip install --pre "nvfuser-cu121[torch]" --extra-index-url https://pypi.nvidia.com +pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com ``` Install Thunder: From 326187c7496105815761ae32ed2be7b7ac998747 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 14 Mar 2024 10:16:21 +0100 Subject: [PATCH 07/44] substitute preprocess in FSDP tutorial (PR2445) --- docs/source/index.rst | 2 +- notebooks/dev_tutorials/fsdp_tutorial.ipynb | 2006 +++++++++++++++++++ notebooks/fsdp_tutorial.ipynb | 1489 -------------- 3 files changed, 2007 insertions(+), 1490 deletions(-) create mode 100644 notebooks/dev_tutorials/fsdp_tutorial.ipynb delete mode 100644 notebooks/fsdp_tutorial.ipynb diff --git a/docs/source/index.rst b/docs/source/index.rst index 804d5393b3..6011f29193 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,7 +99,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik Additional executors Distributed Data Parallel What's next - FSDP Tutorial + FSDP Under the Hood Tutorial .. toctree:: :maxdepth: 1 diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb new file mode 100644 index 0000000000..c5c46d960a --- /dev/null +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -0,0 +1,2006 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FSDP Tutorial\n", + "\n", + "In this tutorial, we will walk through the implementation of Fully Sharded Data Parallel (FSDP) with Zero2 sharding strategy in `thunder`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Introduction\n", + "\n", + "In recent times, the LLM models have grown so large that all the model parameters don't fit on a single GPU. To circumvent this problem, there are various strategies like Tensor Parallel, Pipeline Parallel, Fully Sharded Data Parallel, etc to train these large models. In this tutorial, we discuss and implement Zero2 strategy for Fully Sharded Data Parallel (FSDP).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### What is Zero2 strategy for FSDP?\n", + "\n", + "In this strategy, we shard the model parameters across all the availabe GPUs. That is each GPU holds onto only a chunk of the parameter. During the forward pass, all GPUs call `all_gather` communication primitive to gather the parameters from other GPUs. Unlike Zero3 strategy which frees the parameter after forward pass, we save these unsharded parameters for backward pass. This is to save the overhead of extra communication. In the backward pass, we utilize the saved parameters and compute the gradients. Once the gradients are computed, we use `reduce_scatter` communication primitive to reduce (average) the gradients across all GPUs and scatter those gradients so that a given GPU holds only a chunk of gradient.\n", + "\n", + "For more information on FSDP, we recommend reading\n", + "\n", + "1. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel - [Link](https://arxiv.org/abs/2304.11277)\n", + "2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models - [Link](https://arxiv.org/abs/1910.02054)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Example Model\n", + "\n", + "For this example we will have a simple model `Linear(Tanh(Linear(x)))` which will be sharded over 2 GPUs\n", + "\n", + "**NOTE**: We are generating the abstract trace so we don't actually need a system with 2 GPUs for this. It is only required when we execute this trace." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.distributed\n", + "import thunder\n", + "import thunder.distributed\n", + "from IPython.display import Code" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device='cuda'\n", + "dim = 64\n", + "def create_model():\n", + " layers = [torch.nn.Linear(dim, dim, bias=False),\n", + " torch.nn.Tanh(),\n", + " torch.nn.Linear(dim, dim, bias=False)]\n", + " return torch.nn.Sequential(*layers).to(device)\n", + "\n", + "# Model\n", + "model = create_model()\n", + "# Input\n", + "x = torch.randn(dim, dim, device=device)\n", + "\n", + "\n", + "# we want to obtain a functional version of our model. The JIT does that internally and we reach into those\n", + "# internals here\n", + "thunder_model = thunder.jit(model)\n", + "cache_rec, i_, _ = thunder.compile_data(thunder_model).get_computation_and_inputs(x)\n", + "computation_trace = cache_rec.computation_traces[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def wrap_as_highlighted_code(trace):\n", + " return Code(str(trace), language=\"python\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can show the functional version:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n",
+       "  # input: "cuda:0 f32[64, 64]" \n",
+       "  # t_0_weight: "cuda:0 f32[64, 64]" \n",
+       "  # t_2_weight: "cuda:0 f32[64, 64]" \n",
+       "  t0 = torch.nn.functional.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "    # t0 = ltorch.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "      # t0 = prims.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "  [t1] = nvFusion0(t0)\n",
+       "    # t1 = prims.tanh(t0)  # t1: "cuda:0 f32[64, 64]"\n",
+       "  t2 = torch.nn.functional.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "    # t2 = ltorch.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "      # t2 = prims.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "  return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} input: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}0\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}2\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{n}{t0} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t0 = ltorch.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t0 = prims.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{p}{[}\\PY{n}{t1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{nvFusion0}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{)}\n", + " \\PY{c+c1}{\\PYZsh{} t1 = prims.tanh(t0) \\PYZsh{} t1: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t2 = ltorch.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t2 = prims.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t2}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t2}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", + " # input: \"cuda:0 f32[64, 64]\" \n", + " # t_0_weight: \"cuda:0 f32[64, 64]\" \n", + " # t_2_weight: \"cuda:0 f32[64, 64]\" \n", + " t0 = torch.nn.functional.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " # t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " # t0 = prims.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " [t1] = nvFusion0(t0)\n", + " # t1 = prims.tanh(t0) # t1: \"cuda:0 f32[64, 64]\"\n", + " t2 = torch.nn.functional.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " # t2 = prims.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wrap_as_highlighted_code(computation_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 1 : Configuration\n", + "\n", + "For our implementation of FSDP, we will generate the trace where we are sharding our model over 2 GPU" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# FSDP Config \n", + "# Usually these values are set in the environment by `torchrun` but for this example\n", + "# we will set them ourselves\n", + "world_size = 2 # We have two processes.\n", + "global_rank = 0 # Current process is the very first process." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 2: Function to shard parameters\n", + "\n", + "Next step is to write a function which will actually shard the parameters over 0-dim." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: We shard over 0th dimension of the param.\n", + "def shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:\n", + " # We will keep it simple and error if param's 0th dim is not divisible by ``world_size``.\n", + " # Alternative is that we can pad our parameters so that they are divisible by `world_size`.\n", + " assert param.shape[0] % world_size == 0,(\n", + " f\"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})\"\n", + " f\" to be divisible by the world size ({world_size})\"\n", + " )\n", + " chunk_size = param.shape[0] // world_size\n", + "\n", + " # rank helps us determine which chunk of the parameter we will hold.\n", + " shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()\n", + " param.data = shard\n", + "\n", + "# Shard each parameter of the model\n", + "for param_name, param in model.named_parameters():\n", + " shard_param(param, global_rank, world_size, param_name)\n", + " # Mark the param to denote that it is sharded.\n", + " # This is required by the synchronization primitive we will use below.\n", + " param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (0): Linear(in_features=64, out_features=64, bias=False)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=64, out_features=64, bias=False)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Verify our model looks as expected\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Let us verify that we have actually sharded the parameters.\n", + "# Checking if the weight of 1st Linear layer is sharded over 0th dim.\n", + "assert model[0].weight.shape == (dim / world_size, dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 3: Add an operation to synchronize the parameters before calling the model.forward." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have to create a process group. This is needed because the synchronization primitive `synchronize` that we will use to gather and scatter our weights in forward and backward requires a process group." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a process group\n", + "options = torch.distributed.distributed_c10d.ProcessGroup.Options(backend=\"nccl\")\n", + "process_group = torch.distributed.distributed_c10d.ProcessGroup(torch.distributed.distributed_c10d.Store(),\n", + " global_rank, world_size, options)\n", + "torch.distributed.distributed_c10d.GroupMember.WORLD = process_group" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Because we are trying to play tricks with the traces and skip the part that inserts the synchronization automatically but also does the translation from PyTorch to thunder, we need to drop one layer of the trace to apply this manually.\n", + "(This is really hacky, don't try it at home!)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n",
+       "  # input: "cuda:0 f32[64, 64]" \n",
+       "  # t_0_weight: "cuda:0 f32[64, 64]" \n",
+       "  # t_2_weight: "cuda:0 f32[64, 64]" \n",
+       "  t0 = ltorch.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "    # t0 = ltorch.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "      # t0 = prims.linear(input, t_0_weight, None)  # t0: "cuda:0 f32[64, 64]"\n",
+       "  t1 = prims.tanh(t0)  # t1: "cuda:0 f32[64, 64]"\n",
+       "  t2 = ltorch.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "    # t2 = ltorch.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "      # t2 = prims.linear(t1, t_2_weight, None)  # t2: "cuda:0 f32[64, 64]"\n",
+       "  return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} input: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}0\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}2\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{n}{t0} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t0 = ltorch.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t0 = prims.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t1} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t1: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t2 = ltorch.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t2 = prims.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t2}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t2}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.core.prims as prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", + " # input: \"cuda:0 f32[64, 64]\" \n", + " # t_0_weight: \"cuda:0 f32[64, 64]\" \n", + " # t_2_weight: \"cuda:0 f32[64, 64]\" \n", + " t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " # t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " # t0 = prims.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", + " t1 = prims.tanh(t0) # t1: \"cuda:0 f32[64, 64]\"\n", + " t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " # t2 = prims.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "### DON'T TRY THIS AT HOME\n", + "computation_trace.bound_symbols[3].sym = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0].sym\n", + "if cache_rec.computation_traces[0].bound_symbols[4].subsymbols:\n", + " computation_trace.bound_symbols[4] = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0]\n", + "computation_trace.bound_symbols[5].sym = cache_rec.computation_traces[0].bound_symbols[5].subsymbols[0].sym\n", + "\n", + "wrap_as_highlighted_code(computation_trace)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# now we have a functional version of the model which\n", + "# takes as inputs the expected arguments and all the parameters.\n", + "functional_forward = computation_trace.python_callable()\n", + "\n", + "# This function creates a model with synchronization\n", + "# before calling the forward pass.\n", + "def model_with_syncs(x, *params):\n", + " # We call `prims.synchronize` on all the parameters.\n", + " # This is essentially calling `all_gather` so that we have the complete\n", + " # parameter before we actually to the forward computation.\n", + " unsharded_params = []\n", + " for param in params:\n", + " unsharded_params.append(thunder.distributed.prims.synchronize(param, process_group))\n", + "\n", + " return functional_forward(x, *unsharded_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us now see what the trace of our model looks like with all the synchronization.\n", + "\n", + "Two main observations regarding the below trace \n", + "1. We can observe the `prims.synchronize` that we inserted using `model_with_syncs`.\n", + "2. Output of the `prims.synchronize` have the shape of unsharded (original) parameter.\n", + "\n", + "With this, we have implemented the FSDP for the forward pass of our model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.distributed.prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def model_with_syncs(x, *params):\n",
+       "  # x: "cuda:0 f32[64, 64]" \n",
+       "  # params: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  = params\n",
+       "  t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t2: "cuda:0 f32[64, 64]"\n",
+       "  t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = ltorch.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
+       "    # t4 = prims.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = prims.tanh(t4)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = ltorch.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
+       "    # t6 = prims.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t5, t3), ()))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{model\\PYZus{}with\\PYZus{}syncs}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{o}{*}\\PY{n}{params}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} x: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} params: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{params}\n", + " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.linear(x, t2, None) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t6}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t6}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.core.prims as prims\n", + "import thunder.distributed.prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def model_with_syncs(x, *params):\n", + " # x: \"cuda:0 f32[64, 64]\" \n", + " # params: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " = params\n", + " t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0) # t2: \"cuda:0 f32[64, 64]\"\n", + " t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = ltorch.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = prims.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = prims.tanh(t4) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = ltorch.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", + " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t5, t3), ()))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trace = thunder.trace()(model_with_syncs, x, *model.parameters())\n", + "\n", + "wrap_as_highlighted_code(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For backward, we don't have to do anything because `thunder` already knows how to compute the backward of `prims.synchronize`. We can verify that by using the `value_and_grad` transform to generate the complete forward and backward trace together.\n", + "\n", + "Observations for the trace below:\n", + "1. `prims.synchronize` from previous trace is now decomposed into `prims.all_gather` and `prims.wait`. So, we can clearly see that we make a communication call to gather the parameter (which is asynchronous) and wait till we have the complete parameter.\n", + "2. At the end of the trace (after the forward and the backward computation), we see calls to `prims.reduce_scatter` and `prims.wait`. This takes care of reducing the gradients across all the GPUs and sharding them. One thing to note, for averaging gradients with low dynamic range dtype like `float16`, if we naively sum the gradients across GPUs before dividing by `world_size`, it can lead to overflows. So we scale the gradient tensor with `world_size`, before calling `reduce_scatter` with `sum` reduction to effectively average the gradients without overflow." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.devices as devices\n",
+       "import thunder.core.dtypes as dtypes\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.distributed.prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def _value_and_grad(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "  t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "  t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "  t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "  p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  t12 = thunder.distributed.prims.wait(p11)  # t12: "cuda:0 f32[64, 64]"\n",
+       "  p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  t14 = thunder.distributed.prims.wait(p13)  # t14: "cuda:0 f32[64, 64]"\n",
+       "  t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "  t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "  t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "  t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "  t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "  t20 = prims.add(t5, t10)  # t20: "cuda:0 f32[64, 64]"\n",
+       "  t21 = ltorch.reshape(t18, -1, 64)  # t21: "cuda:0 f32[64, 64]"\n",
+       "    # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "  t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "    # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "  t23 = ltorch.reshape(t18, -1, 64)  # t23: "cuda:0 f32[64, 64]"\n",
+       "    # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "  t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "  t25 = ltorch.reshape(t16, -1, 64)  # t25: "cuda:0 f32[64, 64]"\n",
+       "    # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "  t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "    # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "  t27 = prims.add(t9, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "  t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "  t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "    # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "  t30 = ltorch.sub(1.0, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "    # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "  t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "    # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "  t32 = ltorch.reshape(t31, -1, 64)  # t32: "cuda:0 f32[64, 64]"\n",
+       "    # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "  t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "    # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "  t34 = ltorch.reshape(t31, -1, 64)  # t34: "cuda:0 f32[64, 64]"\n",
+       "    # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "  t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "  t36 = ltorch.reshape(t0, -1, 64)  # t36: "cuda:0 f32[64, 64]"\n",
+       "    # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "  t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "    # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "  t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "  t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "  t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(2, float)\n",
+       "    # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
+       "  p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  t42 = thunder.distributed.prims.wait(p41)  # t42: "cuda:0 f32[32, 64]"\n",
+       "  t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(2, float)\n",
+       "    # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
+       "  p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  t45 = thunder.distributed.prims.wait(p44)  # t45: "cuda:0 f32[32, 64]"\n",
+       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.core.devices as devices\n", + "import thunder.core.dtypes as dtypes\n", + "import thunder.core.prims as prims\n", + "import thunder.distributed.prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def _value_and_grad(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t12 = thunder.distributed.prims.wait(p11) # t12: \"cuda:0 f32[64, 64]\"\n", + " p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t14 = thunder.distributed.prims.wait(p13) # t14: \"cuda:0 f32[64, 64]\"\n", + " t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " t20 = prims.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", + " t21 = ltorch.reshape(t18, -1, 64) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " t23 = ltorch.reshape(t18, -1, 64) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " t25 = ltorch.reshape(t16, -1, 64) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " t27 = prims.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " t30 = ltorch.sub(1.0, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " t32 = ltorch.reshape(t31, -1, 64) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " t34 = ltorch.reshape(t31, -1, 64) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " t36 = ltorch.reshape(t0, -1, 64) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", + " p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", + " t42 = thunder.distributed.prims.wait(p41) # t42: \"cuda:0 f32[32, 64]\"\n", + " t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", + " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", + " t45 = thunder.distributed.prims.wait(p44) # t45: \"cuda:0 f32[32, 64]\"\n", + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from thunder.core.transforms import value_and_grad\n", + "\n", + "forward_and_backward_model = value_and_grad(model_with_syncs)\n", + "\n", + "forward_backward_trace = thunder.trace()(forward_and_backward_model, x, *model.parameters())\n", + "\n", + "wrap_as_highlighted_code(forward_backward_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above trace, only contains primitive which specifies the semantic of an operation abstractly but doesn't perform the actual computation.\n", + "\n", + "Now we will generate the execution trace which can actually perform the compute.\n", + "\n", + "In the execution trace generated below, we can see that all the primitives have been replaced with actually PyTorch operations. Also, our synchronization primitives have been replaced with PyTorch implementation provided by thunder i.e. `torch_all_gather_prim_impl`, `torch_reduce_scatter_prim_impl`, `torch_wait_prim_impl`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Delete Last Used (took 0 milliseconds)\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def _value_and_grad(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  del args\n",
+       "  t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "    # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "      # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "    # t4 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "      # t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "    # t5 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "      # t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "    # t6 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "      # t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  t7 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "    # t7 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "      # t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "  t8 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "    # t8 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "      # t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "  t9 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "    # t9 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "      # t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "  t10 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "    # t10 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "      # t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "  p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  del t1\n",
+       "  t12 = torch_wait_prim_impl(p11)  # t12: "cuda:0 f32[64, 64]"\n",
+       "  del p11\n",
+       "  p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  del t2\n",
+       "  t14 = torch_wait_prim_impl(p13)  # t14: "cuda:0 f32[64, 64]"\n",
+       "  del p13\n",
+       "  t15 = torch.nn.functional.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "    # t15 = ltorch.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "      # t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "  t16 = torch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "    # t16 = ltorch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "      # t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "  del t15\n",
+       "  t17 = torch.nn.functional.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "    # t17 = ltorch.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "      # t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "  t18 = torch.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "    # t18 = ltorch.add(t6, t7, alpha=None)  # t18: "cuda:0 f32[64, 64]"\n",
+       "      # t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "  del t6, t7\n",
+       "  t19 = torch.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "    # t19 = ltorch.add(t3, t8, alpha=None)  # t19: "cuda:0 f32[64, 64]"\n",
+       "      # t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "  del t3, t8\n",
+       "  t20 = torch.add(t5, t10)  # t20: "cuda:0 f32[64, 64]"\n",
+       "    # t20 = ltorch.add(t5, t10, alpha=None)  # t20: "cuda:0 f32[64, 64]"\n",
+       "      # t20 = prims.add(t5, t10)  # t20: "cuda:0 f32[64, 64]"\n",
+       "  del t5, t10\n",
+       "  t21 = torch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "    # t21 = ltorch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "      # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "  t22 = torch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "    # t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "      # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "  del t21\n",
+       "  t23 = torch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "    # t23 = ltorch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "      # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "  del t18\n",
+       "  t24 = torch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "    # t24 = ltorch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "      # t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "  del t23\n",
+       "  t25 = torch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "    # t25 = ltorch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "      # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "  t26 = torch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "    # t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "      # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "  del t24, t25\n",
+       "  t27 = torch.add(t9, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "    # t27 = ltorch.add(t9, t22, alpha=None)  # t27: "cuda:0 f32[64, 64]"\n",
+       "      # t27 = prims.add(t9, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "  del t9, t22\n",
+       "  t28 = torch.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "    # t28 = ltorch.add(t20, t26, alpha=None)  # t28: "cuda:0 f32[64, 64]"\n",
+       "      # t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "  del t20, t26\n",
+       "  t29 = torch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "    # t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "      # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "  t30 = torch.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "    # t30 = ltorch.sub(1.0, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "      # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "  del t29\n",
+       "  t31 = torch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "    # t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "      # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "  del t27, t30\n",
+       "  t32 = torch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "    # t32 = ltorch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "      # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "  t33 = torch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "    # t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "      # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "  del t32\n",
+       "  t34 = torch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "    # t34 = ltorch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "      # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "  del t31\n",
+       "  t35 = torch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "    # t35 = ltorch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "      # t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "  del t34\n",
+       "  t36 = torch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "    # t36 = ltorch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "      # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "  t37 = torch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "    # t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "      # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "  del t35, t36\n",
+       "  t38 = torch.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "    # t38 = ltorch.add(t19, t33, alpha=None)  # t38: "cuda:0 f32[64, 64]"\n",
+       "      # t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "  del t19, t33\n",
+       "  t39 = torch.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "    # t39 = ltorch.add(t4, t37, alpha=None)  # t39: "cuda:0 f32[64, 64]"\n",
+       "      # t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "  del t4, t37\n",
+       "  t40 = torch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "    # t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "      # _ = prims.convert_element_type(2, float)\n",
+       "      # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
+       "  del t28\n",
+       "  p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  del t40\n",
+       "  t42 = torch_wait_prim_impl(p41)  # t42: "cuda:0 f32[32, 64]"\n",
+       "  del p41\n",
+       "  t43 = torch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "    # t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "      # _ = prims.convert_element_type(2, float)\n",
+       "      # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
+       "  del t39\n",
+       "  p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  del t43\n",
+       "  t45 = torch_wait_prim_impl(p44)  # t45: "cuda:0 f32[32, 64]"\n",
+       "  del p44\n",
+       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{k}{del} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t7 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t7 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t8 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t9 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t10 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t1}\n", + " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p11}\n", + " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t2}\n", + " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p13}\n", + " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t15 = prims.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t16 = prims.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t15}\n", + " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t17 = prims.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.add(t6, t7, alpha=None) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t18 = prims.add(t6, t7) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\n", + " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.add(t3, t8, alpha=None) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t19 = prims.add(t3, t8) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t10, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t10) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\n", + " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t21}\n", + " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t18}\n", + " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.permute(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t24 = prims.transpose(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t23}\n", + " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.reshape(t16, (\\PYZhy{}1, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t9, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t9, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\n", + " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t28 = ltorch.add(t20, t26, alpha=None) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t28 = prims.add(t20, t26) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\n", + " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1.0, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t29}\n", + " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = ltorch.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\n", + " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = ltorch.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t32}\n", + " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t31}\n", + " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t35 = ltorch.permute(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t35 = prims.transpose(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t34}\n", + " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = ltorch.reshape(t0, (\\PYZhy{}1, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = ltorch.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\n", + " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t38 = ltorch.add(t19, t33, alpha=None) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t38 = prims.add(t19, t33) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\n", + " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t39 = ltorch.add(t4, t37, alpha=None) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t39 = prims.add(t4, t37) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\n", + " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = ltorch.true\\PYZus{}divide(t28, 2) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t28}\n", + " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t40}\n", + " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p41}\n", + " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = ltorch.true\\PYZus{}divide(t39, 2) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t39}\n", + " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t43}\n", + " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p44}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def _value_and_grad(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " del args\n", + " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " t7 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " # t7 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " # t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " t8 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " # t8 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " # t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " t9 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " # t9 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " # t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " t10 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " # t10 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " # t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", + " del t1\n", + " t12 = torch_wait_prim_impl(p11) # t12: \"cuda:0 f32[64, 64]\"\n", + " del p11\n", + " p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", + " del t2\n", + " t14 = torch_wait_prim_impl(p13) # t14: \"cuda:0 f32[64, 64]\"\n", + " del p13\n", + " t15 = torch.nn.functional.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " # t15 = ltorch.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " # t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " t16 = torch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " # t16 = ltorch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " # t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " del t15\n", + " t17 = torch.nn.functional.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " # t17 = ltorch.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " # t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " t18 = torch.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " # t18 = ltorch.add(t6, t7, alpha=None) # t18: \"cuda:0 f32[64, 64]\"\n", + " # t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " del t6, t7\n", + " t19 = torch.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " # t19 = ltorch.add(t3, t8, alpha=None) # t19: \"cuda:0 f32[64, 64]\"\n", + " # t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " del t3, t8\n", + " t20 = torch.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = ltorch.add(t5, t10, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = prims.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", + " del t5, t10\n", + " t21 = torch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " t22 = torch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " del t21\n", + " t23 = torch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = ltorch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " del t18\n", + " t24 = torch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " # t24 = ltorch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " # t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " del t23\n", + " t25 = torch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = ltorch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " t26 = torch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " del t24, t25\n", + " t27 = torch.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = ltorch.add(t9, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = prims.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " del t9, t22\n", + " t28 = torch.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " # t28 = ltorch.add(t20, t26, alpha=None) # t28: \"cuda:0 f32[64, 64]\"\n", + " # t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " del t20, t26\n", + " t29 = torch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " t30 = torch.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = ltorch.sub(1.0, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " del t29\n", + " t31 = torch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " del t27, t30\n", + " t32 = torch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = ltorch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = torch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " del t32\n", + " t34 = torch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = ltorch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " del t31\n", + " t35 = torch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " # t35 = ltorch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " # t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " del t34\n", + " t36 = torch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = ltorch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " t37 = torch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " del t35, t36\n", + " t38 = torch.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " # t38 = ltorch.add(t19, t33, alpha=None) # t38: \"cuda:0 f32[64, 64]\"\n", + " # t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " del t19, t33\n", + " t39 = torch.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t39 = ltorch.add(t4, t37, alpha=None) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " del t4, t37\n", + " t40 = torch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", + " del t28\n", + " p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t40\n", + " t42 = torch_wait_prim_impl(p41) # t42: \"cuda:0 f32[32, 64]\"\n", + " del p41\n", + " t43 = torch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", + " del t39\n", + " p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t43\n", + " t45 = torch_wait_prim_impl(p44) # t45: \"cuda:0 f32[32, 64]\"\n", + " del p44\n", + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n", + "\n", + "# Grab the final trace\n", + "exec_trace = optimized_trace[-1]\n", + "wrap_as_highlighted_code(exec_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 4 : Running the actual computation\n", + "\n", + "Running the actual computation will require setting up 2 processes and running our above code in both those processes (which can be tricky with Jupyter Notebook). Instead, we will write a small script and run it with `torchrun` which takes care of setting up the processes and relevant state.\n", + "\n", + "**NOTE**: This requires device running this notebook to have at least 2-GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the example below, we will use `thunder.distributed.fsdp` which does the same as what we did above (with some extra checks). The code below should look familiar as it is roughly all the above pieces in a single script. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting thunder_fsdp_simple_example.py\n" + ] + } + ], + "source": [ + "%%writefile thunder_fsdp_simple_example.py\n", + "\n", + "# imports\n", + "from thunder.tests.lit_gpt_model import GPT, Config\n", + "import torch\n", + "import torch.distributed\n", + "import thunder\n", + "import thunder.distributed\n", + "import os\n", + "\n", + "# # # # # # # #\n", + "# Create Model\n", + "# # # # # # # #\n", + "\n", + "# NOTE: We create the model on CPU.\n", + "device='cpu'\n", + "dim = 64\n", + "def create_model():\n", + " layers = []\n", + " layers.append(torch.nn.Linear(dim, dim))\n", + " layers.append(torch.nn.ReLU())\n", + " layers.append(torch.nn.Linear(dim, dim))\n", + " return torch.nn.Sequential(*layers).to(device)\n", + "\n", + "# Model\n", + "model = create_model()\n", + "# Input\n", + "x = torch.randn(dim, dim, device=device)\n", + "\n", + "# # # # # # # #\n", + "# Setup for distributed\n", + "# # # # # # # #\n", + "torch.distributed.init_process_group(backend='nccl')\n", + "\n", + "rank = int(os.environ[\"LOCAL_RANK\"])\n", + "\n", + "device = f\"cuda:{rank}\"\n", + "\n", + "# # # # # # # #\n", + "# Move inputs to correct device\n", + "# # # # # # # #\n", + "x = x.to(device)\n", + "\n", + "# # # # # # # #\n", + "# Wrap the model in thunder.distributed.fsdp\n", + "# # # # # # # #\n", + "\n", + "# thunder.distributed.fsdp takes care of moving the parameter\n", + "# shard to the correct GPU for the current process.\n", + "cmodel = thunder.jit(thunder.distributed.fsdp(model))\n", + "\n", + "# Run the forward pass.\n", + "cmodel(x)\n", + "\n", + "# # # # # # # #\n", + "# Check the traces\n", + "# # # # # # # #\n", + "fwd_traces, bwd_traces = thunder.last_traces(cmodel)\n", + "\n", + "# # # # # # # #\n", + "# Print and check to see if they match ours\n", + "# # # # # # # #\n", + "if rank == 0:\n", + " print(fwd_traces[-1])\n", + " print(\"*******\"* 8)\n", + " print(bwd_traces[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us run the above script and check what the trace looks like.\n", + "\n", + "We can observe that forward trace has `torch_all_gather_prim_impl` to gather the parameter before forward pass and the backward trace has `torch_reduce_scatter_prim_impl` to reduce and scatter the gradients back to different GPUs. This is similar to our implementation above." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] \n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(input, t_0_bias, t_2_bias, t_0_weight, t_2_weight):\n", + " # input: \"cuda:0 f32[64, 64]\" \n", + " # t_0_bias: \"cuda:0 f32[32]\" \n", + " p0 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p0: \"FUTURE cuda:0 f32[64]\"\n", + " # t_2_bias: \"cuda:0 f32[32]\" \n", + " p2 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p2: \"FUTURE cuda:0 f32[64]\"\n", + " # t_0_weight: \"cuda:0 f32[32, 64]\" \n", + " p4 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p4: \"FUTURE cuda:0 f32[64, 64]\"\n", + " # t_2_weight: \"cuda:0 f32[32, 64]\" \n", + " p9 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p9: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t1 = torch_wait_prim_impl(p0) # t1: \"cuda:0 f32[64]\"\n", + " del p0\n", + " t3 = torch_wait_prim_impl(p2) # t3: \"cuda:0 f32[64]\"\n", + " del p2\n", + " t5 = torch_wait_prim_impl(p4) # t5: \"cuda:0 f32[64, 64]\"\n", + " del p4\n", + " t6 = torch.nn.functional.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = ltorch.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " del t5, t1\n", + " [t7, t8] = nvFusion0(t6)\n", + " # t7 = prims.gt(t6, 0.0) # t7: \"cuda:0 b8[64, 64]\"\n", + " # t8 = prims.where(t7, t6, 0.0) # t8: \"cuda:0 f32[64, 64]\"\n", + " del t6\n", + " t10 = torch_wait_prim_impl(p9) # t10: \"cuda:0 f32[64, 64]\"\n", + " del p9\n", + " t11 = torch.nn.functional.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " # t11 = ltorch.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " # t11 = prims.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " del t3\n", + " return {'output': t11, 'flat_args': [input, t_0_bias, t_2_bias, t_0_weight, t_2_weight], 'flat_output': (t11,)}, ((input, t10, t7, t8), ())\n", + "********************************************************\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def backward_fn(saved_for_backward, cotangents):\n", + " # saved_for_backward: \"Collection\" \n", + " # cotangents: \"Collection\" \n", + " C0, \\\n", + " _, \\\n", + " = saved_for_backward\n", + " clear_collection(saved_for_backward)\n", + " del saved_for_backward\n", + " t0, \\\n", + " = cotangents\n", + " clear_collection(cotangents)\n", + " del cotangents\n", + " input, \\\n", + " t10, \\\n", + " t7, \\\n", + " t8, \\\n", + " = C0\n", + " clear_collection(C0)\n", + " del C0\n", + " t31 = torch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = ltorch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.reshape(t0, (64, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " t32 = torch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = ltorch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.transpose(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = torch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = ltorch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.reshape(t8, (64, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " del t8\n", + " t45 = torch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " # t45 = ltorch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " # t45 = prims.reshape(input, (64, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " del input\n", + " [t51] = nvFusion0(t0)\n", + " # t35 = prims.sum(t0, (0,)) # t35: \"cuda:0 f32[64]\"\n", + " # t51 = prims.div(t35, 2.0) # t51: \"cuda:0 f32[64]\"\n", + " del t0\n", + " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52: \"FUTURE cuda:0 f32[32]\"\n", + " del t51\n", + " t30 = torch.matmul(t31, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = ltorch.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = prims.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " del t31, t10\n", + " t34 = torch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = ltorch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " del t32, t33\n", + " [t36, t39, t54] = nvFusion1(t30, t34, t7)\n", + " # t39 = prims.where(t7, t30, 0.0) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t47 = prims.sum(t39, (0,)) # t47: \"cuda:0 f32[64]\"\n", + " # t54 = prims.div(t47, 2.0) # t54: \"cuda:0 f32[64]\"\n", + " # t36 = prims.div(t34, 2.0) # t36: \"cuda:0 f32[64, 64]\"\n", + " del t30, t34, t7\n", + " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t36\n", + " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55: \"FUTURE cuda:0 f32[32]\"\n", + " del t54\n", + " t43 = torch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = ltorch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = prims.reshape(t39, (64, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " del t39\n", + " t44 = torch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " # t44 = ltorch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " # t44 = prims.transpose(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " del t43\n", + " t46 = torch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " # t46 = ltorch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " # t46 = prims.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " del t44, t45\n", + " [t48] = nvFusion2(t46)\n", + " # t48 = prims.div(t46, 2.0) # t48: \"cuda:0 f32[64, 64]\"\n", + " del t46\n", + " p49 = torch_reduce_scatter_prim_impl(t48, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p49: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t48\n", + " t53 = torch_wait_prim_impl(p52) # t53: \"cuda:0 f32[32]\"\n", + " del p52\n", + " t38 = torch_wait_prim_impl(p37) # t38: \"cuda:0 f32[32, 64]\"\n", + " del p37\n", + " t56 = torch_wait_prim_impl(p55) # t56: \"cuda:0 f32[32]\"\n", + " del p55\n", + " t50 = torch_wait_prim_impl(p49) # t50: \"cuda:0 f32[32, 64]\"\n", + " del p49\n", + " return (None, t56, t53, t50, t38)\n" + ] + } + ], + "source": [ + "!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Conclusion\n", + "\n", + "We have created our implementation of FSDP to shard our model across multiple GPUs. In the process, we also learned that:\n", + "\n", + "1. `thunder` provides us with primitives for synchronization across mutiple GPUs.\n", + "2. `thunder` also takes care of implementing the backward support for the synchronization primitives, so we don't have to explicitly do anything to get the backward working.\n", + "3. We can just easily apply `thunder.distributed.fsdp` to our model and it will take care of sharding the parameters and also adding synchronizations to our model. Also, we can easily check the modifications by inspecting the traces." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/fsdp_tutorial.ipynb b/notebooks/fsdp_tutorial.ipynb deleted file mode 100644 index 71ed1b1005..0000000000 --- a/notebooks/fsdp_tutorial.ipynb +++ /dev/null @@ -1,1489 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## FSDP Tutorial\n", - "\n", - "In this tutorial, we will walk through the implementation of Fully Sharded Data Parallel (FSDP) with Zero2 sharding strategy in `thunder`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Introduction\n", - "\n", - "In recent times, the LLM models have grown so large that all the model parameters don't fit on a single GPU. To circumvent this problem, there are various strategies like Tensor Parallel, Pipeline Parallel, Fully Sharded Data Parallel, etc to train these large models. In this tutorial, we discuss and implement Zero2 strategy for Fully Sharded Data Parallel (FSDP).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### What is Zero2 strategy for FSDP?\n", - "\n", - "In this strategy, we shard the model parameters across all the availabe GPUs. That is each GPU holds onto only a chunk of the parameter. During the forward pass, all GPUs call `all_gather` communication primitive to gather the parameters from other GPUs. Unlike Zero3 strategy which frees the parameter after forward pass, we save these unsharded parameters for backward pass. This is to save the overhead of extra communication. In the backward pass, we utilize the saved parameters and compute the gradients. Once the gradients are computed, we use `reduce_scatter` communication primitive to reduce (average) the gradients across all GPUs and scatter those gradients so that a given GPU holds only a chunk of gradient.\n", - "\n", - "For more information on FSDP, we recommend reading\n", - "\n", - "1. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel - [Link](https://arxiv.org/abs/2304.11277)\n", - "2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models - [Link](https://arxiv.org/abs/1910.02054)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Example Model\n", - "\n", - "For this example we will have a simple model `Linear(Tanh(Linear(x)))` which will be sharded over 2 GPUs\n", - "\n", - "**NOTE**: We are generating the abstract trace so we don't actually need a system with 2 GPUs for this. It is only required when we execute this trace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", - "from IPython.display import Code" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "device='cuda'\n", - "dim = 64\n", - "def create_model():\n", - " layers = [torch.nn.Linear(dim, dim, bias=False),\n", - " torch.nn.Tanh(),\n", - " torch.nn.Linear(dim, dim, bias=False)]\n", - " return torch.nn.Sequential(*layers).to(device)\n", - "\n", - "# Model\n", - "model = create_model()\n", - "# Input\n", - "x = torch.randn(dim, dim, device=device)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def wrap_as_highlighted_code(trace):\n", - " return Code(str(trace), language=\"python\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 1 : Configuration\n", - "\n", - "For our implementation of FSDP, we will generate the trace where we are sharding our model over 2 GPU" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# FSDP Config \n", - "# Usually these values are set in the environment by `torchrun` but for this example\n", - "# we will set them ourselves\n", - "world_size = 2 # We have two processes.\n", - "global_rank = 0 # Current process is the very first process." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 2: Function to shard parameters\n", - "\n", - "Next step is to write a function which will actually shard the parameters over 0-dim." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# NOTE: We shard over 0th dimension of the param.\n", - "def shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:\n", - " # We will keep it simple and error if param's 0th dim is not divisible by ``world_size``.\n", - " # Alternative is that we can pad our parameters so that they are divisible by `world_size`.\n", - " assert param.shape[0] % world_size == 0,(\n", - " f\"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})\"\n", - " f\" to be divisible by the world size ({world_size})\"\n", - " )\n", - " chunk_size = param.shape[0] // world_size\n", - "\n", - " # rank helps us determine which chunk of the parameter we will hold.\n", - " shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()\n", - " param.data = shard\n", - "\n", - "# Shard each parameter of the model\n", - "for param_name, param in model.named_parameters():\n", - " shard_param(param, global_rank, world_size, param_name)\n", - " # Mark the param to denote that it is sharded.\n", - " # This is required by the synchronization primitive we will use below.\n", - " param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential(\n", - " (0): Linear(in_features=64, out_features=64, bias=False)\n", - " (1): Tanh()\n", - " (2): Linear(in_features=64, out_features=64, bias=False)\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Verify our model looks as expected\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Let us verify that we have actually sharded the parameters.\n", - "# Checking if the weight of 1st Linear layer is sharded over 0th dim.\n", - "assert model[0].weight.shape == (dim / world_size, dim)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 3: Add an operation to synchronize the parameters before calling the model.forward." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We have to create a process group. This is needed because the synchronization primitive `synchronize` that we will use to gather and scatter our weights in forward and backward requires a process group." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a process group\n", - "options = torch.distributed.distributed_c10d.ProcessGroup.Options(backend=\"nccl\")\n", - "process_group = torch.distributed.distributed_c10d.ProcessGroup(torch.distributed.distributed_c10d.Store(),\n", - " global_rank, world_size, options)\n", - "torch.distributed.distributed_c10d.GroupMember.WORLD = process_group" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# `preprocess` gives us the functional version of the model which\n", - "# takes as inputs all the parameters and the expected arguments.\n", - "# NOTE: `thunder.common.preprocess` is not meant for general use\n", - "# and used only for brevity of code. It will be updated\n", - "# to a newer mechanism which is meant to be public facing. \n", - "functional_forward = thunder.common.preprocess(model, is_module=True)\n", - "\n", - "# This function creates a model with synchronization\n", - "# before calling the forward pass.\n", - "def model_with_syncs(*params, x):\n", - " # We call `prims.synchronize` on all the parameters.\n", - " # This is essentially calling `all_gather` so that we have the complete\n", - " # parameter before we actually to the forward computation.\n", - " unsharded_params = []\n", - " for param in params:\n", - " unsharded_params.append(thunder.distributed.prims.synchronize(param, process_group))\n", - "\n", - " return functional_forward(*unsharded_params, x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let us now see what the trace of our model looks like with all the synchronization.\n", - "\n", - "Two main observations regarding the below trace \n", - "1. We can observe the `prims.synchronize` that we inserted using `model_with_syncs`.\n", - "2. Output of the `prims.synchronize` have the shape of unsharded (original) parameter.\n", - "\n", - "With this, we have implemented the FSDP for the forward pass of our model." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def model_with_syncs(*params, x):\n",
-       "  # params \n",
-       "  # x \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = params\n",
-       "  t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t2\n",
-       "  t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t3\n",
-       "  t4 = ltorch.linear(x, t2, None)  # t4\n",
-       "    # t4 = prims.linear(x, t2, None)  # t4\n",
-       "  t5 = ltorch.tanh(t4)  # t5\n",
-       "    # t5 = prims.tanh(t4)  # t5\n",
-       "  t6 = ltorch.linear(t5, t3, None)  # t6\n",
-       "    # t6 = prims.linear(t5, t3, None)  # t6\n",
-       "  return t6\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{model\\PYZus{}with\\PYZus{}syncs}\\PY{p}{(}\\PY{o}{*}\\PY{n}{params}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} params }\n", - " \\PY{c+c1}{\\PYZsh{} x }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{params}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = prims.linear(x, t2, None) \\PYZsh{} t4}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = prims.tanh(t4) \\PYZsh{} t5}\n", - " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6}\n", - " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6}\n", - " \\PY{k}{return} \\PY{n}{t6}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def model_with_syncs(*params, x):\n", - " # params \n", - " # x \n", - " t0, \\\n", - " t1, \\\n", - " = params\n", - " t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0) # t2\n", - " t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0) # t3\n", - " t4 = ltorch.linear(x, t2, None) # t4\n", - " # t4 = prims.linear(x, t2, None) # t4\n", - " t5 = ltorch.tanh(t4) # t5\n", - " # t5 = prims.tanh(t4) # t5\n", - " t6 = ltorch.linear(t5, t3, None) # t6\n", - " # t6 = prims.linear(t5, t3, None) # t6\n", - " return t6" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trace = thunder.trace()(model_with_syncs, *model.parameters(), x=x)\n", - "\n", - "wrap_as_highlighted_code(trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For backward, we don't have to do anything because `thunder` already knows how to compute the backward of `prims.synchronize`. We can verify that by using the `value_and_grad` transform to generate the complete forward and backward trace together.\n", - "\n", - "Observations for the trace below:\n", - "1. `prims.synchronize` from previous trace is now decomposed into `prims.all_gather` and `prims.wait`. So, we can clearly see that we make a communication call to gather the parameter (which is asynchronous) and wait till we have the complete parameter.\n", - "2. At the end of the trace (after the forward and the backward computation), we see calls to `prims.reduce_scatter` and `prims.wait`. This takes care of reducing the gradients across all the GPUs and sharding them. One thing to note, for averaging gradients with low dynamic range dtype like `float16`, if we naively sum the gradients across GPUs before dividing by `world_size`, it can lead to overflows. So we scale the gradient tensor with `world_size`, before calling `reduce_scatter` with `sum` reduction to effectively average the gradients without overflow." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.core.devices as devices\n",
-       "import thunder.core.dtypes as dtypes\n",
-       "import thunder.core.prims as prims\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args, **kwargs):\n",
-       "  # args \n",
-       "  # kwargs \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = args\n",
-       "  t2 = kwargs['x']\n",
-       "  t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3\n",
-       "  p4 = thunder.distributed.prims.all_gather(t0, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p4\n",
-       "  t5 = thunder.distributed.prims.wait(p4)  # t5\n",
-       "  p6 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p6\n",
-       "  t7 = thunder.distributed.prims.wait(p6)  # t7\n",
-       "  t8 = prims.linear(t2, t5, None)  # t8\n",
-       "  t9 = prims.tanh(t8)  # t9\n",
-       "  t10 = prims.linear(t9, t7, None)  # t10\n",
-       "  t11 = ltorch.reshape(t3, -1, 64)  # t11\n",
-       "    # t11 = prims.reshape(t3, (64, 64))  # t11\n",
-       "  t12 = ltorch.matmul(t11, t7)  # t12\n",
-       "    # t12 = prims.matmul(t11, t7)  # t12\n",
-       "  t13 = ltorch.reshape(t3, -1, 64)  # t13\n",
-       "    # t13 = prims.reshape(t3, (64, 64))  # t13\n",
-       "  t14 = prims.transpose(t13, (1, 0))  # t14\n",
-       "  t15 = ltorch.reshape(t9, -1, 64)  # t15\n",
-       "    # t15 = prims.reshape(t9, (64, 64))  # t15\n",
-       "  t16 = ltorch.matmul(t14, t15)  # t16\n",
-       "    # t16 = prims.matmul(t14, t15)  # t16\n",
-       "  t17 = ltorch.mul(t9, t9)  # t17\n",
-       "    # t17 = prims.mul(t9, t9)  # t17\n",
-       "  t18 = ltorch.sub(1.0, t17, alpha=None)  # t18\n",
-       "    # t18 = prims.sub(1.0, t17)  # t18\n",
-       "  t19 = ltorch.mul(t12, t18)  # t19\n",
-       "    # t19 = prims.mul(t12, t18)  # t19\n",
-       "  t20 = ltorch.reshape(t19, -1, 64)  # t20\n",
-       "    # t20 = prims.reshape(t19, (64, 64))  # t20\n",
-       "  t21 = ltorch.matmul(t20, t5)  # t21\n",
-       "    # t21 = prims.matmul(t20, t5)  # t21\n",
-       "  t22 = ltorch.reshape(t19, -1, 64)  # t22\n",
-       "    # t22 = prims.reshape(t19, (64, 64))  # t22\n",
-       "  t23 = prims.transpose(t22, (1, 0))  # t23\n",
-       "  t24 = ltorch.reshape(t2, -1, 64)  # t24\n",
-       "    # t24 = prims.reshape(t2, (64, 64))  # t24\n",
-       "  t25 = ltorch.matmul(t23, t24)  # t25\n",
-       "    # t25 = prims.matmul(t23, t24)  # t25\n",
-       "  t26 = ltorch.true_divide(t16, 2)  # t26\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t26 = prims.div(t16, 2.0)  # t26\n",
-       "  p27 = thunder.distributed.prims.reduce_scatter(t26, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p27\n",
-       "  t28 = thunder.distributed.prims.wait(p27)  # t28\n",
-       "  t29 = ltorch.true_divide(t25, 2)  # t29\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t29 = prims.div(t25, 2.0)  # t29\n",
-       "  p30 = thunder.distributed.prims.reduce_scatter(t29, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p30\n",
-       "  t31 = thunder.distributed.prims.wait(p30)  # t31\n",
-       "  return (t10, (t31, t28, {'x': t21}))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{,} \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args }\n", - " \\PY{c+c1}{\\PYZsh{} kwargs }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{kwargs}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{n}{p4} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p4}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{n}{p6} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p6}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p6}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10}\n", - " \\PY{n}{t11} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = prims.reshape(t3, (64, 64)) \\PYZsh{} t11}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = prims.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{n}{t13} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = prims.reshape(t3, (64, 64)) \\PYZsh{} t13}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t13}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = prims.reshape(t9, (64, 64)) \\PYZsh{} t15}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = prims.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = prims.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t17}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = prims.sub(1.0, t17) \\PYZsh{} t18}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = prims.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.reshape(t19, (64, 64)) \\PYZsh{} t20}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.reshape(t19, (64, 64)) \\PYZsh{} t22}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t22}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = prims.reshape(t2, (64, 64)) \\PYZsh{} t24}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.div(t16, 2.0) \\PYZsh{} t26}\n", - " \\PY{n}{p27} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t26}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p27}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p27}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t25}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.div(t25, 2.0) \\PYZsh{} t29}\n", - " \\PY{n}{p30} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t29}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p30}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{n}{t28}\\PY{p}{,} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t21}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.core.devices as devices\n", - "import thunder.core.dtypes as dtypes\n", - "import thunder.core.prims as prims\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args, **kwargs):\n", - " # args \n", - " # kwargs \n", - " t0, \\\n", - " t1, \\\n", - " = args\n", - " t2 = kwargs['x']\n", - " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3\n", - " p4 = thunder.distributed.prims.all_gather(t0, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p4\n", - " t5 = thunder.distributed.prims.wait(p4) # t5\n", - " p6 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p6\n", - " t7 = thunder.distributed.prims.wait(p6) # t7\n", - " t8 = prims.linear(t2, t5, None) # t8\n", - " t9 = prims.tanh(t8) # t9\n", - " t10 = prims.linear(t9, t7, None) # t10\n", - " t11 = ltorch.reshape(t3, -1, 64) # t11\n", - " # t11 = prims.reshape(t3, (64, 64)) # t11\n", - " t12 = ltorch.matmul(t11, t7) # t12\n", - " # t12 = prims.matmul(t11, t7) # t12\n", - " t13 = ltorch.reshape(t3, -1, 64) # t13\n", - " # t13 = prims.reshape(t3, (64, 64)) # t13\n", - " t14 = prims.transpose(t13, (1, 0)) # t14\n", - " t15 = ltorch.reshape(t9, -1, 64) # t15\n", - " # t15 = prims.reshape(t9, (64, 64)) # t15\n", - " t16 = ltorch.matmul(t14, t15) # t16\n", - " # t16 = prims.matmul(t14, t15) # t16\n", - " t17 = ltorch.mul(t9, t9) # t17\n", - " # t17 = prims.mul(t9, t9) # t17\n", - " t18 = ltorch.sub(1.0, t17, alpha=None) # t18\n", - " # t18 = prims.sub(1.0, t17) # t18\n", - " t19 = ltorch.mul(t12, t18) # t19\n", - " # t19 = prims.mul(t12, t18) # t19\n", - " t20 = ltorch.reshape(t19, -1, 64) # t20\n", - " # t20 = prims.reshape(t19, (64, 64)) # t20\n", - " t21 = ltorch.matmul(t20, t5) # t21\n", - " # t21 = prims.matmul(t20, t5) # t21\n", - " t22 = ltorch.reshape(t19, -1, 64) # t22\n", - " # t22 = prims.reshape(t19, (64, 64)) # t22\n", - " t23 = prims.transpose(t22, (1, 0)) # t23\n", - " t24 = ltorch.reshape(t2, -1, 64) # t24\n", - " # t24 = prims.reshape(t2, (64, 64)) # t24\n", - " t25 = ltorch.matmul(t23, t24) # t25\n", - " # t25 = prims.matmul(t23, t24) # t25\n", - " t26 = ltorch.true_divide(t16, 2) # t26\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t26 = prims.div(t16, 2.0) # t26\n", - " p27 = thunder.distributed.prims.reduce_scatter(t26, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p27\n", - " t28 = thunder.distributed.prims.wait(p27) # t28\n", - " t29 = ltorch.true_divide(t25, 2) # t29\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t29 = prims.div(t25, 2.0) # t29\n", - " p30 = thunder.distributed.prims.reduce_scatter(t29, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p30\n", - " t31 = thunder.distributed.prims.wait(p30) # t31\n", - " return (t10, (t31, t28, {'x': t21}))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from thunder.core.transforms import value_and_grad\n", - "\n", - "forward_and_backward_model = value_and_grad(model_with_syncs)\n", - "\n", - "forward_backward_trace = thunder.trace()(forward_and_backward_model, *model.parameters(), x=x)\n", - "\n", - "wrap_as_highlighted_code(forward_backward_trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above trace, only contains primitive which specifies the semantic of an operation abstractly but doesn't perform the actual computation.\n", - "\n", - "Now we will generate the execution trace which can actually perform the compute.\n", - "\n", - "In the execution trace generated below, we can see that all the primitives have been replaced with actually PyTorch operations. Also, our synchronization primitives have been replaced with PyTorch implementation provided by thunder i.e. `torch_all_gather_prim_impl`, `torch_reduce_scatter_prim_impl`, `torch_wait_prim_impl`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Delete Last Used (took 0 milliseconds)\n",
-       "import torch\n",
-       "import torch.nn.functional\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args, **kwargs):\n",
-       "  # args \n",
-       "  # kwargs \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = args\n",
-       "  del args\n",
-       "  t2 = kwargs['x']\n",
-       "  del kwargs\n",
-       "  t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3\n",
-       "    # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3\n",
-       "      # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3\n",
-       "  p4 = torch_all_gather_prim_impl(t0, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p4\n",
-       "  del t0\n",
-       "  t5 = torch_wait_prim_impl(p4)  # t5\n",
-       "  del p4\n",
-       "  p6 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p6\n",
-       "  del t1\n",
-       "  t7 = torch_wait_prim_impl(p6)  # t7\n",
-       "  del p6\n",
-       "  t8 = torch.nn.functional.linear(t2, t5, None)  # t8\n",
-       "    # t8 = ltorch.linear(t2, t5, None)  # t8\n",
-       "      # t8 = prims.linear(t2, t5, None)  # t8\n",
-       "  t9 = torch.tanh(t8)  # t9\n",
-       "    # t9 = ltorch.tanh(t8)  # t9\n",
-       "      # t9 = prims.tanh(t8)  # t9\n",
-       "  del t8\n",
-       "  t10 = torch.nn.functional.linear(t9, t7, None)  # t10\n",
-       "    # t10 = ltorch.linear(t9, t7, None)  # t10\n",
-       "      # t10 = prims.linear(t9, t7, None)  # t10\n",
-       "  t11 = torch.reshape(t3, (-1, 64))  # t11\n",
-       "    # t11 = ltorch.reshape(t3, (-1, 64))  # t11\n",
-       "      # t11 = prims.reshape(t3, (64, 64))  # t11\n",
-       "  t12 = torch.matmul(t11, t7)  # t12\n",
-       "    # t12 = ltorch.matmul(t11, t7)  # t12\n",
-       "      # t12 = prims.matmul(t11, t7)  # t12\n",
-       "  del t11, t7\n",
-       "  t13 = torch.reshape(t3, (-1, 64))  # t13\n",
-       "    # t13 = ltorch.reshape(t3, (-1, 64))  # t13\n",
-       "      # t13 = prims.reshape(t3, (64, 64))  # t13\n",
-       "  del t3\n",
-       "  t14 = torch.permute(t13, (1, 0))  # t14\n",
-       "    # t14 = ltorch.permute(t13, (1, 0))  # t14\n",
-       "      # t14 = prims.transpose(t13, (1, 0))  # t14\n",
-       "  del t13\n",
-       "  t15 = torch.reshape(t9, (-1, 64))  # t15\n",
-       "    # t15 = ltorch.reshape(t9, (-1, 64))  # t15\n",
-       "      # t15 = prims.reshape(t9, (64, 64))  # t15\n",
-       "  t16 = torch.matmul(t14, t15)  # t16\n",
-       "    # t16 = ltorch.matmul(t14, t15)  # t16\n",
-       "      # t16 = prims.matmul(t14, t15)  # t16\n",
-       "  del t14, t15\n",
-       "  t17 = torch.mul(t9, t9)  # t17\n",
-       "    # t17 = ltorch.mul(t9, t9)  # t17\n",
-       "      # t17 = prims.mul(t9, t9)  # t17\n",
-       "  del t9\n",
-       "  t18 = torch.sub(1.0, t17)  # t18\n",
-       "    # t18 = ltorch.sub(1.0, t17, alpha=None)  # t18\n",
-       "      # t18 = prims.sub(1.0, t17)  # t18\n",
-       "  del t17\n",
-       "  t19 = torch.mul(t12, t18)  # t19\n",
-       "    # t19 = ltorch.mul(t12, t18)  # t19\n",
-       "      # t19 = prims.mul(t12, t18)  # t19\n",
-       "  del t12, t18\n",
-       "  t20 = torch.reshape(t19, (-1, 64))  # t20\n",
-       "    # t20 = ltorch.reshape(t19, (-1, 64))  # t20\n",
-       "      # t20 = prims.reshape(t19, (64, 64))  # t20\n",
-       "  t21 = torch.matmul(t20, t5)  # t21\n",
-       "    # t21 = ltorch.matmul(t20, t5)  # t21\n",
-       "      # t21 = prims.matmul(t20, t5)  # t21\n",
-       "  del t20, t5\n",
-       "  t22 = torch.reshape(t19, (-1, 64))  # t22\n",
-       "    # t22 = ltorch.reshape(t19, (-1, 64))  # t22\n",
-       "      # t22 = prims.reshape(t19, (64, 64))  # t22\n",
-       "  del t19\n",
-       "  t23 = torch.permute(t22, (1, 0))  # t23\n",
-       "    # t23 = ltorch.permute(t22, (1, 0))  # t23\n",
-       "      # t23 = prims.transpose(t22, (1, 0))  # t23\n",
-       "  del t22\n",
-       "  t24 = torch.reshape(t2, (-1, 64))  # t24\n",
-       "    # t24 = ltorch.reshape(t2, (-1, 64))  # t24\n",
-       "      # t24 = prims.reshape(t2, (64, 64))  # t24\n",
-       "  del t2\n",
-       "  t25 = torch.matmul(t23, t24)  # t25\n",
-       "    # t25 = ltorch.matmul(t23, t24)  # t25\n",
-       "      # t25 = prims.matmul(t23, t24)  # t25\n",
-       "  del t23, t24\n",
-       "  t26 = torch.true_divide(t16, 2)  # t26\n",
-       "    # t26 = ltorch.true_divide(t16, 2)  # t26\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t26 = prims.div(t16, 2.0)  # t26\n",
-       "  del t16\n",
-       "  p27 = torch_reduce_scatter_prim_impl(t26, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p27\n",
-       "  del t26\n",
-       "  t28 = torch_wait_prim_impl(p27)  # t28\n",
-       "  del p27\n",
-       "  t29 = torch.true_divide(t25, 2)  # t29\n",
-       "    # t29 = ltorch.true_divide(t25, 2)  # t29\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t29 = prims.div(t25, 2.0)  # t29\n",
-       "  del t25\n",
-       "  p30 = torch_reduce_scatter_prim_impl(t29, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p30\n",
-       "  del t29\n",
-       "  t31 = torch_wait_prim_impl(p30)  # t31\n",
-       "  del p30\n",
-       "  return (t10, (t31, t28, {'x': t21}))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{,} \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args }\n", - " \\PY{c+c1}{\\PYZsh{} kwargs }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{k}{del} \\PY{n}{args}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{kwargs}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\n", - " \\PY{k}{del} \\PY{n}{kwargs}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3}\n", - " \\PY{n}{p4} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p4}\n", - " \\PY{k}{del} \\PY{n}{t0}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{k}{del} \\PY{n}{p4}\n", - " \\PY{n}{p6} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p6}\n", - " \\PY{k}{del} \\PY{n}{t1}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p6}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7}\n", - " \\PY{k}{del} \\PY{n}{p6}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.linear(t2, t5, None) \\PYZsh{} t8}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = prims.linear(t2, t5, None) \\PYZsh{} t8}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.tanh(t8) \\PYZsh{} t9}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = prims.tanh(t8) \\PYZsh{} t9}\n", - " \\PY{k}{del} \\PY{n}{t8}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.linear(t9, t7, None) \\PYZsh{} t10}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = prims.linear(t9, t7, None) \\PYZsh{} t10}\n", - " \\PY{n}{t11} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = ltorch.reshape(t3, (\\PYZhy{}1, 64)) \\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = prims.reshape(t3, (64, 64)) \\PYZsh{} t11}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = ltorch.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = prims.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{k}{del} \\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\n", - " \\PY{n}{t13} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = ltorch.reshape(t3, (\\PYZhy{}1, 64)) \\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = prims.reshape(t3, (64, 64)) \\PYZsh{} t13}\n", - " \\PY{k}{del} \\PY{n}{t3}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t13}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14}\n", - " \\PY{c+c1}{\\PYZsh{} t14 = ltorch.permute(t13, (1, 0)) \\PYZsh{} t14}\n", - " \\PY{c+c1}{\\PYZsh{} t14 = prims.transpose(t13, (1, 0)) \\PYZsh{} t14}\n", - " \\PY{k}{del} \\PY{n}{t13}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.reshape(t9, (\\PYZhy{}1, 64)) \\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = prims.reshape(t9, (64, 64)) \\PYZsh{} t15}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = prims.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{k}{del} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = prims.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{k}{del} \\PY{n}{t9}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t17}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.sub(1.0, t17, alpha=None) \\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = prims.sub(1.0, t17) \\PYZsh{} t18}\n", - " \\PY{k}{del} \\PY{n}{t17}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = prims.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{k}{del} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.reshape(t19, (\\PYZhy{}1, 64)) \\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.reshape(t19, (64, 64)) \\PYZsh{} t20}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.reshape(t19, (\\PYZhy{}1, 64)) \\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.reshape(t19, (64, 64)) \\PYZsh{} t22}\n", - " \\PY{k}{del} \\PY{n}{t19}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t22}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.permute(t22, (1, 0)) \\PYZsh{} t23}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = prims.transpose(t22, (1, 0)) \\PYZsh{} t23}\n", - " \\PY{k}{del} \\PY{n}{t22}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.reshape(t2, (\\PYZhy{}1, 64)) \\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = prims.reshape(t2, (64, 64)) \\PYZsh{} t24}\n", - " \\PY{k}{del} \\PY{n}{t2}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{k}{del} \\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.true\\PYZus{}divide(t16, 2) \\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.div(t16, 2.0) \\PYZsh{} t26}\n", - " \\PY{k}{del} \\PY{n}{t16}\n", - " \\PY{n}{p27} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t26}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p27}\n", - " \\PY{k}{del} \\PY{n}{t26}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p27}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28}\n", - " \\PY{k}{del} \\PY{n}{p27}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t25}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.true\\PYZus{}divide(t25, 2) \\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.div(t25, 2.0) \\PYZsh{} t29}\n", - " \\PY{k}{del} \\PY{n}{t25}\n", - " \\PY{n}{p30} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t29}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p30}\n", - " \\PY{k}{del} \\PY{n}{t29}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31}\n", - " \\PY{k}{del} \\PY{n}{p30}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{n}{t28}\\PY{p}{,} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t21}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args, **kwargs):\n", - " # args \n", - " # kwargs \n", - " t0, \\\n", - " t1, \\\n", - " = args\n", - " del args\n", - " t2 = kwargs['x']\n", - " del kwargs\n", - " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3\n", - " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3\n", - " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3\n", - " p4 = torch_all_gather_prim_impl(t0, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p4\n", - " del t0\n", - " t5 = torch_wait_prim_impl(p4) # t5\n", - " del p4\n", - " p6 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p6\n", - " del t1\n", - " t7 = torch_wait_prim_impl(p6) # t7\n", - " del p6\n", - " t8 = torch.nn.functional.linear(t2, t5, None) # t8\n", - " # t8 = ltorch.linear(t2, t5, None) # t8\n", - " # t8 = prims.linear(t2, t5, None) # t8\n", - " t9 = torch.tanh(t8) # t9\n", - " # t9 = ltorch.tanh(t8) # t9\n", - " # t9 = prims.tanh(t8) # t9\n", - " del t8\n", - " t10 = torch.nn.functional.linear(t9, t7, None) # t10\n", - " # t10 = ltorch.linear(t9, t7, None) # t10\n", - " # t10 = prims.linear(t9, t7, None) # t10\n", - " t11 = torch.reshape(t3, (-1, 64)) # t11\n", - " # t11 = ltorch.reshape(t3, (-1, 64)) # t11\n", - " # t11 = prims.reshape(t3, (64, 64)) # t11\n", - " t12 = torch.matmul(t11, t7) # t12\n", - " # t12 = ltorch.matmul(t11, t7) # t12\n", - " # t12 = prims.matmul(t11, t7) # t12\n", - " del t11, t7\n", - " t13 = torch.reshape(t3, (-1, 64)) # t13\n", - " # t13 = ltorch.reshape(t3, (-1, 64)) # t13\n", - " # t13 = prims.reshape(t3, (64, 64)) # t13\n", - " del t3\n", - " t14 = torch.permute(t13, (1, 0)) # t14\n", - " # t14 = ltorch.permute(t13, (1, 0)) # t14\n", - " # t14 = prims.transpose(t13, (1, 0)) # t14\n", - " del t13\n", - " t15 = torch.reshape(t9, (-1, 64)) # t15\n", - " # t15 = ltorch.reshape(t9, (-1, 64)) # t15\n", - " # t15 = prims.reshape(t9, (64, 64)) # t15\n", - " t16 = torch.matmul(t14, t15) # t16\n", - " # t16 = ltorch.matmul(t14, t15) # t16\n", - " # t16 = prims.matmul(t14, t15) # t16\n", - " del t14, t15\n", - " t17 = torch.mul(t9, t9) # t17\n", - " # t17 = ltorch.mul(t9, t9) # t17\n", - " # t17 = prims.mul(t9, t9) # t17\n", - " del t9\n", - " t18 = torch.sub(1.0, t17) # t18\n", - " # t18 = ltorch.sub(1.0, t17, alpha=None) # t18\n", - " # t18 = prims.sub(1.0, t17) # t18\n", - " del t17\n", - " t19 = torch.mul(t12, t18) # t19\n", - " # t19 = ltorch.mul(t12, t18) # t19\n", - " # t19 = prims.mul(t12, t18) # t19\n", - " del t12, t18\n", - " t20 = torch.reshape(t19, (-1, 64)) # t20\n", - " # t20 = ltorch.reshape(t19, (-1, 64)) # t20\n", - " # t20 = prims.reshape(t19, (64, 64)) # t20\n", - " t21 = torch.matmul(t20, t5) # t21\n", - " # t21 = ltorch.matmul(t20, t5) # t21\n", - " # t21 = prims.matmul(t20, t5) # t21\n", - " del t20, t5\n", - " t22 = torch.reshape(t19, (-1, 64)) # t22\n", - " # t22 = ltorch.reshape(t19, (-1, 64)) # t22\n", - " # t22 = prims.reshape(t19, (64, 64)) # t22\n", - " del t19\n", - " t23 = torch.permute(t22, (1, 0)) # t23\n", - " # t23 = ltorch.permute(t22, (1, 0)) # t23\n", - " # t23 = prims.transpose(t22, (1, 0)) # t23\n", - " del t22\n", - " t24 = torch.reshape(t2, (-1, 64)) # t24\n", - " # t24 = ltorch.reshape(t2, (-1, 64)) # t24\n", - " # t24 = prims.reshape(t2, (64, 64)) # t24\n", - " del t2\n", - " t25 = torch.matmul(t23, t24) # t25\n", - " # t25 = ltorch.matmul(t23, t24) # t25\n", - " # t25 = prims.matmul(t23, t24) # t25\n", - " del t23, t24\n", - " t26 = torch.true_divide(t16, 2) # t26\n", - " # t26 = ltorch.true_divide(t16, 2) # t26\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t26 = prims.div(t16, 2.0) # t26\n", - " del t16\n", - " p27 = torch_reduce_scatter_prim_impl(t26, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p27\n", - " del t26\n", - " t28 = torch_wait_prim_impl(p27) # t28\n", - " del p27\n", - " t29 = torch.true_divide(t25, 2) # t29\n", - " # t29 = ltorch.true_divide(t25, 2) # t29\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t29 = prims.div(t25, 2.0) # t29\n", - " del t25\n", - " p30 = torch_reduce_scatter_prim_impl(t29, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p30\n", - " del t29\n", - " t31 = torch_wait_prim_impl(p30) # t31\n", - " del p30\n", - " return (t10, (t31, t28, {'x': t21}))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n", - "\n", - "# Grab the final trace\n", - "exec_trace = optimized_trace[-1]\n", - "wrap_as_highlighted_code(exec_trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 4 : Running the actual computation\n", - "\n", - "Running the actual computation will require setting up 2 processes and running our above code in both those processes (which can be tricky with Jupyter Notebook). Instead, we will write a small script and run it with `torchrun` which takes care of setting up the processes and relevant state.\n", - "\n", - "**NOTE**: This requires device running this notebook to have at least 2-GPUs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the example below, we will use `thunder.distributed.fsdp` which does the same as what we did above (with some extra checks). The code below should look familiar as it is roughly all the above pieces in a single script. " - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Overwriting thunder_fsdp_simple_example.py\n" - ] - } - ], - "source": [ - "%%writefile thunder_fsdp_simple_example.py\n", - "\n", - "# imports\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", - "import os\n", - "\n", - "# # # # # # # #\n", - "# Create Model\n", - "# # # # # # # #\n", - "\n", - "# NOTE: We create the model on CPU.\n", - "device='cpu'\n", - "dim = 64\n", - "def create_model():\n", - " layers = []\n", - " layers.append(torch.nn.Linear(dim, dim))\n", - " layers.append(torch.nn.ReLU())\n", - " layers.append(torch.nn.Linear(dim, dim))\n", - " return torch.nn.Sequential(*layers).to(device)\n", - "\n", - "# Model\n", - "model = create_model()\n", - "# Input\n", - "x = torch.randn(dim, dim, device=device)\n", - "\n", - "# # # # # # # #\n", - "# Setup for distributed\n", - "# # # # # # # #\n", - "torch.distributed.init_process_group(backend='nccl')\n", - "\n", - "rank = int(os.environ[\"LOCAL_RANK\"])\n", - "\n", - "device = f\"cuda:{rank}\"\n", - "\n", - "# # # # # # # #\n", - "# Move inputs to correct device\n", - "# # # # # # # #\n", - "x = x.to(device)\n", - "\n", - "# # # # # # # #\n", - "# Wrap the model in thunder.distributed.fsdp\n", - "# # # # # # # #\n", - "\n", - "# thunder.distributed.fsdp takes care of moving the parameter\n", - "# shard to the correct GPU for the current process.\n", - "cmodel = thunder.jit(thunder.distributed.fsdp(model))\n", - "\n", - "# Run the forward pass.\n", - "cmodel(x)\n", - "\n", - "# # # # # # # #\n", - "# Check the traces\n", - "# # # # # # # #\n", - "fwd_traces, bwd_traces = thunder.last_traces(cmodel)\n", - "\n", - "# # # # # # # #\n", - "# Print and check to see if they match ours\n", - "# # # # # # # #\n", - "if rank == 0:\n", - " print(fwd_traces[-1])\n", - " print(\"*******\"* 8)\n", - " print(bwd_traces[-1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let us run the above script and check what the trace looks like.\n", - "\n", - "We can observe that forward trace has `torch_all_gather_prim_impl` to gather the parameter before forward pass and the backward trace has `torch_reduce_scatter_prim_impl` to reduce and scatter the gradients back to different GPUs. This is similar to our implementation above." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] \n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] *****************************************\n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] *****************************************\n", - "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def augmented_forward_fn(t_0_weight, t_0_bias, t_0, t_2_weight, t_2_bias):\n", - " # t_0_weight \n", - " p0 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p0\n", - " # t_0_bias \n", - " p2 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p2\n", - " # t_0 \n", - " # t_2_weight \n", - " p7 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p7\n", - " # t_2_bias \n", - " p9 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p9\n", - " t1 = torch_wait_prim_impl(p0) # t1\n", - " del p0\n", - " t3 = torch_wait_prim_impl(p2) # t3\n", - " del p2\n", - " t4 = torch.nn.functional.linear(t_0, t1, t3) # t4\n", - " # t4 = ltorch.linear(t_0, t1, t3) # t4\n", - " # t4 = prims.linear(t_0, t1, t3) # t4\n", - " del t1, t3\n", - " [t5, t6] = nvFusion0(t4)\n", - " # t5 = prims.gt(t4, 0.0) # t5\n", - " # t6 = prims.where(t5, t4, 0.0) # t6\n", - " del t4\n", - " t8 = torch_wait_prim_impl(p7) # t8\n", - " del p7\n", - " t10 = torch_wait_prim_impl(p9) # t10\n", - " del p9\n", - " t11 = torch.nn.functional.linear(t6, t8, t10) # t11\n", - " # t11 = ltorch.linear(t6, t8, t10) # t11\n", - " # t11 = prims.linear(t6, t8, t10) # t11\n", - " del t10\n", - " return {'output': (t11, ()), 'flat_args': [t_0_weight, t_0_bias, t_0, t_2_weight, t_2_bias], 'flat_output': (t11,)}, ((t5, t6, t8, t_0), ())\n", - "********************************************************\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def backward_fn(saved_for_backward, cotangents):\n", - " # saved_for_backward \n", - " # cotangents \n", - " C0, \\\n", - " _, \\\n", - " = saved_for_backward\n", - " clear_collection(saved_for_backward)\n", - " del saved_for_backward\n", - " t0, \\\n", - " = cotangents\n", - " clear_collection(cotangents)\n", - " del cotangents\n", - " t5, \\\n", - " t6, \\\n", - " t8, \\\n", - " t_0, \\\n", - " = C0\n", - " clear_collection(C0)\n", - " del C0\n", - " t31 = torch.reshape(t0, (-1, 64)) # t31\n", - " # t31 = ltorch.reshape(t0, (-1, 64)) # t31\n", - " # t31 = prims.reshape(t0, (64, 64)) # t31\n", - " t32 = torch.permute(t31, (1, 0)) # t32\n", - " # t32 = ltorch.permute(t31, (1, 0)) # t32\n", - " # t32 = prims.transpose(t31, (1, 0)) # t32\n", - " t33 = torch.reshape(t6, (-1, 64)) # t33\n", - " # t33 = ltorch.reshape(t6, (-1, 64)) # t33\n", - " # t33 = prims.reshape(t6, (64, 64)) # t33\n", - " del t6\n", - " t48 = torch.reshape(t_0, (-1, 64)) # t48\n", - " # t48 = ltorch.reshape(t_0, (-1, 64)) # t48\n", - " # t48 = prims.reshape(t_0, (64, 64)) # t48\n", - " del t_0\n", - " [t36] = nvFusion0(t0)\n", - " # t35 = prims.sum(t0, (0,)) # t35\n", - " # t36 = prims.div(t35, 2.0) # t36\n", - " del t0\n", - " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37\n", - " del t36\n", - " t30 = torch.matmul(t31, t8) # t30\n", - " # t30 = ltorch.matmul(t29, t8) # t30\n", - " # t30 = prims.matmul(t29, t8) # t30\n", - " del t31, t8\n", - " t34 = torch.matmul(t32, t33) # t34\n", - " # t34 = ltorch.matmul(t32, t33) # t34\n", - " # t34 = prims.matmul(t32, t33) # t34\n", - " del t32, t33\n", - " [t39, t42, t51] = nvFusion1(t30, t34, t5)\n", - " # t42 = prims.where(t5, t30, 0.0) # t42\n", - " # t50 = prims.sum(t42, (0,)) # t50\n", - " # t51 = prims.div(t50, 2.0) # t51\n", - " # t39 = prims.div(t34, 2.0) # t39\n", - " del t30, t34, t5\n", - " p40 = torch_reduce_scatter_prim_impl(t39, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p40\n", - " del t39\n", - " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52\n", - " del t51\n", - " t46 = torch.reshape(t42, (-1, 64)) # t46\n", - " # t46 = ltorch.reshape(t42, (-1, 64)) # t46\n", - " # t46 = prims.reshape(t42, (64, 64)) # t46\n", - " del t42\n", - " t47 = torch.permute(t46, (1, 0)) # t47\n", - " # t47 = ltorch.permute(t46, (1, 0)) # t47\n", - " # t47 = prims.transpose(t46, (1, 0)) # t47\n", - " del t46\n", - " t49 = torch.matmul(t47, t48) # t49\n", - " # t49 = ltorch.matmul(t47, t48) # t49\n", - " # t49 = prims.matmul(t47, t48) # t49\n", - " del t47, t48\n", - " [t54] = nvFusion2(t49)\n", - " # t54 = prims.div(t49, 2.0) # t54\n", - " del t49\n", - " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55\n", - " del t54\n", - " t38 = torch_wait_prim_impl(p37) # t38\n", - " del p37\n", - " t41 = torch_wait_prim_impl(p40) # t41\n", - " del p40\n", - " t53 = torch_wait_prim_impl(p52) # t53\n", - " del p52\n", - " t56 = torch_wait_prim_impl(p55) # t56\n", - " del p55\n", - " return (t56, t53, None, t41, t38)\n" - ] - } - ], - "source": [ - "!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Conclusion\n", - "\n", - "We have created our implementation of FSDP to shard our model across multiple GPUs. In the process, we also learned that:\n", - "\n", - "1. `thunder` provides us with primitives for synchronization across mutiple GPUs.\n", - "2. `thunder` also takes care of implementing the backward support for the synchronization primitives, so we don't have to explicitly do anything to get the backward working.\n", - "3. We can just easily apply `thunder.distributed.fsdp` to our model and it will take care of sharding the parameters and also adding synchronizations to our model. Also, we can easily check the modifications by inspecting the traces." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From a7ef2902125ab05db8e3c97455aaf900a9569d4b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 14 Mar 2024 11:03:22 +0100 Subject: [PATCH 08/44] drop preprocessing (PR2447) --- docs/source/reference/common/index.rst | 1 - examples/lit-gpt/_fsdp_thunder.py | 3 +- thunder/clang/__init__.py | 3 - thunder/common.py | 123 --- thunder/core/script/__init__.py | 0 thunder/core/script/algorithms.py | 199 ---- thunder/core/script/frontend.py | 684 ------------- thunder/core/script/graph.py | 819 --------------- thunder/core/script/instrumentation.py | 145 --- thunder/core/script/mypy-strict.ini | 52 - thunder/core/script/noinline.py | 38 - thunder/core/script/overview.ipynb | 331 ------- thunder/core/script/parse/__init__.py | 7 - thunder/core/script/parse/disassemble.py | 166 ---- thunder/core/script/parse/functionalize.py | 287 ------ thunder/core/script/parse/instructions.py | 117 --- thunder/core/script/parse/stack_effect.py | 234 ----- thunder/core/script/passes.py | 932 ------------------ thunder/core/script/protograph.py | 518 ---------- thunder/core/script/protograph_passes.py | 74 -- thunder/core/script/python_ir.py | 507 ---------- thunder/core/script/python_ir_data.py | 68 -- thunder/core/script/values/__init__.py | 4 - thunder/core/script/values/base.py | 200 ---- thunder/core/script/values/composite.py | 177 ---- thunder/core/script/values/materialization.py | 171 ---- thunder/core/script/values/symbolic.py | 175 ---- thunder/core/transforms.py | 14 +- thunder/numpy/__init__.py | 5 +- thunder/tests/framework.py | 18 +- 30 files changed, 5 insertions(+), 6067 deletions(-) delete mode 100644 thunder/core/script/__init__.py delete mode 100644 thunder/core/script/algorithms.py delete mode 100644 thunder/core/script/frontend.py delete mode 100644 thunder/core/script/graph.py delete mode 100644 thunder/core/script/instrumentation.py delete mode 100644 thunder/core/script/mypy-strict.ini delete mode 100644 thunder/core/script/noinline.py delete mode 100644 thunder/core/script/overview.ipynb delete mode 100644 thunder/core/script/parse/__init__.py delete mode 100644 thunder/core/script/parse/disassemble.py delete mode 100644 thunder/core/script/parse/functionalize.py delete mode 100644 thunder/core/script/parse/instructions.py delete mode 100644 thunder/core/script/parse/stack_effect.py delete mode 100644 thunder/core/script/passes.py delete mode 100644 thunder/core/script/protograph.py delete mode 100644 thunder/core/script/protograph_passes.py delete mode 100644 thunder/core/script/python_ir.py delete mode 100644 thunder/core/script/python_ir_data.py delete mode 100644 thunder/core/script/values/__init__.py delete mode 100644 thunder/core/script/values/base.py delete mode 100644 thunder/core/script/values/composite.py delete mode 100644 thunder/core/script/values/materialization.py delete mode 100644 thunder/core/script/values/symbolic.py diff --git a/docs/source/reference/common/index.rst b/docs/source/reference/common/index.rst index 0011c1b21f..20e0144d42 100644 --- a/docs/source/reference/common/index.rst +++ b/docs/source/reference/common/index.rst @@ -9,4 +9,3 @@ Common functions and classes for Thunder. :toctree: generated/ CACHE_OPTIONS - preprocess diff --git a/examples/lit-gpt/_fsdp_thunder.py b/examples/lit-gpt/_fsdp_thunder.py index 77ad8cfba0..133c40b1f2 100644 --- a/examples/lit-gpt/_fsdp_thunder.py +++ b/examples/lit-gpt/_fsdp_thunder.py @@ -414,8 +414,7 @@ def _get_state_dict( def _unwrap_tom(obj: object) -> object: # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder from thunder import ThunderModule - from thunder.common import ThunderOptimizedModule - if isinstance(obj, (ThunderOptimizedModule, ThunderModule)): + if isinstance(obj, ThunderModule): return obj._model return obj diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 6032034cc8..c0c711f5b8 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -19,7 +19,6 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, pyval, pytype, proxy, AnyProxy, Proxy import thunder.core.devices as devices -from thunder.core.script.noinline import noinline # This file defines the operations in lightning.compile's "core" language. # @@ -34,7 +33,6 @@ _clang_fn_set: set = set() -# TODO RC1 Remove noinline # Decorator that sets the core language context and registers the function class clangop: def __init__(self, *, method_name: None | str = None): @@ -42,7 +40,6 @@ def __init__(self, *, method_name: None | str = None): def __call__(self, fn: Callable) -> Callable: _fn = langctx(Languages.CLANG)(fn) - _fn = noinline(_fn) _clang_fn_set.add(_fn) if self.method_name is not None: diff --git a/thunder/common.py b/thunder/common.py index 74ca9f6473..860b6b522a 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -133,40 +133,6 @@ def last_computation_execution_time(self, /) -> int: return self._time_template(start, stop, "computation execution") -import thunder.core.script.frontend as script_frontend -import thunder.core.script.instrumentation as script_instrumentation -import thunder.core.script.passes as passes -import thunder.core.script.python_ir as python_ir - - -# Preprocesses function -# Currently tries to map torch.foo lookups to thunder.torch.foo lookups -@script_instrumentation.record -def preprocess(fn, is_module): - gr = script_frontend.acquire_method(fn.forward if is_module else fn) - passes.unroll_for_loops_and_inline_modules(gr) - if is_module: - ( - additional_param_names, - additional_param_values, - additional_return_names, - ) = passes.module_to_function(gr) - passes.strongly_inline_functions(gr) - passes.torch_to_thunder(gr) - - thunder_fn = python_ir.generate_function(gr) - if is_module: - thunder_fn._additional_param_names = additional_param_names - thunder_fn._additional_param_values = additional_param_values - thunder_fn._additional_return_names = additional_return_names - else: - thunder_fn._additional_param_names = None - thunder_fn._additional_param_values = None - thunder_fn._additional_return_names = None - - return thunder_fn - - # A class that holds data about the compiled object, including statistics about how it's been called # TODO Better document the module-related data the preprocessing harvests, # like additional_param_names @@ -364,85 +330,6 @@ def translate(x: Any, *, name: str | None = None) -> Any: return proxyargs, proxykwargs -class ThunderOptimizedModule(torch.nn.Module): # TOM - # todo: subclass nn.Module or forward things like .state_dict() to the - # model - def __init__(self, model, fn, tfn, additional_param_names, additional_param_values, additional_return_names): - super().__init__() - self._model = model - self._forward_fn = fn - self._tfn = tfn - - self._additional_param_values = additional_param_values - self._additional_param_names = additional_param_names - self._additional_return_names = additional_return_names - d = {k: i for i, k in enumerate(additional_param_names)} - self._additional_return_param_idxes = [d[k] for k in additional_return_names] - - def __call__(self, *args, **kwargs): - all_args = (*self._additional_param_values, *args) - res = self._forward_fn(*all_args, **kwargs) - if self._additional_return_names: - res, *additional_returns = res - assert len(self._additional_return_names) == len( - additional_returns - ), f"Number of expected additional return args {len(self._additional_return_names)=} does not match the actual number {len(additional_returns)=}" - for k, v, idx in zip( - self._additional_return_names, additional_returns, self._additional_return_param_idxes - ): - m = self._model - parts = k.split(".") - for p in parts[:-1]: - m = getattr(m, p) - setattr(m, parts[-1], v) - self._additional_param_values[idx] = v - return res - - @contextmanager - def no_sync(self): - """Context manager to disable gradient synchronization in data parallel mode. - - This context manager is intended to be used in conjunction with - :class:`torch.nn.parallel.DistributedDataParallel` to disable gradient - synchronization in the backward pass. It will not have any effect when - used with other modules. - - .. note:: - - This could lead to different accumulated gradients with ``torch.nn.parallel.distributed.DistributedDataParallel.no_sync``. - PyTorch's gradient synchronization is implemented by applying all-reduce to gradient buckets of ``torch.nn.Parameter.grad``. - Thus the ``no_sync`` context leads to :math:`\text{AllReduce} \\left( \\sum_{i = 0}^{\rm{num_grad_accum_steps}} g_i \right)`. - In contrast, this synchronizes accumulated gradients when exiting, leading to - :math:`\text{AllReduce} \\left( \\sum_{i = 0}^{\rm{num_grad_accum_steps - 1}} g_i \right) + \text{AllReduce}(g_{\rm{num_grad_accum_steps}})`. - - .. warning:: - - You must reuse this context manager in each group of gradient accumulation iterations since gradients will get synchronized - on context manager exit. For example: - - .. code-block:: python - - with model.no_sync(): - for _ in range(len(gradient_accumulation_iters)): - loss(model(x)).backward() # uses no-sync-backward trace - loss(model(x)).backward() # uses the regular backward trace - optimizer.step() - - """ - from thunder.distributed import ( - set_skip_data_parallel_grad_sync, - reset_skip_data_parallel_grad_sync, - _sync_grads, - ) - - token = set_skip_data_parallel_grad_sync(True) - try: - yield - finally: - reset_skip_data_parallel_grad_sync(token) - _sync_grads(self) - - # # Caching objects and functions # @@ -947,16 +834,6 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_stop = time.time_ns() return result - if cd.is_module: - _fn = ThunderOptimizedModule( - cd.fn, - _fn, - cd.processed_function, - cd.additional_param_names, - cd.additional_param_values, - cd.additional_return_names, - ) - # NOTE is_module is False _fn._pfn = cd.processed_function _fn._lc_cd = cd diff --git a/thunder/core/script/__init__.py b/thunder/core/script/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/thunder/core/script/algorithms.py b/thunder/core/script/algorithms.py deleted file mode 100644 index bced598172..0000000000 --- a/thunder/core/script/algorithms.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -import itertools -import textwrap -from typing import Generic, ParamSpec, TypeVar, cast - -import networkx as nx -from typing_extensions import Self - -from thunder.core.utils import OrderedSet - -__all__ = ("flatten_map", "sort_adjacent", "compute_condense_map") -P = ParamSpec("P") -T = TypeVar("T") - - -# ============================================================================= -# == nx.(Di)Graph, but with more safety ======================================= -# ============================================================================= -class TypedGraph(nx.Graph, Generic[T]): # type: ignore[misc, no-any-unimported] - def __init__(self, edgelist: Iterable[tuple[T, T]] = ()) -> None: - super().__init__() - self.add_edges_from(edgelist) - - @property - def nodes(self) -> Iterable[T]: - return cast(Iterable[T], super().nodes) - - @property - def edges(self) -> Iterable[tuple[T, T]]: - return cast(Iterable[tuple[T, T]], super().edges) - - @property - def connected_components(self) -> Iterable[set[T]]: - return cast(Iterable[set[T]], nx.connected_components(self)) - - def subgraph(self, nodes: Iterable[T]) -> Self: - return cast(Self, super().subgraph(nodes)) - - def to_undirected_class(self) -> type: - return TypedGraph[T] - - def to_directed_class(self) -> type: - return TypedDiGraph[T] - - -class TypedDiGraph(TypedGraph[T], nx.DiGraph): # type: ignore[misc, no-any-unimported] - def assert_directed_acyclic(self) -> None: - if not nx.is_directed_acyclic_graph(self): - cycle = "\n".join(f"{node}" for node, _ in nx.find_cycle(self)) - raise AssertionError(f"Cycle detected:\n{textwrap.indent(cycle, ' ' * 4)}") - - def to_undirected(self, *args: P.args, **kwargs: P.kwargs) -> TypedGraph[T]: - G = super().to_undirected(*args, **kwargs) - assert isinstance(G, TypedGraph) - return G - - def predecessors(self, n: T) -> Iterable[T]: - return cast(Iterable[T], super().predecessors(n)) - - -# ============================================================================= -# == Graph algorithms ========================================================= -# ============================================================================= -def flatten_map(mapping: Mapping[T, T]) -> Iterable[tuple[T, T]]: - """If one interprets the items as edges of a tree (or forest), return items for a tree of at most depth two. - - For example, `{1: 2, 2: 3, 4: 3, 5: 6}` flattens to `{1: 3, 2: 3, 4: 3, 5: 6}`. - """ - G = TypedDiGraph[T](((j, i) for i, j in mapping.items() if i != j)) - assert nx.is_directed_acyclic_graph(G) - for cluster in nx.connected_components(G.to_undirected()): - (root,) = (i for i in cluster if not G.in_degree(i)) - yield from ((i, root) for i in cluster if i != root) - - -def _extract_paths(G: TypedDiGraph[T]) -> Iterable[tuple[T, ...]]: - assert nx.is_connected(G.to_undirected()) - subgraph = TypedDiGraph[T]() - subgraph.add_nodes_from(G.nodes) - subgraph.add_edges_from(edge for edge, adjacent in nx.get_edge_attributes(G, "adjacent").items() if adjacent) - assert len(subgraph) == len(G) - subgraph.assert_directed_acyclic() - - for nodes in subgraph.to_undirected().connected_components: - path = subgraph.subgraph(nodes) - sorted_nodes = tuple(nx.topological_sort(path)) - assert len(sorted_nodes) == 1 or nx.is_simple_path(path, sorted_nodes) - yield sorted_nodes - - -def sort_adjacent(G: TypedDiGraph[T]) -> Iterable[T]: - """Sort nodes, respecting strong adjacency requirements and trying to sort return blocks to the end. - - If edges are annotated with `is_return` the annotation will be used; otherwise - terminal nodes will be inferred. - """ - (root,) = (node for node in G.nodes if not G.in_degree(node)) - is_return = nx.get_node_attributes(G, "is_return") or {node: not G.out_degree(node) for node in G.nodes} - assert len(is_return) == len(G) and any(is_return.values()) - - sort_map = {} - for primary_key, sorted_nodes in enumerate(_extract_paths(G)): - if sorted_nodes[0] is root: - primary_key = -1 - elif is_return[sorted_nodes[-1]]: - primary_key += len(G) - - sort_map.update({node: (primary_key, idx) for idx, node in enumerate(sorted_nodes)}) - - assert len(sort_map) == len(G) - yield from sorted(sort_map, key=lambda node: sort_map[node]) - - -def sort_adjacent_dfs(G: TypedDiGraph[T]) -> Iterable[T]: - """Alternate sorting formulation. Prioritizes program order over moving returns to the end. - - Unlike `sort_adjacent`, this order guarantees that at least one dependency will have - appeared before the current block. (`undo_ssa` seems to depend on this invariant.) - """ - paths = {sorted_nodes[0]: sorted_nodes for sorted_nodes in _extract_paths(G)} - - condensed = {} - for path_root, path in paths.items(): - condensed.update({node: path_root for node in path}) - - G_traverse = TypedDiGraph[T]((condensed[source], condensed[sink]) for source, sink in G.edges) - G_traverse.add_nodes_from(paths) - for i in nx.dfs_preorder_nodes(G_traverse): - yield from paths.pop(i) - - assert not paths - - -def compute_condense_map(edges: Iterable[tuple[T, T]]) -> dict[T, OrderedSet[T]]: - """Given a graph of identity relations (including unions and cycles), determine a minumum basis. - - A common construct that emerges from program loops is the statement "A is either A or B". However - if we eliminate the vacuous "A is A" component we reach the much more useful "A is B", which - allows us to replace a thorny union with a simple value. Similarly, we can eliminate chains of - equality expressions. ("C is B, B is A" becomes "C is A, B is A") - - At first this seems as simple as finding the roots of the graph, but consider the following: - "B is A, C is either B or D". B can be replaced with A, but C is the union of A and D. Critically, - B is NOT D, so simply assigning all non-roots the union of the roots is incorrect. - - This function uses an iterative method to distil the graph. Note that there is significant - simplification; the input can be an arbitrary directed **cyclic** graph (as long as at least one - node is not part of a cycle), but the output constituents are trees of at most depth two. - """ - G = TypedDiGraph(edges) - G.remove_edges_from(nx.selfloop_edges(G)) - - condense_map: dict[T, OrderedSet[T]] = {node: OrderedSet() for node in G} - for subgraph_nodes in G.to_undirected().connected_components: - subgraph = cast(TypedDiGraph[T], G.subgraph(subgraph_nodes)) - roots = OrderedSet(node for node in subgraph_nodes if not subgraph.in_degree(node)) - assert roots, subgraph.edges - - equality_edges = OrderedSet((node, node) for node in subgraph.nodes) - while True: - # Condense pairs in `equality_edges`. For example, given the - # following graph and `equality_edges`: - # 0 → 1 → 2 → 3 → 4 → 5 - # ↑┄──┘ - # - # equality_edges = {(0, 1), (3, 4)} - # - # After grouping we're left with: - # {0, 1} → 2 → {3, 4} → 5 - clusters: dict[T, T] = {} - for cluster in TypedGraph(equality_edges).connected_components: - # The choice of "canonical" value is arbitrary as long as it is consistent. - canonical = next(iter(cluster)) - clusters.update((i, canonical) for i in cluster) - - assert len(clusters) == len(subgraph) - reduced_edges = ((clusters[i], clusters[j]) for i, j in subgraph.edges) - reduced_subgraph = cast(TypedDiGraph[T], TypedDiGraph[T](reduced_edges)) # MyPy can't figure this out... - reduced_subgraph.remove_edges_from(nx.selfloop_edges(reduced_subgraph)) - num_equality_edges = len(equality_edges) - - # Condense chains. - equality_edges.update(reduced_subgraph.edges) - - # Condense loops. - for cycle in nx.simple_cycles(reduced_subgraph): - equality_edges.update(zip(cycle, itertools.chain(cycle[1:], cycle[:1]))) - - if len(equality_edges) == num_equality_edges: - # No progress has been made, exit loop. - break - - for root in roots: - for reachable in itertools.chain([root], *nx.dfs_successors(subgraph, root).values()): - condense_map[reachable].add(root) - - return condense_map diff --git a/thunder/core/script/frontend.py b/thunder/core/script/frontend.py deleted file mode 100644 index 62980449bf..0000000000 --- a/thunder/core/script/frontend.py +++ /dev/null @@ -1,684 +0,0 @@ -import collections -import functools -import dis -import inspect -import itertools -import sys -from typing import Optional, TypeVar -from collections.abc import Callable -from collections.abc import Iterable - -import networkx as nx - -from thunder.core.script.graph import ( - check_graph, - replace_values, - Block, - Graph, - MROAwareObjectRef, - Node, - NULL, - PhiValue, - SourceInformation, - Value, -) -from thunder.core.script.instrumentation import record -from thunder.core.script import parse, values -from thunder.core.script.protograph import ProtoBlock, ProtoGraph, ProtoGraphTransform -from thunder.core.script.protograph_passes import apply_protograph_passes -from thunder.core.script.python_ir_data import get_instruction, SUPPORTS_PREPROCESSING -from thunder.core.utils import debug_asserts_enabled, OrderedSet - -T = TypeVar("T") - - -class Super: - pass - - -class PruneEpilogues(ProtoGraphTransform): - """Remove the `POP_TOP, ..., JUMP_ABSOLUTE` blocks introduced during parsing. - - NOTE: This is only for `_bind_to_graph`. The reason is that it produces a - ProtoGraph with mismatched stacks. (Since we've pruned POP_TOP ops.) - This isn't a problem since `_bind_to_graph` is value based, however - it does make `_inter_block_edges` unsafe. - """ - - def _apply(self) -> ProtoGraph | None: - retain: dict[ProtoBlock, ProtoBlock] = {} - for protoblock in self.protograph: - if isinstance(protoblock, ProtoGraph): - breakpoint() - instructions = tuple(i for i, _ in protoblock.flow.symbolic) - if all(isinstance(i, parse.EpilogueFixup) for i in instructions): - assert all(i.opname == parse.POP_TOP for i in instructions[:-1]) - assert instructions[-1].opname == parse.JUMP_ABSOLUTE, instructions[-1] - continue - - retain[protoblock] = new_protoblock = ProtoBlock(protoblock.flow) - new_protoblock.uses.update(protoblock.uses) - - for old, new in retain.items(): - for target, jump in old.jump_targets: - if target not in retain: - ((target, _),) = target.jump_targets - assert target in retain - new.add_jump_target(retain[target], jump) - - if len(retain) != len(tuple(self.protograph)): - return ProtoGraph(retain.values(), provenance=(self.__class__, self.protograph)) - return None - - -def _bind_to_graph( - proto_graph: ProtoGraph, - func: Callable, - method_self: object | None = None, - mro_klass: type | None = None, -) -> Graph: - """Convert abstract value graph into a concrete Graph. - - The key nuance of this conversion is that the mapping from `AbstractValue` - to `Value` is contextual. The first time we "see" an `AbstractValue` it - maps to a `Value`. If we encounter it in any other block it maps to a - PhiValue and we need to set the proper connectivity. - - This is perhaps clearer with an example. Suppose you have an argument `x` - which is used by the root block and passed to the next block, and suppose - you have another value `y` which is created in the root block and passed to - the next block. In the abstract flow this is represented as: - ________ ___________ - `x` -> | Root | -`x`-> | Block 1 | -> ... - | `y` | -`y`-> | | - -------- ----------- - - On the other hand, `Graph` represents the same connectivity as: - ________ ___________ - `x` ←┈┈→ `𝜙x_0` -> | Root | -`𝜙x_0` ←┈┈→ `𝜙x_1` -> | Block 1 | -> ... - | `y` | -`y` ←┈┈┈┈┈→ `𝜙y_0` -> | | - -------- ----------- - - (This diagram does not show the reason for PhiValues: to accept multiple inputs.) - """ - # Peek at the signature and live objects to create Values. This is the - # *only* region where this is permitted. - # ========================================================================= - # TODO(robieta): Lazily generate specializations during runtime. - signature = inspect.signature(func) - func_globals = {**func.__builtins__, **func.__globals__, **{"super": Super()}} - - # NOTE: - # `inspect.signature` will expose parameters in intuitive order. However that - # is not necessarily how Python represents them internally. Specifically, varargs - # and varkwargs are moved to the end. This convention is load bearing (since it - # allows the interpreter index into a flat args array) so we must respect it - # here. (`func.__code__.co_varnames` is the canonical ordering.) - arg_ordered_parameters = func.__code__.co_varnames[: len(signature.parameters)] - source_file_name = inspect.getsourcefile(func) - source_start_line = func.__code__.co_firstlineno - if set(arg_ordered_parameters) != set(signature.parameters): - assert hasattr(func, "__wrapped__") - msg = f"({', '.join(arg_ordered_parameters)}) != ({', '.join(signature.parameters.keys())})" - raise NotImplementedError(msg) - - co_name = func.__code__.co_name - self_key: parse.VariableKey | None = None - self_value: Value | None = None - if method_self is not None: - self_key = parse.VariableKey(arg_ordered_parameters[0], parse.VariableScope.LOCAL) - self_value = Value(value=method_self, name=self_key.identifier, is_function_arg=True) - - get_initial_value_cache = {} - - def get_initial_value(key: parse.VariableKey, block: Block | None = None) -> Value: - if key in get_initial_value_cache: - v = get_initial_value_cache[key] - assert not ((block is None or block != v.block) and not (v.is_global or v.is_const or v.is_function_arg)) - return v - if key.is_const: - v = Value(value=key.identifier, is_const=True) - get_initial_value_cache[key] = v - return v - - elif key == self_key: - v = self_value - get_initial_value_cache[key] = v - return v - - name = key.identifier - assert isinstance(name, str) - if key.scope == parse.VariableScope.LOCAL: - if (p := signature.parameters.get(name)) is not None: - v = Value(typ=p.annotation, name=name, is_function_arg=True) - get_initial_value_cache[key] = v - return v - v = Value(value=NULL, name=name, block=block) - get_initial_value_cache[key] = v - return v - - if key.scope == parse.VariableScope.NONLOCAL: - msg = f"nonlocal variables are not supported but (key, name) = ({key}, {name}) found" - raise RuntimeError(msg) - - if key.scope == parse.VariableScope.GLOBAL: - try: - val = func_globals[name] - except KeyError: - raise ValueError(f"Could not resolve global variable: {name=}.") - v = Value(name=name, value=val, is_global=True) - get_initial_value_cache[key] = v - return v - - raise ValueError(f"Unhandled key: {key=}, name: {name=}") - - del func - # End live inspection region. - # ========================================================================= - assert proto_graph is proto_graph.link() - proto_graph = PruneEpilogues(proto_graph).apply(or_default=True) - blocks = {protoblock: Block() for protoblock in proto_graph} - blocks[proto_graph.root].jump_sources.append(None) - - # Block inputs require special handling since we may need to create `PhiValue`s. - input_conversions = {} - for protoblock, block in blocks.items(): - for key, abstract_value in protoblock.flow.begin_state: - abstract_value = abstract_value.identity - if protoblock is proto_graph.root: - value = get_initial_value(key, block=block) - if key.scope == parse.VariableScope.LOCAL and value.value is not NULL: - assert isinstance(abstract_value, values.ExternalRef), abstract_value - value = PhiValue([value], [None], block) - - elif key in protoblock.uses: - value = PhiValue([], [], block) - - else: - value = Value(value=NULL, block=block) - - input_conversions[(abstract_value, protoblock)] = value - - convert_cache = {} - - def convert(value: values.AbstractValue, protoblock: ProtoBlock, block: Block) -> Value: - value = value.identity - v = convert_cache.get((value, protoblock)) - if v is not None: - if ( - v.block != block - and block is not None - and not (v.is_global or v.is_function_arg or v.is_const or v.value == NULL) - ): - raise AssertionError("ohoh, this should not happen") - return v - - def _convert(value: values.AbstractValue, protoblock: ProtoBlock) -> Value: - assert not value.is_detail, value - if (out := input_conversions.get((value, protoblock), missing := object())) is not missing: - return out - - if isinstance(value, values.NonPyObject): - assert value.tag == values.NonPyObject.Tag.MISSING - return Value(value=NULL, block=block) - - elif isinstance(value, (values.IntermediateValue, values.CompositeValue, values.AbstractPhiValue)): - # For now we discard any information and just treat them as opaque. - # TODO(robieta): refine - return Value(block=block) - - elif isinstance(value, values.ExternalRef) and value.key.is_const: - return get_initial_value(value.key, block=block) - - raise ValueError(f"Cannot convert abstract value: {value}, {protoblock} {protoblock is proto_graph.root=}") - - v = _convert(value, protoblock) - convert_cache[(value, protoblock)] = v - return v - - def make_nodes(protoblock: ProtoBlock, block: Block) -> Iterable[Node]: - for instruction, node_flow in protoblock.flow.materialized.items(): - node = Node( - i=instruction, - inputs=[convert(v, protoblock, block) for v in node_flow.inputs], - outputs=[convert(v, protoblock, block) for v in node_flow.outputs], - ) - node.source_infos = [ - SourceInformation( - orig_file_name=source_file_name, - orig_line_no=instruction.line_no + source_start_line, - orig_end_line_no=instruction.line_no + source_start_line, - gen_line_no=instruction.line_no, - gen_end_line_no=instruction.line_no, - col_offset=0, - end_col_offset=999, - ), - ] - - for output in OrderedSet(node.outputs).difference(node.inputs): - if not (output.node or output.is_const or output.is_global): - # output.node can be populated when we deconstruct a previously constructed value (e.g. binary_idx into a tuple from build_tuple) - output.node = node - - if node.i.opname in ("LOAD_ATTR", "LOAD_METHOD"): - # Once we set `parent` (so PhiValue can traverse through it) - # we can prune these just like all other load instructions. - node.outputs[0].parent = node.inputs[0] - node.outputs[0].name = node.i.argrepr - continue - - elif node.i.opname == "CALL_FUNCTION": - # Note: `super` handling is not currently generic. Corner cases - # such as `super(**{})` or `super_alias = super; super_alias()` - # will not be correctly handled. - # TODO(robieta): handle `super` without load bearing names. - if node.i.arg == 0 and isinstance(node.inputs[0].value, Super): - assert self_value is not None, "super() called in free context" - node.outputs[0].value = MROAwareObjectRef(self_value, start_klass=mro_klass) - - elif node.i.opname == "FOR_ITER": - node.outputs[1].node = node - node.outputs[1].name = ".for_item_iter" - - yield node - - # First pass: populate nodes and jump targets. - for protoblock, block in blocks.items(): - block.nodes = list(make_nodes(protoblock, block)) - for target, _ in protoblock.jump_targets: - jump_target = blocks[target] - last_node = block.nodes[-1] - jump_target.jump_sources.append(last_node) - last_node.jump_targets.append(jump_target) - - # Second pass: link blocks. - for protoblock, block in blocks.items(): - block_values = { - k: v - for k, abstract_v in protoblock.flow.begin_state - if isinstance(v := convert(abstract_v, protoblock, block), PhiValue) - } - - block.block_inputs = list(OrderedSet(block_values.values())) - for parent in proto_graph.parents[protoblock]: - parent_state = dict(parent.flow.end_state) - for key, sink in block_values.items(): - source = convert( - parent_state.get(key, values.NonPyObject(values.NonPyObject.Tag.MISSING)), - parent, - block=blocks[parent], - ) - if source.value is not NULL and source not in sink.values: - sink.add_missing_value(v=source, jump_source=blocks[parent].nodes[-1]) - - # Third pass: specify block outputs once we know which Values are passed to another Block. - for protoblock, block in blocks.items(): - outputs = (convert(abstract_value, protoblock, block) for k, abstract_value in protoblock.flow.end_state) - block.block_outputs.update(v for v in outputs if v.phi_values) - - param_keys = tuple(parse.VariableKey(p, parse.VariableScope.LOCAL) for p in arg_ordered_parameters) - missing = { - k: v - for k in proto_graph.root.uses.difference(param_keys) - if k.scope == parse.VariableScope.LOCAL and (v := get_initial_value(k)).value is not NULL - } - assert not missing, f"missing params {missing}" - - gr = Graph(list(blocks.values())) - gr.local_variables_at_start = [get_initial_value(k) for k in param_keys] - - gr.co_name = co_name - # bound_args = [module.forward.__self__] - gr.self_value = self_value - gr.ismethod = self_value is not None - # deal with other flags? - # NESTED, GENERATOR, NOFREE, COROUTINE, ITERABLE_COROUTINE, ASYNC_GENERATOR - gr.co_flags = inspect.CO_OPTIMIZED | inspect.CO_NEWLOCALS - gr.co_argcount = 0 - gr.co_posonlyargcount = 0 - gr.co_kwonlyargcount = 0 - gr.func_defaults = [] - gr.func_kwdefaults = {} - for p in signature.parameters.values(): - if p.kind == inspect.Parameter.POSITIONAL_ONLY: - gr.co_argcount += 1 - gr.co_posonlyargcount += 1 - elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: - gr.co_argcount += 1 - elif p.kind == inspect.Parameter.KEYWORD_ONLY: - gr.co_kwonlyargcount += 1 - elif p.kind == inspect.Parameter.VAR_POSITIONAL: - gr.co_flags |= inspect.CO_VARARGS - elif p.kind == inspect.Parameter.VAR_KEYWORD: - gr.co_flags |= inspect.CO_VARKEYWORDS - else: - assert False, f"unknown parameter kind {p.kind}" - - if p.default is not inspect._empty: - if p.kind == inspect.Parameter.KEYWORD_ONLY: - gr.func_kwdefaults[p.name] = p.default - else: - gr.func_defaults.append(p.default) - return gr - - -def acquire_partial( - pfunc: functools.partial, - module: object | None = None, - mro_klass: type | None = None, -) -> Graph: - # This is complicated due to the semantics of calling Python functions. - # The partial wrapper does the following: - # def pfunc.__call__(*args, **kwargs): - # kw = pfunc.keywords.copy() - # kw.update(kwargs) - # return pfunc.func(*pfunc.args, *args, **kw) - - # This means: - # - positional partial_args are applied from the front and once - # they are bound, they are removed from the signature, - # - keyword only args get new defautls, - # - binding a positional arg as a keyword arg effectively (i.e. in how - # it can be set in calls) makes that arg and all args to the right - # keyword only. - # - things that cannot be bound to parameters may show up in varargs - # or kwargs parameters of the function. - - gr = acquire_method(pfunc.func, module, mro_klass) - gr.ensure_links() - - # first we shuffle positional args to kw only if they are in the kwargs of the partial - pos_param_names = [v.name for v in gr.local_variables_at_start[: gr.co_argcount]] - pos_param_names_to_idx = {n: i for i, n in enumerate(pos_param_names)} - kw_pos_param_idx = [pos_param_names_to_idx[k] for k in pfunc.keywords if k in pos_param_names_to_idx] - if kw_pos_param_idx: - # convert positional default args to kw ones - kw_pos_param_min = min(kw_pos_param_idx) - if kw_pos_param_min < gr.co_posonlyargcount: - raise TypeError( - f"cannot bin positional-only argument {pos_param_names[kw_pos_param_min]} as keyword in partial" - ) - - num_to_kw = gr.co_argcount - kw_pos_param_min - if gr.func_defaults: - to_kw = gr.func_defaults[-num_to_kw:] - del gr.func_defaults[-num_to_kw:] - to_kw_names = pos_param_names[-num_to_kw:] - gr.func_kwdefaults.update(zip(to_kw_names, to_kw)) - # convert positional args to kw only - gr.co_kwonlyargcount += num_to_kw - gr.co_argcount -= num_to_kw - - # deal with positional args. some will be mapped to concrete positional args, some might be added to varargs (*args) - if gr.ismethod: - arg_start = 1 - arg_count = gr.co_argcount - 1 - else: - arg_start = 0 - arg_count = gr.co_argcount - - args_to_bind = pfunc.args[:arg_count] - args_for_varargs = pfunc.args[arg_count:] - - # do we need to drop positional default args? - posarg_default_start = gr.co_argcount - len(gr.func_defaults) - posarg_default_to_delete = len(args_to_bind) + arg_start - posarg_default_start - if posarg_default_to_delete > 0: - gr.func_defaults = gr.func_defaults[posarg_default_to_delete:] - - bound_values = gr.local_variables_at_start[arg_start : arg_start + len(args_to_bind)] - del gr.local_variables_at_start[arg_start : arg_start + len(args_to_bind)] - - for bound_value, arg in zip(bound_values, args_to_bind): - bound_value.is_function_arg = False - bound_value.is_const = True - # TODO: check type? - bound_value.value = arg - gr.co_argcount -= 1 - if gr.co_posonlyargcount > 0: - gr.co_posonlyargcount -= 1 - - # handle keyword arguments to concrete parameters, collect in kwargs those for kw-varargs (**kwargs) - param_names_to_idx = { - v.name: i for i, v in enumerate(gr.local_variables_at_start[: gr.co_argcount + gr.co_kwonlyargcount]) - } - kwargs = {} - for argname, argvalue in pfunc.keywords.items(): - idx = param_names_to_idx.get(argname, -1) - if idx == -1: - kwargs[argname] = argvalue - continue - gr.func_kwdefaults[argname] = argvalue - - # for varargs and kwargs fed from partial we need the following prelude: - # TODO: (but maybe we should just have a prelude always for the consts, too...) - # if it has *varargs: - # TMP1 = LOAD_CONST partial_args_for_varargs (needs to be a tuple) - # varargs = TMP1 + varargs - # if it has **kwargs: - # TMP2 = LOAD_CONST partial_kwargs - # kwargs = partial_kwargs | kwargs - - if args_for_varargs or kwargs: - prelude = Block() - prelude.graph = gr - jump_node = Node(i=parse.ThunderInstruction.make_jump_absolute(None), inputs=[], outputs=[]) - jump_node.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - - prelude.nodes.append(jump_node) - jump_target = gr.blocks[0] - assert jump_target.jump_sources[0] is None - jump_target.jump_sources[0] = jump_node - jump_node.jump_targets.append(jump_target) - prelude.jump_sources.append(None) - for i in jump_target.block_inputs: - assert i.jump_sources[0] is None - i.jump_sources[0] = jump_node - else: - prelude = None - - # handle *args (varargs) - if args_for_varargs: - if kw_pos_param_idx: - raise TypeError( - f"partial tried to bind {len(pfunc.args)} positional arguments, but only {arg_count} are allowed after keyword binding" - ) - if not (gr.co_flags & inspect.CO_VARARGS): - raise TypeError( - f"partial tried to bind {len(pfunc.args)} positional arguments, but only {arg_count} are allowed" - ) - # the variable for varargs is at gr.co_argcount + gr.co_kwonlyargcount - v_vararg_param = gr.local_variables_at_start[gr.co_argcount + gr.co_kwonlyargcount] - v_partial_varargs = Value(name="partial_varargs", value=tuple(args_for_varargs), is_const=True) - v_varargs_new = Value(name="varargs_with_partial", block=prelude) # type is tuple - pv = PhiValue([v_vararg_param], [None], block=prelude) - new_n = Node( - i=get_instruction(opname="BINARY_ADD", arg=None), - inputs=[v_partial_varargs, pv], - outputs=[v_varargs_new], - ) - # line number? - new_n.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - prelude.nodes.insert(0, new_n) - prelude.block_outputs.add(v_varargs_new) - # replace v_vararg_param with v_varargs_new in remainder - replace_values(gr, {v_vararg_param: v_varargs_new}) - prelude.block_inputs.append(pv) - - # handle **kwargs - if kwargs: - if not (gr.co_flags & inspect.CO_VARKEYWORDS): - raise TypeError( - f"function does not have **kwargs but partial tries to bind unknown keywords {tuple(kwargs)}." - ) - - # the variable for varargs is at gr.co_argcount + gr.co_kwonlyargcount - v_kwvararg_param = gr.local_variables_at_start[ - gr.co_argcount + gr.co_kwonlyargcount + (1 if gr.co_flags & inspect.CO_VARARGS else 0) - ] - v_partial_kwvarargs = Value(name="partial_kwvarargs", value=kwargs, is_const=True) - v_kwvarargs_new = Value(name="kwvarargs_with_partial", block=prelude) # type is dict - pv = PhiValue([v_kwvararg_param], [None], block=prelude) - new_n = Node( - i=get_instruction(opname="BINARY_OR", arg=None), - inputs=[v_partial_kwvarargs, pv], - outputs=[v_kwvarargs_new], - ) - # line number? - new_n.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - prelude.nodes.insert(-1, new_n) - prelude.block_outputs.add(v_kwvarargs_new) - # replace v_vararg_param with v_varargs_new in remainder - replace_values(gr, {v_kwvararg_param: v_kwvarargs_new}) - prelude.block_inputs.append(pv) - - if prelude: - gr.blocks.insert(0, prelude) - return gr - - -@functools.cache -def _construct_protograph(func): - """Protoblocks are parse level constructs, so it is safe to reuse them.""" - return apply_protograph_passes(ProtoGraph.from_code(func.__code__)) - - -@record -def acquire_method( - method: Callable, - module: object | None = None, - mro_klass: type | None = None, -) -> Graph: - assert SUPPORTS_PREPROCESSING, sys.version_info - if isinstance(method, functools.partial): - return acquire_partial(method, module, mro_klass) - if callable(method) and not inspect.ismethod(method) and not inspect.isfunction(method): - method = method.__call__ - - method_self, func = (method.__self__, method.__func__) if inspect.ismethod(method) else (None, method) - assert not inspect.ismethod(func) - - module = module or method_self - if mro_klass is None and module is not None: - mro_klass = type(module) - - gr = _bind_to_graph(_construct_protograph(func), func, method_self, mro_klass) - gr.source_start_line = 1 - try: - gr.source_lines, _ = inspect.getsourcelines(method) - except OSError: - gr.source_lines = ["# Failed to extract source."] - - gr.method = method - gr.module = module - gr.mro_klass = mro_klass - if debug_asserts_enabled(): - check_graph(gr) - return gr - - -def remove_unused_values(gr: Graph) -> None: - gr.ensure_links() - - def remove_value(v: Value) -> None: - for pv in v.phi_values: - bl = pv.block - pv.remove_value(v) - if not pv.values: - remove_value(pv) - bl.block_inputs.remove(pv) - if pv in bl.block_outputs: - bl.block_outputs.remove(pv) - - for i in gr.blocks[0].block_inputs: - if len(i.values) == 1 and i.values[0] is None: - remove_value(i) - - gr.blocks[0].block_inputs = [i for i in gr.blocks[0].block_inputs if len(i.values) != 1 or i.values[0] is not None] - - values_used = set() - - INDEX_OPS = {"BINARY_SUBSCR"} - - def mark_used(v: Value) -> None: - if v in values_used: - return - values_used.add(v) - if v.node and v.node.i.opname in INDEX_OPS: - for i in v.node.inputs: - mark_used(i) - if v.parent is not None: - mark_used(v.parent) - if isinstance(v, PhiValue): - for w in v.values: - mark_used(w) - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname not in INDEX_OPS: - for i in n.inputs: - mark_used(i) - - for bl in gr.blocks: - for i in bl.block_inputs[:]: - if i not in values_used: - for v in i.values[:]: - if v is not None: - i.remove_value(v) - bl.block_inputs.remove(i) - bl.block_outputs = OrderedSet(o for o in bl.block_outputs if o in values_used) - for n in bl.nodes[:]: - if n.i.opname in INDEX_OPS and not any((o in values_used) for o in n.outputs): - bl.nodes.remove(n) - for i in gr.local_variables_at_start: - if i is not None: - i.phi_values = [pv for pv in i.phi_values if pv in values_used] - - for bl in gr.blocks: - for n in bl.nodes: - for o in n.outputs: - o.phi_values = [pv for pv in o.phi_values if pv in values_used] - - # remove things only used in current block (and not in own phi) from outputs - # TODO: think if this would obsolete the above - outputs_used = set() - for bl in gr.blocks: - for i in bl.block_inputs: - assert isinstance(i, PhiValue) - for v in i.values: - outputs_used.add(v) - for bl in gr.blocks: - bl.block_outputs = OrderedSet(o for o in bl.block_outputs if o in outputs_used) - - if debug_asserts_enabled(): - check_graph(gr) diff --git a/thunder/core/script/graph.py b/thunder/core/script/graph.py deleted file mode 100644 index 24c22746df..0000000000 --- a/thunder/core/script/graph.py +++ /dev/null @@ -1,819 +0,0 @@ -# This is a "TorchScript-like" graph representation of Python IR. -# The idea is that blocks are "simple blocks" in terms of the code flow graph, -# i.e. without branches -import collections -import copy -import enum -import inspect -import linecache -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Set, Union -from collections.abc import Iterable, Iterator, Sequence - -from thunder.core.script.instrumentation import InstrumentingBase -from thunder.core.script.parse import ThunderInstruction -from thunder.core.script.noinline import noinline -from thunder.core.utils import OrderedSet - -if TYPE_CHECKING: - import graphviz - -GraphObject = Union["Value", "Node", "Block"] - - -def assert_value(v: GraphObject | None) -> "Value": - assert isinstance(v, Value) - return v - - -def assert_node(n: GraphObject | None) -> "Node": - assert isinstance(n, Node) - return n - - -def assert_block(bl: GraphObject | None) -> "Block": - assert isinstance(bl, Block) - return bl - - -class GraphSummaryCallback: - def node(self, n: "Node") -> tuple[list[str], list[str]]: - return [], [] - - def finish(self) -> list[str]: - return [] - - -class NULL: - """marker for non-existant object.""" - - pass - - -@dataclass -class SourceInformation: - orig_line_no: int - orig_end_line_no: int - - gen_line_no: int - gen_end_line_no: int - # gen_file_name? --> could be interesting when passing SourceInfo to traces - - col_offset: int - end_col_offset: int - orig_file_name: str = "" - source: Any | None = None - - -class MROAwareObjectRef: # or as they call it super - def __init__(self, obj: Any, start_klass: type | None = None): - self.obj = obj - self.start_klass = start_klass - - def __getattr__(self, name: str) -> Any: - ## handle non-methods... - i = 0 - mro = inspect.getmro(self.obj.value.__class__) - if self.start_klass is not None: - while i < len(mro) and not mro[i] == self.start_klass: - i += 1 - i += 1 - while i < len(mro) and not hasattr(mro[i], name): - i += 1 - if i >= len(mro): - raise AttributeError(f"{name} not a member") - return getattr(mro[i], name) - - -# Represent undefined values e.g. non-existent attrs etc. -# this can be inserted as a (const) value and will then be -# translated into raising an error at runtime -class _Undefined: - def __init__(self, value, attr): - self.value = value - self.attr = attr - - -# Values are -# - function arguments as inputs to the graph (including self) -# - constants and globals -# - intermediate results / local variables -# - attributes of other values given in .parent -# they can be used -# - as inputs and outputs of nodes (but inplace is still tricky) -# - as block_outputs (note that block_outputs can be either outputs of nodes -# or attribute lookups). -# block_outputs (and only these) typically have .phi_values recorded. -# PhiValues are the block_inputs. -# - they have (one or multiple) block_outputs as .values, these are set at the -# .jump_sources (TODO: .jump_sources records None for non-node-generated). -# - There must be a 1-1 correspondence between .phi_values-> and .values->. -# All block_inputs (at least before an optimization pass towards the un-ssa-ing) -# are expected to be PhiValues and all PhiValues are expected to show up as -# block_inputs. -class Value(InstrumentingBase): - def __init__( - self, - *, - node: Optional["Node"] = None, - block: Optional["Block"] = None, - nr: int | None = None, - typ: type | None = None, - value: Any = None, - name: str | None = None, - parent: Optional["Value"] = None, - is_global: bool = False, - is_const: bool = False, - is_function_arg: bool = False, - ): - self.node = node - self.block = block - self.nr = nr - self.typ = typ if typ is not None or value in (None, NULL) else type(value) - self.value = value - self.name = name - self.parent = parent - self.is_global = is_global - self.is_const = is_const - self.is_function_arg = is_function_arg - self.phi_values: list["PhiValue"] = [] - assert not (block is None and not (is_global or is_const or is_function_arg)) - - def resolve(self) -> tuple["Value", ...]: - return (self,) - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "Value": - # clones a value, including (recursively) parent value - # uses translation_dict to look up parent value - # updates translation_dict - # does not register phi_values on the clone - # always clone parents? - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - return assert_value(translation_dict[self]) - parent = self.parent - if parent: - if parent in translation_dict: - parent = assert_value(translation_dict[parent]) - else: - parent = parent.clone(translation_dict=translation_dict) - v = Value( - node=self.node, - block=self.block, - nr=self.nr, - typ=self.typ, - value=self.value, - name=self.name, - parent=parent, - is_global=self.is_global, - is_const=self.is_const, - is_function_arg=self.is_function_arg, - ) - if translation_dict is not None: - translation_dict[self] = v - return v - - def __str__(self, _value_printer=str) -> str: - parts = [] - if self.is_function_arg: - parts.append("funcarg") - if self.name: - parts.append(f"name={self.name}") - if self.typ is not None: - parts.append(f"typ={self.typ}") - if self.value is not None: - parts.append(f"value of type {type(self.value)}") - if self.is_const: - parts.append("const") - if self.is_global: - parts.append("global") - # if self.block is None: - # parts.append("block-None") - if self.parent is not None: - parts.append(f"parent={_value_printer(self.parent)}") - return f"""{type(self).__name__} {hex(id(self))} ({' '.join(parts)})""" - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - -class PhiValue(Value): - # node? - def __init__( - self, - values: list[Value], - jump_sources: Sequence[Optional["Node"]], - block: "Block", - _unfinished_clone: bool = False, - ): - super().__init__(block=block) - self.block: Block = block # duplicate assignment / declaration? - self._unfinished_clone = _unfinished_clone - self._set_values_jump_sourcess(values, jump_sources) - - def _set_values_jump_sourcess(self, values: list[Value], jump_sources: Sequence[Optional["Node"]]) -> None: - assert len(values) == len(jump_sources) - self.values = list(values) - if not self._unfinished_clone: - for v in self.values: - if v is not None: - v.phi_values.append(self) - self.jump_sources = list(jump_sources) - - def resolve(self) -> tuple[Value, ...]: - to_process = [self] - seen: OrderedSet[Value] = OrderedSet() - while to_process: - seen.add(v := to_process.pop()) - if isinstance(v, PhiValue): - to_process.extend(vi for vi in v.values if vi not in seen) - - return tuple(i for i in seen if not isinstance(i, PhiValue)) - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "PhiValue": - # due to loops in the Graph, this is complicated: - # we do not translate values or jump_sources here, but do - # translate blocks. - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - v = translation_dict[self] - assert isinstance(v, PhiValue) - return v - v = PhiValue(self.values, self.jump_sources, assert_block(translation_dict[self.block]), _unfinished_clone=True) - translation_dict[self] = v - return v - - def post_process_clone(self, *, translation_dict: dict[GraphObject, GraphObject]) -> None: - assert self._unfinished_clone - self._unfinished_clone = False - self._set_values_jump_sourcess( - [assert_value(translation_dict.get(v, v)) for v in self.values], - [(assert_node(translation_dict.get(js, js)) if js is not None else None) for js in self.jump_sources], - ) - - def add_missing_value( - self, v: Value, idx: int | None = None, jump_source: Optional["Node"] = None - ) -> None: # None: append - if idx is None: - assert v not in self.values - self.values.append(v) - v.phi_values.append(self) - self.jump_sources.append(jump_source) - else: - assert 0 <= idx < len(self.values) - assert self.values[idx] is None - assert jump_source is None - self.values[idx] = v - v.phi_values.append(self) - - def remove_value(self, v: Value) -> None: - idx = self.values.index(v) - v.phi_values.remove(self) - del self.values[idx] - del self.jump_sources[idx] - - def replace_value(self, v_old: Value, v_new: Value) -> None: - if v_old is v_new: - return - - assert v_new not in self.values - idx = self.values.index(v_old) - self.values[idx] = v_new - assert (v_new.is_function_arg or v_new.is_const) or v_new.block.graph is self.block.graph # v_old.block.graph - if v_new.is_function_arg or v_new.is_const: - # TV-TODO: this is actually dubious for constants and we should avoid it - self.jump_sources[idx] = None - else: - self.jump_sources[idx] = v_new.block.nodes[-1] - - v_old.phi_values.remove(self) - v_new.phi_values.append(self) - - -# A node corresponds to one Python bytecode instruction given in .i -# it has Values as .inputs and .outputs -class Node(InstrumentingBase): - def __init__( - self, - *, - i: ThunderInstruction, - inputs: list[Value] | None = None, - outputs: list[Value] | None = None, - source_infos: list[SourceInformation] | None = None, - ): - self.i = i - self.inputs: list[Value] = inputs if inputs is not None else [] - self.outputs: list[Value] = outputs if outputs is not None else [] - self.jump_targets: list[Block] = [] - self.source_infos: list[SourceInformation] = source_infos if source_infos is not None else [] - self.block: Block | None = None - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "Node": - """.block of the clone will be None if block is not in translation dict.""" - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - return assert_node(translation_dict[self]) - inputs = [i.clone(translation_dict=translation_dict) for i in self.inputs] - outputs = [o.clone(translation_dict=translation_dict) for o in self.outputs] - i = copy.copy(self.i) - n2 = Node(i=i, inputs=inputs, outputs=outputs) - n2.source_infos = copy.deepcopy(self.source_infos) - n2.jump_targets = [assert_block(translation_dict.get(bl, bl)) for bl in self.jump_targets] - if self.block is None: - n2.block = None - else: - bl2 = translation_dict.get(self.block) - assert bl2 is None or isinstance(bl2, Block) - n2.block = bl2 - translation_dict[self] = n2 - return n2 - - def set_jump_target(self, jt: "Block", idx: int | None = None) -> None: - # TODO: more validation? - # is_jump = (self.i.opname not in unconditional_jump_names) or (idx == 1) or (idx is None and self.jump_targets) - # assert is_jump - - if idx is None: - assert len(self.jump_targets) <= 1 - self.jump_targets.append(jt) - else: - old_jt = self.jump_targets[idx] - old_jt.jump_sources.remove(self) - self.jump_targets[idx] = jt - jt.jump_sources.append(self) - - def __str__(self) -> str: - # i.i.offset // 2, i.i.opname, i.i.arg, "(", i.i.argval, ")" - if self.i.opname in {"CALL_METHOD", "CALL_FUNCTION"}: - return f"{self.i.opname}({self.inputs})" - return f"{self.i.opname} {self.i.arg} ({self.i.argval})" # str(self.i) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - -# Blocks have the first instruction (only) as the jump target -# (or the function entry point) -# Blocks always have a single final instruction that jumps (or RETURN) -# conditional jumps (including e.g. FOR_ITER) always have the non-jumping -# target first and then the jumping target. -# The jump targets are other blocks and are atributes of the jump instruction. -class Block: - def __init__(self): - self.jump_sources: list[Node | None] = [] - self.nodes: list[Node] = [] - self.block_inputs: list[Value] = [] - self.block_outputs = OrderedSet([]) - - def __str__(self) -> str: - return "\n".join([f" Block (reached from {self.jump_sources})"] + [" " + str(n) for n in self.nodes]) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - def insert_node(self, n: Node, insert_after: Node | None = None, insert_before: Node | None = None) -> None: - assert n.block is None - assert (insert_after is None) != (insert_before is None), f"{insert_after=} {insert_before=}" - to_find = insert_after or insert_before - for idx, n2 in enumerate(self.nodes): - if n2 is to_find: - break - if n2 is not to_find: - raise ValueError(f"could not find node {n}") - - # validity checks? (also above) - n.block = self - if insert_after: - self.nodes.insert(idx + 1, n) - else: - self.nodes.insert(idx, n) - - -# A graph contains Blocks. -# The first block (.blocks[0]) is the entry point. Other blocks are connected -# through jump instructions. -class Graph(InstrumentingBase): - def __init__(self, blocks: list[Block] | None = None): - self.blocks = [] if blocks is None else blocks[:] - - def __str__(self) -> str: - return "\n".join(["Graph of"] + [str(b) for b in self.blocks]) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - def nodes(self) -> Iterator[Node]: - for b in self.blocks: - yield from b.nodes - - def ensure_links(self) -> None: - for bl in self.blocks: - bl.graph = self - for n in bl.nodes: - n.block = bl - inps = set(n.inputs) - for o in n.outputs: - if o not in inps: # not for inplace - o.block = bl - o.node = n - for o in bl.block_outputs: - if not (o.is_const or o.is_function_arg): - o.block = bl - for i in bl.block_inputs: - i.block = bl - - def clone(self) -> tuple["Graph", dict[GraphObject, GraphObject]]: - bls2, translation_dict = clone_blocks(self.blocks) - g2 = Graph(blocks=bls2) - g2.local_variables_at_start = [v.clone() for v in self.local_variables_at_start] - replace_values(g2, {k: v for k, v in zip(self.local_variables_at_start, g2.local_variables_at_start)}) - g2.ismethod = self.ismethod - g2.co_name = self.co_name - g2.co_argcount = self.co_argcount - g2.co_flags = self.co_flags - g2.co_posonlyargcount = self.co_posonlyargcount - g2.co_kwonlyargcount = self.co_kwonlyargcount - g2.func_defaults = self.func_defaults[:] - g2.func_kwdefaults = self.func_kwdefaults.copy() - g2.method = self.method - g2.module = self.module - g2.mro_klass = self.mro_klass - g2.self_value = self.self_value - g2.source_start_line = self.source_start_line - g2.source_lines = self.source_lines[:] - - return g2, translation_dict - - def print(self) -> None: - value_counter = 1 - print(self.local_variables_at_start) - for bl in self.blocks: - for n in bl.nodes: - for o in n.outputs: - o.print_name = f"{o.name}:{value_counter}" if o.name is not None else f":{value_counter}" - value_counter += 1 - for i in n.inputs: - if not hasattr(i, "print_name"): - i.print_name = f"{i.name}:{value_counter}" if i.name is not None else f":{value_counter}" - value_counter += 1 - av = f"[{n.i.argval}]" if n.i.argval is not None else "" - print( - ",".join(o.print_name for o in n.outputs), - "=", - n.i.opname, - f"{av}(", - ", ".join([i.print_name for i in n.inputs]) + ")", - ) - - def summary(self, print_lines: bool = False, callback=GraphSummaryCallback()) -> None: - type_count = collections.Counter() - results = {} - - def get_name(v): - if v not in results: - idx = type_count[type(v)] - type_count[type(v)] += 1 - prefix = {PhiValue: "𝜙", Value: "V"}.get(type(v), type(v).__name__) - results[v] = (prefix, idx) - - # Populate cache - if isinstance(v, PhiValue): - _ = [get_name(vi) for vi in v.values] - if v.parent is not None: - _ = get_name(v.parent) - - return "{}_{}".format(*results[v]) - - graph_lines = [] - legend_lines = [] - - block_indices = {bl: i for i, bl in enumerate(self.blocks)} - block_jump_indices = {bl.nodes[-1]: i for i, bl in enumerate(self.blocks)} - block_jump_indices[None] = None - - for block in self.blocks: - graph_lines.extend( - ( - f"Block {block_indices[block]} reached from blocks {[block_jump_indices.get(js, 'unknown') for js in block.jump_sources]}", - f"Block inputs: {[get_name(i) for i in block.block_inputs]}", - f"Block outputs: {[get_name(i) for i in block.block_outputs]}", - ) - ) - for i, node in enumerate(block.nodes): - if ( - i == 0 - or node.source_infos - and ( - (not block.nodes[i - 1].source_infos) - or node.source_infos[-1] != block.nodes[i - 1].source_infos[-1] - ) - ): - line_no = node.source_infos[-1].gen_line_no - line = f"# l{line_no + self.source_start_line:3d} {self.source_lines[line_no].rstrip()}" - else: - line = "" - lines_before, lines_after = callback.node(node) - graph_lines.extend(lines_before) - graph_lines.append( - f" {node.i.opname:<20} {f'{[get_name(v) for v in node.inputs]} -> {[get_name(v) for v in node.outputs]}':<80} {line}" - ) - graph_lines.extend(lines_after) - graph_lines.append("") - graph_lines.extend(callback.finish()) - - for v, (prefix, idx) in sorted(results.items(), key=lambda x: x[1]): - values = f"[{', '.join(get_name(vi) for vi in v.values)}]" if isinstance(v, PhiValue) else "" - legend_lines.append(f"{prefix}_{idx} {v.__str__(_value_printer=get_name):<16} {values}") - - if print_lines: - print("\n".join(graph_lines) + "\n" + "\n".join(legend_lines)) - - return tuple(graph_lines), tuple(legend_lines) - - -def unify_values(values: list[Value], jump_sources: list[Node], bl: Block, all_predecessors_done: bool = True) -> Value: - if all_predecessors_done: - if len(values) == 1: - return values[0] - val = values[0] - if all(v is val for v in values[1:]): - return val - # different values - return PhiValue(values, jump_sources, bl) - - -def insert_before(new_n: Node, n: Node) -> None: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - bl.nodes.insert(idx, new_n) - new_n.block = n.block - - -def insert_after(new_n: Node, n: Node) -> None: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - bl.nodes.insert(idx + 1, new_n) - new_n.block = n.block - - -def replace_values(gr_or_bl: Graph | Block, value_map: dict[Value, Value], follow_phi_values: bool = False) -> None: - ### Replacing a value: - # - as inputs/outputs of nodes - # - value.parent for other values - # - phi nodes - # - graph input (?) / initial vars - processed = set() - - def map_values(v: Value) -> Value: - # do not call map_values without guarding for infinite recursion - if v in processed: - return value_map.get(v, v) - processed.add(v) - - if v in value_map: - if follow_phi_values: - for pv in v.phi_values[:]: - pv.replace_value(v, value_map[v]) - assert len(pv.values) == len(pv.jump_sources) - return value_map[v] - - if isinstance(v.value, MROAwareObjectRef): - v.value.obj = map_values(v.value.obj) - if v.parent is not None: - v.parent = map_values(v.parent) - if isinstance(v, PhiValue): - for ov in v.values: - nv = map_values(ov) - v.replace_value(ov, nv) - assert len(v.values) == len(v.jump_sources) - return v - - def process_block(bl: Block) -> None: - bl.block_inputs = [map_values(vv) for vv in bl.block_inputs] - for n in bl.nodes: - n.inputs = [map_values(vv) for vv in n.inputs] - n.outputs = [map_values(vv) for vv in n.outputs] - bl.block_outputs = OrderedSet(map_values(vv) for vv in bl.block_outputs) - - if isinstance(gr_or_bl, Graph): - for bl in gr_or_bl.blocks: - process_block(bl) - elif isinstance(gr_or_bl, Block): - process_block(gr_or_bl) - else: - raise TypeError("replace_values works on Graph or Block objects") - - -## TODO: our should this be a method? -def make_dot(gr: Graph, format: str = "png", add_names: bool = False) -> "graphviz.Digraph": - import graphviz - - dot = graphviz.Digraph(name="thunder_graph", format=format) - - block_idxes = {} - - value_idxes: dict[Value, int] = {} - - for i_bl, bl in enumerate(gr.blocks): - block_idxes[bl] = i_bl - with dot.subgraph(name=f"cluster_bl_{i_bl}") as sub_dot: - for i_i, i in enumerate(bl.block_inputs): - i_nr = len(value_idxes) - value_idxes[i] = i_nr - i_name = f"bi %{i_nr}" - if add_names: - i.name = i_name - v_color = "black" if i not in bl.block_outputs else "red" - sub_dot.node(f"v {i_nr}", label=i_name, color=v_color) - - for i_n, n in enumerate(bl.nodes): - label = n.i.opname - if n.i.opname == "CALL_METHOD": - assert n.inputs[0].name is not None - label = "CM " + n.inputs[0].name - elif n.i.opname == "CALL_FUNCTION" and n.inputs[0].name: - label = "CF " + n.inputs[0].name - sub_dot.node(f"i {i_bl} {i_n}", label, shape="box") - for o in n.outputs: - if o not in value_idxes: - o_nr = len(value_idxes) - value_idxes[o] = o_nr - o_name = o.name or f"%{o_nr}" - if add_names: - o.name = o_name - v_color = "black" if o not in bl.block_outputs else "red" - sub_dot.node(f"v {o_nr}", label=o_name, color=v_color) - else: - o_nr = value_idxes[o] - sub_dot.edge(f"i {i_bl} {i_n}", f"v {o_nr}", color="blue") - if i_n > 0: - sub_dot.edge(f"i {i_bl} {i_n - 1}", f"i {i_bl} {i_n}") - - for i_bl, bl in enumerate(gr.blocks): - for jt_bl in bl.nodes[-1].jump_targets: - dot.edge(f"i {i_bl} {len(bl.nodes) - 1}", f"i {block_idxes[jt_bl]} {0}") - for i in bl.block_inputs: - i_idx = value_idxes[i] - if isinstance(i, PhiValue): - for v in i.values: - if v in value_idxes: - dot.edge(f"v {value_idxes[v]}", f"v {i_idx}", color="green") - - for i_n, n in enumerate(bl.nodes): - for i in n.inputs: - if i in value_idxes: - dot.edge(f"v {value_idxes[i]}", f"i {i_bl} {i_n}", color="blue") - elif isinstance(i, PhiValue): - assert False, "This should be removed?" - for v in i.values: - if v in value_idxes: - dot.edge(f"v {value_idxes[v]}", f"i {i_bl} {i_n}", color="red") - - return dot - - -def clone_blocks( - blocks_to_clone: list[Block], translation_dict: dict[GraphObject, GraphObject] | None = None -) -> tuple[list[Block], dict[GraphObject, GraphObject]]: - if translation_dict is None: - translation_dict = {} - - blocks_todo = [] - for obl in blocks_to_clone: - if obl not in translation_dict: - bl = Block() - translation_dict[obl] = bl - blocks_todo.append(obl) - - for obl in blocks_todo: - bl = assert_block(translation_dict[obl]) - bl.block_inputs = [i.clone(translation_dict=translation_dict) for i in obl.block_inputs] - bl.block_outputs = OrderedSet(o.clone(translation_dict=translation_dict) for o in obl.block_outputs) - bl.nodes = [n.clone(translation_dict=translation_dict) for n in obl.nodes] - for obl in blocks_todo: - bl = assert_block(translation_dict[obl]) - for js in obl.jump_sources: - if js is None: - bl.jump_sources.append(None) - elif js in translation_dict: - bl.jump_sources.append(assert_node(translation_dict[js])) - - for i in bl.block_inputs: - i.post_process_clone(translation_dict=translation_dict) - return [assert_block(translation_dict[bl]) for bl in blocks_to_clone], translation_dict - - -def _check_graph(gr: Graph) -> None: - # some sanity checks for the values - import collections - - phi_value_refs: dict[PhiValue, list[Value | tuple[Value, Node | None]]] = collections.defaultdict(list) - v: Value - known_nodes: set[Node] = set() - for bl in gr.blocks: - known_values: set[Value] = set(bl.block_inputs) - for i in bl.block_inputs: - for v in i.phi_values: - phi_value_refs[v].append(i) - for n in bl.nodes: - known_nodes.add(n) - assert n.source_infos, f"{n}({n.inputs}) does not have source infos" - n.block = bl - for i in n.inputs: - i_or_p = i - while not (i_or_p in known_values or i_or_p.is_const or i_or_p.is_global): - if i_or_p.parent is not None: - i_or_p = i_or_p.parent - else: - raise RuntimeError(f"unknown value {repr(i_or_p)} needed in {n}") - - for o in n.outputs: - known_values.add(o) - # inplace modified values are not re-assigned. should they, likely: yes - if o not in n.inputs: - for v in o.phi_values: - phi_value_refs[v].append((o, n)) - for o in bl.block_outputs: - is_attr = False - o_or_parent = o - while o_or_parent not in known_values and o_or_parent.parent is not None: - o_or_parent = o_or_parent.parent - is_attr = True - if is_attr: - for v in o.phi_values: - phi_value_refs[v].append((o, None)) - assert ( - o_or_parent in known_values or o_or_parent.is_const or o_or_parent.is_global - ), f"{o_or_parent} (from {o}) unknown {known_values=}" - - for bl in gr.blocks: - for i in bl.block_inputs: - assert isinstance(i, PhiValue) - assert len(i.jump_sources) == len(i.values) - assert len(i.values) > 0 - # assert i.block is bl - pvr = phi_value_refs.get(i, []) - assert len([v for v in i.values if not (v.is_function_arg or v.is_const or v.is_global)]) == len( - pvr - ), f"phi value {repr(i)} source count {len(i.values)} does not match sets {pvr}, {i.values}" - if i in phi_value_refs: # not for function args in first block - del phi_value_refs[i] - for v in i.values: - assert i in v.phi_values, f"phi value {repr(i)} not in phi_values of {repr(v)}" - for js in i.jump_sources: - assert js is None or js in known_nodes, f"phi value {repr(i)} jump source not found in graph {repr(js)}" - - assert not phi_value_refs, f"phi_values not found {phi_value_refs}" - - jump_targets: dict[Node | None, set[Block]] = {} - jump_targets[None] = {gr.blocks[0]} # function entry point - - for bl in gr.blocks: - for n in bl.nodes[:-1]: - assert not n.jump_targets - n = bl.nodes[-1] - if n.i.opname in {"RETURN_VALUE", "RAISE_VARARGS", "RERAISE"}: - assert not n.jump_targets - else: - assert 1 <= len(n.jump_targets) <= 2, f"{n} should have one or two ump targets, but has {n.jump_targets}" - jump_targets[n] = {jt for jt in n.jump_targets} - assert len(n.jump_targets) == len(jump_targets[n]) - - for bl in gr.blocks: - for js in bl.jump_sources: - js_jt = jump_targets[js] - js_jt.remove(bl) - - assert not any(jump_targets.values()), f"{jump_targets} should be all empty" - assert tuple(gr.blocks[0].jump_sources) == (None,), gr.blocks[0].jump_sources - - -def repr_source_location(gr: Graph, source_infos: list[SourceInformation]): - l = [] - for si in source_infos: - l.append(f"file: {si.orig_file_name}, line {si.orig_line_no}:") - ls = linecache.getlines(si.orig_file_name) - l.append(ls[max(si.orig_line_no - 1, 0)].rstrip()) - return "\n".join(l) - - -def check_graph(gr: Graph) -> None: - try: - _check_graph(gr) - cloned, _ = gr.clone() - _check_graph(cloned) - except BaseException: - print() - gr.summary(print_lines=True) - raise - - -def _generate_raises(msg): - @noinline - def _raise(): - raise AttributeError(msg) - - return _raise diff --git a/thunder/core/script/instrumentation.py b/thunder/core/script/instrumentation.py deleted file mode 100644 index d7c7ef182b..0000000000 --- a/thunder/core/script/instrumentation.py +++ /dev/null @@ -1,145 +0,0 @@ -import contextlib -import functools -import inspect -import logging -import threading -import typing - -from thunder.core.utils import debug_asserts_enabled - - -T = typing.TypeVar("T") -_STORAGE = threading.local() - - -def _lookup_state(name: str, factory: typing.Callable[[], T]) -> T: - if not hasattr(_STORAGE, name): - setattr(_STORAGE, name, factory()) - return getattr(_STORAGE, name) - - -get_stack = functools.partial(_lookup_state, "stack", list) -get_init_ctx = functools.partial(_lookup_state, "init_ctx", dict) -get_error_ctx = functools.partial(_lookup_state, "error_ctx", list) -get_logger = functools.partial(_lookup_state, "logger", lambda: logging.error) - - -class InstrumentingBase: - def __new__(cls, *_, **__) -> "InstrumentingBase": - self = super().__new__(cls) - if stack := get_stack(): - get_init_ctx()[id(self)] = (self, tuple(stack)) - - return self - - def _concise_repr(self) -> str: - return f"<{self.__class__.__name__} object at {hex(id(self))}>" - - -def emit_ctx(v, follow_delegates: bool): - for f, args, kwargs, delegate_to in reversed(get_init_ctx()[id(v)][1]): - signature = inspect.signature(f) - bound = signature.bind(*args, **kwargs) - bound.apply_defaults() - - def fmt_arg(k, v): - if v is signature.parameters[k].default: - return "..." - - if isinstance(v, InstrumentingBase): - v_repr = v._concise_repr() - elif callable(v) and hasattr(v, "__name__"): - v_repr = v.__name__ - else: - v_repr = repr(v) - - return v_repr - - if delegate_to is None or not follow_delegates: - arg_str = ", ".join(fmt_arg(k, v) for k, v in bound.arguments.items()) - yield f" {f.__name__:<30} {arg_str}" - - else: - x = bound.arguments[delegate_to] - yield f" {f.__name__:<30} {fmt_arg(delegate_to, x)}" - yield from emit_ctx(x, follow_delegates) - break - - -def maybe_flush_errors(): - if not get_stack() and (error_ctx := get_error_ctx()): - get_logger()("\n".join(reversed(error_ctx)) + "\n") - error_ctx.clear() - - -@contextlib.contextmanager -def intercept_errors(): - prior_logger = get_logger() - errors = [] - try: - _STORAGE.logger = lambda s: errors.append(s) - yield errors - finally: - _STORAGE.logger = prior_logger - - -def verbose_error(f): - @functools.wraps(f) - def wrapped(*args, **kwargs): - if not debug_asserts_enabled(): - return f(*args, **kwargs) - - try: - return f(*args, **kwargs) - except BaseException as e: - bound = inspect.signature(f).bind(*args, **kwargs) - bound.apply_defaults() - - f_name = f"| f.__name__ = {f.__name__} |" - lines = [f"\n{'-' * len(f_name)}\n{f_name}\n{'-' * len(f_name)}\n"] - for k, v in bound.arguments.items(): - lines.append(f"Argument(`{k}`):\n {v}\n") - if id(v) in get_init_ctx(): - lines.extend( - [ - "Context (raw):", - *reversed(tuple(emit_ctx(v, follow_delegates=False))), - "\nContext (augmented):", - *reversed(tuple(emit_ctx(v, follow_delegates=True))), - ] - ) - - get_error_ctx().append("\n".join(lines)) - maybe_flush_errors() - raise - - return wrapped - - -def record(delegate_to: str | None | typing.Callable = None): - # Hack to allow you to to decorate with `@record` instead of `@record()`. - if callable(delegate_to): - return record()(delegate_to) - - def wrapper(f): - f_verbose = verbose_error(f) - - @functools.wraps(f) - def wrapped(*args, **kwargs): - if not debug_asserts_enabled(): - return f(*args, **kwargs) - - stack = get_stack() - try: - stack.append((f, args, kwargs, delegate_to)) - return f_verbose(*args, **kwargs) - - finally: - _ = stack.pop() - maybe_flush_errors() - if not stack: - get_init_ctx().clear() - - return wrapped - - return wrapper diff --git a/thunder/core/script/mypy-strict.ini b/thunder/core/script/mypy-strict.ini deleted file mode 100644 index 1637b7df92..0000000000 --- a/thunder/core/script/mypy-strict.ini +++ /dev/null @@ -1,52 +0,0 @@ -# Forked from PyTorch's mypy-strict.ini file. -# It enforces very strict typing rules. - -[mypy] -python_version = 3.10 - -cache_dir = .mypy_cache/strict -allow_redefinition = True -strict_optional = True -show_error_codes = True -show_column_numbers = True -warn_no_return = True -disallow_any_unimported = True - -# Across versions of mypy, the flags toggled by --strict vary. To ensure -# we have reproducible type check, we instead manually specify the flags -warn_unused_configs = True -disallow_any_generics = True -disallow_subclassing_any = True -disallow_untyped_calls = True -disallow_untyped_defs = True -disallow_incomplete_defs = True -check_untyped_defs = True -disallow_untyped_decorators = True -no_implicit_optional = True -warn_redundant_casts = True -warn_return_any = True -implicit_reexport = False -strict_equality = True - -# do not re-enable this: -# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 -warn_unused_ignores = False - -files = - thunder/core/script/algorithms.py, - thunder/core/script/protograph.py, - thunder/core/script/protograph_passes.py, - thunder/core/script/parse, - thunder/core/script/values - -[mypy-thunder.core.utils] -follow_imports = silent - -[mypy-thunder] -follow_imports = skip - -[mypy-thunder.*] -follow_imports = skip - -[mypy-networkx] -ignore_missing_imports = True diff --git a/thunder/core/script/noinline.py b/thunder/core/script/noinline.py deleted file mode 100644 index 03dbee5e7b..0000000000 --- a/thunder/core/script/noinline.py +++ /dev/null @@ -1,38 +0,0 @@ -from contextvars import ContextVar -from collections.abc import Callable - - -NOINLINE_METHODS: ContextVar[set[Callable]] = ContextVar("NOINLINE_METHODS", default=set()) - - -def noinline(f: Callable) -> Callable: - """ - Function/Decorator to prevent preprocessing from inlining the function. - - Example: - >>> @noinline - >>> def foo(x): - >>> return x + 1 - >>> def bar(x): - >>> return foo(x) + 1 - >>> thunder.compile(bar) - """ - - NOINLINE_METHODS.get().add(f) - return f - - -@noinline -def invoke_noinline(f: Callable) -> Callable: - """ - Function to prevent preprocessing from inlining a single invocation of a function. - - Example: - >>> def foo(x): - >>> return x + 1 - >>> def bar(x): - >>> return invoke_noinline(foo)(x) + 1 - >>> thunder.compile(bar) - """ - - return f diff --git a/thunder/core/script/overview.ipynb b/thunder/core/script/overview.ipynb deleted file mode 100644 index 152d98a6e2..0000000000 --- a/thunder/core/script/overview.ipynb +++ /dev/null @@ -1,331 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "5c236b7e-9191-4ead-8d76-cd789c09f810", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "thunder_path = os.path.abspath(os.path.join(os.path.abspath(\"\"), \"..\", \"..\", \"..\"))\n", - "if thunder_path not in sys.path:\n", - " sys.path.append(thunder_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "332a649c-cf11-4803-ae91-fff31d31d359", - "metadata": {}, - "outputs": [], - "source": [ - "import thunder" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "02373250-c7b0-446b-ad21-6e847d58feec", - "metadata": {}, - "outputs": [], - "source": [ - "def masked_apply(x, mask, layer_0, layer_1):\n", - " x = layer_0(x, mask)\n", - " x = layer_1(x, mask if mask is not None else 1)\n", - " return x, mask is None" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9c345326-89ae-4efa-9657-545ae3eeb215", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.script.frontend import _construct_protograph, acquire_method\n", - "proto_graph = _construct_protograph(masked_apply)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "2205c64a-caa8-4913-81ce-a43b10ffb1d1", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.script import parse, values\n", - "\n", - "provenance = []\n", - "last = proto_graph.provenance\n", - "while isinstance(last, tuple):\n", - " provenance.append(last)\n", - " transform, prior_proto_graph = last\n", - " last = prior_proto_graph.provenance\n", - "assert isinstance(last, values.ParsedSymbolic)\n", - "provenance.extend((last, last.provenance, last.provenance.provenance))\n", - "provenance.reverse()\n", - "\n", - "disassembled, parsed, parsed_symbolic, *protograph_transforms = provenance\n", - "assert isinstance(disassembled, parse.Disassembled)\n", - "assert isinstance(parsed, parse.ParsedFunctional)\n", - "assert isinstance(parsed_symbolic, values.ParsedSymbolic)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b0178f36-4774-483b-8d37-e2271c42f2de", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "\n", - "\n", - " 2 0 LOAD_FAST 2 (layer_0) || 0) LOAD_FAST\n", - " 2 LOAD_FAST 0 (x) || LOAD_FAST\n", - " 4 LOAD_FAST 1 (mask) || LOAD_FAST\n", - " 6 CALL_FUNCTION 2 || CALL_FUNCTION\n", - " 8 STORE_FAST 0 (x) || STORE_FAST\n", - " || LOAD_FAST\n", - " 3 10 LOAD_FAST 3 (layer_1) || LOAD_FAST\n", - " 12 LOAD_FAST 0 (x) || LOAD_FAST\n", - " 14 LOAD_FAST 1 (mask) || LOAD_CONST\n", - " 16 LOAD_CONST 0 (None) || IS_OP\n", - " 18 IS_OP 1 || POP_JUMP_IF_FALSE\n", - " 20 POP_JUMP_IF_FALSE 13 (to 26) || \n", - " 22 LOAD_FAST 1 (mask) || 1) LOAD_FAST\n", - " 24 JUMP_FORWARD 1 (to 28) || JUMP_FORWARD\n", - " >> 26 LOAD_CONST 1 (1) || \n", - " >> 28 CALL_FUNCTION 2 || 2) LOAD_CONST\n", - " 30 STORE_FAST 0 (x) || JUMP_ABSOLUTE\n", - " || \n", - " 4 32 LOAD_FAST 0 (x) || 3) CALL_FUNCTION\n", - " 34 LOAD_FAST 1 (mask) || STORE_FAST\n", - " 36 LOAD_CONST 0 (None) || LOAD_FAST\n", - " 38 IS_OP 0 || LOAD_FAST\n", - " 40 BUILD_TUPLE 2 || LOAD_CONST\n", - " 42 RETURN_VALUE || IS_OP\n", - " || BUILD_TUPLE\n", - " || RETURN_VALUE\n", - " || \n" - ] - } - ], - "source": [ - "import dis\n", - "import io\n", - "import itertools\n", - "\n", - "print(disassembled.code, \"\\n\\n\")\n", - "dis.dis(disassembled.code, file=(buffer := io.StringIO()))\n", - "buffer.seek(0)\n", - "dis_lines = buffer.read().splitlines(False)\n", - "\n", - "block_lines = []\n", - "for idx, block in enumerate(disassembled.raw.values()):\n", - " block_lines.extend(f\"{'' if idy else f'{idx})':<4}{instruction.opname}\" for idy, instruction in enumerate(block))\n", - " block_lines.append(\"\")\n", - "\n", - "pad = max(len(l) for l in dis_lines)\n", - "for dis_line, block_line in itertools.zip_longest(dis_lines, block_lines, fillvalue=\"\"):\n", - " print(f\"{dis_line:<{pad + 10}} ||{' ' * 10}{block_line}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "dec0f942-0915-4cf5-a79a-9e37fd2f4e6f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Block 0: [] => [layer_1, v0]\n", - " LOAD[layer_0, x, mask]\n", - " CALL_FUNCTION . . . . . . . . . . . . . . (layer_0, x, mask) -> v0\n", - " STORE[x]\n", - " LOAD[layer_1, x, mask, None: CONST]\n", - " IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1\n", - " POP_JUMP_IF_FALSE . . . . . . . . . . . . (v1) -> \n", - " -> 1, 2(Jump)\n", - "\n", - "Block 1: [⓵ , ⓶ ] => [⓵ , ⓶ , mask]\n", - " LOAD[mask]\n", - " JUMP_FORWARD\n", - " -> 3(Jump)\n", - "\n", - "Block 2: [⓵ , ⓶ ] => [⓵ , ⓶ , 1]\n", - " LOAD[1: CONST]\n", - " JUMP_ABSOLUTE*\n", - " -> 3(Jump)\n", - "\n", - "Block 3: [⓵ , ⓶ , ⓷ ] => []\n", - " CALL_FUNCTION . . . . . . . . . . . . . . (⓵ , ⓶ , ⓷ ) -> v0\n", - " STORE[x]\n", - " LOAD[x, mask, None: CONST]\n", - " IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1\n", - " BUILD_TUPLE . . . . . . . . . . . . . . . (v0, v1) -> v2\n", - " RETURN_VALUE . . . . . . . . . . . . . . (v2) -> \n", - "\n" - ] - } - ], - "source": [ - "print(parsed.summary)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "3a1cdbfe-996f-48c2-a5d4-b05389d2316d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CALL_FUNCTION (layer_0(LOCAL), x(LOCAL), mask(LOCAL)) -> (IntermediateValue(at 0x7f064f0c26b0),)\n", - "IS_OP (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0c3070),)\n", - "POP_JUMP_IF_FALSE (OutputRef(IS_OP, idx=0)) -> ()\n", - "\n", - "JUMP_FORWARD () -> ()\n", - "\n", - "JUMP_ABSOLUTE () -> ()\n", - "\n", - "CALL_FUNCTION (0(STACK), 1(STACK), 2(STACK)) -> (IntermediateValue(at 0x7f064f0f0040),)\n", - "IS_OP (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0f0250),)\n", - "BUILD_TUPLE (OutputRef(CALL_FUNCTION, idx=0), OutputRef(IS_OP, idx=0)) -> (IntermediateValue(at 0x7f064f0f0460),)\n", - "RETURN_VALUE (OutputRef(BUILD_TUPLE, idx=0)) -> ()\n", - "\n" - ] - } - ], - "source": [ - "def pretty_repr(x) -> str:\n", - " if isinstance(x, parse.VariableKey):\n", - " return f\"{x.identifier}({x.scope.name})\"\n", - " if isinstance(x, values.OutputRef):\n", - " return f\"OutputRef({x.instruction.opname}, idx={x.idx})\"\n", - " return repr(x)\n", - "\n", - "for block, begin, end in parsed_symbolic.blocks:\n", - " # At this point `begin` isn't very interesting as it's all just placeholders.\n", - " for instruction, symbolic in block.items():\n", - " \n", - " inputs = \", \".join(pretty_repr(i) for i in symbolic.inputs.ordered)\n", - " print(f\"{instruction.opname:<25} ({inputs}) -> {symbolic.outputs}\")\n", - " print()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "8daca4b6-31b3-4d8e-8be2-9744b44d8f4f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unlink\n", - "MarkTuples\n", - "AddTransitive\n", - "Connect\n", - "\n", - "================================================================================\n", - " ProtoBlock: 0x7f064f0f1cf0\n", - " CALL_FUNCTION\n", - " IS_OP\n", - " POP_JUMP_IF_FALSE\n", - "\n", - "ProtoBlock: 0x7f064ef1e560\n", - " JUMP_FORWARD\n", - "\n", - "ProtoBlock: 0x7f064ef81b10\n", - " JUMP_ABSOLUTE\n", - "\n", - "ProtoBlock: 0x7f064ef9cfd0\n", - " CALL_FUNCTION\n", - " IS_OP\n", - " BUILD_TUPLE\n", - " RETURN_VALUE\n" - ] - } - ], - "source": [ - "for transform, _ in protograph_transforms:\n", - " print(transform.__name__)\n", - "\n", - "print(f\"\\n{'=' * 80}\\n\", proto_graph)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5c56fbcf-13db-4749-940b-ecbb0dfd0af5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Graph of\n", - " Block (reached from [None])\n", - " CALL_FUNCTION([, , ])\n", - " IS_OP 1 (1)\n", - " POP_JUMP_IF_FALSE 13 (26)\n", - " Block (reached from [])\n", - " JUMP_FORWARD 1 (28)\n", - " Block (reached from [])\n", - " JUMP_ABSOLUTE 14 (None)\n", - " Block (reached from [, ])\n", - " CALL_FUNCTION([, , ])\n", - " IS_OP 0 (0)\n", - " BUILD_TUPLE 2 (2)\n", - " RETURN_VALUE None (None)\n" - ] - } - ], - "source": [ - "g = acquire_method(masked_apply)\n", - "print(g)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22ef39c9-04aa-40ff-acc2-1387a3f372f4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/thunder/core/script/parse/__init__.py b/thunder/core/script/parse/__init__.py deleted file mode 100644 index c29adb91a6..0000000000 --- a/thunder/core/script/parse/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from thunder.core.script.parse.disassemble import * -from thunder.core.script.parse.functionalize import * -from thunder.core.script.parse.instructions import * -from thunder.core.script.parse.stack_effect import * - -# This will be populated as parse-time narrowing is introduced. -FORBIDDEN_INSTRUCTIONS = InstructionSet() diff --git a/thunder/core/script/parse/disassemble.py b/thunder/core/script/parse/disassemble.py deleted file mode 100644 index 3ae5ae1bbe..0000000000 --- a/thunder/core/script/parse/disassemble.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Convert a `CodeType` object into a series of simple blocks.""" -from __future__ import annotations - -import dataclasses -import dis -import itertools -from types import CodeType -from typing import Any, NewType, TypeVar -from collections.abc import Iterable, Mapping - -from thunder.core.script.parse import stack_effect -from thunder.core.script.parse.instructions import * # There are a lot of constants, and it defines `__all__` - - -__all__ = ("Disassembled", "ParseDetailInstruction", "EpilogueFixup", "Jump") - -BlockIdx = NewType("BlockIdx", int) -Jump = NewType("Jump", bool) - - -@dataclasses.dataclass -class Disassembled: - code: CodeType - - _StartIndex = NewType("_StartIndex", int) - _RawBlocks = dict[_StartIndex, tuple[ThunderInstruction, ...]] - raw: RawBlocks - - _Blocks = tuple[tuple["ThunderInstruction", ...], ...] - blocks: _Blocks - - _Edges = tuple[tuple[BlockIdx, BlockIdx, Jump], ...] - edges: _Edges - - @classmethod - def make(cls, co: CodeType) -> Disassembled: - raw_blocks, last_line_no = partition(co) - blocks, edges = connect_blocks(consolidate_returns(raw_blocks)) - for instruction in itertools.chain(*blocks): - instruction.line_no = getattr(instruction, "line_no", last_line_no) - return cls(code=co, raw=raw_blocks, blocks=blocks, edges=edges) - - -class ParseDetailInstruction(ThunderInstruction): - """Allow us to distinguish instructions that are added during parsing.""" - - pass - - -class EpilogueFixup(ParseDetailInstruction): - pass - - -def compute_jump(instruction: ThunderInstruction, position: int) -> int | None: - if instruction in ABSOLUTE_JUMP_INSTRUCTIONS: - return instruction.oparg - - elif instruction in UNCONDITIONAL_BACKWARD: - return position + 1 - instruction.oparg - - elif "BACKWARD" in instruction.opname: - # TODO: POP_JUMP_BACKWARD_IF_... variants - raise NotImplementedError(instruction.opname) - - elif instruction in RELATIVE_JUMP_INSTRUCTIONS: - return position + 1 + instruction.oparg - - return None - - -IntT = TypeVar("IntT", bound=int, covariant=True) -StartIndex = NewType("StartIndex", int) -RawBlocks = dict[StartIndex, tuple[ThunderInstruction, ...]] - - -def get_free_key(x: Mapping[IntT, Any]) -> int: - key = -len(x) - while key in x: - key -= 1 - return key - - -def partition(co: CodeType) -> tuple[RawBlocks, int]: - bytecode = tuple(ThunderInstruction(*i) for i in dis.get_instructions(co, first_line=0)) - - # Determine the boundaries for the simple blocks. - split_after = JUMP_INSTRUCTIONS | RETURN_INSTRUCTIONS - follows_jump = itertools.chain([0], (int(i in split_after) for i in bytecode)) - new_block = (int(i or j.is_jump_target) for i, j in zip(follows_jump, bytecode)) - - # Split the bytecode (and instruction number) into groups - group_indices = tuple(itertools.accumulate(new_block)) - groups = itertools.groupby(enumerate(bytecode), key=lambda args: group_indices[args[0]]) - - # Drop the group index, copy from the groupby iter, and unzip `enumerate`. - groups = (zip(*tuple(i)) for _, i in groups) - blocks: dict[StartIndex, list[ThunderInstruction]] = { - StartIndex(start): list(block) for (start, *_), block in groups - } - - # If the last instruction is not a jump or return (which means we split - # because the next instruction was a jump target) then we need to tell - # the current block how to advance. - for start, block in blocks.items(): - if block[-1] not in split_after: - next_start = StartIndex(start + len(block)) - assert bytecode[next_start].is_jump_target - block.append(ParseDetailInstruction.make_jump_absolute(next_start)) - - line_no = 1 - for instruction in itertools.chain(*[block for block in blocks.values()]): - instruction.line_no = line_no = instruction.starts_line or line_no - - return {k: tuple(v) for k, v in blocks.items()}, line_no - - -def consolidate_returns(blocks: RawBlocks) -> RawBlocks: - def is_return(block: tuple[ThunderInstruction, ...]) -> bool: - assert block and not any(i.opname == RETURN_VALUE for i in block[:-1]) - return block[-1].opname == RETURN_VALUE - - blocks = blocks.copy() - return_blocks = {k: v for k, v in blocks.items() if is_return(v)} - if len(return_blocks) > 1: - new_return_start = StartIndex(get_free_key(blocks)) - for start, (*body, prior_return) in return_blocks.items(): - assert is_return((prior_return,)), prior_return - blocks[start] = (*body, ParseDetailInstruction.make_jump_absolute(new_return_start)) - return_blocks = {new_return_start: (ParseDetailInstruction.make_return(is_jump_target=True),)} - - # Move return block to the end. This isn't always valid (since a block might - # expect to fall through and reach it), but that will be resolved by the - # sort in `ProtoGraph`'s ctor. - blocks = {k: v for k, v in blocks.items() if k not in return_blocks} - blocks.update(return_blocks) - return blocks - - -def connect_blocks(blocks: RawBlocks) -> tuple[Disassembled._Blocks, Disassembled._Edges]: - def iter_raw_edges(blocks: RawBlocks) -> Iterable[tuple[StartIndex, StartIndex, Jump, int, int]]: - for start, block in tuple(blocks.items()): - raw_block_len = sum(not isinstance(i, ParseDetailInstruction) for i in block) - *_, last_i = block - if last_i in JUMP_INSTRUCTIONS: - end = start + raw_block_len - 1 - _, (push_nojump, push_jump) = stack_effect.stack_effect_detail(last_i) - if last_i not in UNCONDITIONAL_JUMP_INSTRUCTIONS: - yield start, StartIndex(end + 1), Jump(False), max(push_jump - push_nojump, 0), last_i.line_no - - if (jump_offset := compute_jump(last_i, end)) is not None: - yield start, StartIndex(jump_offset), Jump(True), max(push_nojump - push_jump, 0), last_i.line_no - - blocks = blocks.copy() - edges: list[tuple[StartIndex, StartIndex, Jump]] = [] - for source, destination, jump, pop_suffix, line_no in iter_raw_edges(blocks): - if pop_suffix: - blocks[epilogue := StartIndex(get_free_key(blocks))] = ( - *(EpilogueFixup.make(POP_TOP, None, line_no=line_no) for _ in range(pop_suffix)), - EpilogueFixup.make_jump_absolute(destination, line_no=line_no), - ) - edges.extend(((source, epilogue, jump), (epilogue, destination, jump))) - else: - edges.append((source, destination, jump)) - - to_idx = {k: BlockIdx(idx) for idx, k in enumerate(blocks.keys())} - return tuple(blocks.values()), tuple((to_idx[source], to_idx[sink], jump) for source, sink, jump in edges) diff --git a/thunder/core/script/parse/functionalize.py b/thunder/core/script/parse/functionalize.py deleted file mode 100644 index 9664271918..0000000000 --- a/thunder/core/script/parse/functionalize.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Replay the CPython stack machine to determine data flow within a simple block.""" -from __future__ import annotations - -import dataclasses -import enum -import inspect -import itertools -import marshal -import textwrap -from types import CodeType -from typing import Any, NamedTuple, NewType - -import networkx as nx - -from thunder.core.script import algorithms -from thunder.core.script.parse import disassemble, instructions, stack_effect -from thunder.core.utils import safe_zip, FrozenDict, InferringDict - -__all__ = ("VariableScope", "VariableKey", "ParsedFunctional", "FunctionalizedBlock", "PlaceholderValue") - - -class VariableScope(enum.Enum): - CONST = enum.auto() - LOCAL = enum.auto() - NONLOCAL = enum.auto() - GLOBAL = enum.auto() - STACK = enum.auto() - - -class VariableKey(NamedTuple): - """Denotes the location of a variable. - For example, `x = 5` assigns the variable stored in `VariableKey(5, VariableScope.CONST)` - to the location `VariableKey("x", VariableScope.LOCAL)`. (Provided `x` is a local variable.) - The type of `identifier` varies based on `scope`: - `marshal`able VariableScope.CONST - str VariableScope.LOCAL / NONLOCAL / GLOBAL - int VariableScope.STACK - Any VariableScope.BOUNDARY - """ - - identifier: Any - scope: VariableScope - - def __repr__(self) -> str: - return f"VariableKey({self.identifier}, scope={self.scope.name})" - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, VariableKey) - and self.scope == other.scope - and type(self.identifier) is (_ := type(other.identifier)) # Conflict between `ruff` and `yesqa` - and self.identifier == other.identifier - ) - - def __lt__(self, other: tuple[Any, ...]) -> bool: - assert isinstance(other, VariableKey), (self, other) - try: - return (self.scope.value, self.identifier) < (other.scope.value, other.identifier) - except TypeError: - assert self.scope == other.scope, (self, other) - if self.scope == VariableScope.CONST: - # We prefer to use native ordering. However for unorderable types (e.g. CodeType) - # `marshal` at least provides a consistent ordering. - return marshal.dumps(self.identifier) < marshal.dumps(other.identifier) - raise - - @property - def is_const(self) -> bool: - return self.scope == VariableScope.CONST - - -def _compute_stack_offsets(disassembled: disassemble.Disassembled) -> tuple[int, ...]: - # If we convert the stack indices to a common basis then we can ignore stack effect and - # treat VariableScope.STACK variables just like any other local. - G = algorithms.TypedDiGraph[disassemble.BlockIdx]((i, j) for i, j, _ in disassembled.edges) - G.add_nodes_from(range(len(disassembled.blocks))) - offsets: dict[disassemble.BlockIdx, int] = {i: 0 for i in G.nodes if not G.in_degree(i)} # type: ignore[misc] - assert len(offsets) == 1, G - - for source, sink in nx.edge_dfs(G): - net_stack_effect = 0 - for instruction in disassembled.blocks[source]: - pop, push_by_branch = stack_effect.stack_effect_detail(instruction) - net_stack_effect += max(push_by_branch) - pop - expected = offsets[source] + net_stack_effect - actual = offsets.setdefault(sink, expected) - assert actual == expected, (actual, expected) - - assert all(v >= 0 for v in offsets.values()), offsets - return tuple(offsets[disassemble.BlockIdx(i)] for i in range(len(disassembled.blocks))) - - -LOAD_OPNAMES = FrozenDict[str, VariableScope]( - LOAD_CONST=VariableScope.CONST, - LOAD_FAST=VariableScope.LOCAL, - LOAD_DEREF=VariableScope.NONLOCAL, - LOAD_CLOSURE=VariableScope.NONLOCAL, - LOAD_GLOBAL=VariableScope.GLOBAL, -) - -STORE_OPNAMES = FrozenDict[str, VariableScope]( - STORE_FAST=VariableScope.LOCAL, - STORE_DEREF=VariableScope.NONLOCAL, - STORE_GLOBAL=VariableScope.GLOBAL, -) - -DEL_OPNAMES = FrozenDict[str, VariableScope]( - DELETE_FAST=VariableScope.LOCAL, - DELETE_DEREF=VariableScope.NONLOCAL, - DELETE_GLOBAL=VariableScope.GLOBAL, -) - -PlaceholderValue = NewType("PlaceholderValue", str) -Inputs = NewType("Inputs", tuple[PlaceholderValue, ...]) -Outputs = NewType("Outputs", tuple[PlaceholderValue, ...]) -BeginState = FrozenDict["VariableKey", PlaceholderValue] -EndState = FrozenDict["VariableKey", PlaceholderValue | None] -FunctionalNode = tuple[instructions.ThunderInstruction, Inputs, Outputs] -FunctionalizedBlock = NewType("FunctionalizedBlock", tuple[tuple[FunctionalNode, ...], BeginState, EndState]) - - -@dataclasses.dataclass(frozen=True) -class ParsedFunctional: - blocks: tuple[FunctionalizedBlock, ...] - provenance: disassemble.Disassembled - - @staticmethod - def make(co: CodeType) -> ParsedFunctional: - disassembled = disassemble.Disassembled.make(co) - return ParsedFunctional(_functionalize_blocks(disassembled), disassembled) - - @property - def summary(self) -> str: - return _summarize(self) - - -def _functionalize_blocks(disassembled: disassemble.Disassembled) -> tuple[FunctionalizedBlock, ...]: - code = disassembled.code - errors: list[str] = [] - if code.co_cellvars: - errors.append( - "Nonlocal variables are not supported but\n" - f" {code.co_name}() defined in {code.co_filename}:{code.co_firstlineno}\n" - f" defines nonlocal variable{'s' if len(code.co_cellvars) > 1 else ''}: {', '.join(code.co_cellvars)}" - ) - - def report_unsupported(msg: str, instruction: instructions.ThunderInstruction) -> None: - source_lines, _ = inspect.getsourcelines(code) - errors.append( - f"{msg}{instruction} found\n" - f" {code.co_name}() defined in {code.co_filename}:{code.co_firstlineno}\n" - f" line {instruction.line_no + code.co_firstlineno}: {source_lines[instruction.line_no].rstrip()}" - ) - - name_arrays = FrozenDict[VariableScope, tuple[str, ...]]( - { - VariableScope.CONST: code.co_consts, - VariableScope.LOCAL: code.co_varnames, - VariableScope.NONLOCAL: (*code.co_cellvars, *code.co_freevars), - VariableScope.GLOBAL: code.co_names, - } - ) - - def to_key(instruction: instructions.ThunderInstruction, scope: VariableScope) -> VariableKey: - assert scope != VariableScope.STACK, "Indexing into the stack is not permitted." - assert scope in name_arrays, f"Unknown variable scope: {scope}" - if scope == VariableScope.NONLOCAL: - report_unsupported("nonlocal variables are not supported but instruction = ", instruction) - return VariableKey(name_arrays[scope][instruction.oparg], scope) - - def convert(block: tuple[instructions.ThunderInstruction, ...], stack_offset: int) -> FunctionalizedBlock: - stack: list[PlaceholderValue] = [PlaceholderValue(f"Initial_stack_{i}") for i in range(stack_offset)] - begin_variables = {VariableKey(idx, VariableScope.STACK): v for idx, v in enumerate(stack)} - end_variables = InferringDict[VariableKey, PlaceholderValue | None]( - lambda key: begin_variables.setdefault(key, PlaceholderValue(f"Initial: ({key.identifier} {key.scope})")) - ) - - assert block - functionalized: list[tuple[instructions.ThunderInstruction, Inputs, Outputs]] = [] - for idx, instruction in enumerate(block): - # These are already reflected in the next opcode's argument - if instruction.opname == instructions.EXTENDED_ARG: - continue - - elif instruction in instructions.UNSAFE_OPCODES: - # These are unsafe to run, but we should still be able to parse them. - report_unsupported("Unsupported instruction = ", instruction) - - pop, push_by_branch = stack_effect.stack_effect_detail(instruction) - push = max(push_by_branch) - - def assert_expected_stack_effects(pop_i: int, push_i: int) -> None: - assert (pop, push) == (pop_i, push_i), f"{instruction=} {pop=} {push=}" - - # Peek at the stack to track variable mutations. - if (store_scope := STORE_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(1, 0) - end_variables[to_key(instruction, store_scope)] = stack.pop() - - elif (del_scope := DEL_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(1, 0) - end_variables[to_key(instruction, del_scope)] = None - - elif (load_scope := LOAD_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(0, 1) - loaded = end_variables[load_key := to_key(instruction, load_scope)] - assert loaded is not None, f"Access to deleted variable: {load_key}, {instruction}" - stack.append(loaded) - - else: - # We have already functionalized variable accesses, so we can prune loads and stores. - inputs = tuple(stack.pop() for _ in range(pop)) - outputs = Outputs(tuple(PlaceholderValue(f"{idx}_{instruction.opname}__{idy}") for idy in range(push))) - stack.extend(outputs) - functionalized.append((instruction, Inputs(tuple(reversed(inputs))), outputs)) - - end_stack = {VariableKey(idx, VariableScope.STACK): v for idx, v in enumerate(stack)} - end_state: EndState = FrozenDict({**end_variables, **end_stack}) - return FunctionalizedBlock((tuple(functionalized), FrozenDict(begin_variables), end_state)) - - stack_offsets = _compute_stack_offsets(disassembled) - functionalized = tuple(convert(block, offset) for block, offset in safe_zip(disassembled.blocks, stack_offsets)) - if errors: - raise RuntimeError("Preprocessing issues detected:\n" + textwrap.indent("\n\n".join(errors), " " * 4)) - - return functionalized - - -# ============================================================================= -# == Summary for debugging and testing ======================================== -# ============================================================================= -def _summarize(parsed: ParsedFunctional) -> str: - # Clear identifiers for input stack values. - to_symbol = FrozenDict[int, str](enumerate("⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ Ⓐ Ⓑ Ⓒ Ⓓ Ⓔ Ⓕ".split())) - - # Group output edges. - grouped_edges: dict[int, str] = { - source: ", ".join(f"{sink}{'(Jump)' if jump else ''}" for _, sink, jump in sinks) - for source, sinks in itertools.groupby(parsed.provenance.edges, lambda e: e[0]) - } - - # Best effort to apply descriptive names. - inputs_outputs = {} - block_headers: list[str] = [] - for idx, (functionalized_block, begin, end) in enumerate(parsed.blocks): - begin_stack = tuple(v for k, v in begin.items() if k.scope == VariableScope.STACK) - stack_names: dict[str, str] = {v: to_symbol.get(idx, f"S{idx}") + "\u2009" for idx, v in enumerate(begin_stack)} - names: dict[str, str] = {**{v: f"{k.identifier}" for k, v in begin.items()}, **stack_names} - names.update({v: f"v{idx}" for idx, v in enumerate(itertools.chain(*[o for _, _, o in functionalized_block]))}) - for instruction, inputs, outputs in functionalized_block: - inputs_outputs[instruction] = (tuple(names[i] for i in inputs), tuple(names[o] for o in outputs)) - - end_stack = {k: v for k, v in end.items() if k.scope == VariableScope.STACK} - assert tuple(k.identifier for k in end_stack) == tuple(range(len(end_stack))), end_stack - end_stack_str = ", ".join(names[i or ""] for i in end_stack.values()) - block_headers.append(f"Block {idx}: [{', '.join(stack_names.values())}] => [{end_stack_str}]") - - # Group loads and stores. - prefix = {opname: opname.split("_")[0] for opname in itertools.chain(STORE_OPNAMES, LOAD_OPNAMES, DEL_OPNAMES)} - condensed: list[list[tuple[str, instructions.ThunderInstruction | None]]] = [] - for raw_block in parsed.provenance.blocks: - condensed.append([]) - for prefix_or_i, group in itertools.groupby(raw_block, lambda i: prefix.get(i.opname, i)): - if isinstance(prefix_or_i, str): - name = ", ".join(f"{i.argval}: {i.opname[len(prefix_or_i) + 1:]}" for i in group) - condensed[-1].append((f"{prefix_or_i}[{name.replace(': FAST', '')}]", None)) - else: - opname = f"{prefix_or_i.opname}{'' if type(prefix_or_i) is instructions.ThunderInstruction else '*'}" - condensed[-1].append((opname, prefix_or_i)) - - # Write lines. - lines: list[str] = [] - width = max(len(name) for name, _ in itertools.chain(*condensed)) - width = max(width, max(len(i) for i in block_headers)) + 5 - for idx, (condensed_block, (_, _, end)) in enumerate(safe_zip(condensed, parsed.blocks)): - lines.append(block_headers[idx]) - for name, maybe_i in condensed_block: - inputs, outputs = inputs_outputs.get(maybe_i, ((), ())) # type: ignore[assignment, arg-type] - if inputs or outputs: - name = f"{name} ".ljust(width, ".").replace("..", ". ") - name = f"{name} ({', '.join(inputs)}) -> {', '.join(outputs)}" - lines.append(f" {name}") - if idx in grouped_edges: - lines.append(f" -> {grouped_edges[idx]}") - lines.append("") - - return "\n".join(lines) diff --git a/thunder/core/script/parse/instructions.py b/thunder/core/script/parse/instructions.py deleted file mode 100644 index 34fb7fed07..0000000000 --- a/thunder/core/script/parse/instructions.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Extension of the builtin `dis` module.""" -from __future__ import annotations - -import dis -from typing import Any - -from typing_extensions import Self - -from thunder.core.utils import _OrderedSet - -__all__ = ( - "ThunderInstruction", - "InstructionSet", - "JUMP_ABSOLUTE", - "RETURN_VALUE", - "POP_TOP", - "EXTENDED_ARG", - "UNCONDITIONAL_BACKWARD", - "UNCONDITIONAL_JUMP_INSTRUCTIONS", - "ABSOLUTE_JUMP_INSTRUCTIONS", - "RELATIVE_JUMP_INSTRUCTIONS", - "JUMP_INSTRUCTIONS", - "RAISE_RETURN_INSTRUCTIONS", - "RETURN_INSTRUCTIONS", - "UNSAFE_OPCODES", -) - - -class ThunderInstruction(dis.Instruction): - """Thin wrapper on top of dis.Instruction to implement thunder specific logic.""" - - line_no: int - - def __hash__(self) -> int: - # We sometimes want to use an instruction as a key so we can map back to nodes. - # `dis.Instruction` is a named tuple and therefore implements recursive constituent - # hashing which can lead to unwanted collisions. We instead override this behavior - # to instead use identity hashing. - return id(self) - - def __eq__(self, other: object) -> bool: - return self is other - - @property - def oparg(self) -> int: - assert self.arg is not None, self - return self.arg - - def modify_copy(self, **kwargs: Any) -> ThunderInstruction: - assert type(self) is ThunderInstruction, self - if "opname" in kwargs: - kwargs.setdefault("opcode", dis.opmap.get(kwargs["opname"], -1)) - result = ThunderInstruction(**{**self._asdict(), **kwargs}) - result.line_no = self.line_no - return result - - @classmethod - def make(cls, opname: str, arg: int | None, line_no: int, **kwargs: Any) -> Self: - ctor_kwargs = dict( - opname=opname, - opcode=dis.opmap.get(opname, -1), - arg=arg, - argval=None, - argrepr="None", - offset=-999, - starts_line=None, - is_jump_target=False, - ) - ctor_kwargs.update(kwargs) - result = cls(**ctor_kwargs) # type: ignore - result.line_no = line_no - return result - - @classmethod - def make_jump_absolute(cls, arg: int, line_no: int = -1) -> ThunderInstruction: - return cls.make(JUMP_ABSOLUTE, arg, argrepr=f"{arg}", line_no=line_no) - - @classmethod - def make_return(cls, is_jump_target: bool, line_no: int = -1) -> ThunderInstruction: - return cls.make(RETURN_VALUE, arg=None, argrepr="", is_jump_target=is_jump_target, line_no=line_no) - - -class InstructionSet(_OrderedSet[str, int | ThunderInstruction]): - """Convenience class for checking opcode properties.""" - - def canonicalize(self, i: str | int | ThunderInstruction) -> str: - if isinstance(i, str): - return i - - elif isinstance(i, int): - return dis.opname[i] - - else: - assert isinstance(i, ThunderInstruction) - return i.opname - - -# Special opcodes -JUMP_ABSOLUTE = "JUMP_ABSOLUTE" -RETURN_VALUE = "RETURN_VALUE" -POP_TOP = "POP_TOP" -EXTENDED_ARG = "EXTENDED_ARG" - - -UNCONDITIONAL_BACKWARD = InstructionSet(("JUMP_BACKWARD", "JUMP_BACKWARD_NO_INTERRUPT")) -UNCONDITIONAL_JUMP_INSTRUCTIONS = InstructionSet((JUMP_ABSOLUTE, "JUMP_FORWARD", *UNCONDITIONAL_BACKWARD)) - -ABSOLUTE_JUMP_INSTRUCTIONS = InstructionSet(dis.hasjabs) -RELATIVE_JUMP_INSTRUCTIONS = InstructionSet(dis.hasjrel) -JUMP_INSTRUCTIONS = InstructionSet((*dis.hasjabs, *dis.hasjrel, *UNCONDITIONAL_JUMP_INSTRUCTIONS)) - -RAISE_RETURN_INSTRUCTIONS = InstructionSet(("RAISE_VARARGS", "RERAISE")) -RETURN_INSTRUCTIONS = InstructionSet((RETURN_VALUE, *RAISE_RETURN_INSTRUCTIONS)) - - -# https://github.com/Lightning-AI/lightning-thunder/issues/1075 -UNSAFE_OPCODES = InstructionSet(("SETUP_WITH", "SETUP_FINALLY")) diff --git a/thunder/core/script/parse/stack_effect.py b/thunder/core/script/parse/stack_effect.py deleted file mode 100644 index 1ddeff412d..0000000000 --- a/thunder/core/script/parse/stack_effect.py +++ /dev/null @@ -1,234 +0,0 @@ -import dis -import opcode -import sys -from typing import NewType, TypeAlias, TypeVar -from collections.abc import Callable -from collections.abc import Iterable - -from types import EllipsisType - -from thunder.core.utils import FrozenDict - -__all__ = ("stack_effect_detail", "fill_ellipses") - -T = TypeVar("T") -Pop = NewType("Pop", int) -Push = NewType("Push", int) -StackEffect: TypeAlias = tuple[Pop, Push] | tuple[Pop, tuple[Push, Push]] - -# Aliases for common cases -NoStackEffect = (Pop(0), Push(0)) -PushTOS = (Pop(0), Push(1)) -PopTOS = (Pop(1), Push(0)) -ReplaceTOS = (Pop(1), Push(1)) -BinaryOp = (Pop(2), Push(1)) - - -def make_function_detail(*args: int) -> Callable[[int], StackEffect]: - return lambda oparg: (Pop(2 + sum((oparg & flag) != 0 for flag in args)), Push(1)) - - -def fill_ellipses(**kwargs: T | EllipsisType) -> Iterable[tuple[str, T]]: - prior_effect: T | EllipsisType = Ellipsis - for opname, effect in kwargs.items(): - if effect is Ellipsis: - effect = prior_effect - assert effect is not Ellipsis - prior_effect = effect - yield opname, effect - - -__EFFECTS = dict[str, StackEffect | Callable[[int], StackEffect] | EllipsisType]( - NOP=NoStackEffect, # ∅ -> ∅ - EXTENDED_ARG=NoStackEffect, - # - # Stack manipulation - POP_TOP=PopTOS, # A -> ∅ - ROT_TWO=(Pop(2), Push(2)), # A,B -> B,A - ROT_THREE=(Pop(3), Push(3)), # A,B,C -> C,A,B - ROT_FOUR=(Pop(4), Push(4)), # A,B,C,D -> D,A,B,C - ROT_N=lambda oparg: (Pop(oparg), Push(oparg)), # A,B,...,Z -> Z,A,B,... - DUP_TOP=(Pop(1), Push(2)), # A -> A,A - DUP_TOP_TWO=(Pop(2), Push(4)), # A,B -> A,B,A,B - UNPACK_SEQUENCE=lambda oparg: (Pop(1), Push(oparg)), # A -> B,C,... - # - # Jumps & return - JUMP_FORWARD=NoStackEffect, # ∅ -> ∅ - JUMP_ABSOLUTE=..., - POP_JUMP_IF_FALSE=PopTOS, # A -> ∅ - POP_JUMP_IF_TRUE=..., - RETURN_VALUE=..., - JUMP_IF_NOT_EXC_MATCH=BinaryOp, # A,B -> ∅ - # - # Exceptions and context managers: - POP_BLOCK=NoStackEffect, # ∅ -> ∅ - POP_EXCEPT=(Pop(3), Push(0)), # A, B, C -> ∅ - RERAISE=..., - RAISE_VARARGS=lambda oparg: (Pop(oparg), Push(0)), # A,B,... -> ∅ - WITH_EXCEPT_START=(Pop(7), Push(8)), # ??!? - LOAD_ASSERTION_ERROR=PushTOS, # ∅ -> A - # - # Variable manipulation - LOAD_CONST=PushTOS, # ∅ -> A - LOAD_FAST=..., - LOAD_GLOBAL=..., - LOAD_NAME=..., - STORE_FAST=PopTOS, # A -> ∅ - STORE_GLOBAL=..., - STORE_NAME=..., - DELETE_FAST=NoStackEffect, # ∅ -> ∅ - DELETE_GLOBAL=..., - DELETE_NAME=..., - # - # Attributes - LOAD_METHOD=(Pop(1), Push(2)), # A -> B,A - LOAD_ATTR=ReplaceTOS, # A -> B - STORE_ATTR=(Pop(2), Push(0)), # A, B -> ∅ - DELETE_ATTR=PopTOS, # A -> ∅ - # - # Closures - LOAD_CLOSURE=PushTOS, # ∅ -> A - LOAD_DEREF=..., - LOAD_CLASSDEREF=..., - STORE_DEREF=PopTOS, # A -> ∅ - DELETE_DEREF=NoStackEffect, # ∅ -> ∅ - # - # Functions and calls A,B,... -> Z - CALL_FUNCTION=lambda x: (Pop(x + 1), Push(1)), - CALL_METHOD=lambda x: (Pop(x + 2), Push(1)), - CALL_FUNCTION_KW=..., - CALL_FUNCTION_EX=make_function_detail(0x01), - MAKE_FUNCTION=make_function_detail(0x01, 0x02, 0x04, 0x08), - # - # Build containers A,B,... -> Z - BUILD_TUPLE=lambda oparg: (Pop(oparg), Push(1)), - BUILD_LIST=..., - BUILD_SET=..., - BUILD_STRING=..., - BUILD_MAP=lambda oparg: (Pop(oparg * 2), Push(1)), - BUILD_CONST_KEY_MAP=lambda x: (Pop(x + 1), Push(1)), - LIST_TO_TUPLE=ReplaceTOS, # A -> B - # - # Insertion leaves container on the stack A,B -> A - SET_ADD=BinaryOp, - SET_UPDATE=..., - LIST_APPEND=..., - LIST_EXTEND=..., - DICT_MERGE=..., - DICT_UPDATE=..., - MAP_ADD=(Pop(3), Push(1)), # A,B,C -> A - COPY_DICT_WITHOUT_KEYS=(Pop(2), Push(2)), # A,B -> A,C (I am unsure...) - # - # Unary operators A -> B - UNARY_POSITIVE=ReplaceTOS, - UNARY_NEGATIVE=..., - UNARY_NOT=..., - UNARY_INVERT=..., - # - # Binary operators A,B -> C - BINARY_POWER=BinaryOp, - BINARY_MULTIPLY=..., - BINARY_MATRIX_MULTIPLY=..., - BINARY_MODULO=..., - BINARY_ADD=..., - BINARY_SUBTRACT=..., - BINARY_SUBSCR=..., - BINARY_FLOOR_DIVIDE=..., - BINARY_TRUE_DIVIDE=..., - INPLACE_FLOOR_DIVIDE=..., - INPLACE_TRUE_DIVIDE=..., - INPLACE_ADD=..., - INPLACE_SUBTRACT=..., - INPLACE_MULTIPLY=..., - INPLACE_MATRIX_MULTIPLY=..., - INPLACE_MODULO=..., - BINARY_LSHIFT=..., - BINARY_RSHIFT=..., - BINARY_AND=..., - BINARY_XOR=..., - BINARY_OR=..., - COMPARE_OP=..., - IS_OP=..., - CONTAINS_OP=..., - # - # Binary operators (inplace) - # https://docs.python.org/3/reference/datamodel.html?highlight=iadd#object.__iadd__ - # "... and return the result (which could be, but does not have to be, self)." - INPLACE_POWER=BinaryOp, - INPLACE_LSHIFT=..., - INPLACE_RSHIFT=..., - INPLACE_AND=..., - INPLACE_XOR=..., - INPLACE_OR=..., - # - # Indexing operators - STORE_SUBSCR=(Pop(3), Push(0)), # A,B,C -> ∅ - DELETE_SUBSCR=(Pop(2), Push(0)), # A,B -> ∅ - BUILD_SLICE=lambda x: (Pop(x), Push(1)), # A,B,... -> Z - UNPACK_EX=lambda x: (Pop(1), Push((x & 0xFF) + (x >> 8) + 1)), # A -> B,C,... - # - # Iterators - GET_ITER=ReplaceTOS, # A -> B - GET_YIELD_FROM_ITER=ReplaceTOS, - # - # Misc. - FORMAT_VALUE=lambda oparg: (Pop(1 + bool(oparg & 0x04)), Push(1)), # (A?),B -> C - PRINT_EXPR=PopTOS, # A -> ∅ - IMPORT_STAR=..., - LOAD_BUILD_CLASS=PushTOS, - SETUP_ANNOTATIONS=NoStackEffect, - GET_LEN=(Pop(1), Push(2)), - IMPORT_NAME=BinaryOp, - IMPORT_FROM=(Pop(1), Push(2)), - MATCH_CLASS=(Pop(3), Push(1)), - MATCH_MAPPING=(Pop(1), Push(2)), - MATCH_SEQUENCE=..., - MATCH_KEYS=(Pop(2), Push(3 + bool(sys.version_info < (3, 11)))), - # - # Jump dependent - FOR_ITER=(Pop(1), (Push(2), Push(0))), - SETUP_WITH=(Pop(1), (Push(2), Push(7))), - SETUP_FINALLY=(Pop(0), (Push(0), Push(6))), - SETUP_ASYNC_WITH=(Pop(0), (Push(0), Push(6))), - # - # NOTE: These instructions have been removed since they are extraneous special cases. - # https://github.com/faster-cpython/ideas/issues/567 - # https://github.com/python/cpython/issues/102859 - JUMP_IF_TRUE_OR_POP=(Pop(1), (Push(0), Push(1))), - JUMP_IF_FALSE_OR_POP=..., - # - # TODO(robieta, t-vi): Iterators and generators - # "GEN_START": PopTOS, # Where does TOS for this come from? - # "YIELD_VALUE": ReplaceTOS, # I think - # "YIELD_FROM": (2, PushNew), # I am very unsure - # "GET_AWAITABLE": (1, 1), - # "BEFORE_ASYNC_WITH": (1, 2), - # "GET_AITER": (1, 1), - # "GET_ANEXT": (1, 2), - # "END_ASYNC_FOR": (7, 0), -) - - -# Split so MyPy can type check `__EFFECTS` without having to go through `fill_ellipses`. -_RAW_STACK_EFFECTS = FrozenDict[str, StackEffect | Callable[[int], StackEffect]](fill_ellipses(**__EFFECTS)) -del __EFFECTS - - -def stack_effect_detail(instruction: dis.Instruction) -> tuple[Pop, tuple[Push, Push]]: - assert isinstance(instruction, dis.Instruction), instruction - if callable(effect := _RAW_STACK_EFFECTS[instruction.opname]): - assert instruction.arg is not None - effect = effect(instruction.arg) - - assert isinstance(effect, tuple) and len(effect) == 2 and isinstance(effect[0], int) - if isinstance(effect[1], int): - effect = (effect[0], (effect[1],) * 2) - - # Python exposes a method to compute stack effect, so while it's not part - # of the public API we may as well use it to check our bookkeeping. - pop, (push_nojump, push_jump) = effect - for jump, push in ((False, push_nojump), (True, push_jump)): - expected = opcode.stack_effect(instruction.opcode, instruction.arg, jump=jump) - assert expected == push - pop, (expected, push, pop, jump) - - return Pop(pop), (Push(push_nojump), Push(push_jump)) diff --git a/thunder/core/script/passes.py b/thunder/core/script/passes.py deleted file mode 100644 index a53581ec01..0000000000 --- a/thunder/core/script/passes.py +++ /dev/null @@ -1,932 +0,0 @@ -import dis -import copy -import inspect -import opcode -import sys -import types -from typing import Any, Dict, List, Tuple, Union -from collections.abc import Callable -from collections.abc import Hashable -from contextvars import ContextVar - -import networkx as nx -import torch # # aehem. - -import thunder -from thunder.core.script.frontend import acquire_method, remove_unused_values -from thunder.core.script.graph import ( - assert_block, - assert_node, - assert_value, - Graph, - Block, - clone_blocks, - _generate_raises, - GraphObject, - Node, - PhiValue, - replace_values, - SourceInformation, - _Undefined, - Value, - repr_source_location, -) -from thunder.core.script.instrumentation import verbose_error, record -from thunder.core.script.parse import ThunderInstruction, JUMP_ABSOLUTE -from thunder.core.script.python_ir_data import get_instruction, X_THUNDER_STORE_ATTR -from thunder.torch import _torch_to_thunder_complete_map -from thunder.core.script.noinline import NOINLINE_METHODS -from thunder.core.utils import debug_asserts_enabled, debug_asserts_level, OrderedSet - -MAX_INLINE_ITERS = 50 - - -def split_block(gr: "Graph", bl: "Block", n: "Node") -> Block: - # The admin involved: - # - create a new "bottom block", the input block is the "top block" - # - split the .nodes - # - block_inputs of the top block and block_outputs of the bottom are the original - # block_inputs and block_outputs - # - scan all the node inputs and block_outputs of the lower part to see - # which need to be block_inputs of the lower block and thus outputs of the top one - # - define outputs of the "top block" to be the required inputs - # - add the input PhiValues and replace the outputs of the top block with them in the - # uses in the bottom block - # - add unconditional jump from top to bottom part - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - i = 0 - while i < len(gr.blocks) and gr.blocks[i] is not bl: - i += 1 - assert i < len(gr.blocks), "block not found" - j = 0 - while j < len(bl.nodes) and bl.nodes[j] is not n: - j += 1 - assert j < len(bl.nodes), "node not found" - nbl = Block() - nbl.nodes = bl.nodes[j:] - del bl.nodes[j:] - old_block_outputs = bl.block_outputs - nbl.block_outputs = OrderedSet() - bl.block_outputs = OrderedSet() - nbl.block_inputs = [] - - bl_jump_node = Node(i=ThunderInstruction.make_jump_absolute(arg=None), inputs=[], outputs=[]) - bl_jump_node.jump_targets = [nbl] - if bl.nodes: - bl_jump_node.source_infos = copy.deepcopy(bl.nodes[-1].source_infos) - else: - bl_jump_node.source_infos = copy.deepcopy(nbl.nodes[0].source_infos) - bl.nodes.append(bl_jump_node) - nbl.jump_sources.append(bl_jump_node) - nbl.graph = gr - gr.blocks.insert(i + 1, nbl) - - potential_bl_outputs = {i for i in bl.block_inputs} - for n in bl.nodes: - for o in n.outputs: - potential_bl_outputs.add(o) - for i in bl.block_inputs: - potential_bl_outputs.add(i) - value_map: dict[GraphObject, GraphObject] = {} - - def get_or_create_phi(v: Value) -> Value: - if v in value_map: - return assert_value(value_map[v]) - if v.is_const or v.is_global: - return v - if v in potential_bl_outputs: # priority follow parent vs. phi_value? - phi_value = PhiValue([v], [bl_jump_node], nbl) - nbl.block_inputs.append(phi_value) - bl.block_outputs.add(v) - value_map[v] = phi_value - return phi_value - if v.parent is not None: - # this adds v.parent to the value_map, so that is used - # for the clone's parent - get_or_create_phi(v.parent) - v_new = v.clone(translation_dict=value_map) - v_new.block = nbl - return v_new - raise ValueError(f"unknwn value {v}") - - for n in nbl.nodes: - n.inputs = [get_or_create_phi(i) for i in n.inputs] - for o in n.outputs: - o.block = nbl - value_map[o] = o - - for o in old_block_outputs: - if o not in value_map: - bl.block_outputs.add(o) - else: - assert value_map[o].block is nbl or ( - value_map[o].is_function_arg or value_map[o].is_global - ), f"value {repr(o)} mapped to {repr(value_map[o])} has block {gr.blocks.index(value_map[o].block)} instead of {gr.blocks.index(nbl)}" - nbl.block_outputs.add(value_map[o]) - if o is not value_map[o]: - for pv in o.phi_values[:]: - if pv.block is not nbl: - pv.replace_value(o, value_map[o]) - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - return nbl - - -@verbose_error -def find_method_through_phi_parent(fn_value: Value) -> tuple[Value, list[str]]: - Point = tuple[Value, tuple[str, ...]] - to_process: list[Point] = [(v, ()) for v in fn_value.resolve()] - edges: OrderedSet[tuple[Point, Point]] = OrderedSet(((fn_value, ()), i) for i in to_process) - while to_process: - v, attr = to_process.pop() - destination = (v, attr) - if (parent := v.parent) is not None and (name := v.name) is not None: - destination = (parent, (name, *attr)) - - elif (node := v.node) is not None and node.i.opname == "BINARY_SUBSCR" and node.inputs[1].is_const: - destination = (node.inputs[0], (repr(node.inputs[1].value), *attr)) - - for vi in destination[0].resolve(): - edge = ((v, attr), (vi, destination[1])) - if edge not in edges: - edges.add(edge) - to_process.append(edge[1]) - - G = nx.from_edgelist(edges, nx.DiGraph) - G.remove_edges_from(nx.selfloop_edges(G)) - assert nx.is_connected(G.to_undirected()) - assert nx.is_directed_acyclic_graph(G) - - # A size one topological generation means all flow must pass through that node. Thus, the latest - # generation with that property is the farthest we can resolve attributes. - *_, (fn_value, attr_lookups) = (i for i, *other in nx.topological_generations(G) if not other) - return fn_value, list(attr_lookups) - - -def find_and_evaluate_method_through_phi_parent(v: Value) -> object | Callable: - fn_parent_value, attr_lookups = find_method_through_phi_parent(v) - if fn_parent_value.value is None: - return None - fn_value = fn_parent_value.value - for al in attr_lookups: - value = getattr(fn_value, al, _Undefined) - if value is _Undefined: - return _Undefined(fn_value, al) - fn_value = value - return fn_value - - -class SkipInlineError(NotImplementedError): - pass - - -@record(delegate_to="n") -def inline_method_call(gr: "Graph", n: "Node") -> None: - gr.ensure_links() - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - found_block = False - for i_bl, bl in enumerate(gr.blocks): - for i_n, n1 in enumerate(bl.nodes): - if n1 is n: # is? - found_block = True - break - if found_block: - break - assert found_block - if n.i.opname == "CALL_METHOD": - fn_value: Callable = find_and_evaluate_method_through_phi_parent(n.inputs[0]) # type: ignore - assert not isinstance(fn_value, _Undefined) - if fn_value is None: - raise NotImplementedError("cannot inline non-explicit function") - - ## TODO: value for self arg in Method calls? - ### in general: What is with callables here? - if isinstance(fn_value, torch.nn.Module): - mod1: object = fn_value - value_for_self1 = n.inputs[0] - fn_value = fn_value.forward - elif isinstance(fn_value, types.MethodType): - mod1 = fn_value.__self__ - value_for_self1 = n.inputs[1] - else: - mod1 = None - value_for_self1 = None - - if inspect.isbuiltin(fn_value): - raise NotImplementedError("cannot inline built-in (C-implemented) function") - elif n.i.opname in {"CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) # type: ignore - assert not isinstance(fn_value, _Undefined) - if fn_value is None: - raise NotImplementedError("cannot inline non-explicit function") - - if isinstance(fn_value, torch.nn.Module): - mod1 = fn_value - value_for_self1 = n.inputs[0] - fn_value = fn_value.forward - else: - if isinstance(fn_value, types.FunctionType): - mod1 = None - value_for_self1 = None - elif isinstance(fn_value, types.MethodType): - mod1 = fn_value.__self__ - value_for_self1 = n.inputs[0].parent - assert value_for_self1 is not None - else: - source_str = repr_source_location(gr, n.source_infos) - raise NotImplementedError(f"inlining {fn_value} in instruction {n} at\n{source_str}") - else: - raise NotImplementedError(f"inlining {n}") - - # splitting must be done before replacing values, but this is changed even if we don't inline... - nbl = split_block(gr, bl, bl.nodes[i_n + 1]) - - gr1 = acquire_method(fn_value, module=mod1, mro_klass=gr.mro_klass if mod1 == gr.module else None) - for gr1_n in gr1.nodes(): - assert gr1_n.source_infos - have_generated = False - for si in gr1_n.source_infos: - si.gen_line_no = si.gen_line_no + len(gr.source_lines) + 1 - si.gen_end_line_no = si.gen_end_line_no + len(gr.source_lines) + 1 - # prepend - gr1_n.source_infos[:0] = copy.deepcopy(n.source_infos) - gr.source_lines.append("\n") - gr.source_lines += gr1.source_lines - - if gr1.ismethod: - sig1 = inspect.signature(gr1.method.__func__) - else: - sig1 = inspect.signature(gr1.method) - # transform defaults - sig1 = sig1.replace( - parameters=[ - p - if p.default is inspect._empty - else p.replace(default=Value(name=p.name, typ=type(p.default), value=p.default, is_const=True)) - for p in sig1.parameters.values() - ] - ) - - if gr1.ismethod: - call_args = [value_for_self1] - else: - call_args = [] - - if n.i.opname == "CALL_METHOD": - call_args += n.inputs[2:] - call_kwargs: dict[str, Any] = {} - elif n.i.opname == "CALL_FUNCTION": - call_args += n.inputs[1:] - call_kwargs = {} - elif n.i.opname == "CALL_FUNCTION_KW": - assert n.inputs[-1].is_const - num_kwargs = len(n.inputs[-1].value) - call_kwargs = {k: v for k, v in zip(n.inputs[-1].value, n.inputs[-1 - num_kwargs : -1])} - call_args += n.inputs[1 : -1 - num_kwargs] - else: - raise NotImplementedError() - - # TODO: catch and translate error messages, check types(?) - bound_args = sig1.bind(*call_args, **call_kwargs) - bound_args.apply_defaults() - - gr1_varargs = [n for n, p in sig1.parameters.items() if p.kind == p.kind.VAR_POSITIONAL] - gr1_varkwargs = [n for n, p in sig1.parameters.items() if p.kind == p.kind.VAR_KEYWORD] - ## TODO: TRANSLATE args (=tuple of Values) and kwargs (=dict str->Value) to a Value to something Value of ... (probably needs at least BUILD_TUPLE etc) - if gr1_varargs or gr1_varkwargs: - raise SkipInlineError("varargs and kwargs are currently not implemented") - - n1 = bl.nodes.pop(i_n) - assert n1 is n - - # there should be exactly one - (ret_bl,) = (bl for bl in gr1.blocks if len(bl.nodes) > 0 and bl.nodes[-1].i.opname == "RETURN_VALUE") - - ret_node = ret_bl.nodes[-1] - ret_node.i = ThunderInstruction.make( - JUMP_ABSOLUTE, - arg=-1, - argrepr="None", - offset=ret_node.i.offset, - starts_line=ret_node.i.starts_line, - is_jump_target=ret_node.i.is_jump_target, - line_no=ret_node.i.line_no, - ) - bl.nodes[-1].jump_targets = [gr1.blocks[0]] - assert len(gr1.blocks[0].jump_sources) == 1 - gr1.blocks[0].jump_sources = [bl.nodes[-1]] - for pv in gr1.blocks[0].block_inputs: - assert pv.jump_sources == [None] - pv.jump_sources = [bl.nodes[-1]] - ret_node.jump_targets = [nbl] - nbl.jump_sources = [ret_node if js == bl.nodes[-1] else js for js in nbl.jump_sources] - for pv in nbl.block_inputs: - pv.jump_sources = [ret_node if js == bl.nodes[-1] else js for js in pv.jump_sources] - - for bl1 in gr1.blocks: - bl1.graph = gr - gr.blocks[i_bl + 1 : i_bl + 1] = gr1.blocks - - assert len(n.outputs) == 1 - inp_map = {p: bound_args.arguments[p.name] for p in gr1.local_variables_at_start if p.name in bound_args.arguments} - if n.outputs[0] in bl.block_outputs: # it may legitimately happen that we don't use the output - bl.block_outputs.remove(n.outputs[0]) # TODO: what with inplace!! - bl.block_outputs.update(inp_map.values()) # Note: This includes default args - gr.ensure_links() - replace_values(gr1, inp_map) - - # output value - rv = ret_node.inputs.pop() - assert not ret_node.inputs - (orv,) = n.outputs - replace_values(gr, {orv: rv}) - ret_bl.block_outputs.add(rv) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - -def inline_submodule_calls(gr: "Graph") -> bool: - # inlines submodule calls - # returns whether something has changed - # TODO: recursively and not from nested structures (ModuleList etc.) - changed = False - gr.ensure_links() - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname in {"CALL_METHOD", "CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if isinstance(fn_value, _Undefined): - # TODO: We could insert a RAISE here if we then delete the return - # value and all (direct or indirect) uses. - methval = Value( - value=_generate_raises( - f"attribute error '{type(fn_value.value)}' object has no attribute '{fn_value.attr}'" - ), - is_const=True, - ) - n.i = n.i.modify_copy(opname="CALL_FUNCTION", arg=0, opcode=None) - n.inputs = [methval] - if isinstance(fn_value, torch.nn.Module) or ( - inspect.ismethod(fn_value) - and isinstance(fn_value.__self__, torch.nn.Module) - and (fn_value not in NOINLINE_METHODS.get()) - ): - inline_method_call(gr, n) - changed = True - - return changed - - -def strongly_inline_functions(gr: "Graph") -> None: - for _ in range(MAX_INLINE_ITERS): - loop = False - gr.ensure_links() - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname in {"CALL_METHOD", "CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if ( - fn_value is not None - and not inspect.isbuiltin(fn_value) - and isinstance(fn_value, types.FunctionType) - and fn_value not in _torch_to_thunder_complete_map - and fn_value not in NOINLINE_METHODS.get() - ): - ## handle methods or nn.Modules / other classes? - try: - inline_method_call(gr, n) - loop = True - except SkipInlineError: - pass - except RuntimeError as e: - (msg,) = e.args - source_str = repr_source_location(gr, n.source_infos) - msg = f"{msg}\nwhile inlining:\n{source_str}" - e.args = (msg,) - raise e - if not loop: - return - - raise AssertionError(f"Inlining did not complete after {MAX_INLINE_ITERS} passes.") - - -def torch_to_thunder(gr: "Graph", fallback: bool = False) -> None: - """replaces calls to torch.foo functions with calls into thunder's torch language.""" - - def fill_in_value(v: Value, seen: OrderedSet[Value]) -> None: - if v in seen: - return - seen.add(v) - parent = v.parent - if parent is None and isinstance(v, PhiValue): - for vv in v.values: - fill_in_value(vv, seen) - for vv in v.values[1:]: - if vv.value is not v.values[0].value: - return - v.value = v.values[0].value - if v.value is None and parent is not None: - fill_in_value(parent, seen) - if v.name is None and isinstance(v, PhiValue) and parent is not None and parent.name is not None: - v.name = parent.name - if v.value is None and parent is not None and parent.value is not None and v.name is not None: - v.value = getattr(parent.value, v.name, None) - - for bl in gr.blocks: - for n in bl.nodes: - for idx, i in enumerate(n.inputs): - done = False - fill_in_value(i, OrderedSet()) - i_or_parent = i - while ( - not isinstance(i_or_parent.value, Hashable) - or i_or_parent.value not in _torch_to_thunder_complete_map - ) and i_or_parent.parent is not None: - i_or_parent = i_or_parent.parent - - if isinstance(i_or_parent.value, Hashable) and i_or_parent.value in _torch_to_thunder_complete_map: - i_or_parent.value = _torch_to_thunder_complete_map[i.value] - # we reinstantiate because we don't want a PhiValue here - i_new = Value( - value=i_or_parent.value, - typ=type(i_or_parent.value), - parent=None, - is_const=True, - is_global=False, - name=i_or_parent.name, - ) - n.inputs[idx] = i_new - if n.i.opname == "CALL_METHOD" and idx == 0: - # todo get others, too - n.i = get_instruction(opname="CALL_FUNCTION", arg=n.i.arg) - del n.inputs[1] - done = True - - if (not done) and fallback: # fallback - # todo: change name?, deeper nesting? - if i.value == torch: - i.value = thunder.langs.torch - if i.parent is not None and i.parent.value == torch: - i.parent.value = thunder.langs.torch - assert i.name is not None - i.value = getattr(thunder.langs.torch, i.name) - - # replace other things by checking against torch module (make dict at startup?) - name = getattr(i.value, "__name__", None) - tf = None - if name is not None: - tf = getattr(torch, name, None) - if tf is not None and i.value == tf: - i.value = getattr(thunder.langs.torch, name) - i.is_global = False - i.is_const = True - - -def merge_two_blocks(gr: "Graph", bl1: "Block") -> None: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - jt = bl1.nodes[-1].jump_targets - if len(jt) != 1: - raise RuntimeError("can only fuse blocks with deterministic connection") - bl2 = jt[0] - if len(bl2.jump_sources) != 1 or bl2.jump_sources[0] != bl1.nodes[-1]: - raise RuntimeError("second block to be fused must only have first block as jump source") - - replacements: dict[Value, Value] = {} - for i in bl2.block_inputs: - assert isinstance(i, PhiValue) and len(i.values) == 1, (i, getattr(i, "values", None)) - (iv,) = i.values - if iv in bl1.block_outputs: - replacements[i] = iv - else: - if i.jump_sources == [bl1.nodes[-1]]: - i.jump_sources = [iv.block.nodes[-1]] - bl1.block_inputs.append(i) - i.block = bl1 - - replace_values(bl2, replacements, follow_phi_values=True) - # TODO: Should this happen automatically in replace_values? - # Should we also replace values in bl1? - for o in bl1.block_outputs: - for pv in o.phi_values[:]: - if pv in replacements: - pv.remove_value(o) - else: - pv.jump_sources = [js if js != bl1.nodes[-1] else bl2.nodes[-1] for js in pv.jump_sources] - - bl1_jump = bl1.nodes[-1] - bl2_jump = bl2.nodes[-1] - - bl1.block_outputs = OrderedSet(o for o in bl1.block_outputs if o.phi_values) - bl1.block_outputs.update(bl2.block_outputs) - - bl1.nodes[-1:] = bl2.nodes - gr.blocks.remove(bl2) - - gr.ensure_links() - - # fix jump sources in other blocks - for bl in gr.blocks: - for i in bl.block_inputs: - i.jump_sources = [(bl2_jump if js is bl1_jump else js) for js in i.jump_sources] - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - -def merge_blocks_where_possible(gr: "Graph") -> None: - i_bl = 0 - while i_bl < len(gr.blocks): - bl1 = gr.blocks[i_bl] - jt = bl1.nodes[-1].jump_targets - if len(jt) == 1: - bl2 = jt[0] - else: - bl2 = None - if bl2 is not None and len(bl2.jump_sources) == 1 and bl2.jump_sources[0] == bl1.nodes[-1]: - merge_two_blocks(gr, bl1) - else: - i_bl += 1 - - -def find_blocks_of_for(gr: "Graph", for_block: "Block") -> list[Block]: - assert for_block.nodes[-1].i.opname == "FOR_ITER" - - blocks_of_for_loop = OrderedSet({for_block}) - currently_looking_at = set() - - def find_blocks_of_for_rec(for_block: "Block", start_block: "Block") -> bool: - if for_block == start_block: - return True - if start_block in currently_looking_at: - return False - currently_looking_at.add(start_block) - found = False - for jt in start_block.nodes[-1].jump_targets: - found |= find_blocks_of_for_rec(for_block, jt) - currently_looking_at.remove(start_block) - if found: - blocks_of_for_loop.add(start_block) - return found - - find_blocks_of_for_rec(for_block, for_block.nodes[-1].jump_targets[0]) - return list(blocks_of_for_loop) - - -def unroll_for_over_modules(gr: "Graph", for_iter_node: "Node") -> None: - gr.ensure_links() - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - get_iter_node = for_iter_node.inputs[0].values[0].node - assert get_iter_node.i.opname == "GET_ITER" - - iterated_module_list_parent, attr_lookups = find_method_through_phi_parent(get_iter_node.inputs[0]) - assert iterated_module_list_parent.value is not None - iterated_module_list = iterated_module_list_parent.value - for al in attr_lookups: - iterated_module_list = getattr(iterated_module_list, al) - - # what about more complex things? - assert isinstance(iterated_module_list, (torch.nn.Sequential, torch.nn.ModuleList)) - - for_loop_len = len(iterated_module_list) - for_iter_block = for_iter_node.block - assert for_iter_block is not None - get_iter_block = get_iter_node.block - - (iter_v,) = get_iter_node.outputs - (iter_phi,) = for_iter_node.inputs - - assert isinstance(iter_phi, PhiValue) - assert iter_v in iter_phi.values - - ### first we find the blocks of the for loop - bls = find_blocks_of_for(gr, for_iter_block) - - jmp_nodes = {bl.nodes[-1] for bl in bls} - assert all((v is iter_v or js in jmp_nodes) for v, js in zip(iter_phi.values, iter_phi.jump_sources)) - - for_iter_node.i = get_instruction(opname="BINARY_SUBSCR", arg=None) - iter_phi.remove_value(iter_v) - assert len(iter_v.phi_values) == 0 - get_iter_block.block_outputs.remove(iter_v) - - get_iter_block.block_outputs.add(get_iter_node.inputs[0]) - - seen = set() - - def delete_value_and_sources(v: Value) -> None: - # check that it is possible? - if v in seen: - return - seen.add(v) - if isinstance(v, PhiValue): - for vv, js in zip(v.values, v.jump_sources): - delete_value_and_sources(vv) - assert js is not None and js.block is not None - js.block.block_outputs.remove(vv) - v.block.block_inputs.remove(v) - - delete_value_and_sources(iter_phi) - seq_phi = PhiValue(values=[get_iter_node.inputs[0]], jump_sources=[get_iter_block.nodes[-1]], block=for_iter_block) - get_iter_block.nodes.remove(get_iter_node) - for_iter_block.block_inputs.append(seq_phi) - - idx = Value(value=0, is_const=True) - for_iter_node.inputs = [seq_phi, idx] - for_iter_node.outputs = [for_iter_node.outputs[1]] - - for_iter_block_jmp = Node(i=get_instruction(opname="JUMP_ABSOLUTE", arg=None)) - for_iter_block_jmp.source_infos = copy.deepcopy(for_iter_node.source_infos) - for_iter_block.nodes.append(for_iter_block_jmp) - for_iter_block_jmp.jump_targets = [for_iter_node.jump_targets[0]] - for_iter_node_exit_jump_target = for_iter_node.jump_targets[1] - for_iter_node.jump_targets = [] - for_iter_block_jmp.jump_targets[0].jump_sources = [ - (js if js is not for_iter_node else for_iter_block_jmp) - for js in for_iter_block_jmp.jump_targets[0].jump_sources - ] - - exit_block = Block() - gr.blocks.append(exit_block) - exit_node = Node(i=get_instruction(opname="JUMP_ABSOLUTE", arg=None)) - exit_node.source_infos = copy.deepcopy(for_iter_node.source_infos) - exit_node.jump_targets = [for_iter_node_exit_jump_target] - target_after_iter = exit_node.jump_targets[0] - exit_node.jump_targets[0].jump_sources = [ - (js if js is not for_iter_node else exit_node) for js in exit_node.jump_targets[0].jump_sources - ] - exit_block.nodes.append(exit_node) - for i in for_iter_block.block_inputs: - exit_block.block_inputs.append(PhiValue([], [], exit_block)) - - unroll_blocks: list[tuple[list[Block], dict[GraphObject, GraphObject]]] = [(list(bls), {})] - unroll_blocks += [clone_blocks(bls) for _ in range(1, for_loop_len)] - for idx, (nbls, td) in enumerate(unroll_blocks): - if idx > 0: - gr.blocks += nbls - v_idx = Value(value=idx, is_const=True) - assert_node(td[for_iter_node]).inputs[1] = v_idx - fin_o = assert_node(td[for_iter_node]).outputs[0] - assert fin_o.name is not None - fin_o.name += f"_{idx}" - else: - assert for_iter_node.outputs[0].name is not None - for_iter_node.outputs[0].name += "_0" - - gr.ensure_links() - - fixup_data = [] - for idx, (nbls, td) in enumerate(unroll_blocks): - if idx == 0: - fib_i = for_iter_block - jump_sources_to_fix = [js for js in for_iter_block.jump_sources if js is not get_iter_block.nodes[-1]] - else: - fib_i = assert_block(td[for_iter_block]) - jump_sources_to_fix = fib_i.jump_sources[:] - if idx + 1 < len(unroll_blocks): - _, td_next = unroll_blocks[idx + 1] - fib_next = assert_block(td_next[for_iter_block]) - else: - fib_next = exit_block - - fixup_data.append((fib_i, jump_sources_to_fix, fib_next, nbls)) - - for idx_it, (fib_i, jump_sources_to_fix, fib_next, nbls) in enumerate(fixup_data): - for js in jump_sources_to_fix: - assert js is not None - for idx, jt in enumerate(js.jump_targets): - if jt == fib_i: - js.set_jump_target(fib_next, idx=idx) - - for idx_i, i in enumerate(fib_i.block_inputs): - if any((js.block in nbls) for js in i.jump_sources): - ## if this is a variable updated in the loop: - ## - instead of looping back, point the update to the phi value of the next block (or the exit block) - ## - if idx > 0: remove external (before the loop) value - for v, js in zip(i.values[:], i.jump_sources[:]): - if js is not None and js.block not in nbls and idx_it > 0: - i.remove_value(v) - - for idx_it, (fib_i, jump_sources_to_fix, fib_next, nbls) in enumerate(fixup_data): - for idx_i, i in enumerate(fib_i.block_inputs): - if any((js is not None and js.block in nbls) for js in i.jump_sources): - for v, js in zip(i.values[:], i.jump_sources[:]): - if js is not None and assert_block(assert_node(js).block) in nbls: - i.remove_value(v) - assert_block(fib_next).block_inputs[idx_i].add_missing_value(v, jump_source=js) - if idx_it == 0: - for pv in i.phi_values[:]: - if pv.block is target_after_iter: - pv.remove_value(i) - pv.add_missing_value(exit_block.block_inputs[idx_i], jump_source=exit_node) - - for i in exit_block.block_inputs[:]: - if i.phi_values: - exit_block.block_outputs.add(i) - else: - assert isinstance(i, PhiValue) - for v in i.values[:]: - i.remove_value(v) - exit_block.block_inputs.remove(i) - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - - -def find_and_unroll_for_loop(gr: "Graph") -> bool: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - gr.ensure_links() - - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname == "FOR_ITER": - for_iter_node = n - get_iter_node = for_iter_node.inputs[0].values[0].node - if get_iter_node.i.opname == "GET_ITER": - ( - iterated_module_list_parent, - attr_lookups, - ) = find_method_through_phi_parent(get_iter_node.inputs[0]) - if iterated_module_list_parent.value is None: - continue - iterated_module_list = iterated_module_list_parent.value - for al in attr_lookups: - iterated_module_list = getattr(iterated_module_list, al) - # what about more complex things? in particular enumerate, but zip, ... - if isinstance(iterated_module_list, (torch.nn.Sequential, torch.nn.ModuleList)): - thunder.core.script.passes.unroll_for_over_modules(gr, for_iter_node) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - thunder.core.script.passes.merge_blocks_where_possible(gr) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - return True - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - return False - - -def unroll_for_loops_and_inline_modules(gr: "Graph") -> None: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - iterate = True - while iterate: - iterate = find_and_unroll_for_loop(gr) - if not iterate: - iterate = inline_submodule_calls(gr) - if iterate: - thunder.core.script.passes.merge_blocks_where_possible(gr) - - -def module_to_function(gr: "Graph") -> tuple[list[str], list[torch.Tensor]]: - attr_dict: dict[str, int] = {} - attr_list: list[str] = [] - attr_values = [] - return_values: dict[str, Value] = {} # PhiValues in the return block - - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - - def functionalize_value_if_possible(i): - # TODO: inefficient because it looks twice - v = find_and_evaluate_method_through_phi_parent(i) - # assert not isinstance(v, _Undefined), f"undefined: {v.value} {v.attr}" - if isinstance(v, _Undefined): - return Value(value=v, is_const=True) - maybe_self, attrs = find_method_through_phi_parent(i) - - attr_string = ".".join(attrs) - if maybe_self.value is gr.module and (isinstance(v, torch.Tensor) or (attr_string in return_values)): - # the new attributes come directly after the self argument - idx = attr_dict.setdefault(attr_string, len(attr_list) + 1) - if idx == len(attr_list) + 1: - func_arg = Value(name=attr_string, is_function_arg=True) - gr.local_variables_at_start.insert(idx, func_arg) - attr_list.append(attr_string) - attr_values.append(v) - gr.co_argcount += 1 - # we need a default argument to be able to put the things at the end (but this will have to change for *args, **kwargs anyway... - # gr.func_defaults.append(None) - if attr_string in return_values: - return_values[attr_string].add_missing_value(func_arg) - else: - func_arg = gr.local_variables_at_start[idx] - - pvs = [pv for pv in func_arg.phi_values if pv.block is bl] - if not pvs: - pv = PhiValue([func_arg], [None], bl) - bl.block_inputs.append(pv) - else: - (pv,) = pvs - ## remove old input from phi_values etc? - return pv - if maybe_self.value is gr.module and ( - n.i.opname not in {"BINARY_SUBSCR"} and not isinstance(v, torch.nn.Module) - ): - ## inline to const... - i.value = v - i.typ = type(i.value) - i.parent = None - i.is_const = True - i.is_global = False - return None - return None - - return_block = None - for bl in gr.blocks: - if bl.nodes[-1].i.opname == "RETURN_VALUE": - assert return_block is None, "multiple return statements should not happen here" - return_block = bl - assert return_block is not None, "could not find return block" - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname == "STORE_ATTR": - v = find_and_evaluate_method_through_phi_parent(n.inputs[1]) - if isinstance(v, _Undefined): - n.inputs[1] = Value(value=v, is_const=True) - continue - # assert not isinstance(v, _Undefined), f"undefined: {v.value} {v.attr}" - maybe_self, attrs = find_method_through_phi_parent(n.inputs[1]) - attrs.append(n.i.argval) - if maybe_self.value is gr.module: - attr_string = ".".join(attrs) - n.i = n.i.modify_copy(opname=X_THUNDER_STORE_ATTR, opcode=None, argval=attr_string) - pv = return_values.get(attr_string) - if pv is None: - pv = PhiValue([], [], return_block) - pv.name = attr_string - return_values[attr_string] = pv - return_block.block_inputs.append(pv) - v = Value(node=n, name=attr_string, block=bl) # disambiguate? - pv.add_missing_value(v, jump_source=bl.nodes[-1]) - n.outputs = [v] - bl.block_outputs.add(v) - del n.inputs[1] - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname == "CALL_METHOD": - if n.inputs[0].parent == n.inputs[1]: - v = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if not isinstance(v, types.MethodType) or v.__self__ != find_and_evaluate_method_through_phi_parent( - n.inputs[1] - ): - # this case (not a proper method call is usually handled in executing the LOAD_METHOD opcode) - n.i = n.i.modify_copy(opname="CALL_FUNCTION", opcode=None) - del n.inputs[1] - - for idx_i, i in enumerate(n.inputs): - v = functionalize_value_if_possible(i) - if v is not None: - n.inputs[idx_i] = v - - bl.block_outputs = OrderedSet( - [v if (v := functionalize_value_if_possible(o)) is not None else o for o in bl.block_outputs] - ) - - if return_values: - bt_extra = Node( - i=get_instruction(opname="BUILD_TUPLE", arg=1 + len(return_values)), - source_infos=copy.deepcopy(return_block.nodes[-1].source_infos), - ) - bt_extra.inputs = return_block.nodes[-1].inputs + list(return_values.values()) - v_tuple_extra = Value(node=bt_extra, block=return_block) - bt_extra.outputs = [v_tuple_extra] - return_block.nodes.insert(-1, bt_extra) - return_block.nodes[-1].inputs = [v_tuple_extra] - - remove_unused_values(gr) - if gr.local_variables_at_start[0].phi_values: - gr.summary(print_lines=True) - raise RuntimeError( - """could not eliminate self argument - this most likely means that you are setting attributes in forward or using them - in an unexpected way that thunder does not yet support. - The problem lies in (indirect) uses of V_0 in the graph above.""" - ) - - # check to avoid assignments for both a.b and a.b.c - sorted_keys = sorted(return_values.keys()) # this uses that '.' sorts before other things - for i in range(len(sorted_keys) - 1): - kbase = sorted_keys[i] - knext = sorted_keys[i + 1] - if knext.startswith(kbase) and knext[len(kbase)] == ".": - # N.B. we know that knext is longer if kbase is a prefix so the knext[len(kbase)] above will not be out of bounds. - raise RuntimeError(f"Assigning to members of assigned members ('{kbase}' and '{knext}') is not supported.") - - del gr.local_variables_at_start[0] - gr.co_argcount -= 1 - if gr.co_posonlyargcount > 0: - gr.co_posonlyargcount -= 1 - - # thunder.core.script.graph.check_graph(gr) - # gr.summary(print_lines=True) - - return attr_list, attr_values, list(return_values.keys()) diff --git a/thunder/core/script/protograph.py b/thunder/core/script/protograph.py deleted file mode 100644 index 169e4bc60c..0000000000 --- a/thunder/core/script/protograph.py +++ /dev/null @@ -1,518 +0,0 @@ -from __future__ import annotations - - -import abc -import collections -import dataclasses -import functools -import inspect -import itertools -from types import CodeType -from typing import cast, overload, Any, Literal, NewType -from collections.abc import Iterable, Iterator, Mapping - -from thunder.core.script import algorithms, instrumentation, parse, values -from thunder.core.utils import debug_asserts_enabled, FrozenDict, OrderedSet - -__all__ = ("ProtoBlock", "ProtoGraph") - -# ============================================================================= -# == Inter-ProtoBlock abstract value flow ===================================== -# ============================================================================= -# -# ProtoBlocks are weakly coupled by design. The `VariableKey` slots allow edges -# to be deduced (e.g. `x` at the start of one block must be the same as `x` at -# the end of the prior block), but there's no strong requirement. (And indeed, -# the ProtoGraph immediately after parsing has all unconnected `AbstractRef`s -# for input values.) Similarly, ProtoGraph serves only to record organize the -# block topology, check invariants, and provide various helper methods. -# -# This weak coupling exists to facilitate graph rewrites and reduce the surface -# area for self-inconsistent representation. By readily discarding (deduced) -# information we don't need to carry invariants through complex passes; we can -# simply decouple the graph, perform whatever local modifications we like, and -# then reconnect everything. This representation is immutable (notwithstanding -# a few implementation details), so "decouple" means emitting a new erased -# graph. (Though simple value replacements can be done directly.) -JumpTarget = NewType("JumpTarget", tuple["ProtoBlock", parse.Jump]) -Uses = NewType("Uses", OrderedSet[parse.VariableKey]) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ProtoBlock(instrumentation.InstrumentingBase): # type: ignore[misc,no-any-unimported] - """Stores abstract data flow for a code block.""" - - flow: values.IntraBlockFlow - jump_targets: tuple[JumpTarget, ...] = dataclasses.field(default=(), init=False) - uses: Uses = dataclasses.field(default_factory=lambda: Uses(OrderedSet()), init=False) - - def __repr__(self) -> str: - ops = "\n".join(f" {i.opname}" for i, _ in self.flow.symbolic) - return f"ProtoBlock: {hex(id(self))}\n{ops}" - - def __hash__(self) -> int: - return id(self) - - def __post_init__(self) -> None: - self.uses.update(self.flow.uses) - - def add_jump_target(self, other: ProtoBlock, jump: parse.Jump) -> None: - """We need to add jump targets after all ProtoBlocks are initialized.""" - - # Override `frozen=True` for this one limited use case. - object.__setattr__(self, "jump_targets", self.jump_targets + ((other, jump),)) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ProtoGraph: - protoblocks: tuple[ProtoBlock, ...] - root: ProtoBlock - parents: Mapping[ProtoBlock, tuple[ProtoBlock, ...]] - - Provenance = values.ParsedSymbolic | tuple[type["ProtoGraphTransform"], "ProtoGraph"] - provenance: Provenance - - def __init__(self, protoblocks: Iterable[ProtoBlock], provenance: Provenance) -> None: - G = algorithms.TypedDiGraph[ProtoBlock]() - for protoblock in (protoblocks := tuple(protoblocks)): - is_return = tuple(protoblock.flow.symbolic)[-1][0].opname == parse.RETURN_VALUE - G.add_node(protoblock, is_return=is_return) - - for protoblock in protoblocks: - for destination, jump in protoblock.jump_targets: - G.add_edge(protoblock, destination, adjacent=not jump) - - assert protoblocks - object.__setattr__(self, "protoblocks", tuple(algorithms.sort_adjacent(G))) - assert len(G) == len(self.protoblocks) == len(protoblocks), (len(G), len(self.protoblocks), len(protoblocks)) - - object.__setattr__(self, "root", self.protoblocks[0]) - root_stack = [(k, v) for k, v in self.root.flow.begin_state if k.scope == parse.VariableScope.STACK] - assert not root_stack, f"Root block should not have stack inputs: {root_stack}" - - nodes = cast(Iterable[ProtoBlock], G.nodes) # For some reason mypy needs this. - parents = {protoblock: tuple(G.predecessors(protoblock)) for protoblock in nodes} - object.__setattr__(self, "parents", FrozenDict(parents)) - object.__setattr__(self, "provenance", provenance) - - @classmethod - def from_code(cls, co: CodeType) -> ProtoGraph: - """Given a method, disassemble it to a sequence of simple blocks.""" - parsed = values.ParsedSymbolic.make(parse.ParsedFunctional.make(co)) - protoblocks = tuple( - ProtoBlock(values.IntraBlockFlow(symbolic, begin, end)) for symbolic, begin, end in parsed.blocks - ) - for source, sink, jump in parsed.provenance.provenance.edges: - protoblocks[source].add_jump_target(protoblocks[sink], jump) - - return cls(protoblocks, parsed) - - def __iter__(self) -> Iterator[ProtoBlock]: - yield from self.protoblocks - - def __getitem__(self, index: int) -> ProtoBlock: - return self.protoblocks[index] - - def __len__(self) -> int: - return len(self.protoblocks) - - def __repr__(self) -> str: - return "\n\n".join(repr(protoblock) for protoblock in self) - - @property - def flat_flow(self) -> Iterable[tuple[parse.ThunderInstruction, values.Symbolic, values.Materialized]]: - for protoblock in self: - for instruction, symbolic in protoblock.flow.symbolic: - yield instruction, symbolic, protoblock.flow.materialized[instruction] - - @property - def is_linked(self) -> bool: - # NOTE: `is_linked` is vacuously True for a single block graph. - flat_begin = itertools.chain(*(i.flow._begin.values() for i in self if i is not self.root)) - return len(self) == 1 or any(not isinstance(i, values.AbstractRef) for i in flat_begin) - - def unlink(self) -> ProtoGraph: - return Unlink(self).apply(or_default=True) - - def link(self) -> ProtoGraph: - if result := ProtoGraphTransform.chain(self, AddTransitive, MatchStacks, Connect): - assert AddTransitive(result).apply(or_default=False) is None - assert (result or self).is_linked - return result or self - - def debug_print_protoflows(self) -> None: - """ - Print out the node_flow for each protoblock in the - protograph, in a way that's nice to read and debug with. - """ - - counter = 0 - idxes: dict[values.AbstractValue, int] = {} - for pb in self: - for node in pb.flow.materialized.values(): - for val in itertools.chain(node.inputs.ordered, node.outputs): - if val not in idxes.keys(): - idxes[val] = counter - counter += 1 - - def to_index_str(values: tuple[values.AbstractValue, ...]) -> str: - indices = (str(idxes[v]) for v in values) - return f"({', '.join(indices)})" - - for i, pb in enumerate(self): - print(f"Protoblock {i}:") - print(f"{'':>22}Inputs, Outputs") - for instruction, node in pb.flow.materialized.items(): - print(f" {instruction.opname:>20}, {to_index_str(node.inputs.ordered)} -> {to_index_str(node.outputs)}") - print("\n") - - -# ============================================================================= -# == Graph transforms (Base classes) ========================================== -# ============================================================================= -# ProtoGraphTransform -# ReplaceProtoBlocks -# ReplaceValues -# CondenseValues -# ReplaceSymbolic - - -class ProtoGraphTransform(abc.ABC): - """Handles mechanical portions of graph rewrites. - The base case is unopinionated; it simply accepts whatever new ProtoGraph is - emitted by `self._apply`. The primary feature it provides is checking. - - NOTE: - The convention adopted is for the pass logic to produce `T | None` - (e.g. `ProtoGraph | None`) where `None` signals that no change is - applicable. - - Forbid Linked: - A key invariant of ProtoGraph is that every AbstractValue has exactly - **one** producer, which is set by the symbolic flow. (With the exception - of `AbstractRef`s which are placeholders for an as-yet unspecified - AbstractValue.) However, within a block there is a flat list of - **concrete** values specifying the state at the start of the block. - - If one were to replace all instances of `X` in a ProtoGraph with `Y`, - this invariant would be preserved. On the other hand, if one were to - replace `X` with `Y` **only at the symbolic producer of `X`** then - downstream blocks could still have `X` as a block input, despite the - fact that `X` no longer has a producer. (Note that this is only a problem - across blocks; within blocks the materialization pass respects the update - and emits a consistent materialized state for the new ProtoBlock.) - - It is often convenient to simply rewrite the symbolic flow within a - single ProtoBlock. In that case the correct procedure is to generate an - unlinked ProtoGraph, perform the local rewrites, and then relink it. - (Where the connection pass will handle reconciliation automatically.) - - Check idempotence: - Nearly all passes are expected to be idempotent. This provides a good - deal of free test coverage since it produces both a test case (the result - of `self._apply`) and an expected result. (That `self._apply` returns - `None`.) We perform this check many times in order to flush out - non-deterministic passes. (Though the value is configurable if a pass is - particularly expensive.) - - However, given the potential added start up latency and possibility of - spurious failures this check is gated by `debug_asserts_enabled`, which - defaults to `False`. (Except for unit tests.) - """ - - _forbid_linked: bool = False - _kwargs: FrozenDict[str, Any] # Used to replay transform for `idempotent` check. - _idempotent_repeats: int = 10 - - @abc.abstractmethod - def _apply(self) -> ProtoGraph | None: - """Override this method to emit an (optional) new ProtoGraph.""" - ... - - def __new__(cls, *args: Any, **kwargs: Any) -> ProtoGraphTransform: - self = super().__new__(cls) - bound = inspect.signature(self.__class__.__init__).bind(None, *args, **kwargs).arguments - bound.pop("self") - bound.pop("proto_graph") - self._kwargs = FrozenDict(bound) - return self - - def __init__(self, proto_graph: ProtoGraph) -> None: - assert not (self._forbid_linked and len(proto_graph) > 1 and proto_graph.is_linked), self - assert isinstance(proto_graph, ProtoGraph) - self._protograph = proto_graph - - @property - def protograph(self) -> ProtoGraph: - return self._protograph - - @overload - def apply(self, or_default: Literal[False]) -> ProtoGraph | None: - ... - - @overload - def apply(self, or_default: Literal[True]) -> ProtoGraph: - ... - - def apply(self, or_default: bool = False) -> ProtoGraph | None: - result = self._apply() - if debug_asserts_enabled(): - result_to_check = result or self.protograph - for i in range(self._idempotent_repeats): - assert self.__class__(proto_graph=result_to_check, **self._kwargs)._apply() is None, (i, self) - return result or (self.protograph if or_default else None) - - @staticmethod - def chain(proto_graph: ProtoGraph, *transforms: type[ProtoGraphTransform]) -> ProtoGraph | None: - initial = proto_graph - for transform in transforms: - proto_graph = transform(proto_graph).apply(or_default=True) - return None if proto_graph is initial else proto_graph - - -class ReplaceProtoBlocks(ProtoGraphTransform): - """Helper to replace individual ProtoBlocks while retaining the same ProtoGraph topology.""" - - @abc.abstractmethod - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - ... - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - pass - - def _apply(self) -> ProtoGraph | None: - # TODO(robieta): Right now block order is load bearing, so we have to preserve it. - transformed = {i: self.apply_to_protoblock(i) for i in self.protograph} - - if any(transformed.values()): - replacements = {i: ProtoBlock(flow or i.flow) for i, flow in transformed.items()} - for old_protoblock, new_protoblock in replacements.items(): - self.post_apply(old_protoblock, new_protoblock) - for old_target, is_jump in old_protoblock.jump_targets: - new_protoblock.add_jump_target(replacements[old_target], is_jump) - return ProtoGraph(replacements.values(), provenance=(self.__class__, self.protograph)) - return None - - -class ReplaceValues(ReplaceProtoBlocks): - """Copies the ProtoGraph with value replacements. - - NOTE: This is strictly a condensing transform, and this is only invertible - (using another `ReplaceValues`) in trivial cases. - """ - - _retain_uses: bool = True - - @abc.abstractproperty - def replace_map(self) -> values.ReplaceMap: - ... - - @functools.cached_property - def _replace_map(self) -> values.ReplaceMap: - replace_map = self.replace_map - assert not (invalid := [k for k in replace_map if isinstance(k, values.NonPyObject)]), invalid - return FrozenDict(algorithms.flatten_map(replace_map)) - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - return protoblock.flow.substitute(self._replace_map) - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - if self._retain_uses: - new.uses.update(old.uses) - - -class CondenseValues(ReplaceValues): - ValueEdges = Iterable[tuple[values.AbstractValue, values.AbstractValue]] - - @abc.abstractproperty - def edges(self) -> ValueEdges: - ... - - @property - def replace_map(self) -> values.ReplaceMap: - replace_map: dict[values.AbstractValue, values.AbstractValue] = {} - edges = itertools.chain(self.edges, self._phivalue_constituent_edges) - for v, condensed in algorithms.compute_condense_map(edges).items(): - # Check invariants. - assert condensed - if not isinstance(v, values.AbstractPhiValue): - invariants = ({c.identity for c in condensed} == {v.identity}, not isinstance(v, values.AbstractRef)) - assert all(invariants) or not any(invariants), (invariants, v, condensed) - - # `AbstractPhiValue._unpack_apply` will determine if we need an AbstractPhiValue. - if (replacement := values.substitute_value(values.AbstractPhiValue(tuple(condensed)), {})) != v: - replace_map[v] = replacement - - return FrozenDict(replace_map) - - @property - def _phivalue_constituent_edges(self) -> ValueEdges: - # AbstractPhiValues are somewhat unusual in that mismatches between blocks - # are expected (that's sort of the point...) so we need to decompose them - # so the condense pass doesn't get tripped up. - for _, initial_ref in self.protograph.root.flow.begin_state: - if isinstance(initial_ref, values.AbstractPhiValue): - yield from ((constituent, initial_ref) for constituent in initial_ref.constituents) - - -class ReplaceSymbolic(ReplaceProtoBlocks): - _forbid_linked = True - - @abc.abstractmethod - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - inputs: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - ... - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - flow_state = values.DigestFlow(protoblock.flow._begin) - updated_symbolic: dict[parse.ThunderInstruction, values.Symbolic | None] = {} - for i, symbolic in protoblock.flow.symbolic: - updated_symbolic[i] = self.apply_to_symbolic(i, symbolic, symbolic.inputs.map(flow_state.get)) - _ = flow_state.next(i, updated_symbolic[i] or symbolic) - - if any(updated_symbolic.values()): - new_symbolic = {k: v or protoblock.flow._symbolic[k] for k, v in updated_symbolic.items()} - return dataclasses.replace(protoblock.flow, _symbolic=FrozenDict(new_symbolic)) - return None - - -# ============================================================================= -# == Graph transforms (Applied) =============================================== -# ============================================================================= -class Unlink(ReplaceProtoBlocks): - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - if protoblock is not self.protograph.root: - uses = (flow := protoblock.flow).uses.copy() - end: values.Symbolic.EndState = FrozenDict({k: v for k, v in flow._end.items() if k != v}) - uses.update(v for v in end.values() if isinstance(v, parse.VariableKey) and not v.is_const) - any_non_ref = any(not isinstance(i, values.AbstractRef) for i in flow._begin.values()) - if any_non_ref or len(end) < len(flow._end) or flow._begin.keys() ^ uses: # symmetric_difference - begin: FrozenDict[parse.VariableKey, values.AbstractValue] - begin = FrozenDict({k: values.AbstractRef(f"Unlink: {k}") for k in uses}) - return dataclasses.replace(protoblock.flow, _begin=begin, _end=end) - - return None - - def _apply(self) -> ProtoGraph | None: - result = super()._apply() - assert len(result or self.protograph) == 1 or not (result or self.protograph).is_linked, result - return result - - -class AddTransitive(ReplaceProtoBlocks): - """Extend abstract value flows to include those needed by downstream blocks. - This pass effectively functionalizes the abstract value flow by plumbing - reads through parents as transitive dependencies. Note that we assume - variables are only modified by `STORE_...` and `DELETE_...` instructions. - This is not a sound assumption since opaque calls (`CALL_FUNCTION`, - `CALL_METHOD`, etc.) could mutate global and nonlocal variables. This does - not, however, pose an overall soundness problem because we can check for - state mutations during inlining and rerun flow analysis. - """ - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - flow = protoblock.flow - end = FrozenDict({**{use: use for use in self.target_uses(protoblock, self.expanded_uses)}, **flow._end}) - if (missing := self.expanded_uses[protoblock].difference(protoblock.uses)) or (end != flow._end): - begin = {**{k: values.AbstractRef("Transitive") for k in missing}, **flow._begin} - return dataclasses.replace(flow, _begin=FrozenDict(begin), _end=FrozenDict(end)) - return None - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - new.uses.update(self.expanded_uses[old]) - - @functools.cached_property - def expanded_uses(self) -> Mapping[ProtoBlock, Uses]: - """Identify new transitive value dependencies. - The process is more involved than simply checking for mismatches because - adding a transitive value to a block may necessitate adding a transitive - value to the prior block and so on. - """ - uses = {protoblock: protoblock.uses.copy() for protoblock in self.protograph} - blocks_to_process = collections.deque(uses.keys()) - - while blocks_to_process: - protoblock = blocks_to_process.popleft() - target_uses = self.target_uses(protoblock, uses) - - # The reason we can ignore ALL `_OutputRef`s (including those that would index into a composite) - # is that the (potential) composite's dependencies are already handled by `ProtoBlock._flow_uses`. - transitive_uses = OrderedSet( - source - for use in target_uses - if isinstance(source := protoblock.flow._end.get(use, use), parse.VariableKey) - and source.scope != parse.VariableScope.CONST - ) - - if transitive_uses - uses[protoblock]: - uses[protoblock].update(transitive_uses) - blocks_to_process.extend(self.protograph.parents[protoblock]) - - return FrozenDict(uses) - - @staticmethod - def target_uses(protoblock: ProtoBlock, uses: Mapping[ProtoBlock, Uses] = FrozenDict()) -> Uses: - flat_uses = itertools.chain(*(uses.get(target, target.uses) for target, _ in protoblock.jump_targets)) - return Uses(OrderedSet(use for use in flat_uses if use.scope != parse.VariableScope.CONST)) - - -class MatchStacks(ReplaceProtoBlocks): - """Ensure stacks match across blocks. - - ProtoGraph doesn't rely on stack behavior (push, pop TOS, etc.), however it - is still a good sanity check. (Which is why `Connect._inter_block_edges` asserts.) - """ - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - upstream = OrderedSet[parse.VariableKey]() - for parent in self.protograph.parents[protoblock]: - upstream.update(k for k in parent.flow._end if k.scope == parse.VariableScope.STACK) - - if delta := upstream - protoblock.flow._begin: - begin = {**protoblock.flow._begin, **{k: values.AbstractRef(f"Match stack: {k}") for k in delta}} - return dataclasses.replace(protoblock.flow, _begin=FrozenDict(begin)) - return None - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - new.uses.update(old.uses) - - -class Connect(CondenseValues): - @property - def edges(self) -> CondenseValues.ValueEdges: - yield from self._inter_block_edges(self.protograph) - yield from self._graph_input_edges(self.protograph) - - @staticmethod - def _graph_input_edges(proto_graph: ProtoGraph) -> CondenseValues.ValueEdges: - for key, initial_ref in proto_graph.root.flow.begin_state: - if isinstance(initial_ref.identity, values.ExternalRef): - continue - - assert isinstance(initial_ref, values.AbstractRef), initial_ref - assert key.scope not in ( - parse.VariableScope.CONST, - parse.VariableScope.STACK, - ), (key, proto_graph.root.flow._begin) - yield values.CompositeValue().add_identity(values.ExternalRef(key)), initial_ref - - @staticmethod - def _inter_block_edges(proto_graph: ProtoGraph) -> CondenseValues.ValueEdges: - for protoblock in proto_graph: - for child, _ in protoblock.jump_targets: - outputs = dict(protoblock.flow.end_state) - child_inputs = dict(child.flow.begin_state) - for key, child_input in child_inputs.items(): - yield outputs.get(key, values.NonPyObject(values.NonPyObject.Tag.MISSING)), child_input - - # `AddTransitive` should ensure the stacks match. - # (Except for return blocks which may discard the stack.) - opname = tuple(child.flow.symbolic)[-1][0].opname - if opname not in parse.RAISE_RETURN_INSTRUCTIONS: - s_out = tuple(sorted(i.identifier for i in outputs if i.scope == parse.VariableScope.STACK)) - s_in = tuple(sorted(i.identifier for i in child_inputs if i.scope == parse.VariableScope.STACK)) - assert s_out == s_in, f"{s_out=} != {s_in=}, {opname}" diff --git a/thunder/core/script/protograph_passes.py b/thunder/core/script/protograph_passes.py deleted file mode 100644 index 657d6d6c8f..0000000000 --- a/thunder/core/script/protograph_passes.py +++ /dev/null @@ -1,74 +0,0 @@ -import dataclasses -from collections.abc import Iterable - -from thunder.core.script import parse, values -from thunder.core.script.protograph import ProtoGraph, ProtoGraphTransform, AddTransitive, ReplaceSymbolic -from thunder.core.utils import debug_asserts_enabled - -ValueEdges = Iterable[tuple[values.AbstractValue, values.AbstractValue]] -KNOWN_TUPLE = values.TraitName("__known_tuple") - - -def _connect_protograph(proto_graph: "ProtoGraph") -> "ProtoGraph": - proto_graph = proto_graph.link() - assert AddTransitive(proto_graph).apply(or_default=False) is None - for protoblock in proto_graph: - for k, v in protoblock.flow.begin_state: - assert not v.is_detail, (k, v) - return proto_graph - - -class MarkTuples(ReplaceSymbolic): - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - _: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - if instruction.opname == "BUILD_TUPLE" and isinstance(output := symbolic.outputs[0], values.IntermediateValue): - assert len(symbolic.outputs) == 1, symbolic.outputs - ordered = tuple(values.Reference(i) for i in range(-len(symbolic.inputs.ordered), 0)) - new_output = values.CompositeRef(ordered=ordered).add_named(KNOWN_TUPLE, values.ConstRef(True)) - return dataclasses.replace(symbolic, outputs=(new_output.add_identity(output),)) - return None - - -class IndexTuples(ReplaceSymbolic): - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - inputs: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - replacement: values.Symbolic | None = None - if instruction.opname == "BINARY_SUBSCR": - to_index, index = inputs.ordered - is_tuple = isinstance(to_index, values.CompositeValue) and to_index.get(KNOWN_TUPLE) - index_key = index.key if isinstance(index, values.ExternalRef) and index.key.is_const else None - if is_tuple and index_key and isinstance(idx := index_key.identifier, int): - assert len(symbolic.outputs) == 1 - replacement = dataclasses.replace(symbolic, outputs=((values.Reference(0), values.Reference(idx)),)) - - elif instruction.opname == "UNPACK_SEQUENCE": - (to_unpack,) = inputs.ordered - if isinstance(to_unpack, values.CompositeValue) and to_unpack.get(KNOWN_TUPLE): - indices = (values.Reference(idx) for idx in range(-1, -len(symbolic.outputs) - 1, -1)) - outputs = tuple((values.Reference(0), idx) for idx in indices) - replacement = dataclasses.replace(symbolic, outputs=outputs) - - elif instruction.opname == "UNPACK_EX": - pass # TODO(apaz-cli): figure out indexing. - - return replacement if (replacement and replacement.outputs != symbolic.outputs) else None - - -def _tuple_fold(proto_graph: ProtoGraph) -> ProtoGraph: - """Replace tuple accesses (`BINARY_SUBSCR`, `UNPACK_SEQUENCE` instructions) with their members, if known.""" - return ProtoGraphTransform.chain(proto_graph, MarkTuples, IndexTuples) or proto_graph - - -def apply_protograph_passes(protograph: ProtoGraph) -> ProtoGraph: - protograph = _tuple_fold(protograph.unlink()) - protograph = _connect_protograph(protograph) - assert AddTransitive(protograph).apply(or_default=False) is None - return protograph diff --git a/thunder/core/script/python_ir.py b/thunder/core/script/python_ir.py deleted file mode 100644 index 2d66fbb77b..0000000000 --- a/thunder/core/script/python_ir.py +++ /dev/null @@ -1,507 +0,0 @@ -import collections -import dis -import inspect -import sys -import types -from typing import Any, Dict, List, Optional, Tuple, Union -from collections.abc import Callable -from collections.abc import Hashable - -from thunder.core.script.graph import ( - assert_block, - _generate_raises, - Graph, - GraphSummaryCallback, - MROAwareObjectRef, - Node, - SourceInformation, - Value, - insert_before, - insert_after, - _Undefined, -) -from thunder.core.script.parse import RETURN_VALUE -from thunder.core.script.python_ir_data import get_instruction, X_THUNDER_STORE_ATTR -from thunder.core.utils import OrderedSet - - -def undo_ssa(gr: "Graph") -> tuple[list[Value], list[str], list[str], list[Any]]: - consts: list[Any] = [] - names: list[str] = [] - - def get_value(v: Value, n: Node, inpidx: int | None = None) -> None: - if n.i.opname == "CALL_METHOD" and inpidx == 1: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - if idx > 0 and bl.nodes[idx - 1].i.opname == "LOAD_METHOD": - # if we just a LOAD_METHOD, that did put input 0 and 1 on the stack - return - else: - # else the loading has been separated from the call, so we - # switch to call LOAD_ATTR/CALL_FUNCTION instead - n.i = n.i.modify_copy(opname="CALL_FUNCTION", opcode=None) - return - if isinstance(v.value, _Undefined): - idx = len(consts) - consts.append( - _generate_raises(f"attribute error '{type(v.value.value)}' object has no attribute '{v.value.attr}'") - ) - new_n = Node( - i=get_instruction(opname="LOAD_CONST", arg=idx), - outputs=[Value(value=consts[idx], is_const=True)], - inputs=[], - ) - new_n.inserted_for = n - insert_before(new_n, n) - new_n = Node(i=get_instruction(opname="CALL_FUNCTION", arg=0), outputs=[v], inputs=[consts[idx]]) - new_n.inserted_for = n - insert_before(new_n, n) - return - if v.is_const: - idx = len(consts) - consts.append(v.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - elif isinstance(v.value, MROAwareObjectRef): - # this works for attribs, but for methods? maybe have a pass eliminating/making explicit the super... - get_value(v.value.obj, n) - elif v.parent is not None: - assert v.name is not None - get_value(v.parent, n) - if n.i.opname == "CALL_METHOD" and inpidx == 0: - # print("###inputs", n.inputs, v, v in n.inputs) - try: - idx = names.index(v.name) - except ValueError: - idx = len(names) - names.append(v.name) - new_n = Node( - i=get_instruction(opname="LOAD_METHOD", arg=idx), - outputs=[v, v.parent], - inputs=[v.parent], - ) - new_n.inserted_for = n - insert_before(new_n, n) - elif n.i.opname == "LOAD_ATTR": - # print("###load attr", n.outputs, n.i.argval) - pass - else: - assert v.name is not None - try: - idx = names.index(v.name) - except ValueError: - idx = len(names) - names.append(v.name) - new_n = Node( - i=get_instruction(opname="LOAD_ATTR", arg=idx), - outputs=[v], - inputs=[v.parent], - ) - new_n.inserted_for = n - insert_before(new_n, n) - elif v.is_global and isinstance(v.value, Hashable) and v.value in __builtins__: - # Builtins are unmarshallable and meant to be loaded globally. If they are - # included in co_consts, the resulting function cannot go into a .pyc file. - # Originally, the plan was to check if the value is a builtin by checking - # if its type is "". However, this - # turned out not to work since torch for some reason decided to set the - # type of `torch.nn.functional.has_torch_function` to also be a builtin. - if v.name not in names: - names.append(v.name) - idx = names.index(v.name) - new_n = Node(i=get_instruction(opname="LOAD_GLOBAL", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - elif v.is_global: # make binding the globals optional? - if v.value not in consts: - consts.append(v.value) - idx = consts.index(v.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - else: - idx = local_vars[v] - # assert idx >= 0 - new_n = Node(i=get_instruction(opname="LOAD_FAST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - - for bl in gr.blocks: - for n in bl.nodes: - n.block = bl - - local_vars: dict[Value, int] = {} - lv_names: OrderedSet[str] = OrderedSet() - - def get_or_add_lv(v: Value, name: str | None = None) -> int: - idx = local_vars.get(v) - if idx is None: - idx = len(local_vars) - local_vars[v] = idx - - # handle name collisions... - if name is None: - name = v.name - - if name is None: - name = f"_tmp_{idx}" - else: - name = name.replace(".", "_").replace("[", "").replace("]", "") - - if not name[:1].isalpha(): - name = "_" + name - fullname = name - suffix = 0 - while fullname in lv_names: - suffix += 1 - fullname = f"{name}_{suffix}" - lv_names.add(fullname) - if v.name is None: # TODO: or do this always? - v.name = fullname - return idx - - nodes_to_skip = set() - - def store_phi_values(o: Value, o_idx: int, last_n: Node | None, cur_n: Node | None) -> Node | None: - phi_values_in_processing = set() - - def store_phi_values_inner(o: Value, o_idx: int, last_n: Node | None) -> Node | None: - if o in phi_values_in_processing: - # avoid loops - return last_n - phi_values_in_processing.add(o) - for v in o.phi_values: - # TODO: refactor into general mechanism - idx2 = get_or_add_lv(v) - # last_n = store_phi_values_inner(v, o_idx, last_n) - if o.is_const: - if o.value not in consts: - consts.append(o.value) - o_idx = consts.index(o.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=o_idx), outputs=[o], inputs=[]) - new_n.inserted_for = cur_n - else: - new_n = Node(i=get_instruction(opname="LOAD_FAST", arg=o_idx), outputs=[o], inputs=[]) - new_n.inserted_for = cur_n - - nodes_to_skip.add(new_n) - if last_n is None: - insert_before(new_n, gr.blocks[0].nodes[0]) - else: - insert_after(new_n, last_n) - last_n = new_n - new_n = Node(i=get_instruction(opname="STORE_FAST", arg=idx2), outputs=[], inputs=[o]) - new_n.inserted_for = cur_n - nodes_to_skip.add(new_n) - insert_after(new_n, last_n) - last_n = new_n - return last_n - - return store_phi_values_inner(o, o_idx, last_n) - - for v in gr.local_variables_at_start: - if v is not None: - get_or_add_lv(v) - - # inputs in phi values - last_n = None - # need to make a copy of the list because we're adding items to the list - for idx, i in enumerate(tuple(local_vars.keys())): - last_n = store_phi_values(i, idx, last_n, cur_n=None) - for i in gr.blocks[0].block_inputs: # inlined parameters (partial) will be here - for v, js in zip(i.values, i.jump_sources): - if js is None and v.is_const: - last_n = store_phi_values(v, None, last_n, cur_n=None) - # print(i.values, i.jump_sources) - - names = [] - - for bl in gr.blocks: - jump_node = bl.nodes[-1] - for n in bl.nodes[:]: - processed_block_outputs = set() - if n not in nodes_to_skip: - n.inserted_for = n - for inpidx, i in enumerate(n.inputs): - get_value(i, n=n, inpidx=inpidx) - last_n = n - for o in n.outputs[::-1]: - idx = get_or_add_lv(o) - new_n = Node( - i=get_instruction(opname="STORE_FAST", arg=idx), - outputs=[], - inputs=[o], - ) - new_n.inserted_for = n - assert last_n is not None - insert_after(new_n, last_n) - last_n = new_n - if o in bl.block_outputs: - processed_block_outputs.add(o) - last_n = store_phi_values(o, idx, last_n, cur_n=n) - if n.i.opname in ("STORE_ATTR", "IMPORT_NAME"): # STORE_ATTR for unknown objs - # have a utility for this? - try: - idx = names.index(n.i.argval) - except ValueError: - idx = len(names) - names.append(n.i.argval) - n.i = n.i.modify_copy(arg=idx) - if n.i.opname == X_THUNDER_STORE_ATTR: - bl.nodes.remove(n) - if bl.nodes[-1].i.opname != RETURN_VALUE: # TODO Should the return block have outputs (probably not) - for o in bl.block_outputs: - if o not in processed_block_outputs: - get_value(o, n=jump_node) # before the jump - idx = get_or_add_lv(o, name="bo") - new_n = Node( - i=get_instruction(opname="STORE_FAST", arg=idx), - outputs=[], - inputs=[o], - ) - new_n.inserted_for = jump_node - insert_before(new_n, n=jump_node) - store_phi_values(o, idx, new_n, cur_n=jump_node) - - return list(local_vars.keys()), list(lv_names), names, consts - - -# this function is taken from PyTorch Dynamo (c) 2022 by Facebook/Meta licensed -# as per https://github.com/pytorch/pytorch/blob/master/LICENSE -def linetable_writer(first_lineno: int) -> tuple[list[int], Callable, Callable]: - """Used to create typing.CodeType.co_linetable See - https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt This is the internal format of the line number - table if Python >= 3.10.""" - assert sys.version_info >= (3, 9) - linetable: list[int] = [] - lineno = first_lineno - lineno_delta = 0 - byteno = 0 - - def _update(byteno_delta: int, lineno_delta: int) -> None: - while byteno_delta != 0 or lineno_delta != 0: - byte_offset = max(0, min(byteno_delta, 254)) - line_offset = max(-127, min(lineno_delta, 127)) - assert byte_offset != 0 or line_offset != 0 - byteno_delta -= byte_offset - lineno_delta -= line_offset - linetable.extend((byte_offset, line_offset & 0xFF)) - - def update(lineno_new: int, byteno_new: int) -> None: - nonlocal lineno, lineno_delta, byteno - byteno_delta = byteno_new - byteno - byteno = byteno_new - _update(byteno_delta, lineno_delta) - lineno_delta = lineno_new - lineno - lineno = lineno_new - - def end(total_bytes: int) -> None: - _update(total_bytes - byteno, lineno_delta) - - return linetable, update, end - - -def generate_function(gr: "Graph") -> Callable: - orig_gr = gr - gr, map_from_orig = gr.clone() - - local_vars, lv_names, names, consts = undo_ssa(gr) - assert len(local_vars) == len(lv_names) - - NodeKey = Union[Node, tuple[Node, bool]] - instruction_sizes: dict[NodeKey, int] = {} - - def build_address_map(end=False) -> dict[NodeKey, int]: - # Key either (for jump nodes and jump=True) - # or (, False) for non-jump in conditional jump - address_map: dict[NodeKey, int] = {} - ctr = 0 - for bl in gr.blocks: - # assumes first block is function start - for n in bl.nodes: - address_map[n] = ctr - ctr += instruction_sizes.get(n, 1) - if len(n.jump_targets) == 2: # implicit unconditional jump - ctr += instruction_sizes.get((n, False), 1) - if end: - address_map[n] = ctr - 1 - return address_map - - def make_bc() -> tuple[list[int], bool]: - bc = [] - - def write_extended_args(node_key: NodeKey, arg: int) -> bool: - # returns if instruction size has changed - instruction_size = instruction_sizes.get(node_key, 1) - if arg > 0x_FF_FF_FF or instruction_size == 4: - instruction_size = 4 - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append(arg >> 24) - if arg > 0x_FF_FF or instruction_size >= 3: - instruction_size = max(instruction_size, 3) - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append((arg >> 16) & 0xFF) - if arg > 0x_FF or instruction_size >= 2: - instruction_size = max(instruction_size, 2) - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append((arg >> 8) & 0xFF) - else: - instruction_size = 1 - - if instruction_size != instruction_sizes.get(node_key, 1): - instruction_sizes[node_key] = instruction_size - return True - return False - - changed_size = False - line_no = None - for bl in gr.blocks: - jump_node = None - for n in bl.nodes: - opcode = n.i.opcode - if opcode is None or opcode == -1: # Todo: opcode is typed int in ThunderInstruction, remove None here? - opcode = dis.opmap[n.i.opname] - assert opcode is not None, f"{n} has invalid opcode" - # source range instead for 3.11? - n_line_no = n.source_infos[-1].gen_line_no if n.source_infos else None - if n_line_no is not None and n_line_no != line_no: # really, the last generated one... - linetable_update( - n_line_no + gr.source_start_line, address_map[n] * 2 - ) # byte offset for Python 3.10, too... - line_no = n_line_no - if opcode in dis.hasjabs: - arg = address_map[n.jump_targets[-1].nodes[0]] - elif opcode in dis.hasjrel: - # TODO forward, backward - arg = address_map[n.jump_targets[-1].nodes[0]] - address_map[n] - 1 - else: - arg_ = n.i.arg - arg = 0 if arg_ is None else arg_ - - changed_size |= write_extended_args(n, arg) - - bc.append(opcode) - bc.append(arg & 0x_FF) - if len(n.jump_targets) > 1: - jump_node = n - if jump_node is not None: - assert len(jump_node.jump_targets) == 2 - jarg = address_map[jump_node.jump_targets[0].nodes[0]] - changed_size |= write_extended_args((jump_node, False), jarg) - i = get_instruction(opname="JUMP_ABSOLUTE", arg=jarg & 0xFF) - bc.append(i.opcode) - assert i.arg is not None - bc.append(i.arg) - return bc, not changed_size - - done = False - while not done: - linetable, linetable_update, linetable_end = linetable_writer(gr.source_start_line) - address_map = build_address_map() - bc, done = make_bc() - - inserted_for = collections.defaultdict(list) - end_address_map = build_address_map(end=True) - for n in gr.nodes(): - inserted_for[getattr(n, "inserted_for", None)].append(end_address_map[n]) - for n in orig_gr.nodes(): - info = inserted_for[map_from_orig[n]] - n.bytecode_range = (min(info), max(info)) if info else (None, None) - - linetable_end(len(bc)) - linetable_bytes = bytes(linetable) - bc_bytes = bytes(bc) - - lv_at_start = [v for v in gr.local_variables_at_start if v is not None] - co_argcount = gr.co_argcount - co_posonlyargcount = gr.co_posonlyargcount - co_kwonlyargcount = gr.co_kwonlyargcount - co_nlocals = len(local_vars) - # TODO: actually track the stack size when doing codegen (for optimizations) - co_stacksize = max(max(len(n.inputs), len(n.outputs)) for n in gr.nodes()) - co_flags = gr.co_flags - co_codestring = bc_bytes - co_consts = tuple(consts) - co_names = tuple(names) - co_varnames = tuple(lv_names) - co_filename = f"" - co_name = gr.co_name - co_firstlineno = gr.source_start_line - co_linetable = linetable_bytes # XXX - co_freevars = () - co_cellvars = () - - c = types.CodeType( - co_argcount, # int - co_posonlyargcount, # int - co_kwonlyargcount, # int - co_nlocals, # int - co_stacksize, # int - co_flags, # int - co_codestring, # bytes - co_consts, # tuple - co_names, # tuple - co_varnames, # tuple - co_filename, # string - co_name, # string - co_firstlineno, # integer - co_linetable, # bytes - co_freevars, # tuple - co_cellvars, # tuple - ) - - # types.FunctionType(code, globals, name=None, argdefs=None, closure=None) - func = types.FunctionType( - c, - { - "__builtins__": __builtins__, - }, - argdefs=tuple(gr.func_defaults), - ) - func.__kwdefaults__ = gr.func_kwdefaults - func._gr = orig_gr - - # simple cache hack - mtime = None # this signals that the cache should not be invalidated(!) - lines = gr.source_lines - size = len("".join(lines)) - inspect.linecache.cache[co_filename] = size, mtime, lines, co_filename - - try: - _ = tuple(dis.get_instructions(func)) - except BaseException as e: - raise RuntimeError("Unknown error generating callable") from e - - return func - - -def annotated_dis(thunder_fn, print_lines=True): - instructions = list(dis.get_instructions(thunder_fn)) - cur_pos = 0 - - class Callback(GraphSummaryCallback): - def node(self, n): - nonlocal cur_pos - before = [] - after = [] - begin_offset, end_offset = n.bytecode_range - if begin_offset is not None: - # the * 2 here and below is from Instruction.offset containing byte offsets and each - # bytecode is 2 bytes (this is true at least for Python 3.8-3.12) - while cur_pos < len(instructions) and instructions[cur_pos].offset < begin_offset * 2: - before.append(instructions[cur_pos]._disassemble()) - cur_pos += 1 - if end_offset is not None: - while cur_pos < len(instructions) and instructions[cur_pos].offset <= end_offset * 2: - after.append(instructions[cur_pos]._disassemble()) - cur_pos += 1 - return before, after - - def finish(self): - nonlocal cur_pos - l = [i._disassemble() for i in instructions[cur_pos:]] - cur_pos = len(instructions) - return l - - return thunder_fn._gr.summary(print_lines=print_lines, callback=Callback()) diff --git a/thunder/core/script/python_ir_data.py b/thunder/core/script/python_ir_data.py deleted file mode 100644 index c82eb168aa..0000000000 --- a/thunder/core/script/python_ir_data.py +++ /dev/null @@ -1,68 +0,0 @@ -import functools -import sys -from types import CodeType -from typing import Union -from collections.abc import Callable -from collections.abc import Iterable - -from thunder.core.script import parse - - -SUPPORTS_PREPROCESSING = (3, 9) <= sys.version_info < (3, 11) -X_THUNDER_STORE_ATTR = "X_THUNDER_STORE_ATTR" - - -# TODO(robieta): replace callsites. -get_instruction = functools.partial(parse.ThunderInstruction.make, line_no=-1) - - -def debug_compare_functions_print(diffs: dict[str, tuple[list, list]]): - for k, (v1, v2) in diffs.items(): - if not (v1 is None and v2 is None): - print(f"Differences in: {k}") - print(f" CodeObject 1: {v1}") - print(f" CodeObject 2: {v2}") - - -def debug_compare_functions( - code1: CodeType | Callable, code2: CodeType | Callable, *, show=False -) -> dict[str, tuple[list, list]]: - if not isinstance(code1, CodeType): - code1 = code1.__code__ - if not isinstance(code2, CodeType): - code2 = code2.__code__ - - attrs = [ - "co_argcount", - "co_kwonlyargcount", - "co_nlocals", - "co_stacksize", - "co_flags", - "co_consts", - "co_names", - "co_varnames", - "co_filename", - "co_name", - "co_freevars", - "co_cellvars", - ] - - diffs = {} - for attr in attrs: - v1 = getattr(code1, attr) - v2 = getattr(code2, attr) - - if v1 != v2: - if isinstance(v1, dict) and isinstance(v2, dict): - diffs[attr] = (v1 - v2, v2 - v1) - if isinstance(v1, str) and isinstance(v2, str): - diffs[attr] = (v1, v2) - elif isinstance(v1, Iterable) and isinstance(v2, Iterable): - diffs[attr] = (set(v1) - set(v2), set(v2) - set(v1)) - else: - diffs[attr] = (v1, v2) - - if show: - debug_compare_functions_print(diffs) - - return diffs diff --git a/thunder/core/script/values/__init__.py b/thunder/core/script/values/__init__.py deleted file mode 100644 index f1361cd857..0000000000 --- a/thunder/core/script/values/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from thunder.core.script.values.base import * -from thunder.core.script.values.composite import * -from thunder.core.script.values.materialization import * -from thunder.core.script.values.symbolic import * diff --git a/thunder/core/script/values/base.py b/thunder/core/script/values/base.py deleted file mode 100644 index 1956c31050..0000000000 --- a/thunder/core/script/values/base.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import dataclasses -import enum -import textwrap -from typing import overload, Any, Generic, NewType, TypeVar -from collections.abc import Callable -from collections.abc import Mapping - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.utils import FrozenDict - -__all__ = ( - # HybridMap - "Reference", - "TraitName", - "HybridMap", - # - # Values - "AbstractValue", - "AbstractRef", - "NonPyObject", - "IntermediateValue", - "ExternalRef", - # - # Substitution - "substitute_value", - "ReplaceMap", -) - - -# ============================================================================= -# == Generic (hybrid tuple/dict) container ==================================== -# ============================================================================= -T = TypeVar("T") -T1 = TypeVar("T1") -Reference = NewType("Reference", int) -TraitName = NewType("TraitName", str) - - -@dataclasses.dataclass(frozen=True, eq=True) -class HybridMap(Generic[T]): - ordered: tuple[T, ...] = dataclasses.field(kw_only=True, default_factory=tuple) - named: FrozenDict[TraitName, T] = dataclasses.field(kw_only=True, default_factory=FrozenDict) - - def __getitem__(self, key: Reference | TraitName) -> T: - if isinstance(key, int): - return self.ordered[key] - elif isinstance(key, str): - return self.named[key] - raise TypeError(f"Invalid key: {key}") - - def __repr__(self) -> str: - parts = [f"{self.__class__.__name__}("] - if self.ordered: - ordered = "\n".join(repr(i) for i in self.ordered) - parts.append(f" ordered:\n{textwrap.indent(ordered, ' ' * 4)}") - - if self.named: - named = "\n".join(f"{k}: {v}" for k, v in self.named.items()) - parts.append(f" named:\n{textwrap.indent(named, ' ' * 4)}") - - return "\n".join((*parts, ")")) - - @overload - def map(self, f: Callable[[T], T]) -> Self: - ... - - @overload - def map(self, f: Callable[[T], T1]) -> HybridMap[T1]: - ... - - def map(self, f: Any) -> Any: - ordered = tuple(f(i) for i in self.ordered) - named: FrozenDict[TraitName, T] = FrozenDict({k: f(v) for k, v in self.named.items()}) - return dataclasses.replace(self, ordered=ordered, named=named) - - def get(self, name: TraitName) -> T | None: - return self.named.get(name) - - def add_named(self, name: TraitName, value: T) -> Self: - named = dict(self.named) - named.update({name: value}) # Preserve order. - return dataclasses.replace(self, named=FrozenDict(named)) - - -# ============================================================================= -# == Simple value types ======================================================= -# ============================================================================= -class AbstractValue: - """Represents a value during instruction parsing. (Prior to type binding.)""" - - __is_detail = True - - def __init_subclass__(cls, **kwargs: Any) -> None: - cls.__is_detail = kwargs.pop("__is_detail", False) # `dataclasses` forces this into kwargs for some reason. - super().__init_subclass__(**kwargs) - - def __copy__(self) -> AbstractValue: - raise NotImplementedError - - @property - def is_detail(self) -> bool: - return self.__is_detail - - @property - def identity(self) -> AbstractValue: - """Analogous to `id(obj)`. For composites there is a layer of state management above the value itself. - - This is not suitable for equality checks (for example, mutation does not change - an object's identity), but it is often the appropriate target for `isinstance` checks. - """ - return self - - def _unpack_apply(self, _: ReplaceMap) -> AbstractValue: - """Recursively update any constituent references in the abstract value.""" - return self - - -@dataclasses.dataclass(frozen=True, eq=False) -class AbstractRef(AbstractValue, __is_detail=True): - """Placeholder value which will be resolved during parsing.""" - - _debug_info: str = "N/A" - - -@dataclasses.dataclass(frozen=True, eq=True) -class NonPyObject(AbstractValue): - """Singleton values used to signal some special interpreter state.""" - - class Tag(enum.Enum): - DELETED = enum.auto() - MISSING = enum.auto() - NULL = enum.auto() - - tag: Tag - - def __repr__(self) -> str: - return self.tag.name - - -class IntermediateValue(AbstractValue): - """A (potentially) new value produced by an instruction.""" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(at {hex(id(self))})" - - -@dataclasses.dataclass(frozen=True, eq=True) -class ExternalRef(AbstractValue): - """Reference values outside of the parsed code. (Arguments, constants, globals, etc.)""" - - key: parse.VariableKey - - def __repr__(self) -> str: - if self.key.is_const: - return f"Const({self.key.identifier})" - return f"{self.__class__.__name__}({self.key.identifier}, {self.key.scope.name})" - - -# ============================================================================= -# == Value substitution ======================================================= -# ============================================================================= -ReplaceMap = Mapping["AbstractValue", "AbstractValue"] - - -@overload -def substitute_value(v: AbstractValue, replace_map: ReplaceMap) -> AbstractValue: - ... - - -@overload -def substitute_value(v: T, replace_map: ReplaceMap) -> T: - ... - - -def substitute_value(v: Any, replace_map: ReplaceMap) -> Any: - """Find the replacement for `v`, and recursively substitute. (If applicable.) - - Some abstract values reference other abstract values. When we make substitution during - graph transformations it is necessary to also consider replacement of an abstract - value's constituents. Any subclass which must be unpacked in this manner should - override `_unpack_apply`. - """ - if not isinstance(v, AbstractValue): - return v - - new_v = replace_map.get(v, v) - if new_v != (x := replace_map.get(new_v, new_v)): - msg = f""" - `replace_map` may not contain chains. - {v} - {new_v} - {x} - See `flatten_map`.""" - raise ValueError(textwrap.dedent(msg)) - - return new_v._unpack_apply(replace_map) diff --git a/thunder/core/script/values/composite.py b/thunder/core/script/values/composite.py deleted file mode 100644 index bc00dd5469..0000000000 --- a/thunder/core/script/values/composite.py +++ /dev/null @@ -1,177 +0,0 @@ -import abc -import dataclasses -import itertools -from typing import Any, TypeVar -from collections.abc import Iterable - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.script.values import base, symbolic -from thunder.core.utils import FrozenDict - -__all__ = ("InternalRef", "OrderedSlice", "CompositeValue", "CompositeRef", "AbstractPhiValue") - -T = TypeVar("T") - - -# ============================================================================= -# == References =============================================================== -# ============================================================================= -class InternalRef(base.AbstractValue, abc.ABC): - @abc.abstractmethod - def _resolve(self, inputs: base.HybridMap[base.AbstractValue]) -> base.AbstractValue: - """Defines how to concretize itself.""" - ... - - @property - def is_detail(self) -> bool: - # All ref types are unsuitable for Graph binding. - return True - - @classmethod - def resolve( - cls, output: symbolic.Symbolic.Output, *, inputs: base.HybridMap[base.AbstractValue] | None = None - ) -> base.AbstractValue: - inputs = base.HybridMap() if inputs is None else inputs - if isinstance(output, (int, str)): - return inputs[output] - - elif isinstance(output, symbolic.ConstRef): - return base.ExternalRef(parse.VariableKey(output.identifier, parse.VariableScope.CONST)) - - if isinstance(output, tuple): - cls.validate_reference(output) - result = inputs[output[0]] - for idx in output[1:]: - # We can only unpack a (possibly nested) composite. - assert isinstance(result, base.HybridMap), result - result = result[idx] - - assert isinstance(result, base.AbstractValue) - return result - - elif isinstance(output, InternalRef): - return output._resolve(inputs) - - return output - - @staticmethod - def validate_reference(x: symbolic.NestedReference) -> None: - x = (x,) if isinstance(x, (int, str)) else x - assert isinstance(x, tuple) and x and all(isinstance(xi, (int, str)) for xi in x), x - - -# ============================================================================= -# == Nesting ================================================================== -# ============================================================================= -@dataclasses.dataclass(frozen=True, eq=True) -class OrderedSlice: - reference: symbolic.NestedReference - slice: slice - - def __hash__(self) -> int: - # `slice` isn't hashable until 3.12 - return hash((self.reference, self.slice.start, self.slice.stop, self.slice.step)) - - -@dataclasses.dataclass(frozen=True, eq=True, repr=False) -class _Composite(base.AbstractValue, base.HybridMap[T], __is_detail=True): - """Models an AbstractValue that references other (possibly also AbstractValue) state. - - Note: `ordered` and `named` should not contain cycles. - """ - - Identity = base.TraitName("__Thunder_Object_Identity") - - def _unpack_apply(self, replace_map: base.ReplaceMap) -> base.AbstractValue: - new_self = self.map(lambda x: base.substitute_value(x, replace_map)) - assert isinstance(new_self, _Composite) # For mypy since we can't hint `_Composite[T] -> _Composite[T1]` - return new_self - - def add_identity(self, identity: T) -> Self: - return self.add_named(self.Identity, identity) - - # NOTE: We don't override `identity`. (Instead retaining `return self` from AbstractValue.) - # This is because we generally won't know a good value, and passes should do their - # own type checking. (And that checking should almost always be done on the materialized - # value, not the symbolic reference.) - - -@dataclasses.dataclass(frozen=True, eq=True, repr=False) -class CompositeValue(_Composite[base.AbstractValue]): - def __post_init__(self) -> None: - assert all(isinstance(i, base.AbstractValue) for i in self.ordered) - assert all(isinstance(i, base.AbstractValue) for i in self.named.values()) - - @property - def is_detail(self) -> bool: - return any(i.is_detail for i in itertools.chain(self.ordered, self.named.values())) - - @property - def identity(self) -> base.AbstractValue: - return self.named.get(self.Identity, self) - - -@dataclasses.dataclass(frozen=True, eq=True) -class CompositeRef(InternalRef, _Composite[symbolic.Symbolic.Output | OrderedSlice]): - def __post_init__(self) -> None: - assert not any(isinstance(i, OrderedSlice) for i in self.named.values()) - - def _resolve(self, inputs: base.HybridMap[base.AbstractValue]) -> CompositeValue: - ordered: list[base.AbstractValue] = [] - for i in self.ordered: - if isinstance(i, OrderedSlice): - slice_target = self.resolve(i.reference, inputs=inputs) if i.reference else inputs - assert isinstance(slice_target, base.HybridMap) - ordered.extend(slice_target.ordered[i.slice]) - else: - ordered.append(self.resolve(i, inputs=inputs)) - - named: dict[base.TraitName, base.AbstractValue] = {} - for k, v in self.named.items(): - assert not isinstance(v, OrderedSlice) - named[k] = self.resolve(v, inputs=inputs) - - return CompositeValue(ordered=tuple(ordered), named=FrozenDict(named)) - - -# ============================================================================= -# == Unions =================================================================== -# ============================================================================= -@dataclasses.dataclass(frozen=True, eq=True) -class AbstractPhiValue(base.AbstractValue): - constituents: tuple[base.AbstractValue, ...] - - def __post_init__(self) -> None: - # Flatten nested PhiValues. e.g. - # 𝜙[𝜙[A, B], 𝜙[A, C]] -> 𝜙[A, B, C] - constituents = itertools.chain(*[self.flatten(i) for i in self.constituents]) - - # Ensure a consistent order. - constituents = tuple(v for _, v in sorted({hash(v): v for v in constituents}.items())) - assert not any(isinstance(i, InternalRef) for i in constituents) - object.__setattr__(self, "constituents", constituents) - - def __getitem__(self, _: Any) -> base.AbstractValue: - # The semantics of indexing into an `AbstractPhiValue`` are not well defined: - # - The order of `constituents` is arbitrary - # - It's unclear if the desire is to select one constituent or create a new `AbstractPhiValue` - # which indexes into each constituent. - # If a concrete use case emerges we can tackle it; until then we refuse for safety. - - # TODO(robieta): Handle traits - raise NotImplementedError - - def _unpack_apply(self, replace_map: base.ReplaceMap) -> base.AbstractValue: - result = AbstractPhiValue(tuple(base.substitute_value(v, replace_map) for v in self.constituents)) - return result if len(result.constituents) > 1 else result.constituents[0] - - @classmethod - def flatten(cls, v: base.AbstractValue) -> Iterable[base.AbstractValue]: - constituents = [cls.flatten(i) for i in v.constituents] if isinstance(v, AbstractPhiValue) else [[v]] - yield from itertools.chain(*constituents) - - @property - def is_detail(self) -> bool: - return any(i.is_detail for i in self.constituents) diff --git a/thunder/core/script/values/materialization.py b/thunder/core/script/values/materialization.py deleted file mode 100644 index e12e35d0e5..0000000000 --- a/thunder/core/script/values/materialization.py +++ /dev/null @@ -1,171 +0,0 @@ -from __future__ import annotations - -import dataclasses -import functools -import itertools -from types import MappingProxyType -from typing import Literal, TypeVar -from collections.abc import Callable, Iterator - -from thunder.core.script import parse -from thunder.core.script.values import base, composite, symbolic -from thunder.core.utils import FrozenDict, OrderedSet -from collections.abc import Iterable - -__all__ = ("Materialized", "DigestFlow", "IntraBlockFlow") -T = TypeVar("T") - - -# ============================================================================= -# == Intra-ProtoBlock abstract value flow ===================================== -# ============================================================================= -# -# `ProtoBlocks` employ a dual representation, where node inputs and outputs can -# be viewed as either a reference based DAG or a sequence of ops with concrete -# `AbstractValue` inputs and outputs. -# -# At the boundaries of a ProtoBlock values have named (VariableKey) slots; -# within the ProtoBlock there is no need for such slots (since there is no -# control flow within a block and those named slots tell you how to build the -# directed *cyclic* graph for the larger program) so they are stripped during -# parsing. -# -# The inputs of a protoblock are stored as a map of `VariableKey -> AbstractValue` -# and act as the intra-block DAG sources. The outputs are stored as references -# since every ProtoBlock output must have a unique producer. (Either an input -# or a node within the block.) -# -# The canonical representation for intra-block flow is "symbolic" (reference -# based). If an `AbstractValue` appear in a symbolic node's outputs that -# indicates that the node is that value's producer. Otherwise all inputs and -# outputs are references: inputs reference either the begin state or the -# outputs of a prior node while output references index into the node's inputs. -# -# When analyzing a graph we are generally interested in the concrete properties -# of values; provenance is generally only important when connecting blocks and -# performing rewrites. For these cases `IntraBlockFlow` generates a -# "materialized" flow which resolves all references to `AbstractValue`s. The -# symbolic representation is sufficient to emit the materialized representation, -# but the reverse is not true. -VarT = TypeVar("VarT", bound=base.AbstractValue, covariant=True) -ConcreteState = FrozenDict[parse.VariableKey, base.AbstractValue] -EndState = FrozenDict[parse.VariableKey, symbolic.Symbolic.Input] - - -@dataclasses.dataclass(frozen=True, eq=False) -class Materialized: - """Flow element where all symbolic references have been resolved to concrete `AbstractValue`s.""" - - inputs: base.HybridMap[base.AbstractValue] - outputs: tuple[base.AbstractValue, ...] - - -class DigestFlow: - """One-shot helper for materializing a block.""" - - GetT = Callable[[symbolic.Symbolic.Input], base.AbstractValue] - - def __init__(self, begin: ConcreteState) -> None: - self._begin = begin - self._result: dict[parse.ThunderInstruction, Materialized] = {} - - def next(self, instruction: parse.ThunderInstruction, symbolic: symbolic.Symbolic) -> Materialized: - """Lookup the materialized node corresponding to a symbolic node.""" - - # NB: `inputs_after_op` will be needed after we introduce mutations. - inputs = inputs_after_op = symbolic.inputs.map(self.get) - outputs = tuple(composite.InternalRef.resolve(o, inputs=inputs_after_op) for o in symbolic.outputs) - assert all(isinstance(o, base.AbstractValue) for o in outputs), outputs - - self._result[instruction] = result = Materialized(inputs, outputs) - return result - - def get(self, key: symbolic.Symbolic.Input) -> base.AbstractValue: - """Resolve a Symbolic input based on the block state at that node.""" - result: base.AbstractValue - if isinstance(key, base.NonPyObject.Tag): - result = base.NonPyObject(key) - - elif isinstance(key, symbolic.OutputRef): - inputs = base.HybridMap(ordered=self._result[key.instruction].outputs) - result = composite.InternalRef.resolve(key.idx, inputs=inputs) - - else: - assert isinstance(key, parse.VariableKey), key - result = base.ExternalRef(key) if key.is_const else self._begin[key] - return result - - -@dataclasses.dataclass(frozen=True, eq=False) -class IntraBlockFlow: - _symbolic: FrozenDict[parse.ThunderInstruction, symbolic.Symbolic] - _begin: ConcreteState - _end: EndState - - StateIterT = Iterator[tuple[parse.VariableKey, base.AbstractValue]] - - def __post_init__(self) -> None: - assert not (forbidden := tuple(i for i in self._symbolic if i in parse.FORBIDDEN_INSTRUCTIONS)), forbidden - object.__setattr__(self, "_symbolic", FrozenDict(self._symbolic)) - - missing = {i: base.AbstractRef("Inferred") for i in self.uses if i not in self._begin} - object.__setattr__(self, "_begin", FrozenDict({**missing, **self._begin})) - assert not any(k.is_const for k in self._begin), self._begin - assert not any(isinstance(v, composite.InternalRef) for v in self._begin.values()), self._begin - - object.__setattr__(self, "_end", FrozenDict(self._end)) - - @functools.cache - def __getitem__(self, key: tuple[symbolic.Symbolic.Input, Literal[0, 1]]) -> base.AbstractValue: - assert key[1] in (0, 1) - return self._computed[1:][key[1]](key[0]) - - @property - def symbolic(self) -> Iterable[tuple[parse.ThunderInstruction, symbolic.Symbolic]]: - yield from self._symbolic.items() - - @property - def materialized(self) -> FrozenDict[parse.ThunderInstruction, Materialized]: - return self._computed[0] - - @property - def begin_state(self) -> StateIterT: - yield from self._sort_and_filter_state(iter(self._begin.items())) - - @property - def end_state(self) -> StateIterT: - yield from self._sort_and_filter_state((k, self[v, 1]) for k, v in self._end.items()) - - @staticmethod - def _sort_and_filter_state(kv: StateIterT) -> StateIterT: - yield from ((k, v) for k, v in sorted(kv) if not isinstance(v, base.NonPyObject)) - - @property - def uses(self) -> OrderedSet[parse.VariableKey]: - assignment = (v for k, v in self._end.items() if isinstance(v, parse.VariableKey) and not v.is_const and k != v) - return OrderedSet(itertools.chain(*(s.uses for _, s in self.symbolic), assignment)) - - _Computed = tuple[ - FrozenDict[parse.ThunderInstruction, Materialized], - DigestFlow.GetT, # Begin - DigestFlow.GetT, # End - ] - - @functools.cached_property - def _computed(self) -> _Computed: - flow_state = DigestFlow(self._begin) - materialized_flow: FrozenDict[parse.ThunderInstruction, Materialized] - materialized_flow = FrozenDict({i: flow_state.next(i, s) for i, s in self.symbolic}) # Populates `flow_state` - return materialized_flow, DigestFlow(self._begin).get, flow_state.get - - def substitute(self, replace_map: base.ReplaceMap) -> IntraBlockFlow | None: - """Replace `AbstractValue`s within the flow. (Block inputs and producer nodes.)""" - replace_map_view = MappingProxyType(replace_map) - new_symbolic: FrozenDict[parse.ThunderInstruction, symbolic.Symbolic] - new_symbolic = FrozenDict({k: (s.substitute(replace_map_view)) for k, s in self.symbolic}) - begin = ConcreteState({k: base.substitute_value(v, replace_map_view) for k, v in self._begin.items()}) - - # TODO(robieta): Check if a value is only present in `materialized` and error. - if self._symbolic != new_symbolic or self._begin != begin: - return dataclasses.replace(self, _symbolic=new_symbolic, _begin=begin) - return None diff --git a/thunder/core/script/values/symbolic.py b/thunder/core/script/values/symbolic.py deleted file mode 100644 index 60bbaa08b6..0000000000 --- a/thunder/core/script/values/symbolic.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Introduce references inside simple blocks.""" -from __future__ import annotations - -import dataclasses -import itertools -import sys -from typing import Any, NamedTuple, TypeAlias -from collections.abc import Callable, Iterable - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.script.values import base -from thunder.core.utils import FrozenDict, safe_zip - -__all__ = ("OutputRef", "ParsedSymbolic", "Symbolic", "NestedReference", "ConstRef") - - -# ============================================================================= -# == Opcode-specific behavior ================================================= -# ============================================================================= -def rotate_N(oparg: int) -> tuple[int, ...]: - return (-1,) + tuple(range(-oparg, -1)) - - -_AliasMask = tuple[int | None, ...] -ALIAS_OPCODES = FrozenDict[str, _AliasMask | Callable[[int], _AliasMask]]( - parse.fill_ellipses( - # - # Stack manipulation - ROT_N=rotate_N, # A,B,...,Z -> Z,A,B,... - ROT_FOUR=rotate_N, - ROT_THREE=rotate_N, - ROT_TWO=rotate_N, - DUP_TOP=(-1, -1), # A -> A,A - DUP_TOP_TWO=(-2, -1) * 2, # A,B -> A,B,A,B - # - # Insertion leaves container on the stack A,B -> A - SET_ADD=(-2,), - SET_UPDATE=..., - LIST_APPEND=..., - LIST_EXTEND=..., - DICT_MERGE=..., - DICT_UPDATE=..., - MAP_ADD=(-3,), - COPY_DICT_WITHOUT_KEYS=(-2, None), # A,B -> A,C (I am unsure...) - # - # Misc. - GET_LEN=(-1, None), - MATCH_MAPPING=(-1, None), - MATCH_SEQUENCE=..., - MATCH_KEYS=(-1, -2, None) + () if sys.version_info >= (3, 11) else (None,), - # - # Jump dependent - FOR_ITER=(-1, None), - # NOTE: These instructions have been removed since they are extraneous special cases. - # https://github.com/faster-cpython/ideas/issues/567 - # https://github.com/python/cpython/issues/102859 - JUMP_IF_TRUE_OR_POP=(-1,), - JUMP_IF_FALSE_OR_POP=(-1,), - # - # This isn't actually correct. `LOAD_METHOD` will return either - # A -> B, A - # A -> B, NULL - # However the `A | NULL` is only consumed by `CALL_METHOD`, so it's ok to use this alias. - LOAD_METHOD=(None, -1), # A -> B,A - ) -) - - -# ============================================================================= -# == Symbolic flow ============================================================ -# ============================================================================= -IndexT: TypeAlias = base.Reference | base.TraitName -NestedReference: TypeAlias = IndexT | tuple[IndexT, ...] - - -@dataclasses.dataclass(frozen=True, eq=True) -class OutputRef: - """Identifies the producer of a value within a block.""" - - instruction: parse.ThunderInstruction # Acts as a key for the producer Flow. - idx: NestedReference # Indexes the producer's outputs. - - -class ConstRef(NamedTuple): - """Convenience wrapper to access `ExternalRef(VariableKey(..., CONST))` as a reference. - - This saves us from having to plumb constants through named inputs, since: - A) `ExternalRef`s cannot appear in Symbolic outputs. - (Since that implies a producer relationship which doesn't make sense.) - B) Symbolic reference outputs must reference an input, which would mean an entry of - `{some_random_name: VariableKey(..., CONST)}` would have to be added to named inputs - which is tedious. - """ - - identifier: Any - - -@dataclasses.dataclass(frozen=True, eq=True) -class Symbolic: - """Represents abstract flow immediately after functionalization.""" - - # VariableKey: References the value of that variable at the start of the block - # OutputRef: Reference values created by an earlier instruction within the block - # SingletonValue.Tag: Reserved for special cases. - Input = parse.VariableKey | OutputRef | base.NonPyObject.Tag - inputs: base.HybridMap[Input] - - # NestedReference: Aliases the input at this position. - # AbstractValue: New value created by this instruction - Output = NestedReference | ConstRef | base.AbstractValue - outputs: tuple[Output, ...] - - BeginState = FrozenDict[parse.VariableKey, base.AbstractValue] - EndState = FrozenDict[parse.VariableKey, Input] - Block = tuple[FrozenDict[parse.ThunderInstruction, "Symbolic"], BeginState, EndState] - - def __post_init__(self) -> None: - # If an `AbstractValue` appears in `Symbolic.outputs` that implies that the symbolic - # node in question is the value's producer. However it doesn't make sense for an external - # value to be produced within the compiled function. - assert not any(isinstance(o, base.ExternalRef) for o in self.outputs), self - - @property - def uses(self) -> Iterable[parse.VariableKey]: - """Block inputs used by this node. - - NOTE: This does not include values produced by an earlier node in the block. - """ - for i in itertools.chain(self.inputs.ordered, self.inputs.named.values()): - if isinstance(i, parse.VariableKey) and not i.is_const: - yield i - - def substitute(self, replace_map: base.ReplaceMap) -> Self: - outputs = tuple(base.substitute_value(o, replace_map) for o in self.outputs) - return dataclasses.replace(self, outputs=outputs) - - -# ============================================================================= -# == Conversion from functional representation ================================ -# ============================================================================= -@dataclasses.dataclass(frozen=True) -class ParsedSymbolic: - blocks: tuple[Symbolic.Block, ...] - provenance: parse.ParsedFunctional - - @classmethod - def make(cls, parsed: parse.ParsedFunctional) -> ParsedSymbolic: - blocks: list[Symbolic.Block] = [] - for block, begin_state, end_state in parsed.blocks: - # `functionalize_blocks` produces unique values, so provenance is unambiguous. - producers: dict[parse.PlaceholderValue | None, Symbolic.Input] = {v: k for k, v in begin_state.items()} - producers[None] = base.NonPyObject.Tag.DELETED - assert len(producers) == len(begin_state) + 1, (producers, end_state) - - symbolic_blocks: dict[parse.ThunderInstruction, Symbolic] = {} - for instruction, raw_inputs, raw_outputs in block: - for idx, o in enumerate(raw_outputs): - assert o not in producers - producers[o] = OutputRef(instruction, base.Reference(idx)) - - outputs: tuple[Symbolic.Output, ...] = tuple(base.IntermediateValue() for _ in raw_outputs) - if alias := ALIAS_OPCODES.get(instruction.opname): - mask = alias(len(outputs)) if callable(alias) else alias - mask = (base.Reference(i) if i is not None else i for i in mask) - outputs = tuple(o if o_mask is None else o_mask for o, o_mask in safe_zip(outputs, mask)) - inputs = base.HybridMap(ordered=tuple(producers[i] for i in raw_inputs)) - symbolic_blocks[instruction] = Symbolic(inputs, outputs) - - begin = {k: base.AbstractRef(v) for k, v in begin_state.items() if not k.is_const} - end = {k: producers[v] for k, v in end_state.items() if not k.is_const} - blocks.append((FrozenDict(symbolic_blocks), FrozenDict(begin), FrozenDict(end))) - - return cls(tuple(blocks), parsed) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 42dcddced4..16b195801d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -545,7 +545,6 @@ def _flatten(bsym: BoundSymbol): # TODO Test with buffers def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, args=None, kwargs=None) -> None: idx: int = 0 - from thunder.common import ThunderOptimizedModule from thunder import ThunderModule, compile_data if isinstance(tom, ThunderModule) or thunder.compile_data(tom).using_jit: @@ -563,16 +562,6 @@ def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, idx += 1 return - if tom is not None and isinstance(tom, ThunderOptimizedModule) and tom._additional_param_values is not None: - for p in tom._additional_param_values: - if p.requires_grad: - # Supports grad accumulation (like when weight tying) - if p.grad is not None: - p.grad += grads[idx] - else: - p.grad = grads[idx] - idx += 1 - # Short-circuits if there are no args or kwargs if args is None and kwargs is None: return @@ -603,7 +592,6 @@ def clear_grads(module: torch.nn.Module) -> None: b.grad = None -from thunder.core.script.noinline import noinline from thunder.core.interpreter import make_opaque from thunder.core.langctxs import langctx, Languages @@ -611,7 +599,7 @@ def clear_grads(module: torch.nn.Module) -> None: # TODO RC1 Replace with langctx def torchctx(fn): _fn = langctx(Languages.TORCH)(fn) - return make_opaque(noinline(_fn)) + return make_opaque(_fn) _grad_fn_map: dict[Any, Callable] = {} diff --git a/thunder/numpy/__init__.py b/thunder/numpy/__init__.py index b53e9f26f5..ed2e6474aa 100644 --- a/thunder/numpy/__init__.py +++ b/thunder/numpy/__init__.py @@ -11,9 +11,6 @@ from thunder.core.symbol import Symbol import thunder.clang as clang -# TODO RC1 Remove this -from thunder.core.script.noinline import noinline - # # NumPy operator definitions @@ -28,7 +25,7 @@ def __init__(self, *, method_name: None | str = None): def __call__(self, fn: Callable) -> Symbol: _fn = langctx(Languages.NUMPY)(fn) - _fn = noinline(_fn) + # TODO: register _fn as opaque with the interpreter or do this in jit_ext? sym = Symbol(name=fn.__name__, meta=_fn) if self.method_name is not None: diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index ce237d9c97..dfdbeb8312 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -133,25 +133,11 @@ def executors_list(self) -> list[extend.Executor]: @singledispatchmethod def make_callable_legacy(self, fn, **kwargs): - # TODO: an error is thrown for many functions because __code__ and - # inspect.signature for wrapped functions is not matching. - # KeyError: 'args' - # thunder/core/script/frontend.py:125: KeyError - # with disable_preprocessing=False - # See: https://github.com/Lightning-AI/lightning-thunder/issues/386 - disable_preprocessing = kwargs.pop("disable_preprocessing", True) - return thunder.compile( - fn, executors_list=self.executors_list(), disable_preprocessing=disable_preprocessing, **kwargs - ) + assert kwargs.pop("disable_preprocessing", True) + return thunder.compile(fn, executors_list=self.executors_list(), disable_preprocessing=True, **kwargs) @singledispatchmethod def make_callable(self, fn, **kwargs): - # TODO: an error is thrown for many functions because __code__ and - # inspect.signature for wrapped functions is not matching. - # KeyError: 'args' - # thunder/core/script/frontend.py:125: KeyError - # with disable_preprocessing=False - # See: https://github.com/Lightning-AI/lightning-thunder/issues/386 return thunder.jit(fn, executors=self.executors_list(), **kwargs) @make_callable.register From db89e2d6fcad9db3ee7d0582c84b29d3c2aa66e7 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 14 Mar 2024 15:56:26 +0200 Subject: [PATCH 09/44] Remove duplicate code for backward definitions (PR1436) --- thunder/common.py | 4 +- thunder/core/prims.py | 2 +- thunder/core/transforms.py | 503 +++------------------------- thunder/core/utils.py | 4 +- thunder/core/vjp_utils.py | 86 ++++- thunder/executors/apex_entropyex.py | 4 +- thunder/executors/sdpaex.py | 109 ------ thunder/tests/opinfos.py | 5 + thunder/tests/test_grad.py | 38 ++- 9 files changed, 171 insertions(+), 584 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index 860b6b522a..c0ec2fa770 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -338,8 +338,8 @@ def translate(x: Any, *, name: str | None = None) -> Any: # TODO Update cacheable types def _make_subkey_for(x: Any) -> tuple[bool, None | tuple]: - if isinstance(x, torch.Tensor): - return True, (torch.Tensor, x.shape, x.device, x.dtype, x.requires_grad) + if isinstance(x, (torch.Tensor, TensorProxy)): + return True, (type(x), x.shape, x.device, x.dtype, x.requires_grad) # TODO Add NumPy ndarray support if isinstance(x, np.ndarray): diff --git a/thunder/core/prims.py b/thunder/core/prims.py index f477ee2778..035106bb30 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -2841,7 +2841,7 @@ def slice_meta( # NOTE: slice is named "slice_prim" and not "slice" because it conflicts with Python's "slice" builtin -slice_prim = make_prim(PrimIDs.SLICE, "slice", meta=slice_meta, tags=(OpTags.SHAPE_OP,)) +slice_prim = make_prim(PrimIDs.SLICE, "slice_prim", meta=slice_meta, tags=(OpTags.SHAPE_OP,)) def squeeze_meta(a: TensorProxy, /, dims: tuple[int, ...]) -> TensorProxy: diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 16b195801d..e1c0d8dc70 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -24,6 +24,7 @@ NumberProxy, Proxy, TensorProxy, + FloatProxy, variableify, unvariableify, CollectionProxy, @@ -59,6 +60,7 @@ convolution, ) from thunder.core.transform_common import dce +from thunder.core.vjp_utils import make_aug_forward_and_backward from thunder.extend import Executor import thunder.torch as ltorch @@ -668,8 +670,16 @@ def _convert_element_type_prim_grad(a: Number | TensorProxy, dtype: type | dtype # NOTE prims.iota creates no grad associations register_grad(pids.IOTA, prims.iota) -# NOTE prims.uniform creates no grad associations -register_grad(pids.UNIFORM, prims.uniform) + +def _uniform_grad(shape, minval, maxval, *, device, dtype): + fwd, saved = uniform_aug_fwd(shape, minval, maxval, device=device, dtype=dtype) + g = get_grad(fwd) + _, gminval, gmaxval = uniform_backward(*saved, g) + put_grads((minval, maxval), (gminval, gmaxval)) + return fwd + + +register_grad(pids.UNIFORM, _uniform_grad) # # Reshaping and permuting operator grads @@ -867,12 +877,11 @@ def _abs_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: register_grad(pids.ABS, _abs_prim_grad) -@torchctx def _cos_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: - fwd = prims.abs(a) + fwd = prims.cos(a) g = get_grad(fwd) - put_grad(a, g * (-ltorch.sin(a))) + put_grad(a, g * (-prims.sin(a))) return fwd @@ -952,12 +961,11 @@ def _rsqrt_prim_grad(a: Number | TensorProxy, /) -> Number | TensorProxy: register_grad(pids.RSQRT, _rsqrt_prim_grad) -@torchctx def _sin_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: - fwd = prims.abs(a) + fwd = prims.sin(a) g = get_grad(fwd) - put_grad(a, g * ltorch.cos(a)) + put_grad(a, g * prims.cos(a)) return fwd @@ -1228,7 +1236,9 @@ def _embedding_prim_grad( # -def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any]) -> None | Callable: +def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable: + cd = get_compile_data() + executors_list = cd.executors_list if cd is not None else executors_list # Checks if the executor which has priority for this operation has a specific grad transform for it for ex in executors_list: if ex.can_execute_or_fuse(bsym): @@ -2412,43 +2422,30 @@ def zeros_like(x): # The augmented_primal function takes the primal values and returns the primal # result and the residuals (saved values for the backward). augmented_forward_impls = { - prims.PrimIDs.ABS: lambda x: (prims.abs(x), (x,)), prims.PrimIDs.ACOS: lambda x: (prims.acos(x), (x,)), prims.PrimIDs.ACOSH: lambda x: (prims.acosh(x), (x,)), - prims.PrimIDs.ADD: lambda x, y: (prims.add(x, y), tuple()), prims.PrimIDs.ASIN: lambda x: (prims.asin(x), (x,)), prims.PrimIDs.ASINH: lambda x: (prims.asinh(x), (x,)), prims.PrimIDs.ATAN: lambda x: (prims.atan(x), (x,)), prims.PrimIDs.ATANH: lambda x: (prims.atanh(x), (x,)), prims.PrimIDs.ATAN2: lambda x, y: (prims.atan2(x, y), (x, y)), - prims.PrimIDs.COS: lambda x: (prims.cos(x), (x,)), prims.PrimIDs.COSH: lambda x: (prims.cosh(x), (x,)), prims.PrimIDs.DIGAMMA: lambda x: (prims.digamma(x), (x,)), - prims.PrimIDs.DIV: lambda x, y: (prims.div(x, y), (x, y)), - prims.PrimIDs.ERF: lambda x: (prims.erf(x), (x,)), prims.PrimIDs.ERFC: lambda x: (prims.erfc(x), (x,)), prims.PrimIDs.ERFINV: lambda x: (prims.erfinv(x), (prims.erfinv(x),)), prims.PrimIDs.ERFCINV: lambda x: (prims.erfcinv(x), (prims.erfcinv(x),)), prims.PrimIDs.EXP2: lambda x: (prims.exp2(x), (prims.exp2(x),)), prims.PrimIDs.EXPM1: lambda x: (prims.expm1(x), (prims.expm1(x),)), prims.PrimIDs.LGAMMA: lambda x: (prims.lgamma(x), (x,)), - prims.PrimIDs.MUL: lambda x, y: (prims.mul(x, y), (x, y)), prims.PrimIDs.NDTRI: lambda x: (prims.ndtri(x), (prims.ndtri(x),)), - prims.PrimIDs.SIN: lambda x: (prims.sin(x), (x,)), prims.PrimIDs.SINH: lambda x: (prims.sinh(x), (x,)), - prims.PrimIDs.SUB: lambda x, y: (prims.sub(x, y), tuple()), prims.PrimIDs.SQRT: lambda x: (prims.sqrt(x), (prims.sqrt(x),)), - prims.PrimIDs.EQ: lambda x, y: (prims.eq(x, y), (x, y)), prims.PrimIDs.NE: lambda x, y: (prims.ne(x, y), (x, y)), - prims.PrimIDs.GE: lambda x, y: (prims.ge(x, y), (x, y)), prims.PrimIDs.GT: lambda x, y: (prims.gt(x, y), (x, y)), prims.PrimIDs.LE: lambda x, y: (prims.le(x, y), (x, y)), - prims.PrimIDs.LT: lambda x, y: (prims.lt(x, y), (x, y)), - prims.PrimIDs.LOG: lambda x: (prims.log(x), (x,)), prims.PrimIDs.LOG10: lambda x: (prims.log10(x), (x,)), prims.PrimIDs.LOG1P: lambda x: (prims.log1p(x), (x,)), prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), - prims.PrimIDs.NEG: lambda x: (prims.neg(x), tuple()), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), } @@ -2458,42 +2455,28 @@ def zeros_like(x): # The backward function takes the residuals and cotangents and returns the # vector-Jacobian products for each argument. backward_impls = { - prims.PrimIDs.ABS: lambda x, g: g * prims.sign(x), prims.PrimIDs.ACOS: lambda x, g: -g / prims.sqrt(1.0 - x * x), prims.PrimIDs.ACOSH: lambda x, g: g * prims.rsqrt(x * x - 1.0), - prims.PrimIDs.ADD: lambda g: (g, g), prims.PrimIDs.ASIN: lambda x, g: g / prims.sqrt(1.0 - x * x), prims.PrimIDs.ASINH: lambda x, g: g * prims.rsqrt(1.0 + x * x), prims.PrimIDs.ATAN: lambda x, g: g / (1.0 + x * x), prims.PrimIDs.ATANH: lambda x, g: g / (1.0 - x * x), - prims.PrimIDs.COS: lambda x, g: prims.mul(g, -prims.sin(x)), prims.PrimIDs.COSH: lambda x, g: prims.mul(g, prims.sinh(x)), - prims.PrimIDs.DIV: lambda x, y, g: (g / y, -g * x / (y**2)), - prims.PrimIDs.ERF: lambda x, g: g * 2.0 / math.sqrt(math.pi) * prims.exp(-x * x), prims.PrimIDs.ERFC: lambda x, g: -g * 2.0 / math.sqrt(math.pi) * prims.exp(-x * x), prims.PrimIDs.ERFINV: lambda result, g: g * 0.5 * math.sqrt(math.pi) * prims.exp(result**2), prims.PrimIDs.ERFCINV: lambda result, g: -g * 0.5 * math.sqrt(math.pi) * prims.exp(result**2), prims.PrimIDs.EXP2: lambda result, g: g * result * math.log(2.0), prims.PrimIDs.EXPM1: lambda result, g: g * (result + 1.0), prims.PrimIDs.LGAMMA: lambda x, g: g * prims.digamma(x), - prims.PrimIDs.MUL: lambda x, y, g: (g * y, g * x), prims.PrimIDs.NDTRI: lambda result, g: g * prims.exp(0.5 * result**2) * math.sqrt(2.0 * math.pi), - prims.PrimIDs.SIN: lambda x, g: prims.mul(g, prims.cos(x)), prims.PrimIDs.SINH: lambda x, g: prims.mul(g, prims.cosh(x)), - prims.PrimIDs.SUB: lambda g: (g, -g), prims.PrimIDs.SQRT: lambda result, g: g / (2.0 * result), - prims.PrimIDs.FULL: NoPullback(num_args=2), - prims.PrimIDs.EQ: ZeroBackward(num_args=2), prims.PrimIDs.NE: ZeroBackward(num_args=2), - prims.PrimIDs.GE: ZeroBackward(num_args=2), prims.PrimIDs.GT: ZeroBackward(num_args=2), prims.PrimIDs.LE: ZeroBackward(num_args=2), - prims.PrimIDs.LT: ZeroBackward(num_args=2), - prims.PrimIDs.LOG: lambda x, g: g / x, prims.PrimIDs.LOG10: lambda x, g: g / (x * 2.302585092994046), prims.PrimIDs.LOG1P: lambda x, g: g / (x + 1), prims.PrimIDs.LOG2: lambda x, g: g / (x * 0.6931471805599453), - prims.PrimIDs.NEG: lambda g: -g, prims.PrimIDs.FMOD: lambda x, y, g: (g, -g * prims.trunc(x / y)), } @@ -2627,29 +2610,6 @@ def polygamma_backward(n: int, a: Proxy, g): return None, g * polygamma(n + 1, a) -@register_augmented_forward(prims.PrimIDs.RSQRT) -def rsqrt_augmented(x): - """Augmented rsqrt operation. - - Args: - x (Variable): input tensor. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.rsqrt(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.RSQRT) -def rsqrt_backward(result, g): - # An alternative derivation used by JAX is -0.5 * g * rsqrt(x) / x - # where rsqrt(x) and x are saved for the backwards pass. - # This derivation was selected because it avoids saving the input tensor. - return -0.5 * g * result**3.0 - - @register_backward(prims.PrimIDs.ATAN2) def atan2_backward(x, y, g): alpha = 1.0 / (x * x + y * y) @@ -2658,32 +2618,6 @@ def atan2_backward(x, y, g): return grad_x, grad_y -@register_augmented_forward(prims.PrimIDs.SUM) -def sum_aug_fwd(x, dims): - """Augmented sum operation. - - Args: - x (Variable): Tensor to be summed. - dims (Tuple[int, ...]): Dimensions to be summed. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.sum(x, dims) - residuals = ( - x.shape, - dims, - ) - - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.SUM) -def sum_backward(x_shape, reduced_dims, g): - # One return per positional argument of prims.sum - return restore_reduced_dims(g, reduced_dims, x_shape), None - - @register_augmented_forward(prims.PrimIDs.VAR) def var_aug_fwd(a, dim, *, correction): v = prims.var(a, dim, correction=correction) @@ -2705,13 +2639,6 @@ def var_backward(a, dim, correction, v, g): return (2 * g * (a - mean)) / normalization_scalar -@register_augmented_forward(prims.PrimIDs.VAR_MEAN) -def _var_mean_aug_fwd(a, dim, *, correction): - v, m = prims.var_mean(a, dim, correction=correction) - - return (v, m), (a, dim, correction, m) - - def n_elem_reduced(a_ndim, a_shape, dims): dims = utils.canonicalize_dims(a_ndim, dims) reduction_size = 1 @@ -2726,27 +2653,6 @@ def mean_backward(a_ndim, a_shape, dims, grad): return restore_reduced_dims(grad, dims, a_shape) * mean_local_grad -# TODO: fix division by zero when n_elem_reduced == 0 or when mean.numel == 0 -# by returning zeros_like(a) or similar. -# TODO: fix grad when correction > n_elem_reduced. -@register_backward(prims.PrimIDs.VAR_MEAN) -def _var_mean_bwd(a, dim, correction, mean, grad_v, grad_m): - n_elem_reduced = a.numel // mean.numel if a.numel != 0 else 1 - - def mean_backward(a, dims, grad): - mean_scale = 1.0 / n_elem_reduced - grad = restore_reduced_dims(grad, dims, a.shape) - return mean_scale * grad - - def var_backward(a, dims, correction, mean, grad): - normalization_scalar = n_elem_reduced - correction - grad = restore_reduced_dims(grad, dims, a.shape) - mean = restore_reduced_dims(mean, dims, a.shape) - return (2.0 * grad * (a - mean)) / normalization_scalar - - return var_backward(a, dim, correction, mean, grad_v) + mean_backward(a, dim, grad_m) - - @register_augmented_forward(prims.PrimIDs.PAD) def pad_aug_fwd(a, padding_value, padding_config): return VJPDual((prims.pad(a, padding_value, padding_config),), (a, padding_config)) @@ -2816,42 +2722,15 @@ def grad_chooser_backward(primal, x, x_shape, reduced_dims, g): return out -register_backward(prims.PrimIDs.AMAX)(grad_chooser_backward) register_backward(prims.PrimIDs.AMIN)(grad_chooser_backward) -# TODO: exact same for amin, argmax, argmin -@register_augmented_forward(prims.PrimIDs.AMAX) -def amax_aug_fwd(x, dims): - """Augmented amax operation. - - Args: - x (Variable): Tensor to compute amax on. - dims (Tuple[int, ...]): Dimensions to compute amax over. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.amax(x, dims) - - residuals = ( - primal, - x, - x.shape, - dims, - ) - - return VJPDual(primal, residuals) - - @register_augmented_forward(prims.PrimIDs.AMIN) def amin_aug_fwd(x, dims): """Augmented amin operation. - Args: x (Variable): Tensor to compute amin on. dims (Tuple[int, ...]): Dimensions to compute amin over. - Returns: VJPDual: Primal and residuals. """ @@ -2867,26 +2746,6 @@ def amin_aug_fwd(x, dims): return VJPDual(primal, residuals) -@register_augmented_forward(prims.PrimIDs.EXP) -def exp_aug_fwd(x): - """Augmented exp operation. - - Args: - x (Variable): Tensor to be exponentiated. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.exp(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.EXP) -def exp_backward(result, g): - return g * result - - @register_augmented_forward(prims.PrimIDs.POW) def pow_aug_fed(x, y): """Augmented the pow operation. @@ -2931,106 +2790,11 @@ def tan_backward(result, g): return g * (1 + result * result) -@register_augmented_forward(prims.PrimIDs.TANH) -def tanh_aug_fwd(x): - """Augmented tanh operation. - - Args: - x (Variable): Tensor to be passed to tanh. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.tanh(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TANH) -def tanh_backward(result, g): - return g * (1.0 - result * result) - - # NOTE: Jax uses np.argsort in its transpose vjp computation def _argsort(seq): return sorted(range(len(seq)), key=seq.__getitem__) -@register_augmented_forward(prims.PrimIDs.TRANSPOSE) -def transpose_aug_fwd(a, permutation): - primal = prims.transpose(a, tuple(permutation)) - residuals = (permutation,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TRANSPOSE) -def transpose_backward(permutation, g): - undo = _argsort(permutation) - return prims.transpose(g, tuple(undo)) - - -@register_augmented_forward(prims.PrimIDs.RESHAPE) -def reshape_aug_fwd(a, shape): - primal = prims.reshape(a, shape) - residuals = (a.shape,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.RESHAPE) -def reshape_backward(orig_shape, g): - return prims.reshape(g, orig_shape) - - -@register_augmented_forward(prims.PrimIDs.SLICE) -def slice_aug_fwd(a, start_indices, end_indices, strides): - primal = prims.slice_prim(a, start_indices, end_indices, strides) - residuals = (a.shape, start_indices, end_indices, strides) - return VJPDual(primal, residuals) - - -# Adapted from https://github.com/google/jax/blob/main/jax/_src/lax/slicing.py#L768 -@register_backward(prims.PrimIDs.SLICE) -def slice_backward(shape, start_indices, end_indices, strides, g): - padding = None - if strides is None or np.all(np.equal(strides, 1)): - padding = tuple(zip(start_indices, np.subtract(shape, end_indices), (0,) * len(start_indices))) - else: - real_limits = np.add( - start_indices, - np.where(np.equal(g.shape, 0), 0, np.add(1, np.multiply(np.subtract(g.shape, 1), strides))), - ) - padding = tuple(zip(start_indices, np.subtract(shape, real_limits), np.subtract(strides, 1))) - - # We used NumPy arithmetics above, but the current infra expects Python ints. - padding = tree_map(int, padding) - result = prims.pad(g, const_as(0, g.dtype), padding) - - return result - - -@register_augmented_forward(prims.PrimIDs.BROADCAST_IN_DIM) -def broadcast_in_dim_aug_fwd(a: Proxy, shape: Sequence[int], broadcast_dimensions: Sequence[int]) -> VJPDual: - primal = prims.broadcast_in_dim(a, shape, broadcast_dimensions) - residuals = (a, shape, broadcast_dimensions) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.BROADCAST_IN_DIM) -def broadcast_in_dim_backward(a, shape, broadcast_dimensions, g): - from thunder.torch import sum - - # If g is None, then the primal was a constant and the pullback is zero. - # TODO: implement None propagation in the VJP infrastructure so that we don't need to do this. - if g is None: - return None, None, None - unit_dims = tuple(i for i, s in enumerate(a.shape) if s == 1) - bcast_dims = tuple(b for i, b in enumerate(broadcast_dimensions) if i not in unit_dims) - reduce_dims = tuple(s for i, s in enumerate(range(len(shape))) if i not in bcast_dims) - g = sum(g, reduce_dims) - g = unsqueeze(g, unit_dims) - return g - - @register_augmented_forward(prims.PrimIDs.DEVICE_PUT) def device_put_aug_fwd(a: TensorProxy, device: Device) -> TensorProxy: primal = prims.device_put(a, device) @@ -3043,19 +2807,6 @@ def device_put_backward(orig_device, g): return prims.device_put(g, orig_device), None -@register_augmented_forward(prims.PrimIDs.CONVERT_ELEMENT_TYPE) -def convert_element_type_aug_fwd(a: Proxy, dtype: dtypes.dtype) -> VJPDual: - primal = prims.convert_element_type(a, dtype) - residuals = (a.dtype if isinstance(a, TensorProxy) else (a.python_type if isinstance(a, NumberProxy) else type(a)),) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.CONVERT_ELEMENT_TYPE) -def convert_element_type_backward(a_dtype, g): - # perform cast back to input type during backward - return prims.convert_element_type(g, a_dtype), None - - @register_augmented_forward(prims.PrimIDs.CONVOLUTION) def convolution_aug_fwd( a: Proxy, @@ -3257,32 +3008,6 @@ def pad_transpose_and_push_groups_into_batches(t): return (input_grad, weight_grad, bias_grad) -@register_augmented_forward("torch.nn.functional.cross_entropy") -def cross_entropy_aug_fwd( - input: Proxy, - target: Proxy, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -) -> VJPDual: - from thunder.torch import cross_entropy - - primal = cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing) - residuals = (input, target, weight, reduction, ignore_index, label_smoothing) - return VJPDual(primal, residuals) - - -@register_backward("torch.nn.functional.cross_entropy") -def cross_entropy_backward(input, target, weight, reduction, ignore_index, label_smoothing, g): - from thunder.torch import cross_entropy_backward - - ginput = cross_entropy_backward(g, input, target, weight, reduction, ignore_index, label_smoothing) - return ginput - - @register_augmented_forward("torch.log_softmax") def log_softmax_aug_fwd(input: TensorProxy, dim: int, *, dtype=None) -> VJPDual: from thunder.torch import log_softmax @@ -3412,64 +3137,6 @@ def softmax_backward(primal, dim, g): return primal * (g - (primal * g).sum(dim, keepdim=True)) -@register_augmented_forward(prims.PrimIDs.MATMUL) -def matmul_aug_fwd(a: TensorProxy, b: TensorProxy) -> VJPDual: - primal = prims.matmul(a, b) - residuals = (a, b) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.MATMUL) -def matmul_backward(a, b, g): - from thunder.torch import sum - - last_dim = (-1,) - first_dim = (-2,) - if a.ndim == 1 and b.ndim == 1: - return g * b, g * a - - if b.ndim == 1: - ga = unsqueeze(g, last_dim) @ unsqueeze(b, last_dim).mT - gb = a.mT @ unsqueeze(g, last_dim) - if g.ndim > 1: - gb = squeeze(gb, last_dim) - gb = sum(gb, tuple(range(gb.ndim - 1))) - return ga, gb - - if a.ndim == 1: - ga = unsqueeze(g, first_dim) @ b.mT - if g.ndim > 1: - ga = sum(ga, tuple(range(ga.ndim - 1))) - gb = unsqueeze(a, first_dim).mT @ unsqueeze(g, first_dim) - return ga, gb - - return g @ b.mT, a.mT @ g - - -@register_augmented_forward(prims.PrimIDs.LINEAR) -def linear_aug_fwd(a: TensorProxy, b: TensorProxy, c: TensorProxy | None) -> VJPDual: - primal = prims.linear(a, b, c) - residuals = (a, b, c) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.LINEAR) -def linear_backward(a, b, c, g): - from thunder.torch import matmul, sum - - first_dim = (-2,) - ga = matmul(g.reshape(-1, g.shape[-1]), b).reshape(a.shape) - if a.ndim == 1: - gb = matmul(unsqueeze(g, first_dim).mT, unsqueeze(a, first_dim)) - else: - gb = matmul(g.reshape(-1, g.shape[-1]).mT, a.reshape(-1, a.shape[-1])) - assert list(gb.shape) == list(b.shape), f"linear_backward: {gb.shape} != {b.shape}" - if c is None: - return ga, gb, None - gc = sum(g, tuple(range(g.ndim - 1))) if g.ndim > 1 else g - return ga, gb, gc - - def iter_bound_symbols(bound_symbols): """Iterate over bound symbols, skipping symbols that are not supported by the transforms infrastructure. @@ -3562,31 +3229,6 @@ def decomposed_fn_backward_rule(decomposed_fn, args, kwargs, saved_for_backward, return result -@register_augmented_forward(prims.PrimIDs.CAT) -def cat_aug_fwd(tensors: list[TensorProxy], dim: int) -> VJPDual: - primal = prims.cat(tensors, dim) - residuals = ( - type(tensors), - [t.shape[dim] for t in tensors], - dim, - ) - - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.CAT) -def cat_backward( - tensors_seq_type: type, tensor_dim_lens: list[int], dim: int, g: TensorProxy -) -> tuple[Sequence[TensorProxy]]: - grads = [] - - slice_start = 0 - for dim_len in tensor_dim_lens: - grads.append(slice_in_dim(g, slice_start, slice_start + dim_len, dim=dim)) - slice_start += dim_len - return (tensors_seq_type(grads),) - - @register_augmented_forward("torch.Tensor.contiguous") @register_augmented_forward("torch.contiguous") def contiguous_aug_fwd(x: TensorProxy, /, *, memory_format: torch.memory_format = torch.contiguous_format) -> VJPDual: @@ -3603,25 +3245,11 @@ def contiguous_backward(*residuals_and_grad) -> TensorProxy: return g -@register_augmented_forward(prims.PrimIDs.WHERE) -def where_aug_fwd(condition: TensorProxy, x: TensorProxy, y: TensorProxy) -> VJPDual: - primal = prims.where(condition, x, y) - residuals = (condition,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.WHERE) -def where_backward(condition, g): - return prims.where(condition, g, 0.0), prims.where(condition, 0.0, g) - - -@register_augmented_forward(prims.PrimIDs.RECIPROCAL) def reciprocal_aug_fwd(a: TensorProxy) -> VJPDual: primal = reciprocal(a) return VJPDual(primal, (primal,)) -@register_backward(prims.PrimIDs.RECIPROCAL) def reciprocal_backward(primal, g): return -g * primal * primal @@ -3635,38 +3263,6 @@ def reciprocal_joint_forward_backward_rule(a: TensorProxy) -> TensorProxy: return result -@register_augmented_forward(prims.PrimIDs.SQUEEZE) -def squeeze_aug_fwd(a: TensorProxy, dims: Sequence[int]) -> VJPDual: - primal = squeeze(a, dims) - residuals = (dims,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.SQUEEZE) -def squeeze_backward(dims: Sequence[int], g: TensorProxy) -> TensorProxy: - return unsqueeze(g, dims) - - -@register_augmented_forward(prims.PrimIDs.TAKE) -def take_aug_fwd(x: TensorProxy, index: TensorProxy, dim: int) -> VJPDual: - primal = prims.take(x, index, dim) - residuals = ( - x.shape, - x.device, - x.dtype, - index, - dim, - ) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TAKE) -def take_backward( - shape: Sequence[int], device: Device, dtype: dtypes.dtype, index: TensorProxy, dim: int, g: TensorProxy -): - return prims.index_add(prims.full(shape, fill_value=0, device=device, dtype=dtype), index, g, dim) - - @register_augmented_forward("torch.index_put") def index_put_aug_fwd( a: TensorProxy, /, indices: Sequence[TensorProxy], values: TensorProxy, accumulate: bool = False @@ -3705,33 +3301,11 @@ def index_put_backward(indices: Sequence[TensorProxy], values: TensorProxy, accu return clang.index_put(g, indices, ltorch.zeros_like(values), False), g_values -@register_augmented_forward(prims.PrimIDs.TAKE_ALONG_AXIS) -def take_along_axis_aug_fwd(x: TensorProxy, index: TensorProxy, dim: int) -> VJPDual: - primal = prims.take_along_axis(x, index, dim) - residuals = ( - x.shape, - x.device, - x.dtype, - index, - dim, - ) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TAKE_ALONG_AXIS) -def take_along_axis_backward( - shape: Sequence[int], device: Device, dtype: dtypes.dtype, index: TensorProxy, dim: int, g: TensorProxy -): - return prims.scatter_add(prims.full(shape, fill_value=0, device=device, dtype=dtype), index, g, dim) - - -@register_augmented_forward(prims.PrimIDs.UNIFORM) def uniform_aug_fwd(shape, minval, maxval, *, device, dtype): primal = prims.uniform(shape, minval, maxval, device=device, dtype=dtype) return VJPDual(primal, (primal, minval, maxval)) -@register_backward(prims.PrimIDs.UNIFORM) def uniform_backward(primal, minval, maxval, g): # uniform is implemented as (maxval - minval) * uniform(shape, 0, 1) + minval unscaled_primal = (primal - minval) / (maxval - minval) @@ -3768,6 +3342,23 @@ def get_executor_specific_aug_fwd_rule(symbol) -> RuleInfo | None: return None +def is_constant_for_vjp(symbol: prims.Symbol) -> bool: + """Check if a symbol is constant for the VJP transform. + + Args: + symbol (prims.Symbol): Symbol to check. + + Returns: + bool: True if the symbol is constant, False otherwise. + """ + are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) + return ( + are_all_args_non_differentiable + or symbol.are_all_args_constant + or symbol.sym.id in nondifferentiable_vjp_symbols + ) + + def vjp_symbol_mapper(symbol: prims.Symbol, *args, **kwargs): """Symbol mapper for the VJP transform. @@ -3780,7 +3371,7 @@ def vjp_symbol_mapper(symbol: prims.Symbol, *args, **kwargs): Callable: A function that computes the VJP of the symbol. """ # Constant case - if symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols: + if is_constant_for_vjp(symbol): def vjp_impl_const(symbol, *args, **kwargs): args, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs)) @@ -3793,9 +3384,11 @@ def vjp_impl_const(symbol, *args, **kwargs): # Normal case, we have a proxy tangent vjp_impl = augmented_forward_impls.get(symbol.sym.id) - vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl + if _get_gradfn(symbol) is not None: + vjp_impl, backward_fn = make_aug_forward_and_backward(symbol) + if isinstance(vjp_impl, RuleInfo): # We should use this rule only if checker returns True for the current # symbol's arguments @@ -3813,6 +3406,7 @@ def vjp_impl_const(symbol, *args, **kwargs): # It could be a torch.dropout with 0.0 probability, so we skip it if symbol.sym.id == "torch.nn.functional.dropout": return None + print(f"VJP for {symbol} is not implemented") raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented") def _vjp_impl(*args, **kwargs): @@ -3846,6 +3440,9 @@ def check_bsym_for_vjp(bsym): if bsym.sym.id in backward_impls and bsym.sym.id in augmented_forward_impls: return True + if bsym.sym.id in _grad_fn_map: + return True + # We could not find a VJP for this symbol, so we try to decompose it # into sub-symbols and check if they are supported if len(bsym.subsymbols) > 0 and not bsym.sym.is_prim: @@ -3929,6 +3526,8 @@ def put_grad(v: Variable, val: Any) -> None: elif isinstance(v, Sequence) and val is None: # broadcast None to the right shape safe_map(put_grad, v, [None] * len(v)) + elif isinstance(v, Sequence) and isinstance(val, Sequence): + safe_map_flat(put_grad, v, val) else: # Skip writing to constants pass @@ -3947,7 +3546,7 @@ def put_grad(v: Variable, val: Any) -> None: # Otherwise, we will need to rewrite the pullback functions cotangents = tree_flatten(cotangents)[0] residuals = forward_env[symbol_output[0].name].residuals - if symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols: + if is_constant_for_vjp(symbol): # We can skip the pullback if all the arguments are constant continue @@ -3968,6 +3567,9 @@ def put_grad(v: Variable, val: Any) -> None: aug_forward = augmented_forward_impls.get(symbol.sym.id) aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward + if _get_gradfn(symbol) is not None: + aug_forward, backward = make_aug_forward_and_backward(symbol) + if isinstance(aug_forward, RuleInfo): backward = backward_impls[aug_forward.executor, symbol.sym.id] @@ -3984,9 +3586,11 @@ def put_grad(v: Variable, val: Any) -> None: # If the backward returns a dict, we assume that it is a dict of # forward arguments to the corresponding # gradients/cotangents/adjoints/sensitivities. + used_names = set() for i, (k, v) in enumerate(inspect.signature(aug_forward).parameters.items()): if v.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): put_grad(symbol.args[i], result.get(k, None)) + used_names.add(k) # For developer convenience, we allow using the name from the # forward meta in addition to the name from the augmented forward @@ -3995,7 +3599,8 @@ def put_grad(v: Variable, val: Any) -> None: # precedence. for i, (k, v) in enumerate(inspect.signature(symbol.sym.meta).parameters.items()): if v.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): - put_grad(symbol.args[i], result.get(k, None)) + if k not in used_names: + put_grad(symbol.args[i], result.get(k, None)) continue if not isinstance(result, Sequence): diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 9dfb95e767..1edebe70e4 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -812,8 +812,8 @@ def safe_map(f, *args): def safe_map_flat(f, *args): def convert_sequences_to_tuple(x): - if isinstance(x, Sequence): - return tuple(x) + if not isinstance(x, str) and isinstance(x, Sequence) and not isinstance(x, Proxy): + return tuple(convert_sequences_to_tuple(y) for y in x) return x args_flat_spec = safe_map(lambda x: tree_flatten(convert_sequences_to_tuple(x)), args) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 5f5a67febd..3c8128ffeb 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -1,4 +1,3 @@ -import copy import inspect from inspect import Parameter, Signature from itertools import chain @@ -13,6 +12,9 @@ from thunder.core.transform_common import dce +_cache = {} + + def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable]: """ Given a bound symbol, return a pair of forward and backward functions @@ -33,13 +35,20 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable A pair of forward and backward functions. """ import thunder - from thunder.core.transforms import _grad_fn_map + from thunder.common import _make_cache_key + from thunder.core.transforms import _get_gradfn, eval_trace - joint_forward_backward = _grad_fn_map.get(bsym.sym.id, None) + joint_forward_backward = _get_gradfn(bsym) utils.check( joint_forward_backward is not None, lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) + + key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + cached_result = _cache.get(key, None) if subkey is not None else None + if cached_result is not None: + return cached_result + joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) consumers = utils.consumers(joint_trace) @@ -67,20 +76,31 @@ def find_backward_output(forward_input): bw_outputs_args = tree_map(find_backward_output, joint_trace.args) bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) meta_parameters = inspect.signature(bsym.sym.meta).parameters + meta_parameters = { + name: param + for name, param in meta_parameters.items() + if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) + } bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} bw_outputs = bw_outputs | bw_outputs_kwargs flat_bw_outputs, _ = tree_flatten(bw_outputs) - backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, bw_inputs) - unpacking_ops = ( + backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) + skip = ( prims.PrimIDs.UNPACK_EMPTY_DICT, prims.PrimIDs.UNPACK_KEY, prims.PrimIDs.UNPACK_SEQUENCE, prims.PrimIDs.UNPACK_TRIVIAL, + prims.PrimIDs.GET_GRAD, ) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in unpacking_ops] + backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) + forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] + forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] + forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] + # Find required info from forward trace for backward trace backward_producers = utils.producers(backward_bsyms) saved_for_backward = [] @@ -91,7 +111,37 @@ def find_backward_output(forward_input): if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): saved_for_backward.append(arg) - backward_params = [Parameter(x.name, Parameter.POSITIONAL_OR_KEYWORD) for x in chain(saved_for_backward, bw_inputs)] + saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) + + # Augment forward trace to include saved_for_backward as output + augmented_forward_trace = from_trace(joint_trace) + augmented_forward_trace.bound_symbols = [ + b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) + ] + return_bsym = augmented_forward_trace.bound_symbols[-1] + assert return_bsym.sym.id == PrimIDs.RETURN + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + # Remove put/get grad and backward symbols from augmented forward trace + augmented_forward_trace = dce(augmented_forward_trace) + + # Check if any of the bound symbols in the backward trace are also in the + # augmented forward trace + # If so, remove them from the backward trace + same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) + if same_bsyms: + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] + additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] + saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + + backward_params = [ + Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) + for i, x in enumerate(chain(saved_for_backward, bw_inputs)) + ] backward_signature = Signature(backward_params) def backward_fn(): @@ -106,15 +156,15 @@ def backward_fn(): backward_trace.kwargs = {} backward_trace.bound_symbols = backward_bsyms - # Augment forward trace to include saved_for_backward as output - augmented_forward_trace = from_trace(joint_trace) - augmented_forward_trace.bound_symbols = copy.copy(joint_trace.bound_symbols) - return_bsym = augmented_forward_trace.bound_symbols[-1] - assert return_bsym.sym.id == PrimIDs.RETURN - augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() - ) - # Remove put/get grad from augmented forward trace - augmented_forward_trace = dce(augmented_forward_trace) + # Creating new functions instead of using partial to avoid limitations in + # codeutils.get_siginfo + # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 + def fw_fn(*args, **kwargs): + return eval_trace(augmented_forward_trace, *args, **kwargs) + + def bw_fn(*args, **kwargs): + return eval_trace(backward_trace, *args, **kwargs) + + _cache[key] = fw_fn, bw_fn - return augmented_forward_trace.python_callable(), backward_trace.python_callable() + return fw_fn, bw_fn diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py index 7d5eec13fb..7d622cf25c 100644 --- a/thunder/executors/apex_entropyex.py +++ b/thunder/executors/apex_entropyex.py @@ -10,7 +10,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import Symbol from thunder.core.utils import check, same_shape -from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, sum_backward +from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims from thunder.core.transforms import ( register_augmented_forward_with_checker, register_backward, @@ -304,7 +304,7 @@ def _apex_cross_entropy_grad( if reduction == "mean": g = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), g) elif reduction == "sum": - g, _ = sum_backward(max_log_sum_exp.shape, (0,), g) + g = restore_reduced_dims(g, (0,), max_log_sum_exp.shape) # NOTE Apex's xentropy bwd requires the grad computation to be performed in fp32 a_ = a.contiguous() diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 8c9aa4657c..932fdc934f 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -707,112 +707,3 @@ def _scaled_dot_product_attention_checker( execution_transform=_scaled_dot_product_attention_fused, grad_transform=_scaled_dot_product_attention_grad, ) - - -def scaled_dot_product_attention_aug_fw( - query: TensorProxy, - key: TensorProxy, - value: TensorProxy, - attn_mask: TensorProxy | None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 - backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) - - utils.check( - backend != SpdaBackend.ERROR, - lambda: "Unable to find valid backend for scaled_dot_product_attention.", - ) - utils.check( - backend != SpdaBackend.MATH, - lambda: "The fallback to sdpa thunder reference is not implemented.", - exception_type=NotImplementedError, - ) - - tensor_args = (query, key, value) - scalar_args = (dropout_p, is_causal) - input_args = (*tensor_args, attn_mask, *scalar_args, scale) - if backend == SpdaBackend.FLASH_ATTENTION: - # Use flash attention kernel - (primal, *remaining_results, debug_attn_mask) = sdpfa_gradfwd(*tensor_args, *scalar_args, scale=scale) - # NOTE Remaining results contains [logsumexp, *flash_attn_only_residuals, *philox_residuals] - residuals = (*input_args, primal, *remaining_results) - return primal, residuals - elif backend == SpdaBackend.MEMORY_EFFICIENT: - # Use memory efficient kernel, which supports fp32 and attention mask arguments - (primal, logsumexp, *philox_residuals) = sdpea_gradfwd(*tensor_args, attn_mask, *scalar_args, scale=scale) - flash_attn_only_residuals = (None,) * 4 - residuals = (*input_args, primal, logsumexp, *flash_attn_only_residuals, *philox_residuals) - return primal, residuals - - -register_augmented_forward_with_checker( - sdpa_ex, - "torch.nn.functional.scaled_dot_product_attention", - _scaled_dot_product_attention_checker, - scaled_dot_product_attention_aug_fw, -) - - -@register_backward((sdpa_ex, "torch.nn.functional.scaled_dot_product_attention")) -def scaled_dot_product_attention_backward( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float, - is_causal: bool, - scale: None | float, - out: Proxy, - logsumexp: Proxy, - cum_seq_q: None | Proxy, - cum_seq_k: None | Proxy, - max_q: None | int, - max_k: None | int, - philox_seed: Proxy, - philox_offset: Proxy, - grad_out: Proxy, -): - tensor_args = (query, key, value) - scalar_args = (dropout_p, is_causal) - flash_attention_args = (cum_seq_q, cum_seq_k, max_q, max_k) - philox_args = (philox_seed, philox_offset) - use_flash_attn = all(map(lambda a: a is not None, (cum_seq_q, cum_seq_k, max_q, max_k))) - if use_flash_attn: - ( - grad_query, - grad_key, - grad_val, - ) = sdpfa_bwd( - grad_out, - *tensor_args, - out, - logsumexp, - *flash_attention_args, - *scalar_args, - *philox_args, - scale=scale, - ) - # grad_attn_mask is None since it is not supported by flash_attention kernel - return grad_query, grad_key, grad_val - else: - ( - grad_query, - grad_key, - grad_val, - grad_attn_mask, - ) = sdpea_bwd( - grad_out, - *tensor_args, - attn_mask, - out, - logsumexp, - *philox_args, - *scalar_args, - scale=scale, - ) - return grad_query, grad_key, grad_val, grad_attn_mask diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 9a3a7687e0..30376202a7 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2819,6 +2819,11 @@ def diagonal_sample_generator(op, device, dtype, requires_grad, **kwargs): ltorch.diagonal, sample_input_generator=diagonal_sample_generator, torch_reference=torch.diagonal, + test_directives=( + # thunder.torch.diagonal meta function is not correctly implemented for + # input case ((1, 2, 0, 3), -1, 0, -1) + DecorateInfo(pytest.mark.xfail(strict=True), "test_vjp_correctness"), + ), ) shape_ops.append(diagonal_opinfo) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 659dcb6523..0757be3f61 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -514,7 +514,7 @@ def test_vjp_correctness_sdpa_manual(op, device, dtype, executor, comp): vjp(filtered_op), disable_torch_autograd_support=True, disable_preprocessing=True, - executors_list=executor.executors_list() + [sdpa_ex], + executors_list=[sdpa_ex, *executor.executors_list()], )(filtered_args, (v,)) comp(actual_out, expect_out) @@ -822,6 +822,42 @@ def fun_bw(a, b, g): torch.testing.assert_close(actual_bw, expected_bw) +@instantiate( + dtypes=NOTHING, +) +def test_make_aug_forward_and_backward_var_mean(executor, device, _): + # This test checks that the split of the joint forward/backward function for + # var_mean correctly puts the forward part into the augmented forward + # function and the backward part into the backward function without + # overlapping symbols. + from thunder.core.vjp_utils import make_aug_forward_and_backward + from thunder.core.prims import var_mean + + def fun(a): + return var_mean(a, (0,), correction=1) + + x = torch.tensor((2, 2), device=device, dtype=torch.float32) + + trace = thunder.trace()(fun, x) + var_mean_bsym = trace.bound_symbols[-2] + assert var_mean_bsym.sym.name == "var_mean" + + aug_fw, bw = make_aug_forward_and_backward(var_mean_bsym) + aug_fw = executor.make_callable(aug_fw) + out, saved = aug_fw(x, (0,), correction=1) + bw = executor.make_callable(bw) + _ = bw(*saved, *out) + bw_trace = thunder.last_traces(bw)[0] + assert "var_mean" not in (s.sym.name for s in bw_trace.bound_symbols) + + +def test_no_duplicate_backward_registered(): + from thunder.core.transforms import backward_impls, _grad_fn_map + + same_keys = set(_grad_fn_map.keys()).intersection(set(backward_impls.keys())) + assert not same_keys, f"Duplicate keys: {same_keys}" + + @instantiate( dtypes=NOTHING, ) From 23139849f1e0a454d764c9a7016fc303d7db02f8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 14 Mar 2024 17:00:12 +0100 Subject: [PATCH 10/44] docker: build images with explicit `CUDNN_FRONTEND` (PR2438) --- .azure/docker-build.yml | 19 ++++++++++++------- .azure/gpu-tests.yml | 8 ++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/.azure/docker-build.yml b/.azure/docker-build.yml index 436a3a3a8b..a8e56f2952 100644 --- a/.azure/docker-build.yml +++ b/.azure/docker-build.yml @@ -40,13 +40,15 @@ jobs: #maxParallel: "3" matrix: # CUDA 12.1 - 'cuda 12.1 | torch 2.2': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0'} - 'cuda 12.1 | torch 2.3 /nightly': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source'} + 'cuda 12.1 | torch 2.2 | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version + {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.1.0"} + 'cuda 12.1 | torch 2.2 | cudnn FE v1.2': + {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.2.0"} + 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version + {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.1.0"} + 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.2': + {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.2.0"} #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found - # how long to run the job before automatically cancelling - timeoutInMinutes: "95" # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" variables: @@ -54,7 +56,7 @@ jobs: PYTHON_VERSION: '3.10' imageRepository: 'pytorchlightning/lightning-thunder' dockerfilePath: 'dockers/ubuntu-cuda/Dockerfile' - imageTag: 'ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}' + imageTag: 'ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-cudnn-fe$(CUDNN_FRONTEND)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}' pool: 'lit-rtx-3090' workspace: clean: all @@ -74,11 +76,13 @@ jobs: -f $(dockerfilePath) \ --build-arg UBUNTU_VERSION="$(UBUNTU_VERSION)" \ --build-arg CUDA_VERSION="$(CUDA_VERSION)" \ + --build-arg CUDNN_FRONTEND_CHECKOUT="v$(CUDNN_FRONTEND)" \ --build-arg PYTHON_VERSION="$(PYTHON_VERSION)" \ --build-arg TORCH_VERSION="$(TORCH_VERSION)" \ --build-arg TRITON_VERSION="$(TRITON_VERSION)" \ --build-arg TORCH_INSTALL="$(TORCH_INSTALL)" \ . --no-cache + timeoutInMinutes: "95" displayName: 'Build base image' - bash: | @@ -98,6 +102,7 @@ jobs: echo $(DOCKERHUB_PAT) | docker login --username $(DOCKERHUB_USER) --password-stdin docker push $(imageRepository):$(imageTag) condition: ne(variables['Build.Reason'], 'PullRequest') + timeoutInMinutes: "35" displayName: 'Push base image' #- task: Docker@1 diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 387e47fc9c..8055eb2c6b 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -17,17 +17,17 @@ jobs: matrix: # CUDA 12.1 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' testing: 'distributed' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' CUDA_VERSION_MM: '121' testing: 'distributed' # how long to run the job before automatically cancelling From e7f42298442301bc3c982de52a11bff9a707650c Mon Sep 17 00:00:00 2001 From: nikitaved Date: Fri, 15 Mar 2024 07:51:54 +0100 Subject: [PATCH 11/44] Gen JIT: sharp edge when calling `torch.*` which are not yet part of `ltorch` (PR2451) --- thunder/core/interpreter.py | 5 ++++- thunder/core/jit_ext.py | 33 +++++++++++-------------------- thunder/tests/test_jit_general.py | 20 +++++++++++++++++++ 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 43513ec489..904d371b88 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -6044,7 +6044,10 @@ def _impl(fn, *args, **kwargs): return _interpret_call(unbound_fn, slf, *args, **kwargs) # (2) Handles lookasides - lookaside_fn: None | Callable = compilectx.lookaside(fn, *args, **kwargs) + lookaside_fn: INTERPRETER_SIGNALS | None | Callable = compilectx.lookaside(fn, *args, **kwargs) + if lookaside_fn is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + # Happens with sharp edges, for example + return lookaside_fn if lookaside_fn: runtimectx.record_lookaside(lookaside_fn) res = lookaside_fn(*args, **kwargs) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 3a6f17a2f4..b3e2acfb23 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -884,14 +884,20 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: lookaside = default_lookaside(fn, *args, **kwargs) if lookaside is None: - if is_opaque(fn) and fn not in _safe_functions: + + def is_from_torch(fn): + return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch") + + if is_opaque(fn) and is_from_torch(fn): + # Torch functions have __name__ defined + fn_name = f"{fn.__module__}.{fn.__name__}" + + # For now, only torch-like opaque functions are sharp edges return _general_jit_sharp_edge( - f"Trying to call opaque function {extract_callable_name(fn)}, but it's unsupported. Please file an issue requesting supporting.", + f"Trying to call function {fn_name}, but it's unsupported. Please file an issue requesting support.", None, ) - return None - return lookaside @@ -975,24 +981,7 @@ def proxy_name_replacer(arg: Any): def _general_jit_global_callback(orig_value: Any, name: str) -> Any: _maybe_update_proxy_name(orig_value, name) - # Allows loading the torch module - value = orig_value - if ( - value is torch - or (value is torch.nn.modules.module._global_backward_pre_hooks) - or (value is torch.nn.modules.module._global_backward_hooks) - or (value is torch.nn.modules.module._global_forward_hooks) - or (value is torch.nn.modules.module._global_forward_pre_hooks) - or (value is torch.nn.functional) - or (value is thunder.core.proxies.get_langctx) - or (value is prop_lookaside_helper) - ): - return value - - return _general_jit_sharp_edge( - f"Tried to loading global {name}. Global support is limited.", - value, - ) + return orig_value _safe_provenance_inst = { diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 2c591d3680..d4c6dc92d0 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -23,6 +23,7 @@ import thunder.core.prims as prims from thunder import pytorch_executor, nvfuser_executor from thunder.executors.sdpaex import sdpa_ex +from thunder.core.jit_ext import JITSharpEdgeError # @@ -49,6 +50,25 @@ def skipif_not_pytorch_2_1(f): )(f) +def test_jitting_through_opaque_torch_symbols_sharp_edge(): + def no_sharp_edge(x): + # randn_like is in ltorch + return torch.randn_like(x) + + def sharp_edge(x): + # rand_like is not yet in ltroch + return torch.rand_like(x) + + x = torch.rand(1) + + jno_sharp_edge = thunder.jit(no_sharp_edge, sharp_edges="error") + jno_sharp_edge(x) + + jsharp_edge = thunder.jit(sharp_edge, sharp_edges="error") + with pytest.raises(JITSharpEdgeError): + jsharp_edge(x) + + def test_binary_add_tensors(): def foo(a, b): return a + b From 09bf63d76afb32eb249c07ad3b2f94ee7bbca58e Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Fri, 15 Mar 2024 16:18:39 +0100 Subject: [PATCH 12/44] autograd: fix backward support for executors which register per executor backward impls (PR2454) --- thunder/core/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e1c0d8dc70..b05351ca14 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1237,6 +1237,11 @@ def _embedding_prim_grad( def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable: + # If executor specific `aug_fwd_rule` exists then we will use that, + # so we return `None` here. + if get_executor_specific_aug_fwd_rule(bsym): + return None + cd = get_compile_data() executors_list = cd.executors_list if cd is not None else executors_list # Checks if the executor which has priority for this operation has a specific grad transform for it @@ -3317,11 +3322,11 @@ def uniform_backward(primal, minval, maxval, g): nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL) -def get_executor_specific_aug_fwd_rule(symbol) -> RuleInfo | None: +def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None: """Get executor specific augmented forward rule. Args: - symbol (prims.Symbol): Symbol to get the rule for. + symbol (BoundSymbol): BoundSymbol to get the rule for. Returns: RuleInfo: Rule info for the symbol. From ca0b478561ee4e22c6fb8a7dd5599c20506bd423 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 15 Mar 2024 18:34:35 +0100 Subject: [PATCH 13/44] fix split_fw_bw call to not use python_callable (PR2460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ivan Yashchuk Co-authored-by: Carlos Mocholí --- notebooks/dev_tutorials/fsdp_tutorial.ipynb | 280 +++++++++++--------- thunder/__init__.py | 4 +- thunder/executors/torch_autograd.py | 14 +- thunder/tests/distributed/test_ddp.py | 7 +- thunder/tests/test_extend.py | 17 +- 5 files changed, 184 insertions(+), 138 deletions(-) diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index c5c46d960a..170eb68ee8 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -188,19 +188,21 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", - " # input: "cuda:0 f32[64, 64]" \n", - " # t_0_weight: "cuda:0 f32[64, 64]" \n", - " # t_2_weight: "cuda:0 f32[64, 64]" \n", - " t0 = torch.nn.functional.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " # t0 = ltorch.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " # t0 = prims.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " [t1] = nvFusion0(t0)\n", - " # t1 = prims.tanh(t0) # t1: "cuda:0 f32[64, 64]"\n", - " t2 = torch.nn.functional.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " # t2 = prims.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())\n", + "def augmented_forward_fn(*args):\n", + " # args: "Collection" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = torch.nn.functional.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " # t3 = prims.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " [t4] = nvFusion0(t3)\n", + " # t4 = prims.tanh(t3) # t4: "cuda:0 f32[64, 64]"\n", + " t5 = torch.nn.functional.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " # t5 = prims.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n", "\n" ], "text/latex": [ @@ -212,19 +214,21 @@ "\n", "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} input: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}0\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}2\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{n}{t0} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t0 = ltorch.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t0 = prims.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{p}{[}\\PY{n}{t1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{nvFusion0}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{)}\n", - " \\PY{c+c1}{\\PYZsh{} t1 = prims.tanh(t0) \\PYZsh{} t1: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t2 = ltorch.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t2 = prims.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t2}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t2}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{p}{[}\\PY{n}{t4}\\PY{p}{]} \\PY{o}{=} \\PY{n}{nvFusion0}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.tanh(t3) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ @@ -235,19 +239,21 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", - " # input: \"cuda:0 f32[64, 64]\" \n", - " # t_0_weight: \"cuda:0 f32[64, 64]\" \n", - " # t_2_weight: \"cuda:0 f32[64, 64]\" \n", - " t0 = torch.nn.functional.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " # t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " # t0 = prims.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " [t1] = nvFusion0(t0)\n", - " # t1 = prims.tanh(t0) # t1: \"cuda:0 f32[64, 64]\"\n", - " t2 = torch.nn.functional.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " # t2 = prims.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())" + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = torch.nn.functional.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " [t4] = nvFusion0(t3)\n", + " # t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = torch.nn.functional.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" ] }, "execution_count": 4, @@ -481,18 +487,20 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", - " # input: "cuda:0 f32[64, 64]" \n", - " # t_0_weight: "cuda:0 f32[64, 64]" \n", - " # t_2_weight: "cuda:0 f32[64, 64]" \n", - " t0 = ltorch.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " # t0 = ltorch.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " # t0 = prims.linear(input, t_0_weight, None) # t0: "cuda:0 f32[64, 64]"\n", - " t1 = prims.tanh(t0) # t1: "cuda:0 f32[64, 64]"\n", - " t2 = ltorch.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " # t2 = prims.linear(t1, t_2_weight, None) # t2: "cuda:0 f32[64, 64]"\n", - " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())\n", + "def augmented_forward_fn(*args):\n", + " # args: "Collection" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = ltorch.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " # t3 = prims.linear(t0, t1, None) # t3: "cuda:0 f32[64, 64]"\n", + " t4 = prims.tanh(t3) # t4: "cuda:0 f32[64, 64]"\n", + " t5 = ltorch.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " # t5 = prims.linear(t4, t2, None) # t5: "cuda:0 f32[64, 64]"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n", "\n" ], "text/latex": [ @@ -507,18 +515,20 @@ "\n", "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} input: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}0\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{c+c1}{\\PYZsh{} t\\PYZus{}2\\PYZus{}weight: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{n}{t0} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t0 = ltorch.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t0 = prims.linear(input, t\\PYZus{}0\\PYZus{}weight, None) \\PYZsh{} t0: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t1} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t1: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t2 = ltorch.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t2 = prims.linear(t1, t\\PYZus{}2\\PYZus{}weight, None) \\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t2}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t\\PYZus{}0\\PYZus{}weight}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t2}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n+nb}{input}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t\\PYZus{}2\\PYZus{}weight}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ @@ -532,18 +542,20 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def augmented_forward_fn(input, t_0_weight, t_2_weight):\n", - " # input: \"cuda:0 f32[64, 64]\" \n", - " # t_0_weight: \"cuda:0 f32[64, 64]\" \n", - " # t_2_weight: \"cuda:0 f32[64, 64]\" \n", - " t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " # t0 = ltorch.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " # t0 = prims.linear(input, t_0_weight, None) # t0: \"cuda:0 f32[64, 64]\"\n", - " t1 = prims.tanh(t0) # t1: \"cuda:0 f32[64, 64]\"\n", - " t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " # t2 = ltorch.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " # t2 = prims.linear(t1, t_2_weight, None) # t2: \"cuda:0 f32[64, 64]\"\n", - " return {'output': t2, 'flat_args': [input, t_0_weight, t_2_weight], 'flat_output': (t2,)}, ((input, t1, t_2_weight), ())" + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" ] }, "execution_count": 10, @@ -553,10 +565,10 @@ ], "source": [ "### DON'T TRY THIS AT HOME\n", - "computation_trace.bound_symbols[3].sym = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0].sym\n", - "if cache_rec.computation_traces[0].bound_symbols[4].subsymbols:\n", - " computation_trace.bound_symbols[4] = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0]\n", - "computation_trace.bound_symbols[5].sym = cache_rec.computation_traces[0].bound_symbols[5].subsymbols[0].sym\n", + "computation_trace.bound_symbols[2].sym = cache_rec.computation_traces[0].bound_symbols[2].subsymbols[0].sym\n", + "if cache_rec.computation_traces[0].bound_symbols[3].subsymbols:\n", + " computation_trace.bound_symbols[3] = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0]\n", + "computation_trace.bound_symbols[4].sym = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0].sym\n", "\n", "wrap_as_highlighted_code(computation_trace)" ] @@ -701,7 +713,7 @@ " t5 = prims.tanh(t4) # t5: "cuda:0 f32[64, 64]"\n", " t6 = ltorch.linear(t5, t3, None) # t6: "cuda:0 f32[64, 64]"\n", " # t6 = prims.linear(t5, t3, None) # t6: "cuda:0 f32[64, 64]"\n", - " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t5, t3), ()))\n", + " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))\n", "\n" ], "text/latex": [ @@ -729,7 +741,7 @@ " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t6}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t6}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t6}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t6}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ @@ -756,7 +768,7 @@ " t5 = prims.tanh(t4) # t5: \"cuda:0 f32[64, 64]\"\n", " t6 = ltorch.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", " # t6 = prims.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", - " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t5, t3), ()))" + " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))" ] }, "execution_count": 12, @@ -862,7 +874,7 @@ ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", - ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */
# Constructed by Dead Code Elimination (took 1 milliseconds)\n",
        "import thunder\n",
        "import thunder.core.devices as devices\n",
        "import thunder.core.dtypes as dtypes\n",
@@ -897,7 +909,7 @@
        "  t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
        "  t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
        "  t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
-       "  t20 = prims.add(t5, t10)  # t20: "cuda:0 f32[64, 64]"\n",
+       "  t20 = prims.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
        "  t21 = ltorch.reshape(t18, -1, 64)  # t21: "cuda:0 f32[64, 64]"\n",
        "    # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
        "  t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
@@ -909,11 +921,12 @@
        "    # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
        "  t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
        "    # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "  t27 = prims.add(t9, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "  t27 = prims.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
        "  t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
        "  t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
        "    # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "  t30 = ltorch.sub(1.0, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "  t30 = ltorch.sub(1, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(1, float)\n",
        "    # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
        "  t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
        "    # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
@@ -940,12 +953,12 @@
        "    # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
        "  p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
        "  t45 = thunder.distributed.prims.wait(p44)  # t45: "cuda:0 f32[32, 64]"\n",
-       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))\n",
+       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
        "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 1 milliseconds)}\n", "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", @@ -980,7 +993,7 @@ " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", @@ -992,11 +1005,12 @@ " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", @@ -1023,11 +1037,11 @@ " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "# Constructed by Dead Code Elimination (took 1 milliseconds)\n", "import thunder\n", "import thunder.core.devices as devices\n", "import thunder.core.dtypes as dtypes\n", @@ -1062,7 +1076,7 @@ " t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", " t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", " t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", - " t20 = prims.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", + " t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", " t21 = ltorch.reshape(t18, -1, 64) # t21: \"cuda:0 f32[64, 64]\"\n", " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", " t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", @@ -1074,11 +1088,12 @@ " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", " t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " t27 = prims.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", " t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", " t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " t30 = ltorch.sub(1.0, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(1, float)\n", " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", " t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", @@ -1105,7 +1120,7 @@ " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", " t45 = thunder.distributed.prims.wait(p44) # t45: \"cuda:0 f32[32, 64]\"\n", - " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))" + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" ] }, "execution_count": 13, @@ -1279,10 +1294,10 @@ " # t19 = ltorch.add(t3, t8, alpha=None) # t19: "cuda:0 f32[64, 64]"\n", " # t19 = prims.add(t3, t8) # t19: "cuda:0 f32[64, 64]"\n", " del t3, t8\n", - " t20 = torch.add(t5, t10) # t20: "cuda:0 f32[64, 64]"\n", - " # t20 = ltorch.add(t5, t10, alpha=None) # t20: "cuda:0 f32[64, 64]"\n", - " # t20 = prims.add(t5, t10) # t20: "cuda:0 f32[64, 64]"\n", - " del t5, t10\n", + " t20 = torch.add(t5, t9) # t20: "cuda:0 f32[64, 64]"\n", + " # t20 = ltorch.add(t5, t9, alpha=None) # t20: "cuda:0 f32[64, 64]"\n", + " # t20 = prims.add(t5, t9) # t20: "cuda:0 f32[64, 64]"\n", + " del t5, t9\n", " t21 = torch.reshape(t18, (-1, 64)) # t21: "cuda:0 f32[64, 64]"\n", " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: "cuda:0 f32[64, 64]"\n", " # t21 = prims.reshape(t18, (64, 64)) # t21: "cuda:0 f32[64, 64]"\n", @@ -1305,10 +1320,10 @@ " # t26 = ltorch.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n", " # t26 = prims.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n", " del t24, t25\n", - " t27 = torch.add(t9, t22) # t27: "cuda:0 f32[64, 64]"\n", - " # t27 = ltorch.add(t9, t22, alpha=None) # t27: "cuda:0 f32[64, 64]"\n", - " # t27 = prims.add(t9, t22) # t27: "cuda:0 f32[64, 64]"\n", - " del t9, t22\n", + " t27 = torch.add(t10, t22) # t27: "cuda:0 f32[64, 64]"\n", + " # t27 = ltorch.add(t10, t22, alpha=None) # t27: "cuda:0 f32[64, 64]"\n", + " # t27 = prims.add(t10, t22) # t27: "cuda:0 f32[64, 64]"\n", + " del t10, t22\n", " t28 = torch.add(t20, t26) # t28: "cuda:0 f32[64, 64]"\n", " # t28 = ltorch.add(t20, t26, alpha=None) # t28: "cuda:0 f32[64, 64]"\n", " # t28 = prims.add(t20, t26) # t28: "cuda:0 f32[64, 64]"\n", @@ -1316,8 +1331,9 @@ " t29 = torch.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n", " # t29 = ltorch.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n", " # t29 = prims.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n", - " t30 = torch.sub(1.0, t29) # t30: "cuda:0 f32[64, 64]"\n", - " # t30 = ltorch.sub(1.0, t29, alpha=None) # t30: "cuda:0 f32[64, 64]"\n", + " t30 = torch.sub(1, t29) # t30: "cuda:0 f32[64, 64]"\n", + " # t30 = ltorch.sub(1, t29, alpha=None) # t30: "cuda:0 f32[64, 64]"\n", + " # _ = prims.convert_element_type(1, float)\n", " # t30 = prims.sub(1.0, t29) # t30: "cuda:0 f32[64, 64]"\n", " del t29\n", " t31 = torch.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n", @@ -1372,7 +1388,7 @@ " del t43\n", " t45 = torch_wait_prim_impl(p44) # t45: "cuda:0 f32[32, 64]"\n", " del p44\n", - " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))\n", + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n", "
\n" ], "text/latex": [ @@ -1441,10 +1457,10 @@ " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.add(t3, t8, alpha=None) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t19 = prims.add(t3, t8) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{k}{del} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t10, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t10) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t10}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t9, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t9) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\n", " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", @@ -1467,10 +1483,10 @@ " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{k}{del} \\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\n", - " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t9, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t9, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t9}\\PY{p}{,} \\PY{n}{t22}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t10, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t10, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\n", " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t28 = ltorch.add(t20, t26, alpha=None) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t28 = prims.add(t20, t26) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", @@ -1478,8 +1494,9 @@ " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1.0, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", " \\PY{k}{del} \\PY{n}{t29}\n", " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", @@ -1534,7 +1551,7 @@ " \\PY{k}{del} \\PY{n}{t43}\n", " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", " \\PY{k}{del} \\PY{n}{p44}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ @@ -1602,10 +1619,10 @@ " # t19 = ltorch.add(t3, t8, alpha=None) # t19: \"cuda:0 f32[64, 64]\"\n", " # t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", " del t3, t8\n", - " t20 = torch.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", - " # t20 = ltorch.add(t5, t10, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n", - " # t20 = prims.add(t5, t10) # t20: \"cuda:0 f32[64, 64]\"\n", - " del t5, t10\n", + " t20 = torch.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = ltorch.add(t5, t9, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", + " del t5, t9\n", " t21 = torch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", @@ -1628,10 +1645,10 @@ " # t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", " del t24, t25\n", - " t27 = torch.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", - " # t27 = ltorch.add(t9, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n", - " # t27 = prims.add(t9, t22) # t27: \"cuda:0 f32[64, 64]\"\n", - " del t9, t22\n", + " t27 = torch.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = ltorch.add(t10, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " del t10, t22\n", " t28 = torch.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", " # t28 = ltorch.add(t20, t26, alpha=None) # t28: \"cuda:0 f32[64, 64]\"\n", " # t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", @@ -1639,8 +1656,9 @@ " t29 = torch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", " # t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " t30 = torch.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", - " # t30 = ltorch.sub(1.0, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " t30 = torch.sub(1, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(1, float)\n", " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", " del t29\n", " t31 = torch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", @@ -1695,7 +1713,7 @@ " del t43\n", " t45 = torch_wait_prim_impl(p44) # t45: \"cuda:0 f32[32, 64]\"\n", " del p44\n", - " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t16, t14), ())), (t38, t45, t42))" + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" ] }, "execution_count": 14, @@ -1998,7 +2016,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/thunder/__init__.py b/thunder/__init__.py index ae252b9836..e857ec23bd 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -548,9 +548,7 @@ def get_computation_and_inputs(*args, **kwargs): # thunder_backward may recursively call compile and wraps the result in a # torch.autograd.Function to support embedding of Thunder-compiled # functions in torch's Autograd - computation_trc, backward_trc = split_forward_backward( - computation_trc.python_callable(), cd, cs, *inps - ) + computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) computation_traces.append(computation_trc) cs.last_computation_transformation_start = time.time_ns() diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 2eb3f1f56b..c58551875a 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -161,7 +161,7 @@ def wrapper(*args, **kwargs): return decorator -def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs): +def split_forward_backward(computation_trc, compile_data, compile_stats, /, *args, **kwargs): from thunder import trace from thunder.executors.passes import transform_for_execution from thunder.executors.passes import del_last_used @@ -170,6 +170,17 @@ def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs from thunder.cudagraphs import CUDAGraphExecutor from thunder.distributed.utils import sort_waits, sort_data_parallel_syncs, sort_waits_for_zero3 from thunder.distributed.transforms import FSDPCommBucketing + from thunder.core.transforms import eval_trace + + # TODO: the trace->func->trace could likely be simplified (and look nicer) + # we cannot use python_callable() here, see the old repos 2458 + if not isinstance(computation_trc, TraceCtx): + # for the legacy codepath + func = computation_trc + else: + + def func(*args): + return eval_trace(computation_trc, *args) utils.check(compile_data is not None, lambda: "`compile_data` is required") @@ -178,6 +189,7 @@ def make_trace(func): trace(compile_data=compile_data, inline_trace=False, insert_ddp_syncs=not compile_data.using_jit), func ) + computation_trc.kwargs = {} # NOTE: This function is rather slow, so it's intended to be used # behind a cache. ba = signature(func).bind(*args, **kwargs) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 4ae624b6db..62219d73f4 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -488,8 +488,11 @@ def test_rematerialize_all_gather(self): result_fwd_trc, result_bwd_trc = rematerialize_all_gather(fwd_trc, bwd_trc) # check the return statement in forward trace is updated - sharded_param_names = ("t_net1_weight", "t_net2_weight") - unshard_param_names = ("t5", "t16") + # TODO: this is not stable w.r.t. details of the processing, the sharded correspond to ("t_net1_weight", "t_net2_weight") + # in the original trace and are inputs to all_gather, the unshard are the outputs fo the corresponding wait + # If you fix this to be dynamically discerned, you'll be my hero. + sharded_param_names = ("t3", "t4") + unshard_param_names = ("t10", "t21") result_saved_for_bwd = [x.name for x in fwd_trc.bound_symbols[-1].args[1][0]] self.assertTrue(all(t not in sharded_param_names for t in result_saved_for_bwd)) self.assertTrue(all(t in result_saved_for_bwd for t in unshard_param_names)) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 00dfe71fd7..d678bd531b 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -158,7 +158,13 @@ def fn(a, b): def myadd_trafo(a, b): return myadd2(a, b) - myex.register_implementation(myadd1, execution_transform=myadd_trafo) + def myadd_grad_trafo(a, b): + res = myadd2(a, b) + grad_res = get_grad(res) + put_grads((a, b), (grad_res, grad_res)) + return res + + myex.register_implementation(myadd1, execution_transform=myadd_trafo, grad_transform=myadd_grad_trafo) cfn = thunder.jit(fn, executors=[myex]) res = cfn(a, b) @@ -166,4 +172,13 @@ def myadd_trafo(a, b): s = str(thunder.last_traces(cfn)[-1]) assert "myadd2" in s and "myadd1" not in s + a.requires_grad_() + + res = cfn(a, b) + + s = str(thunder.last_traces(cfn)[0][-1]) + assert "myadd2" in s and "myadd1" not in s + + a.requires_grad_() + deregister_executor(myex) From 590c77349f589f799595645774788c81f4dc6efe Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 16 Mar 2024 07:09:01 +0100 Subject: [PATCH 14/44] split last_traces into last_traces, last_backward_traces (PR2463) --- docs/source/basic/mlp_mnist.rst | 3 ++- docs/source/reference/thunder.rst | 1 + notebooks/dev_tutorials/fsdp_tutorial.ipynb | 3 ++- thunder/__init__.py | 28 +++++++++++++-------- thunder/common.py | 10 ++++---- thunder/executors/torch_autograd.py | 8 +++--- thunder/tests/distributed/test_ddp.py | 15 +++++------ thunder/tests/test_examine_memory.py | 12 ++++----- thunder/tests/test_extend.py | 2 +- thunder/tests/test_grad.py | 15 +++++------ 10 files changed, 54 insertions(+), 43 deletions(-) diff --git a/docs/source/basic/mlp_mnist.rst b/docs/source/basic/mlp_mnist.rst index f4335c91f4..e7e81912c3 100644 --- a/docs/source/basic/mlp_mnist.rst +++ b/docs/source/basic/mlp_mnist.rst @@ -90,7 +90,8 @@ Here's the code:: # The training model has both "forward" and "backward" traces, corresponding # to its forward and backward computations. # The evaluation model has only one set of traces. - fwd_traces, bwd_traces = thunder.last_traces(jitted_train_model) + fwd_traces = thunder.last_traces(jitted_train_model) + bwd_traces = thunder.last_backward_traces(jitted_train_model) eval_traces = thunder.last_traces(jitted_eval_model) print("This is the trace that thunder executed for training's forward computation:") diff --git a/docs/source/reference/thunder.rst b/docs/source/reference/thunder.rst index ced14ed4e8..706f9d6f08 100644 --- a/docs/source/reference/thunder.rst +++ b/docs/source/reference/thunder.rst @@ -25,6 +25,7 @@ Querying information on compiled functions and modules compile_data compile_stats last_traces + last_backward_traces last_prologue_traces cache_option cache_hits diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index 170eb68ee8..a4f61b47c3 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -1818,7 +1818,8 @@ "# # # # # # # #\n", "# Check the traces\n", "# # # # # # # #\n", - "fwd_traces, bwd_traces = thunder.last_traces(cmodel)\n", + "fwd_traces = thunder.last_traces(cmodel)\n", + "bwd_traces = thunder.last_backward_traces(cmodel)\n", "\n", "# # # # # # # #\n", "# Print and check to see if they match ours\n", diff --git a/thunder/__init__.py b/thunder/__init__.py index e857ec23bd..0c4ee7c7f8 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -527,6 +527,9 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_execution_stop = time.time_ns() computation_traces = [computation_trc] + cs.last_traces = computation_traces + backward_traces = [] + cs.last_backward_traces = backward_traces computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -549,7 +552,8 @@ def get_computation_and_inputs(*args, **kwargs): # torch.autograd.Function to support embedding of Thunder-compiled # functions in torch's Autograd computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) - computation_traces.append(computation_trc) + # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces + # by split_forward_backward cs.last_computation_transformation_start = time.time_ns() @@ -569,10 +573,8 @@ def get_computation_and_inputs(*args, **kwargs): if backward_trc is not None: backward_fn = backward_trc.python_callable() - backward_traces = [backward_trc] else: backward_fn = None - backward_traces = [] # TODO RC1 Update the cache cache_entry = CacheEntry( @@ -582,7 +584,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.interpreter_cache.append(cache_entry) cs.last_computation_transformation_stop = time.time_ns() - cs.last_traces = [computation_trc] + extraces + cs.last_traces += extraces cs.last_prologue_traces = [prologue_trc] + protraces cs.last_prologue = pro @@ -697,25 +699,31 @@ def compile_stats(fn) -> CompileStats | None: return getattr(fn, "_lc_cs", None) -# TODO We should remove compiledata.last_traces in favor of forward_last_traces and backward_last_traces -# TODO: should we return fw and bw from separate functions. The return type (list or tuple of lists) is not so nice -def last_traces(fn) -> list[TraceCtx] | tuple[list[TraceCtx], list[TraceCtx]]: +def last_traces(fn) -> list[TraceCtx]: """Obtains the list of computation traces that have been produced for the last run of the function. This is a list of traces mirroring the progression of transformations being applied to the trace (at index 0) that has been acquired from interpreting the user program. - If the function has forward and backward, a tuple of them is returned. + If the function has forward and backward, the forward is returned. """ cs = compile_stats(fn) if cs is None: raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") - if cs.forward_last_traces is not None and cs.backward_last_traces is not None: - return cs.forward_last_traces, cs.backward_last_traces if cs.last_traces is None: raise TypeError(f"{fn} doesn't seem to have been called yet.") return cs.last_traces +def last_backward_traces(fn) -> TraceCtx: + """Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue.""" + cs = compile_stats(fn) + if cs is None: + raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") + if cs.last_backward_traces is None: + raise TypeError(f"{fn} doesn't seem to have been called yet.") + return cs.last_backward_traces + + def last_prologue_traces(fn) -> TraceCtx: """Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue.""" cs = compile_stats(fn) diff --git a/thunder/common.py b/thunder/common.py index c0ec2fa770..914f7e5359 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -62,9 +62,7 @@ def __init__(self): self.last_interpreted_history = None # torch.autograd.Function specific data - self.primal_trace = None - self.forward_last_traces = None - self.backward_last_traces = None + self.last_backward_traces = None # Timing stats self.last_trace_host_start: int = -1 @@ -653,7 +651,6 @@ def _execute_trace( # Today all tensor outputs will be torch tensors, even if the input was NumPy arrays # provided in the NumPy language ctx -- what should the outputs be? Should we provide # a helper to convert torch tensors to NumPy arrays on output? -# TODO Provide an option to not preprocess (for debugging) def _create_callable( @@ -736,6 +733,9 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: # Resets use of compile flags cs.last_compile_reasons = defaultdict(list) with compile_data_and_stats(cd, cs): + traces: list[TraceCtx] = [] + cs.last_traces = traces + cs.last_backward_traces = [] # Determines whether to use autograd.Function or not # autograd.Function (which supports calling .backward() in PyTorch) is used when: # 1) The grad() transform is not applied @@ -794,7 +794,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: # Starts recording a sequence of traces (this is not inlined) trc: TraceCtx = trc_or_result - traces: list[TraceCtx] = [trc] + traces.append(trc) # Applies transforms for transform in transforms: diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index c58551875a..93c0c2d874 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -205,6 +205,9 @@ def make_trace(func): primal_trace = make_trace(func)(*args, **kwargs) primal_trace = sort_data_parallel_syncs(primal_trace) + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + # torch.autograd.Function doesn't support non-flat outputs, the # grads wouldn't be propagated and backward receives None for each # non-flat non-tensor output. The output must also be a flat tuple, @@ -329,8 +332,7 @@ def make_trace(func): bw_traces.append(bw_extrace) if compile_stats is not None: - compile_stats.primal_trace = primal_trace - compile_stats.forward_last_traces = fw_traces - compile_stats.backward_last_traces = bw_traces + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces return fw_extrace, bw_extrace diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 62219d73f4..9b7ad2244f 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -452,7 +452,7 @@ def test_ddp_grad_bucketing(self, executor, bucket_size_in_mb: int): x = torch.ones((2, 12)).to(device) cm(x).mean().backward() - bwd_extrace = thunder.last_traces(cm)[1][-1] + bwd_extrace = thunder.last_backward_traces(cm)[-1] bsym_sym_id_list = [bsym.sym.id for bsym in bwd_extrace.bound_symbols] pack_syms = tuple(filter(lambda a: a == pack_prim_impl.id, bsym_sym_id_list)) unpack_syms = tuple(filter(lambda a: a == unpack_prim_impl.id, bsym_sym_id_list)) @@ -476,13 +476,14 @@ def test_rematerialize_all_gather(self): m = ToyModel().to(device) cm = thunder.jit( fsdp(m, device=device, broadcast_from=0), - interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON, ) x = torch.ones((2, 12), device=device) cm(x).mean().backward() - fwd_trc = thunder.last_traces(cm)[0][0] - bwd_trc = thunder.last_traces(cm)[1][0] + (fwd_trc,) = ( + t for t in thunder.last_traces(cm) if getattr(t.get_provenance(), "pss", "") == "Augmented forward pass" + ) + bwd_trc = thunder.last_backward_traces(cm)[0] from thunder.core.rematerialization import rematerialize_all_gather result_fwd_trc, result_bwd_trc = rematerialize_all_gather(fwd_trc, bwd_trc) @@ -829,8 +830,8 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool): loss.backward() # get the trace before sorting - fwd_trc = thunder.last_traces(cm)[0][-2] - bwd_trc = thunder.last_traces(cm)[1][-2] + fwd_trc = thunder.last_traces(cm)[-2] + bwd_trc = thunder.last_backward_traces(cm)[-2] from thunder.distributed.utils import limit_in_flight_allgathers @@ -1104,7 +1105,7 @@ def _test_native_ddp_helper(input_data): tdist.destroy_process_group(pg) if rank == 0: - bwd_extrace_sym_ids = [bsym.sym.id for bsym in thunder.last_traces(cmodel)[1][-1].bound_symbols] + bwd_extrace_sym_ids = [bsym.sym.id for bsym in thunder.last_backward_traces(cmodel)[-1].bound_symbols] pack_unpack_update_bucket_view_found = ( "torch_pack_prim_impl" in bwd_extrace_sym_ids and "torch_unpack_prim_impl" in bwd_extrace_sym_ids diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 1d01bef894..13bd05360b 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -63,12 +63,13 @@ def bar(a, b): # [4] [2,2] with runtime_allocated_memory(device): cbar(a, b) - traces = thunder.last_traces(cbar) - fwd_extrace = traces[0][-1] + fw_traces = thunder.last_traces(cbar) + fwd_extrace = fw_traces[-1] max_mem_fwd = get_alloc_memory(fwd_extrace) assert max_mem_fwd[0] == 144 assert sum(max_mem_fwd[1].values()) == get_return_memory(fwd_extrace.bound_symbols[-1]) # 144 - bw_extrace = traces[1][-1] + bw_traces = thunder.last_backward_traces(cbar) + bw_extrace = bw_traces[-1] max_mem_bw = get_alloc_memory(bw_extrace) assert max_mem_bw[0] == 144 assert sum(max_mem_bw[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) # 32 @@ -137,9 +138,8 @@ def test_nanogpt_block(executor, device, dtype): result = cblock(inp) with runtime_allocated_memory(device): result.backward(torch.ones_like(result)) - traces = thunder.last_traces(cblock) - fw_extrace = traces[0][-1] - bw_extrace = traces[1][-1] + fw_extrace = thunder.last_traces(cblock)[-1] + bw_extrace = thunder.last_backward_traces(cblock)[-1] fw_alloc_mem = get_alloc_memory(fw_extrace) bw_alloc_mem = get_alloc_memory(bw_extrace) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index d678bd531b..03277f3419 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -176,7 +176,7 @@ def myadd_grad_trafo(a, b): res = cfn(a, b) - s = str(thunder.last_traces(cfn)[0][-1]) + s = str(thunder.last_traces(cfn)[-1]) assert "myadd2" in s and "myadd1" not in s a.requires_grad_() diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 0757be3f61..e2100a585c 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1007,19 +1007,16 @@ def test_torch_autograd_module_get_compile_stats(executor, device, _): out.backward(g) compile_stats = compile_stats(lc) - primal_trace = compile_stats.primal_trace - forward_traces = compile_stats.forward_last_traces - backward_traces = compile_stats.backward_last_traces + forward_traces = compile_stats.last_traces + backward_traces = compile_stats.last_backward_traces assert isinstance(forward_traces, list) assert len(forward_traces) >= 1 assert isinstance(backward_traces, list) assert len(backward_traces) >= 1 - assert isinstance(primal_trace, TraceCtx) - fw_bw_traces = thunder.last_traces(lc) - assert isinstance(fw_bw_traces, tuple) - assert len(fw_bw_traces) == 2 - assert fw_bw_traces[0] == forward_traces - assert fw_bw_traces[1] == backward_traces + fw_traces = thunder.last_traces(lc) + bw_traces = thunder.last_backward_traces(lc) + assert fw_traces == forward_traces + assert bw_traces == backward_traces @instantiate( From 371db29df37027bb48f4e24cede2e2a24cbfbc3b Mon Sep 17 00:00:00 2001 From: Tom Fogal <60981+tfogal@users.noreply.github.com> Date: Sat, 16 Mar 2024 00:49:08 -0700 Subject: [PATCH 15/44] Broad but minor cleanups. (PR2462) --- .azure/docker-build.yml | 1 + .azure/gpu-tests.yml | 2 +- dockers/ubuntu-cuda/Dockerfile | 2 +- docs/source/index.rst | 1 - examples/lit-gpt/_ddp_thunder.py | 4 +- examples/lit-gpt/train.py | 2 +- examples/lit-gpt/train_fsdp.py | 2 +- examples/llama2.c/model.py | 5 +- examples/llama2.c/sample.py | 10 +- examples/llama2.c/tinystories.py | 2 +- examples/llama2.c/train.py | 6 +- notebooks/dev_tutorials/extend.ipynb | 2 +- notebooks/dev_tutorials/patterns.ipynb | 441 ------------------------- setup.py | 6 +- thunder/benchmarks/distributed.py | 2 +- thunder/benchmarks/targets.py | 13 +- thunder/clang/__init__.py | 4 +- thunder/common.py | 10 +- thunder/core/interpreter.py | 13 +- thunder/core/jit_ext.py | 4 +- thunder/core/langctxs.py | 2 +- thunder/core/proxies.py | 2 +- thunder/core/symbol.py | 15 +- thunder/core/trace.py | 8 +- thunder/core/transform_common.py | 6 +- thunder/core/transforms.py | 14 +- thunder/distributed/__init__.py | 2 +- thunder/distributed/transforms/fsdp.py | 2 +- thunder/distributed/utils.py | 1 - thunder/examine/__init__.py | 6 +- thunder/examine/memory_caculation.py | 2 +- thunder/executors/apex_entropyex.py | 13 +- thunder/executors/cudnn_layernormex.py | 4 +- thunder/executors/cudnnex.py | 3 - thunder/executors/nvfuserex.py | 1 - thunder/executors/nvfuserex_impl.py | 25 +- thunder/executors/passes.py | 4 +- thunder/executors/sdpaex.py | 20 +- thunder/executors/torch_compile.py | 17 +- thunder/executors/torchex.py | 14 +- thunder/numpy/langctx.py | 2 +- thunder/tests/distributed/test_ddp.py | 6 +- thunder/tests/lit_gpt_model.py | 2 +- thunder/tests/llama2_model.py | 4 +- thunder/tests/nanogpt_model.py | 9 +- thunder/tests/opinfos.py | 151 +++++---- thunder/tests/test_core.py | 24 +- thunder/tests/test_cudnn_executor.py | 9 +- thunder/tests/test_elementwise.py | 9 +- thunder/tests/test_grad.py | 12 +- thunder/tests/test_interpreter.py | 11 +- thunder/tests/test_jit_functional.py | 8 +- thunder/tests/test_jit_general.py | 11 +- thunder/tests/test_networks.py | 2 +- thunder/tests/test_nvfuser.py | 5 +- thunder/tests/test_nvfuser_remat.py | 4 +- thunder/tests/test_ops.py | 2 +- thunder/torch/__init__.py | 31 +- 58 files changed, 271 insertions(+), 724 deletions(-) delete mode 100644 notebooks/dev_tutorials/patterns.ipynb diff --git a/.azure/docker-build.yml b/.azure/docker-build.yml index a8e56f2952..73233ae78c 100644 --- a/.azure/docker-build.yml +++ b/.azure/docker-build.yml @@ -51,6 +51,7 @@ jobs: #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" + timeoutInMinutes: "95" variables: UBUNTU_VERSION: '22.04' PYTHON_VERSION: '3.10' diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 8055eb2c6b..22ba01eddc 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -111,7 +111,7 @@ jobs: condition: eq(variables['testing'], 'distributed') displayName: 'Testing: distributed' - # todo for Mike as he promised some time ago already... or shall it ne another workflow so keep time low? + # todo (mruberry): decide whether this should be here or in another workflow #- bash: | # python benchmarks/ops_benchmark.py nanogpt-gelu # python benchmarks/nvfuser_benchmarks.py nanogpt-mlp -x thunder diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index 777d723697..e815d827f6 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -74,7 +74,7 @@ RUN \ RUN \ echo "CUDA_VERSION=$CUDA_VERSION ; CUDNN_VERSION=$CUDNN_VERSION " && \ CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ - # there is missing cudnn for 12.1 so use 12.2 instead + # There are some test failures from cuDNN 12.1, so 'upgrade' requests for 12.1 to 12.2. CUDA_VERSION_MM="${CUDA_VERSION_MM/12.1/12.2}" && \ CUDNN_BASE_VER=${CUDNN_VERSION%%.*} && \ CUDNN_PACKAGE_VER="${CUDNN_VERSION}+cuda${CUDA_VERSION_MM}" && \ diff --git a/docs/source/index.rst b/docs/source/index.rst index 6011f29193..c019b44e7e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -118,7 +118,6 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik :caption: Experimental dev tutorials notebooks/dev_tutorials/extend - notebooks/dev_tutorials/patterns .. TODO RC1: update notebooks diff --git a/examples/lit-gpt/_ddp_thunder.py b/examples/lit-gpt/_ddp_thunder.py index 8d53a567a4..1bd07619df 100644 --- a/examples/lit-gpt/_ddp_thunder.py +++ b/examples/lit-gpt/_ddp_thunder.py @@ -199,8 +199,8 @@ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: f" Got: {module.__class__.__name__}." ) - # see https://github.com/Lightning-AI/lightning-thunder/issues/2085 - # for why we cannot just return `module.no_sync()` + # issue "Limitations of the current DDP no_sync implementation" has + # details on why we cannot just return `module.no_sync()` from thunder.distributed import skip_data_parallel_grad_sync previous, self._enabled = self._enabled, enabled diff --git a/examples/lit-gpt/train.py b/examples/lit-gpt/train.py index bf5b5e1e6e..412711ce5a 100644 --- a/examples/lit-gpt/train.py +++ b/examples/lit-gpt/train.py @@ -15,7 +15,7 @@ def main(compile: str = "eager", dynamic: bool = False) -> None: fabric = L.Fabric(devices=1, precision="bf16-true") - fabric.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) + fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) config = Config.from_name(model_name) print(f"Loading model with {config.__dict__}") diff --git a/examples/lit-gpt/train_fsdp.py b/examples/lit-gpt/train_fsdp.py index c855d61b92..e896d52ef3 100644 --- a/examples/lit-gpt/train_fsdp.py +++ b/examples/lit-gpt/train_fsdp.py @@ -38,7 +38,7 @@ def main( fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true") fabric.launch() - fabric.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) + fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) config = Config.from_name(model_name) fabric.print(f"Loading model with {config.__dict__}") diff --git a/examples/llama2.c/model.py b/examples/llama2.c/model.py index aaf4aad819..297af9e1f6 100644 --- a/examples/llama2.c/model.py +++ b/examples/llama2.c/model.py @@ -65,7 +65,6 @@ def apply_rotary_emb( xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting - # https://github.com/Lightning-AI/lightning-thunder/issues/1106 a, b = freqs_cos.shape freqs_cos = freqs_cos.view(1, a, 1, b) freqs_sin = freqs_sin.view(1, a, 1, b) @@ -244,7 +243,7 @@ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - # https://github.com/Lightning-AI/lightning-thunder/issues/1108 + # see issue "Unexpected KeyError when self attribute is set inside forward" #self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position @@ -258,7 +257,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # create optim groups. Any parameter that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] diff --git a/examples/llama2.c/sample.py b/examples/llama2.c/sample.py index 9184340203..b8ccacfa48 100644 --- a/examples/llama2.c/sample.py +++ b/examples/llama2.c/sample.py @@ -20,11 +20,10 @@ temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability tokenizer = "" # override the tokenizer model path -seed = 1337 +seed = 42 device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' -compile = True # Use lightning.compile to compile the model to be faster +compile = True # Use thunder.jit to compile the model to be faster exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- @@ -33,7 +32,6 @@ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype) @@ -57,10 +55,10 @@ from thunder.executors.sdpaex import sdpa_ex executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] - cmodel = thunder.compile(model, disable_torch_autograd_support=True, executors_list=executors) + cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors_list=executors) # the generate implementation is not compile friendly, so bind the compiled model to the generate implementation generate = partial(Transformer.generate, cmodel) - # workaround for https://github.com/Lightning-AI/lightning-thunder/issues/954 + # workaround for "Foward nn.Module attributes through the ThunderOptimizedModule" cmodel.params = model.params else: generate = model.generate diff --git a/examples/llama2.c/tinystories.py b/examples/llama2.c/tinystories.py index cafc1b164a..5ef5c6a247 100644 --- a/examples/llama2.c/tinystories.py +++ b/examples/llama2.c/tinystories.py @@ -191,7 +191,7 @@ def __iter__(self): # get DDP rank info rank = dist.get_rank() if dist.is_initialized() else 0 # combine the worker_id and worker_rank to create a unique seed for rng - seed = 42 + worker_id + 1337 * rank + seed = 42 + worker_id + 1942 * rank rng = random.Random(seed) print(f"Created a PretokDataset with rng seed {seed}") if self.vocab_source == "llama2": diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py index 18290df075..58d88d4729 100644 --- a/examples/llama2.c/train.py +++ b/examples/llama2.c/train.py @@ -70,7 +70,6 @@ warmup_iters = 1000 # how many steps to warm up for # system device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # dtype = "bfloat16" # float32|bfloat16|float16 compile = "thunder" # eager|torch|thunder # ----------------------------------------------------------------------------- @@ -118,12 +117,11 @@ if master_process: os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(1337 + seed_offset) +torch.manual_seed(42 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast # note: float16 data type will automatically use a GradScaler -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype) @@ -313,7 +311,7 @@ def get_lr(it): if ddp: # in DDP training we only need to sync gradients at the last micro step. # the official way to do this is with model.no_sync() context manager, but - # I really dislike that this bloats the code and forces us to repeat code + # this forces us to repeat code. # looking at the source of that context manager, it just toggles this variable train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: diff --git a/notebooks/dev_tutorials/extend.ipynb b/notebooks/dev_tutorials/extend.ipynb index 8b2fbe9036..304f6d9f6f 100644 --- a/notebooks/dev_tutorials/extend.ipynb +++ b/notebooks/dev_tutorials/extend.ipynb @@ -70,7 +70,7 @@ "source": [ "# Our operator executor will use the \"multimul\" function as a new example operator.\n", "# This function uses NumPy to perform two multiplications of four inputs.\n", - "# This functions very contrived, but will be useful to illustrate the extend submodule's capabilities.\n", + "# This function's contrived, but will be useful to illustrate the extend submodule's capabilities.\n", "def multimul_impl(\n", " a: Number | torch.Tensor, \n", " b: Number | torch.Tensor,\n", diff --git a/notebooks/dev_tutorials/patterns.ipynb b/notebooks/dev_tutorials/patterns.ipynb deleted file mode 100644 index b22a8598fd..0000000000 --- a/notebooks/dev_tutorials/patterns.ipynb +++ /dev/null @@ -1,441 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Thunder pattern matching for transformations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This developer tutorial discusses patterns -- sequences of operations that can be matched and replaced with traceable functions. \n", - "# It's a work-in-progress, and it currently only discusses how patterns can be constructed and how they're matched,\n", - "# along with some related utilities." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Imports the modules, classes, and functions we'll need for this tutorial\n", - "import torch\n", - "\n", - "import thunder\n", - "from thunder.core.patterns import Pattern, bind_names, numbered_ancestors\n", - "from thunder.core.proxies import TensorProxy\n", - "from thunder.core.symbol import BoundSymbol" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# To match a pattern, start by creating a Pattern object\n", - "p = Pattern()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Then define one or more \"matchers\" that determine if a BoundSymbol is a \"match\", and add them to the \n", - "# pattern using its match() method\n", - "\n", - "# The matcher signature not only accepts a BoundSymbol to review, but also a list of BoundSymbols that were\n", - "# already matched by the pattern, and a match_ctx dictionary that contains whatever state you like from previous matches\n", - "# The matcher returns True if the BoundSymbol should be matched, and False otherwise. When returning True it should return\n", - "# a dict that will be used to update the match_ctx for future matches. This will be clearer in a moment with an example.\n", - "# The following matcher is very permissive, and it matches any add operation.\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "a = torch.randn((2, 2))\n", - "b = torch.randn((2, 2))\n", - "\n", - "# An example program that performs an addition and a subtraction\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# The matcher is told to match any addition\n", - "p.match(add_matcher)\n", - "\n", - "# Calling the Pattern object on a trace returns a list of matches. \n", - "# Each match is a list of (int, BoundSymbol) tuples, where int is the \n", - "# position of the BoundSymbol in the trace.\n", - "matches = p(trc)\n", - "\n", - "# In this case there is just one match -- the first addition\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\")], [(3, t1 = ltorch.add(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.add(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "def foo(a, b):\n", - " c = a + b\n", - " d = a + b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# When the program is changed to include two additions, both additions\n", - "# are matched and two matches are created.\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.add(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.add(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# In addition to matching a single operation, a pattern can match any number of sequential operations --\n", - "# that is, operations that are immediately adjacent to each other. We do this by providing\n", - "# max_times and (optionally) min_times arguments to match()\n", - "# Negative max_times values are interpreted as matching the pattern any number of times\n", - "# Matching multiple operations occurs greedily and before any additional matching can occur\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher, min_times=1, max_times=-1)\n", - "\n", - "# The pattern now matches once, and the single match contains both additions\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.sub(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.sub(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Multiple operations can also be matched by calling match() multiple times. Each \n", - "# match() attempts to evaluate itself in the order it's called.\n", - "\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# Let's match an addition followed by a subtraction on the same inputs. This will also show how to update the match_ctx dict\n", - "# and let us use the bind_names() utility.\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " # bind_names() produces an object with properties corresponding to the function's (Symbol's) parameters, when\n", - " # accessed they return their corresponding arguments\n", - " bn = bind_names(bsym)\n", - " # Stores the inputs in the context\n", - " return True, {'a': bn.a, 'b': bn.b}\n", - " \n", - " return False, None\n", - "\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " bn = bind_names(bsym)\n", - "\n", - " # Acquires the previously stored values from the match_ctx\n", - " a = match_ctx['a']\n", - " b = match_ctx['b']\n", - "\n", - " # Matches the sub only if the arguments are the same as the addition's, and in the same order\n", - " if a is bn.a and b is bn.b:\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "\n", - "# Matches the addition and the subtraction\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Another version of the above example that uses the previously_matched argument to decide whether to match\n", - "# the subtraction\n", - "\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " # Doesn't update the context -- the context is just scratch space for you\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " my_bn = bind_names(bsym)\n", - "\n", - " add_bsym = previously_matched\n", - " add_bn = bind_names(add_bsym)\n", - "\n", - " # Matches the sub only if the arguments are the same as the addition's, and in the same order\n", - " if add_bn.a is my_bn.a and add_bn.b is my_bn.b:\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "\n", - "# Matches the addition and the subtraction\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (4, t2 = ltorch.sub(a, b, alpha=None) # t2: \"cpu f32[2, 2]\"\n", - " # t2 = prims.sub(a, b) # t2: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Operations in a pattern don't have to be next to each other, but they do have to within a \"window\" of \n", - "# the previous operation. Currently the window is 5 operations. Each operation also has to be \n", - "# \"reorderable\" to be \"next to\" operations that were already matched. This isn't always\n", - "# possible. If an operation consumes an input that is not directly from a previously matched symbol, but\n", - "# is derived from the output of a previously matched symbol, then it cannot be reordered adjacent to the \n", - "# other operations in the pattern.\n", - "# Let's see how this works with two examples.\n", - "\n", - "# An operation between the first addition and second subtraction doesn't stop the previous pattern\n", - "# from matching as expected, because the operation producing d can be reordered to be \n", - "# immediately after the operation producing c\n", - "def foo(a, b):\n", - " c = a + b\n", - " x = a + 2\n", - " d = a - b\n", - " return c, d, x\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# The match is the same as when x isn't computed\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" - ] - } - ], - "source": [ - "# Too many intervening operations pushes the subtraction out of pattern matching \"window\" and prevents\n", - "# the match\n", - "# In the future we may expose an option to set the window larger -- share your thoughts by filing an issue!\n", - "def foo(a, b):\n", - " c = a + b\n", - " x = a + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " d = a - b\n", - " return c, d, x\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# No matches because the computation of c and the computation of d are separated by too many operations\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" - ] - } - ], - "source": [ - "# The computation of e depends on the computation of d and the computation of c\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = c - 5\n", - " e = c + d\n", - " return e\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(add_matcher)\n", - "\n", - "# Attempting to match two additions fails, because the computation of e cannot be reordered next to the computation of c\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.sub(t0, 5, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # _ = prims.convert_element_type(5, float)\n", - " # t1 = prims.sub(t0, 5.0) # t1: \"cpu f32[2, 2]\"), (4, t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[2, 2]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Including the subtraction in the pattern allows it to be matched\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "p.match(add_matcher)\n", - "\n", - "matches = p(trc)\n", - "print(matches)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/setup.py b/setup.py index 0f20566fcd..566f2bbc40 100755 --- a/setup.py +++ b/setup.py @@ -46,10 +46,8 @@ def _prepare_extras( about = _load_py_module("__about__.py") # https://packaging.python.org/discussions/install-requires-vs-requirements / -# keep the meta-data here for simplicity in reading this file... it's not obvious -# what happens and to non-engineers they won't know to look in init ... -# the goal of the project is simplicity for researchers, don't want to add too much -# engineer specific practices +# keep the meta-data here for simplicity in reading this file. it's not obvious +# what happens and to non-engineers they won't know to look in init. setup( name="lightning-thunder", version=about.__version__, diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index 6c9e5e889b..c46a699ca5 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -211,7 +211,7 @@ def parse_args() -> argparse.Namespace: # TODO Port these benchmarks to pytest (and targets.py) -# See https://github.com/Lightning-AI/lightning-thunder/issues/1404 +# See issue "Create distributed pytest benchmarks" if __name__ == "__main__": args = parse_args() diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index bd54735beb..9f02b23d35 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -289,7 +289,9 @@ def wrapper(*args, **kwargs): return wrapper -# To compare with PyTorch and torchcompile +# To compare with PyTorch and raw torch.compile (i.e. not through thunder). The +# latter can help us isolate whether it's something we need to fix ourself or +# report upstream. torch_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_executor) torchcompile_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_compile_executor) @@ -432,7 +434,8 @@ def test_nanogpt_gelu_grad(benchmark, executor: Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", fwd_executors, @@ -454,7 +457,8 @@ def test_nanogpt_cross_entropy_fwd(benchmark, executor: None | Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", (grad_executors + apex_grad_executors), @@ -476,7 +480,8 @@ def test_nanogpt_cross_entropy_grad(benchmark, executor: None | Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", (fwd_executors + cudnn_layernorm_fwd_executors), diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index c0c711f5b8..3fd192802f 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -20,7 +20,7 @@ from thunder.core.proxies import TensorProxy, pyval, pytype, proxy, AnyProxy, Proxy import thunder.core.devices as devices -# This file defines the operations in lightning.compile's "core" language. +# This file defines the operations in thunder.jit's "core" language. # # These operators are intended to be used when defining user-facing languages, like the torch or NumPy # languages. @@ -1002,7 +1002,7 @@ def stride_order(a: TensorLike, order: None | Sequence[int] = None) -> TensorLik .. note:: - No other lightning.compile operations specify how their outputs are represented in memory, and lightning.compile + No other thunder.jit operations specify how their outputs are represented in memory, and thunder.jit does not model strides. This operation is an explicit directive to construct a dense, non-overlapping and strided tensor, but operations on that tensor do not have to preserve those properties. """ diff --git a/thunder/common.py b/thunder/common.py index 914f7e5359..67afb496a0 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -157,7 +157,7 @@ def __init__( ): # Records whether we're using the thunder.jit() entrypoint or not # The thunder.jit() entrypoint introduces important architectural updates, - # but some components are still designed to work with older architectures for + # but some components are still designed to work with the older entrypoint # and are being temporarily maintained to facilitate their development. self.using_jit = using_jit @@ -244,7 +244,7 @@ def processed_function(self): def _unpack_inputs(fn, tracectx: TraceCtx, args, kwargs, *, rename_proxies: bool): tracectx.unpacking() - # Translates tensors, arrays, and dtypes to lightning.compile types + # Translates tensors, arrays, and dtypes to thunder.jit types # TODO Translate NumPy dtypes def translate(x: Any, *, name: str | None = None) -> Any: # NOTE Unpacking proxies @@ -628,7 +628,7 @@ def _execute_trace( # Constructs the Python callable c = extrace.python_callable() - # TODO RC1 Remove this option (by modeling torch.compile as another executor) + # TODO RC1 Remove this option (by using the torch.compile executor) if compile_data.use_torch_compile: c = torch.compile(c) @@ -647,7 +647,7 @@ def _execute_trace( # TODO review functions which compute large objects unrelated to tensors and how # they're handled # TODO can the language context be detected from the inputs? -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/316 +# TODO: # Today all tensor outputs will be torch tensors, even if the input was NumPy arrays # provided in the NumPy language ctx -- what should the outputs be? Should we provide # a helper to convert torch tensors to NumPy arrays on output? @@ -766,7 +766,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_stop = time.time_ns() return result - # TODO Revisit compile() behavior when hit in a trace ctx + # TODO Revisit jit() behavior when hit in a trace ctx # This will inline the invocation of compile into the current # trace (UNLESS there was a cache hit, per above) # This interaction between the cache and tracing seems odd diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 904d371b88..b6563bbe13 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -3597,7 +3597,7 @@ def _check_exc_match_handler(inst: dis.Instruction, /, stack: InterpreterStack, stack.append(isinstance(left, right)) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1523 +# TODO See issue "Fix COMPARE_OP handler" # https://docs.python.org/3.10/library/dis.html#opcode-COMPARE_OP @register_opcode_handler("COMPARE_OP") def _compare_op_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: @@ -4212,8 +4212,6 @@ def _jump_backward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> # https://docs.python.org/3.11/library/dis.html#opcode-JUMP_BACKWARD_NO_INTERRUPT -# TODO: we currently ignore the NO_INTERRUPT part, -# https://github.com/Lightning-AI/lightning-thunder/issues/1631 @register_opcode_handler("JUMP_BACKWARD_NO_INTERRUPT", min_ver=(3, 11)) def _jump_backward_no_interrupt_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> int: assert type(inst.arg) is int @@ -4490,7 +4488,6 @@ def _load_global_handler( return check_and_append(stack, obj) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1525 # https://docs.python.org/3.11/library/dis.html#opcode-LOAD_METHOD @register_opcode_handler("LOAD_METHOD") def _load_method_handler( @@ -4524,7 +4521,6 @@ def _load_method_handler( stack.append(meth) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1661 # https://docs.python.org/3.11/library/dis.html#opcode-LOAD_NAME @register_opcode_handler("LOAD_NAME") def _load_name_handler( @@ -4567,7 +4563,6 @@ def _make_cell_handler(inst: dis.Instruction, /, frame: InterpreterFrame, **kwar frame.localsplus[i] = c -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1526 # https://docs.python.org/3.10/library/dis.html#opcode-MAKE_FUNCTION @register_opcode_handler("MAKE_FUNCTION") def _make_function_handler( @@ -5077,7 +5072,6 @@ def do_raise(exc: Any = Py_NULL(), cause: Any = Py_NULL()) -> Literal[INTERPRETE return INTERPRETER_SIGNALS.EXCEPTION_RAISED -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1660 # https://docs.python.org/3.11/library/dis.html#opcode-PRINT_EXPR @register_opcode_handler("PRINT_EXPR") def _print_expr_handler( @@ -5350,7 +5344,6 @@ def impl(tos, name, tos1): return res -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1552 # https://docs.python.org/3.10/library/dis.html#opcode-STORE_DEREF @register_opcode_handler("STORE_DEREF") def _store_deref_handler( @@ -5651,7 +5644,7 @@ def _send_handler( ) -> None | int | INTERPRETER_SIGNALS: # SEND(delta) # Equivalent to STACK[-1] = STACK[-2].send(STACK[-1]). Used in yield from and await statements. - # If the call raises StopIteration, pop the top value from the stack, push the exception’s value attribute, and increment the bytecode counter by delta. + # If the call raises StopIteration, pop the top value from the stack, push the exception's value attribute, and increment the bytecode counter by delta. assert isinstance(inst.arg, int) send_value = stack.pop() generator = stack[-1] @@ -6333,7 +6326,7 @@ def _run_frame( assert len(frame.interpreter_stack) >= try_block.level + 3 with frame.interpreter_stack.set_cur_instruction(PseudoInst.EXCEPTION_HANDLER): del frame.interpreter_stack[try_block.level + 3 :] - exc_type = frame.interpreter_stack.pop() # we ignore that and asume == type(exc_value) + exc_type = frame.interpreter_stack.pop() # we ignore that and assume == type(exc_value) exc_value = frame.interpreter_stack.pop() exc_traceback = frame.interpreter_stack.pop() if exc_value != None: diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index b3e2acfb23..04349bbfe4 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -193,7 +193,8 @@ def is_uncopyable(val: Any, /) -> bool: # - calling a function with a side effect (e.g. randn, print) # TODO RC1 What kind of error should a sharp edge raise? # TODO RC1 Improve sharp edges warnings and errors to show the source line -# https://github.com/Lightning-AI/lightning-thunder/issues/2099 +# See issue "jit: Improve "sharp edges" errors and warnings to show the sharp +# edge's source location" # Context for the minimal interpreter @@ -643,7 +644,6 @@ def decorator(fn: Callable): # general_jit lookasides # -# TODO Add all general_jit operation translations (see https://github.com/Lightning-AI/lightning-thunder/issues/1804) _general_jit_lookaside_map = {} diff --git a/thunder/core/langctxs.py b/thunder/core/langctxs.py index 4f9031d64e..bcc12b6141 100644 --- a/thunder/core/langctxs.py +++ b/thunder/core/langctxs.py @@ -9,7 +9,7 @@ # Context variables, context managers, and helpers related to setting the language context. # The language context is a context variable that determines how methods on proxies are resolved. # For example, in NumPy, ndarray.size returns the number of elements in the array. In PyTorch, -# torch.Tensor.size(dim=None) returns the tenor's shape when dim is None, and the length of the +# torch.Tensor.size(dim=None) returns the tensor's shape when dim is None, and the length of the # specified dimension when dim specifies a dimension (using an integer offset). # diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index f8016f5ea3..92138bfcbc 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -793,7 +793,7 @@ def __rxor__(self, other): # # Shift operations # - # Issue https://github.com/Lightning-AI/lightning-thunder/issues/594 + # Issue "Implement logical and arithmetic left and right shifts" # tracks implementing these def __lshift__(self, other): diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 354e7ddc18..cb22553aa0 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -109,10 +109,10 @@ def default_python_printer( # A symbol represents a function and how it can be transformed # name is a string name for the operation -# meta should use lightning.compile functions to evaluate the function; -# it will be called with lightning.compile proxies +# meta should use thunder.jit functions to evaluate the function; +# it will be called with thunder.jit proxies # id is an optional value to use when translating the function to executors -# is_prim should be True if the Symbol represents a lightning.compile primitive +# is_prim should be True if the Symbol represents a thunder.jit primitive # python_printer is a function that will produce valid Python for calling the # operation; this can usually be set to None, in which case the default python # printer will be used for the Symbol. Symbols that control their own printing @@ -196,14 +196,6 @@ def module(self) -> None | ModuleType: result = inspect.getmodule(fn_) return result - # Properties used in transforms (defined later) - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/326 - # Remove this from here (think how symbols could be extended with transforms) - # self.grad_defined = False - # self.grad_ignored = False - # self.grad_fwd = None - # self.grad_bwd = None - def __repr__(self) -> str: return f"[Symbol name={self.name}]" @@ -313,7 +305,6 @@ def __post_init__(self): # Constructs a new BoundSymbol with default values taken from this BoundSymbol # Override values can be specified as kwargs - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/680 # Issue -- Provide a pattern for updating subsymbols when swapping outputs # Maybe this can also just swap one set of symbols for another? # Consider adding verification that the new and old output have the same metadata diff --git a/thunder/core/trace.py b/thunder/core/trace.py index 4e2bed9eba..88f1950585 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -19,7 +19,7 @@ from thunder.core.codeutils import ContextObject -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/327 +# TODO see issue "Improve TraceProvenance" # Make this more interesting / printer better -- maybe let # practitioners acquire the pass callable so they can replicate the pass? # This class is intended to describe how the trace was constructed @@ -36,7 +36,7 @@ def __repr__(self) -> str: # TODO Should traces be BoundSymbols? -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/323 +# TODO issue "Create a mechanism for freezing TraceCtx objects" # Add validation that a constant is never assigned to / reassigned # Possibly separate the ideas of a trace -- a series of scopes containing bound symbols -- # and a TraceCtx, which can produce new traces @@ -303,7 +303,7 @@ def python_ctx(self) -> dict: return import_ctx # TODO Account for multi-line signatures - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/324 + # TODO issue "Add type annotations to Python function produced by traces" # Consider extending the signature with type information, in particular the # the type information of the return value might be interesting def python(self, *, print_depth: int = 1) -> str: @@ -395,7 +395,7 @@ def keyfn(class_or_module: type | ModuleType) -> str: reset_tracectx(token) # Returns a Python callable that executes the trace - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/323 + # TODO issue "Create a mechanism for freezing TraceCtx objects" # Create a mechanism for freezing traces and cache the compilation def python_callable(self, *, global_dicts: None | dict = None) -> Callable: python_str: str diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 7b602084b6..fa74313bfe 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -134,9 +134,9 @@ def replace_redundant_inputs( return new_bsyms -# TODO(crcrpar): Implement a mechanism to keep track of supported ops that cannot be CSE'd. -# For example, `uniform`, `dropout`, and `scaled_dot_product_attention`. -# See: https://github.com/Lightning-AI/lightning-thunder/issues/671 +# These are ops that are not referentially transparent. We need to treat such +# ops specially when optimizing; for example, CSE cannot coalesce two calls +# into one for ops in this set. NON_FUNCTIONAL_OPS: set[prims.PrimIDs | str] = { prims.PrimIDs.UNIFORM, "torch.uniform", # this doesn't exist as of the PR diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index b05351ca14..8d18d40904 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -540,7 +540,7 @@ def _flatten(bsym: BoundSymbol): # # -# Functions related to functionalizing TOMs +# Functions related to functionalizing ThunderOptimizedModules # @@ -1583,7 +1583,7 @@ def _selector(eligible_nodes: list[Node]) -> int: return gradtrc - # NOTE This is a kludge to indicate that we shouldn't support PyTorch's autograd because + # NOTE This is a kludge to indicate that we shouldn't use PyTorch's autograd because # we're using our own autograd transform cfn._using_grad_transform = True @@ -1850,7 +1850,6 @@ def broadcast_in_dim_vmap( ) -> BatchedValue: bdim = a.batch_dim # TODO: remove this when shape and broadcast_dimensions become mandatory kwargs - # See https://github.com/Lightning-AI/lightning-thunder/issues/181 shape, _ = safe_zip(*shape) if len(broadcast_dimensions) > 0: broadcast_dimensions, _ = safe_zip(*broadcast_dimensions) @@ -2167,7 +2166,6 @@ def broadcast_in_dim_jvp(a: JVPDual, shape: tuple[JVPDual, ...], broadcast_dimen x, xd = a # TODO: shape and broadcast_dimensions should be tuples of ints # but for now it's a tuple of JVPDuals - # See https://github.com/Lightning-AI/lightning-thunder/issues/181 if len(shape) > 0 and isinstance(shape[0], JVPDual): shape, _ = safe_zip(*shape) if len(broadcast_dimensions) > 0 and isinstance(broadcast_dimensions[0], JVPDual): @@ -3563,7 +3561,6 @@ def put_grad(v: Variable, val: Any) -> None: if symbol.sym.id == "torch.nn.functional.dropout" and not symbol.subsymbols: # We can skip the pullback if the dropout probability is 0.0 # Assuming that the dropout symbol has the same output and argument - # https://github.com/Lightning-AI/lightning-thunder/issues/906 assert symbol.output.name == symbol.args[0].name, "Dropout symbol has a different output and argument" if symbol.args[1] == 0.0 or symbol.args[2] is False: continue @@ -3642,7 +3639,7 @@ def is_differentiable(arg): result = tuple(next(iter_result) if is_differentiable(arg) else None for arg in symbol.args) - # See https://github.com/Lightning-AI/lightning-thunder/issues/977. + # See "Backward impl for ops of the type Sequence[TensorProxy], ... -> ... results in None grads." # This is a temporary workaround. if symbol.sym.id in (prims.PrimIDs.CAT, "torch.cat", "torch.stack"): safe_map_flat(put_grad, symbol.args, result) @@ -3679,7 +3676,8 @@ def vjp_call_metafunc(detached: bool, primals, cotangents, trace: Trace, **kwarg # TODO: Can't use a Symbol here because mixed executor sybsymbols seem to be -# unsupported. See https://github.com/Lightning-AI/lightning-thunder/issues/1308 +# unsupported. See issue "Could not find an executor for bound symbol when its subsymbols +# are not fully supported by a single executor" vjp_call = partial( vjp_call_metafunc, False ) # Symbol(id=Transforms.VjpOp, name="vjp_call", meta=partial(vjp_call_metafunc, False)) @@ -3804,7 +3802,7 @@ def unpacking_fn(saved_for_backward, cotangents): # NOTE: Returning namedtuples from compiled functions doesn't work. See: -# https://github.com/Lightning-AI/lightning-thunder/issues/881 +# "Allow returning namedtuples from compiled functions" # Note [Grad forward output spec] # If it did work it would be nice to use this namedtuple # instead of the plain tuple or dict that we're using now. diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index ecb280149f..39ae65bda5 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -234,7 +234,7 @@ def main(): # Starts broadcasts # TODO Make these broadcast asyncs # TODO Perform up to two broadcasts at a time - # https://github.com/Lightning-AI/lightning-thunder/issues/727 + # See issue "Update ddp to use async broadcasts" # TODO "Bucket" small tensors together before broadcasting with torch.no_grad(): for param in model.parameters(): diff --git a/thunder/distributed/transforms/fsdp.py b/thunder/distributed/transforms/fsdp.py index c07afce08e..ad92f5e6db 100644 --- a/thunder/distributed/transforms/fsdp.py +++ b/thunder/distributed/transforms/fsdp.py @@ -198,7 +198,7 @@ def maybe_swap_proxies_of_bsym_and_update_swap_map(bsym: BoundSymbol) -> bool: lambda: f"{variableify(param)} not found in param set: {(variableify(p) for p in self.original_params)}", ) if param not in self.param_to_bucket: - # This path is hihly likely to be backward reduce-scatter bucketing: + # This path is highly likely to be backward reduce-scatter bucketing: # when a param does not require grad, a trace could still have reduce-scatter # and wait in its trace while the grad in the return statement is already # replaced with `None`. diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index f34f98f3ad..9d9fa7bf42 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -52,7 +52,6 @@ def key(node: Node) -> int: # TODO: Currently prefer the most memory-efficient way for ZeRO3, -# https://github.com/Lightning-AI/lightning-thunder/issues/1925 # Need a strategy to balance the efficiency # and memory usage in the future def sort_waits_for_zero3(execution_trace): diff --git a/thunder/examine/__init__.py b/thunder/examine/__init__.py index 23661586bb..c75bc5f3b7 100644 --- a/thunder/examine/__init__.py +++ b/thunder/examine/__init__.py @@ -43,7 +43,7 @@ def __exit__(self, exc_type, exc_value, traceback): # TODO Maybe have this print additional information and return more metadata? -# TODO Accept kwargs for compile (like langctx) +# TODO Accept kwargs for jit (like langctx) # TODO Add profiling (or profiling option) to determine if we have a slowdown # TODO If an error occurs, try to minify the program to produce a smaller sample to reproduce the error def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): @@ -141,7 +141,7 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): return - # Step 3 Attempts to compile the function using lightning.compile + # Step 3 Attempts to compile the function using thunder.jit try: cfn = thunder.jit(fn) except Exception as e: @@ -151,7 +151,7 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): ) raise e - # Step 4 Attemps to execute the function using lightning.compile + # Step 4 Attempt to execute the function using thunder.jit lc_result: Any try: lc_result = cfn(*args, **kwargs) diff --git a/thunder/examine/memory_caculation.py b/thunder/examine/memory_caculation.py index 45cc74a9a0..149349748c 100644 --- a/thunder/examine/memory_caculation.py +++ b/thunder/examine/memory_caculation.py @@ -22,7 +22,7 @@ "torch_wait_prim_impl", ) -# A whitelist registry of symbols that require special memory calculation; +# A registry of symbols that require special memory calculation; # if not registered, the default memory calculation function is used. memory_calculate_impls: dict[Symbol, Callable] = dict() diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py index 7d622cf25c..8a82e04e20 100644 --- a/thunder/executors/apex_entropyex.py +++ b/thunder/executors/apex_entropyex.py @@ -49,7 +49,8 @@ def apex_available() -> bool: # TODO Consider performing the reduction as part of a traceable epilogue -# See https://github.com/Lightning-AI/lightning-thunder/issues/1357 +# See "Update the apex cross entropy executor to put its reduction in a +# traceable epilogue" # NOTE Apex's cross entropy doesn't accept ignore_index >= 0, or the weight, size_average, or reduce parameters def _apex_cross_entropy_impl( a: torch.Tensor, @@ -196,12 +197,10 @@ def _cross_entropy_checker( return True -# Check out -# https://github.com/Lightning-AI/lightning-thunder/blob/main/dev_tutorials/thunder-add-vjp-rule.md -# for a tutorial on how to add a VJP rule for any Symbol. We use our new -# primitives to register a VJP rule for torch.nn.functional.cross_entropy. This -# function is registered as the augmented forward rule for -# torch.nn.functional.cross_entropy below +# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any +# Symbol. We use our new primitives to register a VJP rule for +# torch.nn.functional.cross_entropy. This function is registered as the +# augmented forward rule for torch.nn.functional.cross_entropy below def apex_cross_entropy_forward_rule( a, target, diff --git a/thunder/executors/cudnn_layernormex.py b/thunder/executors/cudnn_layernormex.py index 1a84811278..b6f260cdd7 100644 --- a/thunder/executors/cudnn_layernormex.py +++ b/thunder/executors/cudnn_layernormex.py @@ -19,9 +19,7 @@ def cudnn_available() -> bool: return CUDNN_AVAILABLE -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880~ - +# WARNING: cudnn layernorm executor is experimental. Tests that use cudnn might fail. from dataclasses import dataclass from functools import lru_cache from typing import Union, Dict diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 20759bb085..9fb6c50a48 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -21,9 +21,6 @@ def cudnn_available() -> bool: return CUDNN_AVAILABLE -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880~ - from dataclasses import dataclass from functools import lru_cache from typing import Union, Dict diff --git a/thunder/executors/nvfuserex.py b/thunder/executors/nvfuserex.py index 3e25890dde..27d4ce721b 100644 --- a/thunder/executors/nvfuserex.py +++ b/thunder/executors/nvfuserex.py @@ -33,7 +33,6 @@ def required_nvfuser_version() -> LooseVersion: return LooseVersion("0.0.1") -# NOTE We require nvFuser version 0.0.1 or greater def nvfuser_available() -> bool: v = nvfuser_version() return v is not None and v >= required_nvfuser_version() diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index d9edbe5283..5204ab1d24 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -206,7 +206,7 @@ def create_fd( # NOTE nvFuser's default max length is 1024 operations at the time of this writing # This arbitrarily increases it to 9999 # TODO Review splititng very large fusions or removing the max length restriction completely - # See https://github.com/Lightning-AI/lightning-thunder/issues/901 + # See "Very large nvFuser fusions hit max_length" fd = FusionDefinition(max_length=9999) with fd: # NOTE Adding constants is disabled for the moment in favor of definining them inline @@ -524,8 +524,7 @@ class nvFuserExecutor(FusionExecutor): def __init__(self): super().__init__("nvfuser", version=nvfuser.version()) - # TODO: Replace this with a query to current CompileData after - # https://github.com/Lightning-AI/lightning-thunder/pull/1517 is merged + # TODO: Replace this with a query to a compile option self._use_rematerialization = True fuel_str = os.getenv("NVFUSER_OPTIMIZATION_FUEL") @@ -762,8 +761,6 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: # TODO has_cuda_input_or_output is too restrictive a check on what should be fused # TODO check whether a function would output a CPU tensor? -- can nvFuser fuse such operations? # ex. device_put to a CPU device from a CUDA device - # (mruberry) I don't know if nvFuser even attempts to fuse any operation that can go - # cross-device today def _should_fuse(a: Node, b: Node): def _can_fuse_node(n: Node): # if already merged, then node can be fused @@ -842,8 +839,7 @@ def _can_fuse_node(n: Node): # Some of the operations might be better placed with its consumers (for # example residual connection in transformer block). This pass moves - # them to the consumer. See - # https://github.com/Lightning-AI/lightning-thunder/issues/1520 + # them to the consumer. if self._use_rematerialization: fusedtrace = rematerialize(fusedtrace) @@ -1091,8 +1087,8 @@ def broadcast_in_dim( def _cat_check(tensors: list[TensorProxy], dim: int) -> bool: - # nvFuser cat fusion is currently disabled due to - # https://github.com/Lightning-AI/lightning-thunder/issues/1071 + # nvFuser cat fusion is currently disabled due to issue: + # "nvFuser doesn't support cating with an empty tensor" return False # Validates tensors and concatenated dimension lengths @@ -1156,7 +1152,7 @@ def _pad_check(a: TensorProxy, padding_value: Number, padding_config: tuple[int, # nvFuser's pad op requires pad_widths to be a sequence of Python numbers # (lo_n, hi_n, lo_{n-1}, hi_{n-1}, ...) where dimensions are counted in reverse # as shown, and dilation is not supported. -# This is in constrant to lightning.compile's pad primitive, which specifies padding +# This is in constrast to thunder.jit's pad primitive, which specifies padding # and dilation as an ndim-length list of (lo, hi, dilation) triples. # NOTE padding_value must be an nvConstant (or nvScalar?) def pad( @@ -1262,7 +1258,8 @@ def squeeze(a: TensorProxy, /, dims: Sequence[int], *, fd: FusionDefinition, lc_ # register_supported(PrimIDs.TAKE, take, _take_check) # TAKE_ALONG_AXIS is currently disabled -# See https://github.com/NVIDIA/Fuser/issues/458 +# There was an nvFuser bug that prevented this which is now fixed; we should +# investigate re-enabling take_along_axis. # # TODO Check that the nvFuser version is >= 0.0.10 when this operator was added # def take_along_axis(a: TensorProxy, /, index: TensorProxy, dim: int, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any: # nv_a = getnv(a, fd, lc_to_nv_map) @@ -1715,12 +1712,6 @@ def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinitio nva = getnv(a, fd, lc_to_nv_map) nvb = getnv(b, fd, lc_to_nv_map) - # TODO nvFuser sometimes generates an innacurate result when dividing by a number - # Remove this workaround once the issue is fixed - # See: https://github.com/NVIDIA/Fuser/issues/160 - if isinstance(b, Number): - return fd.ops.mul(nva, fd.ops.reciprocal(nvb)) - # NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide # return fd.ops.div(nva, nvb) return fd.ops.mul(nva, fd.ops.reciprocal(nvb)) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 144aa30974..8f1604e718 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -162,8 +162,8 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) return extrace -# NOTE: See more details for motivation in the following issue: -# https://github.com/Lightning-AI/lightning-thunder/issues/515 +# This is needed to ensure that subsymbol changes are reflected in the Python +# code generator. def _update_fusion_call_ctx(bsym: BoundSymbol) -> BoundSymbol: """Update the call_ctx information of the fusion BoundSymbol object. diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 932fdc934f..005171e4e1 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -50,8 +50,8 @@ def ceil_div(a: int, b: int) -> int: def _sdpa_pad_head_dimension(a: torch.Tensor) -> torch.Tensor: head_size = a.shape[-1] - # NOTE short-circuit path when we already have compatible head_size - # See https://github.com/Lightning-AI/lightning-thunder/issues/1505 + # If the head is already a multiple of 8, then we don't need to pad. The + # pad op can be quite expensive in some cases. if head_size % 8 == 0: return a padding_size = ceil_div(head_size, 8) * 8 - head_size @@ -59,8 +59,7 @@ def _sdpa_pad_head_dimension(a: torch.Tensor) -> torch.Tensor: def _sdpa_slice_head_dimension(a: torch.Tensor, head_size: int) -> torch.Tensor: - # NOTE short-circuit path when we already have compatible head_size - # See https://github.com/Lightning-AI/lightning-thunder/issues/1505 + # ditto pad_head_dimension: the slice can be expensive, so skip if possible. if head_size % 8 == 0: return a return a[:, :, :, 0:head_size] @@ -491,8 +490,8 @@ def _scaled_dot_product_attention_fused( *, scale: None | float = None, ): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the various + # implementations, and this makes the decision cognizant of those cliffs. backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) utils.check( @@ -530,8 +529,8 @@ def _scaled_dot_product_attention_grad( *, scale: None | float = None, ): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the various + # implementations, and this makes the decision cognizant of those cliffs. backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) utils.check( @@ -640,8 +639,9 @@ def _fused_sdp_choice( is_causal = is_causal.value if LooseVersion(torch.__version__) < LooseVersion("2.2.0"): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the + # various implementations, and this makes the decision cognizant of + # those cliffs. backend = torch._fused_sdp_choice( fake_query, fake_key, diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 9d82e12526..a48635e878 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,7 +72,8 @@ def torch_interpreted_func(*args): # _transform_for_operator_executor_execution implementation that need to be # fixed first. One issue is that it doesn't maintain the ssa form of the # trace, which is needed for all the passes to work correctly. - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1767 + # TODO: issue "Try using _transform_for_operator_executor_execution for + # torch.compile executor" torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) compiled_func = torch.compile(torch_trace.python_callable()) @@ -84,13 +85,13 @@ def compiled_func_wrapper(*args): orig = getattr(torch._dynamo.eval_frame.guarded_backend_cache, "skip_backend_check_for_run_only_mode", None) try: # TODO: Remove this hack - # This is a hack to get around the fact that for some reason Dynamo - # doesn't recreate a guard for the compiled function called from the - # backward thread. This is a problem because the guard is created - # with the forward thread id, and the guard is not valid for the - # backward thread. I couldn't come up with a small repro to file an - # issue to PyTorch. - # https://github.com/pytorch/pytorch/issues/114674 + # Dynamo doesn't recreate a guard for the compiled function called + # from the backward thread. This is a problem because the guard is + # created with the forward thread ID, and the guard is not valid + # for the backward thread. + # Issue filed: https://github.com/pytorch/pytorch/issues/114674 + # We should be able to remove this hack once we're sure that the + # above fix has propagated to all supported PyTorch releases. torch._dynamo.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode = True return compiled_func(*args) finally: diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 85cfd9c2ce..535fddd943 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -633,7 +633,7 @@ def _stride_order_prim_impl(a: torch.Tensor, order: Sequence[int]) -> torch.Tens rsqrt = _register_torch_operation("rsqrt") # # NOTE That PyTorch's "sgn" corresponds with the "sign" primitive sgn = _register_torch_operation("sgn", like=ltorch.sign) -# # NOTE torch.sign isn't bound here because lightning.compile always uses sgn +# # NOTE torch.sign isn't bound here because thunder always uses sgn # sign = _register_torch_operation("sign") signbit = _register_torch_operation("signbit") sin = _register_torch_operation("sin") @@ -1095,8 +1095,8 @@ def _index_put_prim_transform( return index_put(a, indices, values, accumulate) -# NOTE torch.compile currently fails to compile scatter add in bfloat16 -# TODO RC1 Separate this into a torch.compile executor +# NOTE torch.compile has a compilation issue with scatter add in bfloat16, +# hence the special case here. # NOTE The scatter add transforms must set the torch language context explicitly so the .to() method # on tensors is resolved (alternatively they could explicitly call thunder.torch.to) @langctx(Languages.TORCH) @@ -1109,7 +1109,8 @@ def _scatter_add_prim_transform(a: TensorProxy, /, index: TensorProxy, value: Te return scatter_add(a, dim, index, value) -# NOTE torch.compile currently fails to compile scatter add in bfloat16 +# NOTE torch.compile has a compilation issue with scatter add in bfloat16, +# hence the special case here. @langctx(Languages.TORCH) def _scatter_add_transform(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: # NOTE scatter_add does not participate in type promotion, so if a has the bfloat16 dtype, then so does src @@ -1243,7 +1244,7 @@ def _cross_entropy_backward_impl( ) # TODO Add support nll_loss_nd, weight tensor, and label_smoothing options. - # See https://github.com/Lightning-AI/lightning-thunder/issues/704 + # See issue "Add support for remaining cross_entropy_loss arguments." utils.check(a.ndim <= 2 and target.ndim <= 1, lambda: f"multi-dimension cross-entropy is not supported.") utils.check(weight is None, lambda: f"weight tensor argument is not supported.") @@ -1601,7 +1602,7 @@ def _unpack_prim_impl( ) -> list[torch.Tensor]: return torch._utils._unflatten_dense_tensors(buffer, tensors) - # TODO(crcrpar): Make this compatible with the coming torch_compile executor as it's doing really well for cat and reshape. + # TODO(crcrpar): Make this compatible with the torch.compile executor as it's doing really well for cat and reshape. # NOTE(crcrpar): why no caching/resue of buffer? # This prim is only used by fsdp backward for now. # Bucketing of reduce-scatter, i.e., creating a buffer for @@ -1621,7 +1622,6 @@ def _unpack_prim_impl( # To support individual copies from gradient to its bucket requires a mask or an arrayy of indices to achieve correct behavior. # In PyTorch, the op for this is [`Tensor.index_copy_`](https://pytorch.org/docs/stable/generated/torch.Tensor.index_copy_.html) where even the index tensor needs to be on the same device as ``self`` and ``tensor``. # So caching of the bucketing for fsdp backward would bloat up the memory consumption, which is the main reason this doesn't do any caching. - # See https://github.com/Lightning-AI/lightning-thunder/pull/1669/commits/a942b87e88738ce94f874c21d4adc38749ff10d7#diff-c2fd275781ba0c4aa7eec811bebb7bf0b6ca52a236b510ce7dfbb831d4d9bb40R197-R233 for the potential implementation's clumisiness. # # example of two unsharded gradients of [4, 2] and [4], world size of 4: # -------- ------ diff --git a/thunder/numpy/langctx.py b/thunder/numpy/langctx.py index 0f6e445d8f..75961b6efb 100644 --- a/thunder/numpy/langctx.py +++ b/thunder/numpy/langctx.py @@ -23,7 +23,7 @@ def has_method(self, id: str) -> bool: return id in _method_name_to_fn_map def get_method(self, id: str, *args, **kwargs) -> Callable: - # Note: concrete implmenetations should only raise AttributeError or + # Note: concrete implementations should only raise AttributeError or # return None for "missing" methods as the proxies will # route __getattr__ to here and hasattr relies on __getattr__ # throwing AttributeError (only) when the attribute does diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 9b7ad2244f..c321a01394 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -134,7 +134,7 @@ def _run(cls, rank, test_name, file_name, pipe): "DDP test requires CUDA and NCCL `torch.distributed` backend", ) class CompileDDPTest(DataParallelTestCase): - # Ref: https://github.com/Lightning-AI/lightning-thunder/issues/646 + # Reference issue "Add an example of DDP(compile(model)) to tests" def test_ddp_compile_module(self): model = ToyModel().to(self.rank) ddp_model = DDP(thunder.jit(model, device_ids=[self.rank])) @@ -157,7 +157,7 @@ def test_ddp_compile_module(self): last_loss = loss.detach().item() assert init_loss > last_loss - # Ref: https://github.com/Lightning-AI/lightning-thunder/issues/599 + # Reference issue "[tracker] Support DistributedDataParallel" def test_compile_ddp_module(self): model = ToyModel().to(self.rank) with self.assertRaisesRegex( @@ -652,7 +652,7 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor): else: self.assertEqual(tuple(p.grad for p in cm.parameters() if p.grad is not None), gradients) - # TODO(crcrpar): Add torch compile to executors_list once it's available. + # TODO(crcrpar): Add torch compile to executors_list @common_utils.parametrize( "executor,bucketing_strategy,fsdptype", product( diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/lit_gpt_model.py index 32889c6cc5..57a85089bc 100644 --- a/thunder/tests/lit_gpt_model.py +++ b/thunder/tests/lit_gpt_model.py @@ -139,7 +139,7 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> k = self.k.index_copy_(2, input_pos, k) v = self.v.index_copy_(2, input_pos, v) return k, v - # THUNDER unsupported: https://github.com/Lightning-AI/lightning-thunder/issues/1145 + # See issue: "Support more indexing operators (index_copy and index_add)" k = self.k = torch.index_add(self.k, 2, input_pos, k) v = self.v = torch.index_add(self.v, 2, input_pos, v) # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) diff --git a/thunder/tests/llama2_model.py b/thunder/tests/llama2_model.py index bf70e56531..b8277cd757 100644 --- a/thunder/tests/llama2_model.py +++ b/thunder/tests/llama2_model.py @@ -64,7 +64,6 @@ def apply_rotary_emb( xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting - # https://github.com/Lightning-AI/lightning-thunder/issues/1106 a, b = freqs_cos.shape freqs_cos = freqs_cos.view(1, a, 1, b) freqs_sin = freqs_sin.view(1, a, 1, b) @@ -247,7 +246,8 @@ def forward(self, tokens: torch.Tensor, targets: torch.Tensor | None = None) -> if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - # https://github.com/Lightning-AI/lightning-thunder/issues/1108 + # Workaround for issue "Unexpected KeyError when self attribute is + # set inside forward" # self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position diff --git a/thunder/tests/nanogpt_model.py b/thunder/tests/nanogpt_model.py index 128ac8c0f4..9d81b3abb8 100644 --- a/thunder/tests/nanogpt_model.py +++ b/thunder/tests/nanogpt_model.py @@ -69,8 +69,8 @@ def __init__(self, config): # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") # NOTE: The original Karpathy's script hides bias registration behind a flag - # but we don't do that here. We always register bias, because of preprocessing bug: - # https://github.com/Lightning-AI/lightning-thunder/issues/605 + # but we don't do that here. We always register bias due to a now-fixed + # bug in thunder. # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. # if not self.flash: # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") @@ -173,8 +173,8 @@ def __init__(self, config): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. - # This behavior is deprecated and will be an error in future versions" - # not 100% sure what this is, so far seems to be harmless. TODO investigate + # This behavior is deprecated and will be an error in future versions". + # So far this seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights @@ -236,7 +236,6 @@ def forward(self, idx, targets=None): # NOTE: Advanced indexing is not yet supported in Thunder # RuntimeError: Advanced indexing currently only supports tensors as sequence elements # inference-time mini-optimization: only forward the lm_head on the very last position - # See https://github.com/Lightning-AI/lightning-thunder/issues/894 # logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim logits = self.lm_head(x) loss = None diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 30376202a7..d934e18dfd 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -770,8 +770,8 @@ def _abs_torch(x: torch.Tensor | Number): # digamma is defined for all complex numbers EXCEPT negative integers and zero digamma_opinfo = OpInfo( clang.digamma, - # NOTE: Restrict domain to avoid singularities because of - # https://github.com/Lightning-AI/lightning-thunder/issues/1138 + # NOTE: Restrict domain to avoid singularities because of issue + # "OpInfos do not use singularity_fn to produce "more stable" samples." domain=(eps, math.inf), # NOTE: digamma returns NaN for all negative integers. It returns -Inf when x = 0. singularity_fn=lambda x: torch.where(x > 0, x, (x - torch.round(x))), @@ -1079,7 +1079,7 @@ def _abs_torch(x: torch.Tensor | Number): dtypes=(datatypes.float16, datatypes.complex32), devicetypes=(devices.DeviceType.CPU,), ), - # see https://github.com/csarofeen/pytorch/issues/2367 + # Used to be an nvFuser bug here; TODO explore removing this xfail DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -1179,8 +1179,7 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.sgn), test_directives=( - # TODO Need to add nvfuser specific support for complex sign - # https://github.com/csarofeen/pytorch/issues/2492 + # TODO nvFuser needs support for complex sign DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -1284,7 +1283,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.tan), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1305,7 +1306,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.tanh), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1358,7 +1361,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1379,7 +1384,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log10), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1407,14 +1414,16 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.log1p), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complexfloating,), ), - # NOTE: Torch gives wrong result: https://github.com/pytorch/pytorch/issues/94333 + # NOTE: Torch has an issue: https://github.com/pytorch/pytorch/issues/94333 DecorateInfo( pytest.mark.skip, "test_core_vs_torch_consistency", @@ -1452,7 +1461,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log2), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1556,7 +1567,7 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): dtypes=(datatypes.float16,), devicetypes=(devices.DeviceType.CPU,), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1577,7 +1588,7 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): "test_core_vs_torch_consistency", dtypes=(datatypes.bool8,), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail(strict=False), "test_vjp_correctness", @@ -1613,7 +1624,7 @@ def selu_error_generator(op, device, dtype=torch.float32, **kwargs): datatypes.bfloat16, ), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1674,8 +1685,9 @@ def selu_error_generator(op, device, dtype=torch.float32, **kwargs): devicetypes=(devices.DeviceType.CPU,), active_if=LooseVersion(torch.__version__) < "1.13", ), - # TODO: nvfuser needs to return copy for integer dtypes. - # https://github.com/csarofeen/pytorch/issues/2499 + # TODO: nvFuser does not define an integer trunc() and thus compilation + # fails. They should probably map integer trunc() to an identity op. + # Until they do, this test won't work for integer types. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -1742,7 +1754,8 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_binary_generator, torch_reference=torch.add, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -1810,7 +1823,8 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_binary_generator, torch_reference=torch.copysign, test_directives=( - # See https://github.com/Lightning-AI/lightning-thunder/issues/2218 + # See issue: "flaky test: + # test_vjp_correctness_copysign_torch_cuda_float64 is flaky" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1827,8 +1841,7 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_comparison_generator, torch_reference=torch.eq, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1926,8 +1939,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.ge, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1969,8 +1981,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.le, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1987,8 +1998,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.lt, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2018,7 +2028,8 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_binary_generator, torch_reference=torch.mul, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -2040,8 +2051,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.ne, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2066,15 +2076,15 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): pytest.mark.skip, dtypes=(datatypes.float16, datatypes.bfloat16), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/972 - # PyTorch's nextafter may be causing CUDA illegal memory accesses + # TODO There was an issue with nextafter in PyTorch that should now be + # resolved; re-enable this and test. DecorateInfo( pytest.mark.skip, "test_core_vs_torch_consistency", devicetypes=(devices.DeviceType.CUDA,), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/972 - # PyTorch's nextafter may be causing CUDA illegal memory accesses + # TODO There was an issue with nextafter in PyTorch that should now be + # resolved; re-enable this and test. DecorateInfo( pytest.mark.skip, executors=("torch",), @@ -2101,8 +2111,8 @@ def polygamma_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs polygamma_opinfo = OpInfo( ltorch.polygamma, - # NOTE: Restrict domain to avoid singularities because of - # https://github.com/Lightning-AI/lightning-thunder/issues/1138 + # NOTE: Restrict domain to avoid singularities. See issue "OpInfos do not + # use singularity_fn to produce "more stable" samples" # NOTE: polygamma returns NaN, -Inf, or Inf for all negative integers. domain=(eps, math.inf), sample_input_generator=polygamma_sample_input_generator, @@ -2155,7 +2165,7 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe "test_core_vs_torch_consistency", dtypes=(datatypes.complex32,), ), - # See https://github.com/csarofeen/pytorch/issues/2361 + # TODO For complex numbers we have some numerical consistency issues. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -2241,7 +2251,8 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # torch doesn't support bool true_divide DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.bool8,)), - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_vjp_correctness", @@ -2384,7 +2395,8 @@ def addcmul_addcdiv_sample_generator(op, device, dtype, requires_grad, **kwargs) "test_core_vs_torch_consistency", dtypes=(datatypes.exact,), ), - # This test is flaky, see https://github.com/Lightning-AI/lightning-thunder/issues/2244 + # See issue "flaky test: + # test_vjp_correctness_addcdiv_nvfuser_cuda_float64" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2510,8 +2522,7 @@ def clamp_sample_generator(op, device, dtype, requires_grad, **kwargs): torch_reference=torch.clamp, dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating), test_directives=( - # This test is flaky - # See https://github.com/Lightning-AI/lightning-thunder/issues/1992 + # see issue "test_vjp_correctness_clamp_nvfuser_cuda_float64 is flaky" DecorateInfo( pytest.mark.skip, "test_vjp_correctness", @@ -2715,7 +2726,8 @@ def broadcast_in_dim_error_generator(op, device, **kwargs): pytest.mark.xfail, "test_errors", ), - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the number of + # non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -3207,8 +3219,7 @@ def pad_sample_generator(op, device, dtype, requires_grad, **kwargs): # Versions of above examples but with padding between elements set to 0 ((2, 2), ((1, 1, 0), (-1, 2, 0))), ((2, 0, 3), ((1, 0, 0), (1, 1, 0), (0, 0, 0))), - # See https://github.com/Lightning-AI/lightning-thunder/issues/415 - # The PyTorch lowering does not handle this case properly + # See issue "PyTorch pad prim lowering handles out-of-bands negative padding incorrectly" # ((7, 5), ((0, 0, 0), (-6, 2, 0))), ((5, 7), ((0, 0, 0), (-6, 2, 0))), ((3, 2, 5), ((-2, 1, 0), (1, -1, 0), (-1, 3, 0))), # negative pad in all 3 dims @@ -3235,12 +3246,12 @@ def _jax_pad(a, padding_value, padding_config): executors=("torch",), dtypes=(datatypes.complexfloating,), ), - # See issue https://github.com/Lightning-AI/lightning-thunder/issues/2053 + # See issue "pad+nvFuser: wrong results when applied to 1-numel inputs" DecorateInfo( pytest.mark.xfail, executors=("nvfuser",), ), - # See issue https://github.com/Lightning-AI/lightning-thunder/issues/2053 + # See issue "pad+nvFuser: wrong results when applied to 1-numel inputs" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3311,7 +3322,7 @@ def pad_torch_error_generator(op, device, dtype=torch.float32, **kwargs): # TODO: only remove these cases when the executor is nvfuser -# FIXME: Zero-dim cases are skipped due to https://github.com/csarofeen/pytorch/issues/2383 +# TODO: zero-dim cases had a bug, now fixed; re-enable. # FIXME: tensors with no elements are skipped because of no nvfuser support def reshape_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3431,8 +3442,8 @@ def slice_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=slice_in_dim_sample_generator, jax_reference=jax.lax.slice_in_dim if JAX_AVAILABLE else None, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3443,8 +3454,8 @@ def slice_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs): shape_ops.append(slice_in_dim) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/416 -# Add strides and slicing outside tensor boundaries +# See issue "Slice prim samples need strides and slicing beyond tensor +# boundaries" def slice_prim_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3466,8 +3477,8 @@ def slice_prim_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=slice_prim_sample_generator, jax_reference=jax.lax.slice if JAX_AVAILABLE else None, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3528,7 +3539,7 @@ def split_sample_generator(op, device, dtype, requires_grad, **kwargs): ((4, 6, 7), 3, -1), ((4, 6, 7), 9, 1), ((4, 6, 7), (1, 2, 1, 2), 1), - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/420 + # See issue "nvFuser split test failure" # ((4, 6, 7), (3, 1, 2, 0, 0, 1), -1), ((4, 4, 12), 4, 2), ) @@ -3748,8 +3759,8 @@ def tensor_split_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=tensor_split_sample_generator, torch_reference=torch.tensor_split, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -4237,7 +4248,8 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=unsqueeze_sample_generator, jax_reference=jax.lax.expand_dims if JAX_AVAILABLE else None, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -4452,7 +4464,8 @@ def _replace_random_percentage(a: torch.Tensor, value: Number, percentage: float dtypes=(datatypes.complex32,), devicetypes=(devices.DeviceType.CPU,), ), - # See https://github.com/csarofeen/pytorch/issues/2369 + # nvFuser had issues with complex reductions, now fixed; TODO re-enable + # this test. DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -4502,7 +4515,8 @@ def var_sample_generator(op, device, dtype, requires_grad): dtypes=(datatypes.complex32,), devicetypes=(devices.DeviceType.CPU, devices.DeviceType.CUDA), ), - # See https://github.com/csarofeen/pytorch/issues/2369 + # nvFuser had issues with complex reductions, now fixed; TODO re-enable + # this test. DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -4552,8 +4566,7 @@ def var_sample_generator(op, device, dtype, requires_grad): # Complex var is not supported yet dtypes=(datatypes.floating,), test_directives=( - # TODO FIXME nvFuser fails to compile var_mean for these tests - # See https://github.com/Lightning-AI/lightning-thunder/issues/1438 + # See issue "nvFuser fails to compile some var_mean tests" DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -5199,7 +5212,8 @@ def einsum_error_generator(op, device, **kwargs): supports_grad=True, # TODO: test all integer types and figure out their dtype. dtypes=(datatypes.float32, datatypes.float64), - # See https://github.com/Lightning-AI/lightning-thunder/issues/1643. + # See issue "Disabled einsum tests might hide potential issues in our + # testing/op implementations" # Testing only float32, float64 now. # types=(datatypes.int64, datatypes.floating), # domain=(-1, +1), @@ -6077,7 +6091,7 @@ def group_norm_error_generator(op, device, **kwargs): dtypes=(datatypes.float16, datatypes.bfloat16), devicetypes=(devices.DeviceType.CUDA,), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/1405 + # This should be fixed now; TODO re-enable, test DecorateInfo( pytest.mark.xfail, executors=("nvfuser",), @@ -6465,7 +6479,8 @@ def embedding_sample_generator(op, device, dtype, requires_grad, **kwargs): dtypes=(datatypes.floating, datatypes.complexfloating), test_directives=( # TODO Investigate these discrepancies -- some dtype x executor configurations seem to be fine - # See https://github.com/Lightning-AI/lightning-thunder/issues/1387 + # See issue "phantom grad's embedding computation is divergent from + # PyTorch's" DecorateInfo( custom_comparator(partial(assert_close, atol=1, rtol=2)), "test_phantom_grad_vs_torch_consistency", @@ -6803,9 +6818,8 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs # TODO Enable cross entropy bwd weight support -# see https://github.com/Lightning-AI/lightning-thunder/issues/834 # TODO Enable test cases after adding support nll_loss_nd, weight tensor, and label_smoothing options. -# See https://github.com/Lightning-AI/lightning-thunder/issues/704 +# TODO see issue "Add support for remaining cross_entropy_loss arguments" def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -7152,7 +7166,7 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): dtypes=(datatypes.float16,), devicetypes=(devices.DeviceType.CPU,), ), - # https://github.com/Lightning-AI/lightning-thunder/issues/1032 + # This should be fixed now; TODO re-enable and test DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -7169,7 +7183,8 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): prob_distr_ops = [] -# multinomial testing is currently disabled due to https://github.com/Lightning-AI/lightning-thunder/issues/2258 +# multinomial testing is currently disabled due to issue "randomness: enable +# PyTorch generators for operations like multinomial" # def multinomial_sample_generator(op, device, dtype, requires_grad, **kwargs): # make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index f7dfbd1267..54b27f4964 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -532,7 +532,7 @@ def foo(tup): @instantiate(dtypes=NOTHING) def test_type_promotion_tensors(executor, device, _): if executor == TorchExecutor: - pytest.xfail("https://github.com/Lightning-AI/lightning-thunder/issues/406") + pytest.xfail('see issue "vmap of sum doesn\'t work when dims are passed as a keyword argument"') def foo(a, b): return a + b @@ -590,7 +590,7 @@ def bar(a, b, c): @instantiate(dtypes=NOTHING) def test_type_promotion_numbers_and_tensors(executor, device, _): if executor == TorchExecutor: - pytest.xfail("https://github.com/Lightning-AI/lightning-thunder/issues/406") + pytest.xfail('See issue "Type promotion with the torchexecutor and elementwise operations is incorrect"') def foo(a, b, c): return a + b + c @@ -1130,7 +1130,8 @@ def test_detached_trace(executor, device: str, _): def test_normalized_args_prims_sum(executor, device: str, dtype: dtypes.dtype): # This test verifies that the recorded trace for a call to prims.sum # has its positional and keyword arguments normalized to the same form. - # See: https://github.com/Lightning-AI/lightning-thunder/issues/195 + # See issue "vmap of sum doesn't work when dims are passed as a keyword + # argument" a = make_tensor((2, 2), device=device, dtype=ltorch.to_torch_dtype(dtype)) def func_dim_posarg(x): @@ -1221,7 +1222,8 @@ def foo(x): assert str(trace).count("Testing") == 1 -# Check for https://github.com/Lightning-AI/lightning-thunder/issues/471 +# Check to verify the issue in "KeyError thrown in thunder.executor.utils.Region +# when None is passed in as input". @instantiate(dtypes=(thunder.float32,)) def test_argument_of_none(executor, device, dtype): from thunder.executors.utils import Region @@ -1683,7 +1685,7 @@ def test_transforms_vmap_axis_size(executor, device, _): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2118"),), + decorators=(pytest.mark.xfail(reason='issue "flaky test: test_transforms_vjp_{2_1, 1_2}_nvfuser_cuda_None"'),), ) def test_transforms_vjp_1_2(executor, device, _): from thunder.core.transforms import vjp @@ -1790,7 +1792,7 @@ def func(x): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2118"),), + decorators=(pytest.mark.xfail(reason='issue "flaky test: test_transforms_vjp_{2_1, 1_2}_nvfuser_cuda_None"'),), ) def test_transforms_vjp_2_1(executor, device, _): from thunder.core.transforms import vjp @@ -1830,7 +1832,8 @@ def func_2_1(x, y): # executors=( # nvFuserExecutor, # # TODO: Enable Torch executor once the issue with sum is fixed -# # See: https://github.com/Lightning-AI/lightning-thunder/issues/438 +# # See issue "Different behavior of sum(tensor, ()) for nvFuser and +# # Torch executor" # ), # ) # def test_transforms_vmap_inline_value_and_grad(executor, device, _): @@ -1950,8 +1953,8 @@ def f(a): assert "thunder.computation" in excinfo.traceback[-1].path -# TODO Add nvFuser support (https://github.com/Lightning-AI/lightning-thunder/issues/809) -# TODO Make these OpInfo tests (https://github.com/Lightning-AI/lightning-thunder/issues/810) +# TODO See issue "Add contiguous and clang.stride_order OpInfos that check stride +# consistency with PyTorch" @instantiate( dtypes=NOTHING, executors=(TorchExecutor,), @@ -2191,7 +2194,8 @@ def func(qkv): @instantiate(dtypes=NOTHING) def test_no_passthrough_symbol(executor, device, _): # A test case for the situation reported in - # https://github.com/Lightning-AI/lightning-thunder/issues/1131 + # "backward trace contains symbols not present in forward that cause + # NotImplementedError" # When an operation simply passes through its input, we should not # add it to the trace. diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 5ab26bd5af..6d76761a54 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -29,7 +29,8 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req from thunder.tests.opinfos import SampleInput # TODO: cudnnex seems to produce large mismatches against reference when tensor initialized from the wider default range of [-9,9] - # https://github.com/Lightning-AI/lightning-thunder/issues/1871 + # See issue "cuDNN SDPA backward might return NaNs for inputs with absolute + # value more than certain threshold" make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-0.5, high=0.5) n_head = 2 @@ -87,10 +88,6 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req ) -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880 -# NOTE This test modifies the global executor map, so it technically should not -# be run in parallel with other tests @requiresCUDA def test_cudnn_sdpa(): # expect sdpa to fail for 8.9.2 and below @@ -157,8 +154,6 @@ def snippet_torch_consistency(op, torch_op, sample): assert_close(thunder_result, torch_result, equal_nan=True, atol=0.0625, rtol=5e-2) -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880 # TODO Make it easier for executors to write tests like this, including writing them out-of-tree # TODO The executor passed below is just a "dummy" that actually gets ignored -- we should provide # a way to use decorators like @ops without a particular executor diff --git a/thunder/tests/test_elementwise.py b/thunder/tests/test_elementwise.py index a76db9e673..af5dcfb58b 100644 --- a/thunder/tests/test_elementwise.py +++ b/thunder/tests/test_elementwise.py @@ -25,7 +25,7 @@ def test_elementwise_dunder_operations_on_numbers(executor, device, dtype): # (math.floor, (bool, int, float)), # (operator.inv, (bool, int)), # (operator.neg, (bool, int, float, complex)), - # # TODO https://github.com/Lightning-AI/lightning-thunder/issues/713 + # # TODO see issue "Implement positive operations" # # operator.pos, # (builtins.round, (bool, int, float)), # (math.trunc, (bool, int, float)), @@ -65,8 +65,8 @@ def foo(a): assert_close(actual, expected) -# TODO Test operator and method variants using OpInfos -# See https://github.com/Lightning-AI/lightning-thunder/issues/710 +# TODO: see issue "Test operator and method variants of operations using +# OpInfos" @instantiate(dtypes=(thunder.float32,)) def test_core_tensor_methods(executor, device, dtype): def foo(a, b, c, d): @@ -126,7 +126,8 @@ def test_where(executor, device, dtype): torch_result = torch_fn(pred, i64, -2.3) assert_close(thunder_result, torch_result) - # TODO Fix https://github.com/Lightning-AI/lightning-thunder/issues/711 + # TODO fix issue "Currently nvFuser tensor x float operations result in + # float64 results" # float x int # thunder_result = thunder_fn(pred, 3., 5) # torch_result = torch_fn(pred, 3., 5) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index e2100a585c..3a44a25dbc 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -26,7 +26,7 @@ # TODO: Move this to thunder.tests.opinfos op_skip = { - # See https://github.com/Lightning-AI/lightning-thunder/issues/226 + # See issue "Support closures of torch.Tensor" # TODO: AttributeError: 'Tensor' object has no attribute 'true_dtype' "masked_fill", # TODO: RuntimeError: Expected index=tensor([2, 3, 2, 0, 3, 1, 0, 2], @@ -635,7 +635,8 @@ def bar(a, b): dtypes=NOTHING, ) def test_convert_element_type_with_float(executor, device, _): - # Verifies a fix for https://github.com/Lightning-AI/lightning-thunder/issues/537 + # Verifies the fix for "grad transform hits error: AttributeError: 'float' + # object has no attribute 'dtype'" from thunder.core.transforms import value_and_grad a = make_tensor([5], dtype=torch.float32, device=device) @@ -708,7 +709,9 @@ def sincos_backward(sin_x, cos_x, g1, g2): assert trace.output[0] == trace.bound_symbols[4].output -# TODO: Fix flaky test https://github.com/Lightning-AI/lightning-thunder/issues/1919 +# TODO: see issue +# "thunder/tests/test_grad.py::test_torch_autograd_saved_tensors_memory_release +# is flaky" @pytest.mark.xfail(strict=False, reason="This test is flaky") @requiresCUDA def test_torch_autograd_saved_tensors_memory_release(): @@ -906,8 +909,7 @@ def func(a): def test_torch_autograd_crazy_collections_in_and_out(executor, device, dtype): from thunder.executors.torch_autograd import thunder_backward - # Borrowed from - # https://github.com/Lightning-AI/lightning-thunder/blob/3401475ee47d5a732b6b4d5dcbd88afcd9bed81d/thunder/tests/test_core.py#L117 + # Borrowed from `test_crazy_collections_in_and_out`. def foo(a, b, c, *, ka, kb, kc): d = { 5: 2, diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index e113438f93..447647b8bb 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1010,7 +1010,8 @@ def foo(x): # } -# See https://github.com/Lightning-AI/lightning-thunder/issues/2078 +# Test for issue "jit: passing jitted functions as arguments to jitted +# functions fails." def test_reduce_jitted_reduce_fn(jit): import functools @@ -1482,7 +1483,7 @@ def foo(): assert jfoo() is True -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1824") +@pytest.mark.xfail(reason='"exec() and eval() lookaside ignores locals()"') def test_exec_import_star(jit): # Assert that we can actually generate the instruction to_exec = "from itertools import *" @@ -2606,8 +2607,8 @@ def test_displayhook(jit): import io import code - # TODO: Implement the lookaside for exec(). Under he hood, `code.InteractiveInterpreter().runsource('5;6;7')`` - # just compiles the string and calls exec(), plus a little bit of irrelevant error handling. + # TODO: Implement the lookaside for exec(). Under the hood, `code.InteractiveInterpreter().runsource('5;6;7')`` + # just compiles the string and calls exec(), plus a little bit of error handling. # I'm not entirely convinced that the PRINT_EVAL is going through our system at the moment, but # it for sure would with an exec() lookaside. I'm also not sure what makes InteractiveInterpreter # interactive. It isn't *actually* in interactive mode. So, why is PRINT_EXPR in the interpreted @@ -2616,7 +2617,7 @@ def test_displayhook(jit): py_redirect = io.StringIO() with redirect_stdout(py_redirect): # Avoid clobbering this interpreter's display hook, and ensure it's interactive. - # Why is this necessary? I'm not sure. + # Why is this necessary? interpreter = code.InteractiveInterpreter() def smt(s): diff --git a/thunder/tests/test_jit_functional.py b/thunder/tests/test_jit_functional.py index e65ed69aec..7fece51f72 100644 --- a/thunder/tests/test_jit_functional.py +++ b/thunder/tests/test_jit_functional.py @@ -292,7 +292,7 @@ def test_binary_ops_compare_numbers(): def test_binary_ops_int_numbers(): - # Issue https://github.com/Lightning-AI/lightning-thunder/issues/594 for more ops + # TODO: see issue "Implement logical and arithmetic left and right shifts" # "<<", ">>", int_ops = ["+", "&", "//", "*", "%", "|", "**", "-", "/", "^"] @@ -1574,7 +1574,7 @@ def foo(a, b): assert_close(expected, actual) -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2191") +@pytest.mark.xfail(reason='issue: "jit-eager: allow sets as a return value"') def test_return_set(): def foo(a, b): return {a, b} @@ -2453,7 +2453,7 @@ def foo(): jfoo() -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2184") +@pytest.mark.xfail(reason='issue: "sharp edges: loading closures"') def test_input_closure_sharp_edge(): x = 5 @@ -2486,7 +2486,7 @@ def _test_fn_global_no_sharp_edge_fn(): return 7 -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2189") +@pytest.mark.xfail(reason='issue: "sharp edge: allow function and module loads"') def test_fn_global_no_sharp_edge(): def foo(x): return x + _test_fn_global_no_sharp_edge_fn() diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index d4c6dc92d0..28fc8a54fc 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -350,7 +350,7 @@ def foo(a, b): jfoo = thunder.jit(foo) # TODO Add test for bool - # See https://github.com/Lightning-AI/lightning-thunder/issues/1990 + # see issue "Binary addition on booleans should promote to an integer" cases = ( (2, 3), (2.1, 3.4), @@ -397,7 +397,7 @@ def foo(a, b): jfoo = thunder.jit(foo) # TODO Add test for bool - # See https://github.com/Lightning-AI/lightning-thunder/issues/1990 + # see issue "Binary addition on booleans should promote to an integer" cases = ( (2, 3), (2.1, 3.4), @@ -414,7 +414,7 @@ def foo(a, b): _test_add_global_global = 2 -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1935", raises=BaseException) +@pytest.mark.xfail(reason='"disallow global reads and writes (temporarily)"', raises=BaseException) def test_global_fails(): def foo(): return _test_add_global_global @@ -425,7 +425,10 @@ def foo(): jfoo() -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1936", raises=BaseException) +@pytest.mark.xfail( + reason='"Raise an error when a program attempts to write to a nonlocal that was captured from outside the interpreter"', + raises=BaseException, +) def test_nonlocal_outside_interpreter_fails(): def foo(): x = 3 diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 971201bcba..15760f2948 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -41,7 +41,7 @@ def test_nanogpt_complete(executor, device, dtype): # TODO Investigate grad inconsistency # TODO: Add float16 and bfloat16 comparison tests here and to all other tests in # this file. -# https://github.com/Lightning-AI/lightning-thunder/issues/907 +# See issue "Add half precision dtype tests to test_networks.py" @instantiate(dtypes=(thunder.float32,)) def test_nanogpt_complete_autograd(executor, device, dtype): tdtype = ttorch.to_torch_dtype(dtype) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 166c4ae3c6..de35f4a37c 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -315,7 +315,7 @@ def func(w, x, y, z): @instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)) def test_cse_rematerialization(executor, device, _): - # Unit test for https://github.com/Lightning-AI/lightning-thunder/issues/2046 + # Unit test for "llama2.c example failed with bookend disabled." from thunder.tests.llama2_model import Transformer, ModelArgs from thunder.core.pytree import tree_flatten @@ -614,8 +614,7 @@ def func(x: torch.Tensor, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: dtypes=NOTHING, executors=( nvFuserExecutor, - # NOTE torch executor does not have bookend optimization. - # See comment: https://github.com/Lightning-AI/lightning-thunder/issues/571#issuecomment-1610778432 + # NOTE We might want to do transpose bookend optimization for other executors than nvFuser. ), ) def test_bookend_meta_optimization(executor, device, _): diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index 0f8560278b..b94403fcf3 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -310,8 +310,8 @@ def test_find_cut_dropout(executor, device, _): ext_producer_outputs = find_external_producer_outputs(utils.consumers(trace), (), producer, consumer) cut = find_cut(ext_producer_outputs, producer, consumer) # Note t5 is the boolean mask for dropout. It should be chosen over the t6 - # that is the float32 mask. See this issue for the original problem: - # https://github.com/Lightning-AI/lightning-thunder/issues/706 + # that is the float32 mask. See this issue: "The Recomputation Algorithm on + # Dropout choses a float32 mask to save" assert cut == ("t0", "t5", "t9") diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index 7152e432ac..69c7a67adf 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -37,7 +37,7 @@ def snippet_torch_consistency(op, torch_op, sample, comp): thunder_result = op(*sample.args, **sample.kwargs) torch_result = torch_op(*sample.args, **sample.kwargs) - # TODO Review how lightning.compile returns Exception information + # TODO Review how thunder.jit returns Exception information if isinstance(thunder_result, Exception): raise thunder_result diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 055df74f1e..efb6319d83 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -12,7 +12,7 @@ import opt_einsum -# Initialies the language context +# Initializes the language context from thunder.torch.langctx import register_method import thunder.clang as clang @@ -225,8 +225,9 @@ def _parse_to_device_and_dtype( dtype = to_dtype(dtype) # Case 1 -- tensor first else: - # See https://github.com/Lightning-AI/lightning-thunder/issues/317 - # It'd be nice to write torch.Tensor here instead of TensorProxy + # It'd be nice to write torch.Tensor here instead of TensorProxy. + # See issue "Translate isinstance(a, torch.Tensor) calls so that + # TensorProxies can pass as torch.Tensors" utils.check_type(tensor_dtype_or_device, TensorProxy) device_ = tensor_dtype_or_device.device if device is None else to_device(device) dtype_ = tensor_dtype_or_device.true_dtype if dtype is None else to_dtype(dtype) @@ -414,7 +415,8 @@ def multinomial( ) -> TensorLike: utils.check(out is None, lambda: "Non-None out is not supported", NotImplementedError) - # See https://github.com/Lightning-AI/lightning-thunder/issues/2258 + # See issue "randomness: enable PyTorch generators for operations like + # multinomial" utils.check( generator is None, lambda: f"multinomial does not yet support specifying a generator", NotImplementedError ) @@ -431,7 +433,7 @@ def multinomial( # TODO Maybe update this to return an offset of how far to advance the seed to acquire new values -# See https://github.com/Lightning-AI/lightning-thunder/issues/1360 +# See issue "Maybe return offset from thunder.torch.uniform_philox" @torchsymbol(is_method=False, id="torch.uniform_philox") def uniform_philox( shape: Sequence[int], @@ -1390,9 +1392,9 @@ def zeta(a, b, /): # For calculate op1(a, op2(value, op2(b, c))) by promoting all input tensors at once # NOTE use this explicit type promotion because a direct combination of add/mul will have a redundant cast, -# which may lead to accuracy problems, see: -# https://github.com/Lightning-AI/lightning-thunder/pull/1155#discussion_r1342653591 for details -# TODO remove this when the optimization pass is ready: https://github.com/Lightning-AI/lightning-thunder/issues/1178 +# which may lead to accuracy problems. +# TODO remove after issue "Redundant cast removal could be performed through metadata-only +# operations, like broadcasting" is resolved def addcmul_addcdiv_helper( a, b, c, op1, op2, *, value=None, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ): @@ -2868,8 +2870,9 @@ def _avg_pool_helper( # Dimensionality of the kernel. kernel_numel = reduce(operator.mul, kernel_size, 1) - # nn.functional.avg_pool does not have `divisor_override` for some reason. - # TODO: seems like an oversight from PyTorch and/or 1d case is very niche. + # nn.functional.avg_pool does not have `divisor_override`. + # TODO: look into PyTorch side; is this behavior deliberate? Could be that + # 1D case is niche. # If needed, handle it with checks and transforms. For now unconditionally # override value with kernel_numel. if divisor_override is None or dim == 1: @@ -3031,7 +3034,8 @@ def _dropout_helper(a, p): # TODO Add annotations, make not a prim # The backward decomposition of cross_entropy cannot be efficiently fused, so we have this cross_entropy_backward # primitive. Executors can override the primitive using internal implementations. -# See https://github.com/Lightning-AI/lightning-thunder/issues/660 +# See issue "Cross_entropy is decomposed for backward but the decomposition is +# not fusible currently" @torchsymbol("cross_entropy_backward", id="cross_entropy_backward", is_prim=True) def cross_entropy_backward(g, a, /, target, weight, reduction, ignore_index, label_smoothing): return TensorProxy(like=g, shape=a.shape) @@ -3599,7 +3603,8 @@ def log_softmax(a: TensorLike, /, dim: int, *, dtype: None | dtypeLike = None) - # TODO Update annotations and consider moving to torchex # We improve the efficiency of cross_entropy backward decomposition by adding the log_softmax_backward # and nll_loss_backward primitives. Executors can override the primitives using internal implementations. -# See https://github.com/Lightning-AI/lightning-thunder/issues/660 +# See issue "Cross_entropy is decomposed for backward but the decomposition is +# not fusible currently" @torchsymbol("log_softmax_backward", id="log_softmax_backward") def log_softmax_backward(g: TensorProxy, /, output: TensorProxy, dim: int, dtype: dtypeLike) -> TensorLike: dtype: dtypes.dtype = to_dtype(dtype) @@ -3845,7 +3850,7 @@ def softmax(a: TensorLike, /, dim: int, *, dtype: None | dtypeLike = None) -> Te if torch.distributed.is_available(): DistributedReduceOpLike = str | torch.distributed.ReduceOp | dist_prims.DistributedReduceOps - # string name, PyTorch enum value, lightning.compile enum value + # string name, PyTorch enum value, thunder.jit enum value _reduceop_triples = (("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),) def to_thunder_distributed_reduce_op(op: DistributedReduceOpLike | None): From 08164bf86cf28e6255da8138e1d66e0eb7a993ff Mon Sep 17 00:00:00 2001 From: nikitaved Date: Sat, 16 Mar 2024 08:53:37 +0100 Subject: [PATCH 16/44] general jit: test enabling sharp edges (PR2459) --- thunder/__init__.py | 5 +++++ thunder/core/jit_ext.py | 3 +++ thunder/core/options.py | 22 +++++++++++----------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0c4ee7c7f8..770507697c 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -15,6 +15,7 @@ from thunder.core.options import ( INTERPRETATION_OPTIONS, resolve_interpretation_option, + resolve_sharp_edges_option, CACHE_OPTIONS, SHARP_EDGES_OPTIONS, ) @@ -343,6 +344,10 @@ def jit( if additional_transforms is None: additional_transforms = [] + # Make sharp_edges == warn default if not supplied and if in the general jit + if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None: + sharp_edges = SHARP_EDGES_OPTIONS.WARN + # TODO RC1 Refine the compile data option to remove unused options cd = CompileData( fn=fn, diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 04349bbfe4..e626a77f2a 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -889,6 +889,9 @@ def is_from_torch(fn): return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch") if is_opaque(fn) and is_from_torch(fn): + if fn.__module__.startswith("torch._C"): + return lookaside + # Torch functions have __name__ defined fn_name = f"{fn.__module__}.{fn.__name__}" diff --git a/thunder/core/options.py b/thunder/core/options.py index a71c6ac628..521af983eb 100644 --- a/thunder/core/options.py +++ b/thunder/core/options.py @@ -170,16 +170,16 @@ def resolve_sharp_edges_option(x: Any, /) -> SHARP_EDGES_OPTIONS: elif isinstance(x, str): seo = _str_to_sharp_edges_option(x) - if seo is None: - _unknown_option("sharp edges", _str_to_sharp_edges_options_map.keys(), "allow", x) - - if seo is SHARP_EDGES_OPTIONS.WARN: - warnings.warn( - f"The 'warn' sharp edges option is experimental and still in development. It may not work as expected." - ) - if seo is SHARP_EDGES_OPTIONS.ERROR: - warnings.warn( - f"The 'error' sharp edges option is experimental and still in development. It may not work as expected." - ) + if seo is None: + _unknown_option("sharp edges", _str_to_sharp_edges_options_map.keys(), "allow", x) + + if seo is SHARP_EDGES_OPTIONS.WARN: + warnings.warn( + f"The 'warn' sharp edges option is experimental and still in development. It may not work as expected." + ) + if seo is SHARP_EDGES_OPTIONS.ERROR: + warnings.warn( + f"The 'error' sharp edges option is experimental and still in development. It may not work as expected." + ) return seo From 59c7fb1ca39342d6554da5bf32ae71f93ff073b4 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 16 Mar 2024 12:27:30 +0100 Subject: [PATCH 17/44] zero to thunder v0 (PR2431) --- .../adding_custom_operator_backward.ipynb | 287 +-- notebooks/zero_to_thunder.ipynb | 1956 +++++++++++++++-- thunder/__init__.py | 7 +- 3 files changed, 1915 insertions(+), 335 deletions(-) diff --git a/notebooks/adding_custom_operator_backward.ipynb b/notebooks/adding_custom_operator_backward.ipynb index 2ed02db162..dc44b8c3d7 100644 --- a/notebooks/adding_custom_operator_backward.ipynb +++ b/notebooks/adding_custom_operator_backward.ipynb @@ -370,11 +370,11 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(a, target):\n", - " # a: \"cuda:0 f32[2048, 50257]\" \n", - " # target: \"cuda:0 i64[2048]\" \n", - " (res, _) = apex_xentropy_forward(a, target, None, None, -100, None, 'none', 0.0)\n", - " del a, target\n", + "def computation(logits, labels):\n", + " # logits: \"cuda:0 f32[2048, 50257]\" \n", + " # labels: \"cuda:0 i64[2048]\" \n", + " (res, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", + " del logits, labels\n", " return res" ] }, @@ -626,188 +626,31 @@ }, { "cell_type": "markdown", - "id": "39fd6fce", + "id": "b4ec7c57", "metadata": {}, "source": [ - "With this, we can use the `grad` transform to get the gradient:" + "With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.\n" ] }, { "cell_type": "code", "execution_count": 12, - "id": "d9f6dfde", + "id": "8c5da6f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", - " call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + "call apex_cross_entropy_grad(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", + " call apex_xentropy_forward_meta(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", - " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)\n", + " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", "\n", "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", "\n", - "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.7825, -1.1014, -0.9563, ..., 0.2801, 0.5359, -1.4094],\n", - " [ 1.1592, 0.8128, 0.5846, ..., 1.0255, 0.4217, 0.2548],\n", - " [ 0.8622, 0.5320, -1.5205, ..., -1.4938, -1.0423, -0.9527],\n", - " ...,\n", - " [-0.8978, 2.1914, 0.1603, ..., 0.0704, -0.7642, 1.4002],\n", - " [ 0.1750, 0.6244, 1.1711, ..., 0.3491, -0.5760, -1.4034],\n", - " [ 1.3689, -1.5422, 0.8149, ..., 0.9625, 1.0281, 1.4206]],\n", - " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([43812, 33387, 31729, ..., 27740, 2907, 8268], device='cuda:0'), None, None, -100, None, none, 0.0)\n", - "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.8060, 10.7141, 11.4505, ..., 11.2361, 10.6558, 11.2219],\n", - " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3327, 11.3294, 11.3304, ..., 11.3231, 11.3170, 11.3209],\n", - " device='cuda:0'))\n", - "\n", - "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.7825, -1.1014, -0.9563, ..., 0.2801, 0.5359, -1.4094],\n", - " [ 1.1592, 0.8128, 0.5846, ..., 1.0255, 0.4217, 0.2548],\n", - " [ 0.8622, 0.5320, -1.5205, ..., -1.4938, -1.0423, -0.9527],\n", - " ...,\n", - " [-0.8978, 2.1914, 0.1603, ..., 0.0704, -0.7642, 1.4002],\n", - " [ 0.1750, 0.6244, 1.1711, ..., 0.3491, -0.5760, -1.4034],\n", - " [ 1.3689, -1.5422, 0.8149, ..., 0.9625, 1.0281, 1.4206]],\n", - " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([43812, 33387, 31729, ..., 27740, 2907, 8268], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3327, 11.3294, 11.3304, ..., 11.3231, 11.3170, 11.3209],\n", - " device='cuda:0'), 0.0)\n", - "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.6187e-05, 3.9804e-06, 4.6022e-06, ..., 1.5845e-05, 2.0466e-05,\n", - " 2.9255e-06],\n", - " [3.8294e-05, 2.7081e-05, 2.1557e-05, ..., 3.3501e-05, 1.8315e-05,\n", - " 1.5500e-05],\n", - " [2.8425e-05, 2.0432e-05, 2.6236e-06, ..., 2.6946e-06, 4.2325e-06,\n", - " 4.6290e-06],\n", - " ...,\n", - " [4.9265e-06, 1.0818e-04, 1.4192e-05, ..., 1.2971e-05, 5.6303e-06,\n", - " 4.9039e-05],\n", - " [1.4491e-05, 2.2712e-05, 3.9235e-05, ..., 1.7247e-05, 6.8383e-06,\n", - " 2.9895e-06],\n", - " [4.7630e-05, 2.5919e-06, 2.7372e-05, ..., 3.1723e-05, 3.3876e-05,\n", - " 5.0159e-05]], device='cuda:0')\n", - "\n", - "Difference: 1.3969838619232178e-09\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def computation(logits, labels):\n", - " # logits: \"cuda:0 f32[2048, 50257]\" \n", - " # labels: \"cuda:0 i64[2048]\" \n", - " (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", - " t4 = torch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", - " # t4 = ltorch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", - " # t4 = prims.full((2048,), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[2048]\"\n", - " t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: \"cuda:0 f32[2048, 50257]\"\n", - " del t4, logits, labels, t0\n", - " return [t3]\n" - ] - } - ], - "source": [ - "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", - "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", - "\n", - "grad_jfn = thunder.core.transforms.grad(jfn)\n", - "actual_grad, = grad_jfn(logits, labels)\n", - "\n", - "expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)\n", - "\n", - "\n", - "print(\"Difference:\", (actual_grad - expected_grad).abs().max().item())\n", - "print(thunder.last_traces(grad_jfn)[-1])\n", - " \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "f9f85e3b", - "metadata": {}, - "source": [ - "But life isn't completely simple. When we noticed that we thought about how to do backward for a long time, this is our previous approach, that is (in March 2024) needed for getting PyTorch Autograd integration.\n", - "This works by having a _forward rule_ for generating a tuple of result and values saved for backward and a _backward rule_ that takes the saved values and output grad to compute the input grads, much like PyTorch autograd itself, but with the pluggable executor architecture of Thunder.\n", - "\n", - "We are working at allowing you to skip this part!" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "b0379bc4", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.transforms import register_augmented_forward_with_checker, register_backward\n", - "\n", - "def apex_xentropy_forward_rule(\n", - " a,\n", - " target,\n", - " weight=None,\n", - " size_average=None,\n", - " ignore_index=-100,\n", - " reduce=None,\n", - " reduction=\"mean\",\n", - " label_smoothing=0.0,\n", - "):\n", - " loss, max_log_sum_exp = apex_xentropy_forward(\n", - " a,\n", - " target,\n", - " weight,\n", - " size_average,\n", - " ignore_index,\n", - " reduce,\n", - " reduction,\n", - " label_smoothing,\n", - " )\n", - " primal = loss\n", - " saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)\n", - " return primal, saved_for_backward\n", - "\n", - "register_augmented_forward_with_checker(\n", - " apex_xentropy_ex,\n", - " \"torch.nn.functional.cross_entropy\",\n", - " apex_xentropy_checker,\n", - " apex_xentropy_forward_rule,\n", - ")\n", - "\n", - "@register_backward((apex_xentropy_ex, thunder.torch.cross_entropy.id))\n", - "def apex_cross_entropy_backward_rule(\n", - " logits, labels, max_log_sum_exp, reduction, smoothing, grad\n", - "):\n", - " if reduction != \"none\":\n", - " raise ValueError(f\"Invalid reduction: {reduction}\")\n", - "\n", - " grad_logits = apex_xentropy_backward(\n", - " grad,\n", - " logits,\n", - " labels,\n", - " max_log_sum_exp,\n", - " smoothing,\n", - " )\n", - " return grad_logits, *([None] * 7)" - ] - }, - { - "cell_type": "markdown", - "id": "b4ec7c57", - "metadata": {}, - "source": [ - "With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8c5da6f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)\n", "|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", @@ -861,7 +704,7 @@ " -4.9872e-05, -6.3328e-05]], device='cuda:0')\n", "\n", "Max error in loss: 9.5367431640625e-07\n", - "Max error in logits grad: 1.3969838619232178e-09\n" + "Max error in logits grad: 2.384185791015625e-07\n" ] }, { @@ -966,13 +809,12 @@ " return (t3, None)]" ] }, - "execution_count": 14, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from thunder.core.transforms import value_and_grad\n", "from thunder import torch as ltorch\n", "\n", "torch.manual_seed(0)\n", @@ -988,10 +830,10 @@ "actual_loss = cfn(logits, labels)\n", "go = torch.randn_like(actual_loss)\n", "\n", - "actual_grads, = torch.autograd.grad(actual_loss, logits, go)\n", + "actual_grad, = torch.autograd.grad(actual_loss, logits, go)\n", "\n", "expected_loss = loss_fn(logits, labels)\n", - "expected_grads, = torch.autograd.grad(expected_loss, logits, go)\n", + "expected_grad, = torch.autograd.grad(expected_loss, logits, go)\n", "\n", "print(\"Max error in loss:\", (actual_loss - expected_loss).abs().max().item())\n", "print(\"Max error in logits grad:\", (actual_grad - expected_grad).abs().max().item())\n", @@ -999,6 +841,102 @@ "thunder.last_traces(cfn)[-1]" ] }, + { + "cell_type": "markdown", + "id": "54d6a5ea", + "metadata": {}, + "source": [ + "Alternatively, we can also use the `grad` transform to get the gradient:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c88118eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + " call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", + "\n", + " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)\n", + " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", + "\n", + "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", + "\n", + "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", + " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", + " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", + " ...,\n", + " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", + " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", + " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", + " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), None, None, -100, None, none, 0.0)\n", + "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.0233, 11.9095, 11.2898, ..., 10.9289, 10.7487, 10.7455],\n", + " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", + " device='cuda:0'))\n", + "\n", + "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", + " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", + " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", + " ...,\n", + " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", + " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", + " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", + " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", + " device='cuda:0'), 0.0)\n", + "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.0706e-05, 1.4403e-05, 4.1058e-06, ..., 1.4309e-05, 5.3827e-06,\n", + " 6.0079e-06],\n", + " [1.0461e-04, 4.8840e-05, 1.7949e-05, ..., 2.9621e-05, 7.3879e-06,\n", + " 3.8697e-05],\n", + " [2.3705e-05, 3.6822e-05, 2.5485e-05, ..., 1.0948e-05, 2.7806e-05,\n", + " 1.4513e-05],\n", + " ...,\n", + " [2.4836e-06, 1.0338e-05, 3.1331e-06, ..., 2.2417e-05, 1.0857e-05,\n", + " 1.8259e-05],\n", + " [7.0235e-06, 2.1758e-05, 1.3145e-05, ..., 7.3762e-06, 7.1699e-06,\n", + " 1.3360e-05],\n", + " [1.6078e-05, 4.2941e-06, 6.9897e-06, ..., 1.2304e-05, 4.6857e-05,\n", + " 4.1070e-06]], device='cuda:0')\n", + "\n", + "Difference: 1.3969838619232178e-09\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(logits, labels):\n", + " # logits: \"cuda:0 f32[2048, 50257]\" \n", + " # labels: \"cuda:0 i64[2048]\" \n", + " (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", + " t4 = torch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", + " # t4 = ltorch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", + " # t4 = prims.full((2048,), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[2048]\"\n", + " t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: \"cuda:0 f32[2048, 50257]\"\n", + " del t4, logits, labels, t0\n", + " return [t3]\n" + ] + } + ], + "source": [ + "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", + "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", + "\n", + "grad_jfn = thunder.core.transforms.grad(jfn)\n", + "actual_grad, = grad_jfn(logits, labels)\n", + "\n", + "expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)\n", + "\n", + "\n", + "print(\"Difference:\", (actual_grad - expected_grad).abs().max().item())\n", + "print(thunder.last_traces(grad_jfn)[-1])\n" + ] + }, { "cell_type": "markdown", "id": "e234a47b", @@ -1008,8 +946,7 @@ "\n", "- We defined a custom executor with custom operations (Symbols in Thunder language), each with a *Meta-* (data propagation) *function* and an implementation.\n", "- We defined and registered rules to map existing operations to our new operations. This allows us to use optimizations on our model without changing the model's code! \n", - "- We defined a gradient rule and saw how we the `grad` transform uses it.\n", - "- We saw another (older) way to implement forward and backward rules that is currently needed to get automatic integration with PyTorch's autograd.\n", + "- We defined a gradient rule and saw how our automatic PyTorch Autograd integration or the explicit `grad` transform uses it.\n", "\n", "Now go and implement your favourite optimized operators. We would love to hear about your use-cases!\n" ] diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index 9c7a5468a3..68f61a47a0 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -5,265 +5,1907 @@ "id": "1638964c", "metadata": {}, "source": [ - "# Zero to thunder" + "# Zero to Thunder\n", + "\n", + "Here we take a very short tour of what is possible with Thunder.\n", + "\n", + "To get started we import it (and a bunch of things for this notebook)." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "e8953e57", + "execution_count": 1, + "id": "28b99b58", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '..')\n", + "import inspect\n", "\n", - "import torch, thunder\n", "\n", - "from thunder.tests.lit_gpt_model import Config, Block" + "import torch, thunder\n" ] }, { - "cell_type": "code", - "execution_count": 37, - "id": "0a62c587", + "cell_type": "markdown", + "id": "54f87aba", "metadata": {}, - "outputs": [], "source": [ - "from lit_gpt.model import Config, LLaMAMLP" + "## Compiling a first module with Thunder\n", + "\n", + "So let's get started! As a \"Hello World\", let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT." ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 2, "id": "d6ca6328", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + ")\n" + ] + } + ], "source": [ - "cfg = Config.from_name('Llama-2-7b-hf')\n", + "class LLaMAMLP(torch.nn.Module):\n", + " def __init__(self, n_embd, intermediate_size) -> None:\n", + " super().__init__()\n", + " self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", + " self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", + " self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x_fc_1 = self.fc_1(x)\n", + " x_fc_2 = self.fc_2(x)\n", + " x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n", + " return self.proj(x)\n", + "\n", + "\n", "with torch.device(\"cuda\"):\n", - " m = LLaMAMLP(cfg)\n", - "\n" + " m = LLaMAMLP(4096, 11008)\n", + "for p in m.parameters():\n", + " p.requires_grad_(False)\n", + "\n", + "print(m)" + ] + }, + { + "cell_type": "markdown", + "id": "702ea054", + "metadata": {}, + "source": [ + "Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`" ] }, { "cell_type": "code", - "execution_count": 52, - "id": "3a159966", + "execution_count": 3, + "id": "67ca2d37", + "metadata": {}, + "outputs": [], + "source": [ + "thunder_model = thunder.jit(m)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "964e2689", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "LLaMAMLP(\n", - " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + "ThunderModule(\n", + " (_model): LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", ")" ] }, - "execution_count": 52, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "m" + "thunder_model" + ] + }, + { + "cell_type": "markdown", + "id": "59db20f6", + "metadata": {}, + "source": [ + "Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance." ] }, { "cell_type": "code", - "execution_count": 53, - "id": "67ca2d37", + "execution_count": 5, + "id": "7f4de1b3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 1.4901161193847656e-07\n", + "58.2 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "58.7 ms ± 50.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], "source": [ - "tm = thunder.jit(m)" + "x = torch.randn(2, 2048, 4096, device=\"cuda\")\n", + "print('deviation:', (thunder_model(x) - m(x)).abs().max().item())\n", + "\n", + "%timeit thunder_model(x); torch.cuda.synchronize()\n", + "%timeit m(x); torch.cuda.synchronize()" + ] + }, + { + "cell_type": "markdown", + "id": "8835543e", + "metadata": {}, + "source": [ + "So what has changed?\n", + "Quite a bit!\n", + "\n", + "When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { "cell_type": "code", - "execution_count": 54, - "id": "964e2689", + "execution_count": 6, + "id": "a6f4b77c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "ThunderModule(\n", - " (_model): LLaMAMLP(\n", - " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", + " # x: \"cuda:0 f32[2, 2048, 4096]\" \n", + " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", + " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", + " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", + " x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del t_fc_1_weight\n", + " x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del x, t_fc_2_weight\n", + " [result] = nvFusion0(x_fc_1, x_fc_2)\n", + " # t9 = prims.neg(x_fc_1) # t9: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t10 = prims.exp(t9) # t10: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t11 = prims.add(1.0, t10) # t11: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t12 = prims.reciprocal(t11) # t12: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # a = prims.mul(x_fc_1, t12) # a: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # result = prims.mul(a, x_fc_2) # result: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del x_fc_1, x_fc_2\n", + " t18 = torch.nn.functional.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " # t18 = ltorch.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " # t18 = prims.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " del result, t_proj_weight\n", + " return t18" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_traces(thunder_model)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "a0071924", + "metadata": {}, + "source": [ + "For more detail of what is going on in this trace:\n", + "- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n", + "- It has recorded the tensor metadata.\n", + "- Operations have been mapped from the PyTorch functions to `thunder.torch`(aka `ltorch`) equivalents and decomposed into _primitive operations_.\n", + "- The multiplication and activation (`x = torch.nn.functional.silu(x_fc_1) * x_fc_2`have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) \n", + "- You can see how the parameters are obtained and the metadata is checked in the prologue - get it through `thunder.last_prologue_traces(thunder_model)[-1]`.\n", + "\n", + "You can actually see the series of traces, `last_traces` gives you a list of transformed traces in chronological order - for example the initial trace `thunder.last_traces(thunder_model)[0]` does not have the fusion yet.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7749aed1", + "metadata": {}, + "source": [ + "## Compiling a more complex model\n", + "\n", + "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d53e0c43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GPT(\n", + " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", + " (transformer): ModuleDict(\n", + " (wte): Embedding(32000, 4096)\n", + " (h): ModuleList(\n", + " (0-3): 4 x Block(\n", + " (norm_1): RMSNorm()\n", + " (attn): CausalSelfAttention(\n", + " (attn): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (norm_2): RMSNorm()\n", + " (mlp): LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", + " )\n", + " )\n", + " (ln_f): RMSNorm()\n", " )\n", ")" ] }, - "execution_count": 54, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tm" + "from lit_gpt import GPT\n", + "from thunder.tests.lit_gpt_model import Config\n", + "cfg = Config.from_name('Llama-2-7b-hf')\n", + "cfg.n_layer = 4 # fewer layers\n", + "with torch.device('cuda'):\n", + " m = GPT(cfg)\n", + "m\n" + ] + }, + { + "cell_type": "markdown", + "id": "e536a4aa", + "metadata": {}, + "source": [ + "Again we jit our model and compare the output..." ] }, { "cell_type": "code", - "execution_count": 60, - "id": "7f4de1b3", + "execution_count": 8, + "id": "36a7be96", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 1.8477439880371094e-06\n" + ] + } + ], + "source": [ + "thunder_model = thunder.jit(m)\n", + "\n", + "inp = torch.randint(1, m.config.vocab_size, (1, 512), device=\"cuda\")\n", + "\n", + "actual = thunder_model(inp)\n", + "expected = m(inp)\n", + "\n", + "print(\"deviation:\", (actual - expected).abs().max().item())\n" + ] + }, + { + "cell_type": "markdown", + "id": "2f681093", + "metadata": {}, + "source": [ + "Just like before, we can see the program it ran:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ac7e8bc9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(1.4901e-07, device='cuda:0', grad_fn=)" + "# Constructed by Delete Last Used (took 1 milliseconds)\n", + "import torch\n", + "from torch import Tensor\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " t3, \\\n", + " t4, \\\n", + " t5, \\\n", + " t6, \\\n", + " t7, \\\n", + " t8, \\\n", + " t9, \\\n", + " t10, \\\n", + " t11, \\\n", + " t12, \\\n", + " t13, \\\n", + " t14, \\\n", + " t15, \\\n", + " t16, \\\n", + " t17, \\\n", + " t18, \\\n", + " t19, \\\n", + " t20, \\\n", + " t21, \\\n", + " t22, \\\n", + " t23, \\\n", + " t24, \\\n", + " t25, \\\n", + " t26, \\\n", + " t27, \\\n", + " t28, \\\n", + " t29, \\\n", + " t30, \\\n", + " t31, \\\n", + " t32, \\\n", + " t33, \\\n", + " = args\n", + " del args\n", + " t38 = torch.nn.functional.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t38 = ltorch.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t334 = ltorch.reshape(t0, [512]) # t334: \"cuda:0 i64[512]\"\n", + " # t334 = prims.reshape(t0, (512,)) # t334: \"cuda:0 i64[512]\"\n", + " # t335 = prims.take(t33, t334, 0) # t335: \"cuda:0 f32[512, 4096]\"\n", + " # t38 = ltorch.reshape(t335, [1, 512, 4096]) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t38 = prims.reshape(t335, (1, 512, 4096)) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", + " t34 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t34: \"cuda:0 f32[512, 128]\"\n", + " t35 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t35: \"cuda:0 f32[512, 128]\"\n", + " t374 = torch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", + " # t374 = ltorch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", + " # t374 = prims.broadcast_in_dim(t17, [1, 4096], [1]) # t374: \"cuda:0 f32[1, 4096]\"\n", + " t375 = torch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t375 = ltorch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t375 = prims.broadcast_in_dim(t374, [1, 1, 4096], [0, 2]) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t374\n", + " t47 = Tensor.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t47 = ltorch.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t47 = prims.broadcast_in_dim(t375, (1, 512, 4096), (0, 1, 2)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t375\n", + " t475 = torch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", + " # t475 = ltorch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", + " # t475 = prims.broadcast_in_dim(t24, [1, 4096], [1]) # t475: \"cuda:0 f32[1, 4096]\"\n", + " t476 = torch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t476 = ltorch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t476 = prims.broadcast_in_dim(t475, [1, 1, 4096], [0, 2]) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t475\n", + " t311 = Tensor.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t311 = ltorch.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t311 = prims.broadcast_in_dim(t476, (1, 512, 4096), (0, 1, 2)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t476\n", + " t478 = torch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", + " # t478 = ltorch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", + " # t478 = prims.broadcast_in_dim(t16, [1, 4096], [1]) # t478: \"cuda:0 f32[1, 4096]\"\n", + " t479 = torch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t479 = ltorch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t479 = prims.broadcast_in_dim(t478, [1, 1, 4096], [0, 2]) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t478\n", + " t331 = Tensor.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t331 = ltorch.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t331 = prims.broadcast_in_dim(t479, (1, 512, 4096), (0, 1, 2)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t479\n", + " t403 = torch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", + " # t403 = ltorch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", + " # t403 = prims.broadcast_in_dim(t21, [1, 4096], [1]) # t403: \"cuda:0 f32[1, 4096]\"\n", + " t404 = torch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t404 = ltorch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t404 = prims.broadcast_in_dim(t403, [1, 1, 4096], [0, 2]) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t403\n", + " t98 = Tensor.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t98 = ltorch.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t98 = prims.broadcast_in_dim(t404, (1, 512, 4096), (0, 1, 2)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t404\n", + " t406 = torch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", + " # t406 = ltorch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", + " # t406 = prims.broadcast_in_dim(t18, [1, 4096], [1]) # t406: \"cuda:0 f32[1, 4096]\"\n", + " t407 = torch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t407 = ltorch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t407 = prims.broadcast_in_dim(t406, [1, 1, 4096], [0, 2]) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t406\n", + " t118 = Tensor.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t118 = ltorch.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t118 = prims.broadcast_in_dim(t407, (1, 512, 4096), (0, 1, 2)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t407\n", + " t427 = torch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", + " # t427 = ltorch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", + " # t427 = prims.broadcast_in_dim(t22, [1, 4096], [1]) # t427: \"cuda:0 f32[1, 4096]\"\n", + " t428 = torch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t428 = ltorch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t428 = prims.broadcast_in_dim(t427, [1, 1, 4096], [0, 2]) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t427\n", + " t169 = Tensor.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t169 = ltorch.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t169 = prims.broadcast_in_dim(t428, (1, 512, 4096), (0, 1, 2)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t428\n", + " t430 = torch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", + " # t430 = ltorch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", + " # t430 = prims.broadcast_in_dim(t19, [1, 4096], [1]) # t430: \"cuda:0 f32[1, 4096]\"\n", + " t431 = torch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t431 = ltorch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t431 = prims.broadcast_in_dim(t430, [1, 1, 4096], [0, 2]) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t430\n", + " t189 = Tensor.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t189 = ltorch.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t189 = prims.broadcast_in_dim(t431, (1, 512, 4096), (0, 1, 2)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t431\n", + " t451 = torch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", + " # t451 = ltorch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", + " # t451 = prims.broadcast_in_dim(t23, [1, 4096], [1]) # t451: \"cuda:0 f32[1, 4096]\"\n", + " t452 = torch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t452 = ltorch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t452 = prims.broadcast_in_dim(t451, [1, 1, 4096], [0, 2]) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t451\n", + " t240 = Tensor.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t240 = ltorch.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t240 = prims.broadcast_in_dim(t452, (1, 512, 4096), (0, 1, 2)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t452\n", + " t454 = torch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", + " # t454 = ltorch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", + " # t454 = prims.broadcast_in_dim(t20, [1, 4096], [1]) # t454: \"cuda:0 f32[1, 4096]\"\n", + " t455 = torch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t455 = ltorch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", + " # t455 = prims.broadcast_in_dim(t454, [1, 1, 4096], [0, 2]) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", + " del t454\n", + " t260 = Tensor.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t260 = ltorch.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t260 = prims.broadcast_in_dim(t455, (1, 512, 4096), (0, 1, 2)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t455\n", + " t395 = torch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", + " # t395 = ltorch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", + " # t395 = prims.broadcast_in_dim(t34, [1, 512, 128], [1, 2]) # t395: \"cuda:0 f32[1, 512, 128]\"\n", + " del t34\n", + " t396 = torch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t396 = ltorch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t396 = prims.broadcast_in_dim(t395, [1, 1, 512, 128], [0, 2, 3]) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t395\n", + " t63 = Tensor.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t63 = ltorch.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t63 = prims.broadcast_in_dim(t396, (1, 32, 512, 128), (0, 1, 2, 3)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t396\n", + " t398 = torch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", + " # t398 = ltorch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", + " # t398 = prims.broadcast_in_dim(t35, [1, 512, 128], [1, 2]) # t398: \"cuda:0 f32[1, 512, 128]\"\n", + " del t35\n", + " t399 = torch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t399 = ltorch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t399 = prims.broadcast_in_dim(t398, [1, 1, 512, 128], [0, 2, 3]) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t398\n", + " t65 = Tensor.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t65 = ltorch.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t65 = prims.broadcast_in_dim(t399, (1, 32, 512, 128), (0, 1, 2, 3)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t399\n", + " [t44, t48] = nvFusion0(t38, t47)\n", + " # t39 = prims.mul(t38, t38) # t39: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t40 = prims.sum(t39, (2,)) # t40: \"cuda:0 f32[1, 512]\"\n", + " # t41 = prims.broadcast_in_dim(t40, [1, 512, 1], [0, 1]) # t41: \"cuda:0 f32[1, 512, 1]\"\n", + " # t42 = prims.div(t41, 4096.0) # t42: \"cuda:0 f32[1, 512, 1]\"\n", + " # t43 = prims.add(t42, 1e-05) # t43: \"cuda:0 f32[1, 512, 1]\"\n", + " # t44 = prims.rsqrt(t43) # t44: \"cuda:0 f32[1, 512, 1]\"\n", + " # t45 = prims.broadcast_in_dim(t44, (1, 512, 4096), (0, 1, 2)) # t45: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t46 = prims.mul(t38, t45) # t46: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t48 = prims.mul(t46, t47) # t48: \"cuda:0 f32[1, 512, 4096]\"\n", + " t49 = torch.nn.functional.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t49 = ltorch.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t49 = prims.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", + " t50 = torch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t50 = ltorch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t50 = prims.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " del t49\n", + " t51 = torch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t51 = ltorch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t51 = prims.transpose(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " del t50\n", + " (t52, t53, t54) = torch.split(t51, (1, 1, 1), 2)\n", + " # (t52, t53, t54) = ltorch.split(t51, (1, 1, 1), 2)\n", + " # t52 = prims.slice_prim(t51, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t52: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t53 = prims.slice_prim(t51, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t53: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t54 = prims.slice_prim(t51, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t54: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " del t51\n", + " t55 = torch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t55 = ltorch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t55 = prims.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t52\n", + " t56 = torch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t56 = ltorch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t56 = prims.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t53\n", + " t57 = torch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t57 = ltorch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t57 = prims.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t54\n", + " t58 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t58: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t68 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t68: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t78 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t78: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t55\n", + " t80 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t80: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t56\n", + " t60 = torch_slice_prim_impl(t58, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t59 = torch_slice_prim_impl(t58, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t59: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t69 = torch_slice_prim_impl(t68, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t69: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t70 = torch_slice_prim_impl(t68, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t70: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " [t61, t71] = nvFusion1(t60, t70)\n", + " # t61 = prims.neg(t60) # t61: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t71 = prims.neg(t70) # t71: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " del t60, t70\n", + " t62 = torch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t62 = ltorch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t62 = prims.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t61, t59\n", + " t72 = torch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t72 = ltorch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t72 = prims.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t71, t69\n", + " [t67, t77] = nvFusion2(t58, t62, t63, t65, t68, t72)\n", + " # t64 = prims.mul(t58, t63) # t64: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t66 = prims.mul(t62, t65) # t66: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t67 = prims.add(t64, t66) # t67: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t74 = prims.mul(t68, t63) # t74: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t76 = prims.mul(t72, t65) # t76: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t77 = prims.add(t74, t76) # t77: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t58, t62, t68, t72\n", + " t79 = torch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t79 = ltorch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t79 = prims.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t67, t78\n", + " t81 = torch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t81 = ltorch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t81 = prims.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t77, t80\n", + " (t82, t83, t84, t85) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t79, t81, t57, None, 0.0, True, 0.08838834764831843)\n", + " t86 = torch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t86 = ltorch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t86 = prims.transpose(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " t87 = torch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t87 = ltorch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t87 = prims.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t86\n", + " t88 = torch.nn.functional.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t88 = ltorch.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t88 = prims.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t89, t95, t99] = nvFusion3(t38, t88, t98)\n", + " # t89 = prims.add(t88, t38) # t89: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t90 = prims.mul(t89, t89) # t90: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t91 = prims.sum(t90, (2,)) # t91: \"cuda:0 f32[1, 512]\"\n", + " # t92 = prims.broadcast_in_dim(t91, [1, 512, 1], [0, 1]) # t92: \"cuda:0 f32[1, 512, 1]\"\n", + " # t93 = prims.div(t92, 4096.0) # t93: \"cuda:0 f32[1, 512, 1]\"\n", + " # t94 = prims.add(t93, 1e-05) # t94: \"cuda:0 f32[1, 512, 1]\"\n", + " # t95 = prims.rsqrt(t94) # t95: \"cuda:0 f32[1, 512, 1]\"\n", + " # t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t97 = prims.mul(t89, t96) # t97: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t99 = prims.mul(t97, t98) # t99: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t88\n", + " t101 = torch.nn.functional.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t101 = ltorch.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t101 = prims.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", + " t100 = torch.nn.functional.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t100 = ltorch.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t100 = prims.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", + " [t107] = nvFusion4(t100, t101)\n", + " # t102 = prims.neg(t100) # t102: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t103 = prims.exp(t102) # t103: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t104 = prims.add(1.0, t103) # t104: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t105 = prims.reciprocal(t104) # t105: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t106 = prims.mul(t100, t105) # t106: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t107 = prims.mul(t106, t101) # t107: \"cuda:0 f32[1, 512, 11008]\"\n", + " t108 = torch.nn.functional.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t108 = ltorch.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t108 = prims.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t109, t115, t119] = nvFusion5(t108, t118, t89)\n", + " # t109 = prims.add(t108, t89) # t109: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t110 = prims.mul(t109, t109) # t110: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t111 = prims.sum(t110, (2,)) # t111: \"cuda:0 f32[1, 512]\"\n", + " # t112 = prims.broadcast_in_dim(t111, [1, 512, 1], [0, 1]) # t112: \"cuda:0 f32[1, 512, 1]\"\n", + " # t113 = prims.div(t112, 4096.0) # t113: \"cuda:0 f32[1, 512, 1]\"\n", + " # t114 = prims.add(t113, 1e-05) # t114: \"cuda:0 f32[1, 512, 1]\"\n", + " # t115 = prims.rsqrt(t114) # t115: \"cuda:0 f32[1, 512, 1]\"\n", + " # t116 = prims.broadcast_in_dim(t115, (1, 512, 4096), (0, 1, 2)) # t116: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t117 = prims.mul(t109, t116) # t117: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t119 = prims.mul(t117, t118) # t119: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t108\n", + " t120 = torch.nn.functional.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t120 = ltorch.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t120 = prims.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", + " t121 = torch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t121 = ltorch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t121 = prims.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " del t120\n", + " t122 = torch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t122 = ltorch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t122 = prims.transpose(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " del t121\n", + " (t123, t124, t125) = torch.split(t122, (1, 1, 1), 2)\n", + " # (t123, t124, t125) = ltorch.split(t122, (1, 1, 1), 2)\n", + " # t123 = prims.slice_prim(t122, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t123: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t124 = prims.slice_prim(t122, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t124: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t125 = prims.slice_prim(t122, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t125: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " del t122\n", + " t126 = torch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t126 = ltorch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t126 = prims.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t123\n", + " t127 = torch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t127 = ltorch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t127 = prims.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t124\n", + " t128 = torch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t128 = ltorch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t128 = prims.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t125\n", + " t149 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t149: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " t151 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t151: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " t129 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t129: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t126\n", + " t139 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t139: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t127\n", + " t130 = torch_slice_prim_impl(t129, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t130: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t131 = torch_slice_prim_impl(t129, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t131: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t141 = torch_slice_prim_impl(t139, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t141: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t140 = torch_slice_prim_impl(t139, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t140: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " [t132, t142] = nvFusion6(t131, t141)\n", + " # t132 = prims.neg(t131) # t132: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t142 = prims.neg(t141) # t142: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " del t131, t141\n", + " t143 = torch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t143 = ltorch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t143 = prims.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t142, t140\n", + " t133 = torch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t133 = ltorch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t133 = prims.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t132, t130\n", + " [t138, t148] = nvFusion7(t129, t133, t139, t143, t63, t65)\n", + " # t145 = prims.mul(t139, t63) # t145: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t147 = prims.mul(t143, t65) # t147: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t148 = prims.add(t145, t147) # t148: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t135 = prims.mul(t129, t63) # t135: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t137 = prims.mul(t133, t65) # t137: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t138 = prims.add(t135, t137) # t138: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t129, t133, t139, t143\n", + " t150 = torch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t150 = ltorch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t150 = prims.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t138, t149\n", + " t152 = torch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t152 = ltorch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t152 = prims.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t148, t151\n", + " (t153, t154, t155, t156) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t150, t152, t128, None, 0.0, True, 0.08838834764831843)\n", + " t157 = torch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t157 = ltorch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t157 = prims.transpose(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " t158 = torch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t158 = ltorch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t158 = prims.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t157\n", + " t159 = torch.nn.functional.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t159 = ltorch.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t159 = prims.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t160, t166, t170] = nvFusion8(t109, t159, t169)\n", + " # t160 = prims.add(t159, t109) # t160: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t161 = prims.mul(t160, t160) # t161: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t162 = prims.sum(t161, (2,)) # t162: \"cuda:0 f32[1, 512]\"\n", + " # t163 = prims.broadcast_in_dim(t162, [1, 512, 1], [0, 1]) # t163: \"cuda:0 f32[1, 512, 1]\"\n", + " # t164 = prims.div(t163, 4096.0) # t164: \"cuda:0 f32[1, 512, 1]\"\n", + " # t165 = prims.add(t164, 1e-05) # t165: \"cuda:0 f32[1, 512, 1]\"\n", + " # t166 = prims.rsqrt(t165) # t166: \"cuda:0 f32[1, 512, 1]\"\n", + " # t167 = prims.broadcast_in_dim(t166, (1, 512, 4096), (0, 1, 2)) # t167: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t168 = prims.mul(t160, t167) # t168: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t170 = prims.mul(t168, t169) # t170: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t159\n", + " t172 = torch.nn.functional.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t172 = ltorch.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t172 = prims.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", + " t171 = torch.nn.functional.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t171 = ltorch.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t171 = prims.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", + " [t178] = nvFusion9(t171, t172)\n", + " # t173 = prims.neg(t171) # t173: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t174 = prims.exp(t173) # t174: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t175 = prims.add(1.0, t174) # t175: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t176 = prims.reciprocal(t175) # t176: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t177 = prims.mul(t171, t176) # t177: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t178 = prims.mul(t177, t172) # t178: \"cuda:0 f32[1, 512, 11008]\"\n", + " t179 = torch.nn.functional.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t179 = ltorch.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t179 = prims.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t180, t186, t190] = nvFusion10(t160, t179, t189)\n", + " # t180 = prims.add(t179, t160) # t180: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t181 = prims.mul(t180, t180) # t181: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t182 = prims.sum(t181, (2,)) # t182: \"cuda:0 f32[1, 512]\"\n", + " # t183 = prims.broadcast_in_dim(t182, [1, 512, 1], [0, 1]) # t183: \"cuda:0 f32[1, 512, 1]\"\n", + " # t184 = prims.div(t183, 4096.0) # t184: \"cuda:0 f32[1, 512, 1]\"\n", + " # t185 = prims.add(t184, 1e-05) # t185: \"cuda:0 f32[1, 512, 1]\"\n", + " # t186 = prims.rsqrt(t185) # t186: \"cuda:0 f32[1, 512, 1]\"\n", + " # t187 = prims.broadcast_in_dim(t186, (1, 512, 4096), (0, 1, 2)) # t187: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t188 = prims.mul(t180, t187) # t188: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t190 = prims.mul(t188, t189) # t190: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t179\n", + " t191 = torch.nn.functional.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t191 = ltorch.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t191 = prims.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", + " t192 = torch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t192 = ltorch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t192 = prims.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " del t191\n", + " t193 = torch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t193 = ltorch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t193 = prims.transpose(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " del t192\n", + " (t194, t195, t196) = torch.split(t193, (1, 1, 1), 2)\n", + " # (t194, t195, t196) = ltorch.split(t193, (1, 1, 1), 2)\n", + " # t194 = prims.slice_prim(t193, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t194: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t195 = prims.slice_prim(t193, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t195: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t196 = prims.slice_prim(t193, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t196: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " del t193\n", + " t197 = torch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t197 = ltorch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t197 = prims.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t194\n", + " t198 = torch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t198 = ltorch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t198 = prims.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t195\n", + " t199 = torch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t199 = ltorch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t199 = prims.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t196\n", + " t200 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t200: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t210 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t210: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t220 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t220: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t197\n", + " t222 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t222: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t198\n", + " t201 = torch_slice_prim_impl(t200, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t201: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t202 = torch_slice_prim_impl(t200, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t202: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t211 = torch_slice_prim_impl(t210, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t211: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t212 = torch_slice_prim_impl(t210, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t212: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " [t203, t213] = nvFusion11(t202, t212)\n", + " # t203 = prims.neg(t202) # t203: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " del t202, t212\n", + " t214 = torch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t214 = ltorch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t214 = prims.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t213, t211\n", + " t204 = torch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t204 = ltorch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t204 = prims.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t203, t201\n", + " [t209, t219] = nvFusion12(t200, t204, t210, t214, t63, t65)\n", + " # t216 = prims.mul(t210, t63) # t216: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t218 = prims.mul(t214, t65) # t218: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t219 = prims.add(t216, t218) # t219: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t206 = prims.mul(t200, t63) # t206: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t208 = prims.mul(t204, t65) # t208: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t209 = prims.add(t206, t208) # t209: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t200, t204, t210, t214\n", + " t223 = torch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t223 = ltorch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t223 = prims.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t219, t222\n", + " t221 = torch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t221 = ltorch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t221 = prims.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t209, t220\n", + " (t224, t225, t226, t227) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t221, t223, t199, None, 0.0, True, 0.08838834764831843)\n", + " t228 = torch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t228 = ltorch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t228 = prims.transpose(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " t229 = torch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t229 = ltorch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t229 = prims.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t228\n", + " t230 = torch.nn.functional.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t230 = ltorch.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t230 = prims.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t231, t237, t241] = nvFusion13(t180, t230, t240)\n", + " # t231 = prims.add(t230, t180) # t231: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t232 = prims.mul(t231, t231) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n", + " # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n", + " # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n", + " # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n", + " # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n", + " # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t239 = prims.mul(t231, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t241 = prims.mul(t239, t240) # t241: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t230\n", + " t242 = torch.nn.functional.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t242 = ltorch.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t242 = prims.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", + " t243 = torch.nn.functional.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t243 = ltorch.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t243 = prims.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", + " [t249] = nvFusion14(t242, t243)\n", + " # t244 = prims.neg(t242) # t244: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t245 = prims.exp(t244) # t245: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t246 = prims.add(1.0, t245) # t246: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t247 = prims.reciprocal(t246) # t247: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t248 = prims.mul(t242, t247) # t248: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t249 = prims.mul(t248, t243) # t249: \"cuda:0 f32[1, 512, 11008]\"\n", + " t250 = torch.nn.functional.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t250 = ltorch.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t250 = prims.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t251, t257, t261] = nvFusion15(t231, t250, t260)\n", + " # t251 = prims.add(t250, t231) # t251: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t252 = prims.mul(t251, t251) # t252: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t253 = prims.sum(t252, (2,)) # t253: \"cuda:0 f32[1, 512]\"\n", + " # t254 = prims.broadcast_in_dim(t253, [1, 512, 1], [0, 1]) # t254: \"cuda:0 f32[1, 512, 1]\"\n", + " # t255 = prims.div(t254, 4096.0) # t255: \"cuda:0 f32[1, 512, 1]\"\n", + " # t256 = prims.add(t255, 1e-05) # t256: \"cuda:0 f32[1, 512, 1]\"\n", + " # t257 = prims.rsqrt(t256) # t257: \"cuda:0 f32[1, 512, 1]\"\n", + " # t258 = prims.broadcast_in_dim(t257, (1, 512, 4096), (0, 1, 2)) # t258: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t259 = prims.mul(t251, t258) # t259: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t261 = prims.mul(t259, t260) # t261: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t250\n", + " t262 = torch.nn.functional.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t262 = ltorch.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", + " # t262 = prims.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", + " t263 = torch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t263 = ltorch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " # t263 = prims.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", + " del t262\n", + " t264 = torch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t264 = ltorch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " # t264 = prims.transpose(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", + " del t263\n", + " (t265, t266, t267) = torch.split(t264, (1, 1, 1), 2)\n", + " # (t265, t266, t267) = ltorch.split(t264, (1, 1, 1), 2)\n", + " # t265 = prims.slice_prim(t264, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t265: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t266 = prims.slice_prim(t264, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t266: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " # t267 = prims.slice_prim(t264, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t267: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", + " del t264\n", + " t268 = torch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t268 = ltorch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t268 = prims.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t265\n", + " t269 = torch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t269 = ltorch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t269 = prims.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t266\n", + " t270 = torch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t270 = ltorch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t270 = prims.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t267\n", + " t271 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t271: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t281 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " t291 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t291: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t268\n", + " t293 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t293: \"cuda:0 f32[1, 32, 512, 0]\"\n", + " del t269\n", + " t272 = torch_slice_prim_impl(t271, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t272: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t273 = torch_slice_prim_impl(t271, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t282 = torch_slice_prim_impl(t281, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t282: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " t283 = torch_slice_prim_impl(t281, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t283: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " [t274, t284] = nvFusion16(t273, t283)\n", + " # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t284 = prims.neg(t283) # t284: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " del t273, t283\n", + " t275 = torch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t275 = ltorch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t275 = prims.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t274, t272\n", + " t285 = torch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t285 = ltorch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t285 = prims.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t284, t282\n", + " [t280, t290] = nvFusion17(t271, t275, t281, t285, t63, t65)\n", + " # t277 = prims.mul(t271, t63) # t277: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t279 = prims.mul(t275, t65) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t280 = prims.add(t277, t279) # t280: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t287 = prims.mul(t281, t63) # t287: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t289 = prims.mul(t285, t65) # t289: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t290 = prims.add(t287, t289) # t290: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t271, t275, t281, t285\n", + " t292 = torch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t292 = ltorch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t292 = prims.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t280, t291\n", + " t294 = torch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t294 = ltorch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t294 = prims.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t290, t293\n", + " (t295, t296, t297, t298) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t292, t294, t270, None, 0.0, True, 0.08838834764831843)\n", + " t299 = torch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t299 = ltorch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " # t299 = prims.transpose(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", + " t300 = torch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t300 = ltorch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t300 = prims.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t299\n", + " t301 = torch.nn.functional.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t301 = ltorch.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t301 = prims.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t302, t308, t312] = nvFusion18(t251, t301, t311)\n", + " # t302 = prims.add(t301, t251) # t302: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t303 = prims.mul(t302, t302) # t303: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t304 = prims.sum(t303, (2,)) # t304: \"cuda:0 f32[1, 512]\"\n", + " # t305 = prims.broadcast_in_dim(t304, [1, 512, 1], [0, 1]) # t305: \"cuda:0 f32[1, 512, 1]\"\n", + " # t306 = prims.div(t305, 4096.0) # t306: \"cuda:0 f32[1, 512, 1]\"\n", + " # t307 = prims.add(t306, 1e-05) # t307: \"cuda:0 f32[1, 512, 1]\"\n", + " # t308 = prims.rsqrt(t307) # t308: \"cuda:0 f32[1, 512, 1]\"\n", + " # t309 = prims.broadcast_in_dim(t308, (1, 512, 4096), (0, 1, 2)) # t309: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t310 = prims.mul(t302, t309) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t312 = prims.mul(t310, t311) # t312: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t301\n", + " t314 = torch.nn.functional.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t314 = ltorch.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t314 = prims.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", + " t313 = torch.nn.functional.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t313 = ltorch.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t313 = prims.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", + " [t320] = nvFusion19(t313, t314)\n", + " # t315 = prims.neg(t313) # t315: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t316 = prims.exp(t315) # t316: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t317 = prims.add(1.0, t316) # t317: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t318 = prims.reciprocal(t317) # t318: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t319 = prims.mul(t313, t318) # t319: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t320 = prims.mul(t319, t314) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", + " t321 = torch.nn.functional.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t321 = ltorch.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t321 = prims.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", + " [t322, t328, t332] = nvFusion20(t302, t321, t331)\n", + " # t322 = prims.add(t321, t302) # t322: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t323 = prims.mul(t322, t322) # t323: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t324 = prims.sum(t323, (2,)) # t324: \"cuda:0 f32[1, 512]\"\n", + " # t325 = prims.broadcast_in_dim(t324, [1, 512, 1], [0, 1]) # t325: \"cuda:0 f32[1, 512, 1]\"\n", + " # t326 = prims.div(t325, 4096.0) # t326: \"cuda:0 f32[1, 512, 1]\"\n", + " # t327 = prims.add(t326, 1e-05) # t327: \"cuda:0 f32[1, 512, 1]\"\n", + " # t328 = prims.rsqrt(t327) # t328: \"cuda:0 f32[1, 512, 1]\"\n", + " # t329 = prims.broadcast_in_dim(t328, (1, 512, 4096), (0, 1, 2)) # t329: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t330 = prims.mul(t322, t329) # t330: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t332 = prims.mul(t330, t331) # t332: \"cuda:0 f32[1, 512, 4096]\"\n", + " del t321\n", + " t333 = torch.nn.functional.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", + " # t333 = ltorch.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", + " # t333 = prims.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", + " return {'output': t333, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33], 'flat_output': (t333,)}, ((t0, t10, t100, t101, t107, t109, t11, t115, t118, t119, t12, t128, t13, t14, t15, t150, t152, t153, t154, t155, t156, t158, t160, t166, t169, t170, t171, t172, t178, t180, t186, t189, t190, t199, t221, t223, t224, t225, t226, t227, t229, t231, t237, t240, t241, t242, t243, t249, t25, t251, t257, t26, t260, t261, t27, t270, t28, t29, t292, t294, t295, t296, t297, t298, t3, t30, t300, t302, t308, t31, t311, t312, t313, t314, t32, t320, t322, t328, t331, t332, t38, t4, t44, t47, t48, t5, t57, t6, t63, t65, t7, t79, t8, t81, t82, t83, t84, t85, t87, t89, t9, t95, t98, t99), (False, True, True, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2))" ] }, - "execution_count": 60, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "x = torch.randn(2, 2048, 4096, device=\"cuda\")\n", - "(tm(x) - m(x)).abs().max()\n" + "thunder.last_traces(thunder_model)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "4944f352", + "metadata": {}, + "source": [ + "Well, that is quite a bit to look through.\n", + "But here is a key thing: The function now returns a buch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", + "`thunder.last_backward_traces(thunder_model)[-1]`)." ] }, { "cell_type": "code", - "execution_count": 61, - "id": "a6f4b77c", + "execution_count": 10, + "id": "4d90df65", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.9922, 0.5946, -0.2173, ..., -0.0981, -0.5058, 0.2747],\n", + " [-1.1552, 0.5770, -0.7432, ..., 0.0688, 0.1238, 0.6786],\n", + " [-0.7813, 0.6960, 0.1235, ..., -0.4840, 0.1373, 0.6490],\n", + " ...,\n", + " [ 0.3711, 0.1656, 0.3350, ..., -0.0294, 0.3670, 0.5099],\n", + " [-0.2544, -0.8470, 0.2063, ..., -0.1341, 0.1877, 0.2612],\n", + " [ 0.3420, -1.1421, 0.9222, ..., 0.5636, 0.1666, 0.6947]]],\n", + " device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actual" + ] + }, + { + "cell_type": "markdown", + "id": "7dcec40f", + "metadata": {}, + "source": [ + "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6ba7f715", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "maximum deviation grads: 0.00042724609375\n" + ] + } + ], + "source": [ + "actual_grads = torch.autograd.grad(actual.sum(), m.parameters())\n", + "expected_grads = torch.autograd.grad(expected.sum(), m.parameters())\n", + "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + ] + }, + { + "cell_type": "markdown", + "id": "0261eb11", + "metadata": {}, + "source": [ + "But is it faster? Yes!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "854f29a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "154 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "150 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "import gc\n", + "gc.collect()\n", + "%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n", + "%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "eb177aad", + "metadata": {}, + "outputs": [], + "source": [ + "del m, thunder_model\n", + "import gc\n", + "gc.collect()\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "id": "1d31e7f8", + "metadata": {}, + "source": [ + "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!" + ] + }, + { + "cell_type": "markdown", + "id": "d23ebbf5", + "metadata": {}, + "source": [ + "## Distributed with Thunder\n", + "\n", + "Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.\n", + "\n", + "Happily Thunder sports an FSDP interface to use multiple cards in our box.\n", + "\n", + "You still need to setup the process group, but as far as the model is concerned,\n", + "\n", + "```python\n", + "model = thunder.jit(thunder.distributed.fsdp(model))\n", + "```\n", + "\n", + "is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though `torch-run`.\n", + "\n", + "Check out our LitGPT Thunder examples for complete distributed training and finetuning!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "18dd3379", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting zero_to_thunder_fsdp_simple_example.py\n" + ] + } + ], + "source": [ + "%%writefile zero_to_thunder_fsdp_simple_example.py\n", + "import sys\n", + "sys.path.insert(0, '..')\n", + "from thunder.tests.lit_gpt_model import GPT, Config\n", + "\n", + "import torch\n", + "import torch.distributed\n", + "import thunder\n", + "import thunder.distributed\n", + "import os\n", + "\n", + "# Create Model\n", + "# NOTE: We create the model on CPU.\n", + "device='cpu'\n", + "torch.set_default_dtype(torch.bfloat16)\n", + "model = GPT.from_name('llama2-like')\n", + "# Setup for distributed\n", + "torch.distributed.init_process_group(backend='nccl')\n", + "rank = int(os.environ[\"LOCAL_RANK\"])\n", + "\n", + "device = f\"cuda:{rank}\"\n", + "x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)\n", + "\n", + "# thunder.distributed.fsdp takes care of moving the parameter\n", + "# shard to the correct GPU for the current process.\n", + "model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n", + "\n", + "# Run the forward pass.\n", + "res = model(x)\n", + "res.sum().backward()\n", + "\n", + "res = model(x)\n", + "res.sum().backward()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2bad9b64", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] \r\n", + "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n", + "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \r\n", + "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n" + ] + } + ], + "source": [ + "!torchrun --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py" + ] + }, + { + "cell_type": "markdown", + "id": "9c65e75d", + "metadata": {}, + "source": [ + "So there. FSDP with just wrapping the model in `fsdp`." + ] + }, + { + "cell_type": "markdown", + "id": "4a6d7a20", + "metadata": {}, + "source": [ + "## Extending Thunder\n", + "\n", + "But we promised that thunder is extensible. Let's find out what's up with that.\n", + "\n", + "Specifically, we will incorporate the RMSNorm kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", + "\n", + "In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7639065", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[# Constructed by Augmented forward pass\n", - " import thunder\n", - " import thunder.core.prims as prims\n", - " import torch\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Transform for execution (took 2 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t3, t5, t6, t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Update Call Context (took 0 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t0, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Delete Last Used (took 0 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t0, t1, t7, t_proj_weight), ())]" + "my_ex" ] }, - "execution_count": 61, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "thunder.last_traces(tm)[0]" + "my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')\n", + "thunder.extend.register_executor(my_ex)" + ] + }, + { + "cell_type": "markdown", + "id": "a63595ab", + "metadata": {}, + "source": [ + "For our base implementation, we take the ccode from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n", + "\n", + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "247074b3", + "metadata": {}, + "outputs": [], + "source": [ + "from thunder import TensorProxy\n", + "\n", + "# Taken from LitGPT, who in turn credit:\n", + "# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n", + "# https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n", + "\n", + "def rms_norm_impl(x: torch.Tensor, weight, dim: int, eps: float, add_unit_offset: bool) -> torch.Tensor:\n", + " dtype = x.dtype\n", + " x = x.float()\n", + " # NOTE: the original RMSNorm paper implementation is not equivalent\n", + " norm_x = torch.mean(x * x, dim=dim, keepdim=True)\n", + " x_normed = x * torch.rsqrt(norm_x + eps)\n", + " x_normed = x_normed.to(dtype=dtype)\n", + " if add_unit_offset:\n", + " # Gemma model requires a unit offset\n", + " # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176\n", + " return x_normed * (1 + weight)\n", + " return x_normed * weight\n", + "\n", + "def rms_norm_meta(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool) -> TensorProxy:\n", + " return TensorProxy(like=x)\n", + "\n", + "rms_norm = my_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=rms_norm_impl)\n" + ] + }, + { + "cell_type": "markdown", + "id": "75ad1dbf", + "metadata": {}, + "source": [ + "Because evil monkey-patching is a thing for short demos is a thing, let's replace LitGPT's own implementation. For your own model, you might start out with a that in your code directly." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e0bdecd3", + "metadata": {}, + "outputs": [], + "source": [ + "import lit_gpt.rmsnorm\n", + "if not hasattr(lit_gpt.rmsnorm, 'ThunderOrigRMSNorm'):\n", + " lit_gpt.rmsnorm.ThunderOrigRMSNorm = lit_gpt.rmsnorm.RMSNorm\n", + "\n", + "class ThunderizedRMSNorm(lit_gpt.rmsnorm.ThunderOrigRMSNorm):\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " # This isn't the best paradigm. :/\n", + " if thunder.core.interpreter.is_jitting():\n", + " return rms_norm(x, self.weight, self.dim, self.eps, self.add_unit_offset)\n", + " else:\n", + " return super().forward(x)\n", + "\n", + "lit_gpt.rmsnorm.RMSNorm = ThunderizedRMSNorm" + ] + }, + { + "cell_type": "markdown", + "id": "d6b7d056", + "metadata": {}, + "source": [ + "We can try our new RMSNorm: " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0ebd5dd1", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 0.0\n" + ] + }, + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_weight):\n", + " # x: \"cuda:0 f32[256, 4096]\" \n", + " # t_weight: \"cuda:0 f32[4096]\" \n", + " t7 = rms_norm(x, t_weight, -1, 1e-06, False) # t7: \"cuda:0 f32[256, 4096]\"\n", + " del x, t_weight\n", + " return t7" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.device('cuda'):\n", + " norm_module = ThunderizedRMSNorm(4096)\n", + " x = torch.randn(256, 4096)\n", + "\n", + "# we're not quite there to handle forward and backward yet, we'll re-enable them below\n", + "for p in norm_module.parameters(): \n", + " p.requires_grad_(False)\n", + "\n", + "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "\n", + "expected = norm_module(x)\n", + "actual = thunder_norm_module(x)\n", + "\n", + "print(\"deviation:\", (expected - actual).abs().max().item())\n", + "\n", + "thunder.last_traces(thunder_norm_module)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "8c620a38", + "metadata": {}, + "source": [ + "But why did we do this? Well, we can now layer a faster implementation on top.\n", + "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a7a26f5f", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import triton\n", + "import triton.language as tl\n", + "import torch\n", + "\n", + "MAX_FUSED_SIZE = 65536\n", + "next_power_of_2 = triton.next_power_of_2\n", + "\n", + "def calculate_settings(n):\n", + " BLOCK_SIZE = next_power_of_2(n)\n", + " if BLOCK_SIZE > MAX_FUSED_SIZE:\n", + " raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n", + " f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n", + " num_warps = 4\n", + " if BLOCK_SIZE >= 32768: num_warps = 32\n", + " elif BLOCK_SIZE >= 8192: num_warps = 16\n", + " elif BLOCK_SIZE >= 2048: num_warps = 8\n", + " return BLOCK_SIZE, num_warps\n", + "\n", + "@triton.jit\n", + "def _rms_layernorm_forward(\n", + " Y, Y_row_stride,\n", + " X, X_row_stride,\n", + " W, W_row_stride,\n", + " r, r_row_stride,\n", + " n_cols, eps,\n", + " BLOCK_SIZE : tl.constexpr\n", + "):\n", + " \"\"\"\n", + " Fast RMS Layernorm kernel\n", + " Inspiration from a Triton tutorial:\n", + " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", + " \"\"\"\n", + " row_idx = tl.program_id(0)\n", + " col_offsets = tl.arange(0, BLOCK_SIZE)\n", + " mask = col_offsets < n_cols\n", + "\n", + " Y += row_idx * Y_row_stride\n", + " X += row_idx * X_row_stride\n", + " r += row_idx * r_row_stride\n", + "\n", + " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", + " W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)\n", + "\n", + " row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n", + " inv_var = tl.math.rsqrt(row_var + eps)\n", + " tl.store(r, inv_var)\n", + " normed = X_row * inv_var\n", + " normed = normed.to(W_row.dtype) # Exact copy from HF\n", + " output = normed * W_row\n", + " tl.store(Y + col_offsets, output, mask = mask)\n", + "\n", + "\n", + "@triton.jit\n", + "def _rms_layernorm_backward(\n", + " dY, dY_row_stride,\n", + " X, X_row_stride,\n", + " W, W_row_stride,\n", + " r, r_row_stride,\n", + " dW, dW_row_stride,\n", + " n_cols, eps,\n", + " BLOCK_SIZE : tl.constexpr,\n", + "):\n", + " \"\"\"\n", + " Fast RMS Layernorm kernel for the backward pass\n", + " Inspiration from a Triton tutorial:\n", + " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", + " \"\"\"\n", + " row_idx = tl.program_id(0)\n", + " col_offsets = tl.arange(0, BLOCK_SIZE)\n", + " mask = col_offsets < n_cols\n", + "\n", + " dY += row_idx * dY_row_stride\n", + " X += row_idx * X_row_stride\n", + " r += row_idx * r_row_stride\n", + "\n", + " dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n", + " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", + " W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n", + "\n", + " # Get saved row variance\n", + " inv_var = tl.load(r).to(tl.float32)\n", + " normed = X_row * inv_var\n", + "\n", + " dY_W = dY_row * W_row\n", + "\n", + " rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n", + " output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n", + " tl.store(dY + col_offsets, output, mask = mask)\n", + " \n", + "def rms_layernorm_forward_impl(X, W, eps):\n", + " shape = X.shape\n", + " dim = shape[-1]\n", + " X = X.view(-1, dim)\n", + " n_rows, n_cols = X.shape\n", + " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", + "\n", + " Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n", + " r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n", + "\n", + " _rms_layernorm_forward[(n_rows,)](\n", + " Y, Y.stride(0),\n", + " X, X.stride(0),\n", + " W, W.stride(0),\n", + " r, r.stride(0),\n", + " n_cols, eps,\n", + " BLOCK_SIZE = BLOCK_SIZE,\n", + " num_warps = num_warps,\n", + " )\n", + " return Y.view(*shape), (r, BLOCK_SIZE, num_warps)\n", + "\n", + "def rms_layernorm_forward_meta(X, W, eps):\n", + " n_cols = X.shape[-1]\n", + " n_rows = 1\n", + " for i in X.shape[:-1]:\n", + " n_rows *= i\n", + " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", + " Y = TensorProxy(like=X, requires_grad=True)\n", + " return (Y,\n", + " (TensorProxy(shape=(n_rows,), device=X.device, dtype=thunder.dtypes.float32, requires_grad=False),\n", + " BLOCK_SIZE, \n", + " num_warps,\n", + " )\n", + " )\n", + "\n", + "def rms_layernorm_backward_impl(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", + " shape = dY.shape\n", + " dim = shape[-1]\n", + " dY = dY.view(-1, dim)\n", + " n_rows, n_cols = dY.shape\n", + " dW = X\n", + " dX = dY.clone()\n", + " _rms_layernorm_backward[(n_rows,)](\n", + " dX, dX.stride(0),\n", + " X, X .stride(0),\n", + " W, W .stride(0),\n", + " r, r .stride(0),\n", + " dW, dW.stride(0),\n", + " n_cols, eps,\n", + " BLOCK_SIZE = BLOCK_SIZE,\n", + " num_warps = num_warps,\n", + " )\n", + " dX = dX.view(*shape)\n", + " return dX\n", + "\n", + "def rms_layernorm_backward_meta(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", + " return TensorProxy(like=dY)" + ] + }, + { + "cell_type": "markdown", + "id": "b70eba5f", + "metadata": {}, + "source": [ + "With this, we can just register the additional operators:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f8f1e77e", + "metadata": {}, + "outputs": [], + "source": [ + "unsloth_rms_norm_forward = my_ex.register_operator('unsloth_rms_norm_forward', meta=rms_layernorm_forward_meta, fn=rms_layernorm_forward_impl)\n", + "unsloth_rms_norm_backward = my_ex.register_operator('unsloth_rms_norm_backward', meta=rms_layernorm_backward_meta, fn=rms_layernorm_backward_impl)" + ] + }, + { + "cell_type": "markdown", + "id": "2426263d", + "metadata": {}, + "source": [ + "But instead of monkey-patching more, we can now register the kernel as an _implementation_ of the base `rms_norm` primitive defined above. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`rms_norm`) in terms of our new operator - so it has the call signature of the `rms_norm`. Because - like many fast implementations - the unsloth RMS norm does not implement the operator in full generality (to do them justice, they have a variant adding the unit offset, we just didn't copy it over), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "6b5c8320", + "metadata": {}, + "outputs": [], + "source": [ + "def rms_norm_to_unsloth(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", + " assert dim == -1 and not add_unit_offset\n", + " res, _ = unsloth_rms_norm_forward(x, weight, eps)\n", + " return res\n", + "\n", + "def rms_norm_to_unsloth_checker(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", + " if dim != -1 or add_unit_offset:\n", + " return False\n", + " if weight.requires_grad:\n", + " return False # the unsloth rms norm backwward only gives the grad w.r.t. x\n", + " return x.device.devicetype == thunder.devices.DeviceType.CUDA and weight.device.devicetype == thunder.devices.DeviceType.CUDA\n", + "\n", + "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker, execution_transform=rms_norm_to_unsloth)\n" + ] + }, + { + "cell_type": "markdown", + "id": "eec7c95a", + "metadata": {}, + "source": [ + "So let us give that a try! Works great..." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "965ba1d7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 9.5367431640625e-07\n" + ] + }, + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_weight):\n", + " # x: \"cuda:0 f32[2048, 4096]\" \n", + " # t_weight: \"cuda:0 f32[4096]\" \n", + " (t7, (_, _, _)) = unsloth_rms_norm_forward(x, t_weight, 1e-06)\n", + " del x, t_weight\n", + " return t7" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.device('cuda'):\n", + " norm_module = ThunderizedRMSNorm(4096)\n", + "\n", + "# unfortunately, we meet dragons if we don't do this at this stage\n", + "for p in norm_module.parameters(): \n", + " p.requires_grad_(False)\n", + "\n", + "thunder_norm_module = thunder.jit(norm_module, executors=[my_ex,]) \n", + "x = torch.randn(2048, 4096, device=\"cuda\")\n", + "\n", + "expected = norm_module(x)\n", + "actual = thunder_norm_module(x)\n", + "\n", + "print(\"deviation:\", (expected - actual).abs().max().item())\n", + "\n", + "thunder.last_traces(thunder_norm_module)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "0e3e4d85", + "metadata": {}, + "source": [ + "And this is also automatic when we instantiate a larger llama2-like model:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7fff2522", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 4.76837158203125e-07\n" + ] + } + ], + "source": [ + "torch.set_default_dtype(torch.float32)\n", + "with torch.device('cuda'):\n", + " m = GPT(Config.from_name('llama2-like'))\n", + "\n", + "for p in m.parameters():\n", + " p.requires_grad_(False)\n", + "\n", + "thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())\n", + "\n", + "inp = torch.randint(1, m.config.vocab_size, (1, 128), device=\"cuda\")\n", + "actual = thunder_model(inp)\n", + "expected = m(inp)\n", + "\n", + "print(\"deviation:\", (actual - expected).abs().max().item())" + ] + }, + { + "cell_type": "markdown", + "id": "b538cb40", + "metadata": {}, + "source": [ + "By peeking into the trace, we can see that it actually used the unsloth RMS kernels:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "c260cb25", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[' (n_1, (_, _, _)) = unsloth_rms_norm_forward(x, t_transformer_h_0_norm_1_weight, 1e-05)',\n", + " ' (t110, (_, _, _)) = unsloth_rms_norm_forward(t102, t_transformer_h_0_norm_2_weight, 1e-05)',\n", + " ' (t139, (_, _, _)) = unsloth_rms_norm_forward(t130, t_transformer_h_1_norm_1_weight, 1e-05)',\n", + " ' (t215, (_, _, _)) = unsloth_rms_norm_forward(t207, t_transformer_h_1_norm_2_weight, 1e-05)',\n", + " ' (t243, (_, _, _)) = unsloth_rms_norm_forward(t235, t_transformer_ln_f_weight, 1e-05)']" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'rms' in s]" + ] + }, + { + "cell_type": "markdown", + "id": "0f6c0780", + "metadata": {}, + "source": [ + "But what about the backward?\n", + "\n", + "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7670a872", + "metadata": {}, + "outputs": [], + "source": [ + "from thunder.core.transforms import get_grad, put_grads\n", + "\n", + "def unsloth_rms_norm_grad(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool):\n", + " res, (r, BLOCK_SIZE, num_warps) = unsloth_rms_norm_forward(x, weight, eps)\n", + " grad_res = get_grad(res)\n", + " grad_x = unsloth_rms_norm_backward(x, weight, r, eps, BLOCK_SIZE, num_warps, grad_res)\n", + " put_grads((x,), (grad_x,))\n", + " return res\n", + "\n", + "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker,\n", + " execution_transform=rms_norm_to_unsloth,\n", + " grad_transform=unsloth_rms_norm_grad \n", + " )\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "d31aced0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([4096]) torch.Size([256]) torch.Size([256, 4096])\n", + "(4096, 1) (4096, 1) (1,) (1,) (4096, 1)\n", + "maximum deviation grads: 3.5762786865234375e-07\n" + ] + } + ], + "source": [ + "with torch.device('cuda'):\n", + " norm_module = ThunderizedRMSNorm(4096)\n", + " norm_module.weight.requires_grad_(False)\n", + " x = torch.randn(256, 4096, requires_grad=True)\n", + "\n", + "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "\n", + "actual = thunder_norm_module(x)\n", + "expected = norm_module(x)\n", + "actual_grads = torch.autograd.grad(actual.sum(), x)\n", + "expected_grads = torch.autograd.grad(expected.sum(), x)\n", + "\n", + "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + ] + }, + { + "cell_type": "markdown", + "id": "be218e9d", + "metadata": {}, + "source": [ + "And here is our module having the unsloth backward:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "ac00153b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def backward_fn(saved_for_backward, cotangents):\n", + " # saved_for_backward: \"Collection\" \n", + " # cotangents: \"Collection\" \n", + " C0, \\\n", + " C1, \\\n", + " = saved_for_backward\n", + " clear_collection(saved_for_backward)\n", + " del saved_for_backward\n", + " t4, \\\n", + " = cotangents\n", + " clear_collection(cotangents)\n", + " del cotangents\n", + " t0, \\\n", + " t1, \\\n", + " t3, \\\n", + " = C0\n", + " clear_collection(C0)\n", + " del C0\n", + " f0, \\\n", + " = C1\n", + " clear_collection(C1)\n", + " del C1\n", + " t2 = unsloth_rms_norm_backward(t0, t1, t3, f0, 4096, 8, t4) # t2: \"cuda:0 f32[256, 4096]\"\n", + " del t0, t1, t3, f0, t4\n", + " return (t2, None)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_backward_traces(thunder_norm_module)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "26ac79f0", + "metadata": {}, + "source": [ + "That's it! Do check out our LitGPT studios and the other tutorial notebooks.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "ce4217b7", + "id": "586cdd30", "metadata": {}, "outputs": [], "source": [] diff --git a/thunder/__init__.py b/thunder/__init__.py index 770507697c..5d6f698a76 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -344,9 +344,10 @@ def jit( if additional_transforms is None: additional_transforms = [] - # Make sharp_edges == warn default if not supplied and if in the general jit - if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None: - sharp_edges = SHARP_EDGES_OPTIONS.WARN + # TODO: verify that tutorials don't have false positives and enable warning by default + # # Make sharp_edges == warn default if not supplied and if in the general jit + # if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None: + # sharp_edges = SHARP_EDGES_OPTIONS.WARN # TODO RC1 Refine the compile data option to remove unused options cd = CompileData( From 7071ccddeeb5199fc3730a728676dc39400a0176 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 16 Mar 2024 21:24:31 +0100 Subject: [PATCH 18/44] refresh notebooks (PR2468) --- docs/source/index.rst | 1 - notebooks/.ignore.ci | 3 - notebooks/adding_custom_operator.ipynb | 495 +++++++++------- notebooks/adding_operator_executor.ipynb | 688 ----------------------- 4 files changed, 301 insertions(+), 886 deletions(-) delete mode 100644 notebooks/adding_operator_executor.ipynb diff --git a/docs/source/index.rst b/docs/source/index.rst index c019b44e7e..3ce2ca87c0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -110,7 +110,6 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik Extending thunder notebooks/adding_custom_operator notebooks/adding_custom_operator_backward - notebooks/adding_operator_executor .. toctree:: :maxdepth: 1 diff --git a/notebooks/.ignore.ci b/notebooks/.ignore.ci index 10cac1cc86..8655f707bb 100644 --- a/notebooks/.ignore.ci +++ b/notebooks/.ignore.ci @@ -1,5 +1,2 @@ -adding_custom_operator.ipynb adding_custom_operator_backward.ipynb -adding_operator_executor.ipynb dev_tutorials/extend.ipynb -dev_tutorials/patterns.ipynb diff --git a/notebooks/adding_custom_operator.ipynb b/notebooks/adding_custom_operator.ipynb index c293bc3167..7515a36c15 100644 --- a/notebooks/adding_custom_operator.ipynb +++ b/notebooks/adding_custom_operator.ipynb @@ -26,6 +26,14 @@ "from enum import Enum" ] }, + { + "cell_type": "markdown", + "id": "a1b6863a", + "metadata": {}, + "source": [ + "Let us define some helper functions (execute the cell below) for printing what's going on." + ] + }, { "cell_type": "code", "execution_count": 2, @@ -33,7 +41,6 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Helper functions (execute this cell)\n", "import functools\n", "\n", "_indentation = 0\n", @@ -83,135 +90,92 @@ }, { "cell_type": "markdown", - "id": "a06c6260", + "id": "c8e1626f", "metadata": {}, "source": [ "Our new operator has the following signature `sincos(x: Tensor) -> Tuple[Tensor, Tensor]`. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.\n", "\n", - "We call all callables that should be recorded in the trace Symbols. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n", + "We call all callables that should be recorded in the trace *Symbols*. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n", + "\n", + "The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like give the metadata of the inputs and an implementation (dealing with concrete objects like Python `Number`s and PyTorch `Tensor`s) and register both of them through an executor. This will automatically create a symbol for us.\n", "\n", - "Let's create a new Symbol called `sincos` and implement it in Python." + "So we create an executor:" ] }, { "cell_type": "code", "execution_count": 3, - "id": "764c203a", + "id": "f680ae37", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on class Symbol in module thunder.core.symbol:\n", - "\n", - "class Symbol(builtins.object)\n", - " | Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None\n", - " | \n", - " | Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False)\n", - " | \n", - " | Methods defined here:\n", - " | \n", - " | __call__(self, *args, **kwargs)\n", - " | Call self as a function.\n", - " | \n", - " | __delattr__(self, name)\n", - " | Implement delattr(self, name).\n", - " | \n", - " | __eq__(self, other: 'Symbol') -> 'int'\n", - " | Return self==value.\n", - " | \n", - " | __getstate__ = _dataclass_getstate(self)\n", - " | # _dataclass_getstate and _dataclass_setstate are needed for pickling frozen\n", - " | # classes with slots. These could be slightly more performant if we generated\n", - " | # the code instead of iterating over fields. But that can be a project for\n", - " | # another day, if performance becomes an issue.\n", - " | \n", - " | __hash__(self) -> 'int'\n", - " | Return hash(self).\n", - " | \n", - " | __init__(self, name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None\n", - " | Initialize self. See help(type(self)) for accurate signature.\n", - " | \n", - " | __repr__(self) -> 'str'\n", - " | Return repr(self).\n", - " | \n", - " | __setattr__(self, name, value)\n", - " | Implement setattr(self, name, value).\n", - " | \n", - " | __setstate__ = _dataclass_setstate(self, state)\n", - " | \n", - " | bind(self, *args, output, subsymbols=(), _call_ctx=None, **kwargs) -> 'BoundSymbol'\n", - " | \n", - " | name_with_module(self)\n", - " | \n", - " | normalize(self, *args, **kwargs)\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Readonly properties defined here:\n", - " | \n", - " | module\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Data descriptors defined here:\n", - " | \n", - " | __weakref__\n", - " | list of weak references to the object (if defined)\n", - " | \n", - " | id\n", - " | \n", - " | is_fusion\n", - " | \n", - " | is_prim\n", - " | \n", - " | meta\n", - " | \n", - " | name\n", - " | \n", - " | python_impl\n", - " | \n", - " | python_printer\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Data and other attributes defined here:\n", - " | \n", - " | __annotations__ = {'_bind_postprocess': 'None | Callable', '_hash': 'O...\n", - " | \n", - " | __dataclass_fields__ = {'_bind_postprocess': Field(name='_bind_postpro...\n", - " | \n", - " | __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,or...\n", - " | \n", - " | __match_args__ = ('name', 'meta', 'python_impl', 'id', 'is_prim', 'is_...\n", - "\n" - ] + "data": { + "text/plain": [ + "[sincos_executor, sdpa]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from thunder.core.symbol import Symbol\n", - "\n", - "help(Symbol)" + "sincos_executor = thunder.extend.OperatorExecutor(\"sincos_executor\", version='0.1')\n", + "thunder.add_default_executor(sincos_executor)" + ] + }, + { + "cell_type": "markdown", + "id": "4f147274", + "metadata": {}, + "source": [ + "We define meta and implementation: " ] }, { "cell_type": "code", "execution_count": 4, - "id": "ba10b306", + "id": "d5a72aff", "metadata": {}, "outputs": [], "source": [ "@log\n", - "def sincos_meta(input):\n", - " return (TensorProxy(like=input), TensorProxy(like=input))\n", + "def sincos_meta(inp):\n", + " return (TensorProxy(like=inp), TensorProxy(like=inp))\n", "\n", - "class CustomOps(Enum):\n", - " sincos = 0\n", - "\n", - "sincos = Symbol(\n", - " id=CustomOps.sincos,\n", - " name=\"sincos\",\n", - " meta=sincos_meta,\n", - " is_prim=True,\n", - ")" + "@log\n", + "def sincos_impl(inp):\n", + " return torch.sin(inp), torch.cos(inp)" + ] + }, + { + "cell_type": "markdown", + "id": "a06c6260", + "metadata": {}, + "source": [ + "And register it as `sincos`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "03516b03", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Symbol name=sincos]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sincos = sincos_executor.register_operator('sincos', meta=sincos_meta, fn=sincos_impl)\n", + "sincos" ] }, { @@ -224,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "8c5da6f2", "metadata": {}, "outputs": [], @@ -236,13 +200,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "aef98360", "metadata": {}, "outputs": [], "source": [ - "a = torch.randn(1, device=\"cuda\")\n", - "b = torch.randn(1, device=\"cuda\")" + "a = torch.randn(1)\n", + "b = torch.randn(1)" ] }, { @@ -255,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "87f9f6e7", "metadata": {}, "outputs": [ @@ -263,7 +227,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Couldn't find an eager implementation for sincos\n" + "Attempting to execute outside of a tracing context, which is not supported\n" ] } ], @@ -284,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "f938dff7-bac6-4807-b79d-a16cb5c6d90c", "metadata": {}, "outputs": [ @@ -292,23 +256,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "# import __main__ as __main__\n", - "# import thunder as thunder\n", - "# import thunder.torch as ltorch\n", + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.torch as ltorch\n", "import torch\n", + "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", + "@no_autocast()\n", "def fun(a, b):\n", - " # a: \"cuda:0 f32[1]\" \n", - " # b: \"cuda:0 f32[1]\" \n", - " (t0, t1) = __main__.sincos(a)\n", - " t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cuda:0 f32[1]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"\n", - " t3 = ltorch.add(t2, b, alpha=None) # t3: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " (t0, t1) = sincos(a)\n", + " t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", + " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"\n", + " t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"\n", " return t3\n" ] } @@ -321,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "6eb4818b", "metadata": {}, "outputs": [ @@ -329,17 +295,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cuda:0 f32[1]\" |\n", - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cuda:0 f32[1]\" |\n", - "Bound symbol with id=CustomOps.sincos is represented in the trace as |(t0, t1) = __main__.sincos(a)|\n", - "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cuda:0 f32[1]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"|\n", + "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cpu f32[1]\" |\n", + "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cpu f32[1]\" |\n", + "Bound symbol with id=sincos is represented in the trace as |(t0, t1) = sincos(a)|\n", + "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", + " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"|\n", - "Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None) # t3: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"|\n", + " id=PrimIDs.ADD |t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", + "Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"|\n", + " id=PrimIDs.ADD |t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", "Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t3|\n" ] } @@ -364,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "41566de2-a60f-4c87-a3d6-58e6a89dc38b", "metadata": {}, "outputs": [], @@ -374,151 +340,292 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "bbbb90c2", + "execution_count": 12, + "id": "24af4b99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "Could not find executor for bound symbol (t0, t1) = __main__.sincos(a)\n" + "call sincos_impl(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1413]))\n", + "|<- sincos_impl = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1408]), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.9900]))\n", + "\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NoneType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type bool, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type SequenceIter, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type int, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NotImplementedType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type StopIteration, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([0.7666])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "try:\n", - " cfun(a, b)\n", - "except RuntimeError as e:\n", - " print(e)" + "cfun(a, b)" ] }, { "cell_type": "markdown", - "id": "3b1fd6e3", + "id": "d7cec09d", "metadata": {}, "source": [ - "There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that." + "Let's check how our function is represented in the execution trace now (change to `thunder.last_traces(cfun)[0]` to see the trace before transformations)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a7ff30ef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(a, b):\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " (res, cos) = sincos(a)\n", + " del a\n", + " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", + " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", + " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", + " del res, cos\n", + " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", + " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", + " del result, b\n", + " return t3" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_traces(cfun)[-1]" ] }, { "cell_type": "markdown", - "id": "026680b3-7b46-4f4b-b16b-641fa9bdcdf4", + "id": "35b71375", "metadata": {}, "source": [ - "Check out the \"adding-operator-executor.ipynb\" notebook to see how to implement an executor for a Symbol." + "For a peek under the hood, we can also first create a new symbol (without reference to an executor) and then register an executor for that.\n" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "2460f808-eacb-4a0f-8f62-6a17e3dce6e8", + "execution_count": 14, + "id": "f28094bb", "metadata": {}, "outputs": [], "source": [ - "from thunder.executors import add_operator_executor\n", - "\n", - "@log\n", - "def checker_sincos(a):\n", - " # We allow the sincos function to be called with any tensor\n", - " return True\n", - "\n", + "from thunder.core.symbol import Symbol\n", "@log\n", - "def executor_sincos(a):\n", - " return torch.sin(a), torch.cos(a)\n", + "def sincos_meta(input):\n", + " return (TensorProxy(like=input), TensorProxy(like=input))\n", "\n", - "op_map = {\n", - " CustomOps.sincos: (\"sincos\", checker_sincos, executor_sincos)\n", - "}\n", + "# this gives a nice, unique, printable id\n", + "class CustomOps(Enum):\n", + " sincos2 = 0\n", "\n", - "add_operator_executor(\"sincos_executor\", op_map, add_to_default_executors=True)" + "sincos2 = Symbol(\n", + " id=CustomOps.sincos2,\n", + " name=\"sincos2\",\n", + " meta=sincos_meta,\n", + " is_prim=True,\n", + ")" ] }, { "cell_type": "code", - "execution_count": 13, - "id": "d864fa05", + "execution_count": 15, + "id": "7fbab758", "metadata": {}, "outputs": [], "source": [ - "# Let's try again\n", - "cfun = thunder.compile(fun, disable_preprocessing=True)" + "def fun2(a, b):\n", + " sin, cos = sincos2(a)\n", + " return sin + cos + b\n", + "\n", + "cfun2 = thunder.jit(fun2)" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "24af4b99", + "execution_count": 16, + "id": "950d74ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- checker_sincos = True\n", + "Failed to find an executor for bound symbol bsym=(res, cos) = __main__.sincos2(a)\n" + ] + } + ], + "source": [ + "try:\n", + " cfun2(a, b)\n", + "except RuntimeError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "aadcf2a9", + "metadata": {}, + "source": [ + "There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that." + ] + }, + { + "cell_type": "markdown", + "id": "995febba", + "metadata": {}, + "source": [ + "Check out the \"adding-operator-executor.ipynb\" notebook to see how to implement an executor for a Symbol." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "956a4a6b", + "metadata": {}, + "outputs": [], + "source": [ + "@log\n", + "def checker_sincos2(a):\n", + " # We allow the sincos function to be called with any tensor\n", + " return True\n", + "\n", + "@log\n", + "def executor_sincos2(a):\n", + " # we need to have something here works with TensorProxies during the transformations,\n", + " # so we need to functions from thunder.torch or thunder.clang or other Symbols \n", + " return thunder.torch.sin(a), thunder.torch.cos(a)\n", + "\n", + "sincos_executor.register_implementation(sincos2, checker=checker_sincos2, execution_transform=executor_sincos2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "1c77c508", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- checker_sincos = True\n", + "call checker_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- checker_sincos2 = True\n", "\n", - "call executor_sincos(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.6296], device='cuda:0'))\n", - "|<- executor_sincos = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.5889], device='cuda:0'), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([0.8082], device='cuda:0'))\n", + "call executor_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- executor_sincos2 = (TensorProxy(name=t4, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t5, shape=(1,), dtype=float32, device=cpu))\n", "\n" ] }, { "data": { "text/plain": [ - "tensor([0.1889], device='cuda:0')" + "tensor([0.7666])" ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "cfun(a, b)" + "# Let's try again\n", + "cfun2 = thunder.jit(fun2)\n", + "cfun2(a, b)" ] }, { "cell_type": "code", - "execution_count": 15, - "id": "a7ff30ef", + "execution_count": 19, + "id": "f9797cf2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "# Constructed by Delete Last Used\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[1]\" \n", - " # b: \"cuda:0 f32[1]\" \n", - " (t0, t1) = sincos(a)\n", - " del [a]\n", - " (t3,) = nvFusion0(b, t0, t1)\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"\n", - " del [b, t0, t1]\n", + "@no_autocast()\n", + "def computation(a, b):\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " res = torch.sin(a) # res: \"cpu f32[1]\"\n", + " # res = ltorch.sin(a) # res: \"cpu f32[1]\"\n", + " # res = prims.sin(a) # res: \"cpu f32[1]\"\n", + " cos = torch.cos(a) # cos: \"cpu f32[1]\"\n", + " # cos = ltorch.cos(a) # cos: \"cpu f32[1]\"\n", + " # cos = prims.cos(a) # cos: \"cpu f32[1]\"\n", + " del a\n", + " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", + " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", + " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", + " del res, cos\n", + " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", + " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", + " del result, b\n", " return t3" ] }, - "execution_count": 15, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's check how our function is represented in the execution trace now\n", - "thunder.last_traces(cfun)[-1]" + "thunder.last_traces(cfun2)[-1]" ] }, { @@ -550,7 +657,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/notebooks/adding_operator_executor.ipynb b/notebooks/adding_operator_executor.ipynb deleted file mode 100644 index eac5685d3e..0000000000 --- a/notebooks/adding_operator_executor.ipynb +++ /dev/null @@ -1,688 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b6f1f42d-f146-4c9c-8ed8-74f2bcf153f0", - "metadata": {}, - "source": [ - "# Adding an operator executor\n", - "\n", - "We are going to write a simple executor for `prims.add` function that calls NumPy's addition function. Our executor will be restricted to only work with inputs with certain properties. We will use the `add_operator_executor` function to create our executor." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "576d267d-9cef-4414-a722-b2cef0665cce", - "metadata": {}, - "outputs": [], - "source": [ - "import thunder\n", - "import torch\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "02e16bf5", - "metadata": {}, - "outputs": [], - "source": [ - "#@title Helper functions (execute this cell)\n", - "import functools\n", - "\n", - "_indentation = 0\n", - "def _log(msg=None):\n", - " \"\"\"Print a message at current indentation.\"\"\"\n", - " if msg is not None:\n", - " print(\" \" * _indentation + msg)\n", - "\n", - "def _log_indent(msg=None):\n", - " \"\"\"Print a message and then indent the rest.\"\"\"\n", - " global _indentation\n", - " _log(msg)\n", - " _indentation = 2 + _indentation\n", - "\n", - "def _log_unindent(msg=None):\n", - " \"\"\"Unindent then print a message.\"\"\"\n", - " global _indentation\n", - " _indentation = _indentation - 2\n", - " _log(msg)\n", - " \n", - "def log(func):\n", - " \"\"\"A decorator for functions to log arguments and results.\"\"\"\n", - " name = func.__name__\n", - " def pp(v):\n", - " \"\"\"Print certain values more succinctly\"\"\"\n", - " vtype = str(type(v))\n", - " if isinstance(v, tuple):\n", - " return \"({})\".format(pp_values(v))\n", - " elif isinstance(v, thunder.core.proxies.TensorProxy):\n", - " return f\"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})\"\n", - " elif isinstance(v, torch.Tensor):\n", - " return f\"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}\"\n", - " else:\n", - " return str(v)\n", - " def pp_values(args):\n", - " return \", \".join([pp(arg) for arg in args])\n", - "\n", - " @functools.wraps(func)\n", - " def func_wrapper(*args):\n", - " _log_indent(\"call {}({})\".format(name, pp_values(args)))\n", - " res = func(*args)\n", - " _log_unindent(\"|<- {} = {}\\n\".format(name, pp(res)))\n", - " return res\n", - "\n", - " return func_wrapper" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "666fa494-21f5-4ed7-829e-f8648fddb13a", - "metadata": {}, - "outputs": [], - "source": [ - "# This is our test function\n", - "def fun(a, b):\n", - " return a + b * a" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "edbf395f-9549-4ae3-957c-ba34fc956b3f", - "metadata": {}, - "outputs": [], - "source": [ - "# This is our test input\n", - "a = torch.randn(2, 2, device=\"cuda\")\n", - "b = torch.randn(2, 1, device=\"cuda\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f938dff7-bac6-4807-b79d-a16cb5c6d90c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "# import thunder as thunder\n", - "# import thunder.torch as ltorch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " t2 = ltorch.add(a, t1, alpha=None) # t2: \"cuda:0 f32[2, 2]\"\n", - " # t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " return t2\n" - ] - } - ], - "source": [ - "# Let's see first how this function is represented as a trace\n", - "trace = thunder.trace()(fun, a, b)\n", - "print(trace)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6eb4818b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cuda:0 f32[2, 2]\" |\n", - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cuda:0 f32[2, 1]\" |\n", - "Bound symbol with id=torch.mul is represented in the trace as |t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"|\n", - " It has the following subsymbols:\n", - " id=PrimIDs.BROADCAST_IN_DIM |t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"|\n", - " id=PrimIDs.MUL |t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"|\n", - "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(a, t1, alpha=None) # t2: \"cuda:0 f32[2, 2]\"\n", - " # t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"|\n", - " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"|\n", - "Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t2|\n" - ] - } - ], - "source": [ - "# We can loop over the recorded operations that we call BoundSymbols\n", - "for bound_symbol in trace.bound_symbols:\n", - " print(f\"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|\")\n", - " if bound_symbol.subsymbols:\n", - " print(\" It has the following subsymbols:\")\n", - " for subsymbol in bound_symbol.subsymbols:\n", - " print(f\" id={subsymbol.sym.id} |{subsymbol}|\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "41566de2-a60f-4c87-a3d6-58e6a89dc38b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on function add_operator_executor in module thunder.executors:\n", - "\n", - "add_operator_executor(name, op_map, *, add_to_default_executors: bool = True) -> None\n", - "\n" - ] - } - ], - "source": [ - "from thunder.executors import add_operator_executor\n", - "\n", - "help(add_operator_executor)" - ] - }, - { - "cell_type": "markdown", - "id": "026680b3-7b46-4f4b-b16b-641fa9bdcdf4", - "metadata": {}, - "source": [ - "The key argument here is `op_map`.\n", - "\n", - "`op_map` is a dictionary with the id of the operator we're providing executor for as a key and `(name, checker_fn, implementation_fn)` tuple as a value.\n", - "\n", - "* `name` is the name of our execution function that would be appearing in the execution trace.\n", - "* `checker_fn` accepts the same set of arguments as the operator itself but returns `True` or `False` to signal to the executor orchestrator whether this particular set of inputs is supported or not.\n", - "* `implementation_fn` accepts real PyTorch tensors and expected to return PyTorch tensors." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "e02aaf0d", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's define the addition function that can work only with NumPy's ndarrays\n", - "\n", - "@log\n", - "def add_numpy(a, b):\n", - " assert isinstance(a, np.ndarray), \"a must be a NumPy ndarray\"\n", - " assert isinstance(b, np.ndarray), \"b must be a NumPy ndarray\"\n", - " return np.add(a, b)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7ddbe12e", - "metadata": {}, - "outputs": [], - "source": [ - "# We also need conversion functions from PyTorch to NumPy and back\n", - "@log\n", - "def torch_to_numpy(tensors):\n", - " return tuple(t.detach().cpu().numpy() for t in tensors)\n", - "\n", - "@log\n", - "def numpy_to_torch(arrays, device):\n", - " return tuple(torch.from_numpy(arr).to(device) for arr in arrays)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2460f808-eacb-4a0f-8f62-6a17e3dce6e8", - "metadata": {}, - "outputs": [], - "source": [ - "@log\n", - "def checker_add_numpy(a, b):\n", - " # Suppose we only support float32 dtype, 2D, and (2, N) shape\n", - " first_condition = a.dtype == b.dtype == thunder.dtypes.float32\n", - " second_condition = a.ndim == b.ndim == 2\n", - " third_condition = a.shape[0] == b.shape[0] == 2\n", - " return first_condition and second_condition and third_condition" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "3a61b05f", - "metadata": {}, - "outputs": [], - "source": [ - "@log\n", - "def executor_add_numpy(a, b):\n", - " np_a, np_b = torch_to_numpy((a, b))\n", - " np_res = add_numpy(np_a, np_b)\n", - " res, = numpy_to_torch((np_res,), a.device)\n", - " return res" - ] - }, - { - "cell_type": "markdown", - "id": "c502944e", - "metadata": {}, - "source": [ - "Now we have all the pieces to create our executor." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e7d3eb2f", - "metadata": {}, - "outputs": [], - "source": [ - "op_map = {\n", - " thunder.prims.PrimIDs.ADD: (\"add_numpy\", checker_add_numpy, executor_add_numpy)\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "11f71c82", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's send our operator map to `add_operator_executor` to register our executor under the name \"custom_add_executor\"\n", - "\n", - "add_operator_executor(\"custom_add_executor\", op_map, add_to_default_executors=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "d864fa05", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's test our executor\n", - "\n", - "cfun = thunder.compile(fun, executors_list=[\"custom_add_executor\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "24af4b99", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Could not find executor for bound symbol t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"\n" - ] - } - ], - "source": [ - "try:\n", - " cfun(a, b)\n", - "except RuntimeError as e:\n", - " print(e)" - ] - }, - { - "cell_type": "markdown", - "id": "d74d0c97", - "metadata": {}, - "source": [ - "The above function errors out because we haven't provided an executor for `ltorch.mul` yet. Let's do that." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "f1ff48d0", - "metadata": {}, - "outputs": [], - "source": [ - "cfun = thunder.compile(fun, executors_list=[\"custom_add_executor\", thunder.executors.TORCH])" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b1527d5e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = True\n", - "\n", - "call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = True\n", - "\n", - "call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0'))\n", - " call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0')))\n", - " |<- torch_to_numpy = ([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - "\n", - " call add_numpy([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - " |<- add_numpy = [[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]\n", - "\n", - " call numpy_to_torch(([[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]), cuda:0)\n", - " |<- numpy_to_torch = (Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0'))\n", - "\n", - "|<- executor_add_numpy = Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cfun(a, b)" - ] - }, - { - "cell_type": "markdown", - "id": "c55b5ed6", - "metadata": {}, - "source": [ - "Our logging decorator shows us that the `checker_add_numpy` function got called twice with `TensorProxy` as arguments and both times the function returned `True`. This means that our executor is going to be used for this particular execution trace.\n", - "\n", - "Then we see that the `executor_add_numpy` function is called with regular PyTorch tensors as arguments and it returns a regular PyTorch tensor." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "a7ff30ef", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " del [b]\n", - " t2 = add_numpy(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's check how our function is represented in the execution trace now\n", - "thunder.last_traces(cfun)[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "0868c882", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0'))\n", - " call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0')))\n", - " |<- torch_to_numpy = ([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - "\n", - " call add_numpy([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - " |<- add_numpy = [[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]\n", - "\n", - " call numpy_to_torch(([[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]), cuda:0)\n", - " |<- numpy_to_torch = (Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0'))\n", - "\n", - "|<- executor_add_numpy = Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')\n", - "\n" - ] - } - ], - "source": [ - "# Let's test whether the result is correct\n", - "cfun_torch = thunder.compile(fun, executors_list=[thunder.executors.TORCH])\n", - "expected = cfun_torch(a, b)\n", - "actual = cfun(a, b)\n", - "torch.testing.assert_close(expected, actual) # Should not raise an exception" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "f978b2de", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[SampleInput args=(tensor([[-5.6039, 5.0201, -8.2948, -0.1738],\n", - " [ 8.4915, -2.8353, -7.4601, -4.3015],\n", - " [ 6.0777, -7.6420, 3.4135, 3.2371],\n", - " [-0.8413, -1.7334, -1.0025, -0.7366]], device='cuda:0'), tensor([[ 4.5391, 1.5542, 7.9208, -1.3760],\n", - " [-6.5864, 8.6491, 6.1823, -1.8481],\n", - " [ 7.9385, -0.4884, 4.2281, 1.3158],\n", - " [-4.6107, 3.5805, 3.1749, -4.5989]], device='cuda:0')) kwargs={}]\n" - ] - } - ], - "source": [ - "from thunder.tests.opinfos import add_opinfo\n", - "\n", - "sample = next(add_opinfo.sample_input_generator(add_opinfo, device=\"cuda\", dtype=torch.float32, requires_grad=False))\n", - "print(sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "f07882f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(4, 4), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(4, 4), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = False\n", - "\n" - ] - } - ], - "source": [ - "# Let's test whether the result is correct\n", - "expected = cfun_torch(*sample.args)\n", - "actual = cfun(*sample.args)\n", - "torch.testing.assert_close(expected, actual) # Should not raise an exception" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "057689f5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " del [b]\n", - " t2 = torch.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# The order of executors matters today\n", - "cfun_torch_first = thunder.jit(fun, executors=[thunder.executors.TORCH, \"custom_add_executor\"])\n", - "cfun_torch_first(a, b)\n", - "thunder.last_traces(cfun_torch_first)[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "2aca9db7", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's try inputs that are not supported by our executor\n", - "a = torch.randn(3, 2, device=\"cuda\", dtype=torch.float64)\n", - "b = torch.randn(3, 1, device=\"cuda\", dtype=torch.float64)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4b3e1589", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(3, 2), dtype=float64, device=cuda:0), TensorProxy(name=t1, shape=(3, 2), dtype=float64, device=cuda:0))\n", - "|<- checker_add_numpy = False\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f64[3, 2]\" \n", - " # b: \"cuda:0 f64[3, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f64[3, 2]\"\n", - " del [b]\n", - " t2 = torch.add(a, t1) # t2: \"cuda:0 f64[3, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's see how our function is represented in the execution trace now with the new unsupported inputs\n", - "cfun(a, b)\n", - "thunder.last_traces(cfun)[-1]" - ] - }, - { - "cell_type": "markdown", - "id": "122ead11", - "metadata": {}, - "source": [ - "That's it! We've created our first executor. The process is very similar for other existing operators. There are two ingridients that are required to create an executor:\n", - "* `checker_fn` that checks whether the executor is applicable for a particular set of inputs (works with `TensorProxy` objects),\n", - "* `implementation_fn` that implements the operator (works with regular PyTorch tensors)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aec61cf6", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From eca8315f9e64073d8a3d59fddebd04df766b7aea Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 17 Mar 2024 06:51:23 +0900 Subject: [PATCH 19/44] Access `CompileData` through `compile_data` (PR2450) --- thunder/distributed/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index 39ae65bda5..0a0eb943d9 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -68,9 +68,11 @@ def skip_data_parallel_grad_sync() -> None: def _sync_grads(module: torch.nn.Module) -> None: + import thunder + params_with_grad = [p for p in module.parameters() if p.grad is not None] grads = [p.grad for p in params_with_grad] - process_group = module._lc_cd.process_group_for_ddp + process_group = thunder.compile_data(module).process_group_for_ddp torch._foreach_div_(grads, process_group.size()) with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm: for g in grads: From a5b5490df79aab9f0bc5d424f930ef2a2edbdb5d Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Sun, 17 Mar 2024 11:55:20 +0100 Subject: [PATCH 20/44] zero_to_thunder: fix typos and add a cell to install lit-gpt (PR2467) --- notebooks/zero_to_thunder.ipynb | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index 68f61a47a0..1c536f0cde 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -163,7 +163,7 @@ "So what has changed?\n", "Quite a bit!\n", "\n", - "When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" + "When we call the Thunder module, it does the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { @@ -241,7 +241,9 @@ "source": [ "## Compiling a more complex model\n", "\n", - "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):" + "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller model here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", + "\n", + "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt" ] }, { @@ -1023,7 +1025,7 @@ "metadata": {}, "source": [ "Well, that is quite a bit to look through.\n", - "But here is a key thing: The function now returns a buch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", + "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", "`thunder.last_backward_traces(thunder_model)[-1]`)." ] }, @@ -1278,7 +1280,7 @@ "id": "a63595ab", "metadata": {}, "source": [ - "For our base implementation, we take the ccode from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n", + "For our base implementation, we take the code from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n", "\n", "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n" ] @@ -1320,7 +1322,7 @@ "id": "75ad1dbf", "metadata": {}, "source": [ - "Because evil monkey-patching is a thing for short demos is a thing, let's replace LitGPT's own implementation. For your own model, you might start out with a that in your code directly." + "For this short demo, we monkey-patch LitGPT to replace its own implementation. For your own model, you might start out with a that in your code directly." ] }, { @@ -1415,7 +1417,7 @@ "metadata": {}, "source": [ "But why did we do this? Well, we can now layer a faster implementation on top.\n", - "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas." + "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas." ] }, { @@ -1927,7 +1929,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.13" } }, "nbformat": 4, From 42f2309922d94940c06f5be5d5a60e294b3205fd Mon Sep 17 00:00:00 2001 From: apaz Date: Sun, 17 Mar 2024 06:02:40 -0500 Subject: [PATCH 21/44] Last minute docs revision (PR2471) --- README.md | 20 +++++++++--------- docs/source/advanced/inside_thunder.rst | 6 +++--- docs/source/basic/overview.rst | 27 +++++++++++++------------ docs/source/basic/sharp_edges.rst | 11 +++------- 4 files changed, 30 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index b119d4860b..0f59d0007c 100644 --- a/README.md +++ b/README.md @@ -78,23 +78,23 @@ See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with T ## What's in the box -Given a program, Thunder can generate an optimized program that: +Given a python callable or PyTorch module, Thunder can generate an optimized program that: -- computes its forward and backward passes -- coalesces operations into efficient fusion regions -- dispatches computations to optimized kernels -- distributes computations optimally across machines +- Computes its forward and backward passes +- Coalesces operations into efficient fusion regions +- Dispatches computations to optimized kernels +- Distributes computations optimally across machines To do so, Thunder ships with: -- a JIT for acquiring Python programs targeting PyTorch and custom operations -- a multi-level IR to represent them as a trace of a reduced op-set -- an extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) -- a way to dispatch operations to an extensible collection of executors +- A JIT for acquiring Python programs targeting PyTorch and custom operations +- A multi-level IR to represent operations as a trace of a reduced op-set +- An extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) +- A way to dispatch operations to an extensible collection of executors Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility. -Thunder doesn't generate device code. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: +Thunder doesn't generate code for accelerators directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: - [torch.compile](https://pytorch.org/get-started/pytorch-2.0/) - [nvFuser](https://github.com/NVIDIA/Fuser) diff --git a/docs/source/advanced/inside_thunder.rst b/docs/source/advanced/inside_thunder.rst index 36013d66ae..2ae68509d8 100644 --- a/docs/source/advanced/inside_thunder.rst +++ b/docs/source/advanced/inside_thunder.rst @@ -8,9 +8,9 @@ Bytecode interpretation Thunder's interpreter works by: -1. disassembling the PyTorch module or function into CPython bytecode -2. interpreting the bytecode using an extended Python interpreter -3. generating a sequential trace of operations on tensors and numbers +1. Disassembling the PyTorch module or function into CPython bytecode +2. Interpreting the bytecode using an extended Python interpreter +3. Generating a sequential trace of operations on tensors and numbers Representing Operations ======================= diff --git a/docs/source/basic/overview.rst b/docs/source/basic/overview.rst index cb272c0b9a..c456635e62 100644 --- a/docs/source/basic/overview.rst +++ b/docs/source/basic/overview.rst @@ -3,7 +3,7 @@ Thunder Overview This section introduces Thunder's core concepts and architecture. For more details, see :doc:`Inside thunder <../advanced/inside_thunder>`. -Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must be “valid” - it must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*. +Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*. This translation begins with:: @@ -13,7 +13,7 @@ or:: jitted_fn = thunder.jit(my_function) -When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a jitted function. +When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs. When the jitted module or function is called:: @@ -23,22 +23,23 @@ or:: jitted_fn(*args, **kwargs) -Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but this is actually required to produce a trace. Different inputs can produce different traces, since the operations called may different based on the properties of the input. -The trace is generated by running the bytecode through an extensible Python interpreter implemented in Python itself, that can be extended to perform instructions in a different way compared to what standard CPython does. As such, it can be instrumented to construct a trace of operations performed on tensors or numbers, and keep track of the provenance of all objects being part of the program. +As suggested above, Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused. -If replacing CPython with Python itself sounds problematic from a performance perspective, keep in mind that the initial interpretation of a deep learning program is typically amortized during the subsequent interpretations, due to the iterative nature of deep learning programs. In other words, if the meta data of inputs (like tensor shape) doesn't change and control-flow conditions are unchanged, then there's no point in constructing a new trace, and we can rely on smart caching to just execute a trace right away. +Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside. -Traces don't typically deal with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators, but it records the operators along one path of the traceable function into the trace. +Much like other machine learning frameworks, Traces don't typically deal directly with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators. Instead, it records the operators along one path of the traceable function. -Traces can be transformed (like for backward) and optimized (like by replacing calls to PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process see the :doc:`thunder step by step ` section. +If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven't yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor's shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time. + +Traces can be transformed (like for ``backward()``) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the :doc:`thunder step by step ` section. To recap, the complete translation process is: -- For PyTorch modules, a Thunder-optimized module is created from the original module -- For PyTorch functions, compilation produces a compiled function -- When the module or function is called, the trace is generated, swapping some inputs with “proxies” -- The trace is transformed and optimized to produce an execution trace -- The execution trace is converted into a Python function and called +- For PyTorch modules, a Thunder-optimized module is created from the original module. +- For PyTorch functions, compilation produces a compiled function. +- When the module or function is called, the trace is generated, swapping some inputs with “proxies”. +- The trace is transformed and optimized to produce an execution trace. +- The execution trace is converted into a Python function and called. -As mentioned above, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times. +As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times. diff --git a/docs/source/basic/sharp_edges.rst b/docs/source/basic/sharp_edges.rst index bf2b0abf16..d62590316b 100644 --- a/docs/source/basic/sharp_edges.rst +++ b/docs/source/basic/sharp_edges.rst @@ -10,12 +10,6 @@ Inplace operations Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon. -Complex control flow --------------------- - -Control flow is supported in Thunder, but certain constructs might still be unsupported. - -In particular, attributes need to be resolved at tracing time for control flow to work. Data-dependent control flow, that is, when a condition depends on the value of tensors rather than its meta-data like shape or type, is currently not supported. Tensor subclasses ----------------- @@ -24,13 +18,14 @@ Thunder currently supports Python data types and PyTorch tensors as inputs of fu Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today. + Tracing Python builtins, standard library operations and functions that call other languages -------------------------------------------------------------------------------------------- Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed: -1. The function must not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. -2. The function must not manipulate tensor metadata or data. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. +1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. +2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>` 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will. .. From 65497980bfb323a9ff8d0ced318aa085754283c1 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Sun, 17 Mar 2024 19:40:18 +0100 Subject: [PATCH 22/44] docs: update formatting and examples (PR2472) --- docs/source/fundamentals/installation.rst | 4 +- .../intermediate/additional_executors.rst | 46 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/source/fundamentals/installation.rst b/docs/source/fundamentals/installation.rst index 7ee759ea82..8d41a24047 100644 --- a/docs/source/fundamentals/installation.rst +++ b/docs/source/fundamentals/installation.rst @@ -56,11 +56,11 @@ Thunder can easily integrate OpenAI Triton kernels. You can install Triton using Install Thunder =============== -You can now install Thunder +You can now install Thunder:: pip install git+https://github.com/Lightning-AI/lightning-thunder.git -Alternatively you can clone the Thunder repository and install locally +Alternatively you can clone the Thunder repository and install locally:: git clone https://github.com/Lightning-AI/lightning-thunder.git cd lightning-thunder diff --git a/docs/source/intermediate/additional_executors.rst b/docs/source/intermediate/additional_executors.rst index 76bb270bf0..911d7f16e9 100644 --- a/docs/source/intermediate/additional_executors.rst +++ b/docs/source/intermediate/additional_executors.rst @@ -10,11 +10,9 @@ Triton CrossEntropy Executor The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example:: + import torch import thunder - from thunder.executors import nvfuserex, torchex - from thunder.executors.triton_crossentropy import deregister_triton_entropyex, register_triton_entropyex - - register_triton_entropyex(add_to_default_executors=False) + from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex def xentropy(logits, labels, weight, reduction, ignore_index): return thunder.torch.cross_entropy( @@ -23,7 +21,7 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an jitted_xentropy = thunder.jit( xentropy, - executors_list=['triton_crossentropy', nvfuserex, torchex] + executors=[triton_cross_entropy_ex,] ) device = 'cuda' @@ -41,43 +39,42 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an This prints:: - # Constructed by Delete Last Used + # Constructed by Delete Last Used (took 0 milliseconds) import torch + from thunder.executors.torchex import no_autocast + @torch.no_grad() - def xentropy(logits, labels, weight, reduction, ignore_index): + @no_autocast() + def computation(logits, labels, weight): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" # weight: "cuda:0 f32[50257]" - # "sum" - # ignore_index: "int 10106" - t22 = triton_cross_entropy(logits, labels, weight, None, ignore_index, None, "sum", 0.0) # t22: "cuda:0 f32[]" - del [logits, labels, weight, ignore_index] - return t22 + t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]" + del logits, labels, weight + return t23 -As shown in the above trace, ``triton_cross_entropy()`` is the one running the operation. +As shown in the above trace, ``triton_crossentropy()`` is the one running the operation. Apex CrossEntropy Executor ========================== The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this:: + import torch import thunder - from thunder.executors import nvfuserex, torchex - from thunder.executors.apex_entropyex import deregister_apex_entropyex, register_apex_entropyex - - register_apex_entropyex(add_to_default_executors=False) + from thunder.executors.apex_entropyex import apex_ex def xentropy(logits, labels): return thunder.torch.cross_entropy( logits, labels, reduction='mean', ignore_index=-1 ) - jitted_xentropy = thunder.jit(xentropy, executors_list=['apex_xentropy', nvfuserex, torchex]) + jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,]) device = 'cuda' dtype = torch.float32 - logits = torch.randn([2048, 50257], device=device, dtype=thunder.torch.to_torch_dtype(dtype)) + logits = torch.randn([2048, 50257], device=device, dtype=dtype) labels = torch.randint(0, 50257, [2048], device=device) jitted_xentropy(logits, labels) @@ -86,14 +83,17 @@ The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an This prints:: - # Constructed by Delete Last Used + # Constructed by Delete Last Used (took 0 milliseconds) import torch + from thunder.executors.torchex import no_autocast + @torch.no_grad() - def xentropy(logits, labels): + @no_autocast() + def computation(logits, labels): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" - t18 = apex_cross_entropy(logits, labels, None, None, -1, None, "mean", 0.0) # t18: "cuda:0 f32[]" - del [logits, labels] + (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0) + del logits, labels return t18 showing that Apex is running the operation. From 8a29afb6b438a78ebb405e1b4492985ac8b29d28 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 19 Mar 2024 05:12:10 -0700 Subject: [PATCH 23/44] Make cudnn tests strict. (PR2470) --- thunder/tests/test_cudnn_executor.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 6d76761a54..5c609228f3 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -202,17 +202,19 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): supported_devicetypes=(devices.DeviceType.CUDA,), ) def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): - ran_atleast_one = False for sample in op.reference_inputs(device, dtype, requires_grad=True): - from thunder.executors.cudnnex import cudnn_ex - # Enforce tensor arguments are contiguous for torch reference contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args)) # query, key, value grad_inputs = list(contiguous_args[:3]) - if (attn_mask := sample.args[3]) is not None and attn_mask.requires_grad: - grad_inputs.append(attn_mask) + if (attn_mask := sample.args[3]) is not None: + if attn_mask.requires_grad: + grad_inputs.append(attn_mask) + # TODO(#2470): With cudnn frontend 1.1 and A100, this test hits + # RuntimeError when `attn_mask` is provided: `[cudnn_frontend] + # Error: No execution plans built successfully`. + continue # Compute vjp result using PyTorch expect_out = op.torch_reference(*contiguous_args, **sample.kwargs) @@ -230,17 +232,10 @@ def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): executors_list=executor.executors_list() + [cudnn_ex], ) - try: - actual_out, actual_grad = cfoo(filtered_args, (v,)) - except Exception as e: - continue + actual_out, actual_grad = cfoo(filtered_args, (v,)) comp(actual_out, expect_out, atol=1e-2, rtol=1e-2) # compare gradients of query, key, value, and attn_mask for eg, ag in zip(expected_grad, actual_grad): comp(eg, ag, atol=2e-1, rtol=2e-2) - - ran_atleast_one = True - - assert ran_atleast_one == True From ffc5c6caed0170b179edf4d259a8c3acca575fa4 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 19 Mar 2024 14:04:20 +0100 Subject: [PATCH 24/44] test: update previous usage of thunder.last_traces (PR2474) --- thunder/tests/test_transformer_engine_executor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py index 60aa6f64ea..41b6f83e36 100644 --- a/thunder/tests/test_transformer_engine_executor.py +++ b/thunder/tests/test_transformer_engine_executor.py @@ -70,7 +70,8 @@ def fn(x, w1, w2): assert_close(w2.grad, te_linear2.weight.grad) # Verifies te_linear was called - forward_trace, backward_trace = thunder.last_traces(cfn) + forward_trace = thunder.last_traces(cfn) + backward_trace = thunder.last_backward_traces(cfn) assert any(bsym.sym.name.startswith("te_linear") for bsym in forward_trace[-1].bound_symbols) assert any(bsym.sym.name.startswith("te_functional_linear_backward") for bsym in backward_trace[-1].bound_symbols) @@ -180,6 +181,6 @@ def foo(x, w): ) cfunc(x, w) - fwd_traces, _ = thunder.last_traces(cfunc) + fwd_traces = thunder.last_traces(cfunc) # Verify that we have replaced `prims.linear` with `te_linear` assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols) From e8d936f8efafce138e9509104f494903e5818b1d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 19 Mar 2024 15:05:18 +0200 Subject: [PATCH 25/44] Initialize with meta device, move optimizer init to be after sharding (PR2476) Thank you @IvanYashchuk @parthmannan @carmocca --- thunder/benchmarks/benchmark_litgpt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 40db2db5de..ae9e8b2084 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -108,16 +108,18 @@ def __init__( if n_layers is not None: self.config.n_layer = n_layers - # Initialize the model and the optimizer + # Initialize the model self.model = self.init_model() - self.optimizer = configure_optimizers( - self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda" - ) # Setup the distributed algorithm choices if self.distributed_mode != "none": self.model = self.setup_distributed() + # Initialize the optimizer after the model is sharded if using FSDP + self.optimizer = configure_optimizers( + self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda" + ) + # Compile the model if self.compile not in ["eager", None]: self.model = self.setup_compile() @@ -137,8 +139,9 @@ def __init__( def init_model(self): print(f"Loading model with {self.config.__dict__}") + init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device t0 = time.perf_counter() - with self.device: + with init_device: model = GPT(self.config) model.to(dtype=torch.bfloat16) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") From 0a1b14ae01bec1e49b8322c0db8c497d707d29c0 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 20 Mar 2024 06:35:20 -0700 Subject: [PATCH 26/44] register lookasides in register_operator (PR2487) --- thunder/extend/__init__.py | 12 ++++++++++++ thunder/tests/test_extend.py | 7 +++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 2d2ce1fd2a..2acdcf6427 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -25,6 +25,7 @@ "add_always_executor", "remove_default_executor", "remove_always_executor", + "register_lookaside", ] @@ -208,6 +209,7 @@ def register_operator( module: None | type | ModuleType = None, fn: None | Callable = None, bind_postprocess: None | Callable = None, + replaces: None | Callable = None, python_printer: Callable = default_python_printer, ) -> Symbol: assert (like is None) ^ (meta is None), "Expected one and only one of 'like' and 'meta' to be specified" @@ -237,6 +239,9 @@ def _bind_postprocess(bsym: BoundSymbol) -> None: ) self.opmap[name] = sym + if replaces is not None: + register_lookaside(replaces, sym) + return sym def register_implementation( @@ -381,3 +386,10 @@ def deregister_executor(ex: Hashable | Executor) -> None: remove_always_executor(id) remove_default_executor(id) + + +def register_lookaside(function, symbol) -> None: + """register `symbol` as a lookaside for `function`""" + import thunder.core.jit_ext + + thunder.core.jit_ext._general_jit_lookaside_map[function] = thunder.core.jit_ext.interpreter_needs_wrap(symbol) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 03277f3419..b7dd300d12 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -137,14 +137,17 @@ def test_register_implementation_custom_op(): myex = OperatorExecutor("myex", version="0.1") register_executor(myex) + def official_add(a, b): + return a + b + def _myadd(a, b): return a + b - myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd) + myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd, replaces=official_add) myadd2 = myex.register_operator("myadd2", like=_myadd, fn=_myadd) def fn(a, b): - return myadd1(a, b) + return official_add(a, b) cfn = thunder.jit(fn, executors=[myex]) From 69c9327d217f07a7ce5ddb775cac7014fac22125 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 Mar 2024 07:38:28 -0700 Subject: [PATCH 27/44] Clean the cudnn test. (PR2477) --- thunder/tests/test_cudnn_executor.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 5c609228f3..c9bd6277ec 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -34,42 +34,42 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-0.5, high=0.5) n_head = 2 - N = 8 + N = 8 # batch size # TODO: multiple of 8 seems to produce NaNs - L = random.randint(1, 10) * 64 + L = random.randint(1, 10) * 64 # query's sequence length alignment_factor = 8 - S = random.randint(1, 10) * alignment_factor - E = random.randint(8, 16) * alignment_factor - Ev = random.randint(8, 16) * alignment_factor + S = random.randint(1, 10) * alignment_factor # key/value's sequence length + E = random.randint(8, 16) * alignment_factor # query/key's embedding size + Ev = random.randint(8, 16) * alignment_factor # value's embedding size # 4-dim (multiheaded) causal cases q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) - yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := True) + yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True) # TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR. # Non-contiguous input tensor case nq = make(N, n_head, L, E).permute(0, 1, 3, 2) nk = make(N, n_head, L, E).permute(0, 1, 3, 2) nv = make(N, n_head, L, E).permute(0, 1, 3, 2) - yield SampleInput(nq, nk, nv, attn_mask := None, dropout_p := 0.0, is_causal := False) + yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False) # Test the scale factor which was added in torch 2.1 if LooseVersion(torch.__version__) >= LooseVersion("2.1.0"): q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) - yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := False, scale=0.123) + yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=False, scale=0.123) # TODO: cudnnex only support of grad_attn_mask with batch dim 1 and both sequence lenghts divisible by 64. Release 9.0.1 will relax this constraint. # Additive attn_mask q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) additive_attn_mask = make((1, n_head, L, S), dtype=q.dtype).tril() - yield SampleInput(q, k, v, attn_mask := additive_attn_mask, is_causal=False) + yield SampleInput(q, k, v, additive_attn_mask, is_causal=False) # Boolean attn_mask q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) bool_attn_mask = make((1, n_head, L, S), dtype=torch.bool, low=1, high=1, requires_grad=False).tril() - yield SampleInput(q, k, v, attn_mask := bool_attn_mask, is_causal=False) + yield SampleInput(q, k, v, bool_attn_mask, is_causal=False) grad_sdpa_cudnn_opinfo = OpInfo( From abecbd82d8a86e3074ff0c075632e496bc360494 Mon Sep 17 00:00:00 2001 From: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:17:16 +0200 Subject: [PATCH 28/44] Enable autocast for llama2.c example (bf16) (PR1771) Co-authored-by: Ivan Yashchuk --- examples/llama2.c/README.md | 21 +++++++++--------- examples/llama2.c/sample.py | 2 +- examples/llama2.c/train.py | 44 +++++++++++++++++++++---------------- thunder/core/transforms.py | 11 ++++++---- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/examples/llama2.c/README.md b/examples/llama2.c/README.md index 5acb8f0742..4a999da840 100644 --- a/examples/llama2.c/README.md +++ b/examples/llama2.c/README.md @@ -28,9 +28,9 @@ The code is configured to run with Thunder by default. Results with 1 GPU: -- ~339 ms/iter (torch.compile 'inductor') -- ~347 ms/iter (thunder nvfuser) -- ~431 ms/iter (eager) +- ~215 ms/iter (torch.compile 'inductor') +- ~239 ms/iter (thunder nvfuser) +- ~339 ms/iter (eager) CUDAGraphs are not used as the results were worse with them. @@ -46,15 +46,14 @@ nanoGPT doesn't implement KV caching so this is expectedly slow. Please checkout ## Setup ```text -Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime) +Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Is debug build: False -CUDA used to build PyTorch: 12.1 -CUDA runtime version: 12.1.105 +CUDA used to build PyTorch: 12.4 +CUDA runtime version: 12.4.99 GPU 0: NVIDIA A100-SXM4-40GB -Nvidia driver version: 525.125.06 +Nvidia driver version: 550.54.14 -pytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.0.0%2B901819d2b6-cp310-cp310-linux_x86_64.whl -torch @ https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20240130%2Bcu121-cp310-cp310-linux_x86_64.whl -lightning-thunder==8b107c6fe531c94c6705dbf39700863685ba5b65 -nvfuser_cu121==0.1.5.dev20240131 +triton == 3.0.0 +torch == 2.4.0a0+git685ace3 +nvfuser @ 0.2.0+git70101da ``` diff --git a/examples/llama2.c/sample.py b/examples/llama2.c/sample.py index b8ccacfa48..094b6c3f74 100644 --- a/examples/llama2.c/sample.py +++ b/examples/llama2.c/sample.py @@ -55,7 +55,7 @@ from thunder.executors.sdpaex import sdpa_ex executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] - cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors_list=executors) + cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors=executors) # the generate implementation is not compile friendly, so bind the compiled model to the generate implementation generate = partial(Transformer.generate, cmodel) # workaround for "Foward nn.Module attributes through the ThunderOptimizedModule" diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py index 58d88d4729..206a4e065d 100644 --- a/examples/llama2.c/train.py +++ b/examples/llama2.c/train.py @@ -70,8 +70,8 @@ warmup_iters = 1000 # how many steps to warm up for # system device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks -# dtype = "bfloat16" # float32|bfloat16|float16 -compile = "thunder" # eager|torch|thunder +dtype = "bfloat16" # float32|bfloat16|float16 +compile = "thunder" # thunder|torch|eager # ----------------------------------------------------------------------------- config_keys = [ k @@ -122,8 +122,15 @@ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast # note: float16 data type will automatically use a GradScaler -# ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] -ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype) +ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] +ctx = ( + nullcontext() + if device_type == "cpu" + else torch.amp.autocast(device_type=device_type, dtype=ptdtype) +) +# Disable other than FlashAttention backends for SDPA +torch.backends.cuda.enable_math_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) # task-specific setup iter_batches = partial( @@ -179,10 +186,11 @@ model.load_state_dict(state_dict) iter_num = checkpoint["iter_num"] best_val_loss = checkpoint["best_val_loss"] + model.to(device) # initialize a GradScaler. If enabled=False scaler is a no-op -scaler = torch.cuda.amp.GradScaler(enabled=(False)) # dtype == "float16")) +scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) @@ -190,19 +198,19 @@ optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory -raw_model = eval_model = train_model = model +raw_model = model # wrap model into DDP container if ddp: if compile == "thunder": from thunder.distributed import ddp - train_model = ddp(train_model) + model = ddp(model) else: # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at # construction time since NCCL does not support `ComplexFloat` - train_model._ddp_params_and_buffers_to_ignore = {"freqs_cis"} - train_model = DDP(train_model, device_ids=[ddp_local_rank]) + model._ddp_params_and_buffers_to_ignore = {"freqs_cis"} + model = DDP(model, device_ids=[ddp_local_rank]) # compile the model if compile == "thunder": @@ -212,31 +220,29 @@ from thunder.executors.sdpaex import sdpa_ex executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] - eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=executors) - train_model = thunder.compile(train_model.train(), executors_list=executors) + model = thunder.jit(model, executors=executors) elif compile == "torch": print("compiling the model with torch... (takes a ~minute)") - eval_model = torch.compile(eval_model) - train_model = torch.compile(train_model) + model = torch.compile(model) # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() def estimate_loss(): out = {} if compile != "thunder": - eval_model.eval() + model.eval() for split in ["train", "val"]: batch_iter = iter_batches(split=split) losses = torch.zeros(eval_iters) # keep on CPU for k in range(eval_iters): X, Y = next(batch_iter) with ctx: - logits = eval_model(X, Y) + logits = model(X, Y) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1) losses[k] = loss.item() out[split] = losses.mean() if compile != "thunder": - train_model.train() + model.train() return out # learning rate decay scheduler (cosine with warmup) @@ -313,9 +319,9 @@ def get_lr(it): # the official way to do this is with model.no_sync() context manager, but # this forces us to repeat code. # looking at the source of that context manager, it just toggles this variable - train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 + model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: - logits = train_model(X, Y) + logits = model(X, Y) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1) loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU @@ -325,7 +331,7 @@ def get_lr(it): # clip the gradient if grad_clip != 0.0: scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # step the optimizer and scaler if training in fp16 scaler.step(optimizer) scaler.update() diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 8d18d40904..c24e44a3ce 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3984,10 +3984,13 @@ def decorator(func): def maybe_downcast_to(dtype, args): allowed_downcast_types = (dtypes.float16, dtypes.bfloat16, dtypes.float32) - if all(tree_map(lambda a: a.dtype in allowed_downcast_types, args)): - return tree_map(lambda a: maybe_convert_to_dtype(a, dtype), args) - else: - return args + + def map_fn(a): + if isinstance(a, TensorProxy) and a.dtype in allowed_downcast_types: + return maybe_convert_to_dtype(a, dtype) + return a + + return tree_map(map_fn, args) @register_autocast_rule("torch.matmul") From 505f9ac8bd43ca96a7e5174c073a90245be79074 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 20 Mar 2024 08:18:25 -0700 Subject: [PATCH 29/44] Update zero to thunder to new extensibility example (PR2488) --- notebooks/zero_to_thunder.ipynb | 4258 +++++++++++++++++++++++-------- 1 file changed, 3264 insertions(+), 994 deletions(-) diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index 1c536f0cde..a1a888cc72 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -3,7 +3,11 @@ { "cell_type": "markdown", "id": "1638964c", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "# Zero to Thunder\n", "\n", @@ -21,16 +25,18 @@ "source": [ "import sys\n", "sys.path.insert(0, '..')\n", - "import inspect\n", - "\n", "\n", - "import torch, thunder\n" + "import torch, thunder" ] }, { "cell_type": "markdown", "id": "54f87aba", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Compiling a first module with Thunder\n", "\n", @@ -40,7 +46,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "d6ca6328", + "id": "892be718", "metadata": {}, "outputs": [ { @@ -62,26 +68,26 @@ " self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n", - "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " x_fc_1 = self.fc_1(x)\n", " x_fc_2 = self.fc_2(x)\n", " x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n", " return self.proj(x)\n", - "\n", - "\n", "with torch.device(\"cuda\"):\n", " m = LLaMAMLP(4096, 11008)\n", "for p in m.parameters():\n", " p.requires_grad_(False)\n", - "\n", - "print(m)" + "print(m)\n" ] }, { "cell_type": "markdown", "id": "702ea054", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`" ] @@ -125,8 +131,12 @@ }, { "cell_type": "markdown", - "id": "59db20f6", - "metadata": {}, + "id": "47d24f2d-0e89-4fe8-8154-9b50f2633e1b", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance." ] @@ -135,15 +145,19 @@ "cell_type": "code", "execution_count": 5, "id": "7f4de1b3", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 1.4901161193847656e-07\n", - "58.2 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "58.7 ms ± 50.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -157,20 +171,25 @@ }, { "cell_type": "markdown", - "id": "8835543e", - "metadata": {}, + "id": "7996acc7-de20-4aa5-80f0-1ab6042e2650", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "So what has changed?\n", - "Quite a bit!\n", + "So what has changed? Quite a bit!\n", "\n", - "When we call the Thunder module, it does the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" + "When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a6f4b77c", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { @@ -221,8 +240,12 @@ }, { "cell_type": "markdown", - "id": "a0071924", - "metadata": {}, + "id": "2ef89186-70cd-4737-9695-ed282da2a56c", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, "source": [ "For more detail of what is going on in this trace:\n", "- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n", @@ -237,13 +260,17 @@ { "cell_type": "markdown", "id": "7749aed1", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Compiling a more complex model\n", "\n", - "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller model here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", + "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", "\n", - "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt" + "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt." ] }, { @@ -260,7 +287,7 @@ " (transformer): ModuleDict(\n", " (wte): Embedding(32000, 4096)\n", " (h): ModuleList(\n", - " (0-3): 4 x Block(\n", + " (0-15): 16 x Block(\n", " (norm_1): RMSNorm()\n", " (attn): CausalSelfAttention(\n", " (attn): Linear(in_features=4096, out_features=12288, bias=False)\n", @@ -288,7 +315,8 @@ "from lit_gpt import GPT\n", "from thunder.tests.lit_gpt_model import Config\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", - "cfg.n_layer = 4 # fewer layers\n", + "cfg.n_layer = 16 # fewer layers\n", + "torch.set_default_dtype(torch.bfloat16)\n", "with torch.device('cuda'):\n", " m = GPT(cfg)\n", "m\n" @@ -312,7 +340,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 1.8477439880371094e-06\n" + "deviation: 0.03125\n" ] } ], @@ -329,22 +357,37 @@ }, { "cell_type": "markdown", - "id": "2f681093", - "metadata": {}, + "id": "9947e8df-cd2d-447d-90b9-ee08bb5a9fb2", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "Just like before, we can see the program it ran:" + "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.\n", + "\n", + "Just like before, we can see the program it ran, it is a lot longer, though." ] }, { "cell_type": "code", "execution_count": 9, "id": "ac7e8bc9", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, { "data": { "text/plain": [ - "# Constructed by Delete Last Used (took 1 milliseconds)\n", + "# Constructed by Delete Last Used (took 10 milliseconds)\n", "import torch\n", "from torch import Tensor\n", "import torch.nn.functional\n", @@ -388,626 +431,2728 @@ " t31, \\\n", " t32, \\\n", " t33, \\\n", + " t34, \\\n", + " t35, \\\n", + " t36, \\\n", + " t37, \\\n", + " t38, \\\n", + " t39, \\\n", + " t40, \\\n", + " t41, \\\n", + " t42, \\\n", + " t43, \\\n", + " t44, \\\n", + " t45, \\\n", + " t46, \\\n", + " t47, \\\n", + " t48, \\\n", + " t49, \\\n", + " t50, \\\n", + " t51, \\\n", + " t52, \\\n", + " t53, \\\n", + " t54, \\\n", + " t55, \\\n", + " t56, \\\n", + " t57, \\\n", + " t58, \\\n", + " t59, \\\n", + " t60, \\\n", + " t61, \\\n", + " t62, \\\n", + " t63, \\\n", + " t64, \\\n", + " t65, \\\n", + " t66, \\\n", + " t67, \\\n", + " t68, \\\n", + " t69, \\\n", + " t70, \\\n", + " t71, \\\n", + " t72, \\\n", + " t73, \\\n", + " t74, \\\n", + " t75, \\\n", + " t76, \\\n", + " t77, \\\n", + " t78, \\\n", + " t79, \\\n", + " t80, \\\n", + " t81, \\\n", + " t82, \\\n", + " t83, \\\n", + " t84, \\\n", + " t85, \\\n", + " t86, \\\n", + " t87, \\\n", + " t88, \\\n", + " t89, \\\n", + " t90, \\\n", + " t91, \\\n", + " t92, \\\n", + " t93, \\\n", + " t94, \\\n", + " t95, \\\n", + " t96, \\\n", + " t97, \\\n", + " t98, \\\n", + " t99, \\\n", + " t100, \\\n", + " t101, \\\n", + " t102, \\\n", + " t103, \\\n", + " t104, \\\n", + " t105, \\\n", + " t106, \\\n", + " t107, \\\n", + " t108, \\\n", + " t109, \\\n", + " t110, \\\n", + " t111, \\\n", + " t112, \\\n", + " t113, \\\n", + " t114, \\\n", + " t115, \\\n", + " t116, \\\n", + " t117, \\\n", " = args\n", " del args\n", - " t38 = torch.nn.functional.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t38 = ltorch.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t334 = ltorch.reshape(t0, [512]) # t334: \"cuda:0 i64[512]\"\n", - " # t334 = prims.reshape(t0, (512,)) # t334: \"cuda:0 i64[512]\"\n", - " # t335 = prims.take(t33, t334, 0) # t335: \"cuda:0 f32[512, 4096]\"\n", - " # t38 = ltorch.reshape(t335, [1, 512, 4096]) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t38 = prims.reshape(t335, (1, 512, 4096)) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " t34 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t34: \"cuda:0 f32[512, 128]\"\n", - " t35 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t35: \"cuda:0 f32[512, 128]\"\n", - " t374 = torch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", - " # t374 = ltorch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", - " # t374 = prims.broadcast_in_dim(t17, [1, 4096], [1]) # t374: \"cuda:0 f32[1, 4096]\"\n", - " t375 = torch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t375 = ltorch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t375 = prims.broadcast_in_dim(t374, [1, 1, 4096], [0, 2]) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t374\n", - " t47 = Tensor.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t47 = ltorch.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t47 = prims.broadcast_in_dim(t375, (1, 512, 4096), (0, 1, 2)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t375\n", - " t475 = torch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", - " # t475 = ltorch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", - " # t475 = prims.broadcast_in_dim(t24, [1, 4096], [1]) # t475: \"cuda:0 f32[1, 4096]\"\n", - " t476 = torch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t476 = ltorch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t476 = prims.broadcast_in_dim(t475, [1, 1, 4096], [0, 2]) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t475\n", - " t311 = Tensor.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t311 = ltorch.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t311 = prims.broadcast_in_dim(t476, (1, 512, 4096), (0, 1, 2)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t476\n", - " t478 = torch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", - " # t478 = ltorch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", - " # t478 = prims.broadcast_in_dim(t16, [1, 4096], [1]) # t478: \"cuda:0 f32[1, 4096]\"\n", - " t479 = torch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t479 = ltorch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t479 = prims.broadcast_in_dim(t478, [1, 1, 4096], [0, 2]) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t478\n", - " t331 = Tensor.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t331 = ltorch.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t331 = prims.broadcast_in_dim(t479, (1, 512, 4096), (0, 1, 2)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t479\n", - " t403 = torch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", - " # t403 = ltorch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", - " # t403 = prims.broadcast_in_dim(t21, [1, 4096], [1]) # t403: \"cuda:0 f32[1, 4096]\"\n", - " t404 = torch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t404 = ltorch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t404 = prims.broadcast_in_dim(t403, [1, 1, 4096], [0, 2]) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t403\n", - " t98 = Tensor.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t98 = ltorch.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t98 = prims.broadcast_in_dim(t404, (1, 512, 4096), (0, 1, 2)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t404\n", - " t406 = torch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", - " # t406 = ltorch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", - " # t406 = prims.broadcast_in_dim(t18, [1, 4096], [1]) # t406: \"cuda:0 f32[1, 4096]\"\n", - " t407 = torch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t407 = ltorch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t407 = prims.broadcast_in_dim(t406, [1, 1, 4096], [0, 2]) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t406\n", - " t118 = Tensor.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t118 = ltorch.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t118 = prims.broadcast_in_dim(t407, (1, 512, 4096), (0, 1, 2)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t407\n", - " t427 = torch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", - " # t427 = ltorch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", - " # t427 = prims.broadcast_in_dim(t22, [1, 4096], [1]) # t427: \"cuda:0 f32[1, 4096]\"\n", - " t428 = torch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t428 = ltorch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t428 = prims.broadcast_in_dim(t427, [1, 1, 4096], [0, 2]) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t427\n", - " t169 = Tensor.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t169 = ltorch.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t169 = prims.broadcast_in_dim(t428, (1, 512, 4096), (0, 1, 2)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t428\n", - " t430 = torch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", - " # t430 = ltorch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", - " # t430 = prims.broadcast_in_dim(t19, [1, 4096], [1]) # t430: \"cuda:0 f32[1, 4096]\"\n", - " t431 = torch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t431 = ltorch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t431 = prims.broadcast_in_dim(t430, [1, 1, 4096], [0, 2]) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t430\n", - " t189 = Tensor.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t189 = ltorch.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t189 = prims.broadcast_in_dim(t431, (1, 512, 4096), (0, 1, 2)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t431\n", - " t451 = torch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", - " # t451 = ltorch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", - " # t451 = prims.broadcast_in_dim(t23, [1, 4096], [1]) # t451: \"cuda:0 f32[1, 4096]\"\n", - " t452 = torch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t452 = ltorch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t452 = prims.broadcast_in_dim(t451, [1, 1, 4096], [0, 2]) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t451\n", - " t240 = Tensor.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t240 = ltorch.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t240 = prims.broadcast_in_dim(t452, (1, 512, 4096), (0, 1, 2)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t452\n", - " t454 = torch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", - " # t454 = ltorch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", - " # t454 = prims.broadcast_in_dim(t20, [1, 4096], [1]) # t454: \"cuda:0 f32[1, 4096]\"\n", - " t455 = torch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t455 = ltorch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t455 = prims.broadcast_in_dim(t454, [1, 1, 4096], [0, 2]) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t454\n", - " t260 = Tensor.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t260 = ltorch.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t260 = prims.broadcast_in_dim(t455, (1, 512, 4096), (0, 1, 2)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t455\n", - " t395 = torch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " # t395 = ltorch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " # t395 = prims.broadcast_in_dim(t34, [1, 512, 128], [1, 2]) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " del t34\n", - " t396 = torch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t396 = ltorch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t396 = prims.broadcast_in_dim(t395, [1, 1, 512, 128], [0, 2, 3]) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " del t395\n", - " t63 = Tensor.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t63 = ltorch.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t63 = prims.broadcast_in_dim(t396, (1, 32, 512, 128), (0, 1, 2, 3)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t396\n", - " t398 = torch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " # t398 = ltorch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " # t398 = prims.broadcast_in_dim(t35, [1, 512, 128], [1, 2]) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " del t35\n", - " t399 = torch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t399 = ltorch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t399 = prims.broadcast_in_dim(t398, [1, 1, 512, 128], [0, 2, 3]) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " del t398\n", - " t65 = Tensor.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t65 = ltorch.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t65 = prims.broadcast_in_dim(t399, (1, 32, 512, 128), (0, 1, 2, 3)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t399\n", - " [t44, t48] = nvFusion0(t38, t47)\n", - " # t39 = prims.mul(t38, t38) # t39: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t40 = prims.sum(t39, (2,)) # t40: \"cuda:0 f32[1, 512]\"\n", - " # t41 = prims.broadcast_in_dim(t40, [1, 512, 1], [0, 1]) # t41: \"cuda:0 f32[1, 512, 1]\"\n", - " # t42 = prims.div(t41, 4096.0) # t42: \"cuda:0 f32[1, 512, 1]\"\n", - " # t43 = prims.add(t42, 1e-05) # t43: \"cuda:0 f32[1, 512, 1]\"\n", - " # t44 = prims.rsqrt(t43) # t44: \"cuda:0 f32[1, 512, 1]\"\n", - " # t45 = prims.broadcast_in_dim(t44, (1, 512, 4096), (0, 1, 2)) # t45: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t46 = prims.mul(t38, t45) # t46: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t48 = prims.mul(t46, t47) # t48: \"cuda:0 f32[1, 512, 4096]\"\n", - " t49 = torch.nn.functional.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t49 = ltorch.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t49 = prims.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " t50 = torch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t50 = ltorch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t50 = prims.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t49\n", - " t51 = torch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t51 = ltorch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t51 = prims.transpose(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t50\n", - " (t52, t53, t54) = torch.split(t51, (1, 1, 1), 2)\n", - " # (t52, t53, t54) = ltorch.split(t51, (1, 1, 1), 2)\n", - " # t52 = prims.slice_prim(t51, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t52: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t53 = prims.slice_prim(t51, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t53: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t54 = prims.slice_prim(t51, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t54: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t51\n", - " t55 = torch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t55 = ltorch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t55 = prims.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t52\n", - " t56 = torch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t56 = ltorch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t56 = prims.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t53\n", - " t57 = torch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t57 = ltorch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t57 = prims.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t54\n", - " t58 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t58: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t68 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t68: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t78 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t78: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t55\n", - " t80 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t80: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t56\n", - " t60 = torch_slice_prim_impl(t58, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t59 = torch_slice_prim_impl(t58, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t59: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t69 = torch_slice_prim_impl(t68, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t69: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t70 = torch_slice_prim_impl(t68, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t70: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t61, t71] = nvFusion1(t60, t70)\n", - " # t61 = prims.neg(t60) # t61: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t71 = prims.neg(t70) # t71: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t60, t70\n", - " t62 = torch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t62 = ltorch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t62 = prims.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t61, t59\n", - " t72 = torch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t72 = ltorch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t72 = prims.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t71, t69\n", - " [t67, t77] = nvFusion2(t58, t62, t63, t65, t68, t72)\n", - " # t64 = prims.mul(t58, t63) # t64: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t66 = prims.mul(t62, t65) # t66: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t67 = prims.add(t64, t66) # t67: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t74 = prims.mul(t68, t63) # t74: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t76 = prims.mul(t72, t65) # t76: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t77 = prims.add(t74, t76) # t77: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t58, t62, t68, t72\n", - " t79 = torch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t79 = ltorch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t79 = prims.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t67, t78\n", - " t81 = torch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t81 = ltorch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t81 = prims.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t77, t80\n", - " (t82, t83, t84, t85) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t79, t81, t57, None, 0.0, True, 0.08838834764831843)\n", - " t86 = torch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t86 = ltorch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t86 = prims.transpose(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t87 = torch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t87 = ltorch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t87 = prims.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t86\n", - " t88 = torch.nn.functional.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t88 = ltorch.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t88 = prims.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t89, t95, t99] = nvFusion3(t38, t88, t98)\n", - " # t89 = prims.add(t88, t38) # t89: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t90 = prims.mul(t89, t89) # t90: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t91 = prims.sum(t90, (2,)) # t91: \"cuda:0 f32[1, 512]\"\n", - " # t92 = prims.broadcast_in_dim(t91, [1, 512, 1], [0, 1]) # t92: \"cuda:0 f32[1, 512, 1]\"\n", - " # t93 = prims.div(t92, 4096.0) # t93: \"cuda:0 f32[1, 512, 1]\"\n", - " # t94 = prims.add(t93, 1e-05) # t94: \"cuda:0 f32[1, 512, 1]\"\n", - " # t95 = prims.rsqrt(t94) # t95: \"cuda:0 f32[1, 512, 1]\"\n", - " # t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t97 = prims.mul(t89, t96) # t97: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t99 = prims.mul(t97, t98) # t99: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t88\n", - " t101 = torch.nn.functional.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t101 = ltorch.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t101 = prims.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " t100 = torch.nn.functional.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t100 = ltorch.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t100 = prims.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t107] = nvFusion4(t100, t101)\n", - " # t102 = prims.neg(t100) # t102: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t103 = prims.exp(t102) # t103: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t104 = prims.add(1.0, t103) # t104: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t105 = prims.reciprocal(t104) # t105: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t106 = prims.mul(t100, t105) # t106: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t107 = prims.mul(t106, t101) # t107: \"cuda:0 f32[1, 512, 11008]\"\n", - " t108 = torch.nn.functional.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t108 = ltorch.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t108 = prims.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t109, t115, t119] = nvFusion5(t108, t118, t89)\n", - " # t109 = prims.add(t108, t89) # t109: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t110 = prims.mul(t109, t109) # t110: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t111 = prims.sum(t110, (2,)) # t111: \"cuda:0 f32[1, 512]\"\n", - " # t112 = prims.broadcast_in_dim(t111, [1, 512, 1], [0, 1]) # t112: \"cuda:0 f32[1, 512, 1]\"\n", - " # t113 = prims.div(t112, 4096.0) # t113: \"cuda:0 f32[1, 512, 1]\"\n", - " # t114 = prims.add(t113, 1e-05) # t114: \"cuda:0 f32[1, 512, 1]\"\n", - " # t115 = prims.rsqrt(t114) # t115: \"cuda:0 f32[1, 512, 1]\"\n", - " # t116 = prims.broadcast_in_dim(t115, (1, 512, 4096), (0, 1, 2)) # t116: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t117 = prims.mul(t109, t116) # t117: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t119 = prims.mul(t117, t118) # t119: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t108\n", - " t120 = torch.nn.functional.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t120 = ltorch.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t120 = prims.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " t121 = torch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t121 = ltorch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t121 = prims.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t120\n", - " t122 = torch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t122 = ltorch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t122 = prims.transpose(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t121\n", - " (t123, t124, t125) = torch.split(t122, (1, 1, 1), 2)\n", - " # (t123, t124, t125) = ltorch.split(t122, (1, 1, 1), 2)\n", - " # t123 = prims.slice_prim(t122, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t123: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t124 = prims.slice_prim(t122, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t124: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t125 = prims.slice_prim(t122, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t125: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t122\n", - " t126 = torch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t126 = ltorch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t126 = prims.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t123\n", - " t127 = torch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t127 = ltorch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t127 = prims.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t124\n", - " t128 = torch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t128 = ltorch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t128 = prims.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t125\n", - " t149 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t149: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " t151 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t151: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " t129 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t129: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t126\n", - " t139 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t139: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t127\n", - " t130 = torch_slice_prim_impl(t129, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t130: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t131 = torch_slice_prim_impl(t129, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t131: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t141 = torch_slice_prim_impl(t139, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t141: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t140 = torch_slice_prim_impl(t139, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t140: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t132, t142] = nvFusion6(t131, t141)\n", - " # t132 = prims.neg(t131) # t132: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t142 = prims.neg(t141) # t142: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t131, t141\n", - " t143 = torch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t143 = ltorch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t143 = prims.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t142, t140\n", - " t133 = torch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t133 = ltorch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t133 = prims.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t132, t130\n", - " [t138, t148] = nvFusion7(t129, t133, t139, t143, t63, t65)\n", - " # t145 = prims.mul(t139, t63) # t145: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t147 = prims.mul(t143, t65) # t147: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t148 = prims.add(t145, t147) # t148: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t135 = prims.mul(t129, t63) # t135: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t137 = prims.mul(t133, t65) # t137: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t138 = prims.add(t135, t137) # t138: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t129, t133, t139, t143\n", - " t150 = torch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t150 = ltorch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t150 = prims.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t138, t149\n", - " t152 = torch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t152 = ltorch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t152 = prims.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t148, t151\n", - " (t153, t154, t155, t156) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t150, t152, t128, None, 0.0, True, 0.08838834764831843)\n", - " t157 = torch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t157 = ltorch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t157 = prims.transpose(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t158 = torch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t158 = ltorch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t158 = prims.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t157\n", - " t159 = torch.nn.functional.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t159 = ltorch.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t159 = prims.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t160, t166, t170] = nvFusion8(t109, t159, t169)\n", - " # t160 = prims.add(t159, t109) # t160: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t161 = prims.mul(t160, t160) # t161: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t162 = prims.sum(t161, (2,)) # t162: \"cuda:0 f32[1, 512]\"\n", - " # t163 = prims.broadcast_in_dim(t162, [1, 512, 1], [0, 1]) # t163: \"cuda:0 f32[1, 512, 1]\"\n", - " # t164 = prims.div(t163, 4096.0) # t164: \"cuda:0 f32[1, 512, 1]\"\n", - " # t165 = prims.add(t164, 1e-05) # t165: \"cuda:0 f32[1, 512, 1]\"\n", - " # t166 = prims.rsqrt(t165) # t166: \"cuda:0 f32[1, 512, 1]\"\n", - " # t167 = prims.broadcast_in_dim(t166, (1, 512, 4096), (0, 1, 2)) # t167: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t168 = prims.mul(t160, t167) # t168: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t170 = prims.mul(t168, t169) # t170: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t159\n", - " t172 = torch.nn.functional.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t172 = ltorch.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t172 = prims.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " t171 = torch.nn.functional.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t171 = ltorch.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t171 = prims.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t178] = nvFusion9(t171, t172)\n", - " # t173 = prims.neg(t171) # t173: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t174 = prims.exp(t173) # t174: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t175 = prims.add(1.0, t174) # t175: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t176 = prims.reciprocal(t175) # t176: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t177 = prims.mul(t171, t176) # t177: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t178 = prims.mul(t177, t172) # t178: \"cuda:0 f32[1, 512, 11008]\"\n", - " t179 = torch.nn.functional.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t179 = ltorch.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t179 = prims.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t180, t186, t190] = nvFusion10(t160, t179, t189)\n", - " # t180 = prims.add(t179, t160) # t180: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t181 = prims.mul(t180, t180) # t181: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t182 = prims.sum(t181, (2,)) # t182: \"cuda:0 f32[1, 512]\"\n", - " # t183 = prims.broadcast_in_dim(t182, [1, 512, 1], [0, 1]) # t183: \"cuda:0 f32[1, 512, 1]\"\n", - " # t184 = prims.div(t183, 4096.0) # t184: \"cuda:0 f32[1, 512, 1]\"\n", - " # t185 = prims.add(t184, 1e-05) # t185: \"cuda:0 f32[1, 512, 1]\"\n", - " # t186 = prims.rsqrt(t185) # t186: \"cuda:0 f32[1, 512, 1]\"\n", - " # t187 = prims.broadcast_in_dim(t186, (1, 512, 4096), (0, 1, 2)) # t187: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t188 = prims.mul(t180, t187) # t188: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t190 = prims.mul(t188, t189) # t190: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t179\n", - " t191 = torch.nn.functional.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t191 = ltorch.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t191 = prims.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " t192 = torch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t192 = ltorch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t192 = prims.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t191\n", - " t193 = torch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t193 = ltorch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t193 = prims.transpose(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t192\n", - " (t194, t195, t196) = torch.split(t193, (1, 1, 1), 2)\n", - " # (t194, t195, t196) = ltorch.split(t193, (1, 1, 1), 2)\n", - " # t194 = prims.slice_prim(t193, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t194: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t195 = prims.slice_prim(t193, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t195: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t196 = prims.slice_prim(t193, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t196: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t193\n", - " t197 = torch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t197 = ltorch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t197 = prims.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t194\n", - " t198 = torch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t198 = ltorch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t198 = prims.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t195\n", - " t199 = torch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t199 = ltorch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t199 = prims.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t196\n", - " t200 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t200: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t210 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t210: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t220 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t220: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t197\n", - " t222 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t222: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t198\n", - " t201 = torch_slice_prim_impl(t200, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t201: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t202 = torch_slice_prim_impl(t200, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t202: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t211 = torch_slice_prim_impl(t210, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t211: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t212 = torch_slice_prim_impl(t210, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t212: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t203, t213] = nvFusion11(t202, t212)\n", - " # t203 = prims.neg(t202) # t203: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t202, t212\n", - " t214 = torch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t214 = ltorch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t214 = prims.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t213, t211\n", - " t204 = torch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t204 = ltorch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t204 = prims.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t203, t201\n", - " [t209, t219] = nvFusion12(t200, t204, t210, t214, t63, t65)\n", - " # t216 = prims.mul(t210, t63) # t216: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t218 = prims.mul(t214, t65) # t218: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t219 = prims.add(t216, t218) # t219: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t206 = prims.mul(t200, t63) # t206: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t208 = prims.mul(t204, t65) # t208: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t209 = prims.add(t206, t208) # t209: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t200, t204, t210, t214\n", - " t223 = torch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t223 = ltorch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t223 = prims.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t219, t222\n", - " t221 = torch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t221 = ltorch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t221 = prims.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t209, t220\n", - " (t224, t225, t226, t227) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t221, t223, t199, None, 0.0, True, 0.08838834764831843)\n", - " t228 = torch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t228 = ltorch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t228 = prims.transpose(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t229 = torch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t229 = ltorch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t229 = prims.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t228\n", - " t230 = torch.nn.functional.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t230 = ltorch.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t230 = prims.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t231, t237, t241] = nvFusion13(t180, t230, t240)\n", - " # t231 = prims.add(t230, t180) # t231: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t232 = prims.mul(t231, t231) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", + " t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1867 = ltorch.reshape(t0, [512]) # t1867: \"cuda:0 i64[512]\"\n", + " # t1867 = prims.reshape(t0, (512,)) # t1867: \"cuda:0 i64[512]\"\n", + " # t1868 = prims.take(t117, t1867, 0) # t1868: \"cuda:0 bf16[512, 4096]\"\n", + " # t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: \"cuda:0 f32[512, 128]\"\n", + " t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: \"cuda:0 f32[512, 128]\"\n", + " t2015 = torch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = ltorch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " t2016 = torch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = ltorch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2015\n", + " t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2016\n", + " t2356 = torch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = ltorch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " t2357 = torch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = ltorch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2356\n", + " t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2357\n", + " t2359 = torch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = ltorch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " t2360 = torch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = ltorch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2359\n", + " t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2360\n", + " t2044 = torch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = ltorch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " t2045 = torch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = ltorch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2044\n", + " t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2045\n", + " t2380 = torch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = ltorch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " t2381 = torch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = ltorch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2380\n", + " t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2381\n", + " t2047 = torch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = ltorch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " t2048 = torch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = ltorch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2047\n", + " t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2048\n", + " t2383 = torch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = ltorch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " t2384 = torch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = ltorch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2383\n", + " t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2384\n", + " t2068 = torch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = ltorch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " t2069 = torch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = ltorch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2068\n", + " t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2069\n", + " t2404 = torch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = ltorch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " t2405 = torch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = ltorch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2404\n", + " t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2405\n", + " t2071 = torch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = ltorch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " t2072 = torch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = ltorch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2071\n", + " t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2072\n", + " t2407 = torch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = ltorch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " t2408 = torch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = ltorch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2407\n", + " t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2408\n", + " t2095 = torch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = ltorch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " t2096 = torch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = ltorch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2095\n", + " t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2096\n", + " t2092 = torch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = ltorch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " t2093 = torch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = ltorch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2092\n", + " t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2093\n", + " t2116 = torch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = ltorch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " t2117 = torch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = ltorch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2116\n", + " t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2117\n", + " t2119 = torch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = ltorch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " t2120 = torch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = ltorch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2119\n", + " t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2120\n", + " t2140 = torch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = ltorch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " t2141 = torch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = ltorch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2140\n", + " t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2141\n", + " t2143 = torch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = ltorch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " t2144 = torch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = ltorch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2143\n", + " t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2144\n", + " t2164 = torch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = ltorch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " t2165 = torch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = ltorch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2164\n", + " t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2165\n", + " t2167 = torch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = ltorch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " t2168 = torch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = ltorch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2167\n", + " t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2168\n", + " t2188 = torch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = ltorch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " t2189 = torch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = ltorch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2188\n", + " t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2189\n", + " t2191 = torch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = ltorch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " t2192 = torch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = ltorch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2191\n", + " t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2192\n", + " t2212 = torch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = ltorch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " t2213 = torch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = ltorch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2212\n", + " t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2213\n", + " t2215 = torch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = ltorch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " t2216 = torch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = ltorch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2215\n", + " t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2216\n", + " t2236 = torch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = ltorch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " t2237 = torch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = ltorch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2236\n", + " t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2237\n", + " t2239 = torch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = ltorch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " t2240 = torch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = ltorch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2239\n", + " t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2240\n", + " t2260 = torch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = ltorch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " t2261 = torch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = ltorch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2260\n", + " t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2261\n", + " t2263 = torch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = ltorch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " t2264 = torch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = ltorch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2263\n", + " t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2264\n", + " t2284 = torch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = ltorch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " t2285 = torch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = ltorch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2284\n", + " t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2285\n", + " t2287 = torch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = ltorch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " t2288 = torch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = ltorch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2287\n", + " t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2288\n", + " t2308 = torch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = ltorch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " t2309 = torch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = ltorch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2308\n", + " t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2309\n", + " t2311 = torch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = ltorch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " t2312 = torch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = ltorch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2311\n", + " t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2312\n", + " t2332 = torch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = ltorch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " t2333 = torch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = ltorch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2332\n", + " t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2333\n", + " t2335 = torch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = ltorch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " t2336 = torch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = ltorch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2335\n", + " t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2336\n", + " t2036 = torch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = ltorch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " del t118\n", + " t2037 = torch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = ltorch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2036\n", + " t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2037\n", + " t2039 = torch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = ltorch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " del t119\n", + " t2040 = torch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = ltorch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2039\n", + " t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2040\n", + " [t129, t137] = nvFusion0(t122, t133)\n", + " # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t124 = prims.mul(t123, t123) # t124: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t125 = prims.sum(t124, (2,)) # t125: \"cuda:0 f32[1, 512]\"\n", + " # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: \"cuda:0 f32[1, 512, 1]\"\n", + " # t127 = prims.div(t126, 4096.0) # t127: \"cuda:0 f32[1, 512, 1]\"\n", + " # t128 = prims.add(t127, 1e-05) # t128: \"cuda:0 f32[1, 512, 1]\"\n", + " # t129 = prims.rsqrt(t128) # t129: \"cuda:0 f32[1, 512, 1]\"\n", + " # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t131 = prims.mul(t123, t130) # t131: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t135 = prims.convert_element_type(t133, dtypes.float32) # t135: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t136 = prims.mul(t131, t135) # t136: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t138 = torch.nn.functional.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = ltorch.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = prims.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t138\n", + " t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t139\n", + " (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)\n", + " # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)\n", + " # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t140\n", + " t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t141\n", + " t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t142\n", + " t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t143\n", + " t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t144\n", + " t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t145\n", + " t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t152, t167] = nvFusion1(t147, t149, t162, t164)\n", + " # t150 = prims.convert_element_type(t149, dtypes.float32) # t150: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t151 = prims.neg(t150) # t151: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t165 = prims.convert_element_type(t164, dtypes.float32) # t165: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t166 = prims.neg(t165) # t166: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t149, t164\n", + " t168 = torch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = ltorch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = prims.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t167, t163\n", + " t153 = torch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = ltorch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = prims.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t152, t148\n", + " [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)\n", + " # t155 = prims.convert_element_type(t147, dtypes.float32) # t155: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t170 = prims.convert_element_type(t162, dtypes.float32) # t170: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t156 = prims.mul(t155, t154) # t156: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t158 = prims.convert_element_type(t153, dtypes.float32) # t158: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t159 = prims.mul(t158, t157) # t159: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t160 = prims.add(t156, t159) # t160: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t171 = prims.mul(t170, t154) # t171: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t173 = prims.convert_element_type(t168, dtypes.float32) # t173: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t174 = prims.mul(t173, t157) # t174: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t175 = prims.add(t171, t174) # t175: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t147, t153, t162, t168\n", + " t178 = torch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = ltorch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = prims.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t161, t177\n", + " t180 = torch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = ltorch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = prims.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t176, t179\n", + " (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)\n", + " t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t189 = torch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = prims.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t188\n", + " t190 = torch.nn.functional.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = ltorch.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = prims.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t194, t201, t209] = nvFusion3(t122, t190, t205)\n", + " # t191 = prims.convert_element_type(t190, dtypes.float32) # t191: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t192 = prims.convert_element_type(t122, dtypes.float32) # t192: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t193 = prims.add(t191, t192) # t193: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t196 = prims.mul(t193, t193) # t196: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t197 = prims.sum(t196, (2,)) # t197: \"cuda:0 f32[1, 512]\"\n", + " # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: \"cuda:0 f32[1, 512, 1]\"\n", + " # t199 = prims.div(t198, 4096.0) # t199: \"cuda:0 f32[1, 512, 1]\"\n", + " # t200 = prims.add(t199, 1e-05) # t200: \"cuda:0 f32[1, 512, 1]\"\n", + " # t201 = prims.rsqrt(t200) # t201: \"cuda:0 f32[1, 512, 1]\"\n", + " # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t203 = prims.mul(t193, t202) # t203: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t207 = prims.convert_element_type(t205, dtypes.float32) # t207: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t208 = prims.mul(t203, t207) # t208: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t210 = torch.nn.functional.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = ltorch.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = prims.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t211 = torch.nn.functional.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = ltorch.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = prims.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t225] = nvFusion4(t210, t211)\n", + " # t212 = prims.convert_element_type(t210, dtypes.float32) # t212: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t214 = prims.exp(t213) # t214: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t215 = prims.add(1.0, t214) # t215: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t216 = prims.reciprocal(t215) # t216: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t220 = prims.mul(t212, t216) # t220: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t223 = prims.convert_element_type(t211, dtypes.float32) # t223: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t224 = prims.mul(t220, t223) # t224: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t226 = torch.nn.functional.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = ltorch.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = prims.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t230, t237, t245] = nvFusion5(t194, t226, t241)\n", + " # t228 = prims.convert_element_type(t194, dtypes.float32) # t228: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t229 = prims.add(t227, t228) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t232 = prims.mul(t229, t229) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", " # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n", " # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n", " # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n", " # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n", " # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n", " # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t239 = prims.mul(t231, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t241 = prims.mul(t239, t240) # t241: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t230\n", - " t242 = torch.nn.functional.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t242 = ltorch.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t242 = prims.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " t243 = torch.nn.functional.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t243 = ltorch.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t243 = prims.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t249] = nvFusion14(t242, t243)\n", - " # t244 = prims.neg(t242) # t244: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t245 = prims.exp(t244) # t245: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t246 = prims.add(1.0, t245) # t246: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t247 = prims.reciprocal(t246) # t247: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t248 = prims.mul(t242, t247) # t248: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t249 = prims.mul(t248, t243) # t249: \"cuda:0 f32[1, 512, 11008]\"\n", - " t250 = torch.nn.functional.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t250 = ltorch.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t250 = prims.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t251, t257, t261] = nvFusion15(t231, t250, t260)\n", - " # t251 = prims.add(t250, t231) # t251: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t252 = prims.mul(t251, t251) # t252: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t253 = prims.sum(t252, (2,)) # t253: \"cuda:0 f32[1, 512]\"\n", - " # t254 = prims.broadcast_in_dim(t253, [1, 512, 1], [0, 1]) # t254: \"cuda:0 f32[1, 512, 1]\"\n", - " # t255 = prims.div(t254, 4096.0) # t255: \"cuda:0 f32[1, 512, 1]\"\n", - " # t256 = prims.add(t255, 1e-05) # t256: \"cuda:0 f32[1, 512, 1]\"\n", - " # t257 = prims.rsqrt(t256) # t257: \"cuda:0 f32[1, 512, 1]\"\n", - " # t258 = prims.broadcast_in_dim(t257, (1, 512, 4096), (0, 1, 2)) # t258: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t259 = prims.mul(t251, t258) # t259: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t261 = prims.mul(t259, t260) # t261: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t239 = prims.mul(t229, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t244 = prims.mul(t239, t243) # t244: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t246 = torch.nn.functional.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = ltorch.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = prims.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t246\n", + " t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t247\n", + " (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)\n", + " # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)\n", + " # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t248\n", + " t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t249\n", + " t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t250\n", - " t262 = torch.nn.functional.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t262 = ltorch.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t262 = prims.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " t263 = torch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t263 = ltorch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t263 = prims.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t262\n", - " t264 = torch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t264 = ltorch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t264 = prims.transpose(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t263\n", - " (t265, t266, t267) = torch.split(t264, (1, 1, 1), 2)\n", - " # (t265, t266, t267) = ltorch.split(t264, (1, 1, 1), 2)\n", - " # t265 = prims.slice_prim(t264, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t265: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t266 = prims.slice_prim(t264, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t266: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t267 = prims.slice_prim(t264, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t267: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t264\n", - " t268 = torch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t268 = ltorch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t268 = prims.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t265\n", - " t269 = torch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t269 = ltorch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t269 = prims.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t266\n", - " t270 = torch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t270 = ltorch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t270 = prims.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t267\n", - " t271 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t271: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t281 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t291 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t291: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t268\n", - " t293 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t293: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t269\n", - " t272 = torch_slice_prim_impl(t271, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t272: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t273 = torch_slice_prim_impl(t271, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t282 = torch_slice_prim_impl(t281, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t282: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t283 = torch_slice_prim_impl(t281, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t283: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t274, t284] = nvFusion16(t273, t283)\n", + " t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t251\n", + " t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t252\n", + " t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t253\n", + " t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t260, t275] = nvFusion6(t255, t257, t270, t272)\n", + " # t258 = prims.convert_element_type(t257, dtypes.float32) # t258: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t259 = prims.neg(t258) # t259: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t273 = prims.convert_element_type(t272, dtypes.float32) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t284 = prims.neg(t283) # t284: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t273, t283\n", - " t275 = torch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t275 = ltorch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t275 = prims.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t274, t272\n", - " t285 = torch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t285 = ltorch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t285 = prims.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t284, t282\n", - " [t280, t290] = nvFusion17(t271, t275, t281, t285, t63, t65)\n", - " # t277 = prims.mul(t271, t63) # t277: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t279 = prims.mul(t275, t65) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t280 = prims.add(t277, t279) # t280: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t287 = prims.mul(t281, t63) # t287: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t289 = prims.mul(t285, t65) # t289: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t290 = prims.add(t287, t289) # t290: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t271, t275, t281, t285\n", - " t292 = torch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t292 = ltorch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t292 = prims.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t280, t291\n", - " t294 = torch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t294 = ltorch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t294 = prims.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t290, t293\n", - " (t295, t296, t297, t298) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t292, t294, t270, None, 0.0, True, 0.08838834764831843)\n", - " t299 = torch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t299 = ltorch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t299 = prims.transpose(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t300 = torch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t300 = ltorch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t300 = prims.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t299\n", - " t301 = torch.nn.functional.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t301 = ltorch.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t301 = prims.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t302, t308, t312] = nvFusion18(t251, t301, t311)\n", - " # t302 = prims.add(t301, t251) # t302: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t303 = prims.mul(t302, t302) # t303: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t304 = prims.sum(t303, (2,)) # t304: \"cuda:0 f32[1, 512]\"\n", - " # t305 = prims.broadcast_in_dim(t304, [1, 512, 1], [0, 1]) # t305: \"cuda:0 f32[1, 512, 1]\"\n", - " # t306 = prims.div(t305, 4096.0) # t306: \"cuda:0 f32[1, 512, 1]\"\n", - " # t307 = prims.add(t306, 1e-05) # t307: \"cuda:0 f32[1, 512, 1]\"\n", - " # t308 = prims.rsqrt(t307) # t308: \"cuda:0 f32[1, 512, 1]\"\n", - " # t309 = prims.broadcast_in_dim(t308, (1, 512, 4096), (0, 1, 2)) # t309: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t310 = prims.mul(t302, t309) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t312 = prims.mul(t310, t311) # t312: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t301\n", - " t314 = torch.nn.functional.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t314 = ltorch.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t314 = prims.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " t313 = torch.nn.functional.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t313 = ltorch.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t313 = prims.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t320] = nvFusion19(t313, t314)\n", - " # t315 = prims.neg(t313) # t315: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t316 = prims.exp(t315) # t316: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t317 = prims.add(1.0, t316) # t317: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t318 = prims.reciprocal(t317) # t318: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t319 = prims.mul(t313, t318) # t319: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t320 = prims.mul(t319, t314) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", - " t321 = torch.nn.functional.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t321 = ltorch.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t321 = prims.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t322, t328, t332] = nvFusion20(t302, t321, t331)\n", - " # t322 = prims.add(t321, t302) # t322: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t323 = prims.mul(t322, t322) # t323: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t324 = prims.sum(t323, (2,)) # t324: \"cuda:0 f32[1, 512]\"\n", - " # t325 = prims.broadcast_in_dim(t324, [1, 512, 1], [0, 1]) # t325: \"cuda:0 f32[1, 512, 1]\"\n", - " # t326 = prims.div(t325, 4096.0) # t326: \"cuda:0 f32[1, 512, 1]\"\n", - " # t327 = prims.add(t326, 1e-05) # t327: \"cuda:0 f32[1, 512, 1]\"\n", - " # t328 = prims.rsqrt(t327) # t328: \"cuda:0 f32[1, 512, 1]\"\n", - " # t329 = prims.broadcast_in_dim(t328, (1, 512, 4096), (0, 1, 2)) # t329: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t330 = prims.mul(t322, t329) # t330: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t332 = prims.mul(t330, t331) # t332: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t321\n", - " t333 = torch.nn.functional.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " # t333 = ltorch.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " # t333 = prims.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " return {'output': t333, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33], 'flat_output': (t333,)}, ((t0, t10, t100, t101, t107, t109, t11, t115, t118, t119, t12, t128, t13, t14, t15, t150, t152, t153, t154, t155, t156, t158, t160, t166, t169, t170, t171, t172, t178, t180, t186, t189, t190, t199, t221, t223, t224, t225, t226, t227, t229, t231, t237, t240, t241, t242, t243, t249, t25, t251, t257, t26, t260, t261, t27, t270, t28, t29, t292, t294, t295, t296, t297, t298, t3, t30, t300, t302, t308, t31, t311, t312, t313, t314, t32, t320, t322, t328, t331, t332, t38, t4, t44, t47, t48, t5, t57, t6, t63, t65, t7, t79, t8, t81, t82, t83, t84, t85, t87, t89, t9, t95, t98, t99), (False, True, True, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2))" + " # t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t257, t272\n", + " t261 = torch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = ltorch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = prims.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t260, t256\n", + " t276 = torch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = ltorch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = prims.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t275, t271\n", + " [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)\n", + " # t263 = prims.convert_element_type(t255, dtypes.float32) # t263: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t278 = prims.convert_element_type(t270, dtypes.float32) # t278: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t264 = prims.mul(t263, t154) # t264: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t266 = prims.convert_element_type(t261, dtypes.float32) # t266: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t267 = prims.mul(t266, t157) # t267: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t268 = prims.add(t264, t267) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t279 = prims.mul(t278, t154) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t281 = prims.convert_element_type(t276, dtypes.float32) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t282 = prims.mul(t281, t157) # t282: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t283 = prims.add(t279, t282) # t283: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t255, t261, t270, t276\n", + " t288 = torch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = ltorch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = prims.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t284, t287\n", + " t286 = torch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = ltorch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = prims.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t269, t285\n", + " (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)\n", + " t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t297 = torch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = prims.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t296\n", + " t298 = torch.nn.functional.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = ltorch.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = prims.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t302, t309, t317] = nvFusion8(t230, t298, t313)\n", + " # t300 = prims.convert_element_type(t230, dtypes.float32) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t299 = prims.convert_element_type(t298, dtypes.float32) # t299: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t301 = prims.add(t299, t300) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t304 = prims.mul(t301, t301) # t304: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t305 = prims.sum(t304, (2,)) # t305: \"cuda:0 f32[1, 512]\"\n", + " # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: \"cuda:0 f32[1, 512, 1]\"\n", + " # t307 = prims.div(t306, 4096.0) # t307: \"cuda:0 f32[1, 512, 1]\"\n", + " # t308 = prims.add(t307, 1e-05) # t308: \"cuda:0 f32[1, 512, 1]\"\n", + " # t309 = prims.rsqrt(t308) # t309: \"cuda:0 f32[1, 512, 1]\"\n", + " # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t311 = prims.mul(t301, t310) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t315 = prims.convert_element_type(t313, dtypes.float32) # t315: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t316 = prims.mul(t311, t315) # t316: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t318 = torch.nn.functional.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = ltorch.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = prims.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t319 = torch.nn.functional.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = ltorch.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = prims.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t333] = nvFusion9(t318, t319)\n", + " # t320 = prims.convert_element_type(t318, dtypes.float32) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t321 = prims.neg(t320) # t321: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t322 = prims.exp(t321) # t322: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t323 = prims.add(1.0, t322) # t323: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t324 = prims.reciprocal(t323) # t324: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t328 = prims.mul(t320, t324) # t328: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t331 = prims.convert_element_type(t319, dtypes.float32) # t331: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t332 = prims.mul(t328, t331) # t332: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t334 = torch.nn.functional.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = ltorch.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = prims.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t338, t345, t353] = nvFusion10(t302, t334, t349)\n", + " # t336 = prims.convert_element_type(t302, dtypes.float32) # t336: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t335 = prims.convert_element_type(t334, dtypes.float32) # t335: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t337 = prims.add(t335, t336) # t337: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t340 = prims.mul(t337, t337) # t340: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t341 = prims.sum(t340, (2,)) # t341: \"cuda:0 f32[1, 512]\"\n", + " # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: \"cuda:0 f32[1, 512, 1]\"\n", + " # t343 = prims.div(t342, 4096.0) # t343: \"cuda:0 f32[1, 512, 1]\"\n", + " # t344 = prims.add(t343, 1e-05) # t344: \"cuda:0 f32[1, 512, 1]\"\n", + " # t345 = prims.rsqrt(t344) # t345: \"cuda:0 f32[1, 512, 1]\"\n", + " # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t347 = prims.mul(t337, t346) # t347: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t351 = prims.convert_element_type(t349, dtypes.float32) # t351: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t352 = prims.mul(t347, t351) # t352: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t354 = torch.nn.functional.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = ltorch.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = prims.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t354\n", + " t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t355\n", + " (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)\n", + " # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)\n", + " # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t356\n", + " t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t357\n", + " t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t358\n", + " t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t359\n", + " t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t360\n", + " t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t361\n", + " t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t368, t383] = nvFusion11(t363, t365, t378, t380)\n", + " # t366 = prims.convert_element_type(t365, dtypes.float32) # t366: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t367 = prims.neg(t366) # t367: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t381 = prims.convert_element_type(t380, dtypes.float32) # t381: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t382 = prims.neg(t381) # t382: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t365, t380\n", + " t369 = torch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = ltorch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = prims.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t368, t364\n", + " t384 = torch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = ltorch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = prims.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t383, t379\n", + " [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)\n", + " # t371 = prims.convert_element_type(t363, dtypes.float32) # t371: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t386 = prims.convert_element_type(t378, dtypes.float32) # t386: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t372 = prims.mul(t371, t154) # t372: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t374 = prims.convert_element_type(t369, dtypes.float32) # t374: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t375 = prims.mul(t374, t157) # t375: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t376 = prims.add(t372, t375) # t376: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t387 = prims.mul(t386, t154) # t387: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t389 = prims.convert_element_type(t384, dtypes.float32) # t389: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t390 = prims.mul(t389, t157) # t390: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t391 = prims.add(t387, t390) # t391: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t363, t369, t378, t384\n", + " t394 = torch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = ltorch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = prims.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t377, t393\n", + " t396 = torch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = ltorch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = prims.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t392, t395\n", + " (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)\n", + " t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t405 = torch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = prims.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t404\n", + " t406 = torch.nn.functional.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = ltorch.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = prims.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t410, t417, t425] = nvFusion13(t338, t406, t421)\n", + " # t408 = prims.convert_element_type(t338, dtypes.float32) # t408: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t407 = prims.convert_element_type(t406, dtypes.float32) # t407: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t409 = prims.add(t407, t408) # t409: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t412 = prims.mul(t409, t409) # t412: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t413 = prims.sum(t412, (2,)) # t413: \"cuda:0 f32[1, 512]\"\n", + " # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: \"cuda:0 f32[1, 512, 1]\"\n", + " # t415 = prims.div(t414, 4096.0) # t415: \"cuda:0 f32[1, 512, 1]\"\n", + " # t416 = prims.add(t415, 1e-05) # t416: \"cuda:0 f32[1, 512, 1]\"\n", + " # t417 = prims.rsqrt(t416) # t417: \"cuda:0 f32[1, 512, 1]\"\n", + " # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t419 = prims.mul(t409, t418) # t419: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t423 = prims.convert_element_type(t421, dtypes.float32) # t423: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t424 = prims.mul(t419, t423) # t424: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t426 = torch.nn.functional.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = ltorch.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = prims.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t427 = torch.nn.functional.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = ltorch.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = prims.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t441] = nvFusion14(t426, t427)\n", + " # t428 = prims.convert_element_type(t426, dtypes.float32) # t428: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t429 = prims.neg(t428) # t429: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t430 = prims.exp(t429) # t430: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t431 = prims.add(1.0, t430) # t431: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t432 = prims.reciprocal(t431) # t432: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t436 = prims.mul(t428, t432) # t436: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t439 = prims.convert_element_type(t427, dtypes.float32) # t439: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t440 = prims.mul(t436, t439) # t440: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t442 = torch.nn.functional.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = ltorch.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = prims.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t446, t453, t461] = nvFusion15(t410, t442, t457)\n", + " # t444 = prims.convert_element_type(t410, dtypes.float32) # t444: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t443 = prims.convert_element_type(t442, dtypes.float32) # t443: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t445 = prims.add(t443, t444) # t445: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t448 = prims.mul(t445, t445) # t448: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t449 = prims.sum(t448, (2,)) # t449: \"cuda:0 f32[1, 512]\"\n", + " # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: \"cuda:0 f32[1, 512, 1]\"\n", + " # t451 = prims.div(t450, 4096.0) # t451: \"cuda:0 f32[1, 512, 1]\"\n", + " # t452 = prims.add(t451, 1e-05) # t452: \"cuda:0 f32[1, 512, 1]\"\n", + " # t453 = prims.rsqrt(t452) # t453: \"cuda:0 f32[1, 512, 1]\"\n", + " # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t455 = prims.mul(t445, t454) # t455: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t459 = prims.convert_element_type(t457, dtypes.float32) # t459: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t460 = prims.mul(t455, t459) # t460: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t462 = torch.nn.functional.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = ltorch.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = prims.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t462\n", + " t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t463\n", + " (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)\n", + " # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)\n", + " # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t464\n", + " t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t465\n", + " t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t466\n", + " t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t467\n", + " t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t468\n", + " t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t469\n", + " t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t476, t491] = nvFusion16(t471, t473, t486, t488)\n", + " # t474 = prims.convert_element_type(t473, dtypes.float32) # t474: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t475 = prims.neg(t474) # t475: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t489 = prims.convert_element_type(t488, dtypes.float32) # t489: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t490 = prims.neg(t489) # t490: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t473, t488\n", + " t477 = torch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = ltorch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = prims.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t476, t472\n", + " t492 = torch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = ltorch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = prims.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t491, t487\n", + " [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)\n", + " # t479 = prims.convert_element_type(t471, dtypes.float32) # t479: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t494 = prims.convert_element_type(t486, dtypes.float32) # t494: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t480 = prims.mul(t479, t154) # t480: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t482 = prims.convert_element_type(t477, dtypes.float32) # t482: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t483 = prims.mul(t482, t157) # t483: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t484 = prims.add(t480, t483) # t484: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t495 = prims.mul(t494, t154) # t495: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t497 = prims.convert_element_type(t492, dtypes.float32) # t497: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t498 = prims.mul(t497, t157) # t498: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t499 = prims.add(t495, t498) # t499: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t471, t477, t486, t492\n", + " t502 = torch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = ltorch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = prims.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t485, t501\n", + " t504 = torch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = ltorch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = prims.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t500, t503\n", + " (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)\n", + " t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t513 = torch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = prims.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t512\n", + " t514 = torch.nn.functional.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = ltorch.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = prims.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t518, t525, t533] = nvFusion18(t446, t514, t529)\n", + " # t516 = prims.convert_element_type(t446, dtypes.float32) # t516: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t515 = prims.convert_element_type(t514, dtypes.float32) # t515: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t517 = prims.add(t515, t516) # t517: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t520 = prims.mul(t517, t517) # t520: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t521 = prims.sum(t520, (2,)) # t521: \"cuda:0 f32[1, 512]\"\n", + " # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: \"cuda:0 f32[1, 512, 1]\"\n", + " # t523 = prims.div(t522, 4096.0) # t523: \"cuda:0 f32[1, 512, 1]\"\n", + " # t524 = prims.add(t523, 1e-05) # t524: \"cuda:0 f32[1, 512, 1]\"\n", + " # t525 = prims.rsqrt(t524) # t525: \"cuda:0 f32[1, 512, 1]\"\n", + " # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t527 = prims.mul(t517, t526) # t527: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t531 = prims.convert_element_type(t529, dtypes.float32) # t531: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t532 = prims.mul(t527, t531) # t532: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t534 = torch.nn.functional.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = ltorch.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = prims.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t535 = torch.nn.functional.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = ltorch.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = prims.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t549] = nvFusion19(t534, t535)\n", + " # t536 = prims.convert_element_type(t534, dtypes.float32) # t536: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t537 = prims.neg(t536) # t537: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t538 = prims.exp(t537) # t538: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t539 = prims.add(1.0, t538) # t539: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t540 = prims.reciprocal(t539) # t540: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t544 = prims.mul(t536, t540) # t544: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t547 = prims.convert_element_type(t535, dtypes.float32) # t547: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t548 = prims.mul(t544, t547) # t548: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t550 = torch.nn.functional.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = ltorch.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = prims.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t554, t561, t569] = nvFusion20(t518, t550, t565)\n", + " # t552 = prims.convert_element_type(t518, dtypes.float32) # t552: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t551 = prims.convert_element_type(t550, dtypes.float32) # t551: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t553 = prims.add(t551, t552) # t553: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t556 = prims.mul(t553, t553) # t556: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t557 = prims.sum(t556, (2,)) # t557: \"cuda:0 f32[1, 512]\"\n", + " # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: \"cuda:0 f32[1, 512, 1]\"\n", + " # t559 = prims.div(t558, 4096.0) # t559: \"cuda:0 f32[1, 512, 1]\"\n", + " # t560 = prims.add(t559, 1e-05) # t560: \"cuda:0 f32[1, 512, 1]\"\n", + " # t561 = prims.rsqrt(t560) # t561: \"cuda:0 f32[1, 512, 1]\"\n", + " # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t563 = prims.mul(t553, t562) # t563: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t567 = prims.convert_element_type(t565, dtypes.float32) # t567: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t568 = prims.mul(t563, t567) # t568: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t570 = torch.nn.functional.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = ltorch.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = prims.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t570\n", + " t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t571\n", + " (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)\n", + " # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)\n", + " # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t572\n", + " t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t573\n", + " t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t574\n", + " t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t575\n", + " t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t576\n", + " t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t577\n", + " t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t584, t599] = nvFusion21(t579, t581, t594, t596)\n", + " # t582 = prims.convert_element_type(t581, dtypes.float32) # t582: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t583 = prims.neg(t582) # t583: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t597 = prims.convert_element_type(t596, dtypes.float32) # t597: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t598 = prims.neg(t597) # t598: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t581, t596\n", + " t600 = torch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = ltorch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = prims.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t599, t595\n", + " t585 = torch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = ltorch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = prims.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t584, t580\n", + " [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)\n", + " # t587 = prims.convert_element_type(t579, dtypes.float32) # t587: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t602 = prims.convert_element_type(t594, dtypes.float32) # t602: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t603 = prims.mul(t602, t154) # t603: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t605 = prims.convert_element_type(t600, dtypes.float32) # t605: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t606 = prims.mul(t605, t157) # t606: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t607 = prims.add(t603, t606) # t607: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t588 = prims.mul(t587, t154) # t588: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t590 = prims.convert_element_type(t585, dtypes.float32) # t590: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t591 = prims.mul(t590, t157) # t591: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t592 = prims.add(t588, t591) # t592: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t579, t585, t594, t600\n", + " t612 = torch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = ltorch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = prims.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t608, t611\n", + " t610 = torch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = ltorch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = prims.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t593, t609\n", + " (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)\n", + " t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t621 = torch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = prims.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t620\n", + " t622 = torch.nn.functional.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = ltorch.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = prims.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t626, t633, t641] = nvFusion23(t554, t622, t637)\n", + " # t624 = prims.convert_element_type(t554, dtypes.float32) # t624: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t623 = prims.convert_element_type(t622, dtypes.float32) # t623: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t625 = prims.add(t623, t624) # t625: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t628 = prims.mul(t625, t625) # t628: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t629 = prims.sum(t628, (2,)) # t629: \"cuda:0 f32[1, 512]\"\n", + " # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: \"cuda:0 f32[1, 512, 1]\"\n", + " # t631 = prims.div(t630, 4096.0) # t631: \"cuda:0 f32[1, 512, 1]\"\n", + " # t632 = prims.add(t631, 1e-05) # t632: \"cuda:0 f32[1, 512, 1]\"\n", + " # t633 = prims.rsqrt(t632) # t633: \"cuda:0 f32[1, 512, 1]\"\n", + " # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t635 = prims.mul(t625, t634) # t635: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t639 = prims.convert_element_type(t637, dtypes.float32) # t639: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t640 = prims.mul(t635, t639) # t640: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t643 = torch.nn.functional.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = ltorch.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = prims.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t642 = torch.nn.functional.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = ltorch.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = prims.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t657] = nvFusion24(t642, t643)\n", + " # t644 = prims.convert_element_type(t642, dtypes.float32) # t644: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t645 = prims.neg(t644) # t645: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t646 = prims.exp(t645) # t646: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t647 = prims.add(1.0, t646) # t647: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t648 = prims.reciprocal(t647) # t648: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t652 = prims.mul(t644, t648) # t652: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t655 = prims.convert_element_type(t643, dtypes.float32) # t655: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t656 = prims.mul(t652, t655) # t656: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t658 = torch.nn.functional.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = ltorch.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = prims.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t662, t669, t677] = nvFusion25(t626, t658, t673)\n", + " # t660 = prims.convert_element_type(t626, dtypes.float32) # t660: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t659 = prims.convert_element_type(t658, dtypes.float32) # t659: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t661 = prims.add(t659, t660) # t661: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t664 = prims.mul(t661, t661) # t664: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t665 = prims.sum(t664, (2,)) # t665: \"cuda:0 f32[1, 512]\"\n", + " # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: \"cuda:0 f32[1, 512, 1]\"\n", + " # t667 = prims.div(t666, 4096.0) # t667: \"cuda:0 f32[1, 512, 1]\"\n", + " # t668 = prims.add(t667, 1e-05) # t668: \"cuda:0 f32[1, 512, 1]\"\n", + " # t669 = prims.rsqrt(t668) # t669: \"cuda:0 f32[1, 512, 1]\"\n", + " # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t671 = prims.mul(t661, t670) # t671: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t675 = prims.convert_element_type(t673, dtypes.float32) # t675: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t676 = prims.mul(t671, t675) # t676: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t678 = torch.nn.functional.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = ltorch.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = prims.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t678\n", + " t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t679\n", + " (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)\n", + " # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)\n", + " # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t680\n", + " t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t681\n", + " t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t682\n", + " t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t683\n", + " t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t684\n", + " t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t685\n", + " t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t692, t707] = nvFusion26(t687, t689, t702, t704)\n", + " # t690 = prims.convert_element_type(t689, dtypes.float32) # t690: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t691 = prims.neg(t690) # t691: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t705 = prims.convert_element_type(t704, dtypes.float32) # t705: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t706 = prims.neg(t705) # t706: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t689, t704\n", + " t708 = torch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = ltorch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = prims.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t707, t703\n", + " t693 = torch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = ltorch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = prims.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t692, t688\n", + " [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)\n", + " # t695 = prims.convert_element_type(t687, dtypes.float32) # t695: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t710 = prims.convert_element_type(t702, dtypes.float32) # t710: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t711 = prims.mul(t710, t154) # t711: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t713 = prims.convert_element_type(t708, dtypes.float32) # t713: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t714 = prims.mul(t713, t157) # t714: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t715 = prims.add(t711, t714) # t715: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t696 = prims.mul(t695, t154) # t696: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t698 = prims.convert_element_type(t693, dtypes.float32) # t698: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t699 = prims.mul(t698, t157) # t699: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t700 = prims.add(t696, t699) # t700: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t687, t693, t702, t708\n", + " t720 = torch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = ltorch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = prims.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t716, t719\n", + " t718 = torch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = ltorch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = prims.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t701, t717\n", + " (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)\n", + " t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t729 = torch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = prims.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t728\n", + " t730 = torch.nn.functional.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = ltorch.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = prims.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t734, t741, t749] = nvFusion28(t662, t730, t745)\n", + " # t732 = prims.convert_element_type(t662, dtypes.float32) # t732: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t733 = prims.add(t731, t732) # t733: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t736 = prims.mul(t733, t733) # t736: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t737 = prims.sum(t736, (2,)) # t737: \"cuda:0 f32[1, 512]\"\n", + " # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: \"cuda:0 f32[1, 512, 1]\"\n", + " # t739 = prims.div(t738, 4096.0) # t739: \"cuda:0 f32[1, 512, 1]\"\n", + " # t740 = prims.add(t739, 1e-05) # t740: \"cuda:0 f32[1, 512, 1]\"\n", + " # t741 = prims.rsqrt(t740) # t741: \"cuda:0 f32[1, 512, 1]\"\n", + " # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t743 = prims.mul(t733, t742) # t743: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t747 = prims.convert_element_type(t745, dtypes.float32) # t747: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t748 = prims.mul(t743, t747) # t748: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t750 = torch.nn.functional.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = ltorch.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = prims.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t751 = torch.nn.functional.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = ltorch.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = prims.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t765] = nvFusion29(t750, t751)\n", + " # t752 = prims.convert_element_type(t750, dtypes.float32) # t752: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t753 = prims.neg(t752) # t753: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t754 = prims.exp(t753) # t754: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t755 = prims.add(1.0, t754) # t755: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t756 = prims.reciprocal(t755) # t756: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t760 = prims.mul(t752, t756) # t760: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t763 = prims.convert_element_type(t751, dtypes.float32) # t763: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t764 = prims.mul(t760, t763) # t764: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t766 = torch.nn.functional.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = ltorch.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = prims.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t770, t777, t785] = nvFusion30(t734, t766, t781)\n", + " # t768 = prims.convert_element_type(t734, dtypes.float32) # t768: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t767 = prims.convert_element_type(t766, dtypes.float32) # t767: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t769 = prims.add(t767, t768) # t769: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t772 = prims.mul(t769, t769) # t772: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t773 = prims.sum(t772, (2,)) # t773: \"cuda:0 f32[1, 512]\"\n", + " # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: \"cuda:0 f32[1, 512, 1]\"\n", + " # t775 = prims.div(t774, 4096.0) # t775: \"cuda:0 f32[1, 512, 1]\"\n", + " # t776 = prims.add(t775, 1e-05) # t776: \"cuda:0 f32[1, 512, 1]\"\n", + " # t777 = prims.rsqrt(t776) # t777: \"cuda:0 f32[1, 512, 1]\"\n", + " # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t779 = prims.mul(t769, t778) # t779: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t783 = prims.convert_element_type(t781, dtypes.float32) # t783: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t784 = prims.mul(t779, t783) # t784: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t786 = torch.nn.functional.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = ltorch.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = prims.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t786\n", + " t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t787\n", + " (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)\n", + " # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)\n", + " # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t788\n", + " t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t789\n", + " t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t790\n", + " t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t791\n", + " t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t792\n", + " t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t793\n", + " t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t800, t815] = nvFusion31(t795, t797, t810, t812)\n", + " # t798 = prims.convert_element_type(t797, dtypes.float32) # t798: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t799 = prims.neg(t798) # t799: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t813 = prims.convert_element_type(t812, dtypes.float32) # t813: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t814 = prims.neg(t813) # t814: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t797, t812\n", + " t816 = torch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = ltorch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = prims.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t815, t811\n", + " t801 = torch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = ltorch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = prims.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t800, t796\n", + " [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)\n", + " # t803 = prims.convert_element_type(t795, dtypes.float32) # t803: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t818 = prims.convert_element_type(t810, dtypes.float32) # t818: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t819 = prims.mul(t818, t154) # t819: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t821 = prims.convert_element_type(t816, dtypes.float32) # t821: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t822 = prims.mul(t821, t157) # t822: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t823 = prims.add(t819, t822) # t823: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t804 = prims.mul(t803, t154) # t804: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t806 = prims.convert_element_type(t801, dtypes.float32) # t806: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t807 = prims.mul(t806, t157) # t807: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t808 = prims.add(t804, t807) # t808: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t795, t801, t810, t816\n", + " t828 = torch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = ltorch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = prims.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t824, t827\n", + " t826 = torch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = ltorch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = prims.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t809, t825\n", + " (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)\n", + " t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t837 = torch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = prims.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t836\n", + " t838 = torch.nn.functional.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = ltorch.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = prims.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t842, t849, t857] = nvFusion33(t770, t838, t853)\n", + " # t840 = prims.convert_element_type(t770, dtypes.float32) # t840: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t839 = prims.convert_element_type(t838, dtypes.float32) # t839: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t841 = prims.add(t839, t840) # t841: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t844 = prims.mul(t841, t841) # t844: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t845 = prims.sum(t844, (2,)) # t845: \"cuda:0 f32[1, 512]\"\n", + " # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: \"cuda:0 f32[1, 512, 1]\"\n", + " # t847 = prims.div(t846, 4096.0) # t847: \"cuda:0 f32[1, 512, 1]\"\n", + " # t848 = prims.add(t847, 1e-05) # t848: \"cuda:0 f32[1, 512, 1]\"\n", + " # t849 = prims.rsqrt(t848) # t849: \"cuda:0 f32[1, 512, 1]\"\n", + " # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t851 = prims.mul(t841, t850) # t851: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t855 = prims.convert_element_type(t853, dtypes.float32) # t855: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t856 = prims.mul(t851, t855) # t856: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t858 = torch.nn.functional.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = ltorch.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = prims.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t859 = torch.nn.functional.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = ltorch.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = prims.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t873] = nvFusion34(t858, t859)\n", + " # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t861 = prims.neg(t860) # t861: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t862 = prims.exp(t861) # t862: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t863 = prims.add(1.0, t862) # t863: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t864 = prims.reciprocal(t863) # t864: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t868 = prims.mul(t860, t864) # t868: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t871 = prims.convert_element_type(t859, dtypes.float32) # t871: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t872 = prims.mul(t868, t871) # t872: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t874 = torch.nn.functional.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = ltorch.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = prims.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t878, t885, t893] = nvFusion35(t842, t874, t889)\n", + " # t876 = prims.convert_element_type(t842, dtypes.float32) # t876: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t875 = prims.convert_element_type(t874, dtypes.float32) # t875: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t877 = prims.add(t875, t876) # t877: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t880 = prims.mul(t877, t877) # t880: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t881 = prims.sum(t880, (2,)) # t881: \"cuda:0 f32[1, 512]\"\n", + " # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: \"cuda:0 f32[1, 512, 1]\"\n", + " # t883 = prims.div(t882, 4096.0) # t883: \"cuda:0 f32[1, 512, 1]\"\n", + " # t884 = prims.add(t883, 1e-05) # t884: \"cuda:0 f32[1, 512, 1]\"\n", + " # t885 = prims.rsqrt(t884) # t885: \"cuda:0 f32[1, 512, 1]\"\n", + " # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t887 = prims.mul(t877, t886) # t887: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t891 = prims.convert_element_type(t889, dtypes.float32) # t891: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t892 = prims.mul(t887, t891) # t892: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t894 = torch.nn.functional.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = ltorch.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = prims.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t894\n", + " t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t895\n", + " (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)\n", + " # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)\n", + " # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t896\n", + " t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t897\n", + " t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t898\n", + " t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t899\n", + " t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t901\n", + " t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t900\n", + " t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t908, t923] = nvFusion36(t903, t905, t918, t920)\n", + " # t906 = prims.convert_element_type(t905, dtypes.float32) # t906: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t907 = prims.neg(t906) # t907: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t921 = prims.convert_element_type(t920, dtypes.float32) # t921: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t922 = prims.neg(t921) # t922: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t905, t920\n", + " t924 = torch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = ltorch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = prims.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t923, t919\n", + " t909 = torch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = ltorch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = prims.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t908, t904\n", + " [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)\n", + " # t911 = prims.convert_element_type(t903, dtypes.float32) # t911: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t926 = prims.convert_element_type(t918, dtypes.float32) # t926: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t927 = prims.mul(t926, t154) # t927: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t929 = prims.convert_element_type(t924, dtypes.float32) # t929: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t930 = prims.mul(t929, t157) # t930: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t931 = prims.add(t927, t930) # t931: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t912 = prims.mul(t911, t154) # t912: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t914 = prims.convert_element_type(t909, dtypes.float32) # t914: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t915 = prims.mul(t914, t157) # t915: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t916 = prims.add(t912, t915) # t916: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t903, t909, t918, t924\n", + " t936 = torch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = ltorch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = prims.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t932, t935\n", + " t934 = torch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = ltorch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = prims.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t917, t933\n", + " (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)\n", + " t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t945 = torch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = prims.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t944\n", + " t946 = torch.nn.functional.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = ltorch.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = prims.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t950, t957, t965] = nvFusion38(t878, t946, t961)\n", + " # t948 = prims.convert_element_type(t878, dtypes.float32) # t948: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t947 = prims.convert_element_type(t946, dtypes.float32) # t947: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t949 = prims.add(t947, t948) # t949: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t952 = prims.mul(t949, t949) # t952: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t953 = prims.sum(t952, (2,)) # t953: \"cuda:0 f32[1, 512]\"\n", + " # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: \"cuda:0 f32[1, 512, 1]\"\n", + " # t955 = prims.div(t954, 4096.0) # t955: \"cuda:0 f32[1, 512, 1]\"\n", + " # t956 = prims.add(t955, 1e-05) # t956: \"cuda:0 f32[1, 512, 1]\"\n", + " # t957 = prims.rsqrt(t956) # t957: \"cuda:0 f32[1, 512, 1]\"\n", + " # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t959 = prims.mul(t949, t958) # t959: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t963 = prims.convert_element_type(t961, dtypes.float32) # t963: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t964 = prims.mul(t959, t963) # t964: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t967 = torch.nn.functional.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = ltorch.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = prims.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t966 = torch.nn.functional.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = ltorch.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = prims.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t981] = nvFusion39(t966, t967)\n", + " # t968 = prims.convert_element_type(t966, dtypes.float32) # t968: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t969 = prims.neg(t968) # t969: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t970 = prims.exp(t969) # t970: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t971 = prims.add(1.0, t970) # t971: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t972 = prims.reciprocal(t971) # t972: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t976 = prims.mul(t968, t972) # t976: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t979 = prims.convert_element_type(t967, dtypes.float32) # t979: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t980 = prims.mul(t976, t979) # t980: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t982 = torch.nn.functional.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = ltorch.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = prims.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1001, t986, t993] = nvFusion40(t950, t982, t997)\n", + " # t984 = prims.convert_element_type(t950, dtypes.float32) # t984: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t983 = prims.convert_element_type(t982, dtypes.float32) # t983: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t985 = prims.add(t983, t984) # t985: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t988 = prims.mul(t985, t985) # t988: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t989 = prims.sum(t988, (2,)) # t989: \"cuda:0 f32[1, 512]\"\n", + " # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: \"cuda:0 f32[1, 512, 1]\"\n", + " # t991 = prims.div(t990, 4096.0) # t991: \"cuda:0 f32[1, 512, 1]\"\n", + " # t992 = prims.add(t991, 1e-05) # t992: \"cuda:0 f32[1, 512, 1]\"\n", + " # t993 = prims.rsqrt(t992) # t993: \"cuda:0 f32[1, 512, 1]\"\n", + " # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t995 = prims.mul(t985, t994) # t995: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t999 = prims.convert_element_type(t997, dtypes.float32) # t999: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1000 = prims.mul(t995, t999) # t1000: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = ltorch.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = prims.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1002\n", + " t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1003\n", + " (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)\n", + " # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)\n", + " # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1004\n", + " t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1005\n", + " t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1006\n", + " t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1007\n", + " t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1009\n", + " t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1008\n", + " t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)\n", + " # t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1015 = prims.neg(t1014) # t1015: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1030 = prims.neg(t1029) # t1030: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1013, t1028\n", + " t1032 = torch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = ltorch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = prims.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1031, t1027\n", + " t1017 = torch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = ltorch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = prims.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1016, t1012\n", + " [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)\n", + " # t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1020 = prims.mul(t1019, t154) # t1020: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1023 = prims.mul(t1022, t157) # t1023: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1024 = prims.add(t1020, t1023) # t1024: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1035 = prims.mul(t1034, t154) # t1035: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1038 = prims.mul(t1037, t157) # t1038: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1039 = prims.add(t1035, t1038) # t1039: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1011, t1017, t1026, t1032\n", + " t1042 = torch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = ltorch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = prims.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1025, t1041\n", + " t1044 = torch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = ltorch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = prims.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1040, t1043\n", + " (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)\n", + " t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1052\n", + " t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = ltorch.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = prims.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)\n", + " # t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1057 = prims.add(t1055, t1056) # t1057: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1060 = prims.mul(t1057, t1057) # t1060: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1061 = prims.sum(t1060, (2,)) # t1061: \"cuda:0 f32[1, 512]\"\n", + " # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1063 = prims.div(t1062, 4096.0) # t1063: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1064 = prims.add(t1063, 1e-05) # t1064: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1065 = prims.rsqrt(t1064) # t1065: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1067 = prims.mul(t1057, t1066) # t1067: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1072 = prims.mul(t1067, t1071) # t1072: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = ltorch.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = prims.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = ltorch.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = prims.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1089] = nvFusion44(t1074, t1075)\n", + " # t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1077 = prims.neg(t1076) # t1077: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1078 = prims.exp(t1077) # t1078: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1079 = prims.add(1.0, t1078) # t1079: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1080 = prims.reciprocal(t1079) # t1080: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1084 = prims.mul(t1076, t1080) # t1084: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1088 = prims.mul(t1084, t1087) # t1088: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = ltorch.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = prims.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)\n", + " # t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1093 = prims.add(t1091, t1092) # t1093: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1096 = prims.mul(t1093, t1093) # t1096: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1097 = prims.sum(t1096, (2,)) # t1097: \"cuda:0 f32[1, 512]\"\n", + " # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1099 = prims.div(t1098, 4096.0) # t1099: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1100 = prims.add(t1099, 1e-05) # t1100: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1101 = prims.rsqrt(t1100) # t1101: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1103 = prims.mul(t1093, t1102) # t1103: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1108 = prims.mul(t1103, t1107) # t1108: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = ltorch.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = prims.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1110\n", + " t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1111\n", + " (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)\n", + " # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)\n", + " # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1112\n", + " t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1113\n", + " t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1114\n", + " t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1115\n", + " t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1116\n", + " t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1117\n", + " t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)\n", + " # t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1123 = prims.neg(t1122) # t1123: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1138 = prims.neg(t1137) # t1138: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1121, t1136\n", + " t1125 = torch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = ltorch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = prims.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1124, t1120\n", + " t1140 = torch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = ltorch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = prims.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1139, t1135\n", + " [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)\n", + " # t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1128 = prims.mul(t1127, t154) # t1128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1131 = prims.mul(t1130, t157) # t1131: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1132 = prims.add(t1128, t1131) # t1132: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1143 = prims.mul(t1142, t154) # t1143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1146 = prims.mul(t1145, t157) # t1146: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1147 = prims.add(t1143, t1146) # t1147: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1119, t1125, t1134, t1140\n", + " t1152 = torch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = ltorch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = prims.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1148, t1151\n", + " t1150 = torch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = ltorch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = prims.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1133, t1149\n", + " (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)\n", + " t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1160\n", + " t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = ltorch.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = prims.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)\n", + " # t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1165 = prims.add(t1163, t1164) # t1165: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1168 = prims.mul(t1165, t1165) # t1168: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1169 = prims.sum(t1168, (2,)) # t1169: \"cuda:0 f32[1, 512]\"\n", + " # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1171 = prims.div(t1170, 4096.0) # t1171: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1172 = prims.add(t1171, 1e-05) # t1172: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1173 = prims.rsqrt(t1172) # t1173: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1175 = prims.mul(t1165, t1174) # t1175: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1180 = prims.mul(t1175, t1179) # t1180: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = ltorch.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = prims.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = ltorch.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = prims.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1197] = nvFusion49(t1182, t1183)\n", + " # t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1185 = prims.neg(t1184) # t1185: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1186 = prims.exp(t1185) # t1186: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1187 = prims.add(1.0, t1186) # t1187: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1188 = prims.reciprocal(t1187) # t1188: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1192 = prims.mul(t1184, t1188) # t1192: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1196 = prims.mul(t1192, t1195) # t1196: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = ltorch.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = prims.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)\n", + " # t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1201 = prims.add(t1199, t1200) # t1201: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1204 = prims.mul(t1201, t1201) # t1204: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1205 = prims.sum(t1204, (2,)) # t1205: \"cuda:0 f32[1, 512]\"\n", + " # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1207 = prims.div(t1206, 4096.0) # t1207: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1208 = prims.add(t1207, 1e-05) # t1208: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1209 = prims.rsqrt(t1208) # t1209: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1211 = prims.mul(t1201, t1210) # t1211: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1216 = prims.mul(t1211, t1215) # t1216: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = ltorch.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = prims.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1218\n", + " t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1219\n", + " (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)\n", + " # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)\n", + " # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1220\n", + " t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1221\n", + " t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1222\n", + " t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1223\n", + " t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1224\n", + " t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1225\n", + " t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)\n", + " # t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1231 = prims.neg(t1230) # t1231: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1246 = prims.neg(t1245) # t1246: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1229, t1244\n", + " t1233 = torch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = ltorch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = prims.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1232, t1228\n", + " t1248 = torch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = ltorch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = prims.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1247, t1243\n", + " [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)\n", + " # t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1236 = prims.mul(t1235, t154) # t1236: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1239 = prims.mul(t1238, t157) # t1239: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1240 = prims.add(t1236, t1239) # t1240: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1251 = prims.mul(t1250, t154) # t1251: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1254 = prims.mul(t1253, t157) # t1254: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1255 = prims.add(t1251, t1254) # t1255: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1227, t1233, t1242, t1248\n", + " t1258 = torch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = ltorch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = prims.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1241, t1257\n", + " t1260 = torch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = ltorch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = prims.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1256, t1259\n", + " (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)\n", + " t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1268\n", + " t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = ltorch.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = prims.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)\n", + " # t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1273 = prims.add(t1271, t1272) # t1273: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1276 = prims.mul(t1273, t1273) # t1276: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1277 = prims.sum(t1276, (2,)) # t1277: \"cuda:0 f32[1, 512]\"\n", + " # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1279 = prims.div(t1278, 4096.0) # t1279: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1280 = prims.add(t1279, 1e-05) # t1280: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1281 = prims.rsqrt(t1280) # t1281: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1283 = prims.mul(t1273, t1282) # t1283: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1288 = prims.mul(t1283, t1287) # t1288: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = ltorch.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = prims.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = ltorch.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = prims.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1305] = nvFusion54(t1290, t1291)\n", + " # t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1293 = prims.neg(t1292) # t1293: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1294 = prims.exp(t1293) # t1294: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1295 = prims.add(1.0, t1294) # t1295: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1296 = prims.reciprocal(t1295) # t1296: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1300 = prims.mul(t1292, t1296) # t1300: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1304 = prims.mul(t1300, t1303) # t1304: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = ltorch.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = prims.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)\n", + " # t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1309 = prims.add(t1307, t1308) # t1309: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1312 = prims.mul(t1309, t1309) # t1312: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1313 = prims.sum(t1312, (2,)) # t1313: \"cuda:0 f32[1, 512]\"\n", + " # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1315 = prims.div(t1314, 4096.0) # t1315: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1316 = prims.add(t1315, 1e-05) # t1316: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1317 = prims.rsqrt(t1316) # t1317: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1319 = prims.mul(t1309, t1318) # t1319: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1324 = prims.mul(t1319, t1323) # t1324: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = ltorch.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = prims.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1326\n", + " t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1327\n", + " (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)\n", + " # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)\n", + " # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1328\n", + " t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1329\n", + " t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1330\n", + " t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1331\n", + " t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1332\n", + " t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1333\n", + " t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)\n", + " # t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1339 = prims.neg(t1338) # t1339: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1354 = prims.neg(t1353) # t1354: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1337, t1352\n", + " t1341 = torch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = ltorch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = prims.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1340, t1336\n", + " t1356 = torch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = ltorch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = prims.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1355, t1351\n", + " [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)\n", + " # t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1344 = prims.mul(t1343, t154) # t1344: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1347 = prims.mul(t1346, t157) # t1347: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1348 = prims.add(t1344, t1347) # t1348: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1359 = prims.mul(t1358, t154) # t1359: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1362 = prims.mul(t1361, t157) # t1362: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1363 = prims.add(t1359, t1362) # t1363: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1335, t1341, t1350, t1356\n", + " t1366 = torch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = ltorch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = prims.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1349, t1365\n", + " t1368 = torch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = ltorch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = prims.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1364, t1367\n", + " (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)\n", + " t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1376\n", + " t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = ltorch.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = prims.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)\n", + " # t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1381 = prims.add(t1379, t1380) # t1381: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1384 = prims.mul(t1381, t1381) # t1384: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1385 = prims.sum(t1384, (2,)) # t1385: \"cuda:0 f32[1, 512]\"\n", + " # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1387 = prims.div(t1386, 4096.0) # t1387: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1388 = prims.add(t1387, 1e-05) # t1388: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1389 = prims.rsqrt(t1388) # t1389: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1391 = prims.mul(t1381, t1390) # t1391: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1396 = prims.mul(t1391, t1395) # t1396: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = ltorch.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = prims.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = ltorch.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = prims.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1413] = nvFusion59(t1398, t1399)\n", + " # t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1401 = prims.neg(t1400) # t1401: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1402 = prims.exp(t1401) # t1402: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1403 = prims.add(1.0, t1402) # t1403: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1404 = prims.reciprocal(t1403) # t1404: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1408 = prims.mul(t1400, t1404) # t1408: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1412 = prims.mul(t1408, t1411) # t1412: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = ltorch.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = prims.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)\n", + " # t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1417 = prims.add(t1415, t1416) # t1417: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1420 = prims.mul(t1417, t1417) # t1420: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1421 = prims.sum(t1420, (2,)) # t1421: \"cuda:0 f32[1, 512]\"\n", + " # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1423 = prims.div(t1422, 4096.0) # t1423: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1424 = prims.add(t1423, 1e-05) # t1424: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1425 = prims.rsqrt(t1424) # t1425: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1427 = prims.mul(t1417, t1426) # t1427: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1432 = prims.mul(t1427, t1431) # t1432: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = ltorch.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = prims.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1434\n", + " t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1435\n", + " (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)\n", + " # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)\n", + " # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1436\n", + " t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1437\n", + " t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1438\n", + " t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1439\n", + " t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1440\n", + " t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1441\n", + " t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)\n", + " # t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1447 = prims.neg(t1446) # t1447: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1462 = prims.neg(t1461) # t1462: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1445, t1460\n", + " t1464 = torch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = ltorch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = prims.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1463, t1459\n", + " t1449 = torch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = ltorch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = prims.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1448, t1444\n", + " [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)\n", + " # t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1467 = prims.mul(t1466, t154) # t1467: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1470 = prims.mul(t1469, t157) # t1470: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1471 = prims.add(t1467, t1470) # t1471: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1452 = prims.mul(t1451, t154) # t1452: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1455 = prims.mul(t1454, t157) # t1455: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1456 = prims.add(t1452, t1455) # t1456: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1443, t1449, t1458, t1464\n", + " t1476 = torch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = ltorch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = prims.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1472, t1475\n", + " t1474 = torch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = ltorch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = prims.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1457, t1473\n", + " (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)\n", + " t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1484\n", + " t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = ltorch.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = prims.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)\n", + " # t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1489 = prims.add(t1487, t1488) # t1489: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1492 = prims.mul(t1489, t1489) # t1492: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1493 = prims.sum(t1492, (2,)) # t1493: \"cuda:0 f32[1, 512]\"\n", + " # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1495 = prims.div(t1494, 4096.0) # t1495: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1496 = prims.add(t1495, 1e-05) # t1496: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1497 = prims.rsqrt(t1496) # t1497: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1499 = prims.mul(t1489, t1498) # t1499: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1504 = prims.mul(t1499, t1503) # t1504: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = ltorch.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = prims.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = ltorch.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = prims.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1521] = nvFusion64(t1506, t1507)\n", + " # t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1509 = prims.neg(t1508) # t1509: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1510 = prims.exp(t1509) # t1510: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1511 = prims.add(1.0, t1510) # t1511: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1512 = prims.reciprocal(t1511) # t1512: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1516 = prims.mul(t1508, t1512) # t1516: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1520 = prims.mul(t1516, t1519) # t1520: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = ltorch.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = prims.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)\n", + " # t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1525 = prims.add(t1523, t1524) # t1525: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1528 = prims.mul(t1525, t1525) # t1528: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1529 = prims.sum(t1528, (2,)) # t1529: \"cuda:0 f32[1, 512]\"\n", + " # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1531 = prims.div(t1530, 4096.0) # t1531: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1532 = prims.add(t1531, 1e-05) # t1532: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1533 = prims.rsqrt(t1532) # t1533: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1535 = prims.mul(t1525, t1534) # t1535: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1540 = prims.mul(t1535, t1539) # t1540: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = ltorch.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = prims.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1542\n", + " t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1543\n", + " (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)\n", + " # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)\n", + " # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1544\n", + " t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1545\n", + " t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1546\n", + " t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1547\n", + " t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1548\n", + " t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1549\n", + " t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)\n", + " # t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1555 = prims.neg(t1554) # t1555: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1570 = prims.neg(t1569) # t1570: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1553, t1568\n", + " t1572 = torch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = ltorch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = prims.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1571, t1567\n", + " t1557 = torch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = ltorch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = prims.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1556, t1552\n", + " [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)\n", + " # t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1575 = prims.mul(t1574, t154) # t1575: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1578 = prims.mul(t1577, t157) # t1578: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1579 = prims.add(t1575, t1578) # t1579: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1560 = prims.mul(t1559, t154) # t1560: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1563 = prims.mul(t1562, t157) # t1563: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1564 = prims.add(t1560, t1563) # t1564: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1551, t1557, t1566, t1572\n", + " t1584 = torch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = ltorch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = prims.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1580, t1583\n", + " t1582 = torch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = ltorch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = prims.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1565, t1581\n", + " (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)\n", + " t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1592\n", + " t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = ltorch.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = prims.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)\n", + " # t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1597 = prims.add(t1595, t1596) # t1597: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1600 = prims.mul(t1597, t1597) # t1600: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1601 = prims.sum(t1600, (2,)) # t1601: \"cuda:0 f32[1, 512]\"\n", + " # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1603 = prims.div(t1602, 4096.0) # t1603: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1604 = prims.add(t1603, 1e-05) # t1604: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1605 = prims.rsqrt(t1604) # t1605: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1607 = prims.mul(t1597, t1606) # t1607: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1612 = prims.mul(t1607, t1611) # t1612: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = ltorch.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = prims.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = ltorch.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = prims.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1629] = nvFusion69(t1614, t1615)\n", + " # t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1617 = prims.neg(t1616) # t1617: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1618 = prims.exp(t1617) # t1618: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1619 = prims.add(1.0, t1618) # t1619: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1620 = prims.reciprocal(t1619) # t1620: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1624 = prims.mul(t1616, t1620) # t1624: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1628 = prims.mul(t1624, t1627) # t1628: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = ltorch.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = prims.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)\n", + " # t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1633 = prims.add(t1631, t1632) # t1633: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1636 = prims.mul(t1633, t1633) # t1636: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1637 = prims.sum(t1636, (2,)) # t1637: \"cuda:0 f32[1, 512]\"\n", + " # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1639 = prims.div(t1638, 4096.0) # t1639: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1640 = prims.add(t1639, 1e-05) # t1640: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1641 = prims.rsqrt(t1640) # t1641: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1643 = prims.mul(t1633, t1642) # t1643: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1648 = prims.mul(t1643, t1647) # t1648: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = ltorch.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = prims.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1650\n", + " t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1651\n", + " (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)\n", + " # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)\n", + " # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1652\n", + " t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1653\n", + " t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1654\n", + " t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1655\n", + " t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1656\n", + " t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1657\n", + " t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)\n", + " # t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1663 = prims.neg(t1662) # t1663: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1678 = prims.neg(t1677) # t1678: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1661, t1676\n", + " t1680 = torch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = ltorch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = prims.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1679, t1675\n", + " t1665 = torch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = ltorch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = prims.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1664, t1660\n", + " [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)\n", + " # t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1683 = prims.mul(t1682, t154) # t1683: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1686 = prims.mul(t1685, t157) # t1686: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1687 = prims.add(t1683, t1686) # t1687: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1668 = prims.mul(t1667, t154) # t1668: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1671 = prims.mul(t1670, t157) # t1671: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1672 = prims.add(t1668, t1671) # t1672: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1659, t1665, t1674, t1680\n", + " t1692 = torch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = ltorch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = prims.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1688, t1691\n", + " t1690 = torch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = ltorch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = prims.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1673, t1689\n", + " (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)\n", + " t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1700\n", + " t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = ltorch.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = prims.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)\n", + " # t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1705 = prims.add(t1703, t1704) # t1705: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1708 = prims.mul(t1705, t1705) # t1708: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1709 = prims.sum(t1708, (2,)) # t1709: \"cuda:0 f32[1, 512]\"\n", + " # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1711 = prims.div(t1710, 4096.0) # t1711: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1712 = prims.add(t1711, 1e-05) # t1712: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1713 = prims.rsqrt(t1712) # t1713: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1715 = prims.mul(t1705, t1714) # t1715: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1720 = prims.mul(t1715, t1719) # t1720: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = ltorch.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = prims.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = ltorch.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = prims.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1737] = nvFusion74(t1722, t1723)\n", + " # t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1725 = prims.neg(t1724) # t1725: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1726 = prims.exp(t1725) # t1726: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1727 = prims.add(1.0, t1726) # t1727: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1728 = prims.reciprocal(t1727) # t1728: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1732 = prims.mul(t1724, t1728) # t1732: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1736 = prims.mul(t1732, t1735) # t1736: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = ltorch.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = prims.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)\n", + " # t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1741 = prims.add(t1739, t1740) # t1741: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1744 = prims.mul(t1741, t1741) # t1744: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1745 = prims.sum(t1744, (2,)) # t1745: \"cuda:0 f32[1, 512]\"\n", + " # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1747 = prims.div(t1746, 4096.0) # t1747: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1748 = prims.add(t1747, 1e-05) # t1748: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1749 = prims.rsqrt(t1748) # t1749: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1751 = prims.mul(t1741, t1750) # t1751: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1756 = prims.mul(t1751, t1755) # t1756: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = ltorch.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = prims.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1758\n", + " t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1759\n", + " (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)\n", + " # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)\n", + " # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1760\n", + " t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1761\n", + " t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1762\n", + " t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1763\n", + " t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1764\n", + " t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1765\n", + " t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)\n", + " # t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1771 = prims.neg(t1770) # t1771: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1786 = prims.neg(t1785) # t1786: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1769, t1784\n", + " t1788 = torch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = ltorch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = prims.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1787, t1783\n", + " t1773 = torch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = ltorch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = prims.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1772, t1768\n", + " [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)\n", + " # t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1791 = prims.mul(t1790, t154) # t1791: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1794 = prims.mul(t1793, t157) # t1794: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1795 = prims.add(t1791, t1794) # t1795: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1776 = prims.mul(t1775, t154) # t1776: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1779 = prims.mul(t1778, t157) # t1779: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1780 = prims.add(t1776, t1779) # t1780: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1767, t1773, t1782, t1788\n", + " t1800 = torch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = ltorch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = prims.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1796, t1799\n", + " t1798 = torch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = ltorch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = prims.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1781, t1797\n", + " (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)\n", + " t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1808\n", + " t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = ltorch.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = prims.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)\n", + " # t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1813 = prims.add(t1811, t1812) # t1813: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1816 = prims.mul(t1813, t1813) # t1816: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1817 = prims.sum(t1816, (2,)) # t1817: \"cuda:0 f32[1, 512]\"\n", + " # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1819 = prims.div(t1818, 4096.0) # t1819: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1820 = prims.add(t1819, 1e-05) # t1820: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1821 = prims.rsqrt(t1820) # t1821: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1823 = prims.mul(t1813, t1822) # t1823: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1828 = prims.mul(t1823, t1827) # t1828: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = ltorch.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = prims.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = ltorch.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = prims.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1845] = nvFusion79(t1830, t1831)\n", + " # t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1833 = prims.neg(t1832) # t1833: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1834 = prims.exp(t1833) # t1834: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1835 = prims.add(1.0, t1834) # t1835: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1836 = prims.reciprocal(t1835) # t1836: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1840 = prims.mul(t1832, t1836) # t1840: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1844 = prims.mul(t1840, t1843) # t1844: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = ltorch.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = prims.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1857, t1865] = nvFusion80(t1814, t1846, t1861)\n", + " # t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1849 = prims.add(t1847, t1848) # t1849: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1852 = prims.mul(t1849, t1849) # t1852: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1853 = prims.sum(t1852, (2,)) # t1853: \"cuda:0 f32[1, 512]\"\n", + " # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1855 = prims.div(t1854, 4096.0) # t1855: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1856 = prims.add(t1855, 1e-05) # t1856: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1857 = prims.rsqrt(t1856) # t1857: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1859 = prims.mul(t1849, t1858) # t1859: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1864 = prims.mul(t1859, t1863) # t1864: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = ltorch.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = prims.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))" ] }, "execution_count": 9, @@ -1016,13 +3161,18 @@ } ], "source": [ + "print(actual.grad_fn)\n", "thunder.last_traces(thunder_model)[-1]" ] }, { "cell_type": "markdown", - "id": "4944f352", - "metadata": {}, + "id": "558f2553-37f7-4b58-b7cd-a744155613a8", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, "source": [ "Well, that is quite a bit to look through.\n", "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", @@ -1032,19 +3182,19 @@ { "cell_type": "code", "execution_count": 10, - "id": "4d90df65", + "id": "59643398-d6e2-4c32-81bd-145a1198b1f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.9922, 0.5946, -0.2173, ..., -0.0981, -0.5058, 0.2747],\n", - " [-1.1552, 0.5770, -0.7432, ..., 0.0688, 0.1238, 0.6786],\n", - " [-0.7813, 0.6960, 0.1235, ..., -0.4840, 0.1373, 0.6490],\n", + "tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],\n", + " [ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],\n", + " [ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],\n", " ...,\n", - " [ 0.3711, 0.1656, 0.3350, ..., -0.0294, 0.3670, 0.5099],\n", - " [-0.2544, -0.8470, 0.2063, ..., -0.1341, 0.1877, 0.2612],\n", - " [ 0.3420, -1.1421, 0.9222, ..., 0.5636, 0.1666, 0.6947]]],\n", + " [ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],\n", + " [-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],\n", + " [ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],\n", " device='cuda:0', grad_fn=)" ] }, @@ -1059,10 +3209,10 @@ }, { "cell_type": "markdown", - "id": "7dcec40f", + "id": "17341d86-d4c9-46bd-ac5e-3a05da1ff72c", "metadata": {}, "source": [ - "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced." + "Let us clean up a bit." ] }, { @@ -1070,25 +3220,21 @@ "execution_count": 11, "id": "6ba7f715", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "maximum deviation grads: 0.00042724609375\n" - ] - } - ], + "outputs": [], "source": [ - "actual_grads = torch.autograd.grad(actual.sum(), m.parameters())\n", - "expected_grads = torch.autograd.grad(expected.sum(), m.parameters())\n", - "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + "del actual, expected\n", + "import gc\n", + "gc.collect();" ] }, { "cell_type": "markdown", "id": "0261eb11", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "But is it faster? Yes!" ] @@ -1096,50 +3242,52 @@ { "cell_type": "code", "execution_count": 12, - "id": "854f29a5", + "id": "bccec79b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "154 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "150 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ - "import gc\n", - "gc.collect()\n", "%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n", "%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()" ] }, + { + "cell_type": "markdown", + "id": "1d31e7f8", + "metadata": {}, + "source": [ + "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!\n" + ] + }, { "cell_type": "code", "execution_count": 13, - "id": "eb177aad", + "id": "ecad9125-bbf2-42c8-b11c-23eed4a6cd8f", "metadata": {}, "outputs": [], "source": [ "del m, thunder_model\n", "import gc\n", "gc.collect()\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "id": "1d31e7f8", - "metadata": {}, - "source": [ - "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!" + "torch.cuda.empty_cache()\n" ] }, { "cell_type": "markdown", - "id": "d23ebbf5", - "metadata": {}, + "id": "49e3273c-99be-4370-9e59-121c00481b4e", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Distributed with Thunder\n", "\n", @@ -1162,7 +3310,11 @@ "cell_type": "code", "execution_count": 14, "id": "18dd3379", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "outputs": [ { "name": "stdout", @@ -1174,21 +3326,19 @@ ], "source": [ "%%writefile zero_to_thunder_fsdp_simple_example.py\n", - "import sys\n", - "sys.path.insert(0, '..')\n", "from thunder.tests.lit_gpt_model import GPT, Config\n", - "\n", - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", "import os\n", + "import torch, torch.distributed\n", + "import thunder, thunder.distributed\n", "\n", "# Create Model\n", "# NOTE: We create the model on CPU.\n", "device='cpu'\n", "torch.set_default_dtype(torch.bfloat16)\n", - "model = GPT.from_name('llama2-like')\n", + "cfg = Config.from_name('Llama-2-7b-hf')\n", + "cfg.n_layer = 8 # fewer layers\n", + "model = GPT(cfg)\n", + "\n", "# Setup for distributed\n", "torch.distributed.init_process_group(backend='nccl')\n", "rank = int(os.environ[\"LOCAL_RANK\"])\n", @@ -1199,13 +3349,19 @@ "# thunder.distributed.fsdp takes care of moving the parameter\n", "# shard to the correct GPU for the current process.\n", "model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n", - "\n", + "print(f\"rank {rank} computing\")\n", "# Run the forward pass.\n", - "res = model(x)\n", - "res.sum().backward()\n", - "\n", - "res = model(x)\n", - "res.sum().backward()\n" + "for i in range(10):\n", + " res = model(x)\n", + " res.sum().backward()\n" + ] + }, + { + "cell_type": "markdown", + "id": "97e8edbf-424d-49a7-8ed6-12cb5e5d65fc", + "metadata": {}, + "source": [ + "Now we can launch it. Note that you need two GPUs for this to run correctly." ] }, { @@ -1213,17 +3369,22 @@ "execution_count": 15, "id": "2bad9b64", "metadata": { - "scrolled": false + "scrolled": true, + "slideshow": { + "slide_type": "skip" + } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] \r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n" + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "rank 1 computing\n", + "rank 0 computing\n" ] } ], @@ -1234,21 +3395,29 @@ { "cell_type": "markdown", "id": "9c65e75d", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "So there. FSDP with just wrapping the model in `fsdp`." + "So there. FSDP with just wrapping the model in `fsdp`.\n" ] }, { "cell_type": "markdown", "id": "4a6d7a20", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Extending Thunder\n", "\n", "But we promised that thunder is extensible. Let's find out what's up with that.\n", "\n", - "Specifically, we will incorporate the RMSNorm kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", + "Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", "\n", "In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one." ] @@ -1277,91 +3446,94 @@ }, { "cell_type": "markdown", - "id": "a63595ab", - "metadata": {}, + "id": "2fe3b40b-c6e9-417c-ab7a-32606cee871a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "For our base implementation, we take the code from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n", + "For our base implementation, we take the code from [LitGPT's implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)\n", "\n", - "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n" + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n", + "Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted." ] }, { "cell_type": "code", "execution_count": 17, - "id": "247074b3", - "metadata": {}, + "id": "3e74436b-d8eb-472b-9d6d-b6412378fde7", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "outputs": [], "source": [ - "from thunder import TensorProxy\n", - "\n", - "# Taken from LitGPT, who in turn credit:\n", - "# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n", - "# https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n", - "\n", - "def rms_norm_impl(x: torch.Tensor, weight, dim: int, eps: float, add_unit_offset: bool) -> torch.Tensor:\n", - " dtype = x.dtype\n", - " x = x.float()\n", - " # NOTE: the original RMSNorm paper implementation is not equivalent\n", - " norm_x = torch.mean(x * x, dim=dim, keepdim=True)\n", - " x_normed = x * torch.rsqrt(norm_x + eps)\n", - " x_normed = x_normed.to(dtype=dtype)\n", - " if add_unit_offset:\n", - " # Gemma model requires a unit offset\n", - " # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176\n", - " return x_normed * (1 + weight)\n", - " return x_normed * weight\n", - "\n", - "def rms_norm_meta(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool) -> TensorProxy:\n", - " return TensorProxy(like=x)\n", - "\n", - "rms_norm = my_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=rms_norm_impl)\n" + "import lit_gpt\n", + "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " head_size = x.size(-1)\n", + " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", + " x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)\n", + " rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)\n", + " roped = (x * cos) + (rotated * sin)\n", + " return roped.to(dtype=x.dtype)" ] }, { "cell_type": "markdown", - "id": "75ad1dbf", - "metadata": {}, + "id": "a63595ab", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "For this short demo, we monkey-patch LitGPT to replace its own implementation. For your own model, you might start out with a that in your code directly." + "### Registering operators\n", + "\n", + "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", + "\n", + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n" ] }, { "cell_type": "code", "execution_count": 18, - "id": "e0bdecd3", + "id": "247074b3", "metadata": {}, "outputs": [], "source": [ - "import lit_gpt.rmsnorm\n", - "if not hasattr(lit_gpt.rmsnorm, 'ThunderOrigRMSNorm'):\n", - " lit_gpt.rmsnorm.ThunderOrigRMSNorm = lit_gpt.rmsnorm.RMSNorm\n", + "import torch, thunder\n", + "from thunder.tests.lit_gpt_model import GPT\n", + "from thunder import TensorProxy\n", "\n", - "class ThunderizedRMSNorm(lit_gpt.rmsnorm.ThunderOrigRMSNorm):\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " # This isn't the best paradigm. :/\n", - " if thunder.core.interpreter.is_jitting():\n", - " return rms_norm(x, self.weight, self.dim, self.eps, self.add_unit_offset)\n", - " else:\n", - " return super().forward(x)\n", + "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " return lit_gpt.model.apply_rope(x, cos, sin)\n", + "\n", + "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " return TensorProxy(like=x)\n", "\n", - "lit_gpt.rmsnorm.RMSNorm = ThunderizedRMSNorm" + "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", + " replaces=lit_gpt.model.apply_rope)" ] }, { "cell_type": "markdown", "id": "d6b7d056", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "We can try our new RMSNorm: " + "### Testing our new operator " ] }, { "cell_type": "code", "execution_count": 19, "id": "0ebd5dd1", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1379,12 +3551,13 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(x, t_weight):\n", - " # x: \"cuda:0 f32[256, 4096]\" \n", - " # t_weight: \"cuda:0 f32[4096]\" \n", - " t7 = rms_norm(x, t_weight, -1, 1e-06, False) # t7: \"cuda:0 f32[256, 4096]\"\n", - " del x, t_weight\n", - " return t7" + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" ] }, "execution_count": 19, @@ -1393,37 +3566,37 @@ } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - " x = torch.randn(256, 4096)\n", + "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", "\n", - "# we're not quite there to handle forward and backward yet, we'll re-enable them below\n", - "for p in norm_module.parameters(): \n", - " p.requires_grad_(False)\n", + "def test_apply_rope(x, m):\n", + " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n", "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", - "expected = norm_module(x)\n", - "actual = thunder_norm_module(x)\n", + "expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", - "print(\"deviation:\", (expected - actual).abs().max().item())\n", - "\n", - "thunder.last_traces(thunder_norm_module)[-1]" + "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", "id": "8c620a38", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ + "### Optimized kernels\n", + "\n", "But why did we do this? Well, we can now layer a faster implementation on top.\n", - "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas." + "For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels." ] }, { "cell_type": "code", "execution_count": 20, - "id": "a7a26f5f", + "id": "6e6d0b1e-ba14-43e5-b0d9-27c0e3b46879", "metadata": {}, "outputs": [], "source": [ @@ -1459,196 +3632,214 @@ " elif BLOCK_SIZE >= 2048: num_warps = 8\n", " return BLOCK_SIZE, num_warps\n", "\n", + "@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n", "@triton.jit\n", - "def _rms_layernorm_forward(\n", - " Y, Y_row_stride,\n", - " X, X_row_stride,\n", - " W, W_row_stride,\n", - " r, r_row_stride,\n", - " n_cols, eps,\n", - " BLOCK_SIZE : tl.constexpr\n", + "def _rope_embedding(\n", + " Q, Q_row_stride,\n", + " cos, cos_row_stride,\n", + " sin, sin_row_stride,\n", + " seqlen, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS: tl.constexpr,\n", + " BLOCK_SIZE : tl.constexpr,\n", "):\n", " \"\"\"\n", - " Fast RMS Layernorm kernel\n", - " Inspiration from a Triton tutorial:\n", - " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", + " Calculates the RoPE Embedding quickly\n", + " RoPE is Q * cos + rotate_half(Q) * sin\n", + " See our blog post for more info\n", " \"\"\"\n", - " row_idx = tl.program_id(0)\n", - " col_offsets = tl.arange(0, BLOCK_SIZE)\n", - " mask = col_offsets < n_cols\n", + " row_position = tl.program_id(0)\n", + " group_head_position = tl.program_id(1)\n", + " col_offsets = tl.arange(0, BLOCK_SIZE)\n", + " half_head_dim = head_dim // 2\n", + " mask = col_offsets < half_head_dim\n", "\n", - " Y += row_idx * Y_row_stride\n", - " X += row_idx * X_row_stride\n", - " r += row_idx * r_row_stride\n", + " sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", + " cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", "\n", - " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)\n", + " if BACKWARD_PASS:\n", + " # See our blog post for more info.\n", + " sin1 = -sin1\n", + " pass\n", "\n", - " row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n", - " inv_var = tl.math.rsqrt(row_var + eps)\n", - " tl.store(r, inv_var)\n", - " normed = X_row * inv_var\n", - " normed = normed.to(W_row.dtype) # Exact copy from HF\n", - " output = normed * W_row\n", - " tl.store(Y + col_offsets, output, mask = mask)\n", + " head_start = group_head_position * group_size\n", + " head_end = min((head_start + group_size), n_heads)\n", "\n", + " for i in range(head_start, head_end):\n", + " offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets\n", + " offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim\n", "\n", - "@triton.jit\n", - "def _rms_layernorm_backward(\n", - " dY, dY_row_stride,\n", - " X, X_row_stride,\n", - " W, W_row_stride,\n", - " r, r_row_stride,\n", - " dW, dW_row_stride,\n", - " n_cols, eps,\n", - " BLOCK_SIZE : tl.constexpr,\n", - "):\n", - " \"\"\"\n", - " Fast RMS Layernorm kernel for the backward pass\n", - " Inspiration from a Triton tutorial:\n", - " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", - " \"\"\"\n", - " row_idx = tl.program_id(0)\n", - " col_offsets = tl.arange(0, BLOCK_SIZE)\n", - " mask = col_offsets < n_cols\n", - "\n", - " dY += row_idx * dY_row_stride\n", - " X += row_idx * X_row_stride\n", - " r += row_idx * r_row_stride\n", - "\n", - " dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - "\n", - " # Get saved row variance\n", - " inv_var = tl.load(r).to(tl.float32)\n", - " normed = X_row * inv_var\n", - "\n", - " dY_W = dY_row * W_row\n", - "\n", - " rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n", - " output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n", - " tl.store(dY + col_offsets, output, mask = mask)\n", - " \n", - "def rms_layernorm_forward_impl(X, W, eps):\n", - " shape = X.shape\n", - " dim = shape[-1]\n", - " X = X.view(-1, dim)\n", - " n_rows, n_cols = X.shape\n", - " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", - "\n", - " Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n", - " r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n", - "\n", - " _rms_layernorm_forward[(n_rows,)](\n", - " Y, Y.stride(0),\n", - " X, X.stride(0),\n", - " W, W.stride(0),\n", - " r, r.stride(0),\n", - " n_cols, eps,\n", + " # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n", + " Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n", + " Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n", + "\n", + " tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n", + " tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n", + " pass\n", + "pass\n", + "\n", + "\n", + "def fast_rope_embedding_forward(Q, cos, sin):\n", + " Q = Q.transpose(1, 2).clone()\n", + " cos, sin = cos.squeeze(), sin.squeeze()\n", + " batch, seq_len, n_heads, head_dim = Q.shape\n", + " Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n", + " n_rows, n_cols = Q.shape\n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " # [TODO] Changing blocksize to head_dim//2 seems to have\n", + " # some concurrency / un-deterministic issues.\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " Q, Q.stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = False,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", - " return Y.view(*shape), (r, BLOCK_SIZE, num_warps)\n", - "\n", - "def rms_layernorm_forward_meta(X, W, eps):\n", - " n_cols = X.shape[-1]\n", - " n_rows = 1\n", - " for i in X.shape[:-1]:\n", - " n_rows *= i\n", - " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", - " Y = TensorProxy(like=X, requires_grad=True)\n", - " return (Y,\n", - " (TensorProxy(shape=(n_rows,), device=X.device, dtype=thunder.dtypes.float32, requires_grad=False),\n", - " BLOCK_SIZE, \n", - " num_warps,\n", - " )\n", - " )\n", - "\n", - "def rms_layernorm_backward_impl(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", - " shape = dY.shape\n", - " dim = shape[-1]\n", - " dY = dY.view(-1, dim)\n", + " Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)\n", + " return Q, (BLOCK_SIZE, num_warps) \n", + "\n", + "def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):\n", + " dY = dY.transpose(1, 2)\n", + " batch, seq_len, n_heads, head_dim = dY.shape\n", + " dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n", + " # Must be reshape not view\n", " n_rows, n_cols = dY.shape\n", - " dW = X\n", - " dX = dY.clone()\n", - " _rms_layernorm_backward[(n_rows,)](\n", - " dX, dX.stride(0),\n", - " X, X .stride(0),\n", - " W, W .stride(0),\n", - " r, r .stride(0),\n", - " dW, dW.stride(0),\n", - " n_cols, eps,\n", + "\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " dY, dY .stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = True,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", - " dX = dX.view(*shape)\n", - " return dX\n", + " dY = dY.view(batch, seq_len, n_heads, head_dim)\n", + " dY = dY.transpose(1, 2) \n", + " return dY\n" + ] + }, + { + "cell_type": "markdown", + "id": "ed1e9be3-d1c9-4c4b-bf14-a025a03687ac", + "metadata": {}, + "source": [ + "We also define the corresponding meta functions." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d7e6612d-f1fc-497c-9d64-15ef99824086", + "metadata": {}, + "outputs": [], + "source": [ + "def fast_rope_embedding_forward_meta(Q, cos, sin):\n", + " batch, n_heads, seq_len, head_dim = Q.shape\n", + " n_rows, n_cols = batch*seq_len, n_heads*head_dim \n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n", + " return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) \n", "\n", - "def rms_layernorm_backward_meta(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", + "def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):\n", " return TensorProxy(like=dY)" ] }, { "cell_type": "markdown", "id": "b70eba5f", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "With this, we can just register the additional operators:" + "### Register optimized operators\n", + "\n", + "Just like the `apply_rope` before, we can register operators for the optimized forward and backward." ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "id": "f8f1e77e", "metadata": {}, "outputs": [], "source": [ - "unsloth_rms_norm_forward = my_ex.register_operator('unsloth_rms_norm_forward', meta=rms_layernorm_forward_meta, fn=rms_layernorm_forward_impl)\n", - "unsloth_rms_norm_backward = my_ex.register_operator('unsloth_rms_norm_backward', meta=rms_layernorm_backward_meta, fn=rms_layernorm_backward_impl)" + "unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', \n", + " meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)\n", + "unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', \n", + " meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)" ] }, { "cell_type": "markdown", "id": "2426263d", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "But instead of monkey-patching more, we can now register the kernel as an _implementation_ of the base `rms_norm` primitive defined above. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`rms_norm`) in terms of our new operator - so it has the call signature of the `rms_norm`. Because - like many fast implementations - the unsloth RMS norm does not implement the operator in full generality (to do them justice, they have a variant adding the unit offset, we just didn't copy it over), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." + "### Implementations for operators\n", + "\n", + "Do we need to divert `apply_rope` again? No!\n", + "We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 23, "id": "6b5c8320", "metadata": {}, "outputs": [], "source": [ - "def rms_norm_to_unsloth(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", - " assert dim == -1 and not add_unit_offset\n", - " res, _ = unsloth_rms_norm_forward(x, weight, eps)\n", + "def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " assert len(x.shape) == 4\n", + " res, *_ = unsloth_apply_rope_forward(x, cos, sin)\n", " return res\n", "\n", - "def rms_norm_to_unsloth_checker(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", - " if dim != -1 or add_unit_offset:\n", + "def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n", + " if len(x.shape) != 4:\n", " return False\n", - " if weight.requires_grad:\n", - " return False # the unsloth rms norm backwward only gives the grad w.r.t. x\n", - " return x.device.devicetype == thunder.devices.DeviceType.CUDA and weight.device.devicetype == thunder.devices.DeviceType.CUDA\n", + " return (x.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA)\n", "\n", - "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker, execution_transform=rms_norm_to_unsloth)\n" + "my_ex.register_implementation(apply_rope,\n", + " checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth)\n" ] }, { "cell_type": "markdown", "id": "eec7c95a", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "So let us give that a try! Works great..." + "So let us give it a try! Works great..." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "id": "965ba1d7", "metadata": {}, "outputs": [ @@ -1656,7 +3847,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 9.5367431640625e-07\n" + "deviation: 0.015625\n" ] }, { @@ -1668,49 +3859,45 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(x, t_weight):\n", - " # x: \"cuda:0 f32[2048, 4096]\" \n", - " # t_weight: \"cuda:0 f32[4096]\" \n", - " (t7, (_, _, _)) = unsloth_rms_norm_forward(x, t_weight, 1e-06)\n", - " del x, t_weight\n", - " return t7" + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" ] }, - "execution_count": 23, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - "\n", - "# unfortunately, we meet dragons if we don't do this at this stage\n", - "for p in norm_module.parameters(): \n", - " p.requires_grad_(False)\n", - "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=[my_ex,]) \n", - "x = torch.randn(2048, 4096, device=\"cuda\")\n", - "\n", - "expected = norm_module(x)\n", - "actual = thunder_norm_module(x)\n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", + "expected = test_apply_rope(Q, m)\n", + "actual = thunder_apply_rope(Q, m)\n", "print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", - "thunder.last_traces(thunder_norm_module)[-1]" + "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", - "id": "0e3e4d85", - "metadata": {}, + "id": "69a93d3d-3a88-4297-b330-23a7fff2c4b4", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "And this is also automatic when we instantiate a larger llama2-like model:" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "7fff2522", "metadata": {}, "outputs": [ @@ -1718,7 +3905,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 4.76837158203125e-07\n" + "deviation: 5.960464477539062e-07\n" ] } ], @@ -1742,34 +3929,37 @@ { "cell_type": "markdown", "id": "b538cb40", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "By peeking into the trace, we can see that it actually used the unsloth RMS kernels:" + "By peeking into the trace, we can see that it actually used the unsloth apply rope:" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "id": "c260cb25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[' (n_1, (_, _, _)) = unsloth_rms_norm_forward(x, t_transformer_h_0_norm_1_weight, 1e-05)',\n", - " ' (t110, (_, _, _)) = unsloth_rms_norm_forward(t102, t_transformer_h_0_norm_2_weight, 1e-05)',\n", - " ' (t139, (_, _, _)) = unsloth_rms_norm_forward(t130, t_transformer_h_1_norm_1_weight, 1e-05)',\n", - " ' (t215, (_, _, _)) = unsloth_rms_norm_forward(t207, t_transformer_h_1_norm_2_weight, 1e-05)',\n", - " ' (t243, (_, _, _)) = unsloth_rms_norm_forward(t235, t_transformer_ln_f_weight, 1e-05)']" + "[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',\n", + " ' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',\n", + " ' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',\n", + " ' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']" ] }, - "execution_count": 25, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'rms' in s]" + "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'apply_rope' in s]" ] }, { @@ -1777,79 +3967,97 @@ "id": "0f6c0780", "metadata": {}, "source": [ - "But what about the backward?\n", + "### But what about the backward?\n", "\n", - "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`." + "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. \n" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "id": "7670a872", "metadata": {}, "outputs": [], "source": [ "from thunder.core.transforms import get_grad, put_grads\n", "\n", - "def unsloth_rms_norm_grad(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool):\n", - " res, (r, BLOCK_SIZE, num_warps) = unsloth_rms_norm_forward(x, weight, eps)\n", + "def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):\n", + " res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)\n", " grad_res = get_grad(res)\n", - " grad_x = unsloth_rms_norm_backward(x, weight, r, eps, BLOCK_SIZE, num_warps, grad_res)\n", + " grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)\n", " put_grads((x,), (grad_x,))\n", " return res\n", "\n", - "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker,\n", - " execution_transform=rms_norm_to_unsloth,\n", - " grad_transform=unsloth_rms_norm_grad \n", + "my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth,\n", + " grad_transform=unsloth_apply_rope_grad \n", " )\n", "\n" ] }, { - "cell_type": "code", - "execution_count": 27, - "id": "d31aced0", + "cell_type": "markdown", + "id": "219dfaa4-cdef-47de-b60c-7c7c1642cb84", "metadata": {}, + "source": [ + "Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.\n" + ] + }, + { + "cell_type": "markdown", + "id": "68226a4a-6ad8-43fb-b92f-c1e8eec6f13e", + "metadata": {}, + "source": [ + "And let us try our function using the optimized backward" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ccc3ed63-ddc2-4b0e-bcd0-f77d66fefe9f", + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([4096]) torch.Size([256]) torch.Size([256, 4096])\n", - "(4096, 1) (4096, 1) (1,) (1,) (4096, 1)\n", - "maximum deviation grads: 3.5762786865234375e-07\n" + "res deviation: 0.015625\n", + "grad deviation: 0.0078125\n" ] } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - " norm_module.weight.requires_grad_(False)\n", - " x = torch.randn(256, 4096, requires_grad=True)\n", + "Q.requires_grad_()\n", "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())\n", "\n", - "actual = thunder_norm_module(x)\n", - "expected = norm_module(x)\n", - "actual_grads = torch.autograd.grad(actual.sum(), x)\n", - "expected_grads = torch.autograd.grad(expected.sum(), x)\n", + "expected = test_apply_rope(Q, m)\n", + "go = torch.ones_like(expected)\n", + "gr_expected, = torch.autograd.grad(expected, Q, go)\n", + "actual = thunder_apply_rope(Q, m)\n", + "gr_actual, = torch.autograd.grad(actual, Q, go)\n", "\n", - "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + "print(\"res deviation:\", (expected - actual).abs().max().item())\n", + "print(\"grad deviation:\", (gr_expected - gr_actual).abs().max().item())" ] }, { "cell_type": "markdown", - "id": "be218e9d", + "id": "63cb61ee-c791-49d1-ba5c-3fe4b5b9a9d5", "metadata": {}, "source": [ - "And here is our module having the unsloth backward:" + "And with `last_backward_traces` we can check that our module is using the unsloth backward:" ] }, { "cell_type": "code", "execution_count": 29, - "id": "ac00153b", - "metadata": {}, + "id": "cd12ca02-6f06-4d88-b5b7-25c4c27dbc9a", + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { @@ -1864,7 +4072,7 @@ " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", - " C1, \\\n", + " _, \\\n", " = saved_for_backward\n", " clear_collection(saved_for_backward)\n", " del saved_for_backward\n", @@ -1872,19 +4080,14 @@ " = cotangents\n", " clear_collection(cotangents)\n", " del cotangents\n", - " t0, \\\n", " t1, \\\n", - " t3, \\\n", + " t2, \\\n", " = C0\n", " clear_collection(C0)\n", " del C0\n", - " f0, \\\n", - " = C1\n", - " clear_collection(C1)\n", - " del C1\n", - " t2 = unsloth_rms_norm_backward(t0, t1, t3, f0, 4096, 8, t4) # t2: \"cuda:0 f32[256, 4096]\"\n", - " del t0, t1, t3, f0, t4\n", - " return (t2, None)" + " t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del t1, t2, t4\n", + " return (t3, None, None)" ] }, "execution_count": 29, @@ -1893,29 +4096,96 @@ } ], "source": [ - "thunder.last_backward_traces(thunder_norm_module)[-1]" + "thunder.last_backward_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", - "id": "26ac79f0", + "id": "2776d183-0232-495e-aa75-3b90e799c841", "metadata": {}, "source": [ - "That's it! Do check out our LitGPT studios and the other tutorial notebooks.\n" + "### Comparing and exploring optimizations\n", + "\n", + "It is also straightforward to compare potential optimizations.\n", + "\n", + "Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands." ] }, { "cell_type": "code", - "execution_count": null, - "id": "586cdd30", + "execution_count": 30, + "id": "a5e0ce05", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "eager\n", + "3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder + unsloth\n", + "6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder default (nvfuser)\n", + "1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "def test_apply_rope_copy(x, m):\n", + " return apply_rope_copy(x, m.cos, m.sin)\n", + "\n", + "test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", + "test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)\n", + "y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "\n", + "print(\"eager\")\n", + "%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder + unsloth\")\n", + "%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder default (nvfuser)\")\n", + "%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n" + ] + }, + { + "cell_type": "markdown", + "id": "08b8454f-c725-470c-92a5-56b2206af0e8", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "That's it!\n", + "\n", + "## Conclusion\n", + "\n", + "To wrap up, we hope you got a taste of\n", + "\n", + "- Getting things going with Thunder:\n", + "\n", + " - Applying Thunder through `thunder.jit` and\n", + " - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.\n", + "\n", + "- See what's going on inspecting traces:\n", + "\n", + " - `thunder.last_traces` for the forward traces,\n", + " - `thunder.last_backward_traces` for the backward,\n", + " \n", + "- Extending Thunder:\n", + "\n", + " - registering operators with the `OperatorExecutor`,\n", + " - defining implementations with custom forward and backward to include optimized kernels.\n", + "\n", + "Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). \n", + "\n", + "Do check out our LitGPT studios and the other tutorial notebooks.\n" + ] } ], "metadata": { + "celltoolbar": "Slideshow", "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1929,7 +4199,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.10" } }, "nbformat": 4, From 318ba28aede3a2f47d094501570047c3a2d56a3d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 Mar 2024 08:19:22 -0700 Subject: [PATCH 30/44] Remove unused imports. (PR2483) --- thunder/executors/sdpaex.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 005171e4e1..a2ff22007b 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -17,8 +17,6 @@ get_grad, put_grad, put_grads, - register_augmented_forward_with_checker, - register_backward, ) from thunder.extend import OperatorExecutor, register_executor From dfefdcafce078c1c0491492c57c68618a9816f7e Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 20 Mar 2024 09:57:35 -0700 Subject: [PATCH 31/44] Update README.md (PR2490) --- README.md | 24 +++++++++++++++--- .../lightning_thunder_lightmode_nobyline.png | Bin 0 -> 116751 bytes .../normalized_training_throughput_zero2.png | Bin 0 -> 407506 bytes .../images/training_throughput_single.png | Bin 0 -> 310475 bytes 4 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 docs/source/_static/images/lightning_thunder_lightmode_nobyline.png create mode 100644 docs/source/_static/images/normalized_training_throughput_zero2.png create mode 100644 docs/source/_static/images/training_throughput_single.png diff --git a/README.md b/README.md index 0f59d0007c..0a19713d1c 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,30 @@ +![](docs/source/_static/images/lightning_thunder_lightmode_nobyline.png) + # Welcome to ⚡ Lightning Thunder -Lightning Thunder is a deep learning compiler for PyTorch. It makes PyTorch programs faster both on single accelerators or in distributed settings. +Lightning Thunder is a source-to-source compiler for PyTorch. + +It makes PyTorch programs faster both on single accelerators or in distributed settings. + +Thunder aims to be usable, understandable, and extensible. + +## Performance + +Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). + +![](docs/source/_static/images/training_throughput_single.png) + +We achieve a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. + +Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway). -The main goal for Lightning Thunder is to allow optimizing user programs in the most extensible and expressive way possible. +![](docs/source/_static/images/normalized_training_throughput_zero2.png) -**NOTE: Lightning Thunder is alpha and not ready for production runs.** Feel free to get involved, expect a few bumps along the way. +**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. ## Install Thunder -Install the nvFuser nightly, which will also install the matching PyTorch nightly: +Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly: ```bash pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com diff --git a/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png b/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png new file mode 100644 index 0000000000000000000000000000000000000000..831a23e2339905636bc7b0a6aae48ed493b8113a GIT binary patch literal 116751 zcmZs@Rajg7(>2^u1xoRv#e++sxNGs?4xzZaOQ6MyO9*bo-HQY*?p}gJ@dhoy3PC>p z@6p%Z&q)q)UD=uao0&ChX4c*jsz6yR3{s3oj~-#k%So$0dh~+o(WA%FFP}er1NqgU z{pj%kTwYp2)BEwktOE(0EC8{g*? zN)F-=uI#QgpT{KLLQ7ZPOzrj^ytt(mxxa#z1n4_lZQ0b`|EzAiH~>?J;7Gpw|9z-D zWeyRRxUrk)bU}Syh_YJuCgaU0D_OV-6IrX>p9n)RI4w{51zt5+i81~2)2+9mXb@rC zo8?odWkIrj9nFJA%>HwyPO(|-5pREUuxD~&b+$WfK$`CV_kR!97!-Q(jLt_EhPtU< zy;+&@0}<#QMMF0&DfzCv18iD8P&74pRF`#j({0VqiKT3>`lVC{-W=J9odN#7&P(bw zHAbo-q5g^c-@C`*!p(jc;e0MagVI?FRgIc-FBR|Y6O;Bz2UR%gmT&n$ffx9H-T8wh z9cHTHkYj$a%Nr|4j^gd@dVs%ZqsLS;3v6%0B`LP3KHe~NyDnT3p`Fmubeto0gPv>w z`1=&D#>F4YcbF;d3rnV-;4#wd&atFNcV|{fe+CQ8+`Zdc{6!#^xqdJ@Ima3aw_`QyO%K0fBmrZvFuN-7a$7f<@KX0wEg#_p*C%)kLdCGp+j_7{1s@UB%xg zKplsTA&1mm7_MCOGZsa)$6Xhx2J0m2XBBQ1Zfd*0Sjz^sb@MA&+A4Xdy*9Yu5aFha zh(uSg+~3}aA9(y=4QsRq#hTx){n=6GT0r$5{b)ZGez>`zlL#68JB`?mF=MN;+9cI1 zz)ogBui5Q;qBxx*y190-O3#+d>1yH2vcH2nB;JYS!8V8bBm~r#RQNpFu5IqvdHk7lTch#mr;G7!220sQGUkmg_x*J@XT|}C3o00YJMX2>$V13D zrR41l?r}MWITD-GjF?D700hyLLo(IR2+B-a` zQ|6I6yua_c&Ep{y3@q82uE%CjAmP$>DrMrlWubr-{1^a##YD2k#X;kG5PF~yMbpX+Vme*Huj z*E=C{-+x%&;J5AHbMq7UW1*@7ktmjC~QTL_Np z(`R%7=0>|G`^Y;X=3=68(%*h`!4Y{0GQw>>ZJB!s**lF&=s4cn5Ar?u0M?6@4UJ0`<~Ae+C$n{aW1QA)er*>Zs!3< zD9qi6C!%BjfeQ=_If(tfa9O*zr(IK1+fU{_$=u z%_^SwaB`xPD-2*BesP7`+{y|UAB*+UD)im?oY?>^FM11|e9n9SbNU4Fy>mW>TYdFL zsFLc|RC|D26ROwtR(5}AKzTH);4MR=0g7{bwbs5uy*SnKX{qjr3h?q!?b?N0h=B}s zWtk>W%+3)pSW3w#JNH)1kR<4-+b_bm5t|a#kiyIV(lA6?i2@I zXYT2EhgM8^bO8Z%NJjD#Ct~DO#yD4rDyTsXwNsQ8-V3h6iBPXzp(Q81f?r+ZNcu1XE z*hgW-PHM!fSGmP>M(R5|x)YqyJ6aYx)SpXaBYLl_$afF(d1ptaaW!JVIl(zoR;f9l z*oJD{(n|-Q-&S|Sc4EKf{{aRVy{b6qeZWypz~;rmLFKl#lM`9`BTW_!zP#vk)W_eG z#%Z$jO~L&p>;MyoL2G?U>5Z@`fa(B>`Lcy0ZE;Pf6b?VCGTi@co@ukRyJcR$E8*WQ z;NekQ$J+km;A*10``^%TRsQ;^jefcPww$l?dg<8jwUim9bfwc?)(L8qz!N z3-+_qAZSNMKfnA&%T!!Nilq)U@Y}Kb0BSpqodacBY^~pw2mesYOB_(Dm*Upb*(H|-qv1Q$41+8y0%MELnRrI0w2hwgIZlJ zuiww!@7i7`U8~8rHf)J>cj7oFZudI=?}YQQ&HdybsW(!YJqi-At1+N(gc8KJ{>h2z zQ_ib&Rc?P}7b%;7ansfErKOK?_nT?>2W-h}w0JbK!Ysec?-c2BZIODq6p!t=V;{RC z(vCUkINg*7=dB9f)_>NH)q1TsKwrI3=O1FB_jlb!ev^27MSye@yXa>_?dS=YEqhwM zbj7WL6;m|1<}q*@qq!bf-xXzMr09y27%VE}HzJE*AGUykKzZ4ELG9R+>0-4=k4$e- zPmpX~lTWjw`KF{ETuHFXERu)fgnVyicH-@}E(h6^TO0^{9J9S@6!}ku6b7h@i^@D; zLBRH{*Z7+Q877@nfCdlwVmLs%jI-B}fH%g1Od`|AmuuCBWDfJIVegnWLXIf(MWH%} z!eD+~8H{3c79r1rQ`(AG$;jg(GNM=1=Wb^)IKO>LTYq#=@A%Yyw8)<6-57k#=I0|w=7*~<4qpzij@O zO0C8`*rNDgi*UOQM-F)Q%iJZ2t@|rToSD?kTlL?pTgkN-`PySzmq}9;RprDN4g~<2 zYtF@Ks6p46WbRHLkEk-we5=0bjhIDs{>MTh}oQjw{L z7s8t3`whF`WpQ0;$|6aa2T+)qcqRWSeE}thf=sAZ{i?jM)fJC^xPJZPl zTc7C^RaTW~4RgZ5%c7IDj$0;J(@{)aS?2a=(*VjM!hxz?uVMvH$=tckryH!eaf&sK zqLBxR=? z#Pk8?+!wmPfNc;T>)96msJ7#|*XrfQ6?1z$Q5hljGm;cL&X%XGnMH|8DhgF|{Hv)~ zwy|!!Q@0jyZ>{;GQdyNtHnki=xa3(!ZowHoVio?7{Ot!)B;IYek~SK|1ERQkhlg|S z1ace-24vG@9-9X0ZgU?F_$ozUDbe#sY71~!X2rbgW#L>-JHuZVy`)wlCWu#I3`AhJ zpKL5hwmT;)^W}!w0PypXwiQc^838?+rYfS(6C_6oaw*^4Ec`ZqAwM()LhqN;*G|)d zegcS@b%yzrO|=D8^k!2_B=kUvTDd9{EDc|H3hKFkPDm7a>yAsD$LL_RlAOYu@93AZn2s$^`_HDc5;KKEOCm-7NKWG3&! zYEUnHBq4TNKWDEe+Wjx1^1=K1$zSphjM8rto?&(3t*zA5ghlnV{jO9l40~pi$7FUd zNcu z8$pO8S6;bdWr7xO2%a+G%VoG`2u42IKG&l2myDWJEp~b(#2~OD(%Ggr zD>7r#g%JawlPPj73_qNct?;02ACnqtaujRupe9_3W;8c2#d82{3mwKmPNqfib}2Saa*McUN{eR#uoSkY zeIahluB0>9UBwe`VGm8D^S_O$ztscZc%l1%?PPgG6N!OBjWF|5s`W14z%Xr9HQr?J zI{EKioPH5HY)(6lLJM^fa>Bc6uC=Nx5yTGG?5_mI$s~p_<;4?>m}WRjqLAF%D^N%7 zw+6Ruc3$?NlV>W{vFMigi26=ojp>KC| zIB=t4;ZTnl6A0DaZS^%8TI7RdTe@24!38g;Lu=lH@a0m>fxVL~<>WjJrTplW-jK z)5~Tpl@NN7<(2IgE)f7Wy4?En7{;29jj|kWo~kXHO3^vtK7K15*MM3p-q%E> zv)(QGdFI*F9Bj3{TQ$Q#)x5VawH?R&xZbtC15np0u||q`=;)uKNlu}_EruVocL+A= zc0Y2(O~=}T5Bot!T2*w(4c}f?J#ZpMHAdSa^|_ppJco9IA89{O$XbwG}`TT>FmX|U|wCp<}4 zh_0F_A$~A5)l4$XoBV|g(q&LViL59gI=I{}v1XV&-?~Z+)>w@qElX4-8niJ0m%vpL&J+dcR`M_wgT2EEyLZC|TXl2?5rYb9M;bBpEfaXLhjLnO9{3M7WVjE$63QGm$I%Gu+gPegmau! zv`nA6O`&#S5{t}G(_sa5E;rJTY6sa|`^bE_Q8+I2kSwMO6_Le#O(+sa0Ekp}u#m4* zPiLG^@{?bnf7+AVF9_kevnE1ezQ!NPI4Yn}=)X;mVI;tY&;4MT3D=fdZ4$$a&7Y7h zRF`2dX5l!%zJ>JQm8Zm(Z|#A)W5+qX{!j8t{DA06_8TW;4_qpHnrJ+IcSm&ZW%WmJ zKkUqE;axEIrgtqk-a@`T)9U#dmjn~+UUycFt1AWfgP8o0;DM&Rh_hK(Pk5!#t%WSz zlUOrDmk*!M>2u-C?u2+>ML#MTe;`p)al3|r!F{W5s*dMJ@^&bs5R%>aGr=1AYVU03 zjlbCc=s6W@gW9UtDbIk{@PYTS@hO zu+V*)Mo3|~u%LVe>ZQDR$z#EgWNF_ie~~^gleXcE3-bi}2J3W)<70-C5@+D@Ys3e! z-~$D}^kNdubMFAo;V+Ke_9SSo^BwD%gLT#q z%k|xOfsyEV-SPUimYU;Hh_eb&vTc<@kyiLnGqVN~Vk9T^Q8#o2$plbOXZ;xW?y$N7hXmYB17V47Jv4Q10D4)m#vT(W{Uv^=wa@>W8eR6_b&?HuKrWvmiMd^~gP=<9wMHGCT}sj|aa`q* zLXl|>drlzG7ELI^M+WD;Mo#N#THEz1*?+m5!9!9S{0v>{^mV)1c^K`fti6?B2zv(1 zx8D&Sh!5PK#w*B<66{m2pn9L%pf;|FYKgCHisBzJFU$$Bu`uSu)hlv3V}Jx7wlU2) zzanveHP>?N_LNxGL^l+H$e#4G(-YyjUQ{TATLDhEY>jhyb7W_YDVLg?+Q=CD)~bI>c6544j}`IevlyS8$r*w4mbD*#0w7lM zCsVSuvI${(!inQ~kaCc|u6V6Xp3)wa(XbFcL?j4J{JAchr(ue5ApG7pT|N7xwy>e% z`!-{HIvV$3CbvX*Wts((j?IReo&BUxqS4lw>4FA_>|&^4A+gMBNJM>28`mu{s=!C|d%hP<@Nv9)AY!uG{+*Y|M$4}Lx4cLApT6-lVg^=s@vA0o?q>V^e zj!dHSJ4J{pdI@wxp08?XSd&&g7D&3w^q|gr7v_>D{4Ghk60edLZBaD4*TD@OA5>B0 zt*1Yt-{he#)beK2o4^S`vD7deVejb8L4&;(We9u=s?|25nA^6#7+HEHQI5C&*^*=t zQu1)u2P3Selrn?kd2Z2Wi3^lXGGf4gNdPS)6-Q0gDWg-z{rJNBc4s&{-;9HUP3KsN z#fH-K9l6|Bf_U{jCHe@SZ9*AI1xboA8Oo!gyu{v-nk{yll2-7BLzY;GjX}QZ(@k$? zMC;^Wmq}RN=u85Yg`&5u=MO@=XoI@IFL;_}4MWMgAsuHf6dI+Xhie48EY>H)VIm#MA}`I@+*Yk<1l^JE`dhq2912Ja4FwAzY@!MVZGGus>1cJnzlT( z9eQ6`t*Rsh6x#_|y_%iWixe9qh+fgNoI+*r0#6MY*5sZwenmDB!m^=o@QO8Pk!%m~~{0 zK1Z!W`Wq_S?wzg0$nynjyO67=w{cCiVAE9l_j1g6;6(N}J+45AJ-LD1`wKerltLOK z=PUkOvy=!akA}1CRNX3pg==ndon)!di&|FOj6qDichfmZ*{@24;M2>E{d>vqXQT>S z%NRw3Bl{Hcf`()#jJOfIkz=}oTI@9e9jt<^riC`!Ch~o>CVEm%;HlevSu=+t_n>`2 zu;o18}VxI) zH1muu-`IHTCsQd*-`9D1_7LOR4FRZfR>JZC_BEcZjIZQSZJ?%1%V6Cv`wv7~Wd*jQ z#UyQfI53h0H*`HzIQblU*3)$Ps(a#->4fP~%=g7qChVRNCrC&N+&uGA>`lfzkY1Z` z4nYvbO`i-TES26cq`#=~I6e_2)XUjC7k<=45kA}@YF)rY?IO_6=zIKsG}*95L%YYY zh=Cc@ac1#)X#EBZq1SMdn;VZV@e8lW{p$#0lHvX&AXQOBPT+ctf+qUsM=g^Eh55ps zknA%6zKjC>c6y}W6|wvX+sUunTBVQIOUk*v#!D_6K^|PiFKBpVQ?@|^IPeP<_FcNU zxL5M2O~$cCuA*@x3mXBrxZg^9OX40q+x+a@8qa*DNsH0RzWRaBl4(Jo*QQX>Wlc$P z*u5gjTWPn&PDnjQgR&rETj3Lh$dUjfd9@+Ej2BTlY3XcM#res~T;Eaq8R@?)q>6S< zTfD;=6|-d3jvUsjT*lK?sexEw6iOIb%M48=$@{(qZiP;g7Gb@&O0)TpK)cJJmiHqu z4EBa*eh3neHpKnR$V$(~{iUdv_ti;94(wy07Sf5mob-t^?(_HN`iHg{4|SNkod)bM zsDD|fD1M4p9{+0;%)P5Y#id{9qLG&-&NM5aH?NW4?YQ6qi?%U{bY&?Q+&WU@Irb-bfvHd^hJ1 zH8oKxFaY2utXd1Wq2`*94G7wo`g;jE_r;Eu>S6W!K*l|3I5bNKS+N$kFpCheAt@<% zS-oEk*pv`jcoix5hx>wita~(BbE~aFnw6=!Y2R8Z)NHlPfutLV#;VJ27Vdb7ZqL!) z%WMPHyKi*If=1dK)(ozXxx2%s2JGKY#Gv=ZJ2(bCtI`mYjb&J4G5j*4nwyG*3O}85 zS&pbL@6A4ATB?rXFU0VoslkiUNpi`GJGNOa-FOGZOLNiSIfOELL)lNFXsG5pp>a=Z{u|JmO`YC%27CLxba{_z|nthOj0KHvEM z%KMvZNZRNRc_capoIop=f|gka&zV7oR<}6|1DKBD>9VDAz*p^rXj$~#nkQ?uY z6NPUR_i9%&x4Q36cA2_*b`}klC94i+ojB4i#dH~v20OZ{Wn)1na>Uvga|M_Rc0+^g zs`ZBD5|s1_Fz`*p<(lr%ip4*LcdK->lpy|LZnl3Q1QHm5h`sv2nw3tG_U)uA9t7I` zAr(g5p?=_yVNj|?__&J+jZ67inN@5E-C#O2o`x|JPBh}Pn^>KwOM?gZuBb7o7t)i` zlCD{Y#7*OYQ$j$%40)moe$fR>+q>nP=y}4SEz6|N%GVerM!1G^YIf2|*GeYj04a|- zE1b``C{6M;&>k#W1juZvkNnDN^ZKW*@CC0KB~C{2r4HFhVOA9o?^5xPpbi_ebw$;JVDS8YRUO|n^H91%P|#LOUKmb}9pW<-2_bmQZs-^TX1 z?8UysMFV9u&X;MF+mA1}A1i(ok0J)Gg^omoDa=`cemO)qE&C``OXmNI!M}LiKIa}Xktd<_3Vt0UD>*YkNuyzVE9X=spe&@@%t7>VIjYRvReN8g zQ8XRhd+bJ2VPkAy>-^Qsj<^Uk0qTrM4rvzwLgr=X^i!6-OAs8QWDA?yQx}p zAsRNb3^H3Vg0{+{ad3P;(~In3l#jGy9~TKFIYcv%p-65Pe)Nn{t}8w2%D+BdSa1$t z<3otO?`Ei2`gGTKlmXbXZ1!)Wt_-zmi#zTMvyVhs`J7BSNLlQe)4j|{Vwsv!{{)}{ zNXL`WokF)RU>nzQy0+&*32Audx9=NXoh@i$@MuH`ZW;Pa1{YWeWd!{U<}{Ih{cV+T zn_2hydxhH93zkzN?_i5BI3N~kQj4A@q)B!FK+#05>28-=HzY2gL%zcoIcdA9+5PRB zeTg8t7HZBW;R6j1IV&eGU@R~Gf_K=>_KrnAOEhDtv|!(zeW|+EyH7u9c22HfQPc5U zJd*;zihcb}OqNnTV(&KdjE7rhi%}~L9DT=Og?=@?Ke74`iEbgrXozl^=VTB^ztko0*{e% z6Z7Qdsh{|S1hq|+KE5&aBI<7ea9pu+Bl( z%Qf%Y1(kos0IP&;+GE7V7}ghs6crDphE}y>z6<`;=5dzFSqGOThjXIQ0?MgA<{%+sitfxP8;JaCzT{W%gzwod_h+ z+i}C6fSluY-McgjF<-4|cv4j%*YALk-rv@g->g2b2!Tc{F&$iFDet@Pgf#*fH7hav# zzP=_)pL{K)-t%}vSy_EF*J~l+&xaMI%gjp$j%Xa7+sL_?q^^I*=;`u}SRBI`9_nBd z$>$z|CTmy>P3JI~>8CC9&OiWXl7kWolm4=d0G2$A({%nd%0w_lh3N;Iwei&m#{%YD8O zJD({K@3|b*#BQmJsdHd?bl@%2Rb=l9keiAaQn2=XJg@&5ns+8`(bVm8BPCO`-ketr zu#tJ{lw-r1s4BurhKclV_q?2#=pO+2-d+Y6{vF`R-jiO8!Oq`es4*d{dzZ`Qu8bGc z<`*;6EB^F&hz9e@o7`q47(LU+naN16>G(0Ik##@D zHD+yqzMpji<*N;MRQVPnx6p3iIrd`imPi+yyP%Xgd5cXIXS`aLOc}0hS_Qyb+76~R zCu79P9d+0rnevTB$qTZ2F*Al+`RHCo+GJfezfC&|VFvj1HfoniuNWTMeFUCTX>~N8 z(++qTAG#kG8v7h+{mqgG@q6$Yr~-F#{2nw*JZ-KQubaYDG_h|hp@?o|QH|D4UEWd= zBzT?CWts6kXT&eOUMesV9%o-R#-N}fz6uAR|Eqt=(KN)B?mXT6l}d55il??YA3Jt4 z8|k*93$Xoh;=4*$B;9U}_72{U32N25^OA|6hpQJx?|(7nVlxa^IL3$;y=KtD;O(wp>aiT=>&B7CGI@8fwZ%S`EMhu^yIpKU~nQ?Hp`ff4fS<&azt;fvM@Sz*gt z_;=UJU7gF-R@R*6e>ZY%`E8s3k8+;PVz1WD1h+PL0JF}eO`w+1;OVll&BuCttE@W! z?FLfgi|)ZT!*V%#s`nz{I}_{`F;=jeJZ>%PooyzA>sO{G@5bq$-8O`Ni{zb2Q67kK ziTQ28_9Cm%8}jY>!)speWRPUo{POMvCpc-ipoG z;S|x0J3KD;H#L5x;c;OmY4eA_17e}Er|kh$s+2rDf+;~sMkMB?b}l*p3{vC0lD#y= z??x%F9&$a(`SRekqa&F*Ya~>u>@0<0dNDl)5vE^OnMMCiZR#rwk~yz+fB6pvsu52MaI4y%(*giT2AIW@2gj9W(yKE-Ny@O<$J}f_u6f4 zop#~5i$EYli(J9Et-O4x2afEX^1D5)I(EUvhe$T73;AME_(dVKuOnFBfsM1w*D~N&)Csr#wuGOXdv43=X?~OL?Q6XhXEo#NP7PR zB9Fdg(P;0Gt7< zOJzxmaf%w7dL$YW7>l2jI4Y#?=gOh8Gh#apiwCBffBL{wtFReW)$#jMF=^9;(rMH**)g(w=lHhu zA#9wCUm{p9J(66>A&S%uU(jB{L)~q5&QTr=RkT=LS?5%Nh^7H%-(dB?w^v$)sBA zV!o3x&vpV;Ym{Z$j}Eb0M495t)Vv+$U+bI(C%qF%hPuw)dYar%S+W@Tx)=@e?&=J- zj+mvz2|*IUqpt6=_ts2o^n&+=Bj)8dyjD7l^1r+^LdXuTJPbo7MgBqgrExE5@v42{ zj;q<@7Pkv@uknDNe*5?%l)90cMfy=xvC?(|5+v;Q1#qgy=Q#r)47<5LUIebt(n!19 zD|RNeVILEknS$YppqEnT#VgBOL7p$_5PRqBsAgNrG_LR-;VL^&_VnWzyfU>iUiv4} zaiPD|pawXyhjn;;(?=h;$n6HjwG0*AE~wT9W6#or$X^L{MF|~7>;MDhLbgk-zy4r4-MAx{fgNdLw9vr~y#bBE^QXRblV`LO~ zQu=Jul}^e#5^EW3KcbhC+Ja-H1goP2lwo_?E)3MC_V5=?!vmRo8Ov`*Kg&V8Sg(Zo z0Y>H$!)(D{0?OLQx?R6HvOSZY*95w^qJI#W3WgSke>I0-{=9MlD-$%q9BHdY5F)SB zwmZ1kT@}^)!9tF4vI#aNQ%?)G1dG{e7$WrWrmls@_9Yvow@wiCV&<%fkvCIQ#vgC}g9ipw22dkhmXcijKhL?>*gj8DsO zpKoLw*`3$-JyK||m^>Q|a7x^iJnasd=#KPxm0F}p*?a_4$MCXK2DA{p??l$07s9>H zJZYy0g_;WqW);G;40w{|NO#_~AC!t0b0WgKUqveoX|AC^W3)fseS_l>kEqA%5cohS z_vJ&VULwRP;mv`aL0(ggn~cP=+?#97rRiyBzldi@Zw|^P9|0N02c5ZDI$&aVSi7ND zpJkt{+r^MWm!+<1&Ai8*12WN6eZ}vx>bJo;L2)@>^~#3mL9i{VBRkZy0?zZO+>&gm zZ8Qned>xuMSlGZCC;cn)$E&V@vnv@rUkO$a%+bQ;nzz&J^_J}oMZ8Yvzq8S}tO0ha zIoERRj_VbU>|W4IeC-H9uNS2_lF!Eh;?LTrCgYb{B0w4M)ooPWYLl$L@ncMz^;nh6 zyPywGW#QQBS$cd+K$qjMj(}!H_U5FYA&U~cpzgd>2FkuLEy^av4o8cx?)HO;V1-(^ zKa?6Th5C2nN?u&ze(9Oj`1Mi2cF=>Af*F^@p||bx?-AyDHD(^b%5_(1lH|ES#W#vf zsBl((Q?S|(p z7*8$Z<0v9R1>4+&-0HLqeq~D9C~C}X;`^*fqfCG_=enU4q#^cNEaTctM1o26VWH7; zdo_Xl*1CIt=d&~rxPE?opZd2PX{+|P?z63&`-glxQtkuid;in36=A`N*sI~7*%602_AO|gn% zDh|IkJN+{wdifAZsfsjK!K?3bf0P~jD|jlLpL#OPq`R!*IOu&k|9)jnC-0)~%BIvK zZeoOqfbf8-DYJoQ$AzQD^HL7QO$v2IT`{k>e2U%4N4;8iT6oI z4qYS6X#)Q~>LDojj>`I)m<0x>zRHRdQgUrXmc0NN?~3lt?&2t^fihprD!n!4VYsC; z+Ou4mZ}Jh5x8A(F0&f1JP#-iC-~NBHiv{HH70VsvGQ&KQTO>PKQZao(ZdMJzHPR^U zN3XFk1aUFDgoNPLjq)Y{vcGo!pmZ+_`6b6FnKKh)Ije$0L;Rt)fO>jPhI|(b#;hDN zZdqV%RfWaBQ$HPy#@on)d-U>|(s-%nq@x<~UgjTRIi>FF6ag#GS^b!fWy)xi!909A z>Ow_1ek;MJP!Bq( zLhmRYQw-t|)2S|uDisrAZ}MeCO5T)poUdV7>QH!TMtTku3JC8N>+P8#_T?#3614p9 zI3iV0y?!1ruatjzms6UPqWImygQTjG%yTp8^zx#KWkFj@xY8jwg!~7sGWsu5>*h@U z)q+xaZVwPh!&gA26j>@Rk1h7jfU(I(zGeQit_^^yva+(6`pL|#yMTTk^FdZ|o6<)t zg|R*Wz(ViC#kV5IwqrZAJe-gpnLqd|lEkv?Rp_{RdO^AzFW(ep9uAEgW@yYH^pt{t z@5xThZ}r7He_A5t&q{?MI`1y0x9Yy0oMW%vcadGRuXX-Hi8X&`_$fNIoXyvMVjQ5+ zaks6p!r>~uyx$}c!Oi|!W%M3i75V#KJ_kpH-?qhK_mE30`q zC8yo}Xfp)R=#tbc^5@_yNBjLu>JqHjt+At=)k`JH|38Zurz|J=|5fRL&HXixUH{)J z%!&L2CWFG4q5A6$TFHIBA)f;yJrJRkCzR2BAA#^b$U=)KUuXxkxF<5 zzEZ!0@lSF_yeV~G67A1CbnCs^TGA|pB{0zX@_-GTcMPsv7XEQq{h7H{8`kHAA{-7$ zSSz#@EEdc|SU-WL_jXfhG?i}sO@a_kn}fm`XT zGHegFp(S{Vy@LhW;&7(p*f&^M?^naAHTqc2v{XB*7Uz^;^W9LN;l^@vw{sq4w_DM0 zBaWd3+1wbbXqn>?#Q)vH(w-*&4|?z)Ry_x}#smUjCn0`rI9-bBb4Ls%HGW$r(To)O zkcrAX#Z6VR#%Ehp$?&t+g_8zCgNy{>bdTs(@P0&k!xP;3-PlUzLw+9Is4!B ziwdF&%!HTq**m6ZV>@K7 zhJ*KYP84t%rO4D_Pc@7Npt4jS%MTEAQdqq%WVvL^a<df{u*IXX~vGTwxoiV7WvS{UITn>vt z~b8N#W!Suy< zaoJa}pV?6vO!ygSU{K(l%H59_+b+Mw9@b|=Ubj+wAH>8%lD|tTWdx>qqIAq9(yf>uDX!oFlLA4T62-3CNq`l%G9v86be*NR&;c7CDf1ZM6cReXbJxs!#{C|&M8C(*L8@I2K z3&nZdA-^J3UJTVU0h$g|ETjSW-VAZo=-hiC^MZlpDBZ{)CBoq<2I~1A6MEY^^-i^o z`O3D8ZXenD!qEGm+=8$<{)RATNt0>(NVL0B+-*5CD-O06dIpwtE8fDccZcXS!#gJ! z(0&hG+kojI2o@*IG2&d%#MKq}Y^NoTeSo4=wokfHiHH=He5sA~SUtogWSbOfVQ&$78jTD57oG*Df3I(NXN4k!aZS0r!XPDpynM=Zq+vUH@w|fi55%2W3m~ zQfP^ZvazS7xBKV_3?%!cRsgwuYZxd=`AuH~`X2)sXGiuhpyPiVmK*Q?h_ zfG94^EjjM9h51iU-Q@u^xbw=X{nP!WpV8G)W?SuNlDq;AAffsI_07pQQuV77?me9J zY$h+pwKmZzBgQy~veT7>EbwY4aXGlJu2YJokD9uU?_wp4Klm8MP3z(N?KQ4a5i(0z zKNfhgI_e)NKiF%pC2f1eEAND=D@fxVL&>e|8$;q;Zd`^RraKoCJJ6!RTJcIw_*Yxw zlqLa6NRuW2V3I>m2_E(RNzjmz(f<49l4L%HmL|W(Va0j>CbR9j`tEmir~l=Z@$H2j z>Tgkl3a-%Ot4U#$gxEzz8!??){< zl|Om_FVJS1U3`+HAB(-qDf2DV^+)qvK$b&txzuokIJcG*f4>poqILV~eIG+7dt5Nb z-E{`CVU0aa&uHGibQn0_65Us6#MNkL1#Sm2+30@zT@@NKdI|XKMhU&F+@GQpa|~Z= zsPtf)JkhJz-J9qf!+I-rf9Jb*y4cuwdf`ISSP4GZyR{F(X`OF0#;J48ECu=Ad38eQ zMz+Q`#U^esJN9#VTf}1ha$mPuyYwtB!Pi!t%XcF;d)}xt6=WC zUYe5e^_PCko}l5!I`*ts)YRg1Svxxyi;D?xf1q`HNw_}ki5Zq>xrvQJud@U#tkJ_n z0+Dyf5RSDb^!y0PBvbY?}+(fC|Hod~8Qkmt80}3;HPExbdptSOJWzba6-BwHi4{1(TOBd_-x!2ta zhgdSf*4kCMF^N@E_C>T`pDmUmW`dwFcjQcScMw2vrgH9QI&@S%Vj=u8l_2D5!RA{y z25QOwsX+<6&roC&*Q+Blz2d^m|7z6|L4?{6(Bx<;zPd5!kX74bn`gDaqVQQ}|E;^; z*7x14LuxQ=iPlk&C?qtQaQZi;`D zxI>DNh}!xQ|L~}YBymgpV~1qU2RT!&&!o2v?8tZ?qCjQIr6|t(y1_Dm%bqs&e8-So z`U)r?#i_vKNNCoeo~VJw%rMwmMa5#OUZy9f+BLNy`JpGdUlVGME)h9m>)Tw%a5&9M zQZO23oYcP@_v^k!jX|uiky{^wWQ^yL8eQl`;$F4pu)8#e85z`qkXztQ(D3w%DTf(E z<0wXBQMd?iWQs3G6V*+Iz2&rws&rD6yvypmwz%59|6Ps)D>txeecXRbMl`z%Jhgn2 znn;`sZwd9rX2P(H2mR(A){On}tmXhetgNvWbfRsZW0@JU`%I#l!$fegYL~MAIg?UO znLKg-NU40i|J^C*jQ7iuV@c|YSf zfA8L*{(nyIZ=5<@?@Vu1QmATO>BDoM={>PMRP>!6iNW}mK$E*4LydD?Tei~B{7S}Y zo9orv9$-YcqVkMq0EyazQ1ICo7NI%QNm;~)C)}IA!nbt|o3ToBlip|Q$GXIUN_>YE z%eE{sIj;XduD&uVu4rpAAwsYOIyed5xVsY^8gJa)EjR=V1b25ya3{D!aCe6Wf;P~2 z~?_H=SF~dBBHREWWl3#eq?{*f3KJP3-wM>OC73+T0}QdyW-_v#V|W7B)Pg!r$ZiDMBpCYJkk(x^CZK%o<+t zH$EtWKYFHk1SIxa24sMB#VHhOc%{K`TGX`zdR8=SShag%|IBfcEb?nzj!w z)pK@&cx(}S?1@e@>?W>^#@Xdj&M8!FuNBUD43qY$5q;JC9rY2Y+RyuGesi|x>btk! zo$H$ue>Cs6r(Oy2QP-X6oal6%iMbox|6=|Z?49J7K)*{F+_^&HVf!M+^jyDs&p9`8 zhf7Ix3}$r@_Un|9L*}m7#z`B^Z6!u;gV_(vzLYuSs(o|fNg&I$X6&Wo-ACd5Jdj5J zgK0n4&qQDsHiB}se#fJWGDqqpQ0A>$UTE)VVrHn=)`tZ+4Hk1;n|1ts5vF+lmdb#Rm9$3{Y|N10H2KG|g?)1+CiMj1`dJ%G14tsZ5_ z2w`{4pQD8|qcv=7!374+ZW+YU^r|?Eo~!-k#z(qG)EyNdw71Y1KeCHi<>cfz^C<9% zl1>Og^|W}hvtD;)x(Ywabo9~LUy|Dd>Ak=psDd++aS{`Bs=m)btbV-`)W1IydrrE2#LcnqBr_4S>b!Dt8`5!WQuN5Jiu zlOOr3Dp~LGqTZkY#q4O8zl2H;7jkG}odQzi89~Y95=y6P2+G1PJih z5x?IAW`vCeQ!R|!E;o!Opb5fV@p0Y^`8SnJm(L}CSyGXWIoR*khJp1;WisjS1bpW| zy^Cy;o3tD?elPm)ZD;OKHc+^y542`Ocnx9j|8gCzUM?P|ICS`f# z0qV*`k;A@m9cga{9uSI~l9O7VNB{ zlXJ=Twiw%44X{Qfi#BJ}4hgxX9X!%sC#D$0L|;wj$*j$6#mqLDCKOJUHTg8yY>8@b zOm8Lnzeq@Q*4J9h2OdhU5VK!GbF(QK*uys`nDxgx`xD2A`Sscgi*hQ;kD=4z-_pOz zX+Frlh~s7b1b4Z;#j~>NZeo0v)1_A>m*VyxEyt}NXED3jej%qA>vq1%OEa;zhfhaO zRz1lM4%B)*)hf@exaT0|GX6y&T@_rt!=mAK{M}53PTlKSeeVw>xr4^596`*xa5PVN zzWUs?i%Ok=8}m`FbXop8GZm?fvlQfRFmY)_Z?$uII}H-Z%(*nyv8MgYAg|Kl{HbN7 z2&liy*N9g3UGpBB#1Fn;17d}qzkN(gMF26AMkfPW`VrUJ{51U3B!(?*XJla;68us_ zgmbI;y0yH-a}TR77$kk^yYNWSk#mcb>1o?^#HOBYI~!|O;v=qvA1%ylq-~-Cqq14` zI0xi78b|b65f{~H9)2_vX|yo(Z_qAP5P7oy=s>PHFlS9wSE(Yz?{U~+YLyB`6VW`{ zZCj>{Dce6QpNG$6T#axZAMq*raAi@GG-M(zQfE6N7^>8VCUp&5X4mR@B%w9cH6s!O zcX=JmLLGpK&iXUi`Z6hJ)Cq*b)%Xc88{z2btYCbtj=ON5p5wRG_-;{c5%sA2l=`Zptm;x-ZxIq5Ml+Hryv?3Xa#^v;+Ks2@<@Z%W?b93cTr}1KzfidNTehM;ZvW*Ieb= zuUiiU$YYrpHG&HmrB2yKlnC{!+L)5T_1h;N2|yOuDc{%DUdL@D$2LqwTBzYHSh!-G z$e2!uB*ZdQVjmjiO0TEQ#PhR*$_Rov|CWW)q2RF;a=7UJG2T^P~{cU(+vK279pPkmfLo9x;*C52yf>$l5}OsiegNb4wqITxaWR`^${%09q3 z!=w}3j{4yw*BKmGkO5pfZ9)0Qe@^#Zb0GtDoDSRk9#MRLlF4DmyP&@}HS4eJ9wWRR z2h6RGRDW`r@MUVL7wn%G=XD0v7VQo87NUg|pdm+HXcbq`&}Iw`{>IL59Ti_Cr=*p5L)DWv3yju;9gd6QE8qWn%^*tL6(8W@&{6jCv?*g&2_RS5ZF*oc7=<>i zhNYjJN0@9)jv>B~tdv~@R=Wp{iq8N5wR$wKdEL8;>wlVa)V@azEH^{ZTz7C;Y!x0u zf7ka-@;li;Fr+@jvTgNOa234yg^c?Kp9m2-;~BnQ1+vEiG(I~?)pp0?gXG3E-W}12 z7^Y&#!4&USl>eZ*CcoR_D5wvV)PZy2e7`hq_y~vJum2mQ@PEn`Aqarwj&&RHNRQM6 zD=$=ci((3=%kl zAV2dB`Hy-5lgBUqruv{hqa|;BiuNU!Vr%q8ptH6}nUGfrm|KvmXcsIR-K*`J4*vjC zh}Id)r^X*xgyrT_UM!aE#SSsWJ{Wrxne|H&OUOsz6U=pFsZSZ%2+P*> z9u+@T_X)oS-UtaS8Vns|B9Ky#+ZD8DyI2^(N6umawW0jmy4lx4J{4kmPDe+yJTqX%g43a8QB1YrS=L^#OxJXFL{8DRo zrKiv=zVsqioNY)pqf{L%Z}^heIn8>>C`CdnQ@uo^gtg_v8{RzH{9!@60p<#h62+rI zoPf+B8_YGPB2LpZE#B-{eG|Dggf_fbZV8$OSZ%c*o24$iJMFN^XmD;>rSkn$#T|~xxJI1g1f<;@I{JDKxi>=~J z1?(%utB8Ws6S=pq_LRn^C=;tW4lio;GkOIZ?+4A&8mLL#{o_ke9}Q%a?)H{r@Fygc z)#v9(b|3zji|p|$@It<=Km9@)+$2o+j$a;MlvWuCk1^m&pAgTO77Gifm2!}x7u0gD z#;ezgR=ao37Oogs1ifMYFW54`g`aHppr_Uqkno2aEVtH0xSLP6r}z)S?X=lM;AjY1 zSFPKEa9%ru3!NqYq|-2aj$CFnc@9Rkf}fOfJNQK;MC3L%Z!0cQ#d=ZRi5^+8DYB^> zR-&b`s8IJ@Gu5-jp(yja6CwfS7!OLUkBB1PYL7?BN=MC-j22o|A(i>swF#CXm9Z+i zo)x(mvKt3aaiSO8CYm}>t|~{n9vWeNuIPUx7cWuobIC*n{mvuexU_~?EZgd=D6X^< z7NP_n#WM);YPr_M+i%oN7v1G?9OH!AMau;H?)Y#Qy_#9tFjHj+TLy({%Wm$n{~YTv ztKqmuvFv>L*SE?I2F13R^YNNIt^1sVsE?u0(%c_wuE@UPXj?k>JVsL5NKwfwT8;{& zi1+F6%KeDjX)~5U28$p=y}6p zuck@uN|tMp;|RVk+(~9_Wv@2s}D2S&@bId8yJb^A4isd>_e_g%}! z$T=ggu=wAL=DnLh+~rQ^JJHBL)~=qs`==VW`8JY}7(cJw7_VXjtl|w0L$HT|g!R38 zIw@FgdIvYIHjG7kPD+#sHd8ohS=qtTkkR+~SEQGF%9QY4jfh8DP+Y8%&rvf-b8b1v zg2wd&a=!CXPY0Tad^c+xDX&;{dG z%iSWZov*_HV9?+wrxgAhbB_1kiLmjB90(%Y;J2?oK4@^4O*)pFvIZu?I`|xP1Uw0y z_yJL5TY640oGLhqwxoxmLgrGUnlTH7e0@T)9wy5lMiAqpFt)2iuG}vQJvq}%6oeO10 zT#4$_S4`!KwU`I?ab-mLY~G?4MUa18iG=BN^Wc){4YVWVx zPY{A=qgP1+2T;80`^(O2`CG)3S0_d11Bs6s$8jLO0R#YWu4l|G^?mt#`J)ynHBNmi zQSie}M{8^=%00GEBD!>DCd3lBq|Fg`kD!voE3iYe!X+Ur-Au*yc#^F=YLtg|Ut$$q zCw(gapsb%xe5ta0!*N~FYXszgYn42W8BaC#`JgP+7JYFpM-7TBYVc8Fn<6pyd-|_t zwNij$rzO}YpkRL9VXB8@F*6+XOg&)qqqUl`lWfO^J6tK4jqoaFbw3K6N$NFO9H|}7 zMTnyH8^{zB7%%a90@4%ovo_IAD6skx+hT8r+UIve!4*UF52reyab5~t{c-LHepx@o zojTHF8xf;B-F}j2cEBvaKR@%nEw(xy_amfK=1wI}uEH#X@Em#Ac9%N-dD9`@v+IF; zzP~0K0LQ>kKOGgt6FGf#ZklA5yYn~3}M!#|ar8`MhEYVJ0K;qGNfg;==J za2FK(2`fS2m&1RU0B4Vo)fOnVmgS@Or2vr-0(VPv!&(S=S=`p4@!IU;KGmg{^(dA= zQLFEO?g6FnM$@ud;8ooJuK)~Qn#a-j)7}PoKS-Uo2YxL+RV#cd6(|{1!@%g)ub1yt z@t|ew0!f_PSGuJc-pcFJv#xNv|KywBNI-oX9iLiLwn!T{c(FB7X(%&kWD%kie?Y_$ zxtw1IySg^!;a$UkaOZ^>va7La;wK!lIn7=r0)Wcob2wiI@yqIQRxhIeW~d9!C$`s* zIYBHlan@|~NK)@<of0gOofqyu8Xm--`;c9RH3{p-Jk3nJusj^>HKd#q6i{Ij4sSALxEJ% zTER%~7ouDt(052l*T}OL09CDfN%ocF^BQ?4Vfy+g7MTd&$X)G;j6 zw1poV0hDu0tmse4S@a-BfN<_T?M)3OwVN9r>SqO-){r}>_sxFbI4I*ZY|TVl3A<9K zK@Un?A|S_4cj5cVVJVv{%!|Q&TKTS?c__Jb!B$~p_pv&;S}k>=w86@A!W$v{MC=LH zR7qXGe&nVUFEQF*bxXK;6w)`rFy?f3Pq(q{UM0h7igH!n`hmXmK1g5c*P|*7Pxx5O z??l!{8{QM|QyxUq?g%Lul8vK9d1v=dnlI7uOz1}+U<=#tslJ}R-T*Hh#;$xlTPe); znHi~bL^qxIx8GBhxw}xsMS(erj>AN>-}S$xwGdKJp-5Bx)+G4~U1smw3pS@k)WQB< z71!M(30>5_sjb*kfcO?gO!}CM;q~+wL(vUqJ6nj^{xAB9N>aW@1ofa z)sH;6E~y<+@1eD#?LKv^*iJHL9ekF|V(sN&UFh$(l`K$2OoPW^D+q*jQjT5w!<}{F zYSgw>OP3r;X}@-fHoFlmgmL<}B~D{b$5Z}|^NTB92UFUYhbMETpAO@&dTjnqbh;|< zQLuYj$|z?UPnc}&#VE#r3SXNX&0o|3KfB_lJK97-y`X$$&3BOKg51qILh0$Ab~A1Z zYt8if_;f+NeV`3<{lD2+x@gHa0e&9jBY%?Q=eTPnixxQ1S8?K5%p|R8k@J5Q)|Dh~ zQNchSnlhy&1=AUJKKya<9!zW4HeaGC3qk!rk-^9 z&KCo=(d=baBdPSZ0kdHr8^32HszjQoe-}yb3I1DCqTN47{ z#M__o0XlHs@ajKqadI+um!Yz4fte6ObhQ;alve303o28~0+d==-wU3rjp@mW5LjEG z9Be`u)jexEbOno)Ic8$so4XELvQAO!K`H<2R`PoSR^%#h8e-&V#qHmwWKWlW+z^&q zeC`xgKk&v@|4VD?jdfm4!^*;Uzz%1?&O$k&=s`xg{t&iLyx7{-{x@ynTbTZi%V$9y z{Pc)B;e_Pi0P_2k3Uh+NS?3OLww=GiV9YP%FSP)2iaUqczihPrn*L>psQ?*ze7W1g zSKUSCER6iKK)a;V!{Z`IC|W6*_K1>oOLZdMTj(0-w>A*gP+CcxEvsCB1X!Xl+Y;Vd z{53V%Xeqr;L&8_sFwwkMG91`q3mN`h>Js!iu;kJu6aG}qfLwSnHK`z)C#<#`Q{f}ipWLzTmpLo4g+E(}$}RAaUElc(v_cuZ5d zXt5-*ruE`CncYSflmL!Vag1?%<#1?A5-RiuUwEuj>9^Q$9BiEV8e7=I&ZsSnLDZOa z*b>RkNBAWfK}BTq2eFkhtpFPtlstJ>`vr>3ND(_3Szf9sH_8tG^c7_93W(shM_iuKHT$uQb_-Ul>EMMLIs5T+|L=oVYtPgOF4-PpRU)g$S zcW4hN6)DE=2g31BG5=`;MJstnic$IE6~uy*yy3k$I+fDIW=CU6`1&1n1<4%h%-9b6 zex-@M3?j#_ICRf7NATSlFIbGnfgkp+^v@T*8JTlW)k=j z@AAh9@wVME!6ukiAg8PV->P z$$+I(*$w>jgdDi`P>6qimtSA8FN4>@)Hsqa`B1OktGl@z>U{?PaIt@D(i?DwE{ z`%I<_y^C&^QxtWOo;?Bg@ed)06y^-m?Kx=2aYS3Z;1bMoV$E1kO>tsgK8t%d7NS+E zK}lYi8eMVZQOXK)463`D*vIYcC%yW~z;YI+mV^O{={8rm*W?Piv71sBAty zA?%%f4Ta=*(30~}0~Ulwrh5-Op5UKO`B6VtUHKKim{!|dHww(Gc$_%U`bP(fg^jSE zbq_!MlACs+C|FhRDYUw!?5Ja=-n#qJtDsyW><338ay1JVr%Aq zY~6bGaC`pvcs(kT^T;_ooB6616x8p%-B;o%7$HAq>HcK@UsIwxw)yW6{W~P5f@D?F zti38Rd{Ahzad{eQmb=5DUD`lQ*B^#i4RkFJB!Mk}dlt zW31e*NyrtV>2)_>hz$@7QL|2_qXx{8w{Z;-Sv@Mq!;ppw0ZHj@;|Y#JykwifUaTY` z1sKV=W~#3yXw)B5gnoD^d{jQw0^nB1P(Zt8G6+7*FjVn*n*)2tnH5*&RR~wG1Co7l zrmH=K?Tn{Rxb24ifc^%@Moe*ya41;pW!H3AgH(^g6p5b0>|4>mzqN~I4!LhkQ-8-A z`YPZQ-w`2=SJl<8P>pQn=D)ht?2*L^6XO2nv!QRx-GQ*GkFBPdf@<0JuqMoADwh2A z)7?_`#lO$zd6*a1&0SS0yvG ztzW~F+o&dePCQ<-3`G5W()jot3xychSus-oL6TPM_i0UHCHz;+?Sd zy;SHC>=JyX{HIM3UTZnb=5hez?&c)8_RaqCdUq|?wsfbK%(1G6^Ws|njD0>Vvrwxm z5awT!0H~+kD{NOPH`55Agrq11vE_G=wKs1(ojnWy1+xA!ctow-%D(TP?GvPCTiTy5 z;n8%H0r}ZuT*-G4nN<4X=lf65#26UVX8u>KvZ4H!!pM;mxQ+l5gp@B}T0^WM=^lvy z+2C*iY;#uNd(8`^%ZMW?pcdH->>Z?>Wid_4Wx&BS7k75%X@w)QemXSJrxBpYx)P{~ z#l>qZkC+)RtR4iejrMa}ZD+eV=Bd>Hp!08ZL+cXtgsP5_^eCgNy~1rN(d!3VbXo2N z%4no%{AGWcwBd${rU)VR)U;s>ahG}&p{MMbG3%{uJEG!s9dw{6t8u3iP_in^MfY#f zp+ejb-m}kC`*t^-r1RJ6(tkYOZo}21?h^Dnclp)Tr)iL-J7^8M`}l79&G9;q&pSIY zM1j0^DI#Z{sFRa}?t@gx(&4*NsVu32S36dWA*&&oLt3@$4EX)A0jnrdvB_j(Sz_|b6cK5DPPm8| zm|#7aG3jd++h?Q6oTMds62m&J5p5d+tDgO3&%zrKS)5~W)WsUcu7)GnG^%ZEa^MTc z!h{zum8+ZCg1~um$kIzucj@G@=Cw7d&FeA)+QF>uxWD?R*frb;iAGkpb%QCRZY#z0 zsfOL&()I5Zy=-2*TC_?fgF9i!!cYWvD!~#fET9;m3L@{_ggmT zmodn%n#GT`&3=U|EqyN%gcGy#G|Xk}o`?{&F_{VvconeQCz?=Uw^F(K4>T&?XX`Y{WrC45vf- z4f0zn5j#+gbq)cE*xlHJRwuYT$>Q>=4sA>`0w>4szX+Hh6DdsR(z`lBfq?3kAt*dQ zDUvd)$>iK&s~-4c_+34=DSp_HjymmB4G*9|sy1qf@#d*zP5@(Ft%Gu*9{Xd7^VS*R zt7YBDX8ze&T(jbHTmAGIYkTn;Q`@9DbVrV;S_oODAh}yc$V`Xg;wa%gi^aa_>iZU@ z0Jg$-uw1$JIW8!g^EObg*Oj|Z;+!KjZv?B`+GUT@oX4>5=2Q4jvR?vrR!C083RuSY zyQ*S0A^^7h>@>su`G*qVse*cA$(&-X3asJpZ7?=}>lhtlx}8q>ot86KzY%9q#^E@J zA*DB`eKDX^sog`C?mqLREv26?@eKLlvv~BDH(PsZ7-dysifHw&ATVpzV&Etr%l^aa zF{gAl8Tjm0m=77`_Au}QCLA(~$?2)c``utg-|$Ok9P;x1LivMdhnK}YV84aeC{7*V(>ISX^BbwavwZbJy) zs?%+bZ@n&;wdx1Xml0aSz~1=&mQd?fA(*OYEjS?$LHlmX9)#M_)Y{;}=)#{7AKT3& zvidD7mm@#|5TY_GN7(;VfGu0?>YV&ID)d7-fWwYGsU;CRxg^VRCGk7qwI8HJ7~oySjfKsgNi9uH5-lmyf?PC)V)`ylpMAK7DAs z$MjvTi?7soQUhO*w;u`EJ?H~ki(kt!n16LzZtPr7GMfXLfNBriazp?5y`6lt=z`{q zm8v^()A`@@Ipt}Az!k&zEe+v{)2|8yur%q%KoCorCe3Z-+}^^UsWZwo(->rd#d36e zKu2xqRL-Ao^RN3PQmfE3iRI{)c8?@FUxTfVol2d`QiPMLnFWc?R=Ta&`_dRuYt!7>AvlY9w%D6h ztuZT?QILgaw<0_?nJ_Sd6f3dp7&VkeI3!5QA78^+BQ7^_XmIv@9>k)ByfN;$&$co- zjE+v!;VW>yQ$bmsgGshkrOog6Zm!7EOpQl^cxg%xNYY7d+0X~0AMTNy*o5`UZr2VH zVp8ra(;^SAwq}6w9DSc2G6~ztrE_!7>P+yO)XY%xN6lLaz&bOjN$ayAqVO&YAX3s+1=9&h53BMn}{;LL@oZ!0>EauAl= z-Eb0_rPa&k@74%fQt~YWE<@q|lWa0^z06~9a>RehcPVFr&3>M~0n7PJFMX^;7^x|h zs;1o=rGJKbaR&j>RD6v#u|~fno5OY+E~nS>=OZ3V#$gGoqBh1r@_9{)X+@%X=IC&?2Vd1Su>S*Foy{=A0xcu!+aC*Y zmZP(2edWm~glt^e?@c(TGkI5)gXbo6d=^y8))cac43Z2X)A`4qt&b%({QYyF9CJG` z+>7JTc2r5+Z6mkJY8EpAR|q^kX~%{)9i$CsRQ}m0P(RD_1?{d~$@v9fjHo2O{A}OY zF^4Huy4z?gc?|d?9ZQQk<8t6Xz5~=+pl=ix=R5Wbf5yR<=W1QXk z)7<$#4CbkzggnWE}=vO$4-R5{yS9rJ=xVjYF$CCh27YDw^BX*`qa5; zaK#e^)1Tr^2#||2KUJK>ba1E-ePzADijy6bq^+6siCz+mUXXU$TZ>^s0~!t6J*qU3 znZtu#!H$#Z`>EIu$g3*D=Gh|nk|5;^Torw(&m^pguE*%+2KU074fMnO4Fp8Hdo$gx zn8MOnti8$|Aj-Ep@iB}?Kl0rFd5-ND!QK*BPjx>ld#N7L6KX-U)b2Q$%5ESHesjnc zZJtz})X-~jb90;0gbQ$T7m!LGn1c#dTWb-{Q8mAFZa@~_Ja#@h*LTw1h58-pSjng5 zAbf@P6q7Q8y0)DfSeRa)EaiFq`_4?Wy1>WHA&D4MC)9WavW@yGV)7uNe7$o~w2C zP{+h>8_hu=w)4^N{`_tRO3G+I%jH^qIILV&rmukGoI$15qVacaUzDz@KbdFTq+Tw! zF^{%ov!(xXZFlYA_7BZk)3>Lp7k9PO5*_P5=I*HOeH2MVKDQF%0l{ien8i_jsW7(j zOvhI9h!;Iqf8?G5#7xFA_S@Rh!2-j>U3#J#Ip^-5m>y4msPI6}MD8`gypM}U+2Y!b z57;@7o8uo^<@D%bj{luba36pX_~(XLAJP=XyuEN#_Lq;7;1{j#ux*n;q(F9{A?*-* zWHttLpY3OO8^|Leo<8G}E=~F9oq9@L9ZlS_Q-*SdUEZK{T97JWXfk7gW6V04RF`!i zZOfgPl61e}xc>GsWzFqkiP}npAosMfuPy1fvm+RQ$Jddr=8tXIvMhrtVbSlMg>aru z1{2|p*Nrp)GK6oZJ`5{9rF|G~jq|&BfaW|wg-H?n2KV6)olnat1D_izBnZ!G?JL^W zM;W$1hjL?Xk^S-+?@kFYWEfTnx^QwN>0l>Dzuw6vW)kqpfJJBvl?5Qa8`ix6fVbjx zP%pVYrDglcroBVL?DK4>NQ%C&==Uq~5StPbkQM?&Toj0Ss`U80j)qP65u4PGLY6C_ z!X(SK&a(Z)O`;=Ew0mzE2JPbyaEX5bOlP{YYvCXxm&euJ_$iEfLsY0~>|NBRWDi?4 zO*~Mr)+#n8!-r_x&_<(7uF!+Gj|{|-7uiAETkwtwC~vAgAU_ijJt3GxTjBnMm8kpa zN24|Qu$#GzoV2i~bB{6&E{$4F>nozEVP}v`H_eH4`!|bGBT4eHdtEcZ0yNnmHy`V| zq>M2IRHnzU9bvI=T_N#L`)xaB@{6M(N1(XkqmUKL_H~)2l8z0vm=^ za#21f3tTc+UgOCmrDWb9UVSJ99rZs|DO$;l<_*(V$fGSk%A*1Df{j$HW)`^vMarr1 z>{)eKVbfn9MG@OflnLPUlHYNp>_aa4uFfl8o#wvs*$JWqUt0&44heB3fCXZdgR_M!y<{zSsc-cb{7Y`YCR}yD2Oso7%|f! zgpR-PbUX36z1=luW=QF+u?yw|a{{Gn*FpFqHThK(ybkJDgttK;9w&UseTU^&_xLK< zb4+$}=pRo1EJ9^F(sG`1Oq?=R%hErl0l#b!Au!3Rsf^R`BZl*U%p42Bqij(v+l0aC zznJa%6zCsuvOWo`*0$x#$}@Wss^PtH)&U@q4VQ%H+p|UdlC3)aK$H}<>Nh`oD(4CD>rG^%Ch=XK6j6fI*1T@A(8L{zvc%Tg&2A@f)*M}{Or+X)1j?}tQh{_QWPW&th&h`7Ts zaDDTWab|IT>AU2}`i5f*$Y2L(ENIN(qEgk+0)kt{_xvx{{jAV3$=l+J_+aCj#}vBJRkQ7 z-~2lK1F=ApkMFI@jn|s6zTZGqCk*j}==rtG^CyD;sbQxIkN&xt*v~~+(P;ZvLt?+R zi82BMT*ohLu{7vs9>g_=H$JK6Q?8wACXHE>OKztAdQ=GB(a`JtK1iRz(l$DDq+H6M z@|MZmiD+?5{^sWnLs$axQX*>9)AKWm@o!xY&DS_DgC51LD+DWi8{Z{~xGrn**Hkd? znIvK>EQxi=fTrrkdYpXy;_1xjH`EFY=97D8@wYM~GUm~|OMhSS&8zo<^Bg^EKOKkG z-d)lWF#CQcC#F~9U)YMzUY%3DA~>(bbyd}bI?ET#i{TY(eqyuIo ztROkhCOi@+g?vE>m6Bk6e4JDudfn>T@2AbmN6gS7l>&0tMyfmsI{uYU_jz9(c$(JS9ow|J-SiAt_|h4J>$=|l61ON6U{%!+I$uDP zaqn1dK2rqDI6&B~pW`j1n*JTj$aSGo?hKx2efr;_gE@Ehhz-JWli-x*<>K|2mhf{q zgy*n31mu}xS?B?a%wzfe&9bJ_NRd+(a!uGMyf%QB?Uz9J3Qafe#cP!G>jO%Xf_l7( zPVow34^mwYYdVj0VR1uYi%&D-QMuqp5_IuaK#c4q*$5`|GTO7sIhrUGs5#`i!$9JY z)oJ-Od%L42-K9I^*{K7C?(WovKFnQuJQw$QZ_QJ-&`7TVUzP%i7g{HD;qmzVTHybd zy)JYA^>)UWvx}`s!l$_w1%q1$2QIRW6}Xo#yQX2~Fys3YwQppm{`uVBnUBjD4HA?d zwvYXT@k?K8SoLDEO}%%~l-hLgDoMAJKH8yl_USomRzscsXt|s9b|Vh%PtLYwOB&KI zvWtvyvei)JWVNJESqdr+l}q_!F%8_&3Ko(r92IEjU9Qo?Ic80*pfM7=oxC$=*jl42j;((k2U*# z1ADYRG&POEmwv~Zt%1V@lX|9qR+}X8;KXG!s9qPE{AB&>qw@4H!5#SXrhh9vbM}Y* z{*wnU;j2RDEs|G%d21##`(5u@1rbkLq0%p_d3>L-JhEClpZM%r^0J5zAgmxdc)9hE zbq;^t{0hqTMk^o4ia`Kz6H1VoA}Da<|@ zh<>3KG@hHNj!5J1%GDv5hbUIKHV%dJblepLq$$ewSTfZrse|xv9g#4#fr{v*mewOZ zFPUs=53OyZG4>1d*M2&w)|F_VtFh%<$v54ZR;B!{LUcO};RzwWtQ~nDpLEjt-BuF6Ul$)8vqmDiA9V*xh2H^C{$_DR$N!G&Cvw+@(9n*Z#nRl; z&FSmQhoX#6_&xGsr}L|_UVEs1zsMauH1o@wD3QxZjdYW9eO{F!&zp#~#rA%Rxn~>h z+I*YFX8^4d`g<`q;kc3ZopK@S27ISH!;NUH z-~_vRR%|^OJVYfttoUt2TqOuUuHGRdZjiTE`xiRA;}M7Jh%9(cRi%R(OAujWsRXUD zyHNyPxJ+k({KItQP<7_tjnR-LyAu9R?pkjf0)v!@JlS;cNR^eo&6KqPEOV>9v&Nga z=VU_|br8FHyS;gBKuN)+oLS2Rz1e{j(#FY8^8v!KHU*7qiC&sMy5h@P_P>0f7eL@P z=1zo&+tyPWW||yle0n!sVp=v{6+(oo3&*N#m-NeCzEsYk%Jw>LUWwTlnd$dhgJ&-s zuXiHSbxnvvywL^7;KR2&itjuZYKDUxeBl|5JFL|VBPl}gw223KvcNWqOuvcF8+^pQ zvmX5o-h{B9`15T~IF6X}g+Htxk>H+x-g-bpt|JaDgztwMozM~7_+)_s#6lOpKKQ5- zQ>+i4xQJ%I%_fMK^TL!Bjknz05T*I6SdybH)9XY6Fe8ug4oH&2q#K)a z^UpWi5RXAwNXjB5K9xUPqr0c=o=@oFW+h`}{-U$?dY8_9&okdt zJPc1`EnJn~pKSQ@ZqQ6q`qC6}O3=IZ-_wb|4+c}zS&2y&`7z29bT>>K-3|HO6FD4o zUZZsaJ8v9(`Z`XQB!o}91RSIlQ^V>~ece(#@`ek#<2HS*M!h1}foeya3!;AjJsDR< z(e8)W>;K6OxLjv_j%Tpz5I&;W*js*J2=*eym$Eo|*f=n4p*?upv!VkdY*7-h8k{%_ z-IX|#jqkQ^S5)GIzS1ni^1+ul?@KGiRm*9Nu6I7co&N`)h3=XPj?-4k7mAwl*=ZAk{{*sDc8=>;^zoeIbo5Tc z+h1RXJdwqyjY0NPX#|~YM*LIv+@!RdC!3{LwND9i>|gdE@i|Lw@VDS|#oYFW)0B54=r|hPkghuuQjg6%!CRXH)N-n(#YiNJeBd^}t3y#q>8EysG^)=c-5fuf>24Fha`s~-#~5--X5LRZ0so)XkGObO$n(`$HSYd1HWsY}X_WEvEx zTIw*oE<(#DN5wYfIT5nr%dtlKXtTnD+vS}fO_oYbd?NXKymsylml&CO{VF#NXeZX! zs5dLk0k3!r*)nbky!I7O<|b% ze3e@}KXjAm&hEr+L()<%L`x=(%Q!&=OcNB?Y#HY>BpFmU^YtyoFRdr$Wza-|&tgID zd8*xu2zp80Sgd9@lZtEHEondM_;ckM-$2pQ%PUc#Xt<>``o5Ly#GCifk_497gj)#4 zMv5#ah{&&9f<^dzqvnSObVd}#-DC7=Y+W?A41*) z>XHx(czl&zlkB)1b6ES2v*iTeq2qQJe$M6ebP~7nKNY=QVv42c*X%bNA0QSgu{4Lu zMjKi}Jc88q>!7ATCUi`FhnbNmWtNjHNHTpwU|)ke8auJXM(KwNVJLucMCC_4L^#0~ z$PgkR{fd8dgjPJ$+_=LL{B*WCV^I4b|fmYQf2 z?2}4-V?!%en^>!TZIV+bJJHFB#L-BGW{_#I$n>!q-?~A=9XRH@Qx}6P*x@9+ECXF1~J)BX}Z}%9RnyU-R)oG7zHl$i^z}g`9JUP7KU5^XX$+OwDEX zMd*>MlmQw+pud=>@?7t`BN$~W@$>e3u_IR7D+zU11_+3#Y^*a6VTLHacF`VB z`(7$?0XVmxW>USXCYd%iU$cj|>z?0EP`=Wv=W|w(s(*ikAdbcx-)_Azay*;2&ozdk zpY!!UR{bwdK}zF!uqN>(jioHjiGD)N1*7cmL}q#>e$5q&fQ#Ha3A_mmMUS+Wq&7O* zgadDo%P*97PFuPVQ-^fg@|I#1RZf(kJIkcUMsSXUii*-Ciuk`+>S(UX|FL+3!t9wnB)oU{2`7A+VESX*KIsZ%j| zi2!qcdGoQG(8Ne3EzWC%0nhdoCabR&5vUGa>ZyAY4=fJ8otK}co^h^Gx4oCo%?;lF zws*W+tl4UxXoNtM`keG#*Au0b#_F_5%JrJ+vJbGJ$QucT$7OH^0q9L32`P!7WhBVb z2XzOzp%ib3c$2dY>e(m^f$FK|trNAETWSF#nAG^{JjY^nM4L&w1(!&Yc;L5!(s)Pu z3}tCyV^c)Y6t84fv(CgsdQHnf%hm-ctZ76<~?NwJ-j?sbvGz!^7z6V7^4%+&cIqb}-nO3FqgN(}vm=D&@ z4a9L^39WrChXpUk@kZg)9juDw&%PCxUPuLIDL8?i>6%XB8smCV<_60*(?sX3^VG4S zuT2PSwjB@8xrZ*8{=`64C(@t6vGAMAW;bsMh&PM&TUsYhSS2zoq$OblLuqzuzJ5Yh z-?KmD6M|ghvMuX*OuA*;C6m+1=bK==+($;DTi*Vu>8dB4RESNxY zn__kE2)hJwX}x>`3Pyooz2Umu{XAx=^g=t0McU4eiN)Ccc4FCNLGYMWL}m>l*L0{E zqp;LFUlr2|sX!!1)WC@>p*yuFQKxAkn2_lQi9}VFRWdSI%UX(4vU+ARO&8Us&zk5Q zvX@4m+HLsI%Pct^!r9_2y^#4)7rZpmFx6f-FxY8}=^5fe>B;iIpLP` zCEoz5n7<^|(Ji7M$ePLCI0hm#XLaqQm3r$A*xXdTzHe4m^0knfe7zHz_j!+4zmpImRL+`hfTUg7`!fpT?Ip=c`T{M0XcVm_0K$PMa) z4a&%V-vwDR^_9~pINqDwgpLnDpva_m1t`EUPnK#{|Ai6}@d3yX7lTFdCn={gm?Y-2 zjyt4|VVM{dWLE&fJ7W~?ouN|i4T@4w6Ogp=8r~yH6nw6vGLo#=Pp`8T)nolr=c)wM znSC^&hBK9uv|$vAb*{5NkmQNp$tfRv|H&vBQI-JslzT~pqNmrpBy;%_GiYwEx!iOo_{jZwvGKp)*G~r=NWlJE^-5r>=m~|9vbhH71S3C1HNojI6fK zh3=iV{MJ#Itg1Z=Tb32WQbtf#g958bRzZ&y&(?q<=X3LHxa?W>Obs01kg;g$!7_EU zVkb)Uy-rS|$CbH*mV#|X0TPOh9zk$o5R#*9oh-lrr%qD6o^O`)I#|s9of~DmhMp%r z%CwE9@})LoJ9;ZpdXX|%>;e`Xg;A|_raSAcyTlQ41f6yG%ffUDBvDRGtmch9jqR*! zR6C4axTZScr0LR#`M2-#oY@t&gWTAY-qbRTqG7l2al43PRIA8W=xgkctSNLcI&Ren zfX!{Q$V+!hzc5|!z8=HY`YH;rRvmq4>dy`pYgvv?2;DaI?8R-Tu6}!!Sx~Yj0p)rhl53^k0de>?cqw6SO(GoBiG8g3g z!~F=w=247|7^RC)jTZ0|^{%DMm|==%3^!5C9?Ff<{NW`-U?vcM4i;8aN2W73xI@jSa=k~x>Z^v0cmejM*QjxM z0M{1FJ|wR@Z3iTZ@TN2iVV`VD34}>O(Z`V&i>x@0uQjU~Poe zd+$@T!(47%-_&C@Xo91O{-tZYgt74sp6}fM9=Z`U^77=HeftFAZyQ)r>O@)3eb3?AW**=TH3MuJR}IYrU@ zMF4okbFceM0vx`aLGqSAjP(D}CXvTgsny<@+v_gqm9$mq{h6se=F8cFN6{K@#E4UW z|5TU$BY^&h!qh1)NVH;+-bYzY;ZgN;Twjy)JXBP=R5p3KRx>!U$eiDn z(oY%C(gWU_ofTnIViybgrpU1C!%JMsbapIPd(v0;8$;_)ts9q|ftSaIF!A#M_;WiO z54LSuuAdDcJuu#CsP1c{&gb$6*tOOIycK$Vu6<6J@(eV1f1Eu~ea(^^klRNVclPpm z%KOq0_di3>;>$A4zbF-u2#a-^OM^ z?mO?r0Bm^4YJe5Gi8{cSouy7t%JQ%WU6{0V3kt~2fi9}h8T0VgdfRGkHAtS2^omaS z>d!T)1XXL}Pe{A?iqJ{}-qxB?&usmsN9ff_V`tG9tjBqLSs-e=EsxzGEjQp+_byEY z)`t1l9Dta3-%TXfLI#I+>K_`+8{@@SOSu0ubI<(umc^FZVP_m~XGe64-eHXhTbu87 z9ss=#(eaDjkHhEX^Z4=oUgQl`;}659VkV+st~Bo0%^TCps_Xx7{BfnZL?oh+0ePp* zi>r@U_iZT&(~3bE6c}HNIIob6B;DQ>DZIh=q)yeKk=7<*i)4{$yKfd%{OCR!5$sfB z=^rR#WKtnl9H1LthFjPtAusy#8$)Trd41J^1_7*<01{f(DSPsx9!4}9Gb2fvppcP1 z>`ZyaGpR_%DxxO4Z_CXZ(C=3Di3gau=~}&aIR41(A@I8fy!YJu^u37IJh4@0l|Vgq zIaD)H_C^(wJ)-?)g%z^Z5wEBt$Z+^uo=(}Q1;X%*LcObLv#%NM-WBa5OHg6NZl89_ zgF4qXbFE6gzihWY+FzG*ZGESZs zNpL)A=H2W2^W!~i=lsTF7eGdi_uuY&^j=@>2ltmTML8Huwd7}+z;N1AZvWuZb0$cF zJ4Oo4>CcAn_d^SE-U0Bp`j1W@r5dR)|rZ84Q@y~zD~XuiR=XE*mGx^btivmy=qZ&c=iZq zRjBR!3(i~~pPHotqq?ARtdprGQyw3pmr91p_GVn#4xWlQe5c!?|C5l?#%g83XUYz5l|I*cQ}AS{DzZ$@e=Naek&IT2m^ANQP1A~COf9$hOeQit zzF0P3rCAJ8wNe!2k&H7okyH?Ngd1M)$CMo9E?`|aD^XC|l7JzG6?+%Jmt``f;28cV zb6pfKifnEzEnJg!W&;QKbD+k=kr_BdQfi1dIW8`4+@{(aoBzqg_vTA5asw#ihPI^F zTe_;jjp-@&SpD$R8; z&@{Sv;*=72PGKOfd+<5*CdtVw41+s9e)g8vTzT-HqD|r!<+fmLyIDQiKF{yF^BPRmgK?9=llC-*Z14Q&3o1Fd5eV*@)f_)9FMJ7m!lVPz%WEym|yblBo)!Zv4Y`+gdM@CIvKI5Ds z*JUZkIJ`&Op6~If_K3Wf!)m&O{7-_eLY@3#UJP7ZlNEM9dFV9+I8k5 zkb}D5v*iNtMfH)#r9HX*otykK_`~{!{DuX7{u=MbGG|WOu88Z;RYiA76=2@WYe-yP zI7(352ezFq3@>%Udf}5kad+J7m$s`Im_oiZNgiPAtdJJKLN8^ING*40LhN#9;EqJQ z<}szMfgU5TOj{NzAW>ovICY-VECPMDUBS`#$bxFgs$?AR*~}d#Tca@!AM#uUP}IHp zmFI`Q`w}NF2%+bG7+`E!$7plg1eV-7eH|Yc3Zek7n5==be*S)>602mP!L3(H4|dQV zceoDq8Qa%FiEzpU1MtuHF0NCC=_-+m`Da3m!NXB*d-55887AB&cj`-X_c_FroT69e zY(#9|G->5LbaL8fYQqhnw87cP4vP=9M|FqGlPRnJ>`C>B(sRuGnrb!BYNk55yF@J5 z2^Yb0rWEM1(fXfwm5Z6%?5FH{F`Z@COA(hr*P#7hb%o=HMh())sohBii1$&S+G{k6 z8m1ZPisE0zbU~9SARth5G z1{6_8-+hRnLsN@ahmw|=3N?6F;&wTx;rJ*peiF4y6&;`wQ5lTV3npjuY#UQF-NJMb+Eq+X*V$%2%c#Vf;cToFtMSsi^qLLh zQkKmqRYHRNPUFe6w1yn*67=nE&}GTrWeUf153h+77zsjWO3>2`p%zY$OpkJBtKsUF zq_R5O$sN@>K42J8L5r1Vm7?R+Q;x0$iA)~=fLtGG-d7#oM<1-WT4SjIAW5y4&|ac* z1kNpeP}3$+!X5{VM~_L9)mW9(1jX*4)i-~b=7wU|l2you=FQCMK$Z!nnu71p6guK< zx8cRLAyMbA&#q>fUEBx;r=C4;dXhX2O`crto~XQ zWFFq+td8~}VkcB6Q>I9kDE^;6Iu>w(>V*nl@{x`8Cz$%#R|AsW{iDg%kCqkz`}H6~ z?8Vb(AuYcWoY5SSjKgN`$WBCy;tvUcpmk{CAQ$3mn19l#wI*xf5`Nm%jgW-sgISd; zRQ*Kf@hJjn5^#I|0|KGU`}l`Yh3x^0D650TsyCRfs)xF*j*j-cjOi-evL1Y}v*zcn zeou5hao^u{dNyQ7c7BIqk9?nHH%lLY8&Jg$TyQ0vAs4g=A}u!TciyGW&>=mnUqDkK zFD)xj7&#FBJR6tR*1-2v5^dkKt7+9@e{0bpE+P9};r|so5qlM=2asqd>cT3B!~fe6 zP=YA%#Hfs-1c8Q5kR zM8GCNt=_k)LL|y`$o}8v`_I!k%-Qg6W4e~)Z!veoU1IwKmOsV5;szSs@!D`y6b3de z4qBwua2*6pf=S?;An4W-j_&6OONF7HD8UldVEW+H^YomYV6Ez2(pjstz{8l6MoZv>;%fB`JL#PrBfV7p+ z>0VQ|V7m=ImCymGgXx0dhCuBO=&g5sl}ljM3-i#cC|n>Oe(95=l#8&3dQ@{^F~Z{O zF*x@&U3Yc#?mT?vl2m8JH_}<7a{nI>E>TpIfx*velAWd&A?{7h!2~HP3XdQ>R^%7P zjBJRYAud7>Cjt+hY;+bIf?Irs-Y2~p;L8@-GXM1qWnERW8RC4|j!_{X1EO!@lz{=n ztPC~2JZb1}wwH>%+L*kT$~=|{D!hymi=rKo$&j33)XP##VpJdq!L71Z9>5-KYT{(k z5MX2`=I=9j1GomQP{em?qI54iApN#&o2%d+UWyHyQryZ zB(Q_;>kk=@wToBEsW#UH4<^KkZWP9@Ug$43g5H;gE>{2nX3qb+1|`V^?koy(Sh`w^ zE;aClZr%OQhZ-jauHtW*%Z#{YgT@21bEUeDfi!tGSp==%Rg}$Vm;_ks3E{xR_$vCp zS7U+CjtL5##F$SBlzuHIAlO%98 z`Nk?3{@+0H5;|_xOdXSfb^&lB+-yl+m#YH+pyGgZmtZNQ3x9EICx^!eZg481v^RoH~;A> zhh=bF1@m|s8%1}|QgNG>)Gd+0qe`s#ICuZ#!l#vxC?V7(1Y;PeZT%G*&~<092~%LS zr2Zl0mm8Co`MEpDQ<`Ik#TiR9wqyj6Gw_DX7k6Wjj9_Lxay=R_TH&=23e#6V!J3nz zs;)0C4|0{joW;8kSFtVgj3!+XA~_Tl3JQ#7kQtG>My%$ovytF6vGcB)#r$0+4US3( z?zGv$99QLlim-Mmt0|r8rL01o28#y*nz6-_tk$6yFJi;mD~=#)#seG$1%u#e~giza~g?bB?-W!q?;F?Cxjjw$jtxF3By<^ z#G;PQhE0j`E&iFm$d^ey*#;5eZ)DMfhsf1#7-IgLw~}i%b_&?L5u_fqv|T_D8~Zj_ zQ_Jcv#=QC#f#p%&_HRiDI*oXec91pk20`D4UYIZqaNt*l>_#L`pw%K{=mHCzY)b7R zhlAHOryvli)&{l;&r%R}nJa)w@! zA&EH*?4iR5nS)sUxjV$dC#~UQ7%g%MV;ljsq>1F5BR7@#3QyflrL&?@9lto^3@`Cd z>+s8o*N3?+O({%~UdX6diF}L|0Kr~@Ws+*X;lPT1Di71N*pk1qcFwRbMBbQ5DsAUE z=uXT6-%Llm*nZWs{_wXF;WH~zv+q7&k>|X)U09J6X5c0l%^4hqAozW*42ekf*JO9C zpIQ+_&^7WbVsljdUafsCkKvV#Y>EJD-3< z2>BEt)!Sru|6ww=r3*+lZZoSIJL$2g{z)iW9CY<&lE7qZaE7KY7gB<=kLF^`qN2H; zIqH|98qiL-oE24UE06D-0;1!*)-NtPHBu)-glOf@4K_l(OeZIWs!%0XP^XH6G!1fI zhkB*$Qo;C z0|_T!BSU;U6sJ8!yq;6OOapZ}CLD<;)+)a?U9ipSqlho5t>t?DpSBhi!!n!)RiQ@> zz+BTuOf)G>+i@aGvsxAfTWTKZZp^K^|$BI{ET9#t^k6-?A#TOy(miL!*Dzy2Z zTxM+G`JuCL`M%*ntCXPaX0{1u9jTyz?NsDe4XN5HvckRE*uy+|JFMe*rHJyl_)#k( z!2UX~u>b7m{IXKTa>IG^znva0won6hof=iP(%rO&5fKhQZ~L$Jr>H^%@)oER(<&R6 z!|$(;vJxNJvfZ@j&-_-cFOQp^mup4Jr*O|?CqW+`4vvnj)8Io@p1I^+`*$dxf^VPp z{A7x`g^J}nnA1^5`0j*5E|=-aK9e-19Wy#Pyfso4k2e|i|>Fh2K^90ad^|pTZ|!-vi!)2!T57lb-9KCY)7T^&}cAOjHI@IX#v*d zH&H-T=|;T{*yE|)pomQV#_^HXe;}hqVPou=p&Aq|2kdJCNEx825?bL!^Rba7ih*^{ z^^1hI|J{R}-o_`X$}hVe3T=bW zV^4mdh~W+&@hg;5{rWo@^u9xA5qUe557(Jd%R_l-U^6mMZ$DloFiufwi7z|Pdzsq< zBp6e&!Uh_9`gH8v`9j>Wgw_3Dh~O3uX*s&`j6Y2Pu$qz6p?c>i z{1UXXk}%}M1K6a8ehHqLcvi>DVgRT>BX!1AT7MfF{kF-K)ZyTHT2^Aa7y{i$O!a*z zL6FuZM{}J^?aobU%uB_Z#8S6=Hy{gq_xKx|>v59d8RVh+4Djrr=kBxCxEL|(DJY>i zOSkjQ^L;;Sy1&(H3=U|QWQb&Z+=N%pze)TBe(wCoyNA+vc%+|4>34Q{Q_^wT7Z zs~H2Pb@@zb+%2ZDz?wgGn1fc?&$e%~c5CbQz1nu$A=miOp!k_H8e}%badGJ|yHIdFOQ+?wdHn*Np`@mQ=VDvU zS!`7wx5x@0RiQb{Zn*l0skGrn-hKA)#9GsD!TZ=x?e zmNHI;^#^sp;M@1m&X!KX*ETs-_45vQPFmXOnYc0@$TFnv@fA(Y zhe8~OQE^~hR#Gb$osFAW8e9AEz~^j^RY#qGk)#m2w^H+~AB!rGuPoIbO z(=6q4`I5trk;NCie5q_Km8{x<9ojYhIyzuR>km7{l|yZ^nG%~_yK-JVUs`bdHV>ad z+k&|zBlAtasq!x#FJSo^?V@g%YJQunwPoblKgXw3E!UEsb$Ayn%h-`u0T&xm&t~Hr zrQd-Pf2AO=@*aCX4wc{CdjK4M%U2F@3S$twUekRAUMmt~qcFt%K|UyRVC*Hb zG>QMbVe}29t+ql$XyPy#mqiTPsGv*b!t^6ixahYA`&k!}@$L0$fzG{4u@*75lL^b|e|c1P z2sEoaf1g6pKeDa%lKTqNp==Bf<;TX(4u}(clrr)($!51cRiR+jA-iE)j_t9em(j0% z1g^Cg+0|G?_%H6bj0ueEz{B(Ft3C6g1K~)p0ryiTR%=)TB?))~55PWEYK04IsIIf_ z_(TEKjC7(DRD#pW-m%ZXxKqSdNVu;3Gs6NL?kw{|E(b10G?NT92+{j%l_6>73hP*R{He8VCqKBt7_*q3+!r(Y^=J>Djxpw-S+=#FApaoW4@@wsY@Dikc&N;QJMkc*2N-jx z^5k844-RQ>jBs?8^B#E?#1=V5r`>;MOJ0eZO!vr-Zlzg#Th((a93In1vx=ACN^!kw zsTRn5y-8wd+2L-w@Pkn?l?QIIJz`SUmGKF7LFl9>=v zATCRH)@O4Iz}feul)Zz7v_MbbZvOO)T%5wptSjn&O)eU{e;;f(66q@oFCa#iTfs^r z!p=L974nJi3T8;^!P5jTKm;km)~YKWxFYnY%53a9vHMNjp2`#(2HD#DA!-a#TTNhgl zUJJ9vXV0H<@{XRjFHiatC5F|;r8}u~K8tL0w+`x!2`s>^pZl*4$3mrKGgy%8EnJG} za|Nmed2Aykzr*`L?|W^2sK%kAdrF7R9{HdfI8W&7|I;gUS=k?^{bR$oT7c9E7XgNJ z-Xru7lFWHqQDi-Mkxuexog4bva<>aCJLsD${V}^lkY$cI*7wVZ@3AYvs8)69t`3s4 zH3Q;g_OW^59yydoD^3}$rX@JoyA$F^4GXLF98jYT+@-MFr8x9uokuJ|P_#|MDz-Zl zx>`0OZ^gNK4*aUdZNlw32~ZR(`%lRb3b$`ncTZ20lb{L@-M^qDZaQl5CjuFpqz%?a z1D-ODVxN80v!EyL2TGiBdAH^vr{>lLuH$yQqc@I}OgR(xmD|gwjgHQ+IdiwlBsjw_ofTq123LzA9T>Zg02jjx^uWi#`l#4KO5zK(knFW*Jq}wHfBbGdjhIQE&INSJYC3Pg-!J$}*c)Lt1#8LreJwZxdh! zLv%UYPw(&^x?wYr??jp?T2N^r>sz74Zr2YKHQ<7x3!%$b5&w#*vXRv(CDJA#WcMXu z6n}hzdRj(=o+4~QF?AM=JXP!@SQSu0y@@(v6`+&qA1eUqx;rS<7U)w%wTW-rZ8bq1 z6EFq7Et6Lx?q7+|ZB%%^(f}I60YZ9j3Kx{gSPrm5wiJm9Y?!!q&#p~4=VJIRy`tI* zUA>ItdKwR9yfAlpBy{IwlY3?`^h#D}5&c)0(nWtz#J#7Xh)3r(1gp4i`BVNpp(Q@* zsnhmv8$1SLvFAR)l4XmhjoO|=rM>QWSe2JD580X%Q_BQJSu^Cw1ZCJV8W5>@p9>qv zR`6TJ^0P|az~s}MZruQo**zvTOIfOYpF_Cc+yM*n)g3R_+;5qdn@dl_<0A38-+EPF z9~EV}o*A|Aj~dImHyo#n^8<{wb0;2GzquOtJRpM=sx(1$@QEx{7~N6I;q{Twi@7RA z*@200c|j>){g?T(f(RhYqwGV+^Yc*z#u4O@$pwe?FG;BcU786+j7e~)mB|Pf+v-ce zXuXj^@Q_p)mY&nm@-}tB| zJVyPUQfMTfWD;-J*f}=Kh|%`muoe?wrmoVpY7bN`YZ3WW_1fcdu~#0um)MgSv5IeK zEeNdMieLz6oCDi+;W;CSkN^0>kP!ZOW(E;(BIIf|{=4@o^K;G%II)!MNbQzDlylNp z_@J?Pb)O7-w(`3h>Z6{5K*1_pLLLS&X!!R7mVPB~RGh2V`V{15SpahuhnY(O{Xx|C z<@J^>X)Su~Pfjh`B$4xg%h{+*g@7bYT0(`RvgheZwj#GY4X>l^**Tf7;9A*LXFRI> z>Z-k~j2&mg81f!U%Je7ToAhTkHXhl!j&o|ZbYQL)tVu*}#2YH5KhGn1s z2r@~op_(VfUpA+y44;sItFv2E3-ue@2NhE=M}an;WD^XFv_V|eXeN9%;GKH8@p}RO zT4nP+dTcZ4+*Bil6wFi9KW7zpDx)M%g=bH-QDe22uROUlI;YOybVWX#q;1zaTw+(FbTUM-^cLN1%^sX*zwN2;(bC47Js<|Lz z(0AG(d*qM**uR)aX8S8Ix9N9>_>pmHpS~v+B8{~Ar#lG*3`4739^)=$$`<3Zpz+Nm z4Tulbp4-~`0_`|nKUgcD>4K$J(p|1&VZS6)(q_GYB0bOlx)z`3E9Xx{9TO~NNpaI9 zr`4rCtX^p|xw2IvSWT*bpZ8{a;T3WI-5+)9eaKNWH>gBVI)7&|!OmH@SH8GvUfz^S zmi!6d@BNQC@2RO}S^x2a)W;hQIGXO^j~GslE9(lMo=#jBi9ed@x<9v)2g)2G=EWr& zPW%x1_d1$8O*gE{Ii~a7s*oKsYZ0<1JXbtrCmD*pU*1&d2pt-A!~H~A9TI%3?WML| zE0YC=Qc453Z8*Sm14Z$uP*Dq|F=%R@;r~ zgE7jUtQ$H2(Rw1~i&K>OHe%{0apN`GF+bf?gJo}MerHR{_{1eAN+I!m-)nySeRoxO z8)jRFif{s!;`AWgh^nyyDw})|do6YJkoZy%cmgjU^$j(4VEGnC#E8b9v{ViA-C!Ry zf@esrx(bjyS;Bosp{Wv(Mnd`Msx~4ju}6!s#|<_s1LOz!a;M<>pIGZC?8PB5P@PH& zn?R?Kb(by}`bThi2}5G6kPhsj@LZqpui;+{+DHht+L%tj=G*FBt)eEBDVnY4BR}HD zizRuwR{2DjRl@R3lnf<&cm2zsfeipzY2bhwYD5X!<9wQbeH*$u;S^7 z&NBA0&|%)u$BWQ|0qyITB5Ntd@=H&nlxwvU)Ht9=xHigm8zMtlMD7K}B3Xn>W6p&1 z>K1g>_mf~dl1G)B)(+$-H6XDEmm_PwU~JsET#1L|W3LT=IQi{MygLCVlA!V2^7B43 ztRC^c(D-Blqj^4thcBqrdAYc?#_bR6b?oL(@d71QJB zJMDR6gIxUy7KZ9avsp-4!(BX$DX4ofbRx*%W0%;w5ges)^+!@|cmoI7TV42S+>O51 zFPcKoP+Kby3_h74z<|dU%Mq%4^APJxpXP=ys4Ja>Bc+#?>xZvpw6y+Uoz$|pWZ;Bor8Ox*rfOn-); z;RoKxT^v-opc#{=e40S+AEcnsW{Df3SMIl$tXL*nCg4aK{5qhx((wFpCf_Hvqs-=# z+|P`T#cwuCS9WB6S@wkU6cwu_;NEM4Jdbu+b2s$1RF54)EC_bN%94^VEnrGe-Nf(2 zOeuK%mPdyZR1g1AMr`&ppFYBzELB)5mYrXOD z4w0|68M6IMDxI+&gZ4q@!(PYE9x-;%-P{Mw{=;RfMfJx>d~i*`=(vF;R5gqZh!vFm z@}K^2xpkfmP||ay&DfW`ifMHWk%QMl62f^jSw$UriX@Fr*j{#)q?viIK!t*5gReLn z9@{lo^WQqeD;L-g9s$KD549blFy9I?@Vvb50AGMos= zSPWdaSS|!Q?+(WJu&#?+aX{Ff`lyk6{wsCn2(K}Cm>u$lfnXD>2T4Fg zf7=vmG=LxNHKbrl}4V=tziP5{=&wN z=dQ51yDbYAki{Ov=Tc|Brn_hrYM6Z3wJ9U#`z945;dNW1+>oCtJJe%a)|9w$q-?Rq zA@qjz-2Y!Bt{uj@Lj=;jAs5l(jg#M=vGV9GCU|pmrGUp7wz77eX(I?8nKxFmMz!u3 zU9#kdieOfU4ilHAK@Ktx)6xs@T}>d zVdLcq>bZ9$?7t!J6{C2`=Nq=Cb5`5Mnqpb-7bD4$_Qai*NQp$LdcHO2Web^ zrsV)sYZNVbEB4hF*uMc%t=ym!VPh7|8zfqc-W*Kh<)?@qtJ|qo*iw=GRT77Og z*b#QM;QfyT16h$y&@m90#U%$SWVo!0Q~b?YojBz4n6eM<%k>InOM^ugJOoy2F%AKr zW6H4mWys;W#r<`=+9-~|(Xty*+zvJ}B=+2^Q&mzXTm~gGJ!6v~*wUgCk^#>|!+No4XbI#!Ae)2V zhO^XCnIrsDkr#7NoWlhrrtzBY;&*fWyxRk%y8{f7h}pAx!oV~g^8vA=pLb^gp_up}&s7tk7?k%OE69{r0hXr<-GCQjRVT~zCbE(iiiMwHM+ z$b}J|Y|*>j@y|$neF*{4hVh%_bVs0sMxCOUw$jD#G9Usx5CB?ygI+RchtIfHb}9kF zZBq3F^|>wwj)ej{17RhATI9vCs#nf&_b#*`__27XCH;E9&D=u=y=XrjiEfXeo+@8s!`468ZkU;i7l1nR~u{9o=Y$ntoepdDmF zzPc^jYfGu=|QXu(+74R^wC&6u9K|A@cmy8vWNW9L+7CLl3+{k&w-#Z@S zjnT>He)@O0=x)02D1ma3V0KCuF&t%!bYp3p%N?fIaHp4a!sxIki&#c#K<}FyAYSof z!EXC&;dPRr(0*6nex#F2w*UFjhv;Bqxj-xen3W;0zf+vTCq@jJMdweW z*QsXh%$vRBEg-@X$V&NBTR5vxO%B!&E4Z^{K1*L#7$*qho*Z_e`clPTDYGCwGGhf< zYIg({A;Mwk&sIT;7ESax??Vvec-6YP=f2z}plP*`9->_H4E)gl@#+>0PdL{ocs;J1 z`~*(!75js+`8AfyTiNR@co&aonT<3rXRJl&6tyedhf#W1J)Yu;p3F6-PC-l{Y|)c} z8P{@zCoWwx12i#y#Xil3D=}nh#ZCsXk5u1v4ya8|`)f(U>XUJtaWG|L2gQDSG}lln z5)1>q6dxzxZaGAab!J&MxP!2#6gwD&y5CHd={Y!{N@Wj zYM}dVKgT@9u9e;I6uiX`fHMkqpv-J+5ax7XNP^}XYWrhz4rCy>dU%%z+Cp8S5t4o# zT*KAoYTGB;-NDxX#z^~ zadyRW(*_&xL+eq?JQLikYX_=c;~&-W6V5NwS1~w@aTbJDDzUme*PVwsOCNp)w{R>! z?ePjilPOlfah6)uzq0}2aaGG)t1xXsV#V}wKYSI&78!K7RB*IC1eqhL*diniJ_fii zI2z3`(;9B4seQjTVq} z=hB)?aHlQtNC-*J8#F@9Y6j2RK2CWiFn11rejn&+1KHpHq7%tq8~jz9oL*z4fNL|z z{%&#hzx-9Iak!FxCJbWa;<@2+5Qz8NL*+|;T@6*m%^YMJ+cR4-^!j;Qh?`R{2xU#* z!u_B0rtNo5Im!LGZUzNEmp5~qBrc-z9b@mR?6}x7QjenZ807$$wm2vxfjULYI!d4O zD6xN{IbQ}#43&=|S>UBT4#s{|raAxYixF_d*zcB=XWLvfs^d zEcAY;@+a7DRA!lUnEU$1q=3@h1{S}o-=4ht0<5c0|l;Bt{hl4&JRN$ zr<{ki!MM3o+yI&&KhLH5#Q0u*L)OhWZ@xD3c;MUog=KXrQ)mWu; zB#JzH_6p{vmhij%5Cq)bd;)(xZ+jiyz|WLFDP1+?M9i;(5r@bED?|o}vZ`;xhQOsa z2z9q23_>>{Btmpye983RdQ1DuD+b+N6GDJTL&nEf1a!StMspi@`Zo?e$91{zNF2c> z^F_zhUJYrM2j0awXuDxKcN3;z53+Yv!@JZ6S!;@$NBdc2+~#s%b7kiYw*Ceq+KxX3 z(jsbYuKp$K)0?*H!6RD>oFPdg^#p9uan~v0CC{Ou4Lc-H*ULwK(}&YXRn{=Kk`;XK z`QTtisat>%+mc-s`#3@8$>b@zxw7fGZP2g!yzgrz6OjLbAUQS-Sgt*)uQS1 zlUu(yjdMiMdf%C8{L*-PBb>o@RVEVd`9gnL?OK*e03S-saBf^2In!@d6hiBdk|s|K zfIg%lDnw^vdId=*grM+d5O+faDK^K%-qK*lt(IeNUqENaC+in@F_>cfqgu!GEdVjS z{-QEgNRH1;&Y|BUIb06cpc4WtrIBM>PN<9lW!sNo`eZm48aji#M<%M__i*!V{@Kte zmQ$dshFYhXMJkA9QvlC}V&lnH|H`e^bD-waFu`G{t0cmYx zepL~|>K_UlMk@yY2*$zRQ$sBy_~1XC(0L@7(WCyVX>@$h?;i;yP0(En^8ax4mO*iK zP1NY%1b24`?he5vxVyW%ySozz?(XgoTn2a7;10nZ?mX{zzq(cT=b4$RsoLGAW$j+8 z9iSXd)r?-^C#YZ3MiGHY`4{u;1p6n@cV|yEb}c|H&XziD{MIdPr7*u5b?6Ini+IfF zF_iqRaTmL7qZ;D*VU_Rk{`&>`1(zOXUPxDd?>2YgsC_tKHJgR!CEobP^d~H3`@J~| ziyQ>n9*6=XGhH_d{^j;>p1gC|8ShGf*x#V-N!NRKbjlTr#;T&fX6@bd?ZbQ2{fS=o zu(6L0`EC!%vxPgcm+*^~(>!c+7V#HtR%+u&xVL?)WGV!b1$zur8W(iMO-_Lx-{SH1&;P?s+Re ztT^s%yso)@jM1>1m#1yfu*TLC|OeY=$CMw+#`)hYSx^^AR5d+@fkmuSmfFkt$4Ll3!V z^dm%nIOdo-zf4+=VNbBPG|f6bja1tx+Vt?vekKn`Ub~?LpgEu>e{!QuHmAH=QwBsQF1}a7Q)Yx^d#Kj@whOztLW23}={Bdbup1BSkK!2gOS-(&F~U`gG4m!2>mjq)FJS&DHQ) zvD8Uc&kS{Z#aA7M2dHGfx1+Ec&Elr9$EQ<`OBp%xM2OrtDGaTS>Ww$Zf>0ll%J-t= z0cWm@TTWq52-XB1;DamIszd%Q$2=g?KKG!ebE2jk&0ba#vMjezhq-1(sZ*EM1vP1J zZE*qZY=i)XX>x$yx~VCqVLh^6WHxaGWwcG7ba{t*x6^w#TEn0?@t0lz!eM;f>dxJA z;01GiJY7R;qp72DqEwzjiLy`1G&Eny*2$muO|PddkC{_+TK)7DzO)q$bTNTeK;X`M zJ3hq+k=vL12Lf`H#UGYTTH(3f|NR`}U!@7&8%{Ni-}x|M+$G-63(%y5M1{n>m(|28 z$IlWb^X&1 zD_rM-KJCy|@!1={rTs*99TvvM#%vNcRw_CeW-MVU|HZW;ka&3I=?!|Po(KL&*tcD< zx%DQzzh&9BxF*3H&H;8qXGw2dt9MXz)tsy4&Of@j+dc`Uhid50aLJuyz$Fv!f-@5H z%lIsR>qjCYC9d9;buZppl0)NQ5dd%Iy?vNQPFp5V=jVwZB$o6A|8>cm^851`I z@vt=9$y1pD+9CMP`Q6vBPd^Sj`VeX1uqyr zETsc{Ea{Q(hF%l9BC_%K#bj5ePU48!!Uk*hNv7=HWsMp+P1Ek|;=^okcjuVOTXFHC zz5M{QUkV$JuFRA2s44dA9tHEW0ezj_;e-P#H-gIV(=Y1@0U50Z)FfxBj*XeN-I<-C z#Ng0~%XPw%tSfGjVO=bU(hDZEEvntpY%)x+z0EZ-y6qchcoe^FAO#wwf9NWPh&s%F z#~@Ol@iM_k$+vvrEdyrj0#4)hF6ZmEv%QWSRKs+Sf3Ihrd6hPA$)sJtKY&eVHE-O# zyEKbm&VFbTH~+ckMkXkWHc_+y<|)yptVg#x*;m7XrWaCrQw+BHIaa5<1vW-_DBU<* z*Q8ibUSE-dR@V^bP~4sJ_;YnyG2ISVubPhk5LLMuyuPFj6pz~O#&|F{_?jlH(H;mxwE4o{(o@PsDDNBDJh z)4dwd^H@{@&vH9#e#<5ad1Yhm=VZoSR^v7aF0Xwq3vv(4Ru7zAR0o9{_Nx1W5biOonY^qc4=5K`6LQ54Q~mULz=EiA0# zEa^~+mRh-K-1wMcte|3kn~rWP_Q zRCJ(*Df;60go;FhF(?nH!-pdTR!-cG0{ouoA3dN-t@})FTU2RyS>|f$Wdf_4r>Se4 zaUGF1-k9T-1_9mW5Uxx@ngpu?@Gb3`?Z~J=jbePR)T+4s-s>OkeEdwq4l5&q=blX+ z5Jt9+1hQKl;Z*y*f^x(D!XKKes!Jds1Gqj{J&9+14DZrclyr2NoSe)=ttsQzd=?VO zq>Cl+a|!+_Y5Uk_oxOLOx#p{dwo>o${IWs~VAa1tttpT^uQ+ed~kVB@{Q(J*JNS*$4g#>&|qj$TtLY zFM`Avcd#n@mRKI<5Xz=}N3uuv*)c(Xuwz&V1|x*nSc9k5-y_XJi1{lR2Ci1YEjH?g zehdGh(5h}wkpMzcoduhmnDEa)588O4n|+7L7YY*Xg&Sn1h3*kaQ3uM;e#j)6k@s|6Er!Gfc(#hw??k5;a4SHD6k)Qc7}p3- zjk_K~3?4)is2nBK7{}4yAO`H^G5A4(zurvbQ~z7BmJ}cb@wNXsa@EO8M#sc?KZZ&$ z)kTn9yYOKg7u+ZF#S_sQZ5CO#SDyn>eP9&1FIdcTDF(6<(|XWgdF@lnLqw6%E8yrQ zG?BRCHSzV3VdjKmsoCliW@$#rTB#D5;tVqIY1_vk{=xag!~fVrf1!bYBP8L~usm9M zW7hUK#76~*vWm<_LZkG{87YUgY8DF zM*1#`a;wJ6Trucfm28kkM#Xo*PCp>1Y>~{Hommv!ux)}oNx=&+IZ^wzfpc$#&y#ny zmN0COsfcc&1&zROYZ8bWgfwKJrDcgto5@wI_$HxvjO`=vx}V&j2n>U6#xMk1#+_#8s_Y^!N>UK!!6{`2Ee zFMxAaTHS0Dcs5ReysuK>{I&sIi-Gqq#M&o zmf`zO?)SBeLaXU2ltA31LF`exZ~xu~SF*ZXDAm$=1`Mb&J+m>0Wq2I47?i|=UI$PP zGG6Qp5=HnCj1DQu>ic_Wlqy$z$Pjnt&B<0GM`H?1`x_Av$rIlrOl7dYiJsB&k7PIf zbFg2?`Q4d)rdefTf2Gx%`ANrK1n0_Iv}nfh=4B2twKItG-kML?PVC{VUndQlA8pm-TjQnyo>Q36cATfd8K&3B$ruA^#omb*d8xi0Lf8Zg}v z{WG)u`k+}6(dMD{q=YH5e8SwUNBA3FS#(>VH>S7hZT2MaLg*L2B~_9T`U6CacBI9> z1NlMd$VsHbzk#b!Fj+ovmm1ELus#{#OBW(Zv?vg=3)h%K^7{J5fG>U?3$kg zJ@#HO-vzXL5&G&wNGJgX&>4^%OU{45J_Dnm+{+Bfl@)iEcZQgMcUG)ashLcQ7I}00 zrU7KBVZ7k^Kt7vCIeOT-oon@%1}c%_$QsYN;++t_-0=91n16>d97uQqsxvwZU`Hhe z)LA(PE=>_ueE@0G$s6W_$BGv*euQ-%40VHDNFI0z7pjNxCq6tP>_#Et$lwUh zs};LPt}4`+c%mo$;tO|aKQtK9YXs!{a@2lwh1p_6@N)$_hp%az>j_2~vZQLFJNJZQ zdnh6A{u{W#`8fJo0WOsF-bawrjcOL&84+ybUC3B<>ac|F2qY1p~Z1A*mnxdr8{V}n#5^%i1#{1j;zuw@tvdkQv%b8jsrQsYyiWv@q>F6pvIW79nzwR1!tmAr zes>ho*1@}yoXvnM(U6{K9oR8|%swBvoFg2%Fno|uvtsom#&T#ZQh93m7t*td6XBTh zJ~S~#bp0UFxg#QVAsg{IjzxcBdG-Z!fC? zSJqqOiZ8cHukovkW3^%@g!p;zTzmuub30kGG8-{BD) zn@4D7qA)?k%t$bbI#_!lHfNL{twHo!R$&dlrXE-atPbmE+RK#C0(Gj`9y1+_v>>$~ z-M>GJi(Kx4sFHnhi!{soi=MT6vvlDyW9;7_PH;mR^#RNac z0A9$D=p@tv@hC#XToQpM`DoPXTMe(rPU zZ=uXfA3zwCXtkjA)*L^NR6yU^!)hwndmKpKg6V_kEyW}vL2WDA(zD$DuY-tGx=)sZ zPoSWt^(WEelh z@zD7SOKQCJQkJo@S|Wdp2PlM2s&jSYgP-=m(%hEFnx8tZE~L%IGSi`O;ktFikCGRH;0?gc)V%S2?h>f(X5PwL=a&9 zP=y`aNVZllt4Q+RGGNBytK-D_>uQ>gGuVDGdez*`0{~vFMuy`V>RiaFM}z&;ezg8a zU!ex#L@z6b6nh_Fl5qCXjzpDCAe?YJVWF};U~VodPs7cxIa#ryr=>urzjcY0KulQ**_BN1W} zHGa7k_w67L_^;?*)B0<{9^&deqn`ugG7zVFrrFFyg3dm{p&7)=qeZo3MXZEUD4pjU zV?JwE2Y;Ly*<2?TsKkqFbu$+Z5{CUJGIqSI>-Tnbu2O;j5ZqdvK{cZ%{g%7gBf|kx z+!+dx$16*32R!f(DzOhB{es=(ef{e>IftN$%!cY=M9I^K3?rFaS6>qR6Eq^uNFIRZ zqAeAe@mHN7N+e%AwdQX%**@Kc&?E;RUxp)ejl|i1If9;+wdmtxNJUlw<6^Kh)Ak&JMt4$@2lJENlirC_xbhiS`T1; z>3Ji>Ih(wvaaA#Yf~ERcm#OQrLwMuZ_+N%VCvV`8xR*zoh`;z_0DGUGu|2+|JwDII zWqMkROU`%hI$?N`{-ep^hjuxogVJ-^u_ZJEz&Eb*3|Z<4G3dDR3sU?|t$S|k3GuHk zGC$+FcUQD_pJ{RjQKR4^{lED*`)Rja?STWLCFR>+(hSN_D1W4UR6!7nRpYdRBJGk* zT>$d<&N-2Z4H`2pAm*j8QIn=J}fAvIg*DFR)qZY0>QA3wU} zM*?=)SuJqz(NNpR31S`pQdR<7{vw!g2Mx)YL#wLq*FkUnhFfb0UgshbY4lr;U&vR2 zh@T`yxB%=ekw`-*>i5CPcfgxkI}z7RE_Q(J{;4*i)UdS$CbuCJPKsP~Npt9&E6`PqpQ zj9^@=M*v2nXOh9LI#cwdRyA^`umnBgMSrZb_}p5pbTsxoel5;eV~Zh&os zt`>h{p*o~^Wn^u+ZDnt*^}4%-r(_6a*Xr6eeP7w8+H>VgJhBpK`OFN#JQvxE+`BS9 z4JYP>`$eb&#IXX?K%*&b6X!1ds9iH*kmN?z;sJ8%)Qq%3L|%_4N|(4--3Ge0GULm1 z3RPnt9xl56)KDU|Kyoq zHnMm9wQAI>*fJ$#cd=hJnw z?yg1QNXt?~F>}hFrB~8aXN>rHs~g@4#?3-OupC^dB+++LX0Xud**=lMfau@gn+0}% zjW|-;8ye_j#B%gdRe9;bjo@D9&@4OEr2f*ypD6UznF z$cfn0faanuWnB7!C3+;wSSmtMy8D|#el9KBr`6#d1JT$Mu?FUB74I}2Bl3=7zdJ$rxc=Jc|Ivx zC_v#pq*XtghuBYvGahDy&e*$E!BZrtUB|$&u{VuaT^_8>>hK1qGA_x{(H9qTvApT= z*NtS25|Wu|!z<ou^Lygh+d6^yYfJTJ%YE z{Q7S64%230Nn973p09@cg4go7ewhUdn{~&QZ?A?{$dYO@lMA+$$dzf%1+{?Z;MOY0 z#J2?=2C)(Yzbs{+C1t~VgyVFeI^qk*!fV2DGxKQ(P+X1|G;5LH#mo6v>&=j!45`d^ zo}F3%ubrA(!ceun!@$69zQnn^lG@T@LD+J&<8$3TEPq1n=xXCdXSU_N(_P#nuOT~# zVfc;D#kjP$h3YK`DG-O9r2MMv9|2Y zyaZS@Pd=9idA6V3X8Vuy>aUi4mJlGh!m0s%IL&kRs&?%O7tV};Q9lkD{$GkSBOWN2 z67%x-53FgN7j%k&^!mU1p)}cfqvI`B?{Do%oSvoog8w}E(~nmS$?9_Ic(o!|qe<1~+oThNqddH-UPvNVDKBnpD^6VD zea-OZm+d?!&?rZ&?4b>INpa70n4LSDIYd21>a3OgLFCqrnjAfKLAlHhRIA zg%xU!I&1FbyyuJ{ye69=?r?f>7_3tSkH?reh7tGm5*Xqb0-?;S%|PO1nZ+^4T)tty zkn^N&xUQ&)a@$?~hQ8#ahqV@8w9H%Ch2n8&zA*(!;n z+G%9y)gok6cMDztKi#O(Zt_2-uSRB1`OUm8VR zC^Jl-R`G(P)xS3D@G-*}|9Hg>!>r_*YfuXz4L@hy(wR?JTCbXs2qZ?JB#O`-nZ^Sc z_R()7A+jZQH7*_wbt@$jTK;UFMTqZ5te6$u%?d&h5}WAYspex#H4e-^I3zz~5e9f` z>MIGt5HBJiieRCotgOarc58O*W%~D@QdS$OO6}63*~ri29vgZpMr7`W!BqVgc8?h1}ATI z2fjsdyxG~raaoR{TX|VLSaO_(OM|j5R}zd(TIzi?Xg)wd)jrnTOQ>_|AMwU=@ZjKN zfiF&CrnroAQ)!cXFYUt@zji#9%QRg8jDpo0ufdt<_ICkK*^i^x>l)aA$D z4acw}HBP~-f-@FlrQ37xT8cTi(;$z>U!Gi={TO)fSLTWMK+v|TwSUTa;2+A(d}l)D z2|steg%qsH60lA?1?V54Jv~!zUm44+0R(zuE|cHHx&A&25g~sKt|var4+g)6xcfya zuwYrjqet)fesFJNP@vXboc^dxJW4bs(8#bzfcf6uJ*gy#Z)h=!E(JqG`RCVOW|Zk6 zTP(0&QnCQ@wYU7u(N#DqBIWW1di?NTL|=?HRNuheN<J5Iyf<0AVh#H-uAj(kd&kDw<)nq9> z2sFj*iY%g4!l}?tX-E`FX~vpk73|KI(v@W|3?bjG(KkvD(Q-eGpAmsy1;q#PlA*CB z+2w2GkqDAT5R`yLLKs-6{bNUXM$DCAcl6Uh)66#pmgruY4jV;4H5H^QYI zh~<~Zj{sVi?B7hnFQcQ{G5vC;qQKvB@~PkN(QCR^GXF}zAhdx%=OfAc?ci52ggFUx z>lF#E%;Liu9?!q_SXi}Hejkq&lprnM_H@mG(GTxPLO4orZ`6}e@ZSI`a84(TmHa43 zNmN#93ua^3a85&v+_mC)cbJzd$@&9j4Kr%C znI|8WES)2iw%O@DIjw`|2d-}}LR?ai6V`B9PZc*hqK#HGCS|t5z%z1&h}sJU)Wo9e z4)f`NJCGuhac(S1#?cz>2N%eRyUSzORMvZX6yVi!I9?~&W_+0!G70{t;55bBZ*f}C z`yV>fkB|&1Q$c3n6`OLZ>o}2w6~}I-w`)N>y9>$4uF+X&{aaM}fc|?uOWs#Y#0eIS zxELmIwH15}<>*n!z_p zBOKsM37pL21XTS5a1yj=Z%e&5qOTIYFXXQa2CwWN2MPv7R04ONTv}PH$O2_!9%-F% zI%FFx*(>+k4qIitA9RAnJD(%`4hHel+YMFUY6iV!7%8V?xhuE#z3k2b1znIW3r#Ob z659-{r#rmM5J7EZa(ZJy0OB1r?-<0_xY5D+ZUC z5x81kP|D8^4i18`szKKs7NJ4^lQxWOJrfH_CHhtnK435}w!^#TyfESl}7#*dWq> zQdW&afTKqu#AeIG6EV`$kDh-Rg)TQrHTwI>ECebijs<&zP0?$_FbB|3{uxu$Q1rS9swr3BL>a5?yH<>?ilSwd6 z6k7Z1kH?;=q9b`?ztp(K0DaDrr;UZ|#mX8IE}J-n1Dzq6K$ZUqOIkxu-Tbo~I~S;6 zMKJiES`@oJjBk8vwGj@Z<*qi!Cb6EbJX079bPthuWPOVNl=HS|-cTtm0OrVxP%u&} zbOc;PUYWSa6+vj<715u$mEoG4)*yV1;N)rV)WN(glQ6Ii*?#u*Or*mg&(MDU=tFdI z6o0w-o$2W4!n$b!mJO-AxF(>rFnAUVal35?m(OF*CxSI0*oTmMX%gguaF|rjSca*t zIU1z`+_SaoOiYH&qPZ>m#r0st5owjntjPYW)+-c{U`UQB7(YODkRi^yXr4nsq*`F% zQuW}cau*qytqwOrs^ww9cC4*6C$V070+8%_FlvBP!Lb8W24u&5@D)0Qm<`Bpul~X{ zAvi^W%YpsZP=_ru&{x5(xZ+1o%EV%@j@*)K3mxcg%bLB=anPb5_Q7GpDen14rt=Z!LQWAIaC_t6X40%u3bZMs{Q;7d zdRHcPZr#Wk6(L)JrvaJbZH}O69m&Y)(E(KEgPz$=?(bdWVf16ZHoO6W?Eh2u3t1wq z?(g|MHhQRi0lHs+Sbk3TMPpnyjZZZJif{zTCG6EBea+imbGD2je@bsIj2`M==RH><*8O0WO$pLa( z*pmnVjF;HI$rtS)!iyzr%(bYP3XzDL)F=(NGWFoX)UIMI2N=oVf#Qh&!no3L7Mr<* z_0-GHz`hzj!eSR$9M*;M@BIbq0yc>{TNdg^(1yd$cms2AvBmRHjcVz3A1du8L4=Br zo#Fw7hQa4;qS}`n&6gAi4iHXo>N>s3fz|cOcNZo}cFZ7}F_-@GYZDkVa*K73vGXQ; zo1uh;=$8@jgmbM=htF1YyAEG>SQD`?FY}Gn`A2ldsQ^}=ObxzlXg=UfYKGr9QS-%a zf~AN_`i(=j&@mHd}hm zOjp%V@Mp%K>0tO75~So#3b$@xI#mtKDR=`M&;%lpqn)(^0ttyTvz_3QdnAHE@w=|C zk>EKUL)EGV?@GF?sOh>9`w&x@W#I#Jh;W?{;`iziDp=Hwl(j(Cju)oi3VZ!EvSK+K zakCg;xp2cK{Z>G|UTl0KBzEzMX|Xvoi|-slKanW3HgZh5TL}j&n~aCMMNd2JFxW@d z=lNh>nHCcQashTf309xKJ44csCY!er;4L+U8QJChL^P|^tJ2~X|MNP9h#T<49*fr) z5aHBd9pr46tS*C3kXuYO>WLv&_q7TXyptfE`rm>6e|mpsj;gu9ocUBY&>(j#`+a0* zz}F{Di82g5JVqoGydMLndV3KPSy%{km?z%$Eb-rsd5H7mD7pD=v9b?Pd!*Xi0VPH zmgje(*aYC*75(K951c(3lJfXVQCjL-*vkc5%Ry66y347?=MS11w zqi)!@RO+4z(MsYePBOV?sckUgXY){BT+X_l0D{XN=%{ zfjo*Y&K(pzKnaox})2SwWt@ zsGbsF@c(LCA{n>{-7gkRXwX7nW$;ydFjJgAQiY)h;@jUSpI|R5x^SWK07zK5aCG(Y za+4;>!ZQVkzrs#0wGHgAJ0uhLFC%` zv5Hiqm%+YiO(G$Nv4YG&FB4cPAFz79wr(KE`VSL{@HEHVPP-EjEToOd4c16&{yRNy zR=JwlW>W|OQf4xYqq{=}b6M*|!DA@SAU}DP4jK@Gg(mm6@>nACj!}{SRJbFw7Jr$? zPJbx<4^q-!{l(V5IP)tqRc4F1c(+3!{aZjsoDy9Jyr-)#F7+~ehS}T=b48OD1}`*> z$(G|R<-;pw<6U@bSt&-rAtYK85>_@dL9O@otooe4lo5ku;VRG)TvUVZzoY*DyZyLz zU`xb%$B*1nc>lsdYDEYLB#OBYNkWn)$V-_3coC7QL-wuEN$pZ0=vqFcUBlr8QPnwU zfJnM{qMUO1oDfD2s|QuaL^1@kky}idvFj{U*UxvTz#%>-6J^J3(we?T-ec^>(BI^4 z_*FeTPua^XJbj4r66Cz7BxrKCL6@SU(F?PWNowCG8@AeWTKICgDv4sb$EcI4tU7w=_$MBC<2Hi@aps7rzx z5;GZRA}S#EWpnQreuhz~5hNFCn|x=i6^)`@%komtiW0Hw()nz^=ISrN0Cn*AEib!; zd9nJ!Y%w&ZNl7$eHmakVP?k@&l~hsi*kDR@6Wb*(b?h^Cas5glDsOC%W~*m4xgueG zZAt&t-&6!Q8X_n*aJhJ`r}d`Nq`ocRs++0tU)4dSa=Yv;)n|_4JKxiB+^Qox@Z1O` zCwhnM=yi6buW6MYvES8VE_;kVWQ*Zi)^0 z7e<;xA4MQuEjTH0=VDh9n^W@!o-NKTZM!!%r#8cc?Rr%NKX=V-A4bE6MY#-S@^Ybh zEjwR|!#S%zxc~2|{$GbCEcMR@MUfH?FWGHsjAi;EZJYzb!AAlzvYC+Ngw z@x1#YCuXXsY9YQ357CB!a-{fXCIq55pnbf=K7%m~>F?xRW(b+>*|C>qF-d7!83X5pcw8l610Rw6WxVTeI>Ydg+1 z=H)yI?m%6%BTN@$x{S*;l2~=aO5>R#n-JS~b33nxgibX}Q(o?vv4WR0=#d3b%m%D6k&Uxvi8-3Kx%$Asj#^|l zjV2U1Td#|DSHmBq-ynvW%ZkZ0?hfRQGYBMrmO3%2cd^QqF)8@4sjxh5#$a0T&h!qa zM3b85Ng=37;28FVEEr44{}CD#ag%inMn8o%k-bbn3`k`VD)!J@bIAJ+{eMCXFYJ_4 zp5OI$_pkvx{h{~a!GY=dL!R`datK)hV`}~_4!ABezyt+?*;$Mn;YsPbHoCU%X2%Uk zqT-m@AGau;@w>g|yTvXKggV(&+!aG(e>av$S^Qr^!#-AU-dJ|h+cD4VKtbV*e+zy~ zL9iq4^;Kcl8I(8HVE(~uRYi?P-QtNqM}qK8vO=%e+h3X4=Wi*4-9_-(9M&1#f4v&$ zBrD*Wzs)9oXK^bK9_!h!Rr(;5wI@J!xD0fxMv5m04Hw{|u?*A(C*tP49x`Nd<5^}d zJac!YCeDuOXFZ*hG^QzrgDbEsBY$4B*eA-|n~E+2U<5%blL{weQ)e`0{xUj*cEM&#dXp|EvD zKCbEUb?sUia59}jjvS3yP{_CB?-_X|qO`q|Hs{mq2s;k%VI0`KD=vUv*j%W)mifRZ+L}^Wrgxyk!%Rg(7%py z$W6%Q;Jh@xL$K9-1E{MGx1FZ;~z3z@?a?bNyef)=TB(r_3cyYh%X2bou*_Sz{ z?_YY*>dkXlYv-5UFj%dHq4~@01`$PCMR7wH#ozH_S)MdKN11}&B10Emb7epaqV2u;yF<29N|kHrx+QTc*vVh%Wa zLqR^T{$hqOGmaO~dx}rl{Ru3b|I_}(yE^adE;ejg9YjjOLD|#zDfF%)$Ul3>Vb`V_#2GZ{!MXZ* zO3(owLFOVTMm=x8HS>)x2#zazJFb81`I?f-@DENH$7r=q8y~aCyt6H2fxRcM;Su6~ z9^*;?=ymw624;HvZ$~~eCfuj2tc>ewrx!#d0aZg&B|O>4hyD9B44N`uZ+ZW#`fF7` zDeXUa4ROLPtokc|;ibCf>D3LlcHEDHAchQHqMTJxPcu_=Sf>Gj`_yjgIF1BKo)W*ASzMIG@C6wlbBS|Mq!UMbJmY;FFH)Bx=T#d3@rG>hsX%zRBPj z<6~$>ci$jjKSE5%6yQ2eYV!kO-JxOG@*K%6OSmSCuBj^C4*`O6RbfdOM~Fcl;gwDW z@PyD+5d>tO>jMf7<03&z*`a|2dZ!cRn?c6HTcwBN=YwKL-MB4dfGB=|$qzp&2iJVlr0327Rf%-b5 z76IGxRJKg2KluZ5A7Z3mlW@BC&kQcltff>rEaag-v3W{~llL{Q zWG!^@t+etNDB-$YVb+FO%!K%F^9=dlmchKje;p=J`EE9TPUZz9 zeJ%z(hq|$ToA{~jbdgR9?T8RkZ3h=zhYz1$Bu6nTS&m*CPO&Fa3wnJUw%zvU4V?B8 zan#e^l)vHQIPfoy^!nBHH9hq98*-7#GHMZHX4enA$gETjksu)~QmSccjmGV~vP9fv z=X8ni*gYp=bAp18+zLqDTxMe>4kWv1J(hSK`@d>GQ^|#!I6Aioif+a4Ki`)BMps)8 zg7=w$ZheJR^H?XHa}gEc zpvLEDlwnaDEWQY3(M>Wg?P=TH)mkO6c<8vZJ@^CJ1QLTa#d4f1#|_3d$; z{8In)S}(G+6bV)0Osjc;D-`>-U$STBuq~7q=krRE>qYIHX%;SWjb?b~4)G&QAuM`Z z>#XjVKnu^GZFJ?1<%S<=422fNaCgUF2MM52k$VnlJA;6unJ*<*9Y-G4D_F>*JeJp9 z=T^{05oq-Jt79<2FTF^8D0kk^eMfRVp?OmN0gMT}-?5Fkm4qz3lw+qW>ED%rGabBq29 z871bbZUVE9ZcX8+pfQDTy&_!$9?++)GU{S6eG3*G9nR4K+_(h_B>C2qRC--N*vhoTS`Mmd} zY6~ zYH-7Lz2BM=Y+&!vBpxU1?WLkVtAwWA{3xpLBr4GnF!Pp-mgz#<%+$*PhKZjShn%Q` zwheVH3=fMV_InHKsDmVreqmK6wfPEV{^yulE!s6lp8qrnF)f;f7&r_-N^Ij+EXnL> z>bhQfvk2PHpyS-KNMmWYgE=^+PFep2?LMLM)W^sDzPg#n>X%)2MYdJ#1P>F>$RT#x z<4E1|u!l1;z_=xsa9t)O9uV?ib<0DpvNwsG-Jmn`i!io67)5;uuOZkomtLTYl^=ltpWWM@x;lT7T?vuT%N{kQ=^ zSbBsh67YiZ6H4;le+DYi{bU0Eu zi!!4~nyx9i{OH&lraC0*?KitK#}=o-t)(#~O83))WJ224K#Y7Tfc@2K;`>waFl*@7 z@Mh8g|K#ku^d=_Iq%kr{9VYX-zKbf-xnA7;Hwu|b6Yc!}L)0}!*VzT@6Wg|(G-#a0 zW@Fnn8e5HR+qTiLv7IzNaT@cUe)q0*f4=X}J$uc}o_TP*5N)ClQ-#b}gD>nB3z6#wfIGWA@()A;ClO5aFmg3$bfsQQ^8+e`(!u z%Q|{QN{6?U6sWX?l$VDLndHcU`ISQ`2#(TXU=^8WQ2>u++7Zefq{0xY$4s0m@j1eF zzAuCcN7v3Y=@iuMXWt&^F=Qp(Zv?`cC`Hp5ue#rj2GcVB!w5W8{+8t*FW{v?iel@I4F!&8-bPO}oO z^gy8c-uAx}S0YQ?u)2i@HJT?7taX3>Hc6(WwmB6X_O?5{GSjT^3LVS?EIumhM29AL zE;Y~XfaANh20kINvUlzCxQ!BUx9dd4y;DKI2p0J+f{zRDbh}KEBir8*>6?Pt0Ijp} zgdOdJ9xk|D&uf7kB6Z;IHY7M_%FDsFPeV=>MJ=JYvGFJ-hC@$fvmQN*0^Wd}czy{+ zBj=zrIWBmi?UBZ~Ps@_%SVJMrpp7I-1zKDWt$(!nv?F*74cFG|N!gr*CgHucgwTHX9=tl{fwi}0s?1vzT+C@uJd(-X?4>7~C$6~vZEURP zT6R}kB)aC#UwtBgPj%OxFX%rWBj z@eqF3_rdj5Mf1ok=kN z6qz9br=bEeldTJ3ecA)ug5u|Ht^8gj*;mfU9wJVM#DB-cV-5}7BQvQaKnl`nu2`IA zM;zJ!9p=}8;|z=3j+SPQJkZ~i6@4O^`?u`lU#d$jHY?Ls#5qjEr+mPHU>MonnQdT; zCQ@Kmap!lUIuL6@E5><$-8>Kk2g_R&T6``p%!`l;Ms01gZ&n7mVOSMQvzw@L)1FHj z)rcPhOkaNnN|j!JcTtSdWGA#}2CvLe)Gu8PX4(|DDw&1S09wW?a&^^h2OpeH1_!B^?RlMk_a6p5iP`$$|bK*o;(XU z!ALycGJqE&%mAA)dH47khcY=sGFB$|^*WBfg3J+OtN?k3TwX#3k|0~Q4W}^XK!f4C zakH+=;r)i)dY&DVDA1d0u6fItr1{8(YW8xOZ9-2pN#aZE?e|!FtI+S04oUM|s- z(bXVSY8+TpGYwRM^&55rx3!gcxI3NZf*Ai9=X1_zUh}KaG zV9kyGTh7=rLP1H=+X`;xmK*a~6^sE>n1*f&R$vjbj@Ef8XRRIn~) z76eT)@y$8PXyF#1^!=xVViia4`GbQ5viRS7B(3r?7t~Y(3w7u7=2O+?eTu#TUV{|K z*UR)vK!w0dZmGrDf?IOA#o6KN=Isefrlef{j_I+01AnkpuVd--)!DO0-do*G@C(np zV$6f;t;WjXHa5Gi+*Vnv4$XxT&ow&1o6EJ-(}Wa4zz*zJ(~@qJmt+!SxBzlR1PUWW z_C!IALehFTS}c_DvE%k0LFd>Mv~F@`PTp0KAJBpWHruNDvDo}d?;sgsq23q0N;LTCZ4>si@`nSwdrkwOr^C2H>1?~(f0eK&^IIhI< z<@`F7FPG4cpMisHa5PyQW@k!jjv+*gyT#9UM|iSvvJ|R|=9$%sfaS)&?33!l(Ss`L zCoFNc`a;9bTaQ91V%x`A;dqBG>RU@hUNGI^OjXcvJu*~xIKnxeL=t7i!&>0JH5HCt z@w3nc_fp~Ye9k-|dHe_(20b>T^5-yUMvx8#PyD^=>E-PD5nd2S%6iwdRcUy{tDls% z!_Nu%K|2ulshSPjxz(25?U(Av`-)|lqLrI+3@@^N0Y4qhBPhU=mDv=>?SlQ<9S-xc zCAHT*;Q{hYG6rQA5M!0hX&aj>5Bz8qn-+bV`U+MB@TgDGt!3hi@2W4G`>6`&?X-0y zx_y4{##YaA{F%_PHt-V%bKTrk0Xp^k<2d2r)%7Aou1aSG@b`f|r4>3_kdjW{I=&zq zi4S=>{bbYJOf<9kbL!y~!~64Kg6p$0(a>wdFw-S3j4Lx`^P`GrPLyr6iqQvn1P zS^wipJ%lZ5SUSzQAQQ#qTP`VtXLx#`nAF)fbCAgT1uzn6x>2kZ8ovOpSMsks-5U90!xwhZHNvh)Ja5Q9 zI=o;cD6noWN1$L!xd*AaK>Ux|q5@RRwyCLH@C`Gg8}$oUt}vqoGo^RCXV<`8etEn< z^bIv;nrR0PErMi%N|e(SJ*n$2?$`1)Jcy_p1m4PAL(7&qR^0kye8UBn1pKqGYn^jF zb6OeNNpYwyQ0kW9b`S4!ZeL;&XRi4F{{FhduoAUHD|nUs=#ad!@f~DZRpW)ErSCF0 zEW+U$Pm>PVvQpz2CN;FO{H@5=E1-~3iG#<@QDVy`lPC@yjv{XW8-<1soFev80eSQYcjTTd-h7aALuSn^4`U?w+z?4IA?(P>f+`Uc^@!>wDxv0YiYbLP96SeyO@Y!_vw#`&qACK2${n zX~P`UYa^!87aLD_FLI){DBP~7wYZcLXe^#gxMEuZSOe6u+Ae8ojdI)@1(jT7Yx8Z! zNwvFP2!`z~2-n3>?b&~#3AprU&U-v@G>6lJ*pZ-ZyDy1{WMN)Kw{K5XBT+$znQw9x z#L^7jV$P$_uoQQS>%Jf;qFM$;8O12l1|>vQx`zH#TX+>EmMz*;pkYsK`cd|-N{JG* zf>rpuqm9A8Y}#2o!Eeo!6eba_G0^ON#V6FpY`7wSi!Bb`IX= z`SwPzMScSpWs!r|1!^l}NT?BFxLV7@urBfb#sQZd6DddCHqY-Z>npwqXcUDapI|h^ z0%~Qt7XK-7xoi#(?c4%W(4&)RpmT1Paq!Q)j5zXeQ?y{mdqQ} zeui0>3uQ{*3*>*3@(*;+|5rtkGHHIB17THP{^+nLJ+Pux$XvBJ%z>4Mk>(U2RU+IX zX-Ci__Wo{JgH-O-eOZc1Uatx>+`=>a!USeWP!xo}9xpLB28(}60D;hnmk!%9L$&*j z`oZjt{ME9H?KCrJgfGwqb#M+zAt9R1OmpCxW?~y07c%gSP!_G#wc|+0dF(llZJ>Xl zJOOTFClC!I+Yowe-9npk$gcC8Sdu(4I!e8xw{bu#Bc3;4msT&T;iBBg6wZx7Td06o zIFf6mpSTRuO{(}(v;crI^uaV-dBd}#mnIM$pNw=13k*tsu{Sau)AQ5*ya2658baC| zYSf|5WO#Mi0aYxMC<6%@pc$Q(`jh@pyY>#OnV4__ipdz!leM^&jq+eZ6x%Ob=uwh+ zissLW{^w+4a)yCttBPh5bdfhU-2mSmJE7W`PqmeqO`c0Xu#q_Y z{{{rl(6wkF{U)<7Rt$lhz?$?I4aWz0>>201EW zUnYB1%B~~DK~J@#q!7YVt_Q<}iDo({f5d~^vBEcYtmi@4nue{1T_L4-EBk=t%5|;q ze&2^^?_R$`8tXr#Dd(cT6y72Q3&f>_qRkeXl=3rVGYPMBv=92ulYb#^1I9~$Ywc+g z>~`t6i1KSbu9&XAIJN?BtuY4aTP1=cBd$)n0Ze(&$I@Yjo5sL;0Q%V)C!3JvxYL-; zWTqvz8nRo3XY}s(`XA(yxF9De2)3pMmVgfxN6r_+s;lt8SMmx9W>vz?sDW(YJABIi zGuk9{xDzW>c1!?;T+M*g{_#!ryz=E=G2e{#&uw8#Zf5B=u=$K=|I=8rBrI^iGDhYP zZjD$h%Q_O)53m)tXcsR3AyLLjXxi+5j5Rj(o#JP=u3CPU(cpQnud2kceiS6&#b2tx z5sw0(xE}DRL0b41lDvQYHqKDo=3(M@=urkOoId;?n)2SA{|ltu>GMs{(L#)~!Sytu z!k8Dc%5>M&fv+TRzj$wT#!6rbPJ|E4$c;_;^e+rfm~SoT>WcE$cmWj<5P9ZwU7kNP z>f9WDtr7~Bo!DH^@C zwU=y8j<0~~Uo(p{{4lVKXEFQbnXVUu=1Ixz%Le|o2MgYw`N@vso*mix8C`RKW*b*= z|K$Glyu$m7KUASDHOsPFp1^WmIVzH|2i;~j79u{?KpJ`)x`(7_g2PidA;UwV1ayxA^` zwNs0PBX!Q=MTE@^c3IULZk~gFyD&q|;$#UXd}EQA zgEZObdg@;18n8KVS>9NlTv#BvCNE`sH#d+r!04p8fNWD+@`bTddQZ_*_ma{y78Sbo zi_;fYu(XTUqs*^d$s!mbLzJ?cn_dT{`~Y}+R1-#khwJE1KTD?@*RHoqY?_IH*=6~T*e%t1 z2}wQirdyTNse^x8`oOr2OdM_z;YzBmH$bSFM79rw{i6C^>kq@Na6tWX>2!cjIKeP* zlrS_Pny@{uOSxxfkcuh#o9_K{4p^d>RytCePsx%K}zjwmGP5 z>auY^3TfALeMrLXPX@A*ha;a3wZrR_9OkQnq5b8eLUxgJ(f^tSfZyG=Cbh@V= zh)XlgA>s!PXLCmpVw`ex%{Gk#$Qxi-%)q_o=%vp@iv-{pF=<7u25$j&2y+_E)_2uj zw}AP0s8H4LU(an~^~H>IGQ zXvIzB2G$rhYrL&w5yE5;b~R8-*vBThq61dwRR<9YBpaQoZ^od1!Jj9oy>c%XG}QbU z!;R(Yw@$F?6|qaeS!Jcr>2X|z=%H+sX5iS&qS~*z>_!`5m2!|x7WH*RnSA#+BV}Ez z37O%LCi-d*FoQ`6c6*F3b>}+J9@BI-Nw9scJKOS-RLdtA@h|rx2J_@{AAI_*$Kl(c zo9!DLD|GnhJPTyev~A-mR=HcEJG&e;tziW-pOvn2Uq;yG$G)ANGSdg*G{GDWdW94~ zMQ_vI>kZNR*7|rf35{mGdkUni*IG-}&yOltP0lOwm(kGB)cuC#1(7lbRrTTz9oD0| zoPi7xZ!+&XTD4vDfF2zX*SKycjtu3@v-a)n$5t=MZMU!PS&<)92jmVnf3)deU4h%4 zJhv&#mOt`D{I>J)LgUX%;`+2+ z9&#qPZB%E~YLT{Ams%z_D<2I?{Cb7IN)>%u8ZZ6lEp^o`{k0r?+BsU|?EOIZxDdo+ zxw2N40#djZ`h@h)d)-O}^0N~*fTNgNd=GF#cKDIR_gfGZ1#JQ^n~x%e*=bvvdf zdZh}6>#gG7*2ahw#1$$XnX$=qKHOS~T2{nd&uUp%P4f_jS;U#nkjH9QicT(qA;Wwk z|8OPBeTTq#D9J9Lhf(y5OH%~Gcq zmKEZ|BtyMcCAnMq6k8a0;-i#==;FET9EeOL7RG%`MKn>8I6EWpe1#e=lU2iRqW4K5 zmIaVEaW_~!_++zn1F}6(7Q1c>T({9WrBM&+a1!;Y6a?G z_qE+H>zPZ7e)8tW?(d|cZ=N>a#SXXpvfq}ffCw}lVGfP5|R=K~&yB2oA3lSnNa zvAx>0!rt~9uIeomK@Zj+xfyTSV3x$npdn7Z z0{Gabgr9(1?MT;}9>pInIQVflw*3cYu1xvG&=6SXXzmg>p5Ueg>DmP(iXQpPF+z|? zq`V~dqZq49H9c7Ctdc<$1!~?NHjY?@2gI#uFP&{BEeD8&d<+sq9{XYI>a;uZJ;?xm<*u|WbTIEwM&c_sff2tUnYKY#B zG|AN4cOQS!L^`H#Pp8WrDt!gL&3vu%P^6F519|Iq%|TS-IK}7J*ZbSHd#FqeixR=y z(;HZnPL><`&FdLrY|tiqs_uosKeSPdwXf zApc&G&19AdvzEW_6#d#oTI8CMh-+oZgf5aIA^N1-!%X)`0dgDT<>9HA?DugpNA;1g zJB^7_uB%Ji&*=lneR`Z?$Pkq;1^U`tx=8n=R@G|BFp1FoSRmsf_aI78h89IEG%WlO zq=jI$=M?CuBb75e%WS{!iatBR~A*=pHhRuYCjLNR7iOr5p z3_qP5-G;SYk2t>wj|5k6hwv`Hh0azRU&Y*;|!|sE$Lamx7cCxB3ey};nQTgR)8+Xvc0;quF57#onsu$hA*Xqn_Ia(WQ$ z!kY^hIlEl?z4R#G0#ek z{%-Cfc%x~^HWei_Ih|D@2SY`KbIk*-N?bGwS)PhLr8U9+w3Yp~lVuc7tNlB0%eB{^ zZi2b*mU)i(Z9`CEwAR4qM*1b|a}wobW&33hdMu>1WzpXm=V^m%64O|Z z3}A8+W*;zc75F#@3b2dRtd}-~FStb9s)^Z*1HDk`J#LXC>HkEql9y?2Yd-$*oO8aR zJUI5hCK2?c+NHR=dcRG&+wA*1RTqEw3eIvt4A29d;Ogt?+4$oX-yc6q-{ZIWMVkmH^|QAnuz*i)kAvjP;S&o$`$f0Pi!V~in=db_S{fVC0MBvbeB}C!inS8{i2p{@ zJt}gNbd07WX5{-qdoaJz6`<{xaCV4{YI`BF7}$hk&^&CAlB^!Zh%}Aq2%CgNJxM;z zrt&mqe2WmYM8DB=vtkgIN(gzUtfLv)uj>qW!*Q)WScBJU7(RRgDr+j@MJh=!)C>gK z!NAUmH&~8*e5Ay&zEJB6T)0e%AU$x7s0UtW-u1Sze{igq1rP?bHnS4L83N8eR_%>C zz^Y|BeD-cZU}=3I3`IhFBRleStc!g6Y1KR2!`j_ z6OC-%lsN6Tz|ZP{l#O%1Y2WQloI~*x1ZvKlp=gLL(RU(zTQkkj_&SPgvsU9UR^aPK zeX;#plNN^)?8DJ)vJY~Q1;kpayY|{F2j~H%teS&=A9UAGc5wDgN_x)?szaVN+sSim zcZ8sVSfu2;sz1%lJZHr4-wLRvLf50GIsdG7R9Rr`JO5AX;sPpp*Y0R!yv{V2>9knW z#(T5m6=#IS+1!;jgXj)AKjkxQv_A+B%}PY-ojsQ@W%1z2;|os!FsMu3;-y5CVF?2s zM=mJyVU-`sAlMtqXBI&uRdSda)wCY_4my(V+|425uZaTiK=Q^raXn)Y4V{_GSodbWe9V<`eVOM{xynlZv1OQK0!N+=?0uXHNWM*=%y#F1JA zHov_&q=|<<AI)1C`O{Px0Ke{*>~v(w!?G1tMda z$DP~QCCYIH@Vk>Fze+;pIz7F-GS6U3lqHK_`Y!{>C^wNo7=#uqqQI}`3jdD2a^n4d zaum=9I(d?D)3eDF1%*ZZQfLIbu>jm06Wa6xA*gyASDwa>YD(t-WcO$zCA#yXhf2UC zB0B6Xxp`QliO$Ou+(@g#=&LJ&+c&tX*TivoeH%wta9UeYRQPI-RX^cXIq-O!TA8U- zaQI8?9yM+ma_Qo$kG%h zu#tYM%6L^o$R$XD_k&r{c7|e-*CL4TVESr?=KKPx7C8`CfV7UHb1}`GWWo_E6UP8E zz}7?pR_@n_W90U>z2(dvrVF0hs|n*vU1Z%U7nL>wz@ez1*dXDkVCYR;BsgGSZ_5$V z6LG%`S9Iy-N>ED0);1xJ&#{?UZsca?N9oh=zz?6|jN}0Gn+`u(0R_fSqc@60w#gJC z)QZKxQ&h6zurTl5ZB39zV7fBS0Th$aeKGEF}XVTY7*eR)_;2^)FN9< z(T&SQD4WPbNowr;6Bz6(@us{RwVNBQC`yYRjAiH?Z8d%n-@jZwbZd{^w2!B8U9?64 z#owEogVk!!j#$l8JqlMqg~TcTSmyhyAeZnjI9un{{5I=1H+@cE0HQU#)Mx;T7N2Jd z^q=2ege`qP!|Z=OYLe7lzIpd~`>VutpnH9GJ&q5-C*3~^h1RH30k|aYo7{+@LVD#Y zMY~qN3&E^y>y#@Cz2bhhov9ecEQce8k^4Knd057_poH8wT<3a6fYjsPvj5zVT3B!hjm%UeA2Ia`}WYFi=rzGn{n;5htKZ1f3mG3-NvRDhri^BLqRSB z3&tZHgS>v7SLw1cus&8DM(&FgD;k`7!Id$~Spo}>yQe_uqBu`oj!$Y8%ogKPE^l4K z((^eXXAFLv^kJUL(&PknT(T5tS%M2!0-)iWJ-8PL6d{CjSeXn%1S zljzCzYOb-mh8wg^-_>T*3;JdPc5E6TTeQnYO10$qrO4C9ehXLLjTe?N!1mbqJg2Lm z+s<*o6SozvKzE<)PlV+?>UXUBcldx5a#$=930nmjKK~ojJJ~nH3$yqhtm%O+5}mdc zmxLpd{k}cjn;Wi*PEf|MRKryvPTs~3+r!c>M;9n}g1+J9JRPMK_b6n`m#h-GhHJqC*d!MQxW(_(etNKk|zSPA)h zAn3W9c_Y-8veoZdeb&tC!=Et=^CV8p8*l~=5!vd|k^odocOhoCt5QKM0MXxw*|J2l zPOHY05f>ydkkz{|tInqINMhWkuD!*M;%8z;P-~efGNi_4bV5PFEC6K)Mz8MFx-WJ7 zzO0H^eI$K9i--S8=szGcm|R8j#a19FF`TncN+p`HoeYt{zfTyvlIcGY&8oGRsiN$w&QY+nk&mclEwbS(uMM&*@->gBIwq`P5T2|!BO z>39{juH2K#e)^??>hE0T$3WFRt_2(^!#ooqNH zS)6C11CvU}@-z6wbFKsO(qFZnUb&|qxcJ$LQW3J1`G5g|$w_N0l3}iE(pJ{tIj4yF z9I`Ja-mA=PpV| zhFK6bhGDY*HB`KOIJfqBJHL9*BbooaGw{#zyUHW#!)W9s`2-rg1Z=O&Hlcgj#q+^s ziLyY`FKbuX2@J=V+t@g`zdAyua#BYiVz#}94eU{O4RJhx!EeTG-OzBF;)`KQly5aK zGL)6KKs184rAs#4%uh_o8HoXJMy}R73l?rjM$JQRLm-C|02r_O?ixl&e1$Z7nHE{)I+&58M|HMdrq_`CI$Da_w_^rtHhLuLDZ6&pQY=wu?KYccw3wniLar^^c6IJRi?C4qlkO^gRo8NEfDpBTB4v2SH5OMo` z&04wrm&uN8m5L)*_UxXj1X+Y;U~m6Y5wOO+C0}?HkFNA~&M@Pmbv{^gQdITyM?oi*+bv zGX>obpB`PsKK0v#T~>}RAkyi=cR{_^^{QqVeh}%b7Q#~)G`%S=L3z{{;OF{ej!mpL<^}Fcoa>yj9$f2g zSoE--V3*GFoG*FmRtyg}|B@5sUsqG=#f6t)G@fMkK4dgNMi$6d*$!~MOOSNNg$(Qh z1`vzw>8*uCNM;a_#}EPe&4~%&VAI26xKRDV!bDG*LC_HR#qnk@Z-RsT??Pl@Qcog! z2Ae;y7!{I=%ZFG>)sZ=IN|xzg^)C!FzZvJUS+_3n<|_x+UFJq>!JR#TVS$g}x{}qc z$Db~bK+s1M=34@P`${LcgnpN}u$7l`x>fCc1t~)n-M&yDS0i1+z6&F^TRq%((~@HZ zw2Cyf=b0y$@6X53ehn9h&{YB?PwV#ByJpA&n1k&bOcbv|7W}@i#+xWzlr!qr&GSM3 z`g#%{J(zEFnY8D)nKHCoo61LO8FLoT_8?B=0#3t}=!91YtgVk1T1q)TmQB~su8^&I4>;i(tSHedqkRDRxldwTX2&G7nc(5|^GrufFu z2JDdtG;rAN^Z%?mT*+#-NU@1)1JMuld#xY0Kz2GS2KU=G$B({gN<0QjmiQjdg~LXb z0y7R|*hbHVXk}9Q$TrTSg(6bR_b6yCbNw+iScla&uMJApu|@-93$j$&O2OXw<)VZ-}meIGZJ0%^L`IuE((uhi2E<}vpPd}j*T zq~q{JKSwXiyC2MUL9Igf5p<*`ml&gH2z zanL_)ixh-c)*nofS)XccFJzIJc@Sk*94Xn_GJXl00Sja&d7apE;U+{Q^rZ(3Ry3n= zS;PpHO^0>O4a?NC2rV71)I&7lt5?_KUzm1vPE2gWZTa&r+rYbkpvlUx4`}L^>+g|u3;t(oYMg`Z)2sGl)uZ@M zt-lUPlGRTmX|k^lQ~*dZMO1T_A{hAEA9$GsJ??U;LGsZL4V%&u1nbpS!(}BQ=HlKU z3X!QdQKW*%n1V_O^&Q$C%!-pBi57MHnql@CXVxTDj*OHfD~chR{VyHtdpVWd|Mvxbvxr1~xa3DF=t= z13+KT&PiqlyrS?gB^W!5HASuy5;3BP=FNHnj;?QBJ^5Hk=zy`s;mLh~?>{)zjXpUo)LSKxckxyD*$diUU4^k7RUn5Pw?%rK;d>*C_i@z3ilRA4vD_&%a;1T{4X7_Z zIsX`%QY4%?3)qE7ARzK08Eg{aS&K#or&qz-qD`Hh{Cwpjy#WKt-yJy^C2P9%U(Pge z#3b^2Z-1RT9=!{?KaHPxvJ~DjRq%;`Hivd@7P%_G^5MR>$;ZC5*ew;rJqtV6TIp8LPf09P$I}OypKu2 zaKXkg>|%>D`nWnCO2ZMNX@brcnP5DLF=o)!jXnr&_PkAg_fusTVUO_n!@8R8<(X$o zGN`$W*%8L-Y;NYt5q>roa?y=Z1NttZ^q7rpN4M4`3D;9e)J=-+4B;D&Vp5$jfT9cX??6aA!*W@N?+ z;Me&u-}aFMs0Y5TOwMea-t}Z#doNmImJsDK^lfiqld7}mUZGqqkC%Ho{CMN7_+*cl z#FX=z$}HdgO|mQbs2DNL1H$55D+$`p+(LvvMknT0+~<=9d|^9IfL3xv9bv^Zhd{GV zPTcLSLYAtC`WAaOv9={js+t;(%mw==9$WyQ5We8sQZF@d8hDd_F^TlX9pWlvTpzp9 znFay#xGChztp=x9%NOgph+3gMfadm-R5Kg1aJM>_4V4zjD3bbN=uZ)leMsB)!bLT- zV`njE416B9UK#cZPtfMVu2xuhui|)!7K5iddYk|Ud4n7w=DOfgD;uO}!$WF5la|N> zXi}?|wWHXch4JYoMLG@qwwvN0o$!^-xUx{7WrMDc5)xBfo5=Q_7l%WIji=JpgBpE+ zMf9V?gT~oX<*;{4bA{LZlt8BYQPWE=SjLtkx*EX0QBi7^2jWaAelFcJwklM|k#Uob z-l^LKXkJXbG=(6iBm9D|Pl|0Zh+-iit$EDqyy^cJ`ie!RQ=`~ocn(4?}bbC5Kr z$c|q;O3#0sYLLTV$6cV$`zd+N?z&B*r@l)GvOYm7A5YCAE|!lTNSRy$X;NTIvfvfw z@U;uBp|xt3cwf6AJzhIEc=`INj7-iiC3IpGo&DZdFDz``nSD(`Lz4!hD1{A9DJn&z zGf+zq`Gy$l4KT}HjBtuq#}+wAV`eR~7>M%~BaoSP796!7KNvRWG(a?NF~Gd9HYbtQ zmQLTipMu=QAI@!AkSQcEUXcm2s_60?06B&w4H9wX4@c`3frOT4`y8wTUG*u9;?XY% zZVF9dj<{-58`HvlPQ7j3_%(MHSzKeuTSEnrIi~KpWCpD@B5>LABq5VAQZ~u((gA#z z4W54ZST(BJn#K@9S{_Vvm>F&V==aqwgq9XwD}0d!u+~|6W6m31s79B3y-M0aN+XsneB@|>2vPU3hG5Oq-OK&_GhE$>lil0#%zJ1 zY2teBWkJ|3n4<{$<_k~9vy3$ybMN5YOAo_1ISVO0<6FHvbSyBfF+318h9@ck)x_2v zG5)P&vG?q2%dwQF0TpIpBhugrnQfhothkmPW4l@ZqWYVjrgcV%zJ=y=sP@GL^em!x zqnr=V{TFgsmG%eF*hm{}unaq&g^VNRMGT+2oj6hiQ_70E6ht$VWCGZ_p7OlKXuDf;QPUJ{J6aZvU5JPJ*Ds zF{MW+IIoQaGfe6btUe(=(1}B^rdbryr!()5ba*$Q;@y2wGL{lp<0uw_+6MS;Y|jd8 z{63#D?j7_3&SbKTY8bFQ^0U^qhaXBhNRGWz+S(XtcQreT2`Tp8@Y(Uf+zKp?jz{G( zn$v-CRwsI|+ss7Numm<|EgzqtyLfrFm!C(>CWoz>L5i0}d?DADbMTYTbyK1_sQGlqhXM)Kep9Bl45Og*o}cbprY5#bc9Tb1v^n^ z8S&uc5@cZX@kjJE<|oRnuqEPRje`|oVKC(UxhxVl?f|?Yvk>MG4k?=PaFOpNGcr|V zDL|#F%0vd*Zd@3KFxvh>xts>gmG;P_?llFSVRaH98WLnv!)&IZHcdPJAgQI@uHiE{ z;bUjv?sybmm142+{La4P&jts z>`Rik`lVLpklY(W3iG{$B03=U+zmCr1x^%kC2O$O#`vTNs%j>7-NC8o+1KAP+;H!HBUS*6RS0+G&w|E0u~x zERdnGPR^!E%KxAB5<9LB#kV`0<+3wi$fD#G;d1JN+r=NzhiyG5&AC9RjkXUyFa}0` zxU&GMdIotR7!H~85ZMbkc;d~GlGz{=P1*3T_@&9E`O^hjSJ#vXNkE$!tD6K#K5LrL z=t45mV1A~-#M)eP80$0S!tEGr1%62A3hm^WVz)SAdI%nfhn6y>mY5vD zEL|hm;4cn`H$+w}8=#zOi3shqdp6^UZw>a=vsnm3>~}A}Xh=ysBd zclDPkNpoQimsF|WcKHf#THl7ORd@;M%VKCghIsGDfx;NlKFKjg>%ig!-N-Xf_svcC z@;duROR!G)0n0~*_rbBky1SeaknCF~mK_#vZOEXkoKAgOUf;+UOcwSIz7B%; z`uR*%*&|mVqzUNwHFPYR^?V(joDwj6{x5|qYy1yYUl|p3+kI`Iq_iL{-Cd$ED2Q|o z-AGE8bd0og4lPJaH%Lpz03$;;NJ&c$!wk$j-uL~yYyH2m)-1l9bFRJ5*}r|AONrg5 z(46G{lS|>V#28p4&6A&xe}*Lyeji7#tu751dwqciLOpCa4*K5hjRNJO>fdoD-c!9d z@+g|Al_dO%i=oN*Yb}%Vj&gnr=LnOY-A1OHbMkkb$cSj1GLksbhoZ280wBBRR6n(- znY}y7=!EJormPI|ScM{q5awu#T`Bde)dR$}^g&KOhdn?7ZPuH4p4$Llbg1gQdUb~M zM?S*-j5&CCLm!+>_coOv8`=p}qiu7QbuCy9L+Z1;IIYOO`$k8DpW*I>db3sPPP`FmzD2vq z1FOPWxH5};&x=pmpTQ3c(KAb1>UVoB2TN;i-pXm+=Mc0r;$*vaZ`|jd!LQBTPTTDo}vbF_7^tl1utkvuL6Ys%9B6_3$Q5R=7O z&#CP1QRm zC`h6I(2)J%*f^Ia{M?kAz`}zw^45b8C<{bN)lI&e$Tp-bHcZFdpOzW!gdkVpT|{{j z9y>|5LFgn;(7^Vk`f`=Di17_P#US_>TMrfsV_0JLxS>#h>JhTr=wMAR&;Ph}Xp3)a zBa$J=YHw4bwe?TBKMrFUvCMGbkik>!=k&fF?`zjee~*T#D~!DMB#mCm4tYlq@UEl; z$Y+n64J68ZL%Q7PCgIhR(-dh13~2Pu|B%~tR;M7g_qZv{a9TxY{J7n!7N@~FTXY`4 z=fW&Pf=+_^GleU21V^Q<`kef-s*CuG%q0Y@#I)9o$!*CQN{@u6yx4HD1( zAi9oau19{NP@8Zzg)1eTNz@nInB?J&bJ3wRnN^41 zAvl*1?bhFJPas6(xuCqh?|MfP@$pB4GSO=Zm=JdM-GmN$)|AS@h_DZX${j-bQFSj#wN;TV4tF2 zNX8*O-tK>DT@-j%Kg$N(HrIXUnS_06E>eEUTF>|bRyaw-lHd2#g*)9nz9&P*F9Z*Fr;J=B6qB)4^ zJlK4EOh&uI!>Gb`*9@bLghb0{hyl=ZC}TZAWrA>*ARgmZH7n^e2qb|AE#AI;~afat_uT3ZLqGxlFrn@}o#~^2klj?Gz z+Kb&~7@VwigvYXz>HlZO^FFhI8fng&fR0@mPnbdguZ0GTw7(GTx^|T&& zfkO*%--%jN>v7xjTB+j^ihAw3M?8fY*18p2)?l8Mw&JHgVINDeOU8~u2Q!M494DB1 zE#5{?iQhe@+@KT8`TX15tEd@8Lq-c9xVh90jCoa7-@@2k(5XxD@j?KW-=t0-{Y|pE zQpa)jI|`1C9)9S~fA$9ypbI(>@21h3pI*O-OxKPBR&1WH_^@2@AtY4|t^GBX@?9pH zcIr=NpUi%?#FQKj{y_+Ek`P%(s`S7mgS{1%NVM6^C{+@5mi9s z1qJLItfn|spLOdeD*N7Atr6_S`tG_$ob7y=?zlX8C|iAFn2Kt`)scEBS11+NMaWML z1$7%dfWH-qv`edVtSCHl7m`AE22^UkA$6w|Ob;+Jl@7`4UJVOHCtegnh!kr^O@1KP z(03`KnoQ!E#pC>V7|1%4(wAFdqMgHgz+e!zndK@;iS{nrKR`0{IYgr|Vs6>?wxxUe zx(<=^OWTQ@4gTwYPER5d<1gQed2*&}r+-QK6sX{QEfl+5K**|mWX7&ZDX1N+FfZil z!6_&|F5ggY+?tk0KJPtgLD@nsJV~UY6lJz(RSI4vKxD{F#r4xGXTOVA@?RKJdD*N{ z2&!6eCG+JpeQ@MtpO;)m}K<;d20*d88VkwE05+bKl(k2Jn zw%64x)w{hhzW>$p(|3|6>_Od9-85cjJe1iWQsI?_oi^lUPMQn^3*tZM{3Xp+|8w1U z>2_n;O%12r_QkVQDcg?$yArl%>sH<-n_`Nq4H-AE8xl@LFxQN#$--_hmW&v0V*1J^h3YvZ${xL?vNx3HDOX?@Ud z|B9g3H$NT{NizTClqdRkfw#+Nb6Ktde~?e(hx8|;eXlnnYXE4}IpsEC!^I|Qk)ah~ z+-ejQ9QAK{mPQ1WPA$=k8O|Khr1~2~JaE3)sz3Ekdgbcn!wb2 zzX=eiHk-eXg~?xaabGBo!u$*glma56&-}4)8!{H$j$dfWM@MnWv`(C9*nv{t>FAwulkp(JHBfla%#{LBL(#&CQc?3Ea>#>>^A?Ok*zB@|r%j zE*tYzo~a){m%iJSzWcpM)cS}AoxOGbmtHVw3ml%1kx;mOO}~=$reD5bD?j{V6f2RP zQttUAi@CZIg_44>H{Pn+$6@7Hd_w&Jc``B1k9T&{i5+{3uvAcVd-7HHG9fzYN*OPd zqaKXQWZ{KH2ji%?P{;N7M6AbInk<_0kHgnwQ}!(TMwWqJ1k)J@c_#EEg_#{BHZ|F% zJ1VP{qAL#_pN$JS;;}mE+>O+9Gr^g)SK=~?h^q#{Ni>V&)u z@I-bRapgwj?r}@Zuqqq`(JN`9tUozEQQET~Rd-lzpOve8<)D|K%YF3BGD_59Cg&Z^ z9cnt%G~wmN8#g|clCiz((X*Awnw9UW{ustz@}HlQ6yLUL#6#8={J1oGLA-41;qL_t zZ2O){0CE_W54HN>c@N9+8sn~^I9J1IojUYthd^X~xs#(AOtIFb&E&o=+oUl}0NnAn zH*ZuA?fwSVG@Bn%wGpp4eA$M)y{lhln#hsxkoza*IITnr;g%OHlkfU72i3P(VZi5I z?%XfXZ|?jmW9k?mxl4baBxh-Phfsi1_wk3#3kO$ClVd4GJ(R)`4Hj^hNC&mN&LDh9 z($#-9FBqYuW3iMQM3$7xGA6}jBs9u3^G#LXF|@`#g*C^5T6r;*o@b=IF4L?Dhv~&D z79h}SnK@4E9X8h6+9uBF@`IPF9HI@}+c#favQal{eg?;^T^D1JGYqxQ==;^W%AzuB zo|_F>I)!{Y52UiB_vC%@@;Hx6sZ$XBonBRdWtX>izv-;w`8vTY8(zbe#Q{@gd$`K7 zT`C}L2ieHkNInO{8;e%9l?DuILD3ggG8U_Y^Ogpg7I_EhkA5x0tT|{vqSP{7vT%nR z{gQrvAupMc`{ke3d(J;~ZWaXGm}jIbv~u?6Bj5Ucfl_vH%zmBXcZb2FAp=Me-P9{br zBXI{EVj9!(36uZ4{1X}k5ny{f1~af~RqnRH%~oRM%#mCbSew$2#mSVj8@X_;ttsxO zcb#v1x8)%%cH%*W3sZ2*ag zd%0YrULR00Gc3b~K=qgEl%2WU!McrDw!F7Y*SXP9?lbQogf)L>xtO9`;*c(9KWSg4 zVO0|)sX_hnKK)ft>?855kNjp?B2w1Q(}hem?P^~}3nSL?Z8P5+7(J?MymPAHT3qh|Zu0hXQN}=#xAdLMfLyEGH9w*f-5hkR8#R1A zkbC(OlAbWkyET?6+~QdM-{D!BRR7ZP9a>BN){mfQzV<mnV`qPw`7KZERZ5sg! zG`T~sT|2Lqi#;kyOC~BMH(I*7w;U(OSi5S`HFYTGH_?4>Wjp@t$D6_7&DQca6;Z() zVT$#)0x6q8H}Y=dN_%nE&T0P4_JFxM8^<$0YHk@f;6eK1n*yDHJJ#GKL-ppMJZL3Xc*bbPaVjp{ze&cdf(aY6)B}t9YE| z{o7HL=mjJ_X5ggT#Y~N+tXrGzF$QE|5Hrd-AqVLXu=l)8ZK%tn;&FE89)Jc;(# zdLUfSWgjUhN|D$_y}q`xY9x$RrJYiK+SX5V2HX&e{8f5oq~)srTpp``i39S1z6N;o z)J%3v;d@PB3y6z&fb`8%s>mXr`Qp&>Y0C_r)MCb#n}8p#9=alNCF+fOk>Q9D4<4le zo4DDB8)UHk;hYG^6n9BBO(6MzYV?YtxPd7_$NS-e%g-iw!uJPekJHoOdvK#RmK4im zKEtZBCKIvCR{8=y)ui>#C}Hi_v&9#eU5Bhwj>W&WUbD@O>8JRwmc)yuImx0m`n@KP z?b_^&MV&EE$Nt!4A{e4r`leI`rluOT=k<+ZcJ$!2HaT6a(qe2lKj^R(yt$J6YIW(a zp_*n@<8f;MB~#gI3v%Zm-P~ z@q$1;5_f$q=WBX*?nHlS4Tn)4tH}J##Le@qH2lKThQRbE)b)iyI7CrBj{O7Kb5hf3 zKG6_$+_FTWD#yaPAJg-nt7Afax*I7!_SBFFsdVhsdaF{beCS1~KYUZb!MJ5k_+wFl zjFmGZB92Oj&r{7Os?r+Z-M~72ut|InD4y~&9?T*Ko)=j>L#-s>sT`3>Q2c~l@mG^->ud!8r$WIb?olk;YveXlT9_cwtXsi4W{*o=rL19_p9*U^gax|5>0>#Cptx2Xu%Zr*BZ`36U>g#Dw*i0;lB;vMm#lO`_o3YYZ?uc}nS>68WE_bTXrCIc~Q%F(5 z$x`=BJBqbPaJB;~+4oPlqq`!Sw@T-Q)llSZ$6m~Yl!-3tBn7LEHCk4K-I$La+fVGfG>n)> z1ky!Gwk&kLYrx-EUO6diN}Kc%6SQsp$uisgs?jYjDqwxxH`?52`a3_@X{WivavsQM_{?n?pK;6iR|^Q)!e8Z2i)ZKaataM)wOzaX;}-AAnp7#bkPy-HubZ{X)yk2w4X$MXSA4Tq#ScXNo!m|cmKhX^ z-H0q*4S}zqeIwKm=B?be!3IOUEH>|32eX5Dx|TIGDOkcGll-E&X+U1N*&wY0T${SO zYnnE*J1O_Agk^v#x+w(CxL6@*usrnLm-ET^TnyqiY6+OIje0iZ8Bjf&bQCX%4yjWm z2{2Y76?4N&UJbYlFugU^ry=u^xwPCg{hcGQGD|bDoG}_Gg=PW^RG&tLk*(ljqQ4o8 zP4CkSw_~XfBpocnAH}Po44dY;fr1-gcPf?@;+Gum0hOF@0}CQkCf^rhy=3I8<(a6^{&Hk87EBVmX&=Ay z_I6>-yV zZym;}qEcx?%~^2}<8$W1T!l-Tg7z#q^l4`#vHM&tD)u9pb86y9N@~L2L}=t$?AU# zW?l}uPKX-jYWS@wH$FZmL#*&6UP9rarj;|cD|Eyy_O;?fVG6Mt&spR;u0w8tw%t6( z=^Tz17*PKr`AiWt@|Z zm$oQ%GfqtT8t)UnqJ5B1(Y^zIeA;)Co_`Ih%05=)9>6KTdP^( zuxi;;{NdqbKe(E;0PwsrG?EI@4zq;c@skHvw64~b#bAmmlNi#;Og^>8fj5OC`4aj@mTNfXo`kx9#%%0hajEQ zkFM2scTBSk&+ic(C+Y+E!?~ViYCVISPxpjt?$8>r2rJViRGgfaW#^*pv4f(XUnPUf zYiN=+AEgXEHJC^)4=R;ylSO)fHWA`6qGGTzI~Sy1+?7Lsz58>9rIxT8#}-QxU72Iu zv^;!3JL=OP_TxDgmThKC;Ej52c^EkvaR2FTBqrt`3Ux&%hu+#AGT+Z^Hb=fnF*E3e zRgSPzitX5_0JU`r_nHb03kVk&P1HUYJ#e&srTilQkc)`I{{;M;Zm+49>bLDr>?d$5 z)|!IX*7{()E`;Jwb`bX~3Z%6cnZC93O1s@Fr>sCoN(5JB$!-yO{{{YwHeU~(%# zx$#$J3T)V#KW4Q#kGZMNne;pB#nWf;BZ{mY8-!VV+^!N-A72;|9t8S5n3B4B&-@}+ zVcH;G*qq<*HwL>MIsj6f%bdHMOw?zfp3xfely7Tq8-cX zo{h;KR@O6ta)$h#hEULWaR1|c($LETLxg-*(VEYVos7(FkkC;rOQ3tOjEoXYCRk@OpyMaNAMdsBeqABEAaw-Hi25F_O`FzwEswZ?y-i#VVh<0ZCVaBayq~dj^A{P@%_Z!1vvTu-MtalRw*(er_d2XP z6I5Q4y3cKrx~s|rnXaOYPpt+U{`ND=GC@d8?*qHR=ZBB*BaDrU4lrLl0!9BhKpNtetF zG481Ra0z~WUs#J@B!b?Vg;WoGFJ`}pnRT00Q!a}(H5b*}-SpB?<7nnlQuJ$k$LWx_ z2Tzfr1wVRy!HH2H65p}sd$|#0!>x5)$Lq!N1`?@2ZU2pn=MOKhJbwcH75C>dNGb>Cj5mo(34TQdOuCE1pg!^#;HCRLXmS*V6cCW+gyDmM>W zcgB8$w#o(kgr9I?&vxpi-0 z)j&$=_Sn{|#jMWsnThv5_o8wqWxwB@l=%UGzV2kT0l5H5U+nAKEaaWiQVk#bzz#BP z9waCFM>X;6W&go0Hd6dv1G+s(|BF1W~^{o1b2WC!Jw)q*%LRtJ6g=A)6$T$I6i8zQ{lH*!%2y>oc;Y*^@dF2#A!vk zS#99-)#@aFjdBLTW45&FI?XLecrrlk>)@U^X#E=8L6$IM3Sc7(-5=c zj$UvaEHBN2mdW0Ux!OceKq>=Kenef+@K*0cUQv+q>CP!{DB=RJJ<#a9|Kj+|$tp%? zI;d>C_w71RZ7u6e35%T$IxiGMt@CR80DaTefja6`kkEXyma0^}f}JDkb404;fpVUF z_mDBi{b<`X>F3>Vud12kID5{d-K*Qh{gkFo)2;pH66LIc@`ED;yH4^9GVP@F%hS}z zpgs+7rn9e-r4TooQT`lrw@N?*aFRRlk9 zBW*~fM6rS}9r?D5>=ud($9?&sZHzuF?YOTylI{4D_8 z=y5l)29gW91O+Yx!D&h*KIhDjs})CuE({ms3HESlIEan?R%qK(SPlM6d#3pYZ#K}= zcbTpx!bfdHq}aTH=Z}78b3?PVxl(&hup>K*C6@8j7fvkFm8%OPsXhp;hcjB#`~9 z>lU%r6=(pu%U7$|O4(`=>x%qtUCBQ7E9TkRP0Z8Q$YKOe(-z#T6U2mCAYCc_S=B7nWc7DUU9DqUmEg1E^p9 zhd;+jR=w;P0kyr^_#aRW2!6#pW86l{m!m1qWyen>IOTj7O$mSGV@?oTC;p&*-_BYYUNFIWum9FDDP74769UEP>qp#Z3 zWS*@|5ld2>2fs0?U|E0jOU&7t?V4cAGpnyHTBR9}LL-X1N2KVphylMvzDcyY{+8R81;O(9}Xr?l!oNi9cR!5)f9^ZS?S+h|o7Y>lS= z=_r&sH9B@Pt3SlUv4&)ub_!7RhXcN#)M=BKKRFA{>0+ZFVn=VZd!5;$QK9u|rky6J zinRcp+#81V-BRC#bAC!6)^m3{n|m!+Br83Vh4aTu<;wj>bP~W2owJ4ELth;T{!tUtVnqD~UKT#<=Iya}|V*T<~N7U=voH2yaO9)*JDAF`U#^~D%S5+wKzIC_WYhd23tO(FR&X(R?KA5 zrj_N0BCpQSo{p6r1dp}l!~#oF@JA5C(^-(xR3(_13KPLrRHKOLXa1frn-tOUu8odR z2CJZ^#6g1^V0gl~XjkH0ncrpYV86lM0=zzGtwau+N&?Q&d$@PSY4eOHa*!OnnO;pdfB23h%+282M8NICg<{R$j%Hn(w#S(INqsK`jQUki0Dns!ea zBu>o?i^bPZ5fWFb1S?Y&JbtUL@IAD@sNz*-Q3YLq&aYjKx24)V9%I2{KMT>z) zYrVe0MdgM=gJRnz&nS)_u2rBReH*J0bZU_31nU=$;fim&VX&YJ@h8?`-Sv zj4SwC_dGyT zpP~sHu@!%mTXQ+-h11%m7FBB;Q1wdMd|iCA z-1|M$TW>rP{onos?tVh+Faxz2NOPL;6T~xzpN2kS%Im*6xF0iGUx2Wo-CNTO0YJ&Q zOVqA-Hm4}9t-}$~f=X?E429=qmEfX$eu6RbceShZ$>oHa2+XyNM-# z@L5tKk_Nr8^UuJSm>q$3pV0Z23Bs0z9cfUvXZ)Q9qi3=={5$R7*U=1JrNUTGN%AZ6 zKZs=t-koDjE9kWbPT2|fM&H^sZvc9+k~K;A#(vKw((0dxvVQweQl)Kj__lZ;mR;_V zc5pxI;tQQbHq7*bC+Hv@#-s}u#g zaw?~t0<42r@%~A2$lMPX#e%w`)$DKF#MI)Oa``6f{Z|tvDOOLD3eH^%KIadm7;|h& z{1GMBeVSfjLm)V}!^WdyjMaQZwDiJol{78w$LYdcYQ?GKYw?nu!BhrGG_nrW$WlIi z^VBsanbr6dAr3N0UU*~Lj*L10Up~GtB4`BG-cS z-Jj1496Z`r@s!%05*SU!XwmsPKzO2j7(V&wf? zI#7~2{Xz9;{RaFn#t_WbnnZUYZDGJPN#nAF%&dPT7vuMB6zf1;H*@ZpzvGwEs`wMDzhoW_~h zX;Dx@)r!3B3G2Hl(F!Y=cHI-s&=1tPFjzwm8JNnrjCpKy4=Jn{BPGs9Zmf41^FW8s zE3d@VB3E@sqI^zpQ({LjjZ=_1Viv^M!@!UpoeR{B!G4xoG5N0Ah5YVRb4*cT1Lc(x zD0VG_Oj)KL&DLbF0zP|NHgoR6$hql1pZ3l8_F%jF0vrk8|4)Se2Q4s8B{jTZNIK{u z>x;xK#_!k6{~*=!WR9ZkP8{1hUxT2JXEXQ&m`kY7K+rhSlPm^ zlEKHHlN`C^x9mA$_-AX0RORe)Zj7HQH8pz>X!(L+)I%ln+#oR}fb-^O#}a?m#0plX z&MyjH!4i`y(@u1&h3*8GS|0l$JgSGkF(!OX@al=ZBi-6I6qSqVCp(u3yDmUMLw8p- z9kHBClNDZOZ(XZ=<Wx-_{7@v(2+mXT0& zVv)or$)R$d3jMU`70{*p>eUK+_9UUtgglj$@J1+0y0A^&#yz*wF`*;vbALDpps46%b`dehPkGK{unjD_ zQj9g%jbDUU;%9C*Ynu~w=Tr)mWr?xvu)?phk)<>IxJjgSO@QXf)mmP*2L^&YxGPs} zBG(@%t4j#ujt!M`xb?)YXi=67w~QutDa;T#LO%cJa|Otu_(?g}+)=%t+oy>1ZbUgl z)6o}w3L_RcOC@BRzPR@5U{}B<;83|1_wfDvvh#}m(`6drwvuGQ-uJN^>n~ZdFWB>w zBW4hlMD`1g)l~dcODg)~nFoB!3H4h>hQeQVi*jha7}4>?6}8<3sV%f7pUp@Ub4HB& zS3?81jGd4wb^zUl4MgvcH_>y2neC+8O%bXZy@o0YYxn~vL{8(h2>NwzF(s)M*0z-t z@`}YSc5@nlrGt03-|R){U9TQ}+kfWn|Ak({ITa9e25S{=L&MGSv4!)Do4IUhA$)|DJYN}fLNff`qezQ%yO-?lR3@M7S(>6 z$Tf*h)~|~!WEbI11#fcU?%l=+#RJL}h83@#yjz={>qKn_YfzI3 zuuM05|MGGqcyO#r`C3cbb;pF1$XDBj?43}DMT#XN^JCrIAlPClJS9ZC@ngsb>Y*cn zGLbK-#smJ8x6YfQ6{DREJVKD_cdMd`;Q01kvO9Aj%$jmSiY; zOS}bev_326KsA?E>3ry60#$CO^c7A=e)EsBwngCIGXn9MfR;|oIzDcP&ix<6z+t4h zF=Pljo8r$kyKDQo+9)Ev@TpkWpnK8Qk8dkoIdx+n`15UjU4QbMGd}Cm_IoN@qIT0S zVFSt0zXn|?upPYGD+he$kd&Mp#1bZ5n2Yu))`egVo1pO2Sm!R_@YiXYGU*nMmypT_ z)?Xp83)fs)?QpB@M7w7@TB%=un`{7@p+#5Z(;behugZMBq$Ywn8fctnZ=CQ*iZVF3 z^FK&5r#3Hm+j)=`a2)eS;8yMEa`BXFKG0vM*Y2u`aX5e|pXJ7T&V&WT72+<*3N8+0 z(I;g}J&Iz|d|nolr&BIwbk&r-`6;ZZ_a-gV5E2CDP{x-b5A(ZvQA0EWQ;@P%s;lv> z=`n7Gx~c=s#QzW=9A{DVW6F3Vp%UVm4rde4?na@;1Jq&3uJqjY@57nx8B6pOra=Gm zxc{#}?G8;3`poVePC9(Nx{>Q=DR5|M5hvm^o`R^BOC8adt1wMe(?ZxE;x0;K<+nRV zY2_Lb8{gh$XY~fT&L;b?YV?-G0p@>)>SA<6wPIE=Tlq#|9tN5cZa}r{kd2>iHS~5l z(XzFBrCFZ-g>ib)0aC+$wqGmIgd|v&stz`+W_ZKIPnLZuoY&?hMaL;m*`eE}ZNUvt zR!k2nIYH{d6*`_oNe^1tfrF~fr*Dhb9-Vi|vsb;`^uT_AXT#{ESDA1tcHAt#DfYx8 zLztifAvH&ug!fcO$I#C?lJ4gskYui^SMo;(FRAh*c|7uK-Qn!kp4GF+)^p`ILj%P* z{Uvda4_2F{_?@rQd$V7EU>vd{jhdmzw>A!+BVpDh}GT0@l!c(&daz z2mY@mTNq#Xvnv_i;GCzE{94G&^F5^@Mk*^^TH0Aru7&O+wIx#kw%B{lXct>8-OJ+b zni%P`*)=rosQkZ__a7*4y}0osLiLX*GXNSI_<5wiNVoVFiZV8O_-6Qv>6U*`@}lE+ zpekk@BB@n3VQJF7^8frQb7dYiR=rUqx|k+z8ygDzPq|d z|9lONTKOA@5Nm5M=}bhJsHnWsIE7ZL%0XV{k?3e;FKxn()BvSrjCP?q*C9Ixlh(dU zotfQLsU1XkGs7EY^QeWIPdLk8Oys105O!iOaky>lP|xRS*cDJw zP)duX>;Pk@IP32}%>@nYN)tLdH0FX-SS|J&N6X`-ggok|GYGFDwUd+#@q+ZC4C1D_ z;$6NZcd8|s<$wYR=3M;J)vdixN#TAh{|6K>qAnmI9pfN$t>>sZheYk3sUB?8IwXil zEKjIb-hzj=TkGU46dTIX5}uM+a;bNYtn}wIaRUpQ`Fb!xSPY|$b`kh1&l`JFyGAfL zygyBZG0jQ0rUS(iT}v#Y0~C|gN^zaa4a2(?AX`2cVr6`(FZB!&_cl)vS4GtqtUE%G z{b_1?kyeCcHfy5umE3FHHZ5IAwGzSN-w-$vB=_BdHs_R0pt@W}hM{9%6>ZYs*S4SZL>`4rEu1$R1L5;C4V_^uAd zShq)eDvP*YHfh)k#LzPguVujdL|L7wlwCJFGQc0^3VRKm_*(D-Dn}L%0={sB^j9VaJmnz(eXbh6Eg1EPgDXB|dyO?E9kS(z_!s1#87wd@1XDtbfjCXiO3Dc$utL)VaiQHeG-N zhd(D{F0unz-VYFKZ8N#}gMd!OSHE)5!t+>%qML->uX_1WDc+cs5B~$of5rjdl@7-H zRFzC;ZMmlUrhrA1lENn^7r;mF=Um*H-vbYU+*Yyq0u`|2N!t`+Kr@`WTvkjZSTm}g zq$#?204@=Gp$A6fz??{7Y!_wzTEsM~fte|vG!L49NuP9Z^^WFvmL5HrUYida<+Fyp zzu_h`l2EUvF~MPo7LeuC-gJB;?Eo5xj{207hE$gDORQ7|g<%$3%lJ9#Z7T>h^)F=5 zREQS)x@ih_;BH>~;X^KL#?_u2WWm&x{ATXr7sFgvp(46cQs2HVr8O-SnVA{4`?0M9 zbT23qs%+Vi5-Z{=B?Z?<2Q{<6=J>N0C2Tj6q>axvmA5uRj#(`&zf{Q|Y6{(WHYvZA zzQ3cTK`l?(Ou&kBon$XMzC^Uy$b+Pb&__Po&*aREa{%%kubqt+E9(J#RR06J|6+L- zc=`}`cwH+mXu}GzaEsWEykeMN{Z#OE|G4t)Lf7)cqHk>(Tq^=r$(Pp@EDg0Uw|#sq zH`uoYI(WXF%&;?N7ZpEf59=4NKx@W7RjYdVnxM+*B){V@=XHOK_UPE_tl-+Qbk_kl zn-~+DO!d&tlV1(b;t&kcb6`a)+-D;O9d^pqgZ+oPZmjFFrT!7Pv$ZfLY|ehM1=+;` zu}B~2ledh~PR8%S@!G_`0?gY3hOj!n0zd8j5&b_$A}W_5v`~7x@v>i{Kw{Zmtws7c z-*KUoyP~+(NgmPcyh4^V-!aQLJ5%<^RQQ9J4{S#_P}3w@q)k9biNwx4o=0uyZud^3 zG&Yzo4(W4^{}84$J!nh4x%gwD=m5Xk8xIAk<`J<0ox#EJ6K@M?Zk|VWCns)h*oMDy|zpM$t zhcqyt77XGrs!iY2MkULIU}JhlV+pD`3eT1lb5ZWoUIjJ`KG0n1qyUFmdh0 zFR&7;o}AJ|33+@^waZ;LW@QY0EA&6#!T-waVWE490i?c*3waJ9f~)^n2J=h>hCjB& zw&2|?hM|&P<5fm|=F_~QO7b}UIqGNF-(EgsS}i!?a!a>S%f_33{4L#9m6Ef8_sl-a zZi#1zZK1{W^Yh%n$V7L?Cka9oHg!uuD)z3qxnMx zE>!u`A#fbKaxEl{#mT|2M8v#WeCi>qtq?fy$npRhYQt1FSP>c5-w{uvgFB@#8uF1~ zui_j|!JnUPoykjor}$!4-@S0&wX#X~=A2)(z!0DTsCbO*8LVM6Uo$may+d@P5?1{c z|L>7>*V1T}t7LMoxeAwK3iuvqe%&jl^o@E+oKd@uEx0~eYqYk=YL+eTb+*!R6psJB zf)A$<@iJC%VqaAW(UAwmmgC)O=j86ZSKlC=H>iU|Y1Z}S->J+Erb6UOM!kaWESG6a zI!2@8DM3J_#4ehH_j;Vu zK>d`TLYz@tTA4a#)`{lfzL%m@NFOB!_M2BDYsphud&_rUsVaRE9Rf`@&ly?e0_{YC zqkBJgVV^g(S5;ch0R~L-CkYRQq5}Ik+`8vnpQM(Aq{nu@zjGi!(%>%%S~l4j8u0~N za!(S^18YpkmAhMzrRvbTyGO_83jZfW%f`-`p9lb+XT2k5eRt;&C<;3A-mnHqqHqh| za6M{}=JAh~5nv^qc8Y~A<+`QFX0xiY6k|<&;NZL6B$a93UXjY>o)L;AfmxjkiXEM1 zB*WN?@dp<-C(KmvZ_uK!R{6@`L6(v7P0L z++VbI(v+mmkp%`;N>_AyQJ==%_Pm7bwk>grHi>4o@QjkvuSZWHYv`Gi>88aTnP-^Ao|CP*=?(Tt;Jg1* zYQ*l4?@kZgt$87SCTjdVAsdsI2fotnoSqa+4=_9JLL1H14T8tBrcA-H9uI%j7hoQ1 z26*JS)Y(2y_FpK*X|R_s(k`L>oK9+;#qg1)oHeC`ADAy0lU!xp;D?wc*R~6U{lCiI zGN28v*&2P?LMd9bK%sbWcPLhz;O-in;>8OTEAAmU6faJ2r?^9LhZZeff&>Wi7D7;`D-e+Vbu)W74KA)SF zyx&TDyxAu}NFHyD-9eEkQ9UfrR;D}sdF$01LZfGRI$B`kn$Kk~l|k{<2WrV6rm>Hu z_Y|2I!^%tO_mvdOdKr&H9gS7Te0qqo1JA`fO^3E#|#6jVKfWuGn&Pu};$Qz(nB zM+Dv#>2AYp0#pl&4kU=G+@UESE~g9hQwRyXnr&<*ZtzOA2m8HIrxP)Jogr52)!%g# zqefXVUk9(8t}S;|*4q^tpq>CprJO5AXMa`y$Y}rhP?U0U+`|!@DSsrxzZz8%v{yTOxSY8?v{p*-QS6`WQIkuv{M@U8Gn99&$O(FTJx!I5p3GE&a) z!7X)!Gy3&K{XORyW4$kVru5F4k}X9DP17>KM`fEt*cFSN>rQL@Or+0`fu})Mr6wB$ z+y)M{MC_RksKsVje8Aih*IL4=auM-)8apABDiyQ|NBP6d52RvOAr%ON-0QPL;kA^- zY^#o0Gn8kYW%svL!G}VhlRB3|>xIahSe(bM#(#_|di{wJ+|aJq8~B~61GixD#$vQ> z@fpV)K-mSF{p}MqID2#>$L9}Z7@%uRLXyO%D>(4S5m^ZIyE;1&tL4tDC71>KNof9r z=uRFL=FN*#ZiPC*&q=}4C{`J_Hjt|4x4Y9LJ7qg2eiXp;~TyNOm5o7(l-*fLC5wG z40CoNiaTs$w2hsusO0BF&oyd4)(jm{y5H-Gh3(^`CPShoWA)#N?|##)N&H;7$64HG zc9d7%@g~`G6iL{?nJ=&EstrojC=qTwilYpwVwW%1&(;6rFeY5&P`5{RM#X(EQ#ra0 zduV7LztP~cW71djXcQK@bF{Q%eJw5P1%En{BqltudOr%Tq zQU!VhopZW!IxnO%bYN=6_`xGEp^9^(z6@nua)!<80X`t^iu{AjT`jEkwW-33w``?5 z_p@p~!p1W!MREMBBdwdOdy93>8OYhPyOAli(H@RWKPy5_W3Y&4b|;K`7C0@g+rGAC zf;37=zP9fq&fhFS4(o;%wYb&oFM=Mqt0f(U2mWb&LF$&TES@Ul?x{bizN=I?8ORa!S@#ND? z;x0RDJ`NM+#=8Vl-SVToUxq8RF5rea3XuyzlNeo2?AHQwao}sEx&z`Q#v%hUcjncB z;wUc*yk0^^f{gReKE2GCnAR#9)^Z9zT3t=rntWTc@rqovt$Tj=KZV$#{v2*lB>jUbDyHf#;10 z7_@Jk$a6&k_xU%;bAqFFqJhL=n*_*n8)c>e7~_H0c^?a|ykbN!>LOLSex3#c z7h`75BTQ#Y1ul+{lD!?zp%m48>PL-^)!r>H&bl>(zA76O$raFk4`AdIXQ%T_F7J%{ zpmN#cZJ{v|K2MPion-K2Lapud{A#O7v`dF@n65JIdDCA5H}s$S)^KL+$e@q@=zRnn zin8?TaSM(eaBV=(-sPrrYo|<;GeP)xxDJ$ixBtQFwm39z0rZ^vHk$7y6jvS+EOOCS zwwZ(+-Ze4bq+8LU#+o0`G<8%%hAS2w+z^~AB-BHRRD@yfeA3q15HRCTBtXWY4G5Q@ z-et{q>gCZE8mq`8J$r}80b)whW|s4rOv4_}1Jkb%k|WW}AY- zuUMC!{Crq9cO`AluIk6r*O({;kjm^!{I;aEUBuhgpR3>^1eFZ~n#yFl!J;A{&*tp!a`$Y+1GG%LUTcD`~5>7yKCNgqsW5}B)2DzbC_ zHvb@o?j>Yj+b3@Su9)RBGJclX(E}JtcYYgG8 znaoUW4$~@Bn(F(w16Nf;CMgrMeOb9DXAOqSbxnQ3_w*pvYSAZ-n;BEaEhQ@-M7kwTcA{KjXS zU(iNP5ovznR_sErQ|nmYECb6R;IWpGCWF%109V21O%&M+$>hYfu2My7$w|dJ#MwNF zo$YFz6fugNwiyE2G>-cCwYNtWf#BEcz06AQusb%Ys_k$78Q~RkZps;E8BHZl_x{5# z1)MNFA^iy3&Ln;%b8$U=k0wKt-Dqeu5CB$c=p7?|UiE>W;uQtDl_CA=X9;}eIueUg zh5l#iXc|uL`t(e+ibFC#%2kp$)d8z8+Vs$D4|f2k389!fCG9ZjY2Np8u))Qli_tO7 zX#3>m%k65(m)s%@bt2>R4o6Q*+jlt}hF-aUNoJ`lz~|VZHmZi}v(RJaiIEp5N4^uM zAt=spCq5Tm{v2vEku720I0Cikat|D!7lO_hlP_p+zV8!H2)V)eY3z_!rGLD$*cfCD zH+_w zDN4V>QhLMD{`#{}MnnZ*78q*Aq?1ZF4vah{LYEv3GgDNJVi=NaoxMq4f2%YgvomA= zcC^BJBcrsOyMl>wFtHQOe25G0$x;evR7^mKxRkHUnuIBFE3zkUhf2}uMx-sHxAIN{ zcw#8bce}Fp^?k-X;hTu1(pjbpmXRZiR?S<7;i`VYFO0wF?hG{Qw<~5NT~s1eVuzuO z5g6mP<0+?(%P63E&VDfj2})A2`TU(Tk4~w4rJ76L(lU* zG8Qrj*fAK3saSEfk+VfUr%jguQJNd+QfHI8)!**JdscE}s_N^_1jT=uuC_7tGo|x` z_k9bdl2W1kpK;%#&*o$sV}?#RK_@p%YC8-Fzc=AJaTU|L=iI-?uT7aae#(8@V&cUS z-eb7ikCq!|>S*`fz~ceZI3HtxK?MHhYlPV_oRc(6cQrI^x+f3>E}yf2uVE|fLs|1^~8HE-ciFf zrhC!|u*_bOt#A&Se~?U^1AZ(1o@7mCwQlN)|HVAb)lj?2O`@hSy!UJ zp3^D9&se*VS;%~U-?@r$f6W=rsu9v(6+|R@@~{3EbS%XW40bZ7ZFRM1A-bGh21ISECv@u;Va?Q6=qDk~QB$zzcXLszUK03*PD zUB`lGq%oMOZ$F6^AV##uH37wiCaCoRsF@5wJrXFzEAJOX7prF&=9i5ebU)XbjT?e1 z^=@~~D_oV7KW&+$4E7|hH)L>{GA@6lU+smY#UjiB!8=2pi6CaIDh9^0b5W+2h)9)p zfp5L-eBhfLCO(rh|7-KeV@UAJ5lH!jC96L_gBkt9TkDHfVKMaev`p_EiXrTKa>Bhx zBJX?gyvkusaq~$C z`YOtiSA+l2M{Fz7-5QPTA1k;VIe-m&6_`w=ekxq6RPF!VWjH4k;i{z zx3;1Zr6{6alscSb=~%kt9tD-wfpM4FbvJ;kU%5?XWWXES?M2c@*M8f#INv(p15dU@ zM)C7>c6ht5FO#wXM}+*n6@@nQvB!ISkx>ZxD#`Z?nlV-aa+T9V>ekPvUIzP+rQ$}a zS;-v8wgKW-jK~w6UAo-F;7sV{Hc56;r4QYXRfrvTlFsRj$ja1tnndSbVh&Nk!+CGM-8K}j)D{euctby^?TQV)`5p^DCQzs(XvHT8 z2u{M+NMQOl@DQ~qpej}TkwATMu-34SI+^o~#IaD2r+miJ;2&eEC?y2Dtn={?5wNv^ zxx^{Uz)61A`$)mavv(tts3aKGi4-ySm~Fmf?#T#x12P|gtOH8;p^1!3)Ly75y5cDZ z>z8FsEvK4l>&Alct%I2~1fOK&hHEdCx$l-AgoZPma0cs1KRFJ#v9jMc<5~N z+99hYw=7tiqPMjSaS!?{w{q@g`eG=hf%UNTXp%!`n+~u{XQ0Il4&SfB4M^0|is_qM zvKy)W2%z78szf~R*P^;{_Q8NxH?pNNg;4)QfNeb)?4i9|2j|E*pQIbJfZT}C?dZkq znVo%7OM!$EcH`Hsd^`d4Wk_@5R)O&(@q&;^?yga$dEI+Cr|(_1?sPaMlX8mTG+e;S zDjLT^5L5OyE7gTh7;Gz4-*keNQ2~a(iP@q*`q7Fyaye2 zK1&CuONM#HyJ<3>jtGQ3tJM8MdKy1|cNV|`u^it$7gW$^sJWtgmj;)`8L6TeT3k1d_B&}Z zX7=~xEQs2Q7aOSiIc$^+LZ4h*aWgR$Qbm+_AvU_c`BVHqr9%HrLnk0Y5GoSS)gjvR zwH5lyGhvCdCjNlT){pOm2(|Rat*A!6{Pxw0BAO8Qr1b?Ve^%k<_{qfBYn{HQprBf< zJ&i~E>D~Y@i5y)xkOg2jJh7ar?)2UHSJ=;%e88T)HBjreJ@)Jq1}Zhg?_Kb_XFqkh z8x%NIQ_H|4gsBZ}NsUuy3hnd$bwG1Tq4{z3lP2_a%Z)6mIUjo*I+GPSEv|iZQ7$1G z%TaRcrRA?H%hlw5_GMh-;v&wKQrqJfrQZ$7+uG72Vf3fWT zz5XmzA-SSAyJ7vgBAs88(c07!!QL6wQYTs&Dj6|CjODG76k)hIMTXn8l7A}H7^}CC zm-)q`)bD9oDB3Lwkrq)ITEee$qkF>CL3o%rUDZgk)QeyKmP1)Mva?e0ZL&*_P((2% z=!v7q<{SX$@35~8 zgz*GEc1-@;P9;y?p7=d!sfAdrovL!M&Yam!ro{H?p3B|rU)i9G{xfP(96aM~s-3}E zafNl|2cCamgc?{W6iB6sWUgqQw!T@Zc8g#-H9vZ*za!c!ZqcGgXkOrHcSI?a;?DNd z{%nvPm|d*E*0NJ1>uR7LJPr_F=?pDK8{BajM5Hf*pOG>pkEiH7^{5Zpanr1Qr)yXz zb$9DQ9hsJCRyD1T5Pe~ubxuOYu>4aD>`y_r<2=?TZFVIQ6GssD9xdHXa4;3WcA!LV zJ`js<CfvH$h@Jb_h5I@;XX@qau z>kVnVKCuxR_MT7LEARox2xDv`@9SYUv^S?EAf?Q}Pw7{`SZqh` z8F#3g-qf-)p|ssvaTbHu@Zf2`QtY@$6iXs7b`mkA0W&nDJO^kU@CngK90tIB^d260 zY*gW@$T;x(|7g+w8yOAxW20{4AxE7FE}dT9(dbP0R4Ga0gk={Vx#C%qlGTBxmG4Qs z4D9?woVv5dc92^cnUA8rh&hT9X*l3W`>O3*4S#YUx>@wgb{p*wM}Skq4t!YzJUg~9XX>n^Of|(=`MLZJ9eerg`%8Z% z=#EHh#p+Ir8l#R{vGq4-!>nROu~U^six{LAzw1zW^ui9P!7^{r6BnGQC){apY~ z<6n+Brt3b{&n>NSo;qOE$MEEQ+4#Clo`dWVIz|rX;=+KV4su=bA(c#|t({H$bsJvY znL$03+5Q{!u1zzLV< zyus1oWa6!#h=?nf#D|4M6Ozw5z57?7Yz2ELnkd$@gjc^p5oKJu>DCuP7w-KT(G$;I zD&GeRv@lUU=lY4h71E*}`OmSroz_A#_jqK97+~HKr;yHuFHiHd$TzLWeKoMn`|7yb zsPzKtPslbzO_)%ftQr8zj6d?UdB_-g2-%+&k^HLv!F^ML{ih9d*p&WB$6S(~Xy^+W zQJXowmBIw46TY+Gtacax0rlcCB|oBK-qHH}i6ncKU{8ft`{T?d82p@jXqpk`)OibR zX}T6uar6(8?H&rO9i5D0Z9SrqqQW~QItsH>#`02)7gjiHWxZUALKpHJlK2{mM!tbx zz)sa52Q`B^rT7S`NNGFck_ovdt0-0I6+6E>a@Khfu%)`I#Uj(KKcbBnEnI{N8w`gJ z35^Tq(cqhZ=;sb1s5e-*Uuta=%8LLq7;!<2i5mDit+iJ+MnvlqlD1Z6(}mKe_0$?x zrqt@hrwe?3_--j)i2JYOWq5lrwRv{;n|kml`Xnoh1-OoD3qe0`;rx7r#{)J3M}@o4 z@|P6$lu&39um}qzd#@VXfyE~P;bpY)Q_E@CWiWE=AllL20=4gP3)#wZ-+s<3*5~KP zD))LEYv%so?6>W%K3J#tYlA+bY7(7K z@>$D$^>rL9N>LJme}9w8APe^ z^dmQQ*=RCork@t3iDgX(`v+8P#%vk*jjN|`i^iV2anlWK`S}!Q^&^CLJr}OX*n&jd zGQ*#e7*Xe_oo{;Bp2z!i=2dU4Z~J!8w+7zw1j6$${~aIU!i4ia-96EhDO}pujOMr- zR%-LnQ06x0Pc@yh9t7TgxyXQ9#ca)SC#AwBHh}8-*vv;JjMh(!#J9dYi~MCsrLs`X zMQ+S_-z>s9!}ERlh-;y;93x8(?}z zkm542#Z%W2Z&+^P6gvvYEvl(1cI@MXz$;&o)YC>Z=(6j1F|lZ? z=87R2|3*}}uzSfTJ_Ld{vzQFguY|jXGTrqzHO<^?-hpH85eB}>yOuf9zBubW8tD%c zxH0sa-Y=8Ag3*x^fYv-?8`)^-a1EcjWffG0GI(CrWZ4PTQmgXK__qS?ujZAEPQB_G z^;YUjB-*623W8_`=?=Wbt)Ny*dg(PFxDp+5hF-T%gOK%GHS2D}9kZe%PT`@1t|sKYUl?F_@=P4l`@6TKu{C^WOx3T1x5R0EXyfprUmr4-$2X$L{+a~Y|NzMhG z|Dsu(kU4Oon=`31vWsB1LzhRUYS_U{%B9*rX93&n@ukxR2&JPU&ebN>=o$;&|g>DqK}ZI`iMZUa_g@_9T6(d z23Z9DVEIG+OZLg%4!$-8r7-(Mn8TD5%T4om<*vioBkWrFq~&~ON^L|gv2+b5#_C1q z@3yJR3&<<~&$u>~k%=O)opA^|mmg_Hxrd8a8Vyx%&PKV`R+4HCB4*onuaHd;1nG_% zcINug)~{fJz?*-IYZ5YpFj@e}H1!GGRFCd^_^#y_E1F|@d)ys(|3atF{ZfPtz5g!u$9{gW>o!Ld(q?1}!RqU8xG=q_48h*^NLdGfVj;{ajHd>OcQ_ zQ25j&*69PTk~i9^S`k8(-}8rp2w9DwYb2`uu_)PdLK^7pkQgkMbW=M3wy|4F8ySAx zX*r{2s9Md2hcU6u@>S#e(n*(??@VW{P?~d-bl)w9cV#sq4v~i0N14|F@nTHHa(kXY zA!b=7XU?mP?bp2P&jdM~G=MX!=jt@*OQ*EJO&>8o&F>3jmp)*}d&n`I_;KX%K1cC? zo$7z!vRgC~bj4O!>01T*nx9F?D_5#YGnZ7k{>VmzD3TtoRx&5HY*I zia9Qwe@STk#@y8iohRb%ov+Ki`yN16h6=XlKiFw?*tt{4*G=-|T{3e@8CR#zwbba+ z&yR)IVM+C91kD{nn3H4xwSMMnl8_BYCNbGy13KJf;Fc3(mK zNusw>s!uJul?fd!OI~WY-3ddfjs<7gH;DGk-3*R~*K~X><7ipOa(OE08~;R|D-U@! zc&UE$Y}r*A5FlNIV%kq%B1%2@?P{whGwnj_2VRHp_7_P5g6U*-+omgxAME7x1b%HT z)fQHYGlRgRvtPvfJ0NA>R!tBv^B+e)$&7F7S&SA?X~(s$hSh^-p0|JHT8EC228~mD z{F2ExfJ=TaE|tu2xTox<7Mm|RzAi~25*))(3zg>wt&DF_ zQO)!G4FEk@FAmBNkwEgaGx1%8NE{*2Y#R$3qB8X2~$k{cahYDG} zExKyHfZ7#lD6740)EVeFKxo@H>xO`E`ioQ@M@Du!s`no*X8z=EsGGZ(o@8)oxfnO#K$5)Z70|?Ib6b+@l6XPO~A`aL|0Nt3Vt6tOqqSKGl5$p z%W+<*cDYpEN4^;tH3?L=)`fC;BP4V%<)LS}DY9h#6~^}|}IMCIl-GN5mL{!ckA;5sCq2|yH}q~VQ1Q-QqWQ_eK09J zz**G+HeDn{l`!UL_Bgj2OHOKjre3vIl@+NQm|vgFTQqoRxQcu4_8rwSQk6KlLEJLA zAOMKox2TAhZN=i4As`Vi&Pip8Z5o|7*gyGF#`;cwt46+n0+JM))UMmFl6%MUxbRnw zETIhrW2h~ylJ&5>WG}Vt19xbqwa-M7hm>u~k0>>&+cx~jY2@^YBgwmC!X}-1dQ3kN zT!86q?V^S&e!!ggXi|R4XCsbg$}mGz#STwpdrhHvoy58Cf%auf37+0!Nj>;i(cmrW z-q!iM<#rMK_2e}=jl+QM4%f=1kp_(CDl^H}`8?B^^lhmSQbd1Kj{gmEp3A?BL$;kl zp28@8Y;SeBPfT>~{HI!!2_ojfogRTFnfj^j+O3;t8Vi^QOres2M%u}^_b!gaHumM# zZ<}RxQ6_OCdjrM!VbX-Bn4{*$-1FFqH|l_wTc#vEnHNQxC5<%Sfp?o5TbT6?5bDL! zGE<9`6pNY~yPrxU(k^_`NBV=YoLt9`A}Q<0?`1jy?uREImZfz6sAb(|4@ItUbDLm8 zFcI0NfaCQ6xo?Bgaf*#Q_H#e^VBTkMbvWuOSY+Y#NukQmRPZnKEWTT9)6cb|RA@74 zH94h$vj)G5WzxiRs&GJG=f?m35|q$SF#*t);r=GLJ2I2Uh_dHedfCApe{@JP_=a>P zA?=X(?c-{By$-t@3oDtQhO8+>wTrdq4+0X`4KDS>=cOBL7%72uTD!UMXjm{Wztry@ zaHXKfRt&8S;OzD5#W0xVdJ`Lh;vT*m%qnd(>+Wb}!KWD?+j8-XT&*5&Wvq~)@HsSH z1QkD$LiGZfhvU~}B}L~k+c#R%gQ&%ZZmX1wu}x6&i%OT5FqG63QN)Ob9dYI&TpeTG z8ui^hVgiM(4BPJa`@3LG8}5Hdy(LeOf&QsF@(#Q024P0$-7cNs96cj#kSwCh*X4PW zEBl;tI)r`6vo*x6(yv17HJhJ`o1j>~Sw7#CE7EX$GZtl52$U z>4PWg#;00f)~;(njvQGYZ#kJ3VP;)c%A&jLHec`eE5@%39C@(HB6;cV40mw=4eE3` z8mn{8@{F_p|EHHxMBxOE#kin+r6Vfy}>ET(ZGiYe!=Q*=c zWRa17rfeo#zHA1^Oh9L(`H`QuUOsOhZ6VIN0sqKIJ{YBYNwLD$Pq>JbzSAMa#`#Tdg|Q!v(!-6$nvr@vP97B2D3ta3tu zb+V+ha~+XR>W8qsH9lS&CP+MrC8bJj^R^y z8}9hPF$@%et$xq1I<W?M~FIookK7Ayp zjU3Wdo#b#joVs*SLa|`w*6?G;YM4%=5gqzU;-E>seRF^pdh>U^v5<`1VZhPvPcoaudlcU_C0jVIUR16?I-|)IZaBD&W^;<5s(i0SiIvLRyQuA$i!>jl?Io@N?@Q9dg zP&Ci&dJI2)TF0v+-ekiY-MI>&1Dc;5%ZNIMUP7R5_yIS&KPOB54Z5lhj$K+bl9Al0Mu%@P^+8 zb5Q^}A_sXZI6g)@uXTL$0i4v)WI`zBS;Z3z1?+{(rB(Xb1%I=|^NC`N;F0%@qn$ z6S3xwL@QOrLYZxsp?5KNJCJy)+kwu=8tuX``PC~AktFU5gUY1gzOUo_3x%o}b$6s0V2_rS5D5R!E>9rkb7bE|VU6S{?B3~9G+@*Q$#Bu zJF$2<*!)KsnB<8UBh~N|hFtHTmRyc|#N}5(%;CZopY_Xl@6FN&&-CZSvSp0Hda~&crM>iL4Y?Z+#%^(c;-jFNxErPRyxbt}b{pSG{B+cCYwX zFk%_P1*kh(fI;wZrBRhSfS+;S1Vt8}F5fqRf!vHH*Y5eUR`tZ9l}c36jXOIW1= zLvG?dLehz2HnSWbdTu4b*+^#%jtiryxn2?k-5(l`&gDFmPqrTh{}EOugg#Abr_YR- zMQSfr{OW_lYKZfTI!5nti&DfdYm59gcBZxxI;DgbcZUZB$X7-!w>Bc%@&kP=8L zCGLWke(J2*OBw9Rz30=l+hwx%nk&<>hdvo-L1q_S*cu7x8ZlTPt_83r%6zw{l!4n|;0GJmS5{)_3}M*J7N!&^aK)rl4tKQ?{F8{k&; zD%td!c_Dgu1UbjD%;#63Tax?~O9c2;jrYe0UL8sIp{dnfG2F2oYq_vE6bc%GPU~~L27<(yc zdUhPNu)8M8kGt8S)*TfZ9tmCuJY+pi{bcJ z@5!C}2OWcKG~DXDhBX`krsc*dEH)D+fk$;+f2G}!Q$NU%dwJ##J9Ey8IGNz!5dQ>F z?zkv}pwA|RLgw^&{J&CB&;9-(xTci(RVnucj0c{Nl^NmpZUDNnA~(_!r?*=-I;HwB z{W$UGTPnH!0=a==S6Czx)uQ~^iXQUA6|k;IDoaw~IwM;z}v8RHq6sf8SEUg z39jf;UwpS8RtDkn_1y7gE2-rm!q|JNj`U^X>7t_jlzv$g%RfWzgq3HgbSB^b5bZdjYh>3qQW>!)uaK%Q}1u(I^>GgUU6 zGf*$b07L7X`WxsV)Mzj#Zl%02AQe=Vmg9c}tG@g7N6H#Gq6!5=-c9#xKR$4)&$6kQ znvyhffbU6Z&PQ^+lbf%!ICkaU_0g$;VSna=Ls#vPKZx-#Bx1E=hrz>{(r^{oYQX94 zJ3X0p6@SZY;ib_77P2FDbwu4*vhg*A|;=IIMbgXx2;u;~(Tl NK}JQoO42O&{{RXE1M>g? literal 0 HcmV?d00001 diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png new file mode 100644 index 0000000000000000000000000000000000000000..be6e5888c3717f5472eb3a2db874dea2c30cfc70 GIT binary patch literal 407506 zcmeFYcT`i));|oQAktKnBE9RE4$@0h1VoyCQRy{Q>4cKdq9P#DMLHxR3L;4FRa)qT z-ir`AQj$=T5Ry0E=RWs7>;0|uuHWDX3m`1d(P~0W@hjG*(c_SkuKAfTUV&4 zsF?H~J^Y)B>T)F&)y37zw3M1J@67tBs2K0LX=y#t)6(L9;^*z`=HWy|^(ZFQl+Mhg zj}2z|At`C(!sSQwE%btaUwA~HcPVK}TSS=pL5%FxXTP^Tzj0@0m>GF~t=p*T)4N>l ziqq~lZz`qOe)nhJ7zwb&FNH2I;12*UB+$k}X2=26r|>u8Y@45NQLV*X5teNy#eG@3 zL0iQ2=^n!qs+(rM5WK^SXU~*hIhCJ>K!RxaD^I>h80^B&$u+K+OL%pvq}%Vt@%U!- zcP!MeOfg*_E>OLbTMB*r+APZc(%f|p{w@Y5Ws}$9PRiD=Py0Mq9cHf9UHJ51-A^Qw z>YH_xdVW$my}|^6Kkzy_Qjltb^Xl2!y5Xu;4s8Wf^))StahC$}rxd)>(2;l(9fPQ%tEcbBgaieq-rypV zr0b$NzB=k@t=>HOCLj%G0@3;l1g?T;I`|deDnI(jb**%5T>kBS?sqF!?iXHRQ1-pL zd%yWf5+6r#^{hnZjr^m?DW^v>5+cw9@#M2qd@%n|jLV&4TNire$6*FPZi`D@bGXvQ z7TNLk^1Dg?z(HgTlgV7%oqh*h#<@pdWGiu!{0Z0Iy*a;0Ftd*mzOCbCo0Seww)Fs| zC9&pp(Os!reQ*QcZUM1^5PhU(oL;}X3HT<|YkY+-hr#yh)o&SUQ<7J^d;SX0e8)fp z-T?Z(S5|uLmY7ZN7va>**(HUKtmv-*KcKcsx)P?YWp(vt;X798{*5S|JFD&KuYG>f zQoo}9QeF_D$Io?nK4dp!CHJ$~WuMm<9WPM(JdukKy1uR<2#&v`{mb~>MSiGjPFu4n z)fEA9z`N&GJoloS>z-%|(y2!(v%Gz61(r@tyv_vCr!wXjB1Q*Em8f%gOO;fTph%X= zWllE7j?_5+5h1Oxkq;b?82O{FM!c)7r{0eUSy?&Up(~aj?$V&$L@&Y)KaSi7Kdd*} z5?+AamPjMsST`G?KaTj=@Q5tSUVn4Kzx&+R`x|qruk`|z@-L|o7KYPEn6<#b3xwpr z6_nUs+t9D*!n|rBwdAK5nM%Idjm_kjfiFw7{;IShXYgs4Y-{12S&oUn;pbW>_;RMR z+3S|?c)Eri#44~ePfK`Z%AHH7-8&Qd4o7x^1#PV(mK@g5Yxj6o07NEDb8}H_?#rCu zoYvcfD;G_i4qSi2gkhrBoSkpVm>tIQ+xa|H@43e4BId$Q<+#k=y}N5zO>BC-RU5Xs zaZddz2>bKr&*TFuS9z+l#g!9p;_$4@N}<-xSD<+HqAyR{S1K8(TSX>TjK=8@Ke=aq z;+1KXeqAEo*Ur*@b?|XFXL(IFY zcjRa?qdy9WCmr{)eYrK8cvny?8T*LNj|=kdh`%h^_;%FaPdz6t*F4buZ16EMD$jQd ze3vyhjx7&6Dzar!ePuCzDo<=Qb}OKoHSF~XkjkXnU)E%h8tC95ab3S%PKI8QLA<-a z;;#}m_cuuVUsEy14t&0XMm?PMG*;~bVD;fnF;vxIH!e6-%qF^S)f62K<|D=56XLl3 zn;v$NSs*4!FXO>&fsd@Utjw%Noh=jEU|r$s0Wnh#gX-B+VrILxySlrWx>~weyO>vj z4uF`4V`8%@&3ehgQbn(xJ+RQ7Gkhs(EfA2b`R&{HjBm=t7mL1pd-v@9^Zw_7Mck$h zpD_mO^5zj9dMCoVAMk5#YhTwq*F-TQe(ZR8?Gy(?bV1kC$~&Mt?ssZFj(-gKsPz5& zcjxcC-$%bcm{j@h{`|BEW)k*|ko;Y{Ccm<%=-HhB`T)A)i^qz`Oz}EZpYfxAyI!gm zseUR zxn+#XOv$vDYj?hkw&pys@veu#VfV_Od6z8(v=|c@&+nacpRW^+38zdxG-?s95qD_0 zn2flNn5TL76>^lC{c*2Xf*7A_sdaL6Zbw-%%`=|G$Hl*i@6y?gGvYkHwR`J~(}k-{ zPSa}v_&^CQdo8O(&P@?AJYrwY(GXV?_d*9=l;`%YXu?0%t!dju-n|N3OeZC>N5ymFU29$I*n=o6b5Vn+)tdf{EgVlt3J9(zM*8Kdw^nKm}nFX|Vv<8MD zqnuIh^0Ni~8(EU0X>QrTruLIco|l_1)16H1gZg-Df~wf(yvq+b?sKfjKDhh&WtSn{ zcbb8`kvzmoUVbUC9#{cvK@B9c37P0f@j(LDl*CL6tqMJjf^^;Dukbuq|L!*7?xN-b z14UwMMyp5f=es0&ah&&eFKPd(99>TnUkoVOy)ex-UFNm#b-m7_jyH(m^zP~7(`*Az z=NlkdYz|#}=Q6Oi=V@Xh%nNo2MyF23_#O76=hi1R5W#VFgSE1WEkn%>OP0-QVH}!n zn%QBhVTNH1t_UGLe$w4{zMpwt;eJ?qY5RDD@!P8pIv%7&FRXF+f^%LP*?tSz8XTFDd1>Y`Lnzl> z30(PbSWvMwGg2YmvlLq%cU5~r_st`twf^3-r&CW;@)I7i=I4JF{|raOz0bS`QcUAj zlMOWe-K0#e|LPlr@?8CjmZaaIFQT8MKVYS~ZYmro8gG!5=`TCWZ=SJ`^x(1HqgmJ( zK+0!f$Je^3v`C~#31OKCS0xcIM$k8~CKx`0_vV6S;*-)cYQ^^$ls;H??-_4?F)cVn z&onDEvxfpme}yOhv@CycDAB0~)Yb_q2-@f$gh?!yU`lKREbDD>f2j$Fn4FyD^ycuW z&@{v}gkYax#xS?EEk1KU-4#aZnTQ3v+v^rt9qxmrm>Tb=yw$7S6vCshTvw`B5<{~m zbO^CW_Ibwk+AGm_=uMc7CUgGEno}I|n)~dAeE;Tscy@yErlEJHy((}&E?VAPzHab= z^H=1{Oi#@a((h80c2MB(*TJF8vyciC^!wF1?Db8BA*C#eE`_pv11-AqR6P0Pg4M4Y zyYUC1c|P`MpGRY-BndD2rCXkCJbPht;w0;hj zx4f=q>RNi+0u5fDNew9GD1L0(Vy)qOjQ2~&7`oqf|KhGX9aZLO=2GD`+bF$THDz0N zSTFtGMTld0)}+(Jz??T_TCqu)l0VG!e=uqcb@ImwWK)3iOBvoB|` z1LJup>fX%jw1~lonrI1@=ez}cQa`kI3MP*yT_%GRg_Zp8Mk&_bg(&c7u7VwVo{s0= z?D0>K)G$)fB;)RZo`T5q=$NIHOfkMX1CNG~zqTi8rnaU|#_bD{x%8@_9f@5JtjYmM z7a0QcZvNT)1UkCsV4dwgzp`u%1uX1O*#z~S|J~`7IAF*FR86o&xAuUUwSFt*v)tVo~d5EAGohM$2tMRe}Sb-EQxr<=^04Ozl^|cXE4Ea}azvJ$!^*jupb9W{qRuurL@V^2lk= zbiT}RxubQFtCx%8DoDf%Uq(eMo|R1fbC-u|$ooCQMFHRN_ltLj+~q2(?!h*y zI3;c04Wht4+>+lUfErLozTm|fYLTWZm_{)v-aDD;IU5>M{Y5EXrn*Rdi|P`kL`^w0 zsk#1D)}g*lMe~pP3sh9mZd4cl*~W;H|2dK<=O3E?%4t4CQPEMZu2W9%rwjkx`f}wb zn*T0eT&3Kjx^JSTr$@<6fPPL+ULaR*|NPD207?bD&m(gX6&0JvpOaeeZ{b4<{~5Pu zX8vY|2C6`BPbqsxZwDtSu&2)-K2+*pRZ7v*$={wI?CIeJQUz-W{-cE|rTnK_T9E%A zP5j+81kDVe@N0SdIq@q>$wz=U4Z0bXNWQq0T?aDOVbTuKxZ$s?yRyK|xYM z@>1S@F4D3pDk{=4a?)~gl9U#bAc&X0Jy_BUB=lcQ{+rK3ClJuj&Bx!(+l&7XUwa4d z0DlcZ!9N}S>-ewdbOO8myC*NuKZivbp!Ai;14KW_cI%l|=Y_HR--IXT7uN%}u-{l7_{gPi=dygezM`fL9Ch5eKG zf8P8jp}O>+vHuTQ{Fk8rQA?4u<`s46f4w!$E1XBOgOnG^?e@_486~HDW`B+gzLezl zf8~@iP1&c%DGd8mRQIU#9^QWjrvAN1*ZgC^ob8Z>;qsL?H!dZIJb3#3)!ibci_ad* zzDj@N3d6SMJ^^W(<@zwxkp8>{XU8HQmV}juC#SxwE93pY|E*+ zORnEQ_{M3r0(M9V+j^jmHG(TFV~G^yKRU>ANZ}d|K;9%<$r}mXcrEzk;`@8oZeP&; zbou}4!Y-%#_GK}_Lv@W_KW@K2iBbIX{~e28l|myC5{2&fSSq<=tSJfov;3R?y9cRX zyL`kyuziKXN7rGRl3bW{{l6W-KL;wMa)I*njTb0aM4eo3?ZkK zio!=Y@K5q2{>A@N6d(Q}3i1DlLi^W$MDgEK@js&Y*Q)d%QT#^~|L6AfpReLSU&VjE zivJIV=RXtVKNIBtFB2qvg@1sLpapdefo(!ydt}lwxoI!3>5Ky3l*qWf)fX~a6)D}E zybKUux+NBlCe84j0q2Ey(eL)wZAt4chZeJ=1Y7bh)YBFhX znLu}pJ-yGTyjy%z>naoQLeRBD4*&@TXg&YQLi~AROB#Uc9Y)NZvqD#mJ-l&I-t{dz z#jQIN;=#9<88yx(ZOYM|3nX$%#3hU*#UD+~zty>k0qeQ|qFM&7A*7E8{0Q zrqE;4Wt+_a8<+8uk-abLq*F3!xvmg!zCGWDS0X!;z2IQ*-}~+}Lnnpj`T0ki0k)?B zbcNK0V=5X*i>75>by4*;XQZ?GnSou;pVMN614T{!3JWw`4G`&xi{I&%X$p) z+y&Qos}L~&r~a>P3tNT$Fw;fl*tGJ)Yh;O()4norvcNV{pDQ@V*y#K&YjBKK;RyZk zl2}4{4NHQ-s+iGM&dz5}l@$144A(41!e%U|-&8fMUA(zN+~ZH@>bo*o9=Pq0Hu6~; zG-0!=jZ=NJ2=Yt7r|JkY#2IG4RUCYy1;foVe8!1>2Olwb=I>Rgcb0UrX0|jYvaFZV zEX?A5ZgH-c?H_Xh@r=Nd)K8P(V3)|;bdn5FW-o8B-*V;_@2tErCV>xTg=|lQ>t$(? zrq394MT2{xjEP;>;AjKUz_6g(g+Vdmiz^}a(K?}i1;&x|5EduqWd-$&&kY0?$U@0x zZ&K(FNXt`R*OoQ?*0tZ+HiWOX!HK$H-!hw)#<2Oiuz4F>uaBEFBIoAE=T~>>-j}TL zm(?V+B8h z721Oxc9}yii_zeRgWC`qNa*cO2>Puq-)h0a^xckOd3O?S7P}fEW~>*n@s~K~l&an| zsfUc?+pJiMrjw@T?8z1D)O@P(ul?w!-sP<(cO$J&bmj?Nk~AGc`GBp@P9x3CIr8@u z&1~p4VzTd7yQGdhXm~B1a+lLA`y1y&1{cOS6!KNbIyFPo(qd7d%O?Pi=7wQ2R46@*3RWG?6q42!_X`Vrt#+v+(oT zdFk^pX^k!R%&kTJmI-}NKGF(b>v?<|TB#sm6uH$l1ar~YHJ=v?UlsEL9BBdWc2-># zAC_g+M02R-5ZWU>nfK8yU_Epsz;KJKW};liy98a3w+}K z8{RP2k0g{Jv$zmGCz?(VmfKdvTOH08kmn0~6h95QyvMeod7!Cz;G~XxvJCS~^t++7 zF&(})T?;*;4E5xTn1&ng%covjh=l|{1KxFQIFXv~ka|9(xhKTKk%cx{3qby+tMT%z z4^He0<4G)WX&4jp#PBfMZCNkDQ%O3VhYxL+a1A?p7=tv+J^%(4&#sQ^`x)uW*2sD~ z5;>kHARmkolHQ`%O1Tyey9UVk63T+Ur^~H~#{TBXtae|@YX|eO=B&p<0n-=*KhuG0h79>_y5%~^+!!E;+T(-la^ z1lz4h8`qGgx{#(lEzLEpTK2F}_FUz>T%zhb;wfoqdA@ajx7EwGz|rYV_c6wUNGE#X z4(09n6$ITmcqGYNtXX>86GC_#^kwCH$fyWf%eblYvVK?&^(OIG&pUBL1IL&_dO)}K znR*JbU!Oo7zJeqZk%0Bbu;KjV=3u@C{8udo@&wk7H}ivHssB<1W|(PsDc9u<4*{*8 zt%w)A)o%Bdhf@$>z9x3mfD!)TQ+gL1Tdcu~fCt_RL#~#%y^Tcv)p?8)GIh^e5d!#yv`mMzv>08oY{Y&nRChYDZaZsg3laLw z>*JO$e$4=8-*H$a8|}*~cHo`L)!J`kZq8Lf`#|yFZK32RAwH}uUsYJAQfQ|Q)VsiF z>ASj3te4Ca#nIUSLzBe`-3~z=b?`TFx8>#(BXu^#)vsY)o8=2y2ZE-Jy%~P@(+%9X zs(Elp+S~^Xvc&-VVsVnWj^;@@^k9@WFl@mg@c^P2$3SFyfA|0Q{%1mufTb1dNCdQ_rJ z@H<;V@|v!yZ%d+4h!_*;4xLoP%YQp3vm;)$)@4nW$GST=bca-OWmIxaD$|wA}Fp}hQ3uMyRBym0$#oyVc_-T?*u4X zn$*1)+lGs+Glis@W+GRTVusVhx6{4gq;0q+S=LGL$W9SM-Xtx+=l2^nVSD>c2Qc#( zoBlH5`Boxn>!gD3tYUr~c2XVl+~ z5;O@HW3h2Zhbg?gC6_K|Co`X_GuLQ(PvETGG35Dj%@Uw zHMPt)4Gld22H@XVVjcoC4`~2CXD;O5aKjd~4J=b(iE#_5b%A;>PW4`!?XOEVcDx@A z5OId_dJ{mxOIV@Cbt9!O`b`+|HjH@H;xp*rakca)AEBmJWVG&3#lJhgRU}uLaO+sP z4#Zo~l!*ToLn^J=&w#}S-mE^&xRK-TKl+)?JsEVOx?9+*%V zck;FPY7K%Zh{-eY1T9!KwyCR|-P`f7sV@MSqh6l;e1XHS{3!BTPgUq&CUs4t;?>ph z9eoWLmLI#mDc%@T2W20gKze5Kgni1Xms=Ab4F;S!4QFcow^g7b!t(1neh{87HEj6$G4 zf_%ivf%_;hbu8SHJ64SMH2J^w8^d>E+v`a=?o@&6ZR0=AiPRw zP)1J;ckebHNyY;JyHe(<)b8hxml5wGyfA)0{H`P(vH?)$&DGJxqnombQ^MasGSCpb zU1)dqz6|%CH3R}o;j1JULyyZuTZ@Tn540d6iA!|42bs<6?2@I*n9qv}`vRvRMDxgu z4iOty9e0$+o&{zAy<1R@nv4cBd-=C>P}@xP9ZU8<4%xK^=XZXN(Ny0`ltFrwUoC)7 zCZO&BB&;KU)US1G!KL6E_rL>XL93rw$|DF zuM{F)c+oEP#mS&0O3WBRynS# zNYh;I4lW;(s8wuM+z*L=G_D2By>U%XqZ}bVc2C46gU#Ngji-@=^F657W$Job+tQd= zcgHr$#PL`SeF5C%kCl@{LvBy7rZo2eNN5mISF);xif_X*f;0Snt-IbSB9Zvw7GGFW0^^)3ygky{awbaZ^xv20n^2Ogtw zFh_w4t%F+hJyqom*7p7489OXB(TdGm%5jrwz8QbCgk{Zj6U}u^I`gA+=E~4RI_RO5 z>wA)J`1O^j+)Zf*ca9PbwQMJz(u&IX zB=xMkara#6@60owr^Sycn(-^pp8Zj4iz3+cJg>WXI~*EjiXOm(Si_u0qRp4gKPs@9 z&{eycB37$th7pl$e@jlu6C^QP`PGI9rn&Ebv|m_w>^IlKgAY;AR9V$+Ye&P-eTLRw z48xHJO5g~k3>3SP$gJH}uehuuoY_wi7W2c}+aw_XM#x0;V9m1Op)|VrT0RN)3^m`n zH$Utvk2m*|2cKi1B;@9by%VdJGY0lD?kh!0UxrNGak?4q`uRuV2<>Q?(^}Zh*Hgz5 zh^xS-IFspWK!0v)_Hu*3N{!8K%?5GUo-Nv9D+%mYxLv5~(7f^_tFdIC_=|<~%fsbw z7xWYIw#i(5Y?83)9GBZRFX`x%XK96JqbYW5c}q<>^nhX{pQymGDj;=jUb@Rmr(w%M zykl?MWkmxldDpnCvDO)GA@S2t(k3+A%WG$|MXdzn#`Pdp0&OcyL*zJ`u$>AO@?=p* z%5SAW6go!!E?B#JFjD9^K46w&Yx!51OARNz5SN7`?hHykKq7BE*3um@#+MPxrW4s! z6_eUXyKUhoyN+zUS_!=$WlEdHV4<6b?}q?fwHsZEFFfa(k|#U`)|io7zaL`-0V=a= z)*9nFmcY@VmOsnBdas}r^kqKx3mo*1#eb3s$7g~(>c>>XZQ1*Lrqu3oJAi0A$I!{nOH_TqdqEg81Cg~|r=V8cJ@W)qu* z%dUk@0_>xaF+pkz&(o{Jpu67ZKOp+va~Tut=&djK-=<-crbB_NfxSU$;9hZ!RNbl9 zIza1lVN;F1*Yco7y9+D?eQ%!YBTHPnJ3>4!(O%ujyxdtg$(=!v1go(Wle5Epid~8g zG@v_}FbRvRQH9FcJxCwO=?-ZH@3W|M>*b~xWgi{N>V^~s1sAG7C!QL29aJ1q##4{&)&bX7z5oZ#uYL!t%-8AkT2J?H!F8)&+G3q=2Fcb}(=X zlg7p9FCD2A13Y7ti3A4sBQL2_6+56xqr1dC-7etaP<@x#{#on6*~Z@ZTrh~n_xP;* z`<;Iw;$F-b2!~m$*EoLdy)27we>J_eC*a;=7uci2j!~yQlOM#+#LdLFpL-c31R8@o zSObCAfJ&k|O|nnt_2^9R8!oLGifXpsIo_D@yKcK6 z+prjql-onJ<%t1~&^x`U9R8}dOY+b2ydXXPyO@p`*n2G33d zl>B}ii`$<@>zmpc!YY2rp2K<^A)AY4m$Sl4U4k>fdLhu>FnNe58KqOS7!)=gG~`F| zS}5#8)}_^*kH@1cW-7SovI9}@tS0oG?8UQ|<+DA+14P%21kBa*rc8W8Nwr;(gt?On zTduQZQD`-nJ3cN#qDkDe^fyXhUQr|vRo>wxVPVLlTte>mli)^zcNoGK&S<Gxw0Qr zvAmKs_hO&I#Ik40g;yzsR4tnQR97YU>0_E}ugyhOWN2=Ood!Wqy+WmYl-Qz!P|NtH ziloNP+)(ip+{#sbBhT_6D`2%!*til!o@2i$%VDdh5=pO~D;68lA~s;bqFx9WDkJIh z_7qkHqN)$T6p5xv!JSObDBI{JUm1_n8r!p)6sfQ-I$KfVID%FP#pRc% zrC^fw$MXjRny5DHtoLJsiOC@nf(MBEGmA4v@DdZ}PND6cqV&&%p|_@Cc~p@q>7>j zuIPq96o1d}<+#Eh)@G+kWH_M+J{Ldl;KFmo^KR;B77|Nzqsy_WA8 z0zZYoXA^S!4_;?5llmXP6`6&9|1lGDeUB!KO;6o{?fiQ}N}lX-)tvt3P3TLuhnJe# zPNqq7!}W2bGLFLdE6lr1Z>`rpq-e%xs|TU@mv6*!JD&g1Glx}$0nwpSpEHSKOzeCp zs5>J;P80%XLG|FmiS$53)4TAcl@vt|mR(=ecw?31Jrk~sKp9ex`iAUTc(Q@Vmhb3m z-;NZ@LaqReZ6l_WG5glb$JU!On|;RN7zrWwJA|*^qz&(-mKCnn6>hN~{w(K?ETt;~ z(kMGJWfxd%YdRY~p!^HKXHI7v0&C0^t(KQ#v`=?ZFy)3nI!+pp{ri7@S1$Rb1&f6 z>pU9Hx`~z(NSsz&&Q5MbN+UbUhx4xx+O`GUb|iIGd+Q~hSWz}Vk7gHsufvdJ*+*>Y zTy6BZ^l929=!5vh`o_Q=0UNLmErYc&2ZN)5>tHlGBJI*wv~&?CS?=TLUO!M&zu$p4 z@2Et&UkN|F2rZQ>|K&`r$lzM1kFacyUdmKfQy1|rMtl6NMiQv7?b3&t*w**--sGu^ z+2fC;7YFr%*7@3EN>$VQyv8nTxe!q3L@OWV&x`yL#_u?y7q!m70jugs&{_TzY)f<7^MO zKXKGK4t7i8r(l4jqhBP6Sg!Tl+Uqdp%D8OAN$k@*Z6A@i)8FABgU~=CgNGF3{ME|b~CXGxt>RNpk1@-=a^f8p zxkEAXLJj$yZHPKavVjYppUD6xz^m_56weM$q=sm!tSGY0*0XbaS99GLCPoNdFC zeZ{~rLQCe>^vOX0KQI!D31ee(iCl4iI-fUP1y=3#o3;)bGJM;4X}rY%Bo>$I7`L!WAK)hiqNRHbZ-TJ8ZK;YPjiJdU zj*qR`6Kz$q+6 z#&O-*5G|^ecFMQbpBIogvEUG0dh9}M3RChj$hEZ(7UL17k8=JPQJ+3s9EL z-Il+(-wOi!VSvVTznc4sq#pp;JWFt1df=IboXrs|bW26%i}9226^4}FZayAX9WQKw zE(PXkO2oZO%!p1O#B<+H?)i0J=R6f~@+g{_AVDVN9BP?z1AVtWUE9_zTi5riGnU2p<^{l{R zTThol#Ll@PPLqKyyW&<)mYiBkW3n(ma!QX+_c> zdv3IvW)4bsAn(;QTe)C+w6a>;^0jP4QO*aQ&JFU{w{xPbQI2ZS*+e9^THCe2jGbi@ZC zhAgn3q9U5+*r5+(1I_0lPAZ!2iUS({EKas?Q+0d|ocZt1p2InW4|00phcTsHeCo&c z{@qzi?}1?7zMqCLo?$c{XJ`taS>8}Qce$_@I3!A3C|B$IoRH^%Y;e89Q*wm6*9hh% zS^jLFyYWi*7TxpA4T`)NX{n{v-QQhWK$Y#KPLv#!SPUfpba==s6&t>vC7zamcI ziH&CmV_Kxkh{VE$Na$(kR&@9lsVf}J^-=AM!`N*@Hu4wXh+k6o>Vng4_eRlD0~DvF}!=qa@)&HH(gExTn>q3Zdo}Ua`IE8)P##XkRR$&*exF$=mR)VZbFOu)Cw`;C(9~ngugJQ%E5K3E{7Z}5FHp0_;!E~Fc=`PW;MARsv z1)6B4eB%9>fdK}`b%_nrn6TW`Fkvf!2H{;y`{nk{o!X+l4n~T^%LMOa=82au1=oqx zQx7~?lVy-|$4@{>KVNX_z!dtuP0v3p=P*vX8*P&lXWp3AbTQH24b$Owy6r{%^+HdO`)UUhao(lbDCG*|c zSEGUJR19;625L7*;>cf2*exb=+;N%L>+)~E6>q6}o3OWS5#O`o1+!(}g>vM}xz2IN zLxY*aUl=rh9`B2&2AqNbN=hFkcu_T&RvTz}H9v8MVrn;`V+!vW9-8@rg!C~vp`fz# z&0EVlMNFdSZkYNs6H3(RGNMdbbLM;$=xOBN-&W_jP&|N7xv)NnyRTvI0dwK4-AMxb z>owQUU5qv#3W@O1a^Y!bhN{H8m0A=@bNb z@b!T-86yqf6`#APlP!j9|5fgJYGbuxK=c`}4ebdE?P+ddEZwvZFt9elUj(5xIwK@Z z%WQI^8Aln)6c)9?BY?noBs@X_g;hIi?sJ*5h1_MndLT$}{mR2CO)Tl@)AeCU0B?!C zZTkb!k;?n~Y2xGZwU*<*F*Alw=!Q+G+o3*1cP=t}-Y)(qWXeh31dEVJ<^o`8fngMX zZg~j?wW1Qbe#zLrVw4ZW3MGb4!TH(5nt8h9X9V%oS->XaZlb2)kow5b$xC5|q;zD^5a+1J@M(xn1td*6!qilitM zcEWVcg2tk$5KSZ#;;^z|EsZvtu080lQ-yFr7SonTi%AANjxAUgZUz)d`{t zp?XN|H8*N2o2}<@E{LY+{vkl05|7VFZ)z2$WN>UsguL0%ey6bW9Er$l2hfgWJOutekI=(eGIv7eM2{#Nu&HJ@Ylk5ns1%Tf@wsssJjeA_$Z`Z;D* zJ9}&qN+(_7W7G>*6f11uNZ+?cKW0}0>NX-suBIWG*R8X;_HjtSaAZi|mU=T6gF4RH78;5zd8 zk}Yk_5Ft2~BcOqCYp*^V(syXFrff?FgZu=|kK*2NC;b*gVtBdjwD@Ye{g(?PLYW72 z@^Di!z1OYHo$f0k+J!tfvUJr)0L4~QGQM%=TQ!|m5b@&1mGZbBX$Iv+OTD>@Ttiqw zf9mNmZhM6lcKEgd!lzO)jmxp?KMvmd;Tv%?^MD)5j9)r9U&@@WA)@cnTCr62Zpw0$ zHS>N8uQTUBX-}^y>72Q#e{ao6U*?jG0Q%^uBzM}=5N@eVR|?i?K7ds&`M?%>4#Pfm zCDUOQ&l79olMl?0hfkn2JpDt&cg3*apm!O66&sB6+CH?kdNJ3vDOQIsqB+I64y?yI+5KX9yumpAnXhprqBUEa z`?rmuD?>z?afMzAT5{mNHs&y7PJ38D+p zDs-26V8^&76?~oqR`R#!LM?bQ`t$f?L2kuv&*Pj4iX_?Ac*96f#$gw3>F(DIYP4RSQ-}pgHff2}-G*bI&nWi*-gn{C z6xEaWAYy|G7>uP<(4namu+@_hQ73B$SI5fJo~9&#!7?4VPQZ$z$^f6mX2V#wa2O2f z$4YLlPgWrWrdiRW-oxux*m>J8OOt-397Usr#xpJtVIGDUl+^hi?(VJZ;mqR{=*m)e? zvSKMZr(Q25NQPTU$^&J~Mj7OaqG&*fmCF~Q@WY?!2B=%Ro(o>(qu!*dL0sWc3E$Es zuJt013FCnG+&wukw{QqkV(Rivq6ED@@z8jw z%Q7C@YpG%YrD=Gq=PJw8R=~5KQn2{)Tw* z(>iqKhxp%PV0L*nR4rOBCk((l2Z3(scm?YNM)k8Qw(8>#dcFP6)FYH`jXfGT?1ZCX zwr7u*x{edz*L2m?6f0|`37@62Jez_MZ=Er><7GH}v6^zohXck-DX(pLk@7{;t|)nC z1FwGs0lINzj$kwiPJ<#jp~GZ@68{2hDsI5rsl^AqOo%wd{Gi%;D;0 zcM&dBwH28a@s=jF6?U5A`6+iV^txIpF+cJEQU~so9sxn%d@ZZ>TckW`o_ux020>AA z$2nWtAm>+3WlC8EXG1*F&m$4sB~-vTXJGc8tC%-FLTX-_L(y$lsV8r8um2n9Fjab+ zLvCf=ocCxH^>g1NInEJ?3UDF_xFk|dy9lH7|BCKE9!80FB-Og$A5*yfr%NRe0DkBm z8N(%H=qk1-kgKi_vf)0M0>icqcI0zo_M>daE{au8gO_?9v6eg#JMgM)+9A~2yPthx zA$;1S_;y=PgYOUjy0=u(R~bJvrGhYp$BPYaPGtMmgSqJXoR>_zKL+$e78GT6H)g&p zrHA;hl!X4^pkqM~3G@KU>d>68HH=)o8gnon-*L@=_sbU>6`eWqMw?s@YwSA(!LT`= zp(OWVy>-2M$ZzF^%0&uVSY>@roK_lLEs6U|I;?8{+cL85#EkoHGjc|%8Q~ID>zlp$ zJJC)uY|Dz`I*Z{Vt!V+dH)i%*&-P~nu?!5JZX;2Na?Z@ld3782Lw$dhEC_Oum2=uO zNa)06L?K~A&5g72FqUE5_nv(M`s|pK2MJfW5%ubcXt8?Wg(hzD3$dAX0fxZzPw>q$^e7P-t9_RwN=km3Q67^Dx?i*0qP!Nqz zhNarH*?TCGWE6KRd!bZBq12g-TdbHTqhzB9IpMN1z_v-wepJ&cwqjT3&Hd3;ELbahWrE%?23`$6_h36|UGu}&z0Pu7Lcfw>TF>TIeL!k}CKo;BDkxo56j8CMy9mBG?34HUrM#=ebnv+qnS^_yVMr~WX{{Ic&b24gvG^G8F$$0&(y|`N|mQM5vo3)(lFcq5=YG=vQ3R!Rd?XFC% zy`4{7g0>nRjmPgriro*Avv&~CAne-Vb|#BKLT52Rp(hq8S0nJ~gNcSwX$e0a=2%tK z&b&&mXsohCN)&aO`sR5*Ku*{=k4F_ED$RO}xU?R&Dn>jsS5CpTk&nn|wgftNO!=72 zxhT6i?gqONCadr9)_XDI5_*5Oun#ONaV|$IJ<_FmsYe*Q>!k)1*F*+R_D{5x`hX!G z((^$iTq@-*488I}fT)!yZWZcV@3CgRy{2(Xma^$up0A{Xle^^sVKqK3?G~DEVuQ*H zb`|}$LzYnSbhmnL9T`~lY+kKJCw}m>qHFAtfOH`RZbUc9H5?5)e4rOI-)_^M33$Ux z6pYm?rqM-Nc7sNRkJcr_D|?J<*Vvgzzoe09vDWIorh4KiyjM`#U`(tPJlM&5QV!wo z@9+bM=<2v3j!Wlp)8|&VJ4}^@n0Z%Nmy&&TZUH*J#-@fdp^zX)M#H-psqp=H4!>Xd zHnY{`_HCHU-myHCFlk#QcA2%w3z>n>tui|Lt1dCdim~0Vp5GehY`h*(pan(owSn@I zm(L_#eCm6|WqK>E#9-}udN+I}eHb)$D7Iuz_@YPHau{3FP{Mq=DdZ7mvn|ZbzTEpu z#qF(b^jm#C&6qlbb*Ehc(}|5L-U+ey9Q67Qoq4mZTX;KrX#2vAx~FXUza^a2dOtG@ zhl9nTbc~3$b{9bU=~!8F3i4MyLO+w9Ve!i;w@O2MWVUJ(W^_1?OJu`2BSJtgmRFOI zNr}LjUPjcMowiS(RWarw=n8yC*XsIcSbc}(!5@rU`vxB#tee(1u8c%#mAQu`<6J2L zQN%7qub!;zbj9O*F7bpu_borP^GI(9k{EYdQZF%kXpO@5fYY027Dn5gQ)mniOl&1z zCS>bxyeTMeX7O%xO8%QK&$|3uDO8LlRETQ{tm>-Zb(rNTktrzOR=;J#$?yXa!M5ZX zRmu@84Kq{%-aluv>T?tLD8O=yEp3v0U4Qrx^*N%>pP&mF%y`&>l~U#ua$b4C_m}Nb zOhiIxym)P08JSSF>~VG8wfw1m5SkYrIFC3cRq?O^8b!nU`vv+#_8VNd$Uou!qc@cN z`q|0av>H03m2`@cm^D(4SM#dO2+(V^OF4DV0-hrnP#do^6OZ}zF?PU-H)M*>-9~AY zdH4?0tN-$gu@D=k%-JpUExML9x*v9T!ZAI3Ld;dID3=yiC5&)x5Yc1o$6Z`Y%F@|$ zQta46wr%{n2+-GFw$EfMHrjX%qgnj||F%L89^~kNkqR{m7}ev9qrj2OMn#E=nlDtk zeJ}610Xu_<;*cqdtDd7v;Nsr)pNJVx)OS*Y{SAB{7xxey50&i6M5(CpW?uqrgwg^B zSoxi6D6oV}ePbX3emKe8%W}(2O^(Iya*r~0Xqh;)d6rs;`DjW33SRS?a|c%>qxjl2R4bZjPTYrdSQ72%_;9M~(J&Gmy`)xYm1thkf!%L{ zlKB|`7UIKmp4kB?R+F=eC{^C2LwRUA9i23yZGE#@1QO@bKFGcOCy>LrA)iL*z$|>- zE*#kWO4l*fm?L{x?Uh5bu!JXcE;rWao-HRQiiD%YI;eLg+QhXU`T~I#+3$E+NxtjR z@NPg<2?Z@i>@!!&#H8R>wE=I@1L;0tlR}7g{qt89^pr?bhSpV|NdQzw1~!&HCVl+> z5cS?sN&f#Ee`RUrpt2k}s4U4ecc_S1nVPwp?itS9Tb!Aza*&!U2e~Wv-kJ-wUg%2r4oNS@!GcuBXp=aY9ad> zSCX!v0ZOjpE|oZqmhNj1*KzM5X#x==#4bxa^g$j@04)uAw0+@1c;J@yT$tBT1gBmg zw{N1)(A|E2`9P@Owu$Ak(yI*rg#YUTD8#6Xm=Nq3qRdD<1PL-RvLuF*jf7+twuv~= zrJ)?PRKCnb`CVe5r?Lh&$7JB?Iwhzb$-4YSB+coD;E;Df`r1z9=H;GY6peIb!|NFB zS|9-*7NO^Mq<@4+Oc^`AVl^c)%SH&8{|?w;wb7SazT)+uV(FdcN$p_stm2g4R-yOQ z%gwqe?{I1x9+klo$d30N1KQ&Eljy2rs5fD=C#s46srRZpMmmOqPNgtlxpG}{+56F5REqQHY`gNvuG zA{&}&nB`OvMSclRUfJs2c5GW;7$NVdT3MaidU19qAHPaNUu?^X-enCEGkX315&6cv z!Bs6yUVqWGfO8CurZ(>985v{*D5;T(W~HOl$$GPzWn4bj`sJ&vx*#rUCfx*|1cjt zePgx83uEbSs-g(aSxs1jOCn@B=}x#sKp@OkQ_N=8utzNg%K(9MIGbihYwz5&akq!E>T&A*1SrLf^eF30U^-0(D@BR(t;}7S$9YsC< zqQkp9FyRD}-)a^83l)U63Sn!GMe zyDqpJ$L>UJy|*5>J+l5VaPGq1&NWbFt-45a?P7uBHBLJ8Q|R$ogl}udU_SeSu|Gt# zxReB3YD?QRSzbEisdj}Ak66f+B>NrM-SQQ45~X5PsZM<01|5I9$+9c8ye#AjYHq76 z*iAKYAW9_7-)6Fr1>GuC86s(DFR+h;zjQX`2}o9>KscvXuXki}8T7jgDg;uOgB3ES z)O-Qidxen2bQ$W>9t$2AS-VFG3P`5>Ffk}P=v+%RZdYXY*QAjyo|@kgZ~d-;URX|1 zawW%)=a}JHB4h5FEV&PUEdq^~ew((e6O()ZlJLcJaex*oNs)*8}Xoh;w z??`Q1cTDEr*@=s&dbL9y3NEIvc6K6kDq87QZ?&A;f|mAHk;rj0sFG>uX$1cJ`2X_8 z)Lz!W@i+ydiYkuF;%bVn-c-&SqikTn7s(`Cm#xUFacs|ag#rb>WoxN2-T0@9RuBu@ zx8d8FjMn-g55_2;DvRp}!ezEH!Lw_L{&VIIq<_QhFsYsP*|#*Gc)_F3nwWYKu~pV7 z;}d>?La^*gp5s6bGP+RS3;CCFRd;^-HK1hQu;k(Nm$EB0xc292ykg?`n*&aKuTsk6 zlSHSgx{8Bt$R7MXVwzr%V0Ze9KKli$hp-JA#Yv+2ui^j=A^~pYRTmpufttNPoq=cc##;z!z0B~NC5;S@1!+KgNrbF3-wKW zYD2x04F^RnWu27nswaT>{0*O>6xkR5BGR8I*zF8hzlWnAngYZPbT!cLd;jh-VpQp9 z+>28AVkxU)Rl0J=WRIiea1NVY8H<0Ry~rG|>kn-g@mKWQmW~n5Hh6(oQ+#Pw`5Et3 zVxXgJqlFYRWyY5mq{Yfop_uM?at0m{&FiYnLU*m^Ywf6m=7E=99d&fIQ0}*^a^f%l z;d53KlOYfCFu~R!d3Ah(CR4Fcq*6g|*)5tVbf9ZbvX~0Ws2)rc3Q&OkXB^>=7w-Jq z(3#qsX|K)LXGTYl33Jyzm93@7p_cdR#`Xxsi?*vgfkV+D#)p}ob*40ZWu4wP=%nOQ z3BQj9R$#G5Opt~+&0N@e#I;^xWG^K!?Mq}$U7v7`M5l)?5gF9OWZf5~npQF>Ze z1}nmE^3rKlVayfX+Vcw(8{hwR9;)4vt&=F0ysi^a?ZL-guhYD&64=-0UrmZp_?LY2 zEM14<`piex$?$ss?)xco7nz5ta&;|b{@@U`9kq_WN7Dlk@UXd2Vv}ab|X|D~;T^1H8#kEa%#n{n`b+`Tg zHOl_M8d}qq=O%AX>X~et#qM_~8G-?oSFaxT7%^n|q3@w&Z)?)dofa)q774Jo)_u{1 z6p7N5rF)V1%xct>WyY#skt4D6OwF%=paqQu)A zE8@|0gJ)Oj?5z;3g58pql%yTO2oKxNaHg~zo>Rl*@^;PUu|{Oj^S(YItB&)hM-qW! zA)?FA1+Y=#vWT)zJW2+d9q6Ll@v?vU-bi3?>n#aw6N+Vy;(@e>0s@~=azx8pms@rL zhrevPwUf!MoI#MJ&rZa@JdB(_ocAAdHtOpmgS>jNMgDhC!a+DrpY=Ji=pW-32D04a z%nd4%5&fb|z;_(=tzTG>-^}JjIO44+c^OLuy>w}{g6$dqRUjKKlQ}4Gdx2l^BEZ9+ zp5jQUZ8vQ=!@_b*k|_^U^-tbPuC=sP~HBw-*D?z9czlM&(9mYP;9e7LOU}@)0t;VB%R>ZWbQkKsK;D~8Ep`YQI zovLaZFQ)jn6il9qgV&dywgAFG=mPZ!VrKu$g0?Q93h#D%9q z7gbEHw2_#WJnO(L%u=R>^d0{Yb%GzR7*}&lKLv5~K=(k#@_+8uXMo4tDL>!l$HQcv z$QDz7EBt5hlJAL>=QqZoB4^%OuXiTMo{ZNH!UZ^c$#=I3eDj(rKZ%-x)#G5)u3M~t z!1CR-6ph36H)w{93r!%ANs=4Vbt${J#!1Hbh{B=H6$}V* z8%MU;=L-tU0N<{x-Qhv8em^O%v^25`s7vyzs~eD~u{Aqq$Hy3R7F8{>>%G@dRyNgT zRTSafs?(a?#kB(FNO#~_Wjx2ZkNw{u@+K5 z8WgucYl5w+>O2^VxTWM+-RV)IE!nGjnprFaW6Nv}z;cKv*kSpx!C>m!Bs%kk(3x~w zZw9LzXDMo>78ZR_i2CMTYPe1W@gX}@yjPNB?0x3k4eZv_17{y=(Ffw^MqDYKD&k;h z7%$*=J24aaI$|)C>1g1S7(!LKek36wv9^o|473w5S5+;KlMndmvv^%r%CGrlL$^b` zYWmAH-PuDZ?IBclcc|t$NGVo5V?P!Q9=&4sZseAy8ZZ9Ohh|ATWUP|D>u#`!tX$zX z;Dr3CA@6z~DQ;v-UKIB2E`cLu zBu(LrEyrBm_e$aw<2bI9VDN>a<#w`UIo^HlF((UoBLoPa0B^MB#rbOa@)wh;2x|Sbi^dz12rAJ#JR+@}36HN8^-x zeJiRJGgJJF=#J%TxXn`fQr3lh&*q?yt~Ta#%WO?B)wc}`M88|Q23V;X zysP$riA_0DbMaQMbYuoK+}W1p(6rVgg(U}-2u$Y5W=vXG;OX08&uGIzx}HP~vy+E! z{Jkws>%JfkHEjeSCl;#eTeVJ0v?>F$OwnBhQhWJV z@b4kF6(fz;^G&j#aZfX;umH*XAiuxJ zZBRbK5wkY4HldrM(Q~5&vIB~dF%FO1CVPxRzcTPbM_&fYK-+OBfha|S%@7ld<8JK( zw*%%SHKtsh0CoLYUSK{~`#b&ry*~sfkRM7#ifkc)hPIWLG0BoY6ojTi!2R5x`Zv#D z%zc6s@yYs5&d1fx3MHW)7oZGGuARIgQwXc<0#B2=uF!sN@;`jED=9~j5xaLEG4JT< zdHcrPXNCJa9wuc3ioV|CnBd*UKyTKGng+~htgZBth5}eP@9s4Bgw754M_5UNK>RXd ze_x-FtfwAxp$ZX@Euw_E89>?P3)xqX4`QVr1~!gKAq$tHJ6Gvp(}}0)Q1tCTjbZ|( z#T=enC7U&v)o!ZwKmEeAc~2o(L0l_VdWtnr;s#6JbTVOE>7le)4A9OP^mpUc_#Y&8 zw=v7tArwcSTw69vO~Qmru&7E)3|#b_mH>Rw#uP5k$z`4F-A!5x{I{0JA1SgY{SrU# zy%*0=EC4cFMjCJChYh@NBHqtvT-=iXr>{AQUYhW-`gUawo^7!NN|)khVch$@T-GI9 zGT#V+87%t}c2pvie7}}618~$rC=xr>xb!yAL>j`0TD^{i$0?K)dp)es+5VG2i^aa} zhRJ^53ZZ*^EeR;j+p>H-QZdY1U?;&8t9XB5W!Hu4Q1r(+>P!gk+8_{ds0}S|x6&l6 z)EY3Ivvd zn?>wQKo5{UF)9Z~pYKxm>XzEh{cFBZVyeE)-{Qs_c~u%k3{Lb`PjwTlIk*Lp5M|Wl}#grLSy+8TQh)KjZPh3If^P^4?+XUah;=0#n*M@I}>1#AGC`}UZS?W z+pmTXxrveZrfT<@^D-qpZ@o&tFjr@Ef+8R}P1y`nAgf$~VYYDgpV>PfQ=P+*TCtyn z1BD|Kz(K450=yBgopNekn@sT01M zrOG)~es{db*P&&)^J5;*nC>{kEitWMFVqCNZV;y`_g$2sr|h%jzT|$}8mQ9V3p zvs`;JJ6+a1rD&p{v5$|hhZI-N_#0jfzzr?n68f(X7rO#$sx@!4VX>jG7K!!!MPv{a zzXoDMKRN8sg-uop!d8Ad4sZ&!EZWTEZSM(RBqb*xB|~=q=>JSdb^PISJYvomx}!N= z5`7~zy&?@-vB=wtVf@qBFA%W(#9ox@?2Rm21o^jt{!ZVvTVMw2BegyqX(86RJxxsG zK@^p7X}v z)7bALi=8ykJ_qvR-fxa5NpTUDF3faaE-OM?}yQ)4A>%jbYKggP(pF5CJ6|hhI;~G*d83-( zXf{Me^Ks~h}2S7sXJ9m;B0%m&O3J?}`EZNd+WfnL;=3+h{*)cE=x(pWi?Vnk zQvXFfMW5ESGORlFYEJ}T(3lkSP;77V3hIqT*H`aFL{y|rl&|lhtvJS_)4)bBsqu%_ zd{kM0X}?_8@#wOFK+emir##s;7YQwh+KnGPZkHR9BK@&3@;~>4yP&t~I>pX0tZcs4)i!YTv57^TS~0|T zFh8u#UWiXg0-kWBh=1(~bG!?*w5s)B3JNIIal6NqoOwYO?3zo*`i*yitA47D{lLm% zXQ0o0L_jx~WBzS*aD8vMD{e!F|{JV4~{dGt2&YXt!KQV532SAsFhsh3X>6%H`xF#RsDkiJa!fi zROjYrJjbp_3af{K4_eChPph2isghG&is&Lwz;y}QPl+RI#V@_Arc0Ry|GzxS7x5vm zdiCjfMjyos9&~BcG`k|AlAnQaf_0)vPSTOHhNEAhl8>5jSfV63J-_cH+fI{Y4YAA8 zp>FPR-@$q76&t63>M?iTWqbg}*Bq7PNXBNIeszC0(;|=l;qz__c!_o)LWwuzR9`Ob z^o&5N#~#WW={u6`1<$5qR^qF|2a*Ppr|pTTeSo+gD8_lVKPY4YKd%^zq*5*Q<}x~Y zSOoTfcW7@VIa|EoFRf^Pj+*y`*l9(QoZ!6cs|Erlbz8uPw5=^vZR-Us|2767{eOQS zzj$o2#4EkwMea0Yfs!>DPJ;{<`(yzR-WT}!f7%s8@!r~8gMv)gH)HzCa+!d># zwJMz;7IP^f@z?*#&PdxKyb_v&Qpkk^yiJ3=TXg%~=I?2dgpinU_ufYlyP_;VTr;j3=i|WPuikf4sI=QwRRQX9I zC4_DZ+Ux4>>AtI_xj}@&@;wpwK6Z^eYi(xYIz}=Rb?UY!%#4EU;IQ<2IneS2ASeD`D@k={so3K&V)i3RxH1i&raWI178OGxK&+_bfB`o}9U+4op= zRk@$rtcZ87yx?V4iQ~7-)<;W1_n4SG2ddg6qGHV$O~#bG(MwYu8b{B1vU)D zRG8K~qLos1e*nnck(-n2ZJEW`kU$4f>%P=z5`YlR_AiA~NaEE2anV44P^R-eO8H!c z_PXGJb_$aX&gYK5Xz~1*4$EQ@72Cd#>?wNza~^dGxx%-v(XBdX&N)MXG|uWSnPoa% zc%duIQ5lbWh@|{XMF_I|ebus)iJ3{T*-^yZ1@TVS!Cwb>kJ~EjTTNZBkE5`&_~LINOq10R%867mjRrw^oKwu!=^|I9%m(1Wr{*xN}e_j91~k ztAu`n-|1WCV0{AbjGAjO`J*K*NeDfa8v&n&IY7Szb$`*|JkoFPCxYfFPsf8c+u7MN z{+{gH-QpI>UzXHw7pq;D{>OJJHwvbVuTRLOZRL9j1)QC=NnkPA!(KdzEw6<~nOIa0 zv6>Dt<9u>(lCWCJ4D4~L6?JP2xRq|SICFd1tYUq7{2&dsK0GV>YulwEg&3n9!&WVy zNGlgVE;kUe17dX|Mr@AU@Sn<}MP@PR1yd&MjYjjQTye62p4v(T%CYOScLz!&W=E?& zD+?8;a>|=vqfs=FS;#b#E6u4RdMGDFHsV4rs*$M zVX>Fv52O(?P7>n-8Fx=I{VkuwZ=#k+>P-f0bps>eDFtDP)RNk-o7@fEKmJiFzd|$u zA)tBIX3{JELDj|1;`MeNIftqi{D*ly>GT99NRAS7xG`xt%#5<!qa`|_oOs8nsv;Jq6%J=I>@0BcSe!=fwk8;4`+` z2mc|=3Y!3m03yGq%<`3PcfSH&$@gEPhI@2ja85C%6JvP3NwVg#k)Jq_kjWc`k5RrW z?X;>&ft|?6#{&@0U^_#m2yWVDMo=O8LuB%s=2rxzDWki^{Xr$alkx*uh9nt>z z^@MvU!bcmZ`9*g~p1-+TVBprrx`y3QDKopLFyaq8swnE$8b|5cJ2U_o@Zd&SfhcBK zqrPL<+O%4S1?xAZT5Pk`C?8p6z5j+sUq_<)OS6IOaC0&T=4rkn&`*$sI+}`s9hA8{ z6=_#U+Saew^L0&cvhb{5Lcj-kBAWBkHl9s|4ZK2DwvJhMCd^H!9%nmRbaWz{C8vb9 zRHq)s@&ARCC6&ERmYEoZg!OUX{pXGl{$`BCKc7;I)Ageg2r4o8qlNr@z~QVrgScC^ve(h7cT*AWdJ#+Zptr{nKkY>c z|G!L|LB|yi2tnGmVi#}m65aMeS;8YDjbMHC0RRkrAa1MVE|Kwb zc_GH+R!RS(2QKmI`Wt8Jso=NtTzYd#?bMw}sE8LfwXtn}wcbsr%&%O#@NvKa)H@n+9RY%Sm7}zU|kEoH~y6u}W4Qqs{+g z0G+5t{@(i=qXG5w5)K$ucaCbSP`grbaY}$mu1gf)8sbOv4C1qV8Ay@XN)Hp&Uw0q# z?|X7XiMhvJqB6OVA9%{1K5o=5c_4wTS-Q!h7~|BVLz3`yi`9VHjmwY@7Xd-!zs{8^ z2L6fjq-l?Em_di8X&p&O>ZuSl3;{a*FXFoD^QU#X3iUyJ)@GMT-}n?K0p+(X8_-$X zWkC{NWYldw<|I9dk1S?;m+GV*U;aGsW9ngBqwFeV=9v!fGjUSY)CJdE9SE=C+@w!} zxNdXYSon{2TZ1SJhxL?iZ4rss@3-hNzWU9-RhXoR&t&7Sx!2owMrx32UG&CgNpz2X zK&oV1NDX^~e4$M{x02g=fL*&$^5HhU077A!6PjY)>)fuclm<4>WdP!AO)cY9jmVD_ zyWTZkb6B}x-Kx%uANYhbA#gM%Nq;&oud!?M8P@J+e2 zlG%~Wox`YNjiSlYVo-f^dBTP&krCzmDF#kR}xFqr1RF^w$E%OhSh zbiKBsSZ*hu-EKH4Bw)&NCPrTN>1*$-O?w2hv>Gu*`Y4u3XzQZ{JbKr8F<{bELz|S% zLiKpTd$q8g?!NNr-q?HP3e_kn^7gqSr}G#C;90M5D_`V3%2`PYa_5fd;Mt8I|8{*= zV(iDbI@cVzB(Pk5#uCqMM(ogL#H;zeXB9AvUH3IvhXVlRcYwx65>BPNh0(XOr7mV* zm2DAfE&QYxB?F(-O-D>pe_5&mqUo_G_HtW=UImLETc2Tt>w-N9Crj-cxJ9NLHn~<*~W!Thmf}D`hc=I6{aVCqh}il;3gSw_eXcR)q#k15P}83WM5d9vZv1^HzHWeFJSpt zq`pw8BF5&QE&pA|_U|pqb3@!UN4+^xe+27$t+1_c!>9Olf&Jgt?*3IZA?*eWi=Ie6 zYB7iY+Qv=9H6F$^zdjZ!E?+4pMc$RZK}E-GG(p%mlgZIo{={osMaxFMILrh+B0)2$ z4iS}LVdOD`6u$ds?(~jrXhsNbJaJzfVv~Vlh#bUy8!V$=kz$;Sz7`^n=Jpy_Ze8r( z_1!H`#8N2LW$)fry0RRPY*wpF)x!;2Vxc1W1AzQN$%Wf&_FbVBcVsc-`P>eM!`Txs zyZ=IG3Q+&0zt+`yOt~^EsZ|l@i7cq-{3ZAQGOs;F4C_gorl%?21S}E%I`D?wdoP?V zrGBhhyWl);wzZ*x9T-$;qy;FRy}ig*?G=~*hIW;9lKHNy3i@)Lj0&rFq|O)e0lc_zV3*zD7cjE zI%W!Neaj8wWAWO2S5Hb=UQp2$is!W0DerRLlS5o>{_mya(lH+TDYB*yP-C3Tx7aQe zL?F8F9*70v<|ez!5%01)EzKqieJf!`e?!XD)0IJ-(E}Ifm!11U4p}utsee+;e1?1h zB}M^XF944Ks-#;G-=E?V`kSnLnj`D>L-9bxM6>j&BbzHcigx-Q0~Hc7umTyt8R3!~ zg>NDg7j4>@nuqkCy`k?5Te@lrPJXRR%WM#^{M;z)PU5Z{lZu4SDRjAE$VmbwEMVee zCI`!7rY+x{TTjs1mK|gA%TA$Q-omE{-@Y=k_#a7v`(Dj57QRQMMs6nlo zjpEB>7#BT5jF+{#71Vm!#M=0Zg0uT04(mey9~L1!jNcpWy>pp`GMO8$OOW>$WWgDAK zaEq~Xm6S`$h=h`dNZI3?EHDR>e65=v&)AM~(DJqZ?Zm66*St4Xy1_?l8{!Y{QM{d} z*c(SKB#Jtads=L2$siQlvw*7%?G--D{Nu-@!KrH$huEQW8y2QJ*U)rpaUD)h5I?2^ zHbr|bkC%;IFPc~sR2lpbP*M7iX+5bU{VJfMKkD#DJi(%BjT1#Vj~-kzVoiwAq+8J+ z;oT57gORG>Z)^fG*p8)qTDOQ|5~GRKy3P!DuG5X4$*(e|f=n&2nv{h zPIGFzZ-tIlKk2Fn3ZLd1(=GNChg=oqzw|;aLD9FMm`8TvPJ?%)9=N+Jw)|=mtXL{k z)*y<>;hy|ur)-cc{u5GFlSq(9>|{Xk?&sX4adIJIaf*mrEo&sDO$$CpciSRjm>&Hb z``zf^pkpNN;P3J2H{36$joh-iYn%hO(K zAXp1hGCujzuIfa6&SR$Gw7YY$Wv4~AzEOwgD}Zf`teldHiH+=c;c?^TD&bWYd9Z&* zmmMuB%o`|^GwvmS@i|~eLBm@AuG!St=A< zhfkH5e+smXSfljfqkqdi(xVK4!z~od(y*rCLyPyin)}7=Qk>)iZ9Z*>V{TN zte{<2d(9u2)$a75etkrG`E*V$KBMwHYJC++>zNEx()aCT@Psix;@W-^qaC(apM_5JM?GSIv@=5&HM>MSv zvz>(`v-nUxOLbTr6Qa-J+ID&DW%*k6{fU{!$j+WNWm>K{J8wI1GORUBoo_q2mKha{ z2l1K-D;%}zAq=7AwdOB7&p)FF#lpo~s6N&f>5YDCljr+dFOsdYA6i{c@8k#vck;by zi+>yvt>izp*3l}zprc=UHlG;)1_|Y zUZMLlA2X#Yes0clva>C7#+ZA4x?m_`g8M9z7S`e>u^3ax$+`3JS6+|gFEe4k%6Cg@ z!+xHq`Ao6M%bP7m%b2A&+|{nl`>zQ>9?%%+Hld0pPo|YVBmZ^%AmmseBaF(? zK`>P;r1&68qN=Wd#jug{@55##*#(dC-*w-%8w?K)n?TKK98-sEO8NshX?=lq|nIe$OE1Xi3N_5*_Gzz)i10*Eq~-AXcKa$T!d} z;d|M!l~WRW_Y_7d4vy^SPEz|?R2|Q&N2qo`qz!1w)w{U<81lJE8Zzzuutxjyl>9}j z?deM}`==?n;C!eGt@1$z-E=1i=O+lI%dEzQ&;0%Jwr$C9hJXGA>h!lVJyYtz4?xMh z1S*TzIlA#CeX7}rk=cka;&M25z9a{fS_%brCmn+CnP08bJ;@#4C}(T@#s!HuVhcpF8>`iG2+J2oZoi3?5DxcIR<+bls(!&G%JY^?E6(ksVM#P-*t{k z`_;4P@Ka=ypZ(}Vd6uT25}i2Hd|z(d8_4TJ>^_F5`oFvp!!O>;(%Ef`2YQ z5`3EXZ07Fckr3{$64#%;c-(O_V_$zhXMR(!;Iw_iCVDmv-uo}*U;dAb-^YQ!i%m?b zT7LhV`a)7LAt)kZ;ZigE)!t6E7Wc3V`aJ)vzNh)G#?IIew{bzYAijS{V!)Uo1WYSi zM@nq4=D_S|1b1TM>L08=K4_VZ;|?!ex&@0;i+a5{z<4+NHJ_rI-}=Cx{S4PC$y6G| zWZILT*8Nv3)2->gN$H86ar$R*Pt_T2>5NfklvjOit-cKYR&%K3mM&EF>`Kd}L(lD@ z2tu9M?7E*vmSjn~<)5c{X)jXTx?dk#NncDesosf8V(W9VjLmOcSeUFh_yyrmB!XJz zB;SzN={?S#zaer}1(x9gjJLtL@$MdqMSMB9?Qqe-2cfUK^grB+1YhbK4EQqa2cZ#X z-n{4IVRgH#1A_^_^A}Tu?VQowA~=SN!Wgg6h1~o@KW^O4kV3QnZi0h`5g{m} zai-Kl^Yv_p8Fa4m>jf?I#T3beK6PSo%P7o;^Se#0RrmK$dC?j6yAnBLhf<=Tc+jxe zkQLqot=AW^_<^P0XV!{X=_)a&uFUEC%_$F#E%p`EvUGH_?LF^kyT09n;4I%2%SlNkqW2It$h_(?SZ$*AocHdUr*Fw zw87G?iPdxHZ|Isq5$DRzNSo6=zO+GsK!^^mP|fP;6J1LiF;+&si;2yH6$sfnmpQc% zo87rB9Pt;2%MF$dZ?8tn2%P|CxWI{0_N@%p_b5dCC{=np~Atb^q}bp-nQmuG&&s z_OZIp-#tmsS1?C9>Wdd5q$vp_TueFVJ}nR1DyaQLh0Q+(O#P6Zf8)VT^c?%C&FrdJxg2kKqdc}z zOMm+yyuOO({?sy;910;h9ktL72NPXVu996&LoTBqynz?v zU2%ivnLSGu9N9Z8TIPo5iAs=AAwkY%-?_bS{pM-1t`7|+oZ}^3a{$@|XdspRK(_2D z(}YVqJQ$l*-6_O!QJZ>HQYj=bxYTsvHOqoa8qtCHZt+YSAH=JEIM_)s$1;InL%Hi;dJ`55IQIpiaFhx z;;Z_!d6F?EUH{-4c}_Zm!0F%6uTM7bXeVQzXi71AOC+fIXvt**1m-FlbK_A-Yg-bQ z!3u`c#$U##r*5x1EDwc{V*(|x1EW}N{qcWTcrbc@TRYyWR+ZC7TZ}k_h=1qyV{hXt z39B(>dG}2o=ed!U?k`UFK`ss>M&yc{xxLG01!*CvD?9XN|bOw&LtrNz>gj|&c#_6=(9 zh^!-As&XH=AjaF0PP;xp31|h_o$-J6XE!W}uLM>=f|2)FLBOs#iUw_cR7N?-sHBz-a4lIX$X)Si@n@gZbpsm{jjjV#IDr9=7ThG|C4EtUZF(*=;uNmzUox%=~Y%u;(r zU>TS*XkC^@w)~}7`O{`FQlRzqNn5sWA!;z9apmjz-p7w_5-g=fX) zmxkVZM0`qw2t`y) zI$U~!i&Cj#a@pE%tMH1BP1ow(XOIC3pyD(zIRC9Y0|iB39kC{0}f?U-S1Oy-ARbwEjXYT zmoGiB`?m6LchJucif3SLJOWUO9VIq9qpqp6DSk?_DKq~Wb5>R%Z$sxr=|D;Eo8tS) zf$e2`=K-Lm$`Z|^dmO}I={h^5xY+<@iPoP25e#Nc960Gu;QYasI|Y`Y9ImM6(jBa{ zfb~W(B{Jb!)=P-{70||ED(CPOnHwVCnGypd_)n*_Q8Qh_9Af-0$DAxP0gK4Q?&tD6 zN)oy^rBlnWm-u92Z$b7qS94pHU%PlyoFdg)Tx2bS#ZlV!46~Y6b9B9|ONEdwS04q4 zs+*!R*(ZJWwG`0JF!BEA8Gnmq5^^T?R*$x_+nF_pc$0^v_{<(=+VK{mZuhN1TssK# z#-N?GHSmhWe>%Sj#_OAq?&w|^q3S^?&l%wm4 zb4<2><>-{dNh&PgEwqBVv=J9=4as01Qf8E9y_b4;!#U(7?FqTH9VJ;HUqTF@29+Hg zRT9C)=FmywN2W2XayUyy7(#~Yru6Bk+FkeUeYN>yBXZQ{4RJf70x8&*<^WVN9Ko@ko3%Exkp0}I|CoPsC)X3 z+9pj@TbmuApM3@t*t+BJpm*zgw}Ed!g7U?*%0MwB{@A&i9 z=03yo^TgA>GakDVA%k%AZw#F@*%SRP;C}xv`4_#^6wWD#OffD##VQ+Wo0>MKRPsp& zE1|zolRN*y!4@$oJ3wgy%zDrwH&%zjb6K2wB>^=*t7#o0|JkS48h4l7PSshjJjI7k z`m2(4oPY0=c1e&T7n^jQymjkNGIm%p|6NSVWz!*%AekQ7@kaviDA(R6Cz1>_24`tI zGiv|!VvK(ld~q)$!iZK!+)dtmgq$QPT3fsul&<07bH9I-Jj{g=HGCl71_H{)&)h}g zf@pgf2Pz){1jv9|OL_P9fiY>9}BorC@uQ*N4-5^aV=frE*Xe@@~Cs zG8pdLMJ)r%O6>DK8N;%e8RH5M^qyunJK1`ukt8%*Sh%8fQ1QZ}_w!G6-2o;@1QSGAzweVh!S1+hU~SruwS7cs4w_|rDE#t?X3yX&O{ zYsSrjPRJwqFwfrgHWjH_l0BG9HzlN~x&Pd0COek*l*XXUE0u2$V?Z*f{9z}whB3Bl z*?>aU%PDGk73Avf{Ce>Kmz|=3SI#!Bj)VLF^2^iQwc)dlVb#z49tZu~+_?yb*tq^n z9&ThYa|_|s*;p~&3mRycJp#5g6c1S)jm~3`z!I;nib|iPjCdPussK_z(Qqr{v`+FI z#5(p}RRHO+sj1)nk^>XJ!)0vzD7v6xDnim&zZNZLzFt{gK>c$UUlOookKuk)C>0`} zMnh?3!L^eF;D47~6Kou#L~U1se%tNO_Jy$YVbVlRJ-ydrHLc712Osh>Xn5HoGT7&B^O&aC z17gZ9K?d9LEa@blsFD~-OM|Tmho_=L%Ss6iYs9`MYm$sjn1W%>VPVMG49JyJc&%Hn z_1@3}sYgk^&q}&(6b>Dm@n)Z*a2ldbqt$MUclH4O3k(?UmxX~?=|*xPtUnJu#2;`E zmiwy^x)z1Ej^8Bz=fc1BI-WQOWjFPFJTG4pUl@4R`Q&iQzJ!18^Pjh?L&P*0%Eu!# zRkFsb5O$rdVcKqTe%m#m28ADzd-B}u`Js#j@p~r)?gW>!-g6yo;a7zcY$~ltPhpi8kAraTj|V(aJ-0d@?AI6^fKy0fYk%Xd0>N;V>FDtxdERP^=b zcmBowH{ypbc>L-6tYHt!6MvC}zHAwtaHHaFI58&kND|FzA2oVvDHL7mrvJzJ=yM=; zr@xx*nQ%49=qeJ(W&d& z%>-FzuotL=#&Hk?RjH9*2jXhGXM7?A0g(L`v?+RcD&4p;q{mN4^&zVESJ_qryH#7PXNm-} zOGw?ylP>TV0ljRuHIRxSvW(v`4wm8+5@3}TZeBsrlvrL{Ql8F1y~wB96yps}hIGG1 zR&EEK9wY07Euwb|vp;kmq2kSstES#21g%%%^hY$$yvSEemCRT-w~#B!^%Yu1mtn0td3jv7B`Le zQ?^fgeNzJaDNGcTFuM_gPHX)Tqg?F;Q^d+U8ao0zT47nPJ3Jse|HQzbMWv^&jh&Aweo zG-5738uT*iDNW2}{rJ+YbQRnNXMH2q(;YPQmT_srW$)pp&)^L{_*?jSd+bJT-~`!3 z{wo`I>m9X#UxB?-osRz)q_<^2khxd-kPMoCn^!q-@=8yB7Dy{t0+e6Rq`I$sNM6_SDQRn$D63sGk75Ps-z zeKEEM-Rr;a5NZqr*DJBiiNod%|CQa{^3CD+O_NcD(+xkE+=TG8T*k1*25jz2y_Hgd ziuWzO9NfHgm-8^=UP+}wm$eF^04o5b{1xA<@=Vu7sXHxxAKLMa8R_3)TZeiY zntHq&kESxt$Z{f656NWJqoa&7rF5ub@0e0s3^g7#2Qd#{5TU5qjsj(^Lc2;D? zm0JHMH5)5GT4At4X7R?Jvw13$V})j5DoPX{ULwypT=Mh z+ITM44sd!*Im%fSG1%NAql!H0v9VUGu6aVZsxPbhNBKfaI|n||FPLB?veTurzi?>c zynk?l(VbDo?d*`>QcjbXVWd1OKb;Tg?zs`Rz89;cn$caDCpb9yy>TpbuvlR=fw$l6 zsikQap^#4cs$NSTa4S}848E=4L3pNIem{gf&UQ=ZxLazcY1aW-K%2fYGeC6tyY`d7 zQShhD2m!?V!aeiU(mMD7)|p=={%Xeda2nKa(iB(MDLOMLgmobKck6r&x9>+R#I z`aBpLRm`l7!(9?g|Ng05FVK%S!Aw8bJ~p<-tw0As#+2@Dfw0YFB}2yiH~&TGqDiHT zn@vvrpXI^{E*qp?IJRZ`QaQnOzNY-&CLG65`;!@I5e?zRG;z#2dxonsvU>y$ndol! z(^ggySd5Yk-8T%JyV35YEs>nwhv}h(KCSD~@?Wf93C)Bbs zpAtqTl@fIlX%qHismP6Hg}Iqen3rhqCnYycE%9^Z{mcs@kqgR@?SF3&|e&eQwK2c~R@HTmOe3gx>;KZB!i87rs#w<~H*( z7n8v~sf=5l*W0eX;qi_PV{=Fm^g?Gc@Ec9a8nG_)~Ts2LC~3czB65j&OfFw65I8fHTx6d_GA(7%C#peKj__ zYPkDiSdGgS`vRjdo|A>uyIa~tVEAbBt5Y?#P5SlO&^s!N7pa%y5sD*vMLo->A!e*R z(`)y!uk2@7q?EbaQ4+^l0TpBQ-9YIayq3`BvSw{x)ySSDWh^eSp_*lTtjj!Jr$KGb zJ2 z5WY0`f!aMI6PgIs1X8=qLUeZG9`r{ccG9qM|r)V!XglC@q$9f__@Wc>qqkFu#9j+~;b&>0Ouyla-# zE9DyhYbCp@*-i)T6C_y1s-5R&wAi-yyC;4<{M|BgzZ9RvN@U`6?QR7Tb`f|T7K(O+ zQ@dNJ^qu zB~m2T*>;Vp2Y=r;m|lqPaN*SS?pX|3S6Y_-W&W!M(Vh7;<6dnPxq;~_D!bQ}hFz*@ zaeKm+;{ug&u#+ENiTrq&Kil?(HKQ0Joh)T^PI^E=>f`W>4K#aGN3X@)&!cxiD0Zbv zBU+tA*oIx$fTJk!;PIfRPsAT{|3Rhal-)*CPEnG6=N6RPY^QauB_LlSpFGH)CqIWD zErJc))R-s`J(}|WIsXo}wkvfeLfN&t^?EkgHh1>!2qfr@GD3wSq0XtHlqhwz#aZqg zfgG`@UZGUI3P&6~tUN?{cD+ecJ*ofAyuGyU_4uwy3Hdt#Roeci_1U z|D2{lGv9n5c{F@L0A=E*?UTZXXn&S*ixa3wkw@Y+RJMUEcTl8t|4q~nn!V1eT~m^f zkbhBqadvr2N!DO#6B+&^*IvsL`o1hUDu}`!4{?!n=P4@N75#-O9}gSFO@n?ZXPF9= zCkSP%w*&l52q=VgGh)WvQTZ>et?^ib-&J;=9Mu&Z^Ul98-= zbU7)nyy}%)jkQz;9t5KTxRnH#oL@WEv5bXrFsO9>^j6;Dr#5(f!0ET zd36B(4QfqIvAu&|WhUMK#jT2oL06Bca$VQGXiqd^og8f=+D_kQDWEuhVN<3F5_`5@k8uN~?+c=7qojkCdaslUVcHw8EQg&a###jHgbClH*ZZL_- z*J#;ov&@$l#t(Hf2?|^TJin0$gi`cs1$PVxmh)ZS*c1K_pJpp_ACBwitux93;%>-Q zpbBECuLq`I)u41TgobmZ9w;La9{;{Eop|4=uBv&pfK{GoXX2#Xx}}Lswh`K&Fx;Gs zEM|&oH;X!50{%cGZ zrO2gIpUQ7YeV4~3G$OIpU(sX+d~bd9mgbNd<)Fh1O*z;uZ|G7fq!PT#*^FSfm&u&* z&X>cwjI+tz=n#=5l%g8A4h~>IcK%Qj%eX?uVSJ1L%ipXkIZUjA^W)&Hk^Xd@dDo&K zrho9t1qqwutU%W6kR!BYqoU0nuOtX0C>XmEqyzANsk3&!UfPgth$3a=y9>?Y*&EuM zv=I(0CH2snjslx>QPi%Te=aF`Y?a!i61HJ#U}L?FM-w3#VYq1{&5JWG z-(!C}BGIBf+F;F{{1pCZ1%b5XZwb)#19$SG1{S}lRCmN^ijuN@3(J<{>3X)+5OW+2 zvp*eYI4TzW@%u4>8bc5)IOmZXxFbWKNoCvI{v>LEOibAyt~G9smD!YrjB?bE3XXhy zP2r(=FGTY5^Sd?b(3q+vl|1Ey%f0+yw|1d6`I^@iVDd1?as6@9zEI9^`I+|LR-sVd zH}e9UG)sGI_p`1DUEY|>FM|%SJ9{S;#1~w32f)FC{yR{6_gu^jiA`Aeip89mMsd!w z1B&06$$(^I`}Grp()kabem9N};y9aEsr}|GEUL%$@+<~nKUOOPA5Siwsj2mmX!oHm z3;3H9ZNsArDN62g|3r9ne{C zmuLx5QZbD$`R2o}3K^La=KJ>ME0FzU)&FG=;=rlJ-imM~fh=ybtuNs1ecZ1zH@$3c z%ea$QC-N)tUpwrI-Wt)l5;BonEY3c}@MEukUwF%w4L35aF=v~V91nC^eWPkP@>=C} z&XwX5&P(;biHHQO74XNb@_#g(Yx;Uae+nO z+c|@LUoLnK^3AON3a_23dA<{SwDrT7NY+)K27sBvw^N;Or`QUN_o4MS$^a%w*&S>Q zi%ZIbiwg}Z&n0fy=CW*&3s?<_P9Kd}4tpA|d!77=k|OUl>BCCl8Y!=Rz1==Arw`tH zq!KRoC^c>mpF&m}aBh3stMVsoB>m=LBc;*6EIcv)Jzol7ZvMKMFvRhidm;%hg?l5q zEu%@tu0~10#Vj(XcOTawmIYEW=oWt-gh8JTyMfNH#TuOrnBkV-8u%1cE`mgKNoF9b z)Af|3$37kGBu;fmT~z%nEv;GbKAh;l&yhe6eGpkylVfZCcoPXJ`&uVs;_^noZThj#$f{Q7Mg9JNma z0r>}S-UpR=jE|Y+>S4g^?AsZg>h=;vC8MxXx~&}j#f@*!Ic_demS|5-k#A|K<|ok} znc2hNv##y<86!0uOUg6smu8pGno3EZ%Ko%0wu%9_Q6VtQmA7}gpN6*>65(S?eJOjz z)=xn#vv%;Yt*@H?b-owb+P~VcDea0c32H1?F{tEk=O zfrb|UKsR3>Omb~nD_bXa$N=kK6zD*NHOL(iMuqyrp1*&^?~^&v2LXACdlM+}0tZ8- zh21z)y~6G@adQA`_J)Jc?Wp=+cQWlPkJXX8Y@iI4<6?rO*N+_WGZEoWxXs)t8w|X1 z*jlNDc1&n&Ue#>dq%jbM>lTRJCTZi=&or6_ricmn`F9qI=_KKD@Rv_ExXfCfY>x!d z%o|MJh52s)D--#>_VvlFX2(yO4u5==V8(8!+IO3wgYIFv-33ie3MLQTWE$)te}(qU zcQE>D$f|w^RhRZJ@Wal0)d?q*u>Q}YKQLhMy{9mQYKXC+xe(hO;Fuq3Y;2aRmzSwO zV{~?G5K2%nk@Am(ILqpqRdia`md>~pd8_=fM2w$;-gRT1zDdpCqX|OTzOU2>+_sYs zJCnf5MQUAnQ=;)zXid}3cXE*|X>p9Uc)UDpQ{4E(YhQjR>kq~xYEZT}p~DCy_|c6q zaABs60t3r(gPszz{Hfx!tM)C+j+3nA_xDP5@CQmd0C?MbHWdv}jc&_b&7lm?MhMyl zKByzCVZ#skB+6-(1rxSjKNVJcD5meTy(Z3LUvqY?Q(zp({)LfZx<F3-zqzHcNWp|zE{;uVwQkh@?;|N>#k;wz^P2x z^~FlUGRLy|0!^I>L9weWy;i_3+i|tmcE1g4^Xk)L6vcx*7()+3lM#*-lT_`ft1&-6 z8QDbQ|KWD=i!f#{THpuY`Hra_?Fvq)tnwkRhhzEVW5z{#_>Pw^$^uZH%t>txsJc?Ar7P%}761i*;q?CXwueHAH$`lCrM7QUR zDtwzSG5A`rx;Tizc6gD8?gm~LtN-FoO;n{P3aAdcZoK4iyz_}wT{SXjye8B8K;-HN zVEEuh&4wNTXWLDF8c*e>{0>*HTl)8b$(qjg68}n@SnOR)x8zCLIW)n#ko*-G16vJ| z0q3lxR9y&$U!EquCOwX&RTJVP#LpOHUX&)yU^=Rv?7d&Q8;+I@>p+CoR_6h0E01vL zS(QWH+SH>w=DAe|mQPOuo))g`NEKD^WChIqH>8^kf#gZZl^N1 zbDtFsKh(P1d*io6%KPuI-|6qn1YHK)=iJ+d?$@5G8hG2jWMt0A2AYT5tzo}%7}2mO zl_QD*gH@!RTJ3jdZv4m;9w791{d-d}G4LBMUz<|WMStZ=_WVg?V}X#;YAkrM?CpNr z135#3QyWTi)T0*tVk!4v+qyiy9f)Wm?^uD)N zA}J+uI2-Eq5EkYTd7?qV7!yksl|@NAl?}!^fO5Ik^3_6YyTy^&r)?#Ge;dCn;B-YmImk!-;2AmpbFQlyQbG$7agcfJN@ z`Yp8ZD`u!CNF?5*ytzMtwP}(k;8o30VW0i)JdKJ6pJj4Wn#{ouJ*7V-WCwS-7E8nH zjm|Z@e36-BWjK89bvuxfkZ5)Q*^pU^0wdh(UjjAYa|3VnP`{MemsZfT&!VJOOSg{B z#7%u(b%OT*rFZAUNikul_p3E^t7d{;t0|gl!6w2(Rf7@x4xl)8WtD-l)o5jN;mHJ6 z$fgT3u`c2V*eu!3IY>(r@&>iNsrSEREp-&Ae4J5w@l%22ymwbt#i5H<|8=ByQapDU z?b*z8fn|N`ZIPh_);@ea`4RC%2g5CSZ07r}bp8CI7_Gpa^d2)8yxe;Ekr;G4o6GS# zQ^8FZKU>JIs+UQpe3g6eOmnHw5-Js8fg$oHuK$Srdzch`;CkQDEaQ)_3{Mn;vynH@ zKX@gadIz$eIp}Z!W9KoRxfmIjtM!FABBRF%2gCiJ3Vh)o6qXI*rXFxTYboPK@#p>I zZNIJ7WvUlhGoQ#~L;IMvQh4=v5mGpDGz1Ow?4re z45*932FtG$KepJVht=eSXTA@DM$~DEQ&U$0QtkAq7a>QH%nD)PF6mcs5MX;gQFvz1 z`q(ezy~W4CMtyb_@67Nyb}sOa?W@kk85IsleNs~oZ((dL(`*wh*^G1nvCE5!QLVGS zQ?zpCBS5N~_1}4JMxanAaTz!s6@h+Iz(;hEY(9TXKiPu#kU81I-~^Y^v5Aq&pQYSd z{Dj2!K}t>@J7cgdFfqEcVM(0*aG>W0PsyZf9Js&~AzwYd)nPLq$1WJA>~Ko#dFu?O zkrS532(i%|RbN!q3!wQ^?zrTO?ciXfjTz{y$L=aWaWvsG?G2=Iz2aNV;XK_GvtV}h zSc>JJ%+^sOe^X`V*HS)i38J18lo4E=LmM-)3vDu>OfzEMzS@gTO9)c|;&_5%T4TKz zv0>a9xgfiuS--ybH0D&-VC+{M&!p&{oxNQ-(NRp;a_=;|U(%5W&1aag8eYV0mRvx+ zzF{=ExE`tIY5A{(xT5~(GKN^5dXjLVK->E^kmhP_7j|<5A}P5$XTY*eguDyc#wi_s zTj{?;t>Z9TdMI z^yNv;)=HCtS*gbJ%{ud_Z%c^5x|D;_g`fgnGndabNpielfppt;bAp>1y9?K)jM&I+jX$NT1%D%o9T zhTatq(LnFaKk5HiJW#qIWn+iToxL%6OOVe`=4wZ)!eEW9LmqMIZu1^zDF=3dA@Lia zWM#9Je7^l+X1ZtFoKZ~nl}MCMvkDUHW$4^>x{!^lDU-{rYv@f|fhvCJ^qG(5cm5GO z@yJ*0YI(2e?i-^WVN`rWpT0KVAX{8 zwtZgYLpA*xWNK6^ubJvZpM{_wXF~nya0=%U$_#FQ?Z>b%FASJ3V8ra{BnV_xiRRR4 zTq?8LwW&8(BkbObQsT}c|EyzU^t0DV=#cU}9##*zVn&S|2SA}oK#lV?K07qzkc9bd z1I64?ZIn_G`H>Sd%KT~=09~^@C9Kg#x%bXQ(PS*7UPC6kspq|~_0eC&MEQZRo8g=V zfWpPs)4u+1w6eln(@c@zN>R=XxoF(1AF%ID{A)q z(l{Ya6B#+D@>##=%@fs`yABDz?!6pu%@)a&r0_76Rk2%dSZv6uuRYwI_TXu#`o1Ye zzlexT!AdlF6hLGp1h{r3hbMBI4QbIASQ-!Gi~O%!uW|_m^<>DsuC9b=O}&U~QNcT1 z(IqDib??Qg|F)+|*pl;D)o3D-K^Kr_E0tJDY=<0^0nVtdXu`8@! z#x1Ek=^nAK;srv>Z+F5w#H<^R0DjSF$5RcB*SGf!?Q|rRcrpus<(p!ttqk0 z1N$n)ATVcsqEmRtiuBph@++(@q;?(#@U|2bKvhOaqj^S&-KSV_Pc_Jo@T-}6^+xY*$tXv9KsX1%qX2;vytkv zi|yg-VEfPj*}6hahmx?hnr|yb-unl%^T(M+FG919-XPkDP6`f+pf55{o4kScwb=%B zpE+H`$GemBil`MXC6oJe>zhVzps)i@U1Nsindip&^EXz z=Q-aU{SA|kvi+M7&u?=2E{*4)!PpHCawKdETA(kTNAF~F`KhQwm@Or2oV!KkbCqNAm#_--8F-vte z%lUtAZst(D@%eEEkC_GG5+&>7~NQViNBpvNa2cHCMvin$7$8+ZZ1^u;K$C5qKSJpYxFd( zIUVIgKQWvfr(mZgD)>lK&3Sk2uuC7}+j}Dne{q=n7p9Ww+bPu~TruF?dmW$TM(Mx7 zNPoDO#7D2V@WZ?Jbg#I`P``GI917eyeZ{X=qFs)!!yfuh>hLf{LK!`84=9g2GVG88 zUSbyo?S-`w6JyWW{hxWF0w#hb=pPsC>WZhJEvZtjRwdk_f&Be!bAdtOp7;+$o zxCsY{(*^H5w2T>js`+?oD5K6gkn}mCJCQzcRyerMY)O4tFWw%iFXX2kG?Ls2f z;;5F%+9`F1`Z98E*Zz!e?=6D8QUf4sGRj!R3yqft{E zjb+8jg;VK6mzqyglBV2=!XmVRXh7dkZaf-ua;rc!rO6@GoNWkRMDMZ(ETjtGvUjoBdH=_Mg8Ugjmh<-Nyvp5B+F7hH|REE*r|#f5lp;;w!SZ zko--byzA0e9K9whm=)ww*!ka_RxZ;9C4!yT`N9@1e2Z2qC;Ki(8 zZp`AYh?G=&xwtbgWR_@*oJVs7t!Pr~4qDUEp~W80R#rx4-b9vpbKp7RCjVg+YfjpF zV9UibC*G<%u`FC^{~+G5t~9;${doMuyOnnwx$InZxx%i(L3neH5H%;q`Khs9%2@I1 zMulo%Z{dizAq}5Brm0fry{IqHp0}CPmtyb9Fmp6?-n}^@k)Qc}+3%%fxZ0KD9OvOM zs;XbU=e(A_qdfCqy5IjUFP=s&dAp~em)Yq)Cc<#`SBb7xeS+xa`1QG9P*zsCM&P^B z3&?S~%O2FGn48wZ4%tZ!t1i#f?Xc%MPHlR#(wN|x%kQzIgtG0Nei_gkwfEt#?8ivj zqTD6b<;odSNcq1e_JZ{ku4xwub`mR5D%U2J`y*yleAfF_gi?T(YQFA_R$ArIGhvqu zi<~&KEKR4q0IDwjvY_w$kSHV3XnS=pOVFwuT67HK{ggdS=tP5HOaM2G!(wI+qv&!W zIC^a23&i@P(%EccggCcVK0ofBi+-B7lB3CNald%<2`~GV1-$<63V8Y1qJ&a{MB6u5 zIMp;fU=@x3AQ_LupVn|tDTC>(W3FK;-m&=z5vDDeen5G(5mBdC+Ifq2tk>4CL6W!l$3D%NhUD<__%R{R+8AJr*7B!g@wGKW*X8BzQK#hGgSlJw!CFNz zUt{u;A8XUa3Ya_nlT-gNUZX_(*5&S4kcc5KGJwyu2hemhs7FC%T({|w3`*=G5ib$<#64cnpf#I6OC%r!ROYq zEk~Nf&IHF;U0nLQd?O>s8LFd+2$f9(uSQVMkq-Vl-JJV5HEQA-UJ%wY9q^oxfjFGwerv-^*m=*;pRk^s<*HKgh97R?}$WF2 zg_+!PITv(?bLITfRiN8=5`uv*Aif;#O{?9pVlTro8+`u{ulGN2k9hmk@t+Bef-`#X zo_|jSiBi!Q$_AGk#v~a6H=0R`ZCrul4gT$Z^6B_-HKBR4CV?Mi=uS`P^p2>2?(v7` zK>^MQp@?&?5<)!-VQgbPkHXoSHZ(y=*|_Y;p9X3zIZW`2$@T}3{S*?cL23EVaJ+#jbo;pKVuaX>d^!6m)3z z#fqQ}O1Ery(g!Vw<#iag`zv<<%Qh+0GZLt<0}-YEkhSy~7^K-(T2U zpXf-i5uA~2i4OP17~Wv@5}Lkwa=Q|!!~w4Dp@AJI{#>#OpZ5I)5L;fnkE6@Ka>!WEp6=b z^z}QamuZ!MuwY88)Q(%bb7Jt!o%7*LBhc9f&)9dK^50J8_3PYVyZPACy_4o|-{+{r zTk18aR&`FgLC{WVnS{QK!22GSnyD-M5K<*JXE_5>n7R4|lhM%^F#G}HvvNuUY&B}K zF*V4=gT?O@!YQgxAUJV@YKM+QSu>Crd}lhnCEP*hT~*fG-f7ZzS*-w0X<3JdoeXTbINwm+59!PAxmPl^xu zFm!5@t(7})^IZBdKStYw>K^xt_!$|Jcf6Z%98IoXW*@YSim)k%Gcjg2q5Ht95cf z)u9TkvszLP9Ew}Arc(v@D6V-{rQX&-Vg{sTK8bEcwH{2??i}C~gj2$XM}yy(Jvl<{ z8sVAuYJ!E;rSk;-cs!>ne1ixip18ttxI0Hfq^ev2f zZMT*1^e&IIRqO$Y#FK;~{WU8VxH8UvQ%R}YjNi=+>6H|+uv8*Ir|g`cl!G{7bP5aOAlZVApeY;(}VZ{ zj#anof`~SxzMnIKc3BC3=7@`fK2{H`3Bh6pU{h;{BB5)=#;S=ESI&Lkjgc?qH82Tb z=mv7dr!U&aC3#wL6Xu($u;!IXZa`%twe#w`nx>JhJ=?iK=kRd8m0P*aW9qH+WfIv}rYzIp7HT7^5-W zJK~Y)QkT9D=TVk6OEq~;SDG$z)v&0z8)cjdf6D)jw>uG1d)Y{HE{dr;2c(nzO)}?( zU*a}xeb*ik9BBE|UD{z|EhRNB6@g~=YhWKAp&{aXR@DnQZqIY$JotujK3{J!6tI)> zl`=wqILp1_YF_^(UNJ4YtiEDV4@Q{WmpioZ6R)kEOFAfU6`{_k_t^&X;8c9;fbR)T%Ddo{jl__IZE1eZmkWKX zw>*cg_22!y-F)200CQr}L^g{6K9M}2H!wAZ!~FPWr<20s-qFirE*dA@z3V& z)1=eR^1DQb2N`LF-mseC4UL-iTv>>uT^r9fv$%MzoC8MqW)xBmYgP|B_ipiU{@CSj zE~3%%f#shv2RBtkU=YcL;BBiY>C&^(o)b5dAn=9tJK1UB&_|`Ppv~Zho{_))#RqN2 z+0M|Trju7uQo9e?R$%~2OE>@=x;=Rs({beTn{>g(#Z5u4p1AGG|LO?Qz-873y3bC7 zQ}tXDKLxE{*q+EwWOYWFF7+LD9`p*FPvkpt0p1SsEuOZJs(P*hDjvJM4AEkh$Y=?@ z4J07_w-`3!&C$$R-v?f`7#L87#)r8khIaXm9c7`6@f^*DN3O=RyZU(*iE`2^=OJG9 z#^3AJK%FK)l!Riwg}EK_GOp9e4!fC0DLrwzl$LXrx>WSHYHSOV;@7#R!Mil^BKWB#8+7Enpde z^hUI28te$FbG$pZyZ^UlW$k5YdEpdJOB+{DQxopZme|yx+!h;ep30p* z5@#pGl9Eu&jj-jGM=~)2l~pkpBJwo88l8fH%|OnC#sjmDb$xbt$)nc<#Qyu%{~W)c zmrtynRu_H?coi}>)@oPLxt!Vuk~ieBl@uWWqE-MEJ%rJgJi*eWjyOi?EL7Ac`-@+gOE!1<69!;DN z-PXasGS)AfykvBM&Za;<2ixeRZcy$Xk2C@JUy8Y3a&@aZnvCnNCEsVD1wuB`Ml=H? zz^&^?qS|*xH*CjCUu_B<5&CWw3w`gqq3M5aMR9NPOq&JejKS}QIYdy3r~w@vO^XO8 zPpwT%fp$j3JzBC}G>>o}4Rfo#*2|pp3l75@WB!TzQC~L04#Sg%K$K#7Rb@X?^cTopn=CSn?mo z%AQ(9{ItSs1(w#i;?ea2xnbPjfP#96h@Hk*Op=}oJs{8seqrG?HO`dy9Go+mqxob$ zk2yPvuYV=6r)ICq?(r5uAgs6EY1%!yxr64-Dhe1-5V)c4+Sg|-T)@f3wWKXUi}3GP zcbtFqQPowx)F`IL;pDrWBAx-gQafws7Zx7)p!`Gx-qd{TP6S$E1i!X{5qB9l}#C*djI@nb$M zY%Oi<{e4bHwk0F&x6(?_Ycq?R_@!#Qrfnx~td68GM7hnbrY{P~a#Pa_ky%r7n%1>3mbE`!6Zk6O&qeK-7p8MNyy^ zCjW^)@UFReiQ9<$S*)FsOEHZkr_MfPCFyj|Id(1F=ZD*_fEBbqO76}6SmkSOH~}2Y zG9li4&--ON$s*;ppFaHwO8X0$YdQ(*bN6A;%$!k-=_*NXXrCK*~rDg-dod9z-{D`M$NFC1j%9?bqF@Zr7zf z2B0vGFbqWYa{0AZ^@$E!LyLvxqmfaj_9nOpl%4L}g}WBnbaTqzDb0d3tL~b84I@dB zZT<^|b}MF_mF^xa=+2p5eG|K7g)1V5FtZB%mnqVvVJ*Ul3-xvB?S8kQhw3IbFTJelpjg#eBQ#Svt@vzNtqIPC^j%3bkx?7j*{2%YuG2F(uFf6+UpRh{&6UuQYW|-iA*K)l~uGh4QoFcC5J4^an;;VKkb033!)mIB|FwCVi#z3i+i>+4qexq5uc`* zaI-}dveJ@2s(=y+6KiIsM7wwPfXrZG&WU}?tQX%JFEG5=(RmWcB z82j_pZVH9U(Lj#~!JOA1MXzMnCW;e_{nk4dGcaG$@VoWrH`RHUO1J}o4#6f5nzDkA zt_zgca`c0_W?f?Xm*xQqdpEF~zF(%X1^hVcL;3Cb9ieS)Kh3D1Wfk(B{^7B#!V(RZ zBA^+RB*+vGF{)W+Zqr#-yV-6bYX$g1$q!{?PMFF(+nJJiIaAz`>*Hf4f_wOykKt1N%B`aJ?F~B9B=GHQ2%3X=-d{WiXL!9SgE^$TWU46s zOcCa?rX!f=!v>K)U0Xk@rv+~}iUK%D=_zs~6=B;J9Cj6>q~`88{Cq7)z~6fi2vD1A z?ym9nziNJVPMRyBVPH8F^H#fP_%$pCYzCj7%#Lg{)=7n@2CW?FDm#32H|KgEh}`iM z`+4-JsrwaS@DwTl>RJ2QzHE=HW?6mJ3wgX@(`@-_f`1%8qL;!q0qD$9%iFU#Y~vqW z{u=zUVDEM`Tm9i<3}C4GMszm}5$we$B&}4Ts<_C)4cPDMa0LVmETq)G+)L}7zCHJ7 zbiMbMEc4c9pm?a5pz8*bvZ`>iv$6R09huXsj$=~-&Zvw?+@~2qm-3AaoN2Nd;s%=y z#Va#Tc}<)31yFPE@v!i{%vnoJ)^&AxF-;8H3rdlZ-rr|i%w!WN{A=yhiPR4j>hcXR zZR+V)`d}fqk6K)5Q78OxX7+ey6zXD=qzN&vyd191Z8qRu(8Knw`f5D8#Zd zSuhwjHL#JBvKPRp8?$McY*W&3WhvF2Z2UgYzQKILs^=otVJzBiUKc z8d|YggK6SK73tuA@BPIk=wkZkr8UljLI@4?4{m-*x2zgU%g{MB?;Rs1ubi<~ z93-VVS3DDFNkPE~N(lngr!~?oz?#I2zBAPHlQictP9?KSNoZYr5Zawy);L%eQ57QJ zQVg3IG&uY`tC2$<9X?F%SI=zm%=#x2Szvay>1ST#b8HGIT_d7Br26QMGU`ZR?y1e2 z&57Ul8iU*n>g}e_@f=3Z2alaMqN84RV!T)0;j83ZOKQSmSv%3F>iwe+AVpsLS{) z+F&k@q~>PutxI+dheS0^vaPL!90P-ovLPymOt4fiEC`o5KJ!MDL2PQ{;Ei~?_iNR# zX>RD$rnFU($&|G^?-+E^h)ED_F)Nz$B5hn|w6`cQ{yrGIqa5V0YC8?m%$^$xLVo=V zG59N4Cf6gt$$&Oqp$9#!yjNAQ~Sk>*w)y zxA#A9m6|RNZv2OSxSn!ZIE)w~kfiZ^{goP0$SVlJOF+4Iy@8O--+}znnf;iwNz33{ zo{)|Kjgu;i=A+E!I^0qpAwnK~P^GFRGDS zjJ5C;_~0_Wo1T+s01heQ4Wo{lj;nRN@DKau*M^c)lnXTJC)Yjf`ki3eCTT;OE4&3S zOKPS!>ZcKOt`4^wA47>be@A>2(+&hr2;c)LRE5<4xFRMSDx+rm-S7V=i)20=mSo}X zRkwcEsys&uZvD?%Y=fp<6lt_$>3DOi5r%%Q*iwTfxh2XV4tc?ukCf7+paR#nBYauB z`sb$GtPm|hw@X?nj}$ezze)doz<1&@;ggBH3ZJUw0I9UzK{%eb`}v*Ydn6C%W$3++ z?^@&a(EdJXyvhLh)1?#ayrS{QW^Z_&U{;JVc$ed9k1z9hOh;)~8 zR3W=%j}o1?bH)-nB$hRghc&Ip4SU?pM&c+cAo9vZdfzO)@8h)D`2{1&hwZ!ro4=Fe zCJ=;xaBZY+T?SHha8u)`5QaKBda@6FSPM}SnBN`wyUh&7@58{wra6(_yqXflN+MnL z93=^s*b0P$Xm`{{B)Q8?jBvVHtxEQ4msnW`m>e*a-kS^H~7Qf95ZnNU#nzuXc ze;-FlRo%tsvr5qZB%g~wXRpAE4%Uff^zO3nfjEyRf4{cX{(o8kaBn{DGNQtpht9KY zzrfh+C@F7rBTzspb-t?ncFu`J^wRYuy!gZe9q0X6r)_h~Vb3CUz?kYLU}D!Hzoc>( z_xRUWrtjZZS-{uTS6)RHm1y40qAg~56sY=cF(b~b%hT?H(AtlI(`aKPz(U}ZN@K|% zsu($he#@_PnlL`Ic*UuEvN`Nc;-WmJr9vg9^Z0*ztIH=qPDD`i@nr z4uQBC29t%>RJ|WPHQ!&s?aG$2^eB)9B99ayOv6*MIFCWr!MLZfS_2V!bouKqx-MG# z5_`-9fI$zM$NhW)E*jCA$OTqmgr(>>F1sXxOuUAt;I&31d8K8q|JlfTcxEnnL8~Fj zG6?SC`TdJ5c_Mc>jNLM>bY2(H)yZC5e(7b8d*Iw`n~aW7X2q(JX(gxNZ(NYlXQYj= z;hmuXpKWloSQg>po7U#8C@G~Zfl3_aXG(|f=EGoEYWG8t`E0qu{TAdsiOJZV~=No%H)KHF0mrR z<<~*bt}nZyTxOm_&eUjU#ll!^*&?i@G68P?)|T%*lvy(CV`<6%hp6|CO0xgshwm~q zb5)ij9GRt+dw~;5Q!`g(W#+&=a_@zjX__f+Q*+=*&0BNtMJ{k}b5B&<3Zep{{&bx~$V)EvO*LBB5pWx{@27vQC&%d^&FYT6}P(j%V ztwZ1NIoDOoZ}G8kpS6aLDU70$+ir=xf+5J6iI+&`JF5W7%Mq_5d#F|dlo@!)y8p2r zY4uB!kHMESagji`%CgMZf?^dsd6(KV4JN0+$9a<%3+zUs^L+IiW>ej@;I703(xWIbGk>G8uCA*e#Br3OQB|n)36EPI9Y; z6%xIE_-cW?<<=Fc7ag)Ndqc)BN-|g*@!9Q#21am`*mCyTXOq9#u7WO3i#`B&9!y>Q z4w~dOOf&}rNa`J9;Z~?^D4(*wXHT%yjJ3L$RgQDR#5<_&v8S4!o7tm)N=d$2-6^3c z0&ClE=fuqzj3H+83uQlvpYgnyzUNt>3{tnw`I--Jcf?$On|rd4zb;BVfj_DtEY!l2 z0dDogU}TpBuxxgOH!0|_CxaH08IR!W`p#;}pw`4tAX-5i!+eO>gRQWosQVtQF9imo zX5azsLSYUm4fR32{(Kp243L(fwPXV^X2VdgYlSD+gWnEJwtg9FZvTlv|0n2gV&OE% zI;-kga_n~3Sew4RH_r>YD;rW?VRqByiL>-TE>h>oLg(UVgYZktVSs+q+jp{?B>&ti z3#kclGp7kmf2sRL&@$8`@$RPg2HAO1~lkmHOHKOAo z%W3@Dq#suZlx8Nw1VxIk>la{3^WP52d%2nr%gSeeaVB$B^-=*JRM{DOf=}y$0~-D@ z6VkMFM`p6+oF%N$FR)pY-L%ZQqn)% z9B8wCh!ofc*tN~roV^|VIo)Fkh{g{m;{~g-To^~51aCz4@JGBW#p8FrZRkSUZ|u58 zbzqcT@+MVk-~xW3wo<*J4m5MmWM|s4n+R(?84D|`dDyio*Pq&%^+T?9XvR8seoQfU zv*0MJb@$CZEA2(xrPfA!N#US;Y~qY&onew78iXzCMB^{fFn zc{kh+wp0D?3F7TWJ;JvNn;>&{YWxW8HcB*dUgGi(KvMjcfxKB`ml~7)i?bU};McMy z;pfKNu3m;JW@b5klIQ;&yAL^0DhGtbLYio61KeCk$7|T%|hqaf?^Gp}>cZ*zZ z3Qdxwy=+c;^Uig;|M`@$9HUMQ{P4B4Eox}KZOHX{nUw&vva zHDNX>o|j4H>wdLw*?=d0(jw4WOvHIv{|u$)*Awg#;l0qTA#UtzW?ZZ2`EEb@aPG%z z$on*P;rgY{y2gR8Y5~7>SN}XKdx=+1!xHLUm$k5_IvHI(>7bfy@KR2S3j51S_nS zu}Fi*f`n`<|7{&+Mm<|i(6UhpYzw45*jV`ndwZ)+dL;$eVdJguY&KpNaz8H*`8nTS z*Sr-gks82B(Qgj;Q5YGj-y-s6f5a-0J`=WEP%mY+ssdee859hz*V{boN8G7=$R{B$ z$0@tz(WcG5QX|^N@ww1>>Z*13Yk(Tcue=8eU@<*UYUSB``=6@brdtKcsqZQGvw0+% zu5zUdhj#7DVJ75eY8iAMWg>9^q}wBUQtlxYsg;PH9DG4hR z9kbmYA!T)m&-qZ#`iDT3=W~F$KKH!cgV|N%Aam>7V6BBs-yv5ynP;Z|RpYS$p#V!eU{ZAFQv^ zaBwDv0z=}lhg1mT4NZjX6&bVxrCu{U{t@ZExcki^jOyI!j5t5(s?S0fS0Y*<%g5I7 zoVo@7+cYEOt}}>(A+C!zOP1ik#sN(xnF0bAMF!DF^Xd*vKdD)#5Ap`!U$-7%9!vZi zYZm8#hls{5Pz{mty@dlXMTqE6mK}vxeyQTrC|uv8M|vHsrb!{^UvM{ljm_%`nd0B5 zIJ;UIFmE<;M3Zk=M=v^^Gy9ju>MLs6pO6533HO}|x}L@N2~0|-J6ZlZv40LW6&$dC z&+Aw9xw+l1x%e(c)9ovT<|$tiU}fPI6BGN3y|OvEAVv+>IM-|P|K4TD{=dM+{|W8W zSvYC_|B>3D!+@N)0L3{y5gs%$1KP>P$$oWqw~=sgx4^XJI{BqXQ&m~f(T6%I3SAUD)PRsj`gDr6bbUMKPdy(_3%U#i@J6WtN4x9XLx-w*#L_ynfNS$^ zVOZ*J+Xf+H__WsTm{#feRJPVO47%jN{Eb+@R34#;UWR8nTWcvI9g`H zU-eqJDOqz6*7)thbNNk*mIH3*U$rW$IHvr`f;#)M-I=xQg_$?yX(CRfFy+}WkIT&S zTk}N$)aP?gN5K!ZH)m{_#q(pUB45xpLMfTi;?b=+Q?Y% z#pNKSvtI+Ouy?vOfzi7eEP91FpG3Wg(X=e#S(A6b29XS5%)z_C14uQ06h6Cn>l2*! z?Ajc09gx0Zn+cwv>kaiL9aN76zg3zhQ03fL0DX|tDo7K7-?W1sWa29$Bq#btOFUR= z5wj$uUw$I9oD|ODFQ%-@o?f~7Ay9=LoA!#m!44TJF3#yIM6aGk4a+;qb*G^bciykO z`;m>4$uG~XkeS6x>-0x4lTv8k^VfXE;wtDppVmJbh>M47MiRNBkN+%HcY~kcFQ^Bz9GWk5F>&RB{3yE|OK& zEc4C)mcGUfKQ*$F`=FFAs0dfVX_Mn)ws5f-33t7SFJ_!<%~W^BhH?iFOKc$}p`zB< zjUtg zYtA_r<0Sp&vx{neIVIx}7I24eden@(n6g@v06^eW{d{;q!HR zm^GClQ=-bdu~wbg$Iy!-Xo~U3kIB++(lXVr#Kb@3RZtZBw^v5Q=a>x>Sd|)_ttuXQ z^E96+r>^2B%7TtwPhvt!T?t(_&ZL9}EcHiJ&r{(A<4i)&?dX*@W#*F$stt06)U?X{ zmpbD=(;%fRr~XC$hJE2q<3gq;QbFxX-FI1j2VuT$+kh6#X9L5Tei`qF!JbM^S10Sf z@iQ(|L${LL?xrXuNN0M+eOp+DyOo0bFO*C~V~1@d^IBKyw_~WK5|L8u--TF_M@3X_ zYNb2{kZ^jdTNhc6%hO9rh3t?M#fJM(PY8(=?u_z1%=fohVntZuginKcSYGhvlD{Ox z>K7}GG;H}rgmRO8Z+&THJY=TDlIc?&z8(H9T;#WCo5@Ae9OA(hdR-|Bp%8M$-njgR zBX3?V@dhL2+UD}&DZo`}yMyb%kdZ1aSPFRxURhwo9D>c|{z$O<2~nLl!IKP?<>AF5 zi5*_cgH8B%h>5q>rojm@VX6>ZE+-YWJae5j1=)`}R10!AFoLCPVlS;}4Sg$)`WyEf z-Hkjfi;<_cN(u3~oXzlO{kPUf;78lEgy7imNB$Mt>AwB~GOUyxTS(KoM^p5rD;l%^ z>STBHYGm=COvAB6Kbn*O8+rS`30lP>QwNFAbLn>nTKRe4wt#DXWp7kyUWe)@US)jF zKFt)CuEeQbC%;t7S{QW9P)UK|Ken$J&Z%Ia=WEc?9_P^2z8fdcME*>TUoGf~)qELB z9Q}Smh5cq8R?D6CAt`+hH*T1Q`Vf7ia+r*RR)m69+&|CU_u=a|it|p8o!;g5jegt> zlE2*^IPdccrwi77srV{YyDz;&m?oT7mv83DSJry@=J(VP65a4#myiPYS<;_$y<-YG zn1W6SZQlGWhdpF-)nxpxHVi$RfSTZENn!~7ctN7>%@J8rE*C3KfUpys2yXnepE-zs zu}z@ZQWFwoaUHhx7?MSF<(!mNyU^mJ(gVp(oFA9ikFK8RJ){XW>Fgp%pid0ob&F4> zvClQ~M)&{=+k9{4;N#$fjZz&j$@d^ku+ABCS6FF%WqLi|KP@^b1h$zmO!xK6l!=WZ zcY5ndLwwgR4ZdUc^Sy;JgZ=SGYmTotce&5IoUTcrDtx*4dZfYHgA+!Tv!sXTRCuWw zoF3TgJ!v4dgi!{s6f3MWR6f?Z2r+0miK%m)ZP!#9qsM#Dbk_yW)|sp%gB#wmFd3e2 zBhQO1kW8UnabGBT{69nzK2%Se_$~Q-oEChw$+#TL(rhO%j>G2~kpk$@o6pEp6_ypJ z2&qd{$@||&ay9%LdFWI1p~0xcbRg*k^qfczg%CrnpWoIsPe&~6ZFc!Z<)4Cz*7(Ii zxw|0rFAVX^Eb~Y94~m8@7~zf&;qgE}9~5+4{@mjpF^GhEc-I}#0;?nJX7>pfwfOShRCpFpES4*+G!%%)uXfoLBX zdd>GqtCiQc#Q=YG)av_|$jLz5Vz6`1%=TF65b>|zhXba)N3jb{sF~c41jcP@!cL$1 z)x38KpT3*0uZ(Qo36kqWcHCq0$KG)*%r*M8XwQaCtO}aXwM*ALdh6TgXFKve=c9W9 zpGmKrzjESsETucPM8}%&6S4YkzC3$-y~Br8X>>vW)+wE0r#mE2(V|53t@*a&#tH5P6d@&{d~J7#=xR)z}46mMuZ%% zqkEk_z4}?SoJI;~XssNJfMj`XaIrT4F(opO?c$3b? z+EX(e;PgqyiSA}bv#~p{Oq$;QH@(Xur|Q7g!VAiJtQ2{SiY4~Rvntj2xQ>YLgYogP zA~ft*XAkjHlbW_HjyDV3Ae z5DyoGKd{d7p>uQB1Ga-|C$rR(>@L`SUW1k02yQA!&uBXGk!45@?fd0v)NF*WL&yov zEH0-8*;0}qJ?C2}9gFhK%(=!@^X>_6R?krCk4yUgf-bJrX)DcOgDJV%lLHQpEb#ut zE>4+8=|3jND~9P~O82uY=Vf>JC$*C@=6URDc8$qFbRdnV*B$goL_h`l~CVY-RE$P@}>gcSdrLHlJN7GqZa3 zxy(0_gQ=Xe^8G(lZ=Hay`iZR7r3gy*-iUTR7cdUN@E~OeXZ;Fo!$T2l#iu=c>s)du zT4L0U(jZZ_REYUFHE=nn+_Vhvl2kk8S%W@gz+u>d9Ic&73mQh3RX?l?4-tVKehToq>}>1kpgZ;%I-sKHeyf z;ON%l#7)p3d(t!GmLZ73d*U)4BjPNzq zYJK7w$AukUcba(R8ud{uT>a2ib#C2CD$zixwHmag1cNClMe#>bHe=*|{IR)TNLgYk z8&*s=W=~sn=y;^i+`(x1DE-P>zM}DRm*dW8qTOO2#*eD0k*9KzEaCdok|JY&a~Oa3 z7-mck&7{`4i3R#4-ZsVcJE%IAOI@LmsA_`=pK-pVRti~z@#l=Hn0#d3(^HR+_g^k- zs}LTb37Q|=vO;Dc=NO3C6pn$99?48YcZ=&yuKYr2M+0&)sLi8SxRJoYwOC-VMFSqO z7@KAQ`l}Xx_!W+nt9RYds|f*KUL9dapTroM?mOLlj3FAwlR{D~m&`K+m4j9zV^`^g z1SB#4RQ?O06Bv39A5pPpVL@<5Ipob>b(8e^#5AhxmNhmQ{1JE(_##qiIK05_Rzk7& zTx9!yx%b6R>Wrih8v`NxK#5&o&@jz+N>^e!E=IV?DQJaFc~h`HD>)7B-x$N<=cDZV z!`Cly!pScmTSl!faaa(VTX~hW8qC4Hs!A0&TEY^BOsGpPCOKa6W8pxuCgG8=TccT( zx;$cZGxMVQ5o&YPW9_z&YB*d;Djss)>9hg%!>?Z0_pQuH7= zEpWHw7S^mOMw_iHQ#(eY7Xz=|24h%QkWR+Mp}I}NdTYXA#ggy=^~M+=O9q+8%|i?y z{kNAR+AMN$HBm=m+>e~gvw1A&^h}D)z$g04k(w?B5YZ+I^)23Rg9~09ZZ7LkeN>Qx z4t}OuW=@O|Aj2*;h<`@BU_dJP3D)EwNf6e=A0>538fys1W1i46xcKu`-L0BD0W^`z({;TsJj(EB)5F3JFmsB<@rZBoLve+kMdaltyOP^GF-p z6IJL#bnRf_VOBp%XR#BiabT)kD;oSd2}SL+z|Lu?6D3Ujlsx)c zObw9|*k;xpe0z>=vr@#zXdG2?=mIyW#*m%dwOc2o6?z>y=DNo@Kl3Iq42|fo<+prU z^d0$dMQQX+ZPr@N@lTT?%F@-kmwtEp>^bE9R%s)K;;fOEZZ8lnGo3-}_L5me%?$I* zpWJLwbgo{4X1+#pZYI0z2XJ<9E|R$2H z{PdLpHKrrq-%QC|yFXQpN8)^#?~Ydb$viqVRNVhPU(Vd&YKa(eE($N0V)cxDWh<*j zD*PJH%q!hZvQrK|PLPtXHx$q+CW5?3-lRgB2`u^yZf8cIN@I%tiBM))32Za=? zHz)<}BmUn5L|rS(hU(D&mrv8oapzic_a~sKXlm+|p`-I%M>ApP1gGqs>WD8iiCTs7`{~SWm#Vcazd4QuyZU||2ns0 zdg2Jni&cdZK4?2*dY=Tjz_hg;m+ED^=#rF1uEo@=(J|5i-=x>yKRmrt`H7~$T_)c8 zrG-Gd#dTYi_GPU(l?;YNwGUBJ_--xsRkKmcsuImNejIWCi!vZ4gQ5bXp7aWxFZXkX z8FLCwWy>x{r0CgpwmnVJ|L**};x2i>CCEcGAsb%eM|e!*7A4E>w8EAZ_zEqk7|M)T z|5Jv?EI$ZH|L(k6P7GZjNCmY8EQE_}a%ZgWJ@!9B(^YD=_}j;i?y7r!4T;sq0~dQf z1g*hcNG)uQqH4nJKL!0FBg~h<_l0&5-(^L7==6&V;^S|;A2EYG>o}I(C4Qf74{9D6|5cX!{ zWtz%kp0;?^rf2Hs`)_Gi&PMK;@|*Qo;BGcYC)N2MRyQCua%OG8IeMBn)2-s;jpieO zpjq9a$!7nkp8xFc;D{0`qEU@-TTJg@A6_nrp>ZbuRCcnecwU9x{Q(#89TsT(dc^87 z-sO|43VU(t7pEFAo(JJuVMB?YMGoCxdT^lDhN8vjIld1hy>+y5m`!(vX9a+EDOBq*F$w-_GV68~WLy3$c( z_uF;zQ05w5UCuG$%9cK!@)HjCcC1LX&Cvl^H~&10@2LMc_%z57B@DHr_FI32M)9y; zH8~7SegXOHw6g>)7g{WUvKoncr)&5+B*x5i&SZ!Cj8Pm$#BeulZ9aP56CR0A5dxK% zx|1$~SKMY!oSPM&S&@mOlKbz~C=Cib^Gq`Dg6Fumr!=ymI=-vE1^nstB&=6fd6&sx z*SSCb_yNykP<`il_v++20?He8Oc}fZjhqK6b2)5ek}c#~VgQE3*nu;ZQsvUZ%S`aH zR^1%YDZ7lNJ_7Sa8hLsmjL67T9+ZywlZ{2fpH$ zH;ga7Vl)ox>eRzwcrei#uaKqZF9|m*;Gby{oql3JBsoR+GucmnVCF3abhko zbcEt&Qs#Di?FM1>nXD$*dyS++W|z#*E6KlW$NBodINt5?Z~FjI3S53}R9ass9yPD` z0~3!O4%$+u-*))hSuKD;;0e&EwCO8e+Q+8eK# zP8+tN!*R?hK>SI%!!%xD<0;ns4K7m7c)n9W1E3Zsm%?!)NaMN{0G`__k-^eq)iH?$ z?D|Qv!~}71zIbu!5j~VK%2_zz&aA|dr_m>d@19UEUYtv@3AXFMXpCPY(Hpl31zi?) za@_6dyT+w8R)BL0oEv&~*t8TN_SgF_M~U=0?UVByC>P@B(WGI}=fc*KOloK>hVO3! z#j&Kd(0Qv#eFM6;HaDh*n6hM{wkUkL?`5G;irXSS+J zAn)iL?k(8aqP8Nn1&;7bYrfSk=(VY^Fg1h1_SLz0NFF9L<}kDz0~ljweM7APterrz zeWKf;v?al=dcMA48ddP!N#vKd2-@NZ=H9>(Yp6RQs?JDtDJ)ecnIVX;~38bDCTRvj(u*OI+Zz$D1GuP$Ab++V~*!(D!A<*;p`HuP}D(hSTanWCgJSUWcs3xmwJWwz`|o z{o`U8;tBwa*aWn_oUV7vI&+}__X&QfW1*BKA@AtA)B?V5UzTi)N9f2uzzPfhA{zB_ zaS>Vat;8;IXXLR1vKBtzl_+#dEF%+f9)Xr3+S zKK-vI$Z77Ui)CJ)jpNNEq35w*{&i|@s5zsyiodQbi!r|<9O&6>QVZ9y9MdITPR~CQ z_J^yTVI^8V9K3cDnbFzlwm1~0R!F*FbaK)JjQ~|f{#2nME{!$1tJz7Uw;vZM(&}VM zL%koaFCJyS4L;&=$O0QL1l8=&1Nwg~RwP2Ksow?;!?1dpdGPwU23_awoZiKi^Q0 zJ9y>vud^dH>{4yh{$lE-^w62)yO=VHBgUr+`<^Rv;x@{NGKQ^&1Mljkwc3^dSk*yi zVn>8fg2`q@DzMCH=eu|2?Px#C&^}=6b%~>bDmOC`Qzr~Iz83ey0RX-B!}99F4hQP; zxnprf!#Yc0w>k}`;t2@7Nx&K6`VsvCw=`l>;iZPI%N&HR^q5|FgoL1edUV;Yc1g5Q z?`lWyjkCdDt|>a}@2eM4swl2Q<<{lKClNS(h(1`s+G+!|Bg4wc7(Gj5#@Q?K6-GFa zr|Koq9EU}?VY8SDdjKZBLjLgU!#rlR9?22YM)dGtSgW9_YG#9*&WKd81oz^Q`GEJ- zCe+fhO*=1{6={aK6CKX%m>1xfsOuU?ln!1Wa>Rtpj5D^*T+yrV+#eX|e}qE)lHXH6 z>+#77ExD{aQ{J1gnH>CKL*ka~0f4^Ril@$aIBzdov^gs9?2nSHlj=L)+^ztnjjT!W zN$Egi`K|%eIPxB*1H2XRKvvN6~V)I=v*z2Rg8R-$69B_w04~U&MivFF2hiJ&a));+@yyd#HL<4YX z?YbF@?IHc7O5jJw`+hO9ncd#IJkZ_7^b4?t=koo^rJp9~pIhqQCNuEH9CTdReLdfp7jITv2bEfnE^7g^ zJSFSa^BC}!vA~HvSof4H@N&PTNefZc4hKX^ElMN3bU8fNV=jyIcOB`~s!b06u#I6p zJ_OGiLr5#J0;vXu7fS}@+VU;4(m%Dh+%A)}D~c*Y%!@=NK35)ZXeqNwy4&Y6yHlK3 zg1ifL@dXt$4syko6yE=7x{g5wQs#@F+1T;j8QZPuFfa-It(Dq$mHPhrf8u+9NCF=xOZs37*Q?gQ!rUBRr=PqmeOo?6jEe9$^P;>#H~WAV z4)s$T!|RbZM_Gxvak7uVTa80Qn+_t`*eNBDr3VA4cpq?5o$hAqsGWI|?u!!dTy1c< z12DppX&QWg3GbR=ecLQcL3a9>9iqZu24W?|v?&@cct|$q>8Oc4>m?X^I`Tq4{qVTF z&0FpX=Yv*Eml%$yDhGA)N*a1f4$0WW4GsFBIonISZ8bjoms(mLlIB}z6rRriL~9An zPI1a(#cT6$NCJ9^%BE51DCZ(4Mz#}|2fBw_w0l*C6$YuCz+PghdZiUImrlu zwlT4)(93zR-6uM6Pe=foR^1TpRfs&EJfAT!GQBQ&#K1b$xHt|uVAYdsu7LwB@=hQ6 zrL=$G_!(!*s*)(@>5q*aXpR&z@HVy0}u;J8f+Y zLTwwBdbo~Kb!z_aeFkJBOel8)`LfpU%}7giqy+!K#v#ZOuQ~WOahLn&10O3_=Ny^3 zn?SoQLgTi-)zF`?m7>>o%OjQTCUR6&Z#=xLmVNbB$URr?OCF!Yg>P^!d2tM1Vxnud zu|sTyibreWJQRPUY2rw9!|$hlDE6SOj+_hY2&uOgcRBe$yNA0XE_>5A_M9H$e6K=$ za$`03<^e*Q#i8`33s;7Qu8q(vED%rEVmG5iFa_t%wj^Iy*+mo=D97h`#Ba_~J1{@c z)N?MZzpQ_1U>^=SUU%a3hNoxQ&(?>$v@V#wHsp2)I$fxrZaFNj=stQ45xvEyqZk87 zBC(g&A*b&RcS$Sm!vh2?Vy@OafI3$UekvTL9{QhHTp$1a#`?(*LM6rei7>+ijd5wPi z%26h!kOwCefxsKOXx`86!$wZ%)6>m(irg9mCHF)bZ@O9fOvViEE1Ernx;%+d%YWqG z^o zNSU8ZCL-(RntN{YtOyUX{cOoxfwRi}#wV_yg7>wuLg5TET3O;0U&a2Eu;#9H#=D3M z3Kk3SCq~molH|R~VXVAu+CLBhW6mOU{dKOe)_J|bC1&Im1$tY13M#Zoo|mpJC2~1? zKlnA&DiZlIAmXrMO2%n?zZ#T2%>t=t{WsR%0j(#23;j>4P$f=lDam)$Mj{#-(={}$ zKJyrFzTspG<$q5zec6;;bsluh1$N*OO*^&8ikiHFpZ?66WdV&|YHB5{7GJhJY^!d2 z{~(%nhYL+_8Kn(ga$2j9Abb`_eiN^Z{!-8+UfHQDDblRDIUsf}HY^Pm{>zTKg!tk^ zjp89CHFnhW*0ZvDsvL|AVxDeSRWxkV0$L)sbRaTC(rg%368Yb{Y4hdQz3=qH9?F3T^W{a)%lpdyFA|1qj z^0va^Q+BJ>5162&1>V0L(<~Q9*#>yD#by=mpWx#!OB`UFV#TL2FoyBI>w$N<($h>c{_qIS znEjTMg{q1Qi(_f`P`F%>3Iqy}8oFMsx-RNlrXV-T%Dqap8|0V}0yeeie^mG?EtbYW zpRrEe=jlJGaF39mE)uUSPj8LPJC$*yh%jWB;m$Tw|5H3o@m5X(=@M2^P z=BxB44WmU?aKugOmVbXYB>SgkN_L|86JjKU*)E!4X50@y)g^?Y`9{hQxvh$aqQxr@ z^r5{U&eG59`<85ic7^>*ArgmbkNi-{OHuD!OPy%npEA@CF3D}u;AA-k10R`g+j$%+ zNL7!`fUQPVdFSUe>yOxt(JDDa<<#SPnixI?yl80m-K=b$SNyH{+h+u%c~buj~>Y%Zzv zSYo@X<_kq|LSpoHhC0 zEQ&6U^x(B+$|v^PO?rG%{4&_hF)R)GwTjLxOM|10lg&`tlG3IXKrCYAOenli7!`wc z*2@4P{SNA@1%X{w_4MWH=I5(XcUxO?#Ul_wvZlvr(1~2}45{x;?D8A*0&?|#%iHG0 zQ*{lzu52!4vnc-N>qPY6*V4l_GCbTT#hbeR*6=+)^?#iHm%Mrnk5VscIIk;(Wg0+ngFx@PWtyZH8*cy{*);-a}xIsgoE+jFG(z z^_lP?sle`Aayk;691vugi^_t0q;+WtcV^=#OEpSzw{zb_Z`5YhNPkte^fq~AhJ8k3 ziCU2%XI82!1{r96o0ctRvfbGQDtNK@-9oJj*#djYphRufFi@KIJj=9atdthi* zUWRkzjtGGiv+y0vd@bbx)Z$bB{pwV79^`( zBvF}g#ktN8SC6YW!A>R0$Bg$RwD@pOZaHrz5*aNj*{J#73Y%uNLfwxF4?02i?y9y&^EEe09YZ z*Rrn}^^UN0M-j#GQLr|?Z-M%W2fJU9 z)RF%=jap;Dwaa!!G$cDvUQK^=*|E?XO|kIOU$kkZ#K;A56Ds^VY=l`;yP0h^Y48W1 zTB`4|Ea2I48T~!@4vTT#$0nx<3?6_;5s7PC_roq4;39;#kop(1NlVxuk&^I~R33}Q zW)OQL8sO(KnB;vO+uhtgwR||v+X=87D;rfOF!7GaN+oNSVSCu}Kvc?Q&p-34vOR~t zb$@%T*m9T4O|!Dhzd^d#Fe1k^mW86sDIxR#LVT8TYSE}T=wy|{I6d=JB<_)@uA(MX zVH}$WWNzF`ESm367CWnsa~sxsJ;lUngaTw9<|H-TN^`g?1lG|YZ>9_L80Q6V{mXU- z7FfEdE~!ABBxM+N_WHJI?Cf%dMd$g_Z4siw`a?TfoC5Say|`R%O~@^<-^&5k9%gU* zCansGppCQsff<)>quFRjh5Uge#?!;q2hft({1)ZtI!U*qGA#wu$Ce)Juu=^m7oN1C z8B9-FcSI>;gwrig9-KQEZj}HGry4EpXKp`n&DL^E5jtS-JeBfX*04QsK%Fazd}CSh zGKL;0rI>Ep2t+x=9P=mYstm5Fr5Kh|o`z<%609QPYHiZ*F4FmkJCy8ZR?O9_Eb1Mq zvmdtwIdNcNlXEsOk}LRG*u*!#8C6q?#b<@{oioiKJ2Ob@V+m{|F7fTLtUV0e0~dig z&cERew)BC$`tXTrcKpm4{@11-3r{DzlJ!urxV17{jMl5F*6n2f_^K6ARz&~p?zjk3 zu;E}AgWZ7n0M9aVa5l4Ja;R~7+=|jJREhJt>gXW{DwAZLAL77yi<^alTg)p4L*4~x zCMFBGk4K7-XYK5HZ(NK-!0WJ|6{ze$mv4N9>Z!8l-tBq5JBi7%@&J~b)ov0=Sf$i2 z>ZCiv#tOs#Vob$rKaNFg^dnRj;UeN-ZbtCR&l;TNdG+;0?IuIka+2uTD?~J_-9Axg zB^m^KWXXki&f{gj_$08de$gd&1nKmRY<`Hp%#+;!ujo9}iX_Inp`T3x6U!?PVWXAvGD z>g`;{nxO4l<(sTosaXfp5L_1QNV}hYOL1s*&Y~9JA%z#b9IX>0>S}a#ir(uBdZAmT zH2ryPz{cLC95qw#lG;vn?{2;F>&s?Mw6{RlM!fQ?y0PK1O|u@V3F$U2gF+=oj*VQ1 z;^u&2NzioyODwU_W6fDi#MSqd?s!I6p`oFNS=TYLd9JhkLAuID5`Lp@vAc5X{x0)% zmvzItuTo(3okHzG;P^w?cdVsbegWhALqSxOFwM#OVA8aKWLvhF!-uGCX?oQWaHLQ9 z9O4N}_`UHiV1)YLQ)g$ca_;Pmpk{}fyvgu4eRf4s7KrDf9ZF<8Y)kWYJz&rv8;E7# z2`#us=Oi29lIe`Qa8O3|Qx&`c?5~u4qXMy%wv<>1G#c~df`D*4eb94!=9#0`U1+wd z=OKF(T!g*xz%q|2{rS$@|3z21`E#)cf1@V7R;^Z<7lv=)4HpC{BMIM|M0w^P3f06v z2+Cd@Rx8=}RW&tg(?U$}ncd(EO)g&vrE8(_9vm}mn`*#$!lX`0=R@aZ0frkpS6S_W zey%Iqq6_YmtQ(asCjdrMN>ofR8f^dw4WY$)<*B;EtW=ZaRwKR8Y(Lds3r05kr4 zZ*lduV{GMKoS$v?KmJ%I`=_tQg7OUI!D!R=9~-8*_@F~JKG3o1D=%Q9SB(6kN=u|% z=flp|fgYn>#T51J1FtMxq4CTLMj{@$Ug zb(XUE zUA1EC;Xo3*nzr-FD!v6gosKaJk64`uPb{6S1ZHeeM?A+sPL1O`9bXrt5vU)NR?XcE zL)EN`dbHIeCvQk#)|LTTUb?cgSY1y;FU4eaOnr6y7cfD&@X)mItbx39d|;|r%DUGa z7__L)?&4nHlvhS?bFsm9mM0&RP`CB{I7!d>AeI?VFxY)uu=q{!f)A1?4eBdPjX!I2 zwhBRzY6J3#OT2ZrLvJhk(kYS{aQfD~Z}0(uG&bbTJ7scNFA7b%w)`I!@q~3p&ea5t zd3&0r2&|15MLQ4CLp)>w`OT8B7iAHvBg&IkEh{BuLRdWWXb>La*1XXRxa;zEp((9N z;+1jv%$WW7Op5+-napU8rnbvuGR~uF?Br;&`uDLe`NP?C#Uk~(8UQxjdoPX~H7*GD zM+tu?TYj0B^sF0I=7AI~`U(RD7=y>WSE7q?s=M1k-Qmm8^cF#J*B4z&MK~2Fb$faC zyyo^lsD}J@Ajr060bi8N|2H;wZuM&(F$1-#{PyeGt_8VDMbw4!?`F*(-49Bsxy+en z>lgX41v^x)!yC-94{&P5aqyhmS4AsEP8=S+@VHee^}U4nQP}c;3i|s9(IXG;Gb6Zn z>O!VEPHN!OHZfS!KOE{*p@=Z>IVQdwuftQ}a2ecM`h0PNRq^C&+e6sRjk`$;({YF{<<6{`l^~P#YK5C`0wGcOwb5Nl%`}6 zl_=VxIl2UnBdK5_Tx>Sgzt8rp2M!)4(E-?A`PN5%aQWFBNj5m|g{C3C*1pVwvj{c| zwre>5)lO2~S0B2USNBy3CJ2gBJ^(fdV%YSb1So8QIfja0H*t}(})m1+LHU(>LW25)b> z5G6E44z{_hMDHJ^$pU$Fm`JTS_6TKbU5!_>p(MAg`rDldQ~B9L9U<@h`F6y-)e53I zy-OU97(ci>V7t$zg2qCR)W3KCw3Udj)597^e`lWQ!>XJ;3yX;MyIK#1jf=?>4v$d! z!DHw^=QjZuF6Dbe@=D2S;;h5UzlIH}jqqJX2$1EZnC?>>zW ziV3^fg`lk|4uoamfIe?)JQwWUSz;ZT7I`vl&00w+`ZV~ofTl(58GIW3tj=@Nqa=nUJbe&sOfJdo@svqb`~$De zov(aRCdH2D5-f)U8;{xldYnO}F=hRE8%|dFs#FjJq>V*&Dqul`q~i7EZZ92koHOYaRxA)!?Xi_0jk2>{NXZsvLH5Dm;@vJ7NUT8w=$MfJf3@P7QzJ z(hrA}!`^^)bDe&gU*@yqS5k_@KpZ_Y#}UW;ay@!Kc4E-REzLlC;7#h0T$6RG6B>Gi0*>xRNHr$)Qbxkc&md#F_oq={7RU`dH$5q?Jw zark*UgtX+3j+DDUy06=LFcp(o19 zLkvTcvHZv?T0&fF$rkpYzPRHADHpfXy72#I0RR)%h3J+8JH=y$myR@@t;`E0hch}A z;!>m?@9R-LLu~y^N!6U*LDwee8YdR2Y5A=4%K92=(-Ye`7sC#;_@Jld1G1F)_1VD8 z@xurILe9|_#Ad+~h3AF8DPcd4;vEaxglqHn)%bQ@68}wUUNOK4O}=z^Zu%VOp|%%1 zxy_P!5LG&gYicg<(&aquc?(J8P#)0eaf{t6{@K1&HZnAI@fcY7=UX6)H?lvq-X&t{ zc{t7f>h!H8(@}~Hu;bw_K)Z_~dLH7PA0ND#+K=Y*ka?fz(jVMi&lmi<4q#{sd%_MI z9dsT28(TwNf!iT9hLkr(=S7gC!mP8XV4V26O*6qM5N!_g=yXZ4?f21{;Fuqq{k1{c zLVSiQZ$YA5YN!U4ixbG{0$lo4=PmCbq)=B>-Tzun zSPr<5{W893;QpiUsZ%**g5KT4WezUllh};*4xd;C|F%#44RTE@@0GS6)G47a3u^7S zGJWqGz&z1~%?;!nbH<#`&=gtK5jF!y0(@9ycL{<5AeYs3!9O}bLD>_1%%7$@p*yIH zZav0^8#SBOYizuN+d*ZMmjQqtA5|40Kj5?aA`ROWg#WgVQaZJ%GM#u@bT7(}J%2uW z$4kTZ(xOf1+}#7=sO$g5`Rpx>e%RIwsCQiE;vrX${3!6@Rk_bsJ2OlZ1bV|eK4zde zqSvicr^I*{2@ZPEt18awEW+NDGqa7E6WURxtV34gSRo3}`wE4AAjm}s|Kqar3FHru zjdSW}{U|lybJ}~## zv642N(NBX5+lgK#Y2IgD@1A@|rQnq>_7=~Y8m6@AX79N;MXj2A+uYtA|1WAWwUsNN z%k%Tl4{u)OjqSM{)~V|SCV`@zszQwD5seI)G}p)l{95Fi1BkUhcq1an4qtU6s-&7> zCn3g}{Sfv7fxPH}kc0~Nhou_}kbhV*hY}NL0d6zmXNw*0pmUPyox1KgLu%g&vGG%x z(67F5;+-tstG=n|SOY+AFrhudWj1^DQV{w14x@=989!B00p4gGN2T;`Xz5{!mUOt7ZcC>_-@9o4Dnw@R=u%KW)J>%s&Kqcz6 z0%&q+{ttWa9oA&l?G0N{U=R=)X#paGf(n9w25E_i5S0=gl-?vDF!UZE1O+501VlhU zN@$8o?>$sQ@6t=?gib;WA%ri^IrBX8zVDp(eBa;Sx$eKX$bDaH@3r=-zqPi~`xn_e zybN3>?PhTqTSJZ0g@kjUM4!Ov}-B%xm-EtBNe&!cl7G>szFW^-#&wLt z+&?bBrdSxaQZz|J3M1l zd_aU_8jjoRcORFdx||{bKFV)nU87>$bALNJSOiH#$~QHeG!3Qy25m znjJC7+hs8u*>C*SesaXgq*UaIIk+Wg9YPj`knj1^O5VuuaLVdqT?qBCNdhs0Ut;k2 zzJQzdtK*(yT2GJt&`&js!1W2$eqU*rP#%Qc%yay8v|$<-4`a%zBgxb8YugP9luUm& zot_b)+L*#?OSVbkdMUC1L;4!_sJO?}cw=DOq*8_j7eM4ZO7p>la?BDX)<6D%iy{Se zEpsj*Hn)SK72djX1&rM1ICK*}AiD)i6(S9)UsQNi=&fdNsCN06TIb0<%*y57^sMn< zlrum@#Ay79syyEo&V>PK>IR&e@u^Os_K-7i#6K6WHUgnAA5TwCgmmtj-SWko*X{e& zd5^)wa8a=(+AR1GPpzu6L9PC_=cFOxajX1C>4_znQGBO`AiXatnI!LHZv+5rhwJ`C z)QF*I7eDhG58HoQ%|gluDI1hO>{am`x~rEY>jcfRw^R%;0T@o?SMoG74~--VXn^)U zDkZ30L#yYh6u50n!>Awf0sM7vzp?KlT)^OHH-MT$tLdiCR)`Nmsl~ibM*6TiV1KiF zHN7$n-IeX!%b6_Q>Z_mzu%5txEq(8dI3G1oA~-7*TH~?^mAG{Gb$DbI@|xGXhI|JO z&)PCjY6R7T*D`K9Gq(DyA>n3EoBuC4Tan%K1i15E&c+6ig$0a@IfK59Zv9UDa}N2*Wr{= zODEB_`E%^v+(_#^@omHHM~;=V(Iu6Uw6Ou7OH5O3grB7Ul_MC|LpZT@ zUX7ZOpEfdq&G%}&OjnI?Khhmjwpq`fuZrA`p25B6sk2H%Nr$#w96&n^d)eo$SO?Af zIzEpJS{o4*a+|8){9~122LDWuyVP-3eAuht2$FXa|E;M6BuE#HE0Df0iB z%;D>Om|RGL1K0rL223OP#(9v_N*Sqa6eRqXfq(lR^U$lYB5DNRL~wNQO<>GNcAvQA z40rE{k#ou`E9ATBs23;chKD`SQkZL4ZN6KZL2vU~za@Of`A?wOjj<-f={{qUSwkXTA)6&M+I}sVwG0d{C;RGI0WrncPV(w2&PSPJ~@BB8?{W1!6EC}1e5+Kv$Y@k_|tukH3ub_#BLvxZ&~4j zWc=)p9rW=ImOcOyZg==OR-zE|z2uE4@)&Hk)V{k(s-Du72Itq}(>I!5V#Z!hkP7}d z+&`uh7@)7p<(4tB^zlMn4axNA6o@#miSHtTKrUUrOI9875kwm*opSBt=nHC^*8?nk z6Y-z~L%5W3gwbiAX+kX*K6ENmI3DC|NndmDv))=7r}@`#DICSCz(t$@VFTj`QRH4e zr34gIEo1K>@bM_gGy*-#9v%K19yAP^8cD>Rm8Ok6fOx=UyGGWzlZIH%D44Fs~CFm0gnfBF?;Xtf`f2Q8_wi#y@W1lqAeqB3A3FU$5yG z$^PqhJ*48BXt;X(=+In4LlCtN+!5Kbn5XbL4;3zb^b>5T&>gP6!?-nW41cyvVJb4e z-vO@HvI8&Nlb7m?6ozv*)p*&8I$mKUkjC&$1Ec1%a9YdA=Q^?+E(3dAUjynNo?ksx zs4^#0C%YJq(HTcvs&3G~BnOTwakqotN*w)zmL82tRXN}#{N}BEa}@kQ>qUUU0)qE*R zWun+C4V$jpXzA%d?Z7q@M@z>vtn*yMd^6Z{-xYZuXWw1DiL1tRxq3T(HVK6ZFh94b z8A7d@dL}kkLy!%nJi;U;)YZtIdR`zE03U@#&=+4~>@{dt=+Hdht%jg{R@xJn+>kzl zg5;(A6-Hjov$L*KjW~aP2<3)FZr?jP+GfBEn$L2UDQ+}TE;#R!(z6t9L*@Yd_7>2M zk<3zmFO#epEoW7siU=$zX-*{IVx#Uy=WsO7w(lR{?j_TuuH~(nNIb`|#M~GA?GanW z7VWYKOlb566R*#(>WHSniBZbubC9_>`+~}9nco?x>L4uAy0iG;>fwc1ZfZ6-T}qf5 zCA-^<;5+PeZNt=3-sK94;l*&u!aT48T70sdH)^_ShH9hyEOM>ntLw{vDpUlzKo@wV zYMzGnlI`(i_ms9s%F3wxMjDJyE@YN+2<$_=u>f%Dr&J5u5_I=}ZxE9?oVD2z?Lau8 zvW|#nm+Vg?r!T!zlWs7uoLV#5FrpNxj#ubYH3OEVwUb7=LDZJ=wjbsgX&0oe1v8)| z%gY_(RY3Z^w`}E%4dOivT1Aq_!6$?{7nV%M;t`mLrJV-=ynbJT4F+Rf+|_b6&;40e zE$Mmu>~|ij!brV-KcOpag0lkBML~Wz`2B?LFv8s*+@NoIx{tyu55)YqffRrmQO3*;4VGl`sM zIU@dM!skZk=>13eJ_${Cdc2n#u%u@3RsBWwp=G?<(KYV{X~%V@y251y*%CrCJ?`WT zKRNe3-zOiQbU7&b4ap+%_7_rmg9&Fkn?QWMmdfxi!`?K=*Sq4wChjl2TapYJ7kHEt z6U$xtD-8Hfw@m53g)W3N)cNc zFsNs53mOwfO>>o{bBg+AsZ8!iXj^Z8ME87s>5K>)O$P%(QsL|vUGQi_u;4xytr^LE zlg!?E_G_}y3d8v zZF3C*GWYum%3^&kCoF1mtXg=?jI%DrE?%CmV)k-guF5_>qrd*%W97Ege`2*BGvz_+ zcdKR;nHb3WOHP~#cZIDEjNa$UKN2eyqw;u9PE?(Asd&|Xc`S2yHZajzzDjJFXHDy) zyW07&3BehOJK{Z7y;M$iT&@LVBU&$O`zmn5c5AQXDE zV@lqAH$X}<70An%t`v4zzDz{yX?SyDh@mi#=+8mEPA1Q}65-No%^E9t5{lS+Z<=U&3(#P{U4sy-}W| zCeDE9X#)O`u4p{L>DxpC3$M-ExVQ=0?Q1F~U(XOXo@8pecihGFd)GH5B`}URso|)1l zRZ-8vK|Ke2X+CxduIa%J>L!J4h7Q*|PoFU6OMU&ycdYFm2S3>|0FjT}vn!!}04j~l zC?iQjEhZZu1>JDA!>q#No5rj4z!zTzr; z<&OA3W8x(UN=8nKrkV|lQafu}^6FNinf0ZhNvq<9S)pUAXthewr9S3bI%~i%ufcY} zh`qeW+LGE$yuK2CpCl?zuYaNH4?jugQCbcLB7ek2iVf!y-C0H5m`XtSX;m}IEG9(7 zV<@Z5^l0tEam>ug)8O~?MU%r{NTliK;HIJJ6E?3*nEiz$^!nF0d1&)ru-c3Iv^4DA z0E~FLjwmp>fMv7o&5(&$5&xPMajQCyz2J8st{^STK7+QSR)~h5>MRXYwDa2|eI3lB zTvj~T`LS=RD;FRdDJWZ-bx=}x1;IxMSM4UrH`cyZa zSqoQA3AXFZ{!ESzEjT=B z?ld>~6Uj9~FZ9@Ud%i1BfKWmcN88YY?zR^F1NgKYrfP}Pa7)(9PFgW^)LAkV7S^X2Tw*$_YuUR}TmX-B zr$0k!`M&kNLtoCG6SNi#8~+;h+AGC0NC^3|0uSV74V(h6N&f^iPY zzeP_Xh#h#1_=1c0_USIEx#SYJk*4nz)>0WJ3G@vYcDaXej?(pp zN;XrLCNduw6|rwbRp${qHm|-Qy2GRD2b)3EJb9|PP^OSax+tIvIO#V(mL=)?Zq>m5 z5?mz54alyvOH6gjQ*(H4wwlE*>s7@20T2O4iWKMy=&X4bpuwm7?*-{lt?#+q_HI|cw2q4LF1W`@AibVCu#76> zb!g^;_)r=(cu?XKXwoC7ysF=_IF&8nHznUWW_hr+l<&3r^T#`EcDoC!ryY8a8Oc)F z^)wT#iy&*m0~NkyeTq+mj-kgp=`J4lFZ5LdzxzYi1Qb`rsqzqculcDYE%jfZyiXq z2UKjXPyV!qNBe=|6PW3()%oO_*5*i5cU$P%2yRH7*HmxqM-r?cWisE8-X(Q)nPru#Fq7G3q}f=x8+~awFoXfy{Dz zIB$M^iP`EMuc61QhaQ2N9>(K(DW`My8blovls%)NoXTf*hjuyz1LkKwrOe^(GSoYL zwqCzp9Ml@OSk2=>^1e~{vN?q^X2#=NvU*DPV=Ie_ump-c7~OOHd$+J}DCw>m%(Lxn zR~qC~9ChVuUEJ5ssFz|z@5QI!f;n5!pV>ycx#w5&S8oTVkcId))#@VqSXQRB<=NtR zwu$j%|JoKs&VZr|7V5w%+pOT}_&vsmIK15M=)7RQy9hb(S>EkzovfAj4J~pTDd`WV zzBPT++b}WkjN!$mwHSfk#6f5rH|j#WS2}FMl<%zB@$g7r7ELr^Msd_M8vDlXpNfWd zh8``$qx_x*ms#qrJ%K2Xiq6xn_UL-}2{&k?@E+hB5LDUT(w3FD!1{-Gg|=Oq<5@R; z!{!pt>{wlm2A+$Cav=4N-*h(H>P?Z`tbz(Qv&a!8T(~}Fl_O$ddz0q7IC5fAVvXjV z5GTy8A=2&WF^>tcpgk^f3i*Ivs5QuLaofQr$QODV7g2g+@cXg#y)oOKAW7=`I9Df_ z`%Y4R0k0Stm~=DihJ z(&G=xRF*MYS#&{3ec@JEvhCzIDc2j@vG$ycTMOL)GJEyOSLM)~H;L7IFb{ICO6IY8 zx4$ceFl^Ji?&^2W9~t%ib1X7S9YFauOF0-mxc%*L2l6`ywkWM97+U z`4)-a+IG;-5Gx{0ow#Cp?BeqnFo&Bs& zDJm~xHdhGI(~a2anx$l~5q-=?TcpYgltAP*Qf$zZoW_8cW_G{LUd(!?QDC~8S$#=* zxwKoYzd^(+rHOh0rl`R=Y-D~>2oiQnKHQs}A&xo0%Tfqm_Mwo0bH##_bNV}XFZB-|H?@V$ zkQOR&XuXl8o=ye|_p=g80#*~KPs#OtolK4R0YlCGhu&ed15CS~cTsVw;CSU^r4TwE z#Y$Gy%fB+JI12iz-Fkn*Y9y-7MXZH5&sr%Ht#crd;hw|olx#fZn=xZjtJN&m7ZEP` z;;MJXY={j3l`Xj&yj81Et&mlmiX%TO@4X@xY1V)?>xO4gZPVpd7SxY>UlVYhm1?2k zy|b>@Aqz{hwGl)jRL4d6*>Kk-tv%nQbd@+><2^5I2IM_3T(5S%@LE|OoBaOFXDML4 zH2AV~*xiN)!e(g-=@Q9IH5ouu_QpO*WVnzz>hCFRgwhG2serN&E!}7fZtz#O=s{S` zLNv3IEin)xpm;^dgw!O8oMlUk6hgvdxmAWJg2HJ}q$XOJ*~fbOZu|Di;o63arhz~; zKlT(WGk~6{zHru>DG4LwU=)P(9E}16tCPRC3lc5eyGg%zNf0#clm{Qyj>N1s8_8R@ zg_+HZ2WzZ_Vm!LU1bo_mXeyT7k7Trpe8heS!A}bXZ?zqmWf9+Uj>QtJTn^E~h9gaI z7@7f6OBYc}SK{$728zvo;!GRgHd@GAZWWiPYR}*_Tm!e;FZkWk9=<5;77dk39(oBv z%Q91m!ebWa&DsbZkNiQ{MxU;h(_)Qdi*j>~*rWJXDfm@~woGmT)x*ArseWsS2Y@f4 zf?v7$vpfAwZ+T3>C8g={RNXBJiW)~agZTS9@4cVm6rnCYUsx%h*@L2MGWTiyi=@hF zKH)5UkgoEGTXRb7WcO=-(*uPVaf}rBQjfy1cEu0r-eS!puSa7FxO>!?q8Ra9V8~hN z658Fzz+`38AJIjEwn613FTo@EyN%E_gIzI5Q4<_1pMa%sh{`q7$YroOLfe0d$PPwT1NNko6Cp4QtQmDi7b=n@%h<=n#I-aa_uQi+rtwYXJLS zHnT5_5^wWEr4|XaJMDGJk@Dz~s#+Gs5z!vzWT`qOpf}ndB}CnNBnhc%4V&z5IE8Un zRlL<5luRayAOBigLyWT5Kt!o?9*{Geq^W67c!cL!D0o=LYRX~=dG5G+jOO~ld%dgF zJ{HPJNnSn8s34R*>E$(~?WE*?vW<4IBI$crL<|HY}MS&1?$cU zC6Kq?bR#;l_SFuudCi-}$rqT}P}iohHN+hQ*MdfZsyw>Ptnjk*EvQwoI*S?OmX0`D zD0ofG-nqQ0Puu6;9YI@UmWncv3Hq*lkzn(BTJ%DYgSoLR_mrn>;zQY>p~PUV9}BtB zD}!6HXR%vm%L@q>HVv^3O7pdS#DRRDAYq}8x(#=u-qcVN;-fRB#_-vWed01x9#71L zc0UcOiY6B@21|&4+%nouH%Ej3&C=Ki&6F=t^u_#`%{k&Vh;JzOR;c6wQ8!O(4D-1C zH*VoRr5c^@{5fp~3-?15uW42)FGR$&DP7#ifO87XiZ&z(W#1qceW?douZ3|$tt*2} zt5na`-Z3b=LTkQ9UTjN%NI2enYijA+yK8;VYWmfZuhjb1U5!q9tEOT_NJy9MoKjj{ zrl>H-2(vMAeg-h!6qItxw`gx6t@_vOI5q#8#^D2h1;II0g<)X{No z^o@sZrF74n0;11D(h^T&5H})$mtmD<*1W!DZ*;-@k&u}M6UCk7mbCE}+cR#&xH6ON zt#>ea@~1ZMFVYZBpo&RUoI5fp+Z_wbNJFa$6~4^!7<)Df4hDG*-0-m--Z6^P5-H1B z>bLsjoV%N;7ACK#{2p7fJ@Q+bLgQx%SLEIUK-5W_z4b#7E?+ihBx=E6T~x`Sjnm&Y zSt*IR_y_Az3gHh%d(Y7pi<6!EEt1Twt;T(m!*%OBElq)yig6xBtHKWlzNmwd1TZr1T^gvVy&sU^I0rp0V3{w7V$#K~6NiG+oS`6j*hxz4F#)~I&K zgo-F(uN}E0U8aTa0Q&w?sIIu1$-TF?c8DdHY6m{NsU!}eeP8Ll4aXam!LWW7JXIsH z+h{z)ohvw3WyE(ULjHWouzrSQ*NT-aC3t& zfiiG?j-_YhCCKv=Y&6v*w8qD62DZpbl`w7M-t()|7sh4l@Q&0ymKK%i45sQlksZouELkJ7y%^4+ic*cGMuZ0zL#z9B%>XO?&dk zi-ZXrqf=D{-QZ8ti8U`0*|Bw-9OmeLP0{WS^1IpLLrHo&Oj5Pik4r!Vm^g~Yz-$OjimrW8+eIuwS)(T5@FKQy}vJTl}Cl2*(NijTI>{3w@X4^x) z?IzgYY#{QF>EBMN=n1CBDMKB83xJi6ZOc+$ZwilRDC+~l|D%~#zKC+5qR6=AwvMWQ zys}n~m|8l_nx7TDKirJ?R6EhsXWauS$)bqCYNso+U?^KTf?C#ED-CK|!m&GR-xqhR z7_3>P7066~Gw}?69jZuGezz;+vvz*9c;z&r+VXt!nab;7Io9hbt5XH5* zQt*@&0c}TPUB70!FPo=ezN@Go0BiCBS#gOPG^eN>Dd5Pu`-1FdR&Pp~`R#?4OXLNX zT%vi>b!`Ky-Ox0P#F`(PlWoLyNVaevM)@eg)MsQqNx4b_TD{jI@x9KY)S^W87o7%= z09^mZfSu;vFNjQB{CdQ^IzFYxH-g7(C9Uq)-Vm{>Of0&klG=X8jmeSG1vEY5LUV`g zP1&4QX;i8@gHD3In&YhMZ4Xl+VSCr%9vq%MgHc#q)BgL2T81)BFaS>iSF*lpwh+|< z$mZ6m*>?#RDIntojmVMO52lZ@TZEZ&bV^yZ<_`PsbvG&P4wsp)zTUvy0l$%QjAOoN zM<^kfp$`$89dIRu{TUu)hCbpo^I>V&_DHQ|;Sg-HC88}140SMCfYxofNUc9eB%u4+ zH=hc?pp+chd)S&J0`I_$GP7@qpwaDU7n0iv>Qh6>9YnNC@$MH0{Z79xw<8l)X(k1(=OUnNnm@Z&bkL9BNj?9lHgEh%F^(#`^Tf`0fVb0RIi}L=OqoNgEM5 zN7S&X!ZDh4o>kPC}O59aB4_T$sG8eo1>~Zq%7ldd?Bgs+i;ZMKtw^e%*N` z&dpw}Mz(+-qT)N2h(+aR`iA6*B^JDKm6EJ8exT4fju@)PpJ)P4o)e9fn)RQ&ws!I8 zjai|J&E3R%PCS=Yw+QRzCZa-kS_c8J&jXs2ym4$Kg*^ zqy@^ZOPWihsqh@Sdt{M+KNgVzBLh(`*7}I;-NA13mZ|G^^E+sAC&-Ts5M9g?!HXCQlug!xZOw!Yqu9>!Voqi(wJd^{p9No|1duA?!5$!oN#--Q; z&Fw7Jwl3M^Vm*jMlwYx>%61-&SIp8>E?}c2utE`(3xdzDM$#mGg`&1)Z+W}{&*xRl zR!fIT`u;*}QWIQ@Y2lq`n}NV|Xm2Fd{Zw?`LE>Vo`b2hpECZib;IwYc#GXzc?jBn~;v)9x3U67@cNvSHdJ6%Tyllqf0e^kghWM zjah4k-$!B{TubJxRUVmb$kXTF8_6k(Q5QexUKW#v?!anGg4}i@+(E&><91$+MyMnZ zSfOd}j$1S_RXy+*L8hY>$y*A~fSbAp`)yC8(Nu-^yazWmAB!cCbANDW7w`;l z>mZxODsRM=+hA%b+vA)tF}e64W^#C?b{KQ6WhEMsk)hV#r)D;GwZ;wn#`!oC?B}K( zA&7WjF2KF-J0lKb(Rkym-c4>b@HAQ#112Mv+Sq^xhT#`Bzte>5E#WUM_q>&D$Z8bv;w-C5u1>SdOrZ8nK`3JbJE96F@3sZ3^sB`>8_ zcq#JsgAYjyf=8tX^Sy6`A~GYVLxu`l0`fHr9?^y6D(-{F1k;p0X%d^Ztu2ql%&iWh zfCo{PQLnW@MU=|4KZhmo?-ffWTK2R{mKDh?$Lbh?g`zi2T=LAw`Z!K)`<7PTxi$!l zq?gyFSM@AK(O#wIjL+BN5BZ*-GT1&|jLXzb@tt|?9awYQJlP#?x{}tz=lhLt&G<0K z_iOd%<>fmUd!wRvRhH5)xU3ZwJ+lH)IkKf%H&Aov3D_er*Ch|5LrO1?>)z?BqwpTj z?)e35ZYG09Tgh<7)cImU5o7jxp`7p9xB^Oqz?Dq$z|HzR8lF#!1ghu9ti!jHbg2aS zDeBwu!yJ_Ww+v}Rg)Behs;ILkq4;ZVZd zZnsyZnR8#+*X&=GUODpRs(q6ui-%~3>cmjnVZ(K-QIVGKel+!{*$1*bh0QW=8$@4Fk)uu?DP+9YWv(%#obuznQQV82NpAO=5kab?-GYc zmPZr!F;hue6EtDl;tDF^yr~}8iQd}EuC-%s7ttQfIUdy0=S$RC_A)=ZSZxtgL!1_L zURrXv*R(+q@KqJEMlx&IC5rm0^JoSC(55R0_n9B~mn_)aCW7@*% z?(4M^1=n=Fu4`=Gts8sey@E}DSSkBFR|0lmRJKyQSj>6@FwouSzK~FZ#%D`lEU%cZ$jOUUQ6$W}wGcZsAtF zg^n@~uj?lN(8Z->iwC33 zh3W0l8dwj+m7dD;)qM($0X$Q5R*Y^~zp&0E@9!;l-sTsImZ|UgE1Q*>B;F|@OfBgd ze|k-f+v!xMplfF;$68QJmEHHT)9&w%9?B|jMNCCiQ`1Y3>mS{G#r)rsUTb*<=_9ZB zAo7s=ZhHQ29!tR<9OmrWG)T~khq=re-wrEWlp4+LhL79}Ils0S4?y<=`*Ie#XMCmB z{N?t7p_DhTZ$lhwsl(S4`hc~!%=%|ahuRF%@~^@sdzoO$>En-hO=mS%ayecX?PlNDowNcV>ry+`Y9>*%4uLqumE)H@OC2 z52lnWx0sIFj#E%>?7%p3Ou5I%m3I8}JYw;Bc6?NzMDp7&0Gu?0PsJ zaOmv(bJCNSbmSpR>-iCZ+6=$7Sgf1f!E9zvTfMllR3Fn)vn`3|z4y1=jf{r}RwY~Q zd^}36p(!glb3sM1upL&IHOIkWQLhPLuER;SPBTA6WN!--~aZ$!ExYejJVj%U{|+F+qvtY1@B<1k+Er40wfvV_g$e(05%?91$u1eu*)DL|K)|;pSHKJ;FgU z!Kb{k{xg+vIT;}(aTQ2Cf53XrAUdpG0$Pth1K{3ianLAurCv9PUKqc%V_dguJg#3y z9b5v5)w(TQ|8tPZpXHOEY0NV}tK(-LaM=7*PG5dj${%0(w-fVMtnwEjh#Op6UN})e zt;)?Fwe8{l)a#~9DJM>i_P$471{KdEPAv{M|0_kB%7WAteWx%cU;3Vq+3qRY(GbFD z4hp&(oAa*%&R=s?{vzjpq3y3f>baqe0#Baz%b%%M#?1nPOy8{jV;svry_SFR6pmQH ziEHj>{I_0S{1kO<%-R2yHM_mxA^EM=z~A1aAh?JcA67{l45~&L^r@yk()*Xv{%gs9 zk+@~_+dK5!!NauS{?MPRzI198e{tEt>>-UwO9;PX0V3|LyMo;|CG5KYf5*4hALv{r3JN9sl3X!r!Fw-|~@i=O;I6 zPyamdYyaJW|MzQuQw3Yif1VqmpJi9eVa8fsT{ryKh@^kq{r_Vv;t>o+m~5W-ACdSU zk@!y;`yY|`3)S#HBJn>W@xSWR{~s6*@xXdOQ@4#n?D@sEvRWVa4+GT{(BYs)-|HA- zf2%9TLyn-gAwMm4JsV5T+;{C6#kV>0`Ih$;uB-yw(v3@q2%gIVy#J*)gp~X3r~NN~ z^s`R)zuAA-iC5+Lq4S)NAx7BBi12bjQNoK4k$BQfD%GWVRp)~B3V0k6rS7ecfZuF# zZm$&xys z>A`M>D{}gBA{4(p;f*hzbgt|o$M4$*wd124j0%^V7cW@10ou^$@ufD3LU85R7yfjm z;2ISVYS|O9@r>#GNZ-xu)%Zu@6+eMhSz+~pnzTy(2a9~3Vfe1|X=}zUq`f=myaF$X zwjY)+^UBAE2Qt4-B=tT}kHE-f=ZGwg9*VYnIYg#u9vXj-xdum*a15@+d#f|TYL--;7bBwdYd4AC`K*wDW zbEHf~jST3+1ef*Yj}F3HEsHTW?r=hk^1`(B&1#`LSbl8?p9c_AxP~!ybjJ0#iPB$m zXgeTYZoTLEFNN#ZSNd;%vWe6AAEXQG#-!X``otO)UB5kKJr|0(z8Nl_U8Lf>dnTr` zaDGpLgxFA$$4-D#gns4Z6rNZtyrFR5)qwn=#Mgu6sEXYmy0mlSOxE@xt0<{2eRI8d z^~u=JV=!>v;v9>+Tt@-Q!LpElei1!@d!?}#-T7+X%Ymy4{?rG=lG5e;q|L_`5wKNw z?ZuYs#IqKABF{|unX*EPck$$6m$7d~;vArKbL7^)%ToWR;|f;&)A`=|hp2-RBRWm` z#(o*q-opEM15^}l^Tj!QsSs40ny>)E+_6I7SA1?%kzDkKZl468Ez~g5^^eu2hFfF; z*5I;?-mk7yE(cTs#Qi={r24s2H?{ny1cug6@)|lpGshbi7S~>s_e+2#j=mZKlsaQt6Y=>Ou$ zQVh=jjQpPeL*zHjdg&&%kQ$AV^dYVkUzP7cIj8Jh*I(Sq@!W9YCn2vY&Dx(viBhPV zsP+dRoG)0(g`sze|c+cp;fF( zc4TkEGzSxgOmW}81ph8aaCWm|Hp6+$W?uthy2Q?M5G)J*yIZ$!hTgw3lj|MfcR<+&{07hHRWK5F#tHR^h#`ei7R6s9t1kha z`P-O78g{!}rw{ums*xho$*!{`Ssz@x+y3@F-L$XYTZIZU-%%EG0>T*J(R3ER-vs3| zg0&^fzFUii^))|;#Yl=OKp_m3;`yWS?Q42GtJD_S)&~fjGkBIl%)#Wwd ziaxuD=k#7XqC8|!cD{ydQiD2U$(;YEeflqg*{zj7c_#BuBBvWskK8)MkF=8)M-<6@ zXn1sc@i^v&e28E7AKaRZjJnH9MbJmYuM!g~fzMo|l6>y~e3k43gBH6;Nuy`^{fuc_ z=(#3&C!FKW_iA@E)Bq5fTXZOps&ctT2BDb34(6~^Rh;R1)dmHs#JK-a#|JgD{MCvj z@|umTU_k2|_H5H@#}#wBeCGWp{R^lsdXL}d-8fvlJN@Enz$GhTo9VNQsgM70xfA@ykJ{5`c~b(8|bNxM;^=5aCdC z?^B$J4@*46b_YK$Up2`t4`jb6_r5!AymGZIy%q`plR zB~cCE?cGT~_zD^yhN!#0k@sz(k0sEKZfl|PR38E>^HS2I%E^(PwNi0A&VLVC&CdLp zq4@SsR(6al9_G|?YEo&gVPzOHG9_TH-X?!enY`OPJU6(uxpiVhZiZi?&VT+8gQX`V zh||L#d2>42lpZXNTD&OkEG~BuoCrahBl%%*%re0^=WLtakyIv9z`O&z{)60qh}mOVl&W2^0iyz8`{GkT=#D4GyKX);5WN&ZW*%^qu zsp5G9wd1iPDqp03tUaaB(5{B{=$@!AqeAOM!?!BS|_!_sT3COGIT|#2be6G50T!I|=^yt#Z zHe;)YszUO1P4o!?1bK}f=3Dr1Vx-#wu=YE#^}>7Xc0%Q0F?r(RZpq*BQTiel-2*Ov zDxCA>Q;&gv%O`@@Sq$WAj%*te=PC!wBbc)ZGUO>?d$=%MLD|0a>5J@#5t6rG*vXSF z`G}ZY?+BXQXP)C^9;p0!27Ry`I3AwZQUm|wG+(FihgJxvXaecn@RfNeZr*_9*P$~H zmIdv7h|N6onA)`o*2rf=RYb~2SXh)=z$*t?v!M`x)9|tBdWnH1EUeB_{_alaZ_-9+ za>#EkzrLDLcNYyyn|(nO$Y`W37jXeD;YmW9;q;qllgjaN1~fqA4z!Hjogll?srur* zaO{EdB%khCsJVxaUwD!{I!!Y#0O)8AWK1DG3nthm_?ewOy|B$DqXrS)Ox>tVbtcL! zs}YCjSB_*=>zWAhTxQEz%f-@7$BDHYe-yj6iPWFw$XBya+xcAV?ixYiG`8Gq87+(A zmvq;GRI%=3jtkXIq1{czu3U&g79t!e4Y#HW5^`a4V)vr}Ea6eE9ou+}Gav4N!DV#7+u?n6dU z_Q4QsJOh`SqO5pojktD57~#YO zGRSpf=0U^>QYWv8lg5{U$a`in`Noz#c{UM$HS9YbbJ(%$ht@h!Ch#*%Ll}pbbhXWk zB`N-p(AmW%}|0DYcHmB8C##qMOb#nE6 z5%HtD$HOcK^YSJ2>B94&`J@@jtvdG{>&Pql$UL>X<1Bd3Z6C*sE|USK{Zr&537D~f z^lQ)XkNij5M3&6D%VL)*#6OgOP=Xn3N9#$WmDjI-C^rMXFK=(qN44*{^Lqat6PM_j@SKgHizbQOMc~>K6C5q)p!`Ye@A(wd>xoFSy*WG zYlJjQjSB81d65!a517+D+XzJoz?;b`!77LI7fCEbhefS>eP7@*8>^z5@+~ldqs`;} zfKhtnF_lnuEDI{7ZCF538Zv{lm4Y%lhIIi)hEb6!D8%|+S+Ff5)-Z_ePR9Y0?g|r39FYmhKYV_4jt1%m_$3q2>4nKT55D z`I!SeUphZJJ=)G!DafxW)TmpKSzl9ibjjtg(%nLY<9O?KxR2@S?;qHU4Sleuei0k^^ ziLmrL{d9{&a@+K*jc3cA6u5IXreU014-HcL&a4tKX^8~>cH4sr@8ltIqF^)<*V z;At-UrXDX@pCh~RPzG1&^h)yeT)&%OK(X2J_^1s#(BOD*#+`7u;iB60TG~cjGCuG{ z8CPp#n@L4ZeDCAZSVPd+;+q3L)m)mUlJ708k_2|OaQx@g%)?(TYUZibRk@4b{Cuau zv)NWal=xewH1!wB-Iu2&(MuMteMIz8|ELJ#5f zuU4Lwo?ttAdh*IAQKQaesmGJg2Sc!YPLEY3{myv{Y*cNlHVMm5*o@HM0^)qfD^D5f zam1NcJ~yb<&K>z3XpvuZ;m3R5>A#YWyDh|K{q?)`a?#FuQY^--aEG=iwsq$ZCRHN3 zFGx@_&cyBEY8>P`)yjRf;%VhIn7h#AYv2lYX}f{?WUePhuY`d_%=Tm8Wqz!Vr2=u& z&1y9E8`})@icP!Cgf6hHxa-wzAdQ~mZgGv{ojd~`OPp#or#|V{Y*^XMSKzq6WEW<4 z*a^n2+}li#_63-^J94*DVT%WdP4AO9XqI2Ofznm!d}W1G#)(0!8mLVl`K~ye-~r1G zT|Tl_i@^Rh-m!@+hbe_Dlj$z+57I7V@Xd+pv*QXq?vYNz+p$=Go{^&oh2FnCE}nNR z=z{}`uBpsLlF`l`3HPPBpY{y>?H60{&R0Tj_a;~F5Eqp=IV+APLzFg<(K8yn=88E^ zc*OWGgMXYXJvuX7kSpt;`ZP^w0ck2@g40g!2&t2)`1Dfp;={(Rp`vFR$81y95@q~S zH4RctaniMJh$UXI5^aFr9MGLxx%FEAOi_*`)A&;h%q4;NGessRHQ203;ltk7nhBi5 z$EVEB*6}ehnn78;g-&NGr^ukpkgjKTOfXD3=fVf% zFsXXcG2>&a9(~v=C@|??M4%Iz^G0XSN=H^TUcgJy%j&m3A$8EZ?Ep zwcH3}y+6|}F~o5=HNQ3UqG>&Z6ezTMpfjk-kTZRzw!JGh*?C_Ep^tK7tfYcpG%W4R zSbsV%ajw9^VDi(0FE$qkWKBxe&omv3V`XycX1SXJ%y$}QN+iGvCN<=oUn(k7^(-oL zsQ0Y*>Xxg{JSpL&zEcAZqnU!2#a(+%CQx+6^0~+{KH9a+=j;YS-1V4~}#Jzp0DY>%@>&-t`T26bltO&0{ z{=U(7UxR7q(Uf@I|HIaM2Q?LbTf+hZA|)!)q^XFAfK-7%009LN0Trb;1q7r@O(=nY zfJ*No9Rx&r3q3RqJpn1v&=ToENC=^Y_Qm&|d++bQ^Ud=|CYh7u%$#%fv-etSuT6%n zWxq8+bOrcJra!~Dw2UriuJ_+Mil32wr<>X3%5!YAa_>J@7x$lg%NmyFW=t4MXt!W} zi~Miv*SoouKKcK>aDjt#FiEA#g_YfwC9g%JS!nZQ^G3}ABM-;4E!3QE>GrKx({i8F zQv7&b(_+arSgn3V4gNtHe{+a)<=}M|U-N$A^sxUnACY%4>iiS4j&-j&MOAjv zp3BFd7s@3PzhvX5Q|GX&(atc*=>)DXN;7Yiy|Qctc4s%GQa7tl&zvk0C3lqh!tYeHCpO z>z{j6RmhDtji;8r$+nuDPT6mD0A0ryFEOj2tZcHz_-H0Nd)bCxlveM%>z2+eD>k_h z7uB|h<)gRJ8l7nEYCQbAYoMj=qt4c`I1h!Hqh%$}y@Qd&fc~V=Z|fnx`ITo=U3A60 zrbk}vRWYJ?MevESYsl}!5Mq!9#gZ{rT!38P%c0tE;sHIAH*bf$N=X!{8R4H`mRel5 zoTr-fHcsTNfzMUfEeqxzb~Fs#X;Ziomq(ym3FthWO-IkLBS;S*;(Pk|O`#8 zCS2@?%zCb@U_4F99kd@ezaNK>4`{dVos~l!u~cRDY53C4A`}WfoSs(?{==E)+|bL+ zjP=gZl6`HsW*Fj~u7ufMVy$pN6jzlRo>lnKU!gWtVJ;6^mQ8Tc!5Tqnx9_^`DHUIC zd;Df0PjT>#WI@Qs&=yP?I{3i2uRUw&(3^x}4n;V5fj8d=FPf8>gw5YwEBzr;8CG`L3Q0$9!KZ{*JI(rQtsPN* z62h`GAsHgk{5HYm_(6StUjGVQ$K->r_qbah6qGzui~z4zFdNfcEw@4{AF^82{vqQF z1yOtm#B6DFXy zS9$#ip{S2o^TF7KiHQoA4j;~+p5SSdBbRDUPYwYIA%E?oEoc64k38L;?mV70;DOT* z9WewfGFA15o$9>_yyrp z@Om^su2>r?j`q%TYS~lAKvpCT^4d5$$#-fjml;pI>eTkT0Uq{3D}jjXN+_E1m6Jx$ ze)TD}n)@ru-1F-D!e@b{c+L`^I>S$qPyXJ7Sttnr5yS&SSO^kQD z*xceeKdfw60j;IX9i$$4H?ov1I2#`=3#o}y9n`sJcv zCgU%Z;3NM*E~bkEwe!9xn*`|DJ4qoflFM~;2&OuvQpIAXjj%Bo7Kz} z*5cn($y85ju*FTx9Qvo6RQ-%AiC(OxbgzZUbNBdi-(?g9G6<8m6{2>VCEmDaKildT zG4HV~yeI`u(Y1_NkxX33nm<}5(PpK##L+>u&k0<7u9)`XV9m4p5dhM!9TmZ$@^yB- zo(SD*i$(`8%ez}WMgij2>>O#b@x=iZ$smBssvYt2nkUw{Mld^ac6dCdq0Jr?&N- zGg?=T@N|Q<4-mTF%isjGZ-DDPnj5{`n!}(9Sy7dcE*7dXzC#?nQLOHq7HWY!#61=S z`BYljwFV3a&F}*r8*UMmXvg}pLFQER{>&IsA;E+G! z&2-&)-?Da7xED7}hTqY#YF|$JAwcRuwY80s#rIt;{?Iv3-aB)f?dKg1s__dM?Lqs0 z9f{TV(yP;i>i=*_#hmg_=l|-od(`I2?2faUo~KSV3FT70h_Tbeubq>*jA8J{{L2|*1 zUBp-|B$wLDn6TpY2;a0mx~{V0vl~n`oNh4+p?Og_iYf=UoCML#OQo7q7L1ZH*9KC& z$DJU$uQf<802>;D^1*_6)gx~@$xLW|j^t|bHO%euV638YQYTr7$r;=H38jt{33LhN z&jYV_O?Ke;*N_z93Q&q^ZhM~k-mt1Tt;!a!NPoF(v`>&#nNdx)GJ3hF5CfTjaFb>0`2`xVKUBuK))MI(zcz8%xme}qa=0Ri(V8W;`-Z0 z|F&FKKbu=urHyGY#Am3mJXQ?2*iQzWk^sFT!Uz2(wvG()&Gf7&N*pOL{f;9Ub@z9Q z6IR36wmnhlSj8`ASDPp*(?bNYT0@B00pu0z_Gp0z30&HK_{HU9VyK>^utaYe$`Y!^ z8fq1k>)isd?Pi;qP?PcDFxX_IW-86^!kgSd7ba*8Wa>?f<>!t#6YYN7hcJ3076ggT z+jsT~_%_Ae8p=AI>V13EmMVk1p1mHFCvra;d}vpUYKac;F$Ob7iGX`-eRYsHFCt0Z z@7(50gg9e~t5@1vj_PL#qj^DI&>hoIKQO^KPi2Oj-sA@E>JP#sAG(+?kTIPnJBmWi zO5I!GRz3BpEz`4`ibvhTN1Ckr&*wgh5d$%8Q9K>14LF_hGv;Rv;q5}}yIdM1bH)Z<-#+RQCeyz95C{^2V@#!326>sRdD98;>(f1Vy zCUr}7okGw-b0nJhzAp~2Y_BWu_I|H@wrS{x;asmC7O8#{lHq>xUO|WBX6Sz=geuJN zGt6>MGxYgzl?(r}Mgx!V8RP$`q-a$Yy?K0VZ)J4!EZL-YZt0V`noJmbuV}dLo5?ge z({#?}*m?=I&R1qqiH_QgVHJ{2nw4LlRTT_AI*%{2z812o)Ek!GV-f38^jhxzO?{y-H(-(t$Ae!qKs-= z+KtfMFx5Bddiq`sc?} ztVA@3wRl3xYk4Rmz$JRdXe*E>m{4g$cpO!7&;?KZq&iC+BnZaUA~j+OG-h-R$)`i0}PY4&|<_;(Ds+*4#j@yyKT#mt{iNT+WQ?T0CpFo8x z-F(Q!IcMuYn{N~Ng|^=7g71p)B#CYO=qCK-<7@EyyuLU#g|&VXRy}G<3g=M(hg#aukz4SfAUWd;@Sa3I^&1Oz_Dky-x zs(rf{1UN>mvUjfXL6SYyc_BI(>RTU-1ae95#2TbHy0(5Op5Pdn=l+$xN@g(67o0l) zq5V?d2toCm343;hEVN`0>+z@>z6~9mPH#Cey9N^(RUTovl`zS<+E%uCKV?zCij-`w znJ69fd+2in%~1f|DF69JOu3_OM-Aw%=#jzPyt#a=j8Im?)K%X)XHgwGQFo@QJ2c=bUA- z3FoxX3IJ!UIEUFJU<5l2+?gqn_O6l!DjcS3WdA8x}v9I zjjERFy|J{ajew_>*S3~|E+37r4$+)Eg;9)k`SXM)^y^0I2aCH|tRdZWuN~gh-s3)q zefh61X)3&^>xUTGW%w@2aaxwNCcdDDB1v8D*9tvZ_B?rQS;m;KTQ#4cKGBtx0HqZ5 zvZy68>IV@7R{wi7rz_jdjsJHJdUIUO?|-B11~5s)z0)EMY7x->tJJyRMkn&fxqg`| z_U(>;q&9-1RcBr{kLPSourv=7n$JOwCrfrH&dU=h}kBcMT{y&@-XZ+{H@^Bg0}0^8$;30;>Ty?uxu^D)X`{*tE{xb zt-uJfm)yE|rBXILx=(6+SlDYqQ^>KWZ9;z+#0&`DF}c&u+Ar>bo>5(7HByUqeN^t` z)gNO7z8L-06@pL2v8$mVOeJ-+=2oLz66Y-TQg&AcT3%d%DReajLzf-^sk#&YWkMRP zC4pVijsRgCKe~PiM80sy@7viCNqkV;kR(FJe~ySiTyVSQqOI4K%xSAoZC+mjRBuR; zjz6VcoRaR*c&xOyh_j)WK;Qc_@*L_R#w{JY*mB_Ebz`hLncmi+_ilTVn3K5NqL2;c zfqy$Qx?Wjy@}hw9qEB_z@9Wo&)@}qpY~|x#~e2836jdHTIH@R?Bnn}vg;u8l2$cne6hue>~-{(#Dcu;id3+N z$tHl_aX8*U_ef*1hvCWl|3-h9RFlL(Y_r{YD}z1NQl*e^<`76#Y`dMRtC8Hw9PKu` z9iznIYcOpMl~Kz;@KScY;>WRnltCR65u5rjU?d0{y=`u3l5Q zrpKq;y8y}~N_uiAn0JJwfDP8JgHH+7C*R`s3g}QIrgjogCc9hHbU&OM-QBhQ8^NRl zUdj$!$-KwEgl9huSM{UEsCpM#t@a;wJd#eHwLRUj^=#+7aqa8qc?D8|0c2@3LI*w2 zD`_LMq#rIV?t7NwlO1G}k9tG73!`(1cR-}`#%bzJ$c(ClL`#5_=UfEmuesi|VA|)P z@xy)>tS<0?bW^-}RFK!MExiO|7qr>$Ba*mem4=vksc_S0xTi}@VAqC^s`ga^G?w6Cwr{& z%_kTu^D|{H#S6mJZ$G z36(us&^TGJIyUZHw7iuVnVtV3xWHOQCbxLr#;cpLH^%np58T^q`2^1(Kv-&ja+Ou> zZ@`*Ov1e-i?b!JwchF~)Xo-NU-=~clQHyrb07Fkyn`I-R)0aC>@W~tKEmrHC5%3}#NZhlDSIno zg)~Fw!PbOg`|KPRvu$XDA&YiauQB@XtIU#@ZK*#qqW+Cb;x_~?r7rU(2@B`}nIqht z=~9y9YsNNnypw}4s{@Vg$xd@EraH-eKE6SlbJrj@%T80Ha9X7++oDKsy{kjgM?KQR zeD>}M$@w`GdWf^fxRBK+d!u*IS}qnDzf~ks_F}q|#Wv1HxenZg7U`d7J7B;7L8)Sn z*3m1#A^s~JfIHsZ_XKkd55BRrqgprwaZx&sKNZQ0(8cZ!?|2-0b~=RYblYBB2@;ZH zIG;7;~v+ZMWkJz*n0DaT>7myt?~CI zz_m4cA!_*i`-=`2e#68ts6n6>!_lUQHfeaU8X6bXUn~bJvl`Ux%kOOr7%cM%mwv_} zePhR1Jszl>>XWYxg>+FQmpZl|LqYWwr-;MNY5Fkbd-L15dkUqzzr29u@@6pE$dyd! zt>WAQk%JDOjh1$ChC3EG+{{93gj4N1G3+#4D*O^kWxe910VCktZdfR{*a=Rwy)3|a zfDalSv|qXrWs56`vnr3K$YvjnNef)Z=i6gI|KKO-NUM0 z(_rr#25Y%JBz4EGZNu|Yxl2#-lo>ms0ag@mA&K=ld>|Zwd+6{zD44|q-Nia@HB=Pq z?{7HW#*y3gPM?P9ZI&K%SsI>~JlLzMJ_@A`+$^gMx0}3K&Sxg>{!tt1E1mXCccZxd zj15fE6;tKDL>md2(Uca^!{0E(qt4{Ce12axUV z$lzo#-%#X4`BUUte&cv#W{Y^TNyNh4!dbiSYhb%I<+EK%XNFcR&c_X;1PqEs>Ohj` z+q%3q4ix4od^QDo(BP|~$*xN{Z?%o_x#*k3UgAVwkh_!-_RZ(eM^1LIYY$EzyRal| zjcnsmg@cJ{M&`s6T==xSqv%RkUFAwu+a`9?on2d(7VPZKXqpl{UR_u3W6+b8P(wvP-yHZkG*mgJj8ETAhj09>#>p|kqLr`) zXNG7Hn6lZHURQMO{Wd<`6B@uqW;R)G-OeF`c>2Scl*cry`R9+9segDnm375Fr=*{E z(8N?94OX*Bhpg!vf0-1_LE25PIzCYFT4Mu?g3N9u{s8GheIFqo)8!MYj@|R78O*P^ zQu9!-|5>guhD@S0>({D)TY*=H{XcA3Kde$b@xi@8md3wzki^Mz!DSw8 zS9P?R9*aj!Dr0#iLS78L9T@HJBE)p{sqfZ~Bo1c&sOkUOZ}#RofXTknDK;II3*?;< z%#4AF4akq6A9|R-L8AUmaC#GE#PkU1>(~{ppV@WZ!w;quuTJ92ob8P>jDO-quMb`7 zih~JRANZ*n`$tsh$)VvldTftRq`_3>lABc$W=>utJ(FC|m`_>WS zcs}K9R=ZoAPV_V6o(x6z$f^hpOYN(y@IQOcwIuNO_B^ zG1c992g6PrCX{f?1M%SIaifx*QPT^U3dQHE-7XD_f)3=jLhqC3AYvksUcA+R+d zU1MkGF|?fZT>HiE@BJl^W(UjeoI%!Is(>Onb=&9uys++;s>_C=sUXm#UNm!Da2b!L!H@^aNTGytIt_j zPi>Uh6U$-)5tY#Vg&cyvCx7nyG;Rib)X|+ZuiB8v&(cAVfGvd(7R?uRt=lt{l+}!F zphwrT2uoqC@e|U1G@F$NJVNj8w9_$29%$R!h4EiWO`0hZaU`8H$Q`?8_lR7sq;?Lz zmn&YawrCtcxld$@mY;DR@1y4w*r#DhHC;NYbEQ?ziGYgK4C#jcXj&`H$$`6_DnF#L zs0&RQ%I;20eQN3+SvrWz3=E4Ap3@^0bQff=a&_*l4SWV(L6A{Rvhb3#^l!<^~}_#0rE9t8t)Yr)3FD((WBp=J+JPbWzDaF#6`f zmq>c^Vb&Xda-D%E+j$z8DRF(p3V~&9M!LgGCqE@Ffmj*qaH2K?MK;LCFiGDuF6~MI z3@t42AO62#Y1u%{U5NQ#P&!E0qsrCq{fAa5hXG^E;YKU51RN0GoNhJtu!vl@zssP^ zaJ~0x?`@F+r)7P%gxH+e%-^pbnoBh7V-t*Z%w^nrlg=>wp|`=w5H{h{9=#hL0d2ch zng560K>H&Pcu&)kj;>zl@42^D{seU+{yEZ@q9^-NZU*`F1C$w?gP&1O+8I{dq;f(+6d zEMGMprL9g3Pw`0}%aNx?-&o?s0pYySv{d@Eagsk_G#;IM`15B*Y-}TLKajfv94SVU zOeYOoE6b}}WjtI~>VzyLDZW1I)C@P^6vBtF?a5w^C;};yeB8Vh7nCB}{7nK3xpaJd zRnIE6fcd{wC4I(>WST`+a24vP_|C53@C&o^s}dq)vu*UF*a%&tM#}zaqYcS|yk| zblV=^Jf1V3Ua_%!h{|=9K7M$W^>;z3B+S!KTrT@0pA9l4y~-dLNEy>r>d?EhV`&>6 zK^tXosL)c#03@P-E7|z*yc1s^Tr%WXHDo-9tjnFYgEF05oZETM)89rT5*<2fP7OSVcQ`a`xxhq4uwR2XxkY_78xDRqb<2N1m z7sWJwBl@R!+=lF5@JmF3h!C`^EqI|wMWu@|LEEL8phTsO&B^`8wiy_Om#J@#-4V}1 zlzUKaQkQ=4sPGz4Ql*xIAx0Ui88M>AlNhWKja!~pzR%41xTf(!OPABJzk6kDL5?7y z#E{xtja_4J=q}Bcx?)u%=ss99vVT2~x}YR>Ro8wo(MJC;Nu<5NDQ(QqS2AR2C}=^U zLnr+#Tas~moLjliW_LntiUcVc4Tc5mvnrY2Gj+HJ#h{v+e@Iy*CcApSYOf#E#C4Zh zH;VF&>MYuXbQfdR9HadU=EoSDXuRojB}YN(B{cwQMSZWuKikA!&L#rZXQxa$`bW&A z7zgGE6g)k{7#OQc_DNJ3EeQETa(_2`lu4GY266V23wb)ec26pr+{B$l=8>#uv#DF2 zBnY6=brXUfrLSe=Oxs?lPdC4VpqE~L??_-9P$fua+=sn~j?nhPoM{MY*S7%XbrhbpsIKtoDTjozpc@AJ`+RlSl5+VEnE3sJ}yRNTCT=WC9SnI z%8P2xv^VhY6ZH31Gfd7r=kdE&N_H4Yf-q52wK-?Yvg`d##J~;Pbj$*y%MoNg&WYbM z5hoFOhADuAF)4g4{BHH=Yb>YKzHl(^51lf3%+L7@^D(tXeNARQ)$=^FP@SVbx88-1 z?S$L-XuC6YS{#+}&JO^s7l|PKV>}8ZVPCk{9p_1a1>jo(A0J-K9#;&|$m&W>M!%6n z`t00HL5FPmL}NNO{px?DhXB-nkNPJEfu`pBE7eoD++vMN({is1FRx>{HAu`MI04^L zO#_a}N#$?s*X!p>)wV2x-T;mU1Tc3$^$GYqX6>ITcSuIa=N^2zcR-S%sOTcbq^{c7 z@mcHVQIBunqJ#Aa4sGU^tVs#0;xeOzrl{IE!XuI%UgLLTWx~BnM`K8&G)&SNBYgJf z(OTZgTDD=~!URw7{#g<$=vp=I4}_}nyNKo{zmH277KC<&!%CBw?b`uD6Q-75jQ|cV z29y6wNNU9@AY;4kA`g5r%zs#`0QlGe72GS@&L6s$0lULs&D6@~wj}oiCb?j+4(<{3 zP=o#@-Cf9x7Ldse>Q_U0R(JI6k7#cKq4u*ONO_E)a&99W6}_pN>-Wo=Fq_l2JvHNg zG=&XHn6C(Ms+tZ31qNbU~>%C862VrRp232aIKT~oQ(j&et zNi(d0X=;_yZoq88wSyrRx%7hOrGTnxEq()7=Fu|3aU#3CbKU!FLgU@8{eKUjq?17r z|Aym@9M!>E*(uEC!Z9)Qv^HCUo#_)UZQD_@Z^JD9vk6fS6;ET(L+4fjv2s(2)WK1e z^oy-{$s&RoM}KWB-UpmK-y-m$Qg9H48T_*O51+mw-7Q(PXU0UWh0E~{Ijicn`EMg= zkH7nd!c<1OJ@7434z6T|*R?|hAo<<#x3qoB4_$2!vBF-2rHzGb#9epJF%D64Y@=dU>)_f~t8L{36Hp zS&eP;8v)x^Z^=&!`vet3tK*!l<{Wi}xqzM82Se`VvrQ7jAKwZC4u1xf6V7eKtk?Im z@n!EGAMCHHS(oJWM|`nO2|}l3!e?oAykgT!R&H|7IVx%e^cVg~XqY0be!`MuokCjZ z&wUo6V~txnC*=B~d@3)kAapGo_egMNd~)f?g7ummgkUnKNIOfQ8fyA&hY7>8?W6$_2j!Rh%u_!9- z(yD1$WND5eJmuI*ikI%QF0$?z8cW%UaWE zah_5ntX8_-EtmU-AJP!9#jZhm@T^K;${zTr&s#?66Y1T}=GYhew%r^eYPqt>?Md<^ zN**9&Rl5?MX`gRoUNZ^{aIk6Pb&8ANYLGJ(5;!WCQqWm4#uTIQmO!&CLm7_D@qBjT zo1Ipj=v@vqpfuxjQ1C4=@VY)YGMd1OShUGmTwIE+m#e8B&-h-AKLOE%Bo7NJgo0Ki;&kqGW zLeW9qdSb%f{nw;4ei1&tuCs;RSYQ#$mLYf&K2Zy06|$P}AE4h5q$1XeLmKjL&&b_S zw_MgI92z{HLHV@H$Sh@-$*?Bktf=$J_{}s;hbC_-*Thmu1q5^5pL2LNf7q|9{&QId zYr+p>u-&&4P+t3zf;>Oy>NT1jfl$QUFREK!S+?7Mrb%bd7(G08Ogr4W!Y-{vrBrSB zn#8U09ZhdK6G9kwS$OUh;%(H z;i#+qf#71y;}XL~WP1m_hGkoRi@KsSHIr?~6z1U$&G(?ZW(-phe3_jNQf;ArRWBUT zYN~Vo!Vt?dSY3Uxhhx?NR~c+2Ja!QdCeIOT8#3j%_J4s(g~R#wQPeBgfs*I2NAB_X z+2;WMg`K+}Y{-ZAW}Wbi@Rj3L(5l?1<9}~r@Km->B0Kim;Hz)i?!sFj=Zr?(>3dM?|SNNSdd$R zl6-GOkycE4jsmc5=fR7}YrRi@{Ib?JN!VF(mJj^gDfe9^^{q*Xi9RpTfDe?z7qI}= z5fjKfECeF0G;i&Q+WCNy57DqLm;S3Gt9*Sv*>VCevPag?*%bmjd&hi#hxcb1;B3+XsAIU2G!@_pGl+VBhw z@>%`!0L(*OQ56FSVZO;OJk=GPWcz^&UmKM{AY3I78SJ$Zaz^gjCHu_gG&a zuQegN#5d!9gu27-VBQvfEA7gRpkYU~oHpLiWp@l1izEI`ORsnl-ea` zp3F;Onkxk`!^^xcBpx-Y7gw_&Nq0DkaiZD+Kk0AZfB@BZFcYS@kI+c&knX!!97cFl zTXE-N-MXNYTxmL#8==hl`o<~l;ow(e+&3Ox-=|*rIr4Fwd~Csms=%2OS|U}JW~vif z(6*|1%xpzvpMS$uyi6%9?{z);R$XQ(xcC&X6R_&zY;1!ISLex<^5wpDKe{)@G&_vq zyMM0QRWl9GIh)a17#aecAB8pL_8KKP>%+u+Z!;sP`D1vC#GAiAq9&`V*3cF#?x?qV zy669nE+YkCXB;@yvWJlqsXtsX9D(PL?E-V{Et<3eY7J|u8@G9htLOg?9X4c<7q_ng zAP@QZCyjDac!QkV{uLDu^LzwHqg(`)INFCD~_OJ^P28|JPEL`1LPv2prWOO@Vd_Lxz%Hp z5AS&c#)X-hH{(S`d}00=R5AkM4IJFj_v4Tjp$U?D*LaOe7J2qX&YelP0UrHoyGC^W zA*?q_?!Jp+^=Cw1JGT_)8En5mlkgUVjAlYIoG#5}H6`0Oi!g-We^Z`;GzLw*?yjuL z#Yd_vFy@}i?G4y|3QE^K93svY#9qBm_hNbuBH zcw<#XrDx|AGd5+9rgya~5}vc*dJCf7K8Bke{FsAP5`pWiNb}amvlDc}lls$dr1$nW z1y9a1eHS^p_+3-CkR33llGe8BN#Ttch;2alR==H4z6MYk)}TPZXVujHQ5xXXNLG!0 z9aNRE(eWUNxsren$e!16i8x#Ll{v66smt2l4Lj#~g*p#DZ)?Y2-}RZ@p<5$*KAZB| zL0jgNqMQ@5k2H}TyLm4ZC(%h`td3m7F!cI6rh1(Om}JaeL4&PEfVv;wf5%ZN5r$=( z<6j^lN@e6Xo;zZb--Cg_19SfedPc$4fz0rd%1F1Or}0tnzxBn%QVgEv2Xg08QsUN3 zBU67I@4M?OS)y9$g|OtYyPBR5_bvMt=b%oG#h^>noN=Y~osX=dsvHf#8yql4#kgR{PVYEdcMmNuhGqa$+Y+C(}J z21*)W2<6Mkk~BMyU^{j%r>l{K*l-Kf&g*Mo*b%!>DdpDriuJ4aZiBnRAZ!wDa7N!U^=*fge!_`ddy_kto7hv^Pl}&N z1zhSBSNr`zz0rL@)H@@L-?w||>ECp%{US}POgY-vg4h8TZ~xYyIidmON`KYaGKk6* ztpukr0$JSsSJ=tqPPw{cHGS3w zqWgveW|WUolY@FQ>3h;n7AbFAnNPew-h~tc7Ilq)Q=jAOxx%Q4A#znjSosNM_k&oj zauYj?+B3RfZN?k$O6`6RprWPiZC&T}j7AHs>8e5jA9kf9Bu8r4x3w-YCUpuivd* z{%5IcugbFPBHngAHEb@4SqlyU1ztLJ)3wv*Ku;KD{rVP;7>5rQ&MK<;7Q^e@k(@M5 zREhl9)AEgm-83Ge8w94|d$oB`Im4$ajt5C^59qkz8KO7mTkL*?AhB2CGYC^|$1zGf z+sAHQkS4K6`s&({ET@y`AmOeX*hcM2v!9dzAirhSnCa^K8Jw`Sj?>gPE(0R|t~q4G z)pmOSsiwBOsu-Cb@wSjY9cv=7ABl!HT=(1*TqG=F@93%)=?YP2{jXxDd_uZRB#{1O zm!O*<|Eha1xSx=UTjhPG(TaBH%T2n3R@__0TZ-m4O81RU`J-`(*5hZ`>gDGBKf-b}L!Wt1ko&dBYUdQv&|UK>=Gb}Uqy@gvy^l^9n(X2j7l2ed zW`SBZ;_`sQWu>{;MMG8B`n#@E4JY^a!aRfPxzgZeXW~ty$qX4#hWKDR{6%w}EMP(_ zADY5N#ir#t~z65*!rd-0@<(`<97Wp0^k zIegy)_OjAoo+b$$9^=h@fW-eqlZIf`4nIu9DN#q1qSoFAHl%aHTP$i0sNnT?oQBTR zZ1@2h21Mh#RbX8d_i>ZDmNRi68PZl6DTnzU)1~Ex3%c_{#%bM_FRQ<5o*$a!9OQFp;bMnf zjPbPm+JdhpW0AtIWz6zAO-)hcU(FSPHGApy#B&;u%*moI_bOIbIHbDs#nmU{rOVvT zmAFzJz`pZC>jnvkZ0XeqB*0y(NfZDgtjJXH8V$UA%NY6^QeNP7?xXvSKEdRSVgTU3QLK+%JNJb zYjV$f67_dCEl>K}(snR*uj@JTDboy`E4&&i>SU=5u6j&17z$dB`6KreE`4(dhedEx zip#2A)*N?P&8xnOHH4xtGfijQq-`W7Oj_I)(M#-{$Yfo1$W zH_wuSvt>|fQfoIktz5LJvgZlYFCA$xLPt{D@9@&o9JidO@p0*WKli|vOnXQXTf_7* z_*N?$GA_^Ym%9XaYOmd<{*w)c(^No=b;fq#k}1HCcGRozh4shnPeW~&(-%7N+dt%6 z2M;DUL!<~w(OfKpGP4MQsS|8Jfj(&JkvxGtUQtTda+-S(^JU*GJ{?{;WYc)eWPzuo zi3bL2S{o+T3S4pr1{1r*i@=A2oiuH!tk3ilq#}$K_H3c0k`{1UV1D{kyoRBKDpkg{v%E${q-ve)DkgnJtOYA9#q(n56Ym z9?xLF+BU`f~*}} z?c*ki1sS7{jegZhOkc!ZGc&>567ycT*t(`AQDgQ%YnB2mk20^lM?duF`u^)fC5b`h zas9D!{Cu#JN(H^mGUfL1GgJ@K{$NW$nBodbZje($qG3L0HnPQb8^P~Xj008%Ezz$fX5g-wjFu<(qby;iiBJB zD=zuE1P$j(JDB3@$A{9tHjSnC@M-cbw@Uq8?ft%Xy}xNs{jj%m*<<5z0@yQl+=w!J zUglGTJBK#39ivW3=F*;E@u;b&xff6ttBUF zSRmP`Th6YEEvkZX)~jK8hE4PzPTqM^E_7tgOWNTDJFMw^qyR#E+PDYuyyK8oL@_Zu zA%0>!^XQ8jCS&t~M$ShaaK$J+HzdDka$tMn<$2dZbbAv$@5#f#EX)0o{OoyY<(3kS9ea5AD&mH63+Hkq8_F$HQsmJr~=!FE5asHU22px%6;HkF?V{ z+_`nmO=g&>w>;T?Gl7%gWp?-Db}!-8_O7ht~{-AT;8P z?iv9PhGl-wA2>G{#>Y^AyTtT3>WNo}(|1v6vJ62k_VpnX?0mN1TLsZPAGY>E+56s? zJsw?Uwbx)~==<`$Am1?kYy468hTW}|U+@sN+=q>F_vFJ1wFQB0ymgZ?V`<$P_SY36 zy3ITjf>;ECxx#wJGJ^38&Dqid*U347gt8Iwf-2hi`H3<)a@^*S_6b_j()bLt8jV~L z2oW&#x`BS>q~NnTfhml$+33Fb!WMdZ4+5TKcVlvM2jqEY^Exlv)jU>?pnWHhF*%=c zL_nBhZy#9}p5f)qHnSBU;6_x|u-W=(+R(aCASK9Kl;)GqAh7f2J3d$|=DY0omN+2@ zFdrDoi`r^ND7w3U5F+~ckFb;hyzb@4Z1d%ioi*xUCJJxE-rj%vFtU>j>{mvFX9i%F zTA1sONZ_0kgg8tRoA~#wfL=tgg=qhQLQ5W%sz3MjsM;#m2btbLV#hJ&&&2ugczZh~cMrW4+1i z{P?#h@EJpaCI_DRaZ}|-j03nBs!u$2`B6ysxoYc`hy2Ug!Gsb|*5N<6yB^B@XZotC zpS5{~8$PH<=OI5$C0Hn+oPE0bRAcAz z{h|a1wp4(dqjI6EUlg!qTs_?QQbpH6)){lBZMKjLY6auJc^_;1R@5)#iLbW9{Dyp1 z@sgQVG0?b+#F{V*(wm<>gJ;Kfh@p{>cQ^o}cX$nIqJ!-|@s9?P4c5-vf>=kKWlCdw za)k;ooH*`WRf(dw0W0i?fF}RBlhGk%yU&in#G#MK~qVY-gOuzhqbF(ZIzROfk3D`uPkdNzLEw0%(55@~~4^X|jW9DqSw@|0C-?quFryxNmz4 zYFF%%s?=-|Vy~){wpyc9v}m;U4r1@xqeewgYW%e*s%C1H*t4~XJ!(X#O`rU)=enNz zKIb_(d6k^xMNYoI`ThJ7Tn7(7;Dx?}YT0vlePh(ZyH;4fzoQqlDi4>{e8`+VuDIw8 zqbxpr>V*@Pxv#E3vn&fKHp825g^caJi(hu2F>#x9vZzf9`0k98nWl0s;8pX&C+Vya z_~OS$zBXEqg@O-3R|d!r#{Y;a$M1}3aGk-_e+^RjQqA(hZ=hso#cAyZT<>ej(YMPg zFeowPaWU8l^L5c26!uG?Hz`r!e6G|38xw0B)0vpFw=MFq$67$rNS~A00o#bnrO7qJq~g zZE5(06@rX9j@)$SY|Hu~g%X;asrytUOLJ~9PD{_)jpqqukLN~HGDJwHaEN+unQ{Ezv9lAcY*ebH$5X?8MSZq^S%PLhCV|0ogZ9r2(C&$j{FMAqfDl$ce$NkV=#5m3_Yj zaKw(~3We3$sc{W1qS*{}KLqiIx18wU>dfM@GArLEd<~TIlkM=u5Vb524*061l1chi zm(Lg)(9+5FmrvVR&q97O4`)+8KfB$)k=`aD<^F6`ePdPT8d4Lx$v@~;&YS-fLOX4u z4mTYmzc!~C-BDmXQZUkhwm4aB2rAzD?Fa}QZaH=*y0_B&m zy+Q7|R!1z9zf^aQCbhbQ#}j7h1$4^K#`4Zeeb*FBziCUqf9xt#t@Sp@2K&OAigZfS zWrI>XXC!hh>g3%KhNhYhCQ~_TyqHF-HUYR~1D_D4PoiCvzD8ofDY+jhEH*R`+w8vr znja#TL)GY{1Ctxx?T1RS@)XRcFnUU9P8W>uTHRUvUUTHBU=4&X7yBX_i^g!yijz)&kcS1XynI&noTeU%k9 z)KLc+Om^8s4uNP$$b#01uuSikvmJkM&uhEedjy8hXN?OCU2lj5QfsV-6!p{7=0?X3 z;eYO~-3BF#L5J(ser~iQb`bXBP~Ds>E#p+n9%Rmfp$$H_u#)tbIhZR09{wK8`#%M? z+dyxKpP{-!679cM)%G!!{kt=l5dcQ+V&#BBuY5tpfg~ICl`X@ zDd+tsu*WghrOV2>F9e6j5&BZzeNiWxZoP#j&XazR(BLDexY*Oe_(e~=z zT)GCAx&Y54#G9w(XA;{J8!aXbq)<>w?-tq4<%teSff5?m=-gi6&5nncp>xl0saK2MOzXs(;gx7zw9;r92~I^d@e(03 zY}0o`&*#t5%|)x#EV~#)*4n*g4IkH7Zve>1smqDR9S(9-hr;+&JZ?ZiUBSjZwF%F-0T0*E9ukObd3y>Bf^N} zP<soI1jw9ekb9*q*jqDH7GC&CQ1q`EPM+9+VbJzEss4$Al0uhkow$RBUGhORz(1 z|H#1nX44r%iG7_u9#_qz&aZK7@gk&OQY8;;85wT-tXhtIo#*+7yiqo^g-t=+SgKmV z_d+wLM_QG1c$EfICkwcdbS+~Dq>(}m-&t6q_NOetmymCk9M|}9uP^iWxM#ruDM-7Y zJR8yb+wC~`#%tI=>s$m-6~pAWU9e$uO<8XG06OW7N#Zr*rEk%UV< z4HMFmY@laV!iP|yhu$#Z}OyZIi)D>=7qb&f*$8;$2xWXqw zj22_u5c~bWTcX5OgMwtTrhC*4Fj|g8BtTH2Tf&LqT%>!|vKd;s~!-DQY_L z!ZNpACE7+uZlgs56dEyhbWaVX=mLa?3bh#+4NHFC?^C+*ALSzOI}sgWBB1HPlb23o z!2@&s{G0xP7J`aNTs-y_2ClbqGz1BwzZ&#%5WCb;<`UcE@7K5T)=0vs&01IYQ1KEs zx!iK~J6_)uQG5R_{x%^)dHhkZoRk4YR}W2tusZU+-vvj;dPAm}%so>AdAF;`>T32c zv-#WnxUk@j(C*OS-;T{cL7VPG4Yfsuk_SSJSM(x7S;e2O90k%ATsO-O$kZ4g#`HdZIKBg`u+nN1a~%1P22<2PTHh3| zlvGKQ_vZsCKlr?wRea$)pxG@nXwh!I!+0+$_>?4J`Mfz043YDkgG*_bS234AemmzH z?gO~v)_dMnel$5+63y-=^SC=(_iEPs_*?5WM!hLu^=uzOuiezxNjLBeXG*kP6>T1h^Ph&kvU^EjD;BeKV6TQAE&)48jnRsVU%92_S7*=8#Q^X zi4j6P8CqVm)atrun>nD9_9oV9{?#m7*^-`Os%i(rz_+lG&`oLdJFoj^IYG%j>#d}Z zd)`g1lhHh(uXKLs{+eV1ZZs8qhXZxz1sO)+_yco3>*JL94ZDlJ=R;~|d~{pMr?{KD z=-&zeZW(98056JCv*qFW&`50QZ1F^AFO32KU82s;>o}iGea~ zEO393R+KR|*A5(jTh9I^W1t%O0g_fYO(!-Xtb z`qVH*zjT8J36>uI=^i?@QOmOyA9-$>BnH6hil?#Qvk~2}As~lzOGo)|bY6k5zeWyy zmas}{K23RgK4l`|-joD*ua(v~1egsY#GV><`S`6Q>}KS}(@$E$11YV~%vjS=b((ir zL@Q4~xG1Sel|wvxgPeMVN!3^78?6+dH%#?Uaa3xf@}LPTXsDHgflypyc{rJ){bGnQ z28bla=tEytzzg8{QOdV$Bg=|bN$HqOzIRiAohkzlsW1D0&_Lyx=SJ=k>NN=4`Y*Tr-%Ek`dbm&B zlY9;v6r_Dkxjn+IJz_z2?ZBs&MOJtvxJfA3t-yHLr;D5@qVasV=$OKy`fIZC%(I&_ zv1NH`1)r%ziUX-xc}w;uvb2gfZut{S3lrbCwf)Ll6sJYnY&S1l?oQbJ^DQLkLg6aF zcwWpCS@l&ocSVusBxoijQew~QSR-Ta*cY{@zpnYeMP4u7dLR2%V2xl-centq8n(EZ zdp5l@p1Y8{@V}yxPK{1vP+N6Q+${ub6VZIGPTKUx2P)mKA@^ zA)uafq01J(j|7LHUJLu5I(8rL?Lsl%)?Ajq{hANBc-5j(GoRbCi}C)xu*Uy;S2k!f z=CXU)rRlidairR#G#39*Z1r&%Fl$2&P18kf%pn#3pk~HpKh;#>~3#ZGnJ9VL=I%a7FMJuGIc++|LA$Xn{UT|X4|#9d z4`g;VzD-fJ>6P6ADeogFWuZO~3jLE<7xm(hX6PJG=i<$-v)`qN{j#qMG3*K$YP0bs z>+-MleDgy{KvYG|%}#F1aBI;02HXwDcY&df5>%P@Tf6k?qX#@Je}|KIX7u|i3Zm7g zqm%{ZDfh22UQNl{DubS)HBLHHigv~q+D4Wz@&iqqo?R{g7U#HNQxjpc=?r=HS)Ph_ z9qjfze*(Y6W+LCb6EM>DUMV`Xf+X_Eg@J0|*y2PcJUt!Zu1q&&HC} zZ#7f&?hn-?P+94!eJ$50tyN!7uN!-{;iiMpngCZ;mY)m$(8w@j{{L;pE=zSd$0e;k zd2SJCR~quU09md+{^g^ezM6tqGIfyigF$C~&AR3ErK*p4#dlJvtHIJGOn$^4s@-uY zQr{^CYG4LR)^*dSuxMR63=O`|GPlADLN}*M|NL|J&(d~{oL=fv#0KHTs zmH|2%XnI!SKjL>7QS`m6P>ox4X+GfC%EJwe9T&cF4m=ue)6EvfV;oaOC+kXomVDcF-hnS$SvmliEhc-j>U4D!>WAdpZ8@;MnT) zRH*AU|F0!ALj3SG1V-rOMmdE9oJ7sGm3c-i(qDDJ{EFhqau zCuCUzDfd}>+C)u%H+QbQu(1(Kas!O5-RlwqJlo*`Zs0#f8Cu&@*m6}K6G`iKAgIN? zYn=?f8*FY`(XGyL?4DWp`vL@L_>j{3+scH_e*`f-wrU)?lqG(={^`tf6`$vAr8Hgk z7jC^d!$5o1!IhLwf$rBk3$1$#G-v1&#mB~=8~QcaqRJFO+iHyrD$1h79=_|tJG{+j zEEXmh>%zSD<=1`_y(P~HQ(qyJ5b*#TJ3cX`jjy%(hs@pI-_50uNUdG7V>1c`Q+STK z4kT?P<8{JObz>t6XCeEVZcc0Ai(gP-vQ$qx!WHQgjV`IAq5z&5{lQ`S{KGre&jtRJ z>fe22Vtl3$%6nMYAPN1qotFKLF-RsTh133LPW(q>#{gw151F5*XKI%kijnGrkjLHQ z_WpYonB~ld>vOluGN0zYI?(ZHrN^#h@5RGS>)AdPXO0 z@<2f)c4C@54)g+Ar;Js1Nq$(l`@5#y z%X{_5Cto@MyhG1C7Z@J1;i~se29yVtYQro5Exy!WW*slqUYsW)1pUA5H*4=IV}c1- zPQqhyup>*UP-cLuk5J##@1t_!c>-LCsJ3k4^6j|ZkP#MZ`z`*XHstvh&n=|A`J=%^ zWvXhrms2DG)4W+Xm%<~;d&gP6jaBJx?;B6gn^A2#sb5Fk&6Wm30`@#Io9>|8q4!p9I2c_tuP zR9d0BK9D)3=|y7|W#jYKTGxmx$8#`Ug!Lp3Q+wNe#4x{I;dJn%BWk?&JMiyC6dVX| zzC2xyl~jzcEG0|JWV~bHr~4cZ#gc%6JK5_wo^;RLn?Kp))f0R7-$^ROSCR)$%3$sO zeCsaR_Ig8Iec>`vZM0%e zeb7FXMkwp#br@fcc+##b15fV5hn=tShXrp@ar>`m2Rm49Hr99<%Zd^={DXPIqacs;&yj@ANn@IE?I3c{*)_z@Z2nWXDyNwcsbIfC;#aLt zPrh~plUt?Y({c6CGw0q^v4ck!x;2o=JIP|GbiqMFwpHPeH{#jr^r_V^pEx|>a&B9< z_8UUaG0uK@ON+(x@h;yu1^^Fx`7RDs_zk#*!O1!{naxPWxn~ZuvU;tC@iUn1SF-J{fU3uj?-jRz^r^XMV@^iR!v^9`e%M&AU$rGbc}S8}62NJvS! z=lloBk(=S3!lh`KDV=E($}XSIqcPy2+nQBH7}+AQBjQ9LDKKfm4O)sGg^2!+i{ z=8$@@vdnhm^2FcIa?%VM$-h&VFv@Uza%>Xu2+<=~SsdAZ3+}3M8e#}u;r%M5cazRjL{8;%C8@n$y}I7_x+7?}dFTI(Zw>oGASVA;t8i(Q|I0z?LFxZIYIQWc ztT-aOH+0u1;MZGlbUZ>3HGGmAzIPY%?P&MwbZ>6=SW>+qzc1gA(#7tw;;rOW%dh^u z)U-Q_)?WCQE!5ur3}?TG(dYh`x?JQEIVfNwoVsUQF$haxfw3Zy#j8oX7>RSaBuXBC zH~$;X3};K$xOMOJMHgEkL+VhDxju`O$S+Ck^IJEW64BLR4@e%qx$%(Z2Is+hL-hc) z@UqHy!7D<=^)9pZ;dc6-sh&*`y?IQVnmwXiXOa7UsQE&IYlQ-F*+@bu5G5>9XoVcfQxnv=>IpfEC z4YoEKWCWbpinWe1AH44+w`KW&{5FM+QU$&`Hw4c?=wOjqu?5zBU_h|ZGDWwO3T!tg z^W7P<^sNih;&(G|%Hm^Jly56z-uhNb4Fa$SA=yCMG{4n7mOd=y&NgiM$_xTb;a52E zoiTt!`vb_*4JZdN+^kH+zxblIC>Z!-t4&G3dS+0Ul^j}!yI&Pgd>u@4MNs(jSYU8v zF{bJ<3nKf>efYGI6D_zJZdCQmWQ_TvN)+eUogWIT!iH9bzrXYW~f#IveNmH zT?9c<4mA$g8naV5?oDO!ze(NQ{92yEj*ID@8hkw1KxHJ$f+@G*r;)Qwq0k|$BKe6B?EzvdmJUZ$;In<$ za3XY^6~lQoHas)1H!8~a37uYjTBoUIeluVOhOes0eggX6{E{?Y)(=bB^ZCVR&oWQ*oo4p-z0dFr6*7~(@{m1;>~&I3$C>-%NNWZgV1X(Yfq-}66g z9HIkvGCGK?gih9 zC(SxY_T1t*qy!#jwMNu_Y*4|J#+_cb=rMnV7Ih^AR6?=!N3idaj81`?{3e;C7$$lg zm-~)dJGofxZP4=Tustrxrg4=+dJhr4lkyhou>!aSawL{}<-RI4?mAS>4Fkqhb7YCi@Mm9t(nTkYlSTL0(gU;$DAn$zen>LG?oRd;Y|}xL5P(fUAow zOXDgx)$eS(PePJw*w=-_Z+#^Hp0o(3Fe)>&tm*Aen6Uybnbaq{=*W}27+PF^-VCJa zKcqhe9#YMw=qhc@#Y@<6Cj|)g{mkPld2rrvQ2jr#CXJ+dd1YL;y4UvqJt<;>^heCe z<$gW*a`5GUZwZ>Ux0#XqXKHZW-M8`Q$$rN#Q=z{1`xoK6Ewah0!+~96NwbKH=@g1* zyN$BR_B*PUWPQ4{yYThP${-;_#vqA~!xUPSda`{^0_h?GKqW~L*lVNgM$Gth7gR}> zA?H6)lOXvOY)r$u2?5-QLek*>rH}Bk**>Jk#ZVuT#QyFe zWRXiCYO;*{m;AA8)rVNkOj#4Z^+Lx#Of>VYF>FsD1fGRJ$bjF?Q<2fkh0_=0b=5Y7 znJMrqr7}n~uXWq%la)hU6IC^gdBow>Cr*NhJJgNIjU$AuFfe|Q=&6M$7X!RHfv8e- zxb+(#*W=#E;63xFRKu$)Q^~0Ld85m0q~WD{J~ea| z`tEwMOi+3Sb=C0e@;86py(qoiY`P5vl;6)w-<>SL%%LW2p>KW~b|H2SSNHA9rAzUt zhFufCZh=9}+=-x{TN2WO74^45>}@sP=BGb^Seu^08>>xi%MJ{v&;L*#Kj6d>&Awjm zGvE|DxZp#ps=dx+i9ep(ja^QN4vhQ7aGj^upy>In$ybIn)v%e!=2U8U(wH$?0V^)f zeiqSe7YEeMz%i`nNEfYUDm}rS)EY~ce@}h{uP|Im3ZM4eqkiqi(SF+^O5Kt|H?N1b zq?O|=N=Z()GetWgUXJ0!w*0iju2zsn;5!kLKMoMvqT~-J?b?w^+^H zAcSWRBfseIYO04lr3&kb6U1(V&e0bRV~Xj0!V8DWN?gi)d>4bnY(Dqq`z!NGDMVtq zw1Cm?M7>wIGWe8_dJ-&b#- zfA7?a^eg97Ai#|g~@ zy7VD-5w*^esrml6At)kd>fVqiEV#8ue!{`AE>n5bmvA9kT>#7Y7RsU}-8ni5B`>i0 z#V5YTaG&1$)w++EC~QB2^N2DrNA@DPx^Y`2O44Qb|9%b5=lG zPW_}11z)Gx_NMw5ze&7AbnT(ZYuYKqeoM&SP3ZSdkh6d23wn;`sn!I!`ry;Rk8>_( zf5x=awfuR66IW036Vh|s{S;#K&u=Cnl4|`_RjXUG_MX#qAfswvlWlabP)S9Ci+%+| zS|MREG3U6huN*~gdsU#(6mB2m>l7~;c6t45Es?fMP238-$4(gct<2MQwWLi?)^I#; zgf9T&j67x&9IRE}<6Nhij_J;r7!zgm#$S4BTm&Xh_)m=~Oee^QQCH5IK*Jw>zp&Pr zw_s7KSz!)EL?gn!KkDi}NZ)g-*hkrPuV_N57BQavzQOZ3v$XRgP)n%oICYL*yjC9U zxy=K?7?Qc(M3A`XIej~Cv!f~iBfTpsl4T{5c4s_S33a@x946CwLn2I+rLI5CXWykn zU{WSDFsOKUb0_*DcBn=}2rzGh+^wfER+y$^>}{*q^gIZ7@POyP!FMkSO#9spb&zB3 z^MaVj{yy4gRw09By1L)^f@! zHuu|`-Gx&z#^yQh*ToS^YsGWl&d*1+w6md1zCUHzR29Vi^sDaNW3YHBcU`!zinY2b zuORZ#@!2pJb@yCqQ1e-O%akslG|e!VJ47^YA{3 z<{9s#7~p%Wy+=fk9F1dkj2-LOfwy%JjiIbEbPu3Q^6mz+;Br;6!F*Z!FS5e}E6RQt z0e0>p`-f_o?06J_D6S|TfXkz2qF}!nqeJx1fDB_o-J2z?f=*Wh!-tus>jhi4tRIj8m+f-lX=UzBadcTrOM9M5-ABsyIHY7SPiAB*X zQut`kGkJ^wdS8Z%rg2N02usr(y=L(i2Q0nkICbOyLNOmkyfOh17v<9=m`H$?U0RIv z$yfP_hA3wlvZO$no({Va4J+=&E|JyjPyIb zm<%XSw(oxQ_qt)DGnYsUXCX^mPua#oRHuR$U7Fm%-cTcPP0cpxkQ0_J^P;z<%Z1Iv zO%YW5lA4-8gd$cDN7!M3ty+>ThH7$tbrO+K|v1}wpmDd{&?OMYa5W<_!bQNSa2HF zB;I9bPt(aH!Q)VqpQp3^ZiqiE4{0aNvSsu)`k14_1VnWU{6Yt^GFHKtHs`jO9VTY2 zpOqyFNUP9~R1Dvy-JHo{G8|mV^sD==I4pbFeGvgBeMJT~tkFa*(04v_lmHiC z-Ab6Iwf-74Omgfwi#v|3WevmjXeOmmHrfK6g#)CafnA@!&`lk(wx+xj()kg^-SEBS;qjnIt1|Ux2DpUShiwQ^qpHv&V)NCBTyPe}${&o##2i z^JZlC%P74ysdOOYnE{FBj5CEn1Y7lQ(6c;IV`ATdn|68#*}L=pCV;Ec&c%rlnWNIE znK0REOZw^Xyp#O;>AaRT0Y9}8^cdi`L{B$gwbHP1I&}DYdAordDhcy35x?YL8bJBJ zvYHN+SvfqxrY`tNF)!ZsUOUBGe>~w{q&-UcBYxXEOmw00W|Bq=Tgyrf<~Tvs+qq^D zI0X8@j^%rWKeq@PGW@k&=IdYEiuD`(!i#y-{Iiu;!HF)j2CxGym81xX9WhJAgz!*m zUOPR=O}R?(VPyrkr#U*iKow(}ZZa-XP|yt82fPLa(g_%|YQE&78$l*+D3-CiZzh?1 zdaw}Y00<>JMk`IqtP*+I#NBVE)h)e?I+&&0Wgg-URf^eA8oK-d>=oh<4T|wXmtc^B zZKhR^vIzWamOV!v3q<7coNuL%h5XyUtR*kBd$NdfE|{DdO!ZW~9zW~ge!K^q?c2Uu zj%6b?L|?Va)+xO1;Nk#`Pl!NO2jgtd%zFnzN@E(^!hE4tWr7Q0r}N2wh7KnOKki=) z=dmHp4z9X`VzjN1o|?FJ&mM26$y$^6fuTZG5%8+6@cTM{XH&HQu;Rh`AG^n$2(Qq&bf%&(0wzb;kW<2v)>89hF@3@6TvM#XuFpW)J!V5TyxOwt*7xLC`r z9Pf`;Ou%J~%8aV4I5g|zB{9^eN>+2A@-Z(*35){6bu;OLaFxOy3&YE08O`^IPNlF= z_Aoi5v9Q0ibZM}M2g+T|9@Pr!!KhY?IFF<0C1`4RFS<;~*tt5jJHxV~dpIce9YEi) zk)mmQwRdTMFZKv2tv>XfPC#(>zI+^ogdN$iz$gP+fd69wV1>XTqV}@+dF&pOfF0YH zxd#5?SNJwEb@m!agI2EF!r+RyeG1F>SFgzO8Xt2VNA%=ylZ>n0M6O@6`Z{;$sOpXK zjbNtxXx8Y!`80HC5B#6Ogq%(bS^A7|IPZf=lx0L&ee7;}E3T{vcjpB)B#5YUAJIcDXU)%zwll6qmre z3@Pihuj|mmp4H8~a|S&>znD$jRAG}|Ij5O9Kj~Jt2K6@hHh*Ik7Tyh0}L2!<5OAsN3hb(^NnLX3qg zaS8L-tPY%}x^5Ab91=JRAm%$mXJF=}NU%pX$*1+WQ@7}2h1EKAh~Bv{xY%jQZLd^P z;h@_5h?YfTD%ds;d9f%UQ#N>jo;We;?^M|=qYM-pg3iaX=07s-)zZ-Z{@zBY)(}zdnh904z%1a_dM}lST!A#mhOjn zDqU1n!TjNf#FI8Pr8P)DHdv?!sV1AhOOeHPf6W5Ds7o?4JPa)r}J{)fbae`-xroM(C9(4+d|dtUG`HvTO5Kah4DD|`*&8i7mA&t zf@k+Sg(OOsE@L0O^qdq}5^hEd<+lNF;m2U>YcB>abwi?PVM`_oe?Y3oGUwB^kWP4W zc$r|J^IQ2Dw2fB4SjN6NB?}QDdOMKi!#GE8vVcmbIGt5D=7f(@!3wkhj;@~+vmee| z37|MHVHy=pUEUL1N={VG(aJbgm+zTs`Sjo5Q5&l3N{K3Q|D6!3m66!B9Nx4cXqDaj zgj9F@8+Qz9LuF-!j?j~UIlSM%U6s9QXw{|EK%s0IVKUKDk)_$GX6=fe6k2ksEIjuj zxE0f^vAXgS85CNz#8xBA=n(tgtaVOLhvrmX^u^TRXVUMS!{5ej&C8bJ^6yeE zLvTGe$SBx*Z8U0C9dg4{c8I6iA@qDS%@&ViT3f;4olM2w#^g_WTK$U=`Ps4ahPKzz zHLbSVR;b(QX>xf#At;1t7kM?r0ci%XTPEs6yL?`2*hVr-8gM-J?6KNxwLsAevTV&) zG}eoKcDK{+GgTiL!o>@-??uvw#?pyvm#T#Sgz3m24MVuYiEcQ#RQB)<xtn0=-;xfh^o@Fxx z&j*QSm@IOCnx*X(kwRxBuVUnejZb@WJhE28>77d1(=6)ShRD4~Wgz2`d+PFoK;b_0 zYUKStUf0;|4*3^%B0ff%N)I_&t6!84JX$Wk^HV&(*Dec-ak0eyVLtxPpN-x8svP zst0d}A+EP^4+*pD#qb%a?9{#qnDkI#-iPSJI|LqG52Slt1<^6;1X-))p^63Z+&?Cn zhB(`$#Hq%x*>~=bkPQLdS+7QnWnT2z-DK~u>O+KJIcDr$O77&u29Cb`MkU@=DmgK; zg9oyE9O_bGv@s3a(%JpGun?I@Ilnq<mE5-g6k!n6aE}uJ;FN{L~_^Xz-cl^aFT^ z85ahkz;n<~?DL@lz;xiK{U3Lwi8k6BBU};$6hG$@?X<4n;qB&cSt}o;i=3-g>Vqw3 zb_Q)RtA6wrax6NWJ$CE5tXbub3JhjkcsvXxE@@|;SsO$b!2rW5D6vH2e z1`G!skBUzl3r6O$;rjz3FK*@~*lx{9N&lVI4W7 zu~|4gfQ6X|c-wtQ_oO;AAr?B<3cA0-E^Nd#_^z})7HrNWs>XiacMP5f9d?sC?f3Yr zn2BwC%xMKXXT^$K(=?(HsuFeus?ePAn&6p$Q`pt4;n9xm^f1ZyO}`*wL-^*${oJ)uq#A2x6`Eco!cG^uP5lV8Iac=qbJTgz z-BM&+sLC6yybl(=r46(NuP&7s^xG-ZSf&RH>-J-AC%nCI#^QWyIF$A8 zW4*e&JK;gs-NAo~j3yO;rHrS^fPl&Xm~Z$bTrAQl=HAH_rcU#G)3xDmjC+GI zbV1^MmmU)Pn*Ep;n@;O`Vwb;HI6oeaJmhpM*g+jfaj%XEB))*|OMIbqFO^o%``#Kr z{fKHvew}_i{8-Rhqq=2W>;PtTFax4Jd2EcaNaVhg7uqWHag7VYq~%O7ViS-E~y4?r;FUS z%k`<1+(FsnTYq=k2fxr___p-n{g)a{@8Tyt^ZblIW-2=+AjGL>W(&7#QaLxP*gBub zOp%URv27@_iabUroT090@&_*hHG{hr%y*Pmj-81^L1;wS9>22`)C?+px_vDVHOWz} zbd-LyX>Ix6e$xqJl58P!*55Wi`D5Y!g*?-Y+|I-QHk)cge+|QfuTuiYip(}C~nA!T`-qUK~JQhr}G6LE~`_kz>>zaJ$9>0r(jd|ZMtAO z_&wj_KKS(P$=l4xeh(`KUSDZ2f_CyxfBP5TMrFNYz`K>@ZT~VRmYe1YFIrrNt}jI18x;5}ocz_G zyt}?~uSlhpe`qb!lE}s4_3);ud*V#9@|vwHa`qSC41b$2AfU70lG2$_FxV+3C~jb2 zoJ+St3)-5+%ti=OuEAc(@}}*)YR!gQtqG|dtL(&`y`WrS)~4CdPF|l_PVZ2oZQ<=N z{>+qbD+zo5IOta_%0OF{2f98nv+}@(ED!!7jdlld?LDggFD!li2E%d49i~#xMBCnc z>x)}_GrjPn^~>%6ia#Y*5vs`Y#c2%`cl1Fp{Mm+tipIrJda}LKWY?Dcz}g>C>?rW- z#q)u6{eg;h&gRl5N#HNQVFP=Aq-{1QMf$vgjtrf-5z`0*1#bj6#vIEvW8l79V=KrY z+-Izf^RNMX9$uI$-xlvHv<6Gu0q+V1UY|+cpybX#sHsn?46&stAWMbZ_Og-pEyV!} zot~P?POWVAe0nVYh_S4cKTVsszO?*euv$^#6vd68pb^ zf}J?octeL^#K+RTMZk=S_dM*@lVYI;M#OdH{TS_qc3HyEF!48UPyAn@ub;;s;{Cl4{svin&{kZY!SK6JGJbQa(wISza zEIyP{^Py2zJEnPx$iru+m2-UU7~OM5$pWpa1ez1Ff1l3ZaGsWWO}vcC6;|u4VFhL* zi!kHV#?~-^ZPRPB!|tLX9#+SBWQG$w^@?5A%vJnxy8qcg- z+==OVwWXKKs`4n$E{Y#vHUAarIP0naZC|}{_fx~*fp+$57L3 zA$be?iwTtHAnAD!xLZ3flZ2?I(wtO7e{oOT48Y?8mjVK?+76)k=lB8HA+~;E;154+ zobaSDoQOMXy;f;F9?eV+*jWebLe_c7yk?%|=u z#qCaCT&?a>D)-TN-GfdWRML+l8%*DkyEHz{kg+Z8TvKhlq{Tg9{f0A{%F0~$b7K6< z0a3*TQFkS9N%mpe|>1$=bLf84fYu{*oZ6iv5_EY@e7V7DPI0-Gr3U$ZD%}re`XaE6b>u=Kk68(k3NTFKA@Q z=63i$bSxea#tIWI_#rRMnMG^;6xwaa^4%IieAfmG3dDuDTs^~R&ZA*M>WRFlK?CWf zM^fK|?Al|49mAX*@8dS`fO7kv)Yjm{nQEt=}flDBmr&?x9qGctZ6~!1fpY zKPzn!43q;q8Y(*Un*H)O*z)Upqq3xYa?8#GVy?qG=WHdQt+~m@9wP}rZ z=+m8=-KwR`k_!<`>?%-bsLgF_Xhkf1A8BLb5nJmtp5SvJX&j{4yCM?)lS#S89Tvn? zygWrU#9Es}9;_q-=`OnM_ZUvRq0S>e)h69@`sdd7s4i>X@ip$>^%l{x-Z$LfS?F#I zQO%9X5(*IZfAJUu5_Vue#wfd89g+@!83xM4!p`jUTk1g)xJ9$B&3W}@aN@b}2%_J@ zC;lVwgQ|9ilvaP(!-$|&3pxwEzyipdAC@^oNq(OovdvXON{|NVJX(3FZLR3V@m&ye zuJw@T=ggg=7sd)8mVlzut>H4^mwtDk)%|EuJ6?sF(`}FAJAziHO#^*;dxG8;q7mTz zZN0fDK&(NT@Z!{m&Ppt!_ug}v@w%aH<0^WqrR5YcY8X;s39=U~XkdI%D)*n_)#b%&#TvZ9EPgU<8p=^Z+@$v-52MRNJ4~qt^7aiy=6d@>-Ih@ zp`42YC~#Ly`q(l9hAf`}lUGBi?yluCDvlyukojJwYH?S1yy z|4;9i^GO_LnESb7t!rKDy4R9Do$Wmj9>HJ?ns4|dPPqIv*n{3-ImH8S`nK7(YMUJh zSRq@}XK3L#^31$>zsU8}lR+vYj2+4{;cnCio!;0@Ew-af720K{B}I;69frjA<4&h~ z@*dG$ZTe6%{c%&Au&9*v$hEg&QFLLdT|5>Qcfy7h4XgPQ6D!?Ms(Ne5Ro-{p8D>&S zal*kCPn@@nuzqEqr7Hg}ztfu(1LvF=7E5u?G^5TDbHX2?A#UprZ{y#G+;&|r>l{Ao zN+=()qjxSN9S>DEr|=1Tqp=~DKyet|?TJ{w&$pfyThv9pO*upTT@EO1>OUNIwpAPy zn`A>>&$Q@)9bKl}FVkey)5OzhYwEdQ$TJBi+=Bw@W6MY|wDayQ8_C+_ zN$%wWHg*X(dTiQ<0)xM$kuRKuVrNvR_-a$2xC5>F*LN8~ z{@A_j&y!ghFz}RRFiCC^Vnezfz+)U{)H{-WhxJQ}L<>2Yk$X@51RCyC`d`=` zAE^6tX)}kV)_tO)u4wxV0beJ7~gu6@=VkB^G= z+SN7kgjcNBC{+oW;I|&Xr12tJ>_05LHOjlK+E5NOR!P3q1?E>av+lEDsw}wWh@vHHp5dmIQgF%tWjfiyJ?Erx3lv&i?UT7E|Oy9uwFB^Yp}2= z`eAvcE=1-{`=WA$%9rtjl%ns2)W@;#auw69*h%M%x8WY(nF)0-*8}!Me3g7M!fp?@E zIqNARoRdnJhe=1f_=|F{`X)VATcp_TzIOn&W<+@}#tD4^^ zSPzNVczr!vQ@^F?8q>B8qOH8b}PjX z+G_e=y@jP4cP%bs>Vzata2`5I5SbIm%{yf=3igzW&`_h`oD0YsZ*Fi>^Hz8pn2)>m zd}mLNYK-iI8dmN`2w1w04#YZ_m4YM}GHjm6uF~$}53(yqr6S*|iu>+wZTl#>k!fW6y6-iQw{(>QLI@h2**E zfzBwdy9~__!G5|x?|wneS*m`)ewjPJ51ynyc*${@M&edaONAv4>^Nwh82b{+hX{Y#%if5Quvu8%E|7vgL) ztM%bN<>vs#aQ~dpm>A_9w~^~KJKdTk)Y0y$-f1&Gn4LhBJ#o(D7AonMN0<8e!A-@lLAlXqapKBHVH@_9rtO%S2Bp0N^&2|MFo}px(Q&Ldu&X;yCay`S zw0Xzw&5`A@p{6=^JH%sz;-bqRb05Zf2yW`zgVJ$j4Q4#>*Fo35M&8NI^BV)|_8xS) zcJA9_d1E>YnsK-1XXixk*)U2JY;^Kklp_xAe-%FMMbzxsjo#j$r8v6=~V$Tv8y7K?n1%g>Sr6IU;F=8#>u@Vh?=Fg2ZYsu+#4_82g#N zk^;5}-Fz^lsaLYCcqL&>TWHc2$9C)sujhSDp}sHHKNv?iV}w{|)?*69oz0YKinZyS z&9gT%dYQ2ss=)K1wFMcjeWLqguT2vq1!KI*)W)mgpG+V$_hVoqS~%01UolRR?(ta*i=A{gjInI- z)(dcs2cL&tPt~iAu}KPW7?8H0-ay4{c709XOdKT<(xlR7&GyOPZb;kgoTMKL?fmfM z%hOgz^XBK7U&7d!k1U1P&DGPpvn_iVu6wTr8plPijS*>I_D0@m*VG-Q5^CO+qnfrI zg-$4jX9t##brIF0V5ro&lpR=KSFp(9XT^zHMJb5|>{S(Sjxf7ic6F&QKFbZ=;z1`Y z+CvGN)nSNv92_@-I90rD8`0HwZ|;RN+Ig23ujAzkqTNJn05%Ilpn^)krrz^uTn?+1 zE}kZY$4kdMKb;8hTlfRavPvxj*6I<*THP_5%%~n$>yCi$2}H?d{P|*RMpqMUzaW)6 zl;oNPBqynQUtsWCwc#41G4gNSBCTkO-C3Qr7FSPRpz-OgC3rBR zCP!yqSueSj9$dF5sx=|#VAAiqEIAUT|J)Kab9Smpgc`Qg}j*YRD#iVuds8Xp}SGy`n?0sZ0%K0#ku$j`>uLBrfHD(6N>zNw9Wyc(_N$=)}#O z)1OX8tI-w&l?>$&*?HsNO9OGy&Ex#i_PT=EHWG<8x{nSMw!ZeUU~EJ_yBFYlO6kPE1b8;VOpb3ZIL z(k5HZSC}VA2(}};vwE|4VUvCN#uPhD8udPX(5~9!x=hV1q@?pi|5K?ZgRqy$r^icI z!`6wEzf?)TrW>Q{=~@8StYt+2wS)DiUujhHt58qZthT#re2tUdSF2BXil!< z{`{3osMJ$-dMO*PgH#;Uu1Lkak@#NIbYry$olSZ%e>t>E!}M2Utl{8c6=_YcJCGS?FG?BGiI22jCmk6U9E~Yo;YkwjoS&7;i&H z5MEItU$Z?84g-5$o%N<$Lc|?+W+h;sxveHEDGm($%I#2+B?$6-yUIwfivH@BWLHGV zoLZO2hHpJ>Za*sAM>*MFsU^%6a2VsZqHMn2+0GODNY}d-XCs%cS({DlF}-1S#$SkQ zj7!v#y^BdtYIu-6=NbC0mra@QTd5uwpCg`fz&XJ5=iikGAn>h|2d#1%c$*f!OyWaU z2C1gYf`<2g#Z>K@;m<$AESz>Frv)pEMJOY9VCA{`l_=+DiHPDo>Y0Y@^J$2L0!ie+SLEYKR<+*6j< zwX~g?_CZQ4teGw6P* zCuKT9x7loWvG6rr(~cRTI~EV5Tq6R^_d+Fe7$iG#nbvuNPpU?Z93 zgwJ|-xecHlvHyIxCuNX-P-~PW* zxs+^bQrvJb{;?NR^-dZ1T2o{DU2)r|di_I;o>!h+in(tMMZv%5l@5-T4#u@BX5Zz~ zc*)t%YnO-z-6cEV=%;Ovr92*e!RTmEaKKYxqt3$|rZhh`*VxF5chi>QHiLYsBsLyH zlUhWc*Yl2euw1Lt!=n0w_5uBBXuIT6bSEFTzt6l9^Z_Ymjd|H37KD6Kywn5Y2tAUurARzl)jd&g8)Nj1VXxe-6Bfa&AY(9 zb`|do)A92BkPnC6=P9k$qQqLuE|iXl``BCnc*gT{eaKUxT0^hXaWRzf-cJ8Fz7Y9MeCCkC*?q&q41|Cyx zNX9Ul2kz4_)QRPbZ+fN;yqY~TnaL(zw!S45m0I2f_lvis4xO-hlz)eX>N1R~oAPAk z&#dxe^O!l_sQiK$;Vzhu1yiJHDZOI9wKnwNw}mAQT86>n-dF|_wlmK?G&=$htMJX; z56JVgkP(#%>0?Ed_ofO*BN@Szw;p~sdGZeWiI~1SW|tpIt!j4%`b2G2`X)>x5)Wm- zUqH?Jye25sH3aOO72XEAYyvmvXqb=N$tdlWSXh*q&`n=6zpJteeY1_LWknO{|O) zwx{%593be|)i)27GQX!`T}DV}A%c+OzoHcpCU4_G0f>TmNQJ)B%@qw>r9 zFtNriju-SR0Y+${-E+Ej(s@YLDRu&Zw zdngqX&U2B}X_-!kvP^o(WnumzWl^oJoY=;Zv16cBwIzewEeqZjZgiEmj9r?|)#f@q z<*E5T#XPQoHBk&}6o-TWJ;{^QIh`&6uD9WBZ5R-^xu&8Nqg4clFEvaWixCH3TPq-$ zu(5Js_!x`+NM`KW7xlH0siw>G&158ReNnNFl-uf8F;QhV8X2E!jy)uBR-ur9(1?Yb znNNAyqlo1`l=pB!TSM@mz0ltLWQVMOS@BW2DIkol)gOfN2&`?J^ayd zIL#H%Lhf3S!%EoARKlC}Cg|~^IQV18EksGs&0JRSu3*)w{Bs3U5vcJMEp!?Fg+9}>m9-QL{V&6FCnGsnN;;bMb&95 zmfAk1oRVl0-QAp9EVt<~!xw-q(DffDX{Qa2PO~qbn?7NFA->&CirgPDH9ci2-UcGm zo#csKj5R2p@@zIFU{ihVLyd`SNSKEhVldPH;&`d1>HNb?*(Ik(JLGZ@d#kT&CiBdp zdAC_SQ*v_?=mg!6;V}*SmSEZVrNGqWaC5+aZ>n747lZ&95cFX`)Mr#xatA;{kif4< z(9D({6l}%{-)Z4Nx-eRq6TehS?TdmLA&O^{8A91yv3bKy(gawnO~XF9u{oxOD})P_ zXt{y+o0>Xb*guh3Y@J)rUnO1x_&Rwo|Le$!yFJF`BH{Z%hk-ZrwIeemYxfvDK81YbNfHJ}c;dwzNYB z#%Mi!#+IcNMWs`+By(5jQ@mV^)X8Jx*45WxbgLB#i1Jvt-NqRXwdV;l3uj0QJjTT8 z+4?s8Ca}S@?ZU{o&-XYtzk5ILHDa`g#xI*nn7ht}%_i4MdQ&xdpM*+)o-|dFUi40` ze;kEdP7{rQ`0CJ0LJU*zU2qrw7;V#ii{|>s3-;5fo5MW&0rsx|A1)8-+LZAT&e=2suCJW3q3HdN=En0y z8|&mJ?{U|~Vtd9|Uak{o?ldl+_lajTVS%cpj~5g6ov>aoX&tG z?%S7x{0LYi z)+U9kEE)n>)Vw;+CGr#_-iWja$;jMJ=t+9nTG#bhxp!w+<&JfVoR;7-<+C39gCc>* zyx{k-tn&<3crjPoRt0F~mT?VMhaj(A3MyLe7WTs8HYu-I&E94|b1}=5g+Oj_xnz(A zt&bhKY?e$iS}KcM2&67#-09G`^>Ue}ItcmnT5eh4O9IJxP_n0|kxXrTpk45}_ywSG zqhZ=pnX}}oE^}r&i7=0O$f&9V-x?h{k2M=FcMiAWn-9rP4C=~t=h}7-#G+J$$-Ut# zguq21;vwKU7cQ|d5NXJf<73bC7mxIF(oelpz9fC^mlswxGshiJVtDXc)9ekOdG0C1xe6mTCW3$(AE)&p9%1^{ZARDkr|0|mkCCe53+v_ALf(giE%`=ML)RWB6QtUh1 z^lc{gK8_Hx2h#1^8KnE*Y$mH-(d>_xnv9Eb3^AglPjIcM5KLR!1srm%fHpti_6)L| z=Cyl)()LWo$=6_;A590djCqyPH8H0nW+o#io%_2(ny(y)fYb)k)$HTt{sHP{OXU|> zhGZ8FiK|x?D0aY(5br*ovDzglD{DEo=fcYTQlqs^=jP36{U`Jahr0E(B4N~U!(=ZD z4RymXSB;}4Z1U?ECj=0D-Q~N1-+nwIS6^9X0v^+BkmT{9_Qrz3qZf55X5GvKm$2`b zux`7!yYOEp!pxZ+P#qnL=K|0{JQPeTOR4?~u(-{EySV-%{`Q-XEZ^L;ZpB0;Hd|esFEp9gSF6gblj)B%8wcOCdNHuQH!?!SDrnE^fR9$3KTlDJIZ)308# z{o5@3{#1W@_0}TV^*dhqQ;7JVNBQemf-`_|PG3ixpMUcmzY&ao`l{b%%q*4^7-u~O znt1%*V*YuW{*cQSt?!2a57{?E+*^}hM-)&Kd-zuq^$z4~XL zN&ccamQn7>$y?hjc)tuPn|!5!=zQVeQljQg2D8bf3(kv)OP{y2V$U8o-0?nGzqr4A z3Ul1gn5a9p{`w+z|4HVK*021_pWVT~9QJ0xpO_^4k2rl_?aH76KHqp}xay7%#cMoT zrtJr65JQN#!}3Tmm?=^1i`2|zt#J(kehjsl>|-y1P8KQs5Y2|U^d2MNttn@{Zjj4< z*S@~vY{W?z?9(RwNdIr!;6H-~z>|~1W%@FC?WKbG8vwlFVf+)6J>tORJdtVUw`dRN z0;lp@Dw27{I_t@H$h3lXy6hIe_39 z&$#yvOXUKT63E(=R}{IAyHmyd9`O9jJpAcL@I5rj%P;sHt9`J+)HCm68K#y^Cy;#k6Mb$sY|Yoj?Oz0-a9$Fz}3K6%7t3me80!% z$s4l|J2p(MPYvrEQiTpb*h?n<>1+RDhDw#u;Ey){cL-M*$`ruCEPkh**2j85M1kGM zG3=&w<-#)m$d1_YTiTl4uC)yb9I*~n!a$#Lzl~Z0#pN+xoVk1TQn^S^8dBj<0Ux%? z1)xL;>ZLaG;c3i~fnR;`hQKN7_M6{hjNg#eU-->H1qcJmmqNul7)%OaU?)%fD?bxH zO$KoGZ$~3@o9{qd#&;oUF1^P4P8%T;VgtHYUjGtvY`RDI$6)1ebZl>mWb z40@Uq6U!6@jxvYRa1JSBD4QU64Ztf8v;>4(6y0MN(=B4FB@>*UU%*hulcdQa>)pu` z6^)en#&oEyc?eBU@7`jZ+j7c>rW%SNl59`PFyAbjnG3bLRR7`YbRl(1_xX;p|Iy%V z0%ZEqrLYY6_6@kKRy-poNH1vQqr$21pbXZtxb}!j=O~(oE?ZIe+5XtKDoo?E&O^h8KJJ(n-=6+cDq7 zyZ2VUWV2y1mRuZL1J_!Sd`&(Eb|0#GsO*aFBh~m-Unq={ifB%Gx?ddBG<_42*y&=f zq8V3trM7yG0CI`;i0uw2&_9Etu%yw8Y;w|jvtfzI|46RxVBF5`ggx7*oKIa`lAY6L zu`3Qub91NLm#d)Zubwo!VuOCoOtGr7<9>hM5hGW1=+@Cf%!tpK=07t)3h1E7>JJ0- zIYH)8K;Ud}zdf#4DZe~Z{d&2o==WE_J-Uw^bo?$)XOb_EVlFw*jPqyVFS{6SH+DM^nPVD#)8qLtz)1+4*X!P}(M&88-!r!7~P!lcj4>$W?n*I+x! zC9DQ9u)6_bU*@u&w&GR!Ep$Y!k-m*38hOwaq|eq+bT*44&0%^nl>GiS`&GepJp*V` zIV#H=?t3^@Q&iC>7J8-#utvKKO`O_q6VbIFmL^^MdGt(RP#8dvp6;+qy0Ou1C>-s+ z6>cc7FZ)Ce4JFg%u`B{%i9|9*70npNwT*k}lVXv4`<7`WiBC-C{gf zPcss4Wi;X1VCQnWlgadcu}9LB$4tX)=+%A}G^5=TjF-u55zlYg@o)~m71U0H%5c>n z3$I+eZT9IA1yE!DE}!3aq@95Q1vX>&y)HooAsW{o6Ehd?q9@QJJ%eoG9crek*PAXc++u9VF_Ci&R>wWTFcStPs>}Xr(q05AWg#XOpZ?ynwzLmct9U)Gx34WDD{dv; z#-fX{ay{1R4!WO&z`oDE>F9ksNuYBkFwq5il-TIRaCy2Nw^B4}rE1++x=#+tk|UC= z1jMy;I-|k*IGa3rroL`upGvI;Z-}pJCG$X7+N#jYq|+O}!76RS67ER&4}CE@62}9K z*01q2b*#(_O_1hOwa(CRXv6+6_vw0tCYePMkQjA?(H4^unRj+d9hA9@L6dbba8aWm zlUZ9dO=3q@ib=q=AIDGmG2Fv+GH6d-#%G0y)qIL;f++BOw!oJfyTk8;PzVy)#GTU)x_#umpWf zj}krnXHys=)7QeQOtm1TzyNe6+_eGKrmVOp$o30?m}1%FS=<&xu*ucKxzh-EYhU5X zDw9s@rh#C07>96q6toYU&NrOWGd-s37^Mz?QZSa?8Eo*IPg)mLqPzh|AT+m>ODXB~ zjE5Oc``#8}Gz9dZf3NQ8Pymg+1P-yWv}+Usa&xy5fr75hDda^g`8_d~d;F!nTEbAJJF>??KB{VQ8~=KChfjH$|f(U z4i5e*aWorz*Ae$&8ffF|$M}9#lbv`k+H3X{rasQ=gC znxyktX+--uL@d(LsBi8iV>nN+e!F_|>kX8@`$vX?S}Tm-pSDBhO0(!mzvxHOLUOZ=x0>rlSJoR>EZ z8N}x5qM)H^3(?}mZ;AMrb=m{Xg2l9MKyD0`vR{a;6=`1;DD%-!cO5kk(|L`vO`P6a z@&k7MRa^7JN;epOYMXU=h8l6!se!u3Cc>`>Uscq z`T$7nnSO)9Adl}Se`Sn8NuWgCZ0X<>OiiGM!~_ko0$$FW(<-YwQa5jTC6JhC^sE+^ zm>_Wn74i6$&~_h*$DW~;JRWhQa@fJooh~SrWIm~UfB4Y|)EiwTCETK)PgCWs%Pl1;;b)cTd5F=ogXA-{IL$2veuh7bKzu2N-G zejRupq9{@FVDt_CQ(nYYE6Mcz9<&mm@j-#di+kwF+9`|iEef$o8}Ccdl&X7S16o7? zhKFy{d*<}LGHTqegr}3nRgF;^^JP2*Uqsc2Nxjce8+|b*)e*B5p62bP;`$$}q>Wao zzGT~f5h?uPiGSm=Wr)xo`N~EC%OC9UHI@O>uJR5lyN8-gqq^V{Z8ps;Bgy<(=ptHS zMHXq~`uVO-*UoPCB?4a-WLPO#zlb`>vJ_L;iOO?Z^#i7t=W^flQs)3bP$#1BRZjYo zfv_F`ytWyd&VL$G#2S|5n^m~R2QFLG+Lcu=(iD8y#&YQWp1U=9F!`?iMPN+E&Fa2q!GmQDzxeKsJg z^Rm1ZpF7d?Y@T9CJNx9}zDdGLL_+75A25fHZb)o32W>3_K%a~Ls2K+zGvy{CcPpQ5 zbz54W2@oIupmZ~LG96-DDEn{WM6(%Fv(Kbisvsb`s)wQhUtc#0AmxP^0;{itTDH@3 z-sy&VZWZsK8(88S44e9&@BCEh@PsJ8;pypuJSdH4!t1l?>=)fpdKQEtLBxRpV69{2 zd~uA@SMEH=&<1yUO?z$tlSo_IAZGWyL5&$y(l+9Ne|_|%iCPU zq+>{t6~VTH>p!Uoy+3S&QE~Usx+FT_E=t zGs~3vM>LIrm-~uQ{IS^P%}2SO7A6pK0Y?HJ5ye|^SsF(f&;ibQ^U%C7{YG4j&hfI! z8-Yua(ST9fncDQL@zUw0xEhq$Td3%>+Em1RTH>?By`=8srM;gE@qtrV95Y}G*Ub2N z!wxC{jn$2D0M+5olk9X%^lo?3n>OxdZYH9mlJ%RaSIa5~fVBOi5`W;WJm^BrfLhHv zi;anqhp!wk)U^z+U{h-*dJTRU1v-2+e}0~k=At4#ETxf<~8^dQEvf!5wow11pdqyyrE z)vHT6k-S{OlTYPR#oV4yio4qz^V$M_lLve?q%;&!+<3COYOwr*yO;xQf62-j+!^~V zV0Ki$SN;8LXEe34a@v^|=DM#?ynBuC{Kbb$p z08xHPA|`*IV9KeAI(C;Cx8L z`7tF~iG1EFKPEObDmshj3##t!h*1T@nP(gIv{ons?julIhVQ``TD!WreS@OMn9az=y@E=P#*+i_utZDPQ{Bwr)Ait_PSI-B1Au_Qm}m3dAd)?{f6- zr&YG#JKx{L=sI!QWWlV}v*`@Fc7cu>g?zV7tj2C|C2x#BZM&Y$X4SHb&j%#IJ@r%= zk)aeO`q3iu;V^uI>677_VFYND!PC4h4{bZDiWK%=F4$T(ogWnDP>;6|tO8rSgo%Dt zt^OZG9QtW^w>J-g!u&sVI2nKj#f|PZs@o;n#nw^U0Wf)|w7NhZaw;cn@rQoKARL0p zX=WRF_au9b$?A1*e_v)_jD&vm^8hm_;0-=x>E8gFTB1G5*W}hjPHlV$P#{^l!3fZ< zg+c>C#a6~H28=sht~8`w*#cwCW(L-&&=YxXPpE8e%iW>8SnIv-gKl_C@CyU~^Zmja zQtUfu2B7C?djSccJ`;BUCZ_!$Fg=2ZuH+{403h8!TAK|P$8MgWF2qR>PzHtJ*Y5Yi zP?>sC=PQ~c62#>U!>=`FE-wy|>{8nitF+IbolgXY-Gm!>0c?h8L&X^oUcAeJqoW^q zlA22WZ46BjT&P*c`RsIW3jZ5?gSYDGZ)2h6vuV2jjZ!g-U~HzLl=ej7Z1mQkBy28W z2-df<1p4;@VJ2=789=HMbRce!mG?$Qk9-6v^Wz3ZxT7+HuJ5GG3_&P`L+OrFE5k+6 zVbe#SzCx4z_d5WhuX%sQrjeoF2S8<)A9W~l_?H$qaoVf`E7o!8b(ZDF$q7&!!viDK z+16a!>CUKae (gwJEfg(cPjd_go)s35XFN7iVe%vXeIbTgK2<$6xW}@$p0Z<16hNv zxVqZ+M)Kz?nbDO?dI;ZI?nP&^i_usG8L<*Q7{wNN8FPi!^)i`t&2dB#g~G8927=RWsiZ_a-2Pf;W}SrrJYoQ({R4O zZ0-MFf{z)(EqXqP%dY4 z3Qv1l&Jggd$)Xz3#Nkf@0Pr3ACk6RghFAH&QZnMVCZe(Q8Nl>?P`zOcTp_8FGRo)e zew^tx%!$^z?)LK?iC-veF! zA>7-V9XijZh1AnLG{i9Mh>NaLfwOzbuq zqwwZcOna^AIA6<&L3a+g`{b=k+_hF*dCG8w=^c2@O^LGEOl)CeaAt4~Dg5Tg%Zt;s zI`_K8Kx5bFW~oqe+y}MrCFkpmT?Z3RrV0W8qs9a74n9m^&+^Q9%PuX1d5WfXXn<&t zRkGY??a>TZC!M!_R6mIfa81<7yx>(fz@7dCY{!5n*IScGO;o4*ie=shK$F2V`uU87 z80kiQ(=DzF;G2`>Ziwc3J*oq?Y7Y?d=}A^suBoScs+7tPHeH@JjX;w{!e&F4lAqeZ zRRyb#zdqkPqCd`!WreX!&{f4vZykh5ea>IGH;WHrWlSt@%as+C5)*8 z(0jXe{l3L$z%n)BAIrC*BJQ%-Wq1LwHUkoXa!ei9E)7%&e9L`O%*5qz=IF<04s0O( z&UI&#!;d3(8M_KLSh9(k6~9W%ZU++iyux$hxf`>^^(6!cK7MUy%&F=8$q6Xpj>n9z z$LD09$Ew3&C$5hDl85%Z02=M#2$B4Z29ky|R;dn7g4?4~n!nVXLzle%tvKB;HLm~i ze_mx6%k)!TmH{JI=G@E%BF)n}ZaBy~Fu=*5OrQVo&|@UXxfue+ReoR<$3(QFFj+qL zoLW~Fua%^=jYxVDP4>QPjdfn=*pF}o*)i~{1mdg+^2swVeyZDA9*T7Yj%XoDq~C1_ z9Sf*?n%lnB%r^PNNfXAN!lm}wuiaczbgj7FkyQ1{-NR{bC$dEbKQuEi3~oBxxl4zc z2wd~5exY4UZlaq8h%uW?nCW|1N13^>WXYj&P;qI2{?c;{KsZml(QJC_ zkQ{J+WA1qx7c3snqWv%pJbo=0U^Re3o`9CNa_|rk38N*?BVKKZJne$4h)#)wno-0* z$Rq~JLKx)1!MS;njj4bi5zzU0X~v5!AjJY;!NKT@en=oAF-Kye-UGlr?W+{w_f95e zz1BYHMN6H0r^>c_780&%ah$=$6|wZ}9=-v{RxMrV!;8D&k++u7$p&=#r%)!Oyaq6* zF0FX%H;1;50k{t~Jzv#R^)xGRfW+>kjUe9A0Jvo-l+VZf15plXPyNYiAu^aqNQrSV5+m`FhrGSK4|J z)Qn|*%|sUzj%JVm5P?ZBpF6+=|5lKtp=;{9&(+?%h95+Tcwtu3&I-!j4L1p=i%w|iiRhe-L_ zp!{3ovbBIectYC*sQqtBG;1aI(i8D*`k=&-)H6pNTwbrE!_7epTGEIwu@s%?Q091i zq*3>i@ZpTV|7)Wii#O-!aIvs6N&m$=uK?NGy907W-##Yb!0ro&2+sgnI^;;x#XgrO zqvWx+5M#sCSE7r*Rx=z3xEa&$+JN5 z>iwI>FX<#d@*-WAQ*DNqT(|B(X1#uVu_m^eO%08DmVvs(2(*U=y=EYk?Q-c8MLldG zfH7Fe_6D(2VCE2sJ!%J|ed!qBXnW9JxFd>MFO(rMVSB4I2sez?2e?j@orq3r_KT}k z2c84)dCf)fbfdv^0KU4e;q zaqavIm-GD*3W>5X=ZlLf84JwmSvVCn!F zKDiwxxFA9Gtn?;Pmi|_k0B~99CwAG!S3Svq0g4mAR=71N2J1_Ov-&C9{dbN6Z7c#| zfNEC2Nox1w`&g*SGsBUa<&|chwe2LYDl#O!-Ln~h>qA>zLGPfGwa!0telgM)o@fx9 zMec-_0i5iu_^!b9ujm^6x2P1&4L3^%pmU0he+aH072z@Hn=>;!Rw3f?>6pOJ$dZ4sFdYUu(y)YU5XRb!sY#QLId(-<7#avdY6q&pA6ueRW0)N5#h z*In26Oc-fVJ~!Z_kDNaJ6F}h4fZjxHXhI{C^a#+iJCw1-Vv=#Cn&^W{th&~W(nVdZ zquQkUZ6y9#rjh0LZQw6rRiw_4g@SjP_rAIT`}231{Qq$ef2QgGCocc{XgEdKb!z&0#^0#FmD3dmz7XS$CU*aKM)DEeAC6Fk2&6H##{eb9%kRb1XU zahY1@nmoBq;H?iyRF{d>dB+ebBQ-d_Ln%Y}#4az-FfQ){hegsbdFc+N>Z1?!;qPwp z;66OifnVu)A7D-z4+3iCCU|d7rfyz(qq4d))6*Aw>pYg#1o{bFQQ%FNAG>{`t}fH4 ziy7}9b?gWr=wb7ok6>5q>(a%Qxq=A_kg>;1WfWDFgB#dOVeaYd?&B$8WBl@~E7%NI z{_l_b3;|v78C*=NYVP4+Yf^TNO4#o57y!)8y*H}#)Lf|h-GRGj6M%zuls@j(EUKyT zSV$cs_~p-p<93B{tSC8d<UXde98PxLncv`xc{A} znMAF?cM|BiB*bKzdD;t#sC=R&eSM?Qh1U;)9`|1j4*hC1HfA`u+nYF!m91Wv+(1lR zeB*4lzXlLrAwSmFTapy=Obq3zZU9P-jPui`%X-^;OeDWNr&%!eWHj!yyvhUlu%>b2 z$boWIaCL_*J05CAm!u3OSqDAB>Jo{`nABzKm23ZrpsdCf>uO zs)(!c_9yHp7Hn1qSQ;1ta=l6_(geyPZF@3E!OA@Op zd=w}lU=c~qgTMCWt-4?XzS2ezdak(c?%8zUnAvbOvtyRq%~+Ao3MKOanpIBR##5@O}EH%^1~YEkVSG zWyC+f(l8?cjDJib`+O^@g)=CqE1pwRV9}{qpMyi}Y{%u`K{UiC;5an8wnHl*^_9Y{thXJ^VV zKUxauB8i)&{W*L^HUEtrF6WZ(vB>&U;lJs6S-pjyWb1RR6GQ zfg8Q&Z@M&&F_bJC5kGNFW!)^neQ!^p0@irM6e%WGpOcJP#Zdk=B%}BDJ@e<&S_&gZ zFNNvPr6}(zfAWnBp!1OHm*-8F6mM_5h$wn%50v{)Fn%_uZ~)MP42g191$gwSjDR_+ z1Yx9tUX%hIPoT~a-4Io9s%)G>{EPTWJb$R_i{1zB9F>LNKPvcF${s6sxwlu$9NwPv-8l3tk2CwL6N$--7OVS= zzuKxl&nencn<1Ier&IIS)5*$4RWAF>JE@dqsN^4c77XU}Yig4l7=_2uGde@~RKO2( zSjs~VyWVa232m57z6&>KPl z&Gc33(s`Qt?%hh*f3a?{=xwTj-ljJIi_G!ccvjUI)~w}u7o>fsUj@yWLLuQ*t)%K~9Kv9OR5^S zK^kkq+Gg~gr*LD%3Ff}z`E+A5ZhM-V)&Ijmwr`_zxLM-*Q;w_u*Z=;@$U%?A(5|i? z`=_g;FzAyywXa`{SG5=#e^!k{=h@8r@jDaHg4VBkrFZdHe@V#7aM^AeZ;<=0vK}YS z;}MgS_J=RkHJn07_pTq%1y{i$ap}RGRD?C@PB_B@JtM?pL@tu8$je)) zjN>E*|BWqMW<&2XrY==vpBHIfZ85b}Qu7Q$e>hMb6x$#mVyxbW z_2C)eViA(<)W5=>LcHC*tLbw!{U!Mt@8?qIF-JRY6)R3dXBoCo>f3oQRViLzL+&Xy zvlI86RLa)**=^d_`0%0zb%qW3>pLt$j zqyn+QWaMM+rJ@RvKDIG;A;J2ljfuOR*6xj;$MyP34=(WFd{xIRB6N4#L-9uVAX=bp zq)kLu{GKbz{zGjnfP%B?cGK_p^s3RK1X{uSR?W~3IdN&mXOmT4Pho7+KjCB_+AweK zs$1VY^ytQ$g>&){cuqS!S4md%KVv!o1iF%<4YKr4gZ!~;r(4r+=mHQ)-DUeQUQUL8 zRU8l*+qVQ=ts-L){s|8%*#e6mfiHSIpAp62SQulG+NW(i_Wu89d+V?&-)>!00SQG3 z0ST!|NSB0kH%NCXT_P1z1Q0R za9w)AobUU@h0YyH`Zj@8WHRgU2^?RoFm2-s}Zi_9xTHpKUI!< z<){q)xWvn`*XDTcN`H$_qC#7DYLrBC$RQnOfzQO5FrBxIxyAD(;%ovsuRc5tZdtOKy zAV{feP@8T8sN3NqfSG}xUIKJ{pI)12$$tAql)$lK&-ZM!~s}ya$pH!`XFNPo4F}82-Q^I0p z>7*@FuwYCzH@_*_-EWYOy&Td2A5`pcQd-lZQC=qF-uJ3~epJq!$7RfBgWU)&K{5*t7U0#`zGi0;T zdBJBWaXlY?!9S#OM$z`D^rQOrDfXLzAsU=H`oSKoh#bc`4{iPiy_(lfIf8mgC+`Yd zQ~DK#?Q!o6e2#eJ-~QcS=`hmWz`xbibNAZ~{J%*+EInk5*dZ4xYaNZprqBdBXOIut zRs~TIkBy8HAgm3)1S{sX@ejnQJ z2itcJteNfV_@0HAUpO+1Z#tulAH>Yne9MROI`pfbPBk0jNzu3eA%@@Q zLc4EsICI(KddSz_eYDOV%-Ih_w^v-K#Ye;xJC0WDHt=%VNa;xf9O;Xl33=YYB$J47 z3Ffu~)cI8Re`l2wf6sAKHMMS;STjFi@c;avDK*1DTsG2f_e-&PaOzg7Kf~fh&1)Eh zxD&D;^X_28_~<_MWrWM$D*i#^@Ys6;gZ!)K?``eWWDPqWSnq;_G9_W|x3@JCO7K-y z0tu49$6!hqDm%E}hHOv2^4rv#$(G3*0iB8@%Lsw|dB4EqbY>9k6*A==b|hh661Y>A z&J{od*9Z+wy;4#gDx!EiXP3uFwzZ{uLtz@pLd8lE^3b)qfY|8ha+xc0LhJ%JEs4R5 zy`>P3&)~s>%~PfwhjJAUJhO?Ri=AA4pC!j2A4p7yTbh|ZGF8H!Xa2sHEIW9Pb9t97 znK8}XnlZa@u3M?~wwAf}D%Nwa51p|vBR-bixnA=}T;C0QeZCFtrTVE?9=EY{bA*0* zZ6WHXat1(Lz{*U4g2boN=peD|*_7=-j*Eumv8Ub|)Vo!CyZ=t%zpvYOtTzmIg2%7Z zq@^<=KHnrU`qk2nNQmXoX;*6OPP#%68?eF?#c?8J=V$maH2=H3lL-MVmc}#9sN5oj zUja2T2UEp?sa|+PfIA9*{BdXFw&n=$g6Ykh4+g!sp67PFuSq?4>qAR{pVbrWRPmyV z(IO?1_26y42+b^+x!1#po1R0b?2En38{=Me(~-7g>|BUHO)@s7(^Gp1dlJ7{*8;i( zmqV#8ohA^TnEb&Bx2OrhDoEyR{j+b;eAt~op7MqX0W)4Wn3^Y&6?0U&a|I|kQXsY% z1TWgH+;}I&6PBCwxnPvGd+B7*y@NAri{*YO5l%tY$BP5PdFxE6iSG)WZh=HsNrvuu zsCjL`$)fYHcuarIMvZ;&M0*S)X?AMp9;e9McEWPvEfLCk_=bZc!XHe^so>(rswl3<_sZ|)q@mksR(jojbephWEK?U|Tl(eR~?=oy&V zc+B_5>-(OCWEXMcl<34RY~B)!;$ZmU@I^r`?L%D&6Zl`LgjNB~GMu?#I(|dZKNI!H zIgL+qdJT)>M-yuapLd2WzN9gbk_5b%nb3<#36o+Sbh-lTF}7O_{1$dwus>KuSIZs# zULR$w{;{OGXMfb^y&JvzE7A!+2{30;Jn!CBV678V9&`Dt%;mq?oVtPS)t0nH?n?cD{v0z*mGqka*>756f?!hcLdj4TkrAc(^IbZgxX0BI znUsYiw26m>Yn2)w3E#jE3nZ_*Zr%Fka`ISgU0XgKQaw|b4hRwe5;q1%2jW_p`ZOzq4)9pPI?5d%G~c* zMjR;$T&XTRdrKZa2Sfv1Jr;4%0pr>v-jdx1;WT-A&gFxb;}%H=F(>r{HEHbGd+y5vukK+Q<(jHnNDoaU z+WKlUF@k&*A6M>{0V85zb~EeRH1{=?pjHKaCY#>|4Q-CJ>;i-=^#pqh2gEUV2sW$A zSq?C^;=kuko>V)xijV#)fe-1+h0d$Qzs~EkDTzO?36MUWhtuU>ZkRvH4?Nk*twDac zj7ItHTofnn`SX?rhQPly@@;1$MLkqrjdRp;lOHEU8l#oc`%mxbk{KrPlwOe8Hoa-^ z56zN6MmN=QvKZ%8k%=WJ9yG+;*Fcak@1GMv# zuu`8~ofY4d`t!uG{{HJr`5gihmTg~hj9R+LP+QgY99uhj5}91z-;4Gz5V6R}4O!sp zGCt9u$Hl^+#2h8nV#PmN6dV2U+8#6WZrR(?!zILv>KbR)5og7$Ow>4YgGSR+Q(DI8urN!IYz`Z?8$YJ7~C|ZE_7hMCqXnL(a1+{9Ad75>Xs)Uli3o?M}k* z%na+&1Ute8i|s^ECNAJx}YQbGR+y-PAwKNZZ_PVnS#n zc3IkD2#Fn#8NVSFXg|&Ikx8@Xkr>Xr&^Jt?V;34DDI$Y_h63H~Nz8DXf~>-PUbbs=@#X2_-ey$rN-^i^m^@@*)N(rQy~oxS(O7Wr5j#8Io94Z zv@|-UVM|Cp+)GyFU`pqVZIMV&Yy6P5E#cZ}u~jQB^f0&mfrjs^G}7^9o$RnjTwg7q zYm$QH$*X&LqF(7f9@SCgjN8JD+iU)uSR z{b$;ut~cH5M%kf{e6cnQo5oU2?LcWBK5=aC06j9-``ZIo17Z+^5gZi*DQ2W{VDwi#1`j)%IW33f#!)5}J zB??hi`irzt8~LuZ4_I3AeyEw*e>!+vU2=0%%Wc`oe|r3`p>+aRTOSSO%THueuHtOH z_XjL+VX;_b@9l;lYoMecW$@O%Oh34jriAP=SvUdH!KYO2_&p_#dgoiBgS+HDVXJ&W z!C8jG0z;}TT}czc?*R-}RA&ZULzqj~`;L&xE7GlOB6~bR%c6XgMP5^gZ|mVlkE<3R zz0EP%iV{b96n?aBK{3>GxMdX2oQ|aOmnk*UUSI;ygcp{lR-AUCQNh(#)S$BPOP%#0 zff`u5bzt-6Rs0TuUvr#-ArG(K46C${zUacT+9l3ckjhxTVA9)JXS&ATnesZeW?EP+X&yny_Zi2XDl6*J1tHMzgk8A_^#eVAG>>M^wiciP7C?D#0 z&ulKQ2)#E6F#I7uP}7Of!+PR1cNVcT(&c&FozMSzrDHbPd-48)6IWVG-mp>TkL1c# zC;tS|*I!42UaBwV+DwNlkLz!g(FYATwOl9zXq!Fm0^0&FV9vx;M;Xz!^!*Wvt4Kbn zevNGWez0Jn)g;n6Y~CyBs zw>KW#$Yo4zeEMpF1PnUNALxunoGk5e-#TrsFUkM+PVS%WWF&eg-GPiUiYeWcI9cCi zDn^3!2m+HK4yenePgl<%KThQf@M7iGc3XJyTiOq#L`IoKhj-#Q!yhN$9$0kQnOYB7 z1VrG>hkH}m-F7)_l0^92J;`%HHRr~!XE^`7oWUcOKVzEkPT*aI54J1y8B|# zF%jC^B$pU}8pmq5-LB@J_HGDuIsLs}Z-%IclC$OX0?}r>mZuY0IP7SZ_ek-R%U6E)Xq@0<#fh8Hu zaS#CoU(t9A+f+FXecYaYz55A-@invXvLms+{cNjoi=f8sc9$T@k@*SP8%8hLu9cEI zna!{g8!^T%XUQYog*)PR7cv)y2m@KEs^e*@de-rAOm~-cI$l%tKCz`5PA6`TycU8^ z5+{$&%%AMJ`CGb$Fa?u*A+h53D|qKvwN%CRL0MXPYD>rb6iK`&4pyHFeX>$mN8Hn) z-9Lg?UntI=)V8_gjeI~kXN?8LQH!s7X0_gLQnaRU{dsinWllW%qFd`|^nLQtsEFX+ z!#@1B)b{CKkjqRmISQ&U#HVQGfFJLxFHXr&0J*YX%bD?h@zVb)2#>aTb zZi_4ZuNFWl-F@6?qC~pH#K--z*{&iCu_x(kHgNe7f$hT-X`^LtxQ&76n`_S=CiOXr zzWY$~6^T4VvQNBo7yf{c`6f*AjV@!%haaEc_WINR)Ume~J9v!k)H`C;(9)gM+_)cP zGS>d?XYCiqPUGK!hsRy6&Mpd=^Ws6gV+$^x0_u-=NJ632kmpsAO%KdkBW0>ny7#H= zWzt^T2k+bfmjV|joqE*a=A)T)ts}k@Mce}%p%=Co1!-Glx-p7OTE1PU(_-JgdDWB! zQp_*$9*RxIC|do#6UZLe^NCbqe|Gt#L$H3hCYh`8&!~)i!iNGrV3E$@KEYrJ={_U` zNVlk5+CdbD1?X~E0pnHb^yjIztds=Le(@prCG|~SDQkzMr{8#tvmOIKWP@X>5-pd; zMf+rO+I#s<|I>`$(pfK^OB89JS(=?NzUqs_i>2Y}_!b&=Ok>bd)_G$`bMsd)T$S?D z#ZN2}d_-3TX9MWxLrDrK<E7wan$@L$4wqX_nfL9DVbYFD@7_kdR(2ogI2uIr&{jXfC^aWM?OiV1O|qHg z)Q1u~@1|9euhuBf^(Dt%{@e_gC$U4*Tn!I{4NLS7H9oOr-dEp?T;n`!Q)RtoF(Yc! zNifh_R)DL2>BgV5z0UiwR|b$kMW9#FjFKjk4js3S$(4_^byid4LenSY| zC1bM&^5qxdH2y2RZGhai(!*X8H-%}9t^%ZhOUM7|)TPHbeV z{yQC1PD2TPnp*qTx4$?x6!rdq+k@yR8M;Mz#Pn%%|lAjg^%nx-U0MwT0! zrj#$ozHviAgPfE>?E;&A!vhO&N>1nyY8-bx+Sl)hmK|E7SUvK*sxzGTFJ^ntTw)~Q zupkg1!m_(xz=a#`DUQ!gu%|hkS7kwR`knnn3)y|sH^l__rZrdjY#8GK1 z2LXNz2TlC@()^<9NqZNW=rucNmh6Nx6%&nmsNS|Py@pvVA+B|r+;lZTU32lDwJ2=t2>k3G|I5IZHqfQ>ES4_bRy0hzR}ZLQU5R?Rv5JSca5ask%Md_llb2&P)7G;$5330*iiIPzp$ z_8gckLA?hT?o|grC-~XlG_|TY#jB%H=TPD82`@(BFf$@Zj%!2%7X-7KJJd=gWd`=? zT!r1cu_^CuG2_}};74;{o(yY-%B6XyM{K?qT&J$esM;X>Ek0t7xYRCLXHj)f$H^sr z|JbGD*5@4f-jgVftgg=waGjGLTjLdLJ3P;gqLHVY*Ae`~FGL-bqt4b}JH3J>TDzr- z`ebUom~HLCl>NB{JHGDd8@x@hqP%8kgWI%O7W{v-`2` zA1m8AJo#vq@uSYb8!`Qn`o8O$M0d&6NBn656063ia-9jE4h4ptlA@){Gahr@OWN-l zGL%s(udR;FD%uoGH*kMl@OeKe_Jy5d#%#igj4*n~cMTYdM`0Oljt)%bI35W`!G(OQwX$-qGC-9H5DB zepZDG@sfYG?!oh_2=;)k@yH~7^`O^emB)3{ZM$qH_dC7eWivRQsO(oI{6@>`uN{^B z$bsc@y1wC2!>6>o7XjRv&KSv)ZoZb6@v$7YkRhwI4cRJ(CW@VkR2%ho=gnUCru{TZ zEtb%eSCQ)Cj-V8|(V1j1gMB3W!f?~|S-};Vpw!-jy_6DGE&M0u0}Uk}sqK^AZM}i% zsCAXSDK^?#&POR2I|}>N;;(r%u>cxtmw*4wqBA}$)?%K&I6l`e$Z?;bbI;xfZu82; zM$~4nsr}Xwd==@Ir*2*!>$R)4>rd`5K#{t63J3-Z8KD*fN38~<9(dF1`RkO!YH%p958ACFs5(E;ag z(IMu%WOyM&1VoK5Zhw0@#(;RoBrkzG8vVE-#q>`NBmDi_RHqAdRg*T?SJY|XbJ<*o z=gv9?{SCjJZW*>ET=nDerG2)rp}2h6d2QLeOMtv4G;+3Zv504o;SXiO4k+kwxZKr~ zC4nuwVLOk7rLTlE8X#*9#9^`GzeYrN{&Xp%ue)wZ#|fBghc-ir0nx?m;O{rLeLh`} z=}y||^w@K<5{C79H$}F8yygzK`8XUzNx2jVly@)xtSw}k=^f1LBc<5^xe_IW|R>^7*%iC{i85r?Uo zYR`sBP2;AV9!%55vIn%_Ulu2(T|j(lg6IU3ZX1G4Pjwrc0#n5&yql-n zcd2$!Ke8plKg7@>Di6KOe_!anb`ck!crZghhgqFCqyNRlP7PfiGP4;zdF4UspqiJV z?TyFN-cs}a&IuAfJ+P;R~3c`n!ho{rw?QM4!$4D-D-Y z<%n%y?Cft}TH5E5me z<8BH_SUCaCP^=3eSB_6i*m1XSCyAf@`3bQP4@2puUnZls&`R|J+{W3Lx%w)ciC5&Z zjAbQm*mAzXub&NblA(LI=6rLF?aNn*<wYLgk=ha8vVsOXZV( zfTM<>hIXTIe+7{1J4oG|Jr*x0-%Fk5M#*AIHu#U5yPWJbjpv%&Gg`lMT-uq&{J66l zPRApHAu`buyk$;y#Q%^z+BLu-V%5fDWw2kz*><;T{1;|ZM!^rQ#aCmR1-YT$f6EIq zt4_X7$i!X9lFKYQ$QvlFUiM&B((!L1kRW5T_}OEi5IA4?sRl9AJZrgY;Ja%HeF9PW zJ_JIv5qBEG`6)4$VpsCIwG+F4D3N}$Q1&#q8pGj35W^7;NMMJ(Z*PQVG(xEzBP`ED z9TwX!q8_c~o&j4H?EN$B=^vYI8R-pDov8`=G}fS?0}{C)WU4$_ZvTh$wP-2=f6wfs?O%xA24N9>UKnIQI!be0T8$krAO0rF4cMU9A?dU{vo%2D76$ zuD(~dA|!`P`To(DBr;OK()}18#o-mEI!0I)4V@lX6o)_v)yU^44yzc@X>`6>S>g}@ z!7K_85U61hMv97&(?EK#3G_LCEYk1ag_?mS0Ly4xHv&!h3%Lr|eN%)P*shA!4m=Mt zL%R+@MPuA9X`DtQgp-8eE3U#FX1(Ro6oyMU&+om| z)){7^2w^7lh-ZNZBe_kW#g&D*&TA^>eZgw|fW$6Pj;uj-qigqc^x$K{zI|vVhxHm> zF9z;X4S*=N)?Vz7yS0THbn!zb!Gqi{q}+-FZ}>FjGJPUMPm_6QWsDEEQFGI#K)d7m zGk1LD+#U1!A1s!-?5<2L6Pdhxn9pfg+XW`~y5dlKCaPzPc(scGNwYf87wx<3VuyMD z0eSjh%^Qg;bb^4B2q`wkeFr{}I9IRT1BMvQ;@Oe49D!;)eP!mtghos9X@2yY?BLOj zqsJH|VXER~+PBaTvc&|&1?%W&Po79G2#ozR7-CCR*#$~WpF2)=Y7Cm18ylwxf;KSy zA$5qUl5nW$RSvI&?fPGS0y*^7QjT(9`4u!vuHrg_dTX!3Z+_2PFjZN%5ppaAMsOct zX%WtH)(*I~lp#|1;Q4v!#>B6{RHunX{zfVFXGgIO{WNo+YqSB4MBaVDkn1emaD6)%=_VJq=oz{c*CJVs<#NR_PELAbcu0b?fD+;u!MJO6O@nQ zg&;xtIQqef!-D}4$E3dW1Y9;cL|nvI!_@ekAUl^XjuBT(3Nxmyv@PoW4QH{RC_eqo zHMX>vj)BU-4rEUEjO9di>nIt@O25{e{XpYiOvx-s@XfP<&;f=Relf2$*H2H@^%{V7 z8iD`Pi_-EYd3K;`Qqq?zDQSFl9qR4lG(r^pyWngL($!Iw%DekSMoU;}j`W03#VJ64 zr$4$|+x5*TL5Mmys8jrfesdsNyJ55xE>#6q+>uE?;V6n@Cr>3CeUKPxyd1IJ zd=We^E@PbK+duUaw9(&E77sq(<<$p1A#nbOi@1PU1siR@5U*z>msST;7Dqn^XDIX> zLFW~-Vq;-D5kr)1+y|Ah^oTeCt))<@K4@VT z^qax?^paJ5PZ$MJJOs;m2cpnlg9qVxSVAT535`CG+)71fzJmVX4?eY><452cJfkH& z&zq#Op}nPe_+^R6=tmNa7HdfXn2ZPZPSxa)TK6J*+;?InvYGd~%=s7^`&*yeZ`(+w0I9VJTkDXlMel?f_n(xY8hEyyp zav>`|J~DFM<;OV*n~n!!dJ)y}2Qk1@e`bupup)q&q`+o;uP2vA@_mFK3mZIZ(ViAq zk_2&1*}P1!1p*14*^M)&Vmu%e=G<=BIsQx|ocZv3m7 zIfj@0<@_9;P4>_Gr{@CD++VpRLjfYX8z0n%9!cXs* zV@7fOz5~C4;RVdSw&F_x{dhpBhU4wq%f`Dx*SZHbZRSYk%gxWsucqyS!BANqz zw8+lPfBpY|lKS7UBWeH3R0~|$t1f+{42aRVIFJo_fETxV%I_TxPLLLy2m^fv>&n3y z43AbrKYrR7v9IF%@hK=qfzs28__}X`dXCr>q$Z+L0tA)503F=gS}t$Uzk&eVf!8^H zORQ}_V%VL3|9GGJiUSQtJA#kWzEUOfOWyhggK2q$lo)t|-**IHH(;s^EF;)hg3uxV z8L-rH174it`GN|g5!mh*;ri{HAOMKv*gau zGw|f&fC^>ysUcwBJ&WH#+a_>9COEhg2;)lDgCJ}o2Q+o%w;4vDkvh;5(#uRFz>En~ zwYoEM8A}ilVr>{6QogzZFMi$Sa1Db2av>L0LIvp&{J8vmQkKbJ=LMt7E>eJK@j?dk z-;j?8vS!)gXV$D&L=)%c!^dWzzpDwfv>6DF!w_=eX`Cjur8$1((~o|}4>5m52liRu zWFeI0;y*p^9wzT^zzB&4UpR&P9D7{`9SY)tFI0FuoTm<=Q4IClgp-WmtJGnRQrL27 z*it2Ue^Wj=DCM&|r|xNEdx2(p6c{*qeR*xI8kR0+t#t{cl21jJSlWQws#vo6OM9gU z3rYnW`L+r^9{a8bp&@Y9y$5Q;NYIl*=8FvPf++LLhWvhb@ZzC&4g?rf;7}vz`QpDk z?FJZn1DbDdA9S`cxaZcP9weOMtYJX0DQl;UiPT7UNR_+Y8Ymhkw4WR-FOT*9B@jae zh_}&24S>8r{3!G9mld^#bticf3`$^q2ABAd1Q_v;IcfBoz;zX8DvCHoaV$P8)kDX; z^=C~=x6d8qGibaX6k)rG-wzB76bwEI5Cvb$ewfB>@c>A8mRAM?bv6W0?0owbv@%U) z0ZH+mUHk7zs$6y8cnBHKpqBa_r1;+|7?A|ww1>w>GPjRcQ&_Nav57i2B@ z9}%HX3Kg`I()?Gygme!}P4c7HWfC8X1J)VOA%1J1wpY$+`f?Cb`;v$F;kzG$k3Q0u z_Y6{lUL+veG6`DkTzXPiy@3cZvg2?~l)tTgvxi;Kr>;eT(WNx|6Q?7^7io zekg*T6%Rd&@~)6U!cNCPF1)5-2gIq9U^1U0yMAREvO4u@Gzaq5i#Z@Y*#Pvg)GJaw zKl6T^1cV1uBfE>|VG0j1_$+u?Q7yQcmlpW(LL9;BjEN1K%33QYOSs79fxLpLUE`lW zPb^2~?!3jxoJF}k)P}V?l6%tNXwrb-h0FT_rHq?xe}2Rku6#eP0ZN*yXVHS-IT`E% z##^v9__vaVMPRBb-yFy=mcb5GTETaTVX8=RsY5~@@J#DEAWfk0H;#}J01xg+vsg)w zD0-W%kG_?CD$r#b00G(=FFyEqB6vzsViSJ3O2BE{(8s4aRb{Bk*T%qbGNs!csMSq@ z&_qvfLDk|Q>G*i&G;r)6KMV*Qwr+!GVGcw_0O}4nQVvCpH^n&zkdK3fcTiFa2A^Kd z1t?00sSf;Uic5Xf7E@fXew)3ars>wE&}3r;SL6HLXQY)AF(=7rKMh#@8qm|zLg){~ z>R)Z1d9eQIqbOTL&>?CDxz@IBZbaLWXU}eCW|rC)eAo;~n+7%2l|J6ynbyzp7fogk z^iw%*d(!rz{Xi6Ota7k9AGx#?(Z|2&EhD;z#d0y^(ta&od0OO!2f~u}4#j^MVgU$# zZrsGRks(10A+rox#A>02UxfMc8=#wA{*Ej)S2DOG@aQLPO(WapS?hvS?3tUt22$g* zBISS8%>$V`l`K)82jTj4h7CZ>x1WsL+;mF6_YLjw5bYGT&lfGdSs?q<2q^*oJEQpy zZcxN^Wd~|F6VnF&6~JEs33{E9KOUTEl_oC`m#0xPGZ!2s{>2eNtO#I|C{!4#Y|z>J znUly9K8pZvpe@1vaC(PzCXK!Q0T7$F1PXSq2i5D+gTb4lly7D0rh=KN!JQ6&Nqe=t zkU1(G=>mzs)i8A(LUOPSz1^n6lK@=C_Ckh%a|{C!j{PhpaHtL=WLSbAEVr@pq--RB z7i$$c;Gok9U*KMpziL3%vJtec^!s#-JOGLlY{pLGm33B6g+BoK!JP8dEMwuFQJvMp znbQs%zJrf?!mMw4p@t9;peG3a2G}O#*U91F8KMj+JAn!O{n>#C<1M&p_?SsQK(lkq zh65!=C|IQ7LzaJPzQF?xkB^fvURS@x&(CSn^!a3E;N){Pkl4)6QZ((itv^rUfuN9~ zzrVj;4(RvoKT{oZ`RmR8FPNb>o{%R6YGRVGIw)5M{jcj8VJUOH=>~J~BX`glj{O}H zyQ)OqsQR37+b+~g(f7Rl>LM<5agqbB5dta-!<{&N5aB&Gv}P{giip5nD6O;VoxU?&x16A*;tRNx?5z2#NrjN(u()56Bs zq0?+(ZFn^^XLH0=^G%3V-euRn%UNqL!6?iA_G;IcPu@z0e<41O^}U#kBoh48;aJcf z_l|3M51F>O=kJAvD5O7NQrN8=YjcodF3TJ``on@l&DxfqUi-D6_H5AA;|iq$gZOdy+fg(|>3>O3|Q0S7%U4?ftON@6L_3S_Lz%8H5wk0@KO!;AA@I^bh4KvtVbRaOv&;K#Zh zFGYt9ZYX);QqCoK@miDvx}aQZ1l*DGSF%S=q4xTJyYTJ4T>fqquzBU6o zjkH{WqBQXlIRO}g@r?pKZX6T%#JVEk)>g2{9}|WN$*(}SNt)LPM46vnYF);#0InpLp42DjP9e{m7RN)~*# z^PQ*FAh{pi$QsM>){WD+p&QZyq-$C@2C@}2IpJRDB;*;iggVbZ zu>&*K|1^n^oB?`2EO-2JX=@71DBX)8TF`GWb!m|Ry4NMK7#T}7XpLC=!&2!13Z=-4 z{vfim0PN#=0gDeT^UWEc7itbSe$aXBZPMsCw!u+}ZfsPL2am1$=R-B<0Gg~TV#Z|z zfp`+3DRm4zV)P=Lo=lzscnARus<$iPSs)yhhYebaH$n>GU@34J$Iv-y6BN;Z5+y!^ z1IUbk7%#^m%mDz0XYVfjEIHTIJ6D=zYMR7@tr+G7X&jo6RKPJ0Fq9c zP?fYLq^Q9lrR;DW9fHb2jxl8Dn>;x8n5DwF&HzR4#UA2+2-<2e?~Rb4m-;p&74j+> zI@V5J7a4K~P|&y^Q7~0yWzs645$eQr6#osSGSDOfd1qfhinJQO_4l|@gAMiK&^D8? zRwk8<2b4Qt>*hTmr}K~fja1do#l8doA~#cTzF#!Af`uOeIO}Rz;nU07vG2|bDUZ_B ztGP{I*798)f=JK3X`Co)3O@8ALgp33vuN`Q9_ZI%;J+e85C}}{p$(5NKnY*`g7<@X zMdXVE>3MiESP7mvX%GJ##G=A!1T&^|i^%Y}bV~0|KMJ|$AN#&Qe0xwf)t(bHnJ4V* z?q2JC|ChPTgApQ`%f6_$)KERFh7t&_?aSaF@LrhuDB`zGogySIrcG1)yvDvgk@H|6 z&L&Wi*{`(-zX%ETZnQs+n2_8GY?>9mC&?IqPBCcY>3I~edfql4WYiVRETrW=0y1U&M7leO_aw6K|LH4qfY(I#wDDqH{CtRWmGarPla1s z8-u8=^6H)bAgb;KwhV#+rp?eeoFh?{0rr^U`<7<*7A_jIuC-1EW{$C|yr;HHJ2czt zxBd;`|Hm4Xs+wjFXI38)^>w$YhhGlJpFV4BZvI%)<8g_4 zkhWHC=_E4JUb~PP*_AgrPlwHBMP}4X;5o&QdJp;t?tRqAu)D6mV!axe9BBC_(n;0j z9--`%#|dZ@0V8Q_c^--$Xj}(9c@)~1SPH6n0j#4Mz^ouD8^vI%Rqwh9Li~uzzK{TL z=ZagLA22Rbuq)@cATJzT7RGhwbENeC6H0m%NDWq~*B>b|-S7SCzb-CQy`)~#dE(cn zWmfKFdG+k*H0grfQy;E}eh}SMV80GJ264b`Iw=LY(6PifT`8C!-D}jc^L&hplv;W`7!k|7~$CRnf z{Z|kfKCRCd|cC?g2)1DfgYEaR@IrpJDjN zn9MsY`E$1E*8+n9@^Nb0dfPsPd@R`%CV`kGD4A^d?N6Q; zX9*|C!8*ndIee1<#`sMMD{dUXCTy&^96-#W94g`>gDw*&R0GH_Nzg^w(h9u4phBtE zA%OhkGrj*pe*c!ipoaC9NCWX!S_Bv-GXe(saf6*}-PzvOy?lm80mbKjey_nnPavK@ z_32iwbA^5tWow2F+S&bIa38YPQ#?>qly2%C9k`^ETW*a6Dxxc(#mmf463XmwB>za; z%xr{1W4Q!fzrOmOW%REGMX(%l9Pu+z-NTQ5HQ?7HMn9Avzcm2hO2Q;G3GCkwQVB{fHsXwn z&LG}`aw1DyuyOP}+U7tXqXpBR1_L6wv?pBrhZq;wr0zT*WZ?4sO*mxo^lUW4Nac-% z*IxcyiJ*E=h7=m^=}BdG*_{2P&q1b$?E8CY!@*^^B;R@ky!B?O7w#SK6ugyRVDy9_&?sI2OX3AlAw6uFkI?_iw`a=Uy6cp zFZegS)C?UCdjJ!r{FNIDV}ZZJ2WR9U#msN9I5is*p(NDx^YeSt1Pxql0$mmcyWF6N z4~FeXjVSw{9`B(*&SNkoKP2UEKW=vdEmX!}LbB%J(f>YDhWq2hop zciMjl62?fpDS8i{s%?8JQz$h;!Adp-6CKY7tT#J@%Qws7jI22&2W>{2Z)8meg;7h| zd?H*JQYSzhAZ@9Bkv<6(b?7^(kGckRh+5+c>zaC~uukkuq=~YOQ=N zeKkLX7CG%y({y7Mk9}6(P|)*vYtaICM~cD(W=g~bP)|I<2Qia!ZWBFE%HuVWo`HK? zLBa6fhp|s;$Qq|y1-1viSDp$D*&;()l2;RVb)1*FIGor$jlQmv-e#IkC`&OfD#C(! z8+A@Jsddg&EdMKx;On0}EH)et%#*D6e+LrmhBbSZ^o#UR)BSnw$9$G+-Ts^VjY3M@ zJn@#na1Rs&|3t%cKa_K(oUeV(%AnB3Sv3C)GVC?=<^bVvFz zDCgF{KN~64envu~h8LU137K&NYfS;iL`R$|XA!61dBn*Os!fur5qvO`nhoX_&a1)o zQ%HQ}iHr*x9)6d}5^&OQa0EzQ$XS1E)%|Gpm@uY3vc%8*)z%NPuhjmuD_Gy&Hj?`{^ltX>EQwj{nt-Q-_>d=&U`1$YTH z`gHx0242Be^1Kel$5+q(`L=>@o6-H_iT z-g&kSg?0l_0QBDw`#6*T%i0S*f4ygakgs!Bm}z!DJ$V0yC`e%{HZ8oTY<7Wq5W>-< zqRJXzt+v7rEruKsq@2roY^kRdDSZbFFpKe%d{nB@VM>E!k7UEqpvw@#nsIJ%eVxF@ zX+PgxWvDH!@aH&UcrZ{aYfgOmG%lz)f4@QfhVNL#mgaJ~j&DSBV^I8ZliZvxbGc-> zVoQSs2Y@QZN$&#+oW06f=nMcobm$BgAm=NHZF4%lDH%fy9KXPXe^@Pv>_)h05^I4W zJQ{J&7K_u(5kLZF0+c?F57iL_ChVqqq`jwf(b1M@q~csQ`*;;`IT0I?#G30GqKxzEV zLGbBbFbYiZu@x5=Pd%<`CHzDVSIYNs(TTyrcNwwcHKt{Y!RMq+jbEZ{)!30oG3?8u zJn5%L=f%nu_al2~jT4qe-5KjyI-3*`-I~7ghXwXLNd_t0t}I91gnBeu7m{H_&$~Y7s=#IA$d!!A4_;N62)zZKdbA@`!t{C#6LVPalJb2%9WMi+aNZQ)UQjpZT!J3XSN z?*;{n;Z2GtMmh&0xS>M1)gxL@Z|?}JDC@*BnmHXa5oYT3!EUZ{18x_C6D9MZb>h^7 zAd4zRVwDDNk%;`xgtWq@6xW50{QX9=!^rvQCRflPp6!Dxbyln4SMCHsp0eUY@+ycG z>FY-eejGeua4>@c;2pqb`e(IOM#6WWJLGLf8YClWgI*qQrgt^~sIU#D;jGnZoqWXC z`fQO@o+bJkJ2CHeHNWnc#KHUtO-OeQnwlKMUp>3M6>th1CxY1ni z;I+hKyxWG2>Rxhl1%3suhJxBFa?1MiJF|Keyx0yN~ME0p+jFti~nChqUEfz`}_j>0w9pczcf`oZGUS%k~K6;Md!jf90{LrkMEyFF}I0Ytdgns@FXkIwJEJ@sq#!6k9CrXo=)5vNuNEH_l>B4bYksaQv z#wHyBEmfv1+7+pS4L9ZWhvfPD?M!#rYmf#f)zw^Cd^R!iaAU?47M0Jya^C~V)Mej- zE8Ch5axPB(o=4?AKu(=LDPd{Yo4R>v*si*{xf#|d*Xhq#Dd$yomvi~q_TjRb^2Zj? z5dKiD+{0cwh-vHve?xU`2rbC%bK9O`K-ZqovC%jd7rx4UWU%lo)c znsA#xuetvdcDh4}%VOLEqTKs2PdqkW#44E1jEAAZSfL~pT|}(@v(O(EXs${4Bw|e8WD2SQgUNr# z&OmQP`Wq3>*cQYIfRbr2GNZ=jeGKXM+5#)h{W7lXTp(Io53Ib<&1CX&^Hqy>zs55K zNHA?9{O7@2+7r!aHpXk^+#;QYR$m6!2fU%{1t1=NMM*+%1`k5s+B<*?)i-kdzs07L zvQ8N5Zg@H8VZwx@{~tTFci;5Zu;rGe)cmzyKciQIwXDHS{ek!kz9_>6TD*-+*E{FvO+va#Gah!joHzjK*nk^uBW%r`sBYC0Z~aYPS$%o9E_<4@g+Mrr1ZanA!bKY4>e;SpA8N`<&Aj;egU2%Fv&;0NlV zQCQwU%NH9@{v$>79l!J_*!N!`@y^W#gFdCdX;CLxH)v=XG9sZVKy^T6y9XG+4|L5+ z?piU{G@Kf_65`#Gx3WF8E%RA;MVOh!ig9@lj!ESgz$`5|VBRvy`4P!d8+ZmenSE3M z{L0L%GyH}C>Vf~Kjx9Ea(F`sF_g{&NDWKO7Zj*K=KQ=84?-O~_t?l6y zDTwoT{NCb7KI#MPYm6f_%^1meY4JAsT zr1J7Z4N598plO#U-Q3TC(9Kg#%OBYZHXiv=^2F^hS#tk63RGA^aSI*7bkl=s?91st zal~|QT7ZeTyU*;&~Sr0bCmPdzJR zT!Xfy(R_KblJz zJigdMQ%PXX5eNj$E`A>{4rVbO^hsoh@352e(cA0Y9qb}lU28M|(u-lF?(7^?B6rOS zF5-8jB*9Bb0Mann@{3pK*zQA25JuY3Oz4#M_=ofolnH1fZCEDB%9xr!`j!6K&mfqP z?c_Vy89e1&DYg8BJSyHSI%kE)9TAq@2-!^4Y@9znd83EN{zqL!E9+RKH2B(K(J#UO zl+3nRe%E3UM9vmR!L2ZBo!i=HbEh} zT0BFv>>(rKXm#eK3UX@q#Mv0=<%lC-2}(iGT=W%AS6|gK6SyM4tX#u!8Dk!t5rLpd zQ*;q5OFd9&1d?JSFKCzBj6rH=54An@DC;H$A*l99?(tM^I17tN4m8zx7aDO_!s<_g zr=Vp_a?b_;P|_v)U}7b?bsE$u0ej1`+4H5vLp^B}uoo$w9)dHBR}fPoHSD&0yTlLD zULedkBh>)P$~$OW+{5Q03cTV!q8L`7fys{yNYds5DD$Zvyrkzz{=C)H+|F4GC5oX% z1!dCE6suJ*$oPRg%kVnQ@yS6pV9x_}Q(VmE6FQ(kJQoDYNoQ~yWfY4l(5YPvruoK+k=jc) zk&Kaxg-xkSd5D*>@r&n(x^%UsTmq-CGKN+X54{#ck|Cq>b1+Cdq&O_iUghRoAf!7E zbt9+~E2ik0E5qg99beD=BS)yGd)Ga2ll@67thP*DUj2-|^5zORzKbdgG#ASD^f&FP zR>r2{X2hm~jN<6{#vX(=(r11G6$NVV7pD-RmC!|TNOxP5H9_tMo?uKwfkOa^jRJCY3ujdO+IQp{K&c{~brlAu6%PX&9iV8~d@(i1nEV^nV&9>opz%j2qc* zYA|ldIY*v;g-V=#pu7pZB<6o1T>B5ibHQtl>1N(9lsJh^`KMJmm0|C}bW7dnad&7! zAYA7a%M&HWIOq1AaY4KO=Y$z@n~?<-_9D#fU?>1)2NFTMwv4|HGD|~ zl;d0gF@>Howq$r{TNFCcvgUYHJvA(>ndS@Z7W*oRs=B9uKto&bQYj$1WIXz*DxV^w z_X1smII$Ujabg%D9;0eH*VUo>jJZmW6Vjd?HdDUa(PVKz{8;ekCl`XwfdGn1!xE} z0mFbveQ}5*VOr^fraVOLTagWC41`INRD(Hf`$i3{yUu8S%mgDhXStMMv0+!s5gxn! z1t!0nXVEHikBA?uW>jn;cP;gE46=QaztRA#4{^9c-@8`3XF6yK1%sNwEVCsqzQHB` z*!==mWBfh;lQEfiZ-6!Bo*?&mGW823{U~T9wPwz-v152MQ7@Bsz0RiAo)E9lp3*~# z!+-gfWeK??zNC!DPDA;WBP*t0fPHlkh-a;I$m!K$i^@};B$T{JHY&7z!~9(-B1y2$ z@8T$?r(j8HwX*e6obz%7>_ofEoT=ocm3=87V?lyL`X7+;KSQ+iUj2g z{Be>TQVSQ8!8j%d4==$N1Q28g7Nh3i|s(Myi* zp-_<45+sklPk+(>7}Fl1Avw&$&}n* zQ3q4`NzzxQh~v#6y!Jm}r=SI(!+-ri%9oKHf|co~CLF6(BX0}lS^?I@IkmrJQ5LBF zd}$Y6IhzpANWS_1!Oe+I=iqPzaI6j$ElMg7V%*CUA9FMGN;}k z*F*n6uB-=DA)Lu~dz!gWNnVRaOZ%SaB2t{&05t6*->ZyiN^mc?0&lG+QWY=nxy8$a zVP~l+>Y^O-uo+OHgbD-Ko!=L*+}v&sjA-kwmlbtvC|uKRL#`r6qq!G4xMwWznyEu+ zItjmcKiT%gmuqkkTeKfw<~| z8xr6jKAyl=@MiL6<3u29?VP^^lj;8q!2~+e7&GJ?>p%)Z;EDjUE0U7RBwEJ9=P=k< zuS!>6+}#HuM)~Dof^U*7Qiq3!iC=HD9uMP@=#l4dfa6{bJOH0SZ1xubu%gdoqG#ly zKrA9y84t)cj`l%7kbbLb-sKcK>c21d46IZ611ZRc8rHM_n5tD^itS0w7l;c>AyXG- z#AYVdVTRje(WxDRE^F546Xx17!i^~xj)We{Z!8z`I%NCc=Q4K1U{Ew1=Yj3)fZu>N<-CA-sNQJ4=@ngv;PlZso5u4cI0&kfehHz9Jc^?3BMt?g%Z0fy322U z2nb|Rs7dNM(hEG*duTaG{;bAKhS-jogkYoMAt0Cjg;V^hiuB4SnpGY+OkHq3q z%eSWA+s@o=mtMS;^0h`|(9^2?GjVmIY9%n!FDvW8Kt+0n!i0-~k9xuZ4^!JXd~y;^ zgZp!YO~adklM^?v7gNE4L^Jf|>Tz^g5LAPK`3Et!1zX+k+drRQwTMc{ji_p9{B|2T zDKBdIrQPVbP+hY7z4CJ5m45c!NX8(FDz+Q95v)er>6N-&bz2o8EO=_w+k3XSsmiDh zL`k|a4sSiYuSr!DER(;!!CYw^jKQ&1el_H&!o0r*6#l`*a?5W_eRJ*fLlw;L zUajHzoMsFNO^Ukn66s>{14G#oI@v#H57Q3u@qaJO$5mm`N8X%IG1aYibd z^o~Z4?yQ7Lfqm~RCE6{KB4u5{WJaEW&%$`OJtEUoho%1bxo^mE-4Bz=QGnd>sTsSk z=?c8w{m&MUATv!ah1Z~1*2d$i@9(NNa^A`iv`ua%EwH~Q%MhlClu>fZce~vkk)eg& zxdU$!pP(S-&K)hlO(>mXFw`*@w;o{ra}Ffy+Qvjcm#lg$hsUR9;#~~FLrDg*nQa@m zCAO4=#D=nz3iw~eCR}Xy0idY3Goo=2u4{~hieMtB_$U($28*2ssmJFt@DVh@rs~Uh zCyF-mkV|zy|JzUAh9U}L297`4Cs0J58pX*OFEp%^-d#mIR@&RgVaBv6)2fcvT40~>H_SIpU- za2qlkKKh|1$4FA=~$Iy@VCPlOgHT~Cl`pfO)bid(QD|=u3ze+CQs-l8tPc?+=T$ONa;l~j$_tBw^r+4Vi*{m>@9=-)fnWo5abf7h`q;PJRU zLjuWG8Q0uyR0hWjFi%1+(Goy?t?4y%`tu;-q7=PO8#?|9@NPQIAVzMheC5vHDFjjW`)3aSCG5nE42ow;l+FX+J0W0B@ zgA*E3l~V}62U0&pv8xZLC4n@bWAX@)gk>My6nZ+252@^ga-{cUk&ZXRXcyaYwUjzn zx!+0J5|!tRqsQvBrM&fXs!;f%IGy^R(7&<@V_>OluMR><{m zW?_|W(X}m8p)k`qY;~oEJhI4}VOBTj*y4m)rW)5zWYOE(`PgM%>O1&biqmU?p$c}^ z42We}u%qmkjy(z$9O8kVnm>P2HP9|txhp0fEWEom__~NgGiNhEWl%Eg6Y5R}%odTa7!(amSVO@U5 zRLzesJHhM*ti-`O{JTBUj5^dJ zwzz-+__+$ls0>Ku;HWF!+5?HwG3P%A8B0hp!x25;1}exZzY+g%4W>ReLLMeFVcg}n z?yAB?I?9Kb=A?9+-ZxF*tsa4m*FR%vnA<#B&u6m^S??~wdri;{A_`9#`UuhBg4~Nt zYyplMnO@{r>e)8jXrbj${?xojydlWT* zv*eMa_w!|Xo}bU_`uY+6L`$=~2zdJbv%sowVcWU8{sw2hfhnyUFDglM*1G+)RJw9_ zx`tFDUpATj&3d}hB-Zu8lBanxd*}vEYONmW{Y#(!G=FH9V(iR?JZZ(w9U3*U zA{CBcfDxO;g0~Hku4TX@r~+4#-a8m8&t6iul!&qweKe&CwKs3*Q=NGP@>=e2ChtdM z-_{*nC6`Gf-@BLFGOdU6oNR931d=E2o6;XD^&?fyJ~`K(QQtyK3_A01)Q7b=J#?Z< zDi%7Qhi2Aur75t{8y`7qd6YD_A+eMd`k84kn(n~H?rX?DH+~QY-yp>@?ZV~|kn&)( z$shDOVGWmly6Djghz_+{l1`Ih*?kle482jA@V&$Kj+;A`gYi^}<@O<$p4%OU zR~0nff{o=mGBnMQOW-MfAB4AN3RF|Ou2B#JsswthbXM*rKVah0ONU_?IJTYYVj6hW z6}nTP@eo=t@-=|wmV~6#y%Gg*e63aOfzfjn3ZDmRDv{Uo{8`e`rcS%Cc$DsPg*Lm-hd>J^S} zQGQ~IU|@(1yoCt_7Mrftd>yBo;R%j#$;d%T%-sIM!1=*!%qJ`u`i6%E78u8ssLnhV z84GLT;9S^76sO#cE}^zCy0TD~cf5B6kxET^_Z3yiSqn)nk?E zqnhrf9#&kT)F^+`q-TEAmC)!i1pKk&$#0*la!HpZXw(FJUX3pfYld_cISlfSRk= zGF`UC7qQtQ^O61JD35}xczD11wk~r@S{a8XB`&drjKckSd>wyDKT?_c=rC5kd8~L{ zQVA5jgQLf*oeLHS2{Pu3U11q8Ue#HosHCFx`;)xVJy|%eQBk+XCAxv|!wmUssy;97 z$~4y>4sq2zbHRr5pj>iD5N?)nDdjzjOj}j{MVH)gDL!I;DUcx8M(~f6{tCP zVox1;;;$Btu&wNSG%NusN~Qt0)x(jgxqsW8;x&W^n`1o^2&7bQ&-F`jKX}H z!hl)WU<2zurL}gSFPSmm(Yu|SuYCqeJFX1bL`xv4Z&r=va^vpl_>$)`BYpHM-S!U< zxuqg^UCK^}Kf*Uu%-3kvWG`ZX-4b-)>#N5$Bc?(+3YqaJd7gNMJNg4XmX#V}BiiMj zaNYA>TucK(;bjyxu)g!E{q!0n_T@0hGb-%mQ6>cN@aLc#WDu2fF9G)syI+?dnZ>7- z2%5=WlHbT7J-D0BVwK&aPsd**NR&(s-CPR~*@T1BTfZvHO&56W!sw^)%Ie}`N|j(O z)^@^B(Z5tI(M)kHN78!3 zc7RwUAiK@!IT)-t!Jkelk@n~6IO7YIWVF}jqfV<5uEaernvs&pYOxwT$-S|V`E5Qv z|1moldjFtQfViLZcOSoUi-O6^Hdrq8Z|fN3 zoE>7ri}+r|az+H)dAEh}A;686=U6m8F?s=S#juQGpUqUDqmIWnY6s&oUOa|DFrks} z2_^un(~;J{vnSvK5sov$w|^`gj(@Fxi+Ic*++r-ySyTiDI}twhjdj93(d`Jms8_h( zA9AaJb`wWVEiU<|E0l6sA}Os$HFL=&?yJ3&isADXROe_Flx2D>+J;=x5leqyHbql! zxs`^*TOSy^nQc=w_Toh~;q-RDJ(^v=+PFO!DXKeUK~Y+mj_rc~t*np*R-7w1zF$OC zh80}tMu%=(ZsR*|;*Sutd{2jy@k9 z7b;w+rWR1nVHr_B!xk^UH*(&1eS4~hGvsx+DYG3CxE9x!jg1JNvndV_-caS+XusWS z4}(0_HwiNgIIGK321c{ok+-oIYe&cyMMJABrIK6iF@4RZE7(We51Jwr7K}f7UfqD= zj?^#>>4!OO`n*HyHY$UgL+d&3GCn=ce%|x&b%Gn%A)rf1GR^Y*lIo`aD8WmRDbFf#5|&q#(eJE-}#H&idDNt zaXC^U7|XOykj;AF`oJ9S+I58r&rnG7o8fE_XXp=6j#^&#xOg_Dv5{M4Z&@hLwS)bE zn)JEMvoak;MLpdL8|Xr_RuFc~9j;es=?r+85JQ{0>6uWcWC~X=gn_{Qwns(;LNIm| zMCOmEqH$95JjE6cw+izSyW;Bsa&-7Xu7Wq&_DCts^P`vNk{5Dwzc;);nfmz6G3R#T zF_Z*y9L@2Ubf{yF9(eCljfqS0TbsScTp;jxGVLk&VBFuhEqk$#&a3S z!}B`p5#5Y1!-=+$>cAWO3t#|apn3RGlP}a|eu=v?8fo;eOn~awNUxEp<+yNp z%aZ1o7>=m_AREhzoN-}DwNmSbCr@1BWai^0As@?D-!VQmmXG*)|5#>9_IfT_vCtL4 z-PqqtX$r-F(e|v4zYs+(Ch!CE+H_GO_v#mq8mj%(WM~Za8FOSCWiYgx-u{fII;ws< zHR3nwGYWc2`4AiIR=xy`+NO4+M=-nIdZVmX*3E2sPo1B>Z-4siPl5b}r5CieM2oS- zwvVU85EmWw)7ZjsTiW}o>48OKZt=Jc66vSi|6;}f1B5^y_f&t%5!+7JvyVG&(tgft zlf*$>srs~`?sWaDSQmpyEU<3!GB(V5iS<-^h)tGJNI+_>;-ryo8+(2*7HxqEyo2?NZ5nyVc8p&dO*e}l*v^<#LO z3s-hQ7#$9kLse~uJO`786oU~F`Wds1NZ)Erx@DNdq!Dl>> z-Q+E~-vfW#Zlw%<{fRs|Z&dcgbxf{&u^nx(R3bfqTl2VUa5EVC@>}wB5d&T`^{xpv z15C%NIRxqypTn5)0FOBXABO$Y`zA+~3BOQUbs)3s3f)-eJAHqbFM)(eup4R(+WQ#x z3~4!|+np99Wfk#xm8|LU>AMEOSg8A-J`1$JycH6@if3litO=lRjR+@fZXOH@J<>yH z-|g4-*+Ftx9#ygwciS8bNS-u-jlHf@SD;8{e zps-;8)KB#sOKFa4O!MmjhqCT)P-IPw>QHylAbSOq62vr@K*Xb$PZ)Cd+IM596P2PK zUsiJ)9M|_<_v@Ft;8#FpnOO%}b%YR_nRoBTvw86YrGCSu>vn#m-Rz?Uw-hv1WKZB{ z#9B;T5bnY)U!YqZ7}=uez9&cT_4PNH%237=?A7-5cVLGh0(?qqiM&v8ugWAP`FP+B zkRXzwmVJbDmfQ{r2@10Fxq~s-T$9INAB?c;AA><1Pl8G8It0&G8h$h`1snm&T@jVq zQ0de0u-K{B@5aGnfhw`~mbCm^QJe7D!*2IdmI*q?OSz$h-ba{U3eafTe_x?Qg@+_7 zg)=AT@ZiZs_RnQ&q-QR`@Gb6^%hD5D%rNslzU3tP(PTh3G|#dBd5Eeg)Sj`vFJF}w za?NiWYfUMtP>a0vin%%lumvM=YncxpvMuEjurSf}lbY}#*m7CdJTmO_$0>Mjl^~&} zuX`p`#phGrdB0NBz2@6#zkdQ3PriL~%^mRmL%I;(aXb3vjn;D%gJd`R=XH&}u@3xd z!6CTsWP1LfSYwh*GT|92G*jD!D5SBtw$hSI8hGGA6_#8`bIm38{7ELCPRce>%eFHE z{~WH$vz{-{r6Fj(Ug8lg%I$9FV(Y~ROY4Enw|U~}7ug|D<>*4%`4cA+wgz@UifT%u zN}AT*!{}c%!ZQ~ddL+Zgy774FepH%k>krmDSQ<&_*K7mKLC_@C`htaS65P!89s2qC#Z+7uaM+Jh#|Ikoco*I0=bT*>`f9V z)Q*opUxl&s{%H{+@B4RM_a)@KQ#s>}(X639$us_j=2as!E!dYlKJ$6y6tx{ds{ zp4)C`g^`(&k}TyYSzue-Sxv4l&UriOOcqDo`8lGhgMk6zlM&nAzHfD5y~$(yo6N#? z8J`5}@e8}MS?RQVnB(ING=B=sM9h67)~L@|m*zHiSyE%rEHdx6L@u&EZp`Ia0}f(^ zw9j2n&E9H=gF243dYZ&rflkeoH*oKhf=OcEp@*O}M1OQ;h%=NXIT%X8P7Z?)lOKrs zwn;<8eMMrSz6cc&c31dUoS$qfJ@U$|t-(FmyLr6miI_LA>OVFk_QlsLN`VR*k9sutmAu5m3dw;;GwpI zDp_{`6*UOz%uTM2Zh^Wk8w>TihiA)o53iVWxqgqxVrAM4?ROqCN;^Z8^h$;YJWlBl=n!g?Egh^4f7hq^p+1ts<&CM;F5mAFznKx<}9V+H! zvKh4q5qx!$ZQ9|Aq`Jgbz8wM{p7F!4$$5KqknsbTRB{adUZ^l?Mtl-AXv$ zAMrcK{u^)f>;k+bOBC#(cj*0UWT8L_6=(7baB^bq^zx77TX_sm%yYL(qAm*EDK~e3 zw@imUVoGB<+>jK`q@+*=V6e~;@GGh>5FZMEd36xoQoh?eoervp$s)9u1X}JC+5;BU z%vqT%$8rdAi_Dn@c_w2r5&*bjC>Oy|jfiMRw9QqAYOPD5SZA3Q+iccZ#N`SPWc>3{EWp|p}}C&o=Q zpsG^3et~F=oO_+RevN&)i3i29&`NSKQBU!V6kY$cY`VU9o*3(UktY2Vb!7|-RwtjOmn)VY%J|E>QOYM;;ELD>}2drp=6v4Pu-4{HL>Cy1Hj z0pe^@xc5*g6-M1bvc}f^7S6hEgoo;rMx`Rf?HcIif9)!r!*6SXD z#*>D+=nk`5VPT=%-#(S__`*WvGEOIj;bTT4c`kM%Hu9rVsT<{yzWTcmZo{G-(B@uj zzCGwB!mDf@6-C7IelFAX+V z9Qwhe%@8!{-oRC9G5S5vm{rj(S(4;Ih(k~ z2S{p}=@Q=0F*dWUp7I(uKS51L+7ge-HT+f9bYOQ@gfkQ#*qS2muR2@+xb7>wV1T_Q z53?-Hr2s+ip#n(6Sq?me@(pK4sV=%fh0ZM&{K;cavG*JYikPiwiE{)N-i4y~8+xVF zlC)=s|8*IHV@imL?B(Qu|7eS$8{p!(4|0^22@y1=6Qfny5XX1{0S!g6^HmkFq9f&I z?r*-!W2=lZMdH;Hy8`b^^S9>fVj=PnXc!a%)qhvzg^fEFg4;$ZEm9-8; z@@YuBSp~}Kwx}CGoi_1mR0Ws2mRKM^mcID_ zDJm>8dtG-e`<^=t=H(Cu%OLEeLF0E5lIP+Kvpkz~l0pus+uQ=Gm_*>DRuyt_^rcGJ zYFQ4!n;ar;begSKrt?}x%}3ndMMBDN5kA^2;4;s%YlwgZ<;2JJSz^`oYo|XQ;qgi3 z0ENYC@{C}pkT!AJYad+@w6h%k=mvJBP$=8!dw?juaruJRQQG!)hJ3pceBE|hHk<{1 zdu(CY?mw|n@i#UE{u3KO?A}J6(Z2-aqQ`i88G#&MmO>Bee_Yj3L%hzAYw7;*8dw_^ z8cUcIayHcaaFSg)hR~{To4etOijrQP=NKZ}RXfJNeA4}v(dji~IVKYzioonj?qU1# zbB(@z@APOe1jt+D515z=*2CG#g0%K6Z9(}70cxq03r|!V7o?7m=EVBRh zHG^M{yZgs<%uJLAarXO9`eD^l9w@PU4Hzyabg!G^QE&fhHX&xCi+NawB&R0Eu*@Bm(qp6a`Go-J>%zx|LA0afZnrK&D zz=UeM-f(I8Hmoj4n7nS45yw`}tN0K}H*WcZFo-VgvW!3Nq%2A(a#-)v?UG3PW~YK; z$Q`%Qw?yv6EXlyHp=C1@#)pmQF`#GcmxYN!X}UOGKEbISm5mRD>3;_`GQaQHOLuzY zjKO{|P>VvvoH`ECI@U|Hk_}W?>c-dg4GkM+q?_qmR^lOI08xI}%&rQvA&Br^^)?a> zGsEr~6&gF+pL1_PG{z!cP3+2J8ow;z3sSpDDt zRk(-C3EOl>5FCNll=1g!(`r}ZpU&^X(fP0~_=s^0>TQymcQXMj4v~xd1JWj^w0xt4 z+avkuu;n}4Ma1eZJ0wc^?&N7?3x`HtE8(tyY%8}U`(keN@FRxYIpTUuX!82v_!S@C z+SoWwK1tUj@kb+D!=2&}Gzg!VXSG^2Ga+Fvnn?X*mskBZQYQDF;%+kylV7)+mv|^F z1u_4{&11>m(_KiO8~Zo072|c6c%?y|VgvJJ*ELzc3w5*Lr`$ZAiNC0~p%gp>A#GMH z_@t7I{=;1ZRAl&{5CAVVye4O|8ie+f>L1|C+%pd=op+7k>yz(=qW1sJ&;MdBHQyuO ziO;~_Y8p_LP;!}6P)t^v6#+?|!Agw7+{c|;4Zl^gBhZ?8K>3^g&v~RFV}3)L4-p0j ztR@DLc+N|oXSoDJqyj)=cbC)n3-aE@rEa=8y@u64eR8)q>&$(}CDH2MWh5=9>XLrNh=>Jq0<7%h4PuS!-JC`x{fyKZes^0bKa z6HYi^cp|b${P+WIAEpPEti-z`9N!+^Z5mDxGre^U$m01d#v1r`EP`hMl4wHE$21Oqq7iw&g=)Ug}+_&RdN6=gx_+~fyW#Jd&vga zHXU%mb&^wmy1|0S<$U*G)p;ic@B<(v2b~@z79Ua1-Bk_#_qWl2ZyVgTFVFsWx+qiw zp`pUdk8Th%ISk%hINV048ihvn_it4g6bL=uS^k{=4*_14hQ> zjxwMFWH8_hB^^-!dm*BfB|4%3WhPBKhs-d4i{r`szi4XvUuX&)Pr(K`);D#*bIE@> z3jK$~N*I~N>Fzi5o*Lr~?I34dIYvA#w?{pS-)@G;Ok5BfqV2_h@i7nvfKMj8`{qy| z-C&TtgvAo~FG*CqqZ?faUVm6cS_ucJgTT+vGTDf8V1S8ytFH&TM1WWj&XZhIwV>jG zDZ~1Z(ezhMMe{eui%=sO(}zh*?V3REpA00{&pbez_4YnREy!iS-eW~>Z5uu?YMfYC zlidE6fU}Xzm2t8wQHj~7cj`4rrEzwc!cA)t5S2D!B5bbH?Mun3#xrz|8L^VmEW)9c zSSzmwnOM=O{6#EZh4BSf5cSCchG|F9^kRPDq}CRAy+|1AB5DBr z&vTB^|#C4KdA>9ks-W)Fu#7-}_(MXsE)D*Bd?7god-wWDHo zU!S=!c{(JM8{5~Mx{+ofz&q^zhuJo23b6E+MIY?@N^E+4qlim;d&)7`q3Vi;D}3< z@2D%Z3H!OZB=C8Q)y2j+G2qFUS%E~8wLA6o)78g_s@qM0er(949!iJ^!P2|8Vso-z@od zwOwu3R^;vKMqlm5Ek*yX@?B%81?jZgjfx&RbOX-(ieez%EeM(-JyHT`daPQ!uW7C| z7`ibwq0Nt}q8&73e%8Ih)iq?RkYOOFbG|%PS{OS&pk8#Y&4^P%Q$zLT79{|^BM0?j z=2+l_ouvjv_~>0I#ku?aHdTagh(~xTtA5oQnS@ux&mL`}=PzhdoL_Si)q!WM4I+r6 zZTY`S(4zr11n|ZjXNUm^KjSfTvE+F`9WJ$J1fjDaRt z;mbIGA|}(52Q~0_K%_MP2B)sO(s)b_p1<M0Dpo!1x z<(NnKKyikxVxm(zUNOM?eXG?cORvp$L$A=nb92{w(WQ-P3ZWACoWl1;IF}S8hYw~O zufbj5eOf=M8az$`2pT%17L}9mw~{C9L84`o%IslD2OqiFtA5nSVJKcv(yaJUw6tw^ za<1Tj7r1LNUNfOY8;s&wYr#AjHpvm|7Pbek)3E|$1u+eockO$4x~L!#A5iGr{DMOK z*s-1bBm$WeHAZFpS6B(`%{Zd-q0i43c^u01in-b_*%5@wm#bckeJ0mb@G}qTp}*0j z0BVPSMH3)L^pR?I6S4Ad$dVHliugxsN! zoIdF8Ky_Z|@e;kCv2Rn@GIS^V<-xcM*tJKFSyrGvt*Np5nrPux(*c71+x)3iPR2J& zJxP7*W&XHom!XD4JMA~TpsHr_t$drhZ=T=LJZG=BV&l(5p<;{ggm~z2F<$2AGcd!l z@adJksZdik{m1l0@HQYSTCY21+i= zrkc{N&!ht{QBp4S>Kn3=oq{4P01X*agb@DfYo=-Xt)bqi#wFIaJ|( zl~L)XDIk~*RTh436zgpt%jR}2sdQcZ+z!_@lsTvGzs=-^OPs=AMz^!`_Ryu5ZlKOE zX_ULCB@39{$wU(>K*AqNAS^>lNMYvrNGr)j#~sPUId3p83;`_PInNrAEnhxNJ7rWN zi^I|DkVUY zZ$%CNho%*lJw)Y+G`-|dgJ$&Z>)`I`G*bNacRj<3T6@N0dVTkg6L>B=_{nN(Oy&pq z&Ux(__#q!77nO1hEb?8yVYfp!`6`{p2?=>S3(}cVV)#v) z8?k>EiWcq(NpPK+I9Je$Ou{R3CrxH1Zg|HZtC`Y6;h6@PT6rBE=2^@~a@VfL784-<6@!cvuO0?%G8PT07GFBe5KsP0#)rr&#@mAc zJp17;1qS6#O!#eK8=(Q}^}@OTTH$YwkjjBQp57Zz*(z>KHA+Ui7_ zR$sh8=qvF%a@D^Lm8s$X66${TT^ho!^+x&vltD2kPo>JIs6{i1Jv}bJ&t7N!IIF$< zOoI30?XbRSA*6~Mx?$KB;7_{GsMf2ag}mBbg!`z#(A(1=`a17&aMwk{0Z!HV&RDYX z^AJFAE#X*4+*BRZeNvwrSxkdbDOJ?xE3qzqo11t}mt(nJ#8e1315!zuGbg`B)wG4) zrEknD@=s(lYkrp*54V2PGK8X1wkCOv^d^2b%kam2o;~2j(L|?$Aw$G%i2Jo1}fXI8)^%FSs4uVC9!^W%L0JF-&?A29y9x$D&B9K>ZNv2Xq8LYntY z0Mlfr6{dktgMY4}aoI8V0IL{c+_~VG-CZEpMc0PMB)!-&Ha|1uySxSiI!>#&6 zYGC23Z+9RKscr|^wVadimoQLUm*(Abz)jUa_0;2^lHHha+vfP>3&5D%)Nni>=3ptx ztX%U484a=iS_Sb_R{r=6ppp|hBEvYN9DG>L_U~B$iG0g%yj`WYCEyy{#^YHzL{Y`4 zeKy>i_v^z+$(;-FX{VeMMEyu2v0FSFL%Kw!)zzjQ?9)LH<=--Qc-*6Flk%%N$Z;rMb9GuV$zITf`*p$=}y;%p_u7 zTo(CgLGNlrMAIs{x}UT8!pJx1u1WN`Tj@kDUrY#ee)*xmg?Rph=GIpQPZ8)EGJekBP*umiLa|MXJ8*75Z@>rynZ}*i}U`#Vfur zq_TOC$YNTS^z1A`iOi>q8B5MP!Igm&QC?^pDxTA<{&Vk>!FL3;GS=)emJy72<1s?? zhT^Q>IP`6~Jf?n}i56i_Q0&C-^vB+7pU~;0zW)|AR%ugz!uVYTgbjtWSKr-!0r8t3 zHahS8aBKL#splJ5aO9QAsSlj)f+7&^<^hhUBBqhA64%J$dm3do$WB1A#^^Q9a{g;# zm{NfNg|0qf{DnXs?U`6Wu?W-986CL8=dYE8NrLGT1w=T9)bP(8#aV`Og}|S6I)k(g!yS8cO}|s) z`HTj*y+hwtyYt}OOiF4oG11fS%=zkW2%5U*tySJrx1SLR9_N{@?(CtcRHca|;&Q)V za@)Pd`9#wH$BHvMpDuU(NPx zs=k2W^m4a>NAxDgV{*uLCE&_#XLWPm`sw6TVChDbHC`j>XYHt z-R8PVx4r`vD#kzE4yO>KIo(}%>(duYfL+p7pj~tU*-S%9Y?hoo5A7w=u9TNsT--B9 z3Y;Lhgl7)8PadOIcMtvIQ>44dmZq0@vtMf8PEw;H%|!L<^`PL4+~T}f#E-o{lmI|* z2LFKFaY=;O-h+J@ej2x1H$JQ4as9Mq&m?JEzmndG+x2DIT0^V7ZA z*kiL;an45wW8Kt&CeU^0mXX@#`?$5VI3v%C}xwq)jTLkwqkynw&T`sQRuVgNXHJHr=TY_azDi`~ zcR>g|h=o4A#uG>^j#*8ZFK-*r|1NY-ml*Bqd>YpB)<^zFw0NtLW9OaE9oyRHe}4FL z(_0q0&q-c2kuue?{2^fe%YtM2ZtR&xYzR$lN~qH{@!0Vu^eYoyahc-sayEUpBR&9e zmSuUox=*llwQqWt&@N+grl&q?bCos5o_^Gu+lse*yIDGnK38`e($cs7%QJbe?T(n2 z_=1)|(P*tjbqf0Wnt;x_8ekl#|HC*mZ6DkFh3s{d)aG;wM(mZ_F`YfF2f2xy`;CGi zCg%qfC#Z=S)p+!?R;rMv_kKx>^?lM`%r~@1qooplN=R6p99`~Bs;bBmhGU%PA&66X z{38De=xBqRu;Sf4Yuwbws2`)ej7h4S^YfxK@b^szf0>wBs9BEi#|3K5C_wrdUzZL@ z-g+wGNvAn;M!mrP+(i9)8xm24a=AL0%efoLg?+su%!i2dS%YD9AGw!MP z`emnN=jnRl`B?2}NM3NFq+O6@TmJ0P%^aJAR(3{V9I>&e3@X2+D zQCz7o-J?zl-aCsy7;~yT8>^PjbRFs0YcNM2deuasvj}`S`@}CmxULCVOcG8)(OA5M zwYjEzh+6JOm7ITHFZ4|W5By!HY%)lcl6Lu2vihUBBlg2=%KFR?Z8`H(715^OO{b1s z9P(~;MXrQge6Uw$+k%hQ7Tl)^&vfJM?|zIt@7uw?Wj*v^TWULfA=t|CKpDc-?+W^- z=s4nW{vf>N#Kk7b+25m?5G5;v%30SuGBu+G2GTV13>40xsf#&HeKbL1(Wzg^8&|pQ zZNbFgZdrmw8-oBw7J!EGw6AEzc^*7`wj|R?m%>{lxpMZ6i2a%6A_K2EGyk z`ojgz+QNR(zWAmk&8&Wxuhq}jz3h)*tL|XS){iI2`&7RzMoHnts(WpbbJ{lR0;7vW zhWT)t+EUy{@YSK~>Jv~KC zZf2kHn--(?)Em1!_o?F}o=)aAIJowLBP#?=9Z#H60?(b^EcJaxP24DE%$X|4Li^k7 zD^8F8PsIQt_d~3M5Ylt_w1eh{$I_x(2ouLooI3{A*3GDS%F`c5)^$bfeqX?WkzU|# z3hC>Gd!%YTsK(n`4ZiseSr0BTSLHeB=qBoqf&z0+^KsTtB%Om}BIRH9IZch3mslqoiIc8uyzHFKvE&Ao)zoNoA_H ztlHErv6`k~B~s2JSADUOawJq=xPGl6JF4i=%r0n|m`3p4vj>U2I#%oL9PI`ccfA+v z*W$|F2T((;>H82QS3)NKtoogz`=>N@Ee&+964N@btX@_R?|l)y=F~+kO`aI9^i6wD zD~|(Yd`tC`R#bD0W5mA1J1QI%2?U>i%LcV_BGazs^i`oW>nPJlJ2R z36jMWyPo*v_h52|m?f=rn&xam0P|{z`AoeJC11se>QGxFw#XSCN4^xS>vRPN?p1iA zuT+Wx?-*pdJWxKzQhV{E935gnmz2E`p9?GQhQi}^+c6R1Ch(1n;P7N5hE|$E-oWht zLR@NG(uuv<>Icn328rPh2rgwd$7)$s$Kpm-)5>TY%|5nz?>@;y<$|t};ozXnZgmLX z`HF|d{N}ATw5etB5vc`ueWOGVaxwHtz3h}ptb&Wl)vqz~O^Is)Mt!?FyJ6Bk2k$2C z_MDj2%a~xOx*>FU`>Rc7*{cijK~rhfG!;pf1zM=9Ifn<2sw@%rR4Kl6?1?8%yP0W< z>noh^t2}OOKln0v6)SS~@pyLORBUMe*YNvdbcT6f;s?>D#)_HDE)4f(oS%i<*~tn2 zU#z`lR1|L4K1w4cJpzh!Iy6$!-K{W`0z-F7OGxJsLnEM+gfP+)f+9l-h=L&92ofS7 zD4u)p_dMr3?>X;U|7Wf9K^I>b=8nDhwXb{cYezu!OL-#>!HX#ER&|&mQfA?Bl6>2S zqBxg30U{-!~ybY%^LaXu6iC=Tr2= zKUHyA&Mk46k;4_9pd?==LXozhAMybXxotOw41EvC(81?e<{G;+c`8WMwp8QWFQK5p zy=2S2K_}vnH~e_4e_jGw2M3}9i!bX+-ds_~pIvgi@`Q9|h}(jsQv<2;9?8-55Eq_{ zmxdpEF6bZcrEO9%l;7R(gyHr>W|H$wz#o^U5X5cEC{afzcEE^2{ePHxtG+z2jU$(ScmJOjE-#XgJe1hCY}1>*sO!o zomqygM=%MKtA*0Aa({oD>*~KBs7q0b^o@U4#MS_?NS)5(x&mZdi(D?CO<`H^QO4_7 z{zs=0pf<`{8WwZq6@`6|lI;XjCjX6_B=~UZE58@&9^I`A+gAA;GY9HDUikH$fKTy? zciF-tZmQ2!Cb{-{QAz8yz{`VN+(Zp(HU*+sCAw9G>qGXF3Vc=>&$Gv?`-a_5nuQF) zwgvW|e;%$fXLgj&cIY6ItTXgbA4^A?T&bpjFTFO$se3%rZi{MQ<9tzWh@vO*6J{)Y z3Ghw4un8oFJe!t!{?>?_5Bm)dmG-!B@%Iu3Ln2-ih+Y2C8m(!#H63w4l&5l{>Aipr za@Lnn`pSoD9|!^^{0hHJ`Md;e0qLZTJvj6y^#t`dPzb-o&}0a{X(ZlR`Ql5GVPhS} zDdv^@QYwW%(U5f-_a6WI#J_KtzO)&WN|n31$J8Q7`zbV})hssRN1u^Hc%s7%XS5jU z?|rNX8=^BxW)J6zw1|npnOtGkvqLKl}R|RZyJqE@zCB0H*Q#QP#b;c$Vsv(%@!yK zXtOcTv8>S+>koLY3+}+X0EOALkF7Y4B%$?{kzo&+GhXYNJ*g1ED*C~3r{?8GKSYEy z5sC-9ni3*^jHE7;yqbnjo6E;ixNORO7u$K`M_rXDpU2Dh@5%S>eM{+MVUL8$Mnc0) z4ZGOMmeeueFtp`$)0jBOdyY_o?U%&`hX50*?_PQq~rM z>2tGVl;qYxxnCAg{2bVjJ&|lXn$WVFYxnEpauvFK@1kqQDPS&Ef%s>A&RRFp@SnC2 z4`g(etr1;$dEez}0jeJ-a5j@ixay^4=>VNoEt%G}_>f~137#&6+2(!h_uIrbSy}Gz zHH@*6C_v2M)`1l85|tZy`=k#hXnh)U3i^(4uuAbbhsECpts=9<&;VaR&5sl+6iRbp z!g7t;_fq7v?EU!=>Ql<_@E9d8*XA4?(NKT%>O4JzTO?2p( zs`kC=3`VkA9L~47Xdw}UWLz(uaueLBac>{^PH@b<>$#!AANW4PX^vF30d67YG0XeD zqeJ25XOzxON5aWS<#r*QeuXx9Dc;)L&nSmYH&~iok@E6Nr|`X;XpXlT4Kk~rf5^Kx z2$JGUU^zozF)R;~83zyHyJKeq#489JEB>aXUin-8s{6S~wdc6&=29%Yuih83eJI~y zCEIMN40JsbS9Fy-SXyl3j23vYzG8lh?W^oI$2q6fF2I{=3&cN`f(mcq3ox3MVOPN` z*=D3Rw)r861I@~Y7Y$B{(vRPZ%sxby2jG;WiL$Vbd)~VG*yrtV+q>f zZg%F2{f*bMsRS+K<}deHZeiI4{)s5vnd2&PVA$onE)~h!s5HhZvKGELfFflcV#kLu z9Nf4$4CuuW!#m7+5j*#VGv9^B2t;VWn$uanWC$9Th~F@^P?lK=P1Wc8ThHwqTl1|| zTWoNpRisZpvLoYCd1-=X@pqi5$B5`HDtU4?e}8J};s8AbK|N@ui~ZO&FySYn2J=GvYFA8K zw|o`z_C>l@JY`>$xKDl57lF zubA5zc4)CW(ed}Q+u%HiK52c!#z%S=TM-&!x@ya84GHUrC#nA#kOqh0HDG^eNbkMX zL~jb$C@n}I3!}D>o;w?K@%|_x0FqAQr6ER8iS2z$3boa+^*v8+@{Tsj#xNp^`1K%7 z-B=0!&dJZ8@1$ZN`;ZVef=8S-e`Hr!y`{g@?aGDbC8-sXZ9mu<-FbPHoCE$+C{yX3 zvW2+*_n5?S0sb)C4g$S8V#G+C$C^RCBGRJM5(9R}&k{^@h~fPXs!3b+t;+6cE2dGHb5GyTx|ozJoeRRt!jS0q~82u3wq82?t@MB{rxO`EAmYF|J~lRtmqvh2uPd;FgpOBECwOnx~#>8gT8e z&@*A}h3df0_fNS7!YJX^H9@4XHmxq-W6C?&<>%z&j@F7jnS|T zx(3UXRJ|Hs1by}M1gZW$K>iV#1k?G^CSti(&tEmP-t6uRJAKsYXI#*vT^DJ)8qiA> zu%(1kP_MSYL950@u!nl-6y?Y3wtNM(SccCeTVS=iaP=!oMeXT$Yvxngd?l+udK}T1 zQ5MHk4rwjI621tHE|NVRHnE{PkO$I8-sLR6rYSfcm#qdJfnGzdscZ zaE$L_&#O-p2eC1>)rmo-ryPROUK%NTF+#-Bmd9x=q6Hl1=x%jeSE-_mEokPsW%9d( zXqH=xRVJt9qqXgX_ZutGw1lYrn8~GCXzHnZ)+~k+{^l>gd--6FDYnHtmDMCR8c;SG?qi$u29tO-tNg&!F0_9?_c%EWu{KI%&AG04_Z6nB#L&smPRKS@xUyRFG;-Q{B6bEf(rJNP$xs z7h+rJ#>z$E$bm;|aPsrFwQK)lyf0!&8IOnG2zfk_)kt|FXS6P=%8}HvNK_1~EWj}a zt=G7G6l1`3Prf<`Dqr@qf+4BE{Kg5=D%%nOvab})ck#ZJ*O3mHybidGet*0+fxSQVr>pP6KVL8TYSr}(dqeDOo$VC2jb<7LtNkTBtAQ-|lVpz?(_Wix7+>1A@bDI6jeaMC8R{%efrQi85 z4^!%FJ6Q*Ay>p_j$mqMVeA`_nxB}Z&|Me!{k1Ee=c&JUZ?pEvw>0NyyY+$yu%!!qK z6=a?JdO*#DZnEc@gc26@N&&^jZ;}t8D0=D^@}`G_0!+b%7oR@y6^37g?C;Itq}e?g z8^?l;9?2t3MG|J^A-(I>*DQd+qNm!mwzJ7;pm zvK0(?cCYS8hhP$B3NI;>*t4UOj{e}vbzRf-q%wQlf43?9e%$kKz8``FR+O!YnG++o zPW|1#ftFi&q#ZHpHAhG((c+M9B2*%&&#^@LFH6l`%xI6%S_GAe?5l$5Ub`hHf~#|u z0ej+Uc66mPCI90^^&+|*yrZsx{-v*SivFp%uuUs4%I{47#pxaYoiwKpz%cTKjOp5jODS%VP_1+_N z$T9qkvl% z>XbVo9&89o-GgRYVryr;NmhN>s8RH`^3+dHn)8uqw}#j_19ozsLpi$W&|i|SE*e*Y)8|b$Rx09dVXE z^CJoEJ0Ef$EIQ`3L7qRJiT)r8cx?Ws?$r^{r*-HcyBr-RE@}!uMfn|L`4ASAuf>TI zT>F)ignoS^kf5u?$+(4Hb%Feeq|UuvNv1Sq)x2dHDi#Urx@hD6gb;=MejIqoTe|AF zE!HgY@!pZp)QA0R>lH;Acka*e?oM;&UXe%6>8q-tg+}%=4lF|zAfe)Bmd$o)$TkSC z-avyCAmeWGRU?SMW*(Ek{zihRZvr7yF^PHXP$Q0mRDNLw>i|3K(nsH>S5|d$b#n_g zRfFSrl-;PuV|s6wR(DHznI@A|=ld*&MF+#wxLLmnS_iMRbFF>L79Vioq2c`IF(RyU zoKA~^2Bq)V*G2Olnt-BX6*yyPJXJIP+X`%o`?xboo{v_|#u5SN=*!pX54rx*(x7Rw zv@W)l+^t$|8;q>DC&GFl>vVWfyG7qgz;uJywCd z{^}~`rHiD0gsIv80A494j$S_1j_}4Hgtknm4CBnq>My#AP>c!<;mrW~D>beQT{122fkqxhZ}It`-wS^RN>viQ&)N&_H7RZ^s(SX zHm~!UwHP?(*#as$@r@`v3UBM^S_zsTU$w@P&U(GduhY98!^w-BYuAw=snxgEP5Fu^ z`QzO&O>q`uNqtRn#p)-A`?|f<5FjR>qzwC0Z(#fWDTiQI&E}pM-(4Y zH>1*B_3ANu&Fd3?#^4X6GB>2;p(pYt}u_ zVfK|ONF@*;yuG^!QE9^dD?D*Utm{z5~&?wP(OD#P-n9sCTGpoc4jGywk}?`Xn#goJ*hJFl8x%%`J)N$L zlc&yrULKHOU6^O4=sI}9taZJHT$)jD$fM|yhI?RG=Z-wT@C0)$prK+@GHIzV?lL@t zZsMZGWQY|YZ&m2?Rgi}?jMW*Sf+7_hdURCfYrPJC${*D)K%gi)ba^^p+Bl%tLdqJv z_>6l%htf+X~lHB(|1>7bD}c7lNWyUY){M z^)cU^45~?lWf#lDO09!m)QzVx-syI{l6}=%*}vk_!UY zbK44mYE_p7z-aBVbm+qp@=$fm)P>cQ-#`PE@5$*~u78)|aL}Nc@;;4&T;H$i#`wjk*rGf)tpSv6M?xH>hgwNKW&2AT}|}%wfo40 zjw51DFUp~*D_WF3t2!ku1GD|zOc0oVwr2-@aDfU^I<0T+k40%#dzzH`6o3|%=$q(B zm$6`Tmb7R(5*KS#MeZGb`4An`U~1VXvg@XSM$G(t4gR8qPRUvK9h?vx4mv~!e3|#L z<-XZsYbQKBQ##a^;GU9gQ2jPjZ_rHeAOdUGT(l{1rwS+`l7L8ZLG>V;RSBy8=VGNk ziJFU>X!p(Ma}w8trN(*UV7j@ItAYHd*V@D{Y+>1FFW1yZe$@34$zRr9TW-jXFv-F4L;Y`{Y>Y*&Pc|h0LIfSvu8v?3nX6mpD*Dl>G#IX zkY2}A`zB?1ucwKE4<*X!geNNwjj$tOs|q)R6UEqBKD2PR0wl#n6JO4=*ZPG52`$AA zxNu1-=EM!N&1B|8Jxh{?(^H2BJ}PenB)1&_IrqZ0$?$NH8{7ZcYB! zlfS4N=Agv$$^1x3Bfn&q<8?Zq%xYp`GfWy(%v!;f3Q|#&TRok-eci-+ z02wY#XLhIh6lTM=~rH>}y66 zdCAEm0mdVhM>hSAgua|))vCr@GK;fIEmeWJH~}5A@OM$WpwwP4dfWC98@~^OtlCHc z^wx1u#xU_?xQ#H+%Kn$MI6o~4dzf(j^omj+N%1JE_D3~71vy-!MEI+mSPMkckQv{U zV?9-!u%?c^moktNYiL(bhV8`d3wyo+M5r;u>&ti^sQaNPRzVtJp@uW$QP}3V4-%%e zl(Fl*R$qW8Ku_uxkXx`nVT2QBVww~OIqe3B+K$Iwm3bzg(1+BcyISX6JYTU)4NB}} zSA=@5GN^0CdSpm1zq;$5Xvw*IXgCm9kE;`5c3BXj^{*uQFVzDSApkJPs99lzr4DR8 zJ1Nk@P7~}Hk-^n)+^d^@U%*K2H6To|*EI6)W{IClLcksq|Ca>Jx>>;&0^2kMf~>+t zJdjYon8Yw9J`@%_;ugz$)*<(L^fKt|(S0y`t2Uev_4ol8sCyHQ(iwT@5P}B-7eF1R zkXGha^4L~KO8ITPDmgqN@DU^XGL6gy%~6c$){WF%(7;UJxps#c-5c4L|J7H4enK_l zHF&LUx;KLdli-=BE>%G$S&*l#eIqjmnX08@Qj`kBk8d&nHGB+S3*HBWLn$*W2~m}0 zjD^}$kcLnE9>%Mnz0JTtoAy0iWK@v+2{}`f_HnKRtn|(d{MwOtZ&j;rpA-iLS(S>! zw0*=Hx#D9hCgob3UK;fn{;G*aVjI7Pqu$7v)2YJn{Bm_=D-XxaKY@Nvf4R>Y9r~K* z`zoH}rDdsxL(fmJVF3kZBRw*y^#O&)`(MhX{TrX?Dv(jjIqfVdeNj1*$W zKyZ2y`_*QjB{2F0@%g1g;?@@}h`A7;?ufb53R^yAHvMlvYdu-#JD-0$`SOt~@ja8t zCV~y}4k3a$kg&1Y>t=jdG;?$UVL!4xWl{Bg6NNmoDCbKDvM$;`63Q3Iti`4P0f>Ux z13T}iUk<{j03nDnm1c}NAK|sFi;k(b@k;^ErC;$QWZPRkf3QYni_nF5nk%LCy`054 zxi$^b9OJya(`_=;vo#aF+KgBC0W>BxlLOPon1|diIH-1SlDBXInPn3|e^7718eeeM zIW^k#f*^g*CTl=owy*Eh;xuJVm@B-?#{88w-W5jmV|8iV3Rkj52u>pNDMjYkqne;A zjqiG$=`^Eh?GNOMJ)VSKdx2%=8vZdqfED}JT~Cq83Dy($1(xoxUmAfSv;JGE`E*L% z-3olD0_adLSG7wCR*f;G`^N87chJ1cWA$B8h~+7O&%{->v)T~|qwo`YfdCnq)0Q{9 z;d0W~UhbtU2Iqp$7*)1w-{FrHFW@{so97vyCY8^!?$m2LllE@AY3?w%RiV-Y@@U~^ z&FQ*wQw-^KuS)!#NZML*9?chGT%hYd@DF9eG-jHLDu3tE1gni1&1Gvx zOhtI`V{6`-+)$HFC+2v8AZg=m|`T zf#{s`H*V0mT&EtB$>Q-#^?ENeL3MBhG=L3qy+f5MQ8tI0+`8Bygp&^#YCbd>Z_$h! z!MJ*JBA3rQmj+D;n`9lTo`vTlU7|!QHqbWhki110JpMJSmMpp#`fPt`sM)=@OZ)Tl zSva6;I#m{!0LAwfipDwrB$*c7sS^2sFRuPd_U{2u^gQ45tdX;Qku^%aI2v*9bp&*- z9*rXSo$+-F8$1p|DAYm~ke~78A@AcPwZ5F4mLY*)PZ_Uxb4q^PP@tCnY6n|tn&lGP z15GJ1&l;m&139Np>+>GV+LM1F;b?U0`!0T7>A584tg56{Y`~ZD*{bkDH6ZcoGsyc<$pcTZ+2qQ)RG31@kym zgHY;(mi9S?#1{}HZ#S~(M0jIVK_rcDSTWSe9L0>c*)tFA(=65L$tENbQgt*fY(izx(o7UK$rUo2?qJ}4G7tE0~DS!nI(e& zFzaK>MAfbR@Qh-X5L`6Q{kKx+#bluepK)PkV~UaH3M_LKR3kC|a|&)S7b!75;;F|40C>FHz~;+Wp)~ z^NBDUfP0xv_Bt;#TAc6@8|z{R(xltS{-wv4J(CiCNmESjgWaCD-aH9HzHmt*4s7N z<||+MdHZzn6DLmtMfve>Pix$_pVX?uD+YUa?0ub=CYtoQNl>pfY3DWrf1ibeJIAAP@kMC969^0y zFoT{MV+=L?>aZ9W zdORQUeZ(6>FQ+ixIooNtduejKPvtK>bNtIe9cj`S?lWiRA*|)3a!YVpR1%O_Y(ht(*wMXJYy|i)RoPkF<0SBB=1bSEZ~h{TRXn9I z9>9g+Kw)?}4_MHp&F7wfq9Pzp3X?JoUyU%J4^*Jsea?H|_2aMBY(M-rhtFPOvQk_i z8DiF1o&+%OgK(%1sPj>$2N?+^OwF>?m(Ta@DZ$UFH&?S@n-8`7g7$w*l}YYon@VxU zr>T?fI#Z<;%Q5O-Q)HR^fVGUe`Qbo`!A~N~El8~p>vo#;@ejp@oAEJhC|oJ`B(Xu< zhRiRZl7YOt`L0zz8T4uzrerD#3w1*i=-9numPZyxLZSHK(6fWHq}!gbr~D)#k8K1K z|0#&q!W>%JuhH+|#o&6(A|M&=d~7HS`rBp*AmwFjlqk7ZW6FY=D|;jCbZPma#|v7C zqmtO0>6qzUhl%%pmA#A=fALm^J+zBxR<`LEcQ~{P(8^!*`N$gu z6^S?b$e2m^P&LtOvDX*PHu=rqJu<@)&|tt4v0%acanV#;H(z8lTU8zUt3Qsn>o>G_ z2HnIU^-{X)>iUigje_D#Dc(X3xlUGiSz6YIPx4Ur-$ozTF)n)~VU_M{5GOlgfe{4y z_)o}if2lAUPq}%@W%gE>xa94=ylUR0xRn_^TQ8EUQ~WlB2Mq-O5T2%1sa$q0&kSyA z!dKsM6(M59u;B`aM8&xy+$DBcu&V}%P#?_X2T@>ls#;MTz1o8|{2#I@rg%bki?m+du#0Z8A8S6DYG;_CT8 z8xus91)Hi8E268#X1GsO1xc)bQ(7wl8uh&@x043u%xTPV%Xu%>NhvP>(^oAK|8K^D|ezIZjW21{e^oJ z8YQO!3HNR01nCMtPzx z(m%fI*(M(ZP|PWAxGi<~p4m9pdnQ0{(=`$s2R5z;PS!UB?k?-SH*Mz+$i#qj8MCC%i*C%Vu8*Euyu@1USf#)n{5kXhipb>nv3m?3t z@@&j(>m@(%cJJ*V`UC{BD#Tnfbg(^H2!P>n(P0`vCpI;{+Y{z;KdSTwfaD;TLxES!7m&`C zQG4AYO2Z;zSAK0Fe(k&PwGT8fD%R_=*At*iu)xA}*2Jl$&rWM&bju}VEXt5dXloB7 z4%*x>IrPex2w`V0qJk_GuJgpXh3c4yuZq5h#`SPb*Ws+=MXerZ{!V1on!OXOVRRyx z>`8=B`S9`ycdq=B7>^1P4o?1>@{4Kay!XwK=*S@Wh-BG03u-gSDq5Bewb^AbBWt6K z)B<{$#~<5-R8JbO-NEu=&vvG+XB|qIFZfc$I)uxMs+oK*aSL^$l68oZ0%l048I=GH zbo%tZ6I`Tm1B?9#b+9Fq+^Dw>;+ZLr(Za!|nzDYYl)^e335YqF8OiBd#^}z8s#FWJ zJN|cV29CCsV>e|co5jr?ZIWH{KlxDNXe|Yw>9ZPjPT}~c3Hr=H#I_#o@=E;adTIpe zMrq%C^j#;==;Z5au`n+MY@Hs#F*p0Xq1RUr-Qz0+S=R4-PF|Hh9Bg`@QU%FE`$=7Z zV2UFLNySTIcKeqjo-j$Bgl8bk5&PUBJ~U_mrXFso`s^J6ztXv#qo(=Wp zIqjH*wh=sjMEF|-4Qx66`;J5cv)!EdSZp!$2$e^!)0qNs&_7e-94(!w zOja%FXALRBCwWKe3DAapvfL=o=y?1MyQK@;rNI3cp-&fIbi$neuu?Y9iUy2wuwo}U z;uc=fCy|>~fYv|QhW{GFosfEKugxcFaE$1FJ*jA~jMO;unsiKoUm-}3?(B6KNqIe~&GFv^(?Q~a*euVq zCW1r0hn9*tf|!as%t44FC($VZO8d-G$|V6>);9U&9&ljSP+_#HwsEmyk7;0Mh^kyJ z8rUylS`|q5GG(n84XpIwP~^4>q{XWQy^QP@C^Le4F%cy|zuAT}OsAM9K*Qd?f${B- zsh`hXr)`+s$;W%N8?^A-bRmqX*^5XAtCGazr|GSvV>^|Bi=AI*Rz0KZSL1v1M&XM5m{qk}&)#8#SBV{p~N^Tm2E!0!AA%pRK1U`>dz!E+f z0Ga{~@K>=&1Nc6;@Ea*Xa%`@%c0N zFoknm2iJaS75}4sN$5coiWZjV%gBa$&|6YaX9v&GejNuL;YB^uMXcAwGo&O@T8k28 z3ttI;1=AKJ-t$X?c(`sMcf>aKyhtc?ww*&(-z85#vt2_WxoC zKVZ8#>Kvgp%m~?}KAth^PRkch+@_WKNS2w98Ax+=JNS<`m?gymiRfb0QuPT(PgFB# z*&4tlQTS}Af*LjyxUcxFZxw)7IC(<6j;vcxMBO{E7!wb^N_24TPZ`RYVGZu$+$o*k zZvO3S`@LrLAXS?WHPPf%%tqL%{A(i$!Lkpdn0`HoExa!7P7}wYC-Woo^TSUthXF8& zn~(ms^W;^h25TUF7`YW=1x7ZdpSL0KfCEWq9Xm-ZZp$f0r_;=g{nUpOB@ga_i~cL+ z{it&+*cXfk*&#(fIp+aea_vi`mR4$&@jZ{|QZ}OQXKy2o+`iV|KGfc;JyhPXL}^;3 zyS}m4II$tuRT}j|=4yeSDAZYl1_OrJ{j->BqWVZ{i2OUR-Ro zS*fkkveAkW_*2EEX;@T=Ru0<>jk5LfD8X1rmvD~TuL^CPzJNtDCQVAm zlWov(jz(qQ_~_kyLmN@1js8u(cXBMg=a^-BKZJ$F#8I$ZtTjICIk?(cxK7E#md>g$`AD@eLy~UVVAE(}$|)<8&C` zcVFj{jIt(|$a1I2fK7$IH&2nHNBB6{vTvB)bEEco$V){3r3KJw@B|NpzSKJ^$TZ+) zNUv^}Ci15Po4t1@U>ARSajT&PXx-JdE)%z1oPG5>>R4oR+aMIi#~_?cg<_9ec4jwI z(1rvf0kMYg`2LYsPq|S6*V5H#75{O`tcdTNLMjjsMQdI~MBj@5Y~4Ub#tIMNg}NLxU6IK-kj0 zsfZ)?*}A67241qsOI7)kS?ZaUsV$$6I3eo=uSF7|vi*sE)i*E}{P}OHF*K@ReO8RA zES{L0Z2a(6pz|nw7--}dJ-YZ=SZ;?y#f(j*o-Tjz>e;{qCfI9Gc~o`)+}OMNh&2Iv z(eGH#P%89R>AE6FIU)`pG68pOW7Ik_TowH*L34VBIjD~S!G?Mqy>2O$0CiUa`}9AJ zim~O=s4~nJCkzFIgr2`VFrLiSk5;6GSUiNxvpv5@hroZCtk}#AD2bIOngw~u&ik-5 z*f7g4E>`TvShgOv1J%rM)fSDfjk=ti;Wg(I;ZMdD==1F?!t!p+`6F*UXVo(AKP^R< zP9(~d=;;39;*G$1LWEyr0FNI`+lnQ??8XWJe;CKUBtKFlV;%;ykB86*HbUve@_ct; zBTDOS7Jhbl9Cm#-oah&V(S{?KOv^d;^)KSS5bMcwpVQIGg-;^_>_>r`k60oI!+wPS z*~g6!X;(mn?JkV{Yu;AS8;*hwykj0Z+RzorfSLC zd3$-Sh}bfPsN?{L4|yeFuxRb5d_0RmW)L4p18&|(@WCaQAqSU`B9-ng2*S8#_~bo& zPP%-?3a@EL7{+gj#epEB*`A(j1g~PBE8NgQ*li?jv*4>hP7KO{A0w!Y===!$T55nq z6+XW}{$0y~EtyJJtGuIpMXjsp3Auu3Vg$-kD=;K!HEM%>oHB)a`Y{o;zr>~9IKWFjvAbWu zhRfr8WPMPVpBSBsfkz(|nKIu{!f4~H}KWO>qHn#u1_eqZkm z*n4fgi^fDsUNx{~Ytu0!xvW~IhD8-m(aVA1{)RH5F*GoLoPoVLU_c4v?h-!A!1$=T z#F}5pd%W~k{y*`|m`&(?Hn#VLY-Uc#BAF(MQMy5hV1tU*Ig<~CV5q*sFP zl>B69qK#m{@r<)cU_bjQQGXKiw}+c?w-}|)2FpkNVN>(`)j_C3WKr>4`>pKWD|}Ko zzFh|NEwrj|r$Uc%2*al;Hq@$M5lNkAYo-$OVr+muoQIGjwblz0a|qJ_>EZu_A7$Y^ zz=0V8_V^jCz-KTBY;8VxFZrQhDl0aC%4vARem+9AZ8__4P?iCX>?Bq#Z)oAVGNO;? z0!5;B>6ULG%3SkwF+mFNJUrSEPEVmjwgb>5uELBHEdUfW1}fD0@G9>4ynzA`rSuK{ z15xQZmn<}WsFkYHd3evvPZUXX7V@XzDSZ9(WH7UV&F+~dRELhAz0jltXD_F5kcMpA zxB^&WbYR$9>)OrNWilqP(q+<8di66)PW~b)^zdW7{LjQ{rvsr&`0r z#mjx4jCR=F^Lg&KX6XFr_UeCvGfnvGXgue?Io;FCoDNV9EDOBasyOV6lGr|T7Bx>X zSj+H7oB_N_>Tr-o9Z@E1QJr;*K}ouiPm_ zPCo(O5duQ#ir8dIDexnu|KOp2Tz8_`Vzs|&1!NHJBMiDX2~!zZULJo@x*D{{FWl9p zpz5d=`n*NU=*KpVc%BqFNaa**BhWa39NDAWvf2r6IQS!6AEvJQ8Y5j@5<5Pofu1an ztO}Dlq+FeqK($o6U55P^zy~m+I`mT9C1>NHZ`zqrjyHVe${(UOe7^)ue)qz@^IDW$ zvPoa2n`!%J6G%y!LjPuB07f6y_QHW5DWVpf8tIeD!}yWAsdOsH=WH*uLE6gurpc`B zUm@&cWRa^R4J>^P=I%}d8|2tL3`l_TZ*kof14*g+yP;T+Q@Shd0(kU|O%u@!pQpsI8UwUCv*jiWAR0Vw{4i0()hHhjyHT2i=`3Mg3h7>lUfo)RI{of)CR!QpJ-~ z(CPQ$%M`$tZFLUTY2OQCU~X>>z;C*5FZCq5+}F&zGjMCd>9N$z47|2C%>7ON;c3Fu zYFl#t8rhG>Gx)uOSLk1`4jJ$0{Ml#N9#5qZ~e#y4tg7X%Lt&jLC@NIb&jM5V==%In~J&2uL9a75Sf2n z1?f%HE?vZ$sCA~11=1g$^n*jNtEsd4cM^d8#2NmF{nR-}+!{7G%emd5v(DsQyd&44 z@@_SlV``2kuyDIvsO-a)tyd~{NoZm7p;PkqJ)dqCnw5!zPj&w#7^)McJBLSI)*{Z< zNydjNV-dLJ0}E7FvkTmBYC7B)y}32Py-6p@iyBU2y9)<|YC7px28YL`ne(C%=w1$4 z?DJ8Opgv(qy}@t^``9qBZ`%?s3EI|ImlFl9YTOD!ef#Dvx!3W172syuqo_bU4AxTy z@s$IohL{5+sVwy$M}fe0Q#_aXP0R*Jp_5q?aW|uAG!S+a67zfD_J_tP8iPH$qE4#~ zf>5NV3pdEvzeTD1=XI#*Gvy2Uwuv}AP`a8{#ao_4!Dy!veDunXJ*_c?D%~K0BV)=b zW|&4vk}+PTf^P8kw!!n3O+mgupM}1)Ta|05@Ag6EnPN#~G9#Leu=a^LRI`x4Nk@a# zH`4|)E4RIL_qhxq)}b!%#|covvW>y|(G!RYq*B#lkcDDjM11ui$%rKDO-BDG{>yYt z4ZxdMpczTg#zq(e>Gq4T04QL=7&AfqBnU68bcJ>RoY9{>BgQ0T18{K2>7;7nZ6M=( zBlFpi2gN=l;tY(Ul8HITEyV3Sah0@M)Q&(ogsp`xmr zPD3Z^EyPe*78m;q9F@@o@Do444ZBqx3D9=(4H2&9c&5&d1o=)UbwwVh9d`rkN?7{` zOH?zDK)2~Dg|Ko!-? z_4{pK`Td$6g5%;kSdF?r$ekkpS!zLen>+ufaSLr9=aT1YAZC0-n3T%{&S}Z4)?)mU ztE_-LhaAC&z>=Zwy?p?!!LN|r59kyPxhfP`xr(rrEG99!?yQCh7Y6^6V>voM4*oIX zx!rn#?kAtQ@4vU3iL}Uue~=o{F|*4mNTEp|qE&-~Pxmcoz6N6`>x+Ya%>(}>W_R=7 zYW4@XfesPPsx|S;Ev#AQ8QLa(-=mWX!?GOyO#!K>XfhBWs|()%H@lUb0<-h#&j>= zaR*s7CAmWri2W-^Bl|5Y53_t6zHz78Xx>3f8~il2CLSio43!w>5RmT8W*r?x%KP0 zfG_-l0x%u7ZjL*H(8M4hX9N!AIyi3kpXY{BU>gu1&SPNhZfhBOSxvg&qhhxYU<$iG zQo=^TKTf+l&tPhHJ#kRUBy=P~UZP3AuU83CX5+|a=Zg96h4n=|G2vbc)%yj3r-R4>Nc&Uo2Gf!5`EasaR8-lYiPH;U90l(<5=BJqe zwTgA9$a;!K87X&hVUpy>{~ecOKm*ea-Q!nGtRppZb;OE+>JRAwm<0aL~RYDc3>Ks~frEVw)viBGZ$hDacro11&6ubue& z`l7FO?c=wB(hWcE&P6u9$p&p@47m2fu$Ef0A(kbY|R+Lc}B|2%z048iH!?u7(R;WA9Hy1@(}Fui=pU%le#{Gxhy@rgonCnAu9JL-#KZ7vl-0gTXYd!QN;R z!s($AOjef$#o(i1swwnYE#4pXsQH7QXy4H<-e{7{4yA|mEZ-~iGM=iYAcQ{q53Ift zZJSRRi*0NF^r%f41M8(k1X96i+nFkwJRdvCeN{u>8G%F+e7N|b`C?vA(BURpNcb`H9s%Wa_Ggoq7T==vWbK(4l?t}aQ4@1 z7Muh*FRtH!6mXcw)*>8&JhHf?_R{Yjro*p`X|=a_UeMIZCd7|Wf^n00JARBaUA;7| zxi%80E0lHT-a&)ER)b{yl080}IgmNCm+)DZ$d$7d$W0#ww{mVG`8*-VOkK|uup@1o zrU;JOg?9|WW*-wB)7=>nc1nq8flr-0Eqvx2L*Sb5WAzod1v_@pT-$2{d8vdyWRV5q zEifEE##`V<<=uH$2Sfq7$F2uQUvI#zfOZl-Z7A>_iop)YuZ`7$!rT`LlxXKYk}lPA}5LX>HVHI}d4q=mUh zqfuprY$#(1DvYZ(v>qFt55{CLv2lOH4gtXrFfkJilyN(hB8!S{VkVD3E!!t|X2AoR z-PzIB#B}Am1OFi~@Iio((Ra+>{AcDR9g4Y2f#n81Ab-rtsQ@81fFURCV=0=7Dzke4 z<4M3`!ur_E+_+K7D{BtUzzMGW?{vybGqB#6oA|(ACdhGrX^y3V;jmj&8C@n@MSs7l zX4Kgtp&WPj&k>Ka-M3e6P2GRg`Rjw+?~KNlmXld=*J;6@VAvh+lb^#5{&YQSU^dtS z%Zo6=??jhxfbziY;7|+=N3v?=VuebAWJ&-LXF?86wl&a~Mg@`@S#lRVOqfWQevDz& z6155g?+43b{2+BWJSp=97mZEfKOmDtHQ*^6HVp(-AQh)ky>7vPl9BJtlH$Nv^5S1h zE)%W1oHB_S2s0gTUdOaJJjNS(GTqAuDv-|neW2RLJVX_Q9)Io`B@yTt)4oOH zN(aYOT<_J9S%Ni8+!{Bw|E#AhgH4Q8zQN8i(a=+ey z{rW{r{NKFO!1%aP5=yC&JSd4iCK~tuPY_Y2Fz5M(3^Qhyb2<<+72Gb{5+cpyy4KKt zGlMNvT38+BIL_JzMwmrz{QZ)efVVd!dOm&#EGRjP-I-43*5fZ}_uUw%|Q$p2B{L_pzok>H*l8$0Hag z@hgtdWhn*oG`#a~lausWNB=%qSn${=*U4Xf`Q&#E=9%kv^DGx;LfhH`Xuz{Oat!GU zeJFYGmii(3>s+S=$TXQF;FypWV=D=g2%gv?wSEBQxaowR1IJGw+I{uXIsA9aq>O{E zB(f#~fMIL^chX0wB(Q4T17o@k-~*fal|4z^;G@Hiz$JCaUPWO1b2|Io^}C4}(@bUf zmuW_4>;cjKtbq)!y|l%)jiL--ls!qK!p@W2^g=d|hYKV>NH^;xQCb{-@zG=1*d5rH z1YTqUdi=#R9YIOyIW*{^Z)_r z1_=QP=|P8dBcVu22}+A}BM1W0-MR0F`L4C+T;E#9JLmDeYwjQWKR*V~ecxA{=Xssi zT@TBl(*MHy09;y8<20}0DHVUv|3UkJ4AZ)3Y6AUJAP!z9(wF)LPlJ{|Vl{`eE=2R# z^4E6lW?{2nYffFXkZi?W5J)F_|fW*@k6Z>Vto7UoFqX(?*&#MYT-&O5>CCt z))UQsi?BC8n&SjZ#CK4cNP#qhhs+mZPGtX2=Hw4L@}urDJp)Lc`@JE&jVmQSvWoJi zs_#5;l?T%{W3Wms>z@BebzG}>QsJdo@BAV({7&NS;zFu{UG?PTbmlLDZX^37gcB~I zj4vhXUq0l~cYM|7bO7|P=*ZO(8X6o!0 z)s+q;9hs`OnX<*sxaJ6T|S=s52n6(qO<*fGyyL+2@Xkh_JX9KuXi zsOmj*yVL;)@#jT?Q@Q-rADs8!(%Uy&(5ql$)NEwcUBMxo{U&+YFdpA);@ z1~Fll*o7-gp4DzOyE%W^Aqv*1LI86oR7hMtJ`TT*$4Z%YJ)l2neUH|ePSbP4$jSz2 zm2|^&u1+=JD}6br8S0=+u=uOXgxAPVlz2UM;R36LIq>k);0YTbNxMIk=DmfY0KlyN5kvwVXqIdX)Q3W63#gqN zx1X-YLHWq9;6J2Fk-~>M5M@iB*4`_MAhWF%hbF`a-_%Tb&IJ;8zhO*&5KQ5k6izOn zvGhHS`3O)5u0_XxtFobL5|-}Z2%>CoN^l3L{x}}s7E1vDkVP~Wf*L@9`d8P2RYomH zV8f&Od#Ab9Ku6GgqV}QbL1im^v;-y|)5G5UtgPvD^{2qoYhVEngbgi!t4bjtp!&6F z0}B0_AO@|rYtq$;1`5i-|5|tR&wSHvT$m3KrFVfb%1dffIhJTx?RD`?4~`jdiS`}< zU--L*E4P#LLAvfZ;Qh-VKu_bJ2WY@O<=TIcL=&cYEmGo*k;D?a8&->eo5zmQr4v+Wf-5=BZP_ zJ;{9|(-%df=xlC42?9mW``2ZlTIetIQkDb1wG*oL|JvVLb>S>_eaZotwHHs zBVspT*<0$3^%6PO!6dn+M5z6i!IP7Ep0JM#xF3eq2Ov;IerD8SkxqCm_NfQdb&n78 z6aaS?s!(bTz+q+tX-PUzQRZl#c!M0*S4*S-!H^1nuwrhte?Y7ML!Zc*thC)dK}?y6 zUIofk0tQ^54OUAu@GM#fE4Xh~0boE>r{-F9sZX(Or>Ytybg@6Q3%c)~6rLBTN*xr} za0oX@b@KQT=~PhoZx8}!k!XBd|6wRgmikFbGzbugzd9L!<{Jf*s1rDC_@^aNq7NZE zKsD!0Ll7Yxs7RpWh?c_w)t>(Vi~$Ncs>PtKkVhyn0D!Q|-Z>WN*!ZSnS>tIGD}-_Y z$4JlgAuAaWOJ}wJ$RT~V`X_EV76xsF*V)wUU02ZFs@Vgc{4<1YOrwF*I=e@w0q)Nm zBj<$M5Tp>xltiPbK^NNv=sQmG9}56A!8(E~=0RwGVI7cxC|FzMLprarKThxuUh|*B zP!#esI8EE{$QL;5sltJcWd+c>4Zf+1hf=l>pqARaaUcLeoQ5Z}GX`Dec`^jdh!lb5 z7w`(nfL5jZGN25>BKE4)THs6VhNl41t?+A*xhMFhzy6VC{Bw{5=nHzlDqv^kM?o(u@-*P`3D1Y8!9$_vBtib+I=iAM;ri&BeWkr2fQv4Ma4acfAB~Shu*^*9q=A-{pjBUKA~0Y1vfw&H*bML z8scqVjse8#TAK#(CpqA%s_dWzPC;GQtJlmf3J|FN83NVI(xnrMGz+>z37L^Jpe2gH zx=u7&&#v(aSb9%p&Hg;Es_IFkz|y^w;Q|qQr{CWRm0+&{n`W+y&*(#Le4$AYsJGga1GI20lOda!5GdoglnmtA`k(f%1C7!;wy! zNEJ!F1of}Re8Blk<*co*#{uhDZ3C|z>52Eo0oK#C140~F1GNsn{-Z|T8#j#-hVL$P zO+Ekc#VM5KG4&iOfEzeLBzlN%gcJ-E3pcW{VIbMx$Jki|sZZpwY9s*C{r&xF1?3IriBPNY1zc;RP<`WUkOF7~Ip?Dil(sUoK*bwFldJ?-#+bQf28iIQWz=JW7-*T^ z2pUB#6G<*%kQf~RR{sxl*Po@}chA8sm~`qBUTi0`ivwm#1GdGTvx_~6Jq4l{mRvUK z(ileh;BmPCRaV&9HuS0kq^!Uf!v%l{c=&g`H2{PR^dFoj5R{XX`V7RfxpeCyNF+hR z!xze;{syaeANrT#p}_n1n$JIZSCKS5^_k|qDP~f!%_phNi1zk&H@CRD!~G%E9%a5C zbin8#)!D@*KA8-g5&C)%028;Q(OXIYE0(YKHv|17^)wS)8Igw(&8&wEnu^_s6rlAi z{OY9uh{FPluce?hs$GZ?0L=&|`3qt2RX1t>^@n68s|M6rvCtxN$7zR-F5_SdB9F4uB&>Waz!pKneeR z(GoC)J}CbpSoG&5e2cglCrPq^Q@$5)#fiI(&(Jyk!PM+q(jy4hnrQI&NK$7ul&|dH z(H=n-QR$DXQ*AxL>E1rA2sLxJ@j-6l=)VVM7v9|L?T`ex%)R(+7eGFw#}|MO@0bvT z)xVZ}_1J{x!@tzz#oz=Yr`7gY&BqXPzZyn)Dn zmX;tNFd?mK#m`6qSZ=DPC;-(CtWYTX`{}d)zA5($3JMbRUR-Rp2A6>Z^!8Y*prVUk zD{nR!w46T2uLD*80#M#FrMcIa0(K_rt^52N$8qD}E+EP`+nlb0Dq7pbP*r&l$EblG z94bZ_K{4V`=;3}xng}}pFAH?zG@p059zr!LX!!noI^#cd;o6vgeupg8ApcRY*Z}1* zG262TkmDG{?#fbnNP&umkgo20w`UP3aPFo%{Y}9w4$5izLYW>QbK*9bftYX`3yCLy z?|Yym!UzH2&i@(__Rn<4#pH-Zx8}37hO3J$Y|fysNXdbi5E!Ht17I?RL99eDX=(qq zK@}($&v3AG#lhju7X_w4_$S_9UD>0k1ygz!hLdnHWTGM#hc#8P(nU{~a~@M5P=?Sc ze9w0rFHe4>br>&J^svRgD67k#<%QP2fM-|;cUt;}K%?^F(9&hN)}PmcvhFK=;?*)&?TtGv z{s0&u^@>P+F6gpi2;@6Jv*1AfiZiUBQ-eY52=rqhu{QzyV=z&sv=oSdszTz6*`*US zw2Es$cXOa~_6*#(qXE!x{7c;pz~XK^0fXf`^KG1VGeTXz1EQ*;)s3$et|1_~fTmiW zN*n=3i6eKx&;+`dL>oXc&~xjp!9kcB7Q?|LXW?)+Xw9qeidt5I#+O4depVKdw;TQ< z3mJTxoa@X8z=Gh^f4Z!=Ih>bhzy!oTZeknfdwtlW8{NHpv(``1&ZmZkXjMfoDuvFs zzf^!YadXH4?OI9$jl>K}<@^Sv;oqsr0Jm+S)x-b^K;VB#bbxav|BQE13%(&;NC}CHHzXu=OlTZj9@jkY1ZTKvz$$A>7k+>8!*fcRM74CH{zZD9~AA- zRL*}2#RNgHUIZO3TO(KNC{VRs3&Y-NmxkIC#9nj0&fvymb19wzU4%ToOB^T>WPjo_ z0Y*cD`iGe$qK)EAgC1}n(s|m|pmC3v!_roTiix6g5l}une}+j9eN_5?JsRbSv+BGT z1C+Fu=ku?n0DX%GFOP3t{UDZnh~F5a0JP~kHNN7RSy?SY9o1>9791ZZtefHdBHvz; zFD~r`_F$Kow(RBSGx&wW_gZ-40)#%yjE~qX^J912uNxK+N1s)oyQV3PFV#y7YsG9jsF%L1FmbzE6 zU;qtQv;=+dh&{4q9(~jXCKs^r0NZR%Afah-3YEg;?8)Q%~epRc@HT1Ri0 z&AVHFxrI46F7yGbZ{tO~V4qWfW#<(*T5mAd_J5CW3xhHQ%q9L$B0!@v`kh^yW{b4! zkaHl(e=XI1cXKfF8WAcg#o?oVX7qt3nD9|+g5@H%+{1f*$XE|@aqGR`dqtySlIoHN z@3ks0jR$|ynLzg48Na-j3jQ-r#CO|{wU2#gP1;lacD`&htWrdJB9}B+kyU1j7^g~Gk zaf&kxxerX)JYN}FC`6dTm8$hQPuRR?mtSg9 zo!Y*`HX?M72c+1I5$Oa1*fTZNGBA$R^}#s~ABqTQ4-&4`6(%1W%Z6Tv?M>usW)f{N zni!KT=~$^-4xNt`KdsEu;Kn9%RCgj`5;Z}e^K&= zCgOq`AG$t)23(ONYi&-hRO|&?Jkt4QQYjzP21|+K>zs3%)Jf@b#*&g7E&A^V*ng4eGlV{DTXKSC>gY}QY zt~6j97TsN=!=4_zE;Z9*E~VCpc%KI9QGa9;h{eW#O9`kgy2s9iP7{dFU`rFj8Hk&a zj}@#17NRBL?cpSjVBK}+@y;)04Eu|z<_kR!AX6i*Z=RUy+7JUp!sc>PvWJ+#b}6XrA`$VbMTl@b={LfR{OBPzK3-`V z3z&&VI5O`=zxfLRP?eEZCY(b8lhim4s|DpiLS)=L59eIJ*{88#@B3r%!|@=mZQ>5l z0ktQKenF{D3f*ESVj_MlP)FD!jz$4rARI=) zdKB5=6Ta^M+jW5;d}delYESf<)mj~ScV>olvUsuO+kAdtcvZFDe02%Qs7QFNlLUx3 zd)}2dF*O4-n#Hl^%PN7S=6jY+7&`GrTIM0iDz0 zRJJa6BRepjeVcfQ$9qN5Fvc`epq;KdyC+h%sUBthzKs}f-B#1t0=Ru#ImKeAEvD~L zI_7TP8n5 z_fehp2ZvJ!c)Q#1Q@XtVR^hm)a54Z^(n1tXiGDqAw5#={T=T5YSr?FT`i zDh@K^*ab@b)UmL!KtbbA#m!e|&56xt_06$IO-GFkRC3H~qorIOCY2dS4ZnqD-a{-Y zN&Mmq!?P5G$@VAKrwzQr$vyoHI5Kegg5AO5f-msp%*`waa zn`5YCE{G9w6+`0Z+wR|^j(w{BBq#eR(%`$R#i36h+&7JUnP!c?7Y>}fCI2|ZFYw}Y z)6vK&;Lg^@h&S)tD|!dnK;Uwgz~GVhQE}7VtWIdVonhj2bbg8%+(%U)b4N$O=iABP z;dg_kmDKx1AcbG5xA+ua&8CoLxlq8{vo66Xg#Xb6!p!occRB6UttM#6&9PlyojG zmeDINwp5XscnMz9^&IKpZ~Hl-J$pNO~Kgfl%^tg)a<~o0gBE@ z#!^I;F8m(5o+AT@37T?V5w<8dUsZz@ZzX3F;}zut0+o-;ZF}gd=wfgI8Q=2ipX`xw zmYxUvc%qKoD;HVL<<9vZNlhDN;}zGcQEk|F%H&s$xGq#O3n->7-O+ z68EE@jJgf?em>~2rADBP{epS#AJ*Y)*-kRC^NBn-vTIZsU9DNBTI$(6 za35i~DX%~7K2a<#!q_K-asFWKu7i1!elqh`RPg!W8o**ut&m{15pCtUl~FdpDmUiY>A<2+$anS>PkGcuwn<+o(V?0GhZk}B1C zqdik_pK-kI3=oJe$~7$-6nt7uh4(Pz5Fhf$TT_$j$FT(gl)?* z>%9Y1%t02tYNg9X!^^gCN2KI07>pN+(>o;aC~QCv?+#^LYPn9?@X_WA?RL^4-9$0D z-Q8L#o#fLwigxgsTLzlF3jh52ljO5HFtqe7gagNvI$dbLU(!78-U5mrN_LN|$S$tm z`+$ZF_xam(tHtjY=7Ons`w-ah2jQ(qS-*jC3U^G;uP9MAw^{F+27o&U4KZvZE5iPQG287KbpH4G!@DcAanZ)x8Ov zz{(Bzahaz&z_&FhK^10t-MPCkB^oQ*`K5wBk46!=HRSFw39Lq}N-zrc<%p9!wJ%uV zW}oj{#;QbZR1Z=W1LL`k#zrapPIxO;R3J0k{#q<)bWKcVB_YYr*jmcO7>$KtEB(|* zF`?*K_6K++|bVKoeyLMqU*A zRtbCLMovYd1tp~uJSxIudO}!Y5U3YP0kLnQTIS7^q}Lv6an-D&;=?@mxnzAMs6BO) z^4=MTHB9Hy^i$i5V;s0ot^VPSgYR#0!Iuh~Rh!Q!iL@MCmB6`08#m~wLQJPmr+8S@ZBdh)H~ z!vL{Y*t1#R7<)d`U^cm~P!##l^n6?~1o3;go71XAQXjw2?G4KL!-}uI#D7>&eaKmS%=VD6B^^g&jZR5dcQ)=x@@D1eK%(~+UIcg)ZZz2-0tqcv3o+HwQ9cQ@>O^#C3O|B9zH7hvkj;4prEZDb17j6KT(IT3ID z5zgrzOC(N0xrDi2ZiSWd$*8X^FLUI4%(RVGAOd_&FqKbjvcxi&;pC>R7IPQ(-kzjh zed{`Dx%(l?`&*}ADxY1#)p68Sw4Kjhs=o^h=p3G$&ziDockI@A2?q_P6cMmtQoLo0znpesI)?cWKV{>*pd0y=J7`%=Exv79 zu1z^WNQSMogI|RPKeuuQotUIl6};|-mDt71Q5j1v8caL~2l3A;td+U?(g}IY3w~90 z2l%i7#aF8b#F1e+76vwGYE^hIY}klY|(Z>^BDkmzOfy(#ksWAm`IT#}lRx zVEUhphYo`#%;#uxF@z;P^tKcc#kS1uJp`E5I=}W}^~Z(r9^%99DrVdde!kSD0}|1R zc04~OZG{;8p);3uae}0%;Te{tJpx4o5h;}$M0L6L&6+K|Q-c6jfdk`4NhD>+{QXPE zn+xP>rwmJAUzzN-1>&!zL5~)2&I~HQ-48VS5y@my(D7l4W48<1<^a*A$$02U zoqaxsTwFA_Wf)omO=% zAvFAyQ#;3liA{8X=;RneGj>h_P0{y-Du-V(vEL~lJ1r^Eiio%RaUZ-hyijO59Vzby z+ZWANBXRZAL}W}`D#fF-Y#VlDJPt;l2JG&na9rTfG0E{KRajj2Rc=3A61&jIW0JJd z%Nw`+$6DIhk-TZ)*`_A>zoUIVwAgVxsujJIfD5><+b8q6ST*Cq?B{-@&JRm0N>x2W zplD)V6sWMi2F5pdg>KRbxPmlcp)93oTLXI;v=wpqIcgo7Qec&m#Dxs}p$co9Hli)E z(}etuTHqt~dE{$*;<;Cvn8u=~F2Fv=Cp}`2@ne`KrZ{{qgyoGay+)8W{E;RmD=cfs z!a646SZV0~g6zIZ!o{Pi{^dRRxH<5R_}s2UE(1SlUS|dNlW8c>PZG}_S;HPFI?!6+ zF;&{XO>S;Y*eh#lBr(X;LkLlZ14=xNZL^2t+TsNF- z8MDY#VI-A+1`H0Lq?nb5rl26qBE80ySijsb*Q@ z^l`Kb>pm7N=T{2wLR6g1`L&OS2Nc-OyRZGh#lL5{xrxEe_N6aL11%yB#?# zH73ZT#yjFHEhhB%=p?RW`S(E2DnbXC`8;yhabp>34*)0hatq;z13P>=`~^#tq!#MQ z%A`-fZ&4k;!WH9?CZLN0XbUQBQKr{$Unt*{ubOyS3xPP~0#aF3)m83$@q)Hsd_d=P z)OdaT!|oeZFA28VpGe8vmN5<4U$W-M^=#|f;-b9Wz3)+=ZqbI1?&eTkbT%-=VGE!r zZJ$C)$LYWa^=*rZp;Gs}v@agNy%8H=sQvgmF|wlOPt$FEz}$;R03t zu3GOd31}3-U3i>W*r`E|K%uDq;v$uhELhbatfw;8#SL`TG`z=`!NRS7 zI$A?Q?hRlmW<@Xyb*=d*vF=+fIi%?>zXTp}uPo%w$wF9s!Bf~LrO%jtyu;6s49HGlFvE!ZCtV~ z*>leCR%{NpbzH#?2OHsf(nlcmCaf)B0ZGchnakVsFdW$T2@3EJSltg{SrUR0)7{Z} z8bt-^1Z}P=HPuhDqablrfu~&Ux}OTxl_f-hGv`@;@Z{bO_9RN;F|zebozEf90Vy|B zazSe|3lly>WJSeXkM2rHyCLDLhJ^X@WLJDXX7s1=N9Ot z;l`pk465lY(5MNt7SEg;U#V&@!p`eUxK4%56fbSVxRCw%_nF^v0Ru|VxWFE9PW?GR zs8f^M3xLM`aNHEz(#nc6CkyQNnpDz*7YIpEsM`%=(cF&w5+d?o4`3F2v|k&J=fjq& z4lk%Blb@%#Bs_CyBb0G%JwutlBNOOX2rU;o*az3uk#VWGssAdZ5V0xZrI3uq*Q zV>h-7vCZY1yXmA7>{dcp$he6>pFL(sVK4a(MWfH*-;cYC~!X0 zw*Lbtk8W2hkYOs%{-nVy&c@0E1(qge?~F~uTT7{gI{;DDC`4sy4`C5J{md*#|7$}r zq2~fv*4N;0ZoI#jJga%vLP&clT)_yR7L=M}yb{8X7_y%Y#EG0ivpDZ{o%4tK14kq`)=H45*NlM;99*t zmXy0gOn}_Sh&8i_`lX+@0nRCV-Sre49iOXXBI=W}PX84{YOWHU582K)rzfLc5pq($ z&h(R&j6(6aIp3PJCA8QQtAbKb8e!z0Iw)$=3^(y zm!*F5(>pO#zhMZp>j7hn$NW(goHQP^6XRcrf8(r1bg0Wxn90Zp9RCbJ zk-*yh;wZ0B z@FYUDQ!~|WOway`Rd3MFDrY+T$VjVU7!fgeN}P7j!`t3L@&=#F_Ggdifz(kf&ki#E zmBxn`4ejjaA8s*5q;uZ0sF9!gst#@;NMELtbOP^TzH?=gQEN85eC)6CR=xWwtw+f+ z6$L*ne9vQc-oHr*vsVG{ofd2=xgd=~{E|&$|56PMvaA~PLHBFVWb*=mev3psv1y{f zoK|{e;CpDL=P>_>J!!I287ZL|2|WauUBH*q=_gR*#SyZ%z@$NQ8hPEhMr5R!O1YiyvL{<4OZK0<`;Jk#ykNr zjcG60yest(2Z;yB2fubs@AdalKGWlN=6P|J4^W;SC06IYPkf7qgg7ukEs~+WymNa{ ziPSEVH=4hFA(eo7K|UnBbY3ELu}ZHxMQL@knj9u;PmVXR%3ZB9*KRCI?^(gq-d`9zBh@h;jE3M0~ky_FZ>&TplGQEt%8TJ zJC??eU$y1#B#|yb%0~owy7wWBs}U^1n%46Fd?+Dhn2gf(fl32ord+c?CI0H@>N4Z4 zLpzDgbVys!enl-dS<-5&PTPnq-4c51{6JZ9Yl zToSz*m2H_n&FCRDHR|tS*I5sCH8qr zDzj%C*eX++8m;yEkEYuu3z)xtE8z0!ysdL!>1j1qBgH3dz96H1qIo@Sv(^2dj?xJ) zpO(jOH6MN|8Jqj%a=I#|sI^;(NYSRP;4bQnSsekc^CGG2rG(`VF&~J3o*sVK{{B%|UEP!O;r0(t zUyi-GBSl7L3HK!Qu#c6FSvK?= zPi3d!{xXv}WM2$mEda2YT6k+H%Mw{!d*M5W)bpYfzFqhcHXL^JYjET|!N9i=U~1TG zJc>rns)^)TiRZq)uK^NZBha8(kLI%P*WgO$N(O?UH!m#?zyr0SF^~ASX*oAI#tzH5 zh8X%xNW1eC7rXfQI%O_)VCBme0zYCGqK7Sa{k~ap7Q2PE0~7Cd0Cy04B4K{gla*24 zSPYfc9iHO){c)ZberJKg%6ev&Ltl3LN7u}}zbFo-uJGo_-|^O!JR;0L5Oj>YeC0Gl z8M`oS`LdIe^nigdZv<(_GsJIInp(!TTTHy4+npMn+g*B(7)-Mn z{r*OtbE&rxKAv$HH?VdX_pKVN-vyp)x>_c=luB^RDe5%*gwASt#l_yEh6kxrGjUpb z91&rW8;BqIzg}9rEc~7J#wI-5L+%6_AmwDO?FJMGX}duqDW(t>7$Iy|kQ39oxpJo8 zR*q*SBBLaH#4u{3h*ooB+fqS1hZ(HnFHcFBu1$Sm=sDVB@2kqx4?6jA#jm%y#>j4#X`6IcW z>e33;=ep*+aA&DnWlkWtO?G&&We-4A8_U?SgR^bU_|pV#zAiH5UQ*Hc^qPo7{n6$9 z*q;2Z_caU(indQzMhq7Zm!BuyT`wP__q>~uTC`pgLKbLG6XsebY1&>`#7eO{TBc~W zXV9GJCc_&0eN;KdPU<;6m`i-LJy$!xB)s*}Cg)rKi?71Y7+JW|pr2_r{H<5%GV+2L z;2#rma|El(Yrm7L-vIcs-mvNAeZ#l<%gg8Kq&+j2n%4Ut^5LtoC*@`+yW z2Wh$w@{WY13JCo>cjvd&chM+#+|vwxALY6j_Sus(c~pdZY^aeOoZcL4a0w4xuVr}X z)dRM5uF+{uQEoniw?tvY)r!x&8(R?`oT`?tR`O6E-MVP5FgxxRywB@w$Z)!qz^Z4C zDj;U~cH>GhTeXa+5R4~6BGLO3@l$nR)SH3UZydRQJLl0?tlA&0Q_4;IdF)+a_&s7A z3AW&m$rcSNb7t2|ew^LzA8xa(AI*C?PfRefK#Eo2knYQH@uU>$*q3;@QI$yIf})8g zUJQU!w}r>ECy^u`C^?lz$b`do%@>DUyBM#@m@9q;ytBvup!TH1bb%=RoXc(V8Eqes zQ$7Mr4-IBM|5?q9uu=8;!cs zgR88DnwtW^*^fY8^fmLDmt0L{Wzz4(1w-mq-yZr`<@c+95%PI(us>ReyAu>X_U^?{ zGM(_G@c?U)XfxAxQuo>lNPG(G2s&7}O9AGX5SfS`n?Q|hK|?9$;}#o|f0kk1!XOD~tV+-~6`_{{^; z?jMg@ANAjD5^}K0;>RWXnUn@tA+FJGv_6lhN{S|B!70Ig%HFHzTS`B33(GEiKNY+@ zBuO%UDuwmH|3P8xSCpDsy01x61B~-uzVfs)%5yg@zdV6+y>SQ|kTRuf|BZLj5n>`0 z;Sd$V!1e|cg9aVFB7o~liSUoBy?W(%u5qp*ef=Lbo~KU*omp!(8!sN~*Vz6zD|N~y zO}N=FRZl_`D9oYX*%pb!)|tHlqYJ8GKe&9?ODgo?UKER9{iZuy+~WQ#pZJYZyq9N& zgS$rsx#kBLMH+pAR#I#U1UEVdb7LzvXkBB9C8MIpzuwA~Zp8gMtnX_!BxdRLh$+hO z3ZccrQK=@P1aF_Jyn02SulcJseflzJ^F=9)xn9X8p;2|{Y+{p@6e$m;FsJhDjL;#? zrj@6wEd~k0-&o$CJjlLUjv!cGwnVH*;LHW zoWc9spCH4+Lk}-k!ix&XQO-!wuoFThv17k}_0IDUOEWMu!~3EA(1v}iyZ)%EC&>4O zVXk+kZrwAq&e*L9ZR_SsK>!&9NzmHwwj%OOjb1M^(6 zVwYe81B$#!RF!d-ml^sr0oYctIwYr|%On+cNp+_!0y!n6u;QnUjAXzTB@}zi5}mh^ zb`{ws*uNh!iD5bFkXl$>hyJ3gU!v{<*(ud_MM$31*|80R$#UE*PK#VX8oo|Bs0}MU zsxN%atKAjUJlcu2M>F{T+4HH#-2`6Z_pscZ1{9x^FK;B;*)mP`Ik{A>-G=+Mc6Syd zDwVaOOmrq`lr@~TWhSnghsllzolZrB}v+!8#gbLb7#MkI7 z&(ipZl^SdtR9c;APC2`C_@^e*`Hl}Ym$6Onr(u{o?v%Y6qGYHaAD6JxyK;VFn8vBc zj~Y+N%E#=05?y&jO&EaBdA>IV@7vSOYs?>bN$+k>P4F3WD@m_mB|r5WIr$PN?jxEh zfjTj-I&LQMXnHJ!r<_D|9-6W1v~skWtaL2}H_8q5%zd4KI(2)Qoq zLMF4o>B?d^ncpG%Hls5jgi=mhKVl9dG2Ptu+gfKfM^Sy;>h*y-Ej$GqX2%uZ}xJ!=EPspPrQc2h6D`X4H1C?Ma0_D|f!m zzhd5g7C^>Dl>5s>)o;U8Fdyl?JEO_a>T@kBC64OJ{MdDb+_Ua|);IYzGtsP37s?V> zKAPR6d@HxiQ}2FmMPWhJv6Q9F4)#4og5lL(9<+$)`s0uY9_P;~;thl}eZWH;zq+a6 z^A(}tw90MC_Aa5jSWw=YXrlH#Z0TY_@(nxX1A65u!S(#@q*1I}9y;h1bS~Qmvm#;% zr(F?_3w1g0f?u?Qx{b<{&i$Xb@kUy1NVnkV%Aw)ofcDf} zL+3|>DH}1hq7oX`JVu2P=fk@lw*(&=QfepV)rpJvoS;mcAAixP*w=q(r}-0`n3R|q zgP#%xAGWR78)mnha<6QfGq}Go`QyP^?K5qW3gykK9W=ZjwW)E{Pc7?1s&RN#4V=2| zQg~x{_dC3nN%j{IZ5H~?^6lvij>~li8Xw%`NONHlHiGeG9C$(mIPjg0yXlSt-~CBc zM#65oqCeCjXk$s&<-T3!FY?h%IVBR41e;_fbLqu*A8Spbtu4GhyxpVQ4d3`>eDZs%x^wlZF8ws8 z7|npft%?vNj>Y%&lVnbBbbiIuAtePZW5NeI9(X5cjH|jtKB<)Jz3g@Dmmc9fby~zL zk6dxm36mo5TGFV+xB*{rKahqUUMnAQ8U(DN)J!*-2yg|1xbDyhn2 z7t~RLej@UKRfk#sZ`HWs3kG>v<#Fzy27(Z(FR#v1872}T->7-mRT^&5@kKn+m zYLI~hQ>dhPJ^aXaE!XcFSd^%x3FxH!D#fEhxya0e!0 zi$E~B{0N(JJO9RwH!Y{Ali!lL7CS^QKkOg@D1r0@po%s>vF+-79Eny+bTtO$6u20X z*h65CSHAOokPzO|d#7iiD;CS%zh4497CDNAA1y2JMy*mB*eS@mzn5|K4Af(X@_8=3 zMZ`iC&%Acp9A`X!lSFZ%Kir2D59bW(W%yP_Q(YvsKOfJLzB(dZKx6Iq;m>;e_7gmb zHTJCVl}JY|rncx!*nX+ShSmxu?qp4*eBdWh%E4*6mp%>3?bVfCuOiJ4_LH`)j|&8(wU!FN3sUjKAQH8dTJZYawkh<4*D1o0VGC+9hLt{8-spABho3Dl@II zXrHCOnbT#Mb3IJ9^j6Dc?lh6WEF#8sxtnD}BY`PE9hQy#i7S^aDrZt;v~7~dOC})P za{NJkHIIDN%&c=zbQnDfSWa(s9DYC;zbk$q=0ajf8?h?bl$$HjU41uZ63%@>rh!Md^~D-@Wxo*xe1G9aG`RRI?MkxrY?D?Li(W8$LUrS8^x+(HUY_;4L(v<+N#~Vf z7L(f4D9tbK;mLa_7ePGLMU4KuMWI2tb^2MZSO|^mZ3ddWxCdg&Q7m;IIN9$1LP%$1 z0orE+*$;A8d3 zN7FOk68NIh=*L2_tos==-zUMsH%GtmfMBH9p#9>0ZPB=+8^$QK*RFbZZDRJL>*SY(eyZwE z0W9W9YE$Cm;;OhSf(I!4ZASE3KUi2ziM#T>(_?K2$xiFFE)LHG*rR^YZvl1Qn3U}2 z#I{S56gS4s`3pvsXnN(1&=r3L%@|esGk`lXLvrfRda{B&%;gOx4X+yxCts}4Jx>uK zxIuk$+jLCdjmOyQ?&EciUn9ux+ku|l1^={T=kB1awlJj($qi0>;Y=3mUqJz0AAa!o zlBYg04Z5j(qnUcWAD|zP>kdZsu$FfK*3jqGns=`i`3LBAP@N@2RMIx`T9DYNLU^8z-ne`M$ABN`|gJDjNTxvyW?iEWkyR#3{`wYIL&n@v2-oBSU zE11!~5owwKh(d|yR{kV)s}ow#7{>?VNS96yZ?2`zw4m;%GEs5zgz81>Y9aBc-zEtR zxIj;lX%LxE`21~fm9m$6o9;yMMP)^(AD#_4)0vDYW5->g)U3b;jJI};6wl|Er zS%()mJjI`MzYHzY**wb%*08z7OyI0N8@s!a6Pe(!n(vh<(YN0@A0RJeo>lYU7sBkr zr@WNr^IIotr&txw8cV-QnrV;>@%}&5y>(QTYu7JKcS$J-vOwt+kgi2YBPpSPK_e+B zprn8xO0$se6akks(xQ}rq%6P z`J2Dl=Z}1)ix+!EwRV!Q{pht{`0TACQ&c@hX{nzth6tP)eakRXtB#=&5^9>3@h)tV zZ_vdg3nqrK-tCrk8PL-1=IU$ErpP3%qc6krCOcv^w`Bj!<`?v(!2h^wEXG!;zz;hl z`Rw#I6RI?)=z_(g9ez^mVoz5%q0uF_^Df6)7p8@j$$#b5n=d!F=vcLqhaUvZ0 zo^#$)X~j5&hr4$^Hav&IWK|3uNy_b1c}N~0oC{3m%4pbotH3} zdgGE4v(z-iy;Lz3!0D{wrmb#v$Z`9YTu3R>^p|2ayfqJ zo^b9pYo3rw=jUskUKy{LGK6EQXS0PR$~UhPXP%-U6mbi;7xKB;c9mtB_QjL^i`?;O z5vSBN3m}B${Ai4L_1UhWM*SfCc^7=TpD6-^KVf^RXpjom=ks5S7+Q5(UcI^sn*?xD z6a$-$O-x8pak30F!7em4+*g9|Y+K*dbzftZWLhL0IBUsHIgIWeWjbJXbY*JUfy^eQGMpVg zJ7vK`*yhzMU)Q`I;uBdHfrZ&2D`J%2^Ed9Ql++#iNV&I@d?61XFfc4MEHCK?FcQ%}n!sLX>Jxl)`*?liL`p#Oe z!hYYcx42E45hE6jXsgva)65wO^~=3CyjjNb$CGB)tJuF72QfalbM+Q4&U%i4YSGzM zJFT~R?`Xwb3eK3$*O%#V-y9WV;QDlPl55OqNOQ8~cGP_q zn6)2s4>p7_`_2+QWM5^qG*gVdsS-TC-(1?<8n)Tp3|O-Y$lg(L&v5#nQe9Ne^rM3g z6+TvPQ?4X~B3C9yTj>~AVI)fKMkFupkInTieE2cMJXwZ%zC-y$GQY3)GgF(s(ytpm z?@bCkzp)?ono(LCjZWw0FjXB_JaK9a-*6eP5i6}E2qFc3{^Z@}5xH>G8GzD! z<4vn8r0l=bnpb)x+#30G{_Ej*_cyZL450*Vy*M=XA)Wvk(`u7YzC?RMutesU6r1b0 zUuQ^642dszaj;tXd(^ho5nMAD!LHnV{hh@wvq}85zRj~&41vm-PKQr+7>=LkT_P~Y zil9+?lqK9Pem_0E)&~7jq&;3e7995|ee-;%oO&`c!jJg;z=TY5|Ddw;i<3%>do7U; zPZzg*s;#`%P6$m4t@o;Ib%nq4oQqGw=3hU3(eLM4SzLz7VHY~?SbuLn%J3vx&L>kJ zBBgr~ib7Tl?vouhCr@9=A~JX1c;v!{QayaC?`E9g)O3;1{5L};cTCM(_Bwyl<>1eR zs+LY``krGDK?VNZ@Ig;QQcEEqN)AYt`xJm&H7FX0DZtckf1oGb*>F=7e1( zm4CiX2u^dV8VR#b*-Zi2@^zO3W&$o3N#`p+HrJ}R8VmSsA2c`6kMBC5ag|lUF|Xq;ArRIi3!el==liFt%l%f?M|RqQU9cXOu8Q4)s;F)oZP3 z&2rhZW~ttWl@~u3a`DRTSgY&A`S;L~VXFquYGW{6%iD&O{3FG8uIA3RFAJLU#RMc0 zm@gb)T#}Ug?{P;b@G=~;%W`vR7^4|USJUb%U+tXt=5su(j0{^p7xR&A#|F& z`0j3e_$bfjc#zlg+MaPq@_9qXQi)iDQ`}BbPjPOgrSUPUl}FzJmr5#{w=s6-^9d^r zVUpFs%7o)@_jNMA?mgEVZ!Ua0we4GfeCf;2&upq18@}E$5kBF57vyhm$lFR*d#kD# zgm0^jNQ6p^)X)eg!6Ut^z(GY_q#QaqLqF&|sz1(bL2^BDe6`uvS+XlD`8j5{wp!!a z3v@L_!go&%uLd28jC1%h>O9j$!Z&jaMkV5+?gAYUndp3T`EIP*73JG4!6T<~)%A0p zMxS;%soW9niWyg??y*j-;OttMAs8EY&BD6;Rzx`5+umdO%Ol(yZFjl47R(fWTnZ6z zbXwL?E=qaj5dU%E-tn(p+mT`ebN|l;+R0??176^D_Ec(Insuq2(OJoa%D-73T;K1V zC;tpllAQJF;EDgbvA2>JgIOIe()Ran^4RsB^WW`8@{jEYt#|t0k*wEAXFw-o;SsbX zcZ4!F>@TM+?Y0~@1)G*P?O5Ovgxy9uCGfLfZ4YTLOU^Go&>9~Zd8)4jsg(G*v+AjH zr~3)%@hik|?$THlmzV3!-v6PCjAW-U3p<^S|9bwVvgL;}&}_N|TVwisB7z^LD*pZ+ z^G1c-3+*=|JvEnY#8pmtpIV*#zRcUotSEWHK}*jDox;PiJ8$t~35DBR+X9W}P7Uvq zOcWIw_JcF}mWsI)rTwfy$#ty1x>=faXtpXXN;>1%CR@oG6=fx5Cz_+`J!Zn{Ff1Z2 z(@kTle2;DxYgun0kG;>ceyuT{jCcLibtMel$l2bmW&ixD0;>~GMR(C8CVA=@3W`M6 zAqXnQ@L6F#OU&^{Ejl`eHu~t$@lxb$5b||}@KeTAbhmg#*~qXltPGuEmR#@~ZR0v6 zuCbE#^3*eQd}z%RcGfPUE>V4XEX(;DmFXpC%y)Pxwd=!wc6j$d!UD&f`KNC4-BF{_ zUs1CO++%Jgv1-J@%1dj!HzX$-w;lGCe-SdT-xCL7+NQdq>#&J78j3z+?(R}OUY1}g zd(1O0;?rxS%(MFhTC z%CJvO1k)rOD0n-nPSP))E_U1*Sv-FJ@zQQIL`P2aEhusx)VSo$zQj1Po-rK2|G{8m zUb3?~Vs+i*#?$a%8NKMzOm^A3PD;gNxCdWMa}pj-+Br$RFaAw2l3a*eOp5gpy8(<~ zQu@bA!ML%rW7K^bOP)aIIyILXeGBC&G-EXrQ1NWNX8~R&szOpnG!vF{Hf;G{ZOm>; zdzGp`dMZ)&G(Kh}0BAtQlk+D?i1AJZ@{MXvarqP7Gg0|CQ|t3W;6tT>gwuN*b0=C^ zXObt5o)La*#oCWArx`PM@q~Qm%NWu1EAFwBKfKpJ)p%lZV7={Z)XuTY=*5i53Jr(3 zJp50rOCs|>4sdXUUTdnNjRm)?CK90X!MZqu@mX?&oy!GMEd7=n;TiKadNb*pMeL$C z+pR=J$4W(9UaFQpW3YMeb}>dk`hpv+&HHr<_P8s#GK>Yk_8oL2j6-e9n^AX<57~OK z%a`%7p>!r>*vCd2{>QwB4=-WXcCOQgs56R}o?`j9V2)cPuS1o0bA1>^-Vi7i5|I17 zGkzsD;?7{2wnPvAcjCqO5|o1H_~pifxhG9<{Nu}ixSGAa&!TYt;j*e&<(1y)WC$WQ zz5xp{B6I}@_3j3l&;-)C0qxe`iq>WZ2T?(cIQjzS*w=NBp(Ls4u{uR4EuL}ZCsFS; zV_z8Ye0TcYS8)cV4(gb*mBc~Rp>(|8(h#7+f zHbe0S1U4&?C4615FiDWjf7?J3hhU$ARc3_p6Mh$gV*m`m?NOnYKc}m!y9F)y*8Pgd z_whSl{o0x0I(iS=7EiXh6|IN3D+M77c)akd1 zD>Cs6a(Cs93k=N*dQ)UazIoX>a_uTH?pW>J`*_%u>NJw1lBmHrQ7%6q7Bbp*3}Opj z*;W=70zBv}2=a0rZacM~Esk!fkJa1Fo;sXC+YN;W{5>OpTu1J$n`1 z|A4PtMTn6eU}A?9yMlA2WYRS#L_qzkoQ~b&8QmR0{_`kud8}FG%nk=G_I}ydPqJ$@ z$2rho^&&=sDOSzPv!?KN;$vZXt->)&9)>G#!7x2Y8BALtwLyP>e!+Pui`il&Y)rBD z;a2%;jwfVYsf@-`Dj@>CA_mlIQcrR$Xm&3gtv+#7k5#izJ@?q2$orS8OuU~(M&apa zfyd9SxZbW#W#G%ERv))FT2L#04yfH>N~_jAQmS}e&s8-@WL~n2(`A^UA0y!TipAVJ z{0r@z`drWR3xiqcqR~SfzbBJv6AB*CM4GI_pmEa1;~@HZ!PMe|p_5A~%TtR@dw7oI zE}DKeXyQ2SHa?tKTdn%8K-Sy&UrMRns}H_}4=QhV4dsexE5Ch;gH2pv1HgOJ=W}1e zJ#;pvht_t8MJU;n-tYIPM(vTa+Msvib@`X?7(^^lNLLwKUpBfFkYiRqskY(0CXaV% zVK4cv_=bdM)jFk3$)N43g$aoz0zEE9wb@0q2;&D!xr>#hkK!)G(d#X%7&+=G9f4y*ZMu>4?9bY<@b6^%^N+tC>?YV#Y4vovqPMnYBXyaqY-<$Lq zujaDJi{IsjEU*2vW>*O1yhZjr*_8J7()S%}TASeegbmleQIapeZs=FMV<|jXfib8i z#oC-Rt$j^k-q8L2N4VIFx!aExVHERHAn(Q6@58NbT= zXcdRYw`NUyQrD|Pz&Ctp$ZWq$hAm6Juq50)S38|dVBSuKJ(X3ymO*yCLHypUpx5<} z#J&1WXvM?#hQn#TJvIts6jgf`*A5j5FQ%OtY^3cL-6`uOw7B@XAmH{<@-gQ2i(DKf z3=u1VnnMmTV@t(#af!4+<+W#%J~k;d@)z}3d@_wCzCh}jnsJTNrp;u?xcZ{i`m0Nf z4h}gcYtQ&D4`sWRa96PT@Z$8ecVh6magB#jsrOD#cG_6p5Q$UnlgV`=x=c4ob%X99 zvBP0qZGctb8S^*8XSEGQqKy<#Wuzas;>A~o|_eue3R#TIfT3BrZ!?bGX|cu zG=D|)?l@Om3**^(bhc{nya~t2GXZtQG1T~2)$QkNs&z*4aV2j_{R^Y;Y$DgSjV6x? z1Wss6yzmoTuqARa^R|l`Aejx+SV`?TOa7*QK8_GYE{E+ixMU*ig=9q+R_;QiPf7O zIy|XD?mMa|@_DtIB-ZwgP>1d0nZ{lRRuW#L3W2Xq_k1qqs$-@tdT`By@^EZ;Ba^l6 z)j_I-x?+lFx2aYOBO$>rUTDQmZ^M)$f7sD85Fhz50saDsp|BDB=d+o0kQC(4^4;#d zOq=wD3|k!8dOF8(L?P(+Nkg2HbFNC5aTV(BqndRl3m)>*noSkagf z<+j64$mq>Hz@fom`?6xvpBr-@CjlmmGyreeO5dfD(*R-Fga=Q1B$a3$UZqDgXOni z*5zq8@nSkgB@y^*0yr5HKKRP4IInTq@$L^_pkhNao;7dz{dPO8Hc(|aWo!<|uV=Fh zk4R2BR4JA-SrIQppx#(A#+S~fJMV)u58uRgiZmzWYX&8$&1u2-burNg7o}!Gu)y`U>Yb4`H*X z;M~!&e9%~lT8B1X&tztW-qIS5w~u-j_AdwUI&vZZ6`Kz1#)-JAp~rM0-5fYJ{5*>CPh zVD8D~VkLUQ`&O+61@Jae9a2h zg#%yIK`wNgq}rez<>i3FR1jhR`~4rEPd{y)`1bV&5@JS19_fn0CX-K&9W!l7kLf=+ z7G*33{{9?D@i-)$<{D}H@76rBAnG7yOZk{DPo;;_1Ap)1#R!=Z?}<8(34E$2xJJTk z@x|HAP4Z*FwF+cuLV4piryz#6jnSWY30G-SSD>$^Xc|Icg$i9o>xS0W!rYlYxvl)< zugY7~sCzt^Jr@ZpgilM{iiMvmBa=d|onm==7cKvy^}7vE5@E2}JXo*jMXD*IP72>aGb){y7YtHv zh6!7}Y9}qvZN`S^E1a~x|6Tboh}Zz}NH?JTibQ~f)ZmKnk7 zvDC`Dt&Pa8dL1|u1@sxIx!_Cs)QhA~2{V0Z#vH)y5NCf#BSbPj{gy3IVDmK|&*#;m zEX>_m-?v9KEb|1u()+k}!NkdQdQxcUyi&yPY{e&cr6KVqj3BhKn18Yab9hCAD0*;S zhbRUK?WtzeZMzYYz*z|+XSG>CL|}eD?!5&Ufq9wchG`pt`H|sJ)_oG}US3|tRrq)B zXXXq+jGx&T)}OHHHh%FtB@?cx@y|7lw5HT?cF|JYGkZpeW^^9PyA~pFdFd|s-AUl! z<>#*Q8p1dENXq>)kMmL=nL(zjahjX3y$k8R#QOfQvp221 zCf~=sk4y8?wwE!1daJ`%Bj%ET8S%(M8GbYFDf~)1sAVr_#JSiW_OE~L-Cc(_inE(q{q16oRbRQE`Irm zlF+e|TAwOf_KZ#CxAX_w{UP3oQHtqNF88hPFEl16C8fw!K9RGya^+pXE&RkbZpQSD z5xVuO2eo)vMDetJUG){q0IMl4xebWYx{hjt(Hs1NT>Mg%XlrP zSsx85s+Ann>PsKp0pR6vRd@;KaH;y2FuCrj-&NgF=SE~xc-%`*LoSM=Og4i{Z{oc> z4~IdM;ux3A($lz(rN`&xSCS8FJ39p*(Ry_tw?%oUht= zKTdcNdq<|!7{d_;W$VJ>{s+-RYjp`vKl91m@^*UL{((P^CG1g> zYl7PQw@Yz_>$VxMIL{kwd^iwJ#=ER^`1t3fb^425_mrBq-$Z3N3>lr1?o|8Bm07LP z$xxO1_i1nxwve#$j|Vth4dl-Wt@0gg?43M*>p_yWC!FtMt$8(7d41{Upo(bAS3k$ZBwo+YLVWsCWY>PjM@fRF z=9gS?i&aZC-|n?R8nonKjkOR<>@#7~OY9_+9R>D;IJP)%THQv2MRCHUL`ax>ngiaF zXm5A0{nkg=6w)1*?zdJ_PE5B**wxXEpG`7P`(>RhG&@|+5+_FL?e|C-sl(Xvb&0Kx zg~iC5QpBq-q6yc{(&ZX4Q9sKMGd!~i8E@TW^_5PYXE~DIU#NTOs<)u{%Mf?iZT9y& z9eayF{1sjxMp^ZcNx!t7j)5YfqOE1QA2!4GZ(fsN7q5Ozrw41vjrHrNm+!AU*WvTV zD;JNB&h88J9C1c`3#zIE`K}m1e);@_Ge=PAfY7hKdp-uu-8BOS{0=U_UT}|>&B%6S5DclIlep+ z&(ESUu6Q1K4)pHJQc=s~hkI(I*nHkXG2bk^tz69==rmV>JToRe&>Co6@&_E&la^j_ zeW@3;VRtZ>sik*Pa3YkSP@qD>G2`fOUt4KCE0jEYc+8_$b>Q8bR%8>EVr_}hWZu9= zax8?*-TruAP*rz5H6Ep1`31_K%yzoUBnZq!hxwnD076xhqZu0yhzqK74pLMJAW&o`V-w4!?8Acn;Bo<`FQl!lD+vp$NVC#Rp{5 zTI(Qz8qFB!Rh0nGBinFzg%AGrvrk+yL?DH4OmP@+&&?Db(ItpkDzAvZJmfmHb2X@{ zt{*+v|Nqs~5ttj2tHp9w3xze*VzGvNe0=%Pr9U}2IW3~Lr=VAysr<{eOp?JJlzxHd z`bGjOO8%?AU>AUeF`0?HmA?qg$@;Fz8YiQAsY2*yxj+Ob9wYZ1 zXYN9RwJ9_eF@ooIJ{cl#=O?w+D#HAMTK*;bf_P;gK}Zp=Og!ucHpCPMD($AGo=7s@ zPjNSUm@l4Bqu|R*|H^lt&=?$qoj@QBU#41H4A*yU&A=`8u7#NB zJDy++)$aysLwctoGi4d1(x~&m-Vc8JE11InN(4!lZufneS2WI6f`5Et&i%jnv7nMw`Mf!r zU*F2g^%?^Xe^OFXxjiFvUa0i0XJ^0b?Qt^&HTac<_uoskJ`n`WDUXhfbU_{K5;7Vu zFi$<^PLuDpw}kQp!n?rBar>E+`1W=g&Na-~*kvO!xhk9+Z5u6BRmL(hGJSWVpS6A| zQPP7$jW_Zg?ny^!TZj@WoC?@GeV>gpt0lVq7}GjbNMvr@G3*!2wcD< z7)Q%r4-CN`9b%AMAu0|Q^U#jvJhfAKnLI>bEMFa1W?9b$kRYBK#*L0 zkW8O?Y{A`}ow>lkeHcc__kfuCoBVZQ^-D6V=X_yRw4Y|1ig6H`e()FM(VD_1mO)9b z5aN3(pfN$G+HYO}x=Al2Rdm9JG+I!f4x9k2x?MGUSA2?9n~rX5XSpTd_^2We==KXY z$aDX}8oJuQjx~OaYrv0^!2#2TNwfozLsH5}o5Rykc&0MsnW{t0?ctf6Ngu@FB37c! z8CacFDQ z`HMVg`tEZT*;nqnGhr_#6{T5%!;p=a!Vd>^``wYk160)9Ba0qDSk90Bd>B;JCFRN9 za0IB0l5}D@lla20knGPfE3NhWBYM+p5c;_Hggl1;Zt%J4-$oOZwo?u=i>@B>yp624 zZrw_R`N$T?tQYeaH^TDsECm=7T&8Ngdmoe78RFC8AF8|`2$hbSpD`wyC)nddE7n>I?qm>!>+xIS;X)$ty zXhXzjb_V?YU;50%JgR0gI(FwH*IG+791e?Fg=z~+FKk(S_Qz&>mzTw>(egXSy4zz- z<74aeeCky0|2B@Ox5>sSBfe5h+1I3d4m-ng{tMOWcOMPqhD;5wR3ODPCnv5b3L6dN ze#*-bt;{oTKalH#ye=sJ%C|&7cZ1wIh&6qvRs5&m%;Nk9VInuL_|?hONRjM#?~730 zSq|G5mmSCTY~@%+*|Z{6t74cmUsrK64}Y_}dfa|MkwZ%7XMcg3E;aZt3cI5VZ7QJA zB58861)^%7N@$(c}1P(G15v zEeE#-RrUMI2ksFi=j6^^UTYzyD2-dad9gQfQl=#8!g>b`I=ZoS#e@4u+ zF8#b+qK*K6T#wm9Dy%f=ut0M6eI&x+W_S>P!tamQJ*3Q~*FxpO#zhW}pNR698U*+( z+RH=%Mw0ITTDAe{)&+uz*h;ane5i`~uNcu&mYeOD+aI?)7*;->p-@<@G-P?W8F{RH zJ}$maJZ!9YJ?EiKz|_wn9{w4*dSyZ~10=%t9I#vWD=>b>wMYErv01;q*newNXY)5M z)XP|Ny&XoTkIP3RPS~JVpV}`DNRSqR(w^lU&wQ|)Giq6-PIaBZ)^Ua{1~D6QvI8M_=)a9cbYGMU}^v{?aQA3UNUa&?pO3ezh^_OVvZWqwcC`$o-JM zev)edhjtgZ%|C-RsUp`yUV-9%YXijVf|iaMmvVSoCBl%xKlBs8^#7cG{*NXh=6@g7 zcxgX?0kis6*BvrU-B3B=e>aW)4|Je_#dS7=S^94X6Z~fyHyd^^yW7l-WWP9B-k`F( zkXRM9;!!1%Q$k4lVfR#w#Dxo*0)^i<>i(RcXEA6g%cC^AF`!)Zd@ImfL@=Y=G!D?a zaX{Z{0IDYnT=d$+*_V@$hSjOKRtRymaUsz9HMT^sDIaPH{;r{Dyg_UZLqRxAt^K_# zoRX50cl};fJU~Lc$e>_<&QBh2vBgT4LC?Fj6s+|ivJ&-pN9@lXPk@V|EwM(pa24QK z?H;iEBgTL!olGEakz*${%=V9|IXq1|N03!sj$di zs^%XV8~G)dm%G7LS%Pu7NzixB|0BVd=nv-O1NE^%U5<_i1d6NN?+16cnPIn9E=1t> znj<^-oQ$d)1sWj&O%Zf^Dg@?uZn@%o&__!d=lL`grA^f-m|X@UEy?(=i1bimWcMHY z4D}XZpXpBgZ|w8`m6$f-E&aitAmCV?&B6ce_ycYvBB8}ft$!RdrT=!vY(N>-8d%Xr zs`}~0F30)E39$w2<%o-m=lc>R7$D|C6s|^lId$Kk2+i2SP@M>gIqmjCV*xZ{R%lf+ zC7SWxWD#EqN}JDw%$5y}QGQ|U_VZlVP?}7=A{ftI=99{#@KRSI{{2ErT<7na$bA!tD( zU8@F|_`Eq~89JzpIztk{8Hg?WSPgLeitqD)FU0ZRzZkrYASLATzf9p=aOZOUuk(M9 zGvp^VeS~=Svy_+qUt(~2ZZCY=vsbW$+OGcYQcsd66i@YHZF{kOkwr$x2`YeP?&Y@l zU~#tVhPH23FL?Sdvwi!cQQDb=yDdNlemke?f)&Ra4D8esnBTh68nXp{NKWpq9C#(; z<%$^yxq-5?f5qbwLNg|kL?B~phQ`J)4S3VqgrIR<`tTA=WS9qwKvYnrPXyq@^=Dc{ z;cJTu3etgQ(KbK4L5kf*>7c)Gbs5swt>wiNm>WQ0b=<{$);UVyk8zCg+NG-PpK?c z$F7HUXyj}+k&s|T&;DJAj2>52rP3iOOM6Q&XJwpK8MtW1DaI8m@Wa3P(T_2rG3{8> zwMkkV&F9+9I{?ASLv2C;b^eEce3idmj8rUI++9Guc6EJ0E+F#7B zFBX>fjnB+_6hJ)l45YAK&s+sw&KIL_P^|e#4z9#M_#vZxuw*ifT#`*Y%Etq!_qhy~ z#(wp_EF|*X3@6`3AVHXyf*C078Zrw-VAkGrUeAEwxq8xnAVOpC#LRlu;C6VmXYau6 zq>Tg%!EgBg_>bGkgt*RMOzVN*Loam~tmFB7<|_ElUuvdYzumdSZ+$S7Y%}}H_uI$% z+n=GyJAJ{`8YW&Ny;BWSMVWr3a!`BmP2wMptQDAv1rzDx{RS(wYW;cQKQmytvK zxjM8)45CO=cr(4+T|<^J?UNi`hGsijmXP8t@$MeC3K7^hl=-lPJUKKx{_Tj4AvqyT z&xw2ql_=`{QlBX7U%nk^XpotB<@@p;%NaA+RQAc$`p$$cvL>ft3Ow;dAcMXC^ekJS zfpP=h&LMa^JmjQ@1h%rgKLY4s2Je8Ack4>_y?tZ8m zMcl{&kNov-6H~xgr<`xEe(2%aEC_(u)aX1FX`A@+{m0Xr(Q$CLJ@U~{uPrSt!3@3PurlPf}cOhv=4I&Q}WHJT*$YU~#<0^D1jdu=JQF!*O7NPL^BM$v_4ZsqRQ4 z+KZslU&+G$XZH#$_-^A|60)Lss%?4MZomI-^I2$IzEb<^L-F(H3`zn}{Ih=N)&$m8 z)NkA24B<2X)Blpi;_=b-h4LMvGxuSW z_~JXsxkOHFjE0(8*kd+DDG@FN(w~gUYGCo?p%Jo83L`*^vvANrx9p-`P+wfi>iV(VsFAQ( zj{R2ihVg)r{4f6;F}D^7FpXvv&8@A~DpZtBA!4|v>i1ij)-RZMJ-

KKNK&jb#a9YH{L2cHXYZnf2>nTTk4L4 zaS+6J6&tI`W(-#h{+|~>dM8WzGx8q{n9-7KeS$ zSCdMax`AGaMnz7%L}bobIR3dGy;gX5bf+;MaylN%W)Zyht1XvVQ+{Rtd<*va(5H9_ z>OG=4@f-=hYda8(lG5}V!|IJ3y~J1{HB))6IrEH9f6i0&!mY;`|6K4|R?wdk-~Py* z0mxH*CfIxw9(^%%txGy!7e<|_0P|uU*~%gytu?UW_)=^Lcp5#T0XDo|{T$L%9A0Qt zJgYbyt@yn`Wu{{v;vbp; zt4GprE>6p`ZA>g%AGZfElT`3ITrYQm_Di~<)aDPsx;*Evo$lj~kXRlei1#r#GFr^0 zMVgf_G-SX8tptyMm=%e!`f#Qo;o{GWcr{E7h%}@sU+yyBb>gRWeKRv;sDqs+fSeY0 zj}w42FPik*NjN9Rb~L{Mdxqn;yWM$FIAFChK&}19B+LK)L5V^1JQ0V09OL~8O@pdM z?i^hWu3u`a)iu|@AJTDM_2VurJ1BSQFp=4QL4tc$jp*4Fn^$MtF2DT;yYYocfIRpU zKQj(oho`vtSlHc=YqS|h2d3Jm)UK_$pHAZ{u;qq(CCRsvs=@N%k7RbtK1>vbB}S#v zNd^!h+&N~962qD=+!C)0^ZP8g`7Bdg_pG}M!%J5jyIE4HcLCk|#$rves<_Jb^XM2l zCaVOnVHqYei@>3$_AmUt8&b z&RqLGIZ(U1XYrh#@-~YTj=;B8yE^gm`DU>fbA`%}Sw%X}^eHot`ar-_9|>~;aYN=q z6CvP9_e}PIlB0#Sa=1&yuQ#g+bhM?qtzq?z=89=a_dL);Bo8>@rrY(?R8}lCN zhW8v~TKgmw0Tz_hS+GhmtP4iv9}JgFqL5|hWcx3^?M0G-4|RuozCxQO4tYH7;|mSy zdH|V+HqK8|o8@-BzsW(!ruZVJGRM1yxX2>!$~Q*xhSh?Hn6H+bSkY*y9V5 zPX+fs5BJW9xLKtD=ZBZyhXM&sN&K!OUww}c_IId3T{JFDWDBsjbr!a)l;$}OIk=E< z*(s!-(*45J5;3u+7VJF>=J%o4URyk+%}NgC3{WEQd1yQX6kwI7hY13?pM<{80|*Dk zaJJ!$WMKSbgn;19Kf^*0jE%mDL;v0N``nbWUtgOi1o0xW7Ys~$$8$yafYZiec}ABT z-4mM!FXA7?(n93&0-SE) z|C+=TE~y%quGPywSfpvVZz_Wzbxk7=6LgW2t(c=Mue0S0*g|dz4pWkbOMc%Q34RP7Ok=dZsxat!r7Tb+oPt zJ$L{j-=IuQRi+ja@k2G@$Np;^`kMj4{m~y#FRzdVA?OTefRc?V;dje&!714zR9{pJ z0MR3A4tgW?n_~Xc2Gi8}(oje0j49&6P?!?+0&q{3DO{RVD4S zT(5)~2^NLk_p^t+VLLTce@SZGG_u-PrJ%6Zd!4F<2L^?7F@fXk4ZRfR4DTzDb`!^O z-J_?z=~a?@g7Q<(Kq6Kgro~srqlE9|uocoH-@GkPdj^Un<0H3v-?=2|6$o%FO@nm! zo^t2JQJM&XMS5$>G%sk3tQ-EDJ-HS4Nj z?~V?H8V@N4F+j&{(l-olGMW|xRW^-;Qk#Z*L8rUVpol(SNHDHX-#`|J`c?j3?01`H ztJIN+u6nyNUaER-hw^i?sn{$~lcZ{a1LBpl&Ly!o<(bAp>>m~pjfvmS_)CTn&edr-`)VAXv7Z36A*#ALYQrS}9q7=nhxXY1y+FtxUxTDm zVY#rPwpJ>Bh03_z52#+1LbUdWCHc#rzm%$7wvxNC25ou9S>jV|WT56AP2nmu0y@q= zhWsZxME3d??oCMS+*Ps~Q4-HvEh-kH%;kHb8LU%k&nmZKoAq_`bw)`-*?F=j2F?f`=?l@p_d+fv-*DOfre8b$Y3t2Z%@qs zr6Jt@%bSiIACNpg^&{KB6XX+Xax)D&8J$heqXrlSyI`6WX{?^L%vl*CZgrp=h9IZ# zPs$8U0V7IWFd}Pt*>t{1VHsNG$;g(HF;{Qa?a_Z0DQOu!yG$Sy;*h6>@Rbo+x^=@8 z3ha9dcq#n*ThpZIQRo!}J5FKz&V|5xQzmj0AoV;*GozHFm}Y7rbI9*MBKc1OB%pR2 zIUjI*eA3`hNhgYT|Nfc9qg>$dnrr)0*T4^3z`gIUA;fViE>|rj*hODw6XGQy%yRfj zomQz-0Qi2_^`X>^p?u+Q$lHtGY}Fb6QQGTUDl@8TGB3nlfHq8Ypz9B8pv6FzG#*!G z=$2yMW4Pa7MWM4!Drr0iL`SA1k@jY*(icnXq4+P%<400hc$sw!=Phc|%T0<(Amef&-&jb1UlOL@JwlML&7^VG8}nhjq)^+Ues+<-Nqx<8qlx8i$v zm8;riLqa_5uv5N3acqWauG38~XaIrUBU3EbyeoNm>fq-ggzdFqd1NAJ1sUS?h>L>; zS5j6#@srLfp6qhIGGHUvFoYlitDeBXdV2aqwzuVY7fkOi)9ddtM@MvSu(ua}eqzep zW*p-SGJn1)y@MmDY>8J}%iFK&eH;Qxi`!4bwVhydMJ}8eivZ#m8<_nmL^YUZ?$rK_ ze|+`bncxWxRtpQUHa=sJg|ZOE8i~`I4@v`l%k&&EnI$cU7VfDZ z#z(Iyn4I$;W554H7~lua%|+{Li`l^`KtIgO(!E{8zu-hP;+3V73nj1YU9kS&$F?;bQt~7RWI;4iZB#@JIA$~y16COIAq`$(K=SDY^CZ?@dk{skg4sl85_pRj- zF1kt1`*f%im*d)*997GVxDNwSh0I4?^?Y5+2UqPZ*N`?Vn0TI)vbrVy z#E~|v3>)q`{p9FIDNwC!CX7$zLDv9%&yTox>UpumM_<3pYZ>wk0@p|e3kUdsxA$D+ zA9a(dQ(;_A9REe!{A0b)%6RStW`$oa7qW5a1-Gxrnxdc++UyUTlIwdevi8R^FP2iHMHmYU5@Q-Vj}mx z)ZN1(I%taq|CybG;PaKvCs7m$nP|lblDf067Py)**5l)8&+vu;|3zxbKE{%po+*Hd z*sSTo?a1oyhz`a%ccN;FxuI2(VI%3zA>sR2?HS z{eZBrZAkcwxUBa$Hi{|UItJ;3p`!ooX7cJkQQaBmJ&+`4q{$^TZH_eviHLtkW*|Ln)YZVO@^>0n&Ob2N^&nfHcM92qQ8K`vL9DdBp=e=XaZtB8ghK!k&dV`L2 z;|3&wdH=yXa%g|@RK`-){=SR;@Sd&D4|wT&P9DPYL1aH87{C0LgoQ=*GtN8+NyR~U z;)q+(Rs+9!_^+tvk43di7vq;mLE>h3WF(~zcjAA#U0$ll7YMFt2khtryT@}H%ECd@ z-D)#Z!&NDMKm5%hJ-+f8xY5G{4Xg{+FOctz_LxP;KviGLb<`cDCkWgbIaefvk?$_7 z4C_{LB2p4*SYkV5X>t0@iNloRUlT2O^SP&66m2)Cf;o3Y3aePgW&oB9hYZCi>KIiw zFWdsoc~Hb0FbYRhFEEw9sQUak)HNTHtoTLtv{=Vod@2vWt(b{Ofz@U69k~s49LNh= z=Plrd5}Dm5tq5tMXpa`{El-tv;n?&cT2g(SnAd4vR&F+{rZl`ei`|?vyD;4S-KZSq zar8!M_#HEy8{_Lo_Bf$04aYtJ--O9c`V|i({fCr(DtaR48B2gQKbc+4&3%xN#(+?6 zhGyy!k6A}cG9{Y%znS_aei&)wVc&I}*o%f>)Pa%4+#jT{<1z zC#uu?Rz{VOp0Hagric%{uk`Q(>RXxmD%@IY^G{aG0Gb{iv8UMce>(wBV_Ar??)^slG0|_mMIqo+D%Bt6aZO4i|Ck*s zc9eR4uHpR4g1_WnUfiYhjT+IO2;Ek1yrp%<*~=_H#re&sn=_(t1Q>KSv7uE@n$bbE2jNG`%AX^EdP4r4hyq#1@@KoKO z6ly>hw>f6gu&~0M{es+a?C_{Z!L%c~tjWFA_2^;;rs~thuD?R2_BJ_6?t-g0p}WH9^)K$3E}fL5bG_;m!}kaSW;b z6y@0xapI3xgZ%uX=n}YTY!paZO)N@&e_ECFUYi@O%rmx#H4vVbbPtA_N*|1}(!X)w zx6=rrd!P9BB+Yu-jrYWH=*BUG+Wbg-;1mKV(k74BN{w`8BkKEH`9{n;ReEerOBtrQ zW&l^$ywvhJv8v)8G?hwtTv9D%V-Y)Wq7tWkm$IP9m*>KN>1?UhVH_5hkyal&OZtW^ z6E*K`A|+;yEU66ncAg}-<{=mOHFs!EQ>Ne~B_P-u(hwYr9q+x6CPnr9W#?(Z&LLhq`qx zl1Fo{zl5@v%~KnI;coHpe9XI_cYKAvXB6|&S`IJtG`?8 zx$~E#APrDcLBjicjN2kzulnY?|DT#4uW)vPHdIm%um(DIMqh)Bpkp(1Z*6sYz-=&H z>gt`2Z*f~DWdNF2qQ9nQe=8TEo1;c~4F)M+6UT4aIVo=eQX`T)MnABx2wS$~8lsph z@7_I;d$^dHJyVF-A)oq9)=TOgZW6zq8j7RBp6W`Y1$Zy2_-Vw6^6AB7a=C0*NG|z& zQ}Tu4pq=%*IOXAA(<{iDxb$bWc~M=gL7*KfFc`OZfjM!j?Rl|8KGC$G&I*IYApxD1 zxaCDEm)nkbpgY>WaDBrN3(ML|RzIqZdr}5P%4yzPPV)Zh$K&A-5)tP6^;OJ5_u3>BPg($LvueXffM`+T*IY+kqdse52O1Wpl_xZ50EfGqyY%%4kVigQ`1C0;Q zY&DJ#6#ZKszL9d`p08$=AK-f{(w>n?Sc%&n=ZjeehjSSWu|YySKEV=wOveg6#{D31 ziY~Cb*^9xfQPA{j^kDLW02zF&0eIQ><+-m?sHO(rp9efa24CDB}G;6!HI`{8h1C4C#Pg-6w6IME;%+{u8Dd z-CyNDq;dyb#Do=wIbeR7Y>-?d;Z_7@%R1by9Y`9qZ3dwK*~ZEmQ2{Wa8g(BxB{oXN zg(vy*?42ykXqRsI2e79@XVMnlMwKFdH#SLCK=pE7o$INQ_08&*2elmz{pjGQl@UKh zzbT}UZ0(X5=m&}Pi{q0{GKwPje9>(+9V>iLg_))}i%67oL(jwV$w$qOFHwm#2-KI2yr&0x5HqMu0ljV1i@9f5s= z#r_K3BALf+?SjQgk`5=TFxM$5B@H37TlNM0lbrFY#R9uvoP~zN))QjXw>7VK!^;|6 z*|bZB9*z+Cb~gaJG< zd3V*aICcD;A&(LGRqjz8p5{p?|LVl*$9AqzXI3LZra=;vVr@``fmD3)LIbURQZn@Q zOmsQZ+ehoUr0zz|r_CE{b#qHyMcXhKnMO?cUeY5$kz)`nsHQybWsVp+>v*(DJpK55 zy%{r>HnlORsA#HNVhzf^HFfEE6xHA|N*6JdC^=WI!TmpIONk{<*P-WP*o)`#?>=72 zm6MkUCfo=DudWB)V?vG6bu=Ve34ex-vWED7-R_a5x&1O3pM9|-@U$-1!NoW8{Vgv6 zgJup}k#|vH80r(ZyDmwoT&KRKSmnsf(2U^W(WF86BUsF1+g7!~gRNEMVBJ3T366Mf zX_p^`-&fInU7>HV`$D{Lf$q2SCXe`iXuw^9%G2)OT*i&_GY%{^fLyqbvv1)6F#|f4 zqe)ECS_++7PsjQHvgZ4I8QaRG-7uTWpN=tO&=ex=G&cFt+HCN!w_~RBHH=5{d7I?4 z>9;(WiM0E<9x@;n*^Dz=6_hYT)uOoQt=y8ZBhNL@{F72O&(1!J=sH)-dD@-zS2qxy zs{VJ4tb>v0|1g&EE2~3^pJK+?YL1ZT^7NQD6nQ17FPI7tTu*q&Di7%t(%(gBB2Z_g zru%Vp1CLOqscb%>qB=RxW1s)pb}E5rXb_)!4;gNc7Ag8)?A3y~ih?!2n5pXON>Kq4 zW8A;OMf9*mFjII6ACdh(bnXJ#Ir%(7GN~@3)q!Rz0$lC|3Igab?a+uyvBItJaI<(@ z)pz=#w2xO<2BPRsjgJsN-_!Au8!3ii@cxkv6oERamJkIR-b^9Zvl{(~Y75sY^Qxne zdp1pr(gHd3_1Phbweb6*cgQjCqS@F^2{#3hCc$GBHVqZ!M}J!x@K)RT9b-LK75G&p zv$knx95dNjNzM6Fz)i{B-dWWZ`aI)>C=%Lx3+s&Xz2FabmOG9o)6y~Be1}`ce@tnj zvWw!nCg**=?)0ecgV?deKVF;%>FEeJ;#i$TP7)b(XMF5fl$=&1OwX;92c?@ z&`?o*Yl|#WA7AY~ZwgQVp;jkG!Tw|B-o|BD&jtDHT$iUBmIO|$dO!2I%BG*%*8lsi zoB!{-{_r->(ChYs-q$w_u9M%$&)serilG(GYfK5&X5Rc<{oE^+l66Nq@BsE$; zwgp7eN~cNvjE#HS*ScQR@e=(hiyh|&B^{0%)&gVeanhX<^lkzDD}~@LKA=ubAwA-y zP2A-yeSd^u1z9S>TRnNr2&20K0Lf9UvG{{|5Zt)nmJ9Jp!{wp{tT4B6t^qU7)-bL`jwd3Gv5Q^>t&N2K1P797*|YQUB%uzK)DQ2@R-?i^ zvR})FioioKz}@$d0X{kr7Lx|9WMjz^nITHWNRjwH(=bx)-U$I#b`7Y|vIBU?O-N?0p4-rDK-w^EAnyr$i&q;G}cZ8N_WflPtV!xXLveB&8 zsD$0s(47SfffyEfhO8eHXxPAAp6tGFyE`A}KLPAVl!YC+_3BBnWR>0O&bv&^X$5YDG_MSb? ze@^Xz!o7*tn2^MG|BR>6lyH@UXM4~rwCV`uqVHB|j~mWQH7>G0!<8|Qh<(5mEO>}k z{)yx+^YT7s8N!Iy)N#}cLaRKhgMV7#q-Po}4bM0Fmtq95NgV1C`8lsY8V3$LFL?3?`;y3%5V7|H)jFwB zR?!k>p9xOghBROEE6}1AJ(3oJ@Wo$9JboS-#nda}lMb_~wo)z&Sd3JbTT*A711bxH z-_dk%Sf5<+Yh%Nk$mJreXoJRBS^M@Z{uFefBi>Mz@4ka_s0M?VHUqbvw#pmPuS9&wk81aiIawSUym^(C-Tb}v^s5(cWT0@AoQNj)j#)Og72i2he0{T!Y=3h@TAjXSASK$YFl?i*4_tn8Hx&tgUbC z-oid@(Mp3sbgvfBh7)vVI zUu>$z$vlPkc#C#x>=`I%F_vO_gb&sXs6m~`UmTMy)Mg`-75_C3$OpQrYFPJ>7SF<{ z5M-r|cKGrtD63R(XCsd>1#jGsuOtK& zt4_si?AeRjF(F1T!^|Qlr5^se95;@n=Qwn+&c3*2ik^G!UBZBOll_uojb-GrF9Ij2 z*#!SF6<}DQ;9}Ot*+^1XkG#QS$v8A?acsBKaRbjmIO> zzkU0CEkw9eBL5`6X!Nr1^T2$$dRt_fOxJr^Yo!6}qiWS+*Hrm;V%C4od&Nwixi>AR zP_7<)#l?NpAoAn&pC2s&;O+NQnkA)F!iE{oOyR?pZdC7neE+*`@j&i^HDUEu^y|BP z4)!{rn8b8i+)DcH8{)ibaVId(7ZU@?EyCKYA zCMo3~7JnWzKd?UNG8b}>`i-2b|IpSaBY|%{@~YS$87$+KS7M}QnNMf4_rfga$Q~zI zx4ud=lhzY`lN)Vdt2ToBGW0-nuddywAGobR_&FpK`#o*WP}~oyy4DpEhUB(EZFS0; zeBr6|*#`2jGX!XY7uybgaGX~arIq(pdGctH=wz(;sNz+aCnmQ*e*#CKR4R>BCFlc| zT%R24c`hq2RXhhm>s96Rh)*KFhhy*+NY+@Sjb~kkAIne@5YK9NRMJ4CiD^HHyxhO| z_(m^3;$zo<^_Z}rj6H0L?z^hb9|o-A+EQ1vT?cX2o@Y_+$fd%ZpDgVihGrb*wbVG^ zb6257V0e(a`XZXF4V8j9DaN6&Pm=Ar^1}K zqz30>nD3jk;s0$+WK;`on|8E?!_HqXyUnRVKW(BGBgc8&;qwEDA5GD#OAw*{+Xmi1pZv;$@MQY{WO-QG?2e^8Kd8!j}EGiN5+j{?4o z|Ipkht&Zn-pl>g1z>utdu2?)f7T?A3TGVfOX%klSsn{)V$!ds>YwPRbs{)zuO7qld zpSF)hRPU%}ty#W)NpqRy(qK?1=wB}SP?WnoW2)Y$uaIq>3XYDg7v38x;!W5-3L4%Y zC7v!e0Wf`QP0SJ^HX9_2L&*#!z?@>FXX}hzgkyyB)$qx|PZi$gm;SDIiJY7afH|wG z%L^L}U(s(|l;(3~?IlBhtcOjVW7}B_};C=Ovo>F2w42Y=+IY~-*{;2oErq7e$0DmwW2Z-{VyYryqP zb0IR4qWl?Gu3eg{4=zOjHd~_IF2_$!uZqO*Z4!fXaUGH(no!8Z8S$mMPUf99xjB53 zmQCX4fI7aPgVlmx`dt6*yfb?KiG;=7)Eo9}C-_l*1*vKt_2!pYN2HEvHFQ4bWkNa7^&-3~He;?+M>KJof=k-3{ z=X-g*-Y*8We&Ie}TSD_vLo$+7|CYn=j(}Jz4wLf%f24y!C{5O&8u=!krw6J#b|kP_ zK|`bPxf{Wdl)2CXD04hJ+9bSYJjhia;@l8-;YHsC-w|vfy}|tA!DPxC-EMDL-l1ZK zuO-=JEK5PYk4OiRafGf`l)lTTz*#TWyvd(WU&=h2vqv%ny%yLMIRscuoq%_5_n*E} zzB;6`E#s&tuA2~kwb)#~QerpVtV-s5gyi;aZm*e>Z)V+uwLGQim8$J+ ztKggB`OX-~Yu>rpIOy8&OV2sz&mi@!y}s?i$(LL_TXWCE=L{DXI7yX{^l|>7Ke6L> zVq(uIQ5RiH#QsD$Hgg}hEiAvoYNi=Gie(>KR=-o|J{rse|0EW)nn{nYrt zEAoTZVawm4s0|l=GUZA`jyi`~)-8_@V_rV?ir?M8gtWe{{3E*k+A-}d=_ozl;P|ai zetADi$TGsoL=?^_p+yvGl-m(Z*Ghe4T?6mW<)A?^Hx;Vy zp0;o%-QS67Id#>_w;&N1VflJ%&(9YZ$nEwt8rIM~+~<4te$DrQ8HIv^a1Vhb7<_4!v;&<2BBm zC;MAVok`vbM%Nko?OOzBCC>b<7b6T>T;l<-uH)mNs3<5P9YM|8+Cxy5;-h0<&zWNh z24XPyry2>1Wa&F*fgS~vt6>1r}O+7fhi?d`uY9C9V zr=`B>akDE#B-_;D+u|dmhVN79olP=-@3JK#__*%qrtuZN^|%d8w0EBR-5LfT;}YC@ zRwspkQqXWYyiK-W#BIKQIA_>0icdU&)iy?kDVJ}{HE+q6Rw4q$ha6DwM-8~B3_NSB zi9^@EaF}+6>_M0?BITJ7Eoi-nUOJw`;);eqS z2Zgt(HJQ@}&`L8>5(<^k^A*zz%%p8k)BET0_4O=Y1CRQSD~;a1R-O>Kx7`EP;TWwI$xQzYs{Es_w1!#&d zP5e^AyEQ5v*|qKqi>=<-lYRFCzeDSRcuXJ_E#zqIeI@Ztb@>Q|VUH)=*8H8rhs*>S z1&WmXHRBM80WONwOlR}bMa9lcFD2Qd&a;LE&ScNmzt#{C|7r&?6Pi3FPb&y!@|VFf z@}6N1>K(t(x8|kV(nYD7BUHP9Td1WT+&~|=Wf;HCLKy&0i@doRY&K= zuHQv)fo1*nH_8HA+rxidTx-4OJojCv?wk~Jp%#4BwT2XYyf?nd@ZAf??mFeV7hCwi zOpczW{vxd7;=-xFgE5?eo0-cl_1Zs!INotG`Si-b!ca-O&y*frRY>jx&t*qTZzV_0Lo=}?zdL~ma%B-xNx@zm%UJ<%+W^t3mg=sB%emH)&ObWZEik9vGC` z2DP5$Xi!V=X3IC=UzrvSDjdBa#KeD?`q_vG3Uzc*L|1PLMW5Ssr}^aF%S-5)(-jxl zIE!=^Iz-dL_S2-q9l2V(P0@)VGHEr%Vgc|8>&5+mDJdC!=($VeoE_E+Z*sW0A8mN= zcbAMK=W+~dt3m+#A+Zrgi64J;9OKum``eNEw=~MmJKxcv?{gUNOFC%;Qb-}kX*Gq< zp1(_a?`ERKcbDF1A1675g4-C=G^)8WSQu!!|0w^|0MgXz9J>T|G~lM)Ev1wh=t~P$ z-5SHPyKnN{e|sBkF;13yJB0?kY9dJs5ZU4~E(N~8DZ%*!)fxsIMPM2MzSrCzvil^f z=GNQ%{{PtS5)tY&F0wGovWToL5E>M3#SL~e6?^L|(rHF3v-1Y$gAMcA@Pb-tiFLM& ztC`PqU*47X&^$IQ7g|i|NByjmfMs5f3Vadl^||2R(vB)_&e1G~%aQjV)C~M(x|btT zEU#szjcL8plr3S$$SxxzFq-_^A3t|h^*&jBZDcEgLX{ib&2W6*kpD9!yD|v4N^m~# zhZ}W98fwUL*9`6;Z(yY1mWo*(IX;-q(qNk$WlfbNj(x9_90l*YPPtF-KRSKSET><_4&DRq2btG7D;ZcYRWJnm?b4O-Uy z>R?G{^PiGbTnEyAS@UOg2!)q~^H7M{9;E}yR z6qyPXBs~h7h6_+$=%w?YRlk92&*ToD%93)tia}rgdW$Cw7gRnEm9cAoPB#9e#mi0z;*WalPRs^XyEwi}?+iGekA2tyeoGW@IAz^R#Y>$qPCVs) zBxSS_;=~V%*#3Zd*5}wLk6)F)ESYlXYQV_k7~#Y}l}C{U#JD-MkwvW*N*7raIiTSh z>nJK(16H>|16Su{|GJunGKe0(t-t=%&#!x0e*7<^SL@EoJ0}st@c;xro0ZSsK``yV zLf+SOq|ZNw6Jv>4+)VO-meXmT?DqD{&Dq{0JIBEnpPnv&^u}f%FzVSmY!p{BTtu8S zrN_R9Tc3}bhn{1aVnk`>J*WQAoW0*OVSl`U5NJmsm}FR;6Y_g}$W_;w#TA)3?i!t= z={MZ9yVt4O45nhMIF4MC6J>c&Z!b^i_b_PjT7o#3{7z z@Ym@P0lRnx4gggo3fden9mLOfGb-C^J|C~ij`PM5V}xKtjOR}d##+*#PQ(I3%bqqa za@>Dp^jPBd(ZHW0V{8q$qIA)^Z`^T|j2e2;nlrin)0UtT*66jrke&{SzRj+Ur9An{ z6X*mKnZ;LF+L@_@f$Mtq>az5xO z%K{NaEAoPx3fC@ior1*mrXT0T1)(kFlVo3ej8AW|hNzC6>CQgEl3XPFXc4702`4Z~ zXMmjNtj-;7<~uR1Ly1O8+;0beb^3R@F;{N}e?>R-DqA{tQ?`EF1ji#dN4W_x z>~PCaiXKO#9W<<_EnNM&3YHgji;J3FK^E4N4NMLVfDcjTGSrktD*ITkDbG&dxIXK@jtUV9!ahspNPc}-KZ}4ST{K%invEJ#)w3;s&o!@dG)t8ZH7lq;#%_(L zqa`NauQ}%)%NMW$3A_c6$S34?g>GLjItwX14fTf(C+E@aA{F_hQi6Cz75PF^p!vJ) zh36VF_se?u9QxlF+Av@43dm?h8;37_|L4iM_n%MB3k9=!dxJ|VyAAfO-~1-s=XBLPHjU-S?rHygdNQ--BdB#~DmFf4V;mvI+q67TMa*qo~kyAxZW^k}7VMF2u0n zi1n@zkKBRaMd<|*1wGpD;lUYROe`l3rxCCg=j1C9lv4Tpn`>HUSaXIbB?3h4;~)!egf#+Vl~6B9PV_PZr0aYgoR|u>JxN zGQHKgxvGNay7UNG;T1j31)Dlo^>DYM{EhgB!4AE+0C~18QxeTJygIQFUt-9bL-93( z_nE;S(@HBZio17#B&f%FPt(&@e=uh`orqGN6U&dja7@XD6j~5j6{2c@R1j(U1z82e zi>jJ%+TyV)SFs|Be7hQ2j6S+>e45n*Cd$z%@v+v-1u(HPUd1q!2k-sk;X-2`I0J45 z6Jn^1ZD=kub#E0*U$6}nQ_hZ94gZg?R*BIwKD1G|%6{do-mAIyQ=4DEz~-2h6TM1$ z!{x(p`KGW!e(40hK2Tiz$B*PG2t3@;p7Q4D`epR~PQ?aAiBb}Ma# zz{|K}@~Tnb7hq-vA_&o zfqaV5!OlgiY!_ag_qUL`r*#5hly<<#+aabLfmR4g-b}_OpBXd>vK2lWxmCsC1 z(N(0dQ*MojXM@;FUxM+J##`Om>o=&$E>ZZJc6r?H8no(q!E#WPW1;b&$lxAn!qs-2 zo3X?fyil^WD%&px*-(zXSU#Rjjo$>@T_E6kH#~UHlYL*hr%PcQlk%kR(DLT~e)Eyr zdrA>I@awfRx$My5#*{p%tmZ}643ro{j7<+NY2#Kg;tKC0PY^l9eQ~`S3@OESx@B5< zm!#}=fQc1rYedC)$;wvnS4WHtt+_(ZH5$MJWIm}Tu+yfCHj=S1i^|Hsp#y6m$iP)S3 z>6`c7AzZF-_U#>sjPonrqCc!j{_&MoNqMLD(bt;18^=KAHy$Aj0>LiJ<(HGV;GP6I z7#pJ<6s+l$y&(D`pz`bKRl8q0?WrAiy-QiaP-RCn-0`mM;)3Gq?NtJN%x^{mEHlEA zqU0u+Ydur39px^c`J*@5G&Mm1&f0NPaZ`}0|MMo2%>jdf=CY#LiHH-#4=%kj(|T;A z%j9~|w;&Ke+U#sXcD~rxYIo#;Fef({A1WfA9-4WmAZ0!lp*9LQ|24Xy_bdl~v$pxFW<`{w zyzf9$0s>2V|148=s$1zg7BPS|A zg8Gtn%Sy1zE>MhZN~Q`F3Gm#AO<+zKv8HlX)84mwA)2>)d2=E_R9uwGn&f7$Vz}Y2 zDK;Vi^sQ8;XBRHr^k`_bwmCJ3v=54E4hj7-ag|y;cjlg1H|II3Uge-r9ey*qsgYaF z9dijFP=X8RUe>z#lqO|3=CXw}{9Hg=kXZAJCEJ!0VI&mSahq2=n`VyJ{ZkQ#WJ}69 zrZQ)Be8RcNJtvfABWc?t`It(2d8@(ZZXA}Mw%_}#5cNV>OF)g%iG*ueI_qP3S>A0| zOreC|50MNu5A3zZCPX@z6ap>}QB2ACa&RPuQyVT+MDgP}G`;Hj`C0ZrPV`fz0PR2u z7CO}Ry_Jg%yNzQX#EL38Y30V@$r8QX=KE$A?UXl+1G>%Um|2b4P`*{qPR!aM;**ce);%c^KO7Pi_~;B zW-OxQ28o&20*&qhbV$yoJofJ12>Ww*K6*aFi`p5E9f}}fXAD7LL|6+hi#W=E;M_v` zLXuOO>|SwoJ4!SA{Z}5xdZW@4|OBI@v4I(<`BIj$@^Bw~&>4k3TuCN%?^j0Gt7` zD+CnkZ5_0!;o{EwG;4t1@nRY%@A$!}WC(=BnIjl^n?bgztJvHu50FPivwB2HGS3%m zjGNtR*3q`I-4hExyl^{;;jo^55shZ(eEqc}LTsVI8wQT~_z*^}3<=yYb((3JSsmAs z1YBZ@+4L;#4o?xeD@Zpz5!RocMs0cI1gx0n-~KRnvHHdDap!ze?ST)^*LoeR(@|k9 z3n{Cuo1+n>%^eoik%6jnF+3W?Qd8&bW{Y~bT96JHLHJiNkM(21Ru^bWUhEg#FA*!8 z-oO1jR=A3^>9|TxqkCNMrd`u6*P(TwqQz8dXR!5pw$YxmAa@Y&Rb>Xjw6a zpnPCtDQKCEXh&O;56RvhSGl^+bEHeaJf_8$2YI@P6na7taudLRSAe|j^N@39K*@Sn z(#)W!T>g@Czwe#>mqfVx;AoII?@_Ux#~u}K`#3SFEgC9lpT5$F&s<(SzXcQsZjH)| zoF}>sHxcx@DRh3pay!b+H2?O{*j;U%c>J2P>U)fysZ6LlYN*iVQC!h%l;Scz@A($p z^wK@Gj{Sz2zl!JjgcyEolHy|bM-mQeI-3GP<{SIKU(?hmHpH^-$0EK69y}I-Pn7|Q zGwvtuvl(%eTq82pCIL_$K!<3keAB_KJs-nBA$%RgTpLl9*O%-tH$z`4{)mOqY7dcX zq;Z=|i({*sgP*zomj0{4Q^+3LSlIia>C|?8)KIO9d>ztJst6IR*-?_9Ym5D8 zcsmUju$Jyww^ld#KwHyKD?58NWJzg8DBZ@-5=CSf(C?jRPRI5dYGi8bEVvlF_c<-+ zxY88-%#^w`S%~4+o_nUE5>iRs$l`8wK!CC_#KnMIMijWj*dhzr=|}T}{WZP%YGJ8-e_n~}{wB%PsHG7+cC&<-x=V|66H3B+36@R$!4wvr& zk;W=Pn5&z&a3A8XEHb@KRFNFYId5G$EAdr|YZs8soKvC5p7&cG?r7X;|AR_5%UiU^8?VunkxHyAH^ z{AkZ)vMei}o zus_F|47?w>EmEmMaGJkNNsW5sW`e!u7X}ye%AFY<)!mW{0wwI% zbF^xd09qIRLm}rHIdR*niJj9A;0XWvl7RWD#gCHq=#vH@-rrfXM0A<7%Pz89oJcAWhFiIdV;_Mvxv`d4o#nANGEk zO-=WZD1)hpR&;8uX=Z{f8_@{LN|k>SGbpl|&d` zI3@l9)2KCfCnd-v7(+>!FqMKr7c~t^^ExH&!?abr7kgcJ%5qYq4c&PDe2bof2lR-( zyB@1?X_r*iJL?40zx4SDdL+M_G_R_`*D6 zU?rT0@plOd{{)iEc6)nru}C3}-FgEe zEQ^r~1_AT+k54+Zf*9GT1n6Oy9mMatY|Mvk}&||i;SKkTH?_(PiSPX`nEH;u!^_2xk7-}9nvh0k5%am zah1=%2TIzNp@)+*2v#kzX@J#A<=*)$Iqb>K(J3@>9b*7zMx&@yU08}$FSXsMMO|(( z(GKGY@OLgg?(;$I3QbhhV6HlKK_^e5;uTh)1E#;Qn>-ty30KOCJtDU!roPc*-V2M-$ z1OQ~{rH;b?0h(PIO1C7~;Wr3Q8mCD6zf;2HOn>kGVB z6jE;L83ENvu*X+d?fwMre-cokMx(s(w)km)nUsV=11@gP`TLklaN`wwz)LY2_n!%I zh`9`BabQoL8hMa#4j(kYzbM zy3LGcIDC(I;(RPn1O`b38gyG2nEV}F3ZZ(Gr?KXCa_%eW&Xl>&VA;^f%vrwWCC!IW z!qpAhFKm-uoI=hyo;s#7sV-YdOj6qgPbe>K1TQH)E2@p0Y69C#b8$$el!`hNvzdfZ zLd-kXU#;2#Tbs-0BbdTjk1~lvm{Z8i^r{fbwz>9PHJer|0T~F5i zDyCgu(5ca8$>n;s$ClIiNpb8q9CU}3&JODgDYZNjIx+hQND%a}$SL-ue|7zazTwqP ze&DTFiJ$HjMN~7zFc)|_&EL0<_p~baex(RfjfyZpSpD^d$?kbqjOhnZzOVjz`~E8< z%W~AXrS8TrrSFRIpd{9czD|NPeU{B3%KZ$62C0YL5L82;E>!FIN-ABQ(0COQOcJWg2el*Pw4?J53lC8DZ?EZ&#<4a9@}V$qCG zjpd4qxfO1RJ<73R2zn>hxMuUa-I65BOSt;*ZFPt!q60}S<8y_3O>*>eztV=aB(FH0 zW?u(Oz)gU8#Qf8x2o05>zkm>&(KW_;2PCVtjr)t)ntu+#U@%(hS20GZ_+KE1_JZAi zs-hp6&XcGtg>~xIs7l>kD9NXQ{M_sdBAf^AA?{Ii(D&6#3cq|25Bo_vaoJCh9e-$)Dp{k zD+Huyv0O1O&fX7u0S3wu_X6rL73U}6#GxuT&sA7XJ^S*l#or#ATbGY=8Y$y~+0@n_ zttXtxi$IJ$j^Nlvg$>#eixGhf1eZ8nP9$UGd6P>nNURtz^mTse*1xqIVWv{ICz-+_WtXg*Fl?z66tFb%&6prU@CV(WDOZU-V(Ka38CdsUpP8rR$8dU zO)K}^XT(E_iz?x=BmdaleD3ndd5#20T#9_lX2f@TD>qZ22?OhA$GG^|*EBzuo8;DM z1A2E{!@Q06KgJZ`lOyQaFxrx)<5*jVpA%x8%}pmyE&*Lye?M9kqO- z4&R%BwsyJ8Yo~dMnS}-)g~VWm_b_YXDk5Jtj8U!bZLxk-$vBzLQVxPOR@@mC6*m&n zE7=*Dqh(8D8NSWviK;p8&uaI4;n5Gq$4t&SGAR_x);gg^r$wWel$=4hUS}f(naznc z=MIyPakkO49F&4!N`;;LFB^a*O+}Q%;euHOu1F>B7$&+Mfmu29Iu}?AFZ1fk8`6#w z@SI&%pt?e183SW`BtHGSW)^y?*L^u1Exh+ZC8%+(CzLNsPpIqv1Y#F3N6n4-sHD+f@g2?`R)(JGz4C4-u^HE>_Fph z#Z-RN_yj@`-@htGbt|oZj%WU8cQVH}Pd<1rsu4Vwffm#xjMwERheoSjjK3O(_NjBc zweMza)*%=tw(9t@(uQ3u(YBfh^8%>PdBfF9xfqs7=M8F%gXd_eQwc0jJ4fK4j!WT4L%`{4rIfD$%e{ z7Y@r*kA+yXnC{`)IZiwxA1~ROH+xa)+=kJ?x*~ep_s>1q1aOFbiD56cIGB(C>dCd@ zc8_B}yUu~kRkep>XRtTBudEQ_SAizC9!A)w4r=_zJ%`bL@{B7ZT9^wx{AmFdsSD@1 zw)Qrk*pHu_?6&A$KNDBmi;WV|^5?xu-Y?}w`zrATk;v0EjYz`&&yRXcSgMLY*y!4v zi2a$ZAi`*}8ba!~xYgps3korYZ66a86_Dw;Xq^9RaceR;#pp>rACQ$@_PVn1 z#X!{pG+X@G)TV*!jR3vpo9yJMi&io9!8Xm%;#Cp#++_RaoF8)6`k~^0i`MY+R~T*V zT=eYY4h~1%c07oHHX>D3K=8u7CfKQ`8Ti$;ps_Oj71}SLUODgsC1)|coOnNqZ=7r1 zW7^@xWds}`D3H~j($|=^aJ}<^@+1)6<5vkISEj;7DyU>?=uFBiI1k->Zel^`Egi z%=Fm)M#m^eRr1^wQsmOCn4QvK#~37VQJM3jvA4vA1oWS@lOH1UgC)@8PT9WOJ%tCy z{bK`ZAhsYt!#W$lSQ|OkHKaC1FlWod%|4?b0xk7!Bw|4_F~ATPpq0WMd(qUA_^aqJ zx6;!ljZqaizZ6wK!=IDG=*rJb7ozF-w{=-9J+x*6a|V(!Lxf+|n)$Y#A<^1qg#V<7 z4L+!Q_u5;h8L z7)xYo%VrySv8jPpnz&Gtpq;r%)$=a=LX-8b*Z3vx^yr>;-#bIT1U&2K7_2W4lPLD7 z{C<;o1Bxb<(4eI+$-wlKryIB-R`dm)PEeDd@I2^AkX77}kKX1u7>vHu)-#XMh`~a; zD3`LXbj@d6q|xmkEfY1Pp*O~i{>^v!6bU+F{hPOPFCs`2xU#)6y?%nQ#apYetzJE) zyq!|1@-`6VQ*xDTJZbr}`=fA$DLdQ>V0~~2jT-7 zDQjO6VTw}x2Of1u&k<3@74-teR4+_*X)w2_GC&37Bx{Zn!;X0Rdjv(bXj{jbGWUT$X;cK0&f34NO1W7e7U4!!M9+X{dbu`Z zaI@i>DhEe-iEJLNua$}D&TI8SlZb2`sLdZ-0K|hiU=b+@Ldwr*Lu2Znp9YuACmi>hD;!6T6je{ z7{g2W5t1x**`}V}1IP!bx1ctr!$Vb)IHUE1YF?D$27fdN>vC%np3Jfqkmw>4q8-R+DTSb1LrL>B*Gs*`b}6r4l{Nk6bse{xpsL&i zNUwO9KPKlNVPg|@Bq|nr{nmrlz4KX6Rvx$Q@7^keIw@S~o~0iL4Gx3k<@cDObT+Vz zl5$LFF_S{#%6|w*rUUK|7Kdvb6hN*Ug~*^SCXlApY+!Qw!-I5 z-(ZQhNafE$0wXfM2CEq#%sUw@EJ7;fR?ZAGgHj0yrW_aENGwy>&w2SN{1r)wu=s`} z^MYs9Nz$IbBhi z7&RE{M$pV*1q1QJq$Ewbmf4sek|e;E5k#_;?u`O(zEg_lBhAj)nHnJ%Li5& zM01bntV@=;UDMQD$q`%nm|nw_Rk+$(Jaz_d>~)W_tmn0>LL@KoFo3lwsSv2^l(5b> zK?s~*D3YHp+sebIMXD>}BTgqo=@?aRdTXB{ZTr`sV}gNFH^Qw~kc4R9*YD!=a;8kw zYHQJ|nz`R)#~F1JHaep|jAZHkss?`UHy3`MWs~sLfinm491eFeePtE6N8bLrkc4I- z@kExGZ1!&z!Vf&k}VY6o!j) zvC`a{MyQsNqhLLh5X_Qdq64(fVO6%8^*WCmjYJxaGM+V5zVIbrT53=tED(D&pjv?95ES)0Tu)6A*OS=hKtX6*_JIAUo8$Sc zSwmRLEd@k8%i4Bd=$RyA6~D)Fqv^Fw|MdBMq;_QnT7>s(^ukgm`58%q1%BQuwB{cVRBe|TEkr#yo(Mun_= z5(X>gWXk&(4;rjhf`kSbov15fD<0%b&nn?&(blG&vxCR-OKuQIhluofGeSg@ziYM% zGy7bSTW!dqRy!d;J0V0{vn&;W3!k43Gw`}DwEF6uEk>P*yjxfGURBP; z2=EeEHBjsP8@U>-EUb*N!b2(gjl2BC0`wT&RDK^av>Y&k44BVvqP;Hp2 z_!fh^ouIIof#0bP`+zG>thdA8H>yQtnh|#wSYeyY=yEx$JcsoKE-DQ9Z1<-m_bE34 z$nqV+?-!6ybhmds#eN!ZOK!iLPgHFd)VL1UXP)>7Ay71QHS2~tzL+#3fW zz{q<2f?AQs?qG2#_oJq?z7AEtzyxa#-Q=dg8=`J>-v{(cbSHZdopwvCG6bhfqcxJ} zni#-1oT42UHruHZf7j)leWp?`%v>cnv2;*Sh2biNa7E|Y()?t{4#7+yYf1$y(#1Jh z_qVo}O+wkxPekqjC`jFECgmhj_SIM0Tst@VIAU|cU>Nof7zoIK%PC*bzPgKpDWJ8^ z5;nJSR!SD8>>osk55~x}no-!iJ+6omRC{5HV45?CTYS$FqD%3)--4yj!S-FkJV=RN zSAIPM=W0I}0;-V-;+Pnt&tvtl)S)Kl;WzP>^F7i`5_7_mOlle`+)7LgtPNcj_QCOp zHLScBouvUgrmOU*S9~yQ4Bwc~096DED~N)JVp-=5Zor2GRQJ0WrAXBgz1$OYL!^7> z=|7ls#0>tCPsmXM(s(aBudI7=kt^6y27Li2&Ysx}WJoVYrHBqY*5F2!vmj#LC2!rO z@5B|LbiL*K{y~Em2t8fCo@rh;J{H@NP6NC`J1u^-Fik{>h(k8+IW7P|$KkxF{km^= zpBdR=Go|+zF4xPhu1$h)RH;z_Z9@5l%R(V-Ics+X0F$As{rL6F#2b*Ub17hN9Jugc zt?I0Xh@Nt1HL`?6Ui%*JP(Du`74;%yiq>oOMs?r~tqo1tMFkuR({s4dqKfZ3x&?qt zE%97}Pe+sRNlh-!9H6W~eNJ(5__^+HDK~1VcL6PHJVh5Z+b=u1Cx7`1RGKe77ICqd z@*6{OkR2VB=xE#1sPj(tZ|;(krx*7Oft0b;_s)1Vkqnmfz5WuL{md!QQJS2>ocZba z9e9^4g-Kgi7;xXO7)O-iUypP4QZ@$XFAe+Jj{|@vPdaKIN`VV0L=T<_FKzN%(LEJriQ2Xj|;zf+_lg71&?3r8|-9E^=?TI zWk1cr-&5l5e_a6_`ImH;FTseoNO#IyV_FL$MB2NI?KTldSECu@szGVTs#ERbKIWs0 z57uubV@=@Uwt===IRxUoI`pcI^zQM{PFVeFOZC`$U$MzI4O-&G*Lvcr%JhUJw~Z;w z(-0%2Zd!;ET7$f2RT+3xlXrh(hpg+xT+WHF7(Ld)i`O_Oz@YW5ZtteG?}005|8enn zMWh$ye#*Y28!sh4&cRTAWj4GmvcDrvCA!z4>r>zZ4lt~LC_N)VQof)GDmzv+AFL$g zrQ?(Czc$K~FFFGKKC)?*}jy zxrw#$W*vkROZH6O{?fzU;epiE)>tbN7vn&Xo<-9q!Cj=69mn}INB-gpf$t)Gg!iVj zNgXfm#d6bV?d`e0`(B4}VX;qQ(gB@nla&19HD~JKxX?_6%SdcP#R10eSH z64s_%isQMj*C-C$#S}pGKy5Z(NmiC6nxahPR!+88O>*I1T}?*P3p9BqKNACb9W?wk zWhJM0HD6u<&|+b)M&bL_AN-u<^%dTvv;Z_SR7V#-2&}_+T$o)z*&AYv5Z~_X=_hb8 zWOgxF#;m3M;@a-AP8a3vPn9gT_8WJEv@T=p~?p;?Kzx`OYwVjq5!x%!t>8QN=~P6cl}P+N5b?RDXgZ7N)R+ak3V3 z>CYqi`bDwJy8-D@b0gqL)QKbTildG{V8Jx zqZh{+Cv6nXmg|q{l35&P!l?9X`Z-Nw|KhPCVhJdPF@-7m7zwPE-(Dfn%Cm_gajTvb zA@f*`{%CIwotynV+c^WD8gL0XRJk#_cRmy0Xx{H>%g$pAcVZuKiW}=D?H~WxB8U$m(0_-FCtY7_XBbr#8u?m|~y7_&KyuPiH=T@T6SU?~$xS z)rjm?y4O{=4pPPr(P@lqZCS1L{G; zyKR`e{}n9$7kTJ^{~p242Pn1c0*MmXS z7wFEOea@hsLQ^pe_zbaxcS#7(n-LAC(z}}d{2;P4IdHhoF3=g&g(bx%ZS~$x&Xlg{ zd7po*RKGs1IwN|9Os`;o%muz|L%<9!YXpj8Kl5^ee^w);L583GUw;fSQsXh$ARL5Z zcO9TjAj1Fm-baJ;CS+@)kYG6=3X81h;coVz~8kzJ=|8HC; zbR!2j;M1nh6+oFnJQ!qUY53pnPeP0AXy(tc9Vj;+ueKBgO-@eV{Nl?tkiDz?Lt3{v zk!m4--E05ol}|2B$L7t&D;agfEVK&>+03D@6oOy*e-H6VUy#-{-`%e~44`BEk?)ip zkUu7pm%o)Y@ZsuzzRV07Qn2PB^q2pa`#@ZXu;yGkNbqkeJ^&3}re|9sgJo}`IMVPN zC(z6Q&-@+1_K!h%3m?5FT(ZvNzni?L@lPmYVi|~0?My9=a@r66S}DgjEkIPI;8eBW z6fm$?)BF6jo~YyZ(c5+D1KCG4$H6)*P#sLs^H)>3x{7mU`D*Z2H`U2LbghAsCE{CF zet;2^(Ry}cS6!B;iI~HD)^1-YUSTJo-+x5&Kd)bFiWX!HHRWJ|-yf;JboBh69*u++ zDfY1f`6dK%Y$t8KqAK|2QGXC6xQF+P8lm3yd0ryhTQGC=8r!cNSI1bxY@f=-D=6R3 zqIo2DA+b;z04Cq^Jr?=cFq2`|~6me)1(X_J^o9xN( zhAA)Qd5!D-*3F)Bo=hnWONh%4S}{#X{+6Vx(z9flwmUkDkz87Ogj4;mw@CnQbCFIW zMJ|LC+$Q(Klc5=O?EmBkGJ4N?oyip@A;(RXy$RR$p&zu|960xAdMi$P)mfn`r7w;- zIp7iP&%8(8u)a2B;~I69B_`bwlh&HWB@v|(8}GHgUrRMdvywHtn@pvy*n+9&M-Nhd z4l^iXrTH1EYI`yramfJyf;N3|Y)yJ!ZV927ld{zPUIy)v?^$Ng*O_%}nkbc0VpxX? z@myeXHv>lM?xZm7X_m&b@xIzJ?_@c9m;Jxq9xvfhA9jf44A}S$*Se`_M6+u%{s-6g zkn;>ZW<2Q}ErBXKs$mW^M9GKq0> zrvDL*kH{==)%$S*=t>VS15DqIFhH1fV1>}JN;-rY_JVrm>RB)Zp3$!QTb9ET`!$Ic z9FeG@>hboMQFCLq-EDmYhs-vB68k@*pf~70GxmUo)ZO`gQ}M%Y&%4%zREse|KG7$F z?ID05^4-22;TtFNDdD7i#+|#+eR?Y)PT3Wy?>-{?(-h%c+nzda9Cl#t5*@_oX_8P} zp3O=!!)By1V!RU1C?yu=Pab^2^Pk+bq5WC>!%`u)TMxTEfM5Nz#xeT;aF3*SzA;`2 z^nu{6GW1iHwSNF_-Re6yb=Zgqo~<7B4Wxd4RK+?)#i7!48%(aLfaqKHm-fK)bc^o* zqhrsBId{kW+PbpP>Uon3oQ$|1koZ9J!Gd_0CFp${6uyV(CSM?(4v8e0rE@$m=F$^@ zUcS)>+Qs(DvzEV~QAeqaz}l7K48@!G?(Tpp^9^IwW1oyBiOfw(T5^(k=U(Be`nx+3 zRKW!Y0MTah{n32<7ibz%kQ&oO-v-g{ex`Fxuav>tw)4W*0rktHc4u!skwm)VvpQ3U zM*e>@QyJ(~%r^yAI=)ZhYhXg3{ND~>7&W2D+6HL61LT7HMe;?FDF^sCU<7}kL*J=+n@ zWWUuXSM3Mpaoj^NIlX{}ce>Y&JJ~1v2w|qa)As8o<9rxJAPt9v4rK?BciPRkl?H0& z_sVkp&YROMCYRb)fk{X!-uZ{t8AkkH-O+FriR?;%%CT48$LOliI0?ao%2m6~C`IqK zw#j&dzMh4JZGhv>6OJQJ6wP0Y z?P)28bW$WTrT~?~{$AHD$*f8+JLyY$^LB~E)}0Ck0KZV2A?hBY3^7SNjC=YG>X2AK z6ppvQC+_Ro$uZ1R%?QElnGC)rlKdCz*Sm9`yLCpF{R)lX=DmTZ(t?l1UyPX(-&3P_ zaprOX@n|@+3DH0PGR@5_sFT(R?|IR~9b{7%@$|lBOtlle%#)8G}}mJWIt&^IXycQ~FfUwlzOLuGdg*Wbe9nQ{8kbSaP3OuabMC z*aXBZ%!zt;G{{A-bAY?`l?)#Jj5BATHMgnzTHZr{!?#P!i~~$Tq0KeF>st5W$7AxZ zpFPadD|sbJYOQ)N`F5qz{7e_ET!98Qc;;`^dkroD;M$D=`T4jZe#u4FR70g5FkQhS zy;s~cP4UyU$sFY+Y@VlwT`xa1(Z#DB>48wIF`9~l-M#JcMwbqY`*nt5ga5C+GY^Nl z@B4mPlR;8r#=d0BUmWXnz_OQe}>SyQPPnv+r)%a~zot+o@{ zlbsOa`ObM=*L`1SclZ4~f8NLOoc)*MaMbbp{l1^?XL-M0pA5OD*r_j*a?jt6(lpQk zXYdtGQ$PrGaH@U=xDg%#`4xkQcxb}2d?VX-z_!dP^V`%M@n4VlU(eg>BH#S)Xj;<9ewvw(>%ZfPQnXr`ksTXx+~w|)9Su|eeH9+Y&wVR z!LMr-e8;{=Dyd^qKSvSZxD+CAgUh%xR9{Xzu+*XzuXc~MMy+%771$ltaLh2>++kYd zxzTupgg7SJroPwL$%XO?8!nMdo>p!?ZMDB5jDMu?iJ0=3_V?@&`B@;}(;Xgp z&4{5WH6ImE_B>bHO8JuYqD{T(ZkT@usf>)zd?3~l)Decd#_@Zo&B<6IIwr>L{m^6zLqRk_7iQuA! zDE@%5j0dl<35goNOJMbzsvXmO;;dBk6=Yk1fBHz!abv!Ke=aDu!Z8`8VUA-;8#}z7$cg{!T^yNj zB(+TuKe$j^Jghi>W#@B@o&43P?cWpf#`6DqAph~TV7xEfCvr`(kU8o;8BGf|D~V)l zyD%1k3fwx^aM$S52%MaoO)FajF5L+z8^2-^308326(A<{og!CDc(l|9XB6V6Zg_e4 zg++B#F?vD=rlUiq(ww6Qz1OVtJtEP)6Y#5F6b1w2lBcuQv)_-MK|eVAy7dJPrvoO- zG9%`wUGF`VN%y?q+%tRCzLn3(r?n3Zrfh)vT)YIX zhX}e`A!JNTKeg7$pXZalQr)4Yu8I!uG#Hltm@U-efzM#1d%bsYzjbH|P;(;>lFM29 z=EcueE#$7&DqxE@F$%tWmElpzx>Kd}+flilI8^z2kC%_G1E0hbuX{cv-qC?yvbsCx z!CP<|y#Dpcw`bws9^0&BV0S%CViaqrwy0J#?)?@mfQHQ!EC8AXf}Gc>njOz43Z1#* zdus~Y{wRr2q}nI*0XG&5fOL9rua*_lZ|bcGp(P{4as&|BnGH7e?}a-P)2Xz+P#Z}W zimcFAoxD%hX}LADOUo!{`}+dARur{aRXnmZBlEc$|j7TNwk0_&ge z@lrC^+DlsHVXjR>*BD?%IX8ZIq7~t#^Z38S;rDf4i+5V>z^gQ7ssv}0@mEN7K#k>; zEnpCTd$cv5&&4v6TRV1WlPye+9xtk+*RQA%*(_BQLN{r&XI^d<( z<~Og}>R)W7+P|bKyb|2(WwTD*K~40Ft`0Jj1bO=y((5v|jOcmKrvNebZu)5OFtv&W zd<&#N*MYcPQl7iFIP@dqdYkscK?11*{uO8`L{}SA+@jVS>-P(BdIaG69nm_TS<7u< zE9sv5d-D;VoW1SWGe!<@O^xY~R*p7?^u`G+j9P@~+pc7^Y9-zY{|I963*0PETkYmz z>o=AxZZq+gdPWDaa`wonbrx^geIxhw2|3MJ%1Ym*H{Guln~%7(fw}@X?kKO_Fbap zEaFv`--tM`@h=**Jc&(+P2bYV%Io6lJ5!uXdO#p4kJS_`iM>0*3>)9h{P}W3)*PxE z#OA(2UA~ZCFrZ~#LiTn1sQyzP&T1xDg|o@Ao8d38+||vL19k|q&R~+!vhCqdo*&45_47t0F6#3AzASZMW5t_;% zKtC*X7)rY2+x?EHYG%yw)aC(A+#~E{f>)D?<16q?a)>u|WGUB_c-Ky#jvaqiyJ(ql z<}#J@u>cUF+dmBw#)imJ!r_MDq56day$F)WFoJihqHVWJOZf@Du`Qr3yoNmhCE~$7 zQ}KrbYImcSPL?ml886)cwFJmCuX$CeF zbkeYGV(wi)tzP!b?e0}KCE5foRuPxocR1%?5RaUv;FNfWt~wu2fe5n9o@cD&bC@Ey`TK`!vkLy4~VG zG~Iv8sI35&nh&e}0k%JYD)D2JwvY>~Yo!C!eD|KX(m5OmY88BKtW^vzV0an#rg~mA zSC&~@kR1*jfDvtX&`$vZA>ppC$4%*ifch8$qA7bXT>EC?168iXdx2l#RXit?LV9b! z>Ho^2sjnC2V8tdI+YCDB*%fasl}$z8d+=tnh)G@ic}YC^tQNtUXXW*~b5znXCZ7AI zI)L#yZ6Vs%FCm1L*k1zKLHEZmz>NMlDY7VhI8k8=TlX+YNp7l+R3pXJi$$wTs6U)1 z;pKopDwHYYe4=v)$?~Rsp(gut2G-#Pg-%$fT)NqVn&z5M`f^Hy-NAt6gRlum&!+*a@tiFg0%}I0&z^b3y9b2&aa!#gsV%zQ)-&I{kN#yM|S%Nf? z+h6bp5ORjXl$|mra9yE5B%3Um>7SO@6arm6+>MD&PAkUsm;MeQCP?Cv65A?FK-8qBRNyw(7AnF7#|5o+!>o~Z<;8+w z6Nz0rj5*|QvPAT%J(l&Q#xFBE{1!<5M5FE~xe2Ij>(<7cdM)y! zGvj)0@HkJjDPxxdtNaj}Kah!I8u&THsa#gimCqUvW2g#MxF3l;?3WD?4>?4X{~kj> zcS62_&1~>TJnz!pKFFQ2!#RG1ho;>39dwSvwogugBAbu@`|BQ?fp_R1?)_BFMyl&Y zKsrNdL69tZD(o7O*GhFeDQ!%2aLo)KnGw*e4Me*GLn(quXkz+pH#r(rteD^bX{XvZ zxbO&Onjqag`e9qe@vphYc}%Id)b{Ho-7lWdTZru~KMZyleh_<@f4kxFuTsqWIb29` zs!}Ywv1t#qbgMHSeS8jpz|zK|ClDC`bI_J%e?XY_&gsaw5cHdpaTr96X!KTd;g7lS z!vvnHfTsPY-h1fXORkXNgEIG*>?f$g^qn#cH>NA~dR)A}0@y;rv|N?3A#o4#G%j$z zx$8hI<%1usZ{pD9e$m{W|`z12b_Y%yB%R}PSRgF zOZJoVQ^n)F#A*Qz+}VGcAfV9%+|)wOuHH#o$fne$5`2hI&n5bv+gzf(2VGiB%;1oH zRZVj!++OO9v~0MnLYx3ZcKIAd;n7`C(;yK@L&onuTlAHnH{d z0f#UxC*H~JHxl`lgMlZL_yYfAUz|AkO=@fmUVXkuY!meg87F$FwnEj#`UCN}{7RUW z@%|rm*ie!hc_a0q+=SfoyS{Q>=);3l<<`rlDdj^F#&-Ti&%ho2rk#MOePA8brFLqn z>uk@8qIz*1{N(S0;UTX8FXlb2lnPsZ2--jqTIA9G+kOx5pWegSH<5RXKkwqp_fTKk zv(1wD1(?Q3^$x`tj>w4ey4U01;fBLrZlhZ)Z2nqi4&C}x*QmeiwPgQLL7J&-QpEIK z$$(9P=a7F@q5l~S9CNvxrqRGKgZ0bx^`LV_K8z`o0N}4{GGC2F_JO!pjYJmJ+%OG? z&NCdSY=&GhEJKs4UxeP?IB5f8M5837zGq#0a+Un^KLRvq4_-t`U@lIZI*g^{O;>!&_{X( zbh*?lijG^-PA(gw3i4=|;hg`|FO_ifRRcfw0wZ%VqMI23COBiK)qM0N<3I#+^bFdI z*`u)9qvnns$UJ-f3o;n#X>}M@f*hZh7}Hz*nK#S6Y7@wO6#rF*{8Taz@ z5f1*5cO7*=*ehfB7~p$?=U22A<_UTn3?;$tIa)e%h=Ud$KIa<_@IJ|Ml}YfB)f5 zt9=i0ub+oGI~92e=ER0UJ5PZHEc9-{_NYaHq*0b4*xgNEc!NQY3f?Ki-eW;W8=bL% zc4mRAjQ_kR?sGiRbx1*WH!XuniK6b1<*jgtgDBzkqDgFyyW}B|{cS3d-Wy?;wwR z!t%?+>nK=Dv}wJZOV{W?E$wLrNMgZ}LR-x`u|3Bs4VZ>3@h>jbjX?MenqGw9B9)eB zAj_%vAapCAlU&{@IafJY>*?CI3bR5^r8>t{2nx8Er?tg7O-3&=1-{kX_K7Kz!^_6# za5@Tl@-yX#erCyH@Q!ZF(ud8g4l$X2Uz&e;0hm?gAWp<`dT9Ms5cZQN@RwqA4r;U` zn@5gOU~6tp*zkndX+*||fv23+0j83lP{h(Gd}O#KF^!YY{;qH0S}szaW~oDqgWnth z2Ej6vN-;yZFH(B&l=>I_m2USm@dUrFOYQ7px8I!FzV>qyFfzX&a(wXEfbH7qNMjI~ z;H^i8ATh?K)3nbKK_C9TRrs$@r0U%f4Ow3KSdd!j6A*}UR&VVaYe{sa2M-s!dUC1~ zh~}Y64Q%1>=uv_f9+ zPk(!J=^G%niILpNd@X4P?Cd-Hp4uf-=;2Y)QCf?QyN<3vDLgW&lxFEl2+7ks9<~aK zMdW^r1V;2eXx|p+f>c4*w-f4tdW?VleZ1m{U|6!6d_3r#Nm)T)SoJf5qIq zZ$R05npLlEw)tm_Y0$C(v_u=C)o4~CVRfLYS%3SP85AN(h*i0&J;|PK-7cE_CH{hX z^Q0g6v_G&|FYZ}+tn_4-;h=G`zZxcViV{-Kk#RjncEnniHyg$2ce{z>iM#ygm#NM! ziVX&S6|eT1NkG(@d<3Gol__5E&S1J)=x5#q1NwKV+4l!OW%K4X2nz}Sz~t6jQ9you zjb;^hfIK!YAzpFlEhwfX;|_qO$&K}lHRKU0ce==oMm#+J)Ar`IBT({xwt@eAx5mco zn{C5AqU@ks-!yVJ%saFvhkN4aFS^FP=kJr!Vw*O=te_wll%?PXG4*HeC&G}ims7`L zOs`qVB5pqeE%b`&4mNbsbj@=+)!b{M&d7EgzO5+~`nBhD98-ANuU%*23faakcX}@l z_qnb!(rPKz&)|2{TByrE>$*yk>$5i%{jH?xrxD;G| zV0OO%Tr^qvDa_g5G}ywB7l+66mphFc=tpWB--@{;{~qj{X)(ePa5F!;Ocv*!ADWyP zLZfNZ^%udUMF(=M@f+Z;#0Vy!-Pn#nw8dWGaQOo2eMM=U$ti|gK?VSk5jg(TXWb0} zvp~Y^wXBr{k9_BD;2|us02qP+Np4*1{APQ`qpxRtK?k*O%wC%}D#Ts2cv2)!pNuBO zy=KtKBguMPt?#u_0rk&=1&k?qGP1m$aLUOj_5c&-frh*%v5Qh@$l=MIBXF}v%quYu zx?m`)~oKDGI0Y~lZT$^YX9KpxjjjgSq_xuOBWe#CZ#0cruVzd|&A`}2vPlE^>9 zL#GWpk(0#?aJtc)5M>cdIN&3OS=a`)jUb$gn{-au9>n{(H``lo*zU^>duhOvA_<7)L*O_!Z}is}``|vW?oubG@^MA8IUh!bfp^D|=DCx}>EiqR z1k7<#0EBkz;wogFL4I6Dwha4p=9;N1Oru*}NqF6@e(g`^C+RbQWdssHwPFik66r(6v8V^&v1HFOAl`_zuz|x5zb5~yiqN-;) zX~C^X+FgJfWw1V3_^lvGRb(tEFAQ~-s1bpC&sg+x+&bDYsUpCX25UBbaCEmV6o(te zl9plo?!Ax5si2g0Nzfx%2bOt=EYHGd;L*nDFYz7O_=?!H*z2&f^r3W8Vy9~WGql~y z74<0;(2Q@wIy+dUcDKAzE6P*7to+p;d_sEAmn5N?a%-sP!7Bs?VXd^|j^8yeAn47I#YDcPD9r|#5L6m#$`z`j79X(3)aw8q`_44w( zu+e&S9vGU9bz|fFNojJo5V)h#k9Sw>S!FrhP-MHqwoIO7Yb`iaIio4TD3C9=ve^1! zYs~i@W^L^3O-H!7gj(FY;UIs6;gRHdxH}g&mkh0gm{pQWn0fwGa-+K$E1KJ{>Xl-! zl0dmPg9*Q(G=td_rR*|sxx8r~{B)4HxaN(872m-yi6-KUfMP;u|~xsPoml^|_>0hcca^^OaLqW;L>YyBC}w9>DIb>92QS!0n3E64y+{8Nr_O zYgRPjr=QM6gxy+*524)E0k^fg3mhwRfxB-NY1_zLQ)p^55@pRkjiI&5?pzqqS zT$kM?Khh#8_%NO{!kXqT+YSPkxWg*|)iTv-p)lyT!vc!iJoWXn{Alp6pr@eFD~}}K z!{DWp5y#GDjJ0SCkKt^Ftew$?N0}$QZ?nG>m;+--8(6--OEJqUKCahSWVOn8=_&R- z*9&$U6W5jSy~tKGI(*2ir)+RKdSjDNF$w&}_RaLDaJpTF$7FmUtN~Ql!XqyMHbeH< z-ah4uoc%x44>J{hYy-aS)TKkJT`g+$6xU-GUScVga!pQEvdmx9^UyyZBEqbTQyMjkG*6VNyK--*; zc)b)t2S9K*H{L9!Ow@wM+WhA(E#Gk?^eX7NI8~X7{$U1%4);t=#0sBH2uI2MIR=-i zt4=ixN_!rKkihXvzg`$=j^7!*o{^6TQ-@By<2S#kl3@lV(>f*fk-IQRs>u<(_r*_e zZC;eeaJ@kMaw3CWh~1;xf0rX4caNcs?w6wLB_z+?0XyNQ-Q@Um-*X8$Uj6B@v4{iu zBbIMq&fS*rLCDIYv{-GJu!njH^{I0lg}$S{w(}!sWxhkHd62*kEvpp}fskBXhHyg) z=&w(vFXGIoxQ2z(aOZozfrIYVb$h_N=kT-7)!3(;R>Z)1jiL=O+>h` za~RUtErGG7F|axP%`iau7e_jBQ?KJ|QI@kNIm3|Hx9k6Qtp8kjUNO`$e3*9_QE<2q zyHl)ir^}Pn$k^CLS7h^*D5V?G2;o2A&d3*banTgO8=q}wGe}yjYP-Y_T9?UX`>JB8 zQBSVZuhfS?mrN^m;8d);Q0ccmj#4izTWDLb&jbF$@mX#_As{ea#^>%S?VNI6>)mG0SMS znIem5uWX!vo9F}ZuNXp1NcGNtqI0uIrpBRZM8rEN}t*{65O@>$?+ zQ2UMavf;d&bho25$qWN@PLDulc`sa8_{*&qFaniZC7N|AsA@aw4 zIKc$C-0@yXN-c7Fv!s;KhD|RQFu^s(%|!||(?lywPd0ldgoErOY3Jh?*iJKxrm*hd) z)*I)eB=Kf4eLHDZOq=@`5ge%`LVpsARK%E2@m>8NV<=9VUiRXJ&4@~-LtgMW$!!lk zqO|p>)*Vn~$ul{LJ4&O2IUU&)T-efhIE^8^o3#&o$;U#$0-jkscXV2$$~mX(j;08Z z`WSA9E(d@`S+T(}rlA8~gsLRJSOP)}|7;PE$8^{8pG~K9C$Jl{k4I{fF}v=oQ0!SF zc|95*VP~k#(;j~Jx?M3nQ43uk!AM1dbRmdanZg|wnC6yllIWEV;zaB?^jbZf?@m?z0JaWSee(0}d#j z9$o@%V3_0>qfEnpls=gknB~H200tJ{Sc;kUl+@)S-7UY8jidm8H+RGw24?Cp3iGkF#!b0^;E(D)wu5u6}Sk4Amg! zPnSkDSG8XITs@Nfcmk*^wnYpLJKsuv>mE-e7{Y8jI{-F9#7vQ9r1KPKRIk_+c-A}c z@e|4vKKhQIBUA+&n6F~PG{rwZq!iF;U}L!PM&aj|S=%Mj+4wdqk@KR6^X?cew(Eta z_>HZ2hhIc;FdKfnlLXz-$nk7B8u!ewgJLs1Fb8vO41V2wQ%e|gLmEPp#s zp?&@V+B5cZfRTdQ8nq^kyYlT*@JJLG;SR`=-&a+qpW5S9cuQ-lJoXgrw=-5pnk6Gb znR3oyZ%f&N$jg|Ze10>i`uhDtEv$7=;;#huomz5Ej`!-H(4(1e_!LC3D=gVq;q3Ae z3k5v1#NebdE&MCfark(4ML58w^Q7+yIHJ=lafZAH)3^t}=+CDc zvlCV6BK6eaEf&QFe()aex8OY9&oDNmGA5L|D~&s86^EmSQd8^*@l0uwTXfSptN-;4 z48=svb~E?E&LEE#`9&(5#jsQ)AO~&J0ex*cXtrD%2=-E_*R0tUxfqc%6y`yk)v7QbRg)I#3EAnTTo1goDdKf~Kc z69adi+gmDT2RGRRWAmGhD`VDwAb8>`6P?kJvEz*?)Vb&^CX%zDy}QC)ukUT=0Kj66 zbe~StLRuievNWpFEyQw%zGmgyQEl#*4P8d7-%6cTP# zbqOYLZ1X%B+`oZL!TEDQ=M=4jL|vK&Bo4QyHv(34oUGa=GK%@Rz;px!hR3LOkEKl+Pknk z-}!9HBC~@S-l<{b*o?aqJWu6r2=xc7AG3a9)kUzgYfTlYIy@xxwFU6?`%vYh*Ai6j zc5fwV!9`t>WvM<~CkX<^`w^2_750=GYbVMF19=bx^nMNkEJ@hvO&|#aBGr8zqtKck z8`NK-V7tuNpml0SszTCtM0+|(hK~!c02*{FVE_u(NKsifHse#j zt?6E$NR&Z->IclLG=`kDG%;jS&AmM+@fHC_d%@0N7GLBLSoN&yFBM5F!DosiXsB%Cn`$c#SV z&0}jgD*5{41qDD1=(3Be$OXxM+E}qQq46%tzxWUNh~7F3NL0=)M{4;G^oS%7YZkMG z*aZhgmm4yzyi)=zeS2fLm%_kkl(wm1Gd2H%s@+H$l!WNlHYgg$iTCD*tFP$#Y=>ps zt^Fxm7k##K>**pW&44Xrcp+(%7Ln$GaR0>{4l+1&WJM6P%}U+LQpTJ;y8hJM*T@~C zMiC;%*V1@2c$|Mz-{6FwW|85NyCPpyMe{_foTS&*7ofRG2dlm8b4^1jhhG4`a)axU z5FAB{f4%8S1z;V104cBQd#YWepq%wyn6jd|JPab$XwnDDVkekd!gB_J%z|jZoF^q^ zd3)UGll4xhFXgZJHNdu^OWNS{0gLtL-kJNym1)!3)=vW znYZ$NeVwWf06;=tt$#`ldYbd7CuI)}h@$IcLM&KUjw}V}-4fOgD?EPmK zx7eG^^q2}`By-v!fNQ30rsZ(E?_RH;e<$uPVs;jsXS&qVhiBo3uq{urhVoGNO(9m7 z9~jUM0jY%98}_C{-5SHz$r+zD+~rr~GMx%oTaD^dGsY2EtEb-lPeTVgJ3s?M*u|wy z_iLG;LWSoBe0BMqt>g=-*4?c?>Ljh4KnQ9{z}tY z7w`PoN7&`kYKPab{g24Rb9c%lKY&e_=k}+Y)k=cy#0hBsV=STcrgZVaq#uF}{;evf zZiyT^*sI*5P>^@~N1>49;pJ26myb6A(IgNppFkZnc`27}QCkS{^UuJmeEaFDoRCJ>bm^GiMI;UQ7T#nuHgcaWhfG2mlLr@bmM=6?w zVO8S_&)+$5GBN0301!o-YN{5T`>wDtr4Jh_feURuY^cPJmbPe_Q;Ud(kFcesH+c#U z$`c+DgCAjS5*Jua{S&B{Cas_W=tS%wf3rUXhK(lTLC6Qe$Z%Tf^iuHZ9ZEvt_l`kg zrxyx(0%>m8c$%CU6W%Fx)p8{r6DiSosi)w3>%W%kps&o2q(I1drm$QL19gr^xvtuOLI+SBJ@#%NR=LZ6bkQ?rpAs1w?F<#J9;?l$kOlehhNV9emw$M zS=upM^iFd1e`95F8lvGH)SooU93J0m1FdzwV%xwrhJ1?QceBrl9^C0vH4&hix&#q| zjG+WZ*>q*G2K9%Qj^ z@M9uIK^;fGi)nzrom!u;lx%E>)QL?;%}bIe$%NV#NvwyNQ%GwsMW5soC+ zRp7<GqK9ZQT@5Y{N89pwoCBSq+PwF&BP}l&w*twwv%&2aI>g0fjX5P_h`y{u z%U;90C7o5V+6NM+lJTN{%p;0Uvmtz;CU_9>skQ@M?HSs(8k~2k#d;c6is-m!zVadTk6<;|+}t?(3Oj3e9(dw7Kqa=N=wuBRjwTp6n&&>{-6D6`{0! z=dZpNWH8GOne+00mmT!iCfLpi4o&j1>JISmdwg&{JT1Tantys}fBI-6(F(wPdTl3Y zTQF=uKk`@Zo?rdmPM&2QU|f5}%W7f8)xNCToxYBDrQkerd61tcd)0+MsiIp-{r1tdz6oRla@K#-gf1O${I zAURL<`+fht_n)~lYt3D=hPB!*yPtZh&Z&KNowMudXjNrd94t~S2n2#7FDI=Ifk5jZ z5VSSueeg+9l+GXo@<7~LN=j8;N(!Os;%H%QXAXhLMW*C5@p%yynK&+w&fx?zB8qaVPKQuF>i*`h_qgi7{&O;-XDB>Jdti_TS#usE{bGWxwaa zIuXX`3wXqcUTpKH8o|uwPYr_42JP2O=3vdJ?<6-|=rbXu1`)!AN$HrpQgR^in$+9-wKxs9QY@2}C3o*qF8D`|sg>CHU7^gIH0(?$-G0t&o(p%G33mp6 zU5Z0RD_<2`$V-}-xkh-mrgA?~j3|Sh#Z1uBP|t*%hbFN%M-msC$xE#fo|aOx=i#{C zccb=LbG_G%;0t>y_l69&>gObPm;`y$Dwf21EbOPwupNmu)g&t7iiUZXOoGDW@ELQt zITrfB1m@(c)C(WPNVMgX6C+E^r_cQrKQJ=0;hJFe!o$16pi$EZkKxg19F2vzCqpK( z4;JK#IP1r%njL)o%veLz#8riv~CE@4yVqn#+NfKOq*7JVpeeJ|H zXC53yC{{{4pn^s97TX90E6os^VTJX5VR9FZ!oD^oFm;Z3%KzLtF&on*)Vz(Pm+d0F zcBs}%66IwQmcOvnOBmt%D17#z&4_1D);^~PJ9XVh2|_8VDG8NFkU{m3ckWSIFDb+$+L~2GX)uJtpFRo; ze(A-Ynuv$vtpHI$&|XJ+vQ-KbJF->Q-3E?6s#_s}n{}tgxsB0E`H#IOmV1DRfQ3di zwV?b7^<7;(*v6>f{?;pUf9vpj?!lWe2`}jucZmD$@OmDjnwHy4bIg(o_-WX?>7^M+M_wON=l9zi%JyR zo`k`Jv27n3BZJUolb?q_#VJop9tEZ0F z-{v`1-%Ak&xyB2Z6{&t+t;0s?pr2Y*p2XAvv!eJ%wz4R&&C8i$e6-+=*f-yy87}8^r$y#d#b8 zikVOCCQZdD!Gi1COJYnTDu=*J=tT_5k-mveb&RM^R$+`#SF@jjHcHACD831g$antb#e<&{2hTqmr~jqbfb~6oCZAzE_LqAD zzJKtkDMX{sjZeeu)!XYo(_pO632 zyAWM)%cPhn9hUT})je^2V*8}=&Ey;3H~i)0Td&(*otTT4!as8vdo1$>lPQ zg>_|Rnory@-7!wkPWVo6;-A$QT#Tz*p*PSs*epC*xVM18`zfoCR5revPyd7D2j#t( z1qJuIPsM@?{vkfeJ|8b7P87D=NzHJWahmRHo2ESre^%J96d$1UX_Ik8ZbW#bFRPVj z)Yii`xuLQFweeicINcKpwF-T5pA1Kt?D*j<%Hnt>f0^OMxJac9z8hH+GO4Ihd+j!vRdd&=`7XK z>2tBxv&`z`Hu^c5+oxJJc9k!eFY1kEnHrcGnEH=KSghM}&zB5sX0eW^S!bKJemwAQqJ{cAFu zR^u5PmAA)Feg^ILFW=iMd&*kJV^L@cms?NSS_)a_x`rP$jyH@;6j~-a5Z?^-EzA6< z8{bG{{_bA6gE|YJt#UJW;KCbo~XeO%UW6Y?h(wHD3x7xY*1|Jb3`k$63Pa~`mG zlUr-2J74*|GW}Nh7Voa-c7Hy1H**&sMMvC8+*;iC-HTU#KQ6m4BrGH{C3q$LKUaO8 z3{?q(Np?%7MK1j$cJ_K}qij^_`)hb?hQmh3Ztk*1cGY9`^+8GPueq^W<}b^!HE}SR zDcKM?<)1?XS86kADTN8r_=SZ9%msP=u6${fLYyAjyRA>}T8f=L*X`Ge4_Ps{ zG0QNgG57KB;c3%(JdRh)%5>wLN9blOB}qP)mz&R>5My&%+IBW5t16=}~OA)w~fe8?3O7Yu}i;3@_?mB z$W-Q;XNl+L&~PrxO6A{5Lu&mN!*eDfI$w>`tG5GhDFp6)jQ;3*r1^K^FR6@P0lC@^ z-MYL6gL~9oAN|_5!Q2#Wl|LzA@=aT`7wbo4SPiR*e%VvcE@O|4^Hq#xRwJKaYCKe) ze#?}#z&GNsP+&b86A}}UouIO%63ye-}l{w$D7ts8{|mOW8pPJXlW@<-#VNlCwaC*!Mv@z@#G%NIlJ z?W&uaFAPu3IUNtfjbv{=oz4c+lPD~te`5XK64#>Gu-oslvLU2xRmG@x=(RDI>Rv%y z@m#yzK*agv!X^E$k}acck*(ltM76z+Wv#>fC-$BC8Kdfh>N$ODosXq|Z|>hlFq7Tz z%{Xu^Xg2+xv{>EL*AHo)($dvRb)cBpB$yPBmapvVz2^=Nw0 zs;#}xwba;f|3mW3@vicBe{Yqp+z7qM1y*18v$VZ?S4CI3`znip>*904X`#cRjgc&m zv?xoc*gi;YmrS2bTTXlO(ebwn@@meg=hhPK5_5aCbLVyNYG}Eb-QC>c?-}o1EDvz@KuxxM0e${L+&$YC8KohP(a3BqrB@mA~AovcFn*ZwKmKo`$NpQk3`LI2~rw>OO| zTwJW+3RYKd2RCCcRtHzwf8FGN-ACHo)zrn>$<5l)0fD@)v5BL*n+OdJa-jeC``0|p zy{!LpBnQ`jn*|JHM}EW3$;QF{ANK}bg^`~Ms#<%Q+v!SM+XJ70F+@4JcsYds>F|H~ z>OY74FI{#1vnv-D7vF#F`d_~KKX=t~HFuG6vw;{k!vj{qo-(h1rp|{}*5U z%hCUQ3Op@}CCvUGt0sy?ay&l_7LwdrT16AQ1DGNIP@Ta)jQ@HEZB%ud-zVe%6ygwh zX$eg)l-(SRmZ{!_FGekSI4q?+C=xLfgvF(dN>7*=byUTri_Q`d?V&tm>cyDQ7v*l! za;&)W?J=3?+=UVd|Yerk3rl(J`ZcR6bvat{Rt zL4h6c1l!Hmc`t|(6S<)--Hl<$e26=ILbJ|#XZL8v{tS6@j}p-v{E)10YP1*>%vd({b_!0tKqcGH5Z%cZ=#Ys*Jqm!M{}VM5%F<- z0dYr(A@JF|PZ|@>4T=ZvO`f-oTT^DgdaK@F&L#F?)S%qu=(%V{V*3PMja05U$Cm=X z8(e%eaR^$7G=eOQMZ!jT3LQ-xQHu5xZWyBHJH>p1Hr1A-94Ml0iqdas(^CG8P!c{Z zt@Kx;C_(iNP8iE$C4xiszNfGAh(xK)4kekp?!zF#z0gu!@~|X&&TWY(9BKuao?6F4 zSlvMVbI;*dY`GQQFN$03RRda}Xs#>~6*y}8NdH8^f|BSZQS*SAh#$tNI*eyd>tWqI zP~FNblJ&`$1kY}SnR7lXK^ZN9`911bnUo3>JGHPK30PfA0y%mJ1p4&xI}2x+B>W4` zYU=}p*`Yy-02i?&T=x8>68UhSEAJ?iHW_MoIMQm9U)h9U_0c4sLnPCZ_=L-GP>mUa z!nbP9AX8$TsH=o^bNfFk`)U&35IgiLW1>KH)kAm%HDtD-GV{{dy5~2xy-z6X1Fn_x zW1P{-7(^TSUQ9MT)f>Ry!=x2Q5TH%5JgiIVbGR>T??Rx|G^fj*9a9+2qoG`pAv-pZg7!1?nmLm(b@J zt};$~a7j4P!i0mC%mkXuycw3Ri4kfUfzFr@gV$MZK@%{oyVlcFx;kVBL8G`OPMYLPMLKNLL=#veJ99ACWgRV@tvX zW(=dysbh)sB$9C!8Y~{OXuk1{sZ%?ue4gw)07pYr==vi-ITurCf?Qmf6mU8*q6HlV z2E{-TV-pN7G*?w{#8W+zT*!8>Z^Cj7=cN)0mmFA=v>mK0m&?d!B!lX*==?vdx8D77 zc(GaTx`(;Mz?xpY5E~T00NS2{62`J@jje-u+#_l4MApo1p=u zF`(+aJ=L)NJadc>@ONVm!G4Pul+^b`BKnxCg6R(uYUs8yO6F=IS4gOhlG7uhMh~Hm zjY4>s;U*6rqEn2gLRA3egHA^@wk3swu$ z_mTLG1fmGpwC*WfkOza&%R9+!Qlo;qj}e10@hljJ?n}b4XUd87308ezVJua&vcbS} z37FyoG?Gv-vSLK%uvB-x=aXkW#=nL;p($z}lkW%A=)#=o{=$k8m&W3_m zqY%JCLxDLdLL}k1l2Bb~Gp=iM;9~4o;4ta7~*xPxEFLz=WfciBeMo3hUsi3pB;Xiz+ouialJH1KGC5ke6cSmK zc%@IELE>}3KX|8(XHgluC99kBWJkECHKFuHR-r`ZcFy5WC_3tp+k*~Hj= zf`ID2lubaDfS{2B1jv8_2acQhATV>QQ)A5M)C&Lf3isq9kBT|TBI2z4ZsO@Jav6Jjd4VQ9m1^k)rXi`8Qdxi zfKx#i0w@4r2Ji+=7%+Fh?EI_nmpynFxXfa=SQV<C_1sov#>x_|xeX*E@b^xag}XM}6J1j2kpvrnQ-a9Ik zR@YVhI*7DZm5EjH@qVD@@!KHL{kFz~_gF3*mNAptngfFm3Xh8qRb6w_auqa*KhPu| z^O8NkkJ60kwd1|%v8P4}4T3<;0L>#9P zY8V+|XiF4%!Qv6mfiKYBfZ3zi>u6%5iFajz_KY8llmPV>pBByCcv)Knbl%xiFRrU$ zlVRd1U6&y=f$A20!lmg$Zjbk&4ghlM&wvl#BP(`x8MKfL+nt;T(YAV3hi2I4a;;C8%z%-4eF5XFj8%4%Xt4R?Q{ zokmTtEt0~CWeYXa0i%#G68RfEYy_}0rH{;ZI(j>2M5yIiW;AEaerNrj`&wL;=#GUsz6*wTkd{O* z1vjQmsKI<_O#8{m{Kmhc`7+@O#`0VfV3{!_G{na)NvZMV6>=WQ`SDL{(OOo1s1Ycj zp8+hEAr2BB*9Cd75t_w&12hyFx6S0aQ*Q_T)`d;o)o_`7aeoY)%B5Cx)onqkdR7#Y zME_kW{_l+vLS&RPkUR7Q+~hL?P(8;2ZX=Q-lE6GMQKj7Sz|@E(0(G-+zJ})RD_&l` za6F+<;nV<=DkBJE;R~V*0rLq8V{`_K6)lDO2tbGcNmnwj3hOoq^|G}VO3!awRy)?G zeg~!sjP{q2R28d?qQ1LN9C(Yh0P;%|;!*DkG9iHbDre)QYG^50Y)3g?9KaF;B_s(w z2brh6K_bDtKC)(F8IShY{sG8N1H`6Qntdwvy|75bSe`gOMEXVETLZlg0@X#GuPh~i zdLegd2Oex@-jMe*U`B@>DY@w40Q{^V)yjx(mYj;G0DHuAT}{6u+Elzu=QQ~#GTL%L zH{*SjQZZ;4l542#&&RYD^-a*Ax_%QV)Mm8~j9N_CJL*miNjOxQub{eSjMU03z`G{V zKPbUmwZL5E33W|7V0FaS#MLLABO&h(IbU#Y3yNR|!Rn-WC8t2zul7{fuh3x^0BF@B z!devZ&Vl9fT^$E5GQD~L?!ejAMGv-QFJzqOHa+xIqBg_yykgMx1)=a6NiF{`OC2B%C&0(*SVIWAMR0K4#688d&>~7S0<$i?s2> zW(^T=(*HG;i>z7z`$*NGmL|Z2Z44OsNDWqw#TmweMM@i{3*6BCOg2>lp#+@cd`}qg z$`EqTFFlS-l2OD5v_|F_UaBt)SX-YJ+xGgI_;;OS{SjmK|2D6AWY7rwX^>J;$^i(B zK&y0%G zEdbsMbL03GBzOZV1a|-X6fVDlFN`JM)+Io@y34$97t00H3eJY98&%A5i6td2pr^Uh z5E5fK93X*~p~i5V7wP+RGKzbR=z=Grtb>o47TIFv#)Dt*^hck|h{hWy@5t%=$ zqdTe(6aUsEU{TH9*z*-LAYW};c{=*}jINK+Nhyei?=!K9W5U~{z7L-3e?$wZi?~h+ z*Lj~ilWPx356)VM8u|vF`qVbcjwGQ2f8t(GC4Qs0{2VY zoheV)&~!FkQCBHCMIE1JfIJ^0CJIrcWLoGb)%Z!?&Z{SaU81l7ub!}GMGFy$_nMgr&N(C763ECVf*NHuFR;CsS|$gnu-(HI0>CdlB_aRHJ3$4IjpQNxM;hL2Y+W`o zAagBpw1yl;=E4wXW?vqH3LDzf3j4Dryj>ceM2yyp)-4p!rwBs@xh*Tmt-LBkUg1E) zkf~0-!!1NBcZFk>HAOdcTNAK^6ezC*2pc<2hj1G==6_l^J_MW$l_CL|O^*eL1!U&? ze>;O1322gUohb%&EN$4T0o2o((tS!&?g6R^(iJ9;p(sFffdgV|o2IO}k0z-ELS2(&8f=;X4$DU%arVP z0a)D};(<8`?uXeF|1NdJyX=8CTa<(qH$JvNDhRq6Mf~6c-IpFymi&GBGH5ymU_^dp z0vfCZQZ}K7+E$&g0*; zEd-PnVAJ6{#U&Xr(&l{{DD+zx1DouI(yW@?xiJDc zg*BVd^CAu~JRcbAAqtHE$nX}CB%bY+G34ru@`4$p>*rk~UviJEY!)lM`^%N)YKaZi^nHj7a3Xx^qXt`@YIpEf6~dF#f7wLEiI z_)bIs8$NSBxY}Bhr-$7HQUDGV|16Qcyumb9_nit%pC}DfA6tPDY0-8A{p_z<>4u2QiwY0to&wJ`tvesLQlL;V%1)m@LmD zAf3kQ2Wd|-3p6tBuXBex3D`n)!-7u6fe!J63aYDEi?#ru5jnEMZydF<>pmw86deZ` zbS>74eOlgV0(&`6(!p(n=MC;Kx(>+83&2r9cApR##9Rt^N+(l~PIePCRq@ zF>Ls7$9}WsAxIu#0Tt@9IH3d4&I0J1Dz5k5dC34osiASboCjrEyCSGy7-V|{jU*q8$Bb4XUK)A@k~CNlyCmjCZ(fH1=?#I%DV zNodBg8`e3$Q{$aPj<$q=^%5N#mV?^K;n+mW9L*I4sEp2c2NN%;o45{PP+dIBv;>0B zyMdqI%1(}psg^WjqPbrCWD8V+eSHAAIl-}Cg7T~=6F9=!0YXrXRD@#0!TBOEd=G~@ z76u5>geMpWlxxP5mrBF{6%>4rUWW#ep@p2&EGxJJ;)yq`ERTD|)&;VdKU9WU%)Kuj zQ=(EUtPM{j(`4MRXB=4xUSbR#^98?3FmmLMbO*VjF0ggbB$ydwMLbBgZv-c3jS!H6 z4DNsvQs}E`2fIQ+aFR^i;~Hut>*#f2N?lu2_iReIz??$GVMe{aWy*jU$78ip|Gnrp zJ^zD6ht(B*L{V*{0ff$oN>T-cA>1bZFcwL`fKlUUd}si4O8*_`FeTyqKDFgsAs-%}e9qhs0-4#n2=~u(>`*_`SXERYT`?4 zk6JUsI?I%T^$$Mt)B0>n`tE;U^DZ&$?g}U^bgkEZ!5!&;w;8y=NAxkx#5lfjvXZm@ zv!L3LB(Fi#=C0ift+r)6J?p{rt*Z9&?RY=&uCHiqybNVDsRH{ILu2xy3vb>#ENG{v zq%f5!?tCEi{z|I#R>iWuEFJ##FbtkvD)!O74)?vXOcayp)X&1lMDA~==6XId4m`Q2 z_TLWj;(7myEsS}apVFQ?1059?^)q>A^1>=OQt8?OXgl=5{wxZE;vzTQskz)rrcP(* zmOe(D2nL0^uVZq22!o-EI307@E#;VxV*bAQYE1hCQRjz4Yx!;=$G*SMr`wgE4tg#u z|5?dsR-d+{xp-!Hnt5`MyHifk4-kD(p19d{$;xXE{U6`^ByMaaUNemZ+I235e;m=8 zzM)!-fVTCzr3NRpsv`wAKaHG!+9I!!6cLjdUEL`1(>Y`mAH($>@8#Bsw{+ZUL!aOXz$o3rBwrFBPbN z0sz&ocIrM_h$H~&y9T;#6jT&MQ2Hf`m9&vIiX_~K z-{@1^>@q)~$)M`a4NyFZf0-eZgzHOJNrT;;a5$KS4oqYB>O-axg8H30;)e!m^5Krg z(LIx<|C~mi#K7zAeh}!=pjg8t(Y-qm4Q#&}PO%s2#fAqTB=o&%q_w$+DnSjnN>vZ< z1QHB^bt&9HcnYd3faq79&2G?dyH5tfwrh6fFDMt^neG7EVx`fh0r7*iDGtZ;zOX#K znHgqJf~h`%9Xvt$qU-{NbFnqn=TiG6R9ARu6cDfbI#Nah*=|%^WuRAH=d%CiU=zg} zdaStD7n2&GWuOGVLkfS4tkFI};t=SM;s6Ng_^cK%b#^43ZN2df2UZ9DvEz6FVsfnZ z1XOqMWM*8DouIoSu%p|qxvv6jH;Lq+Ku_93M@MMTUQn<9?ZKi)G?-4c~nkE*S-*W_4fWJ+E z`0(kUg4|m|B1!n8Gj4Fc_Wdb<0~D$9(e6+!Z&iEddDq!TdVQt=;3Tf}khT{`FoJ-> ze6CPHiGV>+y~?P9o*larf*5l@D*o!V=YKDZ1@-(=KLB9V%?>bmt<2BoK=gR|nS4N! zhW&21V2ive63c|{V&Qsw|EAn7ng4o5ixtO_Yg9~>R8*8ayr?*%yeukMeS2%^`@lYM zh|_cP0H_iCpa7xi{&#SkUF?7`wOgb%tbvAkvfQtr0HFoN0(=rO@wyZPiWUSh+{So^ z>u-P#^qud1zqt&3R{%tbJ67ffqGfwe+{NhHj?(libD+A6WbeN@>1yNVxdmFNG^6CR zQLT7OeZMYkdj$Ql(N07cjvW_Em0Yb($Y*+*ZX3^Zn*WNl8;$v zMl`ci&FpXV2>(2ga-@F(D05_^B-HD0Z}D9HibR4UjXg@ByL^)F2wLYSBqu z3?OYz3epek!o~qOZjlr>s&Xr)t0oNL{PRwx;S9%+(Dvs@div>~t0GYBbYAif)6 zcwUh>ERGmM>pphr92fB%@1{Qr+naBd70cpuPva22fZVZe4<2KK6PkV8zfbUqfIKwF zsO)k@;xOxJwp;sEDmTwx5!e>@zuR^vaq8{=L*d}bn8hsfQ)Cx>7sBxoiKD90Rv%!-A%l%MnPRmw&x$J9nNyJIhtQGl9e&D)p>Wr{{3a6p#Q<=$dXki zVQ%=|O{31A&RW>Aw}K2+>Yx>k1q8jxGoff6(Vz0fqP5PEA(z?I%qmb>M#w0KoyIaT3b?mn=lUbR& z0;^u@A9v%v{sa=)o(NEJRzN}igf7!643=g-S&b0ns!Xp}zhaEe95f2pkW-=t++B%| zr!i@Fu-Gr2+|O~J(JhfT@--)Ew3->>Vfp`J+LR2cEl$Ry+!hbCDqh565ip9&CXnaH zlkrcylL-CQy5Disaze71Ic+>yqSR3y$}Zj`_XtFE2nT$T+)e5cmhLx!Epnf`9?w$g zXuJ}_tIePMn&oD?mp5lCN`?boc>bM$jf>Bpn1(~NIWSe zE=fj97ZrbiDNq?-pQ`n|z$NbX+IFKmoc(>l!WupMD??`)+z;dhJ&}(lKZi#C`jL)t zA_Ekz!H+3<&2_tUjsbTZzu~_cqli;YnXpb z{ZPiF!Dl42*w)Xb&R)GZ(yHU}Ybq_`0hOG<=qL&2;@_hLK^*JFnY~Qs%I#m%9)moh z(nK%pji)2x2AJ&YpkLFN3j(%(D-s)8coB0~hO2B5` zIV{YSo!+^2v7foc-cd8d3{ zudH76@?$S9K_t1LTh#M37K(J+RHjtxt5bXD+i0KbgPup1XOY#NqE&&GI?WF4r#q8M z$6u(H&w1U)3uJy5w(sNb51DK#=RTF!PGgh_yt`@E<=}j8FVvI~QM;8jZCj?+;~Tcx z51~|(o{u+Fn=R~*dnfAeD|12O_{sgZJh$aAmyf*Bk;&KI=tPbBO{H+_(?8#3Q>$C4 znU)51N}cS{x9@a}Y@96pFK0V%2~!<+abk7S*!3mWzbF-n_@4IPXK->9bJS`%B^{{} z#hh#W8)tBZS3}#b#;jc?t>5g3M_Q}5(|pF~BTno<_zr*bEJiuUW6G>Q!znAd{CKUl ztQ>9nB92t(tuSogw-EH~hwi2VtXkVS;<$jr6$M^P&&f^#Ea46MLiZ0C%TrwUXzjwcR?{KVFDuB_aNk5j78CeHyB{Yz@bV*i2S9BB1eAmn44#)K)>{5ukX7^)DjkTA+ zugf?|SP8g_fAPUDC*lER$^Iu}_!rhr%q*`$ea}P5ef*TXqT8GM(2-jFI`YW`DMrIx zCSxB{@qQ~j_OwMzU5jsC37_0?9gQ7XzIxq=r}Xn!|3w?hVQ3x~`lFYcRJ^Ukuikm< zU*^y*=xjdVRjFH_KuIN(yS&_}|JX;7{ID(6Set63YJt=GrvvwHRz{J5i+O1bI<*29 zU7sm;=F*ZJDE_&J;;CBF`&CFyh4>|bd}4V^+4!sP!p*Yh5|e(p@yNbhV<%ylhxW_E zLdy2u(SwKftz!jSM3((uWGF~P<(?7oZ`;~oRewn0`j++kiyZLgP_X$NnCDXUgHlcmtqdl(=6CVj5 z>qn9cxU`t7mwpAM+{h*oJG9kWMsNt64CA;5JWL=KjmZtzCb?X>yM6z763u1R@@qb= z(L^6ntY=r5m%aHy^S-HZRRFw>;8Pn+gnL8&(@nx+*a`RjQdpha-g0=q3bE}lUzzQ# zX@_>AD3AZbFxq7HsJ=m~=gcl?$}7Xeaqiviiuu%~Z%+>eO8PN?5;5C$@fWM3)3Oj` zIa*h@f8xIv_3{p20kt7khDJLai?E2{y<(A1^K9z#`_!=nC8ctYnth^56n(jdR5mTH z=j3@@bY@Dot_kK)T9g%>Sd(GL>e;w=-CRoao9KmmSW;h1 zLOBqIHIoT80Ddxtf7=Orex_Ftx1dzwb98)<+@TkRjdQet5Fh8;DD#R4dL9adF8JQ@ zq>7kYHlMxKe4RPGwi|cqgFb(juy(XbiLoLHGj$&*-w6Ch@ZcNnCe$_Hx+b*PHZ7{Wc8^|P_ z#%iFfL}~TR;VyWlpihd&_Vj^q>NzxE-{s_D>Dt$sXQxDK z9mO`XaJY&4ud#$@ag@qLFP%Etj^NCx2dmT+*&!ur%O%M1%X+(_uHDFcqk}jLR?9{&cn{qDkHvb^XCTDf2n)R=8uK z!BK|x8NF!7Yr?vhfl}srA3aDm2J-*Na{8X_7Jh1;nhlRnBjiKla~zK=5##Usf$x8D zTIKm`X8WW-B&EGxAC`9W-1zKz8m;g28yo&a$LhdOQ=`{0#D;Y;O`_Ys%D&3ac$T@-cn zGg7w!D*o29xLGRjuSW`TByMHg_ZAvYXd~?k?hBsRexWQ(rj^Cofdb7m>vJf^*jFu;mOPgMT*UDIXkR{_iTWc=fKC(`L-cY!rD?z5YD`3_uSIUCyeDV zR-QyMjkPL2L-;qQPvxEX?9h5h)BaL9COtFe*WJ5c5TC!%@55Kglc>$+S>jTt&7ROd zP@xkKPb$_dd(C6tPw=VMW-a8+&v2G`>#nW<9Vr=chZ&3zuQG^ zNtU%l_)4j8o+lef-G5jp@8MyM?dwr*#bl^gle95Hk*!El%Ydy*%5TH;z>^t~+G>5E|$>i^KFsoMktuAdrses@3q?ip?Wk!g4=5(PX#%}<)AaWl+43d4R8hYi2G z*$=edovY>XJ9oMmo*G50*!3|N*rlG&cD7Tu*Zx*oY&+K78P@;ysd+lTm|4SCJ9LqB!h*RK zJjH0oy?!S7H7W*GqcN&tpl#UB%t#X`Z_$cgpIuF>JHG!O|KcN*K~_r}3#*YAPq8?& zm2KDAx8w2dPT}$CT|Yis3EUBI5+Y>O2}-c)h8KbpS-GBz#%%!(gY}*6ShgZh(It6w zte2bQF^I_bd46g4PswPY@n_4orxu?|x~C6OEd4A-vlK)=9pgq+U?uqkhz}HCc@7%4aHSLTm5Tl*OB{r;zV38@s=iQ%Ynb^%4? z-1CgXew^&m&lIOk=fKJAfueft9`vcKWPxFCU_tWYT#{*lH##Q8L!u`mPxKm` zYEO6XB*pH(i?eVHz2A(#=(r_%X~nM9-XVB%lz0+b8kH2jw{WuTvBtGazIcMVcyd9~ zI+U7!`PyS|eD_mXOt=Af&Sedp(~~Fe#xg;5hc9IOywk%7Z4DlSQ8?qC1%bc@0mWY8wl^k(on<|;+mQd}CFRQUtMNt&gcDEeq_UaDh@~XPC zY=BUvfwp>R43uFe*{AHR0pVP!zvgPI?bltxsCzKfS8m(GoCB}U95vhItk*_01TI?+~P(bkB+5(&>n;aG~gt2UA$!zQ9SGDkLncdtXM*8!;U`aHvDo*Bbv}AF_ zF&gRN1Gb0O%&_hdkOG0 zTGtq&RwlNb@zC&SnHYlns_%jtiyroNGOX8Mv_EpqjuC;1e)uu{>~OJDa(8!|c>>{D zFGvA3%QcgmIvZtlwC97&!dMa;d@@Ru^Gx)6TR8EpmgX>>7mtmKZpwXr{#uYx#j*|4vcNU)tJHK;OngM5yO-j#2rz#e(RLo6`Ibhas*C!0EQ_93ito1x4+K9}&oRDh0koB;*OF33@mX z?c{ii*_`XxFXv=2kb=-gt>r3Y>c-F_fCh7l@W^c(lA*pmuX*HQ7I6Au)7aluHGe zY&~Zi=yCGA#@z`lWwMo(K-hhg-=_B<1sqlJ6%myt^0;EUkP@E&youB{}#VglP%ALRu03@{-&~S0kI_gjR>| z=Fyuej)u?PwVOo0jEnwUyA!vhMUsHH-!05oH14*B7dZhTp^8l{E)m zb6Y-KI2)9Gb60oWEiCX(BhrD#^PAt-r|nm#@6@v1^$q!dxMBL_Hy?9ZhzwG)H(xOwKnjfRGr4J|M|Djux zcukA^o3YP54JH>%+spN_96)Xlvr@eGScErs!QIN9v_F?(!6RR0lpP*Hb`Lz1GvO5Q z!R4Xi?75Ho1>RcWy$>(*a&N5qr?x5!R4*#Nj$iT{zCU7rYDApqY5wHR8P?^|XlIx+ z4N4pS(rWb?kiT4I+SWvMb+|2lt~=i{iTtyIrP8?*ZD0kT7>c;=BGcFM=RnjVlqY8` zN}zK=Me&5n`FFa0A$Ws)ZU`5M>~F21%cL0D`jE^#p>H z;B26;&(N0WsWlt!VH@Ms6?Nix$!FY|e>@r$GcuA^He2S$uX7R={<27c|5~f=sN(pD zr!4dyHTJ$)?Om6pTWufqk$~%_TI6YtBGzMdR!jHT+4fuTrann@>wmvRHW|PObqv%I)o7gfJ z4_1@bKebH?j4_90N#UJaKZV8K2cStZ8D(Yo=dFDdjC>3f^y<48w+5OUG!8}V^nV9i zb#h$)2sFtm3~+fenK@VF;2RflRZU9@jmB%_W|gxut-RcO>|?*7&|MJ9WG1h51i*ZUwF$Me0h#wEIXkedBsWw#mkK-CV>aL)16?L2 zdAx!eA-s6a%X{v>Q_W`z$}io?CT3U{Y2{hOw%D_j&XyQNfh3^hsNN&=r*XK)`rXvg zJxTcc3eNh#)u<}5T=n;t}W zvKV~gp?Kg{xv$KaEi{#d~oZrj7`^9+l&wrgIw}1+4 zB2g%2X%QT{r9=XSO`ecRb1YGi@ns8#fycc0C*sKFxI0i%|Nikgw88txVxh&E7G&oN zm*^Luf>TO2%sOE{!l3l(`bO?^*hjR!R%y_h24~TMZRI_*Sn_DJnVz^yW!6@{BJ_6{ z0*%M)qOW7_{7ppPjrW4y9X=0J>;W&;eG|^sx0vC7HGyRy&EQ}4tKXVJSM$S zrKex>#KamMS8~g>tCg2dLFGFkqZoL780pb;!s#y?qHX9c<46R5PQO-Zw<_qA$5(ML z{j>GXY_+2vc_M|HNV!%;$=P}cI(F8~dU3kp1iDcP3B&5(mrxEdN->~VPm<&UN&Zr2 z!@1RRlrrc0M6p6)!mD4)?Lt~rFLiZHkg|?lSTqZ22$R_pt#ZvxFBYSYHqjO0FSKXI zl(ET&Emq?NMPCa20*9aRFYBZVx>>1g45bz`SG9z&wH-*ds~jvjsV=UrZloKHZ{4TF zpz)l!%PD46=AY>AzW3=o`HFP(Mwn0RYL^=+BkZk+`a4_$@5s#4{mlDZU((-DV&N-* zG8!k4u+=ORYbY0EDfui@#sgtS!s7+OA>kTe(f*#8S8x8@ND0XPobf4G)YwD~uCpHH zI){RjT=TBCZCj0n34+ibEi}c9^QZ3r2TR`_PlfmY|87buArv8maP93HC3|LtxMXjV zy&G=D~}9&vH)J+A#bKHuN}9@ph@Ua#|dJ=a;A(T*DaPUj~rq*bf% z*x}xGap>Y-Vu_&uW!N%fUhVqoNKT4Y+ir&~*XleIBNz67OY%}mhl1;M-PvwVzrKfW zds<{4C4nT<%%RajL~U7IjDR)$-H|kRuDYtI)mhK)+z%#QwTpQ7ZLhm+NSc|B=5%IN zM#CgX(1liuhW0xfGc8ANR(R81IM#}t6x~0w-Dmuc0(;%}mFelN!8qK~fDnrs;W#tZ zn`{?_Z1o(f;Y_{w4uDs84@dUNRC0>$!S)^JryY>A;<9peqOx0cN7pzE_meq4gmHYQ z914*un0CJuFHuc%C=@3a$B6>?m)s?++f>CNVx zdJyYJg0)c7I9v!V?z?APT}sHF-d^ft-TQ{?P3L{fKWQwhXS|>elp$Jo^*qGH#xR61 z{h}`tw0Frlr6Irn@pMR0_-@V}((h${uhgC-&wm_?uKi0bf^T}qKc|z zVE#`3l(JpAL!zCnwf_9GIGv1LiMgvoBz1hJF|euCxV?fvpFeG-xT%$6_Xr6fTh=W( z*~$N1RFA_vPpH6ZAOyT0S(drkVA9YE13F0>g@v2fHy?(c_~DSOFMvt5o)OT0Jp4Dg zRO=2R4$4kTZAa6EW_<=jgRG`1oQ5*g5?@GSRRF}Z0xp+~mJsT};$DfIn)kI$kK;UR zoN=nK-=Qn?6`8t446Z8!aJ+n~bB~DJLw3-iW@`8?x*mIhMqzKn3_yB$?*>;-A7;Xn z!e{9^1%^&pQ)X_R9OcV$Qo~?K?RmgXJd?#Xw^(=Q*{3+e9VoNlHx z%HzinQa$Y!(l)WP(3rmciS zz%8?EgT4TW7_Hu&@Gg%92_e^2EfeF&++Q0g2qaG2TTnai-LT*qAQn`8PY;J>IEupb z?umT(_n(UHTW-M@n-`Ogbt$ysUN*;V=ZZp{H@dMA`Zg=)l@{Gp77bF8v|oEA>t4pr zrn59@MkU$3{3#V-Fd@aY)zX;AAmn{$usZe%gO$z>fDBqFr@hU47t}R?BA>^?HCDTD3oKtq6x> z5`LdalCg2HAFa7O7eP#jJ;!6!w3gAyjkqJdDwrGub=O0x*tX9!a9lt|Kg!0~cvLM0 z7Ksxo>I4ub#V?4bIemBqx+2v3`KMdNTpXToW{HcNJ70{^`AZxE%l#KMnO@yCgys#^ zC6FO$3cd_+If6rff2PS6of`^=w?pn&4WAf`aF+63=RPzclb+ zb>BtnmVe(e3D&A(a4>nP1&s3^aa?|$udZXyv}5>Z{NjqY?T-kccJHJ(qh$C?9gq^d z%;ZuLLE!}{WgGfyUlOESmW@@0ge*sVy!>XQPG*v}t1R+XhsymO?H^$D9`r=(8pd2r z1T$Xl?BVQ2Z`PbfzAe1S zwDn!kO=(Z+^KwU%;lQB(fE)ZO3t(1NC5Aug7mFHIO-gs31zLrw5ZJtMg*J{l#_80EwAeb z2D3S@T!Tv96)&;TNZYN z-8T))v0?6I=m~Nrjk4g-G;qtF^nX0wT|_)DK3`z`B3iPQ*~{}Z8h5Dhi^Y0RT5htW zh~#aq?kkRWCBH=vnSKg>#o1M@ajsz{8Fpk6#p;Uwq&ed_HN<3&u zaH<55I@yi(OzSh2rzZ+H=F?c}ynT?a;L@vw`JHEwMp{$pn^M>(a7D(+s;q2!Q?J(1 z;^L`Doz8N031Y>R9qoMh!py4btwyG-U2w1iGDS-2Xxu7Tb<#IsUiv-54tMqNTR-Pg zzP};1Z~c#D^IzomOxsr4l0}s@r~1~$F^=8Q?6;8;=L>vZb7{Dh2z^^-yOE1t;x~gW z#3cSF{ZA+-gzAdmRj#Wtm^JxaqN-4M>v2hR?{X`=d1QJuq1AJbnIjPOnnXkroV#jP zRCN9{G%%ZSv6ms@drVofz5XIHmdC=YmiZ!1z>-R!LJ4CV#Z$XHLFD{r5}V@@XtF=o zy7eP;87dnAd8axFKY|)R6}nz4x^rvCL6TuK#g;Bc+&TY+ZlL55F?;0BL5t{vjPj7A zOBz${pRvYvd3nvGJO(A>>G5Nv0Hc_t1^i$_Q8wwiTt)Uzv)%qC5?=r1cC&D#F!LU& z9G@KkydjD)ZS3MV*J~%!TG@VIB@3utl~v#6ati7PoP|Nis_&UE0flMBbh6~uOuZaN zH&9ntlAWH2&apvmddlg0apX&q>i=zQYz@^aQI#%4dU3p1xV7QaC2W9}2*rOq{uhho z;l*pj?cl$^{A}j_4|k;ei2Za`xQ7nCVL_L58^?TTn^@6%pZ@YsK$$y z9dKLle)#v2O3*HHHR z)T1>;n#RA#q3@rn%DNH+PI?&ZgYhr*Y&p zZZgPxgf5I)E&6?b{N!&;G&;-A{9PadCpOF+S4*oyhS_+nI=UW{ON9WPE1rda44{~ zkxNcE6SF^+m-BMk>5Nd=2!NDp70CUSy_YWS?#oT#HQ%IAbx{?&(n_}upKWn5tXP7( zZDMbn+)tMwinyb47jQl!`y%_RK~2bJ-Y)~Ko;bsjm)t$)3})?VzV0z~R&E+!I2loG z0?sR=mY>R?L+d{xH!MjKZbWiC=z6+*h_^Ybui=-AwUj}+s~FE|S4}V3uVrb32TAcQ z_3FKZ@fVqR(QxUA8llBDUehaiRmw$j7QX|4W@ohz9_OYFa3&Jzb-l6j z_xC^)MTfQH_3F0NnB`Kh0V7G9kL8C+=HCS1j2m8^@$GUl^gUWJ-C6SHs4jGN_TMiom<%SCY-zy*k9A-|9>zKWsLDUb8ZUlXI z6wm9)%0JFexl+Ev{J~M)1&IBd(fl62gAO-|ljW3?FbUaM(*pC=(A-e-j^C^&{klcw z%wD|*|5Co082~2!fYI7){$2C&dU0=w(K%f*7>r<#@5>R%q|EKKIqS2vV7sFE(ppl& zL*b`R2nvqm5Y>O1K&~{}q~g0AA|yz1`A$r)6&eJ9L%78GzGf#R3Wubb)M#>!5EP8> z<%uAdmulnnX?n9E%vSX2ll#)#v@X)mbqKwzbw>}d0$<8vTU3oJAygQ&G^?l^|Ah#< zvKh&HX0ytTLqGr=+frox;B0d$y6A*;%3EGPe`Ypk>W5zrmAv3K>Ou`^riGgSO)qtE zE{$c-YxPx46hwHrQPS}}iLC!uv<>%k+47xdT=Q`MbW$$FpW0S6RcKwB2IJ$e_HzAI z*eTt%dq(8ELQ7OFXzMoZ>74>G9q=$`F*Ck)hbx-`qE%lM9zr!GM?G3+9Np`Qi&t{o z9FsTi3@-|fRQ|WF=bR-+xhLrQ8Z;0HrpT3M=S!_(Yi|R=re~;N?(Flf5!E(c-&M$J zUsfJ>ja|KM0VL%+ogD}DG~W*IQp;trcs6S$*fQcMhg&>f;jm|$w?nM4-|$YG1oWM7 zN|?NU$+6rLzC;?j72x#Cx~Z?ftli_m6vp@UpQV@Ot)p!`ue0Q~9Hd5$tgD0b^ou3f z#2jbDPSmBR)byl1wcFVPDX-U_tVY$^&NXlm53J>k4D?E#KMn6`?^d>NjZY_@;p&<{ z2)E5__w9zyZQcaZya$$R8fFw}SFsS843DHJ$RAC&|J7$8Qh6G7w#MoHiRs;4AVVU3 z?ep+M$fqv%+oQp+UI8F9;mQxK%$Ev?(SLZORocp%6lBQ#r>T|psq?`h7e~pxHM|b{`-RCAl^V^RE zEcK9GN}b>m{gZqu7by?Uo|%O$rfamJw+2xq@l8R5h|QrS>h5^5OFK$t?wk9|!yzq5 zYU}O5+OvgUYFQ?5>SXcWsg2Y98%oupz4wt>MXdWJ zWav!D$q%;xz3{u*E9J9(loqF5B`hmpBkhRi@F$6c6*WDu65lmsg_Xk{y@cqKo2AO? z(|)5-hOjCI-@TAycP%Kr($j|owD(s9(Bhjviv72fm?KPh37DIhri-u~?}RVm!k5zG z0C`dP{Z-q?H_phL%qfvy6r@a8bR{Tu$4I{C*B=)ka@oq940W?3LwHmcYW0xC`HSLS za?W;SQ5kjW)Y(*~s78~H_Vx2gtxS1-9Q@R9zK?rA$jPGb0@Wmo~lGoG~%5;vH-hn7IIwK=v? zR0D9XNmCo`&7Vxq_TuTsF=4>uRG)aY9GB<=3F1B@^F2eSNT0a&rS902vki*#6bAjN zs6gHd3{a@)rgq$!Px%5+vzLkf7nO%5=R&?a38?A!{qSkdYhX_dnH*2nE4HYrXVNI_ zk^&019U^Oh+$BdEda?dgS+dNI3TM*}I^39y)I%=M*MGjjZ#x2+sPmWfJ(_lR88?sY zcO8XgY8M?SufFh$~(yrKE=}yO# z+gn?{ll)!F$BUE%);+nNU@lQQuQmX%#`44(+nd&&jTg9Mkq~uk{_&jzK@Mosv1DHf zk9JN%sdL}oJlY@x*;+%ZL{n9_qg|4Yks;NZitE~Nghbh53h_~&9Nl^no&QN-&?8pm zQR~*3HZh$D+~l~E-m{=`K} ziSqa?$J(@cc>GpCicOL@$)tAEB323i_?L8?L=TXXkhQ8U`+jm)s9NW3ce>c;_sVtL zFG8Kse7bc*mLz9d_!N*N%Q7m$57!>hNl1~=p=MwM`6zd*NjQagov>g|*XeEd=~oCN zB;8;%Nsqzy(Ud?z!65uz*YYkykEyz){)rxHI)CHQ}ew{21!7 z&<3E~90uIpMvT1u@a&)Sb=BI~6U-_%9SC-L3GnzEk0o|fUw_9(Hhz0Cv2!c!p!sb| z;&gHy4<|yIg5#YuS1q2w_joa~YIy++yXm~<9m_-}1+^zO7A_ZXNv>>4Hmw(fz3V1` zr>lAhFJC~qjv9J%JlX?pXJ9r#GCP(sgSlUU%KR^gT^rZz9kn_5a&xGkFw-2$qvTEAi9n zTjqxg9H{tv^# zAhWKGw`eVes{DCZs=B*n(CzIndIzhe`}R|eo|$-lFc&}N|8$#}$7p|%ch>hvyz01x zYM>dAL&7OZU%xBiEXj*hrzawyPrBlEp1Q6Lync6wLO&v#Mh`#QU3Yf)HH}7xf<1+1 z81I@(sBy%k_$Z?$5yKUH3M7R(P+GPNZ-Z~JS-lDEpBtiV#*u-cX%!Zua)}sAYn}Y3 zmGC)-+s_%yjt?S@R(6WWtkVx|_d(-k98`v-rf=l)SS|m@_GTNc>|acj;8?)+E&^C0 zm^nQf7m9|8lum^&e-LDBOy)?63z_Dx7<3aY6hMrm$n$HMe?H&bBkta0ziar!+~44N zcgp{yNCC%~mK{Cr#zvASk&KoPPPVb`b~68+9{MTYpruiUeB%+7c#!K{j@%!suWj-S zoTBqTQ&?0ow5?)COf%H+s7WL*@9>bVY?QlMdKvk&HxHq;iKv8A>kex+*TGjrB>J5x zkcIqarciz-0@(EB%==2xZ9jQ>So(8UcCeUei5|L?>$lchJ=>EPUFd&ntNeeWmLs?^ znU5l|UAS7pfyvC6{%?ZX%ypTe3V%&UnUp12=Fa>TseO&-+f1oju#q# zU1VS+k^qhQXQl8HYu##@OL?&w5p3C{Lkh73{@J|w@p!5N*}IAlN#=2YB^dy+V-6r2 zWHcy8a8}7bF3>M`Xp(yzXix=)VBS+e*9F2~F|;zEDxGmbUtIku$ZJkFt6dJq=P9Qk zF5}hi1=bR7=M8``J#At{iIFl@a8nIrsxx|lzQjRUh9`OjWY2IeocZcMz{{Utw$;Y22rh_)1sF!hGm~->J_p zuYktLT!5zYx@lg>GR?MQA;d#Z)LaDT)j2Rj{zBvA@k7g7l?bSC0s$6%V5I z)xxmLKj+{jW@A<`SU%QX2)6)hc%9{3>@ujR>yh49(y^CpFbE zcupy128Yz0p9)p2A-$J3U|Vh{_db;#%u}WAc z&cEra`|-bNiQK1jM(Fd6?g;1qLBj~y+&D$R#UJiYiS6Muo4&cb*U5!!+9^w3?mIWJ zZ~k(3HBLefA64hadktX5dt1{=^%qdjbt}(YZTfS`i|%xFg|q6%j5ru5K3OHt~ zY}MGFyMy4O8s0XS;k$?(!}!k(vxXbpSjYOkZT<}{OYnv#y4(&a0-qlb^mS?LwG z<>%P+!hRE4zNn538z+9x5Y}k~Lyp+LGq3zhu+^eb(x*M)yj38N?&?ipsk^uEG)nJP zED_L+%ml*S%}6S1COBUg6d1fd7cl;F_h@M;Z)h9otMqxd{g3C!OPFrOH5lI zbi(9=U1zIY*?IWFfP}HDYy1imkGNu`p#$>`!aUl!G>`{YKz5$%KID649|7g7?iWTK zEh=~R-X09yC5Ue{ zdiMetuspmJjEP}8R;95>dg=+ZIbzXc!E`9MB_$r9H1i$8R3W#n_cL58`-@7 zR0?)$Hk|du!AczBk1_R5o;oP|$+9BgPu2EGw~&ynnB${wLRKF#TAaUfPpN-O1;483*yoa^>Yu9C%>D{rmi{Ph z@;~8Az+B={x#midgM4ag{gVv&A7p63oGr3v0hIYF0Zf=eE5cHnaoNYT)L=h{kPNtW zv#V_RMwRBgw!3%X>e=I*N7-aL_;{;V{K(*BClVho9_-Zd8}sd|&^im1e0*+9x;BXX zuiD1)xBQCyBqF%<=zBO*LL6Y7clBs}{u#fk{VjtSL89o!k9#QQVeh0J5M|psb%tkL zbDiN_O2BP0cuCHl?@Pgz4Wnd{C|k~>sm`3q``Ak3(AUH-q*Y_6B{n_%hi50bBwUl# z%u|~zacJ1jlk20TRcog93mrHR>81L%Go4^uJ$tFh z8|XlLuk(8YnOn4M0T7)>HD3<>m}f{--jpzI&Apawrg3HZMNcQ!+^EDJDWvnt{%BS* zFF92>IFc=%OUFcx!V*L`1Kbe7*TS216`r*aBHUxi5q?mQgs@qs3vz)%>!hr#e)Vz* z`jy)4&pi+%v-R*!UMSp1dG=!`2ViuF)q$~)!=L=Ibv|&^^;;(TQdD-ImFVx|egMHl z($OrXBXwf*cr-$ZvofIe;>32|-n{!M9jj)R2Gg7ag@v`QM<1YZi=Ul8CZyQ?;Z`R9 zJ(Nk^BApzH^#o3m6?=%N-*Z8ZEvt`UqiyxFoY_OZAfkY@VT26F4!(ot>a$NsN4F1H zW@-dqjYU%Z+~4mU?vJZ3PS#dxj=GFO+lT1&=!r(ZqB?)(wSp@Y0v08L=;0KT=AaVW zNRQ8`hWm69aR5s<@8BN!)jAJ^tICl7&>ulGQIM*vZZaUEmWrrQm}1nhM+qrEch& z@15VzO+l+?)1oJwp;XE z_0mGkyI>UCs*Ej4PDux$1mb&U#u`9P{5PO+H~b@izMgD0>fZd9uGQ z$CoZ)X|47=)V%?}-znZ$2WVf7{Gb1*$2N1 z-14Bz!dY~Yj0gz6pUfA6poFpUsF>+l2Epj<1l-G+z;=}LpF$1hYgeUXq?w>=79M=M znv=5MPvE_8nt~u_S0raQVu`a2Io_LH(RfZqf;At!FxgB96Hqoa2B)jVzM8H+c=RpyR$-|5B$5NjW6#U%CsjH0 zOVUl{P%ECTR%%awQ}(J-qNn@|xO-VxBxO(sbk9Lmynj_cvJq=qGx)JrjtCG|Ys%w0 zuH}3|C~m#lU2S82iHNk1&Dr2h6qs~`Q zgN8^H)n{m}{M_G^8%Zu{aTiJh=my!s9|gwJ20Bkd)1#X1D8enpQpd~h5$jO*W!2tZ zn3Xbxd!Az3z}cJyzx@)Q)3KZwJ}Rs+GQ4p5787ZrWUCr&PT*I>W3UGUJ)EV@XtwE+ z??R($8``GX#%2PaB=0iW3s+iF|6ARCS&MB-krPwJ%BmH1;dT^KvNR(PBb>BuC7dkj zvl&*n{2VtA$r5#SOd5R~B=ZI=ds0Ggv*};7I>^qM>6u5HHzpNTMjbb&1J;=VSn%zy zrkPHH$-bNd2x=Em!xPbrJ&PyqzPt!Ywe>mrTHg+ht1ADD!L@M|@$;{pBmlvv9 zaZR_(gOlRb&8Zw2l6&`w^n%|q`RC<%TZm%$BnAhEiN5k6J3u5)4T*uqnQn(-s1Um~wf%@@>Am zM`@Pz2n}{I(QMixx8@lZeH00!Y!vM+iuObqeYD)9^uRWcG@T#bcwuCp543|aML0wc zyQAnUwu{uebyt_`VYzKEeND+0tGE| z*4qQ~_a#@cs?>h0QX0zCxYxlRZOP>JYcTq z+6*bWOArLSa+KETFqyIHKaO+Er+Ay2Kw8&~fHVP)7raNriFrBuBsIY8?DnPC#pm>p15@EHyC8h!l?uh@ z*}(>vUj%b8{z;Y`L!z5h#EvOTf19ysYVlJE8}5Tu_5`|CRCMm5&O$^XSnjHjkHYm{S6}AGYY;%}L{B^M$bJPBa{14=v%fc;Y&pZqDM> zF55<`P%jm+StBavfR%ujZCJ2@uWa()!3xV}Cs^T>RJMN2io=O|;&B?H9_xd`)WNii z`PuSvNS{TIa6!`t7y~+`C(i$TPn_&p&?W`Ya!U}~1rzp5yb!lqzO&=~7}mDCA3T5JwiJ#aWn8RL z)m7gfl_&57zC+DUzPwDsD0L z1eF)HkXLZLoV(Csx(hC7`-y&l`C>zYcPa;E+jy<|-cw1=oaju(nS+f_!1!-j)Kswp z(2GcW>|$DoHUU~Xh^kkNuoQ(uj=4J35`8a<4v7qQkh+8Esu{G?%|zLH{3!+N0RY4} zq~_4u)o@|>3p-pNf8m&b(JFlNYC}(BX1A?a9u3+;ShiLbkF#aOkGhUt1vTYAJ+OIq z&L6*v@NlM~Z1c&sybbWpcYbhIuk(s}EEMDa`>J-KMATnUcV@F4R@W7V6`PM&%Qa}g z$o%kEcXWxgW~k$`Td#r~Eh5)N4+aTM!<hIsI(v{|lKAP}hg|c_t5c8A+YSJT zMzkHQUcBgzkkNHoiItS5p4d*f$ubUto&*S@*CEI=n zi;~m;!>IK*@0!tH0F6t_wd+1o0jr28KuUVqf_AxcU~i@<(9i#N-Md4iBkHy;kLgqd zP+@oS^To&qV3Q}SzC{M5mjf{5+g8$O0a^6HKgM7x{)cq^jrK1ln*Vkof6#5&G<~@i z1~O!%zq!!swIVkTV6TRdt-x2Cr{B5p0xLsL)ayti%H7MXh%0xwtH&(KqfTb-Lh6t{)?I=DN&3#LU3V}?R%g8r8uRbXMa0FjtCUQ zRDNT#hWZCiO$@JP?o6A^nb=i**B6)8Ob>(`vZ%(T0_>bR7E*Y`yYVlt!`n8j-~Vh~ zA4H6L%ewJkhPYDgKqTE*Do9c5zsX;(-y8KwO12boN4|TUVuj@Du^JknQo1u?0`RaQ zjHfd&n-<*UL1Q%O&5r+lC5BLT7AxwwH$#mIm;ZJhtF$%Bt==0%&Q2Yswlaj$5usQv^i=8qJQkVCs1>XRPI2`P4O(~TJBl=d$ee|qP zQFEl=Ne=qgQMmg%h2;2DrMAZThch9IzzRz~GCvM*3>~$dWHfgY#UG6}QdV4o_g1>+ z9_}I0-Q2181M2>72h@4V&uLn!i?;J24CVG~ZrS>$Zt5T6>OhFdSdpoVdC`E=N|(59 z8v{jgFqL>Ph!?O<(Aql)3WeW4JN%Ayxi~g?wcb5C`iFSt@{KNK&ZpR%Y3a3@hjfJ7 z@ovi6Vb_=BBI8bWgcnN+b0+Q!{3=^Zx6v-eUeDdAN`6n6lPz*blikb{2-${C<_-xb z2P^gEynN=ZjH>`otVk?5_Kp}ujvlH?yM6!ZKw^GfQ4n{{um6PLkm{Mq^eqNP^w9`a zdH*Kt5Rsp}kG{A+BdB18yubt;< zrQ5M4g_PP(C5806!{?GeW$`5rxS86p-5t3lF_;ixx|l`KM#N$2CtiVWEshi19@km!d1vAq1DwDe+?dN zqxUd0rkG=Z@eUAeK>sJ&Xa>FJwSJjh(k&v2zGNB14}WVmtOmUqkm`T(;}1|i&Ys`i ze1y3X6212;G|geU>V#7ifL>*K+P{JGc;1)3Gapu+3BaYs!?KAm*(iA;-8w&U04ys# zz9ki5#qYW{l9Z-olmZ&x_8q+_oNYl(j79>u^~DbM$%bA{7RfDm2u1wV2~YYp!3F*0 zl=To@NF!h~^e3&;sA!7Za>ceYYvo4ZrJ_N4f(o6Tlrv78B{(K(TN-+it;jS{R zvp#=`-!;7yHX5!LBL5nC^Htkl$e)klK_#B2w?RCCV`-Uw+c_a{YyMQPM%qJA33`?D zt~Oq?uf1GS031x^wNom?8lN-$2)%y3E?$+qOiluEN)*VRQQWGFj0Z1P#KP|n-jB#RI_K`$BD)hWUiyO;31d)~o` zRVk&m83aNG9xnx$nxBPOsr9JFV|V<&etsuR88tzH=3|46s?wgW&z4Voh7E7)wKe(+ z`ChIOD!unHOHlv7aGIRZT)iGCQ&>fbfYbqxS58+-dVqOGYyTZV(S1ntJ7EoH)okGfS(P)C=v>{RFAjAu4AC`2 z8PF10j`vu6-l{oD@RLUBJqC$6KV{NoLsm;ZF5abVF%=GgbU8N%EwbXG!R_22b!%m? z_fT)sj(0KIGRwW7R*weKkfQ}u+0$E13fL@L_Z_9Px{DLUM_6bKqpm}B4K0WMPIl;R z$x+Vka1>ZRiH)C{f>}Y9N^f-14U=Tw%tayBlOPZYw6$U3&pzikk+oW#XB2ba+RApG z%sWu|TfSaWX8T}rSP!!nVx4X&1~R2y!>SPOLSrD@Y5MKt6ladlVk4O0-a?GhIHb!% z*(V<6T-Mj}TUoXAzq-BE)#>R>@h0-Y^ug(m&@^BfT^T0lvT-i;aL?+xr_AZm<#)J( z)2a&gaKELxlpbonS9YMKKt>d3034NRLV`7{ci%1nFXirwzt7C{Dt!dRJ^VWZ=nts- zPA|uyWO(*;L8PSpJG(HZf&n%-sXAN#w-@@JBF4-zady3c3Ld%7x*5}-$2Q?vMmp11>%Suil*dXCR~-&{7-q#e`7AK!G{ zVDLW(1u+XUTY;m^6~Xe>QnzA^?j5t*0uQ+`8!I&C_o?uq-ofu53J}dA#Se8)dFRLg zj{BS??#;X4EP;9Lr|rS5<|DcWb$(w3BKBIb?^kE*jIP^4|7d1x6M-fgBariJzS3@- zm#5CYFbe(BaX7m-?>NM0Z=r>cS}KBoMf_#FgX;`Q@Oc+QaCSyq#I6IxK$aN}xoWH^ z9mt$bYO8`d+dvszxdK)*BvX40a`L{tI)B`DG@Jcqbq;KBEDUsj=O&VJs3!{UE|G|E ze5ur#SCSUeoL%yxuYG5fI@W6W8Ez{_-Lz(7ev~NfBLudSlmYvD8zrA+1sC&zL+E@f zeskKrpe*mWSLN7fRdpbEJjHJ@9-NRWW-6|Jo$B(I%Z$vv>!rSI|ba7&8Df#IQhqZ(z-J1O}^gdvB*FoGB|`r zv=oFj3wB96lXXm0nnvnY)H2c=Fg*i4c_j%0SxNz$q%Vxj-4%o>nVOLgc#QigO}A!u^_!7FDKN34X z9Eee+M-H%TKdst9I+fuL8;zw#M2|7Iyb%jCfg<2OFAvx9RC5-&(j8#jP$Nivte zpyn{9lnAv>WZxD_590988jQa`0EbpK`0=Y6@6OArU4Uk%08eAJYh*wLkVfg5WCZ-k z{O^w5?)&|9&y2^myWcP3*c`=>>7z!8I38MhGvkC5*+t1*E=(II9O!j`fewslsq(U#v8vC%IdL6T%eNSoJo0&0>jSrq(P983fTH9% zdN|x@&33X+iK{kO3p`}e6UB+iB&^w;rHo^gc~OT%zgrx%-XqpewT@#zQZ{~ftTozt zL{5mE*7eY#$Vka5Tw5HS+7K=M=sJNB@vN@B*i_13V>HH46K!?;jDB4Zy4Wk}t{IJG z>jRprpS(vkOo-*KVjIQ1DOXD)FOLB&_5?U&dUIWoQTBdVI_`($CH_|nl2*;JCsr*- z(&#nXTYZMG_SFL7{^YWyP6IkB2wz$OKR*Ei~gSv0=Scc|8MxB5J_(Pk1Qsh)=3 ztd{NQ=)OhY>A=&|Nele**t2QWl4f^tPiPXEZ?&XKw3rH&BS^jP{MJsE@n*LAw!-AU zN9vhx=qE}Jltky-7~_)j3QutSq;(+yZw~*2rs?-iHa`{CjY(kKU8+A_=^pibV6etZ zuv8Xmju-!$<=qQSg(P~cvsUG4`eD|qpV$6($X-|WLoyzj=zkK;_FJTxm5n}EK8&Gc z5vQV+tJ@zS4E=|^oKh%G&p2F_sSbYTn|;y{CwR*-jK?Y%yj@jS9=qVo1h(xkQSf95 zYfPx@%`GT%Dza_dg0*{GM2X-&!RPl^Vf$a}c)@8PxrJaglUp0fTLIw>;WDY~t$ypiugIDZHYkJrA{>bNH70$^pqF|2aiP{g*P%3(3vmF`=H z^Dv*ENDPuUSkK+*bgO?zgXaUBnghozZ* z*Ix$lTF$vbqJ_=o9LsRZg#yPX-0z=-JmC`*MPQG?2K<`NXFI|4DM3$S-KQ8KwKH2F zxsA3Ix0Xnlb*RdUCsm_(T)DIFEf>*Ftd@+Vaf~PnaW@r7>&V*2Ui$q`R}Xwcg4W8{ zl}XF2;J=j#xZx77X-lWMqR2?WYGsUY!q}RvEwVEEiFI zGNthK+o4)v!-KKH3wF_*Dj(c`P}sg@xTb3QUzD&Isaai2;0-*j{cuL2-2D2lxubg| zCiEg?e%-h099|z(RGR9?aIZ*M@A9Oy4E~_sOr9cOK4mw{^3kvXsCTd(9i`z=l%ag$eg2D=rI7NSUZNkI3O}C*7O}H8u%4|?SoS9$&3p(73V?adg`33rKdw9!opY*n z>ZHuaVtiv&P2%dFb95+44-`$?^B=9?Qjb#@zY&qhmqRbfK!fjr?us3c@$ne@^yj#{ zQ}O2-NmbAkt`-lw7;dy2rVX$lymc3_Zdz>4Y`{KK8+z0&b4l*}$A?qbq5E?@Ij)}G zMUGemF$s0gXtfZJexKk0)n8a1&Zx#C!+T-AUibfer|lwRl;xFC#ngV4mU(~(hj9D8 z!WgQ$6@Dt35ut2LW}J&GvjX9HN74(T7Yo%Yv+6)Ukt+2E%m4?T_gDMZ_Kb?S$Ru>~ z@xeU>IU_nfB`0mk0ldjKR^wc0^I*lCqIYnzB>%3F30U!3>GuJGVi0&sLV?8K z=eY8!QG_m|!l^S;CGg#*CtRFmy!$c=gDcX#$x}*$_zx5|Hxy6M31&7EUuX6e!uJ~v zUAu0~404%nZ&^sSH5QPAE4&>ERQ7@SK%lEva=88MYC^{tRDK*N+q=FCB(-Tf{Ul)M zO9LucS%T&`ko#U6dW8xw&A2viJQlEdbe}5qa5NCJ*04I9lLhP=4l<+x5Kj;Ad0!u+ zSQPYz^L_JoJ>vLqPk3kxenx~xgr+eP>Cr|;*vpkzHAqn$Gf1byAx6HX396zBJ;kc= zRSMfaF@VqM8B_yE*|_aU9>WDfb)0ACGr>x};qG*}+41~eaX3Ws>}5E6_W3Ov+LW z`oH;_YzgUhIz=YU8s$o{p)#DbPG6E3%oGkm@^v#iyk%yqc94!sb~oqKx@q?}wPpj@ z2&gK%4-SkcE{oX#A2L3%Bt)y|7`R4j{w80_02;;f9LEJBx47VJqiXuR6NtLc{O_`+ zB&Ow@5L?p#NdJnU{gfF;pX~P?PqzTbrtiPc?O31bu(Vx;jAm2yDv5gHFk%V0+U0g6b$SwD%1l?1>GTI zkNg{noz~(!YY$G)c;WE_ano-D8`D*7$t$N@{(lD5kmMS<jRXJTv@(Up9KE4qQsmkr(ucsjgHL zt&;>M6Oy`LgW_7NiscV#j?Yh#{kB4iL`0weg+v#I%0XYglmuz2{M_9Ep>WFXa=URU zIa<%u6gUVv78=e}f4x$$wQ-$_-|GKS_tsHecI~>TAdPf`bT=s7Al)e)5+Wkq-AZ?- zba!`m3DPJ?gLFy5nZNh_*4k^Fv%Yt%efD2_42Ck$-@|<7yyxBbb&0g)_RW>+?!i|Z z4OSYhE8S%+JMSBSHoxvXADPxUpxuFVPD1#XjtOO?Q zSA0C%0C4U+i0MA5z3nb?dHo5Dz{r?qI3-T=th2LyTBu6){^bZJ;Yku!eO9+#ku)mb zoZ+OOPzRp6{%ATSiaK3L0p02e(Miu;uQ0rsqYifQ(aA?`Iu9MUw6nz|w6RQjQg*;( z(%}fWFQdI`q&aCe)_5i+7qC$(pyv}7dXFRpv*xAKj7r_XMpR!%NoD zUPCYalLm9;*TNmDM~#|{L_GA7`%;02+tEtr=nIgvq$m9Crs+b+ZeJwe7fj%B?8RaF z^~wbn@71SG^xtlco{jX1-&-{eQw2`n2_=xYAhk9zH3l3K@_lI1jg|iK4I@CUvDCMM zSQ`|YA&AE%8XJ1~8PCAPuVj`_b8Msw27wi(4~ma_!fVPO*PA=~y8U`y((=h()n=n_ z`SN@A{^ak|A`8y>)TTx zyt8U`YY={}!~XF4u?Kj=oRZw$ii0FB@@;N0`#akm(m)YdZY~{eE|^4R&+aFw*;Vf+ z9W?p;2;gY)Ik-RlUO}q2(^Y~q5sia~HR13pY1MO^5B~RG%Z>T`LU~n&d0^-TsH3vr zq8ogx5|yu?!egRoX;N`G6)<;jth%^gov)|Bt!X;s;Hp|xPPJVXw;-6pP%{@0pKomD z;-lO4%tl|-Tl~=}+!0#*5-$|yUf1-+k}*)W2Kpa5HK7eR6n-0mw52|vzDtjg4VQV) z>t7Mb&f`ogp_XDm|0Yl7Y9jgi88-Z1+<{?m=I_PgyWcB7u$Ag9Y3pr=W2qgxVe!6o zc-`e$90UXWBrTtEL(sm)Pq!UM`v=n327x`ovNm)D69I;x5_qtg31lBMV|QXaIn{pe z_Wulki1;L^1FN;N4TW#!#?k>(29W%NLUxk@>uuJTGWW*Y`bxjxZd9xe|eI#DsrdOYi+*Kvqf!XoX*M?|FQ(@r!%?#(3Yr}<6XK&uLGC~(?7G$@ivR5a*#$i@sQsU+#8`G@ zwKU9LUM!VXO7)~GvUF@xxIb<+SFAGfyBy?wFP*!zVlN`*-B_HW?BQM*}B4K1UU(P;T-! zP??H|-~tvJiU7bc5gtwq{#2lN1UV8{iW>lQzF8~1#pD`x_UJX3*MsHOqXZg^eUjq( zr`9yInXVKNa1NF1o5Hr}NwXo;iTNfw6P;?~ASM}kSDIviu@k8%bhDZ4et`VAUZUH` z;&nDm>m0qY6O2Jwq(pip;V=QvJprCK00N`H^>iZ^@FsEJCEq7K*6M>7?gL{@TtC5V zVqfL0DnK?FrmS~9{5XHARk`AZv|-0{dF0hdd#i_DQ>K01%s6>_N#uI=QphtIQ>ha` zK+30ySuGDYvh4iSDoNLi@Lu{8XHKdt^A(Ie`})R^tI>{n!`@z>FMqlC?~ zyG{4b)DG$!3{QtX+4m9x6y@}~@WM>bQ~)JH712lu=8I)`eE#gY;GvMb>f${!+9mCL z5`BJ$QedHG&SPD;+VG%o(Y1M6Q8}}o19w53cdt3s9;E;LOJ69%dPLp9Z{ffKxxdhDHrs1|7vIMRN zomn?tQF%wieODn` z6GbUm+UVLUacg*4T`gnqcl7lJZj@W{&&QqZ_snZ)l5~6UEub;pg+_0s^VNS?Ydtc# zU?pE~y$#*p`{}ac@(FWwCp?2Kl+XF&4!#Y1h3;LxtA@(DA}9g+Y_h|1<>HVy{Q9-_ z(EoB_yh8SNkI8hc^PxU5tK5@VoS#O3=dF(9%1c*juM#$R8tnA|fZ3BEm)I#at5}|@ z?B&>Hp5@Kbr>iN~&xB`aPcC>mmOK5^{-DNnKJyRvNWZc(h!lu>JDyhOD|$3n^21=` zc;sn*t%7_IZd_RkOeL=UUK>#JzrZ!2>|LG8S;dYXglCsDv#w2L>iGUY@p$M0#@B-` zKv6Jh0<@oNGP=10AY##c1EAy*{4Ph~d1g?f07dTOp{z>7r$y}NnCQ>6rKQyZvAcJsB4=h!hm(i1cl%v zNUnfJy6tWUB7M+>;=fY^qbSgM%9yC`b_?neKfB+S8|>X`W?w~RiFmoU9!Dm8(@J{A z3Z_l02P%KW5r4tv6>odH!K0a$adw7Q_`7fpyOQs2iC#x{p-%>@=zIBDKRhl>0;6t7 zoA(f?f*oP5C>UQZB{=qUD%vX187KMF_-Fe3kyyGulr(<#){UIw<7#!r4uufjSS2#g zd0j08t@^C&3s#kWR=Ixmr+MVbcIj2`T9pH>s^ywe{Vxsf2bilF_rXVfSWncruNnE5 zI^I3No$ANY5<1Bn!Q#eC@%E+bx8gE;A2WZ!-}wrB#h*_%(g>$Qb$5%C?c**zIlym0Dl*L2i~IXBPr@V~CSZI3jQ z6DmO5AkCY6Talx0JlK~i)}*?hURyMJ*x1rB5Hz|4yLGpW z{T7V19H(^d?W#&mfk+$UIM7~pmLc*c=ee7C(!7BJ_ITUK>3PqLW+2K8z(pJ){%w6}`uGtz< zjatU(8x%s&URVHyj?ni)T+H1^*Uxb21s`Vg-1j4H5`8V79-k)`?G8B4tNnC8nv82= zI>N&Xna5_np*z)<@znS9G8-K&CpJi5=Q(K9?^M~cpY3~?wARK9BogiFl#r}kC6oLW z{K|hg*bf6Ic4-gF&LrKNzIZlc5!j8PYe0=j*=qtEEnec|A!Da>fl=%Hwmkfy?}E8fC1p0>xYia6x^yb-(83a;)&wljzIl1l>^&;(&I2C%D(d) z?ge&^=#8Q0QphA=b(&haIsjh8%{iX(`8G3VPYUgE(YjOH(<6M&JzCCoW64~k7{e;L z^CDo4dq0*gX=NRcJm#8IrB;$&V#SSXot6~cZ{2w`nPli5RPoSw$0+MKqn#`1q|bP%jq6Bs`*dnlrmw(Dm3UJaEV=gsx#?UB@*XY9s zF$uam6~3c{=U)?)SWrA&o&nM5v{<`~(zUn^y;*&@Q z19;>QLbtXSo2^lXwGNA0y19*9c_7Vbnl}E)<`V&qE)UiTG%aoc+BKv3d}( zdvY+#n+~@BLH&FixN0sBb zk^w`NiYY=N6*dm~{)l|f2UhqqKch7ds-1kA=iRxX-J8Z8hnyL3I;JyRb~B!1GKl)Z zLAP}cJ-zaQv0@+20VO=SL6As>11Nh01PI!jrZdL=Tv33hYg;pK_IIDBkR0csTI(5Ml!cU%+rtGKOKON)h)#T;BHVwhl z0-CHTP~0d6vu>xe7M-7?sd)f)-6P=9lCEq!lbz zrt@K`@D$&hT#ml_!}U1#Z4^rRx0mw!^qY=wog|}%EPGW%QbiJA)44aASkC*|0X8qT}tdwF?kD8Ei(tT z)2WS|Ua|dg_o|bjmCpWcXdj<>hLq)encN`L=4B8l*S{JAns8_Vv@|VHtqX&eY}#0# z2x9Fs?L7N}GvW-{=9JNmRTuf~CEwTpNEl2rRQHY_!88$kl<*pa!5B7kO?Deapnu`C z+rqH(VpCR0T-L&l*AcxE4k@#;V6Ib#ok8SLQ>Yq z-I#~KXTzbMcz`=t60-Bf*gH=6+hiTU)H)z=d_XDBk8E@{;cX@aPn(?PC>FnFHeIZ< z!{TiFN^f=P86?Y4Za3lJl{UxWFvqbYjnVRHZ_rURU|`x6BslwVqHwg81)|@R4Q1rOV8zj^Jvz+uR`mQzW#%Az9U9bEgA`zuluDl?x`y_Db~Xjyj+pB z*UdhedJxk%OM7yr*VMx-Av;c37(JRoZmau62wJ4xDo7?xO+3z7DpUGdWR_unVr|~F zzder<&Zia~qL7GF4+P={v@6*_^Z8YC$q*SqZ#%&HEEvPyxJLl#!qPZ;E^_9pe%x`L z=E(D1c~D$j;KId6A)Hg%$2ZIVRM2z}n^l?ZzN7a`-&rs83}+;y`Q9R5Q@G#E zsom6r#d22%m@kC$)>pwD{*Cs zvQ?pIE zbzNCX$rnewv>MzP4erR~v7-R>TgKx4*mu33RXV+q(40`;UH}JM?PrTqZfX6*>;5tT z#+0a)4M-6Yol4-C0E-+L@sj|PgVFDS#&QK3?LEMg+-u&+%zZt<^X9j1VK68@XDsY6 zy>~mTpBG~>y&;DaT`6^Q9|)8%blmN#+;(H)qqCpbI4v&B%zIm|x=L)c)*~^8s(8NB zi9^r*n5R}a!QnA?aBZ~jKGS{`xIQcPBUvC*NqID}_*~EQ^yM3yai888Jb34$8Hu(M z3&5@F+U=4GkK8V#U;o4D-ByR)MzP^4)Dh*u$E-KUUR4W}nzvK;Z!91H zuARps0N0&(4-a2M%rh+~r|dguHJh4W^``eagzsd?7Hf{PSdLm=ua&H52hALNo1ul0 z(}fPNj6qtGez87a)XVhui{EpX6Y-@@yXyrPXQh*J=jZDR89J9C9$3N!4Z<0(7VPdV zNCjvL<+GR??pzFp>c1Cz3-RqE{$9%WP`xr!G}~FsKjws)az@Mu9gIo(mEgvzoI8NW zQve`|>6xEQ^G)w<0k&8EsGUC+aPp zoyGrNlMm}R9E`%MbaPc$VB~*vuyi3uP}NN*=BOk6IjiZ#VdtS=K6eQ5efI|bRvyB zA+18@V+BR#b4i`X$Kd;Z9x&ejvo>SUrcS?d%tMiHf+)j9b?eqd(tnLcE%0&ag?Gvy<}(`YDLIP2!Apb^0Vd z+M>`ly)!&Gd&`5&Q!weBjEYCm;fUsq@-1S%!3CS(+)K&(xMrto<6hY3t*v!9_*# zbJ}k{tL~yYl9|yG3VGbN9}HVcqgtBY-wZMqPq~ffIqgW&g{7! zsib^52&RnssvRXPhG+1DN|-I8mf-t)wN-=XyC9lUEQt>~p2&6hoz6V|J=@A0j78l% z;F;GIS)e0IM+8S!!AMO(@l3+`{bHt}jyzUxwv)Lh#(W^qPt;A%h&0!+{ILL)eLXX@ zPNbUapr zmswQyz_n%r*;Ek&gl$a}$V5|L>x<_06;n`P2Qcyo?wc`E6dh=d(fOEOPwjEh5*^oD zG^X!9xNX>B1Mq3yy;B~q+ZXI4@LC&oQxx3qQqXJ0GKo)BE{2O#*1(*8Ycg#kJx=E9 zg2_Cn2Y4L;Jf>LWfHRPXk~Eodeq7k7-CNq<3E`V!_y!OIz>R=YvrlE=(u{9N|jDQF>T4>5PF)doP)E znl2j@XL8NCZ|6g3^VG4Q67H_6(Y|qzH+;M7<(sU~4b*q}P%oE{_A2o2ED|I2$*=kI z4CUW>d3fWPNTJA}_I`2T+?f0(Y`W=AN^LpKCY({!dzkCo*f7iBp?n^BL~VP}R@YD?WOQ204y*i8&nnxi|2Orto%cFB&SjJX;ANQXUT-7A1&f~>B-}AM zCZ!&5BeGb!Mj+V%2+<4!HzI>}nY8N;20{4T@8UsA2xtp#FkcW10DUk&#~LDij1W_|DvTAoAsQuko>u1Fz<dDXn^^B7d}wrp*9Lg&M|r9z3XcBM#6HHt@5 zg@@h~j~)DFX`FlG$41^paaF zam4za#yDLMa}%IImcBu`Amn#R_EnY)OfP}RaXQJ++k#n|pYsB+C=?>$Fz`s2Q8C5k zyh!P8%M~AtBo1-0;5{3nG#jQ^yD$0dvf}n-do0D|P^$-RHMKqjVaw@?;Vr^BtnuVZK*Ms@WO8h8=gaMOAYeC}y zlLk;M)JX;XT$9#{Jv+>ejyY1KPAe%ie5Zf+tB|qo0HlnZKZjdBvzAE-gJ5JIuYIR6 ze!H7-3%1|_Ouz|SrrB> z#O7^s@(HpUfS2XTwyaew=&WuVnU@wC(-X_-88ZssY6|%%z)TT6^ ztb^+};W65;Npru@D9WI_jhyv;ndiHcZ)vq*{rICUroM@Ul+>&J)9~O-sg#XQ^ZIMX zChq`u9dM=S!oqG}>`LHqTgp;LF%al`jq!%*5Y4qpAFOP4yg_;jGKV@PW#u`K#GY^H zNR@s0yWL@PiPP{K@??j&k}z*l#}mt25h4@mAFzQ;b|SGYV}{|SDz$v=osQWYSL;E? z-*A%LI|0jkykfK$T8tP$k`+|}1rZKv|;(X?NiRz&4vaZQRRqxi-J7^^?eHtZ81+*T8@xBgc2~!Rg(+%TDWHiP&~-3+%5+qzr2*;e z_aD`djVJYU(s7%eB3PvsRQ&FxB<5FLne}IX+wXHaGLhKxYklQ8uIKJoG*Ft(fwf@_ zfzKEr5@-{?#LkgkWVhe_KU?0Pj8q!<}s-GS}yVxSv>%ChIytr5vd;&d0+pJ=B`qRp> zLwUgA0w|@p7L(w-Lp5j(+(xEjxJig~je7xWiqybHi<6<2Kk;_ji0jzwX_o*Iw>8st zBNE~{5%Sj#tt;xoYtFan@znGCFYLxOMNNY>*FpcGDl9t<orMJLI4{Vn>riU<{FFB*u6Byl)kzc@rQTYF8@gHzsitUA&c39Ox3 zO8cEZjJ%GNy>U{mQ8kmF|CHZ)wEc|!`hLI9XR3OT*yM1!#&X9d(a3y*J8dS2zXuyg zFMMO%lTMZkBAwbFaRv`ng7u#zu2P>#pRRJR8zbh*AErOnX3+k~y7^SIs_6Z!?3ibJ zLG@xBZT}^*sh{c&sy^|CBOC*BATr(wkMLog+TJx*WG(<#n0Bt-Uv6AI^K6UhSUTZ6 zCG~WDm}&=|H|A-t^l_2LwZ>H5hEKkNAKW_KNopT5+}-QHVYlv#PHl{Z9QXF1PuD_WNUWlfhh5g& z+zd>V-7GY|&!sqo>ERWsGZFJ6Gok`5D&daE;m}Yv&E4Afnd+LJ zzgg@>6V_|FbbjSz#=lgdmy#VXVo94TTj*F8`sXTG? zY%vkfFNO4`D#a4?5p2{tLpkxZg_tRho&9LOm&E4tLSeY=SD#J%+WQ&h3gk1l@e5xo zTafEzN0r25;gW|bK00eN+#l5K`J{3bG;z8QYTSCh((MWM`0|ZdaMO5ZG_5|T5pUF` zC#+#&-upQK_r5L5fVVjf9&F?_0j)Dv6-0H9TczzFL*|kh-G95(=xG9&DGqdNRbrf> z*Itq?v1f~*v&BNjHdN3jrX7QW|d%k1%&38qIO=)T5 z3u&sXi}NvcL?JrTSHK!!*JvCsOM2kCKkoR{R~<>6snL#VrS+6x@)BADG5FXP;B~z- z`?m~tfF@1kV5Roh+^4ws`J$!v@L*-mtM-Po}x9o%7^$!E}R3Q_F z`sO{j!NjB0&MZRiwrJ3Zw~Y^aRA9p9e%T%ky>c6Q?x=Rhg_x(+)gy-cat{-I6@n>) zJXLs4^V7zSb%2_3&eeA0!BMSsAB*RV2Ta$Kewh|)vIWSc-mbmvsMLndcy?@adF=f( zZc&W+?+G%Cl(WSy=GLF4d7VD;>Wow}fmvA+tK~m&CT(h8JA~HM-SzMxM!Lp<09YrI z)$MRQuv>Rl-)@sC|A-B(uegcf?pv;&MPz87SKEwBN-QA;I(U8<41^Z#k_P!^{5@QJ zSm+g)cwCP5Fhr+Nz!nFxAI$kyKia|W6LJItAQ|blB1Zm5p7p=NvLcY5siAs%4WNMc zGChv$6yhGFy`Ozrdh_d|*hoMcKXlatjH@ibuxS-%A!jX%C7MJEetuh~CkZ@}y@e%s z(q#GPw`lg^U~l~;ENywfl!f;_sz#-qJ9IJPl|r81!)Ka6N-}Neb^eziU7T%3-wxz% zCCqLZL0d?xAy~1}P6v7({Lh^3hcC2&<`g2}f;K%9OO*enR3T4oNRMMNoopP?KA50d z`$Ea_CSlxSd*f^#`pojVxuOSC@!Q>fETmbGe1ArNE(5gKf65x^|F*16`xm&dXDdzt zxJUPm2yTJpCNR&E`pOImf&klby^eefh7%wFbn4k+uq~j0qu6r3LcJ#MT&zm}8ZvVT z^7ar=I{u+=O8?qIm3RNm7UF7SC5&{J2AN)SeEy;|q*|t>vefdx&!E?23tdQq$B_%_ zmx>5dA~i5$w-)qIm0ZVvfKOFGE5*;)JVZyL>4|JiQoWrXdeKlSD27zG%E?Kd9u_{5W zDHv&<0e~;h4hDJl1KO@J3#zEjfU$=g|J&H-bfF{cG$Ft6#!rP1jSXW+<)%DBUNK&n8A*y6OjimC zKHZ;ZR>I5X*N-N#7Rs@<$9n@1Q}r5ia==}t?VH)o0Om5n1Ou5Y)?Tx`27<;iWrz0q zUo|#NhL*1eSUT1B-vdv+!njtec-dy>zC_UWpSE0>utn0&^gd!|pI}wA74N;pn;O~5 zq7I~d2IB~s;J_Px$p|l&*u|rF={=GKmr`$w1bS=t4X)$18Urmh&g?MTmjx8Z6)C zj9NWsXg@nZLCyhu1s062&tbuIr`rJBEPzbZnE)SU_8c$+QxZR^`HtSB$b0X)aPZ|XyVFwFfu)Rt0MkA@t8yDSDkla8Q}j@4;TH@orEWi=mU zgQjcDMpc<9!fwE^Qg}X~vb?;NhImrnOluSQU! zE{$_vioKNbg z7_QHK`P;C+-i5l9+*G9pJ5jvj5A*8R6i(8+xliiY_`qRu@#{u=;}8fWmv{XG$9Ga_ z17z$+EkB($1W3XlmD5rZ!jA4@YZ9KZPDI}aTs{Pt_2DIzxG+i=K0CEs?kNtq9rXwL)0EeYPY zd$hNJ&^aWCX^j`0d=x8=4hTx2soVff(&i_B{O4fF^Z(x^;{yVOb3`VjSKze$589cQ zz-Te<8m(si?+11J{DRb_vLH187Dfx68$9;Q>2n z4+4V!yu96l;ptI$M6>XCe84f3IHEl_1{@2Mv{iIN;3J4NS2Z8_!MzX60wAteYQqVs zA3y#TUCYcyxu5U-eJhFo$G6hbuZNUT85yAo10gHAUx$XlXu>##=pk?Yx=|pJAqA{RG3Qy{pZii_A9+oqE6;8tq4S5YiU`5LKf;gP zSK@0|6NB&PGuE<`hY_T$`XAzGdKCp#-P3Q;m~n0rE( zmT3APg!?5UJN5!JDp}=-hE8qat-_w;NM>QKOifYnmMSZPAzZI$#$#{^a;+)bz4F%) z65RCr(cbSVafC{#likKj)(tZab=2t;W3pqHxyjI1lZIS`D+n`55R;KnetmU3dPeaQ zK=PL*V>74+ccR$(K129G<_kYeNS1?u2=y3bGgIjqk}(YX#eNL53Le8CDLLK$#sWI; zXCFIIK&K$=h=`Q2L$A4Bi$GsVW8;c49Aa42@D?mnt?rM?^yL%jf`Rsq;Mcd(XyTmH zNSNGmmCq5<#p?L&8iE0i_IlO>YEFmkkHt_jhm?+KcuT~^p9I+~wUl{{d!xhwe~)6N z40YI=M2;MZ_o8#t49r0YnrT}$f+lLnMA*>K^q9&iS87aziy~ySz`Dm`SxZtrP=NJ& z$4vCQLlpHqKx%E3T`P_>Jf{m}vgfvY0n$dEkb7gI3dTNw?_Y$2z>LzMvw2!(u+=(J z`9*`F;R95FmBX!YPVcGE0kuM6tlq@w3Lv>f#_2Bq%>@#G`96=2#ay<;p0+5iT?jKYw9{)zfvr5j-KX-?yT6T{^e0((BiPLC>vE8O;O1}QKR z4xm26Qi2DTp(8B4F%Ht*D6^ICoUUSaEl?7-L2LtS4KpZ6VHp4KwwYDs|7e?k%|n+N z9*PRm088IbWz)0-^y15?9yQ|5Z55y!t=EAL{BaFvi6jT}fQq_5`K-c-u0sfPJZ?B0 z3`~s}*$1=`9P5bM9`*;>-iBW(LKyATCzzChOztV~OE`{UwJf_A3dv2TUs}Yl|K^%r z)?{cGH#5*Y4bUb;KFSwzQpFk4nGVKHfz_5Pt;43!vr6o|3|CO-&CoN*OqE!9C z^toR;gxyH7i6G)4%-0aP%DW!^GFcpWT5)Y_9|1gty#Pd}Q>7qs5@G@bV*_Y(BI3Ap zGSkyuBFS=*4;G*?TZb(N9eoJN@2eoy+a}Hlq`?bDlx2VOTk(s;F1c{@nq9Nh;^{U# zqjFf2tk2;x4tWv-Ma$AQhQ$|{f8R(8rlzFZ85s>SOBqTXQ(RN#@@dsz+j{NCxz}9v zRifl%(G*Y#Zo>UJF(@$h5CAn^0A@MO3hm9@bvsKkV}c3ni2T{-Qf43)r*s0dr@n6VSM?$lH(I{<{s9R zE;K7Uh{BRA3S`d>eCu(;y$JuS4w;ry-*&QB1I^si&c65E6tQ>jI z?E4$t$%UC7xFBDVP#E>?2kbCOF7zQp5YTQ0rcFtAB3#!$+t>6^j*BPlkc~_br8nuNsZO?lcoPGs_&?N!9zcE&8jl(A7C`Sf{pznp9bp)Cr}0VU4=tv@ zNEhNY=aHi^QqFkAVIwXa%Ett=P*e~3R4X^2?{Me6Kgd&E30^%W%}f0JzWXt7NmJ=SOzD-^z$|{ABRUi`J-BGp)Srf6pdqyP z;>y2ZV5Vt-aNoXA@eh9{%=3D0EoE$Cv*!)(d*C-Y`$L;f_yZ6*Rc>vJATgd*Whf2}du}o%v11APcentoK&9 z6br8P`@Y)K@XMuVN)Sg0+t~9GqW!)xP(eXi0l{&zO2{TdAA*7ij&IzQ_!GhhUDwk| z(;47i-j`18Cz%)#VzG0?o4|=B2K#I%k%5LROfRcBd;od9#C^I0kf*Swyvh$-LD%Zf z3OhR1z$=;`vfc)v&tU-R`zsJjCBAr$+pjBE`Y!D|XqxgIoO&CO3A@0?63h;11tbQo zn7l-Sz>4Md14~qQLhf*l7?4$c)sAicyS zWOh*|6|7(KwK@0qgF|cep*U`N9 zFt$(F-lVxfkC;IA4OfboH(~0S+XBcCm_{3&sb6mPt*j(Bmpb7nJty7{a}c816W$Wai)8+;HB=uFYxn1Dc*q51QBJ z5NQD{$A3#p(b)M~p)an#uMiO8U6YVl5nr%NzSIa*=xmlT{M~KX&&B?_ihtn14gMe= zs0~R)a~p!De@Q8+^oC>bW5=f6+!sNGmn785xK{Oa51h2$Q(--T1izpHC+4RZlWaOO z!2*P>&D6Mx#x78c19iEZ*&nEW4al{+=mTC4uGD`>4p0WyUg5a6azJtTxaBrPV|{Bl z!L(QWV88$~rn9M-53A_**OKh)>v7;95)sU184kEvwi$t~-rWuU{^uw7wWuzx@lfPp z@MPiv&k%fl%QYufhXeSmiziI9k2&ib7iG{r?Enve;ZB0<#sO9B8jLguR}gF z^K^l(m}KSXO^mxoN4ZpRF)^;gBVEb7ICRMAU9O81bcckqEV!>)EafO%Nm6w2#ECU3#%&+RhMI@Pr}Y*$1}&<>wmBHQq}FMU#fswK?u zZf$}o)n7sPlF1&$oy;P|<+JfNxT3k=${u~s*=4|mY}pp;$|XCFDOjAGnbht6QI(;e zjtdF*j>S}^YPgXK?viMADLgb5?224XUo;8K51<#;C*Z5e$&G#(yGb|Y5HP@8&r5&v zjwz}ewe6De2g?)PQz~_<~oV4B>*?j#8-dwT%hp>6EXcmv$kNp+{2qy}iP8(U(gz?C@ zhRC0_j8oDZMM`@qh|TEKIB2ptm0UU7ySz;fP;$t1wjJ|7*HX5p@99Tt_!pjax4zV; zzb;h1EuS_+N!k!3xj%0fYlRW0yn_m}MXs73Deg|SedM{H)a)VrXMPp9 zMH%_Dl2O7B8BZ4rWnTyOJWFoio(>F$YEz}(@M}`3M?EVC2FREg?T`BJ zH5~d%37iCb^#hPL2z{;oTa{iaU!QMe%B(V!Svrmds$*}ymMW=&`W|ZuU9PMET(?Z% zE6PM=$y)0>w7BTIZzEcIf*a(&>{0zgIIA_;LJIMV`%ee$+Lq5*8qNs?p1*G@z(TNn z+?Uq`x!OlXLaVUYK6<0_f2w2s&!FY>-I?Po77{1;o@ZLGBKPG69%#)$GI!ffbY2{La~L;}GyUh0|c*8X*24 z1XAK<4Q!Zi#TT#ckDO(k0U}2%6EN0*XP1KX)tnm5Rgj*~0}99iYP%a>Tz_A^QHmOO zzlN>EHej|V-H^4ylt{jWCZ;aM!Q3-4!6j(Zk?uZ}j>Uu;>!*K*Y)z9NRQijH^@?dH z&3k`9_)cL*H|{_eqZ%w8Av9hQ{B}fL)s%Z(4#23QwLv6?i%06gns~S4fcN^4ZcQ0M z`xAb#CwNJu`h(3=lC;HYpI-^vdc673FV8Q_0|G%J zNd{nx^2C_=-P#7&L~Z7KAyXSl|5X7( zs;Oyr_dBud!btp(F<+i@mkQ}35-1%+BbK?)*VsX8qSnzrip>3VENt&AG{>xn!8#3;d-U9Q(>P2R>QCACNk@`imRM3*+lQe#Q z%#gSZb`sI|*)x%6GUB3u6IR>(o9{X?u$qeC8p%Jn{`JCX^>S}Hakp=hk1!J>^p z(~D=EasRo6gAB^N*%uoRX_lA&Azi2c{|<;4ZO?z!Sj!R>%G;=XExVoo_ zrFFpt0~%7Km*qb{DyYimgHx@3b!EC9_FZS!S!qkRES6dJ>bETQ%^XW7!-n-K&Cwc_jvvPPmaob^*q>Q?ijduI|jQ=)VjpL~l8J z=8!xB%xW?xzEVUg7tw#%W%3Xt`MoJPby{fl@x1;D9BRcGbI3^Yy1Ac0U^%=zVPw;7PwFH~#qk@OY0665h(#BO`fqnjFW$PrO;B%V7acG-ba3$;; zrqhN?qtLcXN7mBz^@LJn!yf3Buz+qNwKWH|m7$BHLGK&cJ?@&G_FHe#t(PC$dprb+ z5;ihMSPiwk1~(E{g>~DSiwaIiP#CCT<%e>nse(8}FSx~%kRl~8k>0O<-Vo{11Kx1* zK9N>IjBf?J3Za6S3h$e)vVH=b1BS38hAkKtP|37U((W7+%;XtzpvB4MG6d=md#0)H zWBA4^UOo3DthSrU?34j`w>!~%#2!s>BYn@3pRh?HTd^|}fF<8txD2OvyAh34d6#?k zO))>6%)$S#8bXVY1Ho;S|^I(?aefAK{@;3x%B8v8wE`0aa5qI6rO;V zq!HsVf<~*f%D?i0Z_|A!Ue_+6SjLk7T5U@W*N{er?uPhnwqEd~ zab#cjMb$AiDCiMhQlW#eh!9l3^Z@)xTVE}7ks`{-F%OCKn@o^3uc5HXf zk$lMGs4yV01HDFrpTYvZnWhfDFNS!~PR<~u0twvy-Io!hups{j_0xGA@Xo!bCH0#D zuWF*+$*U!4^qQ|&rcvO>izZR2qR$}l@W1vvh;swdoFpDI zFaG^Ifp?9T0m5S%%QAurNqG5hGA|IqfSj$Ipx3KaBvpB zQ;_`Ak|5iJ6!_dV0+GfRWUK9a#yBA%>2KVkiR9z2uM>ujo+jmUQ8;gCt#(R0$JVb*pS_X<}k$o&llzXhl~E# zKN0waBWRl@NCe%rb^+vIZ1?nA`L5NDBZm!TVuU78DeH$t&FKMmBk|;VdNmP*x~w)c zaQ}D==v0u?8`iDj3jSzgVEAZWjEdNBt5@6`}k9h)yW zbvPa-H-S#su*?D6xhVx;10t}#6$g))HO?xS&HGpj;-naGRFYbAf%(oVyd-+ecmlG zjyXT@DyYin*pT#I;AG(uNW1lCC>b*W`e1oB5kD zDM~3B{uT7G=|p65SYVdxHOIc6NZ&*9gQWvV(YbIagvtOfuSh%u6Hr~z=rmdr3xc7R z3rOF^whfp?zC2ir`#tynu=kc>RkmB(C?be}0xBVbPE<-#0qI6ckdV%a(jXxz9g3nf zNJuv*-3_83Akxwx-Q6&$J?694e&4lxTN8>X0-H`e|Yif>3)1&sr{vIsaJ#o3CRM3TzWWDH%O7Q0e#R>tatE?Xu*$JT+%Io!SVGrxdON?c>z z)Zs}6qvEw$5lp28K3u98Cph3hpi?6TSXnY?e%9xf!OZ!!udzO!uiv1}R+8HK??Y`LqXm@;#gjILsO84?$ z(MPw#UEBr1{qeJtl?VI4vzqS^I;H+a{!Or*d+GvET*KW=ai4w*LRjhKA?xDJ`m&{~ z9M=LHrm(Ptso3GbwOmkhg0^zagVUCfmabSjyjup$%eyXQ7Jfv&ry(S7KTD(22cg+$ zgf-%7_n$4R=OEU;J^#@|#9{eb?V+!ui1aM|<5b@ryrdpEKIWS-vduQlH{Yn|5qyHd zxw7%6$7>!DRB5sw*zQfZJ|8rU?D9B09FyW+`%(xkI-3Jot6PDl>@Ec(5YyN@_4c+6$Ajr63RK0{PUy#x;%O7nXAKWG6?FFPM580uP}JNmUMWxc@T z)SW^7?r8<=#wL`FeOt!*%9pbq$k-)WW;tpCbHcksB7%LenO)zr=tP%Vj+({U)-+Mt zUO*R_Z2w^I=q%~iHTNm<^=V`hdwn6P{w0pqlWN9RD9M(AjsqpR8$vM3`3HKXxBT_Z z;UxkmHwr|49_I-f*1B5rj*|wh+RsQgpWXCOz)4HO&h#w%h$mf*|6C5Res;1~fA%DV z6>w8;8%)f7gofI;+|z>G_+gN(T?kT z;R|cu3ys>YG2S>mdxmKJjc1D}>C?71P0xjCLL|vgBmiV(UDNvs>hbOo7Gv1|V#8oU zGp8PN@_CAqZdg;rKw;2eYEb2$10P`_FKBSeHRj6R7Ip9T)1T0hHWNMOE~HBh!F(yD!`+qJ{Y<*1Ff4SM7xk*u&}FtG zK^~adCCSb^i{c=3NISMzIH>!=pG|K1i5|YB!vH%}pTiSL{o0x zpmZF0-e-A-Nw^>g`tEe$v%)Yv*R6o8;52bR(tA~Tu0Z*`l=N07;<&|rYsqW*_6f-Y zbZaeUsP{MP&$XZ{5Oku80&?h2XPrxR$-ac{af&Gt0onZsjWhcQumM(}vNkU6)W3t5 zm={EgaK~sqR15lez^x!8Ux9LV~xO*Y=-M_JQ2L{D-(dQ5U zj6%MGIf%IHbj2bBoSNI)XIYZ%YHOe*s8@L0CC3m;{>OlT<{+Q94bB6r1%j z@Ghi=WmCi4hhp8A%BfP9K4}!uqUSxlpu0*5%H;&@Oe5M4Ng53RJm{k)YO;h zN5$d~!JMIcsQ;hM1P|5=9Fkws!25>Vv73BLCsDY*N{1ey3sapWzpR!|5jd~cx6*10|8)4{*@v7qyO`RwzJN|{1l;$z&D5`9e|eclD8!Ti^IA{Dr^^q%<}@Y`Qp|Y19e?D#2FXC zf|u&vY2f>;?J#I4*YXi8Acx_1;wEk^P>u88!uj0@Tug%qpq;_-F|S+4tlUypcRpJB zv*X39o}#0Lm$z|PK=AuoK#F88Ec{sqeSU_yAsT`9630AFneRsxm`bmj`SUXxD=o&S|O-s}E3! zQ!$vgH9&vX_}LbE!NEu!w$1B;*XRzg(|y1y_^(d^v%%bcXo%o=;^C^0bw zzJsrmVMv&;kN6!X%R5jik;?E(#Ro}3&8O?UKvOySjxl>43NgJct2ZRxxeIhbzt?$x z_Vz^JI(I`qL9#oqeORkq@Cncc4Q;a6wP(_R32z)$1!Rrs+_kSGHX!@UI$wv+J%2kv z@co?=_Jf=w*1f=IK{x>LCKj74xqn<+i^EVjDIrP$jqFd@I=-Baal zXuAGQG=Klg&@|e)<7c1ld_XmfzZV_t^|TnRg&5nZ@j$oYs!QSh*&3XovIaxb1BF1$ zG9CvrWDIJ#lpyd(Gbb2KsPFLj`9p4->E2gf6p4#V)q3?uI$_{& znBgk+WowA#XqF^{0n{{IIR&=%TNsvLhWzGQI9fgY9UWv;H#=-JlJ&rd%p4AMuNdg> zs6J;^dyu~KDfxB~e`u-Im>f6(gy})8DoN=rzu%Zt;n1I-dX$Ro74sHs;k6S^`e^xkD+*r_1X(U;U zfTzJ(2ZEZKTUF%)%%f#ih_*Ps$k!LhMVSQQo z0-E==kUIUu|Fb|((0=xgfTK6)48wyxjzl0#Ds5nd__@Dt9rWa-IwRjk7HBJ)+UQ5t z10w=({cMJaQmBZ6|6#Pnz-Swsf7VXsxAr4;4fER>Fqzzu(}5zOF9P#=wIFfL%UE*< z`UkdNN&9bzhA<1-+FG%KtI4&Y{SjIcYj4`9hQodO@h@WkxBDDN z@6A4eXvDaKp8Acy01#}s3ur5PGq72l9o9lT&!vQ1l$2HU!#g1LDH$lT{I|7y6$uW+ zmok82%3jVXw8G!tG{#AV@ENN?@d*dyXYV@SALWGmW2sl(f>4uzXBxtY{|IDyji68c zI@CN>@L`Dl1d}2<{UewI<_-pPv-!4xX#4867PrBJM(1a{0@!JQ_G11B`TXNUDl`;D z1>Q}z3l+aG$Ubgxkxc_6grV4nmJ{0GU#SVTh=JRVz=0qC`7He7f1ndie(O+#QNQvz zpC6rDIFA>56^&zxQHJ;#!S?v+?wJ8ARreJe2?W~d&B@yT`C-oxoo59}d}Rd^Y=zU2 z6amX5)=nU$1n+%QzL)?;{AAm#5?$M-J0v#hk~Y}*vv(n@wC?yF`tqs zX5a31A0DVSVaTuZ^ksa2{n6L=j!9ZU*!UH!LV3Nq`JxbAmrlw=LP%cR?FP*C2>}tuzI3mR2fCDWqOgzQ9rX_V&q2t!(9A-zV z6b7a=+Oq%)yNE_&UF-cH#)A7c`GBVX>wbpvuMVawQm+bwZde;t?G!z4sGhX~0yuY}F%I%Q?fC z6vAQM_6~bnO}>i>!%jT$Uk{>N%T^|B`S^Je{Y|C^)c#g+zx&7MWT%DO;z{Di5z=D?Om8i)1 zG;Gm6rAAwwBF%=rD8rCaxwqbI5}$QFn$@xCuVN*mJ!$N-87G(X$<&T>NJ>6D`g~Gg z^g29gIYkg(BUzwj5|>nUU5He7AVH2xkkS~j;XZE8$C^+S^|FzUNqK@-{lOQ*O&*WU zl1S^Xq;b|2r2?h&J}cwS+#{rE1G|74{`}Nm`js@=ng{xgr;JuPO1|P?qIydfYb_{P z{_HR|Z^ibc7@Lo>eWiGCz?(R5?rZ7W)fFtq!79QYb| z8URJiF+L4B^;A!1rHWW~5=n@hZCEob5LA0HmZW?nY|2|{t-+v{eeJr&GV*C;M77Ya zuPJn=oWJ*^9K*~13K9M=c4HBfZocCSOSOC|p|Byk8>AZu#+8h^z14IbkJUt-!_w{! zm%bjBAz6>(gqlOou++@c)6>BkO72%zF1>Ze^F=D2vBgv-I_pgDOC>@rt>FD|{Y+GF zqMk6(10v#DPx{T{NyKqz3RVcEJC-!&!$3z$N0*{g_x{JN%gftx9dh=LJ3q4Q9gQlp zvOXpBx-4&3mpM=JQ^W=z$!^OR*ZTN9v)L?WPaQDsy|qDk;tHh zd$*-zsQl^e#Yfr#-|3|bE@N+hIE#-dj!R^4=V7#i1nxeenWTlDI3fQ0tyi^^>Xn<6 z7Mriths1w0Pl{3QBwftu=?|4*S-`x>TD|V=lJi+=YSKewlihE>^KIW@;=6CM!zFdB zd8iUM2>Jt5D*J~d4U2;a-7+nOPMm+5`-p8rqcS@+xuSe}*<^>sJv9OMG1T#5LRR%| z45vkOmgC9MK`yIihMUTpAtj!YBzx56@J=QxGegRhQl>;xrm*}m=c)E#RyJY&d7lQw$tsFrRkrczAmS4b5(J;!YsBy!W8cTmttk6g=ih_`gs z{8s{hu4&%27cU=CM$LZba-T~1AaO}|jy17}rJ_AWsj|f$2Fp<=s`O+xv>;d>$>u6v zB%t=-zTa^JQ9hdMf+BR-`;xVM@K&B+C0e;lP3B}=Ca5K?M00aFoC=@4keZ2!2UfM^ ze3p47xgV+0v-z=vr$fo%mY3EOxa?XttSO)46B)d9ajmR$L@9)0dxc=~(oOQP{&U6k zgB^p#E#}E6KcIwt2~}K|w8jbvUB}Nl^y@edT6r95Rjrf`YX_%vjh%XL*Y4GH^l6$i zt3@uVxZzo43NjDLBk`WmM{UAHlR?DE`kM)iH;l@AIkndj-Wy3*dOs?(Ku>2A3d{1O zoSNMsPV+t{vTI$>7@=izVw`O|j|23NthA}s?ET0~tgJZRh^Sgbnx=(6dwG#F|9-!= zZFw&E$=_D%PTi}fLuHmfW#wLD$$*(1 zi#W`eDV7A@e5~(W_563+;&Cd!Z)$;p91~}+pKOe- zxtAU9jw!eLu`V$gtuY8--J9~R4aw4-&M{e4V^0v9OiFDZ)p;r zDCm8CbDx-P&?F^b#9S@0g=**q(vlpEHfXkB z33FKT<6ca@{VH`qi^`95)XZdcq@onYd=F8K?MuY)*rqe)asD*M)<;jb-~$|a zWxGe-@z5kD@KMmG4#XOCBQ%B`EW_W<)O9nh5@j%3drKFek35gBkg>fWvsl7U@LX8G}uHHnT*byl0W?6mTEP6TR1q) z(?gfVNb$|0b_Q) zH~y;;r@6RY*y7P#=r>JPZgS)!b>;##8yd}vVK)xMx2GSNILj1Z9ALEv2Qi<}WB+Te zt3LUuroFkM8(#1JdJCQ41}+(shDAxQHpT-(V&1OEgqlViF=B^<1$htNTd~Q4nX#xG z4R!bZuZZZyZeFXYAf=?MISE`=@z-CBdpRr@H1U#^ERdD#)DyaQLdzEz4m=aCa=86dm5}Aj>EpkMgqA=>#$Tb{Nkqt9(X^KMpoj5Ft%Cn?FscKr1UJ zShglJYoyVwnowDrQ@0suu|TF)M1I^lkX%0z)rPtPe zVzxER@AylUt2a2d9ItRbIrN>MoPH+Y;=MGHcn^5^_ue#K zdzpAL-52#FTOvdyDWClb{j{y?u}eoChI`fSFIMb!CG++M?y8MXgMV2d?}6&EtocI4 z_3K4SoPD2$Wd!l)qzUvnD>IDH_-5gOTF9^%a|2rA(Ja?$NuK;%*ERF(2_q@R^WqjC z^I(SG$Kx4Rk9)mY8c2@T?~i4eXLfs+Iq-_+Bc$+-eX04pR zRkc6q7uP5$xOu5R1JZCq7HU!AZoE_e>?8MDkWF(<4c9rwh=eiDVExOF%*4-~D%5WB zM)w|BX0dT^_X#=)wB*Dn#60T>rD~0#&9y@4JaN%lJm4Ds(NbjLKIW=5lHi&-mY+$> z)hqthS+|sKeNpAm;?0(EWt@BA>{jWxhUgW2yh?L>UZ-PRraZMDuXB|yQuq74KH--2 zBDJ_4wVo{yGxB&onyw>1CjPL9IchzY`4dfl(7CW-46x=wRZ;r=(QPm$&BakYncy?Z%>1XMHQ#s1?YM)%kesaBoHY zvj4|v)xRU@!?K;JwfL-b(1=C75D?GG&eujyw0YVm7JY9zgAF|3qJS`K18%#7q zb-o#K=-o>=J!%^5o`boF_9BF%pCrrXHi|;$18NuDVO-ECs-T^*B&iWcV;hB!-1oi? z*&did#6hOrwM>gcY3-RMKSO+X5VXN46B{pSNgBD4!>|PxS?Na7C2X_W&G5W|5Xl=0 zsB|K~Mbs=;-2y7dfus}Czd9o;-bNkqE2!rH6IW?nGbJTiNtr?|8w)LYvsH!%_r*EQ zekfhat*X?}{+U4&q5fsi%!pz~`tanS!=n^#qbuHRSL6N+|AAqT;`+sZ+iBs%FPLI$ z(_u>&S(-y#JJ)P6THfqM4<_ev)~EMHnZaCPL9ee3; zaD{G|ne`~It&Gm%`Iw({FwKYS+#eqb5#SY4Z&gqWQTT=q>QPaeiBMYuo8ziB8Jl9~0}Irk zzWa6xqJqjF^=T04Fy`jym_<|aF|i%8oU176m^y%~OB+eG2E+&jrWep$BEi_bJ+Z5*6C%Pp6Nhw%7jg(FZaTZ0s!A{g0 zs-;6`zqI=ROXD(TgdQeX)5o+zeJEUiV@2fCMR|U`VzuVj)_mQTezJIzh@x`>p{D+p;eT7_n((X2G(7 zRChTWo)P{2gvY(^sXH05lV2!?93VqkmH$m(Z)8gDHn#MUAl~cwie1I)(aeU~;&U%| z)jo?MTbavZ(OqEE6@2^Fpy!gqPesWJXV)6{1QJ#CdLQwhfkx7bSH;U7bu$MQuPtWf zS*pPR(-?d4YsDA!i{}Q>NV;F@ykixw7^THK4 zB?sm13T1N7Bf~?jq|XaqiZu1h7{1oSS-hj0LDkClY%zJT#lna5qK__veSYu!@;gM7 zUc%||do|t|*FO767i6(9Vp=RuP2pMs4JXTsnSgCxi*Je3+3p8JzrMW1E98HVr6HFe zGkg4Nc8{?+^oM2aGVfCfQk|E3A`y)!r>>?xERAcDUj1>PXBU20HT-JX(Lc?r_uZ+e zR*-2diDYvURlaWX-WXIWf-Ln@#)KVxC#6611QHFn&-3!DNM zXv;lf?orLq^tUrPgLOnhtuuEu_PzF_Zk*E=L@|ABGT~~N{1%DThsQBSR`In-IaRKg z4V*Tvb2{mLnBHY%EGSPB-b`nR4o8o9MQLb=kr@742b1%>511H$O}b zjH+g)e}D6mNi*iC>#9fumJdxT96`_MmYt!=+<{MNlbZOM4^*)6*2zV^1~~3N^rUK~ zzpP43gWl0(PsYLa&f(HqmA{JSko6*@JdFOygAdTrA6fJMMv%v|j5vpAmy@^hu3y?= zXb;(DXO2(pL7UZ(IGN70?h?Pkxab)oW~&!pDP+r*KY;rS2h?`A6?xJMe>iqpYmVFjOW5{l(paF2iSC5* z@#oq3l%Ud_(xLCzTEw0QsQ$WDqZq|t36b}}lpEujcMt*lJ@-h2xGgmqr5yAfnu)dk zR2$6n?`xGPlkZ->U&ynZ|8+z#npxuJ4Qy|N>971llJ<+75Ag8ZcCwK64tN-*=g7!g zlYaG*jQ=gZQMqi5Xs5g0H)qwd@;EdLrv+!bWcsB`n|PAe#TtVvp=@DcRIQ8boW+E! zbz#gkdO6d5qOIkG8@Pq95X;g%KOE|wCqIgh|CzJV73$xiv`09UCm6-~wOOG+_hdEq zb*1$~?-@+h)l@mD(4{9s+^)=P;Y5C%7i422u{2f-7~}cBSvpVHxjgUQ5)%)j^0Q?9 znQ$ckoP|rN>s14CX@rLJI5qPkDb^P_#kj^Fw5yrd>6A}N*VjZocP)MYROp`D9gfo> z63zl~&i2juQG$I!FBdh>+X&|IUzGv6abKfBCBGg!3Vfi$u|C;M;-$Nw$`$?m!Nfas zfIq4$b>>RlGq8v(SZv_)IJ(?uYqkO3r9$uZ_xVWbfbd$u-GZP&XlVVcm8B{>YYqD3 z+m&r9w$Xv zOThNDu{VwL7Ti?SVK{}TyMOe%lvi|CZsqsVF?SbZTR6Uz!U27)CLq9tkRcyv<4Cz1 z5U91YkFwLcMQ##@$;5lk=Xr5m-+~|Zzq?QvUJvo+FNdaUOAC4|X=Mj3?{v9u>h`Bo zBb|3bAITaOM{#g}RgEbKakKWht`lrH^trpGk1VVZi}7-5Vrz^LUyI;@-^WKw{Pt4^ zyK77LtlouXFMBvegeMWc-$}BuY5UgJoIKakvRb-Str|t+Bz1#K&bcdov5xCQ2n}u_ zzsGClC=M&`=I%(@uiXhpp7X{}6!xw?4yGB*JDs^X{x0}}hA1BFdgjNao-8jY?6mdG84gQaH$k9pvq&Fpf3mrkQ7IUB+pzF zQ0trt@;FL~xh~FL?!SW{xww9YFkDGb;wFWF5S)IUy49LJU9x8y`FdXFJ^dNu%`i3A zqV8lTPp9r=?DWv1bgP|jI9yG4VUd5M;mYVU=EeRsk5iXn+YW&hKp(7+f7bF!l;pAY zRzS{w7$Vu#_%Q$~axJkH<+vxX4mTs~mIGDF7`iVoMt>CN-28dsoXe}6hsqN)tpKlC zJ8Bj{w3s^_sru{3j@aKe9<7)z6vsd3Xk+;;0EY` z+8`UAu3socZOac`(Rph&<9gf@W^E2PV$cB8z z+UJ5YCCXGw?zgnWVyO@-k)wnMn`DplM_Y0qv;2HoYBGuGQg|v@wcUcGEh9zM)NpKI z?>#SGy_Gs-_Ob;bSD`~z;k;v~a1a%#C4EC3^ZfpjQ;mC^Kb4GUSc_VS_AoOGXJW)V z^ESk-px(>lD36kFS;NjU!D+)1#iC799u-*%kqc(4m^)TKHZ8KjIA8W=(ho>AGUb_h zucUo`3Yx)jfzWB;n<~y;>O$G7| z-D{$p@@B#7K*w&Y5z=_|&k9>;Ey)lDUH(jh#MD=qKZzxg?;x9!+!ssb@ai+=1*xVF zR~^izVs?8XNd-_u0lF?%1dK)gJ>(UplJ}a{ae7=+OIOxupYYZ6+&cqOl^pr>_<4l3 z3Hvi#E&EO5XQMNY1%vo=1F7Qd*{mO0e6;X=u74oUTPTV<5&Vj}FC!N=)0ly?GEg!* z_|%X=T>_uFf9Z8@u}P0Qf}!rQ*~Y8L<-rr2y2UyrDpyH8`Y9J@U&?+DcEKD4Zu zG16aw(9h@F}8D{5P@&4&=E^Tcp3SCxM8khvyHY_B2 zS3bcuGbdhFA)y~Xbz{tzjuzEZ2a|mlTmBvCI^q@SD9cJCQzDmsUiHn^q6`6-6BX)a zHUQpZggk0prZUETBG4RS!Gjk0N({emYiPW$=)K9g_fA{06@`qs<;gJ>hA|eU1!{Gu zUY7Ob;eFRt+k_(Xyc&o7_nTA5Q<*4lD$P8-*G4uU_#AUxaxmV0q1zE5Rhc#xQoN2A zHP0zAgiN~B#$gL2;7T^UFRrMb1P2TA^vD8lCWh3{Z;g;$6iA<^oTJeBcOgASJ3%?f zNJ~*GxVUMf)zz{0rd(Z7f!jCex@@85Z#JIe?H>1*NO~fN76K0p* zqwr^%V@xildN&cnhHu)H3t10_qjAHKcPf{d*_Q`BdY9kJSF_=_deVNG*y>gb=5B_LEvVBIHpjB9sF+E6{FxIFSzqj zzJiw|Y7j_XD-dwn9q!jwjbc2lRgj^%(Bg6~ zoLF8{=O*7$tvy+4 z*17LpaGCXrjyvj7^JYpKp5H;49pgV+G$$LtanSjqdNgn=#<&IXie`@>numvH3nx@d zevhU;*MK3Smb!Dso2iBQE$N>5EE$)vV%`?Efcq#)i(M4yy3zP`XLsm6n>W)}v9W_I z-b`^6kHr+cbPxBK=x>DE1B#KRFUoK=4~*-^_r}TuQTm$^3OoLzkq_fNL{7}aYf}Qq zF21*Y@+4aezc(n+vMU7{1!SYA_+fQST^6t6L4#gOx)83`EIRX-Bh}%)%dTbLtD-~R zCcEjw>ROM@eZ;C8AQFMv<(PVFWT5Caq&9J-EihU9&>U&8B@)5vzTuZXs9OA}ObW6v z*GDeg#p|Bu6pp?wZs@UBHC7TEG6M;k=%uQ~ZdoQRi;{5kS(3jL3qlyxO4})rx|^vHq~#%;!pO*KkVGX25^9aofo2*)8h9AncbR{hBjhX`KE~AoU1C!+Vp;R(#bX{Kz z(j$CYO0rsHXzRnhrswpN9@#DA8bky~zGP~7U{$?>@y}iW*x6KW^E(fZO70)mkD?ui zTs6`M&#jQYMN-sV9-N0ZHFYYrI)b zR3E^E@=~tajR2Ml$r{?EOzL_;{F}Zt5Ggqrs_6O&;V5Us_xHO2az7-;w=&(Q1G5D! z&s`q<(!7>(yz}j6KR}c&gF)`ww{uF-NW;n@xL9Aq)pYPm=Pxe902zHK7|L1-VjP(s zMngv@@5C&0fVFLj#8wOxM41c&45*0cbuiu0n&KxWH)SC4rU@5&0z0-2oZE%CM zY@2Q=Y#5*U-adntefPKB_QBzy#iCz^tO8-$XQMKC?x=4PQ#FTXtSXGYyh1j_?og+5 zlz<$1E@3SSNiI~robpV*0^smy@V4BW zMIoZqxRlkt+ceFght0d{!Q$~KE>R?&#Jr*v-hB!`QATIwt72A{myD?ud1~B8mvoKM z@PtKimqhMJ&E79_5c^7e-Tg$L?`9h!q^H_JmEoO)?(1jE^1dLDC!3Kv!s3Sa zsa|(V7aIar^ch#4v<{l7#zbQ2&;r8l;(xwynmIOhwZ)&QV^=_rOz^qCj?k&B*1%~j zi_PHlbk9nyfl;zJXXCjrat&GbLZ-fw*jfXLh!cZ8Qn#b){gUGc<}*FsxjqL%o81Vv zg^H7P8JRq{PMKhu1Y4#ORd1$`9M77OiTx*Tt;qMJ-ykYsiy#=bAz6o>%tPnH$*Avc z8)YVb^X5>h*vgZEgpK{tp2uk!MbEHhRa-5@3o%9%U5^n88tC*iVECnJ5t`p4v~8qe zy-#M1$TwynKSv?Hk}{W-QyPti=XbpbT|Z;9aNKTj89Rh#QQdD?k65*N$%=!J%1z%W ztd>z}TuDI!yHRLV$zdj3tup`hZdRU{mF9ih?|fqia{Ybe#<3HY;@j(A39A@LxOu=+ z9YyM+D{xBd&ruA(Kba^!xa+%CHq>)$H?|eOSI+4USVF}Im7j0q-8aKlBScYIf#~|K zT>$HUAxmz(DreYtg|FDiChB}S~q||rMUU1hiY9gl{PZ%iY4wjXFPCCiK~FH&XP>axt@VD zWlg2j#M(1x*+qqiDH9}Kncig%Pe+G{xewO2d0s9b|L zuI;>#hECP=m*tfG8X|W1m|&!A|8W%UHSuJ!SoAY7Ro)M=3`*6n*c~#r0Ae_(tjFhGUHaP-*c}$3r-*s*{j>*&jePE#Tih{Ag8)A@hf29EFceA zKcR(x$)Lr9+C&7F5>Jrud?E1#Ioln~ldTYxJ-;Ls>ltq*rRGd;mt zjo^I5f60X65@s)q1&1QAqodiTOPuT?4*RxDh$vYlbFk5UrzXB%6PP!my9|ojIE%oD zwEA#u_Y*K=Y0WYfG+ZtPs_`(|B6(Bdy+Y$X4gjPAIl*6XxVGAx)j8MRJv|g~y&*6` zQ$`Y9CwR%3d1!3Me5}JLxpoDV=Ptl}ojx-0FD7yOOT9U3Sw?5)p?=QJ{^P~oDs+@c zu1L(Uc=sHA@gI`-C%dzcZEd0ry?}oZkj8qk)9^~!V!zisP z4F5%is}fJHwf)=#io#E902R>S44>d=pp@U@AY`^y^1l+;TEO6UP_F$Y8OR&Qk&BA? z5^mjIGE%8p1S$+ ztB{A^=)ik(wOqb02p><`>gQo#qV9&LX~CaO3x?|4FbGK?ghV?jb97R}*mudY87II5 z&FSLoLUYO|-?1bnd5EFeLwt}+U!IMzDlMy?VYZwlX{${pSJ}o7-#z`AAvj~HrhJMY zA$k*?$7}n*-bAcf0W}~a76HX_GwE`1W{i2&g;n+)RljyvxWwKOMnbB+?s?hD!N{DT zF`96rqxaz0ZT5rlh=mXAK|+NCNv+~}NjG&M0xU=A2NsmotXQmaibUU+sK45Fd8_$$ z&;Y5{vnPDZ`7f8shMYk5Rz9C4v@7bQw@ViM*kR?|6Qi9tw_@c{h2KE$nesscO~Hr( zMB4(8i&fHxGR!6K@pt$^30gWtpt&@Qlhy)w?fL&oXZtHDqyvj~FX*~R#F6A1Dw&gS z#lbot#tvCj00Uw@Ne$h={pr!pR(ov#-)kB4LZOx-TRxoo&xvWOkRNwEVYyJ){H*5@ zm;Dohw#PPr)C@zV>#C8#M%{&lFtfH#X*yd|PHNdh)_GX7Latf60<(J3P;|5gU25k?Aia;bTl}a{!>;vq)<8!WVSm8x{wWIr_ij zIflIN5<*9^K=+~fpyl*zWsclxwOv+2=XjF(;6^mtw- zQ2O-_--$jr$%D^s7~vzzVgAnOt|Cu(e=?P?5!GK#L5{bp07C|{@F>!fX8UBXZVa2x zvV7v4BL%_-Nqj@`g+$Wjm6Xp-r7bJoyHIq=b@;B&R8fFo$%-d+XQ-*eA$k2B=Do!# zZ6CV=o(NAMI2JXj8$f~Kw=f#&wOc1?kyUJxlUA&{B3phgiURg-? z&P;+FwJ~|XLiqqG2iI)Ags>{pf3Yjs{-i>y{tpH!(ND~IOk3vH7=h;Ise$w_*5L;N z(!aL*vg@oqr}w1~UsIYJs%%{`DMK7Vk+3ss}V0u~%14 zR}}6}+vmB&RgJc@wK7e4Q-bO#8wLK^DyI;h_k1J+IFW?WCLY*=#d2OzfDN{zQ$}1A zKZH&;f+pOQC1l(ZAC%p92EwFSbleP@d#xcbb8SjpceNHWE z(NdM9*S;xy=1Bq$tPTU^Y?r$5Bg{l zlK<^!qMu$Tib_Ut>Y5lIJpDb&dv5ef@bn*00>4U6^UIpZE2r0SLHh;k@ZN5HF~afH z7xT8D$)qu5*NWZIGUKQI4(()Q_V!x9DOXdT2RM%wI5%D-u2gTeJAmI`sdfHN)TGBH zC8xSPWQKEZyvNGn#8Ss)gO7yOnXK-LaqdQ>b-5>9dy>GEU3-a2IX+5Rd-**;1?e`* zFZmk=`XA4Xc=&dIpr2_FmCU{N=f#hgD@dB|_4sXq)_!k1w#9r%818Z-nGpl6#Qm`} zB!~{jeceqzdE^qk!K8HWaya!1_rBq-O4{a|taW7tGBc|V0+QG5-(PoKujlkQafMSb zdbznji??K7>+BuMWDCLIWQ&uEwV=KOhn?!dWH)7bH=Z%oqD}X>-3`AnwWgT_E=zOm zwX!YW-wlvGHrfC={LKKpW6)T>|HPX~Jiz?I3bBi2qjke}k}PbR;o{zl;T^=cU2btQ zN_STg(&MILRf>7c;Yh-IY>iF6*rD;sBL+h4qLUSq3!Xpgb~}z)m7Q}MJKU4=edCc3 zQ-dimj(z_6I+P2iYi<5qTiWZdbzVKoHiVpU)v)NFY|V;|KkOINLLRJ&v1<0H%MS_W zEemY_Sj$iwYT6SI(sZRtuVpUt`kRwt=K~5W>P-lorWCPY8V`NVW=PM`RH-P^6zLih z^dP*Mo=|2f$`s70-j|&>bO(jES(G+Qig|LjpR{p68`taVOQsTo^Z@bj&Xlkf-2Kki zV=_-J*0tksA92>{=8@=j^?&~C`6xkfv`@qMI+~4$CYKdJLH$o3A&5t{H_im3JThGM z8ag0BzDTx`!lvOKqNL&BcdMj-U=8JbY&%8<1%Bq&BkKZ!7yClfr zrvBz@eaZ|yxn^^3<&wdjZ#;pj?*z&wxs99^(frHJWtUpR)0@va#O$!)ZSK`??|1aOJbarXoi9LMiVSrL-1tF7mfc0ru7n`>2BQUNWczm)=CzC5X9f0jNy|( zpLx#{6bOyK{(ZAMyCa`ogI_qHw=0(#{?miS+4>IWGx+mJxq_!1Il7@U>M@X6HcjvTx^zBrdwtx| zZXU4iNqytrqaQ$Yy*NQ}6nm+>Pz2)~cEzA&EQ+)cUpD@(o54ga#+mma_X7`A*fJd0 z!uQG@^Oge}Utc_5sJ#2&PQg@Td^q%B{;Iym8n!o!1 zX!Uj#rCccFl}S@ZCTDqXS&%(0xH_8qRgYTHN$r^!%@|8_*|>j?zy?V>B(q0(BG4)X zFyjC9VL^pFlY%8woEmxrDkRV$Oxx-@K=!|`twl&eZjI&=b~Ff%>Gv-W|Mbw5SWzJ~ zI6w7NE*s!Ve9*ik5}|yJ^F0+j2-Y^~BT@>$JG!584~#rW9I!2x#!*+n1??1mHGPs7 zD6!m2t7dPhWE=zNBwzlf^5AU)nk6@j4rET}0jl|HIpNh(MY-jv7-)re*ZM8ZGT^c? z)6UR3p>VtJMH#=QHxu()`c7qdA|+DQ-w8-)`!)?X_kZvBp9uE$p3U19+cspY@-2Lt=A3u)O_sk8y!*Qrq*4$$qeHnU51EL5 zE|ZqiFhkKzv*dW{{y?+7F)lqN^B#+UkG2-0GY%!_9RB+u*({o2?Tw^R*N%Tk*hNsy zrWg~3A2ILYXfPd+Ay1cHw5)%4&}4y?oj4FIjDd7qPFlo0FRfC63;KW7;i__M5y)+< zSC*=PNW`ZA^^D_QD`YR~A`me$^xIE(CVy?{YvKyEJU@?7hF z$CNJR$D=?n&hDk*+3-U1Cg-Kx6@vwgAf3Pa8fUN`LS@EdJeO}0xP=C<788M$NU%O| zzZ{l!JoUPodK(7(7Kl^Z%8)thG-CKtVRHf=(h1|Kc|4WOv)iGL9^m{i%O3w}J(6I( zUy)wafcmW9=(eR>&D|qDB{S$nYI4g~%_7HuEJI(^?qsQ1d{Yhg|$MlN_r&Zva zx1tv*@K8NRl^p{>wCsB;jTXK9K`zfzH~EY}sSm0a!b>=1DFN@J;Iy;+4&r~%rS#&k z|96(jDIoK8>C?-n;G=!k?;nSq^$v5;escdT27Gd3pCv01Bro!!Xv-V1;I!#8zBy1# zQxG8a>lkc7+1;JnzlFc(x3AE(7n&RfZXhLQoOym!GPY!W5w~HHDEIDZU8gLrzlYyY z4#A9n%436s^SQ(ybzt00p|XA~-V%*S#Fxz5roL1%)MrJweoy+Ny;C+=k``GNQ4x63 zl>gUdMEf`-N-g`MX94Qr6?(TS<2t9`g<0G$3s1=&enVF>IIaI&0+mc&0C$t$-!+CZ z62A9CU+Q=D&yXIvfsF%!q@6qQXykr-Q!=(m-UkpuBL;!$K)BpRL6Hbuisyl3e>W+a zweY>Drr^H?ZL%~|(3QjB?ZOd#CF@39kT`y)E;~p|o3ir@fb?`JT)mlUI43Dte@~eM z8{7@PcQ5GX-_V7qhu_F3VcB~Q6z!0cdm%;=BX>%8d8?PF;;TQulppOqUNzx1t?`Pv3YK~iR6mo_-%97+ol_UOw2TCqLkYg_&I z<0icGmv(#oui6dnd@$!J13A(f)>?K>vdiu5nk$p0=Wg`>#0zk_aZnB=m*!sl{WN$j zqHANFE&xRG@4;B3GF$C9dac`ok`FJB9jF&$tlWnbZ?|mPoBm61xq7?JC9unr=ZhoBh|FHL-QB`J3 zyQl$`APPuOa*(_L$$}s`gG7M^DoO^)A|O#vkepc(B_kq;X4O+q)qLt@KgAk>Qnst)O6F}U56p_7)Wc&x zVp8^P-{@bmD6o#~*t?u3R+$a>XJ;vjC-+|S(B0uzSB<C&d2iqZv# zs`qogx2zPcvw9<~pKYiq)PDXAeEYHC6!P8uR;cA2AwhbZv7xT;o`faNu~T1SdE840 zpS|UETH3`a=d7y}L4zNbayqi?qYzjK7p zD&6kDciks`76Uh*DVlGngdDGg-zmK7j^T{oYo{4W3cr+!OFBqPy>45n54BQSeRWjW zta83OU@il@Cut^0-}^+O<$ENDA=OzYV}hmPW&+P@PH22a>4Eh=rV{hRR9ydND(F1e zh8S`$sJ2X$u{!}+ZFJQs;rBsZrm9HX0%XZNa3zPag@=?~0#o2P@b~%y9C5ARARX}C ze}QzpfwW56PZl3CwD5RV8Lpp&QQ^*)^$AJ4_^RBQytwd%{krr2r z;k`Ux{f8_L_Y0E*W!C&lo45!*jYck(y?rzV?ugavc-qWC2~4(Mx{TR4WKZsj2HvbO zq&)8dryVAe`89w+i#X!N2L3`RRUARh$x!-Q`2#eRd(@66FA+{e?4p3RmcI zH72;^!cDfF7Uv<)tL0{`lmlvjG^|?G2NiMho&N%y)t8;$d}x&pV{`4}h@SxfAzD`Z zyVES8Z;8BYxjltl7(B{ZD~eOtAjl-M4xzkrJfaMa;nG3!PWT3f^W@in962z?`0J*; z59l~H{#^$wbdxRF=p!ujI7L*N8xU0Hi!DP8Ekv6KLcYK^5K`iEK~TN> zjp$bl@rC?{kpe3EOQBDj-3NuvY%&O?ZQGb1AA%~2yOD-t7$nvl)OXyl(DXDZgu~$A zeA35`NMMcw@-m5V2b({BtjW^6Qy5a~fA@y=j$VZ=pi<0x#H(pTFkHtq1+p=24HX z0W|g>gHMz>(C`*cs_d75hzG#iC^py*@=0JY(o_IDy13~`u%Pvgh=5n~G8Vck4!ba8 zNXr&@#F-HE7wS_0{Ph;MU{AU4!UJhD{S)3c0f)k>S$BB|_!FGlwMmF?=pncSAnSaa z6+fY|f-{^;8_gOHEOd^bX{sA|regEPS%`QI2jB?nLsmBOs^{;$e6SwTtN;U)ceJff z8Cx*+Y&!kc%X%Ixatu|}D;>@_3`At5%4#Wr@sp8Q-vw?7%)2y_{2xXFNZQr}I#2L4 z#eOvjwJdN=JgfR8kes< zS3?pcbDUNJQyl{*ZQm?UJPsDQwa2zsKmwyJnPdyNuY!*MKGN>j!K1GNi=6cTr$s)k z6Q22YrAh1sQFiwu%@nkRdgm4r1{p z!UlVzmCxCSa8E;Qd+K@h2iBjxx%i3%hLpxz>udcI?oqT}(j|r6v&zxFv`XwVY}WZB zZ%tzypCTREy{8|)as|E{+Y+pZ9f0V*8L~GIfXC=3qdC?-sJ#2aqnE-nZ%J8V1Ab*p zG%1aT&AtE!pzO&ADf~u?qIaT9keL>buR`Q7JOEad2ZzD-#>_qxW-C0sMwaK(Qk=qK zq8WFnW_RE}p-TkAwZjZr!iicy4nUTxJ|wQT?UL(I*){JB?1wzl{~-9(G9Iu~-zPM` z?#Aym-@NxgC@5wx&qJCmnBiv!hWiR_`V=`)$N@;0u?CW$h8u6#2)ff{*slR?L4Oo- z$&prY2K(|dDZdpl%Pax0v!`~HK_c}boc#A|etArAqAtM!u+Axi#a?YfIflW7L()7K z$kGE(i~nD;$d^w-;fm-uPP4i=;te8j{?X+&9zky@+s-uW>RC}6y>%DB_nQt6tr937 zI7-UZ{y4d3pR|cqEe*_rXi;AP{2HTrGnl5Mviy;QN3t%(U)Ir23HNPdXgI}nrdV&7 z8?(a~92E5R{-!+Barw6T9KDb7s0Ha06W#YpAY*$eVSC;f0Wvka5c~+u#JwdjwXvRf zudmO76JAI211|iiLby*ZA9<*9EB&U`Uz5!Ct^b1e)brF9(83#S24`=QXlFeq8=i@e z3jBoBmz~EG$>n$51~!|yg;w(vN?v%5 z`fhG;+E(Fz6zh-v`fvvha*2mGDm82%G>ba$M}IckJ&Q!rzmxECr@?Hc^DnG1wscWV zoStLBF2qWwOh>5x^H^vNi-CRkCfbj%(k=m~q1|*+j17L#zsj_YmtpyKCjhvef@75Q z<#ux|7P``i@*WkyqU+H{T5yI*pn-}FpW;`feh{nrfCMQQ;RFMMzTnXO*=AL)h+tn4 z*k3~+Vo%j93_OIAT(HsebZ~jW4#9!oc*Y5g6*6z^2Z-)M@E>P1bVs09HtgiTs&phe zplz}lc1!|uDHvj6T0hlySl&?k<4(jt%c+N|Bc(#hA61Y27=~})L#m~N!;qyUfnW5m zp1W~21obWg>(@L-VQcLiK?o`d4&g(mI+bzm)Jg=A%5$LwJ#x}G2W)enllfc$BZ zi2EY4lv+O*tplGz32VFg9<35{6a4PZ*-#IhLapnRI1jb1hrgephehE0cNPI%06_8@ zzvKy!G4RWae%O0nM=>>^7HZF(CnKO?2lu#!{#DqNR)E0&_EIbf+6*6(!1T!$V89k; zTY6T3|D=-D5dof67+PZfZ4+knn&{(zyvdsQCJs=rp~_e5unT)`I5uj+TGYdl`p+Z( zU$e@0AUH)IJZ*HQnDj`~K-sX$J?)~6g1 z$j*V9+)e(D7(-M}A1-*EV5oBG9tGexBma*fL);v)u2dPRKA=%DOqVNxFL1tLEd)#Z z-VTzSqkRK1PYEwU1#1mt&PvEXD{_-wz6l}gh%D6oQ1@W(ZlC=}M@(A`u;pDE9i)Ur z8bx=mLe-$^c_}4)qjfo?eur<)rS)PnLjvGN95kpGax3$6P)-S?W``JLVR(%M1P@_I zng7AD5kQU%$T%DzH_ns6YH-|duIlQC!Z!+h0BG&a@6>5|{}d1`qJfQ8c}*1>!7{yi z0DZj);c41v0Z&kYyMqfs`$ar#4nQfz<}HvV-|w-y23+1J&Rh;T_(qBag!dz_GUtDV z1N;khqhYD`#1+=2G|N25R$ur-bA(AatVk%`orMb?}Dx7`2<=s*% zZyi&U9B_%ju$$0P@9qKnaIs8soxVkbW&Lyryh0;@oqyA6{L&PdjL&Tvlt2c<9Sm>y z|G_svJQR0{L$XBeCmXG>G`7mJtNx(&5*eutCg}^z#=rga|M}hjs*eBe2mI3hvnIc< zNtQ}{?>oE5^VrXi>wR7&-aP1(1%=??y>=G9mafeTLC2KTnKNP(Zt*5tgIH{e;E~i% zw6$DgG;HD z#4lyQ>T%*FFMzV!W1;>#I_w?DpOF*yTagZn4kOt%K-?9+8mLQOiC zL9-ujC)_@rc@aQjEA(Id1C;y6@b#Zxf8fx67U#cjIm*e!#%}m0nF8*Uwp}OMc()g* z8~8oEs@$mclQ!>q=cUqDQx74E?Vi>uKGwJstV7H)Zga{wz{SqJAIZ22A>+4%x%4-*K|zm;v5Oqd!W$#aFXbglSiWaN+Wu@ZJGu;gw5C0-E3qJm=WI|ur00vD ztjpxXGp!2~odAB}0)zkGJs__a;96@{Z+2vR3s2yEl0l%Fjc)}O5JNbX%S=_p;%k*( z{6h->f}x@tK-n1RDf;h%)W~7Lkxoz5Au$f7zab+4?=YBXh=PQ&N!{WL5PL_%yRH!d zG4OAMRAXoO(6FTza0Kb{x)CZ6t_2>SR)TPCuSLEEWu=1|#UKHuF%6rtf5!dKKyPpu z8m`L7fi6KL-i(1Tk9@=0qya9EGRE&BGtr4T}@-yk5Ea)a-_-;+x#CF#LNX@ zz**^*v&%Ww)IDFca6a>gV`C+l2R(Oe^#v9kyw7n|*t3mKq-Io&bp&Y3-mub3p3d8P zxkhkk%+ph9#ZUM>6mq*lwG~Wo3Zt>aa?m>u%l-Gr!z|0E17kaaYTYL{+w)Tz}>_qzJ*Y_G9h4`>GBRC+sCrI}c_<^AUc6l6QrSHPTKKd0V@ zK%xbH4)ZN*%o2*W9XtrbZNZbvg=DJ~GYI;(WKsN4u|)$tXdDxm57pz1CSxrqlLphX zS^xvc=kntt`>_8ogazKL+M=8sNb(|OgqDv#KHCO$#V$W#Bn>4G{H;S&8%{VnbDSNX zw3ho!FVv0Gnk*yZ%1a@XSgn zAW9DzfzkX(p9LU&?lRWhrvyt&0OYHZOSc_A+!Zm4Lu#XZXZju@*LCn$wIwj9%lIeA z;V(7r=uY`YK0WSO;?ZN!DwXiS5-fn_HRUXPDT9svSr1Az?6;xXjSoB-Ot&L!Rfm%h_3@Ix*f&>sYgfM@Hm4o7itJ| zp)CdcD|BqMv%!ITs^vZf&6a{p>8{e@Cd3K~>|nqz;n|~Z{Bq@~hWU9rPul=!mKnjk z+nJhspZ)O={X&A3@_4!y9?sHIC+P7dCG)+7l=Y9E`=6ZJ7K7_u847D46VX3{NO$C4 z*16h?XsuL%48(l=&1VI0le`PWd)^f?z&y?@P`+}4+xbriaoKaP!%fn}o#o`c;^ntu zqoA0;M{)YJe4f$P$}}=5VW3%T*IwTw=CTz`IdN(v(HQ`RN=q^Wcl++0+O6LF_lWBI z%nn=>wpE2o$K4LO($^8brjuzg7xs2f_vZqF$^XRd(HeLolHrIlpT5`5swRa-$`mLA z6obN}kxHBLBauNwL~#hDmO18ydthj@S%Otzz_80iNID1_fAZS>>}8c6B0yd;54r}= z;cE0-P!Osl9O@ax<1BwsvJSH>5+wOxknEm6PqjELa$&tvWfSVBlVV4IF4_RvsLv?? zKLwLl%%BPzNyzrf5PDiDqnkK0uZo+anrExQrZRB$w~I}j>QDUn1$Em~T%|Doq{K>E zA{t-t!@Gsr?+b(`%*wGJC0UZUCz<7PZP&&@W!uR+uIXG`^Tjpp3ty|3{hT|kePLj% z$&2aMW3Nx9{;t@nU22rH2&WzD^c#7zRl4UFo-H}cyft~1+or-~{tb@k?Kkt^q zMm)DfLHtp_-+u;)KlcNl1J8$jEVDof(tnv&S7>2W**`{@tA2zb98GoyCP8T$5jjPb z_#h)0N=IC035UK`&T>|P5T^BNuTGg@&9n+bX_DJQMc-*%b^pnYi7*=VOJxKaqKIga zN>NDQ~o?y#@rkAWd#`jLI>b{fDOgi*}f-I*$l$%)(vhy#d%*!7FbRn3IeV zwRATRknJ4+=^K&i5N7P(5aE+VK!i&O4E?aJu-;cT8}V8j1)4 zPQ>vB-C;zz=TgMbD=f4Ycswp$6)I=uQd*_tRefu_0uH?oTf3uFG$JH_PaX)=t|Rw3 z?Z>5cCXnL*0tE)$?;nF=L;R+3Ycy;95Qt>S55WcS;hR@gw>rZ-`L^}gZxiYirGC-A zs8^hCYWd`3oXGaK`X9nR!=KDRSL27k^oa{|A6viMf`-zdoxIGfqkbBQnRZx@AVPN_ zcTt2W1%bvIYh94Vr-yoZrnu#a$ZUq4Q>r32_Ert{Oh66#i=EIT+VJ1-w}T=yRg41Y z|4|&IU7%H3n>bx%Y(4kpb>H|!6b;`F7slirQ9|0qXNsN==N+f3{hVA=`U0&Bl15e) z2gf}H++>Q&C!kR&akX3@!mv%{iwiS#4M}_b<*K(Guc>@An$3jM=QeNt#ttzqqG+?Nw4Sa1-MmJU(Gz{^4gWkw-^QM+-@wbG`#9TboA7L|(cVV0 zn4bL(izq*QFyM1gyLATXU^7?<(*wrMbXss5O4u{)5 zrKAE{odjO8JVK%h-3YAK*D+cpe_kP+dg66A_!)szp9NLpj+N7U%{5za1YImdV>b~Z zK5&HA`H?u{{air53ReDaMd9!;@jlwN_Bh^oL^S7OJrsreAOG9x*ezPc9hR4$~H`xx$S+*pdZGUezjqoH2B^ODyyqb zUu)L49xm9&T-GiC*?<74 z*L?M;PvB>01&{T>z`%Ls1 z5!}!4GxEzP{IWbo}M8qXt_+dUO@t`P!{NjeKsirPh`2B z7g0gZA~P?Xx_kJ&qKwP}Qmpc6QYeK6ScA*e(m8a9a5rz~GJGCvBLYeP6o zeP@yoWQ?Z#`aQngqbUaYSi_CmJ5dR?wS&VaTJnbgHOwd8>R9B`q*)&@I{9PdM(OOD z3VACpV}nV~BNiY4Zt3dH zr0?{>UhV>Dek53Nck`cp;{viXRlx(#mfq*vSM00}&A+krT>dswtdIBP1)M4_eUM)9 zYYQcDUG&vjU+p*eYLi^-vS{cAx@ETT$y_SWjWGlXRu~>WeAo#>_Rz!H6-GoMZ2OrA z{X0njL3k+|g0L;gzy$x^VbT_U*`5XiKl5}FV1*ty97X4Rwgt1IDyp{g5DQGQ%LbVc z*Xh6xtX2qS-$c_~^^Xu*)(tYBrr#H;Y_o73cfP;`@**myfg94Bh*%Y(A*C6fVP7#< z*rx1FTsMbIzxa~+VCx|RHK-^sCWzcw`c}!`r&|wWszS&5LsT7)>s(#kD{^MC(;)u5 z=V@lD>Qo_d^+X7L;8`|*EE%J_E;3T|X-yx!75T;q6{PoMC*CqN>35zE%MLh39+sU= z2lI}u6LVqthRG3#s9h3Ti;wfF;UB6It?KUPzuGG*S~c2dQdyNVlH-}WdEI#yF2er7<7Rq#wO~ zhycx#`-%efpe4=I=~D|1+_{eqtzoYoeG~eo;I!Pah|JSHc)6$OmAqdZYC93ph;JQB zkGC$l%h;{@#4$2Ha&@zao&@XH(UJ7BFIWqoYrS#^=qCzmo(2`%$1qjxBKs_a7jY0X z(Yh4i%ZnhYvyXfm_0G%c6Z&<_hLAIwsT3?_YOumgWHK0A-toS7Vf7UrR|bG%@dYnH zI3~d$m)6DQwW(BLJJA%Palb&CY~;1|l)#OH2W=0z?=f?FY8W!5@dfNPMiS_s!3j2G z4Q!5o!q)5Sr>&C0efzP-*Cq+AxSY&r`KZAPY>YTItOnLiZ8d)s`uuZsV`M$32o|p2 zX)Wrn-^x`AQZ-njE?=~=%|0XY@j08Z>Dc)!!#E3dvtXiyucl^aDL-&a8RLZJ6|@kb zgHwBsM-!gqyXXJ*7UAQbi#?(zgdHPfwqwvEj8qK%s5>}xgug5mZ5kFX37RMX3mL8A z&`0u_P88%2H8jTL*2_6ic^|<4IW76&Q%t6Z)^0B;xzXeUX7@_%Ruvl_SqBm0LEgN9 z3sme3a-kWZAEnZPs^_f;hEn`RaEQ6hr&MCkc~g6=sIz`lj^mNRd}^i_nzg^*z=mB+ z_hQ|T7HZ2xaVq|(m(E9iHV8>7)Z5dE3rWP|xeSb~h2>FEte>`s=!1^Us71b6HDh@r z!*~K#it=1qC2j;Gn`pg3xF{B`Z{z72LN-7yFS9=h$4BnfjWRJKJ+2@q^@AWQ>g1n$CxXx!a z^H-icdzN(skMvzQoz&z)QHO94BBWJOduQ47dx7N6)+z(3>7|RC#+8tam~27FNCx>V zUd}!V+Jpnr7%S?|_dYE|?q0|-;I>p_J`llJ&(`zO!5=lloE?_<<3&$yHwDGcs!pg# ziJycQ6*(Y7cZObx*wuRo&QKonIffvLGjqp&9mnkb3j5oHtFGGE)LhZqEDGCdZnFa& zw_rk9C%BhB&Lo=YdP(JHQ{cI?DsFv#udoNojD?2ZLPkK9FVMv+j9N6~dDmlAnhRoW zEFh3}n8TpIy@^dY9PfkDfL_V(a`wF`0(6h68jaAKRYqx2 znvD=L?*8Pa1i!n~@EMl_fs#DZ7?zboWe4el^i{=;?HQtKl9SORHEzy1PoAWHt^FQ~ zjC|iX@W%8$$FF0W`41gFb|S;14fDjrY0~iZ)OZkA>ADrj6RrT<8sg0yz~N+nc*OxH z$!m90NV7SFLJrb@%EcPg?no7&t1j?hmY)$xIdk|-)X=s|8VtC=Y%I#cxM9rqOF~n` zZce0C1lr=X%wTCk1=7aMV7(U((JOrmR=+0s+G@q>&i)UPPg&QwytNkl> z`9}D#R%l@l533PJ(iKXll9l5Vpg*|6M+M?w7D}LO7&V6%I3k(bw{Q16)!||-c;0Tv z@!B}D@7YzQr#kr1E3Cu#K^ENC1n9JPcpr}XqijSw<76-q z402t_d<$XlnorWwUg1Z4-PtCy*lu42(R~al*R{C|CvbiCc{f~dj58Xxg;=8(2+%%e z!Q?$S;>NF^qV|5v_~E)m`cGTkF5Lw0?j4XOJt?j(@7F@)d+nnD zoTb?SI~mf>lOISnF^<6+(`We^XP2kC6lVvEm=VqCCG^OR?N9-J{!W;bB?UM!*}OYT zMbBF9#;*qr`3Z|PIHqO&hr3ukZ9O|`t{z%}FM z|FS?O=T7c=^0fL(Y4G7jE6YBjx_0RGM4mSL=+o*0zF$_0#MtT2DaR?Wzy1^{MxNRmmHJ)Daa}?n&-S08_8C{O}Elq zGbl2}@9pc;NWSQi?fiA5r_9#iJqUkVYuyfJF5sWTP*%C=$6t}h6gjDde-I}=>3$rO>bf|6c@F+TUs|S4tFZrWxybIM% zeqi5_seVivL-d8&A61!z710=zQq2kHptYItrHZkIjZGHJo#+JGOF>m+5vvB^t5(Yc zgFAPG<^hrOS@y7O0Z$tug=74C5nwT`{5&1hsVDgO&V&CD&FWD>D1){dcnz~xAIuFYLDS*nmTGf}SHCE-e zHm88#0IFN<+{K7k`$5{m0hZ^74((%D?^xl`1vr!e@TztG(<7#JC^9~Na`5)o^M-e^ z>6JP-7fl})>55-~(~2d)ZdJS`^3)>9^)S5>jw3wT@Fky?aYX|9Aoh57988M?jCcSZ zcZFH6*j&AKPPWShgMri~zZll^lr$-WJR#P;%VzxJQTh?*yziqxaM`L)NYZ~zy9STz zV>uMtIkQ}vH(NbKfL2WTN>}*BiAb7M$A}Si~5NT*e<~}GM zEuf)SBG4LY3f4|8v%dOy*sn$7kjKdgnX)J_1S0T}fk$!n+1np(M4b^ZS&nG^Jc85y zW{sCL56_H%8!lgcAb6oHQeqB(4OVX{cY5zYUnbGOo&65`X&mp2`G#suHN03}MS z+s#++4E5a}u|}NHIsw}80uC78OY9(2L=%n%UH2NswlN+qu+C3<|izB>Cq^MQi^cu6T_uyo**jI#b2 zvDD2nBfJ0gm}CF^n3adO814KI3KGJjUWWxWcOB|Iiwe?Vji}<_`dG(H6|HKUMg!=}a2VfOlIU6h-H%?4Ym zq(%>-Y(g^JM#yf-V!#cD!-V8-sO$$+PNzkIcXt>CQbhq~TXZPCEwKkt$<*Ji8R%86l%T0W5hoLXERQ+(EGE)IvSZS)jI8S7G9xBLXo)~09O zC%CfV^x%BR)JE|6kjCjSnS=vUzAb4X+$jeL(1#AqV&g?L9z6V12C}6_E3()hE5)g} zF`;#f`<8EqHTNlnWjkYgTPG$<^^liW(&}YPdNVOxdLGv-SA1@A!EqTN56ZoeVw+3H zNZ~IQfbULDAq<}Aj=0l-2oqS36y$7*JmSDsNe~EmR=4XqE_cFX`Hf&hQK<{ZsONW% z)%l@*c8V?uh*kgkLmMQ1`7%jwPtUuY%TEi}j1FVbOVD29Y z+YrkGbiEz=e=rq^w?bo326w)Fgu-{~je5UMecPHC{)D!T#R&_jv!OR35Bf9y5qKQh z7F~cP&VU+uWTo@hN4+1B%m%37xRS5nB~k$}@!w1yYIPZ=>6{+cL zaHC1a6KyU9(b(@AbuzwK-=Z!2qH#P2%sHi{w8^)ntM$;$#vfjV=6qe*wDcO_O#e&x zafPPDcnHqik~>@LoRQv*32q?DT#KcptT_z`>dPXlb`WhLagh7LRY_yk1KJ(A+X_}6 zyU&}=y>gZ`eqEn?*dGNQ4{1oX?qbX9DCTaMfoW~~Hjej!OZZE@zC3+qGsD4h`}7Zu z^ldg(^QMg5G)XS#8)X@@99$ajp z%}G(@K5$h^)ZF;3k@ zJNmzWB=9BJG%M3mbEYf&@Lh9pfC5#E-OzwS^<0rvd3|3^tKkXj_S2;quT$B3?F@#_q)H+ zI_);KX!@DtrLpj>fD;r^a-^BMuoEm1tTy_v6A$Y&<_XZz=cZmaBsUzr7_!`>CQk%D zt62DM%PT`gK(KT7CI@w){e!k_V(~oSeoUp)yyfG!i9%7KqHDTt-?wJ-ACK;?m8}Wv zt}{tQI#E8V5(UQVakX?YfH)}_QWycg*?H&EbUO6*x2Hq@nU%*w^}*w63i zVs=dXI_Wio63ZBUd!eCZP0vqg-3{TM_z+(+^c@&!*{pZhUf^zG@`!E+-7j40Ys`ynnMnOeQ-E z>v7Vt!f+SYa2M5K=kx7ba5D$C^)wg&3#BcBWNV zykP2vmbSL+@~1aZVRe&-+^s$*d)4ZY$R9g+3#+00mh!e1woWF~a6(g`N_tWim4eBr zH6tGeH?X;(U0h@tSBYK2_zA(&tNJ#>h8v^Oa&zvEXXf29KeA!=Mh2wwcjN9Du||tu zQ5H%j5DP&SVav#fP$;!GbxGqa%Ar&)kG_ix;p)O;|C;hPVKgmE`d(~WZNRP9_Wq{9 zGQ@!wrR(=jiWCZ_6s|oJ6M4qYa-+R#$cHJIZb(>!x<G9mMr|3l=L;wCD+>8#V=li06fMOX_-V_KY4Zq@^XU-+%8 z>Bo`VxxDZ+S|fQmc}v=uTecr^AF6mF3ib0bmjb6fW`|#4mF~d=loZHvA3JR$kp*E= zD?IA)1?BAR;TPJoO&t1kIx!UUfQybj_>@X#a`14q&}vOkpS#+1p-T zNf_PsgYptLxv*pQ@V8R>Hl)^+U$RX6OVs-M3=cx^r4Qii?u zc2;19wZ?*3fUgM#sa_(paVx^iOvWk`;?MbWcz;Z`!~Uj&USur15`3g5qb!YEfBTWs zGpTb$6U0CqI?VLiQ=LrJ8nQinTFjoHZs!^ur_)Ubh4#;=DM8i~-)%ExhqA&;6bafd zN-ZGsbS@cqtj4%c61;7FRn+>;=v;-ht~jmMAr5*aG#fZ9PYvV2lIHyED}g9(Zqysz z0=|Ie;^!^&rq>CJuC~bGE!OLfZ#UzP;nUx5Wz1-;w557{>N75Pi+#f6O9sKtLLKfQ zmDH;*x^UmNc2O$%%N@f>8Ltk>95*yL!ms@-m5~DA$><2DG)7QHI0)rHES>o3bTCWw z?S*Z3B3QpSSsQK$OMC3EIWKRE#QNO8czlDsN_JQ2l{@XfL2-u?^_$7&q|&4uJ~CX1 z9wTHmLgwNJVTXthJk>?VWcV5KKRY`!_4^Jz{d0#903I)N%XFj^9RN8oJ$14CS)vt~ zoY94h{e+3q(0nEMCh0=|NE7Z)#Ep14Qu8any=|ThrIZev?1wF{!V)G2zvueOm_B-G z0$3ycz22h$AM?$nDMtY0L~jS0B^`%+r($3yp3U1-)ZC~)`rhJk(w9NATyDe$mVs4MrRos%}cFK!<_bgMG-?nqLP4^7awFMpwDKF79vPdloRtqVs2VY`;mR|yb5c6m`2hR0lr-$Cb=0v;J>ba#cDz1j5Ch}%~Z!74*7 z0|QkOk~p(y{i%^VpQ*?%x|f#({7jK}-^KZ6*aKJ zq|kAb_w&<-Ee_CK(}hgwPBa->Fh1jJ<}r5Wav~lc*R4*$}DMSGD*6q1fZMc*dQ{qciJyTeP^vqewt18?p* zIy@J<%zDk2He4}!pzQm%=Q%xEW%`Mp{1?8TWXW|%4WaDp~TLe>69H_;UinUvRc zi_IgFM7(ylbL$>W)y>p=dxbae_!mfyrJABk-v*dnP#9{?R&+38tTq+8n0$2a4g#Yv zoe^~@GK}1T=(L~>CGzo(t3v%%&R;JEdLUGxy7Qqm|4rb%xc zx*{u&BHW^;Jf3hpJ+!Jg(Mj}~NokgxDxGUtf>@G~ z5P<(z1b&aw@ogvvr{eMBNdYWJNs*|3lst+`5?OwMNR6S^i(i%UR+Fk~5Was8VhPdt z&~bi3dZl-B_*Z6);|C5>a~KzV8XaN?J?{hK))b}uj;F@QlOx5*nx$5DhndaZr#kxR zsaY8GrC7Jo$Tdm4sF|f45k}eCdKdpE;@A5H&i0H>T^cg9=}7N#BbkBP-tNvaW-&g( z7q3XSu>Qn_cYy1d%``S4I?cy=Y^|iiR4!H}(~6Hicwdkr{c?HUc{8^$^x;kNRcLEX zN#2<=TOyBX^Vpnv=m^|~(w!!b6KnVG3LO%8s#p=*?2{eu+Ii4JHFek(JHkYYhArrJ zW^wpoBEx7gCevK`rp8>WqP3`yZ`tU%%UX(VIx%|u2`+4lo?hjc4E&8>Ob@FcyD)JpZdBnM zgVH(o)ig=t&WXnwBbngF465uBeJ8iomj zpFINjYoi1qhT$XvyZJZoM4ehE%`01WH~@9Tv8>${M#`%7Kf5~pnmPIG=Mn*Ox$Iw~d{mk=NQC19EFw zc3yx`(J$gs&hkgUuqVzB_$}XTIIK3aj6M7U78y&L^igDz?RMS9c#uQRI)N&l){&oB zF*U41)-sr!95$h#Yji5ybcN6FFTn`D_r0Dl$9HYQXD%r;hNE=Ahm`Y*yvO&o7`Kxb z=FZ*Osf6Cb3uO80Grd8j9~_h*u^^a4-*|pkJRHkqAg)d-t0CQ|^+~8_pPTS}r}ZIq zDWS&klAg0`+Kie2UlSw8{hWkH&Mw@%#j371=4KQ=|CObHSxiSov^1c5^`4+pv=r{n z`SJ1EU^Pw1q8h1iaj3Melay97 zUzf$az@Ei|rIk|6LrIhJUuq#a;T<2ikVkT4l74b~MJqdqi;ASMq;&K)hscm%uEctA z>I16yG4?cR1>=Zgy{^W2H=i-f-^IBCa{~xIzVsy$&K-TCb|Ne4GSpnLgs&)AvUpFh zTsyTH#IYxBG?qXq8xLhV+Xq=)3YLnbYYz+W94HBDORF%V zSp&ygvh@Rl*Z=kG;9MH@b64p~#`;@pKo;(?Sc zY@}_aSU?6&=0Tz@rqS66XJQdO7RMD_5~P(ZBkl-&5R*8upkhjC+5;-p6JGnbJJ^)+(gEQuL>!CWI0>mAx=fN581;4rr@!5#ItQrUeT*2^ zyzLBq5T```QWGDYZ(XkDjt91su|?Yy^oKpvzzgEws@vwVe%~W3b8D!2;OI|PKoC%GN#%d zcEi9&<(vJ_GWXW^c4?5_4u`!T?T^%VGOn1>4aLWpw1m^am{FQ3+#nCy>-WKXF`O>B zCvMy+%VK)%WiV=s_%=QKE0xR<4M^-1DXvbMcVDO3%JxjcbY93w4bAZj1-vRGhPk{P zXY);S2%=*71N_di(ybS)H736jIo&FzEB=BKVE89phm3DWhk+7Uz2f5WvG34l&uZHX z78~d=ZLcjMpAP+_NEx_h=7k3irD8-pb_ApOH%M_$*;rX+AdPBR%0t%2S>b#+ydCKINdA~)NhkJ0*Jfp)nEln#I+oWy?z6Kut=MLBH2sA!4)tIkR7ttLXU>kGpc zy*Zbo<Kn`JZKru7&8;q+nN51Eo=+9{hxJN#7+p%)!2lKDaZCm@>W^W zRK<3Bst6o@W+WZ44_%|jdt3a9ckA-y%dAvi;mXI!Ex1`Qyuh`(QN>_zE0pSjc2fmV zWYfjnmx_!+fhcb6tG?Gbi?EKx!%jn;Fl`>SoO`9B>rq@I=Y0}_a{!uyZlqkSmoQ$$ z=}H(lEHvS%LeE+9_8`E4ns1ZJdvGIa zeo*g`*X|vmUbl|m0i2(m@Df(FH0i|`NXD}QS4A|gUcEXpaXR3u{vTQZCPGt4m4Ugy zF4dwts}qp6@g}1Z=H0ik`XXv)G#brHL(x$5Y*G)X3DS7s+m_GH&K*8z1i|67)DseH z`XJJa=+pn3H!n#6&bHyZ6p7NFS6-X%i_G4ZqD9K0UY|u8x=II@fNa*(JSA37`g7lo zxE$?^@j`wjTrrzSGl@;$A^R8JZLNP*MVb|nL5#(aen&1>(FU>psI9OTeab)vOXTP=R@|FCHPGL4!S%yV`)6E78v?QYbg_clsmm>e!r9_Tn$-p&m@H{z1ndwY3e+EWABI2O?y`44!8K`_lF>9eQ4V z@3H;SN%Xr5lKb5S`J*n8d-kRGCVPgaEMuV1_OA|+FGMSBCnFdI5hT2|80Ka5c3=t8 z1TkmXu$t|MC8*+W2Cpwd1-%$*Dq%$OKf7zBd&AotFgsC*r*})*=t*!0XJ%vo!#8sc zbODB+YU_lD`@M}4X&N&7?GSA-{4Wj>;VE=O^r$R`Z*;fY;CP9jznIc(ve!&rLB^K5 zs5BC8#muOl9!id}bb~t2H7gF0n|P|THngyO zJFsxcsn@V-g&MxHur9jP{NBTEQcZ7zpUFafGi%>|h6d|gWtj~@GH~)0*|EdNA{tSu z(%lSd`kmc`;u4|fR+tiTqQrD}wd9IKwY$epHMNZMm(r2v_IxVV;0$A8Ux|^!y7jtm zQ8#~=zOCz8Fo~p`lKaQ>u8B0e7Aj$-;;ka7bOBE=)L;gF8fcf|)ZFdq+MA3dR90ZA ze9$nJezq`UbyM-IJ+J+SGgD09qn57HA?J}hWS@W2*8_Qai-kfKzC3EWPUYU`lLV#V zO_ta0B{P86f1uZ50%gcUpH&TaRE^FYq*q$#IOnj+DCx`-(TG!&mgWO(z}8}Z<_-RY z#9_<$adSr?nLBCa{YM+=2IC8zGHY9Rbbou_(+N|BXLFuUWrYaNDZUUkz{9aj`W{yp z=e3znI816rwNUW^5U~L4&!9+P$r0FO5ouBnKUczy8hE*1-xSQChAjAaW=sCeZtX+k zlV5^PcjODtJRlrPBtX-D5R$p62k^SsG{ukxoWsUlQYS)a4s6(r(%tLUm4(LOD&F8L z9s%!?bd4?WN_FzD@!5PhuWyC;dK1i5r$OHl;NMFJ3-JH+@ka>!1D;N_zX;n_Vw}|w zYW&V7trw@P(oEhs`C4V1Ja6h~Ts4}hFt&1yJ*BY)_fQ4Sx7Rn5UL@96*u)Tgt|Dld z3Sr;WNp>$1VD)Uc$wtSsxY@z%T(qGhV)+X9M$J+vpIDZUmsosb7)y{$(19r9YSu;U zaEF}qwqctyUh)|x5$ujM9R_N53c8U7J!^x;Pl$6mMe8~V} z#4hV^r{q3n9c$6XM-Fc~;&>XLG=?_Y6llrFH;45bp?6FssV+hZy*Zq?$Lju?`W+ef z#ze6c2LW1G%b8p_c5Zn@P3=C8FF;VWbF%=P!C!s)`|1&(9hOPlj^UF_OFcLNO%>lq zpWO-nsmIk4OV^l{3>{@?r+d70Ih$wnuu{OwEyt39T!Q*Gi-^hy4N22f>9#iRDUK%H_`}xo- zvw$t#_d(|q`J-I7^}b5ZG)5$UhDDI#L(T=lsFS(<*-7$(#ruA_&Sg4!0h|b#_H!@) zALiaNtjf0C8l}4%q@|gFAWBG=l(fL4rBS3&P+DoFnKTH3bWIxRkPeX&l@0-s?(n;A zpXYs_=UscPcdfnlkNy48VN zDUjZC7JFxh7)K#Fz*8PE0yispyy>%csWQvuEY|I?Bg@d;wvGAfNe0e3c$&}-& zZ_uS3v%BNYp+JKT*9AHn2X0r_=}>&tPRCkF6I&8z+cyppyx}XZa zVUnTEf)sVIXr2aNalvr_h?iHHafiW80cTbHH9p7M0%mAp5Nnb%cox9x+1Ew%Srn+O zxL)96#40A0dQ$5M5Zxsu{x*SciL7o)DPkJ!C8cdRBkLIpRG^|Vgd~;XK=xK(Hb|GqkNsoQTLz4cV`E>D+9pT^C?{MrT1-(& zKa7$>aFpBqgV!wOE>l*ezL`{Osk@1UQ{=I`ZESbMV7b0iWrX6307U>0VJV<1f)Yjk zWpB=imH=Ao&HX!g3RIF&?FoSNOkiLKuviT`LjfpKZ%A5La|O5dTjO$Lj%VE;-VMGh zsnF#v2VA|dDf7HZ;1NOJltDTE%7WlFjslfw%SHEOMcV9x4+O$z@z zRkGB(S<7n`l26G^eD&Z7dVB zU=*4_?n1}&wv09S(@LUXQbm$FH&Fm4MhBZkMs&)}iWN30!Wd5XfzSOn3oC-+sMZ9t zemLyT!>*jGrg{A=`PLpcs^{(!N5Y1399JEwuemMI327TWneP@s{TWQ4IoP{+Ehi96 zI1gC<=XdYt$@x`=AZO)O zekJsxp4!oXbCi8r&h9>Cv3-9tC^`8`N@Zv$ZOiRJ86@cS02~0j``*(JLO*sAOT&)8 zd9P&%d+#lm9GG0Q%kZQ!1r>HfTm|u8fSfSS&KqKBxiDv=Th=`bo8p2MBU3FU?-TL| z^4fInd5&LixplZMCc~E02j>@>vPs!=^s(#v5LQ}&WB3r)KeD1lCCpbxV`1O;qPJCa zF#XdZVVc8G$o*=bQ_9CV&cS6!p__Li4~$=S#jU19#4TPpQ%I-Knv`6xd|2_Y@@T2X z1=KTmc2o+uprF(HqaJRXVg|302Fn0kMi<9Ypc90VdA&vplvAW~s#@~|!&sK9=&>rs z&otX6+=7n^!xe3~x;IBx>RigzRbQ{9@HR)Aa_da8fFfX=7V|c!NCMj; zVi`n&Jq!1Ho~>fN*I=WS9m*LkRBCp__Nk8PS4`j7mIt7p{{#BH-{9|JcO*CCS{irh z0GQ2x2OSWn-}0!)oR7T3qsXbft-3j#xfb7MBn!h|?-P%$&7iiVFpPaN-FwA9)y03P^kH5G--JoKU+0{#gEm6Rii9k%tCMU5gN*3gL6f2T! zj6tntA1h+_NzTJL=U$)Z!NP`t!@`-J{dGRl?#)8&3)S3s7;>j!2ulub+N+1&vYq@*3 zWNdYUao3fXYCX)AC)zW6W1Q;X@~4%549lD{P>#f!gGUTmz-i_B^R#ZWrEI1odqod74RLlo$Cu{N zfGsxP;eYw>{ojq#U1jD}ixJEZzFS-4<_?=3sSPZr@(ivwm3a^wx9TX1q7PvCGhUoz)L~0f3_jPW1_@yp3hs zO%(?guRiUOA1+&IFC{ZUQuboSV@C(adhpLRX4wsBoK<0--v` zOp6v>n^vo?rKfl*ofl*Q6FsjiRi^sHTdi|vPLHI7SZm3^Tdk`tt}2%W;LlAN&-oS% zqCDk$9|FW@%n+Dc_T6Kf_L8)7h^`4u154=aaVi);^(8S%pp67ejTKvJ4X-dyb^+=#(DMnZLqq*ip zPS;0q!Wt#M7c)qHjc?XX;kT#C^1J7n72I2z;#$=F6|Rx7N}6em!xRu6@{zqA3tf(| zMN;E+!+0;cq#K5YmVvz;7S&W?Fr$-KlmP#Y1_I{a9i}8y@;`8aAc~&J2m%0L1Ws{* zMm%R0pT|7^KGyhFx!lJ6wclcRiUzemPVM%I#v>_F)>`P)zBPt{I1eKu^}EkdAHBmu z(cmzA=kO|x)p-h3%7j`g8@x~ZN}72d==n26up=vy$4ZSw&g(&83{=a#TV4_x5Gq_| z`hI5prrmHZ7;;a}g2TMIUe|hA+UyI-h5sY}5xm}>6+!91V7rX4@IYok)K;1b>N{R- zQndLI9-1>PZq9TfJ6?377uO^F4dhoS4|4)&#(ON&WjxDiE>|zjs#oLLz8zjy@5{kT zX5s_U3j5_2u1Hm>X>yDZnHVFYw6dQ= z<I?gTOTWL(uP5><^dx+B*gm&m3eOuY5)HL(`<+D++k^x5vq z-7kPi4rr`WhnrK;5Fxgt&CPM9I{X<>NxpvlTF=z<9&qvzyq6{QjemgGWEJqB z`es7^Mg|z=Xrm6)v0P|BOQzq}gg@=D#H|eNGKU$A&F}d-d<`fZm%bH2miAt|NL?Z* z*oGRWgnK|ok+!GrJ#nlVvl1zy@y>V-i6l?DogMCP z_&y{fQiHK4@b5D44yjUW-e;rAObIKHa33~~rs|Kgvc7$x%M8kXOf-BhfL;4m{F&R1 z?f=`{*77^IffJ_~d|U`31$_VfRE?`8EOjlFLD0xI^IPq36DF8;J~RJEa$6Rty@>bM z?qA5@_Gf>m(``rf%hSS|D!5o+Vt;ybX}SvE8f%xj8%V8hITJB_1OvBwAnxiFJ)JCp zdjKC79oAUTHC)61&${y`D@U`>om1~Tb$Wj}{Wpn5BSuNPbyVk+#YHb=Ls(C-HDf{C zo2`N4{K`wsY;ihSo5LDQTReneQ`xl9;=E!?+pL|zD5G71-Ae0r%rm21b7gtBmbxR3 zN}kw98S{6u@1Dc0Rs6X=$PRF2sz-)u=PrMG5$9uIcp(K~5cG*XQi>+xXiA}v=zD3f z)^E982bwF3-&X9J^FJ|u;~=?}pbInOW&EM1abIQEB}Okbb!qXNa)l=TDc5zX@TdI2 z+gVCovUZ6@j8Saf!R4ci? zP;z;3mzIMn5!LaWSX!7@otK7-XGLm^NsXW}@p=<8bR% zPVjxdeAbx*jX5y>XLrFXz(e^%XKg;X-1;WAP&@T;cVxoJ;6?#!7gl@9=f8w_@>A*YGy6?@zvmbGkuGu9r5Cq7h{V92}`^Q z+p1e$-mKI9C$XH-s&Ti@ty-FRwt&!8Nvl`PcE0|;#kRpZ=bfra)zzo6yb7EG^k(Ls zZlkF5QqYJyl?Y%zNnGar%dP3Lg+AB~dQ`g&n6cf>RT11smtrFE@64%xFyM?J2HXs& z1quoaRczt4jkTc_8ya-f1tOu)_!BhKM){E z=CRy)|1+v$=yHX^8s}#h?a6ILJ?sAXNHdfp}{xlD{#GziCJA`8D3ZeWj@Zw~_}z|B(Ey zm~lIxnAc&=JaGlKGhKeoem#Cj>{G=j$mm0=J%!g6dJ4gwaZR=`m|uZSQXl|*uV3(m z!4_^Bl!Le}&f5yWtP}Q7H(98n)1ZX@$shWsrO8S)P!oT<3WnxP)bCKV&_sQN!G9sP z52gfYXcYCYSw`0`Qej|aK__I51EGC@I=G%_V8(*b(=Rv6xW5V_N1pk`BK7#_sLX%H z!MdT|MyI|CVjV%}eF6i072(4Se!b6pbzZFpg=~_y^UMV4kla9MpYxKtSAFt1F#xW) zpZA3HE)MnbVk=My$jRmda5J?+!zTKOVzvN$~d+5- z-GY$A{XHOFw?6fHf^f(hA6+YFUk~gQ#8Y~8XW5zI8QCCnOc=9aX|EBT` z`gdm-%J`B4^47ZwFA4rSjU>0wCM|>@jr;C%5I>ZCK{jU_{W(Aos zzpQxwul>Ev{XM;Dl?DV89(i-vf@JZkoP~l7@LP_|r;zPTs)V=D>uVxW8yg!1V04r_ zeubD62391%7FMmi`N)WQOJ3_AHFV(cJ#gA>@xX!b=Lg!BLSisNz?f;KgTij)owql1 zz}wLUjTeNrg4E;JjsU3#W*!Cj;J-CYGBK=?Rg@RaiDP@_Yff99iLyT_{D06G&P2c( zhK_@yoZzQEg&P?%dZmUeez8>$$^%h!X`pGW+Pg zU;>|{wgrmp@Y%TFCaiGUMl#$boNhb)S0_j2!iWOk_L1Gis-}x&Zv#R+M%$i z$)uR@PmthIx6FfoJRB-XuOf7pgo? zbnF|-t4L-8Q*gE+iAsPI1(IvT9nJ#M75%JJ4dk5kwuK4Y++qv2JzojGyDH(*bgH_x zG6F2M{X!Gr$LC%pDMdP__VyDPbh=M@4xQ)~5(|x7hvGY2v~FZr-VN@xe2e1J~c0jU$w$Ym-xUJ}Kta2Q8OcbAv?!rl2?OE*b{*lvX?Oi|}& z>LTJV@6kV*7S!|W>2xCB(7g`kL9cbS-9ZfR9?9xD1@jLuNSS-Xg0y~n=X5!K9kfxo zUMA+ZgFs-VGXCeSSj91X(MKQa96l|w3MdILOk_X+#%vqw!;4i6&5T7t!Hnv*+-xL>5T9P1IGv#h$~Z9hcmu3%DWT(CWHNRhVH@+1q<*q%aY~Pz!;tEVHk`h z@JacHXWQUI40FAr04Xlvzab@1LPjiCZonR=r_?Gbs8^>EpjmksXGE-#PMqK+v46i0 zj9YlKfgxEuhi2y2|DQhwFow?ZgS;Yt+QZXsJ)(kK1P*q4m~?%8ecyZ3F=YAd;Q|iz zEFbOp-==jUP_=*#u)fzH4Q0b%AGB79F=E-Jvq*G)$5NKIVO?MiBjG{;yg9gn@L<;z z5N)G$H~<+zQHUmi?9W)l^gsd_cLgeFyA=UW1uFlq;!n84V0m9!f74IEO^19;OoPl; z018M3$-EA~e|+m1Y6uyOLY5}||I(ju-G)BW7l0*9H_L+0k}Cy$f%eA_OQ7KX^Hqm3 z7|>~gl=p8LU9$hfPgC=d5}>pLv%A1?>mwkLI{k~nH%AV-=trRO^(+71gfUQ#dg?=> zmeT=1XJo32G6jA=>qo%!fb9F7KYV)>linD3hk@nGvwr7~0na_Tc`5Vx&uptGu@C>E z3{nV&3L#Kj>Rgxwf<1GidkMXe_1n7A8Je=jlTKMjPc z;le%N%GCWX9wmU+{8rQ*>qT!IIsOx>aG4vd_)O|oPd5ig6|W@B1g-go1xctwxGxuo zX~n*AM3tjH`8b=5eJ!aI$Np)86abkq#`_q=OUmauiSDt9V`Im%&$rwVxxwTjf5*eC zMabUWbAHLo=H+eH(C8bh5xIw8JoAc^)>XUfS0CJ9hQ4(ucWWYP2Mz?+qAfDr~Wdl|#X5jd{^_U~rjJ{g>QvOmDjC2UL*`u1LHvaBZszS z_8Q8F-5J_6+~5ds|IZ7@YpP)CwJHV*uzA4cBD-Ocw*zb!9_;D)0QSF>&(;{=>D)7Z zTtk`RF{a|thKfDCe!&8UhD?x!peoSTl{HJxsajcG|K_tlbO4-Z*9Jm_|4Q*Yzf*h% zcn2)PMu3N?o&K{O$TF*Nxx1kpLPqA2^bnG1a$ATfAkZu@<0x6xm<M$58x;`FYvN3Ul!A`$ZWZr&$5byWk&YD>pHsfH*O#=a^nBbqYxJm zL8q<(x=p}2C}i3K5<37&9*WDVCM6!64_dNxTGfBHwiKaI1O#py+!{=xkAeY1mvkPm zcu^LLf@^01L|m|o$3qQb$!}`NzhU*^{nah~U;X$Qz&Ib65dd0cV?sS>p1*@J>=Tu4 zr@tLiO(;D9TJwdw@O?1C;ErQ-g*-a(q(1*Xq!ODxw}1!X-Mf6CKBCK?2j33xqxAzN z@fkrgV2|&zzC5@A&E)b40Wm7E6|l%IT;)L!`6z)6kbwhyHcCRTY^wX@aLeUv4==XR zXUZn%OD zK24wTP@z4dcdB7TZhqu6I#%K7GyieX+pZKs0^# z5BzU@9)Y`23ah~S>kTyEg6B`~uT%VS&3*l$AS@x5N06r4_w-7*_7}I_-RlJZ4Xlvj{cv){pkElf(%EoN-om;}f5a#iX z2<>m>BeND1l?*`ltu?ON&VuB6Fg60@jNcJ~LwRVC1TKsk0sj2AOxylmE#UYm z(Z&M0q!+mX;7ke<15O{7?!SCW9zBLW@7Ze&XB)|^YKn@Om#ZSJ*^n0wu$xeT#Bsn4 z_%lefuxy5>U$x61)pK<=*p@B)@^t@*W+*r`EKo+e0o=Y~9%6!;GFY=2 zC_+5-QCF&=aDsthC`-n}=f9OD767EMcz+ z^1wYqr-!~WWxQlikL(AqAnk7qiM=T&C02X^sL+zHMzY{yF6at?N-yKSjM(oY z9Tiod0K{BUy-Y_a@8Gh7_{o**tB)4j%~_8Y0m3K>6%q^B`$mde{)N0;Q%!-~ZZW!l z(2&59-W1gUhha8R90bO&OaNm#{eGhP9@v|%1DE-}^Lp?}g@Cy&-K~cSrVEhVkeoxk zqVDfb9`B$tPEXAm>ChP51G0|jO>=fYLhZ0c#{5q4Y!e+;P(+7+Ap`OeGTi6DUU0e# zzYw&6EY%k)Af88GQ^k~rJ|&=x{{tWT??0v5wL4E>lE8B!GdP2FAcB&&0C23~Hjw_l z5t!{%9VO0SwQ8=O??m-1C8*{tKY(u+rq9icln?JUnVFes+`D%VZc|@E$r-EM>&Z*h zDsHuGX+QC}+ttb4FY57DRlo7Bx4%$I!n(fqaOb8zBfm2P{&v+^iQy#BM1Db3_8WVr zdx~vVS?)^Nwm){~Z#WYBRQU&}HtPL5?sY+NZ~MQAd!YIL&$;n`YeNb1)Jb;eLWQ*N z+C`$qD!%RUt{_q4201;hJPA)S(mnH7sjNi^{nzPCQ6|T7n8(dqx1DtaEV=A#Y}mo6 zdtME6o?c2G5wOFHg^I^Z8vhz4CFP#z>BXYZ zML$v%4o~4L*by}L=5y_e6Bcb{_cP4FN~&3OH!8dBqF$__vF6E}RWr1;&`sOfZ#VG~ z45@(V1vz9g%=3uh?S$FvYLiJ+bhMT6xhTi4ih@q>da2pkjTIl0VR2h0@_ctQGtms{ z;gQXYhmYN#m$bM{_MHD&`FNN?@dYR%bT0ORV*G!Hx5YiR!e>ZuxR$wXHQ>Dam>gNx z`Tm&qXS1DwX%|3;6<^M+7ZYtiE!%$s0I82t5&>a(6gV!@E8Dy&Myi+G<#<9bHpL}l zvQ*D|%hG}MXoZHvu0kW&&$RW}%mIj?zy_^&M*j;}$ZpNy`)Ib%H*VyPRj}=Fx(aZX z9Oi+R4SnhWI@_bfF>n%%$b$DJKb^^xmzL%cUb{USfQrtRpU7*_-?`+WR$tLYrS@Z9 z6oJ!mZpaM0)gmcm>Lv2wp1ZvJYcv|DxqR{(Yr#vEdYv})uW8vjfdxn`-Wq9twuJbhCU`ua}fic@WaZ%E{CsvW+kE~9~QeEsIa5jH!pN`b*~S& zz852!>ah_e?pX*PFdDfCl{E_M0tg-P>FjcbP3^G-r9q z{l{ffw#V6Kx0eKW?{XIM3s*lxdtG=Z@tFuwPQM;so`R7F*4G0rg&M#Hn12q`WxxB= zz+q+>dKXpWzyVtMb#is*vu(yvv>jx&^cfr~v9A9uwkc?#mn`l8+`@|!09pxtFx6}xdY}&)v^O!Eb z--T?1Q-;r}VTLlmjevLv7?aiesWC5DE$=l&Q%Cg?oFqf+bFeBQ;Z^s|pF3BO={gmk zkpfseLo$?lOAWEGw0!ZEihtz7aAv_hVslj#aFZo2mpP@gGbNw*pVbqLfcHhwvv6eS z>F%vzmmO^NOfdF!oAe)FJ;U#r3wB?fqRu3{d!vzw14m&5*}s-6h%r|xHurCa3tWaC zfI{#w+du~7H_SFXj^f`CCrwKc;1(Cexs3qXNya@PXRTmi)}Ch+-WYs zW#TxQ$KF|!J7ci%U&Ho4XY*)r+Z3K#tCehjEA8?rNfF~Ho9=S0)GL}>lDe@B)3IcY zL`gC>b7f3#xO$@0-RO;G>z6fHSDbJi@KL`@{eL5ek`c9g%aY|~T|?GM=$D?z!+yaA zv#4%CP2jvFzLxI3EGXJd#!Eb+;l7e+Fb0%!%dW;fPQJ4)1hySILy#jK_k@e1nW#X% zds>4jM7qkO>9RTPiLen|@l9v?gDotx1|gw0QLhxh`~K<@&sQeEPoA3Jws#kNIDiRVG2-IdfZUM zR3`9RGX*;lUWYP0?!5WL5I+(h&`PUnRLochpq)36k@D7s$X6Hi%cHv18u0eJxz#Ex z+L@MN@E!3HwvMV4kL_7gJgxy?^3`+p)qO-|H^_#B-?@7p2{s8NNDzZDDQzewNoKY+ zP=N|*fISI6KC8|89-BfA^D6`b2*bF@tozt&l`m<^A~6#|3jFO@e+HHja&B7DzO>VP z@ZcSg;_byM(Z}1FHGpWn4t4~+t$8#xGvPunr5`m=RdF%*>oUvdRUCUp31Gb;R-RN@ zbabw%t{R3T&k@!tOliUe~O`3Z31MB3s!S++8`jGOt8w1hl#R# z1#a00^}~RsTxa*-LBmT$s^I~*{xH(()?eTur@%Q5ILRW6pj8b(S&$HiFD>;Q?2b&j zS?A&l%QeKW`@M4p2(wN19*~AG-R-KHu28Q6TD-p7j{$b269I32R+C_tE=0+6O$|0& zukX|J$CdyOeyAK_x&;uBW3VMB<_xMy-h3;rg!+a0n|RlO3PsB1&*?b>m~#FXR844R z`{(pLI{_C>;wY#a2Kop(3D7gN;n#2}P%(T3SkBvHV8)^5(9Q#Ja?OumFO=@1o7=#8 z^BQ&}?!`mC3ojbKm+FQ52Oof-eAYg)+WF<6pFrfck?p+ z%Ro@*;YlXo%_<#XtHGeX4wqoV42#*CkB`LHI>-cW3!Doizn}0p*n6m8?sv*w8PieK zJq`xsz;tXKY&rq)ux-Jw8Eh%BrggeA)uS^AjA8kr16QaJhK4UK#Ul1ye8K**L>qY2 zf`mjTKE9yO>`sTP;qpQcIU%EjAtQd)3yGVfdn*H5qUlNADJID*{Nw6miZBQU!O4V* zlv_2_@YKSHe*5!1bo;LXx&={6T9)pFc}pE_NC7!PmLK~vo~j8d`z?dIZ0;D{WpX%RzzdAv zUAEV&?dotToJ%N7wdDy5g^E<)6~Nfbc2CbtDl~i2to?gsAIbYGYb&SSu`~ncpt*<@ z*;H|6WE7hzFl=*mMM=F@=_NoaeZn@emUa5}7IAEF*P(hN{+6Ct;D@6-uWOZ^G;Qvg z3Tk&eH&LUXXV355*pEPs?T*yS^w`W{)1+ck3;OP|M z<1<~D^(|sX(o6q39dQ6=WQY#T0d9#%oX9)rc3=Ue8v+45Y2kP<6U6?FYb_`iv@i1* zsT90ID9!9{4<-+|Gx~i4j<~rlNYI?9Ei41XJda#MRCc&%@k8vuE+~wc>ej&sD<9Yh z-2Lf)ww@2MHg&Wr_gsBpQ5Qd-w#LWd*CjGVcWpriDhVEjw-w2@^)h+PUWs=JbZdp8=fW$h5MWSE2Bvv~UklOO#A z@L1=s1I|)F<9&Pp);ji`oif>iDdmU2-yG5J)euVpi`m)HPKu}-Tk&=M=WptXg`$LI zq5;tSw^n_dLlI#Z7uUcY+L^}Qe{sm1YKyf-6-Y-vj+x~WU@v5cst6_B52ZOt@ zP$ob7fjcl{ML_#aWQBb@)7fCu6rBGQSWT#e(z}IyU3HVVhBqSAf!|ZzUm@%c)(?*X zV2INu4)}_ze9)h*MLBI+LYh!d^?3qEaf1jpSkg|5uBK!nkgyTK6ILi5EmVUwM*~kP zWIDFy;r9)d7E+8wGnkI|`+3W?fR@?%45AMbOY8)wS8deH3?utNgYDS&GCI&Ma3T0h zwa2;+cae)R*vFyJ_Yzpf1xLTwTr%)$ngF21@~%6b8XxdnDokey**C2Ew(~3-zh{lf zti&Ly^Re4IN;>M(KgPBl9L%$I-90j|?z(k#B)O-+IlG<|=?u!nQrtwc0qeu_Pc$sZ zgC$z{#v^ZqF!uqWlO8C!*m^fAfv;J0xEC4l^W2iIG)f@Hy7ASO^z$`TE(7h_Ur)b0 zZctIao`Hcwg7%UuhUO)kPOhJFIG!>&2@X5+y;BP5UpWC>Iyu6zR9XtM%xFBflv&wd z@~|t&sW4;6J#+fRNJ8TtHV`=(3Pn!5DJV2Y&OYTfqB>PMlJrC}TH z8hYEIURv9v88$?Po-xTs&%ZL^%JYRk{~+j(wS2$LD7S(NPQ>UH7z7KAthShFNV*CF zOES$!eZyp8y5EO2$RLgy50Nh4B)re|A1%Kax=V-E2$3~RYV85knqhKNv9!1sc;@yUCV#Gl(Mx!Kw48|_r zM`K+}-nOwAa3RV*KcYhqCFbXE4!iq9iHBRC7r8Ox*H#5$60{Mf5XGRQ1nGtHJvcks zO4~@32w92J+vM*Hpt}>`84sOAXc<;X&OALrk-zysXML(X3tHKW71TByvpP3@?2uf1wLJbJAvcD*G_8mZBN3olknOP3I z9u)%#W9}1O60Ni^(5%QE13t{iGCuPPaNQ+0j6f|x0+UH{VpnaBHmaBHXaxN|2$>40 z7hAzG)FtmwoaB6Dc&0bLgQ1#;gSK<7tdyvdBp7OoeFpdKwe&n~8;n=KZPi?ID2tWH+_DCnGXP+y3z=0!Z)H<48#tC=UQhODsl&wg` z)IA(aDT|7Ptd)__o_r{CT ztlLARH$|&vFeU@jZffp-H7d~MV@x!HP86j^%hUIoaqJu|-4~}E`*sVa{>-U)h{Rpi zULD$Ra2|@CHNPB0As&gf%Kf)(Rxs7Ei6Aul{IbW1J1V=HVkKwAWBF*N&ji$Ny@s8h ziOV9YYkO@I+4+JZ$qh+UL|t>5fDbvprGYaAzC8hx)K?TbXgF#D^plV=4^8qS_*3`` z|AR;t!)t*nsSuKYYlt$4;)tre3>MzTgQ?W41c3>_+TDQDPqNiZV=nY1p)qLRc>?;T zNWg7Bu&0VHlC2>D4dul_?{=ITzo0R7J2Aov!qL#PP>f{huMyj{UxIbCF_2oTjSE_R z_sWL0nxX6$kZ{EVIbE{w{@O^=+bHpV!_Sony%@&R1s3cnLHAX4G0^1=odRww2h?hy z?1A$VBb!q237E3MdPl^UA25`RLEd+9>!0$|d{AoA!RCThx|ms`L~bnW%9UKH1h}-Z zlyo5qB>9uXVm@k8!aK`@p}8}}DZeAU50Fx`+W{!TRV+_IYJ z4*qGcyl3wU)K}j3Nn^nK$73pb$sFR@8E_+ES$+No-YRueW74KSwI4CNUe+LwT$`D8 zHm|hQsTE_Devw7yW{jBF6uH*%VmTvOe1C3f4R!9axa|b&+gNUQzET9+C#m2Y^)*_w zMj6bJPh;H!VUw%_x>>)Bq+Sh8@Ja7n;@T6fjM0TXJ&Z+nRGn8*>ap9$c6eZ{CQC&y z!Sh55_3JxBfzL*K$RD>34qhrEjyP%>yKe^s<#}MhRSTV1w{sN{lp_6+kA87IZomI- zdx|@;`0eZOwzGI)9jt1hQX|ZSCBqBGG%^fXYver7&Z+kwJkcOZ2xArSb-s@ji5b+ZXhYod96y_`YbjX>kZ6(Cs5cbN`iJE=XtbxTrK>PA1FG2eGMN&i=4o*St33)0h^XN%$HFu zFu(AB1IvZo##A_I{oAAh^Z*=ZJ(fG!*CF5AIk^r>N(8hm_B-+82b9=5T7tt56OEjg ztwtjI;vEY9IC?^O7Ia^v*sUH4h-KH3T04wyNm|6(>{aMeT7SOJie=qS^?a6-o!7m4 z!juQ>P_Bta<7Rjm;qP$3-mtc`-Dejla`2g)S*-&EwN!`Lyz0jfdNRFkz+z8sh*iuT zTOZGcs#cVm$;Hl&6@3xAr7-bMPCc8|H&f|(m**pl07b*p&Lfz|EG2XAqG_iIvzoDW zMpKnj*&y7tGzd1rAw#wG-u4k6U+P@$?4hPaw}pahMQ@IA*8}&&cV6yj*0gSo1RqO- zEx-zx{O@-ulSAKvt;zIxyL`-44`T%uBW;GMMfxs0r;Y6MXiYW-QOJcY3Q*w`+jxac zO^5jwX97{`Y;zl9x;Hyak~4@q(=$Dj^LbZ$tZw`A`5l#AODK7QMtR_bDSrQx^{OI+ zo@UnFdl!|ai*vV5-&Rd>j$Ch)Ua4P4!GM1)=n~D-e&ci5Bz#!@!2}oSWlt%QJNi2$ zetm2s_oz^E$Q{*8LUr53Vd;}Xj{mXViY%4M0I(Ee0sU<-V6-+kaXGf4yPz8PdSGKS zp!m+O5LV#MM~_pBIi9{K_1eMw1RST?|!EfEfys8621aMbWRoj`nd$;XDn>04O&f z!wN61j$4z!2Hn}^7vd~gU|CL@`!Wc$_(W^Up_?Jn8yLx02)LGAS7+_rZqa2zLo8#8 z1sd_TyjlanQz1eq>;>Q@7^51dwduMozDJ*v6NFK1K2?y=7w|kQ_&iM-r`_IcGZqtX z{i&cM``R-HdLwJ@`)XiAc_Nbb)CmJF6>G43WkcOp^r6mUpd{=x&yB4tCUO=X*F~<* zSFeq|p!H{TVQ^0U75BZ^COaXw6$ey23l{0HN)A2FY{WgAEGfHw{@2*5Z|;qB`chcq zf}boiV)+&%N&SXlkKgs4C(^#tUC4RjaL@Hb%mLp@Rx?`+v$`UH{>#bLRLZqzJOh+{ zOPZv&XHmr?oZ?E6a`kNn7S3qs&oP)IjVwM39Z2u3+|!IheQc~+Un*-^-K$go74sc) zug)6_MFs3$r0Q%^!y~77gD*ehCwCs!c)hc53hl_6IsbxwxcXW0jXqp1#IWZkQRby5 zebGFP^;%g({c@37khh@dVC4d~)k&8G{e|wDbaB}tdjW|o)zh5RV<&SNWY_-7_d4!9 zNcXV@s}=l-Gq!hB^w-aHqpXg&Q8C~;UGoo8D7nm~1@?pQSBQf8uxnd`*p$(_ATaj? zYv_{iq1@05A3_tNBKu5yGUEQcvP#~^I7rdkoyme$uM%UOeU`2XS@)9kctHaT!sUR+ zPZ0vdEco$9v1;d4FyT96pDDCoP)!^D7;YJKf`i|1`= zlP^OebB=p=TTw93c?IgR#&L~ifd$!z4uW~QO`Xl)N1~u0BxJ+*MJ#pFhN1&R3b}>X zH_TzZLE`8#ska}eeqQaJ{9=Yby#8m*q@uW?CUIu+`P9@}77@i&Yxr^dJtDcf<&MeK zx&2!VDOwg;OgP#HezBf2o$4F12%Umdc1=fe3a9w_iWTHaI8zvbQK9kI@pUQ@+aXl(NZh_wBrYWjW9E|OziugjRP+~^ZCJg14n>eD3*CMLn8fZcG`oG3(;OE z)|t;XxZRy6?bgeEkOFm07pp$3I=L* zV0C~dg5NZ?8umiyhdl0ByAIwWi{w1HBk>`yxlCJx)~SQ^>`Sm5Sn3C~Y#pG;9z0{7 z1GV2hKLDtz`T90Df$5NK8AxR%elR{FgLC7gwRok#iB^%lVg0 zd=;|#ldbILj)w(4I(0JVzRkkl_uMk1K2?7Z&%Nj0!-UV%n?s@|N)>BgA-}|B@j6j6 z2u9S=cue%&(l6BMISJv*Jb$lO1XuPE2j;aI2ao949y2^}(2c|cac2|ms2&9C=9x<2 zr|umiGrL~=oG*?tO^Ta0b!|Yr<_87d236^`xjT1mJAIsi^0~eYliAMKoNKciCFaGI zIqa^cg9Za472+RfW?2s#^Hk?fID{;Q%*M7O%fIgru98ibh8P`+fhMFqf;u69SxwfB zl2VMyF}5a_9qdpyflKYoH4crs72=2H!kM;J(`{*^E)!N0<21A}&jf3zF@L1S zCP^gUhVI1rGYXH8V3|fEE#_UX%;I&*7#>OmcdM~qk)e{|M8VlP9dh>51-taJ zigQdmVvYcHK2i1sR!T)C0a6A9-?ojXPrz>%A03H;05Did!0pU?RbXP7TpA2^g%evT zVp@=(kk%1R-mmZ_k|9K41DJbeWpx~%q?s1fGoj|9g`p#CP=tokL`D=s?=zrqd4&y zIPOXd@I=Kblg-en>3<@5dZC52*JW;N2KWRFe-w`6L1?%!>=tz@9~Z7M|eHXBt(^igDtj_=NdA`5k_0nPwymuE*D_wpF2QSgKoH+mokq zJN^E*B6<-EXzj3W_s(bn9P&`w(fBLL*rTF@A zT+60Nyq58cmm9(wf;K|0AA#6Zk0|;cR?&;(ISG_#xe(bK3wcgCZd!UQW72M7O{){> zjshPSk3=(lV8!;5S?#0c9kwod+SPfK>S-~-b98~hv_m$K3QQlpH5dh=xYCGd#h5&= zt>m#{G|9(5j*F7g4xoN!nItsQlc+;4Lxcm{O_L;f)F@S9l#?+C)*x;KT9@&kU`|n2 zxe@5AvoB^?g;IceGoi37A3PD4rTN7h?Aq0a)znAUfaWEP(haO#9S@L3pRqCU-PF@X zwE;(o8xZ+!$n;Vgn6N0v4}GS-_b3}Hti9(%rtGxxBZu`0*UC+H3GY7 z>Vdm}wubbjM->CfhUe02k$&#{bjuEl3`}EeRGGe2ts`hqscVp%ap!5-XEOVRj^Gcf z__bzo4by(OHej%Rr_b?+Co@Bo|4AoP|F6rrUny4OxGqCrxuNAjh3gKyKz`Px06Xe| zcQ1mF@Ud6+sgTq^dz8zT8zTwfuHV!PXo)4GZz&s6?>@)8F2s8K{t<(3EkXDeEVfok zeEnVg0yt|bP>mB_VtVtNeeBAowql`2eQ>1E2|~ehf!r8>V)F^p)|tx995%5if85CrRw|=IPYYi;Xe-puk8eUsu!|5B^VO|~TB)MREm_ScqB50hjp8R@C zZ%2WAYFFYN{>hnTk=;YgyQL1dKBQTzfA_bZ(_ZUstL4kpK2tTWeyw?vTL-Zi!>q=w zcW;E83f>ad%Bfl2sAA@H8d=?Cut8-jh$&Ahz{dCRK7ZM9z|V{V3;qj_AA%r`;bjra z5qX$Wwq4My-q@00S)HMyql^0z|McERNndbo9~tb=ZVg<9re8EfN2W0Sd{$#bee=kJ zuu|$O!ODH$P$h;vNt}SU6m4a5?f>xhmSI`0UArhDEg=%pB_OS!AV?!hN=P?ImxLgl z7J_uQN`u5hNJ=QJf;1>0-AFg=aXa7l+wWT6+H3z<$FcrAKbXjT?(4oro#Pzm$ZiO? z?E(nu<9O6>>c@tb^GO&M6aBo|F)$LvaOc|P*a0K`uV#FQUTzZLz` zPEeri3exIV3c3Hltj)P&K$dJMMbZwXfm5IXXshQ;pvjx$W|_qcu!Q~EmZ3~_bL7jm zmEBsD>=LmrpHy*sr{mKLXJ$>r7)vrxJ_qcY`DAu#ob&vNO##6rsTJk(8 zw~a(yuPbgcooat5TO3Kh@}AsWO^)dEt$;78^~%&{%C=E+*GH|&sqj(u9D(t#h@=hL z_}p*+eYSn>CW`m5YkBJBR?nP(#aXR{JU7ZnN_g$a;Yjtp%F!Aa4^b$$Z3=6VkgP$5 zsfj1=kK=EK>PHHD0i@S+9vKmBKT71PkT^iOx`|0*xQ2kctGu}`3g(roxmkvbA}gVL z=%tyNxJyxD;l2;vP_*JYz51f96ne$O^(y|il>_NVUmj|Oa(ku-H@3k#d9K4D>_JS_ zRILP5r=$Mor*l9ySYyUS3xy>Hdek&Db-fQjg&rf=dMo4kjL-{^MvP6ig zgMQ_sF6kse@o-aR9LY|Q(zR+S^X-+TG`yPOyO$)~?k@J*7756RiwR>ie*=NA1ZzndAG7 z=f1k1d^}0L%7*Qj$qtB(zJ4fgyxlC^C40K~uYy>RtjXFjBl_F3WHG!Q{;Y z$(z4)*!zxUD1Ay*3GKL@50Z=ptU6p?)CQ}>#ijR_d~ou=p8mpG>F3pA)u{1hdCGIbyIJ{om&vm}J~R*Fe$Z9D-}6r=rEcYnr>Ce|Ya&97 zOKo)*JyMoOZ~7|*zeJZZz)50PaL zvi5NtFO6SQJ8zAuG2Myz9*VJd-*LpK1A{w}GT2|`xS0}45it&Dh9l?m!>0_#0tqSS z$zRJ=K3TgzXh};Yf=hWDS7w8sW2#3?GpB1v@&;VWo1_%zOA7gDJ3J9>Ik9@SDyJ=c zlpWJK&o9qcHhhdjAV$e}q=Tl@+%Z6ih8gf8H{Gl2!tTc8qz-Obyb%A)3@sW~(;Fa% ziRn9i{1fa(s`Eb38)9o1YYcmaPhS(&pU(;6XK10n0Ap@j#xo+iM9GL#^{Y9QS;^fD zQdF$^t(c6HyFg=eEQ&%qx+3h7RisAriqRH7cGHQ#}CW$btS zbCPZUcO_qE2hJ36QI({S!9tJUoqBRaSZ4@m zWunY1chwNBo<-ArQcrAV_ZIprTF_?^p1ny$JFc+&?#QutKyl$`kt^Os4kIWGTJ07D zSta8&Q99`c&cS6JeG~=`ZIn)~=)P4pp^xmCg18-bx!Kaw2tKFM&QBJVpY7GN`;bDh z(KX-jGvP7g>}-#aZq2-_VO&h6`gq9FmBKqg+~Jh^CISAZIZZ)SVzTJk`_c+0>3(vN ztun-=1fEE@B^7__`M~#_g|J28-dO3#Y->6^)Ds0E01Q26qwbpnXK#7um&b+PX<%^0 zSaipC!i_NpRdd5>9bz=Tk?->`<$*F9U%$-f3DlqplN3ft@{yJWgwn;H?!km!%r78Ew=EOUb0b^%4xwGg36wyg>!H3Ql71sCxcZlxiPx~Y z*+=Mc51WXm@Tz2k+~}%3x92;7Wo_Su{R;(L!mFym9*Zn;IAVy1%tQYs_5eq;CbU;U z`lTj;HM}}@=#fENXggC^0tl+;TPoi@fR!8($X>u(roESId)5BlIyCaq&`#ABrM%Jm zj{~@Ea~h>6liKg@#pHj^?N&`<1V{7J3KtctHEF(1BNZRID!f7;D=J;X=bX;5dv?fldnnR(z zQlGcL`vFTgUB|frU3W~z?ui)}FPd&?i1o_>C{@{6A01|*fjeVHz-TnG^MDI?de+wQ zUq{Rp(88r8Smn(2KW)==4vW=9?$1!HSbq%UdG75PY=#`<0iWi6jKr!duH{b(Rdhrb zZfJgICa@i>d-}Kr@HGqA)An7nv&b}n1lins*vjV@b#Ke+%`ZaH#_hC?E8RB(u9uR3 z2lT?G>QD25ZDX%-om{sl?&W$urBMXtX^v+6$$#0hmPBm)^1$ zjUoGEv3CN@J1DQ?ND}K9K;t$7;h@n2MZcUdwlW+6_t-xAN{jm+b{o&aeA110Xquos zlcLYVENW^Z{&^`4dM%dvOsSvId6E9`o46}Y*RCByVG@%HOKaPfiIy?_CAkjDwKRa2 z$SVwNZ6(?PyhB>I0W}v`1D$-k_aa3Avcnkj^As(0DohY#5b%F~Ni5_RBwmVv8Gvsw zqenst-9Rsuf`5Yr6XG@sl+VJ zh8sjJ`yTRSq*Utnj3y=xW}M>FE?;P={D zyh;}Q#0i$E2vR&X!9t+7TNL`G%4B0SZZ!~X-VBo2ZC?0E*8F$5AN1`xA8H0aK z=>w$e5f3itZsE?}N@NVx#yXXYaHAI42aw*mL^A;@!dkiMavUQ$LbGYtPln|J*(Bym zoMDO#jy$o3*TODo8+m4z6-0^M%jL_Yl%xi!&$XS^`TGGJi0BNasZJ!;b-XH!pzP&*Z$Ii-dPXGt*M-%L6siy%~Vawot?wn>i`ItO=vvH7Jl#scGBA~V-kpsxLZSW zQozQZJ}*=Z+Zsa-Tk>MngOp9N$Zvy^RLdhKpu-jw4q;gd1$a+P@AuY`GJqP(@6F-Dy*IQ# z0OXx`OnQUs@21)rL?}&+oMpTy9+(%-_+HN(Qm(0t6iNzJQjW})`;9MmFESgg>`LV< za>(puo*Ox`6jn$cW;dgV>y_zgvc)Q7sh)~qkielN6vX{}DURYNFLsXS)|dNOf23`# z`hC9&Dr)7#z6q_<4JzzWY_vn#7?%0<(8d@uB}g(pUDlQlV7zq^?LyUX4G>WuHh6!z zF%?z7B3D0oPcmBhO_2+L(i*NUo}_u6m}OKEt}hcML=*e)hf^PFT0 z?RJnyfQjR$L(7NrP$XPr3=(xx=HcP#tT(B|g14c9!6;zT6foKS5 zzA6&syoxd0Z&2Y+(J)yuG~l5$vK~T@DP|QoELSUBZ@90*D=JbGmIm-$8!rD-!iKp3 z97=d)H@OXMgSfN%hvtp0;>*M)05;XXf7 z{-J3iUEE1Ff&;bOH67b0byt#?+aF1gHxcBgDenuw;*2EqCu$!pQ~f4M%sU752C-0? zIY@tU7T6p@`n_30v?O0(IF|5Fxa*^H6a6Npe9XlXx|(!)4K(dAW_+h z30+7vXr8uu>pgunMV72l)u)b#f4`R;#@- zDjc8Y?H>MH z=8y$(3+uCJo2dNMv;H&HY)O-BkZewHjAUnZ8=t(Z7|Y)%%~wg6UUh-}pWP^nZdA^a zWBQec1gE6;sU?{y$k6l9(OC7x-M>}^;Bn_5tcEtd@4%>Kh4H3vbQ7q@Eyiozuv4Xy zHAt|p<1&)?v@lq_{p!Q}vyzEeN4uGQtJ&CAH_6ko8`F7p_ot6)hL4*|7gxDb`oq83 zn&iG#6EE8gCzQLsJ2|kRHd} zP!>k<`g3!8dI%T2IHe?V+6q{ik3DY+W)G-)FYgobpk|l7vnxAL(0u9!SRuX;?_Hzz;8Q^!N9-oQy-WjJP71?!)_ z;)ei3r8dn(yj#bM;0u`v761PoDj02l=@aNh$3Dx%H(Vetg5)C)s8rLC8^vzwTSXc@S}E}p#n$KTYxMecP1`0*`5C3pk}$^T7e zl=zM1|2s1Z;`V&P+!1}Bv7{MdG^dV|S1}*?dp#@aAnr6hGM~GV*(RO%N&%GW!HHyv z#eWu&EBu_`;|(mS_;=aeDUlSqTMDxx7*pgc*ZqurHe31qPCSr;b`F#@V}^#b%F{^b zxl3;mPxn)YYKCPXj8k;ZUOJ>Gs1^jHd~6!TZDETEAEaY7SMKb=p$z5GGQoE#1eaL% zV!{QCPr0!#2hut{->5rM@qKn*Nq=zxb2$xa%dN$`kUvNo1w$67S(*rd)@!v70)$fW z4|wp!0Z=P7$6|v&a4peq(#ekLzT3?JeS$gl+}sz>rI`XX0% z0D0IFZaEskG^X|2f9SFBWexv+S#jjcLYV!(@MW(^QQWs7MbF04y%$c)0URpXbdiJ0 zJ}|VH%;R*@`C(x`Xp@?Rpa~WjvUM0KmTo-#Ym&npC_f#*M^(jO(wbsHI}5DpUdrQQQW}|IMR=hr9Hu`5!a+ zdon<&vhaoeYZ5hdypo=|^S(gb&D=#K>vQ;5)`zrC;{fQ5K%Xq5t_s|EM6ZkoMa%YWl|C({J!= z?q*9#^RuBcl<33_Uvymna7j`0itL5UZ_W=2p^En3ee(4UW9a*vbNHo?h=K2;xVbt9 zIZLr3(!njayecrW2ySA0n>d1=HN$JNfM!MB|NqNRjrpTzF9bQ;M`0 zh-bKa3iMTqQL`P`ehDDz9KOBPnmG4!(Z2XQLv)5;uyTd{E-pZ*wWG1}mY?^n)U z++EPUPr4sOJUpvUqe0x^0*iG!o>LlKgKV&8l?K1O7;>`Rw!dmPFrF@V77Pu@)d(H|cz>2)_CDxQ~(rQniY|US=*d(aGAdRF(Km zRQJK?=pu2~+OcD)5z5j=NklI?SG42(-`f}`dWD#GvHwO|!fQpdHLfhHvC-cEj$w;r zF%4avp(cvqkm-p;=y$$Fx*48k;ymZ9ucsRgfort0N zpUUaCABnzEtDW*YJ?cbQB*=_50cTc2gCRm(@5)Xv;yelBk`uX?^IzwOuK{TX3l%*9 zc$RCKJwNbJ2i$oN-7n~-#&V=IBDq?uvdcqN$>$Qg&)IC~qkV}O$Jks>zprFlu(Icj zS6!5%_^1%aNq@wD4+2!S`iLUjAq36gf#(;wq2-C>a1X4YNk>o-kb2^2$u&&(r{*L?J^O6xxmwn4zBSt^YvT zBId{?<<$5s7}ebwVjPZBDdjH8%WNkR(%o|Zj*`s}t*Bi8v77)`Jp7G(T~hOWkM5sP z>nR?lm~9cSn1AeN(IqfCaMGZ^{d`^r^?GJXKC)FR5#|setJvghs8gHjIt|=@Y$Rf@Gm3o=}@$flq=wcf>q+yOer_YV3rQ)T@ z6+<>A-+3$KjL%0hhCX`O#|&nAHvTZKvcqLiQ@gWfA$`Ro%@N~|g?q0QPA4>NrrAVh z{xJl!zmdlTfp197sDpq%LhAK^q~Nl!uT;!U`{;4$mOK6{kvNToHwn`G;md8RbThl6 z9uZ5$Gh^*U+D}zU3I8r@D-u50&C>FG%5pd17KT_IH1)DWuUhPZGqh}WrOf>nbwtx# zKDo}yH5t)KckJe9Oc62nJ?QOo4*13E4Oi+akjKvb=dm}5uU;E|SBqro)hlx35p%}p zg$LOhnMme|KTxO_3jXD-#fkotebtZxWd%4C@Q$|?W}xvZ9BF}+EWiquH%u#5Ft66( zW(%d{i-&>xhp@_9fJ)4(A`I3JV-wS-UJ`H?0X?6jn>Ol0*S-FM-2Wa%Qz4S9f&1Dl zu-shQuX|_I=6}aDj#VI`(Jxna;?(qkkr3Q{D+!Ia&;3&D!yd^aM45~o3c~=Co0`lDXCdXuLqjQe-8?+ zSz%@uL8^P-?X)hJ69_;s3BU}WE9(9Vi=HlFY=3T5y67?R3d+wgAX2raeoY#=MxqU}milUKxH+eGZsZhsa z=cs(V_WEk1zh8#de7$XIayp}80kvpx1(FF=BCb>fCL&l=D=$Fq(nkuB)zuZVm$B!fCrW_;r-VeM>v-8uooNTPD)!`2Gs6%W}NNhW*JunA02rm3A^T``fPRD zGz(aMjf8?k@LL4?e+2=^V`Px0tq1S>);h{_QRG)=k6g7Z3>JNtW=f=F; zl=%g&{0JBs%kd>#3`xp7?%ik*B)r z!GoMV|2|?FeP+iW7t`Y%goNdjF69Su@o%^0)6~SH04}8++}%hopzT1mHFrbPHKD%k z&U&>>I5d7|J&S{La@ce$tfMYc@n5A{D<~|Eh9sZQ^E>80`P*O0`$0BTSNHrb=$@!o zR{OA}X|uv_GAI2Jd@AXvOjS9_`Eu6aVMsPWWru zNqCjBK{D?zuyPWTM&F1!HDCU-#IM07SDmdpEp-wdf0@FsGVq&K%nzx`t{_bYH_j4W zorxHNZagWUvjoP_o*;_!G$h866{i_^s4_N4PrZV0=>pBg{@t8mO>y~NsrSDB0f9_fIQx%`hZau8WJ0Yj=+VPW^5=Bm8y7f|T^7Gt+yh#4-3HPt!AxR}N}+ZgDP*+c0U-god+MU%)J zn&UMWVKIW|8<-}nMS7*3+deoEVta|Jpnlbf0G%nR9@8P6uJr%m{hRm@@4qE**%n@Q zbXA^DxL>K-3uU?b`(obNyxhj$p3l8|_BG)ZqoVk81$sv$FJes?D+`c|=?Ht?pmgxf z*od!iKo7(4H+XRl0|9?P(>AqL{9R}&wdHb7-!k!(Pp*|oY4PU7)!bC`xecjs@T;#a zLRiTwA`71{KTep);5pG>{g5vE6bbkN$V^MF>-zdEatJlDMain?N5+E&(Gld%gTubu zC3RM^(B&fl%lI4MVeo`-cd7a<~m6_3rXUGyN(OJg|vCK7B#tinvcTw?VZk zHI2|2wKj_}a9>?ETPYV5+O8vbU@S&r5#gf0^R^YA7;n^@l4d&Uu!N9j1n%^De17VI zuel<9#&7lBRjn&JOQKI*EOibE@iai)dh@6*F7xc=Qfc1d_DaJ2s#obpHc-Kxs!Bkq z|IW}fcZM0&W8s!VB;@ZI%**ZCoO_iJ0{vAKMRfk(3-m7e+?lw#)<}?wJW_&oXxe{X z?CO0eVjQiJk3jo5oF|bIyRDEyrFxVa9GNMti^posqip5|0r}F#N}5o?N@n3;3{6Vd z5@(v5G;?yHXz5Yp<>cF*aG3ZMs*oPt(P+1&vJkuMAj*AvBUboZaZSzs=ESjT{QUGo zUzZ_+(2(Okc9-~R*G&J@o`u(%qIE&XKCN~Jr^p=C+c*CXv#;~>$DHQ}IZdqA6%8K4 zjl($pd*fl6I8~k=d+@sLf>^DnwAG2|Rc7D#` z<}+l5wFr&kM258QgBV8b2*bcetcPe;$sOOaos!UA5vePxG2f2jC|*DR+Twpboa76U zW|@=45mB)jmY3FGueO^vHR}hAO_9@HWBWE!&v!D%oJKG1WMTdsW;9^nm~nVB7EGh$ zA1)udH>dUHGnSvkQ5d*6Bh_@qP=*4EjQcW;0jY;xG2G`WN{*8W1mdk+N$jjhK^Z4L zf_Qg@61QT~duvqujNM%2WY2QA6u>}bJLexU*UOmP4Uw0mOeVEcW)k~Z{ozgcezWAu zn_rU64L+OlTP#xXXF4e6N8DNwAfAX?RLE|cWfG2lfvs9RJ<(g~pZuUIB5PpyaxpKSXn(?_S_q%(T`0LxZ zywc{11R=$>{kH}Dlj3es{l6*hzY_bG0=Lz7TIc0GKo^Ce3bX4;X@v{)mpcjg>KQ=| zJp_;Fxikv+t=-T>=9V}f?a9iQU%QiAZGDsT{k=EQAS9V}umwqRApHk4t=(O|%tr=y z8|#@9v)=4Jl*e9@_D>rtXlk)aq37ygRzAYvqEIWJB~E-Ej;|4lfu??!lH%$_mT!fw zH$P)Mh0SndUQyg$6K7Pd;8O3?84qtA?~*XJ2jAV#UN)tSQ4`N$ z!bNEmu!?kx?ZjhyZiEnd?(*_-KP{Nvz9dDFfyEW?=Xjfo^q zoZ6kR!Tt7Pk~dLm$-#^)WnF3U+;HpB;?^1N=*{YK`?1Fr_I)2CF0o!AR=hq0)Amhc zvPx!Bh9|@Y*%ZMQJCfrborAe(j1*Yv7)J!7CI;*^Gb`q9-Tb;$M=ZJ%RdPU;@|av?IiKCaoEMs5(n^3QTa)#0>rJiZElhn1W!9XVm|I1_}y=jrxfX=SRT%5 zf5DiBiKU$96LmD}_TYLir#~C30BV5)@?G^!KNH0HlRw7SYC@(PMe!Gfh3O3-SBx*a zOGSI$b5Hls&w&1R($9oAIWwjXbxw<<$2b0__Q`)M8YH#z@=0_cE{qa6gNKlfZbMRB#F1GjOjWj-= zuCrypijmA@-aanI0AZBy+}3Pbqg$m0zx~I^9eOj4)ScLjPxLW-6N-HA%lst;f3kQ} zt zZ+}_&Y(HHoA|m?U26^<+r_q&)SC$d2NeeXJ?^fBKL>C!x zZ*_9*+NNqcqYE`$;0JI1IbZLmHE+uaesmlc@dQo#x(lYinh#=NJ+_dF_|f9~3dWN} z%jqv6@K$CK;WOZCwZL03m8vScVXDirRI`b;8O1y3P&ufDg3+XGFV+ZC2X_z#>vZ*Y z0&*B-G9hR3+m zndfgAENd^-z04ZU87-kXr9BlJyhYe&^*m&>+sQ{Ov8V^%7nPv4Z1tRHAFtDl?&;a~ z!j2u^fv!|xMQNs~ZloN0NNH&K{bDXOFB2`B_DX78SLvhI zdr!y9>!zpe)9#)(IOKtbX`&(BMjXMPCjhO+AQZW@1gB~MQm%(QztfZPZSABk;-9F= zG1;?D_Km(f1*{Dr!>B;KkVYAopy|Yv3v9Oi@8;GHTsSK$BOL;f zL=$1-s1O`|NuZ;6hf?;SjJ?DH9ns!XvZi6Kq1R6;Yl<6;1PRu8W7=ui-zV?IHq~0< z-9<0EyHoOD<{B-DqEgQ1AZgK8n&~^873=IIC3%8muRfPX4&cbtI;f^Sd9qlAy{Xwp zZ^W{;OJlA%j}j{qr_FCg3GN{H(ZL5Stcp_by~=1A=B_#w zt!7R+-XPUw0ZXB}b_5qdfV=2x@+~f&)kqnlC=m3!{;ek+A)un5CjZ$#%N~&O;p2pe zXS!Fw^E^adWKi+b@%Tm;Y-&xMw^Bhy>d>#mER~x4%2jFMe`q};Cs^?+ zvKyC{dTjk#C9u@Nt+jn}WJb}|;x4Mcmf>qU6{0vco2}oceZ8#Ld%!r(v17XQx=(*X zt)q(Pj?5c=B}#5o&C3UNaUU5JjI&0MGD}Z-lt(L%JfxJ?MXu;KyVa_Bzc3$TC?B?I zeV2IRSF&>;7h)`OM~Y=Xqqon()aTbFQ{g{jt8Pb5MjGqF{J~d#@aHuCx~g-#d)|_+ z`fYd6mSgUH^DW}tJJ&+JSDJ~Ee7!`R1GG!G#h2q)-CI`L42QduBl?dAVi|@zizmMk zJ@|TVSSNZovm-tzdYdt~A@H zYZT;6jQ0JNQ4+%*&J0lO5-*Vqr$B z#B-Bl>RsCcm*%c4~s?S5ALqc6#=ghBav2A1d#!42;%--#( z;a$|Y5Nbcls>D9VQ#*HZ2D2S|?pb^Je$>z{=Nd=*VegZUI%?YA3H}DXoFx-!aArY? zdoNdKpd@C*T^9g8>bhNdw5x zt(Pa?)xA{EaOhi&;ThC36mWo^>f-$7iqRjg3|ll^Cmw`}aspBZl=eciL%tA85pt=! z?pK$T^i6=y7U%Js$P}+$4Ghn_dGEMmuYo95>^>bI=DRy`_L@4>G9tS4!y&wHg&)z(dQ#7nAvOBmt zVQ~0N^tdx^%Vx)IbRjbzwJWHvaO)M%{yygR#;rV>oYiRK$&Ah)f_p&UKEjIW>`ss16@wkHa;Klc@#8E(@%HB{;RWb#JTsG(41Y$!Y7hP3}1 z@h{EHUr@tl7}wYbWs@;bGan!R)Ni~5Gb76fq@Ff|r(4^0yhCqgn|sJ2EgCWRg&Ox> zYL;ioQtWu2o*SPZB{Z2S6a-tu!E!FPK`zgF5J9W?{emWXVGoT0*MyEK1=9ormvztq zhGsEV*hgurcp4za)m@yKWYfI85zIA^$k&q|$vBxIQH@kiQ#bPs6j=F>N%>|3xn9npkC`R_I)zy8z z$^4lIZ5TJXlk5u`+j{0r=BC~+^c~Bk>8W4w*RW++QUWomBvpKAQyVlV8Nrw0uRbL% zo47UGG+*|m^%d(kO>Oar`}W(YMJ#UFx0*1daw<8R zT;J1jcNb}_T-p#%F6Qa_d3fg%o<+iNn8TGBb-E&5;lSYTNltsTn9=Y_SY^eC#Kg-=E_0e9V{Gd|;6j(i^29&KtW(6UWjezA2Fk#{Z^@Yja>L|+T zcZ&ue!r8l^Y9P1|v&~k|4IY4w_We*7M7V_8u=Xa;ou+yO4c#th-zXqYwvXRLYrGIV z`f~;utD@|FQ2HPa&12OAM7)A-91Ut4CA6m}#$0n@>F6u@)OfO?JKT{dZXA^pj&9qE zGHQ?g?zi8!@sl;)*}JK>OA@DBQnA9{T<=QoZq{1(>ezkNOFD%z5?`N`ULDYusZ{e^M;GU%6);E5zM{6CPTt2yI*EB8 zF6UO-chf{^&o6W%Nohqbz_*Pqy&Rbu0+$agFmy>{9=msDwooT+P;a<7yYf|s z^w-2rYf>71d9zX?bf-3XZYSBbZ0YmwQl-S}%_m*9z_76_@j7uKZ8Qo>`t{Y=fI^?B zA7*EIzmelEme}v|6&)K1$xDp7HgTX^1u~)LPO`rSp?xlX!buZ_(^Vb!8EMW83kI|q zhMPqoPZxeNVBhc7{aR?4P^pI3pQ7eepJXproP$zl=PlrKglLkE6KOQBn2l+L>R%S_ zJ&B3hS9FgHu9eB1|7;mY6iD4_k zAw`ip4|JNWizlZmLW9_-*;?Bde+P>UxB>`U2s=n)ESLCK zoOiF#)Zv!XCCwFz!ixi>p}T#?F)y#0`<*V;UcHxv!Kxu?{KJijs<8s^OAEdwItt?% zdGwxSuzlXSw77jrE@SsDs{Ry9H0rWc;zvXQNiC(@ufA%ll*2@}$}0UGCbHeaA`3 z_^jDhhQ)d%@mI_@u!58C-1+xex!&?2YW}6UzOt7s$-`_+bD$nxk|deHZ#`p|R>&>@ zJmX!BkJGoH`}GzNvNX~j2}UMu4!`F~!rq`KruRT-KPct)IJ&toK!eESxSqP=s%C`P z`gpbKHS;RQWDu$}**HKX{AO=ieWw-bqylznV0vc2VKKhk-H%!}@#v`gjjrY`0W(?o z7J1J^q$PEBZZT(3=$erdlnhYvR;rS(yk{i1r*o^bCm}4oSJ8Hw+LQ-(NFqof_t_Ia zF7qlvjE=smJiA=8pZEMbQ&lXxa@Ts6+N4(ojrqDoOP6x@4@;AK^JwELLOZV!JS2e` z2ckjUd#%taQr~0)R+v=$YauMIDK%)A&|~eS(h|SCqE=lQX~mHwr3`6nHe1SvpIXVl zj?aO<@8C_LTvvfZMttaWj#Vw{jB#vLO8xLm;AZ;oUEph#?3C?>Y_v57! zi4u%?H|=B8oz+Z^757SryTs+@{0O5~-;Rcca$%Bv;f#LNaJU8(RrkUiV zD`|-cQ>$J}NUB>s@v^JAOp*M5O~Q>;#UAODi*J9aKAG0_w^X$O+lJ=5j$Z-bGU|g) z;(~C#JA`*5TWiWH(yXIe<>>{=7s%w!MdAMn?23CJ1$I#iU)b&Abz`~kaSmRit9I>4 zcs|hR5=u<*ZdOiBCqhZShwVIazP`RFp5po~g&Nqy7qcLW5LI7t1@Ljxagj1AN+b=O zd!x3SE@$PtJp{!dtS5_FJZ>S+m;Nl&@_%4dyzjp~O-@`Dw}4gjQs@%jWEutcf?|C+x!i$WF}WiI80l2+5I?tk5CjqV(;WfY|Q`y zI&n_OAzTFY%^VXfQKS@lACMjcR1t0_fjs0C2A*uY)pa5Lg)D?_tf48<1perk9#GRN zI~k-SSsU-=-@rv>6{t;>)jdDN5Lgog_S)@R7Lv+~PtW!SIkb{aHhp7G=o~jKYHKKe zeOc;cB-jy#?ay;Cjig}rs++e?iTBgN^i7F$q=^$uK7xyS2Aaz6uS&)nl9?|J4yi~{ z=u}xJWWo>@CWdVk)}(M?cwp^1L$pU}8lxidy2zmWk}aB%KORl*r`MBT-4`F*3^IrI zangalK%OxNLq$&JqZ-CNoWKw5b%qo!w` zU&ChpJQN{J+$cCqZbQ7kPh4D*CPJF~WQcMJnqsTT!#yd6BgvI$sWb0c?(-qf}oO+AssD2(vG+(f^b! zM+D(|sq|eI%LE zCbjWgm%>*q{KwT$<8Mg2D+9){g74O=-G7r8F2F=}c~m8bD1%qVBIucZGWly3x!HuX zues8QF-HDLYyE4Exon~Z%6~0t92vuxy=1dZ|Lh0|>3NdF0+4s~35?Zf6-R>~i|5}M zpI`fFcoN+X)Lj{%K!2akWS@yBz5@T8EC~fDt^cii|I14_4{}LKaZZNI5poQbl%^h7 z6Nf+^7IFghama*V_Ox;{zJ~^I)9#~-u|w-rSK3mwA33w7`4#s#_(8G!H}9vnzS)3x z`QZeYrBzC0{#RJ}c+?zbtW<=V8t*SDU95hBk4!np|8?S(mQmRQjNQ|% zp%$wnWqo4>`kenvBWFU3^Ry|;bNpVU#dH8nyocPUAv4eK@2w6n47rPr_p0I^DTyR5 zyH9(->nEC-0lEij)~q>0Vx)|rK|Es;GDkR(7S2Bvw^T4uCyIJ7zc9j2Kx9pkCh?d2 zi|q!9;d?IEe3!pZ8u;IzROs=&WC9*&Z$?uMWP$UsMUB=vWaGhzs^0d5p>hm}x%`(2 zU3Pu)-TBIN7I<-X3)C?@X>ll*@l%5BV$|}BD4;LhvQ??>HeErhz@xfZz zc-XPil_eiD3rm!?d9hs)=X!iCQzk2^n6;zNcdX*25uQ~1vb6^o{1_Z}xS%Od7Yt&? z^S?^1!$o@FogXgC6dHrM$2qbfAL#MMNTR2qQ=oUZe3jd!pZ!`Y!8wWz8KpXgdgR(7gum}-l9n1IC>W9r)Ix5E=mWNmpbtu?j5KH zu-$KiuX0f;I#-O!T>kZx6AA7Jhw@8YERSK%4ww^QLRLc=cr9fkqS<6ltzilvVGyC( z1VX;?)uVc^D0IEsd_ z0xSA&P_sZ14Bb;*uUdEqeY;K~0vqF+7tp?Z%b68N!|O16Pi~lN_jJyf$mxrwEC~oM zmuuuSrv+0*KZHLlJhwzDhtN^i#o#LTy=DxJ5kb6}+Ei7WWc-i!_q(+}ZN*g()B;e^ z|J~T*Hx^RR^`E3Z&??el>Ji-zLmJH-_@{_DF5$z7K?9lj?;&Zi2JeBs-OKP!hESoM zd2fe+&B$ zB3sL!ea*~4p0)qm@EWHIC}YLI{0KYBmC_W+j%(b^tao8s{gO}n*YO{xu6@Z{zQ0@h_M^m-B!ZWuUMjo??tk5U?KE8U||Mv`{i|r79ep; zvlmwn>;LG!gRyax{{X@*eP97AOerVC!-FMy0hhF&i$_o@{-!WLknhZKqeUUq#lGv7 zdWxpIyu54AKoR|4aXtT6 zuM@um1@di@F%lsqY%S}JKEW=X_(iXf4fyF#Ya)g(SPc2mlV@CfM-+CT-E3}KV3JldEBl=-&*tg^UH2@99omJ zy-_O$(3?=vy9cgzb_$PtQ5WBymg0tu5xb1JOCuG(bJNYGeuYqOu|k%Cvq2bl2p5EM z@c-W=TAp3H7n}G%v6Z!KkG!wk>7AL#1P&%$ABkPGGvUk5*tb?40aZNN;zLFHrEMTM zlwA@q_Yj>A&T#V`egDx1Ne)R0A-5+V84bij*T>1toj+$(+j0p})hu24^d&^Go_6=g zgSV`Q>N8sw>SFsEhsC45^2JBB$E#jGpu9<63%<)>IvW$O#!O2l(UcYKQEW5Z35pu7 zj(5(X_6aV*-1*vrV`H2!D-)J;eM43W{47bcv=EyuIvqU$da+?wkiew6v zS#D&W3rSLU6bVgq8zMtQA~Pj2l$qb!I_G`QdCob{aGvx2{y3lK`Ftw&cki{=y4JO> zz1E5SIolb=(DA&9mU;e9LXvf;)urx&6dAU3g?>@Cih- zCR!=iMH#VDt_;D#)XJ;(?k5iV@icQecizE&``o1qNH3} z3f=LT3>w#pifVt(Q#@W^zBlP7z5R5%#>IANT5#W^VTh$M_?oGzz4iadogdXj3(vn0 zh#CtOeExjlRBj`JI+Kq{I#p^4g$q2Wbner-0mk@I7(AlY6ZN*~$7UsbXcy(ZaT1=9CS`3k zLqgE&CeZB%X?OSY45&sUQlo8GXgGfqKm5Zsk!Q0)=s8%!c(cg^xamBVqABpO38vhD z-Y|~qa9Jh@w$Cu52)t+75cRqQNYVjY8LxmK3F&ZrpBokH(ftR^8j7z>GLW3EfZF9S zYEJ}j89skm8S5!?9s#|7Qr7A)A9VbY@1JwHiq0bzfzw;}kH577YpKsOSb7eM?j!KX zKK&Ae?Xjx{uAQt*u2uduIV@B)ZdT~wD{;9 zk3UrCAC}jUa6_p_VAXL8hiLgFVU*M-=pQg&Q?OK9mP1U5aq24R@zQ)x6UgJ@C6K`t z`p5FjH3UD)KKtOcj07}|p6IWsS%St>StpCKt*b89!(6VoKo}>Fze34m4^O$!>k*a5 zWXt2H@}18MJ(Oq0#IL-L4d0TllVt||f7NF0^q_CrCxSx=Lo&zS#tksM_LlpMd5ENN zIw!(uzNy?*kG1DMxH&q8rAhWA>j{MB?Bo;wUDsn z{{};_ZnI3hLMgzw0>GWZD>mA|>hq&!ofqJ9I)GZeskg8Nn_iN@1-HT|2!j@b-r2ao zwJ3~Xy#Mb6s(*pq8c&h(7&6F*KrgW~cprGqew5Z1u)E+uKOZQXEa)J^We2Bw9)Zv; zA@{tFw9Cc3kKK&Hot4qekVV77dZ2NLu1hJsbBZ~9u@8@+?&S_XF|dAo`W4dlaBkTZFC^5 zWr2J(9R5&eA96w9Ma{1M*r;fPxn9#<>(uqVXTV*@b7y+|Hzxiv)ceT8t-k>EaYbst z=jlzCSvo^LLN6J$?d$IqH*MLg?BBHI$B$JLS0krJERF+Hc}z-ftgBOPG!D?1cc62D z(R;DlyyD}xR~JlcZ<}q|n&kYXkW|3x7@jjyD21R6Ov{O+%VNK5GHPJNeANg&>y;7A z_AIl1EW@7gYs>5ea#o54cS^gE8Cjh6(e7GYN1fg+S}=*Lr$mmRw5ILTXN)(pOm{BE z8>VHOZU=nwA7jZM5f=`)Y|kom>Un_j$$uml#Z?nGTKc-LG@QS_YcXRkYxKCledQe4 zE96u&m@)zBghbwKk$KL&(y}59TN#|*%;$;-mYU~b)wZFo(-7dyx~NrGreH5f&&B+{thbZ*3YNp|Bxh;dr_L<|Y3Mj}b<8rNh` z2yt-irQlH`8Q}-oceI#OUMD$Xw-=vhz(a$* z-Kvf&tZPh(CJuf~cH|`Jkub~BV(a*K8ZypayIi)hr~H6bq3wa*#NsatLRDAmX)ckF z$MI!H2>X>8&6aVihQf&I5N?YXuDogRi*v6BJQnIB8I}38+WBDgd!<&G@O>eb)pA(k;4I%`T@rllNKP3HdzQEnppi(YQb@&3=qMz| z=t^G<6MlO^?g1^9jg1}#Us7-m@nRbHY7S6fbV(`pp|7_hyG?w~M<~yB%rk&5i{9D& z8W+C{mLbg*NdmusOyzs%J!fi6WETC633?S=_z7-}#CHsw0aj#T!tP7-A9yYi2WQ_{ zz_i>HV`YpkMYicxTI?rNm@3W!BYDp}N$(C5-W?OZ4bw;->rRiii^}~HCk;q!t)M3ZQC_~dv$SQ8b$Zg8`5HVFTJ@;Le6yl*H?ii)e}lh@qP7v%5&1R+m}{<5Pw zCqMk_>~ab-wUho!kBNi_oR8eD@7w*_ci`f+Rex_aUy6yAs2Ml`g=Y-3K>E?~q2KZjxwyU0vPcele?Z z)SsfQ{EIK);X__ncOdT6z;b^UXx7z4t)GnhGTYKzQb_{GbgrMg^Y=v8!i%+TIz}tCwbuW~p%IYb6=l#i_ zJ79Wd^xNp$O!e2a*!iAEu*;ude5(G^bl`?iQBjfYVX-G)gB75yi}$mDI4C9YV1b*^)C|+)8+YS1UxdKUI%N`PA=7Mx2ME*Y zZ^G&M(jPjZ=Ii|P`d+*?lL)QI`^<@q#U6G~@h1lh!T)m0snXLLO-f zKeca_J3pS@Tkdm!9@|NSmA?p=t@M8`+h4U+ZbHIl4A(3oM?w8NM6vE}F2V))&+!8K z8gPf8J8$(*(nI}0mteeBiMlO0ERPYcf$Q7KT8Bu89`0ZMj_C1c`=^Ra**^zu@EYOy z1!E@vt-cm&$B1Ya|8@w2{=K|+T?)5o;|h(rG#EJ)%nyHkt#?Ye#^D-K!dHOm zQ00vHiZKmV!t4F}_p{I**9?u#UFvQLFbm-oyFx;4p>wjd7ego|JkF3W^c&g^ANd-r zQ0HauRlrKtbaXpZ;ls)8+^CS_z926C4#ArHkvLb46o#$f8F|7u0CeiIO=)~cBf&KG zS@HVLEqtKRp?opvFmAO!qSyIAV29sj8tr@oTD+$AbtWs0sjoXQjjNI1oHk=4J6xGi zLvTh?)AU$08XG-=PN%;J4ZqUAqXs%Yp!ZK#DxQI%+zrq~)49T9Btqs&07T8t5bPfZ zY%LvKMd2n7g_~j7j;4TEpReJ1 zPwRAU$pR*wH=g5p-S5p!7{5n@Jtw0R%}Ia>JFupvKG?`Zm6NtYUo!ewwmW^%rNx%E zVE`zn@)0J_^TL$wxajEUnKzKp=QCpK#0DuKJs^z+K&^Fk{VFYjV<-PlU3jTFZV7q8 ztfeN0I=*nCgw`|q9V#c;0hQaGxVm<%;^DnGP3DF+6t*3;SOMt;KoRGz_1iLme=fL<6l#0FlNwmi*N%zsLDoODBOA{ z-7NsSdAv`1#K7vKW4B19g8kiS?0P&5Xi7)`91q_wNrdC||2@M)8jND_-!4c6wDMtH z0bWr6;AxV#fCrXS8Kx0Hxm?apKR{Z?I3quT;RJJMQH2pDbS`@x-GuG{!u}<6|8wWt zT7Q6y6b?oMLh9lfuDf{qP2;de#u49dB)OtGDy|Ug|Ha6r6K5PFs?6`;8fdk`Uj)TT=Kdg4#?twFP)^Gh#4AF#@Yr7-Dx>`vk=j*u%dgdmu4F{z25H+G^mK9`X_Tn@i3y>#DlK7YcF8 z_&<(Qj*-N{{>$v;e2VHqig>%^50nS6OYC+--OjHQ2KLPPood>P(PhrjsGAGsi_^X6 zwAh88pVe1{RE>q<)3-@w>>c;m&gwmQO(7Q(Ok^T%B;dgVi;Zb@w{dIPC0YlCDQU8u zDig&x~P-6Davdu1qqH}xpV!8J=U?Z-4q zCD0|uGbscqq~f2S`>ZS3%%cV1SNaw=Vew;yTCkh_PTM8L73M49x0k?uy`;fLoQ41N z#)DJ}C~(U6X5W`>#a}VRP0XKfeHxDtr6aO9TUT`ic9;n79AF(QROPE*(e)B=sEsP! zG$h2PJ(HEZa$o+r z##@LM=T5^aWvCVBW1p-qf379H2u9&bXQaspkRpvq!?T42lK7s#g-f`NxP+a>gH{@e znD|049lu+!C-Ml8d--7+Pk4Xefyc}b0^%yQ%*@S4AVA}xVAu#9X43o%Jnx!@FFpHy zMa)k%0AAlTz@w3_6D1!70J-`6jNI0H9QTqS6Se<4vY0{>V0{EjBRpxri1MfIr!^m~ z$v#T|y0~yu!NdRLe%eXz&M;y50U1B!Lf7#n{$HFuW(I3#ZVH*-CBj?M;n|VKA*D>@ zTWNqJ3;r88vR~4eL_zT9g=$3ynecf;+$V5f9D+YHbNTzSh`CuOmWbXpD<_s3%&R3b z9rq`quEIZp4A6CVgC1_ub#Z@3IQ}O~*@hpX&sw~U>}q#5OldJhkMj=tOM!*$rvFHf zmA6{;%8-XJ^R4vnAj(V;QTE$wC#{PUcR1a*zRYpsgZjqJHe(IOefXs_Tm`$OtEZfP zAnb1^DPqzOx~uj1&FS+!m>X+p^sb)US8#0`Sk@`B8yuSd=p64h-m4MpVl_HUOxUs79Qs#2(>F`%W3ln%ENzVp1h<|Ya!=VP zOc}6x;|Vg}1&utF-DY$6RHB5G@-E$B{HXvj@eOrlU&$wu##xfD59^4wvz-p`l#u6| zt$#OiKG|lFG$w&bK_z`OG}+a~Y29aPj2R$dHiV~;E(Lt8w~T$$6UF8}<}_tzML>dn zl+9idN^Yrm-U#9T#X8wxm}=;4FbzQZp&yLj^j>xZJKx6}ySBH%t&343F(idZxJqmr z(7oYQ5CdVP_UABipxSbMFFP?n=yPMLa6{|J_K~J9*BVdB3Fw+jI=SX#>5D(`QOV8$ zXPC>>jF^Ytpw*>AHzxT3vKj+q1#skq>$XUn7lwct70ddxz4puv;^~*oa!+bns;;x!Qc>IZE+8_aZZ;($^Y+d4Z*8?Z-%A8&Q4;AUpjaINDAsx{Tzxc^A%n_PR0 zls-<*_xqhEpw=Qhe*7tY5aucHA+1KcHp-(9z~=8}syVUwS&H5pqjr5~P`db^N00@D zm$V`(wd<#NLjXs=LGj_xMmfbPTQQ|DVPE6Iu24t}IDj-miNVj{al|ugC7tb%m!5Kc zI3C_?X`-R%az3-eFNp?ojW+89ziHq8_nuJ(kQ9Xc;G;i%eEVkns1^h~zyIdKxA1M| z_V29fVR6T{;O)c&zb(^X&wN$K%{){8c^7-Q2v(2J>Bz{)%*Wq3a5`XopW;Q+_1UBv z(+dnoxk!S{=zm0(x98mk>bYRoO#kB~n+9veZb1o9TJn=NPJlF_;Q&No3d9>v0%QN+ z0X@q8es+f2fGmqzx}BrV?t>$wp)B7Hg8YfQ0Y%OGAgp=x{9i)-5a1Y_U^<#08h2*r z>+9<++X7o~7qME)3~y5B5C%#3Z;s@baPX(B7AIk^#9Dw2uzmcf_WV`du*Pg}kctBU zi*pa>T@JZJv=lt^cSOrp!DpG7e?<6vp~0eA&r9Rl%D1zI*}2ow($K4S{dIU6+-Mju zOTY9=dHwz?;m3#s$Nf52(`WgStHrCz1VXO1{#>*Dz_jEA6;4t(LyA3Iy3YVA0mN&& zMS&#;;*E}kpnMr2($)06`dN?)O|o!7oq=R#2wX*YW7as78W@v8JK@F?qn2s0M@aG{ zqKLW2$dE1hDq|WaCuy)xCVTh^QC`$lV<(i76NnNeyIxzHxp{76kJ>)#HYoJI_v|Oj zhY$NpylJd=K@!6${YMdky-6ccS|hRJlctg&{7we_H>6e-V`S$$~n`V2B&?NB&bKvY{C(EHmu^P7gE^ zpTq@yO35TjgnoW^FYU!C*tK;_eJ|jU$bfGb!RV5eabB^2m7N(#dmjl|ja+uc)>BI5 z0)bT2xQp(;&&2$f+trlQp6`mJ6=T=RXy@`%2%F6zp#tixBy~$|-1&Q%Yt51$AEorH%Bh@%pN6Vp zAINV<2p77{8zfO<8mCgH1C*kt#Tka2c9$OdTCJpwYn{Om#?Llr=Py?p1td2&0U|O! z^Q+6Bzt2YGW>yAqdHHd#N23Vp<=F5nWIMEdT~1mB;<05T@Js5_Cu(+7Jv*}F7KYMB z#8fS2;3Y8PH;rljnqv|Q?9YfjBG)wE;^W6clv=MhO?vjS%S{~udPgEdG$$mIq@ zbOw6uCoaG(8h&@*R$?!XqxdP{p-Y#zHmmM5OotzAUL3K%c5!|i(RDL`KQp}n04O?1 zk`-V3OuSVEmqH@piXlW9@5wAx#rZ%V;t^e&Ws`y4yVg~LAe7Cus`kS|i(ch=$aaLZ zW7zP}u)VQB>6-xRw|CTKGN5#Bq zV>=1S#>VJptm!YE87HfDJ=uTECwrr>q=L!*(o>S)@gtY_lS^Z6(9Uexww0T^GKyk~ z#Voq+p8S3~`sgD8mGx7YJIbrhoEobQN6eztR(IG3vkp3}9+7%+@}zOxfoIe-67dJO zvDuIIRj(<#eehlG`(#;g`|;;<8Asn5@+3{j7ko-zP;fhc%sk`ORCr-w;oW;wnkP@6 zK7DJg&bFYwbC;)@UbW&)zucne3y#F?1H&JtxfuP$E;!no%$Lu+iFLBOU&*HAW) zlo2aRxJL>N{m-1qdhb?q0-ma!u7H_@CL=}W0ZU1N*uI$dr_w?~$Hz2=X?^EKJ>d$X8FwbYww3nwRI6hQ9 zAnDe?i#uYz*dbfD-Ox(C%l5c3y@-L=P00#J>(7tF`cl}e7GisSzsf8%D!Rq5Fjnth zv76YNIVrmD>F8Jc8I6(&YO~T@i0eNgS_JF z%u2O=D5Xb%k|WkK3upXjr$b#ify-&$4B>%JgIHA z2Q!DtHJ+BantVE5BAsjREfYkP@FSze&c2%n9rO~Zzac?AONWVEDR%^*?9jsAUG8q{ z7E4J<=>Zk@2I_#XA=W9MYZFB*aVy`y6n_#U*48TI z<)HEs=naxFgki{e+*?&~?)cV?ebN3mC;IZw>fdt1p{W@3?zMGGDJ_0lnF~+Gy5Bwz ztbw6_jWsLY8+RXlfHsq^KHrLLK3I-z1-;5BfTg-s633RWwAex7$R|T<_?F?z{AU2s znrs8=1j)*U2Kg+ z6b<{@xb}j~(Qr*WMn7Fcg6;-$j3GwwTpe$mOR2Ekz69tys9VrJaeao2jApMVoT+U9NvQ*vz58(6 zzh3gcwiQ~eRSF2N@XXb0c)JTw3aNk`$O9=T5Ui~t3eOa?+7DMUW_#+#hYtCJM+?2{9 zFI!uMVg2sIjsN%SzV&i;NO^r|_Kt(15UqvTHLvNuDvz!c-u1FIOZu%{sDeJyan$fS zny{OnpKsnno6eOjsL=25be!Zgza}WgHrb#_joQj&Kpl4L27P)c$NOC+0L5tJYrIBsx&f$qL>9bLLLnvvRn-O20d`b}E?7 z>Ycflr`vlFb*Ss81CmC<4Pn6q=~(0%bS&5EDgUGWiKW9fZRn|Olq?V4i1si}Nb|E6 z?mFIVR`36aNk%S%%INKNO|c5Wp{#_{R{1%kb>=QdJ@<6cUX$ZrePVMv+Us7`x5Ia* z*W`|$J@rWRe0u1v`|efX^bU8w)VRNFl(Teox$dNva&Auv9=?`zzBI#kyssL0Q~GLDAIE2m z6c6p)VytrVB>dAY+;deS==c|Tb$!o+e=bbB5E!`pOtXbQay0PlyX$!s`2r7Qs!DU@ z`V9?)iM@k49-n&6gzyhDOpJ@S3_SsQY5(cnEQ}W|J-fJI& zi6+e4yQGlG+aB#|cuI?H$dj&^ju{_d>nFw~$!ajl?o*8OA*RB*OzmOY0rNS+3u5y;xl(jtupyjZ{Xc-y z5YfU-guFQ83!6sXQjARIH<$FAYC@**qEk)-zf(3I{0d9T!_IBdL=(}nFyXD~G}xUa zwI)!=YXDJIWYbqT{}&I`arGcx8+;DxYD{~K8_xY&3=3KUJNSVMmk!g5Ly0Gxi&FQf za;K&iE}WZB!x$C#SvbeMgNSi0p*^EW-R53a4K^gOkdJYe)eem$j z482ZA_&-Dn;|>s2aJd=8zvEPI0*}L&0UVGF^CCkd;!a2K(#<+Y*u1NcpEs}HjXVvo0X--QX=cFNp-PjR# zQtjT~4`%jZbVVzl@qYG9yPad==C7~jJYIjU{cXUfiK3rW>M!Sq9#6;!wdXzBE^wuF z%=O;86R{zbD;|%S&yPS=MGSR)9)a20c$lZ;acfGn{3BU+wf~CI3V|L7Y&EjR0!x3# z>OZ;(7Rj>@cT>RaZTi=7>LzvWd=$Zved7GNK}?T+p!UOa9uZH(JchdPwJqL2-QA(o|Q@b%2e20l@p$CCohH-Ltd22q5|JePmy8+s~Nk zO|hTOzalT;z#C$aDQ75?bx%JgNnfI>-*2TI=iu}CV4cbyn!U-O#(GNhn?8vgL3Mqb zH7d1#;_d6_!00N|cD{;)Kff??=zA-MkkDx4(Z3H*@=+Jp%EchZk%x=U)oM8Py_5kH z$$FE$e)*LI6)7(1*)1*tRh>|oG)8x)c0b6CIb!<+F}l>XsZ1_QT8s?GO>r~#Au41a zW~FhY1#^6ENLj|;Xe+)46v=I8`QoKhhl5xkT!wQ+v;sNJEYeR`{N59k1{BB7(n%vf zHKF>4eu`(|11$xKk|yKCMH*H>X`u$VkyK~urI@p(=dXawd#{s5;>!1?+BzG+kjq@ z*jv&8ReS+9e<6h4FBXLfEADZ&E(Nr{sZX8ATA%26H##;pvJkkjFk;KGl04Z{ru%)m zKkK%4v9Pc(5gwpV-p*$Oc&-h5FqHhwEC#rwn9_fJ1OFYQsCylJAe9z1-tK^PVKFV% z znDZkv7RG@r;u$c*e|7;t#xmU}0u=T+x3e~@gG?iGEj|Wr8BB;}7o-Ic!g}+4%5@-q z@|W)FiA6M>ITYd^L_7>z6Mbf8?C<+`s6or7*ov0J=jiGyxnem9-Wof zU8Hx-DE;6M_~-+GgxBGzi6CF2$TULtgCcV$*)0&r=3$;TR00asI=m$su1nkae3|G4 zjYJOtO`O0-5D0DNzlnm7dND__V;WCBF|mXa@=x={oP>GOiDC?4-sp}`E9fVv*FRF@ z-amAyx({u@T?kI?`Tw&l!1oj!hhBje3!v+}za=mEH>}O0L7;#LW4^iC+uQ3GunfL(a;JQDHO6q%>~)m|+^j+EU0 zB*tDK^IVDZ>55`Z<7Fk|H_t|dKjiN6_`d6?9p)~yEqepw&bM$Dm`6Wy5+9>-czNu6 zhK-g&)XCSd^(xnX#d>Jn+=YpZnLOf!SJEzdV>$*L1Cml7B6&>a7o?`-^#UA*43Y+El*TlTQb zkn#^z;zy?k=vsX&h1EN{Xh|vvEn~+fSci z-&lOr_}O~lafb{VOOL#hD|8Gut#BUs_$*{|emL~=81K0Z0QF=qd{dNVo#i2G2t+2% z|3x4|fcfR|7~){_PE2Il>mC-3#4^2Y0;)LX%Vzyh#_(4$0beS>Rs`|%kHM6P*yDCe(*iDj8~7ZjlWjZXAQoU+*yeyLc;VQm^f(v;2|ggMB-JCfVahiSNu1dcG&G{f9wN zT7!`(WiNQ2UM$dq53rWtqZ8{ndB6v~`ShLLfW-Jy=L-^C72$PcBA zf0Y~wzYmYGG8iCLOe>Q{$3bNhOqAGBTqCiv=1Iq#$@eQN4{FKx~l}G zVDqtGbCJ?WjSn*3KcOq>Z6TREB!p-KXTUL?63jQr@pd#8{_m0H=zy!f`O_T>RUCHf zjcL218j02SupkuXG~xgdw|u<4{V+tNs6B2Ib%3g-z-6-tPcRSgc{2a<#IvE`DYvtr^fq|=le`vu8y9U(w({KdxHLTy}!4|!oKb4ckAO@ zr3JT38s|no-fPaTR|3Be#j=S0qt0H;L&2l_R@`|g@p)=WzfQEnAff6uf+R_W6A<coxk8zDXTNlXr<}m!5XYKYV4*!M-&`509;gsD98_JSV z-MO{ql)WYnAO4QFB64U(y45N_L({i^-!tPkBKRU#kCm(qDu|P661yxoB#f`@xcaVa zQBhjshV290l)0Dhp7M(ufoMC{Dh;K8j&KH>u8#_Q4V5$sGl>)jH5nxZ-l5Q$jPps^0^THojTTUv?jc<@;Y$bb*qWt7)naD+r#Js4) zl~KBFE0+Pu}XAeH8$GaM)=nI25fH=P#LU`cM+bvn)I)$#D98l7SsGwhw;n z-JhYg*M1kCC$zE^x-@jByC|3f;dklX9o~j?zEtVlR;@^)Nf5k0WCy7g6TCfSN`dS= z@Vi4GubAHe<%f_vMsrZGG3=X7d_=75`*wF{@jW5PUN^-p1{}{yzvP?QvzO*{if#e{mXkhpjD?-FW&Nbb*PW4W`?z9AcrRvOo>JNM+Ow#IFh_Qib(g0 z^d#Cp*hmkxJ<_nyLQ;;Nb zepfrchwweI0D=|Z(Wt0A!uH_XB%Gb$O_9=4eb%Z@H#9EkIdTN(=IZFo;CRp?_mTNJc%PV|XzR%@dt{ zGN$PE`XnGVq0NyUXe4jeVt7{UE5#}Iwm?y1oskDwYi9}5(1|Cu)Efr1-*Nh!8z~xh z4Kx$Z4FYhw7Qjs9)zjl@fGdM>q#c+C6+4?u053~OS* zhQ-=(k~eDj;!qCH$Y4&2OqehlWP~y>_POL&9UK*le2LXYok_@-Pk^hh0{SQTOh2H z8>ap-k<0j6vS(^mp*I0*xy(Tk{-@mkQqu!sv9aB{6b<4(fe+m{Ll#Vt-9pRpG^~mMatS@xCgf-8y7c4so zJkoBza*b{Hduvh0OvizI>%b>OjM1m6*s`JIM38bFhT;#pYkc6Z`BGuuQJGW(N~)S} zQ83eHK{Js3R#WXn2*+x+70gGe^6&q4*PExm^EBmV84wbUFSo}HX^I)`FETuVtJR|m zj`NNzrV*$aU&;Q*cyz-|Hh1v`Xxk3Xs_o%Mi2C%?+Si-xmblQvFjQ$DUqu}7@6+}; zU=uaSpfH?Y+KYoS5;e_vbvk}Y^Mknd;@(|=7A!zM09Av76P$#Dtsnwgdda86#bkZq zlv`BbymsLJnt9!t4k!oDJnpJ8=v)8(XhCM_lX;CG3^=rRTXP5)br@#tU(|9t!IPq` zyxdxK-WussfTOmfic?@$*8r-a<-g+`6vMe2Dwd%llEEc{1F%`2J17cl*8_n&P##ym z42^`qduX&nYL9dPpDayWVWdgIyFs=mg%?0aF9u^D&%4Ekas(V)D&+D#<@^0XmPp_? z^T4-78)Jd%4n7?>*x&vT4s0V&$3R_GYFoG4TN*@*LchR2qTT{`pY*dK?oI8Gg__b?s9>X-QX7PkT&ED?EZB!iz~rldJF<1qF|xi@n?vN9!tCLa zrf-9tmTuQUE&s^8s=S>4b7yh9K|y>%)u?m{9^O=BR4v!6f38V%P#9wVW#@lHUfb9; zvt7tro%`=uGt)Al_XJfrR-m@3Mf;Du*_}J!k?)QBzJ&vT7vM>J%1iZy_%(=o;Z`2? zJiw*(nO2-=^{jVWt#cz<_Ai^Iq^1&TJ^K9@zEyu1fkutjYgQ_2;^PhE3(W&x?2;u4 zSz;Rt3A)yy68FT=D1e1CvsfNiNBI{~sV z0>9+s?(xdm5lIc+{}sscTRF{bs6!p2Sbhh+EDl6K20wGeC7Kxby!BBqv28ofjsX}6 z&Hf_rrSdls@SRN`Qz&pd7g8KBcBmhJ)Id&?-Dj8vAR7GpG&hbVcpah zTm!nP`s%?Wv}K#(r;#x-uCR*3DQ zIuB@;l_I(QQYOWG&XT;GL9_gB>Ka+r}OaxZvYb zExL~hQVY^{o%u;*l4C3N;agKI6-*5>-U~W#1QwoExDeMQ>U!Y$1|}wOe7-dhz4f?G z+jdEtEC;w(`sXuxeW8rl&e=~#xQ&7T^NC0Ht>|+M?-+td)+Aj9g6597wUrpc!FYHorP0oa*WA9C{7ONq zZ*9uE5z2SdY;Im|qPN9iBd&k%I-7t-7UQzw)4 zFMPeP7ZJGGjp0J)X3BIya7Q)D?EUtGv|q z0n;KE53QDD+0PuIKfY2eR3*g90-MFKd&`f+zyJ~h11T3e>?Ehnd@d8u=?0&%3_fU; z@$kay8t-2BP`_br4!rZ-3N55fE-tKrp_6vdMi{caCz_CIJ<87{1G-TU(TzH}@0In? z>7u&Vf0|YOZ&kQA{$Or;kr=4bncIRPF)8$7gY@BJ=#M;42efVB`O%ojeUl(TvnZlJRkKf$A1-W8g0QB2)QazsWCox6#J&Pe{@gErZeB3XGA@SMbkJ zUQKsQ)L?Y@M#4Z(_#q?M_U=_n{`*I6Ng|dH7oD0)QN#~>SdTg`aB^*yX1HQvQOAXF zZkzvnX;DXTGm@liEyMmE{a^I=?^QcI{yxw7CGA?xZC&P#o2H_TQl^-9pSar$HofVt z?sG%j+bjI%1=nI#uTGGPe9P^)Z*x0urgsOR`MzJ!{LB4>gCT5pPOHri1g#&lT&*hL zg%Fh&Tm_YZE?OQ=al`Fb2lgTzvG|nXHDTe7O3(u5dtIX-C9% z869=@o|*kI7s7B8$uoNlr;Kfbiy;Y)#$9UIKd|MjvQpaeA-8Di2f26<;?Xt~F33(~ zw(Hmy!D$u((9WG%$b>tfmJgawNA%R!e(c(~FEDr)`Gc#GTII45w4|X(-7Q~&swF3G zW~_m_2CgS9!!C3a=q&71)ueack~M_7Kw!$VV3YGoEFk7+rwh&ga^Lt6NQ?5MHwt9 zS}7P=l$eJ2X}CMw+pKz@pxk)RgVT}aaDbPaf$RPQ!H-DA^mh!&A}vLtMeBx{RoqpD z(jgtT8Jcj?(-yFEq9@vSfoe$imhp91VEdInS>S+1V!NU8i5R)jy<>aL&6}E=8(P*U zTN=6@vP2;B%U=?-^50>rRDP~++ZqcCDhjOdS5ku-QVfO4c4WxJu9wQ=F zCAc7%L=Xk19rtFZHo=F=4fuR1dH-TccF@ZliC=LQqaG{G`$WBzXi?Z_Ko^nHe(T8v zI&B3F!pp9s3Kdp*_dtYundKO?qEkKrAvw86XKbf_=>>CtskiAot=kiMg7Mm*1=Eb*axk|g$e`_dRCPJJx;%+eFg6pnGa zby1aWOdLc^H!Hb;9C(eSl9E%FVd$M%Js!aQqn!WQIjzg`xF8N2P`yZF93ODtyELRpJ z%-3YyH8z1XGzyJ6@{Zb0dwVn}A7`@U|H%T8njFS}HSEcPmV3ps`Lb04$Lb&hJy3_d z%)zBEEYKo?ZK?F{@-mhUi|l#4izrZa{5nv*kcHWKXRPcBZL8dG>|{CsarXNn*D zX8v#9(_&GrbIih3KeEv;Zf%Nh&+0;%Jsn0zY##egKEjhpJF$33KWpXc)+21MDViA^ z`Xbd`>U1>l7$KmZCyD6WztYzvzY_wiIKT>tt2%CDv1q|Xf0xbnnA9!sOZ(X5=g{St zap1i+9l9$qpS2qF?na*lb*q>xWIqjTzl_A4GrzxUg6Twp|B6>P!7kD>D#Ubr>zCnp z!$iRDyM6_mCi7@JsolhGYs;dCj+9FKBP6P|3+q=Tc3t1=abBV)+^H5q3|E?~-!SJR z>gnsc+A$0EuWmjwnZ6=^%Ge$o12pP?vnYK#$Ue+EP%NJiT>!&B*KAn@^`l*S6`3{g895q@6RAejf*d1_umy zWJkcX&;7X@S^K_j1Ce{?Ae*gOZP1qO!!YYE)_1Bx8)#tr?kZpJPWuGt9!qmT)~J3X z@*kfqJ%u+|7o-sAV_QORu>Ka*^}`=d|A{|5VOO}I_E}4D^%4#Cs>3Vrk?1wk;Llv} zARF6eO6Z5u*pZevw>X?P4@3z4H!($j6*HNap_FDUI$_WtBpDd*W^7Vaqdl7Z2I_T{DSauw(IfcH)?| zzeZ0oGV&rd%ayisAiQDJP18M6#neM}^67ohEbch_Z4{ik`J zj(xlZH4V;>At)}gV*NhmMK1$?(eX;Ygbz~h_v0ZT>Qb=iC_``HzMB5qoQ)`(bo@g0Mt_s zL&b(AaJmpWupKp2JQgrD7GQv?WyBjCYmRcOUtid@WGZ@WtS#^<3(VIW2AHt>eO2XS z6-2MF_Cp*!NMoS}Xd_1BhvlpORiH;pvKmev!24ETVJlrkR)7_*edI{Na<1)cyNRoC zKIPB8{j;A~DJgXAz9Rew0FU7s(PO5TsnKn%)2+T`wW;Oct6~PYvZ&NYOzU5@ulVxt zy~m?Gx=l=FfWMN($TD;z7q6}j35E+wSzBL8B7bv*Ukhrxb}uhQraeLw$Z4~)w5Zqf zo2B2&8qgaL&`v|STcNLEZJaM3E~d6Y!aa_LnB5j9qKFd5p#$}{q3b<8{WpDYcPHL~ zXipCi93fH$Yn&WB$RB|!Yu1x-O)-xc*`m_{0oLnd8>?9!wG8Gvm`a7!5>!t%{ixnM zHeHLV+rYl9PZeWBb<4`bZipE$_Ft&qxq=de-E9O)V5dP8<$luic>w<*dB^`p+?&T!xwdV<%Up=cOd zGG@w{P#Kf3%&|;mD4Ay}4P<7SDPsw>%(Do~JeAC33Yq6+THkeR_rCiX_OqXNKhO7l z|Lnbge*0Iq`@XJoIL_lZPG><773UOx5^#MZm)=+!puwqEJrBCpmo?kkdg0Rm8m*{Y za?XDkQlR?8xUJZBfGuVBb5@IOhbYK2r4UMrQZ4aG1YbV=CrbKfGT%SSG($nKagFmU z!f)&+x?3GPCbCizDRgdgC(Q?iG;64oeTcnyR9ffeyj!p8A_g=G8uY9~Pt}!^-AgaP zGebZa^vvSnM{^Zu_zpJalwX=O1S}oT%0 z``W(t^%SbP49+^Z<9&+$f{lr^yyyzi%%D0U>{CbIZjk1IA8L< za*Bf=A_l+FjoCDFanZ@ESdw)Al4BhGwi?=#GJ5`k1*XfyOr#(CgehkLPP>I}sn?MYxbvc1f3Tj6|iR=iN*DX}Z zE00I%ah*7<>>B7tD@&5MC!k3sR9A#tUMd~6bnx99nilh27xVl94elEO_pLRS1-8f_ z1CCqS&(?|+z)v{ z@QP~}t2_yEvfOVu8Qh$uk&5Zrw9uuOtC#ZKUsnDwTAt6PtJ(oxBz&FTOAkF0M()?2 zfWW8MB^>~i`NK>BtVsVF0CP(7C@+tyflec@@H>jhX}?fY<~8S+U3b|_HlNZ$X~Hw` z&B4Rd1Uz%m1x4vy2s(y}QeY45vJeLFvnsJhXoBeo7WHlSts5uauRFd*iCdv`#GHEy zJ_fM!m|AP)U(Q|sDMbf~*#XJ0&Yt$59!ks>P-4E$oIOCIF`|IflKcbmjAV%5Ng|*T z04=A0(Aec8+|mJEs0-!o^5@&zPcc0jk2hCA$0lTq4e>5~2A>v+8U}XMQIBrl;hi@0 z7R&`&rGF1x0Bx|>{v-=3*PtY`(G8MPLjexZ+f3eiZMHOXrNCMC$4RE=G<LU%SYCC1TQ&

    5b%fi{$TWi08M}*@(sSj>+v?jSBo>tAn~-F1Bpistdl|2 z>`%XzCHY)4E0()I<8YV>88D*2$~o3I$?@?KEVxw`6kYYolcQlAh4e<(XM0dyF>jYP zHG6N5pm?l)qE*04AaUR21n|KN@%mX6cK*fFchnE@Zl|Th0ur+7L(Q+Hbet zO+>^4>Ne9me}0d@*lqr*0apKY4e;{qJ8=D}V@W@4smd#Ag+fJ%YRt3a)MCwuZq^2 z(ay}@U*g|7rw2@SS3to%$-7c@B zGh9|@-QyBJtg=b$!^SW|>QCjF8eah1m=>BCrm=wKLI0`?jT~l(W2LpIQx)NBzoZsd zG?!7^hV^kYlb@lbjK1}~xMiTpgolYT`FueTQ=3(?B7g6*SI4_>ZV=qNyx%0cTW=}0 zbwvorq*kDv0$UY3_wIvsXqH%QBHa_g1eLF59f#97`!C-J%xKBSxow~m3wGwXvK|>n z-#y2?W+aeo?0VZIRx+6QAt}MDv!#9bn`dfM}!*n_C`Y=OOH#H`=^$ zv-1nvT!wmh#@*3~P9or=-D1kUT}{T1?*EVogj;L(YPRklm+RoI zatM4>^zy2<^da;7(xbsWWc%{)W6wQWHnJv#UT6RpE`QV(m;(aIziJE4yDD*l|D`J~ z6(;~H+PeE$Uv2Si!WEY<=I|eKKK_y5#834QkT*}W`I_@kJ$#VOnT(8RugXaA8D7l) zOwX4d;-H0?dq2qg2&nBC>teUSFB)p|IqfBw%i@nYQq9vy((f5p9T>o@ojwsy1 zI=;F(JVO;WJ?gsp4DO-|xUxVc$LZc?&dd&xfT!QrQ$$L8ZehSeRsIT>#RzZ{TJ2kE zph5O9VH=T3KH}c5>UB!;)=;#~y6kf7-;95f(4*hce@DWRmfm2_cBT|`J{uRwE#W>R zHQ-ezDq$?zyy5dOFKz+ukn8zDF4Z5rg8;7ff0zsb10pbf%MiN3cw}R*l$l#3&yI#? z6t|ss0dJ?LJJof?hj}gfd6~F2RgiJaLt5E`DW|TSD?rIhKS1W&G={lC{`tOllY1&2 z@hngj;cxu&ck>L+JFBgS@V3{}(>r<5nAgjV2=AEodDDw}9_{+PknESH8SU8)9`i2o z(rgC^sYNn@v(iBO``~0+DGl`aS>6AK$Cr{jdVJ7cU4XtW0CR%R7|=U;u?Ia;BM3UE zOkosXITZ+)NO(*0&GV$~?{gykM#JXEW>M*MS7Tm*m+&risLHbhA{|u)@O#gxtSwl{ zRvZ;gF#PCav1BpE1zO_yo5x4(g zD$V7p1b~T{zjJ{=M-3i5Xqpcbpwq;!^iZAme(|YCBehK;)t9Aj9o)~U_2yhfWOKtL zup79x>9C>FL&nzHl?q%CyB0?(bVI}^o*93H2!W7)+y7{B1?b>ambqXFgUzcOh680T z0CX#2?_B5t$AE)Pg_DhR2uN8u+>S&e9KUB5C4-O%Qaw)^FqQ$oiVbp5pGkzRRqJl6 zsu?>vc-Ib;Z+Gi}uAt7M18=G7O*5$HA`~-_43L~(!68#DeOGLkYDpwOZ>6qLCbvgS%y*G0`P$%ds0O&gE9Is|CXjraf%-|TeC%g{EE57=eJBn|+$2ftFQV*bRN!SQcy;I?56)-!*k09`$`lP*VecHq_plR(0G6Vm{cMMjN=@L*Z(+Xm{!TwH zGx+`wrs|-&3i^*fv>;2H>c9|!xvM1A1`C9R(l2eqtC^m;8Y2{OLo~^04%{irE3$1- zptsz8#?%0O4rtz1cW~S`v+k*eI`ZkGaxi!F_TL{>X|PBQq7Umr;15*$>F~7WKsEW% z8jhtJ+~IGk)@HFXsh#mhd1@D)f@b!mUzzFLrF*&mYx;5bw`tJtnf`Ud@*xH>^bW_t zZ^^;><^6N;aq>XigQ?R!0N~Q`WRidLI%>%jfvyB`qJ)F3_5LvN02w3)Jb>hYRLv&J zzymGP0Emx*niMkrKCM7QX|o4}`UI6S@FT~>1s!0Q^XViGkc#o_QjPxp zlqcQb+O>mM(_W=DId4OP&v))9VO&AlT$gEbqX&oM{McVRB_Nz6sOW&rG_*7+7GQCm zjb+Ib;P;kc2!14??!tSOR>ka@?-b4~4hBX(kcIVDJFEaL@r@bIuc@d~nrH z{d`&H(7}FuYJE2V*G^}+qXU}m0bBP^R~3iAeJAY9{s(bJ>6aszyU0ClhXeE-3Q^9v|Oz6+X;ZpNm*p1 zPUHaRaebGiVZ;X-7jR4T6>1INZ^Rer{4d|Q=r^eD+2|RT*f=o@DxnNskCc!2{{L}m z{p&~1+G*h!v>Qm^=#9QBVr%sP+wVZ=SBlIffU@Q>-`igS)DQ$z<^JC+D!{uY>^u_l z{*9l~fA>oVDm8^b^&&C*Z})(KoXMaGc(4H;YI{#(kkca9i*) zY=L_cLhDsYFRwN3w#eCAS|Hv83e_^QwZh)_Qv+VB9L=xS2LvM8M7(=_z~3XoX9gh+ z+RxnHrju1T3%%`=>jet`wenTW2+1w{VfYg2tgKleR(j=GjF34$#n&)^u;a{-v=#3v z(Hv>JC4iyZlM_0gEhJ2_2%xE#1aH>NuOhiS@3zMP(RbKY{Z!qqcR8~+O0V;Q109BE zat=*R04(DnnBWE{*X9R1EHvpt`{EAQF2CnE(kVc951bV6tA1YkRN~+SC9ls=eA>N6 z2tLvO(Fn9w{29DzZ`ziSASxsrgMM+tn`-|VJm>v7+7f^B?DtEQa8p;1#m0d&!%z1| zo`-pk?-IuwKoTP0$p#W1Vy8pw_f`nel=zn8{*JqJ3NZWRmBJwV&Iq{Bwp}Z8ZgR^O@9V~wf6bj8xFRcomJ;SJhPW0 zx&#Okay;*|F%#xs$)bF?C~FtnuLq)?c2a#B0DAU`KWF*=LFI=oL!#KgwOhRnz_g(A zi{xNSvc-K%MNdFe1SCw5x&5hbN00W;r+8k!3z*1$qyQt;fmMwY@rMBgA{0}BS=tNO zz*&5|pMd~$P_|+AN3_AXL7+hQoQUgqJiENj{41JD-fI8vMAHS3wQVm}-#CTv3Qte^ z2*L+>$2)6Eut({N#O5=&#QoQAdzcIk9`b{06U3w`54@g8Q-RdvrZF(!$*0zK%b=?Q zavxcpwyj*1oK$_;62qU1rZ*R~A?)>}_g~H+>TYrqBSp>q~s~56>vjvH| z_U*sB=Lo#Q-)xt3KqElLU#SV{__P^-q60=4$o9JiTVWNMWePRYuQZ&20;1?7aQ+C4 z76pIt@noZav*Z-IcGrLheO`lR_{>I}f;1VMige-}$4{%2yP%Gu*#R)?KSZNd4m7n> zq&CWPOs&PtRi5JF0&j2nSG!;CO!@8nV5z{7VCirB>%JFoFg)PMTQIx~-o#+Im5Via z?$hBMRHvTU0kYU>sT@cQe_D!$P!n3-spuE&ssXw!V3HAivqz2{oK6uTyWA5V_2p@m z!FD2%wH_r2`8T8HTm|EG(hRw?tfKrve<(S=BCIdeFSwDeb)n zvBGUH4rD-K)(h%m*{)xH1ObkTh;C-E&+W0}{3(Rv_|@APV#T}Pjls*K6-y0Z1h|D< zR-)qzXi2b;v~&`+i{x2}#lqexAW$+Kc&2-FcsYI*Ikc>M4M(IA)&HQ#`3uP@9WM!z zCibd{H3NBB<5bOp@x0-7T-92Jz7Yz`y(Gk!s69cbbK13I{aN=wiph1N*qcDiiD{@ zH`w{_6$!x!m^Cni5yWCs@!`gA`G2}lR-lTl!M0{=yx)KW6*5PO8HN9$bp3x9fVpAX zuYk}OU&`ePq%-`93l7E(uQF0W+2GH80{@%Y;QyV6p-;lRc71yrH6hxfD#OnidF{4& z_d(Dk_fkgbOJtD!s!M}6jYIdPuM=*-q1Vk(&?h`f0u{^>eV1JcXt$2=A$`{`KcnX$ z=W}>VeoiRvEk71{qj-$ZKBH|fTZcDsQCP7$mrg(0*<~%wxQG}8>$)3Fvrp$R`f^Wv0v8UCM8YD+* z8~R?>7h}BUq?vW^fsm;7Ui^;copio%ty%UuYS1x}91v;&Z4zlADH* z2_lKo6a9kkE{ykBLK~jv(-)V3zxoeaP*AIXE(CwN%F5$^U@A(3LI_ZUw3V(Cyaf5Y z2UA!i3NY=pH=9BvKnpPigKA-gFVOmg9bkx~kRyiw`!eV)_v&2E(J4|4 zKM$(V41o`)Ktp!*S7?SBvQJJS58FUPhWw8Jqc`9?10<`k71TNs_p=!V!BQ=pA8?Qr z;1d=!dFAf1sU-KiPD1e#hm3c$drg#e=7orKnRLur52{-GXyL)cel3X%tM4kEFSIp4 z@X{p?IwPm>BR3X6)-VDr`NH&Ei1*Yo^imS!nZ%zXhoh3HOZ+Uc9X&5N<_L%}^H*ct z?Rg6d(3g@eIs!PH;>=r!#MmHf3MnL7Roir}lZ~oQb^WfGzp3_I!^C2gpCg|8#rZds z+-TdeY9oM<_5`h?TQ5f5%daX1gda0|`=dQU%~R1 z@A-G_ud;)Io&q$`gC=`ae`p8#hlK|s!*~XOljyFCC}0jz`|2(vzu6~%!cJk^C$A?n z17Dc%e&2F_iQ3t?J#(P-I_~wgh~3N=@Rr$dW#(|8pBNQ~%-gWJW4pP(xhQb`sz25z zQ;CkaCC~qpCiF&Cn?FEXUKs%!j{p0qR-;Blb> z1pi7rjXU(#%}1J_Kk&Z(^KsM>X2$=J^%)#tjRJ_RSgXL5hyIoA>Q$^;J0*rtJ_$M{eG1i_!x-`{n%zg$|qAeYuS;k3N-T!VrZ zP4Cl=bkmZ5c^Cs)@@QK1j3_r3EA11#h3-4L58^I#qg;wUn;2VJvRA>Cfs69h@Ib*6 zz-iTY;uK;Y!@r?^y! zLgsJAyX@_1E-@t`=HyrL{GkQS|s{$o-e zzzkAr!%;VZy>zGcw=>re56+hb@sByo#O$)xdR3?Z6~UbpE^UW*W@Mn1!wXqlx_03~ zCg?i@Dx(r&Ls`2T0|6?wpJNRwOChg7GFk8Fd*!Oa)1&-gx8hROwz0%?)UByne}ZWk z8lcf#KEFVdt*r7zcF5DF9=br?%DTG^YRv%{GQnu3PAXql?1b8?63`jvu6Lu4K2@Qa=!#H=eBfQ+lx$9d71> z)KRrHs2z6?>`L4;nGQ19e=W7qr2yt_VuWB{N<ib_PEdQi^kA_tNbQ+W*;-^YqJ$d(I1l<|agqIqWcP=GerJQBYWtj_e!&4BTLZ?lr7YaM=x%xLDiDY@6B*(Csi z=_{(3KSJZ1dH50u0DxO3D#vgHq&v~zUG_!4eTF(})z2Nz;GAJUR`}uEZ>3qWjXR*? za%qD`8;##z?OkCzfX>i$9_f7UpNf~b_z@(LLzqGdNs4r=dM zoc7LTYXUwgtUNRpd^$2WKPqhZHuzX0_`=ROf4|PBHe4G*E?V?~OqoG1`RmpF% z?f*0MQ*U`D0=-R9RJjU3U$?+?O3Ztg45$V;=sku$Kw~yNw*4V(==}<#{`md=Z{gei z;BF8$LcHK5aQQZNGK;i`B8+m7nFNJ7c&E)0)_{}>EghN$j#uA3fox|()6MhMmB!0s zSapfLg#s#xt3#r%Sub{bv=Q zG=PF$D{|u1dwEb=UE$*WL+lcu*%68jRt#)G8u62#Y1vc~2f(@yR({DcjL;4I7O|8E zT0@0WfFg26wK9BkS$k>6=S}uI|1P_nR>}$us zKE#}YtlKR1HtwLmNrn?t$v5qFbqKo;ij0Q(pgjUu7TlyS2BMx%8K`N;cBcrVS!|sQ zqD_`cCn>N=W7xCU&sj3jpSo~V@||8Y!qYE!^THt8o9lanDa=ef4gEj-irqVA;NBq2lcNI(KN)LTr2dRk@l|iZ#wxLy7j(*m*4b&CQEaUs>#JFKeWqJW!(f zYPS7_l6UvS#d7PRBY&IuoP724@Dog;=j$+F+m&n!FldOs#t8@p(U3hR+F`7E%HZZ9 zu=etOJyy34i1iuUi(EW~N5-1PB~@xiDzc65&nPZ0yr1HFd^o{>cm*;Lkk}r6I9=GQ zuCm+bRREsMnG)D>DJmZq-WV5 z4inOpp8l4Tuon7IhUg1-gz5aPc@V(VJ;2-N(r_^r>>E)CHzBr`0fPdsnevUVm&oNPq zq-{%gqgGifOD)Gt0CEIyC&r6*m$epGVQW1+&4<3;drS&US5OHK)e@%^p|T9Sj60~| zns`Ra;^Ecr!1Uotkk2t_-yZ51U*3m#netf~Kp3{xP%Yc&QHgZ8+3j$-+d%{}6o3!b zY^9CWy}r@{ak+zh?gee-4`o|BwqN>sKHx=h*4P5HFGKIh3KZK%{qj5H$EB6Y_F^e( z5!y6}LRBqTR0$e-E*3S570i~PXJ308yM24I_L0)({dVM4Xp?7THHd%-UJAiCcm{F4 zSbiLbjjk`x^_uW$pkPoF;(V#RckB97N~`02GY3tHV;eUfV|pq-pL%e#sYf~c66hWe zewc@kJ}O%}H#4OFPx$^rDV&=LT`kggyIl~hC(nWmJsEJ(jRJq`P51+Ec9l5u9D0L@?1}dNy6d$X%!%!{ z1h0TR#Qyk4&YTd{HfSxaE_)E=`p~H9gUljx^=KK^e6(LO*(N~~BsO$!H5iR3%pLPW z1IT`ANKG5wt(ZZ*r^Z*FH_-@$&ecvA%U@nBPkF($yR32WS%YKjwd1J!cMiMxl3c0H zHyUb#y$)(OuYY>M16nMuX^=?lwX(+gUXNw2K^sG^S3=&Vz{Pesoa|yL>)a%e7Da%i zeyxD)9^v9_I2h6bhj!1+iZ#uO*8mzE$Kq_wAnFDWPklCZ+9W`Vb$h-;8}*Q$d$}T# zw2R9b{N6S~A(xjcg|RWQQ1^CSb}*$lt8xd1N>4}N1nJnTn9rd_cSbA=dui6^)HiOKE8}z{JlZd9OL$c4i|n1p<;bYZ84De41T((hb4?Dsc)V|1 zMgYK@0n5mK&MTf{gjVga3pK|4mwlYluz@=3y)!e|9?5mV%zK3!LWuda~FbEM$cHpit-TJRP1PmGiCmDv}|@pkmvMIlT<4h_sPgI*EA8c^$;p{K1Gq z0L??jkqKlO(Z}>2G$a!&isfdDID#2G*O=xSAT*0(=MsU_4HwG?ad45|i!_VP91P8}a;Mj~}_j zEg#X!kic>Zpaf#A=Me9k8l4?aQ zPP86^{`JW!{0)yLC9-TAYcLg<)q&caHgNW1h)t}<;kA9Cd>Yn!|B1VBPw+8x`YX+_p@Q><%=O36=eK#}W2W>k z(M7VVJS?BSd^@YC*?o;uK8D;)c**|)!3>Eho3ZWcgCSK4Ld0vu^O)1QWD2)#)&?+R z$c3Lnoeh3AfoTiS%n)G#-I%viHor!aVGI_@dLm}(gAJbDc|62P7S>~?p(Ry|c55J! zHq^3q%Iv*^>0s+w4ozbyuSnU79uAj1v}ljN1c8m8ADfr4!)XYg5q7l$$jD;NKS?j} zSV3BUoQZ0{DlStkf%D$iyl5lWkuUup9H(z!A!eY4>_Oofr!IL%AiiWA3R`~8m3#rW zE{nVMhHm~DWh(8|!@b3l1cZfViUR<0EIFClM)4DJN2prukz;)|;Ud8aG5F?2qqpQ~ zjh52-^Uvyf>uVxBHt8mxyf*Iictb(cXjw4(5zn)&*O74jfi^1zl0heI`s25uCb^jI zc?LuwQSU3#k;eEt%Dj@&Wmn<`IpuEIS~K66EYQ66#>3`Z09=Xf5Jg8sjb!-ZpVGwH zf$90eK2qv>mrVHW*l}TYl@cH2jnM(4C5Du6MRwJljwwY7md#o@8GZ2iWje@4j;mq0WvEpLS0Q&mdt{aqkI+lXlN-r&{fE zycHkanl0f{zVR3lte>(W}5Cvs<}7(Wf;_q!tl=&cMp1aC*d z^wn{rTRB1K0lD5Mdlv*-8^*E!&Qpejn$28?=DpEvfH=ijwV`ET_OK|^vR9KKmpkB3 z9t)W=^`zM*Hf;gT|FuIP{AxU7j$=9e#zJ2CX!-@L$rMz6NtNJ_h&DNx*X7JDV)-7% zvQPnJGSg}f9+3k$V5-V#mG~f#o9=<_K>zm`L81-I$9vt{VHSY_Z~{Vv&`gKAs2BzE zV=$a@J{T_hMv$GX+;CtqzV1uf%jY4zy2%l*-o%9yy(E{9p(eDX{m3Igfi|^y=!8KU^y_!kAc(v>M~fQ`&u)XMZ*i4HL(7<>GL;e zybGB&K$Z+tl+Fdh1*|My?+s7Ju+)_SVo?M9)&{6V8s*@vD1I@ET;r<2spEPiOtv5H zpV->l+iOQ(jb$KavUGm|?q%_uBtuh+HcJvA ze0D$6Sx8$VyG>3`v@G|Do*ts#vU0(uq~VwD#aS617~mnotm1^xa>Vu$w@uF=-^7P! zROprmkzUduMBMNEvVEpdmqYDVW3ZyFpH|rg+o5^}gjRTDoZC^vSh^K9En%cTJGbs+ z^J)l5twQ!{`D%*Q{pGWsZ&$6oF1`qiu)bMZsCjX?`G&|lzqoKk`IsN(^0aMT`ccew zPHpSe*bdXsT+R9b<@3*&O5|BgxaYs}7fMRH;)4Q8Miwkac}jt0Hw;gWn1tlS zd;cr&X93=Rj4m!zFfOhnrwE|*l+Av8u4Ae_#=JgDBL(Cpu+IL3gWDs#tM754V41Y# zbCyp;#{z^*>13XTGH5_9adB`TAgw&o#edY2K^eKkA&8I70q3~oSO$IAxap&jVzE%JnvhHbR0}VqgC19W#Qo)JoiLK28X08pn}y!uE8^U78{Oh=!{Og`qhF#06& zLqyS3{#jL}nN0H`sg(&oRuvUOtIRtWQZ;XUaP5clIbnCVF}J2YHyS00(UqlzVW)XZ zzcYqqfCwJ4S6b0sEKP_QE<#fHK)GFs8i_sF>ACWwwPe2SxjAbPLjPL((z|c%I@sn) z2Adf|gnov=`UwL(KX^hpGe{R~i1j*GWGRyhS#-JyQ!E;vm%S%z|BQT$j>jqQdqu_S zKOVAUw0_)(WmFHdc*{8#coNhLonBLYHg=9Om7+Ri0dR&D%VC0>#BG$dSEV%je6m&V zrZ%6O`6gl`^X5}8-RT26QyxS{#Buo;Z6uQO3#n@IB|^J%awlPoe%J0xiHMx{-R(_j zQ{~ocwCYHdKyUULByc=~*_s|dTg#@*?4E8a8l}&)R!4$mJJd0-<2c$I0U6}^OCKJL z;f%m~!7~Ng;ZH12e#PgzKvM7_5rl5e&@o^Z4=tXLkJI%5KE^Fzqn6~YLL#&G@YAAE z$Da}TemE!V2Z#3AwI2d^U-Oug_?Uolf{>32U%%*@*jXmvVa)7GG|A^*ibH^^>2vm} zPzF;NdVl}Btk`D3W??~TY#r61m7N-WQg@FItcLijFUZsv`J#$Hx8)a+e)5OclcJ@h z&^kW*Q>0k)RoVf#!d=zSydN@sTf9n`k0~OcBYn^Ud~KbX-RgK zSB4Qcy?c4)QE9`ub|Dw}KHutb*WV~_dIv8(k-$KM*58&}5OqNBR^ABXIQg_IIFsfX z5n@-Y)ti5qN6c2-gqb@xCfs;f)-&B8#cJcpOV^&v+bOY{sAGLxv`JzrV?nn~q}rm2 z+WV#_8fU5)ra>^Dp72?szCeJS^SSCHNr4XDkm8!YDOH~qhllCO?-U^I=5Ff=LQIR; z78>l;bOEH{Q*IHd(_kt$2=-s8einFC2aMkeAIudJ%_tU1NW`{W2n%7kkV#2&485aS z^-)N|x6qi-OMnESCfaLKC7a!Hk$J~Kmx@uLp z>{{Xu-CnVj9~?S1?OICL{44?mnA#xiioY&r034_{E5v2tSt6}Ood4@JyUvmw(An9;>k6i5S8DQBi_^n?Spq)gU zA6SFiDdou4u1J2y8v^?#;mcm%FSkc1)4Iu@@nRe?odh2Aqyu34o0^fGM^2XX%$N?h zfGAbimeTcC;>yJyEBE6zKaC1aKEdM-*_z{TQ$t11rSnj4#`g`4Y2ww?*2$fbbTP1? zThG9M$qn+zhAjW1hhJMWYf6oa+sZTo$;JKexq0#U7`ME9Zp2&cvXJUPBCOuuPClA; z&wym9ViWOJVKbWH+vU~!KNk+A0n!4L7P4#YI{Dd85a@0`~*WMrf0 zi2Ht&cF<#bg1)nR$M9^;-p>2#QQ{+89wey|RIFfe_IyJZC9!{L8Mj|?7P-KZfvG>*_R>eK5N&i`Rj#c zI24d~cwS2$o3mQmU+}cJ#x+jp*qtjYBCxFPVfXOiZCL%PIV&0Qd_|WGifqslc3=K- zn)Pd+U6li&_2-I0D=UXWlLJq7y@Y#fpL@kW_>wZ`&Fja22+YD=7hQA;lB&3$i~mr; zqy2|)rA%kkCbPi;gFvT%r0cStub+>}i``3t8aP6P7T7ud(5p-1`SeG-8g?ONx1Z_G zQwS3fb&DV6l$Y`bt1}d5?3ub%-Cv|>bZsb93m>6f=~NjO2z!72%)OrMn-%M`g`#C6 zyTxd!0C=vpSVpCwHKtt=Zm0kqGcbuJ{{`2 zp9b}$0TvHWk_djvfIRQVcrux?6S7ko_ZOB z+}*ley)8}U0$SkKH(AL%vH-GI^=M=5T_}3IYZjLQ(T6Rcb46LTHcHs%>cCio@UM}g zo$Qukesr)znRt6D5Fr`nmTnQWfE2e#JhS{x4vu*(jJ7Qq^Q)XuD6sf}LKP|K1ogI7 zn-0+*(3~p~p@y|(>#Mtk)aS=g=TIM<%i6fJ_52kWX09`r&5%)IX7*;Tb`q)s)3e0=FLoL4N3w&dJv011YI^5|GkdFlcF9byF>hS4*1hgKkZ<$J z4}PQ6nqEwqNLBV-0DSP!$zd)cC2_BE_^|zi`B0 zsFgEZOI4Qk#+VV#qiI6><9-2fVW&}vw$-MWq(~^!_tH@`pIT1f<-cq%r$oBDpJ%6r z*U3kkj`DP}Ik(=gC362_@|}e=*|F@N)v84%!rC=}YU=9b>RarNbdmJiFo9sr9Ey+N z%u=6ZWUL^c%vXfX%-P1@6gZ(G$*RIAb4}S`3ST*6=G{hHi>CS=tgu2IktyimVTqcZ zgcRg`x3WjwhvgrY1-<3C8~mhsedHKMd8nQs+AhZmHrN(6cV6T}%F`0ENgL~t5gjz3 z@Wo=!O#`?!_<4*Gabg-((o(d1kK% zz|~q?AKTr6`uhcMofa}NO!3n^)JHoUtAPMb4fe}nze)vSw-H2byGtj^Se=}H9N2of z9ADp0Pa5Sopeb-t>$T7+G&i(*eIak?dh+2W{@g)t2phk(cszqk>f;5yK^+aest)Dy z;%9KTE3{`kQ>0|S%QKv5P_<&K+BLkb^^B;1D!M$-xhu=n*5=$3J%LNI)%?4Jy7`33 z`Dw60W`XA!OUB;G14_oFHeBxCzCJ`+-U*?0o_Ty~ZZa6|EHX7GKn>we9?qhjvYz@g zEIVPrw%dME28QgZh`(rN0z?)26gRg>P2}&`h=ne2O>R5n)|m>KMP;T$W4wx1hGUXS20I@ zFHa6f*BSI~N@Qktie@a??5aMRey7}MW!(@AH&-EBb#77fOkh0ln2l7xo!^}Yy40~a zbF$KYx1<#m=wA2%GM0Dyq2O_}NR~nMtxK0FA{=g`Z(fE9UMl9m(y|M%gSzf8)}SE% zFOkQqDP`no!?yJfGopu}-3rUTxIF{J?8ae7%qfb}P zxIO>$9dxB-USR8GM3T#3^XKp248MXAet1vTLeJ2R!aj)oR7YRh2t%7Pq(V}_nfUrc zFp}Z!l*M^voZ;=Mt@Y#f1%0Wd`#CA<15lXK`JIndBWc|;)uFM~HFAz_CeiHqWM?*T zG<=<8!Ik2v!;($kyprAjMrrAu1{UbWA-@`X=0bcq%SPIS(Dr(KX_lQ@5<*dVRMEY! zq&=g6);vS+e0BdhYiK)Hv&jx&YWRre%vD_f{X@Ga_~ryy`fE&Ki&Lq&F|Eb~`O&$~ z`iCFecY|`O=kzatCgMV;RPSx*U--*Y0DjCgr2h_fDBsN|1vT5xlh z(htSgTtQL1zzs81L5+AiO?`E2&wMyFm*=^G)-_eEAsflKdGNBj9NxtK5tw3gfwBQx$}H_*)%E+FBeyt!ZLNcw?)OXO~Vfr zQ#0rXF@nzVVgqO<#(-&Gu{2uM}5igk2MuQbh&HYZm z)`?+-h?NY6((qY}%$$#X1ISS7;lTAQZJvn1rQhRD9DCc1M^ZPJ)}~wA+b)V1cad!| zjc{fb37o8Zk2uTJ9=Xb-Dd%7%%_HV>Zq#5<=Wv%}8RczNZ9GG_N!KRHPKKhrE2pBg z`Bv@>Mo4HSKuW}`!Vs6m-&`KfWRa~*!s$f+y2*)nKA)^;@owejn*N9>neq?2l!#Z{ zncu_=5T|02(QWe#VNcjeT2J}9l*7Xmdt`+Sp1vp)PGqL1d*Lhe24;zzV~*k~Uy5pF zkya4p!D#hlU_F#uh161Ph*cD;TAX8K4+5O#P_GS?FXN$ODXhG1_QND9&e~~E#yIrE zO?hm%&p5AlMjT}FN6#4*IE|~NSW&;8?77k*B>YJ9!buWz@RmdPp6D2-toK^bn)Mtl z!=mo2nVYsQf3NYmgW&nL#9>1kxqAVW#34<06rSkam=v~h=DsABYeM&&X(8Om!-W5- z?YjNI2UiV-3!M?Ky3dJz&s!6&awTdZMcdB1U(NHKXDjnwceS|Yq!gotutLh?R7=RC zoo&wf!AYIx(y9jhTwhAJt@vB2p4!H(n)7d9zw!G_Jpqt$q_aK?1JvZ5{@LXHLh*$s zS+P3jFBh=`h}PsA5h@`Fzl;COiR0+f9WCj&@x89&tYntEJFPYQDg~7x39s*DCLdFQ zt-%ug5qj{b0Mr#Zww>{Yll(74PuzX{1G!wa9RqoFyNr3euAhF+;R8W+1tvc}5`DWy z_G||16%-SWAK!WPlsQ6yLgMQCD9RA7>M*twwJ1nP2Tru?LyD8lHT$hK7~q!1x9rw* zncV5I+se>A<8{U|FjrBN>c89i$z+ObcIr#MPtkE4w0VvN5X`4mn^?qH+oN@jZRUeK zQovqI%Ne?UVL_G!_64wj?D-kM&f#QHb@1pzSjBO?v~|zp|9qkTib5mofO6~My(0YG zp!0+dCVwbTH(?8u6u zyW7-vOC?(Pw4kV6WwHC<;rpRWM6$Dk%$7aJrNrlm)l?WwEi1XVawgIMAdp{K`2B)} zW0UX{M;bpWnNQe>$A^i@x9p}rIZx`Nq7(wDJbfg-evqiS(TCU zjWi#+7D*c45(@my+0zDkI@yokU@tLSAs=i@#L)8a3TXGl-6utBh9o;@o)O=BSzg&I z;+>j!uAX$^H1fgMIEg}IsG2QZsT|(h&vl$^T(|I^(_vsBd05ftn;ZI_!9I`F)Sa~r z^LnwUL;TL5)!9-K48C}YVc$ojk-l}Rl#y|tZ{nHi5eD=o#^hS(ib&G4pw7s=dEqKw zq8p?$GDgN+vYs<$1ZYzw+CsSZpjB7Kc7J%D&pLT;Mk~h_?E~V1jyYGQy_5~6i5h$f zH0&T^maCrRz%FvH)lvIP{OxlXty}TjT*uI=;peAe0dQ`z_@D&=&N{vX(v!$$zp)kA zr*=jE4>dJJAiF1=2ByV(H3HlN1Z$Im8ua6g#i(R~Vc(?>@AV>VeyNHDSzvI^;V@v2F$kwm!A6q^RDlE%^t z_&)$*S)#w{eS5Pq%)M`QPaFZAsTmBp)TAK{?W2|!<{^MY1yu+Tp0GiPS}ahr8+$A) z^>E0(BDbu=vOV6SWv7m=niaR-iL;%@JBGPC)K(NwC|b2i8Int@Q*XIu8GoKMRpG|u zT&;dhN9cV_qVV+i6Nl)!Mtg1;0I;Ld6i9oC=?c?G&z35bd*MsI^H6oLj}W4NO8WEA zq~qL)^!i$ge0M9Qgd7XNAnAui(u*2p6G;`xmFEnI(u`Mr#auiY?xJTq^V0Dw?2sNF z#X|D^mG=`mq5zs;PnBeQEq{-BuE(_kU0o2p zBNZ8df~A39erv!>M1|-Z3_b3y)U?{J^}N@DcffpH+&fuf#1HPBCR~{GHLh6ey&pmo zFNM@pDaH!?xlE$(qAn%Wu-8|0s^pI%J(X9?^Nj~y3*;H;XE(ibdIaV)$>z@-GbSa5 z=mKSAH=5LjC3Ag)SzE6zH!V{kpMPySfXmamo>)gI9L-Q!p&2R_!Q){=YCoTE*4D&i zhV6onDDc|71P)b<>G(O+5~lXXsgUy06yoEsj}Zc3H1y4&hPTDK-h6cpO zQsETe;n`?Vr^YcDxeVEf5!iTPTTMjWRaMPd{KkjYJ7vdqmLQZP?g^2=>5k>n$&fWj za!h>nV-Q~qg$HmV>`yUSGuxaVqTKGW=$*%rg2EkyX+;5{lJ9?iMR5m6y2ei-uZaf2v&q2mZ4lj;-`0G zL{d{y=9qRg3niltYz@D%!L)z~L6MBJI)X%=77*~tJf56&$ACI+_&vCMs zE;i$d0DI7=^9hXEf`y3*seJ+rZB1Zjsz89=m>;jQ9LsowmE>(LTq7UIz5nhE+BDAA zU{KA{sWax=#4#{Wl&XdA^JwB@*eF{%lgs=4;0n?S8(aE43^6%RDFwS549uQm@Uj7Zq*lO4%fr!fDs0*{D`A+8SO<^atFO|BD* zTTa0F!RY?d1D;ckh2|emOPP*aoV?d=<`NHtI6B5#XjGhugw*EeTLBCz3-8Wn%lI6? zPT)BnQ8$RlXz61oGZXBEgv(*pc%Sf6r}7Wy^T}-=AE+gWHYd4`a$~5PVEV9nQ{q?% zlT$*py?!+94(m0oFM8GqjU=#O6Qk|`zz5n^At?&fd$kG4yl2{i_K}F%+~n`~n+|{# zLgje6ypQr(=O{y`Zbf1*`AQbBzL)rO*>d$>TPwe%8Mf=G3^JA48sG1T-i@d`Hg_tJ zQjVveAW8EO(DLt?z5!_jC+z%4X(u(fxrSC=W*A>r&cUWM=T&c|h80r6X(`UJc7M{M zuZx_Ujs*FbRb`wa-Cf7ugOx0Xw%op+4ta1sjA$;$_q*l8sWD4x80||@BBpXQZ9|}_4H?{FYmbQP;urfy9z zV0rPmExiF6u*8N>Z+GG$yEC4&loy6|2w2I{?$gW3p1Go4FEP`8O<3z$L+w+{t#1l; z?d_R%VlihJFR1-L%zbwt)&1MQJu4xKC}f0VkF2bWWN*hPbYxRTQYt$#GLKO<*`s46 z8b}>8*(+s-k`zg*-}O1&-Ou;A$Me0v_w#+8-}6@|j&t6h_xqZ!>vdi4BN8~ax>cVv zrKalOdlFB_*6&zsXVGBH0494TfcN^^FL$Vu_T_V)uFU?*W9Y5kuGjtBMH{TCEa;Dm zx1WYijR!Z1T5{B>ruhbBhl6^r>6V(}=NiXfo`qBFL(j|NB|F=CImxQX==)t#PadZdv0 zmU4sS));^zJM%xhseZ=#MG5$oH$UFp^`7%_!jxyEbfObKH4Za!P4dNIdt4~U7JmM?&#_g~iR^=-RJ6~Cs*7Fk&7Oodf9`CTo^j#Ax zTz;Hpb%Ybv-vd+?I?1+ zLth}jc1}ussc?4h&IP@68a#>Pq6lgKNQ)tbZ;7XF!n+XtxH;hUViEX;Y1noH*akD~ z$uoB|5XhZEuIik}{dF_h)kmanb6`sp&z#@6pXfM6b>OUTtCw7$LKX2WQE$-cM+yk* zkoqYpl`>Q4>3PP-#0tSR7DkR+9~aNeGonUGMako+rn!1v_`P%wqzgU$9X`AgA*#=_ zRY(x=5h*IZkJL}~K7MVxYLxQet)~x;QbdQ2T8QTLyn2-#86^pRJ2e0<{$%PZ3tls% zXnCvnZJKR~=ru*}E@k%6oY=#pa@oRx9tNlKHErUV7j5a*YF{P}mYR1WKqo>w7d)SL zG9rqu9&(FrZp#w@k$q&>tuTUGb~4-KF5s5!P)?yzVfL)BZGLM4Kb0ShMRg*DYpyvv z+weVrf{&v?V`GpPjHs74NXXFJuX~L0m{O^5BUk!fr^bZdmbh@^PVhx4W$~N+X_@EN zb|2?@#jRS$U8eE%VzKz$3<})f2V-#wL%vV9mRwW1(tYf66jyvtJjJ@rGhmgfi*rOu zsd1Y|4}!ej24YCH_SK#so``I1*w2V5_R?E#o-Mq)R6T!~pqkL4al@oLfrP`3YlOkD z;KXihIr@&h`rN+i^F(GKDF2v-=M{`y>8sYmCk#8?-h;v?}t}r_HocWO02y6 z))g7EjfeWn&WRuyB3&Y1jss@|{kIki=V0Ps2@v?Q#5ky?Eu6@z8@3^@BmjW<&ELMf z{CDc~Z;(m{oOM|5013--AN%sKZf2CCKHqVphwscKs2yLU zcuYurPjKk5-&uAw=L<2>&qEG$FhAOX8LK40o9XCHjk|DOIbm9C^Jn#p3~e7E`)ejd zkm%?#i=4EU?HS({5JR~maQ|i$-A5uBk}A^9^`=E#&jg4qG|oKBbPwF6d?-)`;bcPE z_xWIuMxc&pX(YK#wjfb5($Ii|^em4P%S;kV5=WlBV|rRdI_ZWaQ37(OA^Psw9Dcf9 zq`J&!F-<*(jQc+Fbb~eQK9NP-0f15JW^O1=ZnErVP-b%@@;v+27PyDWp>iWzGa^ed zXmr&LulJuSfr=7yG!T8ZyBkV%Npyh=(2G&rPU)*@ZACe--+wgCR%dEk#ERzi+Z}#2 z#T=bG$bsvE!*t$ccQ0qzs<)q zD=d#w1-tK*-j(Z75$S{!gn0LR-T3_(oi!QIm(mQJzTa&Ds&hud$bJwYE}b?@~q)sm}Lg z2#PcT=m%MfhJ4^bA}u0!+iKBk5VkvgOZ05_*#EkxfQBu6z|V0zS5T zwubMdP$sV0Lw?PcnHdjkmdP%~)2%%fKW+h$ChePQiQBI#b)$TWzP6fne?W`W#*SdL1dQa~-%Zd{j8aE|uiME0RsI7Hg zV3ALa7Jts88ds39RE_dPoE-0cp%KJfmT_(0G)34E8@K!=j1D0$$T~me!8s}2s~&KI z->5A!Fv#--y+09NhE^U?{RlPj1 zJYc!MwdW}6YIwP+(-1v|y896QhAmOTujt(4O|La$YqfzN7)brYukNx_p^UdJV&`lW zA(iV$mvWu`_T}9X(N^t~YYd|x)jJ>Kd=~roojvYOQ6SA}JFthb)t zDeuHw;wxM@7f4V8$zT-mRyeOzK1gr)tP2_g8%@%SY_x-s#7rC_+%NXmy#?KP!6f;} z5RN03M2_!>r$#!h3bP}B8gQY3q^MOtH{=9CpN%}S%`U%ev(C!iOqT-e{+?bmqkv9% z{B!Y08AC0;(7|C3)miMj=`E*iR_6vZivS1e>Y3wriFA1I41tVGX||LK%WX5CEoolq zvGu%2<;H7&l`@5aVAqu5imxIHc*PJze*aq9S|O5fh| z8`3Fm+9H^GVWRgCY@Q8fp%?MF9ScMbkA?InY52>Qn`G#;=ud#U@4_$8$IdEmuZWOLIfclisTKJJB6&U2PRPq?w zlpJQeL`>r^7!(jyu*;8A+a@DP+W9eOzJ}QZ?XERdeZ|lDY2}ZFH8J1K99l>PURe>} z&q!Mt^Q!;YPVx^Kjt>9cFAn*Q`a&q}N)_vPp<#PrtFpp+;< z^Qc=R&x^7QAusr9a(5yjXbAqj`Uq@`c9WjM{s$1L1bo3au@uw8i%Y?CUkqX24!%IU z)WC0YeIscy6v$NL=G2q>7VEh2Yq_}-)$k#APGH^}ykt$Ki2Zf?aqE}1`RU`1=6+|Y@Vu~Xgnwo+OZ;jN}z0d3r=+z+q zc>c^iS8801p%3#Q4G3n9>fn`!fHIE(Mx*A#IPF7Ve_n*XB#* z{xZiL=_hZN=_Q;V;7`+>@hYvfv0kcqMhHpAOHgxrAtcSeb=3+&(oZTjv{bl(|nh2X&UX13wdcRFZn1d{Jkr?_Mby zd4ng((^)7u`?kj6za}0~JDL~rjV3AXV>wN}*wkD{i3xiD`esEvkDb)gLf+iSmB5SS zbyERfE&0Sqe7&9(KRB=DrX{#*Y35{MO6>9jS=e%1hptL0ETtt7>J=c_k{0-0cO1;5lm7d? z+Ei0h#?<_(bl;q?OPg z^Fn0gon9^Eqs;FrvWa6sKF;mwm2H5w7i%z=+U+}_7N~;NxhGVwfnWt=7bS2Uh1?-4 zS{5-|pPlrTK}`5_Sc?+nKN_BDoK^I@5xI4Dxcx=!oB7=!AfAVK%Do(%p@U*JZ)v18$Xze9al%2F&4+*MON>1X~^{pc5*%4Yln0d zy_b9Mwuxxs+e3et*8ivFz}GmFI`kP9no;}H@!=C}=oq1V^NtF4SB@Ik<)L@zF~x6z zU|w50#Uh|gUKD5ub>Z8MSPoi|h-fftJUb}I{w)h3?299&Sjg*ICMT@p=^w1)T2l}T zg3cR=G4py%uZeF7qji{U^6$fc^2cD*VineoH_TzJT%w*J8|MSv0P+$qYK@h#X-2{u zoyYyMmJG3z3CD9r9I(_Z0?|$uDIaaS<9sHmUdz!oFmJ2I=WCM#>a7Jx2}n) zm?4h+h4)_%Xtd)B>dY2>tQauG-0H*yn-6@1k1sxXq zRIrv~%4t3W#ipv0tpFMgw<2^0;O125wfyQrY$Ox8rM|RHhbR(n0eee0la%Bb7!g^ zQ5?P%+aZ)7xOX&s?t}I5^bwOo9 zw}|(-C}tb*4#EBqDj0H=Yv(K5wKRK#xC=u>`;!2wCNEOHCXih(G4cF%FzQx(!LiDS zE;Hzy*;j!aCl=57!U`UkMbp>8PXfL?{V(E9-4>S`PvxT- zEa*NJlw=C`v|8j?aMi|Vj>z{0GVIBTNkhfwTGspvFWR}>r^e|Fa^HpUy+RmDX1@ew zoOHmRg}Ej3U1`Rq?Ld*RK|qoc(E|p=U`*4E2)6WmN1rMzod0MV{w6G`N$Zqp74))j z+*6~%nY6ctG!eoBM`Yc9EqauuE`9!h*cXhp!m#4O0i3P94D`vk)VL_2vg%8)qn$v_ zqX^LWRiR=ywB40vi820wfMLL#CZDK4S8D*f<6(!GADE$R5ShF#~c(q%W?Z&(4Xk8$`!Cg93P20H;&wC!j0U_?) z434PM#8vw!{e8fl9lL@j#pBMuwe)WjJZV@wi_FJ>7AB*)RC$7nG9#u#e{d@;*Y0(CPj5gNhI>FH|eiiZn_$TG#wbUDECcfYp5D zNQSn~2>m4`1FL>O8LKVapr&_pec%MoRj9Id09v{9o1N4nMne};VujAcc_fauS;LdO zwH4QcCa$E2Ab$kz^Q#>D%Xy;^m~u?BiK%bx?cif!(U505l>0eQe(~)DN~zJ z$aa8v03sfTPa0?TVxJw-eS=pUv7dX>C z1-QG^lYr{LrMPkp(7GxxbxyvuwG{_#&i*iWZaVip8|-<;k9-nNQ*Z5}@G`U+?Wq6T z!Bq_z=6UXkd0nY{UFWZ12h#5xS+L|;#_I>$6#?FFNxHQDfmp$)RI$DFdDg23%~&5V zuX^>PzRoWOvRFM7T1` z+7kfCrj8}Mt#r?Zp;&RhmLmif3!BF$k#m2 zkBV=tKP61-5dU$Pi&ekPz+pF?e>s0DCqh4|ia$GfbU6ey_5~Y=(Q0wYmwJ6t>EPh- z2C0m7#^MF)y8Oh%W2ORHfSE^*g|okAh?RCPI=iGEI`?d#G-Y`QqW1T-JgyLEq38JY zraW#L?S%Z3If0|o#2q+MgQ1)z*?B(`AS??FZpUuS;_IkGR>^!YWc?vjbIMepngcFc zRnW2wB$JueTImwve2XiYPKN?E%>443*q@-rVm1lUbCB-!vHDbI-Pj<@7xlsJvM09c zYv2zW8nJggZ@$%!RD1e|Vl5xF@^^+^Ma$oUHn(f|#(O;fg!k6ZXea z&eohOjmjFZ``mwn0xgsbK1v~Fc0MJ@<$@%=-}Ii9C;WpDmbr5xo;DmdT6do+xtt|{ zQ&X|hgs_MGvJIw+u;TRpTPp@VxpjGC6bD3up$bfhqgj~+ClDswMVa!9i3Q@U2)Z(6 z5*MaTJ2$i-TMnVQC>!8Mb~}+6AarW4<#T()O<26z8XZSqv&-fC4!OZkHaBkk35FfL zn1DfvP?yFJWncOlP{YR&mQ(Gsn|35YPw%<8^XZ)RT33`({YWnloBPDlG@q}F3@vnCAXIh>qr0uDdZ}3Jur0-Z5{m$it1JMQ_G760= zZ@pU51uvKau)0OaJVZ77uH-NYA?acic}3QJ_e)@Ohp&sZz)zw)zo-6DYRTaS?ekQ& zYmfX`>NED#4OqlnWJ&IHU#8G0Y&7&PHZR=sQ9r#6^PNw3IQ>GmO>6lbo`(&PsBAAy zICWFuBui4SZ9;^n{*+PUgg#5f?S0w~pD54);x$ky= zeW*@`v>KTvHh*ET#MG-x?5Rml!wSmgo=F_vksknx?KZPG6DBQ(qZ6$BN!~O-yCGY_ zA=%1}uz!-RP~JCmGW?U0RXS^_p$I_4hHHGgXbBV%I*rH(l6?N7xeQ%C4^G~k6`!WZy!vs0Ar^?d-*+Iq<2XQy$bfR zk%syXWOWR4c{Qz&J&o0NdR+MxdEf$GvC72v7cHu(KW71SRhqjzIqt*7rsLE#VHS+? zg!xomJ5Y>aqd14CuA4`0Kn!VkS055VSBOFTSZ&f=$6;-K=i3j_0yr(B6~|-5MKO=m zq9*UV{LUf){>}jSJA!5PuXJX-NAUNW0{}`FUSy%iXrUZT9#>^kz_u!$kVm1i3T+zS z%xglN`hXCp2BV~yxDhV2B_A@u23kE*yE2b9GR&pBR>Z^L!yz_r4eWK)L=Ff(69O=W zm2~xoAiEiHz%U;&DE?p+J&D-iEJFHi{=4*>5o~1~3&s+$mFG{bL<2-OQLw(5_*izZ z6Xsml^XGAelS=`lw`MqIXZKm*?U=(6C8WKI9f?W-yesBV!KU|tDTiYH>@Dw}__nCM zH%HNmpWF32fYj`n@Y6Fr;qiPuah*vJ^LDIuXMr6Cgw?A|SvSt%XUvu?@8E?Jt7|rm5k5AE0LH+QT6IU%cKOK zMcc`d3nyZg2emDj%NeER`k@}|q|_+`{f#F-$6 zv#mF2TGO|d9AL8Gog_p0)?ijB*?vC!wyG*1^ZZ-f&4WmRB=*Ofe_@g!*yK-$K-Niro9z4s&C2nP0m~y&6>t?f zzrv&1nMetsIDLSo_-WR_i6cBsCp93lIzDt<*}{NS1*;J->^4yqFzF_=mJ*9(;XhR} zj4g;~CNHjH@Wi2`Dp*Mq%mE~?4cg!ry+X^0?^A3?8~q~P`@08s@g8D_&l)_v-9^Z0 zmw#6}hK<|N5abO!)fJruQ;1yG{$;st8X?ru3f8Z@?u(6vnJ0(Znuqp2VIcU_ zwX@v(L&2$kQgy>;Ov<8|G2v zLWXAD)1~PH*=z}x*dPY(9Oesj<>O* zcm~W20+s1TKIe>t$GsMjnVn9CW{8-TrvYWAd(}3y(a|_RmzgPnyX$Hpa?L;)3sI$} zp6uKngd%`}egttCiuTr8>xxipbTl5u$@AAwcxm4k_@IEKm$`w`oKEV!zX3tX&~kGO zcE1Nf}~9gF*^brv*ixx{H$DB zzPq;jKjGkKOI9I|bO7(E=PxP->z4filVLfuV{eV6LFWlW_)8>)Ltf}RsqaUOK1KXY zR29N$ypVB0I)D%v$Nx!WjIcW(wLZZV3W$l^^ngHi&MLvc@06`<$J>2wYCkb63cXi5 zWRiciA%H*)eQ7>|$GMG)H>|eM7Bud+sqxqk^7b_~on}^#P$9PrMm2c57n*`u-glt> z@1n`SU?p;86pon4KI!f%!rGj->da2U+OWwPd-V{I1iUN)lKfQCRcajDg;zmHg?B&z z9G(<^8a)&K+V2V=L)EDJJn^b`I#G<`Ci|XO2`@IMb6$8ir9n=DVBSP}28>8<`oF%l ze0VRI65Y;Ucq_q@hZ>j9%OjCX@)_wtyxn-71}}b01q(jl^QxETKrn2tN8SJW_9ElU zkOm6f;#nAOBlPn7w@-6G*-2?`ZsuQbB}Yo4aLM{@%R&}Ch(rqc0g~>wfof2{LRRAz zQne}gH&ks1SQPs?7p7i=3j}`>erp&r3;qPd{^Zs#{shV-fa)%Pg+7$gyXW${ z(j*E?5(=}f+(W4ErgjPd@CDLH7X0KdK;kJ?z{lus3~}_4#ZnsBUMh%ccy`jQ0XV_@ zZt`EK&w>Hy0WbZ2-kuPcIQK-w6wq`o#EfPxj38m24FW`-ptaZH%UV@t{&X>B?p5ZE z(n9x(&>aRN!FYhFS$!K1j3~^*uN$Ro-q#SH{VG3Dy4K14_XV7P3)m|`GpO{2?NX~x zsfSh~>G0W?eL+MgiboQL@(<+%p`16hj%Mtdo4L~Nv};jiC%z8@DbF zSJ((vUT6~(SQLJ^?IpXIpD6GY2@nHmWcahYg)FUuQF;R6=ZGMrpngQ{5=;OyEcr(@ zU#v+UHBTa5jJKiDo?U&L(y5B(z8?bfE&R?YNE@BYBVi#R$Ay25apbt?OySA;US(_z zMMHlF-{IHC-nAil_m!OunAG$v9Zzjb(J1Yh)VfTP#m`MAwG>`PfLmlW>@abo5E$;1-=Ok}1s@pY1ra&`}TpAgKsg^IH9A-)5up%Y@VklN;eU;v=I#BbP@dr`|b5Ah_WM?oOTw5XD7m}`Fr-l9>Fk7`$m~ny&lX(k3Df=qUQ%hU<4Cw zBJVmWXo3Od0|Uz0ZzE6sOQMX`yT>ek(j3BZ83GmKpTy<=Rh7ATCl=cJ5{}eXArHYL zl~p5`_c377MO$9*Dv~)JQNbGA3PrRS@O+Ogo)d5BG3Z*CWH#?EvB)kR2Is}Tf?$OP zL*x{qNiOaee^E<@zMz7!28sUvusQ%p;&r0cV9zK3SczM#01QuM;rVPNn2)rg@Q0Tn zrUP=+<_A)3k7^_F6EK-jMc{$eIcT6sweE3Ny1FLbfUU4+Y>!y)dOxhTJvyr%V{>V7Zs>S>^mrHbiTP*$`JOe~FuiKNC60=$Li6?oH#n2O<eeK~tgnl69K~tdMY28k>V-ZG zEocTg4J)}Cc6?oTjkXij5sUFD9l_oHX$6{Mp&=X`EV6bjZ>F6S$X%Voa!q4)l7W%J8qaaH;{IoxaK#;md01Ydc5K9qe zR$~F$Da3sD$_4^R7@ZP)LWOX}AACJM89hJ+FOD{^eT7W1?Zyg93PD?*-;>;*?xdW- zpA>+Gk`iIXN)=71cYdh52G+vzp+olq)bZDc$+QXK z6_mh!*>h{l2y=p9|5Ez-+u?Pen^P`Q*zfLu$kKn#D?|MGC*fh5P_!&OISO;(f(1+fM7PT)>!eXb5;$t6>ZQHDlrpn*%~A^7 z-NiG(HXx+DshQO_U<3QNgj^NhamN)@R_ep0Dk4^%_kHJ7iMqT17Ehmn`MT|5PuYxR z7w%R2EE=g~PW&uU>`YsM^s$LUyn@IPas`C7@b?MR&eD?417Y^gnYp=WRJUNta#G|2 zzXmgot_~2}ei2hMhn+4(;)>>!7gU}Uu3HTSOltiH3_PeVDkF0kt_5S-<{@MNu*Bm) z0spZ=XTi?xF{EB8O?xs=2KKZc3WyPpUu}cRwKeMhRc=L}ulo-;t?Vx82wO*C-R9~Yzvhw$Qtdse`_ z^|Jn|Y7zz3Z&Kn{6ou+h_wQ|-$SLZqa{~K4f_sph2lT7b=Yze4vr1on@}%aSTsADf#5>+m5L5sgN+LYOm z3Q$ExooA!MC*%4@mKKbaG*sp98FlhMc;uC`tE{fL?1J`_o|2;lUWOXl$F%H^>~S2c zab6oaeWU%&Xtm%Xa!;Y`#>nYzG2levTiC&yTIT#=4)nvI_tA8xIcg(N2m^Y5&j4U? z&hV^w ziy|ITM>SIJO%FU5UL1UNvL}_(dCs`rGWu1vvg^Z_yTZkaf#Cf}9nc_BJr}&1&v>n) z>v(P6tK+Qt0?vw;SmFv~CKz&V$Soa~UAX0qc|YBD<~$Z*pd<9In>b`*KQgw{SkMoq zaslIKY+&uQ2G}nMuiOtmH)FuG-#zI3pdnl=&#JibaGM8vfPoCH67SOUXJ=n}uf1A@ z=>Xv(_^szR9SokFVOl!yuKoE6y$3%(@CoAC@0p37O9VnezSZyjoa0raES|R;EYD)D z7AgNL!tfCjSCB*^laOg~VZ80ByS96_ybEt~UC(Z#@{mUH^sEP`YaNVFC#o8(Y~*{t zxI5m=)@Q&``_7Z~iRcZaZ?iAX`U+S5PK{OL`@2w3_7*A}zmsUWu%Fr!$tc;QwGlUW z+!P*P*Y9l-4DhHiD3Oz>#tJ@SL0bx+%O!yniYqS)xQkq)`}mg$VyyBzG5&oap)XJy ztTx9`ei6Wu)Ge*~h|?`Gx$quJ&I^hB%=DNvF0Ch!+YNX@4Z*TQ^cIqdi&j8Q0PuqM ztp9=sfeaZC(3h2V%#CF85rXIC_#o+{S-Hbly?x}1v<|x6r{9S!okQWlRtrFw`g|2T)*~%b);p*ECMZ=H9^UOaf)U z5k#Da$RQ8-X^qR>&=6zf87YF3eq}F|U zVUs6rkS=^_2OHltAqAUxx_79r1!-HECex+yo}canIKm^zXybPnA&pj~$UO(jWm<+n zAz=u;rUmDjU{bs~4Q}xnpD~IHFRq>tO$m=Xr|hGWUz^;a&LXs{FxlqVXA~NzhFpc5 zrg-_{vBg|2NFa^E>`byln0mwee1@@&1^UeR9PejIr1J+ws>9D5#*2eAhls?KnL-4) z@4#On7ymZ}wSyEN{B9YDvX@2ys&qB2&N&;5+S%TQH{n1yRqYRN$mfS6WxATb5c2&_ zVN*}09_7p~t3F+OuC>!`;>@G?i}BOq`{Yw5l`quDtuPN&9y5tH=@|4Yt8;U4YjF22 zQAslyaE_8hai0wO_~7CDxmC94AH-xi=Y7w=aK_=>lF-*|yz7|YG@x?HAm)vv^ATi#8c zvSB>>^8l0L$+UazPq|>_hxg=#l_|@Q2gUhPbptd&=)HyNZ7xO*Ab+7x1)l?!UE$^j z82hxgQUCCM$@PbyyFF($mA2y9F1@*fgU)Ev9%wuA^_ojI@5oh;A|vX8*%5sm4M@*) zkK8g!Wy~?t|29Jp;_c;r_`BTxTiUd~o3(N8myj6_gupQS!*!6Uf5wE!&96qMIJgyr zA2_ ziX4x?J&wMyfFM+SpN1Kc~v_Gx{+ zU253+5WLrMJr^pu0DOq*ZDoXrvwSJUPd9Ubtyh75yN4YWPvLy-dTxsmtP-CosKc83{ZR38o@yTvwdeIc1G| zUrbVd#i_RcK2G(KD@UDQ%iJ9OG0Ov7YZUusasjRXf|en9o8kMjDiAF3H%)BmP?2mf zh6d(gmlHaG;&Iz4Uqnp3ey5~27e3*3x==?HFt$S1u?SJUFOQb1yxc^%!|9w%`5T6X z4=~g0+1i~Om&*jcEe)1JlII_VOlP0`y5iTVzSqV?_B+z4s`%v%L9NQ`#!g)MGXAmy z+FT7}uPzp(bO6iR?1#BvuLzV*y4_eeA#@q1yes3VC8q@Q|l_e8WFRY3Yn zH@^Dnq>0%GI%yLmo>T_tky9R?~1#MitUJWvuR zZNHU@-&x`qB|6VqTjJ@NY$v1?c$9yjH1FW;M4 zHE-WUVsD9(T3KM$OVduSO?%P6^<-8o`U)%hNAb;j@gyt{U^Gwr{+F8?Z_MgoaBUuP zP4o$xuwdkpYn*L!a7V-44D_gYVYO%YjDWr2=i*WW<5~?s8#nC z$S}z8v01o@HKQwioh$N63pF1e8EjhetzG|$5mDX1INbcHh)i{GxN!HYHuqq4%tA2{ z(`tU;1MBpaF@6c$>sMZPPJuEHRQ;mN5{HleQl6)hdYXAuHnvRrX%kX>Cqm(KzW2wf?80U+#*`a;12 zAx!-m*LNf{#*wef1EmCM`zu2+)x5m&r5!4=I)nA2XZ9mQ5*$F9%lX@$#O$97B;6GU z*D%+gVar6?Mo0<9$so=FgoN@W1G*eZMidt<~92yu=*o)6jq??Ja7 z(hm4DMU$HApvMQxZES3GAdQO1KxPMIt}TLw8?A@Y?(KkZ*zt^YFu0jH?Ai4b1= zkV0gBxDG-RtY8a71)B3-an#h*TxVfEcFg$5<4dERCwVCkT*yvqAg`{1q2gELTq_zh z;KrsxURx?aK7Zj_T{wr)2Hetav-2-BS<-{*6Jy_*!yjB(cU&9T+b{|`>-2+jy6w#?PQZVXP`hj-}m?)ytA0!aF;R`D7;26J7M_*)1A=XX3B%s zzXW8Dw0<~EufF8SafGE@r+omO4MavrrQ;dIL~p<+yk9Z>3;08>FyI?o%MfKQOZ7NY z!vvRQiyzEI;NBE-GE(CPk_0I^Iw5Y{Llbs_>FWEaQ+Q%ZHiKVR0*|M+CU#B+zLL6i z@Bj44Aw#(lZl_V3l4hPs>oe_9BkvTj;=(X`Q5PS}TIaP~ef4|I&|E=XKD zplejE7Y+5V#*@HFC4NZ{rNBg1@pukx-3Bir9LlOr2LbR++vYk`^r=g)KO%*srbyZ7 zA7O#+V$}UD51Yg1wX|mFm^38Iv@f20^r-`P!k?HB`+`sf7qU}@A`BR8LQkk zyG<%8F-*i=0S3mi>lHTA(tT!E-qWI zMxbhM?}>Lf2UqDbR~LIyb7^0e{Xl;RcY%TYyBdaHM%Fbf9FN;=MX82G2Jo!eyj-_g zi*{5v_Tb%@2c9nG$n7;9gj+y9^vD;^i)@{pY9aUIOA)#9J}S!0PLYhG${B8@mKK`H zcyUS%D{iz*nJB~xFT7OuP*ZW>xC1Aw!#pqTMyJc(sR@94cpJP!@d@&z16e7s#FO6! z6m^N>9nv+AX2Aj%XyP5jZf;>Ya04k42gSIe_zYL|)v&>d)n-l+DB3%|${7Fto0Htn zE)G>QuS}QFX2Rg!4y(yEtm%={m#LJqOJOPZ(nn*)MUbVe{Ld~$Z1M5R69!jROr&~( z7e_BWJRG|gcViHY{uF!lKls(0xUOs!Oqxo|4k|b4pfIjZGz*>IT=v7ffG4=43@ke? ziTAC4qNJ3oFdzZ}B1y>R+-$}~QN$NS-`cHC0;{FwIAT6Een1U-Mh_*F z1;??%#)+HlAzEMH3B)I)DAnpSDr3^-4WIF;VcQy1G3c(_4|pO_>vLV2(LlD^v&|S^ z@M^Y9GIj}0qi^nia7z%uKs@-% z?e3RNXZM=htd2kYI=+Tno>=PfWNUYzvN<&e&PS?xfBmTlltcam9j}!nyqFC=%&=P- zqeaxvK!a$Kq_@5EXOJPU1*&)+KZ+tH5!kt=?c&RzPo0uCE-G^3P(4gso%Yh}LMoVI z5l2OPHD>+XS6LA#az4lM-reY-mby0@7q2|;|E@6;&$deh)yTB0LPWB=DCCi9e3&Nf zV7zkfNB9PX$R?$)wm9`5e8Q2v1DSEK`>G=!O4r=!xHueaIm z+E1lSL{h%n;xKO2DaY)Bv6CyT-11nXQS%kiQE8kLO z<7>S_+_KD_E0h#Yj@Q`-D~G|~;P|@hk4B(Ija7((;J+lq6rQnohUhqw65c?(k?)q= z-t8KEf*JR5soog8(IdOkYs_5Q#Qw+ksg}Ha7uu+1DhTEcO>Yr9NRRe1@T}HT!%8n& zj=D#o^!+=GXJh48QpaN8Dsz8i$d;YUfrA`G#8hg0fp`XXe)4%vkmhS&=4!W^F*k4E zy=2)H{jjjT8v|Q)l<_>51gJQW?dN#(s+%4iOmCz-2um!{JGBt3iP6e3Y7D2#x_d}l z=GY0RM0ndxpXI~*le3HrPUADMsJ#zi**RdD7Em#`X^dSAwcDCn(WrALqP>U;+ImPS6h*Z7<@ z{Cm#7!YcisPTgC2^iTo*{2;+oi2rj>v1rVyIru6TwJ4z(QwnN<4l_G4td1$>sEj}{ zaXT06rbnB#)IRBmK)EWPW(0G)i-cAkTP=QIr94cmGy~A_Z(T?LcU`=2nI`Gtn+#GNwL|)!Z58kR^oaZ0+t$}{d_RxdS?=ZM} z5+1_;v+);6&_3q_{EBUh)5kDc@q$YhV5K^@-rU3SaA}jLsln2;1j8cmIED%-GYt%y zl=UVEtS=(4w+!PQ=;8}vQ49qjI{Y)Dt0C>oJGr!3WI3;s!D<+kJt|rdne4u0VW|k; zGa;*rY0n^DdZtYj&m(fxB?}}|y!%ia*go~77#%jbC>Pr+aN1=WfIl+J8SM6@M+@Ku;kmUe`PJ00$!E*&h;wza_KuMKyQ~Xh?Bk~y zxjMNz5c8ofC91>@JL6NPT}~ZeFwmv*I1DdNoX6_|iFbQbwx_Of-($^$>f2W`-5eYkAl>uhzxhng!fQ9pLBGE~phEY%M`D;iY!1Qk> zeRMc{3CA-`=aI0#vQfPvc|Yb$(+TgdIpwd+W)4!BXu)Y*8VeW?!6zJLBHMi&{=3_D zbv^>sOvRt=0>@!mSn1td&x7OM|9D7y;=R^z!+l&u$c&ot5IcH$b{zjWr@!IUK+E~V zdJkyF`CB6O;1lHfHp$N$N=E1x-(~;}>M`3DUr>JeB8;--N7~xP1TwuiiGs!uO~#5% z+bBbXLY(@z8tk>H>kEYlR3xLauRr{Dc1qE(6bWv4vsdd*z+xKu!+#UpWS@_BC?@t* z!!G7O{g$7>z}KOsJs63{?YT{2X1VmKB2b2!^M{j3fy~;0Ux%qY6cih```@eBF2k!+ z#wvao6%4h$qXQ2I&IDVfGsT62`wULhMxaW2OYCewf~(=XrR_Sd5s~2#qY<8YeFn5A zveWxT-?lBz{RSQU~Jgpg(oCkkCULp6NP zch;XY7oyig?cj%U@5~2j(k936&2>tHg>?q%3r`{%##e10gCkfKx<9olGqEUsFLm}0 z@V2gnmI-HAVZXxl?NGcC#*!ph6+<0ZKX6-C4NJHv@BNLd0mYkU=tZeY|CkAA*erEx zxd)~GCh~8l!<1jb`R)V!!qstp~qAMuD4hs)oLXQ^b z4ixwRkI>$FIoAwvKRU`&?#b|xUWGgVD7@N(wRS+pvTFn>K&DD<`FvI(y4)7QzIH!*sh zi5r-F=^>C{JJ^tDcf;|fDm{@11R z>TM0_Lq=R}j7`V1nd4CU6%V;~&GGof3|ex>t=x4!$9GsR0yK2@1Ihh|aQRc#r-UL< zy0XVG&bh7P5KUM@Hokn%wc(IzNWfI2&C4TWwpKS3)r4H zj9HnXH0pDqHT=oDjWi9K#7i%ZhF|{5T;74+aL5(9Y%|huO3`=f z;OzJN%YL&TuGP2u%ib$BxEe4(pL2xG{<_*D$Slrj?2YGPu(-~D6FzEVd-IfEK@Q`r`3+J*ugQkZ2H={g4MxAAjTRO z8iDT%y{{WTckafeTTRec^3dYz)KpH>l4)fELLJ+?@ScJMlu?k3`8vuGqcD*e@-5a( z9_xPJq0M&0+s}ODAlr*>`%aJAN50mA&N;nKhHqv+&edoON8KT_3l2g3 zG!or7%!VDdgy^?V{8eJqCnONwdv6|(<+}b2N0cFR z5kd%&p(tgpj3Ek^T6@4k!qUpK9EQC2rtxK%J? zcJ}5#w0Dm*pKkA(_soSQh4hhg9?jwDMr$Tws84|y3GvKPNlwB0w^lSy+aX^la(W{I z9>jh9cy(CYiS)EcJ`8_X?F~$I1%1t}gV00PN(dIbH#YJ$J3Hpkt}E^?BqC`Rmcr6< z(dU?Asuj&j6G)OFkyQIlK|*(KAS4kA63*uTXvbeb*=+4At}WK)aeP$s(2YZg^~nWU zG$rkNryL!2_aYOVUk2^kTZf6}*p6Iuy?yq_PWA*T0;OOI4!8U4KOJrW|A(icHRz23XCFj(4!HN9fw)Bf)uT^v(0lBNiYj{x%a#SY_JgUm*#kXlPmB{c#FvY<`>q zPRJhvi7-QT0J})mW=;c?`TVtGP8haTjt!j_NQr}RW`!h9Hw!KliuJ#*p@F2S0!f;( z;&(%pP(dw^gSFS&<>@$!)QDnn`>Z38ppEi|{K3L^bspJmTyoF$%UiDQRm^)2A&TBu z^yQ7iLeG4=(Qv}$2)k!ZsCC7YK3+mmq0KLuh;=H$*CTJ?Zw{`i5OGzKde3?i>?}Wm z9{4PFWU1BV~}<}o;D>R+Fi`(rVDc?px8nBy(TTl@%mSPfgX zlgAw(%5D4a$uDQXM11gh-Au%MeXxW;z80 za^Nyl%J*I7k<9Tg=G5ntKOd&oBjdmkzEwCyIYPYes08AAl@3Z=*~6hyPggN9&5p@Z z<{w=JJ6u5C;>8{)UJ)>+FoW~!cK+$;HFvu~yBc|c;YOlGZUl?myIj$T1O(wQ%6Tge zWDoTJ?mb`?(E2~T2VwNj|AsvXmbpK~OV1U^Kcw;*n}!PJEhvA@fn}2~U+cQxN~2qb zt}j|9M5o!ljEV{aGV~=N^BBl1ex3(N?#{FLZ3DmP$p|=*tJ!cMxpcUno#chROf4C` z^aTO1+?-=|=-@BoL@Pe%Y1|rbzvDD znMhT?spm^Iozjx5p{L4sm}!o4;1p8bRuTbCOS-gq0!+gZ0?^<1QV5u1I?Gcz2QX0x zPa~~d^x^$7$_biWevR0S9&RKvA{G36@E? z$PM^n_<@y?eF%&=;wbEygFrviA;0pFdEWq5F%iv+hMI<>+f^)p-7kC?5RtPyk*$k`>}#?)ZMN#gG}dq-1s!Us`?y{{vdnCU+I6C-Nf%FSqvr z5k)QHxc)k#U5Auh`GJK*XfFDOA{=0G@eU^&70t#KSW~8mIaH2El9EBZxE1ox%|Mbn zEwgS}@ECdV42Z`NLp;V`$JSswfLE|F0VF+89y8^LKpuVf=f2yhrH3DumcsnjEili6 z|8!MgZ&sR$AGba3@aix?tKM?a2D*<1hQZf{b5>WJMO?Lk+P&sN1i(l<_A5t8?}Km| z+#O&9mk;#{Lb$Pg*?$b5`z)5ZwAbW8(@_E7dNnd;?l|Fqw^N&kE`?K zwp4iarO}_)rK#($o{qeSZA~;ND&>Kb(3<_iMlSkcIQ&(l08)nV-`+v=9Q?k9?F5Lo zzEdv-$&jCZuw#9~BZ;A!q;8>*$^>eFCGp!k)nfg%cPGB`LhiP=fABSE5 zaGDN)*Z=kFL0A$#Q+WOslZUR!{-{I|GnenKT;2b0`f^(3o zjeSHd7ftj);sTt$-QoM47^rnkT5Zbt;lVG#K{#LDML?pZ0tDL=)VclNfp?TGe=jGx z_QmL~v>&ve=-GD+w9n6bbrLE<8j*rznn*suBZu|{4@HQ5cL}LU)8gn&>-smPF$L#6jl`9*1@INon7M6LVevFZhCSQ4{;E(wynf2XZg3IqEq zkdzKvillUZ9rF*C0UYcaoi*^){Nc7S09uKDvWtYg?Y^W!FS%T_K>VM3umnWK{|#U4TFYumCR88wo?M9M&A>NdAWliW zfmKN4cyvCyu&CpD6=x^5EvM{KlKb?vQpaQoj5A*|^9vCqm+$_=kcc&8yiUzIzWClb+^`m{NeG=s>@A=H=p1`n;Mm3N$r?B+?oSTtvq_ zIdGTX@SGS^^f`E|J6Y?U<33AX^k?0iDiUknkEhnERQdC)f-W@)AIonlgYeYBRjk{~ zj9M-u^}!WE7s{6RtD|HZEOd%?+&79_D2*sg{1-fxEhnh-{82HPc~zd z=|g`F;0v^#ZzCE0}_>%*c3%}#%x?oFs7GIf=+VW=Id~{0VzQ3 zw@oMzP~|aEL2aQZN*m#O9$R)x3&5iaDi}T>6j-TC8XRekx^*=qFej<~Tbcd~h!Ju% zQD7(Dr9u|4;_Si8p;A0wK0N?6gi1~dET!}DJaQCHaCK^*Xelk=RR=LAh@h@CC$YS3 zZ9AQXeVKAMy0%_?1Tbh`PKG~rgMj|tMzRLdiDrOlxu{T9iRu8?O2>4=hUBYqk3yZ? z+EuU|z#ktyJF>)pHWyQ4&Zp*54+>{Up&M`ss8pO@}=N4x}d*rJtUvld}D!X5< z=-v*H=jGD8elk*Z2e>HwAL~F$k8n~D!rO2^lc9pwlILvKM@$T{6b5^H+&!N|Bo}=&z=i^0CKw5O zKI+8>V=;U9J9rVAhOvA050gem9R(&L2$`xEd|087bt}=FL=;5xrEi<$&Em|3#df33 z!{0ouR5a@@2oWl-hC_3sCl)jgqTC=+?jqp$N{0G4_ez|C| z%>F38K}qK)#I&0^v)E9LXML_{uB3TQxg8soI`rYHfs`>7>pIh znK@E&;11nC?*ZJ|bw(pQ9~uA{3yxWXA5eWBhHC$Pq_h)O?fSu> zcKCw^U2+EC8M{bD_>l5g;>Lg9a;Ma#GI`Elva@&tx)s(QLy}lTL*X1yb3-!-+xj@c zZY&2u1ioj!zz+S3z6-eOe_C)>M1vkUoBz5bdUU#KC50~x>z_XS+v^Hcz8 ziBeL1`5-;hiYsc`RT+wPoI6-_j&0;%uNHb_ShDUdWO=F&(_*5Qy7&@z%?be6fC+ee zuL-fS*S&GZ5A)CHkaOUkvKZrqhz?T6`-(p8NNPerq{E_Cjn)Y1U+x(r&`-tP3m^wh z!?bb&9%iQg7y)oF>9=+ukRcJ_o`F7yJ%AR~9_D?7FvpjkX6K&)OgmmL!s+wjW4Mm} zWmojN!66(sI$@i<)b%`YPo5e;UU8|l9%tZPZr6b2eApbef|8Bx_z7^n{oLPI21-h` zd2T!l7>n-G-*dziItMv1&Qh;#EK#AkDC?UdZ(&>Gv;<2zU0dvPmv~@3RMsA45CEtx zN?@y+EhTR@?)vcLRlvF%v6w&1|Ipz04D+**SjW9U;ff2N$tSv zD<*p$mEJ2+sb5{@Zq#b^zF(gsbgq1>aV?+iip>TpmM61TXQwfjvNE&uV`tPLF29nbWVlLi*n6HGn{n7`>^-&W_Zmwh+mD^;Q!n&q%D+=1QLN3 zj!op)(2<#M!KJ9keZjC0{NXITP}Xt=s#Y;~xEX_Ml8atq)Z7Um_fY~{D?`*&iI9Zj zeaWfk`Xd>mcF{+co4>0zG#n;KY<~+ggaHFW-X5O-zteFzMrRDcJ8=zPOA(P>smPb# zi}<`=c5WO594h1Ws&BtFN^F`;vzXUG0HF;wQ*4C1S&1WJ=AH{{SLO!Dsl48_Lwwz< zdOBXO(=ZsT;Sf$@%Lu<8cfraBQEr@I#}qWjGm#zF>&6#`BHuDA@q-{&g2>;(Fwew94{`}J6;BL1?Ls360ozfc4TsC=ia?}$Ru zi92wZ@!CqQStb*jOP27N@OBEwW3`ANK16#5v=LPns^@nz)wE^-^+OQEH(~=Bjoj*u z-d3h$U50HpLF1Lco3orVo9lnkv=baR@{zGjkh9+c8v^p*LXIAQc!v%WTS`bb#V!cH zMaTdk!0!=VFyBrB)!QeK{UAMDksS^jG+pq!xvFxqK4QrltHs^&tgV?7jg^T3n= zfkgw_C?1X-F29<%d&vs$!0XA!kRVD-_x@>zOD6|rZ^EIo`uAWdgd=;5>Y{>#<@8XxdLtP<5!5{0?r|`*8Du&f{ZZ+;w40NV3-a z@Xu_okszl1*@cA2gdW&n^PO_h5uvYWqdVPfqE3hI?|@^Ap*&F(cot3y*_*Ruc}+WI z=34BOkiAcH2DOU=xbd>-9!B&?ND#kA0RP?eu&;vPVM||gj7q@yKL&u|TOt=8B)|U! z{q?&Ee}dCwBn6~XtRlD$!_mA#{6O9nq(eg)huu{NI9i(7uN?rQU=MsBjlB1sw53i6 zDcYcNU8~4=ObrIxP4~S~;J2z}l#T*WeJV$l$(|wSFzWiG&HDiUk+dQ;NL$k?IIoL7 z5f97m-Y`~UfyV=m%9A2zHm;$UJzsNWN*&fgdvco!Jc)qKaf2*Yb6}AJ= ztBP$R27-~BVylgu3#)$qT5ZC%Ha%TUArS1KK<}4<;Qy>0=po%ca+WwiY;^BwtW5$^ zk%lwij$!x<45IWW$=*@)8<=9Vb>Uhc9+(okOn_fH)l|2-c6#>MEh z{no`?*6)fCSTpAQph|rh!O{uNy{G=qAf1|~rHmDFvRH|ZYPyTuYx$au$5~MY!9p52 z$_<)|pHCY=x^SeQLM~dFuJi5-AjKZYS5vvt>omn3;0SIemjPW#6Ac6_J^jh!gJ|pF z>pJ~>>QbnDgHoUE&YLE#igF21vMSj9aqjx7NPc>jLq$xMVV)pI9DuG%CRuEVh;#>Q z{4S`@kcNdxC`WZ;3<7(*6Qvd-1 zP)4+Wz;#;{_iw!yzYbuZ+c`j5Q`P+e029|urx_`PLa}k!fJ#ic3JDJ(U>+fo)Wndq z6Mj$f1VBbcASnM7@e$!mpUYr-vi*R`pi>c7fDiCE>-Dcs2WN}#o2nnez{a_|sPB}d~ zjksiQ9O4a?dx9q+b1B*d#Eu&O_$% zv{H6nQdxxMVF6Xp!sTi`B*JKpB5Bm(U~KsQZ`X7{PSd^8TnR;asmP03AmyET zSGG?zq82^JQ#Y7Y5j0<1?x#C}1}aK(H%?b# zu-4;$OcDL}wqQhcXjbXot5=SM9JB~7f9&TzKytlzr;vFnC!CE?7tQm4I&=Vl#8C$r z<6l7K17mxdl=*N+nFR7_CkvEzTu+lP*;`2uMqc4A2ONyrR~h|uQIR8Hc` zh-)zUldi!8Tm#3;)MuCAp;^iyi@k_i68a{LKA1 zM`zQX9RM?QypWvT1`r`A5Ok&Ajvq$cL*Sou4*}pF0!bpCXu~1I@Ez;0NBn8;9bO-H z_$_xW`3#Jqq~WswO(hS59B3-#4uLzK%t@rtZECz(a%A@gOSTPv@2XOF?F$}GLpL9X zNAvC&b2gjQhopqW2vc&9lwWQ~15 zoWuw?iIL%3O5h{{eiooVkSLBPPoHhqtf=l*Fpq3N&iPMtWG2C9F75r3VdFJ~4Y@n# zE@mR;5Ve1@7^H)s@`$X4PCwk)V@a70|4(L1Kp#4$z$=y-DCr`J zI;;Y)-}#=gZ5I&_MK}59JQNLhC|VkB`D`#MBybKQDs|`%+dK=ug}gvZHDioQHDX8iW%p;w=- z-sLA{M1$9ouNNhe<$9PoEa2v825pj*v1@;pFiZeu09>zYO6 zjV3{OSBGuRWqX^;mAC6zc=F*uSmiszXzNOQc(q#)Bu+lyJ= zaomL^GnE{+zC$Okpn?sKoWw4@^~un8Rp0P%PD8Op+-dgJD{zT83jN%6uVeEX4RmOrGZKC7 zatFMKMvg0ghHK|MR91kPmuK--9z?7v;!jpqDo^Rr=I1VCX$yljh3B_bb^&S|(`K~< zQl*DlTzlHZ+41|JLHUkIDiLG`rwM!H%BI2L5k|-={pjZOn1tTDOagS8R8yqEc1|+; z=W-G*Flo%8!`ep>?yjMG{0j7Y_v5ag@Bkw||D1N0RU-8g8GuSJ88rw2ZaYX|%kkZk zo`kJ!i&Sp2{!p;h-}$Lvi(~H-D5y!by9PzE`7kghe;{D{&es2{nIVl^!Xl8M0v9#` zVNJuR$to1odtdOESS2`AMo4YAJSXqG(D`-;OWX8-7W@q%sU_Vwf;9a|O7_MIpMAJZ z^R_#Izxe>^#Y5>ub0kJmROWUey<|xwGS)~TF5&dMHD?^SJrSkYi)^qvtB%@EuLJl` zSme%iLD#0b@sUsZJa;Nz0v_pbFA{;bee})_O&ojbDq2O@^m)c%PXHCX@nkTNJ_{c zAf!&_$eZyV6}wYF2Jd^xvamP@oE;OOfcTb<{ac%rDCO64`R<&G_fVtT%3zdFj>x&nqT+}nN7E|)I<%07@Xbi0qa zCX%LgoBK{pf(aYhS0>AZ2j86^(&Q`sb@P4urLi&bzXoV=Vl1r|h20GL3o(YhPF{oE zMsH8J?m1^pVQJIK2Uemzt8tSKesYV)m%SBdUnf!C52)8sZ-n9BtNhcTEpL%n>w-b! z!mUtJV941j3P@BM@%i*AhANXb0FM@u4%ae>{cYcSJ9Z#Ax%)+SxTyOV{fAt1`$C2gqTIC^^QuE^a#rq}?h*KNU!2pgbT@4r zDhc)%?@#faXu43l{)lHZ)U-}g*D`))=r!)-K+oHYiRPS`_PAm}#rx0(ji;}&hrEa9 zu*yx>YfPi~y=X2K!i>EDm4IRb-Jz_Pz7lG{sa`&z?KV*wKho$0337|3_kWpQkn{5v9U7C{8H$|m0f&=KJ?!N1CxO@Tz~df3#p80dxXUJ<$h`Yz?SjrU%I z2$wzZ%Jq}KjN#z$KzZ}GIJ{H@-LU!`ufYcD7PGRX4VbRMqmvlC7oFzBd;f*(|Kyqf z|DXLA*1Rcz-@;t7FAAD#cS%K0p|L+`_EYcOq`EOwu8Fw6(iQT%NpArHQMm+^2d8bBJb$bm_Hs_xZ|WOuSN zN0y-M2&J(K+EHSQYZ~C+K;7;Kv4F`QYB@6X^p_49 zTrIc%18s!LBFkRt0H9GDv~wN_yroF9Z8zQLKUcRsxdsT|8@Awuh|=obb2(` ziCLxz*|Qn!Dz8hB{OeL+o;T8cZ*?yIfBv&?QY^A#+8=P47(_vi{`~djYDA@0nELMb zw*6{X6Z5(2PLsr97lY&*h*&J?Z%z0=VLgtdSP) zKS4{hBecZ%ihHYIfmqsOIX}QHBScTG*Lv-D79XG_vbXf?Cjh>>?i=BNhJ^=YhxS(b z^X;i4ojrdNb9I4)T!Sw6u{4Yt!N`UD1om_g(l5H;m+23vYW|M$I$Rto+>F4^3v3)J z6uKO5exkkz3$3Bj^I3^s^esF}r74gOG~Y<0zjDdtQY4izAAa+PsK&nzngd>NppMbB zye#CyJD zu!9nNn~fo1Ll6rMg)e_hPa+2Y+lJ~ryUN={XuW#+{5<8N<@_(_eXyC%GqO=8$>M^K z67Wl7?5A@2ISZf*1C}2d6}FTDM8dX+MwbK`i1Pz%rm)`bP6Dji=O@;T7mxWXSTjg+ z|GjN$>bq#?MmGQ^k%??if|02atSZWU;1g788?JL5hl2Hm#$g!9m5g*!Zi_5dMC6E6 z5J;F-3zvxqoM0NFdDmVeL~MjF{Worz6GOzA>iZmGlvZL@6EVob#M2N{h|va9xX#F) zR0Q8%ARp}*3@R@TEpb{R=$+pod^RCW4hyv2uM-{sRd&m2XDMrs1r`ryYn!x5$3E;+vy+Hiw>8WVX5F-!dT zzk!SYtxrTi`TuOqu@VhS1=Te&hnDxrFPr>pHh=8aOEA}YfKgX=YG z{mZTwJooXirSF==5~#|r9S`{G=Nv7-N$oxpRDoG{ik?G zo@eHSGuxG%xiS|XgVju$+gtkOp|zQ{HLl?ZOmFh97(Z_xF0&FLgBJ4qL@wwWp|{?U z9d$MNyaOYpe;uCLj~@cqcUWH>j*PG{v^4)G<8#-(3px#4DXMa-@rMH`2|s}s05JY6 zG5q9H(rAa-jZZ$OzThCaEOCwJ84UU?T*o(WW~~paZ>A=QyC~yB`MFbLV=>j$LcP7c zYOX8!GHW6$i$+^6E8ps6ri=0=FtK=(+}pQr2b-wSNGMX{z^CXY?X#vwn6;3FKW-`d zRa<<^=68v1j3MOb-7D;6Ke@tN;BxmSBxs+7;ZVd!?q1_Ep6cRcoh4bxNM7Z28&@&n z3v+Kz=`+2fs=V^vUwW98abcK8Mz<+ZVsx0EP9GUAF9Dn~+vbZ1LsiGT#VHzD20;UC znT%EyH!Sms+B{uE6W53hfPCaXXixof!N*dYko&3UWU+T<#pYsejQ#o>duNnZDGXBX z@7LX2N#5!Xwq%JlPK;=)Wr)5x-Qc{CHp>O z?ng-N&-gJIC&YQivYC$?Fj7d|4xOZ{~K%y-HdnkUN`~o_`I#N^dn@4>3}kwm~EX%v9B%&p;p(2 z9q!>!DU}C0gO`tm^wt>|^c%-VOII@SJac$s#;02H6RG-xTH}Mc9{rQU9s2ezVGJ0F z?AX^PJTLV1^^q%*zF(gq=CxYjxq>rt4P$VhvNoEG?&iE(v0&1B?2tj|{57UiL7a4xpR5>?(d%0n2c|@l%ad$CJ)@1Sk%d+zaIZ4(msf&-i ztl1LYJh8oZf?ag*%S3$TEbq$d)J!yU^%r@^{%^a+lRL|07X8aY&2NzsM$tNO;D}3| z7Q*JXYJ6fFp)ZW3OWNG;1{9dfr8*)V)jlyP#PquI4@Qj+QMsdFcb^uj;& zu^0(#&9-cnua_UOkKTOcf%z~e7)jY{;)rw0V{`FHqZ?bpG_?$BxZq}nMf77z0 zyYTSx)(36OC9kjzdG<%B=ZYix`Otjp`kbvZY2vltVJE5vF9% z&wUuM7)~^DyT3kkdn*RNJk#d=7%ee9g)$xXW+<3yp;z2Wn5cH`>=^3sS9tcNoq-WZ zz=?t`p<~t9jEop9$lK0!AB+|S|59)+1``YzfrctyII)QjmN8ZRUVIQO^^5vlUU2EV z>&u`^sQGs%ar%R++RIAp)SuQm3Mq{{!~l?`2J zc+PYR#Yen7OJl!2Ta?V&+u)Jumfgeb5zn}CYEinF zrXAR<^;fW0Nqg+a=e(O!OsY&{bF{wdY*NLV>|E?reZ$?|(=oM0b8Ja1^Knv(PHb9% zBUj%wKTg7Rt-QUSojf1+@z90SdkLSa#S-R9Lh|G4>9}Qp7JV9-x6C9LT&hY4nLNOSsTXnYoqnz`%5D=3|<$ON%162=41_dH^` zvWy|rzlk6YQG+Dgl)jUf`D0OD&IO|=e|Ba(&$ljCFejwgnvDQGYiwg(M_Ag_!{DUy3-Rr+MAm>osBXw z-+XWz8@XepWt&&9C)}G%Or~`btDaO-t{%Z;FPFA^6G)`zk)lTeWItASuwC0Q3*cWh z4Def$l~3|yyEbgHGiZo`2Sm0C+!qFB6o*+-PmvseOzYEyE)qiTk1m|9(TUCw*G^HF zDSSNXJ#@0q@KTo4+Q;28hlOs|k|UGQOAe#EJi#TJ&loK}*J%hgS~pqVgwg0>9`=*| zWLF|ugS;kRt81s96NnJLZ}g`5GEtVS)j=tis~w-bOC6onsz&hHbGABLB?(eDQx~YT zc8y8lr?DKeEMJQk#(eXcZ#oqVTPz4@b06jj@%2c3FtBKxHCWK;uQ2xLw(@219^>0* z{Cj9_Zno>Vd>nNx-B9u2H#)Fcnb2uZ&syTg#$w!2?}4>6+WBZF?_u(e1j};U6v;Q< zxchI7XmzvjRbNKizmd+lY@|1~v27h#VOpW@_Wo8tNj6>+-=C9=gtepSk@b(Ks}8ZD zt2Y;fCQ|N7AEugM(S5^*S#a+VmCCB$daRrxItJpdu9LA-??=Bj&+W zqAHen+r4pmm-_&QYZNsT9wzhE&aGKzQ*N56+9<8Bd}RIRGZcACby`_R?0no;=dP6y zhjT_s)|nVVY*W8IY$TbyN;Zg=q-$l zzbX<(Ncs(~N->H{_8d&us^A5bT9m5|Z^28Hm3{SA4L9HqAX3T?vpdMJCtkj0^MOBQ zmp%Qlr6OlSpRc=p4$5o}$>a<=?8gpnP7OAyZ!M_fs@HMVB`-DtGhTWX&#p)Z+MrLNL(*r1P_29fai;5R}VLWj*n}TRf$4v5v z-l^o6e8)7;wv6P+ty$8ocmpx-)IDt8N^Hxii5mW4??Y!rdfcpUzuTRY(kG5{=?f(# zyi0vyR!u^~!%BYz<=m!ST6zfDHt#gUn*lSXcggZg} z;x*@Tt&RC?Kh|IA5sbx}Y;})S+$i!(o|4)SU6I{c$ys+|?sXy34x z;#Bw{Pmz3lP>W92hLuC*PW#MtMhoW>AERaqR^F}6Jc#KxzsQWoy-P9EX*N{gqq`wZ zM<60!plbZt^UWc8o-b-i&0{IFY~HyQqvZAvwp_nmaVtT^P;yMVg|M!<`&xUvTv2tn zg?ApdyfYBU@9y$kM9V}>I(dA?=ExUAFQUo9Ig0cuL%a6*xXgO>@G9x(!oeX=?pPUO zGaiQ*YrdW?Gjxu<7F+qA!&TMcYaYumFDyt>A)YbntOwPYY2bNW;ou%RUG8bD$>{NE z1D-WIEf(L!FkPcqQh}wmW!&9`@(*Ot5!&&V29-^;e)DwpEZs?>ZRHwoN9IFciUMI< zMCSXhahV%G#XRtJ*UDqIRd$}I>re^YU0-kBs_=wsAe|T*y|*@q^Sz1`sVZmh+dJcA zU?6;%L`lV3|BNd`=h9lrc2}SC$9F{0NfcUPss}#K zT$PA>+#W?TcJ-O;9<_prNy9rfAL~QUqspz9DZZ>N#NTB@t0z{UXYJZ332Bx-Pk z>xHnb$Bnwv%x@RUa$c{`MHOD%KiO|~LtXJaNO2aD~G_73*9+*J1wT43H>o$}gPbnwwp8LV#|Vh#@CcV`H;wQX*u z%zfY^_04FcC1vYOkxNZ!?tIZysaisFWRk+?@|)5rvahB+zr1W;dHYIYxokeJ89ykv zOey_bU9jBf&20gcbQ^Q`PHFrsYeMs#14^b4?v}cVdIeC~PR`!K8tA&jHp@@tm5Nwi zsbshIPjri(Uj`fD&cE)CdW_xuW=Cl{!eravgD+%C6xjq8)aORBdlyD`4?pf-8@aXA zY-2B@t*|Yxdk<^kr;loE?kN3n>t~2n;%m%7sy_Wk& z$@*UtT$79bIuuBI@CBLQc^eDuPp``Q5kXX<0fvVHq8igjXJ-I=73s=r%NL)XtUQ;r zm}OB)K8#$6s%8dv{YVcD(?B8~yUY<>xdpYd_sYD?mea|t{N#-+$*q;<7*y4KV(6pl z#c}adxfOHDZeWZ@kTjj&qA39l=+QD}zvqF;NSr8+=7oM#m`lw{q;${FOPwZMy&W5_ zeqi2L(zCd~<@2qD-Qw>>%QVbm@)c==y?WbrJx|1%48DTFaiPbp3(p&yg^{${_b)Fp zzwIcs!bGh_8{4Z;qMC6j)!vGB-*R%%)?4pF`(#$bC7+@EnD*p6NSrccaej|Bd*k&| zV|dR=y8ePG)~PXdkCF$j4wHvB9#`S-(cP{pnq8(Ep5I#ZUcpW!M|jAbh)g6;FO)jQ z+tfn`zUo}fp`7{3-6JjPcwF&Dy&|?#(oJy*C5Ww3HZ)}!9=|VWS#GwowajP8Z&MmS z&Z3LAlV6fZspc(^sSRYiE+O3NI^0;9B)OjEKeE$)@E7h(TMS)r#dB~F(WxqDfgEh^WX{hu!df-o`71!&QDt%As zE_swwVas(HEM=#|Ar1Xc2rrT$PXn#|#8u1&X7@A;4m8E4sUaWO2!7c#iAMZVHRHN4U{dVgcstGl~D z;uX1j`8AgL@yUAJPR%CvY-O zPwy;|8i*&;NF}%UeQIf8p~N4sFzOB28RevCF%!M}O|s6#1$Hr)F9oYddU6hol}V+( z|F~y-z7#>f3ML^ziW7SwCKU z!h6MrF?Chn>V$%%Xx9<+XA(ueLZ};ABaDmzB@n-4hmaQc)wUF-U@o9Snuik# zG*Nz3NUG=6 zp`dL45eI$y+Umsz5=H%qvd+F|{R01^VMF}n$=JEV8bPhP)K)7Gk$}c@+eSFY-)RV_(hS}m`LQ)9@L&j zZenOk4QIK|@o_gQzT>KHF;FDVN0RaPv6EMp#aCL`HZJ;q&fQwg9pMsk3*1g<8am!| z<3c^?^zkpdqpRnla-HX;;hIpK`uc1268(hL^1N@0SM?bQHlAXbQAJmHHo`wedutlF z&`$-l+A}}zGeMc%xV^qgp}zd>c8a>Vk%R_brqlFNw=|vINzxKEEoSEzGc*%dr7%3j zy*uvoZ`6cZN?cnkofdbnsmyYnv|eitHPwH)SFNU`S7jzJWvj>cp%+F`x?5$CRb6zf zqqFvLm40)_i(a(;6}a$?hhR5O}%gOfj#kL#HRVJUaWs>up=E=*xj!}{*;<61t zd2)Lzhc;P+_DpuqGn!VLW1b{6U$$;qQC;72!Ni@=YoGS2dop|F?#QIaW2P$$2R2&c z8e8J3f*W0kpd@15aAmcxCQ`jfqPkRXyZQ(e6p^w)XAizO?swkaDSB63&3Y?r7wqCd-S;BH!ffN+_%R1^9s2q%lbcPJyqWmSF_3lcW9mHP)^csn;}n^l%7ex0t9BSK zJrN?&+W>}bdVcwKs?u57)Hb9NfAhoL>9BOzm+03`*YKYAcv%UoPk9bKwp}zI6^Y)} z>0IJ$>0!rasq*<#da5LmLa+P&UR)1d)VnYyW`aP4d=_j8&pz*#*_a*| z=g#Z0e|Hfj!P3S0ZdIm`__OCVFHaj?N`f`hi7kb)B{`Y3^d9 z-RE*2*ofV^^R2rOrgmD{bJFaMOO5imYa`EyCe{-JgM43#!qk1-J#Ozd*WP3;yB=@* zxlIm?jpI^z>CDMKhualf8#AaUjkUFYo-@l(Mb0PAIGY3q9s#yux>6=%V|`RcD9od8 zsE@_q%KXF5lJb0nJ)O??sdz7*nkAKX-q7Em!nUuUuFsSvOJ?GB=I)U^nM ztlP@+(v?-yVT$hLQz=b=JKQ^fa;_uy4?p;Vo^=Dk^+Oap(2u-Q{ifZxwH4ue|7NIdJl*8T>>Wr^*=g z63UXas~BhPUcqjj>bW7+E|J}3y4uI0XOmPYvU>OdX{r zhUA?bB#@OhIckNLiv=g;UhaL8-7vcna^m>TPVdY4j^myR)i+IpV{efzgvOOCYKpsI zbNr4|tYVy4)~=7nMAO;MEso}CIajDT?0{GhLiEO~(kQKv%l$!b3SDs)DaL7^)Y8p(6(Tn-a zW-o|ES;8b2vZgR$tF4>5Pfrux>zqIS^j`P*)!Koc<@&IW7i-eGInFuE2ZY5Ce3#Xs zn11W@BT>SRdwZttsk=U6o3Jj(R4na}s@vs7xnN`)h-aw1QeRECgjyaG%B7%aA>)nv z7&sq?uyH#GtSJN&o4d+Ws>9}+Ef#V!1ebh#uAlpS=>0>h_`cG&H_E4a?X15{Re8Iz znvRMne$L5Box>Gv5`0~8B>qZ8#CPbUnkoKt@ly*k#qte{Q4!Z76C*dVRRM?cj&g?^ zREZy5X6fwDlBkA?8EksDP2UT_(mEc>6ChFp;?Pl;vttb5ckRAJRhEU+35A zdN*0l<9E$;bmv=)ob;JLITEdJcjH!Ly3EFPx6S!vp&7Skj|8ma&6|fMBqZj#>ZP^4 zmr*`Cc2z6aQ2x>Co}cOFT4SoT239+Y#r^6I(d(l$JazXSv7LDS+9by(HS@A!)M8s% zmEL6Vhvgj8rJLWntt2N77ndrkgmx5h!Kes7UzaSng*|lFgT+%FL*0S8t`5D!a?vS6 z11vL&PK!p>FY6d9bOUQ{=s)SbymnP6-sagr59P9^3>!rQOECxK%H^VZS;OgKhxl14 zqURjOgPN`!!S}iKZ;SNY$VRhtE`1YT7PuxcFctIZTRj@JxXDPO%~%6`72%zzLJYVi3rP;WS0R04m(%VP78h^R%y9R&#fZ;R22zMlVrirm-u& zEXCM24lX(MBBJx6AhNVTgn~ue9h#B42MBf<2CJN#B!EExp|NDNbT1Tldo?e6hOI-! zg-(z_H3c4K`&P5NuU}jGTP; z@-O!}HAhy^!ib0gm*@JYjWN9qE`4Q-+@pFt(`j7fL03hIglWIlPo|wAA-|U*YN{NX zxfB=I8T3A9xS5LP4xPhP2+J75-~kG`_}3WAvXa#&`(mr<{ZyvjnXSo)p{}SDXkW4+ z^6MZ~Euv#+w&1oXVkOw8N?Xz%{x+%dZBtxd_%ih~e_tn^&}NI!(3O;qQ4?n8i-x7g z(~~MWTOX3Zpv;FD`7GpCt&C&VuMZ%1OYI4YyeJ7q`dM?Y)M>fsF(uxkA_R20tn;z; z*xdETTx{4nHVivGozqnj=7l$%quTmPC0I6@Q?6Du9$$41KP|xhe7#L(BccLondn}- z@^}4pbvaotgxI9@?{Ew(JhVep|_vnLC_6TfWrXY2`yf zR+spBI_K`^4*D4X^1VWJ{k@-wgUSmp80ZC+mwV*#QEs3_@m~am8VtkFJD&T}qo{3R zQ5}V!@`gy^?M~vSB<&R|1p!Aq53d$_dV(U|BHp#ULvCzg{Y#d`vy9JpXbZ+j zr3t&Fd#a9nVVFvPuMw|y=eo>A`prP;0K3TBQ}>g<`G=7(%o(f{V&tXkt8Tl^-iUm+ zlV$UU>Lm5~D^xTNuQjO!?(T2N(UZ9sxXJ!fc9CxKNp_g)CWTz$3mZ?l(&;xlLQnXW zD~+luvXjdkOH?|t&Jp&k_&q_}W9BV6=}+`-30^^x&zQ?*Y0Sk{gcclP-sn~-QPNF) z{KhC48{k%dX|T@SYAw2k{rw`t2$@t%n2(78nb*-hY$z`Stiu7Sm$^|RHa3BnccGWi z19AKpqn9+tKIKsMEciK)8_c#{s+4E?e+YZ4s5rxBYcLR8f(Ca85Zv88xRb^jcXxMp z4;nN;a0u@1?iSqLUFOaC&Y3m;S+izt`hrDwzg1cFqJj9SeNQ- zhwTqU2m@X@V5=!JFL^U&Iy9xoAr1rAA{I3)yzDZ*jHHVVU+-fz_%{KAg?H$@~WX+|@ zAV&7uuzE7eahS&z)mEfOF1Tf>vV<;Sn;UWwEAL&;W=ddDJ zASU~J49SuH^EO1mg`o!;G*s*$F$FH{OpepE9a1v(E%Mg~)Hd1RVpojd3aL@SxsbZHsxvw`}6s{6k&#YV2usA>PyYElD-b_g)a)8E#u8Za0}gFAYnMv zx>xfl3CbQz`zqNqdKqN(ZCkgYsZ9Mmz~Ukc+4S4Sw{3;EKVRR!MfV;+qb+N=gO_OQ zbNjrWzxb(`>dtXWVwuCfD8FGFleyDDoVz&DE}vnXXqKL!|93I%yyPq>3Q@^+N}{4l zh=6IFnP`{0o-{;X%kVkkQ6nibTZCANV^6Y)j9?~AAV5Da+iaw(Riu!WJH{h|&Hqg- zOZ>G;{vpSPY&1(?`7Ppbl5sv7Qo}!^6K`Aewadd|MSFIN#ya9DoMZNg<|!Y}(F!)| zey^)N!2J?q;A2X(Pi`n2#7NnR_ zc-dXd;_}uAJb#do9cpJgcG@fBZ2w=G_h7JAb-9qtt)S&Cyrq6mOn8q&;X`aJ?ek`# z3sJqg@1zBm!P&c2FY*t~KNUnj+eVC|$0VUJ@iQF{>|;J0=Q8dg>|gpLv$JxhEeWxZ zQ&q7e%9@)TS;KU6nw3^yvg#}%aa^pn^zb^flYXN?gVP2N_!!2w_q3nU<%JIyB9XL9 z5vY`T;DoahsUWrfAj1$cv?`lNtIC3$QMXdX#R^{E_Mf7RJJtJSnBQ(_5^b2#uu5KR zWk!R6IaqH8#(IWVc(5k?Bfr@$**MVKTs;ooHb3`+4Uq~eYVX*tgmx|%b&^a-42vFb z=aUwctTkSU_33YeF7Xd26=c^JG_(DLXrLEC<+3@dGIoye!L9c>cunSk&$W-OR-!`# zMWNOx^Pz#tX7S<+Bd5Fw&D|8e$zlHmGkGmZCpkH5JP9@tRrhyNz49K1_c_+)9YuUs z;8@o}N<{zj>}tTybj)Ft>TM67+x8e|&b=RzHq5w&Mx}NjGkJeg5^gav zMe=$^Ys_oQ2+ynixdP6jYwAs@L+TW_;=}KfL|pFw79$p`*>g(2)w>FRL(aSgBiNa@ z!43SUlMs$~rdoOKDy~vb7q?FD%wVz2vhpT06j*hpZYH^XiQg=ilJr|yh!y)>Rbr8x zfjlH?>s-~ghFxwQqRl9UR->Z3cynE?^hfmsz|G!UOSUFsnK8)R?D+pO{*d8RLOpF$Y{(*g!UJj^8-Mq$jLl9rW#rP^plcB5~(P7iy2*FR_hgV`k){cx~RIik?XO6LtRgQ zyTnHUVjyd}B`-cJUJw}^R(HRD19pw!)o#imZgkm}TJ^$UIP+l~>Wp|B^XxkEyP88#tKAa? zS!C(F>SG(A=Ni5`o@Mxb*{YQNTqYxF% zd&eARK$yUlus4@_CTj!d_V;a!&DzJIdeKpq#zC=Z?4Qs;UdPC!-;9_~G&4o^@R?FZ z<;1XLIAnsTjoHAGa7e0SaQ`qg+Tww#F<{>oj?=zFW6r z?*Wg!3?JHrmsQTr(Zmsg4{=7-v}YD?e506l@b(IOu}|!4_s)CX+EY0m7tcr_W%4z( z2riS$r<-}Wmbem2^N+r9?Tu~gB~0sUM}sj}%I!k7@3#)SC>$Py{@5&v)|3|{^GLc9 zRO2i*yLlPaujN89H6Gj4MFZTB?Pws0SV9Ej;U5u{Jo1@~7gr-LEc>Tgt{Y3&rp-iYhj<(;$hDtfXpFw!*AU#S1}?I#1C z%TK5v+W6g$cP~LF&i~Rksy2F$_;8x*deYU}_l>k5AT(Lvi}S3ek~+s&2B{9mhwGbX|# zg&f{SB#w<13bVN2H!pXHC^T{yKF_6Q!m0DTYNV+hp zx3s3GUZo$vAi?LK0^UbBR@Z+%S%eH4z$olPFN!exM&VR;^@PT*enHgE9P+@Y-8qKi zWFU1?d5|!&S3BRvzUz`C?n|6HOfU_#)Tor{nBP3HAf=D*Tkt8TAHi=pd=w&IwP#== zt7%M+Z2PK}4?{dL@snReK+7pw2tUBqUIweWr9alx$4iDdZcQ_f zxP}Vy%F_7Qp{rrlcW8!EC|VOdn3GJa-nFl0A=D#rn=5|2$Kuu18qDN0X?KVn^C@dtcx;(H z2)o{P1)@BoA=zj1A_Q1yJ)@9HbW38mhxGCR4@ILww+*Jn^~`?=qp16?oHxkrap|`u z?88Z<#q@fiG3;P6Ev#ypHj%^5@OTgFuQ3Shad4cuz*iN7JiLV;yL?9T26yMYA3X2R zLcbvX=x%R=HPUqABCp}|ctE95`~AVKPI=ns+r~pRMpg77r{S_J<#7%}2d+Y^n?3Fj zHNvpk17-rnw^8~7?Pe!jd8VewG-?Fj|VWD!} z&rmX^XLg8HZ*G{K0i{Gbi3Xv3x!$jAiRBV(w^+`t<>2b$?v*gIvBQccmkV=#r{_NT*?$vQY(elzkWGJVlKv-8C+o znJO#-HZ7d{)n3p*Y<6eNY0JifoRkDWM!SINUQ?p(AA zDfBu%94KH+Q1{skU^QvdrDQcY);}ZeIPz75rmL=HeZmOiy1i(7HOOT2_EM%>6Y87H z5i&9o9Nx!ITbAyhg^G#$_UiMXNwG6l-d^{??d_T!E*P?gmM|I$mkHvz#w~?$Uo1^* z`VbayTU<~4eXQCh3%;QL!XGUApV%$PV|5&i?$wsS?C4Qdi7nsd{|c+Wm@_56fXKSP zbUW~9HCk*BybVYCPaQ0nvFcXoPw|16xz46EpWDINu_$1|aOm-U+&loZ*_GDVI~Nd- zzzX1dyy{FGh}D0-%eojT1nZx0V&s*Zn=qz+8UmNiRQ_@HdSp>^8s2nU^VZ8idcI%x z&9!SsUtMp~lKk5O>S>lE@d=2{@#=F>taa0uN@ao|5O9L3w^@Q&S4bb8XDccoDE|o0 zUA|oFS29QY=T^#~y^yFR01Ba77c0=CCe>>~?oxzg zv~unXEFI))MX0+30>vb5ki}C906$2=>2_(8f}0{uMqoO;k2xsh;?Ol4Ea3w~KE-YB zkN*smHU9%d-4otcL}h>m+&Xz&D&V*9kz3#=I_L@c1p1VxSuji}VYv3Z=s3cGBCQPE zfF_+pS11!GmU6v6eqAN->Is=(ugxDz#J|@|KyTk$9c4di<%%J}BaBEHgQA;#WBxnB z?!xc2bCEeZU+~*BMb+lQ94hBoXpGYI`^X;sWTOwOyds)^Wd{+WL=+KB-SJ$U5e8b( ztM=mzP4O-0<6lblevS851lcs^E^}SJ#pOSq#DWLtHJY5~A7+Zi(RJUJOobH8-iX&8tsjyi^1i zz24nu(`1-&6OKw7q~DCmywVN;?x=8^p@Hlu^* zIdab_^82wFoHioiWWRPj(+6)TwPwoJE?CKzSN1=ng=Pa7d46sHH~!04&yM9rV?y%1 zfjN`8ruab&wGMa^kR&QdqQ3H78FX2~&j37_kpu<^T%P)p*2O_eX?-k#) zsl~*wusgrAq=wO@SYbte{VaXXffX17*iPoZ{JW@1k%V-E5Ypn4MjLI01L!lvDrx;B zgZ&5n|Jct5mrulz$}7XI;@(N~Q(mDC2e7^`63hI65Whp&y$Ns0Th5a}6M3U{(i zz$x)=l86su+FdN_qK$F#&HNgL`@sl~*YDm0rs%An)hB)Jh~wvWsWi=gYyBbpgK(87 zyDh~MkA@ct^aic|BEg^lF@^109=BGNWt_+eX9d0K}EiE@Kj!`%)cpmq)3;ecks~0mzUr|?}I zu*x%7_?{1^xvuqEl4&$Qxm1H~wzs8>>{CL+=*AV{7nqLGK%*qDYCqa9{vN>d{k9;B zRC?SW*Zn{d5(c|}3PLRGEZP-B6AsA+K8?$YL9@-BJ0mgyckgc}9j9ukqE?|JMa>{w zG>{3q;ushH&n^Hwrd6(wCh@^xbT*n@-bsvNSoaCm)fRhL`(jwEyPUt*(XmE#=o0r# z)U|xH|19R;29U@958?*dFOX?Q^FCgV$OgIz(a^pbq0)D_N zu6ebigED32LS4?MG1_t2Q($>?<#uS}#lgx>+1s&Hc5VIO%x)%Kq zNBUpUI!QCqsgCcTve1L&lqbod%ACKS#UBY!e){KZv54{c{Fdf!=q;nigwR3#3Pgy` zM{^1EloeRrexFN!`~ja}r`Ut<50J@2qf_2EqU>1gq)ykr7ED24*>5PX{ zLZD?#j(vWGE=<#DU6Bj2{eA*{mdfOc$9zpq>t;ykpy2!#?LRI>`7j+;ScF2%Gfaor z>z^XGBGNC5t2cG!-s~`*#Xs>$#^sNt>)STvk^QzhpILN)nSa9SaT!wW_p9sSIlI8u z1k_3#=_Kmh!C}EjY=*Hg{1#{ZaXx$>5OUMiPLh%oT{;JJcJk)cML?ZK zjlS}u!S)k4M;f%kZuv%1v%v*?ux6v@K7*>^!d>qpQ+>rtnfX-r5Hyi`kudUuOE;Vj zAG(@)xi%uRZe$lcL-n>7D!qU=!ZShj*+RCHI^Nl8JK{{ysDa}Yhy&Z$zqs($$q`N0 zx7O{Ag<%Fa9(sSX2R61lzGq;g2W4C6*o4qoUE66m$bUc?3zWJE=RJ}Sv*d5wb$$t3Il)(TE90pB+0sXeB_%U0{2Og9K+C##5L}{b!ZWp zSf-=oy3fpyN*BHzsg?`O+b4=+9oGhX$^j9Wh}g0Zh%%a|`f}DHhkyIi;S$t?{D^=A z>fn-um~;c^hE@3XaR1wST z$=n<)WAsPW0+$SQL5DMJEp?sAGb$~n)YY4YPx{I^b>br()7v}sKpo=#iioISU zjfiFnaJbOpP`HD%dw=`La7XUaz`fS>Hr{eRXh0d2ft0oEy6YpE08doKX1j_|tMO64 zJ6L94=uqL*xeJVqHK>Cb$LFtbg*l`*(RF%($J)gB^XB)T?q88V^5Tu-dX@dH?-243 zj$Dpb(R}Jc2UnUzIdKGC5-xaNM-$jFtd^?3^>eqR4BB7gRmzuWSMUJUQAnN|+JuWA zA;r&WGp4y-hb1o?u+VwIKRyn)NH0B*Ws%4ubVeJUAP$V5{jfiHD6zDiM>(+;G&!*| zTWtKv5S=o};yWZutui08pP6=zduCk#x^i9j#Gbc$fuAmvAveZkRE2RlUp-D5tTA2C zEmqEjd9go7Nn1nzL{KSNdQ8 z@4qK-#6r4X*Fm*D{O_fVofHKRi;hoR)Jy*cl@u@@q+=^ zb|d|SALzeuc|z7zn_9GW%9kU3BR)ZZ&ZmkK!%}-l>mI1uinuUbSIWJ5kT!xJUEN9_ ztTB627AmDRdxgqAr;IVUmPCes|D9+`qCts&`ML6f+ruS0N)%R?KS%>L-1KHqU$H}p zaC0|XlWe55k9E-_Hz_Kc6|vI9vJ}MF+Eao0sMMfKtyD~RVsZ>=iUqnmS%y`k-SX(- z==j#<<63->?^b-`DlS?N1Qjk@o&i^Xpq#>-(C*E@l{yk17H}zpMIQ*`j5oqASJ@60 zE~+3jQuiX3#+)!|>t*h!!~zUaMpyPa8|zv+d-SrSpjwpy%F88?Kp zvi z^)R`u5ow~452tG^G@mYp9!=v{7)Bz5MFWZa0=fd|H1xq9=>H-d?f2#q4=`fNn<~Tq zqmd_q2k|AtpEUZzS7O3{)+f0z;Y<2o)uyeE_8iTc^mgO&>O&MV#4UyjXb%8N=D%Jc zZxi!-^bQgPr`<+k90b3&T?=x#KC%IA_r&MRICAkVwijLJ)mc8@UA9)B%gpzQKA%D% zNz;brN@z?FD7H`^-+>pLcF-uyNkF2{Y&KZXaqYBa645zWjq?G zJ?M0>d3B$w%5))I5o&+A17pSN&6~N{tbdLkWU(EGG+WJSs)#!LQzVpZk!-Y>G&GZ= zQ_A>jJ+0@`A_v#Xuo<8AwA{O@BfH5>yB_|HjM+Lu_KxDc&B~Qk1L;miobN^=@@tPI z%50{NpVj&8dR`zpgTqhv)nhOR2-eld>3dvXYIcUy;bI8uZbj3&5dO>9J_E)!rO79Z zI_;H-;6n5w9x8bxmbQ7nZhI--sqTc`$YOO-3)Ov)b!?Q!Nb!#Ra`ze(Y6L z62*^l41ch(=?|93b#eW4Yz&ZkOo0S&SoA26j(9VRgcGFuM-Ppkmb1mXB6dp)hzYHE z>-_{ib|mNn+EZk6nj?8W&SJAy+ccRkLwqO+>%c9q!w1I<_7c>4L;7lcT}2+hJMfD( zS2von`9)J9hvQDR9J@Pd&>)D}zz~7hx`*4bv^y3gu^qYy?0P~g6a#Z!jY&NIk0~Z_ z%P};%Z6=i2bag*&7S%-pan7fDU3M^>Av=W;*%9lNWcw2_A7t=#QUV1IAkXN9&B4RA z3*nP{V;*6nuG{WtPi&{-V5HxUVp@!B7MjE!x2Ht(9Y8j-b|M~7tE<>0)Fo;Rl5^l9 z>D>R2eW?$r`gdwC2VB((lnHPJ#~J8<>hS}T#k(U~&Mj5G5;Nbp&uA=kCuYzQB;N+| zn$a|Nlyo*L(%oFcv5?Pk9XWrWwsYP@^@PAHo+BI*%0o%q`Lf_Y9@Y=`j|hj7lFy7Q zA|3vB&JigZXC#_GL{bg(@1CJ^v=`7HK49{H$JB()FtR;$BU#cfA^5Q=e7 zf1qQ`dFYwRU~_%b*y{&%4z^_Y8hKhE{_$ArmxZ4j4mTo6qaDHFv_%Gcyh{0skWBsq zpp#)L?)uSwY8(sQcsp^W1S`P&cX2|#7_N#)`n{^q@gORhk>3TEsbd;{q&7 z;i7s>S_FtPI=u<&*ec7z(R%g#8E0!f6QqSJuq=PaG<>nqCHj!3DYv*GgqOx{%hN(r zlsCGT_l-97vWKo1^o*Fw<{W_ZemIbvnB0OX*2qpAv7-X7#?QzAh>j1zveO*o;)@k5wX0x84QGerueS*U= z#j;byYK$=wJ|Fti9bS%%Zjlc?4(HU{NIQ_Jij?l1I++vjuSXHWBje3EBDow(MhQe@ zK~2&i30w+BO!$1$!MNHPgFc)tGT!SJba3ZNW0rG~H#%Hu!FM`aQ%YgAU%kIDZmF1N zBoM67L=un0#?m$W+sZGS&4akr7di5&_A~|02=e0qXMVrK2xFtL&*t*CG(2DE^d0gQ z|8^wREk&mOG&pGG@75308p@>?7zT1=BQ4CQ8s>1Lu|NNQG|Ls@Gfy=BSnHFwOg2&_ zoUbor`|7op>uCAV4CYi~3uBX_Zt=LMBR^`R*;b|)>>RycPxf)z8Olafx7DWXsTE6^ zG7M#jM7RSBj)>V1aAALmNDP?}2Qn-MLXYkOdlj9%PmBFT#2sgnRk%8%WxTRHxc^bk4jy*;7BY? zUBHiNGiJ8hDTZV?d$C#&+!&>R0l#YoPiy`XI~Jjo`H8n!n2KE~Pf3uIR74+r{WxC{ z7e8%GBl&^is1!B1?n);Gfow#J`9zF0cPiX5lvH12NmwY-=g-%OF<{*pFhTXw} z*8AFx@sGR7mh}aZ&p0f0I5)=TadGd@{I9Vn3B3wkl@tl#Y2=L0n_bOvrN(gp6h6;J z#-ICPZ$j{pg?k?eb;0S`;n1pCZ1_CRWCMJe0bMF3fC;wki=uqOlqw70twI@D;%?hGmnyw&*IFn8C(-U-p)xN5cCre` z@JAX=qP4nDa2)}6{(bz8zhkHhDEx~gJp3+LAKa8OGViZoNar`KV9>gqDD9rbO*Bfy zB^$|vW^HuFOdBd7w+zvKp~n2JD3xKQ@ zF+XJGn4UlMqnE(}Qg3uo6O0YJiVt&m zdxGDD&Ja%~0F2f%#>#NP;|sQ~P};&{?u4h^*|I(E8J(aH<6loVox*C6QT3QB5cZ6) zKzPGktuBV+QI9~>+bhrL%-QrdUIYvhXUQ0@!^CO=?V0Z>4()ZTEkm^sBPJ*uy)&NG zGZ@YdRlKF6twIoi?>zCOdQZ9b27Q`$0;Pbg*_SX;KSHm7d%? z%Qgng^+77`vqJL@SHt2!J=}l-{?OR#s^1}BnifmJXbfm4&~bz~)<#NjoE}~PRU=3& zdxhZIuz#dM-m%n6WYE*LcV8^!D=VeY?F4P4-b1seAP)Z)FVNYsPm~Zn^|-qYsOn51 zHRM^WH~k}J@|0MKeZ;Z-U>;t{1yG=|HV%}6J`f|FQ3>&cDFxNxGfUubhLw!?5HM*< zrwY(N@g_$ZeFEg?$TDc6)Mv9PCPu*399m9#a<20`v2YF*w7NgwzJD%xA4J{hPPGOL z-{cF`o>6B;ww*3f=Lcdn0jg-4b>yEUt3QM+>D^H}Q!DOMY}$IsrF7Tzpv-W;+{h0(c{Ps+l+{zddY6IeDX zQ|e&)Dxt*UC1WqBK}|E5gmY{baKvKmrSmZnADI>=OMC~eUjC<@qcyU zb3(&YD4#eIuG#svSu`h;|C&cszLd%5o}TtsaJ=lmt0EFxEM@K0Lu6pw#gf^q3xe zZZ&`c`!w*HN|PTomIm;^wjF&axpnKF;add}CiXO*<7HKwoX^gpK%Nuyd%YXGt^{%J zVZZA&)XB&6=W;gtP608H!H9--tCHtf;2s{kt1yMrNu_?vpo{#ZnW$fw|`k zAHxUpM8{EpBJBf&>V!BRp&xVxkl)E48JRnO?s4DW`&-H)`#FBsy1uZ<{W;?7^^0z% z*u?_accf-Ap4`;s;kE&pMOMZrjm4U5w(K@IN;LG9UmUya+Do?$i{x-%iVP7Agh~n2 z4tOHR0remxNVx-=wbNr5-NU%vv7vTzQqCfsEC8nJW04b%2>J-QjJ#u4j560H4flP& zO8X1VCYK3Nlqd`>=JB^v2?|5dUauKnbfTzvVRGi-l`@LKr`ai8|n^WE|gJ$0*Db$8sCj8ngZuGi;_*8=lQ)on&H znGXN9934qd{9HB6+8BeqL(XE(+G9Q2j}#S7W-GWr&hsT)XkeTG>ZLEwjiq1+2Afek z2qAfxnXf2PkIQWX6?0fC6C+S;1Ai1@d`bnD#3gf)OiI(pJxfRM;+gnv%#s3TQ%z1P zk9npeY>AX6p`QZ84M7x4^y)QmbUNk@8wCl2%CC_$OvX=cN|;uU#x|u2Hw~ao|fDP&tF&DhP84E|A^W5Q^jgTKW)}$+~1U#bPs- zAPHDKe*g*Hzj@j{_PgtN*Az^2?PamDl(`KVRC(>9A8Lpj3Hyh}o(3j;wO0Q%`TlLF zrfOHCWtSVN=YF($yFVvU5U?0OfA5^0A1>W*Vfy(Sr9^Gu0=74xHNqMNs0`9)EgUqE zwxMAi4$y)q7w;A)&Og7j`%N_$hzdvfkGDnnr0TLrs<(GRc|s%i837e?o|{esv}s4c z3RJXbL{+nd8o!E&L08yUpJ!e9?gGWT^#SxPj%Np02y+s50LGX zI;P0790O!7p+}Zvp;n}Rk5}i#{^8qoOd-r5m6;z^?gt0~ zIkH0tt4;zi?RU$TRqqE3QVA5*vMK;OO8)8sJH9Vo(ZN4nHUf(p+O03a?E=+ozLl73 zXlFRt(r)+aGY+>E5}i&pYQzs}ngroCO{FTMDV9Tpj5EW6!Fj==r*s=yXmyrh3_)1?zux!a1Gjl0OCUd|Ow5TH#WThC#KIM8 zN<8tbLrs%K&NLZsP#>&&GD+esz+L|s59sc@af~b}4YoV2+n=;;RxiI6$D{){i-gp@ z*RbEviMsXET7x!mDAUJ6UcXQdR2vbW7B!*6VJ%IsK>@^!h=Fk0zXs z)l?zSV(v7ETfUJ|QTtD!JIdY3O8-0?aro9_vzRu$YPB+S>Oe*BXOwpE=TY$J=IfDj zYL5@aCBXvQkII@C4ZVj+B`-O^Vfts@C5F*yvfbGHtG`y&0QZ@x+hi%uj~#gOB^^sN zV?e44|0DI9!%vgbp~9_pnSbvTPriVII3`H(ET|XW`fqJMD3K>w#U#yxTxcj`|-H#tH=l!9x{XT`$jB#!7i*~{fp(d7CxX5pJs6&~XhMIluT{{=ZBJM+d+F-kDz6f?l<8$Up&q*`6efRCOROnUda3Pnk50{FhgM=FK zLXQ_;!F(uvzg|!n%o+JnA=IzRT>$2~BUpm@sTn-7L6(}qE4?D$`dT6{MO{T zuiJS0OfQv5fo3$Aq>KG;GGbnJvLxbesmca}4qOAN>%RNBSf>&OAfGKu${2Mu9qjap zHa8tybdg^;H~3U}F6eS#_l3TRJ`05_i~kBsq>>js8?EngB&(77KDa6Q-(EPtt~p)H z$J9{<9osRIh73_ttI6Xys#O#H_HlcSn91kl-tO$gU0%_Asqib+N5f3FS}`pAsXhV0 zO^=V^Tr$`l-x*a)2j>xAZ;pdCqnAANxzRvSIdA{Sj2&^eQWaAW-vKsc@ry!fityBy zxYO-oog?SsURG}$vkHt(rZn)4D#E2S%cif*$W=x|4()}%H7YqeG^Ow>>jnc-xF*0n z(C6)ql-M(bM&}y*&9}g|vhE;w{fkKm#iK-Y<=$`%DpJcr?NtRaEcnQNfVPT^^9i^V zhwFoZJ|n=F0K7(d`qaIfltxSv6%<{-KMut-YJr3s32I~D%tdFU#U?zY+eVP@=55e+ zD!>HI5GXl<>Xt(zqp|ZI;W1SC5)TdTg+V$Ssc18&B}vXUl<@iZx8-IQv=$eS7YZUn zB&r!V=m?_)=ax}6Sne;{4u~!=wv0dl{XNEx(cv^s!eI?Ijq$MD{9M-wpH>22G>{^f zu4%2^5?>fDYe-oluOv{``~wzwhnf>8g`6>wpa0V5;<#`hERHprMDL09;C>q0-19Yy ziHe9GsH`lGf?2dAYwQ!mVF-c7{M9~9H#{TYvZ|n_^vl_UD`7{*{g4X8vFJKRG{7Pa zKfydIedeivoYDK|Uuz(z3*Q&bLF15hVpvg<=;JTOe`F`bozv7ph<_~8xuV7Q;dWyd z;z6j5NP+h=gfxwWk=oL<0RLL$MbRCd!G@PrNDQ#_>UC3x!tDBFcLg0DaCRm@v-P?; zrNOfLd(LEJHEzgR!%;x+?E=61e|jVg(8UiI-n*m9 zLSe7#J?al;?7eV?f5fno!ZV^|d^LON@)mvm#_Es%Z2Ge=I#ayIW61;@qZCXp_<&xk z8tw{vbYHd=DWk8fw$^+}IKAqQDjphk%mp{pUmwVsM~%+-s3(VUuYGnw;{CyBh)w<0P8Tm$6P z(MW%JDXB{)icy;806>^_nl}P!H>vej!7L|FIp>dzKkM8H{|b5uVk|ct(d`=m7D6WG z<4XeIeOwDMEZaS-WHBtO5ud>Jka|MRW-IlVpOUl=TR`AsLRvtS(V){i4(1~ZpHW!x zAVmNgAp_g?Htlw&1+tOzHPb~kppiuXor`}BJh#}a?#%z3K2BRkxo#ayjp+o8y+Bk? zyy7kg=#c1L4Om;3|wPPBAsOaz&#nVutg9oq(rF4rjksF zxE*TMxPH^{Q|N!pnkeu8Fl+XB=xRPUD@NXjT3d7?L6=M&OGDlurtU=y0G$V8L3`XC zLDU`1`u*#j1@1Tk(+R0O?s!PVEa(8;m8NHEs%hoNA<7PGl;sfJifn*FU0mOYm;>VlhvsQ#^V{m1Oo<$CY@;yPmPLajmf0$>oya08{u zww8P0sQ{PmDA|M3?T#Pb$vI4dHXmAit!B|zI?hx|!jM(D_ndBT^_D+6QhOr z+u)brlbsqWg`eHaYR5S`k@9v=!)R$f?`wgIZ0l4&S?QNFcIYRa<(j5J69aINpwA1G zd{tAi?Sy}o&nYF^sNmQ^8Jz;wQcW~X+K^X5&I z&EChK3Zx1;M9?d7j^_VhuvxDu-s{9_z`L9-CYBRV*RYFver@phrCAj+#=AUFuX}(1 zo-M0L^f!O-*rd{CyqT)RCJ3YRYEvX>ep_}JJcR%fipvjl?_bz?K3pdc28dmeJ)SNk zMWh2&AIbK8em%={BX|`Ftr;YI?jR87*OD|>`I_4y7@S34rW3)|j4@vOrbb883E{$8 zjs(z;vn;1`ljISNz`D}gYvEE3`b7`(I7Mil+GD8gfu;3)CNdy_U~*a{q5e+YAsc#b zS*lSUtWLk&)kj-o{LtR^OzOKv_D{sPq#cewV&XWne~N{kYYZ%=7ivrhOsyR@v%GrV zzDbAqNq;Ws!h?M^E>H2f9gW%!gA56nCt47f8=ZDG{an7S#HDY;6 z;f@Zbeg}*%@=&(bqWDh0UWRkOl#f68Jt=1|u2k05sd@GDpf2uKY_Gu^Sv>9tL;_wr zzpR{>1j1e+C=5LKy`Q4Ag19~RwArp=49zRspJ)o??wP#?eLlNyc8fRKz2$vvWdawD z;%)#G;U0jaP0-h!XYA;V1xJKJyI&71vPyg$-nF#5Khr5P z|A=KKWyJi9XT(*4RBB|xI;kFUZj4LQg8@}Skhq*`jo`TDCH>37MilX@RBMaNZL`9Y zTOw|G%|>NGcJ&pbZX=YG(e=&0Q3$V?-X^$#npx}k236aLr$QO>t67Cf4gSXcK#ix1D4$F z^?>*`;?>q!*0kD|PqoCe2hZ4HX?1eT_7uZG-or3f7dgaC9r4LjJHLpgpP+wnuVneJ z49ahRtcC_lzhWKw@B^Ve1^wpUK8#&gPY#b$%a8ur>``m zU~JN^`_Z&NKH4+I+;p4^jG3gzu~qvaIM4a28Z6PvD{2DoL9EsyKs$_y1C7d3zpP|; zHcC!yIWF@;xKs1m`C&tvofN7T2omCvFir3f#Au3jm#6u7N&IwZof>X|4KGz?4NcM@ z^OIom_&#mgR4LVzxL|iYqpbF^9Kn!1f2QcnU~2TF^RHPubKMye=>QVs z%gBNhDQf_AG%eI@w5=7@acy#h$0Z)flKZFxmmf*P@p}cN;hzZnAAOo5BG$%P@qTdR z9P`tI#ebt$R-{CV`%7^%ay31tSA5~b&Qh<43n|i2?f%US#e~IWqh3HjEI-^q@C{q| zeQ0ObJQ$l2S`m`eGCu+Eq}(y>GNH!Q+CKuHyqVH=?^3=zx#F5CFA5Jx%i zUzRtm<6-w}=gS?R6N;oLjW;UkIIS0veBU0icWC3Vw+7y&`g1?}LZq>KGYGzUC~0J( zqo9FSyUYN{n|$Z1<~7hL`g04Dw6nP?!tkR2x+7xYNACVSsbBivkBe?xtewsE1V9CzK zibT&J_Moylx>@ykT=+0r14H`g)p`zWZrlhs@Dee9~ zc2Mi$dm|5Qn8W1widz{iH^`h%0uV!Uc2w}3xA8lQ%dU-o<>|Ic%xlIx&CcHe3G;}W zp4Gxo+ejnbye5_f?Qs-T@8|036bT!xjNibXg9bPy3epW`$bQG-0kW?jfrSNJBmh|w zR%lQL?jhiVwbye1(XMmse@mjR1|qiZKa$wfsO7E5YXho+tyzT=KVg7WC@*rg-5Peb z@-J$Ic6$k;#+Lp}`lU%*LP9Dk`*m@|TTzKb`|!|DR(vmB8%(3oq^@7Zgexv0s%?dK zhY8)ZzqC^dIe>w!vy0mK^ls_rj!m}yzqfvIv4%R|`>G-S8BZ`@&U~fc|FWGY?`cP5 zIaI~P0;5(A4$^@r2Vn9xFhRi_u$z~>n(-h1x8bO0t*`<|7ZdEO4d}&p!NTZdU85)x z(<30V!BIp|0Tc5=eXeU3KcLG2#w@{nsx41tbAj2hkqZ_A)$T_b`d}57b@v(ul}e&zDle85gHXs@({x zu`84A&uM_6DQdkkIgtNF9zORMoU?V`4+Akoq4|&}R*Cdja~xiqo)G8(Kp z|5&4}#HdSdNiX@A{THnLe}SyZ1Wi+loF`d)+O|O)(&5=JWRYl6mE|JvMO5k8mDWCl5IU zZGUa`bYv9Tyeu#=isLVK8RTaQSNtGh&C9{#ufTHf)0 zI~R|I!6^qoGI$gbAIn(Ls|@}lcKxycAF*en<&4=1ek7>gQfu2RO)xAEshV#5g9cLF zzzDJKb0zSqT#zXg!=X~1vwc$$Vqu-1*Unb{0y~>!)Tp*PpK?~`LVT>=Dq2AEO*Y1# zI`9D7BTZ&k5&p^}sB;A7_n1xK)0(<=z`22LrhYMMfk+5ME0LI%48CEXnzt{`53#S> zqN5l1O+};KA$wH&p}&$<7f(VZtFcJd*bqhJ&*&H~NmSq^CF_kQv(N7f_E6CIPu1>x z!s)XQIZCHdh{k+#Fd?;C++Ztfvqum`gOJaWHdJ$f)b^$szk(~ysH6w0DaWD0;Q^-$ zk9VWWM7n_wS?hm>?-QSazg|^#lk%(ZOSc2xi>cFxQW#!VWs&LuoICipPo;$>G+e}H z$zstQCYrOd>76(K7k6(Rly&sIi%KY^bV!FZFD*!SN_Tg6cT1;qNGj5ev~+h$N+aDZ zNSytC&+m86xp(f&owqjy^T`sS*Q(5gS|8 zKF#Q;&N@)mEG@kLL3G4Dcfjvo$r6c11~m|}1yyYFbvW?tIlyiop%pcbaPqElm=t0n za5`fWEw|Bv%j@+6%_oexBrk5k{P_Zgiz|403=D+_U=t`PulsZM85?DYW$(ek3>s(gx@wHof9zp+s4QeP!2YQGx!uk>x0W=x?XtQA;cx^(12YA)bVG z4Ti7=3VzU^TTifH#K!c2=!g{11)=ym91YgZS?ZBCJ_6zop9cNOCIq@4glepF-USv* zAcatyxz|FyA+ot95!6WUML>r>2)ZqO6ch%7*Ly)?Mg{*VG7xk(4CexeQBh$Zq`0`m z|JD4eCIVBy=OzE~IFeB-Y=M;E47Dt}4*0*hDLK&n&IUt&1z|LL4-x;_Z(B}lRau_G z;csG<=#G-6`{!}Zi)Wxt0T0}y{p+`y+eox2Db#N?H?UAP+2o}qU;Sin)QV#{51XR*ufu>wfg z-nt4)yQsbfS(Xq~@>WuV#v=F!wA*N^uZyN1BEdF1L^2Ig!~X(-57$jIky3;#53!j0 zHB#}gYC08qr-_LTaavg*6=VRIg6{!AnBS)kgfe^-`tfA&?k&xCpdg~^sj^^JHLqG zX6X>TFSQ}a{#%pOBN*^(!sFLOO$p|9>VbmE!-{qen8)lfm=ydvO@#^H=#P$6hXmOGY$dl^RKV3|G@{KK`kU6PPBHNQ3zK{_XAvC z2o7x6Wn#$^K!Mn@{_sa{QQ= z$t{Q1CuoE#3ShS&GIGPNOH9dV$O`X|=Iyw{X0we}sG>tW0cLt;h}(#sG!(;57sc%(N(=6@oTwt;lS@3~_Zw_(>bDk5HYA~1)( z&L-duZ*3xxzsLQEaiN>QG~9$S4$q3hkRy<@dsYX3xnJ=e<Si5`OQ|84f^R z2GtXaQ2!6!e9NtOKd&p6$%nOc|=NM z(hJK?eSQr@S*8nSy;&xTwm*tOi%TF)2nDmVonPU?jR3J6dOFxC`|BlkI&HPcIqpqj z6f2exVKZxC8}~&a@wZ%sftKO^MSM6qq6K`N30~D1KF5Pj1mUTu05vl1@IYC+-*5Sa zrYAfuj~l*Q8ifQ|5B|xicLRp2N1;6F(DvlsViDmwjR2HfNS9RT!`Eo2fBews@{dXL ze2e~dszSoImtQ?gjF$s|KG56SW3fs5`+*3I;CY1IRM#VB^>R(e;{y&Z&|wcZot4>OAU;5!R=XTgXXd+^%@?e*->G)hUR%tz za+eWBHp}m$s@_Zn-^bvy5Xe}+`vayY+!*r+bt7tCwaodGi&cyUDBrC2FEZuEl&WTd8p@g#}V z3ZTJ%`wha6^RY7=BtPG`+6FbS;Tla6iVxx{KDtFA!jIa)i@{^7N`FRQ zm+U6lMhx%uH38QE>3^Uam^fSpb*!?aW5r30Hy{Ksd^)+j)M60&ha2>0uQv0%oNp2R z#a02MZnsLs6+;r#&lX2e#T-^z-V_Kj;>e*V>RPK;O)$RoWw(bHe1RLu zQ@8(mP2oY<^KLhaLnco#0hX;?WWTi9YXgO<2;_KL4R#ji$2<*TZ&|NQ&Nq4_73&}H zs<-DW;>m=4xxCGD`Fh1a=^jwxyZuRcczM}fdObqr`@8WZL?J5W*3(`XX35Ft@b-5+ zg%S@P9&z?hx4|!%? z*ilA_!f<_3V6wDglCX>RetdPqNkEh*q)m=#x2GLh>>qZh>fXaf9O~wDe4JhbkzcvuA zs-y6IScn7z7~_i7iZGgOMg@prZZ;o;LY>{6-&R^N>2aGF%;G(7x>`m;qVX7mmN*X2{}xNUS^w#@wl~dl`I?Xow*6#LM?x-(kK|ykjPY@t zs1ZMsTczLDR?a0&(k2ozJ`JS`(6@)b+)G*g{tXIS(|hR`hW-QV;0ZD4*N9*@f6nMy ze-sCRKpmQcsoN7z$D%t12R6sP_y9{wFMp<&Q+|EnEx(Pn>as1GB%~RPxG(p@_tJ4uVjN#A3f2FeBpe;Sj1 zqE+ZtZopG$LX2S|hYxPHI4-z1TgLwj0+qfl$|q+4`ClH+*}iDi-z(GCBcD60@pO`s z5SDrEsi!!?Ic7Ln-7fOx#^&TVi_Z=pI4V$q<_{{+{9)QMpGZd-BW&F8x8T}zBme1f z96GPXZQyZ=5{^oe;58(S7{UuXj`;TWWHs0z{Ut_g6hhU*v;bGCqwK7WCSg0Vpt3}) zsmA`*Q)$N)DPKEEB9+4Mqh&p%TH@<0#hv}-PZ|YQ0~O#V=dSrt~ zkoV+Y+QV3@J(ThfFJvUxFIRGVtflEl+`ImukMN6MzqN^N4e*5TWHotC24pq>dvFsz*7(Rf4{X6l3@-bb z!qLw8_cY@qp&^mIUv^k2qjHI!ylVw1iF|m;? zd#=WMmVoot4hEoGkW8dL*TH}EQ&x`3=5*sJ*J?{03m;lP)aPH&-f|-eH%;zEO1k8b zJ8Y;h+^ZI5L(?wK;5Ol0Zg(N1*QkW&cG-@10%WyM{jp!1z~CB+_TW;z4PGnBq1)bS z8%>A3l=)5P*BHVLS#&Ma^E|4U#M(EnEZ!2Dp%}NOBYY|riI3_Z%X?vm__8ESwLAbl zVwrV|luw$T99f%5L9rfSyY2P;oJM<3k| z=0mjI+5>IoE1~)%hK|o05BHeIDN>elFjWw4_bd2`;E?28b><*lFIxmbD1iX*>$g+Q ziJW@j4zzP;!x12FfPN??#ky3d5T;DYx9w(etls>!JpIWAQ5mGMKan$eCYemWqdCp4 zs);;j^rWrR2&Ao*HUsSjW;sw0x2tyG*iqy%@yOSlF?1vDV*MP6LF(b49+!we3dD5^C?`^(No~u-IjM?o87Nh-) zs?NqXC{l;htaC)VoyZo%qF1XV+q&3c>aE=1-b)^&z~bh*IbK1#yTm<@9CUuXWjL!e zm%AL-Asyt9YrA*LlV4c^=U}Ws0B7DaS$m;Yn>*D{%o6nZ9khl*I;A>nB7X0fe+hqU zt;diDvuW^d1wGGR2zY*A0!n9?T62|J2)TahNflNkT)#$Ls@Uv9*o8shYZcbrF(iCF z&f=aS@E60F1hJhuXzs;_VAwix?7dPYFMn||=YbdlBdU!MAw_6YL(#BE1PNvhgQXE; z-G`f!d>IM@H`M8R?j5ILF|=MgIj00}?b5S}_580Rl@}(r9=6T_W8qYt<8k8keNaUj zek^TlBQ))H`9p;Cc3;^v6v$_O6lUL%5BjJ!EE7=kslaHW(l7qMW@TO*a0MWV&Dy=H zG42bw)KKWmbZ8sestccdyS{Jq14!m%|7#Rla<%#%2eOZl48hSLI~|YCCqIOZr;b}Q zA|7aZ(sc0;zl8Tm`zIW)=DYw7eM)G{8|$GwS<+qjrqYM`|Vs5C4c4tCh zkV%!9%0FS&I60#HP2mfTlnVpL)1#MyLU&QQQMCU26VQo$mHC^=9||AvSM%Jl(BXL$ zc5w|NAY3$?%d|N~3c1}1BLO=BKIu5rSy)7*luzU^ihPzquNVcaa&Lz}@Fu4yZ~vWA>k(uzfw+cs z3Vk1O^8ABa*l6tzVjQ0yyugibrc#;s^11vBGk2Sc{jrYJ0Q|SVj}-+x{Qu3uOjVcj z_0ZOqR*nPpDp_d#NQFg-LZs6hG5t4Ub2Ml)~=hqjd|Uy z^qT|_x1a1x1Z+UrPO}k-KJ8zY1~DVgS2LP(Kt@LBk0CUXNQ;Nr;O`}RYeqSCF`^#R9U_g|jroMx~1 znjC%@22$csc4zqA9514YUj$&rf*Ax!4K{$}kuy5O{XyGG#Ylln zn$c;FA6vzo)5#I{-%?43Tu80!!8FGG&0z%&!%9=(-E|rxDBhq|dnSw}%7(q(H@a;n zW{bs_EI?K>;`sSWTE15A_Lpbl)= z+f%@IzJkwm{nZ;C_P_O3OHRmR0ZI#uucV+m#ey$}`PVD^_ZNWE zfSB>&oKkSC&E1a4-*V8Mx(wu&mbQ6ClNVwM-iptGD?fM{K##6!CAb1Z#)N{$w2H42 zfjMlI#u?U%&D!F{q-v63NYPXjo#)luq6bHTsS3LMap4;PRs`@h!GKlbBk@{V^~gH3)_ZB#<A1n81FQ%E&t0J0h)XQf~2l>F>ALGR8pep`S^GU|~hv zb2j|r2>(+2T`gB89PuTxAhgil_4wXJk@CLYM-FNwq(L=MZJAdav%*V7UL%DpzE_I) z->GU{_kN)Q?YDPZ!>J{RYGrCcVM^Dgnl0}$GWrrxmH|;PgTn>&eVnrcBb`G#Up0^{ zp+jT*E}M1Mc9p-i7JmEZ%PVQRdYh%;0gf>+G4s{z5Z8YGJr<3ud8!7F&^czGVc!Jj zx5~;FA97mi58<@)0kaf+Amls=lu6&%mwto2`skpnR~XA)8$Yz%B}KgxWSC66%vJ*} z&%)!O(XT*Y6L`%>Da7^>gC?QV2w!K^ztfTDiFPha1Tvv`A;N?O(trF!&pl_AoQ6)w zpe8e`@CjSE9F@Q=(YNBuvCsnG!|9idO-jJ>p6Zlj!&b=Tjkv5L*lE8#G|uX7_I(Z^ zu)$jJKg!j9rBTY--!(hoUGUsw&5F&ap3czEh)H>`Swu-w8&mh}ttDeBolZ+m#NVbt z0(nOc&xqO^`;R=W#%6$oR%yYXAMus9_rj6qTLIfM+P(RaN`-H|#leqwlORtj9t?q z?Jeqxk9#T=kiH`3tXzXUn)Js3PVE)kho?7a7jM9Dat@>l+<5ejXr#i0Wn|(UpEGA=J}< zZK`{Sd%Q8#BMF3+|6WdXtncEK0Wr4^KrL2T!UIgIfdiQSv7LACY((R=-Zk@2UGcXJ z8L+Z3PKA{evB&65t{IF&#TT614C;*G#(mLe)A4~p>S}dpm|7AswOCqF*mXLl(!Fy4 z&J!ir5KxAw8f%>P#^R+#FEpFrVbiIKM1iiaymUwG;I!<0HaD2`t!J|y%rlz&sUf(p zki~_O#pj}W4><1_f5~-B)XFtUfR?};Zs(oAUnSS6S1fuBqud*-qye3G(M95%on@i- zHzZ__+TnXRI*qnl`+$Q!dKTI94tBW_tG`8IgKVwNUX5enVpwN^tq zD!;Ji7cKVYoO@PqUTL4&kW(d9Xc#kW|6$r$1@CMOI^qW?ZU-RoI`S?29vQ9im!)qw z94<%pO0#;!eQQRdW7g5pEt(vEspQ9N2JK-2@YM68k;gGe(u*f0W<7$Vc#H&9FW=FK zNxbAyC&nMY-sIp+PvT-rU7BaLXHfI|#chrK5m8KXeLzj-NG$cY$HKLcm22de2Fuqy zK&g^t1Js|$1G-vU__{BFR04}t>r>~f3WZEq%)a68oG}r~EPfZ%+QNxWdPYk}f%NFt z?w5vr+P`t|<9I-UoYyyt?;TJemX(Ob=bc7M*v9-kA7_*i#JvMahuLOWs`?jMx`qu{ z)uTzRcXOG5;C>7(xMfgD?zCosDPApQsZl9c&-lkBo2u@kj7!fiA67ky2|n1sKO$VG zdQH}TY=~+8>fZK+#eceyt#z|V&1frZ3`~;)2z9}u*XE4Fq}v4Vv5;l^uJ!>kvlszU z`JPMECElsFHf<9Mw`8Pvhz0-acDYy8xKbZR?YF8C191c-5|DKn!B;4U7%s$cB`o)= zP}hU}i1k{-JK48_(*59N+`kQ$XZ*+z$0bbYoBaub0lwjF&arZT2D=5h=sje5t;J2| zEvIkw6$-^jrmtKQ?j|cH3Z|C3?HUp&v#swd?$qI{kon!avKg!_y)oSQB<^@{#M)3O z-P_aX@Ef%+DtWrXDv?IvmDlBNsLc^i9Yb1`qYZmgp_8O(ji2EDzaM+d=Bp<9oad|b ze}<5`U5dRD+^HA0s-FY3WY}0M`EX_NDParmLu85RQEmJ_gdkggFXDmkcF`^A38%z> z{Tc#-wn$06l!k~znCTtW2sFN5Ay%Qk<4e<}`{~3t<%@N(2&8}|(+9ii? zK_e|)L`c%qs6$_#WQ-!(he9S;!mT+4XQ2Cx3|dAwBD<+%8djk|`7+(%zRL%OxKxI< zS{?Br=JR)8xDna=X1|cVh8fTe!`ErzazAz7ee4eYhxWqEeyQaZo5c{VSt^|p2DIri zv+Fy7ngnzfq_>JXHS~OXM9sNj(CT&>GfGBiDhWWY77{}w@D4p&7Bf-|x6a-9=*T~U z^f})$BJQ6%tn2Yr3ez$ng$$tSf3-M0Q69&UAj@EJ#{}I4TpqXm&WiFusU38&)Y-!T zWB7JowoW^{>@@eQasvKCAu1sOv5Pln_Q z+nBHdP=QTSnD0mb^-av+5nxEHK>vz;12u?XQ#_PMO|*Ki^IZuGIOn*-ENoFk&+x!1 zu~g`Y#kejHGOKZGp^D*HsheV(GNayGZ$7G8!<_4mhu1)(P^f6*9S|ANbUo-ciXh&F z%Rm^SwIn9z{!#jA4!1XWj(D1K4BpdEdVG0tgT{RYQ`;2-O`qF&f0Ih8SN6#UiQ{nI zK7@X|V|@rL`Q$f+HK@sURzI)nPZs)(Jpb?@aa6!^APQKHBGejwb!&ban`di^^a)E< zQA`XYnTkye1A(3lmY&S#l!{Eg5RsPc;YaegXf$w>6!L1OCRZ(tF+J4e|B8f zM)YR*I&S5E{t_%xEtV#n^ZROWDNP}gO2APzHR`4EZn&mpG6X_Q;uZg@?L8UAzCEkt zgWrV%y>gwSSX09%H`w7#ycgf&+hQLJ3Re9-nc}QZr&pMaUNB0fKPArl^mhxb#&43O zrW9LnCgddtsL+0WbuID@uYB;^l=BS|c{IEH(06Cfd1@6=`ipyDA2{%e7S>4cF1iJ& zbk+z#su+RsRBC^em=P(ZUx=hq!ch1m!jqXwV;V}BhsXD|{w6nNwb&9PGD_jq+p4CJ zX}KXfx*-fQ4TUV{AgQe7w_H2KS}yr{@QbO?n62Pamc(JxM5jpg zv+?Y;VT&&=JW`aJ7pT6_`QOF+$zXyl%b1Ex1~m7!7_;Q|CLm!BW~ zOErqf?G`UmUWdi^New&^iQxK33>7+p-PDPmzk9r1S$8{BGC26pZsssRUIeM4gWbgW zr9^}1DfYq`@3WAFDF=G3sc9q4q9vl%C=7V!@mC_aE}TcO15uD2UvxNFS)tJQ(8kRT zMz6Z9`K^v9#;D4a-+wKQmMrSxw_6xwLm*1y7uK^3Dz`a$ed?rrdok?edAV^9Vx$lq zX}DaUgdNBii8TE$lIlj-wT?ms^f-G%_Sq$+gAROgq`dNneNm_@IeWrK<)~u?8`!fp z>ID|w{8Sp2Na2a)9#6~SzuIh=SGdQ{(QDt*ZSZl!Mj~m6F>OBU3V!P^)bcYMNh|Cd z**%!jdy4`$zgBv2ydhsDDE-gHmqcZ5}i+Xnla@8UanLwrf$D1=*D`r#u>=ZaEW+5 zd>PWlCK8$Hg9g9u&b zX6prv2a+i5UQLreJCp)RC1U zj9Y7JM(Qvqcy1n%vAS=9>2R0y!9jni#>8UOf`DCEz=0}RFu=cgb`p$BS57&4}8lSk6`sdTBANJ>e(rsW#7+x>aa+G!e+SxhlM6|K8W!(M~uGiyx!5K<0L#t(-^(4gGA-aeFf{jY+YvfO?oi+b+Zz}-0)LeIxzDfMiygCp zsw`hXMky+5TR!Wbycn$3k?0G=CJC~`zmCqId79@_(#tJvd20Ji$fL96JU1L9X)iak zSFrx)x#gDY&0tzn8?vSP80X}ktF4pIe=j_kn1U4o(aI=BRdPZrYoNQkc(`7P_w`W} zX~96;Y5LfxAJ|5g7X}{In(G@j`WY3>KfLSq_kL4(U{#(N#O0)MckbDSrOS?IiXBsU`+DXUbWAk}a`=xB$yfDYqjH>Vccd1>{)EdQ zhU)2+9yq;N9l;+6;+EQkgH;a1w3EEzKm0=yJC3DSSZYJlP+S__q@>Rwn=wU8D+K?G z%j>~~G#pYxs&YdI`)de8PU)Sgk`l5)J$<1L$E~9!HR=l9&Qmb*LIC=%E~HLAw6F7PJSQY^ppc?^x$ArPmHk7~wrgf%{E2PGz2+R5)x`8+6N z(O*)E3oQ?z*HTJM&fHP6VkM0Dz@pb`vHJAg+tJxP~-fm<6W6b6mQqC5m+@Ij_vrITN` z?IPsjKa-xmWu=#FVtUvZ{Q z_Uvx1>eHS585)-6W{A*E96pU}{4032B;dmti>HPi2`j`p)om>9DPe_Sji6h^f9G=ty z@ts$^F$}rpz2rd!1X-!}-Z6i{#3-;ZB)=(;V zAh%uUQz3}pva^AYo#cd@Z(r%Ao*STixBIh8P}VRK3pRt-JZn;z0DVvvn1!i%NB%4s zIvxyH@@97mc_g{%P<)*Fk}P;Cr~ienlsPvUifF(-w1uW$kppMz_InBb+7rPiha&5H zv+#JT$9sqnZrOhx1bnQOH^{)S$f1S>o=3i4E2gD+PRs8}dOP>@)iX;Fl|0m_oJ0_; zgvi*T2A}a5ryeY%3!r)=ibfbEaC}i7Z7X&Vwg;fA^nZr!crI{0W11uU|N05=FnOuT z^s3`jieyA^DbE79jM3NEYCgvOIVD~cr$16!n7?20x_Nm+ip95V1!unkOoYe$l0?2&sxb~Nw3RuN$vSTy zDPGHdEhU=IGpAel`EOz|9fxui5*=CM0yn_(K*qKYvkllA6m|FR*e|F8$2kh$fCJh%HU-M2*lz)l2qLJ zR#xMxBk;!)UN zQPBj&b`2Z_V*Pzy`#s%Rr(`pZ)c&(7Y!w-S1)kNj^~WPkOaD3SpH2HYGmYjBZjfOd zL!;$aUG%}%(6d28Fh7MX9vg;>H?cc>G!-~!O>gkv0@O*zLPepOGzcGbz&B;5Xk!B7 z#mQyoqY0jF2vr+&t+XbVxi3p38&@mm@vo3mtF&dQWYXft_?-VD>gip*x_`XIx>^og zRqZh1$8Wzrbl8h4tO$;SJsvTHP1pJ5@x)kfJ$c2Qbp6T4^;EP9760Si#dZNVibebl z(^idS%Vg9x>B0V{AUM}!@WJ+&s)o_*Xfmm{zHR$pAM^eXkm&gZN*7D*ev;E+W-ak% zvjYqzuKUV~TSn)yDs^W#n;APE!)bz~jJod?ay>}Any?t34ZTEKn+G0Uz>zhF*X7zo z!?~8uD#C@>D+2%*y*{!O8Z4P1CBU@_uf~O6z#C&u>Ob|yh>;;uii1P``>bXyM@r5s8Ztxi7cS8_*`tv@ z#iG34=i9s2uM~U(*y?{a@~1Ewgr!2z0@nCeOBiB_1U@R@JtKkT8D$e6RQ_u0)zWdy zkJ(dWGO(X)x;rb0ziX?`6p(8o8+|EdC)^@{yWoQ z#f!f=l%;#=B2#!(=f(@ToWi74TGC(wzR)1FIB)T|VlN*}Ff2Z_<}tIC;lTjYYymro zP9$~CPJ3!lsppEY)_0OmZuV!^~?AAfMg*TrZzqbHmuGP+1 zwt}3!1NEioASPbF5G8714}Z-aH9B#W#NSUsECVUblii5uQaZqUzT3@rIijD%fzP~J zGv4|IV7OKRjY)t^M#=vxvZ_F>@RMIEA%4*mq70FvXHS3XwTgL*;3eFU>>+~^)<%UIE#W5M62c1ZVxNY z9;EDt5mpru^SVzDlnH#@CN2M!U5N2=L08vvD%4f7%-nUwZC)9b{G_^8oZ;YGH1juH z#>`KInVF&rP@A1@XDa1u%Q2t2#ixA!F{)GtOsmT!2qIr9%4n&O;y zhs48CpM(qo>Qiu;G-a0sx!9ru*LHBXJa=lm&N_v)Y5KdrOw3QGjv0u%#-00n_}JI{ z=(E)P=5p#((>6R3Y>-bOj-qm`a;lbe2B)JTnvA z=_n88F=EmiV=nK&fQK}5M9*E*+jX)J0e=U2BxZ)vZ~qw3 zM0S|OMRAXo#20|8>r^FuKb#S{-|rX6{NwbmkshsZ**ooc=KO`jU4&2J>qW1T3u1xv zT9)6xBDs--GTc6BHnRzEdyV0th3^fy@FcY})cRQF@!2P@eRYar|J@XrU#mq?53f-93> ztu7sPyUrQ6GYW$Fi$8V-U1&sb7jM)V`!4^yNv83z=keSl^Q+&sv3kBa{%ywn@Oc== z^hyLv0OOMMIk_t#REUN61VrzBOL$Lct!7vyr1&*FtR!&sPB%INc-R8~xWDDlef$Mr z%eRD13nCEES;JpKVfNk48;5mc7-RaPg*HZ5C$ zJYMHpFFUM$P%ATTEWGE{UUT~FitWtDi#8zV8-6g?!DxPw8*ra-u(FiS>3?28LVEli zGC4?pEi40D)Ix*12On%DG9y73Akdn(>s}w)X z&Gmo4=O~?5<=sy0)|Wpk68lf;qUuEHOj2}zz0 z+-?Nf+V*&pL%7fXo%W*YY&mWAlB(8=$PE=ro}Tp6dO^WPvH9%p?^qFVF=i*fZ;9_ADd)5q^4@Or_qre$_>v+1C zsY0RYqFzsH4yDss$t8NeSl1{~nDB3@9`Wq>5#LwVOqI$=RA_wScX^F*r5Nt+dQg(k zaY(I@^>Ko4Hd@`cHu++wEv9S+4L5>ss+q8j-?;8|bj}1#yI{t_Y?bJ4c0f)}xPm5` zhr(gQdy^#yKOc|7u4HcVwnVt#f5r9*i+T5_cCWu(@?obShj;ZhEf>~l!lBqPUnZ!=%+zFesB5h+NJK`!fpRd@G(i;-|?JA!&K}vN4{t? zEzK@tjpL?*e!Huw($9a=i3Ob?9G2+*t5efwG$`C(-dk5jD}M`xRK4~}A;I2!gNs6l zMv_!~a4VVz*~?9oV5{{&;fsuZ0~2Hs+=|?9g0T~;Q%?InV;;F3C!JSj zNF!L}prV|Di1Y^r27U5>x(E$oU?3d~a~5%7$~%g6cnSm0SFJTp9B5nPuo_{=T9Am4 zd{7Z+ZE5^|)FppM25J14P&%%E8Unw8D*c&Vx?l3N^w)KB?%$%d$240(E;mjW{+>1e zVC{N)qeUZ#e^Y@r@lx}aP6pGGG=8=gR_$xTJD&RXGBa1#UXH54ZZn%Qnzz|w7)U0% zvYBk6JnAN5W=NZ}y0}c)qIkNV!l_)0WJ|dEv4&lVi>^;EcWtDN{o9|GSPq^{(fG%$ z2sKBY=d>(_x6iLf!;%_9wi=@yUtDg~$;ywE;n!JtRS66bUJ|@yFyGJ(!_v=+(^$oS zZ^)9nIrswBKSL*LH5c;UX6*!Hk-h)lN@bbq1fy>56mH{=zR6F`QTK6f3&P+lmpBH0 zKdYsm9=W>wX8Tbuy|ZSKCbSx7%O!2ptCEKP{_IRo3xz;5SyASPNjRNp>zn_{AG7Vy zOsMraKx8CQNiohd_7_|zltJPTG2J*lqLSJD?QhX>b3l0YD@oz-Hd|@%8v=p+dvYK& zj+W2qQP0_H%D=rEaiHe#bu6<{u@2k^d*FzOLPwukMxma*h>iN5drN@y?#8iYq#swG zxKMlc%QP-ilm?Ghtt+{n;ic^55`*ae5KYT|J+o$ScY^hO{UYc>jPo2qW-PvJ^hhWc zN%D>yO?|`q!N#36(pDg0r`_fqD{GEjdv@^Yj){S+Cx7HUvX(tILp9{l5J}JS{Pkcc z-|t;+g2{lq2~&7{dQ%KoC(=i;)Z6qA&!#)K*&l3`qF4S2qQk;~qNOG@ixPinMD_v} z13uBsR17-|1wDcc6j$I;dtUZu!$@3#{BTOFdmgp_EuWqJhs$E>h%bSgL zzQP$Z(FC%406dG*O>xvqG#^Is z)Id#*F1t?_q87BF#fhxnFR*DBoyx2y49F-24QR_WFS&HrNR1g1YpLjsGs)jQyvY4$ zWWPUi*_^(W(b(R6aKO-)N}O_q^U+hEUVELOnM%H#_0Qb0u|@Q7vqxjPO9#46meWr! znUuX?jnI3Qa+a>&@%SA&OO1z8CTqmG;;!9ycRvl!zftw-44mK4r$ro#)&(~*Pw!B% zcSUJ&rJYL^-CR9HKh|$XgeF2RxPAY!aM&&feoSHdIuPq~>gQd|h>{TU?gJbeqL!3D z4X4$#6r$8g@>;@@9gD+>Z*$+O5xMIZHtxFa!TB{`{Bz>rvU&*S?{yq{U8zbo(}!{Y zJUfJD9+WgZnL+2hb#l5au75{B32COSIBFrsW)cOi?qHjf+ZJx#ph};(CRB-p|K$3s z!*r5fP86N##xb0rWj83e{pTF%EWVWrgg0LMI3|jtmy*05-|_Sj9o7aXG+ds4 z6!^c^X>r}B*C1cKA2T0LVJOn4zSQChNLW`#mGb|ii6(_J8>h?`cz4h(ehkp9yQ)8-$J9e8${}@ z=M#V9TA5w`*v3f(6NW+zd`Be@=;;B@;369}m>VdZ@jbIf736UYaHQ;<%8)T`9 z3@tssBHbE~i9rOlaUxWZ`Pf$1$NzUh17tnXuN*EL-l$d^?lG%r1T=(Mt2V1eC+4;; zAv#QaB|bi}Q)sP#Cv8O@I(S2>((db3Z%MCL{%QFYt@dDdKtX`T<rtU5dOWmy6oQ;WjeG)VTS0kJsIL1ylP%EwxCz^1;qU z@9#5B;$;c@GWY6;0_Qsd*M|kC{k&K+89<1)d`djSCD4`ZW1Ib+6g=Y9v?Y4=uwKpX zzH_zcjDsB)Z7+ICG7)L#DKcKtRrMNE2c>S?^y2USpn~i5_J?VQjwcDgMa(n=Z!{a9jmc)3I!x_LD0qJOciU zY{=F$m&9_qPPwfT$8GIG#m8oX-ZIQ;MTf)2%&WWL!w+#|!pv{-P(Skxc3pU$-;rP0 zR{67rLKrg}zHEEWWNV%u7zQvLtI8J?9-7?kRJO5(g&jVS%^5D0ZXQ?95 zsUas?s9P1b!zq>J0TzN6mI&TXj+zgwCijq<>KQoN%*a42g z#tpI6Rk6n(D#01wHaC^O89Vs31xy%>pGPq)SRVC8WE%ySqWU8>G9XyA>(v?ohgsZrF1@&-Z@& zeEXa+&KUb22970b{nniKyz{!Qna1gVWH~R=`H@Sl&Jjk5ac)9jj>&uqH-$U6&ZxFy zZ@^nm;!l^#fE7kEEn;O`x930tZROeS+%Edt`w61<0dVV|>>wm(+61IyJej+~3@Gkm zgRFiAX-I&4}Mvi~uP@ByXXr^?@Zc#&8tg;T8aXg-ha_1p7Y z?~GeaaKcvMwbem0EV+)koF2?u%KOH((PY6$7{ZD#hlBu^`V4>1?`Q>!4Mh-Ai}JmR znUu5(G!}wE&k)+rmxPYt!$kDp0%%_`tM}V4m=%p*@f}QL#E+k(7YPptQyip`X z<(mp!!DE)@(m!xpI{O`;}xWgstBp&o)2@cWYEZk<*f+CSTP` z_olS!Q_qg%_VrG=#_1NdWSf3?wWM?(cS@P_UnHid#kY6k(gC#5 zPHqXAcw-tgsjg1R?@%9#bt9;>*V=`B1b5i2vh4&;N_itN7I znHFIypSjsi_c@LBAeG2E2l|^+D<}*zN zUtMsFepxiu$2i-hd#p6q&}vI<*;Nz7o|ujm^DcYY^!ttcVbfGff4u6ni4)3rK}*P$t;qU{zadA8hf9PoQh-HnjWu=&rzTAe|kbo(~| zY)C|Tyv{&VGK|*)`C-8TcG>QU=fw4;h+?UtD{GF)-8n#0;s(+h^dM0jE(8g`Ywj6v zd|GVR(h)HK8N?89U8pbAJdVR1ghH@Vx!ARMs2EnW>1-z9deO+M-?C^P^ZIpw>S1Up`R!2qY7E}mTk4@ zQtKcZkie$_c9IhxoHqB*;HKX5VnMgc1dYfY7)I&bGpwd_`?YC>h`z7Yq$Or&ix=56 z$XXOCR393CP~!gy2t>lkmD2bLn?ShiK3T@lS&@9Qx6d%{b$B#h3CihfGQ*1SQ;f<0 z18!yS8cFsM>m-4;&C!BntQeYxuk!{FUhNP?pt@s1yKw3WV8_r3evl>P6~s?*m`;+> z-IDXoX!@vYXOX>iXy&{46pweuQ;8xd?|O#g`!w=`y%RZT875P2rglP2JFJH_~Z13|;4BFXGyswZDgD6B*u3w>v&o*`=bipJpCtP3btQj+s8EJv3a7x|lFDH1RIbjYJvme~ z3WL7RPs17k!j^$ZnuO{wM&p{*^#-gd2`WAVE@lEe8h=OvP4^qUw9B-<&90IT;klYZ$6n$gnT~@FVe^3wQWA3^tr1 z2b1F57p5;FAR5#X_WV!Z587FOzSesfVRlzjkwKhaHxiX`QintYDt^{odz##qokmq2 z*#f1-#J8JCM@oiQo}roh%>12Om@9i*K6f4CfUl^=@4{l0VI-VnYGUp<-*))Q#_rF? z2hOxSVWRY`KvG(waTE%{qyidItBJ#E568~1Gzx4Z6?m}6#m&1P=r(<2wjG)*OH z;oahLS+EkNR$|9e6IGWU;c}2*U_4HZZQyp6Q>b>OGn0(vFS%y(>&j+>P~C{juhmi3 zt;!wq6gJa}_d!g?;|QokP?L;cx_xwH%H+tJe;BdzbY*Q68jYsj*2WfuNd3mh#kfsw z#M7v=tRT!U2-P0$u@+9+V=bFFxY6sY)6VYs#Z$gGG4)G|5w?+tz5b? z)2q&Oa@3Sqq9`gOG?1cqeUJ|CTuBzxNaj=D=npc=aZb9je@spK(a|W=7uatO?A_F4 zco9E%o#uNFQ&e`Il&YT4w_%9wVr;@?G5l&G3D*!z%BgIZmT}yDt0!}w#IK&}+6YlU z`2&w`63C5-?pF+X+kOt_fZfQL5=p>4F`MZCi<8s?QWE- zJIY-CAG-hqkBO0X4JVJ>M);>(1TGQ|CM5@>6gHp+SasaA$k#da zHKsP7I&8I-DCQZnx{X}wB;{pmlJ%6n4n!(Lulf5?WFd(-RZiz{?lTk_m<;ZCBhn@g zLS^dI3I+?~r|ez02$prx7N6OAUU)xxGw<_#!d_{Vn*S#vw2GLNwP^KG1r@;!YH zZkT_lD3B29j|-to){I-OZ&nG6+0?4`TKb8~TdEhEWp&gM9~6mS8ti2Ggd6$hLVJm( zW}}&f?Qo7j8#|~GvSzjUElP&>Ho^IL#VGN8l^w`%TW6&EPxr=%&%gWsN#uhrRovE4 zngBom(<5&lA800TV6M(nSO(HI+i&MDd${cM4t?c~S6?0%>D-D}9YwcBP_y405%#zY z%n~j{)=*y=UzkbC`J{vH*+r7LU(p=uY>Cr%ksuq6y%^!F`5ZFb2B zOoN%(GxuIaP{Z{)vc5%v6^Hhzu_%z9`{`oHZuTRP6Vy?2fXN!THuiry6lgM;PNEO3 z-3W!z<7#qi+NeaPuVTjM+w>kpN;T;>zK^TxMeyCP>rp#7xUF?IOh=hAicSW(jO7MO zfK^EM8N6gD=Jy+GfPh|94pR@!gx_=c`rO@zU&nF}K~9^jT&WH^LnsA&lEOIZB3veC z&D$9gPa%m4-cLNVZB1ttNi3d*u6LO`*!ueO6JP)Y9fPaa=ZCu1`F;sRnlP^x7HM#( z{YA8X$*GLupd^?Cu$hzLqp0IUIRFv+*;)O2*G+gaAav5Uk4HyKCSw6W&uF}(bmVtH0jKHg)+`OUT~&eT4^Uej0ekM5qtCmHR|KTRAf9y8iAsgo+n*AKxM#)`LNzFLiY5-Vo+CZ$m5 zgvpT~L8RVbE#ay}@8{SI52#Ci2{)Iho}_^QDsCv8lrHFa|1;-Zo%zOgLx#?713shy zi$R;SE9uun(n}_~Mf(%s=3=#4^7G4Lba#&Dw|p|Am)_EyTPE=rI^Q6kBZ+!#?b3lY ziJQ$hf=T{sH&t`PWFd<`;YYc`lj+}UCkk^7B$n{+&)R@5k3knw;-KNB-r$&`{O#Uy z;@tU0l5{&=t91wdMDhR%0(&|S!~x){wHKR`;Hf()t5?i7 z)iYy7<}z2TkUF5%tw!qXU!!wZ#myxEY@LEWPF^vfmy&Pn!cxE%)G02afu!vm=xBZF zj3JRUI$I-mJ|l>Ka&#>RWdyjVMyq%UMu&8+Jvnbb6%X^cU{v1bl=q+h35zbaJ`7?i zPiYdQ=@;)mUHZ-oMiTHl`PUNq5cgLe>vSoHrDc`*^>}>KVMNXc-1o8+?mh=X3K|tn zr}pZdyAC{~V%2a@?N$q+iJ}nsN7$SEu{_Pd8u$AUJ)23)Xdbf4@`djsAmN+cy>WNk z2UpMIL*|7lgD@;Sh(?Bd=$1RQM0fjvW~dRVNT-$Atzj)9;d)Y$4gqkx z%;b7-`wfKWaW4LJAvDHl>FyXUoS<$!PwAaPa4=y`23Kc6>#^CVPXJ=Iz&t#F{cfVX zeh*1Nc|IVJ7RZIYZs#+@*Ydhh>duW@#AC%AH4ZksSSN22m3rOq4WpqBr!1n1?rO;) zd5Pq6tx;fAfpa-$+$gxVm*p&w{os9l^BG|(jWQvXY?ZI-jUld5F&d=owTA&a&rC_U zE z7kadZ?MkiQXg6pMd#)GsNbhk@ci@o(yp|pidh3F~epB(5E}eH+o#mH4e{&o*6g1%I zP@UGTvoKn4Hp0gsk-Fo{7KD;#EyGRc|4A_EuZ|$!$}Zwt4o^j)Bs`lIKQKWN1qdl` zA@=ak*9=yXqT8aaV~eCBgW2=Q#v#+fr2`HSS*#J4Kc7psjFQmvR`Qk0HB5W@N9P*2OBR{zhd#9ADV>FISJTgpq$=4=HWJp@ zf5oBxCen{Fh(dRO^Jl6|{RFbmL04zxo5ExK$J3^|wpnleRz3DKzL7@XSmehmQN=RF z@RgrjVv@ksI5%T5 zs{C*pIU6*tAmfOmpHNfhDD28Ri}^J@f{-JK(5tC7B64K$M(&Z`l}jQ;o!x|5s}O5g zi(NuzIP9k?R;{QdbArw{SG&u36T290Rn7S(Z-ql9*(_p}f{RbdPTzFw5QUR!9dt7mxjI7)67k2dXH%S$Jw<~EAiAxnoJE%fGJs4&poTNvpUF31xDmuDYq@v^Q zt^1k(Z5Cbhc*q}8BMyTV0^#x_$NnOFrymsOuk^_A{|F-E&`WAGT8$qbTux1%pT<;C zzIWHi(&RWrLyjIuW)KB$Jb5;s5rNMUMz}jBN6uVJ(J~~;^$+fe%%@Q@ojvpIDj`J0 zL7;69lW+3xS>8#pY7SGmcI^UgfWtK0MpD<&5O%4Qup3bf66^xA74@D-2AKB{p-F*R zyeUc3OBl4zZ(!DuI}_gcg$WRu!4`#JTt97~L5g*cZyQVzFPNRo zS+f~FdX^M1Vusj$5e%G0P`j=yneW)~UmVupr_(z1BQ!FOR%18ehJt(Iek+|N?YVPQ zO9zEdV`UyJ(J5nf!AklHORtl6#1{mDHaWy$(1VGarBSjP{|P)j#6cpsxVgu@o&r*n z?tN>f&n;7h57OuJT3g#yDis4g|EUqW7kOJr*Oq!7OJ6;fAvNFBZ_rtbCsiv3x~x8f zZ+%lb`%Of_L^|&5l};b$B{GG>&z8nX6^)K4?2Jo2TKOB%rNJLuOAgYNYFAWRK0hze z`>%Pa6GJ(ilLZdKa>7OntnlY>-iddCwqW9_p(H%v+fUFJgha`mlI-}IwMwjjCfK0;I%PWZL|xLf^fP@iKI9|;-o6CHwdCGJuy z62k*a0nHz!uw?rn_bTWPKq zPFw(uW{s(=(lL2l+AH;TJUjX2L_zY_IL*7$wZW-YQJsfW75?4%Z8K`JNBI?3d=Afm zUW6YKohI=8MmzQFMS>7%ojhZwz42DAL|QHRLHZ|H=S8+fvl`p$SA;A=F8U5%gffQH zuOk0c%7I&<3TlHIvG;CWT~KDCE$2r-NARkUt7i{!L7ee^wuxAy5J^Kjsw$8<(_v)Z zJrhb`V+fa-D_LyacBtDr`|JD?-TSA= zJQN+s-khiL#@=0m5c=#Ov|l`vWv!Yy!VI$Mwl{4Al8qEu8wNLIe;AaNkUic)bH#Re`d4yYNG}~ie9E_*&K5^UEzXqE^ zpSMNFTC9>qsMz9o?G~Cv+HqryCzbjkrLrXV z)uV;_aOgB}^m121^wpU(ATwHfG+54?e*f<6xSdf>vM@5rvg8AlK`MK#hnJD5ah1+MfJrly2(Q{)9rfIYBG$Q=U|dL`FVFOIURT4MH@`r zTs7W12Z4z-TUU;Sa&oCH^Ee_h0QQVZ=vY05dEXtnC`}6|{InXS*keDeo6Pq&>$}Xj zEoDVVw2O^}vk!A4RXNf`*-B~y)T3)RnK+LlXSIK2p)>>*phva ze>0i4M%_(p#!2b+%1)Ym9A+^W^zA=%FB?;VQUj_J#xhGpe6PhKX+p=h7F6&z@#o=4Ba-kgP7 z*+&6huK_*x*lBhCU60p#284GLv5I+p^%m-Ag_jLF2nZv2>6IN$vJG`Ebw0It&I+ksD*=Z9Knnw^&C!CcA(*7& zK7=SYiF_Z&KT(jG&mwo#reUxBMpO%m^vnS51yVe(jlbxn+RY8v&$|8ieGvneCSk0z%cSD1 zz5HV~xAfV@Th~SRoi2~>xTP}Lg*$sgx!{`K5rOM>%9YJc3u9c`APHk8NEd{N^Mfh) z&?m0a@3AhlJMJuMJZ|oIFk1@{yLHN?_;vMhmR4MhI8A6VQ~3r=V#*VQVz3iohWh-x8&)g1WmtEq@<^MGGD^nubH!00kc zL=8~W24q{b(K466p~1`a>rbnL?_i9M&|Jh^HoT`!L(u+`IJ_lJPM-UpXATWky*K;B z9N4zLK5E$?8(O@@tNceg^sKqGl;2nXje(i9Ov=1L-4n!joOOJJu`MU1_l7 znMaWU4`7-9AB@`{6kx^w-mx>Q7mM_-1z`6^V2VS`ph(WM%64(;)L(%7i^s4*yM{mabfB7e zSMXqMkEeUI_eWcJyMy#%6ufXm#{HwLQ|(-Xn3Aw&RPt*DvqoVx^WFsLW7AfZM8=hs z&td=KDJX0_Fc;)Tfb;V$T750>Ea)R(_)r3lX1IJhiuJaRzX&eXYr-*o$u$Af)>kdXD+XpN`$Sk@CHT{V-che|c$6J7$ zipG)m=jTN4*M$G767_0V8;DX&3&0RdoVr@|JqLXCn2h_X8u}kD)#mx=xe3o|PQKLw zSMrLLdH^4`VbvwYAC~R~I6yGq+{iy8Q2O!=p_@FWIVx(78!fd0uFZkgVCcPDG}va8 zF~2?1|L?^V>|f&qaQ-DxT61uWF^7ub;N%5?M`*16)W7youNvY<5*v7z)T{Aa*)TNu zD`R`nh&Tn`w}{??9~=R+ygTm^)I7%!KMj8Xe!XL zhk^^Zl~Uii`gDvPMD~}-e`grHF(Cid0So~@`}wqJ%d!tjs^Q2_wDk?IG8fCXG@Dll zntb40ow~t7M5|7LPp9}Qhlr35B|}&n00fX#7vPr=S?Gl17p#WwwqOXX-9K3?m)eyyl zp!SWC>qgH=3*1XZYDyneCowG8JuldE*7L=cHrY#%&G^5M!c8^7PuGvx zTozlTsEO6|bP0St_|pXU!imhnBIfLDy<@%OGY#64t8sa`feg9q3K6DfL?G zs7f@uxrl$NJe(l|mzHqWKoE_E*97EY-_-dmx?PM7_XSj)pzY4uwwaic5FAT7Xz`S^ z{4!5~=qnu4>^S<+qfl8K8CqNCM$&N!7KcoH!G7a?(dR!Ql`9)J@Ufnq3EX(_3;3c| zywiDX6-Bl@D5>S8T2XIp(kO4-XqH#z7f29oVLN!*-=f0wnohD59_W%(QN0)^LVE?a zu4pn;C3;0hp%x;gpXzzO!*IUrvN}0a*)Rdr3r#`cq+iIK1Vy$_HECnsKHfbrPOglX zzHmAoSLy{?p>ftF&UoOa_5k~hA%Q_<`nW&$J0IO96!&IbbcdLvc`g(IG>K%uTg>?l zpZ5}+yZ85S#z~SfDcB{2UtIPN7qq1#PnF$f6dhfM2*Ky-GbIT;?`S9~i*Wa8c<+g} zA_1*x)lv7!d!g}|OTEn%aX6JDO&TcPF9s z!orMnrx~^LNe1|OiQR*aJU6_Dw-nmwzk?4p2&zgmtCi~2ryUHx-SJuD^{6xv{Bgvr z`K!9xzO5lj>3zFv;$CM!L2Q{8NA`gCQ@g0a>Ui>JGW~<^-b~3mhn3ygP6nsL+kngl z*J}AXw=+l!%K#+P#eAk%>Ydklb*3`cAGhtfldL+4R1O0OtIdjhCpqYA(_*kIFlpn@ zmhHLm_i(X+n-13tNr@4f_`ffn>5 zuATeY@E`S$50fP^fy@V>5*5c{Al=F^aQEi*1KKjdr^q%0e@Zjq726yXn4_tW(XbF; z?a+XY_3Giqfz-hFyAQCeALXER2d=gQKDdxZH-rIPY$$ga9<(F!Sp^?Z^-$MAln{9< zw~baK1QXhJo_xkX+*a_Lz~lXU&1Ps?pi(7*4X{60Xqv_>FiKHnx&jf9oL63(uNN2^ zE`iP^S0cSa7!bvk-p+Kf7Ruv8FUwz2ucbfPd*LaR{6nwv;gTznK?zJ=SAQhmqDO*QLTmzgC47XQE2!iOikRAXKy)Ii2FaxjYt(w=@qQ**G`8Xbh<3$kKAT zU{P$@%;b2pnJilORc)qhF(!##tuz=w!SBJ9(~6~r9Dq053gAv_K+Wrx(O@M9{@u>^ zqFcZ+j_OLpBwO#PHi9nKy|WT?G^aHu#ur~?xHf`LdX{IxV7cI+6}sQCJ_{P$7vX36 z$+{O5>ITC6J<;Z6r|r=oxTb^jXUkI1#%FWcAhw@Z9C{#?v%iPYUV4RLm5_bUnzoGFE5IhKUo>s`8-_Uv@>n6a^9YehpgRQ9)Bvze0%HlLGNCuIRlj<1d^la z(i2(daYe|(z?JNA?wy)>Fe8k*9`CVOR{*XL8mQV~std4~{g#)iWIp*To}!*8Or^i; zP1=e#{BTyIECU!05(tOMOtI@K@fehS$q`2@@2&=JvmL8O({Vg-YHuXT5nzgbS-GI# zvwz;7smL4U)1YfyV?0`QdwRH<*i;FCqsa2UeKVZNN|B3Ps#GabZ$6!%2LJ`x0CGX! z_z6A(|dGuF8 z!4qz>ST7|2vg&BjDd6If{uwEXJ6W4!Q5tMQo5O!|$>t_nQJQ?9 z&2a<@+$IXNP2_)FXosn-c!ci2uTdW=MuNmIzV=WTI`Dc>{`YQCmb3vyhr|Wm)G$UL zOxvb(tq9%Der3aeE2NE93mxL_cr+I%5`~K$d$n^Mycq_M&u%r~*m5^=XXsgnK&IZG zi_PRXSLS&9;9`+nXZwT1e4$Esi0NEtE&r+ZxtT72p34HoyLTuLL=Cr;qJT1;M+*-& z$mD--P;7}{5=uR}Dc6YF2C)aeCwRIb!_Y4cT05 z)8wkCRW8fKV^@z&TPLF9W(&`r#PE;gtd`%>*DGZq}#71d?WX zHM~1<7Fvi6i3wa(JlMiG_t#tDaK5J~|NP%Bw@_hXn~U7C&S#{}DWPnS2K#tvYAEk< zS4qXKC)RY3}|yKB<2zj?X4 zoct2SRa;vFQO$LwL^TJgT&Gi@Pzv<&-1RwmXVg~ zHdP-b@Z*il_r>azlO`Cb<^vftwT6TA0O!;M#T>ab~rLT{$(2wDKmgeScl+WyO z>4hV-i$U7(W^MXa-TH9O&ti+@qYW@z-M{yIpwJ(TwW!}KG4(6k$)0n6uinqJS3d6e z^!#dfw`2bHH2|Llg5>Fss|E3jcrn00bYGGu?`?9AHSk2@arqG$j>qO-3q?Mk(HYGZ ze54ot`zudr?H*HS-FfSt22-`D#vHTq`{@N+8l_9-l6zjG)uOZn4nsj%Wpiq&a#^m8 z#O2khkAk0P;7;E=)Odi05Q)T_$p3}u*eVi*FC6W-Ls^*@NNdx6@sZ~*ziBwDbd}N& zSAj%g{8^L&F|rMr;x|{umIHv9b$IeZ9h<9O-)Qs4Ry(hd_0BbHV{Z_jKJyJ*t>AVp z@2gxlCjK{TsNVfIYj}o$0G9suM_|F9q4iR}lFN^|d}!p%NPrn&2Dw2LDDE9Tf@DmO zM7rbbev?~v-LcVK&Bm^XXutq=X;HgecEfJcQ!yizaWn7b4=l9LUwo zX_hJ$8y9tz?$P%O#}|=sPJ+q@`#q&K62>Ae*^l1Sn8(&t`I4Rb*eC)ndXM~L2d|xyBbmMSpMn^W1Xp51D&Wk$>Yd`XuhQooO3dQN@UZ|&`nGi6 zY-*>mSTOL^{Wj;*YBqf|pO`d&La*Y9Vp2#Kx}BNq8{Ly^Wvul`h>a71#3SiidLxqd ziE-G3z#^6I2g1Kj-fpnn%InKm4BaQh;%oXrSx3{4pjM2A)^>R#zEq%IGY3PU`;rpG z0$<-Ws8eyC?6@9GI%cb`mzA_dtLg5>KuS0cjPv2D5dZs!t*^b*S?%0+y;vdOn?SJ* zei55Q@c@jTGH=&3+V9NEMOWELCceb3g!p#~-sb3*Fn-;L#rB$H)_z(e0Ph`n`P7tL zlaVGxu2t(^zLeDWwGuGSOAJde+{YZeZWXFAdbrVR+F0Fj>&=(tD_iB~OYH}vDU-lXhm?mSr1{WCAOM1r;RjP$Hw{X=HD%8W3vvR@_ z!KAxmA75>+LnHH5-nu{x(F)BQV%Fl-uThWosxn*-y|m^l+}!Ubs2-o9#v`w;yRA|F z&f&{tly=CUj113n@9~gM;R?Y%;)M(3M||`H;W`$CYwrW>IXGCTS_|-4K&_>N7@C^f zyd`}dTfzQ9+S;dHJQ8SXsoEwSd&j%{6D$6EPeH#2KMzSrkTB6357qA(-Z%m@@4}IL zPtQ+e1>MGG_}@^i1({jze3cLO*NanX-zxN7TNN z2alFLc8)w2$M8sl_I9rh@1h{XI#|_*F&Qyr8VwFRwXE;Pu~jrHk|$fioW8+^;o&Wa z(t&+IXuZlgikXimoQa=iJh;zylk@cxwQBi~QYnQ(sr;9(_xGg+(~5kTI6tF)$iziD z)fY}`x!4;!f1}`pEE4Su?a(HBwX5dlqb zP{d-!n|dpQ=G@->;^0#oXQVTf24D45m-hTXLN-GX>d#!PN4Dgq;bgUqG5)Tq4h?zp zmjS_-HmIiI)sP%1I7x$`BCvb&A?zx|H|*UsVai_B+kZkbIUT4s^$}kHi4Pc2TEhd< zcWeQ)Nm|%VpF1wVU$knZ)I`5*UK|#_A<`8CDeXW_Ts}z|m`vAhkB$xEa+@Mt?9QtI zT96%y;mu|@I_JgI*s(vjw`feO{3x*}B~&(dpY!2-a`(5)>bROo23g+Y&Gx$iZ*6)` zx6^!3;9yE;GRE#rwX?RcKP*W1cOkmHNZq8t!CXY!2h&0bHbWQm2;iGQkAV5h@*o^E zx_t!nBtB!HIWDYTZ`M<4q*$S;)*f;;(j4-5yU*~zxz;TU_HGjdWd#tNFAmt2Wz!q1 zm&My|<vQxVnJ?Ss%ds}On*sxWJ8JLC(* zCk9aYK4iFnaGRILx@3*gpZfCQ;V z=ZmfgaC0FdA2v6(J3~n9U|`K8po`D*3;xWbs)rc+)G&6Xv zbPyLAUg1l6d0b7#$@;urWf2MnJs^yEb3I4F4@-eMn|BH)14v`fo(OnlN$K zXjb!K&ZsLg z45(gWqoWkj+~<_-%(pymjL8iDz=5Kp08ycDux|Vb%A+x-rBR_PfD(&wpx>WF`vtqh z_XUoT`*YfXjc-Elv}33D5TS>nE<1p?d;uz9VzvRGrp*l8d|%x!cgl5EZGT?WIse?g zy~y6yAzCpuIZX(i0++q;+#j_b6@+4v|Jpwr;fn59s`FEg$R#)d3}utfue_?&5f#l_ z=5$Koz{M-()~$R=xtafQ;a=kst*S3^P8HOf`St1dt$julP+1qcUy>1N^I zXP$)16fJfql>IjfaF8hCeunOkUw4EEGz!FkD8PGw^%WWg@S#y)J$T9cpD3WSfKl^r zgJz!UKY{)KA_yS=ZoldqJCwrPPxJ(cCwJ8l5g>w1x9lWi;eHXBe1WZ2rzAHw@BX{r z#qn}J$I6HOtd45ZY%kx;adTQlzS6{F$BlhNyuS-J@e@2MdS$+RsU#J(LIaBy*}#WK zkw^2vS;b~eR!+M$I2@Da`WQ)~$2X6cX;Zfq$_oXLyNV#~uV~p5zjeHz=G;!Ny}0)T zT8t#NuC9NEU@W%T&n9arf7W7nxgw}2W0TPO482#XS4x8_3@tR7wQKVtZvbHWbE!<0 zMEaI%l2mRd*hG3Q@#tA);n0;!lhqZk^V@x!2_UIbqH$HUF~jXnZl_Ofs8;RyqWyAV z-+megGMqrCkt7~^OW*bv*1)-C`{>TJ>kjkc0~aweh3DNNIxrKfbM6J`1{QVhm6n-a zq{?)+IYcFCc?28`1AR28`=1&t60EV9-gf zJRIRPv{VuOejPW`t}e@j&?E(IpEO!URC@j!Wmx`=G6OE1pkb@jZg`cpuC3Hc>;Byi zD}Xji?ga^Pw0LJ^Fu!l4DV&0|_hUFE1`bh71GL+=^ul&sift^{=xp~Mc_g|xBxA6x zJmzF)sTw&FpHciJ-JNYURtE0+@JKmO4s)yfYtV@g`X3gPKzhhbASwVTrH}SNJ}sCt zt8am3Tl)4NQNr9GpI(~HZvu@-baw7G+b!REhxD6WwS+%vp}gDyG-XK8vlLglP$8Aq z8<9U_MG|PN-uY%VntUjnu)&*w&J1egr!4Z*XZ#y3_WDw`AEU{Y zKid28zlV8>?D>3W`vMkZyw*S*zK5=c5(l)6jo>2qrWB0fs9*Z!L!DSg9lAJj%s&Bz z=|A1#|BHa)^(?VWIroi2;e?t=po(f}g;SU+-JhHW)X8@~_r!UC_Ahn@87{QkHV&7j zR&i)PuNgc~l*GWUvAnfsUY=j0SIex%Xe|${tHY*I$-W8S7DBLJ0}jjP!VTG2u6|Fc zPQRKYT)Jx?(*#A8t1l%hQ#d6%W9O&D05ph5%1s)_*32~#)hWaC9$VdTD`Zk+3uprI|cnDp_*VXy5hxKpQ)Zt@vDY@?{Q~W=%`|`Q^ zYtSj>0hZim67?3uaZPfNI1T#BwTVIvP%$yDa6VHl-ula%lgD)|qY?2@%Ccsv^n~2C z*UjJK?kx)y%H!AZk{!UiX2Nz)(^#wnol*Un|2Xw=BXR~&V>>@Qv6-HNN!s$MW@3GNU4%8f` z9kM{7c_%l?`3AFdI@9Kg1w|gtd2t2e{Z^a+c!lXzWyx`|f5$zn!xB#4eju+lQ&+;g zQ=xn{aF%1!A$V2aE?W7C>*iv1b*GZ=DgU^0RVIziCI@kyd;(q~d2-d~shn(_52O#X z+IFOeEA67uY?-fcr=mHXyoFGywc6S=m&u^hfo04x97LA6{?mD5qO-ajIZ2UE1w^eB(H!goip0XZ7)7wEJt(P^>ArE{{atshi*yGEt zoP+(ws->%nzPFtfJ&`(Xv2!vaGX}_6B-C6`)MY;ngWNJ_K!JRopE+mF})vvY}TW#LaeeI0tjrPgvcslZEf3*(vXWMViyZ8Y*JU(+-y?;Vl z#)SJ<8bhV&qB@YVdhh>(c0aCOz!&@*FgS8H(dqe#^*gChLI&0y7L@ZcezZ zH(eu;M4?=c_wo&vwNAB|R3D8ADK6cAF%-tq!<{M~pzCYw-+)^N2V$-GLcSb`_*>Ry zoQZoz&wDF{f}43=tJ5egJ*Ye2pc)DHH)rDvqCPpFf>FkScH#x>&{nxaPbqEk z=CA~~12(_&3B3#Rg_R3a?~}4z>`F^^hAjoh%tko_f_%a zOyPa=$Oi`927lFpnZ5D@RLYkCjFJ1V8J5KR*Z3?l)PYF3f_(Kv!FDYH_^VqCy%aGb zo<3XSOKj-T|5u^cBsd+4SdskxE4+=xnottn?+Kxbi-*AI*)JE}-Y7eep>(=oB>>L= z8$#Iyox+prZiqviS6#I`b`(x{z-Or0lp%w`{1-qbgxF3A3H0X%<*^ljK~L`+sKW6h z^hF=9IP>r}&yUKp$_Q#x)k0#s($Gg~Q3!?tX0(C+YLCTQF)HW+;=n1g<9Lpv;ljZF zt7VGW{TWQv$IHM(Cl?Ra94HYO6YLj!nj3m)&=h7Og`gH8w9Enhe42kh-;KkDBPG6y zBwVQzlz0ECEdTFkn3zOFA(kjT#+y2R2fND=1augZ0A3*<9{9nDyewTZP&MkqN^5UW zRbce^@gqUaA9kmt90Qa%r%#n&;O;s88io}XFlM8iHqahtxp1QXd5R%-4J|;AgwR6u zp25zHw+ZX7$bWDG&jR*pPvr>i1%aMYpv?L~x*_-~b?f?Hv4^;DHIRFD%=-kK*B9*s zGRXlhLSdd0DfgP5tTTc zp#@G5qT~39AKZx8^UZ*OBLHBaeu_$GM*3gpg4W%nhTws;8`3;yanIM&0Y-P|(TUZW zfE-?a0M7|D<$`(im9STR!>k3%%sp{U*sd0#2f{II_3y%)jX(F>XgO9)ox}~Sb8>+5 ziLb{bX2{kdRa3x@3rK-3&q}Y*{m$`!^1>o0fQI?FO`I(P3}vP|3cm{=X8wS&Q$mCb`u*~ctLkOFyDoJs zyD%w^i4-$7La>PYVfpKzfeoNH!)j76+~!}Xo>#3KutCMNMaTF~ezf@mrRyz}TD;X? z35VK%)v{y|9yb}u?;iEXG4GWJv{Z#+`W3I9?ZE?gh6qjKz7Qn3OZ#RJU!tIo4xeUs z0dEGCGj^F*Ar{TBeWp%Wx^kGqt=v)8V+uLg- zprXaA9a}Q_k3||XtWn1!=5OPUY!H50ggOeSgg41l3>0p3h{*01n|v85^qikROw}@d zH#-3fMr>BPDn{|OoiTVHY2BmuE+6;q89s+3qjYGjlO=t2e_CbBxzoXWak&^JRLEwa zuOl52pBGNy8Q-bB*c@&2}tt%hEd?v!257)dRiygAOlGv@j(`*zt?YYC31WZw8ovRG|nvNCGwx%Ew<}y3kn`0<;STJDp+r3Y<9w zSlb2;vOvdODWY?C?)JssKKov*S-`}f@I-1qY!FXX$lvakdCYynm~b;ao+WL?`Mh#^ zmYml2^Qfjg?8$aPz-1xHA1+1sy9>YqxxIPl4muwAj-=)}wf&@vQUCQ>b6(pnZR<7) zGO1Q*O|9rOY56iNWd+MfsqE?-L!)oC>^4WbCJn_B;52Y1?aE^W(C$tsR55daXxRBH zF=`38DWVu&&`%P}>iALs{)L3-y6-4&${q&m*95!-22KSW=AAijtWBK%W*~yH#VfXy z(#FRi5=@?pX`XabQjC5umV+JhOn(6$&?PEb>MUOM6L0-Z|KJYAM^9wX1$pELV~#N3 z&5a!uc_wb9+19xHhupurHAQ#mr{h_r0ko2ake>vwi{C9+ufw3CuwPZt5WsW!f#>@5 zU0O6A7>=va=Jytn@pd9zA|fr4AdB=Cf$u73p`P*h-)`ts%NdvafJx(pNuE!sT-oME zekKB9vWWa=b`aoZmN``vRyfJ*HJ=W|@{;`2V&a>(&3 zrhS<0KfJeLcGh)P;s|P%=r8DDz&Lqdo`sSe?5e)L01-eVyacMkWY4(>q(IH&8}s6d zzYq#?3(vevMBwibLrFR(g+5R0|NrvV_;FJl0>EKmB>%4^0-kq7tvN~Fv8$8N zN?}VKPhK0~)u@hPaeqUPIronUaj0wZg}SCOu@C6LHSzdKIxUn?gkn4ZrvPEN?Nd1e z*Z^(E=2T9__6rr=$jT$V7l(H0AmB)g6^L^F6Cdu4L6qC#S)=%93H`;ZfEr+UX~2ae z|L5AE)v%jvs)MD$$_|VXh7rPw)rUnZ$!~ft%L5ih9eM#)c?)tcC?JbTVoIdBlJkS7%olm1VT0DG55>wp1nEY)yQEQR z>5^_fkOpb#mPT4(&g-4K=HB@+Yi9VxTKwQW`|PvxdG;>x{OQ2RHLhX&qdT0JI=~y4 zKe*w3U4R>O+`gLS77ORu(8!2wiSjk=ZomTQJxqXDEF%P#N)}S`BWIenXRtDq%utec z(+GSv#X6g1$=-+`%CB>LczrJaTyg_qpkA|U))@j>LCZ$MzB9lvIiGGQgbJ)FgnS8J z%fpMR1yhGVQao!0lby|7??0XcUCKIhTkZ#)!uRiG5E{!jb4yes>Ju-0VNfElnPiJH<@_ZPhzyFb=C^(yZDC}iMq z7-zL?+Ue*HjU79P{xAjwX)zNOJXGlDI`qn<&$qMOvdQ}yT0JlF09c-inl(NW;`ft# z?W=6ESmO>spl>Zajn;ps;N-Ur<-)q1MOc!Q`X!L zospWFzvAAWES!3cE#KeScl|CScowg)Nsy!d=jPn(^nAxiT)(~rH!)ig0ryMG&HRhO z(W4KQvnT7x+`rmZYaQ1Mf;@US#9`DUgB$_?s;ER8xDO~H5;3CQkYi2;vMa`s@`ap) ze1Nzb0g8rrP}v~bm7^GXjW^exDCVE^tW)35H3;;4Pg%BJwD2)%o!(Slu5Iiz;}w3J zzEALxe!w|kyVJhasw2S3HA-`)x6$*oXeM!a)w4Kl+U7T`&suqg$Z@%Ep_9|58|~-E z3O506x!17ZJuW*(r6Lh7N^O@%RI++p3daM|R7$DrS-S`4_ubglWCmJg?umPP2&tTA zedE261lM&}DEmt)-uG>-cWkCE{64cbA$oN!F)?Jfi%d4BKXKNQTHa5UKm5T_obDVi zb_TJ$&f+a^;P(MMYCcb95`$d9Q_4RrX!4i6VX3#(?XNku%9 za%uuEzi1z~+~19)k?GP$`6aaR$6tMy*tFPGFlOtqjU_jf<`~K7#yEI>(kN%W+9i{x z5#XRC9Y?7pb|rxe2sf0$zOSoi(4tASl}oiN3belHsVfXfHh~s}2U%}#td2|S@jEe6 zt7hXM6V6T)X$r}rI5eaQT9av!oB~`%WT*fX0%ZfoaCsi!_az3b8@ZNCkpL7*lZWs% z_$ZFCEQf$p!idaj2{06#szw9Bj!JT$hCY-!yGIF%caQ*2s7}U!s;h}RZDYJZZHV9) zWaRLD=Hi>RhILoZNlX5PSWB?$!Tj)OH)g&uThdI^7rG|`PqJmM(ppcqt8)G%^~I2> z=<4dWkb2>b7pP?H=AhguaI|Vl(#VO*e>u}zn3-t@Qxlk{C*r(NG`2??UYyxfeEzaf zuWOB0JA89@u>d>yb$1&&xz_tVKaEKvyMXAK+isv_Uo+-O&$KA90sAr z(?J&Nl@9nG+YQ|VGodC@QSTryG5a-|1v6t)P7LWaR-L`!rku4klil(iCNm>rZu8-u z0vNNMfButC&DxrXOxS&_yx!;bV)|UDNTV>ud2h;kp2ZM2N#h%IqLg zQfBz;Mg}(aTlDh{f-bhYIM+J>jONwqmbg> zov}T8C)Z)Q+4^u%ASok*gb(yIn~U9!M6oNFm|0jheV@J>m@&@z`1^M`_U34o(Cb!Y zrktD0BRP!%<(S4k)rE1Bquv2xK$4EFg1)(nTHC7eH)?@*QQp5{v5*&6Gi`5%7`e zp>Z-%iP}!D7ALToYPOik$iW7%V`LDHX!e7M>U+@n{7k((gw-;!c&Lkvhw(|<_>YD% zSzeI4WXf2b^s;i8)3Q(Zl|7@~d~H<^nb1{fu;|SWjSWiDQjau6oy|rmvAc)f3zwEmXDHkH@S)??KDtx~ck0=Ik6cj^~BHEYtvI3mCjXc<}(qv90EL;=c006io?ZVP^})&9PbKo=s8Z+e$X9)nGD=G?n@qH620&a>q57Z&^Uc$< zW|^vi>K=g{>B`+HCMOgUE!AZJi1aB_*HA;V%Ag5+Nazbu3jM+Q8hmHOzpA4QR3`$y z+Zzsz){n7(1ohCdw}%SYZk-|FAXY#%N5ug74?!vdp_P(}zQg{-!r3tmGVSm9$skRE#F zkso;X=XRmFrE-SY(6JXUyjq=_O37lAV5-)donxljLTSFi0aMKHbRcIEonvTi`)sS# zs=b~u2@CM_YOa@f)NR9-JYK&;%CITF7G18QGk?AefEqVC9}9AE#(3qB@t*$v>$xB( z><%(`V3fIO;#Xp8c-jxnpu}WJkMs!+91^XQ^;Zzt>dl%HRNq#8RCoB15p#IOFZH;Z zni*e{LPPC<_z>lL&?^e02^x@HdoK1Ug=1Tym+gtckIOP57P|znQ+D2L))sTWT5~6_ z`PT>{nAPLUf?h={h#=@T*ENG!hunV%o`2kHRK~LzES8Hs6c(ZVrPJ&heO{7fFZ}1E zKPUI=JL8RQ+smXn8?o%_L&k-S8P``GPpQTM%K1#XMMjh&6gZTlcrFg*dG zLBLNGw>oNVJDT`UWGju5hX)bflf}--pHs#Cu~H~D(`EIuzfiMY6tO-Z^90sbYq5wG z3#oP_GJx_S1LUnb(t&(NqT&!CKBxiIU&9l9f*in))Y8Bo=18eZC1aT%F zDzJ%a3vx0MC+@O)gMwJh{X)y(4|CQ4Go5~3l9Qs@>}n_In4fTz<}^Jo5E9PMZrJoh zD`h8HJ7ui(a_QS7G6w6Xm!(=KXUB;2?$^K1MyXxh=KvmyE<{%e&4krD;(S5gy#*dJqKJfj)Ah`qmcK;4HU~(n2 zfxU&|n@X%tginW1;`}f1Q8+OgcM1UF)W5KSF^JXEqqO|-dWwG6Qby^bXVfX>%41+_ ziekWa(h{_9H4annA>>2=&nvkUro_UDD5CZ)I~K;qtii$3V2011KXaROP?l&{ko#P3 zGO-JObw8oz0k8k{Nam`=9B?`JDy?3ZYu1;`_Y|OkU2%4H9GER~R4dKS;xVy5*Zvs| zm;Ur@ZOxI}M(bRaxha?x*vzWgp^bxA&2^QA`jR;>O9?Xq@+Kd712QoDPSM0RbV#2O z;zQ_=;$h%^sc^3Tz6BZRl&Gvj$Z7vPW%>X0l+pOrkBFQo$Zq#_l-tz0+I((zriv%? zCL5QF+w{WFUPMInjg=&Sp30N*(yMi++T7yZtS@|ikEfD#6Xnz9cW-3lmV*C#zRAU5 z91&C4IQ9xe`i3!-@1iH_SAB#KUuqULx)b75^{ zME=Fd(6Ww>&pF*0UAJU5DnFS;d(IZ0!-%wdYqzlQ^kP>Gz~#l>P3^cL$PQ=thCaRr zYW?3rvekQeZ3DcaKyh(YcLztTb^pk6u%Zy%-7``WmY>?NHxKFe{(tHBe+4PRx3X+7 z2q7~IrZXys6b1QSKB{y@Dh^YO7FKx)!_gHy-T$>aIaO!FH{b5p?nC7nnsmHKz;004 z%era}A#_0yBP7hU)+iB(?7Tl)cwrMMNdPhqi%6IGhW7OKK1OGihB!J+hZcvv0;N>% z5qcx?b8Db%ucL_H%XSobJhM7pceLF4$ zNSW}@9n?{Ef1PR>lJv1P1hw*Me?2GMl=#;yDCPBX(E|YxiX+&(kK4>}XbZxiqLc0O5r5-v ztXL4tnw)GrcKqsI6`kNN=0k&(eZvt59KdWBS!O0(X^;)XT0c_)DoUf0$sb_fv`hbH zx{PSF+m&gqYTd^ZVB2>t+x@}|8H9>nmg0+CF4p{738Z0UVTpV#@^%GoIE{1ijeZ?j z7GNC=DuOVu@#VebPZz~;TScq6>MPquL_G|E?*_C>qzHfYGs|39w=U?6QbJ1uVdJ4k z;K+KYCRFn1mYO(>2uu2;J^U;Afq-|6kiC=-R&;NEY2uC2w@34GD0lkFIT-k_?wED7 zO{9=rMzt_Z!{*f|r6vBHA~ZnNJxsgx!&Z4N-U$o5T2h7yJMF}%!)WTDxAG7WwAEZ{NCRHD9QzmfJD+nq;T$pq{=`TTQ@3 zovQRUxp%w8_Ti^}kt$E#aDc~h^vzLe9@j+IM~)SLz$tywtVG{hD??gueTtKnq&}@= zzP~nRALUNfC0u8kA*Kk2KD?Gmfjl0ImMJ@2iL1-5;kG(R=lX44QA9L~P>9CsUp8f# zzb{tKiBm12ddZzLd2|DRrn+e2ge=c}{x;3wyKV4_9`T)(1}9qdr(%g1rJ)E?r3t|# zT2a*lCKi^*3obqL@t>`~m(!?jCK)vJ+4!ixjq&jC_}W|0;hNm4$H>4S7ryXvl*t@v zagIR-qk|RoDg{Wn?J*lEsPR07MnoW^pr}Cg9u#BwALk`PgZ5VIwol z5_93Q1YFH1eFc8Sc`AB^$4rJ?g)^U~Y}wMuZmVG3j<**n5_4R%oyw0BIPI-F9kfj& zU(sHDspxfbR(g4mjLjm(E?$YTTBS-g1^}fO&mtw0*E=Hcq!5e1K|NSJ39~r(=BiFkHN&AtRx{2o`0x7i3g^LED zgtRyDxyEob9G9U;zu|E3r?O_i*`KT;bHQBz)lNvj*_*E`wacs;7#)3HNZrXrFH(wS z%T*m94hMZT0^AdJP-*w~hXcBXL`j7XX?!EqN@LtsAO9c(;-WN}=%%W{zj$g-x2pvz z>;Y!ebp5`d7k}yx8Z-Eu`qF0$k{@)>vKCNZre2zrm9(Bss*BGQfa0YFLkNcg8Z4cM zjYWctHDThrm-^BBs+u^Nic9bnV?PDnAi^i2fv-@e52%N{wZABcEOyMIKL9}@D0Vnf zvE>@9@S>V$|BL9{K$G;pi%#^3+$>8J!dXQW*%N8)XIXnVRc0LS5$4Vd3~Ebk0lDQW zw=S!_a~wb0oSXW8xH2d4MfhjNr$IuDdcR)P$G^PHMfN^UhF)bB0lpn)Oo;tDnfmwN zJvY#ju>@j{4Jd#JNJ8Uh|MVHObins*U}5RVzbgw9SxGvH1B)Q#rV7kK_?RTmU*GL` z2YNB_c?*v~ctOW#6XUoLQ*!D=z8Ep=7IRtQuS9)-xBF4j^^yLuvwh9x*MeRjo1!|= z0Y8?3um0qsvLU%pUk9YpfB_4K8iYG`ZNl4PL<*87 zeyB}>O%_6qE*U z@sPNoZXLGm|0JwR8YNSce~2ipvV~=&3aAsgor8`O_>^OsresJ+IKx9n+fEw{peHu= z9Nb#^7f{ytt8*Mc=eFGAvF0B^wUBhE!EM?2Q47jw<-`5gO$`*MFAz#S}q`Ib$w=HWhtlT8G%d!WkBBQM_ zZaK9II8!FcZL29mKVq;r;U{ejwF~~1A#*|nZ&bHE-b^7zTeQJ6uPyUrSqLC8T2eb( z{juZDNT^N5R@fm8B2)yKwiIHZ<=@zvAD)#tbh3jJ6-G^ICXx2{B}|qDi$n=1+yJ_Y z!phZgf{zy%6VVN$cUFjr3r^>=j#1dLM`lmsSsp;7{CVl$g5UYV8$l-{X0*3QQy@}p zx|PX{m=7Ck{zaABB`5YC!K|qHQ35nkVb_o26{BH)Ye$s1 c5gqr4ZmF=eaI4>faNtKlMpe38(&YVr0AKrvF8}}l literal 0 HcmV?d00001 From 9ca350146fd0e6a022c985032c189d2825bba65f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 20 Mar 2024 11:22:27 -0700 Subject: [PATCH 32/44] Link to Zero to Thunder studio (PR2492) --- README.md | 4 ++++ docs/source/conf.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0a19713d1c..d5d2bc29b4 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Her **NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. +## Start with Thunder + +Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). + ## Install Thunder Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly: diff --git a/docs/source/conf.py b/docs/source/conf.py index dcf2414b54..052fa2437c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,10 @@ github_user = "Lightning-AI" github_repo = project -linkcheck_ignore = [rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)"] +linkcheck_ignore = [ + rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)", + rf"https://github.com/Lightning-AI/.*/blob/.*#.*", # github anchors are tricky +] # -- Project documents ------------------------------------------------------- From 2b82a6b14dd46fe6148fdc2905dbe8870e4aa77e Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 20 Mar 2024 20:29:25 +0100 Subject: [PATCH 33/44] releasing `0.1.0` --- .github/workflows/ci-testing.yml | 3 ++- setup.py | 19 +++++++++++++++++-- thunder/__about__.py | 8 ++------ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 34751f2795..cbd8fb2aa6 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -79,7 +79,8 @@ jobs: - name: Install package & dependencies run: | pip --version - pip install -e '.[test]' -U \ + pip install -e . -U \ + -r requirements/test.txt \ --find-links=${TORCH_URL} ${PIP_EXTRA_FLAG} pip list shell: bash diff --git a/setup.py b/setup.py index 566f2bbc40..b0ee14e897 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def _load_py_module(fname, pkg="thunder"): def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list: reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines()) - return list(map(str, reqs)) + return [r for r in list(map(str, reqs)) if "@" not in r] def _prepare_extras( @@ -43,6 +43,19 @@ def _prepare_extras( return extras +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: + """Load readme as decribtion.""" + path_readme = os.path.join(path_dir, "README.md") + with open(path_readme, encoding="utf-8") as fp: + text = fp.read() + # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png + github_source_url = os.path.join(homepage, "raw", version) + # replace relative repository path to absolute link to the release + # do not replace all "docs" as in the readme we replace some other sources with particular path to docs + text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") + return text + + about = _load_py_module("__about__.py") # https://packaging.python.org/discussions/install-requires-vs-requirements / @@ -58,7 +71,9 @@ def _prepare_extras( download_url="https://github.com/Lightning-AI/lightning-thunder", license=about.__license__, packages=find_packages(exclude=["thunder/tests", "docs"]), - long_description=about.__long_doc__, + long_description=_load_readme_description( + path_dir=_PATH_ROOT, homepage=about.__homepage__, version=about.__version__ + ), long_description_content_type="text/markdown", include_package_data=True, zip_safe=False, diff --git a/thunder/__about__.py b/thunder/__about__.py index d9fb64ba1d..15e838ef4f 100644 --- a/thunder/__about__.py +++ b/thunder/__about__.py @@ -1,21 +1,17 @@ -__version__ = "0.0.0dev" +__version__ = "0.1.0" __author__ = "Lightning-AI et al" __author_email__ = "community@lightning.ai" __license__ = "Apache 2.0" __copyright__ = f"2024 {__author__}" __homepage__ = "https://github.com/Lightning-AI/lightning-thunder" __docs__ = "Lightning Thunder project." -# todo: consider loading Readme here... -__long_doc__ = """ -Lightning Thunder is a deep learning compiler for PyTorch. -""" + __all__ = [ "__author__", "__author_email__", "__copyright__", "__docs__", - "__long_doc__", "__homepage__", "__license__", "__version__", From 4cac85243bbc01c823611141729a0c310c2c72e0 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 20 Mar 2024 15:36:50 -0700 Subject: [PATCH 34/44] Update README.md --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d5d2bc29b4..9e32d31ec9 100644 --- a/README.md +++ b/README.md @@ -36,14 +36,20 @@ pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.c Install Thunder: +```bash +pip install lightning-thunder +``` + +It's actually not a bad idea to install directly from `main`: + ```bash pip install git+https://github.com/Lightning-AI/lightning-thunder.git ``` -or install from the local repo: +or from the local repo if you want to tinker with the internals: ```bash -pip install . +pip install -e . ``` ## Hello World From 5bedfc41bfe3fa2c3535c71e6bf572a467f5e4dc Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 20 Mar 2024 23:56:34 +0100 Subject: [PATCH 35/44] readme: fix pre-commit link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e32d31ec9..8f461fb518 100644 --- a/README.md +++ b/README.md @@ -172,4 +172,4 @@ See LICENSE file for details. [![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) [![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) [![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) -[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg?badge_token=mqheL1-cTn-280Vx4cJUdg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main?badge_token=mqheL1-cTn-280Vx4cJUdg) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) From 348597fd045903aa232b6811bf6bffa392edbd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Mar 2024 23:41:26 -0700 Subject: [PATCH 36/44] Use the throughput utility for benchmarking (#21) --- examples/lit-gpt/test_parametrized.py | 49 ++++++------- thunder/benchmarks/benchmark_litgpt.py | 99 +++++++++++--------------- 2 files changed, 64 insertions(+), 84 deletions(-) diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py index bca55173fa..5e658b6447 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/examples/lit-gpt/test_parametrized.py @@ -7,12 +7,13 @@ MID_BENCHMARK_OUT - use this env variable to control whether you want to see the combined results between each test. BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented. - Uses 'xlsx' by default. More format support to come soon. + Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'. ''' import torch from absl.testing import parameterized from absl.testing import absltest +from collections import defaultdict import os import subprocess import json @@ -48,6 +49,9 @@ def add_to_dataframe(self): self.dataframe_data.append(self.perf_metrics_dict) def complete_dataframe(self, is_teardown): + if not self.dataframe_data: + # The benchmark probably failed + return #Called when tearing down the parametrized test #This generates a summarized dataframe for each perf metric and saves as a xlsx file df = pd.DataFrame(self.dataframe_data) @@ -59,7 +63,7 @@ def complete_dataframe(self, is_teardown): self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() - if self.output_format not in ('none', 'print'): + if self.output_format == "xlsx": output_ext = {'xlsx': '.xlsx', }[self.output_format] if not is_teardown: filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) @@ -84,7 +88,6 @@ def complete_dataframe(self, is_teardown): print(self.memory_used_GB_df) def run_benchmark(self, kwargs): - # benchmark_file = 'thunder/benchmarks/benchmark_litgpt.py' command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) @@ -98,32 +101,26 @@ def run_benchmark(self, kwargs): print(f'Running {" ".join(subprocess_cmd)!r}') proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) + + self.perf_metrics_dict = {} + if os.path.exists(self.json_file_path): + with open(self.json_file_path, 'r') as file: + self.perf_metrics_dict = json.load(file) + # Cleanup after the benchmark finishes. It might have failed before creating this + os.remove(self.json_file_path) + if proc_output.returncode: - print(proc_output.stdout) - print(proc_output.stderr) - proc_output.check_returncode() - - with open(self.json_file_path, 'r') as file: - self.perf_metrics_dict = json.load(file) - os.remove(self.json_file_path) #cleanup after test finishes - - if self.perf_metrics_dict['average_iter_time'] is None: - if 'CUDA out of memory' in proc_output.stdout: - self.perf_metrics_dict['average_iter_time'] = 'OOM' - self.perf_metrics_dict['model_flops'] = 'OOM' - self.perf_metrics_dict['model_flop_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec_per_gpu'] = 'OOM' - self.perf_metrics_dict['memory_used_GB'] = 'OOM' + if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: + defaultdict_oom = defaultdict(lambda: "OOM") + defaultdict_oom.update(self.perf_metrics_dict) + self.perf_metrics_dict = defaultdict_oom pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success." return True, pass_str - else: - print(proc_output.stdout) - print(proc_output.stderr) - fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure." - return False, fail_str - else: - return True, "Test passed successfully." + print(proc_output.stdout) + print(proc_output.stderr) + fail_str = "TestCase did not finish reporting metrics due to an unknown error. Triggering test failure." + return False, fail_str + return True, "Test passed successfully." class Test(parameterized.TestCase): diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index ae9e8b2084..9120584989 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -9,13 +9,9 @@ import thunder from thunder.tests.lit_gpt_model import Config, GPT, Block -try: - from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities import Throughput - # from lightning.fabric.utilities import Throughput - LIGHTNING_AVAILABLE = True -except: - LIGHTNING_AVAILABLE = False world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) @@ -109,7 +105,10 @@ def __init__( self.config.n_layer = n_layers # Initialize the model + t0 = time.perf_counter() + print(f"Loading model with {self.config.__dict__}") self.model = self.init_model() + print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") # Setup the distributed algorithm choices if self.distributed_mode != "none": @@ -138,14 +137,10 @@ def __init__( } def init_model(self): - print(f"Loading model with {self.config.__dict__}") init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device - t0 = time.perf_counter() with init_device: model = GPT(self.config) - model.to(dtype=torch.bfloat16) - print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - + model.to(dtype=torch.bfloat16) return model def setup_distributed(self): @@ -243,7 +238,7 @@ def pad_collate(batch): y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1) return x_padded, y_padded - train_data = DummyDataset(self.model.max_seq_length, self.dynamic) + train_data = DummyDataset(self.config.block_size, self.dynamic) train_dataloader = DataLoader( train_data, batch_size=self.micro_batch_size, num_workers=2, collate_fn=pad_collate ) @@ -251,24 +246,30 @@ def pad_collate(batch): return train_dataloader def calculate_model_flops(self): - input_ids, targets = next(self.train_data_iter) - input_ids = input_ids.to(device=self.device) - targets = targets.to(device=self.device) + meta = torch.device("meta") + device = self.device + self.device = meta + + # calculate flops on a meta-device model because we only care about the shapes and + # because the flops calculator installs hooks on the model + meta_model = self.init_model() - model_fwd = lambda: self.model(input_ids) + x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta) + model_fwd = lambda: meta_model(x) model_loss = lambda y: torch.nn.functional.cross_entropy( - y.reshape(-1, y.size(-1)), targets.reshape(-1), ignore_index=-1 + y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1 ) - if LIGHTNING_AVAILABLE: - self.perf_metrics["model_flops"] = measure_flops(self.model, model_fwd, model_loss) / 1e12 + self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss) + + self.device = device def train(self): t0 = None - # if global_rank in [0, None]: - # #Calculate the model FLOPs - # self.calculate_model_flops() - # Setup Perf Collection - # self.throughput = Throughput(window_size=10, world_size=world_size) + if global_rank in [0, None]: + # Calculate the model FLOPs + self.calculate_model_flops() + # Setup throughput Collection + self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) if "transformerengine" in self.compile: import transformer_engine.pytorch as te @@ -326,45 +327,30 @@ def train(self): print( f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" ) - - # if global_rank in [0, None] and i >=warmup_iter: - # self.throughput.update( - # time=(t1-t0), - # flops=self.model_flops, - # batches=i, - # samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), - # lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.model.max_seq_length), - # ) - - # metrics = self.throughput.compute() - # if i % 10 == 0: - # print(metrics) + if i >= self.warmup_iter: + self.throughput.update( + time=(t1 - t0), + flops=self.perf_metrics["model_flops"], + batches=i, + samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), + lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size), + ) if global_rank in [0, None]: # print(f"Total time: {(t1 - t0):.2f}s") - # print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms") self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) def add_perf_metrics(self): - # tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s) - # = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations) - # = global BS x block_size / average iter time (s) - self.perf_metrics["tokens_per_sec"] = ( - self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"] - ) # tokens/s - if self.perf_metrics["model_flops"] is not None: - self.perf_metrics["model_flop_per_sec"] = ( - self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"] - ) - if world_size is not None: - self.perf_metrics["model_flop_per_sec"] *= world_size + metrics = self.throughput.compute() + self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"]) + self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"]) self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9 def add_model_info_to_metrics(self): if global_rank in [0, None]: self.perf_metrics["model_name"] = self.model_name self.perf_metrics["Num GPUS"] = world_size - self.perf_metrics["Seq Len"] = self.model.max_seq_length + self.perf_metrics["Seq Len"] = self.config.block_size self.perf_metrics["Micro BS"] = self.micro_batch_size self.perf_metrics["Global BS"] = self.global_batch_size self.perf_metrics["GA"] = self.gradient_accumulation_steps @@ -416,7 +402,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None benchmark.add_perf_metrics() print( - f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.model.max_seq_length}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" + f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" ) print( f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" @@ -429,12 +415,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"Compiler: {benchmark.compile}") print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") - print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s") - print( - f"Normalized Throughput (Tokens/s/GPU): {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f} tokens/s/gpu" - ) - if benchmark.perf_metrics["model_flop_per_sec"] is not None: - print(f"Model TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec']:.02f} TFLOP/s") + print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") + print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") + print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") except Exception as error: # Helps catch OutOfMemory Errors and post processing of errors From 65c5092b99d6859b85d5574ad4802a0b46659875 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:02:11 +0100 Subject: [PATCH 37/44] Bump hypothesis from 6.99.8 to 6.99.10 (#11) Bumps [hypothesis](https://github.com/HypothesisWorks/hypothesis) from 6.99.8 to 6.99.10. - [Release notes](https://github.com/HypothesisWorks/hypothesis/releases) - [Commits](https://github.com/HypothesisWorks/hypothesis/compare/hypothesis-python-6.99.8...hypothesis-python-6.99.10) --- updated-dependencies: - dependency-name: hypothesis dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index c6d95336ef..a1402bd69b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -8,7 +8,7 @@ pytest-timestamper ==0.0.9 graphviz ==0.20.1 fdm ==0.4.1 expecttest ==0.2.1 # for test_ddp.py -hypothesis ==6.98.15 # for test_ddp.py +hypothesis ==6.99.10 # for test_ddp.py numpy # for test_ops.py einops # for test_einops.py lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5 From 8f508f2b4330eb1bf29822af6e54a29b273447a3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:06:00 +0100 Subject: [PATCH 38/44] Bump ipython[all] from 8.22.1 to 8.22.2 (#13) Bumps [ipython[all]](https://github.com/ipython/ipython) from 8.22.1 to 8.22.2. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/8.22.1...8.22.2) --- updated-dependencies: - dependency-name: ipython[all] dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/docs.txt | 2 +- requirements/notebooks.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index f12c617045..14615a08c4 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,7 +1,7 @@ sphinx ==5.3.0 myst-parser ==1.0.0 nbsphinx ==0.9.3 -ipython[all] ==8.22.1 +ipython[all] ==8.22.2 pandoc ==2.3 docutils >=0.16 sphinxcontrib-fulltoc ==1.2.0 diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt index d2f9d92e56..47a14902ca 100644 --- a/requirements/notebooks.txt +++ b/requirements/notebooks.txt @@ -1 +1 @@ -ipython[all] ==8.22.1 +ipython[all] ==8.22.2 From 14cef317d381f62c63bcc63fc919fa6fa9c24fe3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:07:15 +0100 Subject: [PATCH 39/44] Bump Lightning-AI/utilities from 0.10.1 to 0.11.0 (#16) * Bump Lightning-AI/utilities from 0.10.1 to 0.11.0 Bumps [Lightning-AI/utilities](https://github.com/lightning-ai/utilities) from 0.10.1 to 0.11.0. - [Release notes](https://github.com/lightning-ai/utilities/releases) - [Changelog](https://github.com/Lightning-AI/utilities/blob/main/CHANGELOG.md) - [Commits](https://github.com/lightning-ai/utilities/compare/v0.10.1...v0.11.0) --- updated-dependencies: - dependency-name: Lightning-AI/utilities dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * actions-ref: v0.11.0 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .github/workflows/ci-checks.yml | 6 +++--- .github/workflows/docs-build.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 491d275173..ae28b10c53 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -16,14 +16,14 @@ jobs: # actions-ref: main check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.0 with: azure-dir: ".azure" check-package: - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.0 with: - actions-ref: v0.10.1 + actions-ref: v0.11.0 import-name: "thunder" artifact-name: dist-packages-${{ github.sha }} testing-matrix: | diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index a9c8d99b6f..d37e381b40 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -15,7 +15,7 @@ defaults: jobs: build-docs: - uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 with: python-version: "3.10" requirements-file: "requirements/docs.txt" From a27bb65b1b9525c5b51d2b096ad122c10f5a2d6c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:10:44 +0100 Subject: [PATCH 40/44] Bump pypa/gh-action-pypi-publish from 1.8.12 to 1.8.14 (#18) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.8.12 to 1.8.14. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.8.12...v1.8.14) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release-pypi.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index e97fedf8e0..078f9e6066 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -27,7 +27,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.12 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -35,7 +35,7 @@ jobs: - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.12 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: user: __token__ password: ${{ secrets.pypi_password }} From d37eebc5d2f5232633c37fc833afb65cecd5d40a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 21 Mar 2024 11:18:44 -0400 Subject: [PATCH 41/44] simplify readme 1/n (#31) * Update README.md * pre-commit: running and fixing... --------- Co-authored-by: github-actions[bot] --- README.md | 76 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 8f461fb518..d3d325a155 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,94 @@ -![](docs/source/_static/images/lightning_thunder_lightmode_nobyline.png) +
    +Thunder +
    +
    + +**Make PyTorch models Lightning fast.** + +______________________________________________________________________ + +

    + Lightning.ai • + Performance • + Get started • + Install • + Examples • + Features • + Documentation • +

    + +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE) +[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) +[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) +[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) + +
    # Welcome to ⚡ Lightning Thunder -Lightning Thunder is a source-to-source compiler for PyTorch. +**Thunder makes PyTorch models Lightning fast.** -It makes PyTorch programs faster both on single accelerators or in distributed settings. +Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8). +Works on single accelerators and in multi-GPU settings. Thunder aims to be usable, understandable, and extensible. ## Performance Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). -![](docs/source/_static/images/training_throughput_single.png) +
    +Thunder +
    -We achieve a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. +Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway). -![](docs/source/_static/images/normalized_training_throughput_zero2.png) +
    +Thunder +
    **NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. -## Start with Thunder +## Get started Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). ## Install Thunder -Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly: +Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together ```bash +# install nvFuser which installs the matching nightly PyTorch pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com -``` -Install Thunder: - -```bash +# install thunder pip install lightning-thunder ``` -It's actually not a bad idea to install directly from `main`: +
    + Advanced install options + + +### Install from main ```bash pip install git+https://github.com/Lightning-AI/lightning-thunder.git ``` -or from the local repo if you want to tinker with the internals: +### Install to tinker and contribute + +Install this way to tinker with the internals and contribute: ```bash pip install -e . ``` +
    + + ## Hello World Here is a simple example of how Thunder lets you compile and run PyTorch code: @@ -82,7 +119,7 @@ print(result) The compiled function `jfoo` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs. -## Running training +## Train models Thunder is in its early stages, it should not be used for production runs yet. @@ -102,7 +139,7 @@ python examples/lit-gpt/train_fsdp.py See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder. -## What's in the box +## Features Given a python callable or PyTorch module, Thunder can generate an optimized program that: @@ -132,7 +169,7 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations. -## Build the documentation +## Documentation Docs are currently not hosted publicly. However you can build them locally really quickly: @@ -168,8 +205,3 @@ Thunder is very thoroughly tested, so expect this to take a while. Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. See LICENSE file for details. - -[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) -[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) -[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) -[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) From f37ac34399c8f039637b527a79064d106abb4bc2 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 21 Mar 2024 17:20:36 +0200 Subject: [PATCH 42/44] Remove deregister_augmented_forward_and_backward (#30) --- thunder/core/transforms.py | 86 ----------------- thunder/executors/apex_entropyex.py | 74 --------------- thunder/executors/cudnnex.py | 108 +++------------------- thunder/executors/transformer_engineex.py | 25 +++-- thunder/tests/test_cudnn_executor.py | 6 +- 5 files changed, 25 insertions(+), 274 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index c24e44a3ce..772e65a84d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1237,11 +1237,6 @@ def _embedding_prim_grad( def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable: - # If executor specific `aug_fwd_rule` exists then we will use that, - # so we return `None` here. - if get_executor_specific_aug_fwd_rule(bsym): - return None - cd = get_compile_data() executors_list = cd.executors_list if cd is not None else executors_list # Checks if the executor which has priority for this operation has a specific grad transform for it @@ -2484,15 +2479,6 @@ def zeros_like(x): } -@dataclass(**default_dataclass_params) -class RuleInfo: - checker: Callable - rule: Callable - fw_fallback: Callable - bw_fallback: Callable - executor: Executor - - def register_augmented_forward(op): """Decorator to register an augmented forward implementation for a symbol. @@ -2510,40 +2496,6 @@ def decorator(func): return decorator -def register_augmented_forward_with_checker(executor, op, checker, rule): - """Decorator to register an augmented forward implementation for a symbol. - - Args: - executor (Executor): Executor to which the rule applies. - op (Ops): Symbol for which to register the augmented forward implementation. - checker (Callable): Function that checks if the rule should be applied. - rule (Callable): Function that applies the rule. - """ - fw_fallback = augmented_forward_impls.get(op, None) - bw_fallback = backward_impls.get(op, None) - augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor) - - -def deregister_augmented_forward_and_backward(op): - """Deregisters an augmented forward implementation and a backward - implementation for a symbol. - - Args: - op (Ops): Symbol for which to deregister the augmented forward - implementation and the backward implementation. - - Returns: - None - """ - # Restore the fallback implementation if it exists - if isinstance(augmented_forward_impls[op], RuleInfo): - backward_impls[op] = augmented_forward_impls[op].bw_fallback - augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback - else: - del augmented_forward_impls[op] - del backward_impls[op] - - def register_backward(op): """Decorator to register a backward implementation for a symbol. @@ -3320,31 +3272,6 @@ def uniform_backward(primal, minval, maxval, g): nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL) -def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None: - """Get executor specific augmented forward rule. - - Args: - symbol (BoundSymbol): BoundSymbol to get the rule for. - - Returns: - RuleInfo: Rule info for the symbol. - """ - cd = get_compile_data() - if cd is None: - return None - - # Search for the executor specific rules. When there are multiple rules - # for the same symbol, we use the left-most executor in the list (i.e. - # the one with the highest priority) and we fallback to the next one if - # the checker returns False. - for executor in cd.executors_list: - candidate = augmented_forward_impls.get((executor, symbol.sym.id)) - if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs): - return candidate - - return None - - def is_constant_for_vjp(symbol: prims.Symbol) -> bool: """Check if a symbol is constant for the VJP transform. @@ -3387,19 +3314,10 @@ def vjp_impl_const(symbol, *args, **kwargs): # Normal case, we have a proxy tangent vjp_impl = augmented_forward_impls.get(symbol.sym.id) - vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl if _get_gradfn(symbol) is not None: vjp_impl, backward_fn = make_aug_forward_and_backward(symbol) - if isinstance(vjp_impl, RuleInfo): - # We should use this rule only if checker returns True for the current - # symbol's arguments - if vjp_impl.checker(*symbol.args, **symbol.kwargs): - vjp_impl = vjp_impl.rule - else: - vjp_impl = vjp_impl.fw_fallback - if vjp_impl is None: # We could not find a VJP for this symbol, so we try to decompose it if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs): @@ -3567,14 +3485,10 @@ def put_grad(v: Variable, val: Any) -> None: backward = backward_impls.get(symbol.sym.id) aug_forward = augmented_forward_impls.get(symbol.sym.id) - aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward if _get_gradfn(symbol) is not None: aug_forward, backward = make_aug_forward_and_backward(symbol) - if isinstance(aug_forward, RuleInfo): - backward = backward_impls[aug_forward.executor, symbol.sym.id] - if backward is None: if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs): # We could not find a backward for this symbol, so we try to decompose it diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py index 8a82e04e20..818199ad5b 100644 --- a/thunder/executors/apex_entropyex.py +++ b/thunder/executors/apex_entropyex.py @@ -11,10 +11,6 @@ from thunder.core.symbol import Symbol from thunder.core.utils import check, same_shape from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) from thunder.extend import OperatorExecutor, register_executor @@ -197,76 +193,6 @@ def _cross_entropy_checker( return True -# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any -# Symbol. We use our new primitives to register a VJP rule for -# torch.nn.functional.cross_entropy. This function is registered as the -# augmented forward rule for torch.nn.functional.cross_entropy below -def apex_cross_entropy_forward_rule( - a, - target, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -): - loss, max_log_sum_exp = apex_xentropy( - a, - target=target, - reduction=reduction, - label_smoothing=label_smoothing, - ) - primal = loss - saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing) - return primal, saved_for_backward - - -register_augmented_forward_with_checker( - apex_ex, - ltorch.cross_entropy.id, - _cross_entropy_checker, - apex_cross_entropy_forward_rule, -) - - -# This function is the backward rule for torch.nn.functional.cross_entropy. It -# accepts the primal output and saved_for_backward from the forward pass and -# returns the backward output. The backward output is a tuple of the backward -# output for each differentiable Tensor input to the forward pass. In this case, -# the forward pass has 1 such input, so the backward output is a single Tensor. -# This function is registered as the backward rule for -# torch.nn.functional.cross_entropy -@register_backward((apex_ex, ltorch.cross_entropy.id)) -def apex_cross_entropy_backward_rule( - logits, - labels, - max_log_sum_exp, - reduction, - smoothing, - grad, -): - from thunder.core.transforms import mean_backward, sum_backward - - if reduction == "mean": - grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad) - elif reduction == "sum": - grad = sum_backward(max_log_sum_exp.shape, (0,), grad) - elif reduction == "none": - pass - else: - raise ValueError(f"Invalid reduction: {reduction}") - - grad_logits = apex_xentropy_bwd( - grad, - logits, - target=labels, - max_log_sum_exp=max_log_sum_exp, - label_smoothing=smoothing, - ) - return grad_logits - - # Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True) def _cross_entropy_transform( a: TensorProxy, diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 9fb6c50a48..75494cff5a 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -35,8 +35,6 @@ def cudnn_available() -> bool: get_grad, put_grad, put_grads, - register_augmented_forward_with_checker, - register_backward, ) from thunder.extend import OperatorExecutor, register_executor import thunder.torch as ltorch @@ -338,7 +336,11 @@ def _cudnn_sdpa_forward_checker( if d % 8 != 0 or d > 128: return False - return True + is_backward_supported = _cudnn_sdpa_backward_checker( + query, key, value, attn_mask, dropout_p, is_causal, scale=scale + ) + + return True and is_backward_supported @langctx("torch") @@ -601,99 +603,6 @@ def cudnn_sdpa_bwd_impl( ) -@langctx("torch") -def cudnn_sdpa_aug_fw_rule_checker( - query: TensorProxy, - key: TensorProxy, - value: TensorProxy, - attn_mask: None | TensorProxy, - dropout_p: float, - is_causal: bool, - *, - scale: None | float, -) -> bool: - from thunder.core.compile_data import get_compile_data - - cd = get_compile_data() - if cudnn_ex in cd.executors_list: - is_forward_supported = _cudnn_sdpa_forward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - is_backward_supported = _cudnn_sdpa_backward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - return is_forward_supported and is_backward_supported - return False - - -def cudnn_sdpa_aug_fw_rule( - query, - key, - value, - attn_mask=None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -): - output, softmax_stats, seed, offset = cudnn_sdpa_fwd( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - saved_for_backward = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - output, - softmax_stats, - seed, - offset, - ) - return output, saved_for_backward - - -register_augmented_forward_with_checker( - cudnn_ex, - "torch.nn.functional.scaled_dot_product_attention", - cudnn_sdpa_aug_fw_rule_checker, - cudnn_sdpa_aug_fw_rule, -) - - -@register_backward((cudnn_ex, "torch.nn.functional.scaled_dot_product_attention")) -def cudnn_sdpa_backward_rule( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float, - is_causal: bool, - scale: None | float, - out: Proxy, - softmax_stats: Proxy, - seed: Proxy, - offset: Proxy, - grad_out: Proxy, -): - return cudnn_sdpa_bwd( - grad_out, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - out, - softmax_stats, - seed, - offset, - scale=scale, - ) - - @langctx("torch") def _cudnn_sdpa_transform( query: TensorProxy, @@ -726,7 +635,7 @@ def _cudnn_sdpa_grad( ) g = get_grad(primal) - grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd( + grads = cudnn_sdpa_bwd( g, query, key, @@ -740,6 +649,11 @@ def _cudnn_sdpa_grad( offset, scale=scale, ) + if attn_mask is None: + grad_query, grad_key, grad_val = grads + else: + grad_query, grad_key, grad_val, grad_attn_mask = grads + put_grads((query, key, value), (grad_query, grad_key, grad_val)) if attn_mask is not None: put_grad(attn_mask, grad_attn_mask) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 91d32e7a88..ced5d8fdb1 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -19,10 +19,6 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) from thunder.extend import OperatorExecutor, register_executor __all__ = [ @@ -411,15 +407,6 @@ def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | Ten return False -register_augmented_forward_with_checker( - transformer_engine_ex, - prims.linear.id, - linear_forward_rule_checker, - linear_forwad_rule, -) - - -@register_backward((transformer_engine_ex, prims.linear.id)) def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad): return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx) @@ -429,9 +416,21 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False) +def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: + out, saved_for_backward = linear_forwad_rule(a, w, b) + g = prims.get_grad(out) + ga, gw, gb = linear_backward_rule(*saved_for_backward, g) + prims.put_grad(a, ga) + prims.put_grad(w, gw) + if b is not None: + prims.put_grad(b, gb) + return out + + # Registers the implementation for torch.nn.functional.linear transformer_engine_ex.register_implementation( prims.linear, checker=_linear_checker, execution_transform=_linear_transform, + grad_transform=_linear_grad, ) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index c9bd6277ec..4128e02914 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -110,10 +110,8 @@ def test_cudnn_sdpa(): query = 1 * (torch.randn(shape_Q, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) key = 2 * (torch.randn(shape_K, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) value = 3 * (torch.randn(shape_V, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) - is_causal = False - attn_mask = torch.randn( - s_q, s_kv, requires_grad=False, device="cuda", dtype=thunder.torch.to_torch_dtype(dtype) - ) + is_causal = True + attn_mask = None expected = torch.nn.functional.scaled_dot_product_attention( query, key, value, is_causal=is_causal, attn_mask=attn_mask From 735b8758262aba9f8d6c7f9308ca6aae76395d6a Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 21 Mar 2024 12:03:36 -0500 Subject: [PATCH 43/44] Minor Readme cosmetics (#28) --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d3d325a155..7d60063864 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Thunder aims to be usable, understandable, and extensible. ## Performance -Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). +Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
    Thunder @@ -121,9 +121,9 @@ The compiled function `jfoo` takes and returns PyTorch tensors, just like the or ## Train models -Thunder is in its early stages, it should not be used for production runs yet. +Thunder is in its early stages and should not be used for production runs yet. -However, it can already deliver outstanding performance on models supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama2, Gemma, Falcon, and derivatives. +However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. Run training loop for Llama, single-GPU: @@ -141,7 +141,7 @@ See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with T ## Features -Given a python callable or PyTorch module, Thunder can generate an optimized program that: +Given a Python callable or PyTorch module, Thunder can generate an optimized program that: - Computes its forward and backward passes - Coalesces operations into efficient fusion regions @@ -204,4 +204,4 @@ Thunder is very thoroughly tested, so expect this to take a while. ## License Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. -See LICENSE file for details. +See the [LICENSE](LICENSE) file for details. From f0e57ed40b3054570e3c0b898f0171e21e58c39f Mon Sep 17 00:00:00 2001 From: Sugato Ray Date: Thu, 21 Mar 2024 13:11:12 -0400 Subject: [PATCH 44/44] Add `py.typed` to support type-hinting PEP-561 (#32) --- thunder/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 thunder/py.typed diff --git a/thunder/py.typed b/thunder/py.typed new file mode 100644 index 0000000000..e69de29bb2

4EXw`RLm{ zi5wGjT&rI-Y;J4z$KK@jp1<*eB3vo=b^KX7vmSHys9RqhuBty{K2xMV*PQbEVoTB3 zB2li5{Dy#j?x32Z-dTBm%_z0VAbHAtIgOJ`_q9`Yga_hyxqNR9)SpFbaSbJYynWmF z{3Uh0RHY)lEajK$Mk9&6o?l#)3&ssPI+qhPC)4`}vHd~~HzzzB+s zbLbuSs0SgN;=_yxn9gwOIi6W-yp_n zOLgZ}>1d;qS>*L~CaoR{H1oGTSR@_nbuyKH*y$87Hq#hpt&3;Hh1rBF{xD2jRbJG) zS!BN9(w7C{kWf4dT9w>gye2a&%GV}u7a{PVoq6?Igg~78e%dLHsFoL`MlTJA?TH?v zu%Woi*31zC0iCKsZD^bC{9GTMhi~JWgvY`Edb@0{6(P`o3sV`9MXoSgP#q8Ua8TqE zEnCnXhjlHe%Z)g-S?WXxIBh>6A&66pX%U>LdC2iB(Ly~!07LnEj1f0N;QJ>^WGrZe z0R97L9+8;wTYp{!74g*R?$4Ser`_@Wk#e&Z`dCZI0T$|}Nf^WjUT=UJNg3Gs`A$dh z3Tt=~3I`8GlLq4_Kdhv)lEif~Y0bc+uWYg1TBR{3($@dw6Ctq1zG4L5=I?GiK^Gxl z!dQtN4eK&w?2UB^?Gm*qZW1DaNB4o>daLAgjo6hPyLW=R+Cs_3Kt-7F=0!%6BlO1_ z`Oq#}RGZcl+925HvPlE`Tpn~`UGym_DZRQmlLyW1Gt>~Xxf{eW@V3V1`;7)z^tA$m zy8^8nV0idbjF6?4=~WNH>YAEnH$@4!h@jSMuqcj+LSFK%vW-WFf&5A2K~EVIqOcTX zS=7WTHs9IFHb!wddhZF8x>0gy|Lpwo>q2>}@!B0aw6=UlPn( z_s>(ujZ>p+qNj*cV+ho*bD?I(5@=K!<`JN0ev@WXuVQyP?CT46J^;zG`4auCRs0+9F8 zy$9_s_ekP6l|qu3IdS4R1LoOb6*n<_KhHUQuyVKqq)ImLscDc8kSf7dW#m3JDhg*E zK5NoOV}3QxD}h>-QIf^m6Moh)<5EInKE3rg^|UIc_;t9wA6$$_*x$aGL}KsPC)5YQ z<+BQ~4}Csy3sc^(#%EFQX&agsOOM<|E*;2UM#}nGcyF%KL^f4M&dumY@ZKc(RuZuA z;zKe&6Nk1m-zXL&rdri0+fUhsDMD{)7j07__sNL|7mhRW@88D~`dyY{7A)2>3;4&5 zb_8G4^}_TaKLdw$piSv@ zSyhZv#8Ax|Nu2GKP`*Ys4sF~0%i79L`wD#^!$Zvr!z27t6}T^exDw=9aM75L3^ww? zC-BF+Zh?e`=jjR2n0wVV3Nmm+UCA`K{h*6B>BFWDHGFe9Ji<-Xg-ZNd8uO!whZce; zY;jI<2s^fTftsxLE4wNNhxD>jBH+wk@WU?Jcve-+RyGQo_&MY`Akc3>gE|dPdGrb$rPV0z`feLnF*4usO#bI~vOVhecpM}V6*qc2Bs8W)bck@}^<{t^3 zhO+`KO$RM~owNM$JdL>=g;ppZ+9ndGcOV~5{}Cv2%fgB0AkMkpx5O@}GS1Y2CL;Tz z5jg>P``CP1^1gRZHUo^h)WPJ0K)QbXHUz z+)nli`V2f5Njs_(cID7)9!b%|vTBdY>Y%P0n#P~3BLqfIRX<;+h_iL(P&s7b=pEeJ zsTPOpkNa1^m_c*`_SEUbK#f5x451hmhi=ccwKcMu>x78P^PzD#v^k3}_R=BfIt%E! z;NqZcLcyWUqocqb0?6xy!tOD|I8+CB(LPDeDgk5shL9slD9~C9CMUCzB4QcK{2}-X z*z(sZ&lwTD&;|@M{!Sb=wL`<99Ynx?K?}a*s8z*=w3~`UTSK2cfipqV$;G9hu&{7G z;!?m?f)~W7hzDNa;eos-p>W~~XpQrhNaHw}h24z6R!NBG@Dasv(sO0eAzVQQxk&Y@ zNl@1}y#phlcAr4_lUsxf5L71k7f`u4wtF$3A+F%Xi}O+o3sJ&_u+2=!Cz=6}nx>cl z)ITYkS@049M80Xv98tA7q$V&_J4a+c9jv?c`+vWM3e&lRSpEy+SBJU5Q?k}gw!CAW z<>K?_gvtS^JaonV0zinxer%CM2v~)4dKG}}y}0D83>TqB1W)xn;xd57d!PLjcw*P9 zfXiFE1zw08A$aoljQTdU(_&W~{qng!8_|)WA!XQl#gzeTgwEBTTx&u$AA0?GC%HQt z>Wv(jJ7QOF=f{Zwt6@sXB382bE+%x{nJ|u1fY)spxj^xM#|6gf6+JMu5+Jg_SV_tM z3ryFLVw%D|6n4AqgdDQCdnK6oOtxz7?x1_6>`>3^NIFLl|F~%M4dfXkw6gCy+cr=X z{?gKaRIzt%&r_xCswa;jo&Br-+0UGkZKiLd=KI;2#~8D%gTv7);(J=jrsSKFGW=)H z<127F2ISZ~&sw28~~#Sr2hb;sndXIe;BX-8;@ygxtm)A zsggJOsy9#B`Ijd&)Ek*=*zG-u&lTT2KitV;eZ9})?#tGgbDWRTE4#AFNPW3Ih}vYe zz9t#lf?6Jtq#>*gIOvFS3}SF#6A(g73gn8RbZfXhjJ$W_apg}~Ws?shAI`R)&;SF7 zgZCZSWaQg-O9ntrhK&1QgCWO#Vl@>T4;%wYU_A{-wL+H05$9~*Ewk9}wTsD~kT?Ch z=;iEerKbHuayY(U){#^GQvHaX*0Od}eEh}JR~M#aoaHK{a3eEchuj2xU#Q|YfydE` zqf7<|#63Oy9*l`UI00l2k<_oFAT?*tH~~cFhFbwU%YCHCGhOqL#vJdQ8SMU`F<&4) z1ITiR#=Pf&Cp#9Fj$@(~4jlXrR*QO1{;{H{`B){#uHwX;RUKg654}2UK^|GSH;mwu z6;$yr>oa#`dQ6EQPx5Q>ik&sd+&cOs%+&W)IyDdnqm$3u@si^4LJCQAx_Gc&hC-O# zcu_HL@=~y0ntSU|v7eA+oJR2~WKgt8#`qPrV z69|+C*t-A0_q#LC2zK(j{xd>^qZ^8-{;z!}T3Y(}VDN&1KEuVV5F;KKMfwP!65V5` z6vXw~CzWbbRmI5?;q-)52t)kM!a}e z5bVD78IwaW_@m+d#w;A#-fv@^spC2;Xh1^w&%AfSN35_SSmF0ifIC_KhVU@p67Aoo zwC>p!78WZoCsr=w{#N(J8@zW%QZG2KV6M34z*K4Gl7g$uw-DbbSPNW2LtgC1A0RK8 zN#ndAFRYm^m#H{I13qV zo5cp!gsL!na50Pc0%P9rQ(Jox@kM%UCs`B(iC zrc)y=zouH=L)rX!_uPOTD~LXOK||B-%adt{ztk!SY{k|0O*cp}i_@7i_;H;b*{a2~ zsu+^WBOQ>k3gEl{-i!T>ewQJe{H3n~&?c zxIhbU)3KNV7NP;46#nz+ZS_OAdlJOt0x5m*OXeF>>VU+8f;I2MbVq*Qz~IBkS5^Yc znhn9&dtSi~T*pm5{VbDW2}S28sI+brzIRw*MMy4<-QzO2zaR+;YUB<;^AQ@m?R6Hg z;t$5`GQIT({QkI`&SCa($w)j`>43MP_Sx!Do@$l7jJ?)#P9G?J1Bj6184S zN)(@*r_Hh!F5_}K`#!cNeE0#f(k(g!_i?E2<>%fLs+8@U;sA5hH zhJR#77}vkzDX+mIk?f88pJm(hQo@#s^P+PhBq#OZvg;B|)+#*QXlWTJj}mGqzGUTp zXW<_&$}xDa3!(&Pi%+}+uj<6pif#~R<`nmz@o)6cOat(eKJ2SCMlO`(-|VgOQm-|v zm*)uwe2;4DKkfD^u3!3d`ux5UKM11$QWu1#J#jzxKApLS=_0SP2h8OXZ{=S=_jlnH1Ueya7uKTnFN)W0!qP$0)sUoj3X5?U}a z-`6?>#5VQMMnT!$mG^m$;4zEe+qw1QgUpM*c`QT}Xr#LsQ&RkU#kO*VZG`aTRi!z|#~)mxxiJ9E_~I*a%JdJ;Avq5H zHOXCfk&IPUM{&vjWS>hShrdV$oXe1C3ZpuG0rO^AhR5D(eK?qGirCJ~mO6b^-(5pxX>Ai&B?-SQL=lkwC z`|h*XclX)n&$s?AmGHiK<}>CPV~#0HfbgJP67(O*k+T5${f%ZgJAA=4y|*pB`JTX6 z$Al|6m9Z&&GrQ-^|4`bv?@zsuKSrqG#KQ&8j^4LHXWyUzAV9GnY)!6;;_HZ**{`mt z%{g%>L~rbTe7+Q4k2irRCfv|8Fk_ZX+`TX|V`aqJtqp6lqX>#D5!Hq<@WfZWs_X0z_12m198Rwt{CIx%SQ`r za(JQ-eF&^aTF7yg;d8iF-$QmqhPt)ZyyW>O3F>9#GmlDz0t9*IYhG%1ZzRC&a}jlJ z$Kc-2__oiVLu%Uz)V8Z~K(S&B0R7gRp#RiwmFtlG0ZN}&Cj(Y*4z$gZn|>4g}kt!V^RffenmXT=? z5bA3I5^qYv_=Fs)`5DVK(B=?-dY5Z-N$)+k{wn%|PzZiJ1l%zP>kq%u61Xg|tne>$ zSy2pshk(4J2@HP{M!*Q?2KMT$^|sW8Mjo)70AaHWq)p5byg*~jPe}>92MuRuwZ(Is zmTCVHf)bg-dh(&3T3JU^~LiM;16`zN2b6vc0nYe38r+bh z#Y$w&0IK;$@vjsXP;(3cB>o%1?cb9Zr|*JYbM|gvL*yCp6MRyAClG2`B^u8WuGC!x z(7yW1fuHs-lN~$H z8<_o9eWg!cd-R)^@%-`z+h_nNN;AoUDsdG)VI6El$$N+=t6r{t&KeE~$@L;5vM3-I~! zm2`3e{gL=f>_sI|kni=gMLZ{vfmrZo05(2Lz{aPw&%^_I{Z#LbRx0rQ>?9nS|7@W6 z`{(fgEh}HrgAmEO0Cq{2z}tM8*}MjnRObZ?Qs5q`jf^Mga`~y?`8+yYivK@Z{151`O208q1^|(7S;2Am*ZB(ug!4f?gJ^B@cy2CggDBd#wjzoibD*Pb5$FO-}9AH=$(vcDnL-TqUoWA4r022_uL zs@cVV4tq%Jb1_Q$bhd{KxWkFVf^bBTmTTh|vSa~7AOp4(!q=``1Yit;!Q}Zt)MQ2d z)9+K{LXO`b+9*^1m=$}MCY}h4fMeIWRa{vGjat`fzB74>P#h|0y~UA79yX=Mc~V(j zs>$j#R(?ban(f%40U)3QI#PyrlhA2({3&sfl9quutL`95DJ8&*jC>*|>jxRnhWe#X zx=wKyQ4nN(QVpO|jx$^oCIcQkj=83F(4mpbwb~Fmm-fuB0JKt5ybJ$}Y+)1e$$5@{ zQ|v6Re-`HjI^cs>NI_7p83u|Z2+&6NPrT&g4sho`*7G0VW%izX;|45a-TC*k_BF1U zQOcn%#l?{T-$W;gimXD8ovPJULCpD7>s~-p%5ltC0|!(FDJP;V`_>WzR&4`a8=oK` zs5Q;qA$=VGc{sy&m+Ud2d6H=PbUh<{9H&e@*E|2J0=mMQ}%E5TO33Bk$Zs{z?6X6KkH(6VBqC5%BudT0a~T z1Gc6}-OBxY2yR7-x$?UR2C+V=B%n7)10Td+_Z5K+3q&LGVh|{=$cOP#^NZo`Rq+ zMJ>}Vy`!U}T&#b1#Q#N<+P}LoDWK*rT5WVW4hA~>6#y-7xt1Ey0AU7;3kZmeWT@yj zUn&Ji{JB2{ctM9EC{XLBwLi6+!wBp-r?mZj_(I02-vDDBM$~Nb# zD@OG50R_3rns;%sh=#?^98s6*vCpjzSd|s0hB<~gVQTfOt{VYK|00+c&6In$0q3y4 zPS6wyTKWY*g%Py$alWZLEBqO^yolP;ZYAn>%WYMvFU+W0u%-MGQHkgtZd9dn9j0#k z!%YwnI!<6vZkur42^we4*;aXv0Sc=Bi#CD?e;7t)^F|SPe0;b6AOf^ldI`Iv2uS~V zI;4onfA`n_({d0_w0=r^X_ShSqZom0I-J%FiJyX-{U?wAIDe9nMA+vf2tp>Smw2{b z_bs4%sS=a(<^i{NH979bjoW@Rsk*8B@jw$8te2SHOJ%?Yl^0L_!JSCl5sx5qiv8au z`$#)sQ=AL3l>uOlc1^J+C2vwt0A4@DecN8lUVizRC=7P)F1;HPoX%(`O(5hkT&Xj4 z1Y_%j`wJWh0nX?>4<6dO?4P!3VB1W+9eE5~;H2VU zQddKKsYrH>iRU*Z!3h4~{^j55=0DRhc?+1CtD_e zlJKW93Fje}6uicNncw#O`2`-6y>osWX=B+4Ej}WuBapHC8q8~^0rw&~u)i_yj%U<@ zu%v{R^`q4yLPIWd26S=nt_{cO0k0f+^9~U8d}BL5Z4s!o%p3i{yo$a?@K62IWO<{4 zNEn@=(`h3xWJcU=^W+J=(@D8UJCJih3s+iz=%*yKa6u4YRMgyk#RNbJxxx`3Q05vR z-vmxr0eWv^Ch(~uZ;~W}lleOT3k~R*=u3AVF%Vj)^YfsEl8qP|`U3t&DDb5HBgp)3 z>lkHx7~grg$=fp&w40_A`@p9CQx3@ro|q^gc$)7s41(Jz1m)SF5`@#bRVo{QOAf{M z*|6yzXcXcdHfz`jJbMR$3k#P&OeJ8c_6Bgo4K+U=OAvk=0JDC7wx{R*QHu!?l4Hu? zOcTIb5-q2gOq%?niN5o=xqY7$%~%Y9Dp!d1Z^CE zb&rZM0~E5k#y`33ztFP$5B~5E?#OAd#P2mhLS(a+*XSD=Q3I3F!UVkDMfv^mZwwI# zFvQuS8wpn+@Ze7u-OItdNv#;d16%p@yEsNU+B0*+SIChTr|yid0y^I>?A}9A%+4y& zs1hM_o;!WGZVRf%=fB!x+ZZ9p!^sJ6MCOwZf%cRA4MYA_wBP^eR6#y{6ID0~7-EUm zE;LkjiDZKj-h-wxcnmj`MGQwT>3wm=O9AL{Nh}|% zkni-YKQL0^oStCq+WKjtK?o*p<{GtG1lj0HVDKKs{B%4Ju*cU^#Xb2zoq6)%;nUwX zLkCd)!0Av2r(;5MV-4V77E+FTi2(FJo3_#-u&bE-hce;PLh^wAtbY{CDW8+iKe1qFl`*xqLiR(w$k(1yz+X#dCF zFNsb4(RNm`c_V4B9Ad9GPl`q&0OSvri$WuZ73K(*2QZ_|bw^#BkSecV^!V=ooM|~h z#@Yfj)QMpQg^%jN-^J=03jtMCR7!xm2c82%O^!R%37i%wi*S9e>S-PyR4R4h6pN@nDk@KCVmVw`(yB5C?EykHVN2 zgb64YaCIdR-^5LTpG5Wg_!r@h>#PQMoOe%txjPtCpe=UU0@H3W&7JlDqqu$a<0J5I zU}WI0lC>O~iI?;oYPYn-wt64D(>*WMcTeH>?J9HapU4VspA-3qolGPPggF2p?k{ICe36n0_&E0_gBTFSF6_mJhBwDgT4dUlL`vpaq)okSnSSnT zfAx?#gP`O5fg}DkHdl{MaRfe9S@_S|x#m5+E#s1Q4Yooa7G^4ck;fS4%3~i^WO^pn zB?qq0QxsOiHa}figO%T%TUvs7>So$UM@MtkGMRUf8*nwr0-VbpFu-SRmp+BIE16JQ zm6*~!NS%7>^YLSg*+$$lz6kw8DybfW^Jp5&O|thG@z?AZYpyF;0|`-cij7Ek*E(+bo8W74wXs`Yl0sODqD-Vi`l>e3m z@c(@{?OBw(SvaA`S8@>OjGjIsk-G#0bIGX3Y(75=jo^Uo3JADo#~^gtl$VmnHO$R4 z6Atb7$weNYdufg+4zL(u;Sz{4N?M+#hU3uGm1&3ut_4`A|sedUCUF_A(w_#(3~C2nmd*mc4`9^)c~I+9 zOmN>^3u`YyGol*n3z}r3lohz-q#`4C(JO4qRQ^1-%!^pUxEmM@k2(uEKwRejs=~Uh zJHWL%3*cv0B!8U}D1m`YzK+IZt-#xJ2UK%JJa;WYCi#RmWL>zOIO#D*_ zR;dkaZS4pSs->uxEd`0inN1@@LP!U}nwB|Y*4<|*zZCnXZ`T3S!)B{sM<+@>3I-jOaC;X-y#vcK3i0pLcb28GQu?itr@9+k3dEyW7I5@v&2EjYIAa z&K@)L5z1sJ9ii9z*OW=XhCRQBiJ-dl>q|4MR-6&5hyzo)^LGX}@EbC`YlvW}m8Vta zCkhQ?2LBO!9lvenJlY>`QiZDV%T`8tggTpEO4kuajz;Qz%YKx1zB-k?nMGB23 zhmeT+U5?}k)`kjsb{1R z!(`3Q*rfP^ktgY+epz$x&aCuFB~lYn55reYu}w1r;hwNqyE|0+&=gvRE5?g0`|llq z?|l!Ur;tAUMLFN4$eL4F$oDd~4Q@<^`lT^mTaq<9@m$n|Dq(_0%D0F2Ml$@Oyp7qF zgJ@ReJT=ERUlpQriYqUw&{gkBHY7oIKd*lJMfu?#8g$M7Rey~+5wfom2MJ@GX<)wiKh6x+ zR)H9Vzaen_YHD8m(|{L`j-`9>JL&=_h<@-N4S33yx3R0&%XG970Xwn|wfkKC(@-JF zNi0qvRoCTZyn6af41yYU@}3$Cus}? zcEdmBNtLwk6KM|@m4xA3FAIazRQaDqp6N(Y#1IH~YlU6cVn9mBXu%=tpE`mG zV!{p~*ul%z+1TE$vBAU95#@}Wh^k=9ChPhsf2R5ayCN& zf*jB-tUvayS}q?_)q`M^Snzj#y`Y{W*uRM{-uZNv4_&u^gMbU!Cz^r%Q|%Z(h%hQ2 z@!e?KoMb)eSGwg=)2lf3lHhGs5*1*K#NEX2pKRtycL5A;W*Lxj%RD4y*`hmJd4^VA z0NzSie58lh=T&TiT0Lyd59{TS@*s{AMdwGh5`d!qq29iHsaT6#Bem9R!|N0 zdok{d;6-MsRC&}Pt;)1sj0%3al8W0^wdU%0z7vFklHU0y-dBb8C9Lb#4wH1Br`3ER zwy%@YMDOmRsM|-(eRF34wGZNPG0c9P%frQVKgwdKm)&PR)EVMlax{EKJ)MnriE$m` zo{0A%#u>wk&)Uf%^hS0`Br)@x`;12!gWZgxF&CIl)yX`VgNO#PQ?2L^>!M$mQxbUk zZ%P}LsHyxgT+-}GgtC;VcRMq=mH;yMJT^Sb>y&5vRQg;35b4lWB)2W}9$%p96IPGb z;fKI&b2E%r7IF@1K0ZOhos@5c|N0o?$~0=u6NZh_Rnwqz$Z254_Am|$%ybtdz$yoPXPO2= zWzGweuE(xcKD%*?UIxLv#s1WzNr+O)X%NXIeOZt1sK4<^tQ(kJK0I|tRG=gM`-&=Y zZ9B1yXU!U9q}oQZU_5T7a?6%)MTs?Aqw~!LuAYP`GVUL@Q=vFNd37G6uZTCZh^f+U z6|->lYf&Jxt%_uaN36$w)5elqa`1>+ALa8xj5_Ul(wxHVO)+nD%q8S>n8Mo*R%iM) zrQWWy$LpFmI;5{ne9mu2WQ1BrCUE0E*+?{8TG*%Ndmy=6J8VsJ3|T^^L${m z)moW>c?T`qt|vMFo0A}wmQ>FcqR z-_x22Jer6wb7fi%Mk-{(djN1JKi70apnHV#XWcAZg$<}v~9YDRXZi7QjcC49I%ub`l9 z`0Cv2po?!y+QlLfC>kUJL=)cK2fBzl&{<1Gld?x$;BR;eE1m=AmX)O@a)>lMA74Fo z6w?QQb6Z=Rf|Mwz8uLI&6xx&8XcZ(xW2G?|PrXXmkGKQ1F!G+yVaMv~V;z_QSBATJ z>6XGm4FOB7H?nHD&bH@-Q*)&bogV>P1Vd8JN@%ej+^ymzovzp%<(wXnU^mel+xJ@C z)MfQ1a|K24lJRVQcjVz>-U*6om>K|(iV1GtWiIw?Y_R1dxpN+zadL{ zMeES^CPQ+Yvmu2<=O->CbIJ6N9XV9&ibSA1s%i7{KKtz!RXff%-M0DbX3J49emA#n zX8yW+l!ZQ_!na-$MZ(M@dY))oUi7o;)f2)j69c+=ycUwD8UTqrPwFm=&dmu^<$D2t`gtlH;(ioq*6>GsG*M7qUz&4sMNh4u!kJUVT98X zcpwCU_m!fDbAeG#w_@r_^F}4-&(?*fogVP_pTl5SKA&*fjo7-6IDrt4H0e3ML!gO<9Ix<#f1se&6wWSl=9O$3}UvHsMd4wmpZyOGPIq! z8V(cSIgNVHan0|rI|_Da9!FVPgwuLmWyi&%GTfnByxTs@WQllZ_%t|x@3SApF`BOh zWBe9bA7o)|?}zlSwP7BCLEz$%u;J01h_=$KFxZ6?nFIFN*$y=G-UC>hUx1XxbPE%z zpTy4R`m*6XqqYe$XY*r1Z5WISE776#aB5M@z6t#39e2MpeQL*Vg~S&(gPt)%7mXPA z93I!P-Z^ykT0QjB2jMBwSO1Wp2_wLdUOf|9s-;UQm;&jKZybFm842kZ`o(5v{Ejy= zMX>FcMP=ovD%&)Sm`^66_^*bITpyA*wg&4ShY;05*+T-XId&=0(_jYW`Q`g^SN6~+ z`}|d)8HtvG%!W_#V9$#IbTGf!qmf6O(RR;a9@;npRzb4AjXujlaBI%nnjBnRuH6~6 zB&9*>`s1mOZ}(v=J}=$cE<5UxKB8W(KRM1FjaRUIpjWBI2aiu_lgCS~R7|$ohfnq0 z9XlA^nFA8AASAX%J0(XxMtb&@Yg8-5-4KHyPX~)Cb0xg^?!`vt?0n?e{;It2SUMjI z5=I1&ib6!gtQTo3aSxV6^AHNODKB1sBG;4EV)|-Tx||erGqG$y^>qhjr^ho9L@kejG)Ra$WC|F^}XG;_+kmOE^=1 zf7u7f2={!)RaI`t?J?+0JvUU`xJnp;WSOs6K7vqWD!F{V7)f>dm?l@qn*%97p-H8e zEvzS4@$Erf-U4^#^33rWk%(@vjcrtaj;a+0$?v6`7xDNh(&@4yKDs?j;kLT4RT}?Q z?-DZSPT=+p7B$wwOfbiNhGl$ORLU#tdB7Fk%y7vgs>5;)`PP}_#^N(U$m-1z{&_Ft`J{<&dyVGv^2>&Azt+bw{h(H=paIC&Ps6KuvkD@7i zX4_w{z`$5CX)a*N87YBX{L1i*nTM$zU95!$v(8GGKbu18xsm=zCN{yeS6~KZGZF3E z&}xVA-E)E3;_Tho3xwhz)?pHmML;NJ78tUzUa7U>m|v9-y8_6Coz7pOtPXfWgY1a7 zl431S+w}Dun@jbz%a3N(_^(*>C8s%5HQ|J}J){XC%a`E@TAGx$=|Nj`4u~xO)VcuK zQyQ5nCU6m6v+0Upo0)(;P)mLX>50_usa2gEA4Cg5v3IIKV6!#Y%x=kR*K-^^lV z5LK9B_9}qP*?}S8&3TQ0RKBV?2=tJ!OPw<_^LR5acg2(;`5W;e)R$LBFR4K)$=%Cd z*NvRy$g(hmMlYL!*M#NeVMd7{Mrh;MPZxn?&|4}9?haO;H*!wBvdY__v-9Jj)x4=Z zZOZ)JLJLM@u(Q=(aZJDu&z}ai^NKBch+nhvuC{NKg?L>b--N=k^g@_eW1+8BMt%2Ag1suCAgHFsbOhFP=vXn5tzTiMrcwr|Arze8Hk z#!pjNbPvi?kPV~yPtZc(;P0P&je^NuVVVE6B?67y6cJ0+?s^bx@MIk&{uC89%iS1M zp}gXgJV55%-pKa;MuZn@APX{hNEE3z)x&L0+C$xIFGtJJKUpAQ8!8mh-H?>5kDP|P z#!NBjovSG8JgB(1Oj-$5=F;dnxe?$ovQ9MS01hbl=1o66Wkxcl|3sMiYdxk+5EJ#w z#2ddRMyBOp1*uio6XKCK=_``sxTJu%#eCN_WIYNdvGM6cN<(S(fz_r~h84mT$M34zeGiSl3uZb=ajVbCAQNbEQ4% zf`V@kmmLEHshYB4rBuKOAJhnNWVsnyb;I6D-Oct9-alpgrfFp%3%=`-CfVSB{%E}e z?T$QqT}@3*Ap|*ePFak5Qu<^=x~-L8ukI*pntksq#bqbC*O!|V917fkQ#Zxe;``Ni zp^=no$lxWi5(}=qK?{Q@d$Gd;rof%I_p5F`NJYx?nI*~5?#V~F`qnvvGWaS`G2*3A@+#>r$~UOlf7lw;oy|D&Hs~*l6`qcYh)t2$B`y z@C<_~j=a4^!Sb|%?4NI~kl(RHFc`Krv()}^J!9szy;1&+CU^Gw?xiTiL>ktNe$mm{ z`@r2uDu_wo5Eo)hsZV?);CstJjy7db47B1}ONkjaYwr(@CBNEF=mKZ$*$m|Ll=Vw+ zAxPZB1yxm?`PBL1gCTbyu6sZpD}jyj$C?V-t2a&Ky)$^N1eOz1WCKL@OL~W^#KB!p z*12AOTAr5gLslg_MxuQBXX*0*i?r#_=PDzUUdci<5<+e=>6{w~W?##|nK|%&>r!et z%ffG|Kzq4sPiMipvMuiXvU82g;gOk8qa&?jCA}c#?2RpXITY`F)MY*YqY3!tr_>9Z zq$WK4Bl`J5&=$4}0?-z_X17=iy{h)si)`2~Trx|zFY)!ZOK*%F+5zl(FwD ziwoNO+P-O=Ll!d5Hw;62Q~JO8$!7d$_89#Tcly05OtK{kMxGkb&n6UPMa1xZB5}LztU!?UuMZsviT5#OGM&qYt98RvjjP7NFP;}f``(OdVBv>Ml9?ovg zbhLUiU4ip2D`Ptb@4pFmluxj1Liw=!+ReXLl7FvDozkpX03(io9R~C8WJ)uYu%Emv zaLl3;tk028LuWmbKwD?bE`#Cn3Ei?27djV-vHi78ZtD4OC~l$XdMgpylAIwGP?%mY z=-s%g?^-LnLf3r4CBCSu`ip^9P;{^};ca!_$2iE6!y&Q z%jrWKvq^_xm;JteCT%JY&9PO*f7~Shn)KA`!MsoM(#sAOFB^H<6_}=Y$I9Wu3BIb1 zRviBO;}6d#X~@qDadMhl5ijd}Ebn=H3kvsr^#>~Nmd>a29VfQ5j5>bsVhD=)oaR_( z?I20L5;mx2(QwxKaC_KQ!*wXbYhxzGU|s@{+iqZSR2NVXlr6;>ZkfF>yhXkk5hsYcviJaNV$?RB({Iw2*lL?CS}#d zm*oVABU(FM2zK+2dL>b(Ho9W{6Xfpf{XqL$d_yG}d%puR1drE5I9KY=>y6>hAXvfp zMqo^j`UItl9cGz3IbJ=POSLV$6_9gLi&bxFU+Q4V+%kA?w^7`Dm)*k)x-7k((dz^)_`6?XyhEsC9^EURdE}9)!h$MghQYn zBpo=HbKu_@%N4VkVCaabFm*Upa_%^40H@6_BMn*3~YV?lzFrP{!h^;ER zOy9arC275x{h7y4V14Se<*#+^(eDR2=~G+X=v??^#s)}>am@kMvBP?x8|u58XY#<2 zisHoB8_qa&izpnPYKO)#z@MQ*s|=ZQtCGH5FpNeQ@uog2NRX@Rit&e*;?340suU#0 zXJS=={zp&Ke+*+3VdIh{7qzI`do@3^s@r{1xCn@wXr8U(WC6;zomeyvP;Y^0z>T8qQO-yB3-0<2X|71bTos?~j)=7``D_1oHcMoKs9SYyEV!!JA z=1Y0~G7AmWLOt6h)mQ`Ghn2munE4`TJ}AEz`0SgaL;9=4@|WWD%KmN^`ty#UT>0qJ zb&8=fTmk5M_%-cEk-z(JFjGgSt(EFPxO!^c2hh5OTfJMD*bw6Rsjw@_HrCK1aqyxB zdTyAKH`&${Qn}`%ps?RYmgT2-!B)TB#L!xS_Tnj%F*Wxe8y&8?h+ho)eIlH#1;Gwp zAdUJv)_BpZFPsHA4UM+y9&MU@R|ydN^(KZ$4ynNL5!;*1B?1r%$J<6S>d>PG32e-8 z;Om%PQd&KK*+?-u%@6YfjcBr%&B$j_DFe=e;(W6@7e;=9h+p5_KPm;VkmpHlaaT>o zh&QiBRl1KC7hPmc#t<`SqiDLgVZSx-m%j&r@gG@9myEK3YNauTbib{3+jg{L9rxNN83XL6P5^;I6mcmwjt6Kgnc~e zR-&0I8>pFDy|2hDPuNw|W}CPveY_qUEhh+e9Dr>fbqGs8B!0P-59pE?@J0N}c;wKP zCf?*JT3M|SC25x$CE+i)AECKZ_EwfiiRnrS#9ae3=-VCa&QHpo6yaho4MZ`(_3Q0^ z-(y}s$W_bv+8C9!9AV_k5H>dZD$sCq#`FDXWIjF~%Z*uEoMe2H!oT9hVt5nWI0BmE zjihz-Hv}<-rq6;*X;l&rd%@@GHQ{+-} z{yOGkR$>ar=C>K~U8>>L`H^0+m{7y0UoEG6`I_u&rr_`Q-Y!(=bA^-DUUNvNcNw4x?e7YC&FiijC_kEe#O_V>dHEW;`D0h^?=9>5DJd1I73HQ9Q zMAnq`LiQaV-E)GO`m&n$uiOuLCZ};#-l4Oc4Rd8zXQkVHYH`Fj<8s@SF#b6-=5_5= z-R&04W^1M+7>vNl$dUoLFR(+iI+a4^U+wT%#%$4Ky7-S zK5b9jT_iJa1ft#p0+p0e`)?%M}$Cn zA{=>zo;&g;{5t9PomqeWXWRZ}uZm*~tLj+o#;d%)ikDO^*Kbvrn<(f1`~U&EtPI|t zfR~&QiHwSD(tQIH(6($mT7t*VmK6c-9jgT)X4!PDQsR0Fxh%;G6&M78i(-ZtKK7D6 zy>}{ZKpWwKL785P_s5+ItqJtmSyX=)F^l;{eVK0T`xmnm%pmzK9+ROq_|)*ajUeNWtDN!QhY zqp~c1%1WL#Jg_%MBcsNo2ZsrKjT2SzE-%cKvP+Q zSzF6Ebh*i&MBCl_Nrxs?d}#h!-SprsSc&hm)t{rSM;94hG`^Mn6ai}~@!CHcys$Y} zy<8fYm(Ry})mn8{XfS74X0U#q#Kt>msA0vD3!Ixd!AJ?Qo1<9x4(3kbz%(TX{LZu8Xnq{j0m7 z;;^gh&MvtVKKh6BIAl2oqb9<%{qRwXFIUS3uYv2tdQ zzE3`t8EGy)I1g+!bE}+oiqu})Qm(dabr(~B<8%E}^Hz*+yrXmDi3KKF0Wq)Q*8P1( z=!l;Jwo@=D!&Bwj+a?&6Ljkq+{ds4eU*_X zPS+Ce35)=n9Vdq!GahnyLm)o!*Cr5B&z~*Jpru>uzFrh>{e}R`K6BL}rTvtajM$rd z*VLxZ0ydEE9v>cuew_UQR{uOMQuxd*mZvj`3dfJ^H$vQn_NRY7X}o=it>&#L)4ac8 zVO#a{kPMC>0#_R1wUWbK=+9dgkCjQ|@GuN+^+)yYD{rd>A^nmAmQw3)x(m&g`Ne+X z%{5ZbwWT`cqIo=W_jxDx`XA0OBFQTZ>3RK(nwx?j#-#V24#Ix%ywoBCf%kZ;lz)6p z9$GS2)JW=GZOTj2eMmtcV{v0hodIwTIMiu&m>5Z`P&JjT6rF%qE-F-(v9}4a;UaWYPLNw9)?4UW5KU`cTZ{?7Lk4TJI~%G>s(fXY{@f=9PDf zKgjLM6>7Y3`?-SMXa*jhWFjD=eXEqn6OzSsbTKB-@XmAw8RfBZV_@Fi>~ZMu#b+HbsrjduP@sm4 zZ{H`iP;|3@#kGIsu~2nu4s*HDQ{DK^753h6Lo;1Lz-B$u=;Sga14Eu znWbR#+_(YxkvqhGzaH6JSpygvA=KBQ$GKaxV}qFE*5?Jj>@KYLS#;l0pM_%RzexIh z;-l5W_l&(QcI|s`6yrkPdVef=O6TO z<@1hNvk#t+1VhF*=tPd^COUYno2IUkb=19bc)K;Us+5+8?VMpQ&+@EX?Qa7`!;Vq0 zA*bib26X6w$FkBfU#F7t<^H`YW9!HF+ihUZ-LrIKcv1;mcF}h*z(~%3l+TW9WSun`+z5^9Ri3(R9XE{*yAC$* zAnm|Zw!IvR5F@z>(hs7ffETZ;{&=xIRSzNTFihXt3&BdqQ}TLk(&|lVgE&>3?WK3~ zAUYQe!c|VkofU*&aND4{Nd@CS)(h9rO?~iV^;dV0QahC$U3;^cjwG3chD0F_SuZ>y zgAOnN>gxjy$saW{>7Jtly+0ygsJ3$ar{iYoIZ1W)ksOJRB1k}wOIOI4k3bv`1fJWe zoX#(y!MIYH8KEKFLibB}1AV(1p&BIqQ^4?%7^~MAxE!H|X!-If(@)W*kh3JjF*qTF zX)6PSirIx1%oV$E)U)g8p+-$qz?URzbe3nvLbiB)c|0b7(G-B_>wSrJEL))YfPcFd ztXy{inPVDTCbR<=v~G;$%CCJBO~&=qOj~M9tCEiv>RngWEb4OMOZiC zERdnBa&M?++DQM(pkxtVvH8uZVKDn)^jyCw zk=iTqq*Thio0RHNvz@|PMz)d#Cu3J6oxdvRZ&|I(dd8Y^E6dTo6_fJnfn9G}I^!$g z^1??f_6GW`H^l`ghCMsoyhZ>DFL$7b7ImUvcKNKE>PBx8`D9Ljn3j7ZC~#p~bMZFX z3dnw69&EP8`pZFoI%MZ=^X34(bo-QxPnTF9(H=^cd&L9*qvjfeDX1C{ckg+l948xy zO9vRLiZf~fHj;N^(AFnT~juz(WIBIlfI@sl4lUinn&r;K+)OPrijlX=PFArSv zx*h;-`i{uf1x)YNGws|u=gwR@@~gnDRm3y5iuO21Z09xtr3(82zt(XSX-S(|&EorP zv}X2aZ>a;#ZPyyCQw{Zhxu@PuBro=p;?T1v|G3jUYlt`A3_)xv~q z9OP-|{c%ySCf|^bk{UpI{91!rS8Fh7%%^iXpJc`4mQFVe>ql6n00Y{=+@wA>7GB{C6N_>}W`0xmBnbquEZ~E;8yPEVnB{Gxklad|2 zQyvc=z1?0=_sn2t`^Aesb^6Q{TkI*w!m}WWC}W1Ni;8PEuHDgDJ^%WPyF%~H=0{F$ zH6(XsNYT*i*KM9&pSu=x?b@8hsqkkt>?BmzX42L34op6XRp-8ZwNm(gboQg)zVDgz zLRXug#vkf`Sb3(uUtQ87cZIOJq4kMZkDh1aK@4dNw*JVzn6hHXz1m5#yKbng1X%}2 zC-hAS{H{kBtnBVUfia*wOYMX&y`&X3T++TJ>>6s+8qkwAnlN(C>Z`DDbqKG=uK;xI zcC+Qc+}I45+e=@_9<%qj<4dA3~#R$EW@ zq))gGQtVOm)n###C$+17{AnW%mf6J(`xuTXi+b~b18hJV_#1*h2$txSWLMD6SrWBt zN(b}}Ts!ISI0ot6Ch8#cV!$%Z9y5R-(m8N|?|zk360v(Pc-rSounv-0YA6a;yZ_^j zqq<%0t}5czq5;qWtxnkOccj@TGkx{H^Z7mMfV$1Q;e`d2PK*~`t({%+zm9A$uG~%u zHBRTq)9Ls&(6l0)h*uTEP&>MepGSJu8qmhFpdFpR&n0BOMC?J$r?5J;oJKb=s=6gP zb?b_VcL(TpkabugX^iyK=zqw`TBWZYW%7P!S9?C+FCl~d5r?v4<7;Act$cw`DoMWD zC-sGXt`j@CrMbTOk*z3Cj|l_EBvS}HHUh@%k#^OuPEDg|o-7FwHL9O~0RH{uqZ2mj zS!u3GH`)qghtryjPQ)1(1Th~6EDchg&J9`oVDV{MC>G;ekROVXr=H$2DQ^kzJ$S@$ zw4gq$pMQF5?qKMLh@=BnUCtiOC2nm4$7aO8d>1)=L6J8xeKd>fqo)z?hph+L^c!(y z^ZifIB7XAQw`WYqh_ofeS(X^-hTK#tVpMHzBIT`GyXqfhxqU3r&T4dP#?!#%xQ^r2 z_qm72HR{>XBkL3lI};BCtP|Tut>{ka8J_LWvq)NW{vQ7>?cF!L@cf9Cqn^}&rSUDK z#oNQ_m390kWwk{U>%-Ar1B$KpDwgr_?fQlL5$9g%fjvI&Y0=3zvm`$6Y+rob5Y zL+?avQzDAR5`yV1{^oJ`MiD=-S9ISys@|pdj$YMKjKIuIW;)Mwz7OB$53Fl14CU*H zAaxwZY^(OxmSm*cm78-uhH4;PE-rBI;b% zh_`XE{pOvcHB`7CAlz;^UD8`RG*kSkd)s>JEA|Gard&DiqgP8{x6h}~-6ahuo65P0 z3Wc2d+3zhB(hN}kv3Id1>90$A7P{&$ynltgsvZ92=@Z!uX7WLuZ5;)j<=qjMU|!CW z{@D@ltz^>zKb+gb)qx3gQuB&cLE@vd8dTY2<_!DepWkTtjbXpHN&|sqkL=g7I;n?R z0Zm~F8|LB32q2blV5IV~Ca6FUIMtJfFQFX$77Yee=^|-)&uLf?$G5ebVo}8gBQK`G zPIU#3YC0bYu-~Z?WWSW8$bRvlz9Kw;c&U2sY)9FGC$->Gi$ANqlZfaS1d12i8(A(k zo>sXFmOxo*gtc<=<{`mGhunIo>87^ zO8iJ&BYQkS;*I=y5*yu6mdxqy3k+c$?71Lb03nD@wsYw>4YkmUK9cG$2wpmhwiqkZu&>#?GjKNqYhl(cS|3f&wR$r zvcnA7VV2r?61{X&43xaWIc^d^K;7jtJuh9I zjG!Cv*&;~`hovnfSEOaAj+i(~mX9 zXLnI*UuFLphT--shree|Kn(thMPYY*2mqr!($B^dx^Y-ytX?@Ng2bVk;W6C)1FM6WzBumVR z%*yn&e12_i^9PRN$K{h`2K&bMJ?siK=AI2mdgRmy&0kJA2{CEvZp(=Ok!8~1z>B~d zm|GhYJ#Z_p^Q%h29h*1ewvD{SmkJ#i^2+<0`DzmrMve2~R@D=I_qGbqwJ*HgR7>=7 z6&6FqwHHpa&EQ-we%sKxr_TD_6%C>DAzADj7L`mjMS&|_tQ_7Gz z5>`T32n(4rR}wOlg@huRGA%Qq5*CY08J4jOWmw3Z`P;B&7`QGn&p1b$geckWx zeV@Nvm()6)$8qe(zHj@sZ<|!6?(|J5|`x`i09U?UeqC(w;hPfv)Vt=ma$IW=(rmn{j-XCfe76aAHA)@8XLURq9bQw=asD; z(@*tcypPR*pV1SZ#Je07h739Jep3mYtKUkJ1Uy>@y2=!8A9*YY+&wfq{u8Cy(Qnhm zC_6Wz{yb&cGl$2XU}oLSYFBD9|1JzzTB^!>`mkR^_3bLpdffJs-dyY;9S|_CX#Of} zH~8jzo>_I@bC!qnOh2$vgrmjH3?75F)uw&q;!E6`J&s);DRAI(tS~qi>Wb0wQ{ol7 z9JJ}JU%?i3fTG8i>#WBzSv4`gZ)GEQd#QSqM`l&&G`{B#tPw-YTH%nN&kui}aA;$5 zt#dj@eEev#^0t5xt@@ngv%M+eR{QcuyiU_rIC^lC%0XXzN4-hqLc%O!Sxs>6HsINw z7u=3)9yH}>q>}!AUlp^8u*(8b+p9_2aZ7Y-cxmjw{5gXkIj#-j#CS zds(<8tZTv{#P!E{luOeVar9DO$ljj&`Cg<*X6amy>EL#{vtpb9o{FV+sLtVz+jg)TNJb;Q!McHsjThEB~sy_<|JaS%-4GI10*3A~gHrPSif zz>&6_zt^Aty-!$U1{a-jbgGtJLNjH=zGafkaCx9Sv67O1w?&#U-MIKMm)kar3o%|C zK~=+ezSq-vm-b^iB7+h?&9rup9kH3y?X%_LnJ!NupDkB2^6+$U?{Pk+K*fzmwOO)% z&AuTZ-t2M1R{wh@s>?LA$lcyWt3{NQwvapaT`8^MgP9%3mr{O9uIN=UXM+&0C%q-V zI{nlqYQ1|yZwsYCnA8kfdYHi?KION3_Cf&_hl*x zVWF%45xq}O4kJ|V(nsOtczVFy`h}a#3PCC>2bbnC8%1vuq{Ibv*)E-8qq=fu@++$(*X7VyO!vkqzx@?i~`fle2zRl+GJX z=DXU&AFS!W*NV6(_L}b+6J32KpZ_vswLEPunx2Cv@jXdT&n{Qw?6&Tp(_}s+36cJE zSn<&5lGEFnD6GdNoh4L9H;j689zU6%)aAy%X?D)LB75|q_zX&7YHlRfmj212bEy{f zh2Etf_-pl-Yz4HFtm0^l3o7GIN|ftc2y{|=E{0A-AMMp4Y8V=x6qgb)+A4MOWASks z?<;{?Ps zGMZ~mO3K9|>}u2H=daXjn9Ve*U;-t(y90dUL+=?d6#Q?ZW{QdMHi zH%T(M_}XH@ijc6*HuFnvIu?(GGHM^D0=ZoxKBLzBMm}Gys_H6qOeCqE zapC9i*n)Y<7Py$~8;1{a%|G;fw8rA?z0MN*2yOMKH)eW?=etog+UuLrbnzvZ@qQkB z&kn~Vg0O6o7x9VO=%|5fv@7<)+4u-ECw<)^vXqNU`s|)JTYJOa?ztR@S~L07VSazZBC$pXeZ?%3e=C<(!AJw6#X`}>|#8a}v~A{TCauu(+6esYZ_m<#Z5>TpRY$p%*(&>pVFWFJ`D)zBH>f^vz`| zoXOyj65G05;X$(-bP|>}B6tkXLGr`<13_K`#nKV-b^Xg@@e0F20|?sBilYCTHxt;&K^d zdC`x2#S?;%p?a3M0WbCIwb1^`$@SzkRp(b>)#-#}(`0Yqa(QiB>ff~hD#X{x=w0}e zayCe{mXNvY#&Suv?uBjxd4rUavrTe17b)(xzTu*&d_phFb?nlm4=1ySSjvfMTWh&y zE1K3cXsa5Vo&u>;*{>Nz3=`XD#ifKayk9vCIl8~j*GX%MYb}T-lRRFY5Sx)B<=e1M z=v7*Z>GaxP$^96BQ4Hb6_v{9;o)s7WN&L91@1{k8K#pd_SlmY=Vy=#-O5$Xtu1;f? zd%lw0$W!Kf+Qj>n8+);txn(6&Lc{gcKJnyqx-;sZGzLT4Oo`VA_)|?I2a9Eml3bWV zG)V$j8uX;5kjHkPwNF-g_vDaprC!dZHeaGHVadl+YsPEoEE>z|)S*DvWy>j|jXCeM zxXp9T-6djhrp>*kD70#&8EmIcha7Sw35UbD-t=^S^AGDJro54`>&z-+bD>E@4V*VR zY{L0fX5tvhufHnsM+xKJo}o9*MLqR)1%A)wzlt}iUtkHnku4$>#+X%N*o9wl0ai8h ze(P_NUE>E2DozApgFlvE`l*f!7!Kmhsg+G1yxqEo z+51U>#q-Q`OEU5KVW}7HpJ7>M9wwzM$uoFyjf;2%Z653hyD`<1Mzw#<+|9C&IA2j5 zKUSQRJF({28FfoTcc)886An}ctzDL%*#tSOeOyxkm6;5%b;3O$zpovC(5(L|PH7)^ z!ox8Dm11Cynosh?_lkk1UyHi8rm9@HX;r(?)BpyVdyerN?WGrzY)lAGhcRq z(oLVd($=NMn@_q^40CDNx8VL%KlZ|(rLjxxSsvFFs3hNfIcP^ykw_6ZMw7NClJNRW;P!@EM{musZfhEX zBHn8Q2FZ3Ah8FM8@BTdwmZ>sL5ofwmY~qytmE*VJj5v+~jlQ?4ND`)qj0e23_b z&(s=?Ek>(BZ)}ea^exe~4)hV@6bRsHaM07E;ffy!-PRdz+SL-eRkYzi5+CC$(=G2kIyG07 z7?WpdulE#pFAz2R^b@;9)e;TKeC#+$Yp6b?bA3F!oiVE`)CO<>b!H74v5OGnHR7f> zb6I+)JOrc7DkLz0hP*zvQ$G%-eKVNta5NH}TBy9h)X!bn*RVCe@1v8eB#thmk1}RDb3vk*h7=oZb^J= zOId#D9Tl3B@6MuZ@zTnOGgti>myFly^|GyoZ=)Gz{Mx759hWQVFxI5Ysoy$;Bg{4$ zn+sN~Tqv`sBFyx5?;Ck>$fLX`rI#Zp#BbSnw-66l*00JfA?b5uU?nF>{lan;z3eaTwZS z>E#|%2!qw(=P9Z9`3Htcq(AD-rGIuv(>U4XD811-d3t(Qsvm^ z5Z^#rv4RGqQ6JQd#Ve;_6C_-1#YO_`577D_3}5c)?F&B1%~UxTV%IcEKjQht@2ELm z{+7(Kk!GGo#zc1hE-^Vd{`%fO3i z6y%0uwpX^dj9gUMqBLQY;Nk0;;%15{kJ(N-y(w?1OFsJ8e7|STU1(A_G}__@Od8D% z@_V?y^~^pRyH!xoGo41Zz%v?>mz9h4Wa{33cCpwU3BJ&XQzlrb+3FJfI>562d*7zf7OxWVyWi7bhhxS#k2AW z($UdO3!^7~YrQjV8M5`uSpEv96DMbNBFtiDHdAdTPI_f!h-g~E-|w6f9Zf_^sE`Vy z8MkToRgJQ1xMIO}5h~LqQw#TrOGYXwOI(EtMyfY6k3~9c9$!4@<#I??RMaB{7lDS$&{k?EmIA`^nQj3n>UMhFB`>0xeu{P z-{7W0Z}m8ay|(pHvv}B%LX#l>Ak}Se5C2qdT?q&ZFrVt^Dkhq{NSjYvd2RRGxO#W$ zAIOewvqC%Fv3^lbPa53Lq{RiIQkHxv_!oUysi%ZIpqBRD(=7d{E9-qDX7daoSd*tw zmbsqzVTO5W^X)k=ssTP5QohfSUg?a6@qhz!12dJ&u0YhXc$$7Sc{8Rm0pFpOwbb)H zBft7dnNOMAw)duCNfse_8!sw|Ajxn$5$E)6PN$KVxTWk;iwyJSx?FmiuDEzMWcy8` z4d^)BB^Hlq4B0`1hP1@iNFY~jJe;-J2VJ!?x_|zeH}8j!R)tI0$Uu}GlTRkOZ`rUx zj)8v1beFi8DXWU}PGd>!SO_Ld@;LI1$6bv#zaElGGA=1MOqnhEB-uVv>o=dbam5zx zm4S->mc{cfF^A%Hor@O=wft^&tD8u+|8eH-Z4ugPPUv8rLONJCJ#!9D0!OEX!goub z`rPMxSe@{#7MoIUmU`H1~dd&e-cXm~q*>+ymE)}tmni%BE4lyu| z+k6&budyvNEy(OSz*oJ!Ig6;5h9lzv(5JGSCa)kwo4CcfxE4gHNWi7<1vJO5ER!r)!P-@K6o`yXYwnX7Y4|ratc}pFN2{M7J8~W- zk3|?5I(kSgtPJa`sFg1$2Ep(9HXq6~eGwY++Pmnp9>J_?k(*ToV}MR-E*!GeVcnwC z{U-QlFRLZ&;?JSwao647H8|~)!+p`Yb0wly{qCg1GchmIqydo|_4%iF=YG4iZjN#D zu6m5}!O6Qvc-d$`iB-AqncQ`$LuZ!XMumHF3$ot8|&g+|qpPhL!~e z+BDA$)!I9p;|T4K2*M{nUl7lG*IKn$bT*t2Y4)6lj=a-0t?IORXVA1y(X@_3yKM(% zR%ki^<$SvB>Hu6#iG#xv>tR#mU4qr8*Oaf_w+l9W*27m()EW`*{ZN1Z!QH~`E)oJF zkemw^nT>q2;2>7=`=qQp-K8V=_O8hm42Xz1A* zl|{F-k?Hktq`!tbMoKZB09og;=b_SPKXpeKz8pU;qMmq7Y50-iHe1O@A9-lS>@W9% zq7AxqbQu4^%2dY?Xs+qX9Lzq-ePNNl$Cx4bv07+E>afTALYASE3p;4>#FMJl3*Es} zbm(54hGhF`hXs}WV{2)K;W88cm$1vyurOVjQmhwdeDPs#6uDpbwr8f;9fqYSPwk;) z{u8$Cgx-W*6bomtaD@C&m>HzXuL|B6$TL(HwPpvP)HW+jx{27m-A&|tFF)&d#bAld zn|DACq^H@0nIvWUgZh}9#7$Vd=cAI2#J44Hn_Q{^^U|n+wbbo}Vq0N9B~*cV^9?(2 zAYbGm$=eMX6JdZL1Irz@br4$|(rx5ZxqVHbjD^x|{fpYwe3G*29^tkc^Aken0ry_R zZ47O%`_#tvvQn+%<5Ulw9}QPU#fjP65WC!H=e-cGip?i!tDeN00ei5?+$~?`gjmqy z=EQlV{p9sfh8y|6LX(}1^JB!ZSMj!uR#x-D134Dpl7@$p$gE^NpQQKNrA21@P3%Qe zk-DBHA@1fL8*iZf2-tb__2KKeYmmIctf;vA*??`Mx30S6+y|`J+;A-kb&9(Zb)%4% zXbrMP&E7-mQcIu67I}SiBtSO$PF5U@P&!#Rk-IdwgxIFK@N>s!*P7oKimUV9Vte|5 z28}=YgU52Gm%q9fpE*?r>7BMsgxPj8cVk0`Yeb30(`3N*U+|RaS|qW7nBkY`QfQqn zNXDioQhor+QaYm69An99I;#q+uR%55#Oe}7$soQ{ z>{xHip@;JFI7l5%R1I_ig3OFb5=y~%mKZ+UBPd+L`qFuqzyAI>fDe2 zd1$nfRL5v@fY=|SO4Ykb$8U>5&zPYh{DR)Qm+7yz2*{D7)Xc#QQ_-^luC zHlG&iF)CA@6X9k(+vP@O*d644;j-Xm<4#(DCj(Lb8Z_rsS{BI3NB>a6OpJH-Elgr1 zvKF%xhdP`m)94iJAwlRQ9%`Ppy-RYFqDSi#hBdwU9c=+d8gjN*!uJ)AH-jT6i2RD~oN^%0_4k)nkuXA z&opS@866Gpe%Vbzpi~=NSi@=VkXSOCwFn+ zf5hyLsI8{N$({oRwnzo3emux;T(-M^7=3f@+q;^(@J&wRRQbJd?FdRDdEJ*}w|8F! zI>{d;W}8k1e%WZ{K86fxM1y$bIXc?46|)EL9~D$B4hn-H5Pf=Q?A5rcejRV z^v6h847JzzWs)jyakII3KKd%Qi(CJ8)JHD6swS`uSD-eT^9P=tit8hU#L z?p3XI=9))7RuU+F=&hQUkJeI^%ST^B9gUZ-btsH8P<>1Sakm0577ud1s-#eU-EdEf zCF->xow>U7>L{=kbUZ$;3GHo^W0>lrI1y#4x|pwj%m9))li`rfo7|MU18bMf7)cal zVR=2g6(^qJ#KCpIRL*=_I{1*FWmL5^GeBA6+--9hJ$bM~dXp3X$=|vY)-W?t#aKoT zcO>&n{|TTrtK7M}&I$TgLsv*lunsVRwRJY{!237=S4Ahk;6O2*$$ft?$T`bUv(#Od z!=2d6FQ5$0-tJ!xqxa##QRLW@fzUczs%y?&N8*iZYOosiWig&cy5*J|-MqM;iM2lqNg3yjEc?4rqfJi+9BJ-AR7h%H~pi)H8=rK7w3vm;t0o^cYz{W><-0jy6N zfQL4qYLZ`HY*Z>AQB+ilbY$^;RBP`T9e{sCFBKuLe(-ct-j zhrOtsyVf^X?Y5_@%|r@f-*(6(j$ve~2L=hyJ{Z7zzYj)nFHp2`fB3piI0+uiM!Iqg z+sEY^z=vg@LWw%)(zMp^3ntJ8VV-#dLHN$%1Z-`8kN$L+m~q2R11%MmA3T9BIn0aSU1xlE)F%k#C<=WcLJNg zW44<;Het)4ae^I?J5KC7 ziG$f}vx4#-n+KN$i;aAgYI~msm>f8CVVv#!=X<;Gp0Cssp4RFYW6@%CM6f91T^aPS ztE>5emZ^JNrEj?Drss{~jQWEk$j9I?dc&HAYnL33_t?M%SqK@d{6JhpMcS`{h*(Wq z+EF-hp-)U%$h1#V$kt;{eKqn8?|@HPY!^+$StQS z3NW|zdi@qWK(EwI)Q4W}GgIqAmpDxZUaR!zE-!FdcF-3yu9AGS>@h!eUFONrQTqv& zBDvn~W#;HbMXF&kTM?pK)W(UPoDAQ}q`;CtVW2KN3gHi0!7I62?{JUN=o94g0` z?hr3gv!_kY4fFNc2^6>tD7j3~Fm_ZbmoGlYr&V6_(Ole>=pIaJ>&sS59&j%ur5IhH zbI%nccUHe9#qX5}WA;wz_l|sZAQ)~Q()`dMIXhQsvUA9Fv;De+cgvvvL@`SrW@x&) z*mje`DGX%|`A4@)K|g*?3w^33ZEjb@Zi-X1V7jwM-Ssl?)k1^%cFSg47lX{Od5P=K zn;k8ziaWz@3`4>s@^Nn7JZ;Wf!}30>?IytMt_rJh4w_)5uI|daqpYmvWUjY zVkSmf@!ON^++g;x1W}@ACG139O_ihS@}7uP8P+a+|9b}hm<$DYVm9&9wnr)C3(9Ht zTv7}d7dy~pMBdY9HBE;(LMQ%U(O9?vWYnZrH@?XvH3yg!i~*ld)jj{NankpJHE~kv zHkJ=sVkO-{({vQH9#z{T-#r5cWEA_%gKYNudfc{X(U<9xE-cX7TqzM-?y17cSX&s$ z>P%GI^{}kvRiPE6?{CH2EO^X`7kx7hs&WuX71DR54?L8rlm9#yr2q8dX1h6sZDaM(lM#c+ z<_Wwd*|yh~dt*P(-KX}vh)lr~hHHbyl7QSZhLFR#Lt{VRA8jj4<(fMPYl8jb!xwgN zuUvG0D2K_Q>a1G|B($IFa()a62TvY-g*+UD(}4pt=t$n0Gf_Z7;C^d76AB)h_c$@G z5PCuQ>mLcdu&jQHe`9~%zRG<`HA&jQyeXC)M6#Z%#)}Cg@tLMqR+N|j#Nd&lHT@Mf zm-9Z*$#q-?oV!&qi2BBdNt2-Fo9#Z_xC^f?cGILw?2yBibFv77v1NF`)07Mx6liMFnmyI(Qwh25r}*!!kh2|VPUpJ; zn_fPSQWZoxE{ZnecSLYyE3?A`hdGeg@J~jf&8_7ahRVT|b6$h%0z4-dnDSWPVtly2 zBmiaksYYPG9E12G!?B23dLI_aVE+Krd0L$tPav``z30Rq!T5taSUM253_v^ZP(DfW zQ2{jm{0-lD*<3t}uBBkm@9kT9@^$J2k~q_?tmK;N|g8y?y`37DyH2jZYl` zC57nhlEHEpU)mN>z`1ul?Ujp1I$Ksxxfe;mPt|H-W+vyFcDdh0Q+RixIM?jkFu#y4 z`o$5DA-zIxK2H3iZ`-phcp;r2z@P2@gh&~YdLma;pA-M13zD=LpHpoRD)?btjdl^) zZkEbplu}r1L4%|8MG4%1IX@OY8%xSz8Y1%AlPQ2Bev}2GhlCui_j#zy*4?r%kC?3N zP^)C2`V*X>z^30&d7CBbUUQX6Py>oP{~<~5JsBnI5Xcx+#nyToOaeo=tL!D^a3OJ` zV>HC^Ry4Y(T~WGZqPe-T4z8~IVZa=`&JPwhiWQOkiPk2J5P92p-z%8lrO_^L zra`x7f_C1GCEghLa^BFgSOGJay`LjUa9cYWUaam;h)Yn|mEH;ZhYCA+Oz;DPkaTlI zf7|oBFYC=&i>GI;as78kLZ+nZeZR8b+GX{_tHe1uCT9e%KeA1M#z_?Di2kT=GjTtUn>rD~2EQ zj{-*DBe^6*~BBjCMe{N~;Zz`~igOegjtHC#&+T>U3${fET zDf>eG>x+Nvq`3kaY>&Qpf#;2$V*{%x2|zovd(9uVaOn`Z!APb=(BN;vOm&Ru;3 zt1hw!Hd8bt0NU>voWV6c{oE!XKXaz6oWdL`2Oj5m!P^@i;yuSfRJihVe6Jji`&8Gx z3=Z&#d&SN~gFgP?{3;u~Z7#5k`-PJR3BJu$Fb6D#{hYR37l{OZ3I)eK7oXF@+fnXG zo&gYwKF!DiEJ?_watfG0*|s8TY9@DA7rQl$>1c@iW)ua68&WF}pZaE}!9H zg|env4q1GFIsWtd``;b1eG#c*M-%=2plIhp&NuiF{xeJe|6M-Xs_Qk{;oa_|p9b}Y zN#Hth9$4#R;6Ndee^iYf&e9VeP>@SRlcm6xn$CHW8Z16NU*-j%A}+2u`x1uht8_4O zxEtC}$KX;k&-ir94R(b>+EyPh>{?10IbfUW?)|k};%Cx&U?7dXdpvltua%FddJxlYS4N>`ZndW;k+!^9T7pZX_bX7$QCJ>uH9~{T#tFC4 z2nU$Kw|GFj>4dSRaw zW!5yy<{SCPpaX~l%gfg#*^>ZeWsZVnE)DNxIQ<8**9xJYqV9gQ@|e(mXH*lQa=F@x zHb#yhs-^p;aEzL`Hw6Mv6WzDR4;)56GS_Z_>tX3-pQIPEx%d4;Z}k7g)U#1|PjDJO zLZ`bXX!PU#(b4z&1)(KH?dJ^PmJcHTkC>mwR1+!4AGBrDAvHxl6hI$17}HomEAWJu zJKR2BSj2&PaWm1(1JDo^ET$N#oj{}NS6|fESkqBpO`*d+0hY{PGU(3FL3G!?3!6=J z*A580r@;2Un@t(~fi=&RJA(=ABWaymNJ%`h=u4pWPtkyp9O?cM1$j&$u>cEMo>^c> z6Xbkb&IrKu3&|dT5Ndg61sP8q{d7$OIS6GlAv5X zJ;M@G$VKsJe3Zl01|9WJ`*8@euICd!*AASQt1172wWC4rdyvdX!S=`f4^TM#IQsKk z{~!Iik0?94y0|;iG-4P<%nA3YSa^{M3tM^7#Y7IMuN<~kzY#k$WDH)OxA_?>xgl0E z8Wb`gMkl{moQC%kvwbH8>in4J>p+K@;nEX?+-B*HC<&qgH@{x9(_yrgiDgd|6gHO= zoY-LRgAekrrFPIZ;LhG4TgXHKq-gW~kuMe!yH`!Xj0|VMK7gHi#uHNW8jExNEJ#1- zOvL74G%cT8tthC!JW#O!=-kQ5I1VXH{~3q_37(0{R{Pd|4wW5OMB(5(QUlIL&mR8) zkSfkQ{HhHS^yALyPj_>uTz6<7LlvjWhnGW2PsK9e={L?P7o6n?vNM3XVKt4G88$lY zujh7F0LCC>j7I1P8&t>!S{z0zzfI<&V7nRzhuF-i%|#9q_^7>v6G4P(LIBdb`Ctax zW~DK=Dj^B!s%gVe)1aTezIC<(Vs>Dvu?7cT^ni0a2EtfdT!lt>!>Z%3o3fYcYL72#eHGi&v}cp%;P& zzV&Dr8o4LKk0UqwV<<2Al^x?T zZ4ZwMPZ(vxIkrfN+c|-N)Z+p3akxgpVQ%hg3w1FbOA-cfljZf^KIHy55Dx%$-r++P za6K4vs;0mf(aEc4_amueB#AnLD^3d&_*+9YtQ=q3GmZj_mXBiU@nYF7DGypcK07ZG zWwcK$?dMgm3=!~@aZouOhl=(y?*s@z`C*3{Be*E#yQOJ>#9?>jM1&dVNwq}^4i!3j z>$lnfy32KhAA~FI-7`-bc-yo{I`?0!~sd_^YLCPq`@2VfD?~>_%bjN*eA&*2V)^n-n)GqAoW!C*eGf^ zct#w^;{Vzq12i;#f6Jvmc1}P8U=G|w{<`!1FZ}4U)?KKw&kyzk>^+Co)s1=Py*4oq z2JX0?pjQYA+6Y6ad@ZPD7HA&_URd%c3{yyJ>B-@)ux*w@rC@5r57?2wspk0;94ZCc zR(aqJMwxPL+XiYFi~a+L;o{6fD*Imo;~zY~O^HB__6XFt8!6KM|9jcL8*2Q&>T@e} zj-tu(F68`rWqtj4NF#OE7GlrYCx5tm__ThpYI#M)ew~Nr$QH(v{~y}| zCO0qJX*xY1Xywgs)`mQ&2#z(6;#P|s+TL2hcc@F4$#jUcNOA&$FBgTz%Y{nLUnv8a zTI8KssmynT-=bt*x|t_~q%ZHFTmpPb2=FOr9n8QpOoo=`?0_$iOAaT%gFSGVgC;{Q zI<=cNf~)fBlNRVHOBTd>0A&1<^u`MbWQTsa^hUgf9BS!xUn@=`M4OvH@^SzMKO!C6 z)u*!r8Sr%nG8KSr>~On3((&pAO(B2l8V`AdLzlh-YKh^+aj5Zp*WL=jF*1+ZVvxfv zT8lb83P4>J$+Up$!M)uku?lc>YgmO}I8+rn4Ct!WX1m9M{F^a?o?rGm0JVh!zv%z| z0tH*^5d?rE{#5`R6u2iGn}i&Q0wwUG5MTdunVJ_U^EJjk& zfQ!Hv+xfv5z%l5IXOKe)k;+_|WA_md?UDyO{-pK+wcWq3Nr^1s&SWs%hA9VV?$=!N z^oOjg)d^WUl}E4}{u?~~XMn}E7pFl;+yRD#iAd%Pvp|ByQN@#!laYv9u&~e8#{2bF z8U)6shui9UdM(VYaGlmDfu2`?tKwc@5$;a0fbt-5Q9Lw+fESAT>IN-qH>5&{7xeA5uV)?uitm$&jx_52 z^tbTq?hL;Toy5OB1qObEb#A2&B87OL=}14^Gc-p4`Mz(ApvVkWyb_#9?_mgf4?YJU zMFDJaXPhe^43B8d2ta~q&=mawI3FE@P0AA(alO>u$^)0nV+qnNNK#TE75W7{@@Gp7 z87zGRvarHHis?P5EEz=5+Q|UTnEQ=LBAyad%*ml5`lN~(?xJ~CTexDs&^isVM$}eP zOhI)BO!5cXPKE zkHLb$=|@0cJQpWw8HwQV^Wr`W;$<87Tkm$dK~q*ROmaTpwO)R*kM8TsNJLbJhDy|(CzF>y{!{~l2_vDI|QMy{dbo%-~+~+wBG-1K7>1kTrpeK-q*)(@yz}a za`mnjU#6DQ2Iu!d@*GH48kPZL^JPj3r-x_FL2*JLPhxC|*>2?ERzkAyDF{2~eHiH! zltZVWeT-2U8eQy%QwQNU%HH=|X*y7eUDq|>VqGOD&W9i%{fAKo{lUdo;tdEz&@(*cvOBL3iUu z=kM+(BXU(+>pt7Fs|s}2wm>fGyoaRs#$9xkr(0F?hvwj}>B)I5r$4_)#4GLx!GC`#t{7mrFl6`u-~eu1Rl}vh5DW zk4nwDh3T#5-Sl1^L#}(2gq~_kpgxN=~c|~-DNUqo_hN`wOjqq zdHGqepZb#;l1>FMqWe{WZXSfX(*V>RrD9P(TIkz6rQ(IpX17l?@pJ{Y60O7H{lj>*-e*G*F!bP?&M@`pJozb4pSx%SjBc;}iJmR7 zLGnwTIR=0y5FSWCcxb@Ax@r=D@;j5-zXzUm^0x*F2>Qseng{JypGwN4wNGiI{6?#f zpLxzV%n!_ZyH?ulk~+6Go-zg_4y?CDT7Jf{OUNWZoiTKzgi)RoWMCDm3uI#8!tlGH zwa$h=dHv9DJ!Bw`3FCp;l3PDUge}`&zKIAF069%Zv4uvMh2aJ64Z143t8ub3%dBF$ zI7&WcgF*X3vJR+qzJr{lL%8T<)&BwYV2+* z>3BqK$!Q>%Y+MTrg-Zfd!+#AUIdEw5A)R_S0*u0t^s}@5aGj0Qb< z0$WzV{}k9lC*hO+|J#7HBiQP9QvPW;OxwmS*h7-8uH^Y z(=|WmCC=J5O#2ZcZz?HseG*?Zvp`F@_bu(*NXKsI^INBu7}UQz^I>VccD4A>1bI4I zsnUAv`*blb>eVdUdN3o`NJaO!-&$*)w$pJ%Wc_!Td^k;m=6H|2tNW$&W1TefqSFN7 z5UqT9LB3YW%D8=U&dTJu+y3{9-SA=q*Q%3TKODv@e!D?(YTjqrogZ224@DvM3Js7y z!f#K{!p*RRv_Ifx2&Me<1-dt#yB>B_yzh<4dKe`Hd$-c z!4@(IZk#)eB#3x8??-~wekQDY5lr9~Bj(%FMTx%;Q~-w1KMp2-tZPdg{&Qy}wx5KA z2B4;P49gOcTh7qx?-Q46Fm@-vy4LDH9}*_BRZ1>P^Ew^!Gc|gORF}ww{5z{lF!ueA zD9n%>QG0*iwfdGy8nTG-7yhI7b{!)8ijN ziWw27D!lI&vSMk02U6Vo0!@{mJ{gcdjJV<=l2~V}}Vx}XV9PYv{TkRF# zLvt7;o%=aX@)nug-BWRq7b&45Z5|LJv6jjwp|iMChxsW$i94+5TmoSyi2mxm{)y80 zPnZ%RzoU?ASQNmeQwZ++P+(FW=sEUl&VVu1ODg~k@YoT9b(rm~t=7B))c0`FHaeib zYro!c4G3X0YL~@;xliFZFN@$({}}H3?|Asoq;~v*2K3(mLH~=8!uqd4|9hYR3qRVt z`6<`Yn|EZVL%0&B0bD}3CM;fyZL}V3N01uT4PM4rT*zQs50|aGdC^E(h>TW`s(}oOMuTqEeDuEzN%#_gqL%*g{KTEN zJKXBI^JAOdBNGMB6X8uwDwXIBK||&o*U-KQu2tHDPuVz#&+k4pMrxYtz_R;|6nW}T zTkD`6_DZ!O1N21t?d!+FV9;}h9YlWm_Aont(sY1H_0e@wBO{c zwL(uH=99B9*3ia3}%K9%1etOlVi_JkASJZ}$nJUxCt8cSl`nA3TWP|H6Y9auCU+ zf-&&O!oIWm+k-eG4??Q#rGwEArElI4AorhX0s>79#dx&@g6OVU09u-)P>>Bdw|=eF z7{IMT$|Fd;JTdiz6F^p7s$XmP1)#j1PqO0?%~FSUB_1U^yDf|CTi1WKuAfT+Dl^|? zG7H{>@#)uK!>E7e^aC0zMSCtx9LOve0iEGfayzO8n9C}ultj7+vojZ#hTt;yscWX$ z0Z*cg1Y8(xV6^?uqy_#>S~juEdC=`!1EgfQYZ4>DBJ)dm^B{!A|N77mujZ#`9l)`e z2*P1}3*TB;^O%{Cee?lnjI{B$SK&Ts6j5=2bmt^d>l~jP?tQwJ4G4DVnEht7Yw}N? z^eEHx@a==t_M{@LlLg=W#P5yDRW@6Y{h z>L35+i2nC`{_lPDf6l(Lc8)-#{PCSS`t&caSCL@IcQhj{@l-XZ2f~lwRf3&=GVg*Y z7@Y0W4m*7qEty^)6iA>AOTDInT!>VAf(XKY#J;)<3~0VaQ8~~oy8K7xg89$VH*y9L zt^L2sayg3rD7A7E$=UX`XLgYh9vdhZRrY=)??y6BHn^1^JfG|VNvqI1Or*0MDxBaf z=|`YVvh?+k!&sjA&i*BK1sWArlP3 z8zj9z#JwJ0ork%B_ESiM_;-!@LrEbU!XkA8I_!ps1JMHI-&;Qotlc+OKk!KZbo*FP z``L8`-?NyRnVF6AKS#I!`nE84(#E82b)&fS_QryVNa2B^o}eys6B2Za=3gQ1lVSxb zTNFR%{wIOgm3ru^3Senkw0fD%htXQus`+nVjL=@WlpbDS?Qd=VL9+n*{EDY57u!8Y zmb#4uc@mtopmL|LU~ZG8VtyjN%DO$`Lokamz{!~Z1m&d=zJjk+ne<^~Y^E5Z7(SKn z$uSNUc2z!cq*T7}Tgy;D2Z7JtXqlFPv-1SH=xDQ@w<7{A$ZSq%AiT^mHAKKPX_IX6 zFY8I4uHJAsKlOS$sdC|u@B-fL?SjoUt{9efroAUmI_;rhbCQ$0Pm{TII5qwGgLK}t z8fAJ}?Tb6;Cw9Kjm9@J4B2BzKikNCJw6*PRVTY@@W9DkNGFVw*5n1WA;`J>dZpCdd z&UG-ZMY3hLz@ZtStRs3G4C^jiLXKZmS7&{;`%G3s#0Pl`zbDIbJ*U08cT{WVKK!Uw zc}8kBn~Ie~VVg}R?mdOe&=X7P(2AACb~u}M#sL9sTeslV8NAO56#GE=2C zd>SU_?ryZ!lY|GT?ERB4ml}a6@CzL3>h31O-r;_D)8&wDi9Yh8$|f_jKFqF^ob#Nk zJxWCQ5)A8${>m5O$qB_2M=%Y=avd5PvbaSvT#N-gTa1-b+0}jcm#9q7kaT!tO9ffn zM!TYWs%AKsl-rmEb~pN2XAnGmql}S*Zg{Pm8ZSKjA-$av_F3@b<2CtO@Qv8N;f%PUSPG03Mi(dV+MJjP!OPqmS)1Z5V74~G>C znc4j~HKdlFJiGILFW-kxyr*_V-&20>NP9Ir)p4F!M}lR#{Mar05j@@Vhw4XUWpV0z z1ZC8JF7GXs7i>Y%a;_UX;i(cCW!um%0 zRE_4q`+Is7_9$~6h0CxG{CKOsX)1g}s7=IlH~Iv{QAC5?lO}JF}RlC{E$5+SO)td+zzxjH9t~4x!J4~c!p;Rz> zy}zztNVQ{X={Q@%*Lttzy9-`oc3+k51!`!_eIgQYxUp5K0V!3P=h_cf(KvNT(tp z3@Hr~Lx-dY(j_3Jw48hX&-1QxKApARk26cab>G+ZtG$1F+kgMTN*Zlwb@#2FyAXKv z*H6)p9#)18Tm-FzB<^DL>76hIjxS6 zosvW+`g%SEZh?0<<-(LX4)j!mH7P?~4Lqnhkw61;e{wrC>H?@q=^JykPU&;o?8x}< zZTajY(eE#y!r=HrR=Tzc0StoC3%-?%R$TyDR7zt@+y-@FhpS=vyD3l;(({!T7OFBZ zGLA`aP>|c!F8ysQ>0R`NL`UkiO!Pv2Omy3WwQcX#q&F73eVd4XIDV@i^i%$vxU+CQ zD0}N-OXIpE>B8i%+;u3H(=7D)Zdvh>B}>lRGcT_;N8Q!4w%V6&JwGx-K8vo0zh`gp zbnkKJ@BP_vlQB%cF)*Bf?3|V{7Fm(~1&-o4?qQy*$Nm;+1{`8qLWN8yaD4Rlwv}CQ zpdKO{Ogn{a`bjvsFHR5WLGjX#T7A)>Cg;@ewGYjgU&YUl+)UkUCJyMp5+U%I4;DYB zG@o-!-8>*YDa)?c{@sr4+(jxcIw7rAAd>Q`uOUO0Wy2ruU5q<~qx%tyFrG@6LZ=8? z#Bq!^LwW6;yG)0@cc;O7^;Y!VvyOF;0p7!6hNF++e*x*k{{$rG&kz6iE~ysOVczxE zN0A{+29ptW&r(!NjntWz-rl&C`1Zvsm&+ve=ToZrNJNnhbZ*4Z4W`Yu8@Tn9ca$)- zrMjY1#8pw#=JM-jfmcMHnzuG-cRmuGN%mqa>oWIWR(pD=(k#r%0#@I1KEineXMpCh z(w3GzCV=r_01=9s>7fanTWhgVS~q{ag29j*E=52-%I0SKTZhO)5>KMVCAr?c!6r&6 z`3Ap5Qc^O@g9Z#{Sm@7B)3YV$C(pP2KAOx*6ejAzV*l=k`$TpXC;SdNE>EIs;!BGDu9l?)J}$!0!Q!PY8rJ+whpg^A zIp^A3W!DG;65moj%7ImlGt)+*;c6%saDnuVsr~AD@;?FTqbQD;a7r>{7}rP*LA;sC z`T?EDSb&RT`tK^W4Y(<_dj-nA0|67l&gN4#h8o0NJNva!pw zm&`PjJrrdg}%6m`TD)WeHlU6p(!tK{J340Buxp*B`PcR{ z!DmDl-&%<$8efXP5yTqUh<^*gHhqYXLfBuZvcVsA+F`GkDoASl+i>ARJftdgl8U?*V6ICv4mr7fA9Co3cxRNX%nAlIaQj z16o6`awg!tX}`-;F;hc~kJR#(qQjQOlmE>a5kN@#|7MKF|NqLsUX6rKRMoZpL>tW_ zu}ln@(xM##ED+7)0IJ^>Fe?#1?(O>bN95`xIT(^7^TzcL)AFJ*W`x9LT}?#G7%rR* zuQ{58k;ml2!;Cun!Qt*utg_Z|m+B?`$oatf~}#bOF5Wvx>?}{ed0aZt2C|Fvq-V^6Dn0lMTYKC)6fh>5Ktu$4Zv@k^1h8bt?=Jo9S|Y<{bqqP)=r8v#>9alIkX2UjnK)}Q4q+kNbaG7QXrtc z-P|=00J1XY#liL`;Nw=d>nk8F3fAS*KL-x&pHKf2Q~qy@oac}KQ}6{E;t%J)=0A_c zfMIe#x3>j|0T>Q&&Uo|^@NxzJ(q5uuOSi>hR9?x%md6hR010JW%J2TaJ4m{F?i?H( zaE&XftEG;@@jY)Zy6&Y400M5GxqMwWff_XtYC+NB<^}!m$jH|F*EK1?N&}-!B>D)~ zvJq`ih_3~e>%Ue+b55KWYJ1XS$q7bO0I8j|J5NAlnf{j&L_&=&{}GWuMDF%@p5VpL z;I4LHnQdGCy`_;YOM-&$v5g#{m||cN0mc7nUKn-tW~pgRRg6sKwzmCzS|9oHTRE4r zhG^8w1ux4s*X4Onu46x;R}5AmN3$II^ZUSot;z5i_W+K!WAR~X1Q2>C#dgH0fY8Ig zeE!W8&K;P&H9-%Y`eb2ok=QUu_J}Z}BT(v)@V~v6M7J71J0bfS3(@`X2k`BG0R#Ag z=iQ4cw)m3uM=xBLG;pA)98-zU1Te@?~O4X*0P8y`!OQqksH1>|<@GrhhI z^6KOfBwQPl76j~Re}2uQ1viA6e+IQ-C}e@vNQ2}Fy|^{qU??~{0CO`f0;|@B*IxHu z?XjW0xi699_cfS5MC1sk`YMmy?pJ;$ z6uNBxGiNyvN4U%*CQ4@UjGY+GfFZKB%s?RCEBf9=m=h8bjz+6a0OI0Kg_VG5w0e{C zqTO(=K;O5NfOFpqI8Z2$|I$T=Sa=2{P$trd$s=OW?Sir;&@ZG4>Av7)22MPu4!V}^ zr~l$8V+MQ_obWOM-j8+cpV3S7!v?;~oyX9(;dKQHZL&aFQ$nByrBadrS`(uaUWw+z zqA-YV%C5FZ2JrW?1)l+-+2I45{!&KO^2s;^tyFQa_ z5W>I|h!DnqbDRSo1k~^y*bqP|O@ACtjskWmRn8V$y(?p7UQTn_S4SZ9e8+*E4<3MV zTwmv*o~c0$mVK#k5iv`%ZFr%D-tZu1mY-1-@VYm#wsu*i2;+24hdMfW17#y=4@Lf0 zaj4ripfZ#Wur(&k9oX?D0A3i{`pn_#bZcOL1TCe4$K)1*@ign-MzV@<^AC@B47aSP zvDW}&>=am&MD;b!UjRxEs6p0XXCN=QvL8siJ9ommKc@S94@|yf9Eu>>EM1K?%|3YN z3HIoxrh|TJYNmQXLNx8$HRnK|(raf2Adok4W#@STfjo&RPZ!-uZ{{ys0Z+C@@f%RB zAiTKeBsS2E^)sMa{J(CjAG$URe{|m%7y=Z*Et#q+D#Yf@R)TI^0Ff%)(<~h5_4yuW z@V~QyE>IUa>JqS(NImf`DZ_^Jsjr4D&Djm1N8@Hz?vV*ha?|GG?78k}|EgB+U&QHZ z3*mDf_yxBBM?DkJb{PX{L}+&c1G5pJ$0c^_^-^e^$QDtfP?wko6gACXD+U8<_X#yG z4I>AEG{~9T8;A{cx`P2ZL3ds?;?-_`ObCHcd95NKV4CB8qp=V<2i<)e1E105&4bw> zW-HA`-UQ!R+II}9^I{iv0v_QgKxp5ZpPwIXWJMP*81W1ci?d&)9^MA#=io32PVEFy zhcAGF`_`b!zdHT`oHx$zg}8rSJvYb<6`xvd$2P`+#u?27US|&N0ZB@Y+W}q2V)*p^ zFtJS{<$kw1{`|2FL`{c{NFcuDWa8I}t%R_Vv(DtdK1Jb*2&rP_v(?egOH z!!P#ZECYNInVbV_ZN#Y(yb0+QIM4Vs1HBgYwpYu}=c)*h2gkk+Mj9oGopyCs>Ajr4spL9$hF>D@z-Rt{d+7XIoc#jI?I#J=wRji$}%fY;iD29XW4#{!d&Ec z0s_{HvQrZYQ177zSfjl_odWi zbrlPq^ec`$@4im+K1u!;`Br#3BWP~DM2+Yg6UlUkx`&uU9=)kbm3oYqd!nZs06$?BfODi zAno(f@B%l(#DBcMEU}|cWF-TseV8)HnW+MGP*z0;OIj!S0~fpMW<`>|L|e<(xLmVy82|c|PPUvpl$JQJpIG%zG1b-El3+pZ;WN#d6$gP$> zGi%^^%A|1&$xg0k5_&J_?}wmf4h$mAWAf%b<7m(zz!`952vX-)x$GfJkOD7>o`$Nn zM7X{a2Ix6gTOz?PjaNDY+j4!^W7U6gNxSM?$gxV^PAXBe(_xAOMdL@^G}C8JBvF18 z`L?wuhsxy>(Nt$8jBH_C*cQ?*-?;#maWux0tfat@y3zIm{7+0?lWIIlt+X8Q`d708 zX?Pd+fQe1PqyW+9y==$)6!b6{H?CNp6A^mBXiaPVZm+D@zp1u%vW4#&Fe!JmB7ufP z{`?YRSjFUP{r9`U+Mi@jO3}OR@?1kRf(!h%ah8P3T22?9KgZVifrljUEd!9h#PAIR zyq`8&Ur-aM;ypj<~5t&2=^mfw(ThR1qYn%`~qaB(%S77_({BtsyIa zv&kWyn~2|J34~q~=B|AeYv+*Sw*u<ikxZ1K>CV)A}MPr z@ik@Z!Qam!HEtGItXa%iV1sCpPqo0nX!oqkWebDb0K1RW?~vF>UdgR~kr>v6u^yRXv6B3=FLzJdL0{HRc2ED3n; z*-;)Lzlj$|ZfN*P$hK9v#-J6*5k1Q3PI8fmr}*{cQGUOwn~8HzfpU&ZsOi`W;kQt0 zJCHgIRPe96^MJl)K36|p(D{}?RZPg@1@wzY*`)U@V3p!f2p%4pTSg+quK%Jtt)RA^*{*%R-V3M$O?V7VE-@-@=;!R zw*pk4bORpQvmZd&2Hw!a`$Ti(s8XMPF z0;Bsvj=6{?RWK=PmbNB?41^pd;ej9{loV?B>1imeHHg2h>EFay7nVyS3;!mEOrejk0FB)ND$C!k-KIdZgXh~UDNCLQusOVdH5B1R90cm zP=+1pZDC$xg0xLyJin+wc8LHn*Cs`)S2o$+-H9ksD&_!OZC3We-`gb=!o*rik|fHuD0|Hk^cczh zJQ3p7#Wpds!uNVhgNaN%4VcdCFb*Di+e5^797$Or9}eYB2QZ^ZH5s^L#6I(ZU2WZr zgpiw$d$2i*LC_ZGDLi;TR;|c>;xNK4S+U8z{zsc{2Bf+|<9W(PvKB5Z`$$Fq?`a+v zH=S7mMZ2ay>R(aRD7Vd$20mqImqi%FnamG=`E7%*l}+3RU0)&=tRhqkk*t_JLPdyt z?7J-QmY>Wzp161=f0IKy$GANTG#|HMpx4$~TpZ4@JAX!%I46eBx!|_wkV&uxib_&N z5DB+*tK$9rq4Mt}8F#q0>rY!5_CjraUH#m{OaIH`wbL>YTZhH-z4E%?voRU;^g@c4 z(&nYJvw_8fBAQfS3P2B!+?jqys0OxpV*9>Am1(qKZ`y_0ntFD2mY!}8kjrU)VrHhV zr4>1pyrl5(C_^ZfQ?ugHj?VmD%k6lXVonlO!q4JNbkp!`Y-0>4e%+)4fe^B9QMfRP5spK?k}JiyIpyFbWBk26qU(arRSHPTiqm zizgbmF$z|66N-q;rsI7Y4>=r`na6|5T&ij5;=qa$Ie|k?!p!GrK*#M*bAzGZtw*E7 zC@ISkij7*QwSf8g4te-_SZ5p*mI;7@%BE7p$9xi)7IzqppC|$SxbWm-f3|qqjTvvP zjDf~U!9Qff@=%x+k)LSnpge#@(DHp;h5!TiVJ9ruiLdA}iQp4q;;K{Dr3vNB_fcLV z`En+2&nRF<^Xi9ja)Wq*wFAav3jtzP7fQtQ>Jv6)hYY(h!jgLH;fb${_OPKpO!^s* z!h;@w4$CZS`Dm#ay9F0^G_5)boE|*KOt*>#LsyM^5+C|en|d%PdLuE}Mt=1uIEYyh zNwqZDtV!I^?Cn%LZYEG(vaJRQv3*y6 z_=87V%j4;Aet~e`8+A)IA_dCaXS^$TpNQt17KL3iua0&$zeMH0Isy_lfV+I5(2oOZ z#q@q9y;Ah3K=R=L2JEDlhFcBMA`_y}4Zz2wfD;h7cl@7Q5e8JfrqDU;ON$%`%3|L6 z(6qaeJeN$3s$ZPC3H$+Xyw6_}2&}mDzp^8)?W_U{|4Z1SQ~leY`>1e}1rA=lBMXm_ zYk+;%bZ#>;+qL)@1w1T^2Wp6f{sI|Z2AiSL9@Y7X7h<%7fj^E&$cborzNr77adz+%4Ye zr+qGYW!lpvTwM=|jGYd;yVmvd`+}j@R==IGgmWe(wNSf!MwTj8E(#U5Xb!JpXgBj3Qwiy)u0!eLo&_7CLIfIsl2r z@siJJ$}jkWt=#o>!CPDOf@z;tI{?4d9A1VU{deDtr+zR=w-tN}df{%HDmuX8U5qzQ zEm_JRnxqPT!jQCeTpgK1cigv0QB~Ky|H&B}b-M69(PQa!);KaytV%*KhPXeZ8^gnE zh+bS=(h##J$an2YJmlpk?z`Q?Fh~Zn=-9>2;g`}tTl;HiU4onL@BMx$=#4Nb=D949%{3Zcew_i(zt9f6#X=CETdP)z~ zyNAd`ps~73)y9dA7Sxq*M(j#p?=PI>s?MGHD+>Vr0PV+^j?N|svM>!J9>TbFw;OnB z@m0gzfe7Aj86MlN`BQf4M%7_v6^{KL_wgcI69%%yS~*sWpkuIaXHWqO5pWP~l%W4` zYg29kbARA2giO@3rzE5vcJysgk_@J(dg6kRilOMA(KHGuG)QFS#lw3s^x(JT3GT=m z3!%JLwNSVHs16*bWQP#hK>pG~1qhJlT}~*WDCMAz52mK?)}q93$)J>eh78u0+y|t+ zV_{8^FA0<UwR!zFP1!(TJO7j$J#eFEJv0)R)fj+MU+P zWz--K8@gWGy{?|KLxR)CeAbg;Tk-@G=4W?i>gU@p4<&?`r6AmiQR0g&;P?Aifcgvw z39@mQ+2g3E+=ZrxxRAN0$=tmnK$^RWyEl8XL~nHaYD8%v_Aj-s$idM3SMATNiHri4 zB^R`|$9AOr_}`8@?;N_m%Vq{!sv$_VW?o-`I#R+L5ZB2LcvKU&DtTCFVz4J$Z!g2&BAZZRBDcv;LjaM51f;Gn zfDK&kle7X2*qq(DoGfB4h2P?p9 zi;2BY4!-@`g3Tf83HhHRX` zJME2XgLSrZd9Mgjo)w#_1!F^Flto!o7?@H)j2c0OZEw}UAp%U7CpObfOx|uhs)%}? zBxyE%(1gFY&?fvJTu2SUQ1BcqgT2*x6~7hv=~YEfpB7&n&19VsGWy~{HuN2-Dqi!>Ey=?u13`FO=`Oji^eTMG-{YY9>o{(aA& zdFE{2mKB5hn^rbrBg<*kBpwV62y&9A%%C(-oNSy-Ui0t{t27#}>zy;iGz;gXjKK(r zp#~2*PR@rH2!~{R(!)abf!JvVaWUoPwaagIE$}BjOAFkEu;WkE=hZbW?6iFj{Ehb* zrZZFp*ocahI^Rvrs!y8n>RS`ctLqmNz_oyfC-~}W3}t!;ahO^| z7WgzB?8r3|!1%>Os_Xj0H1TuysDwO~po@)oNV<9T1FYiTYhSZ;@Zo0xA&`mLW$a&p ztor5*ldk|2=?gel<=e(tcjMzH;|mV8qnIf5cu1%d`6_l*+KCdQ61hc%F`Qeb6$yqq z?SG{r?p8NX>AN+mh#+d0qcdv{-g^Z#)hCjhYn+71@wp)m4(|zVy41{LB0VtLfwYkK zE!l;P%G&v#UZEh9&L-*&^7q&dt(Ge;9o5_qd}P38pjAHUKY(U6UePmEFDEh=pq#dp zD^q_4Q(OsP8Jp_e7WGtkll2wy&UqRDvp!m&S!mSD-xYzYC`8%`Am3=zFl!(dtnSE!iK@%1^^qR!O$f6 zeLgG{z+P7J*hrLNAQA<#`&BC7X!nwJs*RyA(PLf(WAPVKWQGLZ`BZb z&|iZBOUz^p@Uj6M&N-G#>m3m09OBEA_BVOKd@Fv(4;1wf3}VuS>$R_|33%Z$xq8X1 zHls8=tBJ23$Kq$WOPPa~XeJjDtZl%4w-2%{R;9qu(Gv)*VVb@nrY^ipjE_v%Twvm< zYucd`EoWB4>hSlDj=_;0hx)Px*WxfpMQnpf-GaXh5C&`~-{l17)uB!w9uq(l5n{V5 zuS~Cs{uTk8yE2pB-{(U!EH~5#a>qm1RQHCTDQ&YNWY~xx_$@g}X)CZ`cui#a>e({A z1T2{&^PEJ64>kSL8kQfzNTp6vYurdprJW3U2ASw-oD|p<3)``D2L}6la%z7s0t}-D zcy!ggKrRH~_+b~`C|dUvSzelzd}nZ;rNNqMsu zhx=gmaUg?uT1bX9fbj?+$xqSuozsd)DCBj~4&yxxg4CQ8M;@X77dAgVK-ct+NlXH~ z$;|Bbplq6rYR12VTru+dQ8x~(D~$kJmzb0j4pBz;d6Xwn4p>%U%B?|ZirDk&Bep4# z`qh(`>Fmh1IutVVLbE$4e2QnCoAg?({$mjpW$mGDas1$OAvJJkblt^ZBbJSIoehQk zQ#V0u1=n2>)1ko?-mTc{U#rM~r@VSaAGX7>WmTHaa+!Ay4{o8#>0+4f zuGS%p}TZmW$5HI#ZlOckcTj{-V)T)-$! zgA!JE!|J<=kw&&gkH}yh4*f~o0MaasJNj%cj0tHV{`zM7t8ZSUp>`sJ!w+GK1v?9f z$mht`x}{jbnXZSr-9Jd0XcJkD_L;3;6hLVbb*)G_x)0nz+sZn&H}GD&>;&4;=>Z0Q zRh{oz*6+^@^&VPlH;*@Kz;VJ~zt|ncsT81#jkkZy8>2QM3#py4DVtt0xBknq;ws(U zpSEN|__L9->^I^>9Qoo6|4&)HxU~^!(zAl?KOc1s0n)-M;w3F?5u$J%4EZ&(k}@1r z@W7=9-2H%#N%hWFDb8M0YMhp?v*MMTCMMI&+P=y(M45QfR|vI^J04wzbT_E-G$uDq z68roM!7f>v69mxy7a%|{Zo9VYw`Sz@9J?F|i-ajI#AI_^-3u8DX-zAp6rFcWWOIu7E8pz%%rfKc@l_Z@u z4QRk=<1}p@D3wof?e4n`mo{5$zdCr8Y6~H~zN}chK+)uR?@G9a_!s&7+TXN47n{4m z%;GzJ@X|Ntz__s>Vc}g!`V%rLAC2smRaZ_z=8d{&hLPC5&DwV(5r53|C?DX!dNzg) zo{&EzyJm#>-(Mk-KRuCUUPYSmuR`R%CRWC#7$X??$+#zZDN`%f7_#Mj;ZK~aQ~phU zkYVXH5hi|36)=x08x>U63LefoUASL8XB}+573-Rbdyc@tJOx`B3HQ*=f+2SCy<4ws zmu={uNohnqW#;<#zGu>e-ERGkK~MaUy&(RJ;EUu(i~n^E)$G-sV$ZOTEU%aaKW2VX z`#n+|CujWAhh$aR9M;oTALR=m zb`ik26IhBZZ&QjRUKjb6IE$Q(SisDz%!<@~T;8@($4q@wfeF zS3OBQmnzANH%(|%4y`~b@4eRW$sO>w^%h9X4-#ke))zuf)fWp_&C?|rKj<>jh9zj?!xxL-aG*m->@%+$b~$b@?@Y~<>4e|yK|bU0Dk z4%0F5#vkq4Ql+?e&%1756)L__Pi}~&Kwv#J`-+)>&5SFbuMbYA3HhCNlO;rBW4f47 z_CN$@kFNUXLI&gi4#7#29zJ?b6cb#)@DkE1KwBf(%ViS+!~QXIEF@I}@JF7eKD2w$ zJ4Z}q6NZr%AZ#Fz6(|}_&F&0S1t)QFDZ8=otz28AaG&+x_gES?*(Op)o+KD1>6NCn z4U9!$<0EGbLVe)41=cQ5 znC&tl!Z|fm=S`*&Z*fp!h|k+M;m}ZBjk#JQhyCYx(A~9MT~Hp57oF3<`_^J@#gMF% z!c;wapc7&C z%k(^MI3X%*5aB07X7|#QZ<0;RqM1s;z0B!eYak9=mo?Wh#lJf2lL-titvcMFsJNV2*Vm3 zeI9%oAu6>^R^nE4c}{Tg^^#o`OmM?}3oUx0Y$^QoUPGns6DUx#AyyR%^oO}Y&i*Ld zWzAm;GaCaK0BRpEVEf(BJ~491`Cb&4br;}IN59EfbA0o)7rsAEqdD=`3iF&;?Kewq z&SpSSUVXi1auTiRAkL&;&nOPSUixpFBv1`7aSj}CF}$g>P`L)o$cU`RYI~+9n*fl2 zYIL>P<;Nar%%qN|dH4+-n!lM1SnImCNdP>z?;7+rV8=J>ut}rWX0^&6i0WITWfOy+|>@gllRqWVV5((QfS zAigJbUtc})6#q5}Z^D+xh89xM;!TPc@BO*viWn;81I5sNub|rxoGOaez2g6%=VHsx z zri%lL7!m}0Ss=-;b02kpi8W-~s$X<6j$U3%X#*C90bS_qp;Evj5$zt(+YJ3;PrA~g zPH`eLqq~gv`@yV4a%}JxC4Y~np(BGwSy9qAZj|wkjwN1YR2$P7_sNlep(Hi+nyo1x zi@`%@-y+kaE{}@znInJ{JGg$}MNHb$NP>)+p{>c({m`}*@`i+ZVga8A@#&-OwPB9dWKeaHn>Cl;&TRa z%if%W$gD6p_Cm59%Z7m1`^-OLVk(jtc8 zd+;E4{IqwIF^BCw_R#a~Vp$~8 zau`Ie&Com&AL-WKi>3E*%95)twoUhnLZZxOpM_E=y6nZd%6blw!)yu5)#gvPgGTC! zcKnq}+rSVvR#u1Dmjk&p^@F^`6z?a&s4xn$Gh)ILGPCg|b{CKCpWNWCC&-Im}?-1#7y|Jq5Y2T2U7NeyoV% zw(kU|VFg*S0d%JVDX-TAll55ah@W>yK0@25g8;|o&K7FIH#>YLF`tayrUJB31%lG* zg14!i)f*jwDj3mdB8~?@Co1EP(_!(`MDN36X_k81i6o8=B(e4~VN2mxAFpu5^x_u4 zv8&4sa-_#vyg63Y3d2IlN6?ghG&cxh)iQZt&oIVLN0f5R)-2@J(?FkjHHjoDBHlKn zvhl*}yZg&DV4rQCgcjG`&M)+)w6OegGQgNf#I7!kCIBREfsZ~&u^`6E`(MNI4WgF- zckQm*T^$v1HhmYD>i>j}0eBOQQif_jR6OY`_09t1WanDM#c_IB5di6b*~}WiAj%C4 z?g6j!jOTazX1K65sfE`KJ1m^8iV)Gz_jcC7!wT4zvxN4d>DzPy%&LzXR?!}k$%I|#r6y~4x z60|pwrSIP|Qr3c@Z__U_4%>(`$HJqAz~)XA{9({d>pq3xkg5O{VX1u{0n~I9as3ws zvA)ID*X(~K_6g_Xg1aV4Z&l6h>5f7xk*|P2(tS1W{zXMEKNG z+B#OZwRAm!Y5-#JN10VVrtSfqDL;lvpZdDi(=)7nRtuchl$CR+Zsr^SIs+z+eCN5E z8-1_w*+#&s6SHl^X@9WrP>57<;xN3@hMrPRT=j>}Stb@nInp=_f z&)UE7M9+E}cjo8w`FlB!ndhkQ{*HY7=F)Y`5i3U3ASO8Rkg;_yObaoPQd#CZ@a`m` zB225!g>pb(uca;u_HasN`7vf72J*o8i}tYXqL-?7IQ0zXB^DQR91b_QI}oEXA!tJn zi|qqIO^A&mKx$46`joI-+ht|-D=cW6GkxsstvSTH*7vrhSg8y&eoW%i&YFN&T{4e& z3XLLI0j(5vP&$)fFcM2yvG#`(=gX5sHhsE3 zVRt=lh1$h!#h-nD;1ZDihNDG=qO;YMNya@JM*=MvL*Xgfl%f@@7nW(AQ;^K z=)mATdZE;6ZvjHx{Yfj#{z$uUq}G@l$Gt)zOmLW@q|9TNuYzA#`yLCfU;7{+Bb-b8 zO@yXOZ_9d;as2{b2Ytd;y-x|d$^Y<@a<0P_+Q1wKqjlWH{?tCiF05)`GkVCA!^n%> zZ2huUSYWx<XA_MV=nXHeF91#Iec|qP zBtpU|GCa}5(srfY?-Zc^PN|Z(^iSOhaogBZ9&vm1L?!q9T3SQ#RWA(M8=i*Hz({XBE7PGlzP{Fd%!z8I!0{U>G)U>NDk zZ^NOEE$pqoYTD3-u@tR}*JuV~RnsIh1^Y*Z-|Z%IpWiI9+qSlk11$;zd}~sJOlp#9 zXbc@{1`;VSGJ##;#q@%CAGBTqu*YIZ?*hFV>kin9+pVe}7e66DP8N@RCq^{2^VRD% z0`oJ|w`SIWMX){&8(0K))mMIAkBdhsRd-LDfpQlgCN{0rxFz6Ci{~X82mZRRbPcb; zB>(#^@bS^IIkzRIm(UZg&{cAJiWU9x2-3_o5)9m-SZ#5gV>guMKH&cFx~TlS+p?pxU`1}qfi z#yM3rs}{*EYob4!P~E&zT34s?MIj(XwtMjz%T&t z>~fr|%4}vRL1v_p)t*i9ifa!` z+XXWI(M+#?k82_~U>)b^v)i{X&;I3?mkw%(CeDk-fw)CPm7L;^ns!;3@Kzdi zahT6qQPM-G7#)WruD_NAZ^&XHYpWAfe+Pa1nY{Cy*oD8bKvhw}Bz48l6@Wo*E{Pk! z6z6j&z569(>8pVN=|%f}+<;bYX*KYe)Q##;a##y-?>!@GVC;kcM=wmCrF(pGx z2LJ*^y_tBa>;HDWNl8*;&v#&T;L&+Y-4ISfER=Edk^3_OokV~3Bz-lB#Lntd=#}7U z`TfHRf~up%u>CbN$DQR*m35_fL>pH*p24>ZSM@Bym=Y6<-#&|(%>&%VxYWk9X@2eB zf2N#5!zVGiqKWD47)lZuWf&gY0@{~CkQe9gY|HpSoe|%=U2bwnlrRV^Kyqk#{iUD( znkH!;0aS*?dqDSWnAzut?E@WUF5hQoc#(nTZe5X&os|`WEq(hjc_NfCLjl%tuM5%Z zW7!QBKOZFunPs&3v7P~jVT6nt%B z|1ydP1?b>OWAz;~1Rsa9qJWmj+tW~_>lUed0qCl{Isw5Jv%t$jQymoOO$k8Ju-ILq z!S*YzXYTw+fC5$4BTrG(aE5sE)~>&JB1L{D;rR>h9d1@4}&o6E>41I)fol0 z=E^EzR9-t_jrhT^n1nNCw!w@%Ax;pITJmqSB}bON^tEBS3GfU4%x-4`oQE6-lOcuE zu1~=uVV5Ju8=oQs#D3wG`HI0;LkaHjzgs?}L$euPOy{nDcnJFBdn^W^4t0|q>3E#!F z5te>V$cSB`_b|BfGd3=%iA+@=S;O_m1;}~3HV^FG%i5W5Guy%M9!R6aW3ONJ;dqh< ztW1?EzUz{MhW*giAL6~hd4k$Q8j^g}i#t{{Q zASGuzyw2t?uh>JCcYjVm+=rS;6(0}Sgf&r9MR3>l8I#37Jr|=Ma0iGhm{~v;l#$n{ zkLpys4tJXVVVj8^*njG3-J#e5t7!k?Q0ZJ$Cf49x24JI^xsY!}<8Zb7SW2>@dVpE-^m$?=r+ICpd=v}MO24hUwKdU}rX|x^OpE$}OOcAP zW{yeNQ9)JsQVp-POUg}VbxN63BV`z=)rlQFk3Z_X8;Xp+!R99=5seR&WtnbJmgaRh zAc_6{_di0^oGSRdcZy|N&n*J;gMxZDM5@N{wCKj&@6*XU=g&tZ>-%-{PCHizIN#-D z$20#RfN&%}-w+1j9|pF(sd37tJOnVLFULahfn~6Rm$8xLBLHj3nGM{_EhQzTJ|YvL zoFDwp?<~iC(Pp}z1Z#Q4+(%}t*{2i#rVzbI9I6}D%f&xYqLl`yd*}xS22#q~vILOV z<)VxHL@h8TSAyrwVKXBra&L-*Y-IEhSPioZ0t|;n(Cc`>We3)tR(H@+lIYdrfh6YF zkl{(1DF6WR!TM$8)`N3;_dEDGAra9Qpd})@C#qCF(*0)c`Lhc&SyA;pme|MXkPJSc zatMytvGfz27{{$-&1t?HdyRBJlI@Rgm+d(6V64!`wtXG_*^5cUD>A7w5JR?|J^3<< zo6VV8)$yB+-|7JY)#ju!O;+wl08tdJjK_%Vi-?vpx-S(%Go3iCle*V5sPfM#oUK$b zUMyk@nk3ge;|;c~lLq)fMBL=e0vmUD!zY2`DQXG;x>w5@wQa*d&22AE3~0J+!e_gf zLm-UtKL`4*RYtkAI(!7#2tpvvOB-&r?HV zHXMgWi~=WofzFU`K<`4FTqv4&Pvk=~BfE(hzY#XyjDsNb)77i?DqS`VO|O z^c3Sr*5$UGuHJ=^<17>HQi*xaB3~b2dJ(wmRZ+O+-ExbQ#j71zH!Oremhb^`kx1;J zohkzB?!yRZz9vfq0+7)0FHWrJD#C&!VFDO}Z2;$6c&X^uCS8PK8y4lO)v{O)H>aSRI%dA;&< zc{8ptOF=)x({601zh%r%Hpp{4iD=wk0}YG{{l9uL4Xx@#)WlvS_C+C06Aze*{4?X? z;$RKvOl&DdD!Niq&=t$a+g3>x;2&6mnWJG3HsEeoxP$C^-qxIEzeS){6FNqcRberzdv5KY(laL;W881yJ5HNQCy0w z2xVtvX6KTTRfx=stU{7`B`cehQD#ZD==*qizdygvIOluLyK_37@9p;ct6Ozb&)0K2 z#{F@B+#gUMcT;;D`bAi{ZuFTvZP2*QSC4D~+Jq2;Lc&(;NlDe4rT5vpx{xAop1xx+M7W)g|fj6ZjlBCc2aD7plw!xeh%!M!`=aD|NQOE|aIB zWOTbt#l!cTd)^gxO5E!MjfaBCSUCCCenf5LpM+eHnZEpVf_z|w66OTkwahpT4-uE* z=lVaW3)E2eNJg*S620D$b3qxykhL(uULh7l1PxZBx0B3LFIlG(Ob$;Ua!Lv6**$c< z&ROUYgHes)O`)=fn(fAKUv@50uB8)dPUTS$7|}8qX;`YFZ%S)W@&~D720!0_J%ew? za5!Tul9F$~$#dJu2clx7@5HTl62b|NN)NkRMIOlPqz+&k{1Hs(2ehGp6OT`x{7msq zrYhq|s|&LdM(R?pjo>Y@6~W|*+_>E4Hbqiuv~+ow(75yp2NQ)}4U4eD0Z}IU;0SXq z`fxH1vO-^I;zkHjiGa*75C*c*@&lQM_MR65>fV5udrP-0p0BmI>6&x80P!uI_}wDX znfQYfsyOP@bFaK4X?y1GSEm&$4C5PJ66O0UaL3n@oZ3J_Pb9)yq&2>;ymgsBh8Pq| z29(lLNmwN%nlv!j^Rj3;G%deIN?o{QN_J;|hhy`N zxzfm#%dyslsA7i~y`&E5_ei`VAm)g({Y-cHtUdLXSStG&R0+;j+3OaB{95GA-wZTk zn(p?L$xm@AG;X-RuEE1%8H)JEBji!Kf~`H=pJ`=vuWlXHHh6VsOZj$ZEhi#!gxM20 zat}~!nyKYIY*ya2b%z3|%~7eIb9q?L)oah5bbKKsV;~Xo9(+-r=DmDo`+U7g>^^2M zBFVMUE0a-@x{z1Y%0nGS~JaI>%t#sjnztThx-^(UoR~ooa~MxQ-s$4 zqjzj+3_(Sw`(ju9j~yrc#1fWh`L+9ql=d<(s#0!^w#ePzh4|h!BN3r8R>>oUkVrQS zs4ulae~c+os$e>P?OuWLQw@^zuBNWxr*@G4${A6r$ULuuj;GNoVav22bA_!L?)XkH z4E9QUBAoq&CLKE)(0FNd%dd0+MTV2uuYbI5d+bX))78|=Jv&qJx>NAC)p8)$!Atp# z`*mF83%Mi)#g1v*&g)ZwWwfW?W~njVZFY7l%hifNvE+ZyIE0VA^yS@7`}B&K725M` z*h-6N0DG~NsTdu;;bqT(F!YL3GA)8LnxQW?wBT1Vy;_>zH`!E*wP&W3hwim6f&&`# zz9EeY_}GOx+0#)dNClYJ?*$N5F-(y0#l2Tz3k*P!N~2euo^fbq7mz*YcXIMmQ@Ho3 zB;4{CoB94@(jTITGpUN?><(Aqu4PJ;-h&xQ(7R~!=8&#t@2+##?fB0uN}?UyCMUH{ zbUoE$f0bdEd5rgNfE{P?!QhgtyY$*XLlotR5aEcpT=8R9m$B`}KBH?S*jwR`j-I%% zw#3SFYJCT|%d=dI<-}MeAuM-^jzZ`=5ho({I~Hr3v|leUpPw-Gj>;pGD=Qawi_JG~ z3m&$5*jgddZLltx$89;R!9wgh^+90mD0#yqU$Qnoeycd))g9_Dvd=cNuB+x1-7%(t zuBPREzVG~LmX4RSFO#6Ka@CIU1(G`yRb8Q(+V`xc_WfwXYI=(oY+h6v*=O`RP}m-| zbZh1k$z0_Vm1hO-<&PjuO?9?kkB1i-KQViXZI?1I&#O1XG>zZ$G@v(DAvCJlUhE=f zyna-_#^Qukjh51Q{Ue&4UBlCI`xD>3)GmJUGKY$;lC8kSU_hn1TXNj5u|#}&n*0*A z<8!=n#1JqF-;;V$LGrF$uNSy9(i&^)X}sHbSMp+Yn9_gJybF&l< zvNT1TujH;%$E`y35^n-kP>!#&Z=4*sk$LJsN90g^bJB<9cFN_o5@na+A}ebNR5}sY zb)%@#^(~U-;ST{FAv)QBT}IZBU=2@8QVX5LG3>Kpj?3pV{9;nyW{5UAJHVkl^Md$! zGzZRicS!uAPK?5gY|VajbJfy5hnzo}n7-9mzw;hWNrn%eJ^B;6iZiMzt|L)^=qv@6M!7KZ^L?rgV>)YfTKL7 zF+}AXF~Q{nc!CiOLUHZw8f6{C&5)*FF{SH1NBPCWSR+Z%ug+iOmJXl<{szFZWY27( z2ay63kc!|YxOkuOTdKHq`2J)PzDQtN?-QevTe>~!+N+Iqojn&XIr;tb&3d{3Vei}B zLS0Zpb!e}39FzG*ZRKI(d`CF^mJezkhUHq0@^sIUsyCZy`-ha;tC8 zf807+DYlO$>dqI>cZqa-Iud$Rm)QeQcKron8C1oJCK;favN(>@Vb^v|q0JV3VA~rD z(Nxt}?D^261V(QT$)2PX5N;F*=PI6?_P~8PH(B^r5mkNzwDE)Ss}lT>OBI;;PiQ5x1KvEz;|;WwxB^^ zr1A+nCVHAMuK~xm@M$@c@S~AfM&!T+kbI3OvC0TzuZ8UZSn*Im{ zEVXTLTx3Gj_EZJ2RF>;Flgt)cjtDPeqEmGAV*cI9=4Tl1T`1egB6wVzLsZHdbBDpm zQj9!5%kE{EN)3?qmhtzk$K#9LMo~0lTNrx>xjQFtS@NxT4t$XFgDOXt0q2QH_<}ast@eNY&)>Z8urBr)4K!4G1<{~V|L`rtSvfGd7;tAIU=66s=RH1`7eNMBmT^JgwVIyE;|foPrMY#*5(ztt@``!7=kj&zs4zOn76t zRLrQ{yX-txB`@=-uUi}>(@`agS<1J|WXUhCR%fs{q@O!sm^O33*D;o`=ng%gxwy5!1>(6F@6gi<0?N2kOJT3lkZl=QHF?Wbb4YffqL86rkaD6GIv`#0iWZ9w&DVR z(~XoOUqm*PUr6Zq^)_U5EhrLQ?-8Q;rT}e1kvs{YiNd55+<%kgx%bA05F$fFNd%ut zP7mZI9Z{8{^&oU)QB?i$Fn%8Op$|I^%hXDlIF10`AP1WvkP$N%6Fq=ay7UnSy@`sF ziZh5M$d9i1()T$#eZrir+P6=7B z_2IiM4!(9XL$o#u5)Q_El$Rb#jg4}VU3HHrx%tI6+oidxRQzFq$oJuXtLuEBUL>q2 zZ(2Gwc@zEWni_B&I=oBiMobY*4@r2Vv%2~|UO)Y@VC-n6kcCi6X3Z@YVTpa%yQ_;L zEs5#ejdSb!T2(6DfRw1=*X9cGZa?F5o??e*omC9^&W)kRg|HcsZVlxpSZRU#bMMhH z=CGXT4zq0zGPVduMuYGMZFb)6E0)xsSDd-g*Y#A z-6{Z&!g_koj0SyK55IC(VsK}gv>_g%K(k|67AHS%)M*J}%LT`ZWKPFqBdGhsZyQTl zP0O!?1q{*tCKOt&m{-INTsUVJy#|!l*k{f(J!w~x7{X6o*U}SnVoO3!M1W7yvdX)S z3QKOR?@B1v*S}?vIG}Q~ePZUx{h~2mnu54yWr{IFJOZOD3=yXLaMeoOkI4JCpF8hE zeSR!4(Rhbcp@!k*R0iVPvWAX-dM?;f@%?@|h&d?VMrp^7+8@Y6FVXOPB zw(GeW!ulBz&izXbE^`~*;tcql<)D+~B_>tCoB76Le>s#NMrVd$&&rw_%R)HFOyfgF zNn@tJYF~6=SHf5yclBPcw#y7W@CK9jSwz!$a^I`<< zb%4mMZh_?>ddP@t@khm@EW$t=nrWD>D$3xzwK|^|biRO@v*&J*+2O;7$sad-E36j& zsP^Jc360QPSKNWiQ3#QE8Bp!xXViP>p{sCHmu|qADf1!*WSDf{K$jttI65?wZ#E#5 zz3%9-A0WhMReNXib&6EyzGJS+>ULrbk1jWA9|~S?;m=LwdU!HY-1e*a<)K?XK{xI* z^!fn$HVu7450UxgXG(%?&$0y)Q%UNZypjIIf^;Cfe){(9TN1^)=Pz79zh~*6gxPvY zc3iI-C$+q;+8sw#f)3yQUMI&Ix!$B^`Il)}=!4I_0U?-9zjVhLWlUU`VuN4YF&cvvSZs&x|}iVRK53_w*g-i&gAzXst~F%W(Qc~h$B z9MFh605VyKkn5EK(sB5izpmKprtC(5T`GVUE+##HeTig9&5|`APCDQ*1)Vvxp^f81 zl*|$RWTJi!B^8c4O9%Yv$E6p>a-a~k4C%%JGo8}X zE4z@2;BL9D-atwiNKeUROSIIBFzvvB#&0%@)6kk3&uHP@eXsAOoTQL170U^|y|L`_ zUKOP;RNe;zTNB@mMkq-4JqSI#u3|H;_pPmFFrg$X*ONT#iD)`Io7W@CS%C%jfS6hk z+Fhy!f2OrE)NTm6$!El6XCVn>m6nMQ2U)XL`DfJvcI*0tf71eJjWhGeNl6Ys&1w`} zZAKKHYxgw%ep2C?Z2zM}wF7FP;*u#E(2!vY><$;u-RH3eNx*Xl`5ma+ORgVamN!Oa z+GScrNjKl-EH}60B*){pYt=y4oUO@D_H5}yFZUsp7omD&e(atznawH(8sfsk4_LWP zgQ{YpbTr27J{#`RhnSu9af>Uj?{n^KuS@BlC27s>E`kPZH*w$kJSX~x9S{9 zQhu<>xISJ0iMi(F1C_Y;b}bJZu{)oRvF2RuE$t%3y>#f;b)+-TRL3CnxRcRrLgg1! zHIKPgt_)pGXX5w>L1scniWFaAZ;95u6zz!m<16+c89edacTZo;h~FPzBASO*VSDn| zWeFf|IuwU??*ZLxSr=DcdXO=T^>V)f8ph9-wWN-Nq;Y~_=iOqs*`9kkxgb%D26@ka z6rB9G62*fa-R7Q~4{AV#``p2FME(2P)fi+Dor)$S$PBqW;oIZdG!QZ#x~`&((N}-_ zQ73@myP-|sK4MOa&vy3q67NIgleP~XO1yo5`3zb0b;WfWhP6<@f@ud!b%JK{&6 zagjQwy!`ELF|8C^M1j~<>ru|FQAf(MK|3qs*U`46sB)G@R|Hx!B6&IJeT#C^WTqvj z&Gpysun2CA1T0gg_F8f-}*Z7 zsEO8#?J`x0#W#s0V#NXxt@Lp@$H`%r>-9q1#ewY1(YD{LaCV)SS)QeDCx5c@8}TM7 zUP`_ntv!y3a5j`|&P+JGy|2`I0_?012;Zg-6eDVc9D+PxBn2&9Q8b8(pihN3=unu% zz1EQddF4}4Zytf#@-dfA$0GQxLE3Ci4u@t9?R?T#dj43cjiI%H?SAN%;p@x(TrMeWj`XZ;}c~vLI zf4f>+>ZDfU$Y{$gNei;vC8_;~b6Om7tcn9aNfbMMcVPBNQJft1c^y%oB-v~E-l=(7 zN^!LF%SV5Y;bk!jqTT+kZ#-S!bwm$KtrLCbzq^q^MEcqE9mTnzdarfCgk#N-2~5^L z*-nqlpMT}}%tNmJK2Wys6Pe&C0|4Kn2uAwaAolK~?R50Df4QL?|0xn*jedo0i96iS zS;~Ewmoi$v5+|r2$BmmfoAi$gzIX!TbjCVXDv zK69x2sliM*&bSoipDt$__4wK3HqZ!T#=OtY(iJgQ*6HfmM?niR3yqifdC=g@n(1Rk z4h#fg5c$Ziw~;So85nTYFRzgE3oG1}4f80T!08`kC3{$W%krSWL~%S}-PP%~VPPJNvW=-pn!mjQU%Di<9~VX+tG6NNS?1eX7oubkX`~Ik8xcbh(l5 zbRCjOE;_+A!TNxMNG5Y*P*z}(c4dwZKuc)x`(l2~^_9hakqP3Ih5BzIKU!{uj(`88 zGV&?xOyQGr3fDVga|5JKM;5$^ChIx(m_YpA?FXH&msZSuz^U-(*KfY^49{H&w^{TD zDinnE(=b#%0XWC5N2v&nWDcGc$IB#_^airg{?BEiA%x zG8>AnIKpT0fY6mu*?G2F6ciyk@bSwx^!Up>35b@lK% zo>x0CfFHa8X;m4cPZE3(xd^;VzDAIJ;&%7B@q6O7v_) z80oRKyE_Gr{}8WZPcoc-bzbQ6(dd%HE_jVA8n2ypdi!W?4l89-0Rlda0rdUJKztun29bF_sn z(W-D2Dx=KKi+8B=#K@B#$+=5AoAd@sz(aJ7;rvh)$jVrjhr`<}%3Q;>fZ&2e5d5bi z`QGwy!T%c*B0wVHUu4{2eILW2@S#4$-wHSss2?mOA&sqwTBfe?Pzi*p0VsXiH@?Jp z`>pohVbHgE`ie-hywJ$Kv}M(Vr_1cf(x%y%L>2ryiBF(>EzIcQC=OE+q+AyBFTgH+x)aj%HH9uECOVN2Y{of)WOi`^FF^E4zTT z<|y};gINLIZ#Cwu;Xb$-P20j@E^CUxLK78w1S26dYZFwo0&UX*bGRt~ssTP2X7Zn@ z<^Rhz_YfN0Z=dOYxc3ran0>`*q{K6EeJaF3%cK6;m}4nN1fj{>r!m~uy0xGK7cMKR zi#!grZSgawC;4Gr5B1N9faq2pPC(ZR@*j(Eh3tAuky9S@V2(CLm`4Fip5tIuJwhneu0BY6KY9c!Tja8JpcU0c|ZMfU$sEE0}a zUD>wWQDlDr>iIZnao1t8z#4|RwQGH3HoR~Ll1UpgwZTo%dRLS|R69buIJ^Sp%-G`- z2gpRbaJe+va;Ql0p41ijm&`C|l>f9+%M9u%0(J_c@6z=c*Nx z^I)nnMJw+s3bk)hmeeh$gI+|~pEbQ67Rz((-ShB6qG!nu;}cEaQ1(gCB@Dw)ew7w$L5V-yZ#xx|?33J66+K z@0d*oGw7V(HVTocF-Wa29D>2WehhM~R5V87M^LW(`(fzpom1$xwxSVrhLN@?kQ1Le z!1~0XSf^+(HY)|t^`JpgDaXoH&f#@1 zh--x>SzK8S!~fOwr3SLlp#T0tVGb@MTPzRXGg#)kwTj9%7@br>kyly2v$$N{>7(;T z?weLMXt7rs8LUxqX&SlM6YnQE$O+S!gY>U>VJ2#r6LTLhgj^BK zo@94hC&CrU(xO5_QB1T9HTU4CHeYgMr_ci(TsXzuNyYFmL6dO=?%7-lQM}uVJ$?|w zxz9gpZ+t8-GRBplZC@-8$;nC`&ML)k5JBw}OYQSeJIo)21wX_0$8`6rUp#KbWARoXuwLhmBlf>=k(xkrY z_<#djl}SZ^D3<47o9GS)Ht54lFBw+^)q2TU8l#uXs9|O<&9jzeW%4i**bK%%+FfFL z^j9x_3S9gK;p!<`X{hLApUZ2dH!$V=;{8ds={}3!_pi$hRKB8oO>HTWQBg?EmlI|p z%}+%d;=_1I82G1#*+-o_M92ko;byEfmnJIVXd*F}rnDoC15AlLD<(k?7j5n=idhX~ z`Fz4YvUS3QQH{Z3w^Z#EJOlz{gfS}6b%G(5C(5(R+P`%o_#GZMwrZj?Aa|0JOEWm= zjxaL1yAB3ZOAon~wZ?H7mt0cj!XkPjv#Q^N_YrV(IU}=(HG`2^#2icR}yk*5K~CssAOT^!K+8Q$tTol5axdkRGpJ zeFY|9#DPkB;>i6%WCr9E%z%ueyTF{ni7lu3tmPirDjr|n*rI|a?Q{)>si8@i-^-G~ zb;nmaTYM57RwE}JuZnIX&DZXrHX0_VXsm(dh#dO4-En)?>H_sooI0b~J!DBRaK5UP zo`6H|E|>r@^F^D;)*i}yN*^xl2_>pSw6Q!KTIog9M&4Dr4D_+q;^GgMPH<`RU9deR z&4pbE)$Q=Qf+5lDD2^!`9&sWExjLp77u2n|BiE!Rbzk)s!N}&ybmvy4sL`9HL{5T zXc);!@X&-W6y-Rh3hO3S7uQk*BX(kI1ujXWy!{0=Oo?2?8q89^H@ILQ4WDfAbDuo# z9y<-tET-&@j38g(M90Zfr>5^7dlO^gRQq(J*?vDVN*1vdcqPMc55{I|U%FZYW;=K+YIV?~=y(%6)USv0@ zL(}%bR)m3_u%kYLuZH0`GzQ+)S(RNtluPqy+~qB0G-+dZE|@0v&zB!^VS}cP%B9pW zS+A~K^@9Jjh6_&ch4^VcM;fDv#~~&R$W{u8azzO4qHf8sabeM?4QI~b!KEXMZjFnF z8MZol$mrPLegQD488$?SWp;tZ>=LTLyH<`Nr#1?>EK!7SRV3I z>E_4_2#E}DrNvzBeg^|PX;U67g8!Hghuy}2-!TjA5k%Mn_4@GS%ErRIra$MqbHS>*^jO{O8DAh%l_sz^a{ysLo^8mzNgx4h{ zf?!MZg5gCnUAtTfTaN23ltL_~GFVJ7+e2Ws`F0kqxUu$~k8ZFjL&LbSb{anyZp!|O zB?T8wOD{^12+lDQImf^5;wWRSujxJeT=T?x{n>Z|W9^Lkd8t&(#r%s;(5rKO=gTjF zo!I9;3oj-~p>cNYcB{BqKrB21+B2LQIsP!y)Km^k*{kzpID%N7BF}?3A|&!V0g>NE zKZFE{eyDEPGeK}(KeuP`fl!~VoeEL>o(U?Imdhul`#G>vx7FxKlr9Z}Ps@qCKFrF2 z)ndNP0>8XE8-j&%xW~iT0sGy3kL%2@Ud>idV%?G42kygD6fQ+z!$^@C_bjyNMqz4 z`nZq+yiItel??b|5&EC=HsIKZtPLzn1+n%%&w$l4B|E+u_58D2@yN&d;N6YrABITC zj+Ei0>b^!MeP0H>v35`o5JZN#kW)}NjM3u|za2m~=-GZ;`hJnb)f-+bcXicurL>OZ zvK$uAh@~Flgru9Z{Cr`|Vwn-8X>F}iXlfyPMYSxhME zftg|lMBzoWkW^fZwj6rY;^v9I|GXlM?b~s$$9z`VMd@9WYG#fS%}tc}4yysAu{*2l z(t-x-MMfBZZh5G;jCshqETQeIimAF5iA91_#6pX#Ht)yW!a)-cKc(iE;kx zM4J~XVRC)!6lOqeYrauDpE!Z8&)g?#<+OIRm|bK|+U0`z4!lV@AX8gnU;80jJACHFNA4aI z7tzSpxQY9vWN=hCQkA-7&Pb>nf-TP5H7oG; z1LVoVLA7&cju19%83`=GVMNbC0`r#~hMgXQ(LJ9Qa%$6ulF_E6OT<_jK(jc9$}vpg zvC^jD+)7>F<#?BVqpR6Br`1I~W4~|L|L}58CnSNSVUeRNRs7moDvK6;8veI?5a?)y z<=)We!n*1d8e3lr{DAQoN%kFnA|T&JpVw?{s`x$op_Po`9Px|%gyfa{M{mcqBB zPV~AF94JQ@CEo04!738Sy_Ee-7i-;eS0;;)OY_t{n9JlKTF%Z&0!z+5xOxCmj`m-W zg#XjZ-@zOx-1zs7akOAP&|kwz1!fmxPqD>5xcN@HdfVy8n$t2iu@+^_@_AKQLEoWLI>;za6A8AQxe9WkeckU3ue-y47(rp#~JQbY4M z3l8^#m&EWIScamiMVm%RV29^^=?+`F=oy%~C9rrcl#WPjMAUDP?Tv8fd*B-9y=Z4K za%BvCGfH90W)pCC#i46#AxJzJ1Xh&#ZtomHthMpR#1LfkX~y^32x7CzpFXZoMu+9I z1pGQs2A*7-vi$Z?C+4_af9w1Am-<^@VUX0mEx`c9=zWnvTopxqFI!tfQ6?2yY(6X9 z`cFK_5%5S9;>QL(%x4nB*;Yxap~IfAsh253_KK=m1M{8mYlDyTfifS0=%&Eu#{+5p z2nE%N_xq}0;9UGPmG4rZ%ZEFRIztV}JTQyRc7q)IMh)aeae}uB$~aN00N$!aBpHyl zEPnMoHEJV`%MrGK#NE&5G=)=lIrtg^Hja{`DPWW_|DAwMmgp%?_*v^Ig=HG)$-BpB zGm+;tbdr60uw@FLeM8Az(M+_+Z)dxORYLzn^r^I)FnevWtp!-QFRy9&|jz7HC@5+XdoECPdOePyfJB0hQFcfb6oyFZ1@n>!6&C%qNq0;AW$x)-z($ddbREU)V%|uW0Z*<=mH8 z|E>EvZ*wWg)N`Hp!TsFNL9~Denfz;}F9PyKWu>~r^0j?&foidy;`(!2r1Cm%r-P#c z-4EG3VYKP}mw-n`-{l(^WXOIY`-zL2`3wQV?kWPyfJbAaK*DVRXIJh!m$HeZJ#RTD z?K|#+_RxM*uYh{!sV9AQv5!w9`t8c0vi`^OBVTIR9Cc=NFOFN;srp1aas{&mIj-Jy z+18X1hxC%gTC=6`qcrccfiCr(Zgp8-T^)tl?E+F0>-t5DY~h|`oXmn$I6wV5?tv)p zbCRvK&fk^M`~Bk}cGU>z1xy5i6#fXt_YrXMUHV2dDIf=>p>?q43p~EJEc(6 z;`GSpcYpJO!xYctcOVc5u^Td9Jjh~z=GW}>dM3oD=b!{vMfl|GMkp$~n{WugEt2M_ zy_h}-zW;a$dE5={E#^PKxXh6Ikmjybb49RkIdJ7pUUu2#XJVx>s@VVPK@o!B*1bhQ zIRL&1)Q=(W@qYw!_!#~lC7l;n+KE8=Gyap@C+VJ%CyA|s*Pm^EOhJlpLM4RS2ws1X z50;v0G$TZ9q(bL5M-|&d>YY&%it<<5a)Z32rS2kH4WoZSW?D=Qqm%dcf(JxS>bX1f z5Gi(6(Q(0hG9dj~gfNHbDKPA-LH~Ga&?-8?-`_uvSr_ZxmwypjgXYfkFN-1>rzo88 z%j>B+0L%4eGIm#lqG;$JstJUmcru!vl5%O9t|}pBw+Rbow^{HF__I1U02BY^rT=A2 z0ymXGDGRerK2Z4`C1XI{e||RVI2<5HHvt~?eMvxPY`Ih2iD4}#51oQvoZn{JGlZh7 z+tkYd&ydJIO$+Yn*R4j=PH$~hx&HgK$TSd{ZT|BV?-V*oSa}`LiBZiGZFv17?`KM6 zLs6{yC74;LBfJu~qT<57J~a0jAKN7S)PnjidC0v@|gMxNUwYDn2VImLBqr z0O>6bcvU?Ag*aJZC6lTR{7mek{x0-gg2GF+x+KiV9jno&x*HkUc8AgYsNF!3!_(vz zBvK>b28fT`dNvA7#={6ZfFuEPmtaN!tSmbKG7Qc@8j^pi)E*$7ui<6Kxgwm*0Q%W% zp=E`AqW$CP|J71TNNdmGw@8s7H`S$GQt-!I@veq&)%yV8DN!QV+7R6Gh5>-2zX=vE zCvt;E*>;$zmsDs9HL<;^+j4TL^38b2a_;Xf2mg022PXKogZ$4|`)P!i-@(2FWwh>>ptnK{R*XYDGH+yQ_IUnBu6dSkzY z+;3Y^63bi0`VDpIot54)qBS>F7^Ev{A##uA@_qnm!w``6IFjl77LXPU;@0-|bpDf~ zxPiL0r$!Hn*={L8jc`{LiG&wQS%ti}debq8#)9m%W>3V$^>P6^DCDgy2`G9%cRiGQ zejL`~13-DUE5g!54Rf?ThQ}X40bUTQVJK0Nk?wU(E=aT2WNX*`qY3lJZQyI*{Cs&nJQiff~qOM50K{^6A_ zo^R8S=yB7n4eHp(3P+wcheo-13aI!!@*rRne_Aw47Hf^UcQ`h?j4Z6RXK1x|0Z<|V zCr*Tn&zGgZnI075wj=@`Fogm(#mu}v`b&)KJ}1=UwWBK_QHUyFj-6Rqv1j9 zpg&wbzC3V&X?iWu z6%EC<6YUtJ`0V^P5cM4l&%LXE)&uPHkT3;6Xw9vyEn}&pOGmjh_Zv#w*+-xOxHEsG z@L({KaBaB<0ad`S1_0$8%`P9PJ*?B($KTGQzR0A>_C&Kd?Up{DCC*58WK8mn37=G? zpPA#KI*tDJ!+GlkiF3EUJ${teeI<9JVm49snDUGKt53TQb6{T+ELFUM(m7fG=mRe3 z6gd6_z|)OH#y^hS^Fqj9ta};<$^b$$MVG#zhPX(1M5PGC0(=_^Z77in&msSh_s-r( znYxrLh&_mZDC0P?=VGq@t+cB2?Vm3FI}nt0e6>kUKh^Dvg2!k3o{*EzC6N3(@jfp1 zC1x7g%*p;DNF9dN-Fdgv(MhE3i&6fqJ&rHWuSd6vA5V8QiU7FM_D{I-FBj`mYoPmu#cu?2c59+HXodchxWZ8=058Bi zIYeLfN^~j207@LZ5UT9XYx5f~Ech>!_KL%4ax| zjsAv;^d7c|U@g_3Yv!>Ddr@UcBPH&0+1|!TA{#>wq4RIsn2Q4R)gOYHry!NK?aAo2 zX@xOi+~eEt){UKBJVcCkDiqI7Fe)`7IPveq+_QLq6g>c6699NL1wU|6|oiy)nUU&5Px$CX}u{&SP3biu}MiSP%;U=v2Ydw3_&onKUb_P zRDR&`F$s}msW3NMI;(~Wc!5~Gdk_$uJ7Cfk2{qPnaRk%Rh;|&Ln^4hTX6ln#X11uRg7%lJSKn~GhxAu{0qN< zxP(Oh2e~xugpuI!l_o+{0t{;R_-~*oQNkBLo{UyBSAgN&a02hUySvXKARMBKBUD2F zzn1F%Z;^q@iJsm9=`t3U*ZK-;gI3tfjll~FWQb|8gk*5Sg622giNA;R3DEl=3%z8D z&pP*#ufe_KYp-VfTPRZHM5`T6zUQIWugV_uP&YND7fKqj+b9U#sY4@?^g{i&r5Av3 zB$~M=JzsQi!s!>8J=0%yH z?}&ISBVDUNwDs9qx>pZ8;G?UaKJAhP2#-b745Er&o)bIZ5_395!`T4TKB_uz`!ZlGS#+=%~Mwb?y8M zZ4j&{;{+VW#A_8N$j6c&pZ5zz>Ac6k18EFFPcZ}#ife^jP%+5r;g9CTx@u(Sv_pkm zKxWYboW5`K1M2@^T)Uqb%ZfxTiAZO0AUAmI8_Cso^Rp zapZObr%vL1Q#n7ErdHq4G%_wt=PR)ea6b%`e<@xZNA-4j8aHkB#WOMW`b*NVS3TFB z5VL#h0XNF2H9OHf`Ls~s@X{e9w{{1F)RuAR`wjpVSLxj})i957WsOtRz}PyvMg+jh zMWm4b>rCqTK&)UwZKYzdG@ zgo2%eP7epbOro;wbdd@`E-Oo{^)N}TEM(qAFDkx5h3rwEN7mo80H(W*)*<$PVqRFL zj<)Tv&S3-Wg1^VP1u+g(e_(+tD5<|lElxW87PwJ-$PkHKM~SxKXy28G%b8R@HgQko z$dRK(7FZ&OTLG&cY#cNKV4&eWJbs|N-=@EM9)Tl%HNT(`6y*@Hn$joOd(}gIh6o7I zs<_jum@t&;fIR+#TqUcg6IitPtPIpW|AFLYIe?LIZS$74ZE}%zxgLwkr#e$8mSy!I zExe1}md&3rTziriKE`NSO9?%@paM)L_ZargN+AKExtF|^U{;9x}e|t3w_p+2T&n3Itat6Nw91uyFngLR0>v*XR z#v$Sly6Yv}<)EJ`{QlRlK!BYcYi4<+r(#qq>_jCcd11^??sphzdZu4~y}C1jUQqwE zQ!(3P`N_2>7CRcy>hu1pT)P#9ox$;Id>4Y{qqjw$wuxW zg>m_R87Z9DYOLX6ev;OwVh29BlpZSJLaYd<(l864i9f1180DYA2tB_>SfG$81C>U$ z28Nq3AzA3hyx)!p)8WIzSH_-ue4Daj-@gAGI8^UG9}pJtnboYEw@Znv{-dD;fe_0c z?m>tI8O8-dOB^^2krBu1eP*0Mah&N#`h~V`Lz@TQOH?=XMv>Z-T}R;jte>WO+v$Dr`QkNw zK7H##MNiR7m$zq6K|kek{6eWdncqf;-%$kjsW4itTd#FWsjDz*V?~bApTu5%8o+sy zYYmRnMhEaBoo*!|M~GDKw)!ittmhh6O=C=^pcLDhe4pA%oJQy@(i&hWCn7D8Xrzuca`-dSTD zxCIhi0#2U2DcA(R-A{gL-i3N=FAkE8!~2*92o`5->ixB8E0W6S~`bt|7 z!j*_{IHraf6gY>1#v9A$Tu42{0KlGKTll=H@Quj>c;pWXpCjdcNa;F0>popVR4<}V zOoSG+fj<-Y;yLi_&#cpi7;Koko=`s@vPIpnMPU^Z)wICqxWoVPB;wG3ivJ58TA^}M z`SRCjKRJ`lx94@+$jq1UszDzD+hQz)vZymflsQgAEs!FmlFIu%{&JbEaKBVVvK#=G z6gU}n0j_~kgxwzOj)XD)%`VJ;;wIXEdba)fnRO>jd@5?_kQzq;ms``t4Fh~NWPZwY zMj)kzZ`#mc(R=LrJ1?@-^fSTRkjUCVB9Ixz@P#Ti`$XB%N?>$gQPuG(pyegd4<#Sx ziwO1eSH&|V4Do@UU%bWz)3@H{@;b#>b%5!qvt5tnCf7z>u|#xA<@M8)If5NSqiph% z=ev6;=-T*0g2`=CR1;q(#_OXJxrrIuvbJL09_BumHCr-3UvR8 ziK){RNQzknojI>A!c#-q`%+e_bfU@Q2PC_#>aI|&q6^-x(s`yn#zi()4PMSDG#BN> z*6%7@oz@sCzG=!3MRDe7sdb|zX_w5O9faLT2jTmUENy6{h(JQse6QPXT~0pj=VN$sHNz!q_=)a4w!B%fD}`qVj@@;WA5HB_&%bc^|4KMd z$?aY==a&BAsOtnYtT)B;KTE?(V!5DBRaf+USwnhvW0KEFqA3~5Mk8$%6gpel%3&Ab zz>nfdvd6xQ3$#_k9BjF}a}u2-_I8mXsI|wuZt?TjUWKS6hp#QaCqC4BQoQKX>$qbF zI>=X^8$CT2Z!WLH{FZby0f%W16I{C&xaEb;P{`pMs$jb~@<3w833?x_?GEqa+z%89 zSe@Z=o9(0PtclXBXG5o2-fMqjtL!Ci!WkgM^_ z-A>CgfDYvMt~~-kx61Eds~)~&qbNfS|9W&T9kl#}J%KFqXZ_uOOcJXFIi^Gy@@_y0 zGK{q=%l*AgBaVJB@x2|o1&;vZXSG7G3fhvDGCBw<*pi^aNbVK)8Mx&8du^$J{%+h< z(ZLvNZNkQSiwYiV7BDz}9Z#Ew@3Bw^AcMNd;7IHBv73ok<>!ss9#Sc6WQ_X#*jX-1 z-_Zg>Ntz{V&tU`^ZN!qHOG6uLec(iQF~n|lYSOgO(!W1164H!elFGmuh>%1WevSb0udS{B z4|LmdVxdvZJHHF4;-SA8VgZoiP3{3em40*kHJk#{&2b(IsBx*~;Zn!yU|8fMq!X!0 z>IxYWk`O{ja`1LvPEIJw0Z+Df4A_VgUrW9M$8*_rh9b6U|JFfHh@5ZJ?SSc<(Y(j4 zjKHwNgFya|ye|6~8a{ql{v6tARvCc`wgn;h7!B)r{`cNX=DW*@IHYYN$^KL=G8zs- zKy}Nuka&G)Q}zw!8Xlq}gMBD(7o4j-t_^&De;0HEs1As}oy-#PQ*UgBGTK zh6ZS>zA|GQ3c!deCcr6Y14ry%D87Hai%dt)B$@U4i~Y-e#l9CQ<(6Jxm;pWrT*g#B zS1$tg5oX?l8u^rGBoeuIlWX~idrf1r%^veDf`2DBeR zn9Wzs)mL?QKBV4W{3ecQYk$eLY~EVW1pj|O)BbP_|9brx6mN#Ls|SaU7=$1bpoBanm2|!P+9CMHF^xORg<-dFkMQ}*SG*S$tYl9&& z<`u>j)5?ftJqATBYu=-}toYE$t-xBw5SwkK(I*Zh42+@7ESS-%e4v8-r`cZG{!@Rk zn|&lu25!?d(vD>2K4iU}u+@q*paJ1Iyuhd|XWEkusQ|u*WiFArb)e8&^TwaHy7zwxKci!u3HqSZjesr4RX5S12TDZg0V0C1_c9G&K z;Ca7G5G6Uo9NMzl5#8P)=x8^Ndj_J)@lQ&=gFuMWpO@-J}EoE_KB>Sd%e|plaQ?!5e`VAGo z{ByQ8jlp$T{e`4Qtmw|8mGzIf@xd0Jtw0p0Am>JZ08VkX9Dwbq3pHkFG5_f=yPvt|4GGiVJ_QKO-O~sxddFZ0_08d<%6~&2o&PnHw;5jAClf&kA`wHxi|3#=&#oJ zW_NME=Luh0U)Z2a{~Pzp$Z!2^+LIQE%)!NRPq8QGt&I#dJ3F7v?jjgp4zxaSVqii;N$(x$>MK)%kti)|YfGMoRS zZ+h+`5z(p!fncIg0-eDEKk@2%tCEiv-&}9-^!dv;7ucM3XS)HwO&7Uj>7b_dnqb9$LCj$SH46N_%jhx}d7A>v-~VB-(GxON_rV@dx9vyN~plWz^)LWH|ypfWbG&nFyI z6VAT~^W9TLYmWtwOL^MCKA^SdWEJaM=#TL$MNH^l*BAl)I85q&8zTP-+SA0_`rB67 z=V-QpQE8K;`qY?yJ8xjV&C||qXUJ!JY-Ylmyt@?>MAl_IsE)djz{Hw{MdQ=9R~3piDxv2S%*i?=U=b z$_n49NO7>jRlEseG`VWwWQ|)(8Cl>({GJ6;7WplV$UTaKC_X_N|D#9i=Z%`y20%dN zOxn)ERU8D|0=2cZVda(vmL4;iAxI@YsmfA>V|IWn>6Q490bR$7boEdsq^CqRBbOu) zraYzU7~g5_Iny_+05I8NlItKhST}lj!M_Y6hFhTAO^$a=>BXyr6%ABH>buHuM1ui} z1no4ZSU5viTtZQ1H5Z@;o>&f?l9p+hK8&&WKEb#*Bs4fUxT39p?OXY{gZ}la`ba~Q z2p?j;eoI5sUK{LvPCOGbRGzo^)i*32NVO3eKJa|OS|onVeMB2tdg7lH4EE_mqXl}h zIdb{g_>N9d)x?3#3C8-P6;31cMg`I3=}cR6B$I0J{z$*;gEg!O=x&Pu$89QqMu%(g z_T#Jc*CO*TWS`<)c~iA(cVV@`wIr1~Pybralj8WmNN1yJ`@GaDOI_*5^{0sDtj?8| zO^=#FdRTQSUE(Mg?vahmB~YdyK4=@l4|-zMXHA$Pb)fQWA@AXlPK=V`5^81>&^$ra&gr3)|Lbg2d9mgN%V1pB#;`q{HyN z)2P`2@{X=_r9*spCc4$U#J!$8>?x&-5af6krlYwAbY#hA9)<`fiSQK~3ZJ4Bj#FRd ze6YaU{V4`&^69x(Js#}FP*Os!($F*__n_=EO--1Al3_8E-DiZAYOOE|>qfUMgyA(Shn+l`7ps&au z_IVc({PO#Ar+!j^9E&rW{Y~bgem{Vr(MLcvFA^=JW`ZGWp%7`Rb5J0(? z3FbaBK#RpT6_sNlmc=33J$_#q{WwAFA*8^f-rTY31JP!GWgQPP3+wd*S*gvMOr(hn zJ7DhBHaK|y`+WWG6QowTJjwp!5Fk8m5ry*~%N{3Pgve6v5?6#-tkQ}x0bN*^xNf3F z6mE=S{9F_^kro%_qx+Vs>GT< zx+bRPXiqj2iC#XY%g}4tx2Qqqsd?k{=Oqh%JomUJ#|pl|i_!&G+{;c$QOF6*Z}|x; zdQE6-I&J#f=gXn!P_2HsOdn?gYI>eK?>{s%R>>r`a_rx1E zcd)-#=EFTEyWr+H`Fwjyao(W4p2E+hf3Ed}@mR)WW>my*kA(4dM@Pu*i|xBS+iJ%* z6~7u4_{vmfNR8bn4_zxhncY^{FGjxN5^e{BcfXoOy<9Ds!v;ME= z#hJCvx%a*IwXbl)u{<#5g0qqV_Vec<)?uvwZUDjW_UA)VqizARBc%A?+miLVyo%`{FMhr&;bR!K&PDP#47sbPgdC{kbV)p)~ zi+pD2yc65&@#!0{F<(*pBU9 zY}YD5k}bj7!=t6?1DKN=E}{jT^-SiV*);HhyCPvfnWtNyi8;++tQm+D;RQJcCot&* zYW%~{_~h#~A;99K=*pF`2Dg&TCd04(nR~k1v%bLe1xiUS0!XcoxJXe#nYq=i@$zia zHxjSL#TzwWag}?E&AleasnBtrBsm?5r5ooL67r3%tOEY|>wb01-OiMhClbNbo`MM~ z-p(8bGCM-Rg;yFk&8sjf3FgCvrCc6wH+8kXtLdFf0f~CArD3(Ar7XSeE2On^pvafW zN*5J?BSNuW0?jGs?Eku=FvKu_Tv(oqF-^{yk`=+6L5$_wD$^u8*Ypx zc6_)lUUgPpQg?dnrZjy#wQt)C#BqCv0Xr_kFx&0_=jgYQ060{A^hkXFPRkIa-C&|N z|Iv~g+3So0?~s`El?=uv2KS87htIj4qx$iqRap|>Cn6OQ zlibEVn(@Cp!>IbaoL%J7H&B-A1-H1k7^#a!erTYN{z%udHet1lhruR{CE-Zvv42%z*Nflh8b6Wk_FrxrY1*l1M>GFZyShI~mr5&eMsxrxSOw=+--iEZJ$( zuC)qSUoje-t~Ua4Ca=*Ltp>L?=eC^zF9-U7ht&zpM_G0_Iw?iH%NEnj>nEpF6zR67 z9U67Mzm@WGVKe>>%vY4Y(2u==@9Ke0*&IIb&RHqTCGut+X^$^0-nPo5Z3gi2$K`fy@Ox@-VE-ELC1@DUsl}j(bX&}EN!YTV%britQYxUujoHo(%Vd>;D9 z6(J)lf$FO=)HFSj>{bkBWbj>S=aGmmwD9dT7SE*{I! z3J}N7-oU6|4l((Q#VUd~#(b+SEauT30n?KzFk$OzcLUf`K8-J(_O-^MU?nEpu+w8M z?DE@MQN5jhWeLmOL@fjEypKhpGU&ne2IO*i1!kXGy0bk!m)A=jE~zu$&g`REbyjxK zBjIeNR$#siJN4Tci)gSL*aSp}Xwu1;^A<}A96n8i zP7Ot7F#9h)_3yy**&NQB>ucs<$n8NsH~P3aVGYC4{Ga`G-=Ck{5F!onBtcAJ_!sdf zc^*VCO1W|V)Eiz|zWYq-Pf#v~{`XIz&oVIPR+@s4TM2gHY+Y_bO03^EONEDNV;^Uk zyApZKUGa_uHTC7s+~TCrw8*2(;URf;)bshIT!wEQ{Z9K_hPcF_$rlr|)futv8bixU zLbB=1BSjD8js|TTT(^wJ>BJsQs-DFDgGB=HnZpk$KM?Ky+EX^y=ZYqEeP=wxfYS`5 zC_qd>!)wZsBT>S~lcdo)Q+J5No%-OL;}TNByQ7EPXU{#D7}y372_tGLbN?!A1FVgQT# z9V~0%yJc*Zpto2dQE*qj^A+7^gU>TZHP%jO#u}61j!vRgZggrvo_56%%U64;d1`su z3TSX6u87SN(`7mOlrsWxtyB+wOS^>hm<{?hPSkMcltOY;SE%MP;dq*HCeff&uUndO zV$>9LVFi8MTc5PVE-07!e z?WC7s)l)jbrle`fFo!F}9l?#2C;Eo|Gw=oK+NEoP-oNPS&^A~r5@jPao0!)+an-Wo zo8VV^EHtgT;@I5Q@}hHs|FIajL1XRZ4Z*01x!Bl3`x>2;-*3rI(>;>z%kNag05E%f zS_bAoPDr1@VMBt#7v{w08o& zWIl4=HTGQ3yC=}lA?J$W_B8woX?S5NA>Xni6J0%H=Kl8ry z^)coZ?o#T-@&9?GGsDq|Y6$pHUALq$%QL;9vW!0EY0ZIK>e$}3ciPBVB|AZVj6sn; z_5RPAk9vpy}kM*7%2SlHM3?IfvP?rm>OT+i0uTVZ(I_g^uFyZ>h@{DBg8Ip&j(is||B zJtowf{vXvaxt_S{P>BZ(DZl_>?>pIIPP&37SV{;$+&n@M2^8%`w*wBwAWPc9bqxxGs=FWI&7T5=`;@maL9qmV}A z7t8R|5tiMCun?3)8*~p8Lpqk+I5Ecb;RKHDPW;Y@+E2)L|Dh6x%+)s~dBd(=FzKgF zqTg@XGocnOmpyfmL>bHW@#T*BwWU0zR8>B_+h#q*(=n_+V>GGNmaciS+>HgXPL2?KB zn0171F4$>C=nrI`v3`QNY207=w0@nf-)r0EHqN*_zo}Q*Km(yy-i&~#OxE|J3tK!AYm2Mn0rap zO%F9bT&jjM(#Kr?wBN^F)=R?lyrYf?3)n(?ecIW&5MzSD0i@zI#NS*6^w4K8M$73P zfzdpP@G3}YSvDp6ua3)-{POojve60p+Y-fT^ovgZ-JBVFWiQ@B+r6!O)sjo&36E;0 z@jcPLE3Cwe!bb`qs>4n(b>}W+P244Jjaz7}44#6@w#W}yyi<@WtZOFDi!Ca4Z8w)8 z0vr=pf}=UENh{lD%Tt~U{Bm@=3k|69s41(}wLYY_ub`l7s=xnxOWR`;JdOzu7xA649gXY?ejHXX{k8;-QAjUB0g{K|L9Hd|b5Q;gBP zBKm8LH~NmL58a1XE;#|M)67Hwgm7=+2*ucUBd!;1F3g|lp%0qu$Ls8-tTFDJl=F?Q zeKcOW;o*o58QW^hp-h_oWdjW+1KWAQYcK}}!-=9!!E8aUHfJlh?75qF2>s_u4g997 z{a_tlQ29GZtLY|#6sGpj?TOVrLtl-MLd_O-_Ra%(zg#Tim1xM3h$-+Eo#EM09kBQ^K`ffF{F5&ayWodQ2(Dms`3Eeuk?rL`NS^0kDD9GTew@W(Ix zSLcgs)PHP{pNcsLu8?^gPk^6@(!7MzL0D15H;%IeW1O@!v-O;t9K`;PpG~csJIj{H z^Z{K4o{PJge4V0=Q5%TTx>nPikk#VD^shWIRL2AC&pSXFar>w=Xy zzyJRH>FvP!&1i3q5F`2sHe?K#7d8a0p!JCXKyI;0vh{{)f?v0$5=#qy&rYs9-2!xd z^KB12kOnR^izL}ng22^ZX5G?h3tx>lt<*CA);LbHp?q4+w#|qEbx8|RqfPph&PTdf zhkrBo@+Y%OzbLf(6@NtCc-Xvl{P)N904$VC_(>rfpH4yRidZgC0BX+a$-n9~8AczonG~|_DT099Hw>^V{Jas4i zjFUTC%G85qh4HGk5R6072>rpVurhJ-r=KhcQ4((t5>VRcTgC-*La#R@I=AB%2h_cP z8Q`rrqP;EbjQeiC7hh(H1+pBeiOy^8+Jz|lasQRs+`7K?+*6cE5Q9aN`@N0F#vXZsfNhd%eEzUJR z@fpe?B`^>{afcOWn?oNHa3Z~}rV1n1O9LFgzFJk;`eR2J{>{G$+R|GB_6aU$ZG!a` zqT?q=Z2S@0BR;cbb#e~NH_EoPBu5=y^A0Mz^9RSA5tDi4F8Kl4sv~kancIR_*s2!Z zs{b{9QdoqxdP==px^oB-FQiv785zEBv>FTWX9q28w~xN=rjG&of9KN4i;ey?Y=pLH ztAcQNL|yGqid+TClqMH6O^;{WJ%h~Lneskj|EimglwKgj|~Nt6_-CkA@oed4{pj>3Zbzx0|R5$30V;K2xc7 zk+PH~)Pm+t_t!WeyTZ5x@LDek^yB=?;oQ=XE_IfdN4D!;3!#z{TnpI@O{tv4tW6U; zD0z@>eTi^QJ-3f$1(A%PP|~;t%u%elj5azw?v8TrzK?Qm6Nwpy;Zo|d=}{IBP3DC0 z>~C}wHqv2eSTTa_iB;cr8gB*k3bE|m7!W3tPs0}tq#KgF;wvK3F&EunHn3=P6)-xJ zwU`<6^Oee^s)zZQ?CJ-9qmGzIyl7P-M5$3+c7Xdkyk%ti+|iDEs1e!}1%$+PsQU3K ze_u>b!-2-``No{Z%JMz$`!a*^c`toRvFNYB|B7U>?2^SjQ(S4r(6##=6{h!<7V?N* zoqXO!_G=~~kwONs{X~QX~+A6J0MzbEh)+ zY;bVf)XcgL#2}V`%F9vf5#1JfxwR?ck}04Yb0jiO(yDHA+YE9os$^|H8<>+X-2gmV z466@)h}kh2NeB76tFLV+x}|9 z7-d{$B53}>=Z24i$Qcrq-yX)~thbYXlX)q^c-vJVg({X(?o)Fk`T_@I83(MUtTKDw_Fk`yEcCvcL#Qe zML2GII%5f%Sg_Qd%va=nP}?9|9SdqXZj(McLaPA^tN^BM7A&_Z#JDsF*IJ}}r=qhnKjw69JMPC#Rj& z>syOix>^TBn+?cm)AEode{4NA)izrGV&r)Jn2&&)AYb)Ndyt%cO+V;Z;8^R9f3ud& zDB!FvVmCkVCRH3TQ%^XXJyFT6m$Vl+YD81%f z`8n~*tNIi$uJ+-ORg%HlU`J6&P*1r#vWGEvbN@erBg{pVZi??OLz8Mcs<^3n1<}9G z^E8T%2IOM4D$5sUm_L2SL;1+wWy}Fpv{=7b7Qk$BLnPtQQZL)6;SSGSE;ooXtWmpJ z9GL>%+K6aQr2j+_ry3{81qrXIVpkugZ49Fh#!iUD#+DqqW$OGj>qm>e$eL_&FT%Sa zR}}T>49_;+p~uXcoUUGT7WW3PpzbZSXZGXW0whG*oc|D6$|?TJI9}FFdoQuyR!t92 ztE9LaGe5*cLAh_&u{Nz4bbpjfe^3Z8c`#Vf^)^npaetV-`LBzv{UuXFOy5hIT6(D1 z!|!w1&z1Iz1k^D09r3&x>(~`n=^HU>bNkm)w{SyDJ7l`hXY`I3(4u!$#XSJA7}6e=;I@X@)AHjz~DX2sI%tFO>G)%v_XC;CBl! z5kwG-W%toS+$p742@D+O2sB?D(iQa%?k^@`BN#KlPt)@d6W^VEIg*4V^(~s?gNS?_s!aPY%QQCEkZ|;4v1;w5B^lqC&u1*UX-iW1y^Z6pZ(o@s za~0HXM7nvKansC+1~317==VE{D&LJ3TY{JzA&-EqS5d?2O)UqSB0UFgw|B!w2pMrP z+sDUig(njpZiBVrkIk*6JZ$5mKS_N_5Ikddq(-u#q!)_~&YIb`YaGQn9#&lAjD=82 zcRshiwXE$gnLl4yGB++e5&`TUnrJ_rhTnQ(P+tw;|A;MS3XrH^OF}26^Knwq?!et7 z1!hwMM6^MVGzQV_lhyv`J-&Ce;pAAEAR0FMTfQ; zyz6_}DPny1R1E24n$(J?qln+d>lFYuTB^J*pJSe42}uz_gv$8@z`EXN^hbYTnvS;n z;mA^NFVzJ2(HQ0Ik5}yY8HK`#^i{dL(dXxp#-#=px5w)Qqfcwo0@iE6f5=lMd<`I5 zAy%W?TNyB70Erd73MzebdnF-Hyj}8yKY9H?KmMBpdT#7_ce;!Y(JT>I*|S;w(teh` zNqj!Ad+pD%o=f@PK(@fcpLf35^iWxP4H$XeuG?dleTd%MA^pQ|<>=sOj;J5yVqt3A zNchUp{!p#{in>SjBmK&R(Mi{NQkYzl2w&N*MP+=>5! zVP{DK#SAVJ*}i)tH}=zKaf|Fty0*OU%?vEO!+LyX#+v?&5wkUMD9TGnP|P$=RsBkx zw&P!%QgDs-qxi{}erT)KWdQJOGg!v4B2C_`*d81WcE8M-SQ-t)4>YMCy`&Y@kB3A$ zR76CMoo1+R*nD!})A$~`YTs{U4b=)EX>g1g{nFav|C8&z*%MGckj_EVw%z!qM#>ji z?Z)dQ+aK?clK@_Jby%#(VC*^%u0 zz7)1p~@3s7aT)6}sgqRWaq@WY3AsNHmJkYlQt} zlEhO|MK!5VP4ZQaE8QlwzEk;4!~#->;gA?$2U)yFE;p8AWE*z^5L^$iQ9%#A0tH+OP*kZg z9w;D|6aAf6OIykt;;gQJxnUE%QMIT%N*g%NQ9iE9kno|BG63AbF3|gZ6PKoGv=9Qp zK>bBMhor4x9V3En(Z}uTjGqP07R>#{+W(CBo^r4pKRh&h(tPvK)%&$^824zlbWq`z zQ1`$$r4zBPG_={2u%~y2$XAmX@rX)*Jq}a)}hmJ}w7= z`2wEQ%o&8t2m2d2yFDr6d2LoxxmIU@=4+tsVu10ADw=u1@55BVMv7=;Gbb$|&h$Ky z0gto!)<$gFl>A}sX4(BBA=zq!Ob(l7yH1r