From cf6485df3c06a8a847ddc6876d199c69a90c575b Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 6 Jan 2026 17:07:21 -0700 Subject: [PATCH 01/99] rework profiling --- applications/llama_3.2_1b/custom_profile.py | 48 +++++++++++ applications/llama_3.2_1b/inference.py | 92 +-------------------- operators/common/aie_base.py | 24 ++++-- 3 files changed, 70 insertions(+), 94 deletions(-) create mode 100644 applications/llama_3.2_1b/custom_profile.py diff --git a/applications/llama_3.2_1b/custom_profile.py b/applications/llama_3.2_1b/custom_profile.py new file mode 100644 index 00000000..2f1199c5 --- /dev/null +++ b/applications/llama_3.2_1b/custom_profile.py @@ -0,0 +1,48 @@ +import sys, time, inspect, json + +# The current call stack; for each active function call, we store (func_identifier, start_time) +call_stack = [] + +# The cumulative time spent in each call stack path +# Map {function identifier: tuple (cumulative_time, {sub_call_function_identifier: cumulative_time, ...}) } +time_per_path = [0.0, {}] + + +def profile_call(frame, event, arg): + global call_stack + + timestamp = time.perf_counter() + + func_name = frame.f_code.co_name + filename = frame.f_code.co_filename + line_no = frame.f_lineno + func_identifier = f"{str(frame.f_code.co_filename)}:{frame.f_code.co_firstlineno}:{frame.f_code.co_name}" + + if event == "call": + call_stack.append((func_identifier, timestamp)) + elif event == "return": + if 0 == len(call_stack): + return + last_func_identifier, start_time = call_stack[-1] + if last_func_identifier != func_identifier: + print(call_stack) + raise RuntimeError(f"Function return mismatch: expected {last_func_identifier}, got {func_identifier}") + elapsed = timestamp - start_time + + this_path_time = time_per_path + for f, _ in call_stack: + this_path_time = this_path_time[1].setdefault(f, [0.0, {}]) + this_path_time[0] += elapsed + + call_stack.pop() + + +def enable_profiling(): + sys.setprofile(profile_call) + + +def store_profile(path): + sys.setprofile(None) + with open(path, "w") as f: + json.dump(time_per_path[1], f, indent=2) + diff --git a/applications/llama_3.2_1b/inference.py b/applications/llama_3.2_1b/inference.py index a2408d47..9ce47c22 100755 --- a/applications/llama_3.2_1b/inference.py +++ b/applications/llama_3.2_1b/inference.py @@ -37,91 +37,7 @@ generate, ) -# Global logger for profiling -_profile_logger = None - - -def profile_function_calls(frame, event, arg): - """ - Profile function that logs start and end times of every function call. - - Args: - frame: The current stack frame - event: The event type ('call', 'return', 'c_call', 'c_return', 'c_exception') - arg: Event-specific argument - """ - global _profile_logger - - if _profile_logger is None: - return - - func_name = frame.f_code.co_name - filename = frame.f_code.co_filename - line_no = frame.f_lineno - - # Create a readable function identifier - func_identifier = f"{filename}:{func_name}:{line_no}" - - if event == "call": - # Function is being called - timestamp = time.perf_counter() - _profile_logger.debug(f"[CALL] {func_identifier} started at {timestamp:.9f}") - - elif event == "return": - # Function is returning - timestamp = time.perf_counter() - _profile_logger.debug(f"[RETURN] {func_identifier} ended at {timestamp:.9f}") - - return profile_function_calls - - -def enable_profiling(logs_dir_name): - """Enable function call profiling using sys.setprofile.""" - global _profile_logger - - # Create a dedicated logger for profiling - _profile_logger = logging.getLogger("function_profiler") - _profile_logger.setLevel(logging.DEBUG) - # Prevent propagation to root logger to avoid console output - _profile_logger.propagate = False - - # Create log file for profiling data - timestamp = time.strftime("%Y%m%d_%H%M%S") - log_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - logs_dir_name, - f"profile_{timestamp}.log", - ) - - # Add file handler for profiling (only file, no console output) - profile_handler = logging.FileHandler(log_path) - profile_handler.setLevel(logging.DEBUG) - profile_formatter = logging.Formatter("%(asctime)s - %(message)s") - profile_handler.setFormatter(profile_formatter) - _profile_logger.addHandler(profile_handler) - - # Set the profile function - sys.setprofile(profile_function_calls) - _profile_logger.info("Function profiling enabled") - - # Explicitly call profile_function_calls to log this function's call - import inspect - - frame = inspect.currentframe() - profile_function_calls(frame, "call", None) - - -def disable_profiling(): - """Disable function call profiling.""" - global _profile_logger - - sys.setprofile(None) - if _profile_logger: - _profile_logger.info("Function profiling disabled") - # Close all handlers - for handler in _profile_logger.handlers[:]: - handler.close() - _profile_logger.removeHandler(handler) +import custom_profile _iron_chat = r""" @@ -421,7 +337,7 @@ def set_prefill_time(): # Enable function profiling if args.profile: - enable_profiling(logs_dir_name) + custom_profile.enable_profiling() try: prompt = args.prompt @@ -445,5 +361,5 @@ def set_prefill_time(): ) finally: if args.profile: - # Disable profiling when done - disable_profiling() + custom_profile.store_profile(Path(logs_dir_name) / "profile.json") + diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index 5238f6f5..fa74f7b3 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -213,17 +213,29 @@ def run_runlist(self): insts_bos = set( self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist ) - for bo in bos | insts_bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + sync_to_device(bos | insts_bos) start = time.perf_counter() - self.xrt_runlist.execute() - self.xrt_runlist.wait() + execute_runlist(self.xrt_runlist) stop = time.perf_counter() - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + sync_from_device(bos) elapsed = stop - start return elapsed +def sync_to_device(bos): + for bo in bos: + bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + + +def sync_from_device(bos): + for bo in bos: + bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + + +def execute_runlist(runlist): + runlist.execute() + runlist.wait() + + class AIEOperatorConstraintError(RuntimeError): pass From af46d819ba64905d10e7d04fcba1e0ae469d99d5 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 6 Jan 2026 17:07:36 -0700 Subject: [PATCH 02/99] vibe-coded flame graph visualization --- .../llama_3.2_1b/visualize_profile.py | 413 ++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 applications/llama_3.2_1b/visualize_profile.py diff --git a/applications/llama_3.2_1b/visualize_profile.py b/applications/llama_3.2_1b/visualize_profile.py new file mode 100644 index 00000000..23a36040 --- /dev/null +++ b/applications/llama_3.2_1b/visualize_profile.py @@ -0,0 +1,413 @@ +import json +import argparse +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from collections import defaultdict + +# Try to import seaborn, fall back to matplotlib if not available +try: + import seaborn as sns + HAS_SEABORN = True +except ImportError: + HAS_SEABORN = False + +def load_profile_data(json_file): + """Load profile data from JSON file.""" + with open(json_file, 'r') as f: + return json.load(f) + +def extract_function_info(full_name): + """Extract function name, filename, path components, and full path from the identifier.""" + # Remove parameters if present + if '(' in full_name: + full_name = full_name.split('(')[0] + + # Split by '/' to get path components + path_parts = full_name.split('/') + + # Get the last part which contains filename:line:function + last_part = path_parts[-1] + parts = last_part.split(':') + + if len(parts) >= 3: + # Format: filename:line:function + filename = parts[0].strip() + func_name = parts[-1].strip() + elif len(parts) >= 2: + # Format: filename:function or similar + filename = parts[0].strip() + func_name = parts[-1].strip() + else: + filename = "" + func_name = full_name.strip() + + # Store directory components (excluding the last part which is filename:line:func) + dir_parts = path_parts[:-1] if len(path_parts) > 1 else [] + + return { + 'func_name': func_name, + 'filename': filename, + 'dir_parts': dir_parts, + 'full_path': full_name + } + +def collect_unique_identifiers(profile_data, time_threshold_pct=1.0): + """ + Collect all unique function identifiers from the profile data. + + Args: + profile_data: Dict {func: [time, {children}]} + time_threshold_pct: Minimum percentage to consider + + Returns: + Set of unique function identifiers + """ + # Calculate total time + total_time = sum(child[0] for child in profile_data.values() if isinstance(child, list)) + if total_time == 0: + return set() + + threshold = (time_threshold_pct / 100.0) * total_time + identifiers = set() + + def collect_from_node(func_id, node_data): + """Recursively collect identifiers.""" + if not isinstance(node_data, list) or len(node_data) != 2: + return + + time, children = node_data + + # Add this identifier (regardless of threshold, we want all unique functions) + identifiers.add(func_id) + + # Recurse to children + for child_id, child_data in children.items(): + collect_from_node(child_id, child_data) + + # Process all root functions + for func_id, func_data in profile_data.items(): + collect_from_node(func_id, func_data) + + return identifiers + +def build_disambiguation_map(identifiers): + """ + Build a map from full identifier to minimal disambiguated name. + + Args: + identifiers: Set of unique function identifiers + + Returns: + Dict mapping full_identifier -> disambiguated_name + """ + from collections import Counter, defaultdict + + # Extract info for all identifiers + full_info = {} + func_name_groups = defaultdict(list) + + for full_id in identifiers: + info = extract_function_info(full_id) + full_info[full_id] = info + func_name_groups[info['func_name']].append(full_id) + + result = {} + + # Process each group of same-named functions + for func_name, id_list in func_name_groups.items(): + if len(id_list) == 1: + # Unique function name, use as-is + result[id_list[0]] = func_name + else: + # Multiple functions with same name, need disambiguation + # Try progressively longer path suffixes until we find something unique + max_dirs = max(len(full_info[full_id]['dir_parts']) for full_id in id_list) + + disambiguated = False + for num_dirs in range(0, max_dirs + 1): + # Build candidates with this many directory components + candidates = {} + for full_id in id_list: + info = full_info[full_id] + dir_parts = info['dir_parts'] + filename = info['filename'] + + if num_dirs == 0: + # Just filename + candidate = f"{filename}:{func_name}" if filename else func_name + else: + # Take last num_dirs directories + filename + path_suffix = dir_parts[-num_dirs:] if len(dir_parts) >= num_dirs else dir_parts + if path_suffix and filename: + candidate = "/".join(path_suffix) + f"/{filename}:{func_name}" + elif filename: + candidate = f"{filename}:{func_name}" + else: + candidate = func_name + + candidates[full_id] = candidate + + # Check if all candidates are unique + if len(set(candidates.values())) == len(candidates): + # Apply the disambiguation to all functions in this group + result.update(candidates) + disambiguated = True + break + + # Fallback to full path if still not unique (shouldn't happen) + if not disambiguated: + for full_id in id_list: + result[full_id] = full_info[full_id]['full_path'] + + return result + +def build_hierarchical_layout(profile_data, time_threshold_pct=1.0, zoom_path=None): + """ + Build hierarchical layout for flame graph with proper parent-child positioning. + + Args: + profile_data: Either dict {func: [time, {children}]} or [time, {child_calls}] + time_threshold_pct: Minimum percentage of total time to display + zoom_path: Optional list of disambiguated function names to zoom into (e.g., ['inference', 'generate']) + + Returns: + List of rectangles with (depth, x_start, width, func_name, time, pct) + """ + # Handle both formats: dict or [time, {children}] + if isinstance(profile_data, dict): + root_children = profile_data + elif isinstance(profile_data, list) and len(profile_data) == 2: + _, root_children = profile_data + else: + return [], 0.0 + + if root_children: + # Calculate total time from root level + total_time = sum(child[0] for child in root_children.values() if isinstance(child, list)) + if total_time == 0: + return [], 0.0 + + # Build disambiguation map for all unique functions + unique_identifiers = collect_unique_identifiers(root_children, time_threshold_pct) + disambig_map = build_disambiguation_map(unique_identifiers) + + # If zoom_path is specified, find the subtree to zoom into + if zoom_path: + # Build reverse map: disambiguated_name -> [full_identifiers] + reverse_map = defaultdict(list) + for full_id, disambig_name in disambig_map.items(): + reverse_map[disambig_name].append(full_id) + + # Navigate to the zoomed node + current_data = root_children + current_depth = 0 + + for target_name in zoom_path: + # Find matching function in current level + found = False + for func_id, func_data in current_data.items(): + disambig_name = disambig_map.get(func_id, func_id) + if disambig_name == target_name: + if isinstance(func_data, list) and len(func_data) == 2: + _, current_data = func_data + current_depth += 1 + found = True + break + + if not found: + print(f"Warning: Could not find '{target_name}' in zoom path. Available at this level:") + for func_id in list(current_data.keys())[:10]: + print(f" - {disambig_map.get(func_id, func_id)}") + return [], 0.0 + + # Use the zoomed subtree as root + root_children = current_data + # Recalculate total time for the zoomed view + total_time = sum(child[0] for child in root_children.values() if isinstance(child, list)) + if total_time == 0: + return [], 0.0 + + threshold = (time_threshold_pct / 100.0) * total_time + rectangles = [] + + def process_node(func_name, node_data, depth, x_start, parent_time=None): + """Recursively process nodes and position them.""" + if not isinstance(node_data, list) or len(node_data) != 2: + return + + time, children = node_data + + # Calculate width as proportion of total time + width = time / total_time + pct_total = (time / total_time) * 100 + + # Calculate percentage relative to parent (if parent exists) + if parent_time is not None and parent_time > 0: + pct_parent = (time / parent_time) * 100 + else: + pct_parent = 100.0 # Root nodes are 100% of themselves + + # Get disambiguated name from the map + display_name = disambig_map.get(func_name, func_name) + + # Add rectangle for this function + # Mark whether it should be labeled based on threshold + rectangles.append({ + 'depth': depth, + 'x_start': x_start, + 'width': width, + 'func_name': display_name, + 'full_identifier': func_name, + 'time': time, + 'pct': pct_parent, # Use parent-relative percentage + 'pct_total': pct_total, # Keep total percentage for reference + 'show_label': time >= threshold + }) + + # Process children with proper positioning + # Children should be positioned within this function's span + child_x = x_start + for child_name, child_data in children.items(): + if isinstance(child_data, list) and len(child_data) == 2: + child_time = child_data[0] + process_node(child_name, child_data, depth + 1, child_x, parent_time=time) + # Move position for next child + child_x += child_time / total_time + + # Process all root-level functions + x_pos = 0.0 + for func_name, func_data in root_children.items(): + if isinstance(func_data, list) and len(func_data) == 2: + func_time = func_data[0] + process_node(func_name, func_data, 0, x_pos) + x_pos += func_time / total_time + + return rectangles, total_time + + return [], 0.0 + +def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): + """Draw flame graph visualization.""" + if not rectangles: + print("No data to visualize") + return + + # Calculate layout + max_depth = max(rect['depth'] for rect in rectangles) + fig, ax = plt.subplots(figsize=(20, max_depth + 2)) + + # Color palette - rocket colormap + if HAS_SEABORN: + colors = sns.color_palette("pastel") + else: + # Use matplotlib's tab20 colormap + cmap = plt.cm.get_cmap('tab20') + colors = [cmap(i) for i in range(20)] + + for rect in rectangles: + depth = rect['depth'] + x_start = rect['x_start'] + width = rect['width'] + func_name = rect['func_name'] + pct = rect['pct'] + time_abs = rect['time'] + + # Convert to absolute time coordinates + x_start_abs = x_start * total_time + width_abs = width * total_time + + # Choose color based on function name hash + color_idx = hash(func_name) % len(colors) + + patch = mpatches.Rectangle( + (x_start_abs, depth), width_abs, 0.8, + facecolor=colors[color_idx], + edgecolor='black', + linewidth=1 + ) + ax.add_patch(patch) + + # Add text label if above threshold AND width is sufficient + # Use absolute width for threshold check + if rect.get('show_label', True) and width_abs > 0.015 * total_time: # Threshold in absolute time + # Create wrapped text that fits within the rectangle + import textwrap + + # Calculate approximate character width based on rectangle width + # Rough estimate: each character is about 0.06 inches at fontsize 7 + fig_width_inches = 20 # From figsize + chars_per_inch = 14 # Approximate at fontsize 7 + rect_width_inches = (width_abs / total_time) * fig_width_inches + max_chars = int(rect_width_inches * chars_per_inch) + max_chars = max(max_chars, 3) # At least 3 characters + + # Wrap the function name + wrapped_name = '\n'.join(textwrap.wrap(func_name, width=max_chars, break_long_words=True, break_on_hyphens=False)) + + # Build label with wrapped name + label = f"{wrapped_name}\n{pct:.1f}%\n{time_abs:.3f}s" + + # Limit number of lines to fit in rectangle height (0.8 units) + max_lines = 3 # Approximately 3 lines fit in 0.8 height + label_lines = label.split('\n') + if len(label_lines) > max_lines: + label = '\n'.join(label_lines[:max_lines]) + + ax.text( + x_start_abs + width_abs/2, depth + 0.4, + label, + ha='center', va='center', + fontsize=7, + clip_on=True + ) + + ax.set_xlim(0, total_time) + ax.set_ylim(-0.5, max_depth + 0.5) + ax.set_xlabel('Cumulative Time (seconds)', fontsize=12) + ax.set_ylabel('Call Stack Depth', fontsize=12) + ax.set_title('Flame Graph - Profile Visualization', fontsize=14, weight='bold') + ax.set_yticks(range(max_depth + 1)) + ax.grid(axis='y', alpha=0.3) + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + print(f"Flame graph saved to {output_file}") + plt.show() + +def main(): + parser = argparse.ArgumentParser(description='Generate flame graph from profile JSON data') + parser.add_argument('input', nargs='?', default='profile.json', + help='Input JSON profile file (default: profile.json)') + parser.add_argument('-o', '--output', default='flame_graph.png', + help='Output flame graph image file (default: flame_graph.png)') + parser.add_argument('-t', '--threshold', type=float, default=1.0, + help='Time threshold percentage for displaying functions (default: 1.0)') + parser.add_argument('-z', '--zoom', type=str, default=None, + help='Zoom into a specific call path using disambiguated names separated by ">" (e.g., "inference>generate")') + + args = parser.parse_args() + + # Parse zoom path if provided + zoom_path = None + if args.zoom: + zoom_path = [name.strip() for name in args.zoom.split('>')] + print(f"Zooming into path: {' > '.join(zoom_path)}") + + # Load profile JSON + profile_data = load_profile_data(args.input) + + # Build hierarchical layout with specified threshold and zoom + rectangles, total_time = build_hierarchical_layout(profile_data, time_threshold_pct=args.threshold, zoom_path=zoom_path) + + if not rectangles: + print("No data to visualize") + return + + print(f"Total profiled time: {total_time:.2f}s") + print(f"Displaying {len(rectangles)} function calls above {args.threshold}% threshold") + + # Draw flame graph + draw_flame_graph(rectangles, total_time, output_file=args.output) + +if __name__ == '__main__': + main() From 1f184ebc0f458879a3b9bbe7f061d982feca7d0c Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 8 Jan 2026 09:16:26 -0700 Subject: [PATCH 03/99] plot updates --- applications/llama_3.2_1b/bar_plot_profile.py | 150 ++++++++++++++++++ .../llama_3.2_1b/visualize_profile.py | 54 +++++-- 2 files changed, 189 insertions(+), 15 deletions(-) create mode 100755 applications/llama_3.2_1b/bar_plot_profile.py diff --git a/applications/llama_3.2_1b/bar_plot_profile.py b/applications/llama_3.2_1b/bar_plot_profile.py new file mode 100755 index 00000000..05f4ab82 --- /dev/null +++ b/applications/llama_3.2_1b/bar_plot_profile.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +Generate a bar plot showing the top 15 most expensive functions by cumulative time. +""" +import json +import argparse +import matplotlib.pyplot as plt +from collections import defaultdict + +def load_profile_data(json_file): + """Load profile data from JSON file.""" + with open(json_file, 'r') as f: + return json.load(f) + +def extract_function_name(full_identifier): + """Extract just the function name from the full identifier.""" + # Remove parameters if present + if '(' in full_identifier: + full_identifier = full_identifier.split('(')[0] + + # Split by '/' to get path components + path_parts = full_identifier.split('/') + + # Get the last part which contains filename:line:function + last_part = path_parts[-1] + parts = last_part.split(':') + + if len(parts) >= 3: + # Format: filename:line:function + return parts[-1].strip() + elif len(parts) >= 2: + # Format: filename:function or similar + return parts[-1].strip() + else: + return full_identifier.strip() + +def aggregate_time_by_function(profile_data): + """ + Aggregate cumulative time for each function across all call sites. + + Args: + profile_data: Dict {func: [time, {children}]} + + Returns: + Dict mapping function name to total cumulative time + """ + time_by_function = defaultdict(float) + + def process_node(func_id, node_data): + """Recursively process nodes and accumulate time.""" + if not isinstance(node_data, list) or len(node_data) != 2: + return + + time, children = node_data + + # Extract function name and add time + func_name = extract_function_name(func_id) + time_by_function[func_name] += time + + # Recurse to children + for child_id, child_data in children.items(): + process_node(child_id, child_data) + + # Process all root functions + for func_id, func_data in profile_data.items(): + process_node(func_id, func_data) + + return time_by_function + +def create_bar_plot(time_by_function, output_file, top_n=15): + """ + Create a bar plot showing the top N most expensive functions. + + Args: + time_by_function: Dict mapping function name to cumulative time + output_file: Path to save the plot + top_n: Number of top functions to display + """ + # Sort by time and get top N + sorted_functions = sorted(time_by_function.items(), key=lambda x: x[1], reverse=True) + top_functions = sorted_functions[:top_n] + + # Prepare data for plotting + function_names = [func for func, _ in top_functions] + times = [time for _, time in top_functions] + + # Create the plot + fig, ax = plt.subplots(figsize=(12, 8)) + + # Create horizontal bars (easier to read function names) + bars = ax.barh(range(len(function_names)), times, color='steelblue') + + # Customize the plot + ax.set_yticks(range(len(function_names))) + ax.set_yticklabels(function_names) + ax.set_xlabel('Cumulative Time (seconds)', fontsize=12) + ax.set_ylabel('Function Name', fontsize=12) + ax.set_title(f'Top {top_n} Most Expensive Functions by Cumulative Time', fontsize=14, fontweight='bold') + + # Add value labels on the bars + for i, (bar, time) in enumerate(zip(bars, times)): + width = bar.get_width() + ax.text(width, bar.get_y() + bar.get_height()/2, + f' {time:.3f}s', + ha='left', va='center', fontsize=10) + + # Invert y-axis so the highest time is at the top + ax.invert_yaxis() + + # Add grid for better readability + ax.grid(axis='x', alpha=0.3, linestyle='--') + ax.set_axisbelow(True) + + # Tight layout + plt.tight_layout() + + # Save the plot + plt.savefig(output_file, dpi=150, bbox_inches='tight') + print(f"Bar plot saved to {output_file}") + + # Print summary statistics + total_time = sum(times) + print(f"\nTop {top_n} Functions Summary:") + print(f"Total cumulative time (top {top_n}): {total_time:.3f}s") + for i, (func_name, time) in enumerate(top_functions, 1): + print(f"{i:2d}. {func_name:40s} {time:8.3f}s") + +def main(): + parser = argparse.ArgumentParser( + description='Generate a bar plot of the top N most expensive functions by cumulative time' + ) + parser.add_argument('input', help='Input profile JSON file') + parser.add_argument('-o', '--output', default='bar_plot.png', + help='Output image file (default: bar_plot.png)') + parser.add_argument('-n', '--top-n', type=int, default=15, + help='Number of top functions to display (default: 15)') + + args = parser.parse_args() + + # Load profile data + profile_data = load_profile_data(args.input) + + # Aggregate time by function name + time_by_function = aggregate_time_by_function(profile_data) + + # Create the bar plot + create_bar_plot(time_by_function, args.output, args.top_n) + +if __name__ == '__main__': + main() diff --git a/applications/llama_3.2_1b/visualize_profile.py b/applications/llama_3.2_1b/visualize_profile.py index 23a36040..442d7115 100644 --- a/applications/llama_3.2_1b/visualize_profile.py +++ b/applications/llama_3.2_1b/visualize_profile.py @@ -296,13 +296,13 @@ def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): max_depth = max(rect['depth'] for rect in rectangles) fig, ax = plt.subplots(figsize=(20, max_depth + 2)) - # Color palette - rocket colormap - if HAS_SEABORN: - colors = sns.color_palette("pastel") - else: - # Use matplotlib's tab20 colormap - cmap = plt.cm.get_cmap('tab20') - colors = [cmap(i) for i in range(20)] + # Define base colors: blue and green alternating by row + import colorsys + blue_hue = 0.58 # Blue in HSV + green_hue = 0.33 # Green in HSV + + # Track x-position at each depth to determine column parity + depth_positions = {} for rect in rectangles: depth = rect['depth'] @@ -316,17 +316,41 @@ def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): x_start_abs = x_start * total_time width_abs = width * total_time - # Choose color based on function name hash - color_idx = hash(func_name) % len(colors) + # Alternate hue between blue and green by depth (row) + hue = blue_hue if depth % 2 == 0 else green_hue + + # Track column index at this depth + if depth not in depth_positions: + depth_positions[depth] = [] + + # Find column index (how many rectangles we've seen at this depth) + column_idx = len(depth_positions[depth]) + depth_positions[depth].append(x_start_abs) + # Alternate brightness by column: odd columns are lighter, even are darker + if column_idx % 2 == 0: + saturation = 0.6 + value = 0.85 + else: + saturation = 0.5 + value = 0.95 + + # Convert HSV to RGB + rgb = colorsys.hsv_to_rgb(hue, saturation, value) + + # Create rectangle with no vertical spacing (height=1.0) and only left/right borders patch = mpatches.Rectangle( - (x_start_abs, depth), width_abs, 0.8, - facecolor=colors[color_idx], - edgecolor='black', - linewidth=1 + (x_start_abs, depth), width_abs, 1.0, + facecolor=rgb, + edgecolor='none', + linewidth=0 ) ax.add_patch(patch) + # Add left and right borders only + ax.plot([x_start_abs, x_start_abs], [depth, depth + 1.0], 'k-', linewidth=0.2, zorder=10) + ax.plot([x_start_abs + width_abs, x_start_abs + width_abs], [depth, depth + 1.0], 'k-', linewidth=0.2, zorder=10) + # Add text label if above threshold AND width is sufficient # Use absolute width for threshold check if rect.get('show_label', True) and width_abs > 0.015 * total_time: # Threshold in absolute time @@ -354,7 +378,7 @@ def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): label = '\n'.join(label_lines[:max_lines]) ax.text( - x_start_abs + width_abs/2, depth + 0.4, + x_start_abs + width_abs/2, depth + 0.5, label, ha='center', va='center', fontsize=7, @@ -362,7 +386,7 @@ def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): ) ax.set_xlim(0, total_time) - ax.set_ylim(-0.5, max_depth + 0.5) + ax.set_ylim(0, max_depth + 1) ax.set_xlabel('Cumulative Time (seconds)', fontsize=12) ax.set_ylabel('Call Stack Depth', fontsize=12) ax.set_title('Flame Graph - Profile Visualization', fontsize=14, weight='bold') From eedc527f0fed5d1fff4bb8f51df6a9e6f0c50acf Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 13 Jan 2026 16:19:00 -0700 Subject: [PATCH 04/99] simplified implementation (no KV cache yet) --- applications/llama_3.2_1b/llama_cpu.py | 272 ++++++++++++++++++ .../llama_3.2_1b/llama_inference_harness.py | 181 ++++++++++++ 2 files changed, 453 insertions(+) create mode 100755 applications/llama_3.2_1b/llama_cpu.py create mode 100644 applications/llama_3.2_1b/llama_inference_harness.py diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py new file mode 100755 index 00000000..20534485 --- /dev/null +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 + +import torch +import math +from llama_inference_harness import harness + +# Operators +# ########################################################################## + +def apply_rope(x, angles): + """Apply RoPE to input tensor x using precomputed angles.""" + # x: (batch, seq_len, num_heads, head_dim) after view and before transpose + # angles: (context_length, head_dim) + _, seq_len, _, head_dim = x.shape + angles_slice = angles[:seq_len] # (seq_len, head_dim) + + # Split into even and odd dimensions + x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) + x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) + + # Get cos and sin from angles + cos = angles_slice[:, ::2] # (seq_len, head_dim//2) + sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) + + # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] + rotated = torch.empty_like(x) + rotated[..., : head_dim // 2] = x1 * cos - x2 * sin + rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos + + return rotated + + +def rms_norm_forward(x, weight, eps=1e-5): + """RMSNorm: Root Mean Square Layer Normalization.""" + # x: (batch, seq_len, dim) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + return weight * x + + +def grouped_query_attention_forward( + x, + W_query, W_key, W_value, W_out, + angles, + mask=None, + num_heads=32, + num_kv_groups=8, + kv_cache=None, + input_pos=None, +): + """ + Grouped Query Attention forward pass. + + Steps: + 1. Linear projections (Q, K, V) + 2. Reshape for multi-head + 3. Apply RoPE to Q and K + 4. Repeat K and V for grouped attention + 5. Compute attention scores (Q @ K^T / sqrt(d)) + 6. Apply mask and softmax + 7. Compute attention output (scores @ V) + 8. Concatenate heads and project + """ + batch, seq_len, d_in = x.shape + head_dim = W_query.shape[0] // num_heads + + # Step 1: Linear projections + queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, d_out) + keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) + values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) + + # Step 2: Reshape for multi-head + queries = queries.view(batch, seq_len, num_heads, head_dim) + keys = keys.view(batch, seq_len, num_kv_groups, head_dim) + values = values.view(batch, seq_len, num_kv_groups, head_dim) + + # Step 3: Apply RoPE + queries = apply_rope(queries, angles) + keys = apply_rope(keys, angles) + + # Transpose for attention computation: (batch, num_heads, seq_len, head_dim) + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + # Step 4: Repeat K and V for grouped attention + group_size = num_heads // num_kv_groups + keys = keys.repeat_interleave(group_size, dim=1) + values = values.repeat_interleave(group_size, dim=1) + + # Step 5: Compute attention scores + # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) + # -> (batch, num_heads, seq_len, seq_len) + scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) + + # Step 6: Apply mask and softmax + if mask is not None: + scores = scores.masked_fill(mask, float('-inf')) + + attention_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Step 7: Compute attention output + # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + context = torch.matmul(attention_weights, values) + + # Step 8: Concatenate heads and project + # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) + context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + + output = torch.nn.functional.linear(context, W_out) + + return output + + +def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): + """ + SwiGLU Feed-Forward Network. + + SwiGLU: x -> (SiLU(fc1(x)) * fc2(x)) -> fc3 + where SiLU(x) = x * sigmoid(x) + + Steps: + 1. Two parallel linear projections (gate and up) + 2. Apply SiLU to gate + 3. Element-wise multiplication + 4. Down projection + """ + # Step 1: Parallel projections + gate = torch.nn.functional.linear(x, fc1_weight) # gate projection + up = torch.nn.functional.linear(x, fc2_weight) # up projection + + # Step 2: Apply SiLU activation + gate_activated = torch.nn.functional.silu(gate) + + # Step 3: Element-wise multiplication + hidden = gate_activated * up + + # Step 4: Down projection + output = torch.nn.functional.linear(hidden, fc3_weight) + + return output + + +def transformer_block_forward( + x, + weights, + layer_idx, + angles, + mask, + num_heads, + num_kv_groups, +): + """ + Transformer block forward pass. + + Steps: + 1. Pre-norm (RMSNorm) + 2. Grouped Query Attention + 3. Residual connection + 4. Post-norm (RMSNorm) + 5. Feed-Forward Network + 6. Residual connection + """ + # Step 1: Pre-norm + norm1_weight = weights[f'model.layers.{layer_idx}.input_layernorm.weight'] + x_norm = rms_norm_forward(x, norm1_weight) + + # Step 2: Attention + attn_W_query = weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'] + attn_W_key = weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'] + attn_W_value = weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'] + attn_W_out = weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'] + + attn_output = grouped_query_attention_forward( + x_norm, + attn_W_query, attn_W_key, attn_W_value, attn_W_out, + angles, + mask, + num_heads, + num_kv_groups, + ) + + # Step 3: Residual + x = x + attn_output + + # Step 4: Post-norm + norm2_weight = weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'] + x_norm = rms_norm_forward(x, norm2_weight) + + # Step 5: FFN + ffn_fc1 = weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'] + ffn_fc2 = weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'] + ffn_fc3 = weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'] + + ffn_output = swiglu_ffn_forward(x_norm, ffn_fc1, ffn_fc2, ffn_fc3) + + # Step 6: Residual + x = x + ffn_output + + return x + + +def llama_forward_pass( + input_ids, + weights, + angles, + config, +): + """ + Complete Llama model forward pass. + + Args: + input_ids: (batch, seq_len) token indices + weights: Dict of model weights from safetensors + angles: Precomputed RoPE angles + config: LlamaConfig with model hyperparameters + + Returns: + logits: (batch, seq_len, vocab_size) + """ + batch, seq_len = input_ids.shape + + # Step 1: Token embedding + tok_emb_weight = weights['model.embed_tokens.weight'] + x = torch.nn.functional.embedding(input_ids, tok_emb_weight) # (batch, seq_len, emb_dim) + + # Step 2: Create causal mask + mask = torch.triu( + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), + diagonal=1 + ) + + # Step 3: Apply transformer blocks + for layer_idx in range(config.n_layers): + x = transformer_block_forward( + x, + weights, + layer_idx, + angles, + mask, + config.n_heads, + config.n_kv_groups, + ) + + # Step 4: Final normalization + final_norm_weight = weights['model.norm.weight'] + x = rms_norm_forward(x, final_norm_weight) + + # Step 5: Output projection (check for tied embeddings) + if 'lm_head.weight' in weights: + lm_head_weight = weights['lm_head.weight'] + else: + lm_head_weight = weights['model.embed_tokens.weight'] + + logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + + return logits + + +# Main +# ########################################################################## + +def main(): + harness(llama_forward_pass) + +if __name__ == "__main__": + main() diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py new file mode 100644 index 00000000..c80349b6 --- /dev/null +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Inference harness -- all the necessary code _other_ than the actual model (forward pass). +Exposes a 'harness' function that can be called with a 'forward_pass' function that implements the model. +The 'harness' function does the following: +1. Load and set up model weights, tokenizer, and RoPE angle look-up table. +2. Tokenize the provided input prompt. +3. Run the generation loop to produce new tokens; this calls the provided forward_pass function. Decode and print each generated token. +""" + +import torch +import math +import sys + +import safetensors.torch +import tiktoken, tiktoken.load + + +# Configuration +# ########################################################################## + +class LlamaConfig: + """Fixed model configuration for Llama 3.2 1B""" + + # Model architecture + vocab_size = 128256 + emb_dim = 2048 + n_layers = 16 + n_heads = 32 + n_kv_groups = 8 + head_dim = emb_dim // n_heads # 64 + hidden_dim = 8192 + + # RoPE + rope_base = 500000.0 + context_length = 131072 + + # Generation + temperature = 0.7 + top_k = 50 + + # Sampling + dtype = torch.float32 + + # Tokenization + special_tokens = { + "<|begin_of_text|>": 128000, + "<|end_of_text|>": 128001, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eot_id|>": 128009, + } + special_tokens.update({ + f"<|reserved_{i}|>": i + for i in list(range(128002, 128006)) + list(range(128009, 128256)) + }) + + +# Utilities +# ########################################################################## + +def compute_rope_angles(head_dim, context_length, rope_base=500000.0): + """Compute RoPE (Rotary Position Embedding) angles.""" + # Precompute the frequency tensor + inv_freq = 1.0 / (rope_base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + position = torch.arange(context_length).float() + freqs = torch.outer(position, inv_freq) + + cos = torch.cos(freqs) + sin = torch.sin(freqs) + + # Interleave cos and sin - create angles buffer + angles = torch.empty(context_length, head_dim) + angles[:, ::2] = cos + angles[:, 1::2] = sin + return angles + + +def get_tokenizer(tokenizer_path, config): + mergeable = tiktoken.load.load_tiktoken_bpe(tokenizer_path) + return tiktoken.Encoding( + name="llama3.2-1b", + pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" + r"|[^\r\n\p{L}\p{N}]?\p{L}+" + r"|\p{N}{1,3}" + r"| ?[^\s\p{L}\p{N}]+[\r\n]*" + r"|\s*[\r\n]+" + r"|\s+(?!\S)" + r"|\s+", + mergeable_ranks=mergeable, + special_tokens=config.special_tokens, + ) + + +# Generation loop +# ########################################################################## + +def generate_token( + config, + weights, + angles, + forward_pass, + token_ids, +): + generated_tokens = [] + + # Step 1: Forward pass + logits = forward_pass( + token_ids, + weights, + angles, + config, + ) + + # Step 2: Get logits for last token + last_token_logits = logits[:, -1, :] # (batch, vocab_size) + + # Step 3: Temperature scaling + if config.temperature > 0: + last_token_logits = last_token_logits / config.temperature + + # Step 4: Top-k filtering + if config.top_k is not None: + top_logits, top_indices = torch.topk(last_token_logits, config.top_k) + min_val = top_logits[:, -1:] + last_token_logits = torch.where( + last_token_logits < min_val, + torch.tensor(float('-inf')), + last_token_logits + ) + + # Step 5: Sample + probs = torch.nn.functional.softmax(last_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + return next_token.item() + + +def harness( + forward_pass, + weights_path="/scratch/roesti/models/llama3.2-1b/model.safetensors", + tokenizer_path="/scratch/roesti/models/llama3.2-1b/tokenizer.model", + prompt="The capital of France is ", + num_tokens=100 +): + + seed = 1608560892 + torch.manual_seed(seed) + + config = LlamaConfig() + + # Load model weights and tokenizer + weights = safetensors.torch.load_file(weights_path) + tokenizer = get_tokenizer(tokenizer_path, config) + + # Compute RoPE angle look-up table + angles = compute_rope_angles( + config.head_dim, + config.context_length, + config.rope_base + ) + + # Tokenize prompt + token_ids = [config.special_tokens["<|begin_of_text|>"]] + token_ids += tokenizer.encode(prompt) + assert len(token_ids) + num_tokens <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" + token_ids = torch.tensor([token_ids], dtype=torch.long) + + # Generate tokens + print(prompt, end='', flush=True) + for _ in range(num_tokens): + next_token = generate_token(config, weights, angles, forward_pass, token_ids) + token_ids = torch.cat([token_ids, torch.tensor([[next_token]])], dim=1) + + token_text = tokenizer.decode([next_token]) + print(token_text, end='', flush=True) + + +if __name__ == "__main__": + main() + From 5372fcab2766fe1255506cc40df18f4b9ab0d441 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 14 Jan 2026 13:24:52 -0700 Subject: [PATCH 05/99] add KV cache --- applications/llama_3.2_1b/llama_cpu.py | 216 +++++++++--------- .../llama_3.2_1b/llama_inference_harness.py | 61 +++-- 2 files changed, 153 insertions(+), 124 deletions(-) diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py index 20534485..667cba62 100755 --- a/applications/llama_3.2_1b/llama_cpu.py +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -4,11 +4,12 @@ import math from llama_inference_harness import harness + # Operators # ########################################################################## -def apply_rope(x, angles): - """Apply RoPE to input tensor x using precomputed angles.""" +def rope_forward(x, angles): + """Rotary positional embedding using precomputed angles""" # x: (batch, seq_len, num_heads, head_dim) after view and before transpose # angles: (context_length, head_dim) _, seq_len, _, head_dim = x.shape @@ -23,6 +24,7 @@ def apply_rope(x, angles): sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) + # (The same cosine and sine values are used across batch and heads.) cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) @@ -35,7 +37,7 @@ def apply_rope(x, angles): def rms_norm_forward(x, weight, eps=1e-5): - """RMSNorm: Root Mean Square Layer Normalization.""" + """Root Mean Square Layer Normalization""" # x: (batch, seq_len, dim) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + eps) @@ -44,103 +46,99 @@ def rms_norm_forward(x, weight, eps=1e-5): def grouped_query_attention_forward( x, + keys_cache, + values_cache, W_query, W_key, W_value, W_out, angles, mask=None, num_heads=32, num_kv_groups=8, - kv_cache=None, - input_pos=None, ): - """ - Grouped Query Attention forward pass. - - Steps: - 1. Linear projections (Q, K, V) - 2. Reshape for multi-head - 3. Apply RoPE to Q and K - 4. Repeat K and V for grouped attention - 5. Compute attention scores (Q @ K^T / sqrt(d)) - 6. Apply mask and softmax - 7. Compute attention output (scores @ V) - 8. Concatenate heads and project - """ batch, seq_len, d_in = x.shape + assert W_query.shape[0] >= num_heads and W_query.shape[0] % num_heads == 0 head_dim = W_query.shape[0] // num_heads - + assert W_key.shape[0] == num_kv_groups * head_dim + assert W_value.shape[0] == num_kv_groups * head_dim + num_preceding_tokens = keys_cache.shape[2] + assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + # Step 1: Linear projections - queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, d_out) - keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) - values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) - - # Step 2: Reshape for multi-head - queries = queries.view(batch, seq_len, num_heads, head_dim) - keys = keys.view(batch, seq_len, num_kv_groups, head_dim) - values = values.view(batch, seq_len, num_kv_groups, head_dim) - - # Step 3: Apply RoPE - queries = apply_rope(queries, angles) - keys = apply_rope(keys, angles) - - # Transpose for attention computation: (batch, num_heads, seq_len, head_dim) - queries = queries.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) + # This multiplication produces queries, keys and values for all tokens in the sequence. + # The weight matrix is such that multiple queries, keys and values are generated for each token. + # For each token, each head corresponds to one query. + # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). + # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. + # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. + queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, num_heads * head_dim) + keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) + values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) + queries = queries.view(batch, seq_len, num_heads, head_dim) # (batch, seq_len, num_heads, head_dim) + keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) + values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) + + # Step 2: Apply RoPE + queries = rope_forward(queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + keys = rope_forward(keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + + # Step 3: Transpose for attention computation + # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. + # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + + # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. + keys_cache = torch.cat([keys_cache, keys], dim=2) + values_cache = torch.cat([values_cache, values], dim=2) + keys = keys_cache + values = values_cache - # Step 4: Repeat K and V for grouped attention + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value group_size = num_heads // num_kv_groups keys = keys.repeat_interleave(group_size, dim=1) values = values.repeat_interleave(group_size, dim=1) - # Step 5: Compute attention scores + # Step 6: Compute attention scores # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) # -> (batch, num_heads, seq_len, seq_len) + # Entry at row i, column j, indicates how much token i's query attends to token j's key. scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) - # Step 6: Apply mask and softmax + # Step 7: Apply mask + # This ensures causality, so that tokens in the future cannot attend to tokens in the past. if mask is not None: scores = scores.masked_fill(mask, float('-inf')) + # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) attention_weights = torch.nn.functional.softmax(scores, dim=-1) - # Step 7: Compute attention output + # Step 9: Compute attention output # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) # -> (batch, num_heads, seq_len, head_dim) context = torch.matmul(attention_weights, values) - # Step 8: Concatenate heads and project + # Step 10: Concatenate heads and project # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) output = torch.nn.functional.linear(context, W_out) - return output + return output, keys_cache, values_cache def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): - """ - SwiGLU Feed-Forward Network. - - SwiGLU: x -> (SiLU(fc1(x)) * fc2(x)) -> fc3 - where SiLU(x) = x * sigmoid(x) - - Steps: - 1. Two parallel linear projections (gate and up) - 2. Apply SiLU to gate - 3. Element-wise multiplication - 4. Down projection - """ - # Step 1: Parallel projections + # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) gate = torch.nn.functional.linear(x, fc1_weight) # gate projection up = torch.nn.functional.linear(x, fc2_weight) # up projection # Step 2: Apply SiLU activation - gate_activated = torch.nn.functional.silu(gate) + gate_activated = torch.nn.functional.silu(gate) # (batch, seq_len, swiglu_hidden_dim) - # Step 3: Element-wise multiplication - hidden = gate_activated * up + # Step 3: Element-wise multiplication (apply the 'gating') + hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) - # Step 4: Down projection + # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) output = torch.nn.functional.linear(hidden, fc3_weight) return output @@ -148,39 +146,33 @@ def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): def transformer_block_forward( x, - weights, - layer_idx, - angles, - mask, + attn_keys_cache, + attn_values_cache, num_heads, num_kv_groups, + W_norm1, + W_attn_query, + W_attn_key, + W_attn_value, + W_attn_out, + W_norm2, + W_ffn_fc1, + W_ffn_fc2, + W_ffn_fc3, + rope_angles, + attn_mask ): - """ - Transformer block forward pass. - - Steps: - 1. Pre-norm (RMSNorm) - 2. Grouped Query Attention - 3. Residual connection - 4. Post-norm (RMSNorm) - 5. Feed-Forward Network - 6. Residual connection - """ - # Step 1: Pre-norm - norm1_weight = weights[f'model.layers.{layer_idx}.input_layernorm.weight'] - x_norm = rms_norm_forward(x, norm1_weight) + # Step 1: RMS normalization + x_norm = rms_norm_forward(x, W_norm1) # Step 2: Attention - attn_W_query = weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'] - attn_W_key = weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'] - attn_W_value = weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'] - attn_W_out = weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'] - - attn_output = grouped_query_attention_forward( + attn_output, attn_keys, attn_values = grouped_query_attention_forward( x_norm, - attn_W_query, attn_W_key, attn_W_value, attn_W_out, - angles, - mask, + attn_keys_cache, + attn_values_cache, + W_attn_query, W_attn_key, W_attn_value, W_attn_out, + rope_angles, + attn_mask, num_heads, num_kv_groups, ) @@ -189,36 +181,35 @@ def transformer_block_forward( x = x + attn_output # Step 4: Post-norm - norm2_weight = weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'] - x_norm = rms_norm_forward(x, norm2_weight) - - # Step 5: FFN - ffn_fc1 = weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'] - ffn_fc2 = weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'] - ffn_fc3 = weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'] + x_norm = rms_norm_forward(x, W_norm2) - ffn_output = swiglu_ffn_forward(x_norm, ffn_fc1, ffn_fc2, ffn_fc3) + # Step 5: fully-connected feed-forward network + ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) # Step 6: Residual x = x + ffn_output - return x + return x, attn_keys, attn_values def llama_forward_pass( - input_ids, - weights, - angles, config, + weights, + rope_angles, + input_ids, + attn_keys_caches, + attn_values_caches, ): """ Complete Llama model forward pass. Args: - input_ids: (batch, seq_len) token indices - weights: Dict of model weights from safetensors - angles: Precomputed RoPE angles config: LlamaConfig with model hyperparameters + weights: Dict of model weights from safetensors + rope_angles: Precomputed RoPE angles + input_ids: (batch, seq_len) token index/indices + attn_keys_caches: Previously computed keys for past tokens, one for each transformer layer + attn_values_caches: Previously computed values for past tokens, one for each transformer layer Returns: logits: (batch, seq_len, vocab_size) @@ -230,21 +221,30 @@ def llama_forward_pass( x = torch.nn.functional.embedding(input_ids, tok_emb_weight) # (batch, seq_len, emb_dim) # Step 2: Create causal mask - mask = torch.triu( + attn_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): - x = transformer_block_forward( + x, attn_keys_caches[layer_idx], attn_values_caches[layer_idx] = transformer_block_forward( x, - weights, - layer_idx, - angles, - mask, + attn_keys_caches[layer_idx], + attn_values_caches[layer_idx], config.n_heads, config.n_kv_groups, + W_norm1=weights[f'model.layers.{layer_idx}.input_layernorm.weight'], + W_attn_query=weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], + W_attn_key=weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], + W_attn_value=weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], + W_attn_out=weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], + W_ffn_fc1=weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], + W_ffn_fc2=weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], + W_ffn_fc3=weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], + W_norm2=weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], + rope_angles=rope_angles, + attn_mask=attn_mask, ) # Step 4: Final normalization @@ -259,7 +259,7 @@ def llama_forward_pass( logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) - return logits + return logits, attn_keys_caches, attn_values_caches # Main diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index c80349b6..6a40ee16 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -11,6 +11,7 @@ import torch import math import sys +import time import safetensors.torch import tiktoken, tiktoken.load @@ -38,9 +39,6 @@ class LlamaConfig: # Generation temperature = 0.7 top_k = 50 - - # Sampling - dtype = torch.float32 # Tokenization special_tokens = { @@ -101,15 +99,19 @@ def generate_token( angles, forward_pass, token_ids, + attn_keys_caches=None, + attn_values_caches=None, ): generated_tokens = [] # Step 1: Forward pass - logits = forward_pass( - token_ids, + logits, attn_keys_caches, attn_values_caches = forward_pass( + config, weights, angles, - config, + token_ids, + attn_keys_caches, + attn_values_caches ) # Step 2: Get logits for last token @@ -133,7 +135,7 @@ def generate_token( probs = torch.nn.functional.softmax(last_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) - return next_token.item() + return next_token.item(), attn_keys_caches, attn_values_caches def harness( @@ -161,19 +163,46 @@ def harness( ) # Tokenize prompt - token_ids = [config.special_tokens["<|begin_of_text|>"]] - token_ids += tokenizer.encode(prompt) - assert len(token_ids) + num_tokens <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" - token_ids = torch.tensor([token_ids], dtype=torch.long) - + prompt_token_ids = [config.special_tokens["<|begin_of_text|>"]] + prompt_token_ids += tokenizer.encode(prompt) + assert len(prompt_token_ids) + num_tokens <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" + prompt_token_ids = torch.tensor([prompt_token_ids], dtype=torch.long) + + # Set up KV cache -- initially empty + # This is what passes information from previous tokens to the current token during generation + attn_keys_caches = [torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=weights["model.layers.0.self_attn.k_proj.weight"].dtype) for _ in range(config.n_layers)] # (batch_size, n_kv_groups, seq_len, head_dim) + attn_values_caches = [torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=weights["model.layers.0.self_attn.v_proj.weight"].dtype) for _ in range(config.n_layers)] # (batch_size, n_kv_groups, seq_len, head_dim) + # Generate tokens + # First token (prefill) + n_tokens_generated = 0 + t_prefill_start = time.perf_counter() + first_token, attn_keys_caches, attn_values_caches = generate_token(config, weights, angles, forward_pass, prompt_token_ids, attn_keys_caches, attn_values_caches) + token_text = tokenizer.decode([first_token]) + n_tokens_generated += 1 print(prompt, end='', flush=True) - for _ in range(num_tokens): - next_token = generate_token(config, weights, angles, forward_pass, token_ids) - token_ids = torch.cat([token_ids, torch.tensor([[next_token]])], dim=1) - + print(token_text, end='', flush=True) + t_prefill_stop = time.perf_counter() + + # Remaining tokens (decode) + last_token = torch.tensor([[first_token]]) + t_decode_start = time.perf_counter() + for _ in range(num_tokens-1): + next_token, attn_keys_caches, attn_values_caches = generate_token(config, weights, angles, forward_pass, last_token, attn_keys_caches, attn_values_caches) token_text = tokenizer.decode([next_token]) + n_tokens_generated += 1 print(token_text, end='', flush=True) + last_token = torch.tensor([[next_token]]) + t_decode_end = time.perf_counter() + + t_prefill = t_prefill_stop - t_prefill_start + t_decode = t_decode_end - t_decode_start + sys.stderr.write("\n\n=== Performance Statistics ===\n") + sys.stderr.write(f"[Prefill] Time to first token: {t_prefill:7.3f} s\n") + sys.stderr.write(f"[Decode] Time per token (mean): {t_decode / (n_tokens_generated - 1):7.3f} s\n") + sys.stderr.write(f"[Decode] Tokens per second: {(n_tokens_generated - 1) / t_decode:7.3f}\n") + sys.stderr.write(f"[Total] Time per token (mean): {(t_prefill + t_decode) / n_tokens_generated:7.3f} s\n") + sys.stderr.write(f"[Total] Tokens per second: {n_tokens_generated / (t_prefill + t_decode):7.3f}\n") if __name__ == "__main__": From f0d3289e631c23afa1c988cf1f8f8f466e7fa23a Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 14 Jan 2026 14:45:32 -0700 Subject: [PATCH 06/99] start with simplified llama for NPU --- applications/llama_3.2_1b/llama_cpu.py | 72 ++-- .../llama_3.2_1b/llama_inference_harness.py | 160 ++++---- applications/llama_3.2_1b/llama_npu.py | 358 ++++++++++++++++++ 3 files changed, 471 insertions(+), 119 deletions(-) create mode 100755 applications/llama_3.2_1b/llama_npu.py diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py index 667cba62..258e342c 100755 --- a/applications/llama_3.2_1b/llama_cpu.py +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -2,7 +2,7 @@ import torch import math -from llama_inference_harness import harness +import llama_inference_harness as harness # Operators @@ -194,31 +194,13 @@ def transformer_block_forward( def llama_forward_pass( config, - weights, - rope_angles, - input_ids, - attn_keys_caches, - attn_values_caches, + state ): - """ - Complete Llama model forward pass. - - Args: - config: LlamaConfig with model hyperparameters - weights: Dict of model weights from safetensors - rope_angles: Precomputed RoPE angles - input_ids: (batch, seq_len) token index/indices - attn_keys_caches: Previously computed keys for past tokens, one for each transformer layer - attn_values_caches: Previously computed values for past tokens, one for each transformer layer - - Returns: - logits: (batch, seq_len, vocab_size) - """ - batch, seq_len = input_ids.shape + batch, seq_len = state.token_ids.shape # Step 1: Token embedding - tok_emb_weight = weights['model.embed_tokens.weight'] - x = torch.nn.functional.embedding(input_ids, tok_emb_weight) # (batch, seq_len, emb_dim) + tok_emb_weight = config.weights['model.embed_tokens.weight'] + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) # (batch, seq_len, emb_dim) # Step 2: Create causal mask attn_mask = torch.triu( @@ -228,45 +210,45 @@ def llama_forward_pass( # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): - x, attn_keys_caches[layer_idx], attn_values_caches[layer_idx] = transformer_block_forward( + x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( x, - attn_keys_caches[layer_idx], - attn_values_caches[layer_idx], + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], config.n_heads, config.n_kv_groups, - W_norm1=weights[f'model.layers.{layer_idx}.input_layernorm.weight'], - W_attn_query=weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], - W_attn_key=weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], - W_attn_value=weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], - W_attn_out=weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], - W_ffn_fc1=weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], - W_ffn_fc2=weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], - W_ffn_fc3=weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], - W_norm2=weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], - rope_angles=rope_angles, + W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], + W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], + W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], + W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], + W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], + W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], + W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], + W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], + W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], + rope_angles=config.angles, attn_mask=attn_mask, ) # Step 4: Final normalization - final_norm_weight = weights['model.norm.weight'] + final_norm_weight = config.weights['model.norm.weight'] x = rms_norm_forward(x, final_norm_weight) - # Step 5: Output projection (check for tied embeddings) - if 'lm_head.weight' in weights: - lm_head_weight = weights['lm_head.weight'] - else: - lm_head_weight = weights['model.embed_tokens.weight'] + # Step 5: Output projection + lm_head_weight = config.weights['model.embed_tokens.weight'] logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) - - return logits, attn_keys_caches, attn_values_caches + + return logits, state # Main # ########################################################################## def main(): - harness(llama_forward_pass) + prompt = "The capital of France is " + config, state = harness.init(prompt=prompt) + print(prompt, end='', flush=True) + harness.generate(config, state, llama_forward_pass) if __name__ == "__main__": main() diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 6a40ee16..736d48c1 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -21,37 +21,65 @@ # ########################################################################## class LlamaConfig: - """Fixed model configuration for Llama 3.2 1B""" - - # Model architecture - vocab_size = 128256 - emb_dim = 2048 - n_layers = 16 - n_heads = 32 - n_kv_groups = 8 - head_dim = emb_dim // n_heads # 64 - hidden_dim = 8192 - - # RoPE - rope_base = 500000.0 - context_length = 131072 - - # Generation - temperature = 0.7 - top_k = 50 - - # Tokenization - special_tokens = { - "<|begin_of_text|>": 128000, - "<|end_of_text|>": 128001, - "<|start_header_id|>": 128006, - "<|end_header_id|>": 128007, - "<|eot_id|>": 128009, - } - special_tokens.update({ - f"<|reserved_{i}|>": i - for i in list(range(128002, 128006)) + list(range(128009, 128256)) - }) + def __init__(self, weights_path, tokenizer_path): + # Model architecture + self.vocab_size = 128256 + self.emb_dim = 2048 + self.n_layers = 16 + self.n_heads = 32 + self.n_kv_groups = 8 + self.head_dim = self.emb_dim // self.n_heads # 64 + self.hidden_dim = 8192 + + # RoPE + self.rope_base = 500000.0 + self.context_length = 131072 + + # Generation + self.temperature = 0.7 + self.top_k = 50 + + # Tokenization + self.special_tokens = { + "<|begin_of_text|>": 128000, + "<|end_of_text|>": 128001, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eot_id|>": 128009, + } + self.special_tokens.update({ + f"<|reserved_{i}|>": i + for i in list(range(128002, 128006)) + list(range(128009, 128256)) + }) + + # Load model weights and tokenizer + self.weights = safetensors.torch.load_file(weights_path) + self.tokenizer = get_tokenizer(tokenizer_path, self.special_tokens) + # TODO: Assert that weight dimensions match config + + # Compute RoPE angle look-up table + self.angles = compute_rope_angles( + self.head_dim, + self.context_length, + self.rope_base + ) + + +class LlamaModelState: + def __init__(self, config): + # Current IDs of tokens being processed (most recent token for decode; all prompt tokens for prefill) + self.token_ids = torch.empty(0, dtype=torch.long) + + # Set up KV cache -- initially empty + # This is what passes information from previous tokens to the current token during generation + self.attn_keys_caches = [ + torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=config.weights["model.layers.0.self_attn.k_proj.weight"].dtype) # (batch_size, n_kv_groups, seq_len, head_dim) + for _ in range(config.n_layers) + ] + self.attn_values_caches = [ + torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=config.weights["model.layers.0.self_attn.v_proj.weight"].dtype) # (batch_size, n_kv_groups, seq_len, head_dim) + for _ in range(config.n_layers) + ] # Utilities @@ -74,7 +102,7 @@ def compute_rope_angles(head_dim, context_length, rope_base=500000.0): return angles -def get_tokenizer(tokenizer_path, config): +def get_tokenizer(tokenizer_path, special_tokens): mergeable = tiktoken.load.load_tiktoken_bpe(tokenizer_path) return tiktoken.Encoding( name="llama3.2-1b", @@ -86,7 +114,7 @@ def get_tokenizer(tokenizer_path, config): r"|\s+(?!\S)" r"|\s+", mergeable_ranks=mergeable, - special_tokens=config.special_tokens, + special_tokens=special_tokens, ) @@ -95,23 +123,15 @@ def get_tokenizer(tokenizer_path, config): def generate_token( config, - weights, - angles, forward_pass, - token_ids, - attn_keys_caches=None, - attn_values_caches=None, + state ): generated_tokens = [] # Step 1: Forward pass - logits, attn_keys_caches, attn_values_caches = forward_pass( + logits, state = forward_pass( config, - weights, - angles, - token_ids, - attn_keys_caches, - attn_values_caches + state ) # Step 2: Get logits for last token @@ -135,64 +155,56 @@ def generate_token( probs = torch.nn.functional.softmax(last_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) - return next_token.item(), attn_keys_caches, attn_values_caches + return next_token.item(), state -def harness( - forward_pass, +def init( weights_path="/scratch/roesti/models/llama3.2-1b/model.safetensors", tokenizer_path="/scratch/roesti/models/llama3.2-1b/tokenizer.model", prompt="The capital of France is ", - num_tokens=100 ): + config = LlamaConfig(weights_path, tokenizer_path) + state = LlamaModelState(config) seed = 1608560892 torch.manual_seed(seed) - config = LlamaConfig() - - # Load model weights and tokenizer - weights = safetensors.torch.load_file(weights_path) - tokenizer = get_tokenizer(tokenizer_path, config) - - # Compute RoPE angle look-up table - angles = compute_rope_angles( - config.head_dim, - config.context_length, - config.rope_base - ) - # Tokenize prompt prompt_token_ids = [config.special_tokens["<|begin_of_text|>"]] - prompt_token_ids += tokenizer.encode(prompt) - assert len(prompt_token_ids) + num_tokens <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" + prompt_token_ids += config.tokenizer.encode(prompt) + assert len(prompt_token_ids) <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" prompt_token_ids = torch.tensor([prompt_token_ids], dtype=torch.long) - # Set up KV cache -- initially empty - # This is what passes information from previous tokens to the current token during generation - attn_keys_caches = [torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=weights["model.layers.0.self_attn.k_proj.weight"].dtype) for _ in range(config.n_layers)] # (batch_size, n_kv_groups, seq_len, head_dim) - attn_values_caches = [torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=weights["model.layers.0.self_attn.v_proj.weight"].dtype) for _ in range(config.n_layers)] # (batch_size, n_kv_groups, seq_len, head_dim) + state.token_ids = prompt_token_ids + + return config, state + +def generate( + config, + state, + forward_pass, + num_tokens=100 +): # Generate tokens # First token (prefill) n_tokens_generated = 0 t_prefill_start = time.perf_counter() - first_token, attn_keys_caches, attn_values_caches = generate_token(config, weights, angles, forward_pass, prompt_token_ids, attn_keys_caches, attn_values_caches) - token_text = tokenizer.decode([first_token]) + first_token, state = generate_token(config, forward_pass, state) + token_text = config.tokenizer.decode([first_token]) n_tokens_generated += 1 - print(prompt, end='', flush=True) print(token_text, end='', flush=True) t_prefill_stop = time.perf_counter() # Remaining tokens (decode) - last_token = torch.tensor([[first_token]]) + state.token_ids = torch.tensor([[first_token]], dtype=torch.long) t_decode_start = time.perf_counter() for _ in range(num_tokens-1): - next_token, attn_keys_caches, attn_values_caches = generate_token(config, weights, angles, forward_pass, last_token, attn_keys_caches, attn_values_caches) - token_text = tokenizer.decode([next_token]) + next_token, state = generate_token(config, forward_pass, state) + token_text = config.tokenizer.decode([next_token]) n_tokens_generated += 1 print(token_text, end='', flush=True) - last_token = torch.tensor([[next_token]]) + state.token_ids = torch.tensor([[next_token]], dtype=torch.long) t_decode_end = time.perf_counter() t_prefill = t_prefill_stop - t_prefill_start diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py new file mode 100755 index 00000000..350c741f --- /dev/null +++ b/applications/llama_3.2_1b/llama_npu.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 + +import torch +import math +from pathlib import Path +import sys +import llama_inference_harness as harness + +repo_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(repo_root)) + +from operators.common.aie_context import AIEContext +from operators.common import AIEOperatorBase +from operators import ( + AIERMSNorm, + AIEGEMM, +) + + +# AIE Operator Configuration +# ########################################################################## + +aieops = None + +class AIELlamaOperators: + + def __init__(self, config, prompt_len): + self.context = AIEContext() + + # RMS Norm + self.final_norm_prefill = AIERMSNorm( + size=prompt_len, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=config.emb_dim, + context=self.context, + ) + self.final_norm_decode = AIERMSNorm( + size=config.emb_dim, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=config.emb_dim, + context=self.context, + ) + + def setup(self): + self.context.compile_all() + self.context.prepare_runtime() + + +# Operators +# ########################################################################## + +def rope_forward(x, angles): + """Rotary positional embedding using precomputed angles""" + # x: (batch, seq_len, num_heads, head_dim) after view and before transpose + # angles: (context_length, head_dim) + _, seq_len, _, head_dim = x.shape + angles_slice = angles[:seq_len] # (seq_len, head_dim) + + # Split into even and odd dimensions + x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) + x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) + + # Get cos and sin from angles + cos = angles_slice[:, ::2] # (seq_len, head_dim//2) + sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) + + # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) + # (The same cosine and sine values are used across batch and heads.) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] + rotated = torch.empty_like(x) + rotated[..., : head_dim // 2] = x1 * cos - x2 * sin + rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos + + return rotated + + +def rms_norm_forward(x, weight, eps=1e-5): + """Root Mean Square Layer Normalization""" + # x: (batch, seq_len, dim) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + return weight * x + + +def grouped_query_attention_forward( + x, + keys_cache, + values_cache, + W_query, W_key, W_value, W_out, + angles, + mask=None, + num_heads=32, + num_kv_groups=8, +): + batch, seq_len, d_in = x.shape + assert W_query.shape[0] >= num_heads and W_query.shape[0] % num_heads == 0 + head_dim = W_query.shape[0] // num_heads + assert W_key.shape[0] == num_kv_groups * head_dim + assert W_value.shape[0] == num_kv_groups * head_dim + num_preceding_tokens = keys_cache.shape[2] + assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + + # Step 1: Linear projections + # This multiplication produces queries, keys and values for all tokens in the sequence. + # The weight matrix is such that multiple queries, keys and values are generated for each token. + # For each token, each head corresponds to one query. + # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). + # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. + # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. + queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, num_heads * head_dim) + keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) + values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) + queries = queries.view(batch, seq_len, num_heads, head_dim) # (batch, seq_len, num_heads, head_dim) + keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) + values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) + + # Step 2: Apply RoPE + queries = rope_forward(queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + keys = rope_forward(keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + + # Step 3: Transpose for attention computation + # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. + # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + + # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. + keys_cache = torch.cat([keys_cache, keys], dim=2) + values_cache = torch.cat([values_cache, values], dim=2) + keys = keys_cache + values = values_cache + + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value + group_size = num_heads // num_kv_groups + keys = keys.repeat_interleave(group_size, dim=1) + values = values.repeat_interleave(group_size, dim=1) + + # Step 6: Compute attention scores + # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) + # -> (batch, num_heads, seq_len, seq_len) + # Entry at row i, column j, indicates how much token i's query attends to token j's key. + scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) + + # Step 7: Apply mask + # This ensures causality, so that tokens in the future cannot attend to tokens in the past. + if mask is not None: + scores = scores.masked_fill(mask, float('-inf')) + + # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) + attention_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Step 9: Compute attention output + # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + context = torch.matmul(attention_weights, values) + + # Step 10: Concatenate heads and project + # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) + context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + + output = torch.nn.functional.linear(context, W_out) + + return output, keys_cache, values_cache + + +def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): + # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) + gate = torch.nn.functional.linear(x, fc1_weight) # gate projection + up = torch.nn.functional.linear(x, fc2_weight) # up projection + + # Step 2: Apply SiLU activation + gate_activated = torch.nn.functional.silu(gate) # (batch, seq_len, swiglu_hidden_dim) + + # Step 3: Element-wise multiplication (apply the 'gating') + hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) + + # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) + output = torch.nn.functional.linear(hidden, fc3_weight) + + return output + + +def transformer_block_forward( + x, + attn_keys_cache, + attn_values_cache, + num_heads, + num_kv_groups, + W_norm1, + W_attn_query, + W_attn_key, + W_attn_value, + W_attn_out, + W_norm2, + W_ffn_fc1, + W_ffn_fc2, + W_ffn_fc3, + rope_angles, + attn_mask +): + # Step 1: RMS normalization + x_norm = rms_norm_forward(x, W_norm1) + + # Step 2: Attention + attn_output, attn_keys, attn_values = grouped_query_attention_forward( + x_norm, + attn_keys_cache, + attn_values_cache, + W_attn_query, W_attn_key, W_attn_value, W_attn_out, + rope_angles, + attn_mask, + num_heads, + num_kv_groups, + ) + + # Step 3: Residual + x = x + attn_output + + # Step 4: Post-norm + x_norm = rms_norm_forward(x, W_norm2) + + # Step 5: fully-connected feed-forward network + ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) + + # Step 6: Residual + x = x + ffn_output + + return x, attn_keys, attn_values + + +def llama_forward_pass( + config, + state +): + batch, seq_len = state.token_ids.shape + + tok_emb_weight = config.weights['model.embed_tokens.weight'] + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) # (batch, seq_len, emb_dim) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), + diagonal=1 + ) + + if seq_len == 1: + return llama_forward_pass_decode(config, state, x, attn_mask) + else: + return llama_forward_pass_prefill(config, state, x, attn_mask) + +def llama_forward_pass_prefill( + config, + state, + x, + attn_mask +): + batch, seq_len, _ = x.shape + + # Step 3: Apply transformer blocks + for layer_idx in range(config.n_layers): + x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( + x, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + config.n_heads, + config.n_kv_groups, + W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], + W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], + W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], + W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], + W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], + W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], + W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], + W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], + W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], + rope_angles=config.angles, + attn_mask=attn_mask, + ) + + # Step 4: Final normalization + x = aieops.final_norm_prefill(x) + + # Step 5: Output projection (check for tied embeddings) + if 'lm_head.weight' in config.weights: + lm_head_weight = config.weights['lm_head.weight'] + else: + lm_head_weight = config.weights['model.embed_tokens.weight'] + + logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + + return logits, state + + +def llama_forward_pass_decode( + config, + state, + x, + attn_mask +): + batch, seq_len, _ = x.shape + + # Step 3: Apply transformer blocks + for layer_idx in range(config.n_layers): + x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( + x, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + config.n_heads, + config.n_kv_groups, + W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], + W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], + W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], + W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], + W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], + W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], + W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], + W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], + W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], + rope_angles=config.angles, + attn_mask=attn_mask, + ) + + # Step 4: Final normalization + x = aieops.final_norm_decode(x) + + # Step 5: Output projection + lm_head_weight = config.weights['model.embed_tokens.weight'] + + logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + + return logits, state + + +# Main +# ########################################################################## + +def main(): + global aieops + prompt = "The capital of France is " + config, state = harness.init(prompt=prompt) + + aieops = AIELlamaOperators(config, 2048) + aieops.final_norm_prefill.weight = config.weights['model.norm.weight'] + aieops.final_norm_decode.weight = config.weights['model.norm.weight'] + aieops.setup() + + print(prompt, end='', flush=True) + harness.generate(config, state, llama_forward_pass) + +if __name__ == "__main__": + main() From fd6f759bfa2dddf3836cb7fb94e7063e4a3251fc Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 14 Jan 2026 15:09:31 -0700 Subject: [PATCH 07/99] offload last layer GEMM/GEMV --- applications/llama_3.2_1b/llama_npu.py | 36 ++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 350c741f..4dee2fb0 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -14,6 +14,7 @@ from operators import ( AIERMSNorm, AIEGEMM, + AIEGEMV ) @@ -27,7 +28,7 @@ class AIELlamaOperators: def __init__(self, config, prompt_len): self.context = AIEContext() - # RMS Norm + # Final RMS Norm self.final_norm_prefill = AIERMSNorm( size=prompt_len, eps=1e-5, @@ -45,6 +46,31 @@ def __init__(self, config, prompt_len): context=self.context, ) + # Final GEMM + self.out_head_prefill = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.vocab_size, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=True, + use_static_weight=True, + separate_c_tiles=True, + partition_N=4, + context=self.context + ) + self.out_head_decode = AIEGEMV( + M=config.vocab_size, K=config.emb_dim, + num_aie_columns=8, + is_mv=True, + use_static_weight=True, + tile_size_input=4, + tile_size_output=32, + context=self.context + ) + def setup(self): self.context.compile_all() self.context.prepare_runtime() @@ -293,7 +319,7 @@ def llama_forward_pass_prefill( else: lm_head_weight = config.weights['model.embed_tokens.weight'] - logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + logits = aieops.out_head_prefill(x) # (batch, seq_len, vocab_size) return logits, state @@ -331,9 +357,7 @@ def llama_forward_pass_decode( x = aieops.final_norm_decode(x) # Step 5: Output projection - lm_head_weight = config.weights['model.embed_tokens.weight'] - - logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + logits = aieops.out_head_decode(x) # (batch, seq_len, vocab_size) return logits, state @@ -349,6 +373,8 @@ def main(): aieops = AIELlamaOperators(config, 2048) aieops.final_norm_prefill.weight = config.weights['model.norm.weight'] aieops.final_norm_decode.weight = config.weights['model.norm.weight'] + aieops.out_head_prefill.weight = config.weights['model.embed_tokens.weight'].T + aieops.out_head_decode.weight = config.weights['model.embed_tokens.weight'].T aieops.setup() print(prompt, end='', flush=True) From ec0d372c91cb73fa6b12cd17e2f67ba37389f574 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 12:50:24 -0700 Subject: [PATCH 08/99] add profile path analyzer --- .../llama_3.2_1b/profile_path_analyzer.py | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100755 applications/llama_3.2_1b/profile_path_analyzer.py diff --git a/applications/llama_3.2_1b/profile_path_analyzer.py b/applications/llama_3.2_1b/profile_path_analyzer.py new file mode 100755 index 00000000..638d938d --- /dev/null +++ b/applications/llama_3.2_1b/profile_path_analyzer.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +""" +Profile Path Analyzer + +Analyzes profile.json to aggregate time spent in selected functions across call paths. +Supports fuzzy matching and "zooming into" specific call paths. +""" + +import json +import argparse +import sys +from typing import Dict, List, Tuple, Set +import matplotlib.pyplot as plt +import matplotlib + + +def parse_function_name(full_name: str) -> Tuple[str, str, str]: + parts = full_name.rsplit(':', 2) + if len(parts) == 3: + return parts[0], parts[1], parts[2] + return full_name, '', '' + + +def fuzzy_match(pattern: str, full_name: str) -> bool: + return pattern.lower() in full_name.lower() + + +def get_all_paths(profile: Dict, current_path: List[str] = None) -> List[Tuple[List[str], float]]: + """ + Extract all call paths from the profile with their total times. + + Args: + profile: Profile dictionary + current_path: Current call path being explored + + Returns: + List of (call_path, total_time) tuples + """ + if current_path is None: + current_path = [] + + paths = [] + + for func_name, data in profile.items(): + time_spent = data[0] + children = data[1] if len(data) > 1 else {} + + new_path = current_path + [func_name] + paths.append((new_path, time_spent)) + + if children: + child_paths = get_all_paths(children, new_path) + paths.extend(child_paths) + + return paths + + +def is_subpath(path: List[str], zoom_path: List[str]) -> bool: + """ + Check if path is a subpath of zoom_path. + + Args: + path: The path to check + zoom_path: The zoom-in path + + Returns: + True if path starts with zoom_path + """ + if len(path) < len(zoom_path): + return False + + for i, func in enumerate(zoom_path): + if not fuzzy_match(func, path[i]): + return False + + return True + + +def find_matching_functions(profile: Dict, patterns: List[str]) -> Dict: + """ + Find all function names in the profile that match any of the patterns. + If multiple patterns match, uses the most specific (longest) pattern. + + Args: + profile: Profile dictionary + patterns: List of search patterns + + Returns: + Dict mapping function names to their matched pattern + """ + matching = {} + + def traverse(data: Dict): + for func_name, value in data.items(): + # Find all matching patterns and use the longest (most specific) + matched_patterns = [p for p in patterns if fuzzy_match(p, func_name)] + if matched_patterns: + # Use the longest pattern as it's the most specific + most_specific = max(matched_patterns, key=len) + matching[func_name] = most_specific + + if len(value) > 1 and value[1]: + traverse(value[1]) + + traverse(profile) + return matching + + +def aggregate_times(profile: Dict, selected_funcs: Dict, zoom_path: List[str]) -> Dict[str, float]: + aggregated = {func: 0.0 for func in selected_funcs.values()} + + def traverse(data: Dict, current_path: List[str], counted_parent: str | None): + for func_name, (time_spent, subprofile) in data.items(): + new_path = current_path + [func_name] + # Start with parent's pattern, will be overridden if this function is selected + counted_this_iter = counted_parent + + # Check if this path is within the zoom scope + if not zoom_path or is_subpath(new_path, zoom_path): + # If this function is selected, add its time + if func_name in selected_funcs: + pattern = selected_funcs[func_name] + if counted_parent is None: + aggregated[pattern] += time_spent + elif counted_parent != pattern: + raise RuntimeError(f"Double-counting detected: pattern '{counted_parent}' calls '{pattern}', which would lead to double-counting '{pattern}'.") + counted_this_iter = pattern + + # Recurse into children + if subprofile: + traverse(subprofile, new_path, counted_this_iter) + + traverse(profile, [], None) + return aggregated + + +def calculate_total_time(profile: Dict, zoom_path: List[str]) -> float: + """ + Calculate total time for the zoomed-in path. + + Args: + profile: Profile dictionary + zoom_path: Zoom-in path (empty list for entire profile) + + Returns: + Total time in seconds + """ + if not zoom_path: + # No zoom - return total of all top-level functions + return sum(value[0] for value in profile.values()) + + # Find the zoomed function and return its time + def find_zoom_time(data: Dict, path_remaining: List[str]) -> float: + if not path_remaining: + return 0.0 + + pattern = path_remaining[0] + + for func_name, value in data.items(): + if fuzzy_match(pattern, func_name): + if len(path_remaining) == 1: + # Found the zoom target + return value[0] + else: + # Continue searching deeper + if len(value) > 1 and value[1]: + result = find_zoom_time(value[1], path_remaining[1:]) + if result > 0: + return result + + return 0.0 + + return find_zoom_time(profile, zoom_path) + + +def print_results(total_time: float, aggregated: Dict[str, float], plot_file: str = None): + """ + Print the analysis results. + + Args: + total_time: Total time in the zoomed scope + aggregated: Dictionary of function times + plot_file: Optional path to save a matplotlib bar plot + """ + print("\n" + "="*80) + print("PROFILE ANALYSIS RESULTS") + print("="*80) + + print(f"\nTotal time (zoomed scope): {total_time:.6f} seconds") + print("-"*80) + + # Sort by time (descending) + sorted_funcs = sorted(aggregated.items(), key=lambda x: x[1], reverse=True) + + selected_total = 0.0 + for func_name, time in sorted_funcs: + selected_total += time + percentage = (time / total_time * 100) if total_time > 0 else 0 + print(f"{func_name}") + print(f" Time: {time:.6f}s ({percentage:.2f}%)") + print() + + # Calculate "other" time + other_time = total_time - selected_total + other_percentage = (other_time / total_time * 100) if total_time > 0 else 0 + + print("-"*80) + print(f"Selected functions total: {selected_total:.6f}s ({selected_total/total_time*100:.2f}%)") + print(f"Other (unselected): {other_time:.6f}s ({other_percentage:.2f}%)") + print("="*80) + + # Create plot if requested + if plot_file: + create_bar_plot(total_time, sorted_funcs, other_time, plot_file) + + +def create_bar_plot(total_time: float, sorted_funcs: List[Tuple[str, float]], other_time: float, output_file: str): + """ + Create a bar plot showing time spent in each selected pattern and "other". + + Args: + total_time: Total time in the zoomed scope + sorted_funcs: List of (pattern, time) tuples sorted by time + other_time: Time spent in unselected functions + output_file: Path to save the plot + """ + # Prepare data + labels = [func for func, _ in sorted_funcs] + ["Other"] + times = [time for _, time in sorted_funcs] + [other_time] + percentages = [(time / total_time * 100) if total_time > 0 else 0 for time in times] + + # Create figure + fig, ax = plt.subplots(figsize=(10, 6)) + + # Create bars + bars = ax.bar(range(len(labels)), times, color='steelblue', alpha=0.8) + + # Color the "Other" bar differently + bars[-1].set_color('lightgray') + + # Customize plot + ax.set_xlabel('Pattern', fontsize=12, fontweight='bold') + ax.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold') + ax.set_title(f'Profile Analysis - Total Time: {total_time:.3f}s', fontsize=14, fontweight='bold') + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=45, ha='right') + + # Add value labels on bars + for i, (bar, time, pct) in enumerate(zip(bars, times, percentages)): + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height, + f'{time:.3f}s\n({pct:.1f}%)', + ha='center', va='bottom', fontsize=9) + + # Add grid for readability + ax.grid(axis='y', alpha=0.3, linestyle='--') + ax.set_axisbelow(True) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save plot + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"\nPlot saved to: {output_file}") + plt.close() + + +def main(): + parser = argparse.ArgumentParser( + description='Analyze profile.json to aggregate time spent in selected functions.', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Analyze all torch calls + %(prog)s logs/profile.json --select torch + + # Analyze multiple patterns + %(prog)s logs/profile.json --select torch numpy inference + + # Zoom into a specific call path + %(prog)s logs/profile.json --select torch --zoom inference forward + + # List all unique functions + %(prog)s logs/profile.json --list-functions + """ + ) + + parser.add_argument('profile', help='Path to profile.json file') + parser.add_argument('--select', '-s', nargs='+', metavar='PATTERN', + help='Function patterns to select (fuzzy match)') + parser.add_argument('--zoom', '-z', nargs='+', metavar='PATTERN', + help='Call path to zoom into (fuzzy match sequence)') + parser.add_argument('--list-functions', '-l', action='store_true', + help='List all unique function names in the profile') + parser.add_argument('--plot', '-p', metavar='FILE', + help='Save a bar plot to the specified file (e.g., plot.png)') + + args = parser.parse_args() + + # Load profile + try: + with open(args.profile, 'r') as f: + profile = json.load(f) + except FileNotFoundError: + print(f"Error: Profile file '{args.profile}' not found.", file=sys.stderr) + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in profile file: {e}", file=sys.stderr) + sys.exit(1) + + + # Require --select for analysis + if not args.select: + print("Error: --select is required for analysis", file=sys.stderr) + parser.print_help() + sys.exit(1) + + # Find matching functions + selected_funcs = find_matching_functions(profile, args.select) + + if not selected_funcs: + print(f"Error: No functions matched the patterns: {args.select}", file=sys.stderr) + sys.exit(1) + + print(f"\nMatched {len(selected_funcs)} function(s) from patterns: {args.select}") + + # Prepare zoom path + zoom_path = args.zoom if args.zoom else [] + + if zoom_path: + print(f"Zooming into path: {' -> '.join(zoom_path)}") + + # Calculate total time + total_time = calculate_total_time(profile, zoom_path) + + if total_time == 0: + print(f"Error: Could not find zoom path or it has zero time.", file=sys.stderr) + sys.exit(1) + + # Aggregate times + aggregated = aggregate_times(profile, selected_funcs, zoom_path) + + # Print results + print_results(total_time, aggregated, args.plot) + + +if __name__ == '__main__': + main() From 280917a078db0f64098e745a834d0b0f41ad7283 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 12:50:32 -0700 Subject: [PATCH 09/99] change profiling to allow annotation using contexts; remove nn.Module --- applications/llama_3.2_1b/custom_profile.py | 56 ++++++++++++++------- applications/llama_3.2_1b/inference.py | 13 ----- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/applications/llama_3.2_1b/custom_profile.py b/applications/llama_3.2_1b/custom_profile.py index 2f1199c5..89ed4a51 100644 --- a/applications/llama_3.2_1b/custom_profile.py +++ b/applications/llama_3.2_1b/custom_profile.py @@ -9,32 +9,44 @@ def profile_call(frame, event, arg): - global call_stack - - timestamp = time.perf_counter() - func_name = frame.f_code.co_name filename = frame.f_code.co_filename line_no = frame.f_lineno func_identifier = f"{str(frame.f_code.co_filename)}:{frame.f_code.co_firstlineno}:{frame.f_code.co_name}" if event == "call": - call_stack.append((func_identifier, timestamp)) + log_call(func_identifier) elif event == "return": - if 0 == len(call_stack): - return - last_func_identifier, start_time = call_stack[-1] - if last_func_identifier != func_identifier: - print(call_stack) - raise RuntimeError(f"Function return mismatch: expected {last_func_identifier}, got {func_identifier}") - elapsed = timestamp - start_time + log_return(func_identifier) - this_path_time = time_per_path - for f, _ in call_stack: - this_path_time = this_path_time[1].setdefault(f, [0.0, {}]) - this_path_time[0] += elapsed - call_stack.pop() +def log_call(func_identifier): + global call_stack + if func_identifier.endswith(":log_call") or func_identifier.endswith(":log_return") or func_identifier.endswith(":__enter__") or func_identifier.endswith(":__exit__"): + return + timestamp = time.perf_counter() + call_stack.append((func_identifier, timestamp)) + + +def log_return(func_identifier): + global call_stack + if func_identifier.endswith(":log_call") or func_identifier.endswith(":log_return") or func_identifier.endswith(":__enter__") or func_identifier.endswith(":__exit__"): + return + if 0 == len(call_stack): + return + timestamp = time.perf_counter() + last_func_identifier, start_time = call_stack[-1] + if last_func_identifier != func_identifier: + print(call_stack) + raise RuntimeError(f"Function return mismatch: expected {last_func_identifier}, got {func_identifier}") + elapsed = timestamp - start_time + + this_path_time = time_per_path + for f, _ in call_stack: + this_path_time = this_path_time[1].setdefault(f, [0.0, {}]) + this_path_time[0] += elapsed + + call_stack.pop() def enable_profiling(): @@ -46,3 +58,13 @@ def store_profile(path): with open(path, "w") as f: json.dump(time_per_path[1], f, indent=2) + +class CustomProfileContext: + def __init__(self, label): + self.label = label + + def __enter__(self): + log_call(self.label) + + def __exit__(self, exc_type, exc_value, traceback): + log_return(self.label) diff --git a/applications/llama_3.2_1b/inference.py b/applications/llama_3.2_1b/inference.py index 9ce47c22..63b2c18e 100755 --- a/applications/llama_3.2_1b/inference.py +++ b/applications/llama_3.2_1b/inference.py @@ -181,20 +181,8 @@ def inference( hook_handles.append(handle) device = torch.device("cpu") - model.to(device) chat_tokenizer = ChatFormat(tokenizer) - total_params = sum(p.numel() for p in model.parameters()) - total_params_normalized = total_params - model.tok_emb.weight.numel() - logging.info(f"Total number of parameters: {total_params:,}") - logging.info(f"Total number of unique parameters: {total_params_normalized:,}") - logging.info( - f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB" - ) - logging.info( - f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB" - ) - combined_weights = load_file(weights_file_path) # Get parameters from model config model_config = { @@ -210,7 +198,6 @@ def inference( "rope_freq": model.cfg["rope_freq"], } load_weights_into_llama(model, model_config, combined_weights) - model.to(device) del combined_weights logging.info("Preparing AIE operators...") From 7b56ef8872b50c94c79bf38b7cb4bfa8b4cb710d Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 16:25:55 -0700 Subject: [PATCH 10/99] refactoring started; RMSNorm offloaded, last layer GEMM started --- applications/llama_3.2_1b/llama_cpu.py | 4 +- .../llama_3.2_1b/llama_inference_harness.py | 17 +- applications/llama_3.2_1b/llama_npu.py | 124 +++-- operators/__init__.py | 38 +- operators/common/__init__.py | 8 +- operators/common/aie_base.py | 317 ++++++----- operators/common/aie_context.py | 184 ------- operators/gemm/op.py | 490 +++++++----------- operators/rms_norm/design.py | 94 +--- operators/rms_norm/design_weighted.py | 99 +--- operators/rms_norm/op.py | 189 ++----- 11 files changed, 537 insertions(+), 1027 deletions(-) diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py index 258e342c..31849f08 100755 --- a/applications/llama_3.2_1b/llama_cpu.py +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -234,9 +234,7 @@ def llama_forward_pass( x = rms_norm_forward(x, final_norm_weight) # Step 5: Output projection - lm_head_weight = config.weights['model.embed_tokens.weight'] - - logits = torch.nn.functional.linear(x, lm_head_weight) # (batch, seq_len, vocab_size) + logits = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) return logits, state diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 736d48c1..3d639627 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -69,7 +69,9 @@ class LlamaModelState: def __init__(self, config): # Current IDs of tokens being processed (most recent token for decode; all prompt tokens for prefill) self.token_ids = torch.empty(0, dtype=torch.long) + self.reset_kv_cache(config) + def reset_kv_cache(self, config): # Set up KV cache -- initially empty # This is what passes information from previous tokens to the current token during generation self.attn_keys_caches = [ @@ -184,7 +186,8 @@ def generate( config, state, forward_pass, - num_tokens=100 + num_tokens=100, + use_kv_cache=True ): # Generate tokens # First token (prefill) @@ -197,14 +200,22 @@ def generate( t_prefill_stop = time.perf_counter() # Remaining tokens (decode) - state.token_ids = torch.tensor([[first_token]], dtype=torch.long) + if use_kv_cache: + state.token_ids = torch.tensor([[first_token]], dtype=torch.long) + else: + state.reset_kv_cache(config) + state.token_ids = torch.cat([state.token_ids, torch.tensor([[first_token]], dtype=torch.long)], dim=1) t_decode_start = time.perf_counter() for _ in range(num_tokens-1): next_token, state = generate_token(config, forward_pass, state) token_text = config.tokenizer.decode([next_token]) n_tokens_generated += 1 print(token_text, end='', flush=True) - state.token_ids = torch.tensor([[next_token]], dtype=torch.long) + if use_kv_cache: + state.token_ids = torch.tensor([[next_token]], dtype=torch.long) + else: + state.reset_kv_cache(config) + state.token_ids = torch.cat([state.token_ids, torch.tensor([[next_token]], dtype=torch.long)], dim=1) t_decode_end = time.perf_counter() t_prefill = t_prefill_stop - t_prefill_start diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 4dee2fb0..3981dd32 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -4,53 +4,61 @@ import math from pathlib import Path import sys +import ml_dtypes import llama_inference_harness as harness repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) from operators.common.aie_context import AIEContext -from operators.common import AIEOperatorBase +from operators.common import ( + AIEBuffer +) +from operators.common.utils import torch_to_numpy, numpy_to_torch from operators import ( AIERMSNorm, AIEGEMM, - AIEGEMV + #AIEGEMV ) # AIE Operator Configuration # ########################################################################## -aieops = None +aie_ops = None class AIELlamaOperators: def __init__(self, config, prompt_len): self.context = AIEContext() + self.context.build_dir.mkdir(parents=True, exist_ok=True) # Final RMS Norm self.final_norm_prefill = AIERMSNorm( - size=prompt_len, + size=prompt_len * config.emb_dim, eps=1e-5, num_aie_columns=8, num_channels=2, tile_size=config.emb_dim, - context=self.context, - ) + context=self.context + ).compile().get_callable() self.final_norm_decode = AIERMSNorm( size=config.emb_dim, eps=1e-5, num_aie_columns=1, num_channels=2, tile_size=config.emb_dim, - context=self.context, - ) + context=self.context + ).compile().get_callable() # Final GEMM - self.out_head_prefill = AIEGEMM( + min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N + config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N + config.vocab_partitions = 4 + self.out_head_prefill_compilable = AIEGEMM( M=prompt_len, K=config.emb_dim, - N=config.vocab_size, + N=config.padded_vocab_size // config.vocab_partitions, num_aie_columns=8, tile_m=64, tile_k=64, @@ -58,22 +66,49 @@ def __init__(self, config, prompt_len): b_col_maj=True, use_static_weight=True, separate_c_tiles=True, - partition_N=4, - context=self.context - ) - self.out_head_decode = AIEGEMV( - M=config.vocab_size, K=config.emb_dim, - num_aie_columns=8, - is_mv=True, - use_static_weight=True, - tile_size_input=4, - tile_size_output=32, context=self.context - ) + ).compile() + self.out_head_prefill = self.out_head_prefill_compilable.get_callable() + #self.out_head_decode = AIEGEMV( + # M=config.vocab_size, K=config.emb_dim, + # num_aie_columns=8, + # is_mv=True, + # use_static_weight=True, + # tile_size_input=4, + # tile_size_output=32, + # context=self.context + #) + + +# Allocate buffers shared with NPU +# ########################################################################## + +aie_buffers = None - def setup(self): - self.context.compile_all() - self.context.prepare_runtime() +class AIELlamaBuffers: + def __init__(self, config, prompt_len): + self.x_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.x_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']) + + # Final linear layer + W_out_head_parts = aie_ops.out_head_prefill_compilable.partition_B( + torch_to_numpy(config.weights['model.embed_tokens.weight']), + config.vocab_partitions + ) + self.W_out_head_parts = [ + AIEBuffer.from_np(W_out_head_part) + for W_out_head_part in W_out_head_parts + ] + self.logits_prefill = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)) + self.logits_prefill_parts = [ + self.logits_prefill.subbuffer( + length=prompt_len * (config.padded_vocab_size // config.vocab_partitions), + offset=i * prompt_len * (config.padded_vocab_size // config.vocab_partitions), + shape=(prompt_len, config.padded_vocab_size // config.vocab_partitions), + ) + for i in range(config.vocab_partitions) + ] # Operators @@ -311,15 +346,24 @@ def llama_forward_pass_prefill( ) # Step 4: Final normalization - x = aieops.final_norm_prefill(x) + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_prefill.to("npu") + aie_ops.final_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) # Step 5: Output projection (check for tied embeddings) - if 'lm_head.weight' in config.weights: - lm_head_weight = config.weights['lm_head.weight'] - else: - lm_head_weight = config.weights['model.embed_tokens.weight'] + # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. + # for i in range(config.vocab_partitions): + # aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) + # aie_buffers.logits_prefill.to("cpu") + # # Reassemble (transpose) the logits from partitions + # logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) + # logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) + # logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) + + aie_buffers.x_prefill.to("cpu") + x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] - logits = aieops.out_head_prefill(x) # (batch, seq_len, vocab_size) + logits = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) return logits, state @@ -354,10 +398,14 @@ def llama_forward_pass_decode( ) # Step 4: Final normalization - x = aieops.final_norm_decode(x) + aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_decode.to("npu") + aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) + aie_buffers.x_decode.to("cpu") # Step 5: Output projection - logits = aieops.out_head_decode(x) # (batch, seq_len, vocab_size) + lm_head_weight = config.weights['model.embed_tokens.weight'] + logits = torch.nn.functional.linear(aie_buffers.x_decode.view_as_torch().unsqueeze(0), lm_head_weight) # (batch, seq_len, vocab_size) return logits, state @@ -366,19 +414,15 @@ def llama_forward_pass_decode( # ########################################################################## def main(): - global aieops + global aie_ops, aie_buffers prompt = "The capital of France is " config, state = harness.init(prompt=prompt) - aieops = AIELlamaOperators(config, 2048) - aieops.final_norm_prefill.weight = config.weights['model.norm.weight'] - aieops.final_norm_decode.weight = config.weights['model.norm.weight'] - aieops.out_head_prefill.weight = config.weights['model.embed_tokens.weight'].T - aieops.out_head_decode.weight = config.weights['model.embed_tokens.weight'].T - aieops.setup() + aie_ops = AIELlamaOperators(config, 2048) + aie_buffers = AIELlamaBuffers(config, 2048) print(prompt, end='', flush=True) - harness.generate(config, state, llama_forward_pass) + harness.generate(config, state, llama_forward_pass, use_kv_cache=False) if __name__ == "__main__": main() diff --git a/operators/__init__.py b/operators/__init__.py index fc203892..9b00a29e 100644 --- a/operators/__init__.py +++ b/operators/__init__.py @@ -1,24 +1,24 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from .axpy.op import AIEAXPY -from .dequant.op import AIEDequant -from .elementwise_add.op import AIEElementwiseAdd -from .elementwise_mul.op import AIEElementwiseMul -from .gelu.op import AIEGELU +#from .axpy.op import AIEAXPY +#from .dequant.op import AIEDequant +#from .elementwise_add.op import AIEElementwiseAdd +#from .elementwise_mul.op import AIEElementwiseMul +#from .gelu.op import AIEGELU from .gemm.op import AIEGEMM -from .gemv.op import AIEGEMV -from .layer_norm.op import AIELayerNorm -from .leaky_relu.op import AIELeakyReLU -from .mem_copy.op import AIEMemCopy -from .mha.op import AIEMHA -from .relu.op import AIEReLU +#from .gemv.op import AIEGEMV +#from .layer_norm.op import AIELayerNorm +#from .leaky_relu.op import AIELeakyReLU +#from .mem_copy.op import AIEMemCopy +#from .mha.op import AIEMHA +#from .relu.op import AIEReLU from .rms_norm.op import AIERMSNorm -from .rope.op import AIERope -from .sigmoid.op import AIESigmoid -from .silu.op import AIESiLU -from .softmax.op import AIESoftmax -from .swiglu_decode.op import AIESwiGLUDecode -from .swiglu_prefill.op import AIESwiGLUPrefill -from .tanh.op import AIETanh -from .transpose.op import AIETranspose +#from .rope.op import AIERope +#from .sigmoid.op import AIESigmoid +#from .silu.op import AIESiLU +#from .softmax.op import AIESoftmax +#from .swiglu_decode.op import AIESwiGLUDecode +#from .swiglu_prefill.op import AIESwiGLUPrefill +#from .tanh.op import AIETanh +#from .transpose.op import AIETranspose diff --git a/operators/common/__init__.py b/operators/common/__init__.py index 4fa9ae3b..b70aba2c 100644 --- a/operators/common/__init__.py +++ b/operators/common/__init__.py @@ -3,7 +3,13 @@ """Common utilities and base classes for IRON operators.""" -from .aie_base import AIEOperatorBase, AIEOperatorConstraintError +from .aie_base import ( + AIEOperatorBase, + SingleMLIRSourceOperator, + AIEBuffer, + SingleXclbinCallable, + AIERuntimeArgSpec, +) from .aie_context import AIEContext from .compilation import ( XclbinArtifact, diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index fa74f7b3..a60e06b9 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -15,115 +15,27 @@ from .aie_context import AIEContext from .aie_device_manager import AIEDeviceManager, pyxrt from .utils import numpy_to_torch, torch_to_numpy +from .compilation import ( + XclbinArtifact, + InstsBinArtifact, + KernelObjectArtifact, + KernelArchiveArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) class AIEOperatorBase(ABC): """Base class for AIE-accelerated operations""" - @classmethod - def get_default_context(cls): - """One global 'default' context if none is specified""" - if not hasattr(AIEOperatorBase, "_default_context"): - AIEOperatorBase._default_context = AIEContext() - return AIEOperatorBase._default_context - def __init__(self, context=None): self.artifacts = ( [] ) # CompilationArtifact objects are uniqued within the context - self.kernels = {} # Name -> (xclbin_path, xclbin_kernel_name, insts_path) - self.buffers = {} # Name -> required buffer size in bytes - self.buffer_static_data = {} - self.runlist = ( - [] - ) # List of (kernel_name, buffers_name, buffer_name...), will be executed in sequence - - # AIE runtime state - self.buffer_bos = {} # Buffer name -> buffer object - self.xrt_kernels = ( - {} - ) # Kernel name -> (XRT context, XRT kernel object, instruction buffer object, instruction length) - self.xrt_runlist = None - if context is None: context = self.get_default_context() context.register_operator(self) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def add_kernel( - self, - name: str, - xclbin_artifact: comp.XclbinArtifact, - xclbin_kernel_name: str, - insts_artifact: comp.InstsBinArtifact, - ): - assert name not in self.kernels - self.kernels[name] = (xclbin_artifact, xclbin_kernel_name, insts_artifact) - - def add_buffer(self, name, count, dtype=bfloat16, static_data=None): - assert name not in self.buffers - self.buffers[name] = count * np.dtype(dtype).itemsize - if static_data is not None: - assert ( - static_data.nbytes <= self.buffers[name] - ), f"Static data for buffer {name} exceeds allocated size: expected {self.buffers[name]} bytes, got {static_data.nbytes} bytes." - static_data_bytes = static_data.flatten().view(np.uint8).tobytes() - if static_data_bytes not in self.context.static_data_pool: - self.context.static_data_pool[static_data_bytes] = None - self.buffer_static_data[name] = next( - k - for k, v in self.context.static_data_pool.items() - if k == static_data_bytes - ) - - def add_to_runlist(self, kernel_name, *args): - if kernel_name not in self.kernels: - raise RuntimeError(f"No such kernel: {kernel_name}") - for arg in args: - if arg not in self.buffers: - raise RuntimeError(f"No such buffer: {arg}") - self.runlist.append((kernel_name, *args)) - - def get_bo(self, buffer_name): - return self.buffer_bos[buffer_name] - - def read_buffer(self, buffer_name, shape, copy=False, dtype=bfloat16): - """Read buffer and return values as a numpy array""" - # Create a byte accessible memory view of the buffer object - mv = self.get_bo(buffer_name).map() - - # Interpret the buffer as a 1-dimensional array then change its view to the expected shape - arr = np.frombuffer(mv, dtype=dtype, count=np.prod(shape)).reshape(shape) - - # Return an independent copy of the array if needed - return arr.copy() if copy else arr - - def read_buffer_as_torch(self, buffer_name, shape, dtype=bfloat16): - return numpy_to_torch(self.read_buffer(buffer_name, shape, dtype)) - - def write_buffer(self, buffer_name, array): - """Write buffer from a numpy array into a XRT buffer object""" - if buffer_name in self.buffer_static_data: - raise RuntimeError(f"Cannot write to static buffer: {buffer_name}") - - # Normalize the source - if isinstance(array, torch.Tensor): - src = torch_to_numpy(array) - else: - src = np.asarray(array) - - # Create a flattened 1D byte view of the source - src_bytes = src.ravel().view(np.uint8) - - bo = self.get_bo(buffer_name) - mv = bo.map() # byte accessible memory view - # Interpret the buffer as a 1-dimensional array - dst_bytes = np.frombuffer(mv, dtype=np.uint8, count=bo.size()) - - # The BO is an existing array, so copyto() can be called, which doesn't create a new array - np.copyto(dst_bytes[: src_bytes.size], src_bytes, casting="no") + self.context = context @abstractmethod def set_up_artifacts(self): @@ -133,11 +45,22 @@ def set_up_artifacts(self): Compilation will be handled automatically based on the provided description. """ pass + + @abstractmethod + def get_arg_spec(self): + pass @abstractmethod - def set_up_runtime(self): + def get_callable(self): pass + @classmethod + def get_default_context(cls): + """One global 'default' context if none is specified""" + if not hasattr(AIEOperatorBase, "_default_context"): + AIEOperatorBase._default_context = AIEContext() + return AIEOperatorBase._default_context + def compile(self, dry_run=None): """ Set up the operator and compile any necessary artifacts. @@ -165,6 +88,7 @@ def compile(self, dry_run=None): f"Compiling {len(work_list)} new artifacts for AIE operator {self.__class__.__name__}: {', '.join(str(artifact.path.name) for artifact in work_list)}" ) comp.compile(compilation_rules, work_list) + return self def add_artifacts(self, artifacts): self.artifacts.extend(artifacts) @@ -182,45 +106,6 @@ def _move_artifact_paths(self): artifact.set_path(context.build_dir / artifact.path) todo.extend(artifact.depends) - def run_runlist(self): - elapsed = 0.0 - if self.xrt_runlist is None: - # Execute as separate xclbin kernel invocations - for i, (kernel_name, *buffer_args) in enumerate(self.runlist): - context, xrt_kernel, insts_bo, insts_len = self.xrt_kernels[kernel_name] - insts_bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - bos = [self.buffer_bos[buffer_arg] for buffer_arg in buffer_args] - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - opcode = 3 - start = time.perf_counter() - run = xrt_kernel(opcode, insts_bo, insts_len, *bos) - result = run.wait() - stop = time.perf_counter() - elapsed += stop - start - if result != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError( - f"Kernel {kernel_name} did not complete correctly: {result}" - ) - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - else: - bos = set( - self.buffer_bos[buffer_arg] - for _, *buffer_args in self.runlist - for buffer_arg in buffer_args - ) - insts_bos = set( - self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist - ) - sync_to_device(bos | insts_bos) - start = time.perf_counter() - execute_runlist(self.xrt_runlist) - stop = time.perf_counter() - sync_from_device(bos) - elapsed = stop - start - return elapsed - def sync_to_device(bos): for bo in bos: @@ -237,5 +122,159 @@ def execute_runlist(runlist): runlist.wait() -class AIEOperatorConstraintError(RuntimeError): - pass +class SingleMLIRSourceOperator(AIEOperatorBase, ABC): + """Base class for AIE-accelerated operations""" + def __init__(self, *args, **kwargs): + AIEOperatorBase.__init__(self, *args, **kwargs) + + @abstractmethod + def get_operator_name(self): + pass + + @abstractmethod + def get_mlir_artifact(self): + pass + + @abstractmethod + def get_kernel_artifacts(self): + pass + + def get_kernel_archive_name(self): + return self.get_operator_name() + ".a" + + def get_artifacts(self): + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps = self.get_kernel_artifacts() + xclbin_artifact = XclbinArtifact.new( + f"{operator_name}.xclbin", + depends=[ + mlir_artifact, + KernelArchiveArtifact.new( + self.get_kernel_archive_name(), + depends=kernel_deps, + ), + ], + ) + insts_artifact = InstsBinArtifact.new( + f"{operator_name}.bin", depends=[mlir_artifact] + ) + return xclbin_artifact, insts_artifact + + def set_up_artifacts(self): + xclbin_artifact, insts_artifact = self.get_artifacts() + self.xclbin_artifact = xclbin_artifact + self.insts_artifact = insts_artifact + self.add_artifacts([xclbin_artifact, insts_artifact]) + + + def get_callable(self): + return SingleXclbinCallable( + xclbin_path=self.xclbin_artifact.path, + kernel_name=self.xclbin_artifact.kernel_name, + insts_bin_path=self.insts_artifact.path, + args_spec=self.get_arg_spec() + ) + +class AIERuntimeArgSpec: + def __init__(self, shape, dtype=bfloat16): + self.shape = shape + self.dtype = dtype + +class AIEBuffer: + def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): + size = np.prod(shape) * np.dtype(dtype).itemsize + self.shape = shape + self.dtype = dtype + self.bo = bo + self.on = "cpu" + self.device_manager = device_manager or AIEDeviceManager() + if not self.bo: + self.bo = pyxrt.bo( + self.device_manager.device, + size, + pyxrt.bo.host_only, + 0x10000, + ) + + def subbuffer(self, length, offset, shape, dtype=None): + if dtype is None: + dtype = self.dtype + assert np.prod(shape) == length + itemsize = np.dtype(dtype).itemsize + assert offset >= 0 + assert offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + assert length * itemsize + offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + sub_bo = pyxrt.bo( + self.bo, # parent bo + length * itemsize, # size + offset * itemsize, # offset + ) + return AIEBuffer(shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager) + + def view_as_np(self): + self.to("cpu") + # Create a byte accessible memory view of the buffer object + mv = self.bo.map() + # Interpret the buffer as a 1-dimensional array then change its view to the expected shape + return np.frombuffer(mv, dtype=self.dtype, count=np.prod(self.shape)).reshape(self.shape) + + def view_as_torch(self): + return numpy_to_torch(self.view_as_np()) + + def to(self, dest): + if dest == "npu": + if self.on != "npu": + self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + elif dest == "cpu": + if self.on != "cpu": + self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + else: + raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") + return self + + @staticmethod + def from_np(buffer): + shape = buffer.shape + dtype = buffer.dtype + size = np.prod(shape) * np.dtype(dtype).itemsize + aie_buffer = AIEBuffer(shape=shape, dtype=dtype) + aie_buffer.view_as_np()[:] = buffer + aie_buffer.to("npu") + return aie_buffer + + @staticmethod + def from_torch(tensor): + return AIEBuffer.from_np(torch_to_numpy(tensor)) + +class SingleXclbinCallable: + def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): + self.device_manager = device_manager or AIEDeviceManager() + self.context, self.xrt_kernel = self.device_manager.get_context_and_kernel( + str(xclbin_path), kernel_name + ) + with open(str(insts_bin_path), "rb") as f: + instructions = np.frombuffer(f.read(), dtype=np.uint32) + insts_bo = pyxrt.bo( + self.device_manager.device, + instructions.nbytes, + pyxrt.bo.cacheable, + self.xrt_kernel.group_id(1), + ) + insts_bo.write(instructions.view(np.uint8), 0) + self.insts_buffer = AIEBuffer(shape=(len(instructions),), dtype=np.uint32, bo=insts_bo) + self.insts_buffer.to("npu") + self.args_spec = args_spec + + def __call__(self, *buffers): + assert len(buffers) == len(self.args_spec) + assert all( + buffers[i].shape == self.args_spec[i].shape and buffers[i].dtype == self.args_spec[i].dtype + for i in range(len(buffers)) + ), "Input buffer shapes or dtypes do not match expected argument specification." + self.insts_buffer.to("npu") + for buffer in buffers: + buffer.to("npu") + opcode = 3 + bos = [buffer.bo for buffer in buffers] + run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) diff --git a/operators/common/aie_context.py b/operators/common/aie_context.py index 0a76132a..a8f6e516 100644 --- a/operators/common/aie_context.py +++ b/operators/common/aie_context.py @@ -28,8 +28,6 @@ def __init__(self, use_runlist=True): def register_operator(self, operator): """Register an operator with this context""" - if self._runtime_prepared: - raise RuntimeError("Cannot register operators after runtime is prepared") operator.context = self self.operators.append(operator) @@ -38,185 +36,3 @@ def compile_all(self): self.build_dir.mkdir(parents=True, exist_ok=True) for op in self.operators: op.compile() - - def prepare_runtime(self): - """Setup XRT runtime for all registered operators""" - if self._runtime_prepared: - return - - for op in self.operators: - op.set_up_runtime() - - # Pools of preallocated buffer objects; each buffer object is allocated - # once at program start and then reused across operators where possible. - bo_pools = {} - page_sz = 4096 - get_pool_sz = lambda x: (x + page_sz - 1) // page_sz * page_sz - - # Allocate static buffers first - for buffer_data in self.static_data_pool: - logging.debug( - f"Allocating static buffer with size {len(buffer_data)} bytes." - ) - bo = pyxrt.bo( - self.device_manager.device, - len(buffer_data), - pyxrt.bo.host_only, - 0x10000, - ) - bo.write(np.frombuffer(buffer_data, dtype=np.uint8), 0) - self.static_data_pool[buffer_data] = bo - - for op in self.operators: - if len(op.kernels) == 0: - continue - - logging.info(f"Preparing runtime for AIE operator: {op.__class__.__name__}") - - # Set up kernels - for kernel_name, (xclbin, xclbin_kernel_name, insts) in op.kernels.items(): - context, xrt_kernel = self.device_manager.get_context_and_kernel( - str(xclbin.path), xclbin_kernel_name - ) - with open(str(insts.path), "rb") as f: - instructions = np.frombuffer(f.read(), dtype=np.uint32) - logging.debug( - f"Allocating instruction buffer for {len(instructions)} instructions." - ) - insts_bo = pyxrt.bo( - self.device_manager.device, - instructions.nbytes, - pyxrt.bo.cacheable, - xrt_kernel.group_id(1), - ) - insts_bo.write(instructions.view(np.uint8), 0) - op.xrt_kernels[kernel_name] = ( - context, - xrt_kernel, - insts_bo, - len(instructions), - ) - - # If multiple buffers (of the same binned size) are used in the - # same kernel invocation OR across different invocations with shared - # buffers, they require separate allocations. - conflicting_buffers = {} # map buffer -> {set of conflicting buffers} - buffer_to_runlist_entries = {} # map buffer -> set of runlist entry indices - - # First pass: track which buffers appear in which runlist entries - for idx, (kernel, *args) in enumerate(op.runlist): - for arg in args: - buffer_to_runlist_entries.setdefault(arg, set()).add(idx) - - # Second pass: determine conflicts - for idx, (kernel, *args) in enumerate(op.runlist): - for arg in args: - if arg in op.buffer_static_data: - # Static buffers never conflict - continue - pool_sz = get_pool_sz(op.buffers[arg]) - - # Buffers conflict if they're in the same runlist entry - conflicting_args = { - a for a in args if get_pool_sz(op.buffers[a]) == pool_sz - } - {arg} - - # Also conflict with buffers in other runlist entries that share - # a buffer with this entry - for other_arg in args: - if other_arg == arg: - continue - for other_idx in buffer_to_runlist_entries.get( - other_arg, set() - ): - if other_idx != idx: - _, *other_args = op.runlist[other_idx] - conflicting_args.update( - { - a - for a in other_args - if get_pool_sz(op.buffers[a]) == pool_sz - and a != arg - } - ) - - conflicting_buffers[arg] = conflicting_buffers.get( - arg, set() - ).union(conflicting_args) - - # Allocate buffers - buffer_allocations = {} - for buffer_name, buffer_min_size in op.buffers.items(): - if buffer_name in op.buffer_static_data: - static_data = op.buffer_static_data[buffer_name] - op.buffer_bos[buffer_name] = self.static_data_pool[static_data] - continue - - alloc_pool = get_pool_sz(buffer_min_size) - alloc_idx = 0 - for conflict in conflicting_buffers.get(buffer_name, set()): - if conflict not in buffer_allocations: - continue - conflict_pool, conflict_idx = buffer_allocations[conflict] - alloc_idx = max(alloc_idx, conflict_idx + 1) - - assert 0 <= alloc_idx < len(bo_pools.get(alloc_pool, [])) + 1 - if alloc_idx == len(bo_pools.get(alloc_pool, [])): - bo = pyxrt.bo( - self.device_manager.device, - alloc_pool, - pyxrt.bo.host_only, - 0x10000, - ) - bo_pools.setdefault(alloc_pool, []).append(bo) - - buffer_allocations[buffer_name] = (alloc_pool, alloc_idx) - op.buffer_bos[buffer_name] = bo_pools[alloc_pool][alloc_idx] - - # Setup runlist - _, (first_xclbin, first_xclbin_kernel_name, _) = next( - iter(op.kernels.items()) - ) - context, _ = self.device_manager.get_context_and_kernel( - str(first_xclbin.path), first_xclbin_kernel_name - ) - if self.use_runlist: - op.xrt_runlist = pyxrt.runlist(context) - for i, (kernel_name, *buffer_args) in enumerate(op.runlist): - this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[ - kernel_name - ] - assert this_context == context - opcode = 3 - run = pyxrt.run(xrt_kernel) - run.set_arg(0, opcode) - run.set_arg(1, insts_bo) - run.set_arg(2, insts_len) - for j, buffer_arg in enumerate(buffer_args): - run.set_arg(j + 3, op.buffer_bos[buffer_arg]) - op.xrt_runlist.add(run) - else: - op.xrt_runlist = None - - # Log allocation info - bo_count = sum(len(pool) for pool in bo_pools.values()) - bo_footprint = sum(len(pool) * pool_sz for pool_sz, pool in bo_pools.items()) - logging.info( - f"Allocated {bo_count} total buffer objects with a total memory footprint of " - + ( - f"{bo_footprint//1024//1024} MiB." - if bo_footprint >= 1024 * 1024 - else f"{bo_footprint//1024} KiB." - ) - ) - static_data_footprint = sum(len(data) for data in self.static_data_pool) - logging.info( - f"Allocated {len(self.static_data_pool)} static buffers with a total memory footprint of " - + ( - f"{static_data_footprint//1024//1024} MiB." - if static_data_footprint >= 1024 * 1024 - else f"{static_data_footprint//1024} KiB." - ) - ) - - self._runtime_prepared = True diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 9201eb2c..2cb12214 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -8,12 +8,9 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) @@ -21,7 +18,7 @@ from operators.common.utils import torch_to_numpy, numpy_to_torch -class AIEGEMM(AIEOperatorBase): +class AIEGEMM(SingleMLIRSourceOperator): """AIE-accelerated General Matrix Multiplication (GEMM) layer""" def __init__( @@ -36,64 +33,50 @@ def __init__( # TODO: Add support for partitioning M and/or K # partition_M=1, # partition_K=1, - partition_N=1, num_aie_columns=8, context=None, **gemm_kwargs, ): - + num_aie_rows = 4 + min_M = tile_m * num_aie_rows + min_K = tile_k + min_N = tile_n * num_aie_columns + assert M % min_M == 0, f"M ({M}) must be multiple of {min_M}" + assert K % min_K == 0, f"K ({K}) must be multiple of {min_K}" + assert N % min_N == 0, f"N ({N}) must be multiple of {min_N}" + self.M = M + self.K = K + self.N = N self.tile_m = tile_m self.tile_k = tile_k self.tile_n = tile_n + self.num_aie_columns = num_aie_columns self.gemm_args = gemm_kwargs - - # Set frequently accessed gemm_args self.b_col_maj = gemm_kwargs.get("b_col_maj", False) self.c_col_maj = gemm_kwargs.get("c_col_maj", False) - self.weight = ( - None - if not use_static_weight - else torch.zeros((K, N), dtype=torch.bfloat16).T - ) - self.static_weight_shape = (K, N) - # The operator's M, K, N represent what the NPU operator supports. - # Calls to forward() may supply matrices of different sizes, and the - # Python code will perform necessary padding/repeated application of - # the NPU operator. - assert ( - N % partition_N == 0 - ), f"N ({N}) must be divisible by partition_N ({partition_N})" - M_padded, K_padded, N_padded = self._get_padded_dims( - M, K, N // partition_N, tile_m, tile_k, tile_n + emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( + "emulate_bf16_mmul_with_bfp16", True ) - self.M = M_padded - self.K = K_padded - self.N = N_padded - self.partition_N = partition_N + if emulate_bf16_mmul_with_bfp16: + min_tile_m, min_tile_k, min_tile_n = 8, 8, 8 + else: + min_tile_m, min_tile_k, min_tile_n = 4, 8, 8 + assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" + assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" + assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"gemm_{self.M}x{self.K}x{self.N}_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}" - def get_artifacts(self, prefix="gemm_"): - # Extract parameters from self + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - tile_m = self.tile_m - tile_k = self.tile_k - tile_n = self.tile_n - M = self.M - K = self.K - N = self.N - num_aie_columns = self.num_aie_columns + operator_name = self.get_operator_name() base_dir = self.context.base_dir device_str = self.context.device_manager.device_str() - - b_col_maj = self.b_col_maj - c_col_maj = self.c_col_maj dtype_in = self.gemm_args.get("dtype_in", "bf16") dtype_out = self.gemm_args.get("dtype_out", "bf16") emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( @@ -102,245 +85,174 @@ def get_artifacts(self, prefix="gemm_"): prio_accuracy = self.gemm_args.get("prio_accuracy", False) use_scalar = self.gemm_args.get("use_scalar", False) round_conv_even = self.gemm_args.get("round_conv_even", True) - - if emulate_bf16_mmul_with_bfp16: - min_tile_m, min_tile_k, min_tile_n = 8, 8, 8 - else: - min_tile_m, min_tile_k, min_tile_n = 4, 8, 8 - assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" - assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" - assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - - file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}" - file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}" - xclbin_kernel_name = f"gemm_{file_name_tile_base}" - kernel_flags = [ - f"-DDIM_M={tile_m}", - f"-DDIM_K={tile_k}", - f"-DDIM_N={tile_n}", - "-DROUND_CONV_EVEN", - ] - if prio_accuracy: - kernel_flags.append("-Dbf16_f32_ONLY") - else: - kernel_flags.append("-Dbf16_bf16_ONLY") - if round_conv_even: - kernel_flags.append("-DROUND_CONV_EVEN") - if emulate_bf16_mmul_with_bfp16: - kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") - if b_col_maj: - kernel_flags.append("-DB_COL_MAJ") - if c_col_maj: - kernel_flags.append("-DC_COL_MAJ") - - kernel_archive = ( - f"gemm_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}.a" - ) - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_total_base}.mlir", + separate_c_tiles = self.gemm_args.get("separate_c_tiles", False) + return PythonGeneratedMLIRArtifact.new( + f"{operator_name}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matmul", callback_kwargs={ "dev": device_str, - "M": M, - "K": K, - "N": N, - "m": tile_m, - "k": tile_k, - "n": tile_n, - "n_aie_cols": num_aie_columns, + "M": self.M, + "K": self.K, + "N": self.N, + "m": self.tile_m, + "k": self.tile_k, + "n": self.tile_n, + "n_aie_cols": self.num_aie_columns, "dtype_in_str": dtype_in, "dtype_out_str": dtype_out, - "b_col_maj": int(b_col_maj), - "c_col_maj": int(c_col_maj), + "b_col_maj": int(self.b_col_maj), + "c_col_maj": int(self.c_col_maj), "use_scalar": use_scalar, "emulate_bf16_mmul_with_bfp16": emulate_bf16_mmul_with_bfp16, "prio_accuracy": prio_accuracy, - "separate_c_tiles": int(self.partition_N > 1), + "separate_c_tiles": int(separate_c_tiles), "trace_size": 0, - "archive": kernel_archive, + "archive": self.get_kernel_archive_name(), "generate_taps": False, }, requires_context=False, ) - - # FIXME: We should be able to reuse the same xclbin for same tile - # sizes, only swapping out the instruction sequence for different - # problem sizes. However, there seem to be cases where this does - # not work and the GEMM appears to be misconfigured for the wrong - # size (resulting in a timeout when trying to run it). Perhaps - # XRT is caching something, or something is wrong with the run- - # time parameter (synchronization)? For now, create separate - # xclbins for each problem size. - xclbin_artifact = XclbinArtifact.new( - f"{file_name_total_base}.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - kernel_archive, - depends=[ - KernelObjectArtifact.new( - f"gemm_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}.o", - extra_flags=kernel_flags, - depends=[ - SourceArtifact.new( - base_dir / "aie_kernels" / "aie2p" / "mm.cc" - ) - ], - ), - KernelObjectArtifact.new( - "convert_copy.o", - [ - SourceArtifact.new( - base_dir - / "aie_kernels" - / "generic" - / "convert_copy.cc" - ) - ], - ), - ], - ), - ], - extra_flags=["--dynamic-objFifos"], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_total_base}.bin", - depends=[mlir_artifact], - extra_flags=["--dynamic-objFifos"], - ) - - return (xclbin_artifact, insts_artifact) - - def set_up_artifacts(self): - # Describe required artifacts (xclbin, insts.bin) - device_str = self.context.device_manager.device_str() - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - static_weights = None - if self.weight is not None: - static_weights = self.weight.T - if isinstance(static_weights, torch.Tensor): - static_weights = torch_to_numpy(static_weights) - self.add_kernel( - "gemm", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer("A", self.M * self.K) - B_parts = self._partition_B(static_weights) - for i, B_part in enumerate(B_parts): - self.add_buffer( - f"B_{i}", - self.K * self.N, - static_data=B_part, - ) - self.add_buffer(f"C_{i}", self.M * self.N) - self.add_to_runlist("gemm", "A", f"B_{i}", f"C_{i}") - - def _get_B_dims(self, B_shape): - """Extract K and N dimensions from B matrix shape based on layout. - - Returns: - tuple: (K, N) dimensions regardless of B's layout - """ - if self.b_col_maj: - return B_shape[-1], B_shape[-2] # B is (N, K) -> return (K, N) - else: - return B_shape[-2], B_shape[-1] # B is (K, N) -> return (K, N) - - def forward(self, A, B=None): - """Forward pass through GEMM operation: C = A @ B""" - B_shape = B.shape if B is not None else self.static_weight_shape - - # Determine output dimensions based on matrix layout - K2, N = self._get_B_dims(B_shape) - N_part = N // self.partition_N - - # Build expected output shape based on C layout - expected_output_shape = ( - A.shape[:-2] + (N, A.shape[-1]) if self.c_col_maj else A.shape[:-1] + (N,) - ) - - # Remove batch dimension, if any - if len(A.shape) > 2: - A = A.view(-1, A.shape[-1]) - if B is not None and len(B.shape) > 2: - B = B.view(-1, B_shape[-1]) - - M, K = A.shape - - applicable = ( - K == K2 - and (M <= self.M or not self.c_col_maj) - and K <= self.K - and N <= self.N + + def get_kernel_artifacts(self): + base_dir = self.context.base_dir + emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( + "emulate_bf16_mmul_with_bfp16", True ) - if not applicable: - raise AIEOperatorConstraintError("AIEGEMM: incompatible tensor shape(s)") - - A_padded = self._pad_A(torch_to_numpy(A)) - if B is not None: - B_parts = self._partition_B(torch_to_numpy(B)) + prio_accuracy = self.gemm_args.get("prio_accuracy", False) + round_conv_even = self.gemm_args.get("round_conv_even", True) + kernel_flags = [ + f"-DDIM_M={self.tile_m}", + f"-DDIM_K={self.tile_k}", + f"-DDIM_N={self.tile_n}", + "-DROUND_CONV_EVEN", + ] + if prio_accuracy: + kernel_flags.append("-Dbf16_f32_ONLY") else: - B_parts = None - - logging.debug( - f"Executing GEMM for dimensions M={M}, K={K}, N={N} using NPU operator with M={self.M}, K={self.N}, N={self.N}" - ) - + kernel_flags.append("-Dbf16_bf16_ONLY") + if round_conv_even: + kernel_flags.append("-DROUND_CONV_EVEN") + if emulate_bf16_mmul_with_bfp16: + kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") + if self.b_col_maj: + kernel_flags.append("-DB_COL_MAJ") if self.c_col_maj: - result_padded = np.zeros((N, M), dtype=A_padded.dtype) - else: - result_padded = np.zeros((M, N), dtype=A_padded.dtype) - for M_lo in range(0, M, self.M): - A_part = A_padded[M_lo : M_lo + self.M, :] - result_parts = self._execute_aie_operation(A_part, B_parts) - max_M = min(M_lo + self.M, M) - for part in range(self.partition_N): - if self.c_col_maj: - result_padded[part * N_part : (part + 1) * N_part, M_lo:max_M] = ( - result_parts[part][:N_part, :max_M] + kernel_flags.append("-DC_COL_MAJ") + return [ + KernelObjectArtifact.new( + f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.o", + extra_flags=kernel_flags, + depends=[ + SourceArtifact.new( + base_dir / "aie_kernels" / "aie2p" / "mm.cc" ) - else: - result_padded[M_lo:max_M, part * N_part : (part + 1) * N_part] = ( - result_parts[part][:max_M, :N_part] + ], + ), + KernelObjectArtifact.new( + "convert_copy.o", + [ + SourceArtifact.new( + base_dir + / "aie_kernels" + / "generic" + / "convert_copy.cc" ) + ], + ) + ] + + def get_kernel_archive_name(self): + return ( + f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.a" + ) - # GEMM produces 2D result, reshape to expected output shape - if self.c_col_maj: - result = numpy_to_torch(result_padded[:N, :M]) - else: - result = numpy_to_torch(result_padded[:M, :N]) - result = result.view(expected_output_shape) - - return result - - def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n): - num_aie_columns = self.num_aie_columns - num_aie_rows = 4 - - min_M = tile_m * num_aie_rows - min_K = tile_k - min_N = tile_n * num_aie_columns - - # Calculate padded dimensions - M_padded = ((M + min_M - 1) // min_M) * min_M - K_padded = ((K + min_K - 1) // min_K) * min_K - N_padded = ((N + min_N - 1) // min_N) * min_N - - return M_padded, K_padded, N_padded + def get_arg_spec(self): + return [ + AIERuntimeArgSpec((self.M, self.K)), # input A + AIERuntimeArgSpec((self.K, self.N) if not self.b_col_maj else (self.N, self.K)), # input B (weights) + AIERuntimeArgSpec((self.M, self.N) if not self.c_col_maj else (self.N, self.M)), # output C + ] - def _pad_A(self, A_np): + # def _get_B_dims(self, B_shape): + # """Extract K and N dimensions from B matrix shape based on layout. + + # Returns: + # tuple: (K, N) dimensions regardless of B's layout + # """ + # if self.b_col_maj: + # return B_shape[-1], B_shape[-2] # B is (N, K) -> return (K, N) + # else: + # return B_shape[-2], B_shape[-1] # B is (K, N) -> return (K, N) + + # def forward(self, A, B=None): + # """Forward pass through GEMM operation: C = A @ B""" + # B_shape = B.shape if B is not None else self.static_weight_shape + + # # Determine output dimensions based on matrix layout + # K2, N = self._get_B_dims(B_shape) + # N_part = N // self.partition_N + + # # Build expected output shape based on C layout + # expected_output_shape = ( + # A.shape[:-2] + (N, A.shape[-1]) if self.c_col_maj else A.shape[:-1] + (N,) + # ) + + # # Remove batch dimension, if any + # if len(A.shape) > 2: + # A = A.view(-1, A.shape[-1]) + # if B is not None and len(B.shape) > 2: + # B = B.view(-1, B_shape[-1]) + + # M, K = A.shape + + # applicable = ( + # K == K2 + # and (M <= self.M or not self.c_col_maj) + # and K <= self.K + # and N <= self.N + # ) + # if not applicable: + # raise AIEOperatorConstraintError("AIEGEMM: incompatible tensor shape(s)") + + # A_padded = self._pad_A(torch_to_numpy(A)) + # if B is not None: + # B_parts = self._partition_B(torch_to_numpy(B)) + # else: + # B_parts = None + + # logging.debug( + # f"Executing GEMM for dimensions M={M}, K={K}, N={N} using NPU operator with M={self.M}, K={self.N}, N={self.N}" + # ) + + # if self.c_col_maj: + # result_padded = np.zeros((N, M), dtype=A_padded.dtype) + # else: + # result_padded = np.zeros((M, N), dtype=A_padded.dtype) + # for M_lo in range(0, M, self.M): + # A_part = A_padded[M_lo : M_lo + self.M, :] + # result_parts = self._execute_aie_operation(A_part, B_parts) + # max_M = min(M_lo + self.M, M) + # for part in range(self.partition_N): + # if self.c_col_maj: + # result_padded[part * N_part : (part + 1) * N_part, M_lo:max_M] = ( + # result_parts[part][:N_part, :max_M] + # ) + # else: + # result_padded[M_lo:max_M, part * N_part : (part + 1) * N_part] = ( + # result_parts[part][:max_M, :N_part] + # ) + + # # GEMM produces 2D result, reshape to expected output shape + # if self.c_col_maj: + # result = numpy_to_torch(result_padded[:N, :M]) + # else: + # result = numpy_to_torch(result_padded[:M, :N]) + # result = result.view(expected_output_shape) + + # return result + + def pad_A(self, A_np): """Pad A matrix to match operator dimensions (M, K)""" M, K = A_np.shape if M % self.M == 0 and K == self.K: @@ -351,7 +263,7 @@ def _pad_A(self, A_np): A_padded[:M, :K] = A_np return A_padded - def _pad_B(self, B_np): + def pad_B(self, B_np): """Pad B matrix to match operator dimensions based on layout""" if self.b_col_maj: N, K = B_np.shape @@ -367,56 +279,16 @@ def _pad_B(self, B_np): B_padded[:K, :N] = B_np return B_padded - def _partition_B(self, B): - B_parts = [None] * self.partition_N + def partition_B(self, B, partition_N): + B_parts = [None] * partition_N if B is None: return B_parts - for i in range(self.partition_N): + for i in range(partition_N): col_start = i * self.N col_end = (i + 1) * self.N - # Just in case, pad the weights before adding the buffer if self.b_col_maj: - B_parts[i] = self._pad_B(B[col_start:col_end, :]) + B_parts[i] = self.pad_B(B[col_start:col_end, :]) else: - B_parts[i] = self._pad_B(B[:, col_start:col_end]) - self.static_weight_shape = B_parts[0].shape + B_parts[i] = self.pad_B(B[:, col_start:col_end]) return B_parts - - def _execute_aie_operation(self, A_np, B_nps=None): - """Execute GEMM operation on AIE hardware""" - M, K = A_np.shape - B_shape = B_nps[0].shape if B_nps is not None else self.static_weight_shape - K2, N = self._get_B_dims(B_shape) - C_shape = (N, M) if self.c_col_maj else (M, N) - - # Validate dimensions match operator configuration - assert M == self.M - assert K == K2 and K == self.K - assert N == self.N - - self.write_buffer("A", A_np) - if B_nps is not None: - for i, B_np in enumerate(B_nps): - self.add_buffer( - f"B_{i}", - self.M * self.N, - static_data=B_np, - ) - self.run_runlist() - result_nps = [ - self.read_buffer(f"C_{i}", shape=C_shape, dtype=bfloat16) - for i in range(self.partition_N) - ] - - # Check for NaN and fail hard - # for result_np in result_nps: - # if np.isnan(result_np).any(): - # nan_count = np.isnan(result_np).sum() - # total_count = result_np.size - # raise RuntimeError( - # f"AIE execution returned {nan_count}/{total_count} NaN values. " - # ) - - # Convert back to torch tensor - return result_nps diff --git a/operators/rms_norm/design.py b/operators/rms_norm/design.py index 2bf09b43..499c52d5 100644 --- a/operators/rms_norm/design.py +++ b/operators/rms_norm/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size, archive_name="rms_norm.a"): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -46,7 +46,7 @@ def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_s # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", "rms_norm.o", [tile_ty, tile_ty, np.int32] + "rms_norm_bf16_vector", archive_name, [tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile @@ -120,93 +120,3 @@ def core_body(of_in1, of_out, rms_norm_kernel): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - # It must be a multiple of 1024 and divisible by the number of columns and 2 channels per column - p.add_argument("-l", "--length", required=True, dest="length", help="Transfer size") - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Tile size (columns per tile) - defaults to 1024 for backward compatibility - p.add_argument( - "-ts", - "--tile-size", - required=False, - dest="tile_size", - default="1024", - help="Tile size (columns per tile)", - ) - # Trace Size - p.add_argument( - "-tr", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - length = int(opts.length) - columns = int(opts.cols) - dev = opts.device # Now this is already a device object! - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] NPU device cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - tile_size = int(opts.tile_size) - if ((length % tile_size) % columns % channels) != 0: - print( - "transfer size (" - + str(length) - + ") must be a multiple of " - + str(tile_size) - + " and divisible by the number of columns and 2 channels per column" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = my_rms_norm(dev, length, columns, channels, trace_size, tile_size) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/operators/rms_norm/design_weighted.py b/operators/rms_norm/design_weighted.py index 20c4fbbe..6a2929f8 100644 --- a/operators/rms_norm/design_weighted.py +++ b/operators/rms_norm/design_weighted.py @@ -16,7 +16,7 @@ def my_weighted_rms_norm( - dev, num_elements, num_columns, num_channels, weight_length, trace_size + dev, num_elements, num_columns, num_channels, weight_length, trace_size, archive_name="rms_norm.a" ): per_tile_elements = weight_length total_cores = num_columns # For each core that does rms norm, another core will take its output to do eltwise mul @@ -53,11 +53,11 @@ def my_weighted_rms_norm( # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", "rms_norm_archive.a", [tile_ty, tile_ty, np.int32] + "rms_norm_bf16_vector", archive_name, [tile_ty, tile_ty, np.int32] ) eltwise_mul_kernel = Kernel( "eltwise_mul_bf16_vector", - "rms_norm_archive.a", + archive_name, [tile_ty, weights_ty, tile_ty, np.int32], ) @@ -157,96 +157,3 @@ def core_body_mul(of_in1, of_in2, of_out2, eltwise_mul): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - # It must be a multiple of 1024 and divisible by the number of columns and 2 channels per column - p.add_argument("-l", "--length", required=True, dest="length", help="Transfer size") - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Weight length - p.add_argument( - "-wl", - "--weight-length", - required=True, - dest="weight_length", - help="Weight vector length", - ) - # Trace Size - p.add_argument( - "-ts", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - length = int(opts.length) - columns = int(opts.cols) - dev = opts.device # Now this is already a device object! - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] NPU device cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - weight_length = int(opts.weight_length) - # For weighted RMS norm: cores = columns (weights are broadcasted) - total_cores = columns - if (length % (weight_length * total_cores)) != 0: - print( - "transfer size (" - + str(length) - + ") must be a multiple of weight_length * total_cores (" - + str(weight_length * total_cores) - + ")" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = my_weighted_rms_norm( - dev, length, columns, channels, weight_length, trace_size - ) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index 4553fa6a..0d3457f2 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -8,8 +8,8 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -20,7 +20,7 @@ from operators.common.utils import torch_to_numpy -class AIERMSNorm(AIEOperatorBase): +class AIERMSNorm(SingleMLIRSourceOperator): """AIE-accelerated RMS Normalization layer""" def __init__( @@ -34,9 +34,10 @@ def __init__( context=None, ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert size % max_multiple == 0, "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -44,27 +45,20 @@ def __init__( self.eps = eps self.weighted = weighted - # Initializes weights to 1. Weights have size embedding dim, which is assumed to be tile size - self.weight = nn.Parameter(torch.ones(tile_size, dtype=torch.bfloat16)) - # Enforce ShimDMA limits for weighted RMS Norm (uses 2 inputs per core) # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"weighted_rms_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"weighted_rms_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design_weighted.py", callback_fn="my_weighted_rms_norm", callback_args=[ @@ -75,127 +69,40 @@ def set_up_artifacts(self): self.tile_size, 0, ], + callback_kwargs={ + "archive_name": f"{self.get_operator_name()}.a", + } ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - f"rms_norm_archive.a", - depends=[ - KernelObjectArtifact.new( - f"rms_norm.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "rms_norm.cc" - ) - ], - ), - KernelObjectArtifact.new( - "mul.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "mul.cc" - ) - ], - ), - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runtime setup - static_weights = None - if self.weight is not None: - static_weights = torch_to_numpy(self.weight) - - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.tile_size, static_data=static_weights) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_mul", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_mul", "input1", "input2", "output") - - def forward(self, x, y=None): - """Forward pass through RMS normalization""" - applicable = ( - len(x.shape) >= 1 and x.shape[-1] <= self.size and x.numel() <= self.size - ) - if not applicable: - raise AIEOperatorConstraintError("AIERMSNorm: incompatible tensor shape(s)") - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - if y is not None: - y_flat = y.reshape(batch, -1) - else: - y_flat = None - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y=None): - """Execute RMS normalization on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - if y is not None: - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input1", x_flat) - if y is not None: - self.write_buffer("input2", y_flat) - else: - assert ( - self.weight is not None - ), "Weights must be provided either as input or during initialization." - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"rms_norm.o", + depends=[ + SourceArtifact.new( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "rms_norm.cc" + ) + ], + ), + KernelObjectArtifact.new( + "mul.o", + depends=[ + SourceArtifact.new( + self.context.base_dir + / "aie_kernels" + / "generic" + / "mul.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec((self.size // self.tile_size, self.tile_size)), # input + AIERuntimeArgSpec((self.tile_size,)), # weight + AIERuntimeArgSpec((self.size // self.tile_size, self.tile_size)) # output + ] From 1daaa394849b1f0e3963d46b37d62537fb5e9ae2 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 16:35:41 -0700 Subject: [PATCH 11/99] last layer GEMM offloaded --- applications/llama_3.2_1b/llama_npu.py | 48 +++++++++++++++++--------- operators/common/aie_base.py | 5 +++ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 3981dd32..93811075 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -6,6 +6,7 @@ import sys import ml_dtypes import llama_inference_harness as harness +import time repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) @@ -110,6 +111,14 @@ def __init__(self, config, prompt_len): for i in range(config.vocab_partitions) ] + for buf in ( + [ + self.W_final_norm + ] + + self.W_out_head_parts + ): + buf.to("npu") + # Operators # ########################################################################## @@ -352,18 +361,20 @@ def llama_forward_pass_prefill( # Step 5: Output projection (check for tied embeddings) # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. - # for i in range(config.vocab_partitions): - # aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) - # aie_buffers.logits_prefill.to("cpu") - # # Reassemble (transpose) the logits from partitions - # logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) - # logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) - # logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) - - aie_buffers.x_prefill.to("cpu") - x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] - - logits = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + aie_buffers.logits_prefill.to("npu") + for i in range(config.vocab_partitions): + aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) + aie_buffers.logits_prefill.to("cpu") + time.sleep(0.15) # FIXME: There is a synchronization issue; without this, the latter part of logits (rightmore columns after transpose) are missing (all zeroes) + logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) + logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) + logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) + + # Reference: + # aie_buffers.x_prefill.to("cpu") + # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + # assert (logits - logits_ref).max() < 0.5 return logits, state @@ -399,9 +410,7 @@ def llama_forward_pass_decode( # Step 4: Final normalization aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_buffers.x_decode.to("npu") aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) - aie_buffers.x_decode.to("cpu") # Step 5: Output projection lm_head_weight = config.weights['model.embed_tokens.weight'] @@ -415,14 +424,19 @@ def llama_forward_pass_decode( def main(): global aie_ops, aie_buffers + max_seq_len = 2048 prompt = "The capital of France is " + #with open('prompt.txt', 'r') as f: + # prompt = f.read() + #prompt = prompt[:max_seq_len] + config, state = harness.init(prompt=prompt) - aie_ops = AIELlamaOperators(config, 2048) - aie_buffers = AIELlamaBuffers(config, 2048) + aie_ops = AIELlamaOperators(config, max_seq_len) + aie_buffers = AIELlamaBuffers(config, max_seq_len) print(prompt, end='', flush=True) - harness.generate(config, state, llama_forward_pass, use_kv_cache=False) + harness.generate(config, state, llama_forward_pass, use_kv_cache=True) if __name__ == "__main__": main() diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index a60e06b9..8e6ee9de 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -226,9 +226,11 @@ def to(self, dest): if dest == "npu": if self.on != "npu": self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + self.on = "npu" elif dest == "cpu": if self.on != "cpu": self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + self.on = "cpu" else: raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") return self @@ -278,3 +280,6 @@ def __call__(self, *buffers): opcode = 3 bos = [buffer.bo for buffer in buffers] run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) + for buffer in buffers: + buffer.to("cpu") + From 83380c20d633e62356f00808d99b6c0ffebca143 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 20:22:34 -0700 Subject: [PATCH 12/99] fixes: --- applications/llama_3.2_1b/llama_npu.py | 41 ++++--- operators/__init__.py | 2 +- operators/common/aie_base.py | 47 +++++--- operators/gemv/op.py | 161 +++++-------------------- 4 files changed, 86 insertions(+), 165 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 93811075..c17ecf17 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -19,7 +19,7 @@ from operators import ( AIERMSNorm, AIEGEMM, - #AIEGEMV + AIEGEMV ) @@ -70,15 +70,15 @@ def __init__(self, config, prompt_len): context=self.context ).compile() self.out_head_prefill = self.out_head_prefill_compilable.get_callable() - #self.out_head_decode = AIEGEMV( - # M=config.vocab_size, K=config.emb_dim, - # num_aie_columns=8, - # is_mv=True, - # use_static_weight=True, - # tile_size_input=4, - # tile_size_output=32, - # context=self.context - #) + self.out_head_decode = AIEGEMV( + M=config.vocab_size, + K=config.emb_dim, + num_aie_columns=8, + use_static_weight=True, + tile_size_input=4, + tile_size_output=32, + context=self.context + ).compile().get_callable() # Allocate buffers shared with NPU @@ -93,6 +93,7 @@ def __init__(self, config, prompt_len): self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']) # Final linear layer + self.W_out_head = AIEBuffer.from_np(torch_to_numpy(config.weights['model.embed_tokens.weight'])) # unpadded/unpartitioned, used by GEMV W_out_head_parts = aie_ops.out_head_prefill_compilable.partition_B( torch_to_numpy(config.weights['model.embed_tokens.weight']), config.vocab_partitions @@ -100,7 +101,7 @@ def __init__(self, config, prompt_len): self.W_out_head_parts = [ AIEBuffer.from_np(W_out_head_part) for W_out_head_part in W_out_head_parts - ] + ] # partitioned, padded parts of weight, used by GEMM self.logits_prefill = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)) self.logits_prefill_parts = [ self.logits_prefill.subbuffer( @@ -110,10 +111,12 @@ def __init__(self, config, prompt_len): ) for i in range(config.vocab_partitions) ] + self.logits_decode = AIEBuffer(shape=(config.vocab_size,)) for buf in ( [ - self.W_final_norm + self.W_final_norm, + self.W_out_head, ] + self.W_out_head_parts ): @@ -365,7 +368,6 @@ def llama_forward_pass_prefill( for i in range(config.vocab_partitions): aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) aie_buffers.logits_prefill.to("cpu") - time.sleep(0.15) # FIXME: There is a synchronization issue; without this, the latter part of logits (rightmore columns after transpose) are missing (all zeroes) logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) @@ -409,12 +411,19 @@ def llama_forward_pass_decode( ) # Step 4: Final normalization - aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_decode.view_as_torch().view(1, 1, config.emb_dim)[0, 0, :] = x + aie_buffers.x_decode.to("npu") aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) # Step 5: Output projection - lm_head_weight = config.weights['model.embed_tokens.weight'] - logits = torch.nn.functional.linear(aie_buffers.x_decode.view_as_torch().unsqueeze(0), lm_head_weight) # (batch, seq_len, vocab_size) + aie_buffers.logits_decode.to("npu") + aie_ops.out_head_decode(aie_buffers.W_out_head, aie_buffers.x_decode.view((config.emb_dim,)), aie_buffers.logits_decode) + aie_buffers.logits_decode.to("cpu") + + logits = aie_buffers.logits_decode.view_as_torch().view(1, 1, config.vocab_size) + # Reference: + # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) + # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) return logits, state diff --git a/operators/__init__.py b/operators/__init__.py index 9b00a29e..4bc8ba33 100644 --- a/operators/__init__.py +++ b/operators/__init__.py @@ -7,7 +7,7 @@ #from .elementwise_mul.op import AIEElementwiseMul #from .gelu.op import AIEGELU from .gemm.op import AIEGEMM -#from .gemv.op import AIEGEMV +from .gemv.op import AIEGEMV #from .layer_norm.op import AIELayerNorm #from .leaky_relu.op import AIELeakyReLU #from .mem_copy.op import AIEMemCopy diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index 8e6ee9de..1a8cd105 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -196,6 +196,8 @@ def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): pyxrt.bo.host_only, 0x10000, ) + self.memory_view = self.bo.map() + self.subviews = [] def subbuffer(self, length, offset, shape, dtype=None): if dtype is None: @@ -210,29 +212,42 @@ def subbuffer(self, length, offset, shape, dtype=None): length * itemsize, # size offset * itemsize, # offset ) - return AIEBuffer(shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager) + sub_buffer = AIEBuffer(shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager) + self.subviews.append(sub_buffer) + return sub_buffer + + def view(self, shape): + assert np.prod(shape) == np.prod(self.shape) + sub_buffer = AIEBuffer(shape=shape, dtype=self.dtype, bo=self.bo, device_manager=self.device_manager) + sub_buffer.on = self.on + self.subviews.append(sub_buffer) + return sub_buffer def view_as_np(self): self.to("cpu") - # Create a byte accessible memory view of the buffer object - mv = self.bo.map() # Interpret the buffer as a 1-dimensional array then change its view to the expected shape - return np.frombuffer(mv, dtype=self.dtype, count=np.prod(self.shape)).reshape(self.shape) + return np.frombuffer(self.memory_view, dtype=self.dtype, count=np.prod(self.shape)).reshape(self.shape) def view_as_torch(self): return numpy_to_torch(self.view_as_np()) def to(self, dest): - if dest == "npu": - if self.on != "npu": - self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - self.on = "npu" - elif dest == "cpu": - if self.on != "cpu": - self.bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - self.on = "cpu" - else: + direction = { + "npu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE, + "cpu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE + } + if dest not in direction: raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") + if self.on == dest: + return self + direction = direction[dest] + self.bo.sync(direction) + self.on = dest + todo = self.subviews.copy() + while todo: + sub_buffer = todo.pop() + sub_buffer.on = self.on + todo.extend(sub_buffer.subviews) return self @staticmethod @@ -275,11 +290,9 @@ def __call__(self, *buffers): for i in range(len(buffers)) ), "Input buffer shapes or dtypes do not match expected argument specification." self.insts_buffer.to("npu") - for buffer in buffers: - buffer.to("npu") + assert all(buffer.on == "npu" for buffer in buffers), "Not all input buffers have been synced on the NPU." opcode = 3 bos = [buffer.bo for buffer in buffers] run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) - for buffer in buffers: - buffer.to("cpu") + run.wait() diff --git a/operators/gemv/op.py b/operators/gemv/op.py index b86088d4..bd10dd79 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -7,8 +7,8 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -19,7 +19,7 @@ from operators.common.utils import torch_to_numpy -class AIEGEMV(AIEOperatorBase): +class AIEGEMV(SingleMLIRSourceOperator): """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" def __init__( @@ -40,31 +40,25 @@ def __init__( tile_size_output % tile_size_input == 0 and tile_size_output >= tile_size_input ), "tile_size_output must be a multiple of tile_size_input" - self.M = M # matrix rows (if is_mv=False, matrix columns) - self.K = K # matrix columns, vector rows (if is_mv=False, matrix rows, vector columns) + self.M = M # matrix rows + self.K = K # matrix columns, vector rows self.num_aie_columns = num_aie_columns self.tile_size_input = tile_size_input self.tile_size_output = tile_size_output - self.is_mv = is_mv - if use_static_weight: - self.weight = torch.zeros( - (M, K) if is_mv else (K, M), dtype=torch.bfloat16 - ).T # weights are stored col-major/transposed - else: - self.weight = None self.xclbin_artifact = None self.insts_artifact = None - AIEOperatorBase.__init__(self, context=context) + SingleMLIRSourceOperator.__init__(self, context=context) - def get_artifacts(self, prefix="gemv_"): - # The underlying MLIR design is a matrix-vector multiplication. We support vector-matrix multiplication by transposing the matrix beforehand (AB = C <=> B^T A^T = C^T). + def get_operator_name(self): + return f"{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_aie_columns}col" + + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_aie_columns}col" - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matvec", callback_args=[ @@ -77,116 +71,21 @@ def get_artifacts(self, prefix="gemv_"): ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"mv.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. - # Runtime Setup - # --- - static_weights = None - if self.weight is not None: - # Kernel expects row-major weights, so might need to transpose (torch weights are stored in col-major); - # also might need to transpose if is_mv - if self.is_mv: - static_weights = self.weight.T - else: - # Double transpose cancels out - static_weights = self.weight - if isinstance(static_weights, torch.Tensor): - static_weights = torch_to_numpy(static_weights) - self.add_kernel( - "gemv", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer("matrix", self.M * self.K, static_data=static_weights) - self.add_buffer("vector", self.K) - self.add_buffer("output", self.M) - self.add_to_runlist("gemv", "matrix", "vector", "output") - - def forward(self, vector, matrix=None): - """Forward pass through GEMV operation - - Args: - matrix: Input matrix of shape (..., M, K) - vector: Input vector of shape (..., K) for MV or (..., M) for VM - is_mv: True for matrix-vector multiplication, False for vector-matrix - - Returns: - Output vector of shape (..., M) for MV or (..., K) for VM - """ - - # Flatten batch dimensions if needed - if matrix is not None: - matrix = matrix.reshape(*matrix.shape[-2:]) - original_vector_dims = vector.ndim - vector = vector.reshape(*vector.shape[-1:]) - - # For vector-matrix, we'll transpose the matrix internally - if matrix is not None and not self.is_mv: - # Transpose the matrix for vector-matrix multiplication - # (if using static weights, the matrix is already transposed once at setup if needed) - matrix = matrix.transpose(-2, -1) - - if matrix is not None: - matrix_rows = matrix.shape[-2] - matrix_cols = matrix.shape[-1] - else: - matrix_rows = self.M - matrix_cols = self.K - - vector_size = vector.shape[-1] - - applicable = ( - matrix_cols == vector_size - and matrix_rows == self.M - and matrix_cols == self.K - and (matrix is None or matrix.dtype == torch.bfloat16) - and vector.dtype == torch.bfloat16 - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - if matrix is not None: - # If matrix is none, we are using static weights that have already been written to the buffer - self.write_buffer("matrix", matrix) - self.write_buffer("vector", vector) - self.run_runlist() - result = self.read_buffer_as_torch("output", (self.M,)) - - # Add back batch dimensions if we removed them earlier. - if result.ndim < original_vector_dims: - result = result.reshape(*((1,) * (original_vector_dims - 1)), -1) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"mv.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec((self.M, self.K)), # matrix + AIERuntimeArgSpec((self.K,)), # vector + AIERuntimeArgSpec((self.M,)), # output + ] From 3aa1e0271a91a0473bffec0c1d323d3f6355f653 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 20:30:13 -0700 Subject: [PATCH 13/99] cleanup --- applications/llama_3.2_1b/llama_npu.py | 77 +++++++++++--------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index c17ecf17..f6783562 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -323,19 +323,6 @@ def llama_forward_pass( diagonal=1 ) - if seq_len == 1: - return llama_forward_pass_decode(config, state, x, attn_mask) - else: - return llama_forward_pass_prefill(config, state, x, attn_mask) - -def llama_forward_pass_prefill( - config, - state, - x, - attn_mask -): - batch, seq_len, _ = x.shape - # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( @@ -358,26 +345,40 @@ def llama_forward_pass_prefill( ) # Step 4: Final normalization - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_buffers.x_prefill.to("npu") - aie_ops.final_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) + if seq_len > 1: + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_prefill.to("npu") + aie_ops.final_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) + else: + aie_buffers.x_decode.view_as_torch().view(1, 1, config.emb_dim)[0, 0, :] = x + aie_buffers.x_decode.to("npu") + aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) # Step 5: Output projection (check for tied embeddings) - # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. - aie_buffers.logits_prefill.to("npu") - for i in range(config.vocab_partitions): - aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) - aie_buffers.logits_prefill.to("cpu") - logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) - logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) - logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) - - # Reference: - # aie_buffers.x_prefill.to("cpu") - # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] - # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - # assert (logits - logits_ref).max() < 0.5 - + if seq_len > 1: + # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. + aie_buffers.logits_prefill.to("npu") + for i in range(config.vocab_partitions): + aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) + aie_buffers.logits_prefill.to("cpu") + logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) + logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) + logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) + # Reference: + # aie_buffers.x_prefill.to("cpu") + # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + # assert (logits - logits_ref).max() < 0.5 + else: + # Step 5: Output projection + aie_buffers.logits_decode.to("npu") + aie_ops.out_head_decode(aie_buffers.W_out_head, aie_buffers.x_decode.view((config.emb_dim,)), aie_buffers.logits_decode) + aie_buffers.logits_decode.to("cpu") + logits = aie_buffers.logits_decode.view_as_torch().view(1, 1, config.vocab_size) + # Reference: + # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) + # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + return logits, state @@ -410,20 +411,6 @@ def llama_forward_pass_decode( attn_mask=attn_mask, ) - # Step 4: Final normalization - aie_buffers.x_decode.view_as_torch().view(1, 1, config.emb_dim)[0, 0, :] = x - aie_buffers.x_decode.to("npu") - aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) - - # Step 5: Output projection - aie_buffers.logits_decode.to("npu") - aie_ops.out_head_decode(aie_buffers.W_out_head, aie_buffers.x_decode.view((config.emb_dim,)), aie_buffers.logits_decode) - aie_buffers.logits_decode.to("cpu") - - logits = aie_buffers.logits_decode.view_as_torch().view(1, 1, config.vocab_size) - # Reference: - # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) - # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) return logits, state From 002d48733900acc3a6974dd1867ff86044ed924f Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 20:57:16 -0700 Subject: [PATCH 14/99] RMSNorm offloaded verywhere, cleanup --- applications/llama_3.2_1b/llama_npu.py | 130 +++++++++++++------------ operators/common/aie_base.py | 7 +- operators/gemm/op.py | 6 +- operators/gemv/op.py | 6 +- operators/rms_norm/op.py | 6 +- 5 files changed, 80 insertions(+), 75 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index f6783562..a920053a 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -6,7 +6,6 @@ import sys import ml_dtypes import llama_inference_harness as harness -import time repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) @@ -35,7 +34,7 @@ def __init__(self, config, prompt_len): self.context.build_dir.mkdir(parents=True, exist_ok=True) # Final RMS Norm - self.final_norm_prefill = AIERMSNorm( + self.rms_norm_prefill = AIERMSNorm( size=prompt_len * config.emb_dim, eps=1e-5, num_aie_columns=8, @@ -43,7 +42,7 @@ def __init__(self, config, prompt_len): tile_size=config.emb_dim, context=self.context ).compile().get_callable() - self.final_norm_decode = AIERMSNorm( + self.rms_norm_decode = AIERMSNorm( size=config.emb_dim, eps=1e-5, num_aie_columns=1, @@ -88,21 +87,37 @@ def __init__(self, config, prompt_len): class AIELlamaBuffers: def __init__(self, config, prompt_len): + # Vector of the current token(s) being processed through the pipeline self.x_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) self.x_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) - self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']) + self.x_norm_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.x_norm_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + + # Transformer block layer-wise RMS norm + self.W_norm1 = [] + self.W_norm2 = [] + for layer_idx in range(config.n_layers): + self.W_norm1.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.input_layernorm.weight']).to("npu") + ) + self.W_norm2.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight']).to("npu") + ) + + # Final RMS norm weights + self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']).to("npu") # Final linear layer - self.W_out_head = AIEBuffer.from_np(torch_to_numpy(config.weights['model.embed_tokens.weight'])) # unpadded/unpartitioned, used by GEMV + self.W_out_head = AIEBuffer.from_np(torch_to_numpy(config.weights['model.embed_tokens.weight'])).to("npu") # unpadded/unpartitioned, used by GEMV W_out_head_parts = aie_ops.out_head_prefill_compilable.partition_B( torch_to_numpy(config.weights['model.embed_tokens.weight']), config.vocab_partitions ) self.W_out_head_parts = [ - AIEBuffer.from_np(W_out_head_part) + AIEBuffer.from_np(W_out_head_part).to("npu") for W_out_head_part in W_out_head_parts ] # partitioned, padded parts of weight, used by GEMM - self.logits_prefill = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)) + self.logits_prefill = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)).to("npu") self.logits_prefill_parts = [ self.logits_prefill.subbuffer( length=prompt_len * (config.padded_vocab_size // config.vocab_partitions), @@ -113,15 +128,6 @@ def __init__(self, config, prompt_len): ] self.logits_decode = AIEBuffer(shape=(config.vocab_size,)) - for buf in ( - [ - self.W_final_norm, - self.W_out_head, - ] + - self.W_out_head_parts - ): - buf.to("npu") - # Operators # ########################################################################## @@ -280,9 +286,24 @@ def transformer_block_forward( rope_angles, attn_mask ): + batch, seq_len, _ = x.shape + # Step 1: RMS normalization - x_norm = rms_norm_forward(x, W_norm1) - + if seq_len > 1: + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_prefill.to("npu") + aie_buffers.x_norm_prefill.to("npu") + aie_ops.rms_norm_prefill(aie_buffers.x_prefill, W_norm1, aie_buffers.x_norm_prefill) + aie_buffers.x_norm_prefill.to("cpu") + x_norm = aie_buffers.x_norm_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + else: + aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x + aie_buffers.x_decode.to("npu") + aie_buffers.x_norm_decode.to("npu") + x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm1, aie_buffers.x_norm_decode) + aie_buffers.x_decode.to("cpu") + x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) + # Step 2: Attention attn_output, attn_keys, attn_values = grouped_query_attention_forward( x_norm, @@ -299,7 +320,20 @@ def transformer_block_forward( x = x + attn_output # Step 4: Post-norm - x_norm = rms_norm_forward(x, W_norm2) + if seq_len > 1: + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.x_prefill.to("npu") + aie_buffers.x_norm_prefill.to("npu") + aie_ops.rms_norm_prefill(aie_buffers.x_prefill, W_norm2, aie_buffers.x_norm_prefill) + aie_buffers.x_norm_prefill.to("cpu") + x_norm = aie_buffers.x_norm_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + else: + aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x + aie_buffers.x_decode.to("npu") + aie_buffers.x_norm_decode.to("npu") + x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm2, aie_buffers.x_norm_decode) + aie_buffers.x_decode.to("cpu") + x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) # Step 5: fully-connected feed-forward network ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) @@ -331,7 +365,7 @@ def llama_forward_pass( state.attn_values_caches[layer_idx], config.n_heads, config.n_kv_groups, - W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], + W_norm1=aie_buffers.W_norm1[layer_idx], W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], @@ -339,7 +373,7 @@ def llama_forward_pass( W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], - W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], + W_norm2=aie_buffers.W_norm2[layer_idx], rope_angles=config.angles, attn_mask=attn_mask, ) @@ -348,15 +382,21 @@ def llama_forward_pass( if seq_len > 1: aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x aie_buffers.x_prefill.to("npu") - aie_ops.final_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) + aie_ops.rms_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) else: aie_buffers.x_decode.view_as_torch().view(1, 1, config.emb_dim)[0, 0, :] = x aie_buffers.x_decode.to("npu") - aie_ops.final_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) + aie_ops.rms_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) # Step 5: Output projection (check for tied embeddings) if seq_len > 1: # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. + # Reference: + # aie_buffers.x_prefill.to("cpu") + # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + # assert (logits - logits_ref).max() < 0.5 + aie_buffers.x_prefill.to("npu") aie_buffers.logits_prefill.to("npu") for i in range(config.vocab_partitions): aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) @@ -364,54 +404,16 @@ def llama_forward_pass( logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) - # Reference: - # aie_buffers.x_prefill.to("cpu") - # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] - # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - # assert (logits - logits_ref).max() < 0.5 else: # Step 5: Output projection + # Reference: + # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) + # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) aie_buffers.logits_decode.to("npu") aie_ops.out_head_decode(aie_buffers.W_out_head, aie_buffers.x_decode.view((config.emb_dim,)), aie_buffers.logits_decode) aie_buffers.logits_decode.to("cpu") logits = aie_buffers.logits_decode.view_as_torch().view(1, 1, config.vocab_size) - # Reference: - # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) - # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - - return logits, state - -def llama_forward_pass_decode( - config, - state, - x, - attn_mask -): - batch, seq_len, _ = x.shape - - # Step 3: Apply transformer blocks - for layer_idx in range(config.n_layers): - x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( - x, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], - config.n_heads, - config.n_kv_groups, - W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], - W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], - W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], - W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], - W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], - W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], - W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], - W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], - W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], - rope_angles=config.angles, - attn_mask=attn_mask, - ) - - return logits, state diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index 1a8cd105..89740497 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -177,9 +177,11 @@ def get_callable(self): ) class AIERuntimeArgSpec: - def __init__(self, shape, dtype=bfloat16): + def __init__(self, direction, shape, dtype=bfloat16): self.shape = shape self.dtype = dtype + assert direction in {"in", "out", "inout"} + self.direction = direction class AIEBuffer: def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): @@ -213,6 +215,7 @@ def subbuffer(self, length, offset, shape, dtype=None): offset * itemsize, # offset ) sub_buffer = AIEBuffer(shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager) + sub_buffer.on = self.on self.subviews.append(sub_buffer) return sub_buffer @@ -290,7 +293,7 @@ def __call__(self, *buffers): for i in range(len(buffers)) ), "Input buffer shapes or dtypes do not match expected argument specification." self.insts_buffer.to("npu") - assert all(buffer.on == "npu" for buffer in buffers), "Not all input buffers have been synced on the NPU." + assert all(buffer.on == "npu" for buffer, spec in zip(buffers, self.args_spec)), "Not all buffers have been synced on the NPU; for some reason even output buffers must be synced!" opcode = 3 bos = [buffer.bo for buffer in buffers] run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 2cb12214..bb7c3c79 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -169,9 +169,9 @@ def get_kernel_archive_name(self): def get_arg_spec(self): return [ - AIERuntimeArgSpec((self.M, self.K)), # input A - AIERuntimeArgSpec((self.K, self.N) if not self.b_col_maj else (self.N, self.K)), # input B (weights) - AIERuntimeArgSpec((self.M, self.N) if not self.c_col_maj else (self.N, self.M)), # output C + AIERuntimeArgSpec("in", (self.M, self.K)), # input A + AIERuntimeArgSpec("in", (self.K, self.N) if not self.b_col_maj else (self.N, self.K)), # input B (weights) + AIERuntimeArgSpec("out", (self.M, self.N) if not self.c_col_maj else (self.N, self.M)), # output C ] # def _get_B_dims(self, B_shape): diff --git a/operators/gemv/op.py b/operators/gemv/op.py index bd10dd79..fcae168e 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -85,7 +85,7 @@ def get_kernel_artifacts(self): def get_arg_spec(self): return [ - AIERuntimeArgSpec((self.M, self.K)), # matrix - AIERuntimeArgSpec((self.K,)), # vector - AIERuntimeArgSpec((self.M,)), # output + AIERuntimeArgSpec("in", (self.M, self.K)), # matrix + AIERuntimeArgSpec("in", (self.K,)), # vector + AIERuntimeArgSpec("out", (self.M,)), # output ] diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index 0d3457f2..48442b9d 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -102,7 +102,7 @@ def get_kernel_artifacts(self): def get_arg_spec(self): return [ - AIERuntimeArgSpec((self.size // self.tile_size, self.tile_size)), # input - AIERuntimeArgSpec((self.tile_size,)), # weight - AIERuntimeArgSpec((self.size // self.tile_size, self.tile_size)) # output + AIERuntimeArgSpec("in", (self.size // self.tile_size, self.tile_size)), # input + AIERuntimeArgSpec("in", (self.tile_size,)), # weight + AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size)) # output ] From db9c4b8eb2d6ed7f5deac1df2793699e96251a0d Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 21:03:51 -0700 Subject: [PATCH 15/99] less buffer copying --- applications/llama_3.2_1b/llama_npu.py | 27 ++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index a920053a..cb6035f3 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -269,7 +269,7 @@ def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): def transformer_block_forward( - x, + seq_len, attn_keys_cache, attn_values_cache, num_heads, @@ -286,18 +286,14 @@ def transformer_block_forward( rope_angles, attn_mask ): - batch, seq_len, _ = x.shape - # Step 1: RMS normalization if seq_len > 1: - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x aie_buffers.x_prefill.to("npu") aie_buffers.x_norm_prefill.to("npu") aie_ops.rms_norm_prefill(aie_buffers.x_prefill, W_norm1, aie_buffers.x_norm_prefill) aie_buffers.x_norm_prefill.to("cpu") x_norm = aie_buffers.x_norm_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] else: - aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x aie_buffers.x_decode.to("npu") aie_buffers.x_norm_decode.to("npu") x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm1, aie_buffers.x_norm_decode) @@ -317,6 +313,10 @@ def transformer_block_forward( ) # Step 3: Residual + if seq_len > 1: + x = aie_buffers.x_prefill.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] + else: + x = aie_buffers.x_decode.to("cpu").view_as_torch().unsqueeze(0) x = x + attn_output # Step 4: Post-norm @@ -340,8 +340,12 @@ def transformer_block_forward( # Step 6: Residual x = x + ffn_output + if seq_len > 1: + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + else: + aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x - return x, attn_keys, attn_values + return attn_keys, attn_values def llama_forward_pass( @@ -356,11 +360,15 @@ def llama_forward_pass( torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) + if seq_len > 1: + aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + else: + aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): - x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( - x, + state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( + seq_len, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx], config.n_heads, @@ -377,14 +385,13 @@ def llama_forward_pass( rope_angles=config.angles, attn_mask=attn_mask, ) + # Step 4: Final normalization if seq_len > 1: - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x aie_buffers.x_prefill.to("npu") aie_ops.rms_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) else: - aie_buffers.x_decode.view_as_torch().view(1, 1, config.emb_dim)[0, 0, :] = x aie_buffers.x_decode.to("npu") aie_ops.rms_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) From b0942e849ca368a6d8074df517e05846990ae540 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 21:30:27 -0700 Subject: [PATCH 16/99] offload first residual --- applications/llama_3.2_1b/llama_npu.py | 37 +++++- operators/__init__.py | 2 +- operators/elementwise_add/design.py | 4 +- operators/elementwise_add/op.py | 157 ++++++------------------- 4 files changed, 68 insertions(+), 132 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index cb6035f3..1bbbc2ce 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -4,6 +4,7 @@ import math from pathlib import Path import sys +import numpy as np import ml_dtypes import llama_inference_harness as harness @@ -18,7 +19,8 @@ from operators import ( AIERMSNorm, AIEGEMM, - AIEGEMV + AIEGEMV, + AIEElementwiseAdd ) @@ -33,7 +35,7 @@ def __init__(self, config, prompt_len): self.context = AIEContext() self.context.build_dir.mkdir(parents=True, exist_ok=True) - # Final RMS Norm + # RMS Norm self.rms_norm_prefill = AIERMSNorm( size=prompt_len * config.emb_dim, eps=1e-5, @@ -51,6 +53,16 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Residual additions + self.residual_add_prefill = AIEElementwiseAdd( + size=prompt_len * config.emb_dim, + tile_size=config.emb_dim + ).compile().get_callable() + self.residual_add_decode = AIEElementwiseAdd( + size=config.emb_dim, + tile_size=config.emb_dim // 8 + ).compile().get_callable() + # Final GEMM min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N @@ -94,6 +106,12 @@ def __init__(self, config, prompt_len): self.x_norm_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + + self.ffn_output_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.ffn_output_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + # Transformer block layer-wise RMS norm self.W_norm1 = [] self.W_norm2 = [] @@ -297,7 +315,7 @@ def transformer_block_forward( aie_buffers.x_decode.to("npu") aie_buffers.x_norm_decode.to("npu") x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm1, aie_buffers.x_norm_decode) - aie_buffers.x_decode.to("cpu") + aie_buffers.x_norm_decode.to("cpu") x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) # Step 2: Attention @@ -314,10 +332,19 @@ def transformer_block_forward( # Step 3: Residual if seq_len > 1: + aie_buffers.attn_output_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output + aie_buffers.attn_output_prefill.to("npu") + x_view = aie_buffers.x_prefill.view(np.prod(aie_buffers.x_prefill.shape)) + attn_output_view = aie_buffers.attn_output_prefill.view(np.prod(aie_buffers.attn_output_prefill.shape)) + aie_ops.residual_add_prefill(x_view, attn_output_view, x_view) x = aie_buffers.x_prefill.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] else: + aie_buffers.attn_output_decode.view_as_torch().unsqueeze(0)[0, 0, :] = attn_output + aie_buffers.attn_output_decode.to("npu") + x_view = aie_buffers.x_decode.view(np.prod(aie_buffers.x_decode.shape)) + attn_output_view = aie_buffers.attn_output_decode.view(np.prod(aie_buffers.attn_output_decode.shape)) + aie_ops.residual_add_decode(x_view, attn_output_view, x_view) x = aie_buffers.x_decode.to("cpu").view_as_torch().unsqueeze(0) - x = x + attn_output # Step 4: Post-norm if seq_len > 1: @@ -332,7 +359,7 @@ def transformer_block_forward( aie_buffers.x_decode.to("npu") aie_buffers.x_norm_decode.to("npu") x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm2, aie_buffers.x_norm_decode) - aie_buffers.x_decode.to("cpu") + aie_buffers.x_norm_decode.to("cpu") x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) # Step 5: fully-connected feed-forward network diff --git a/operators/__init__.py b/operators/__init__.py index 4bc8ba33..7da4b542 100644 --- a/operators/__init__.py +++ b/operators/__init__.py @@ -3,7 +3,7 @@ #from .axpy.op import AIEAXPY #from .dequant.op import AIEDequant -#from .elementwise_add.op import AIEElementwiseAdd +from .elementwise_add.op import AIEElementwiseAdd #from .elementwise_mul.op import AIEElementwiseMul #from .gelu.op import AIEGELU from .gemm.op import AIEGEMM diff --git a/operators/elementwise_add/design.py b/operators/elementwise_add/design.py index d1eda376..0aa6d6b1 100644 --- a/operators/elementwise_add/design.py +++ b/operators/elementwise_add/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_add(dev, num_elements, num_columns, num_channels, tile_size, trace_size): +def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, archive_name): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +37,7 @@ def my_eltwise_add(dev, num_elements, num_columns, num_channels, tile_size, trac # AIE Core Function declaration eltwise_add_bf16_vector = Kernel( - "eltwise_add_bf16_vector", "add.o", [tile_ty, tile_ty, tile_ty, np.int32] + "eltwise_add_bf16_vector", archive_name, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_add/op.py b/operators/elementwise_add/op.py index 3521c3c2..bfe04f82 100644 --- a/operators/elementwise_add/op.py +++ b/operators/elementwise_add/op.py @@ -8,8 +8,8 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -19,152 +19,61 @@ ) -class AIEElementwiseAdd(AIEOperatorBase): +class AIEElementwiseAdd(SingleMLIRSourceOperator): """AIE-accelerated element-wise addition""" def __init__( self, size, - num_aie_columns=None, - num_channels=None, - tile_size=None, + tile_size, + num_aie_columns=8, context=None, ): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size - self.num_aie_columns = num_aie_columns - self.num_channels = num_channels # Enforce ShimDMA limits for elementwise_add (uses 2 inputs per core) # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels - total_shimdma_channels = self.num_aie_columns * self.num_channels + total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + SingleMLIRSourceOperator.__init__(self, context=context) - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"add_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"add_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_add", callback_args=[ self.context.device_manager.device_type, self.size, self.num_aie_columns, - self.num_channels, self.tile_size, 0, + self.get_kernel_archive_name(), ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"add.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "add.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"add.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "generic" / "add.cc" + ) + ], + ), + ] + + def get_arg_spec(self): # Runtime setup - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_add", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_add", "input1", "input2", "output") - - def forward(self, x, y): - """Forward pass for element-wise addition""" - applicable = ( - len(x.shape) >= 1 - and len(y.shape) >= 1 - and x.shape[-1] <= self.size - and y.shape[-1] <= self.size - and x.numel() <= self.size - and y.numel() <= self.size - and x.numel() == y.numel() - and x.shape == y.shape - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - y_flat = y.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y): - """Execute element-wise addition operation on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size or len(y_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)}, y={len(y_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input1", x_flat) - self.write_buffer("input2", y_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + return [ + AIERuntimeArgSpec("in", (self.size,)), # input1 + AIERuntimeArgSpec("in", (self.size,)), # input2 + AIERuntimeArgSpec("out", (self.size,)), # output + ] From 02e53073424554437de319b2d688f81d2ebf0692 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 21:42:31 -0700 Subject: [PATCH 17/99] simplify --- applications/llama_3.2_1b/llama_npu.py | 181 +++++++++++++------------ 1 file changed, 93 insertions(+), 88 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 1bbbc2ce..39715a17 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -29,6 +29,19 @@ aie_ops = None +class AIEPrefillOperations: + def __init__(self, rms_norm, residual_add, out_head, out_head_compilable): + self.rms_norm = rms_norm + self.residual_add = residual_add + self.out_head = out_head + self.out_head_compilable = out_head_compilable + +class AIEDecodeOperations: + def __init__(self, rms_norm, residual_add, out_head): + self.rms_norm = rms_norm + self.residual_add = residual_add + self.out_head = out_head + class AIELlamaOperators: def __init__(self, config, prompt_len): @@ -36,7 +49,7 @@ def __init__(self, config, prompt_len): self.context.build_dir.mkdir(parents=True, exist_ok=True) # RMS Norm - self.rms_norm_prefill = AIERMSNorm( + rms_norm_prefill = AIERMSNorm( size=prompt_len * config.emb_dim, eps=1e-5, num_aie_columns=8, @@ -44,7 +57,7 @@ def __init__(self, config, prompt_len): tile_size=config.emb_dim, context=self.context ).compile().get_callable() - self.rms_norm_decode = AIERMSNorm( + rms_norm_decode = AIERMSNorm( size=config.emb_dim, eps=1e-5, num_aie_columns=1, @@ -54,11 +67,11 @@ def __init__(self, config, prompt_len): ).compile().get_callable() # Residual additions - self.residual_add_prefill = AIEElementwiseAdd( + residual_add_prefill = AIEElementwiseAdd( size=prompt_len * config.emb_dim, tile_size=config.emb_dim ).compile().get_callable() - self.residual_add_decode = AIEElementwiseAdd( + residual_add_decode = AIEElementwiseAdd( size=config.emb_dim, tile_size=config.emb_dim // 8 ).compile().get_callable() @@ -67,7 +80,7 @@ def __init__(self, config, prompt_len): min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N config.vocab_partitions = 4 - self.out_head_prefill_compilable = AIEGEMM( + out_head_prefill_compilable = AIEGEMM( M=prompt_len, K=config.emb_dim, N=config.padded_vocab_size // config.vocab_partitions, @@ -80,8 +93,8 @@ def __init__(self, config, prompt_len): separate_c_tiles=True, context=self.context ).compile() - self.out_head_prefill = self.out_head_prefill_compilable.get_callable() - self.out_head_decode = AIEGEMV( + out_head_prefill = out_head_prefill_compilable.get_callable() + out_head_decode = AIEGEMV( M=config.vocab_size, K=config.emb_dim, num_aie_columns=8, @@ -90,6 +103,10 @@ def __init__(self, config, prompt_len): tile_size_output=32, context=self.context ).compile().get_callable() + + # Group operations + self.prefill = AIEPrefillOperations(rms_norm_prefill, residual_add_prefill, out_head_prefill, out_head_prefill_compilable) + self.decode = AIEDecodeOperations(rms_norm_decode, residual_add_decode, out_head_decode) # Allocate buffers shared with NPU @@ -97,20 +114,25 @@ def __init__(self, config, prompt_len): aie_buffers = None +class AIEPrefillBuffers: + def __init__(self, prompt_len, emb_dim): + self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.ffn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + +class AIEDecodeBuffers: + def __init__(self, emb_dim): + self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) + self.x_norm = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) + self.ffn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) + class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline - self.x_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) - self.x_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) - - self.x_norm_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) - self.x_norm_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) - - self.attn_output_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) - self.attn_output_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) - - self.ffn_output_prefill = AIEBuffer(shape=(prompt_len, config.emb_dim), dtype=ml_dtypes.bfloat16) - self.ffn_output_decode = AIEBuffer(shape=(1, config.emb_dim), dtype=ml_dtypes.bfloat16) + self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim) + self.decode = AIEDecodeBuffers(config.emb_dim) # Transformer block layer-wise RMS norm self.W_norm1 = [] @@ -127,7 +149,7 @@ def __init__(self, config, prompt_len): self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']).to("npu") # Final linear layer self.W_out_head = AIEBuffer.from_np(torch_to_numpy(config.weights['model.embed_tokens.weight'])).to("npu") # unpadded/unpartitioned, used by GEMV - W_out_head_parts = aie_ops.out_head_prefill_compilable.partition_B( + W_out_head_parts = aie_ops.prefill.out_head_compilable.partition_B( torch_to_numpy(config.weights['model.embed_tokens.weight']), config.vocab_partitions ) @@ -135,16 +157,16 @@ def __init__(self, config, prompt_len): AIEBuffer.from_np(W_out_head_part).to("npu") for W_out_head_part in W_out_head_parts ] # partitioned, padded parts of weight, used by GEMM - self.logits_prefill = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)).to("npu") - self.logits_prefill_parts = [ - self.logits_prefill.subbuffer( + self.prefill.logits = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)).to("npu") + self.prefill.logits_parts = [ + self.prefill.logits.subbuffer( length=prompt_len * (config.padded_vocab_size // config.vocab_partitions), offset=i * prompt_len * (config.padded_vocab_size // config.vocab_partitions), shape=(prompt_len, config.padded_vocab_size // config.vocab_partitions), ) for i in range(config.vocab_partitions) ] - self.logits_decode = AIEBuffer(shape=(config.vocab_size,)) + self.decode.logits = AIEBuffer(shape=(config.vocab_size,)) # Operators @@ -304,19 +326,20 @@ def transformer_block_forward( rope_angles, attn_mask ): - # Step 1: RMS normalization + # Select prefill or decode operations and buffers if seq_len > 1: - aie_buffers.x_prefill.to("npu") - aie_buffers.x_norm_prefill.to("npu") - aie_ops.rms_norm_prefill(aie_buffers.x_prefill, W_norm1, aie_buffers.x_norm_prefill) - aie_buffers.x_norm_prefill.to("cpu") - x_norm = aie_buffers.x_norm_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + ops = aie_ops.prefill + bufs = aie_buffers.prefill else: - aie_buffers.x_decode.to("npu") - aie_buffers.x_norm_decode.to("npu") - x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm1, aie_buffers.x_norm_decode) - aie_buffers.x_norm_decode.to("cpu") - x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) + ops = aie_ops.decode + bufs = aie_buffers.decode + + # Step 1: RMS normalization + bufs.x.to("npu") + bufs.x_norm.to("npu") + ops.rms_norm(bufs.x, W_norm1, bufs.x_norm) + bufs.x_norm.to("cpu") + x_norm = bufs.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 2: Attention attn_output, attn_keys, attn_values = grouped_query_attention_forward( @@ -331,46 +354,27 @@ def transformer_block_forward( ) # Step 3: Residual - if seq_len > 1: - aie_buffers.attn_output_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output - aie_buffers.attn_output_prefill.to("npu") - x_view = aie_buffers.x_prefill.view(np.prod(aie_buffers.x_prefill.shape)) - attn_output_view = aie_buffers.attn_output_prefill.view(np.prod(aie_buffers.attn_output_prefill.shape)) - aie_ops.residual_add_prefill(x_view, attn_output_view, x_view) - x = aie_buffers.x_prefill.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] - else: - aie_buffers.attn_output_decode.view_as_torch().unsqueeze(0)[0, 0, :] = attn_output - aie_buffers.attn_output_decode.to("npu") - x_view = aie_buffers.x_decode.view(np.prod(aie_buffers.x_decode.shape)) - attn_output_view = aie_buffers.attn_output_decode.view(np.prod(aie_buffers.attn_output_decode.shape)) - aie_ops.residual_add_decode(x_view, attn_output_view, x_view) - x = aie_buffers.x_decode.to("cpu").view_as_torch().unsqueeze(0) + bufs.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output + bufs.attn_output.to("npu") + x_view = bufs.x.view(np.prod(bufs.x.shape)) + attn_output_view = bufs.attn_output.view(np.prod(bufs.attn_output.shape)) + ops.residual_add(x_view, attn_output_view, x_view) + x = bufs.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 4: Post-norm - if seq_len > 1: - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_buffers.x_prefill.to("npu") - aie_buffers.x_norm_prefill.to("npu") - aie_ops.rms_norm_prefill(aie_buffers.x_prefill, W_norm2, aie_buffers.x_norm_prefill) - aie_buffers.x_norm_prefill.to("cpu") - x_norm = aie_buffers.x_norm_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] - else: - aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x - aie_buffers.x_decode.to("npu") - aie_buffers.x_norm_decode.to("npu") - x_norm = aie_ops.rms_norm_decode(aie_buffers.x_decode, W_norm2, aie_buffers.x_norm_decode) - aie_buffers.x_norm_decode.to("cpu") - x_norm = aie_buffers.x_norm_decode.view_as_torch().unsqueeze(0) + bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + bufs.x.to("npu") + bufs.x_norm.to("npu") + ops.rms_norm(bufs.x, W_norm2, bufs.x_norm) + bufs.x_norm.to("cpu") + x_norm = bufs.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 5: fully-connected feed-forward network ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) # Step 6: Residual x = x + ffn_output - if seq_len > 1: - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - else: - aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x + bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x return attn_keys, attn_values @@ -381,16 +385,21 @@ def llama_forward_pass( ): batch, seq_len = state.token_ids.shape + # Select prefill or decode operations and buffers + if seq_len > 1: + ops = aie_ops.prefill + bufs = aie_buffers.prefill + else: + ops = aie_ops.decode + bufs = aie_buffers.decode + tok_emb_weight = config.weights['model.embed_tokens.weight'] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) # (batch, seq_len, emb_dim) attn_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) - if seq_len > 1: - aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - else: - aie_buffers.x_decode.view_as_torch().unsqueeze(0)[0, 0, :] = x + bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): @@ -415,38 +424,34 @@ def llama_forward_pass( # Step 4: Final normalization - if seq_len > 1: - aie_buffers.x_prefill.to("npu") - aie_ops.rms_norm_prefill(aie_buffers.x_prefill, aie_buffers.W_final_norm, aie_buffers.x_prefill) - else: - aie_buffers.x_decode.to("npu") - aie_ops.rms_norm_decode(aie_buffers.x_decode, aie_buffers.W_final_norm, aie_buffers.x_decode) + bufs.x.to("npu") + ops.rms_norm(bufs.x, aie_buffers.W_final_norm, bufs.x) # Step 5: Output projection (check for tied embeddings) if seq_len > 1: # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. # Reference: - # aie_buffers.x_prefill.to("cpu") - # x = aie_buffers.x_prefill.view_as_torch().unsqueeze(0)[:, :seq_len, :] + # bufs.x.to("cpu") + # x = bufs.x.view_as_torch().unsqueeze(0)[:, :seq_len, :] # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) # assert (logits - logits_ref).max() < 0.5 - aie_buffers.x_prefill.to("npu") - aie_buffers.logits_prefill.to("npu") + bufs.x.to("npu") + bufs.logits.to("npu") for i in range(config.vocab_partitions): - aie_ops.out_head_prefill(aie_buffers.x_prefill, aie_buffers.W_out_head_parts[i], aie_buffers.logits_prefill_parts[i]) - aie_buffers.logits_prefill.to("cpu") - logits_padded_partitioned = aie_buffers.logits_prefill.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) + ops.out_head(bufs.x, aie_buffers.W_out_head_parts[i], bufs.logits_parts[i]) + bufs.logits.to("cpu") + logits_padded_partitioned = bufs.logits.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) else: # Step 5: Output projection # Reference: - # x = aie_buffers.x_decode.view_as_torch().unsqueeze(0) + # x = bufs.x.view_as_torch().unsqueeze(0) # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - aie_buffers.logits_decode.to("npu") - aie_ops.out_head_decode(aie_buffers.W_out_head, aie_buffers.x_decode.view((config.emb_dim,)), aie_buffers.logits_decode) - aie_buffers.logits_decode.to("cpu") - logits = aie_buffers.logits_decode.view_as_torch().view(1, 1, config.vocab_size) + bufs.logits.to("npu") + ops.out_head(aie_buffers.W_out_head, bufs.x.view((config.emb_dim,)), bufs.logits) + bufs.logits.to("cpu") + logits = bufs.logits.view_as_torch().view(1, 1, config.vocab_size) return logits, state From e815c50881e850f2d2663a0e8a6bbe20b414419f Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 21:45:01 -0700 Subject: [PATCH 18/99] offload second residual --- applications/llama_3.2_1b/llama_npu.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 39715a17..263f29aa 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -373,8 +373,10 @@ def transformer_block_forward( ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) # Step 6: Residual - x = x + ffn_output - bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + bufs.ffn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = ffn_output + bufs.ffn_output.to("npu") + ffn_output_view = bufs.ffn_output.view(np.prod(bufs.ffn_output.shape)) + ops.residual_add(x_view, ffn_output_view, x_view) return attn_keys, attn_values From ebe8f32f074c3b075bf2f533d8671ba8b6c0d41f Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 15 Jan 2026 22:29:40 -0700 Subject: [PATCH 19/99] SwiGLU offloaded --- applications/llama_3.2_1b/llama_npu.py | 227 ++++++++++++++++++++----- operators/elementwise_mul/design.py | 15 +- operators/elementwise_mul/op.py | 172 +++++-------------- operators/elementwise_mul/test.py | 3 +- operators/silu/design.py | 22 +-- operators/silu/op.py | 158 ++++------------- 6 files changed, 270 insertions(+), 327 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 263f29aa..9955834e 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -7,6 +7,7 @@ import numpy as np import ml_dtypes import llama_inference_harness as harness +import logging repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) @@ -22,6 +23,10 @@ AIEGEMV, AIEElementwiseAdd ) +from operators.elementwise_mul.op import AIEElementwiseMul +from operators.silu.op import AIESiLU + +logging.basicConfig(level=logging.DEBUG) # AIE Operator Configuration @@ -30,17 +35,10 @@ aie_ops = None class AIEPrefillOperations: - def __init__(self, rms_norm, residual_add, out_head, out_head_compilable): - self.rms_norm = rms_norm - self.residual_add = residual_add - self.out_head = out_head - self.out_head_compilable = out_head_compilable + pass class AIEDecodeOperations: - def __init__(self, rms_norm, residual_add, out_head): - self.rms_norm = rms_norm - self.residual_add = residual_add - self.out_head = out_head + pass class AIELlamaOperators: @@ -48,8 +46,11 @@ def __init__(self, config, prompt_len): self.context = AIEContext() self.context.build_dir.mkdir(parents=True, exist_ok=True) + self.prefill = AIEPrefillOperations() + self.decode = AIEDecodeOperations() + # RMS Norm - rms_norm_prefill = AIERMSNorm( + self.prefill.rms_norm = AIERMSNorm( size=prompt_len * config.emb_dim, eps=1e-5, num_aie_columns=8, @@ -57,7 +58,7 @@ def __init__(self, config, prompt_len): tile_size=config.emb_dim, context=self.context ).compile().get_callable() - rms_norm_decode = AIERMSNorm( + self.decode.rms_norm = AIERMSNorm( size=config.emb_dim, eps=1e-5, num_aie_columns=1, @@ -67,11 +68,11 @@ def __init__(self, config, prompt_len): ).compile().get_callable() # Residual additions - residual_add_prefill = AIEElementwiseAdd( + self.prefill.residual_add = AIEElementwiseAdd( size=prompt_len * config.emb_dim, tile_size=config.emb_dim ).compile().get_callable() - residual_add_decode = AIEElementwiseAdd( + self.decode.residual_add = AIEElementwiseAdd( size=config.emb_dim, tile_size=config.emb_dim // 8 ).compile().get_callable() @@ -80,7 +81,7 @@ def __init__(self, config, prompt_len): min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N config.vocab_partitions = 4 - out_head_prefill_compilable = AIEGEMM( + self.prefill.out_head_compilable = AIEGEMM( M=prompt_len, K=config.emb_dim, N=config.padded_vocab_size // config.vocab_partitions, @@ -93,8 +94,8 @@ def __init__(self, config, prompt_len): separate_c_tiles=True, context=self.context ).compile() - out_head_prefill = out_head_prefill_compilable.get_callable() - out_head_decode = AIEGEMV( + self.prefill.out_head = self.prefill.out_head_compilable.get_callable() + self.decode.out_head = AIEGEMV( M=config.vocab_size, K=config.emb_dim, num_aie_columns=8, @@ -104,9 +105,81 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - # Group operations - self.prefill = AIEPrefillOperations(rms_norm_prefill, residual_add_prefill, out_head_prefill, out_head_prefill_compilable) - self.decode = AIEDecodeOperations(rms_norm_decode, residual_add_decode, out_head_decode) + # SwiGLU FFN operators + # Prefill: M=prompt_len, K=emb_dim, N=hidden_dim + self.prefill.ffn_up_gate = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.hidden_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + use_static_weight=True, + context=self.context + ).compile().get_callable() + + self.prefill.ffn_down = AIEGEMM( + M=prompt_len, + K=config.hidden_dim, + N=config.emb_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + use_static_weight=True, + context=self.context + ).compile().get_callable() + + self.prefill.ffn_silu = AIESiLU( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + + self.prefill.ffn_mul = AIEElementwiseMul( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + + # Decode: GEMV for M=1 + self.decode.ffn_up_gate = AIEGEMV( + M=config.hidden_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.hidden_dim // 8, + context=self.context + ).compile().get_callable() + + self.decode.ffn_down = AIEGEMV( + M=config.emb_dim, + K=config.hidden_dim, + num_aie_columns=8, + tile_size_input=1, + tile_size_output=config.emb_dim // 8, + context=self.context + ).compile().get_callable() + + self.decode.ffn_silu = AIESiLU( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=1, + context=self.context + ).compile().get_callable() + + self.decode.ffn_mul = AIEElementwiseMul( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + # Allocate buffers shared with NPU @@ -115,28 +188,43 @@ def __init__(self, config, prompt_len): aie_buffers = None class AIEPrefillBuffers: - def __init__(self, prompt_len, emb_dim): + def __init__(self, prompt_len, emb_dim, hidden_dim): self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.attn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.ffn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + # SwiGLU intermediate buffers + self.ffn_gate = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) + self.ffn_up = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) + self.ffn_hidden = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) class AIEDecodeBuffers: - def __init__(self, emb_dim): + def __init__(self, emb_dim, hidden_dim): self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.attn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.ffn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) + # SwiGLU intermediate buffers + self.ffn_gate = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) + self.ffn_up = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) + self.ffn_hidden = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline - self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim) - self.decode = AIEDecodeBuffers(config.emb_dim) + self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim) + self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim) # Transformer block layer-wise RMS norm self.W_norm1 = [] self.W_norm2 = [] + # SwiGLU FFN weights + self.W_ffn_gate_prefill = [] + self.W_ffn_up_prefill = [] + self.W_ffn_down_prefill = [] + self.W_ffn_gate_decode = [] + self.W_ffn_up_decode = [] + self.W_ffn_down_decode = [] for layer_idx in range(config.n_layers): self.W_norm1.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.input_layernorm.weight']).to("npu") @@ -144,11 +232,29 @@ def __init__(self, config, prompt_len): self.W_norm2.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight']).to("npu") ) + self.W_ffn_gate_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") + ) + self.W_ffn_up_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight']).to("npu") + ) + self.W_ffn_down_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight']).to("npu") + ) + self.W_ffn_gate_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'].T).to("npu") + ) + self.W_ffn_up_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'].T).to("npu") + ) + self.W_ffn_down_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'].T).to("npu") + ) # Final RMS norm weights self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']).to("npu") # Final linear layer - self.W_out_head = AIEBuffer.from_np(torch_to_numpy(config.weights['model.embed_tokens.weight'])).to("npu") # unpadded/unpartitioned, used by GEMV + self.W_out_head = AIEBuffer.from_torch(config.weights['model.embed_tokens.weight']).to("npu") # unpadded/unpartitioned, used by GEMV W_out_head_parts = aie_ops.prefill.out_head_compilable.partition_B( torch_to_numpy(config.weights['model.embed_tokens.weight']), config.vocab_partitions @@ -247,7 +353,8 @@ def grouped_query_attention_forward( # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. - # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + # Transpose so that heads are consecutive for attention computation: + # (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) @@ -291,25 +398,62 @@ def grouped_query_attention_forward( return output, keys_cache, values_cache -def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): - # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) - gate = torch.nn.functional.linear(x, fc1_weight) # gate projection - up = torch.nn.functional.linear(x, fc2_weight) # up projection +def swiglu_ffn_forward(seq_len, layer_idx): + # Select prefill or decode operations and buffers + if seq_len > 1: + ops = aie_ops.prefill + bufs = aie_buffers.prefill + W_ffn_gate = aie_buffers.W_ffn_gate_prefill[layer_idx] + W_ffn_up = aie_buffers.W_ffn_up_prefill[layer_idx] + W_ffn_down = aie_buffers.W_ffn_down_prefill[layer_idx] + else: + ops = aie_ops.decode + bufs = aie_buffers.decode + W_ffn_gate = aie_buffers.W_ffn_gate_decode[layer_idx] + W_ffn_up = aie_buffers.W_ffn_up_decode[layer_idx] + W_ffn_down = aie_buffers.W_ffn_down_decode[layer_idx] - # Step 2: Apply SiLU activation - gate_activated = torch.nn.functional.silu(gate) # (batch, seq_len, swiglu_hidden_dim) + # Step 1: Gate projection: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) + bufs.x_norm.to("npu") + bufs.ffn_gate.to("npu") + if seq_len > 1: + ops.ffn_up_gate(bufs.x_norm, W_ffn_gate, bufs.ffn_gate) + else: + x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) + ffn_gate_view = bufs.ffn_gate.view(np.prod(bufs.ffn_gate.shape)) + ops.ffn_up_gate(W_ffn_gate, x_norm_view, ffn_gate_view) + + # Step 2: Up projection: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) + bufs.ffn_up.to("npu") + if seq_len > 1: + ops.ffn_up_gate(bufs.x_norm, W_ffn_up, bufs.ffn_up) + else: + x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) + ffn_up_view = bufs.ffn_up.view(np.prod(bufs.ffn_up.shape)) + ops.ffn_up_gate(W_ffn_up, x_norm_view, ffn_up_view) - # Step 3: Element-wise multiplication (apply the 'gating') - hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) + # Step 3: Apply SiLU activation to gate + ffn_gate_view = bufs.ffn_gate.view(np.prod(bufs.ffn_gate.shape)) + ops.ffn_silu(ffn_gate_view, ffn_gate_view) - # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) - output = torch.nn.functional.linear(hidden, fc3_weight) + # Step 4: Element-wise multiplication (apply the 'gating') + bufs.ffn_hidden.to("npu") + ffn_up_view = bufs.ffn_up.view(np.prod(bufs.ffn_up.shape)) + ffn_hidden_view = bufs.ffn_hidden.view(np.prod(bufs.ffn_hidden.shape)) + ops.ffn_mul(ffn_gate_view, ffn_up_view, ffn_hidden_view) - return output + # Step 5: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) + bufs.ffn_output.to("npu") + if seq_len > 1: + ops.ffn_down(bufs.ffn_hidden, W_ffn_down, bufs.ffn_output) + else: + ffn_output_view = bufs.ffn_output.view(np.prod(bufs.ffn_output.shape)) + ops.ffn_down(W_ffn_down, ffn_hidden_view, ffn_output_view) def transformer_block_forward( seq_len, + layer_idx, attn_keys_cache, attn_values_cache, num_heads, @@ -320,9 +464,6 @@ def transformer_block_forward( W_attn_value, W_attn_out, W_norm2, - W_ffn_fc1, - W_ffn_fc2, - W_ffn_fc3, rope_angles, attn_mask ): @@ -370,11 +511,9 @@ def transformer_block_forward( x_norm = bufs.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 5: fully-connected feed-forward network - ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) + swiglu_ffn_forward(seq_len, layer_idx) # Step 6: Residual - bufs.ffn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = ffn_output - bufs.ffn_output.to("npu") ffn_output_view = bufs.ffn_output.view(np.prod(bufs.ffn_output.shape)) ops.residual_add(x_view, ffn_output_view, x_view) @@ -407,6 +546,7 @@ def llama_forward_pass( for layer_idx in range(config.n_layers): state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( seq_len, + layer_idx, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx], config.n_heads, @@ -416,9 +556,6 @@ def llama_forward_pass( W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], - W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], - W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], - W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], W_norm2=aie_buffers.W_norm2[layer_idx], rope_angles=config.angles, attn_mask=attn_mask, diff --git a/operators/elementwise_mul/design.py b/operators/elementwise_mul/design.py index 88ae1e31..f2a2a266 100644 --- a/operators/elementwise_mul/design.py +++ b/operators/elementwise_mul/design.py @@ -12,9 +12,10 @@ from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ +from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_mul(dev, num_elements, num_columns, num_channels, tile_size, trace_size): +def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, archive_name): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -36,7 +37,7 @@ def my_eltwise_mul(dev, num_elements, num_columns, num_channels, tile_size, trac # AIE Core Function declaration eltwise_mul_bf16_vector = Kernel( - "eltwise_mul_bf16_vector", "mul.o", [tile_ty, tile_ty, tile_ty, np.int32] + "eltwise_mul_bf16_vector", archive_name, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile @@ -146,11 +147,6 @@ def str_to_device(device: str): p.add_argument( "-co", "--columns", required=True, dest="cols", help="Number of columns" ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) # Tile size (elements per tile) - defaults to 1024 for backward compatibility p.add_argument( "-ts", @@ -183,9 +179,6 @@ def str_to_device(device: str): elif isinstance(dev, NPU2) and columns > 8: raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") tile_size = int(opts.tile_size) if length % (tile_size * columns) != 0: print( @@ -198,7 +191,7 @@ def str_to_device(device: str): raise ValueError trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - module = my_eltwise_mul(dev, length, columns, channels, tile_size, trace_size) + module = my_eltwise_mul(dev, length, columns, tile_size, trace_size, "mul.o") output_file_path = Path(opts.output_file_path) diff --git a/operators/elementwise_mul/op.py b/operators/elementwise_mul/op.py index 954391b8..82cecd95 100644 --- a/operators/elementwise_mul/op.py +++ b/operators/elementwise_mul/op.py @@ -1,164 +1,72 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEElementwiseMul(AIEOperatorBase): +class AIEElementwiseMul(SingleMLIRSourceOperator): """AIE-accelerated element-wise multiplication""" def __init__( - self, size, num_aie_columns, num_channels, tile_size, trace_size=0, context=None + self, + size, + tile_size, + num_aie_columns=8, + context=None, ): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns - self.num_channels = num_channels - self.trace_size = trace_size - - total_shimdma_channels = self.num_aie_columns * self.num_channels + # Enforce ShimDMA limits for elementwise_mul (uses 2 inputs per core) + # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels + total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + SingleMLIRSourceOperator.__init__(self, context=context) - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"mul_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def get_artifacts(self, prefix="eltwise_mul_"): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_mul", callback_args=[ self.context.device_manager.device_type, self.size, self.num_aie_columns, - self.num_channels, self.tile_size, - self.trace_size, + 0, + self.get_kernel_archive_name(), ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"mul.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - xclbin_artifact, insts_artifact = self.get_artifacts() - - mlir_artifact = xclbin_artifact.depends[0] - mlir_artifact.callback_args[0] = self.context.device_manager.device_type - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input1", self.size) - self.add_buffer("input2", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "eltwise_mul", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("eltwise_mul", "input1", "input2", "output") - - def forward(self, x, y): - """Forward pass for element-wise multiplication""" - applicable = ( - len(x.shape) >= 1 - and len(y.shape) >= 1 - and x.shape[-1] <= self.size - and y.shape[-1] <= self.size - and x.numel() <= self.size - and y.numel() <= self.size - and x.numel() == y.numel() - and x.shape == y.shape - ) - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - y_flat = y.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat, y_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y): - """Execute element-wise multiplication operation on AIE hardware""" - # x, y are [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - y_flat = y.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size or len(y_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)}, y={len(y_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input1", x_flat) - self.write_buffer("input2", y_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"mul.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + # Runtime setup + return [ + AIERuntimeArgSpec("in", (self.size,)), # input1 + AIERuntimeArgSpec("in", (self.size,)), # input2 + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/elementwise_mul/test.py b/operators/elementwise_mul/test.py index 1a3f762d..743fb8d7 100755 --- a/operators/elementwise_mul/test.py +++ b/operators/elementwise_mul/test.py @@ -62,9 +62,8 @@ def test_elementwise_mul( operator = AIEElementwiseMul( size=input_length, - num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, + num_aie_columns=num_aie_columns, context=aie_context, ) diff --git a/operators/silu/design.py b/operators/silu/design.py index 5968943b..ce85517e 100644 --- a/operators/silu/design.py +++ b/operators/silu/design.py @@ -12,15 +12,17 @@ from aie.iron.device import Tile, NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ +from aie.helpers.util import np_ndarray_type_get_shape -def my_silu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_silu(dev, size, num_columns, tile_size, trace_size, archive_name): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] transfer_type = np.ndarray[(size,), np.dtype[xfr_dtype]] - # Calculate number of iterations per core + # Calculate number of iterations per core (using 1 channel per column) + num_channels = 1 total_cores = num_columns * num_channels per_core_elements = size // total_cores N_div_n = per_core_elements // line_size @@ -43,7 +45,7 @@ def my_silu(dev, size, num_columns, num_channels, tile_size, trace_size): # External, binary kernel definition silu_fcn = Kernel( "silu_bf16", - "silu.o", + archive_name, [line_type, line_type, np.int32], ) @@ -152,11 +154,6 @@ def str_to_device(device: str): p.add_argument( "-co", "--columns", required=True, dest="cols", help="Number of columns" ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) # Tile size (elements per tile) - defaults to 1024 for backward compatibility p.add_argument( "-ts", @@ -189,11 +186,10 @@ def str_to_device(device: str): elif isinstance(dev, NPU2) and columns > 8: raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") tile_size = int(opts.tile_size) - if ((length % tile_size) % columns % channels) != 0: + # Using 1 channel per column for SiLU + num_channels = 1 + if ((length % tile_size) % columns % num_channels) != 0: print( "transfer size (" + str(length) @@ -204,7 +200,7 @@ def str_to_device(device: str): raise ValueError trace_size = opts.trace_size - module = my_silu(dev, length, columns, channels, tile_size, trace_size) + module = my_silu(dev, length, columns, tile_size, trace_size, "silu.o") output_file_path = Path(opts.output_file_path) diff --git a/operators/silu/op.py b/operators/silu/op.py index c8a89d05..42ee4fcc 100644 --- a/operators/silu/op.py +++ b/operators/silu/op.py @@ -1,155 +1,65 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIESiLU(AIEOperatorBase): +class AIESiLU(SingleMLIRSourceOperator): """AIE-accelerated SiLU activation function""" - def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): - max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + def __init__(self, size, tile_size, num_aie_columns=8, context=None): + assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + self.size = size self.tile_size = tile_size - - self.num_columns = num_aie_columns - self.num_channels = num_channels + self.num_aie_columns = num_aie_columns # Enforce ShimDMA limits for SiLU (uses 1 input per core) - # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels - total_shimdma_channels = self.num_columns * self.num_channels + # Maximum safe configuration: 8 columns × 1 channel = 8 ShimDMA channels + total_shimdma_channels = self.num_aie_columns * 1 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + SingleMLIRSourceOperator.__init__(self, context=context) - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"silu_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" - def get_artifacts(self, prefix="silu_"): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"{prefix}{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_silu", callback_args=[ self.context.device_manager.device_type, self.size, - self.num_columns, - self.num_channels, + self.num_aie_columns, self.tile_size, 0, + self.get_kernel_archive_name(), ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"silu.o", - depends=[ - SourceArtifact.new( - self.context.base_dir / "aie_kernels" / "aie2p" / "silu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - return xclbin_artifact, insts_artifact - - def set_up_artifacts(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. - # Compilation artifacts - xclbin_artifact, insts_artifact = self.get_artifacts() - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed. + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"silu.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "aie2p" / "silu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): # Runtime setup - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "silu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("silu", "input", "output") - - def forward(self, x): - """Forward pass for SiLU activation""" - applicable = ( - len(x.shape) >= 1 and x.shape[-1] <= self.size and x.numel() <= self.size - ) - if not applicable: - raise AIEOperatorConstraintError("AIESiLU: incompatible tensor shape(s)") - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - - pad_len = self.size - x_flat.shape[1] - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - out = self._execute_aie_operation(x_flat) - - # Remove padding if added - numel = np.prod(original_shape) - if pad_len > 0: - out = out.reshape(-1)[..., :numel] - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x, y=None): - """Execute SiLU operation on AIE hardware""" - # x is [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - - # Verify size matches expected - if len(x_flat) != self.size: - raise AIEOperatorConstraintError( - f"Input size x={len(x_flat)} doesn't match configured size {self.size}" - ) - - self.write_buffer("input", x_flat) - test_pattern = np.zeros(len(x_flat), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) - - return result + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] From 7515199e6f8990f5f8752c5776ec5fe73b319fe0 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 16 Jan 2026 09:41:51 -0700 Subject: [PATCH 20/99] offload RoPE --- applications/llama_3.2_1b/llama_npu.py | 136 ++++++++++++++++++------- operators/rope/design.py | 5 +- operators/rope/op.py | 126 +++++++---------------- 3 files changed, 142 insertions(+), 125 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 9955834e..09a2b2dc 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -25,6 +25,7 @@ ) from operators.elementwise_mul.op import AIEElementwiseMul from operators.silu.op import AIESiLU +from operators.rope.op import AIERope logging.basicConfig(level=logging.DEBUG) @@ -180,6 +181,38 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # RoPE operators + # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) + # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) + # angle_rows=1 because all rows use the same angle row (angles are per position) + self.prefill.rope_queries = AIERope( + rows=prompt_len * config.n_heads, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context + ).compile().get_callable() + + self.prefill.rope_keys = AIERope( + rows=prompt_len * config.n_kv_groups, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context + ).compile().get_callable() + + self.decode.rope_queries = AIERope( + rows=1 * config.n_heads, + cols=config.head_dim, + angle_rows=1, + context=self.context + ).compile().get_callable() + + self.decode.rope_keys = AIERope( + rows=1 * config.n_kv_groups, + cols=config.head_dim, + angle_rows=1, + context=self.context + ).compile().get_callable() + # Allocate buffers shared with NPU @@ -188,7 +221,7 @@ def __init__(self, config, prompt_len): aie_buffers = None class AIEPrefillBuffers: - def __init__(self, prompt_len, emb_dim, hidden_dim): + def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.attn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) @@ -197,9 +230,16 @@ def __init__(self, prompt_len, emb_dim, hidden_dim): self.ffn_gate = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_up = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_hidden = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) + # RoPE buffers + self.rope_queries_in = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_queries_out = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_keys_in = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_keys_out = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles_queries = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles_keys = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) class AIEDecodeBuffers: - def __init__(self, emb_dim, hidden_dim): + def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.attn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) @@ -208,12 +248,19 @@ def __init__(self, emb_dim, hidden_dim): self.ffn_gate = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_up = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_hidden = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) + # RoPE buffers + self.rope_queries_in = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_queries_out = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_keys_in = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_keys_out = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles_queries = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles_keys = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline - self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim) - self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim) + self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) + self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) # Transformer block layer-wise RMS norm self.W_norm1 = [] @@ -278,41 +325,58 @@ def __init__(self, config, prompt_len): # Operators # ########################################################################## -def rope_forward(x, angles): - """Rotary positional embedding using precomputed angles""" - # x: (batch, seq_len, num_heads, head_dim) after view and before transpose - # angles: (context_length, head_dim) - _, seq_len, _, head_dim = x.shape - angles_slice = angles[:seq_len] # (seq_len, head_dim) +def rope_forward(x, angles, seq_len, num_preceding_tokens, is_query): + """Rotary positional embedding using NPU""" + # x: (batch, seq_len, num_heads_or_groups, head_dim) + # angles: (context_length, head_dim) - full angle table + batch, seq_len_actual, num_heads_or_groups, head_dim = x.shape - # Split into even and odd dimensions - x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) - x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) + # Select prefill or decode buffers + if seq_len > 1: + ops = aie_ops.prefill + bufs = aie_buffers.prefill + else: + ops = aie_ops.decode + bufs = aie_buffers.decode - # Get cos and sin from angles - cos = angles_slice[:, ::2] # (seq_len, head_dim//2) - sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) + # Select appropriate buffers and operator based on query/key + if is_query: + rope_op = ops.rope_queries + buf_in = bufs.rope_queries_in + buf_out = bufs.rope_queries_out + buf_angles = bufs.rope_angles_queries + else: + rope_op = ops.rope_keys + buf_in = bufs.rope_keys_in + buf_out = bufs.rope_keys_out + buf_angles = bufs.rope_angles_keys - # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) - # (The same cosine and sine values are used across batch and heads.) - cos = cos.unsqueeze(0).unsqueeze(2) - sin = sin.unsqueeze(0).unsqueeze(2) + # Reshape x to (seq_len * num_heads_or_groups, head_dim) for NPU + x_reshaped = x.view(batch * seq_len_actual * num_heads_or_groups, head_dim) - # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] - rotated = torch.empty_like(x) - rotated[..., : head_dim // 2] = x1 * cos - x2 * sin - rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos + # Get the relevant angles slice and repeat for each head/group + angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len_actual] # (seq_len, head_dim) + # Repeat angles for each head/group: (seq_len, head_dim) -> (seq_len * num_heads_or_groups, head_dim) + angles_repeated = angles_slice.repeat_interleave(num_heads_or_groups, dim=0) - return rotated - - -def rms_norm_forward(x, weight, eps=1e-5): - """Root Mean Square Layer Normalization""" - # x: (batch, seq_len, dim) - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - return weight * x - + # Copy to NPU buffers + buf_in.view_as_torch()[:seq_len_actual * num_heads_or_groups, :] = x_reshaped[:seq_len_actual * num_heads_or_groups] + buf_angles.view_as_torch()[:seq_len_actual, :] = angles_slice + + buf_in.to("npu") + buf_angles.to("npu") + buf_out.to("npu") + + # Execute RoPE on NPU + rope_op(buf_in, buf_angles, buf_out) + + buf_out.to("cpu") + + # Read result and reshape back + result = buf_out.view_as_torch()[:seq_len_actual * num_heads_or_groups, :].clone() + result = result.view(batch, seq_len_actual, num_heads_or_groups, head_dim) + + return result def grouped_query_attention_forward( x, @@ -348,8 +412,8 @@ def grouped_query_attention_forward( values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) # Step 2: Apply RoPE - queries = rope_forward(queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) - keys = rope_forward(keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + queries = rope_forward(queries, angles, seq_len, num_preceding_tokens, is_query=True) + keys = rope_forward(keys, angles, seq_len, num_preceding_tokens, is_query=False) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. diff --git a/operators/rope/design.py b/operators/rope/design.py index 780e52fa..0e85e1ef 100644 --- a/operators/rope/design.py +++ b/operators/rope/design.py @@ -38,11 +38,14 @@ def rope( num_aie_columns=1, trace_size=0, method_type=None, + archive_name=None, ): dtype = bfloat16 if angle_rows is None: angle_rows = rows + if archive_name is None: + archive_name = "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" assert cols % (16 * 2) == 0 and cols >= ( 16 * 2 @@ -75,7 +78,7 @@ def rope( # AIE Core Function declaration rope_kernel = Kernel( "rope", - "rope" + (f"_{method_type}" if method_type is not None else "") + ".o", + archive_name, [tensor_tile_ty, angle_tile_ty, tensor_tile_ty, np.int32], ) diff --git a/operators/rope/op.py b/operators/rope/op.py index 7bd0f091..b8d66487 100644 --- a/operators/rope/op.py +++ b/operators/rope/op.py @@ -1,39 +1,36 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIERope(AIEOperatorBase): +class AIERope(SingleMLIRSourceOperator): def __init__( self, rows: int, cols: int, angle_rows=None, - num_aie_columns=None, + num_aie_columns=1, method_type=0, context=None, ): if angle_rows is None: angle_rows = rows - if num_aie_columns is None: - num_aie_columns = 1 - + + assert cols % (16 * 2) == 0 and cols >= (16 * 2), "cols must be multiple of 32 and >= 32" + assert rows % num_aie_columns == 0, "rows must be divisible by num_aie_columns" + assert angle_rows <= rows and rows % angle_rows == 0, "angle_rows must divide rows" + assert angle_rows >= num_aie_columns and angle_rows % num_aie_columns == 0, "angle_rows must be divisible by num_aie_columns" + self.rows = rows self.cols = cols self.angle_rows = angle_rows @@ -41,19 +38,15 @@ def __init__( self.method_type = method_type assert method_type in {0, 1} - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"rope_{self.num_aie_columns}col_{self.rows}rows_{self.cols}cols_{self.angle_rows}arows_{self.method_type}m" - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"rope_{self.num_aie_columns}c_{self.rows}rows_{self.cols}cols_{self.angle_rows}arows_{self.method_type}m" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="rope", callback_args=[ @@ -64,71 +57,28 @@ def set_up_artifacts(self): self.num_aie_columns, 0, self.method_type, + self.get_kernel_archive_name(), ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"rope_{self.method_type}.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "rope.cc" - ) - ], - extra_flags=[ - "-DTWO_HALVES" if 0 == self.method_type else "-DINTERLEAVED" - ], - ), - ], - ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runtime setup - self.add_buffer("in", self.rows * self.cols) - self.add_buffer("angles", self.angle_rows * self.cols) - self.add_buffer("output", self.rows * self.cols) - self.add_kernel( - "rope", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("rope", "in", "angles", "output") - - def forward(self, tensor, angles): - applicable = ( - tensor.shape[-2] == self.rows - and tensor.shape[-1] == self.cols - and tensor.shape[-1] % 16 == 0 - and angles.shape[-2] == self.angle_rows - and angles.shape[-1] == self.cols - ) - if not applicable: - raise AIEOperatorConstraintError("AIERope: incompatible tensor shape(s)") - - # Write data to buffers - self.write_buffer("in", tensor) - self.write_buffer("angles", angles) - - # Execute kernel - self.run_runlist() - - # Read output - result = self.read_buffer_as_torch("output", shape=tensor.shape, dtype=bfloat16) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"rope_{self.method_type}.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "generic" / "rope.cc" + ) + ], + extra_flags=[ + "-DTWO_HALVES" if 0 == self.method_type else "-DINTERLEAVED" + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.rows, self.cols,)), # input tensor + AIERuntimeArgSpec("in", (self.angle_rows, self.cols,)), # angles + AIERuntimeArgSpec("out", (self.rows, self.cols,)), # output + ] From 7ef6b035645ab0b1a18ceea19864ecf050045501 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 16 Jan 2026 10:14:49 -0700 Subject: [PATCH 21/99] offload attention query projection linear layer --- applications/llama_3.2_1b/llama_npu.py | 101 +++++++++++++++++++++---- 1 file changed, 88 insertions(+), 13 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 09a2b2dc..c1786be8 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -213,6 +213,30 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Attention projection operators + # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) + self.prefill.attn_query = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + use_static_weight=True, + context=self.context + ).compile().get_callable() + + self.decode.attn_query = AIEGEMV( + M=config.n_heads * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=self.context + ).compile().get_callable() + # Allocate buffers shared with NPU @@ -237,6 +261,8 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d self.rope_keys_out = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles_queries = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles_keys = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + # Attention projection buffers + self.attn_queries = AIEBuffer(shape=(prompt_len, n_heads * head_dim), dtype=ml_dtypes.bfloat16) class AIEDecodeBuffers: def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): @@ -255,6 +281,8 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.rope_keys_out = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles_queries = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles_keys = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) + # Attention projection buffers + self.attn_queries = AIEBuffer(shape=(1, n_heads * head_dim), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): @@ -265,6 +293,9 @@ def __init__(self, config, prompt_len): # Transformer block layer-wise RMS norm self.W_norm1 = [] self.W_norm2 = [] + # Attention projection weights + self.W_attn_query_prefill = [] + self.W_attn_query_decode = [] # SwiGLU FFN weights self.W_ffn_gate_prefill = [] self.W_ffn_up_prefill = [] @@ -279,6 +310,12 @@ def __init__(self, config, prompt_len): self.W_norm2.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight']).to("npu") ) + self.W_attn_query_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight']).to("npu") + ) + self.W_attn_query_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].T).to("npu") + ) self.W_ffn_gate_decode.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") ) @@ -325,11 +362,11 @@ def __init__(self, config, prompt_len): # Operators # ########################################################################## -def rope_forward(x, angles, seq_len, num_preceding_tokens, is_query): +def rope_forward(x, angles, num_preceding_tokens, is_query): """Rotary positional embedding using NPU""" # x: (batch, seq_len, num_heads_or_groups, head_dim) # angles: (context_length, head_dim) - full angle table - batch, seq_len_actual, num_heads_or_groups, head_dim = x.shape + batch, seq_len, num_heads_or_groups, head_dim = x.shape # Select prefill or decode buffers if seq_len > 1: @@ -352,16 +389,16 @@ def rope_forward(x, angles, seq_len, num_preceding_tokens, is_query): buf_angles = bufs.rope_angles_keys # Reshape x to (seq_len * num_heads_or_groups, head_dim) for NPU - x_reshaped = x.view(batch * seq_len_actual * num_heads_or_groups, head_dim) + x_reshaped = x.view(batch * seq_len * num_heads_or_groups, head_dim) # Get the relevant angles slice and repeat for each head/group - angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len_actual] # (seq_len, head_dim) + angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len] # (seq_len, head_dim) # Repeat angles for each head/group: (seq_len, head_dim) -> (seq_len * num_heads_or_groups, head_dim) angles_repeated = angles_slice.repeat_interleave(num_heads_or_groups, dim=0) # Copy to NPU buffers - buf_in.view_as_torch()[:seq_len_actual * num_heads_or_groups, :] = x_reshaped[:seq_len_actual * num_heads_or_groups] - buf_angles.view_as_torch()[:seq_len_actual, :] = angles_slice + buf_in.view_as_torch()[:seq_len * num_heads_or_groups, :] = x_reshaped[:seq_len * num_heads_or_groups] + buf_angles.view_as_torch()[:seq_len, :] = angles_slice buf_in.to("npu") buf_angles.to("npu") @@ -373,8 +410,8 @@ def rope_forward(x, angles, seq_len, num_preceding_tokens, is_query): buf_out.to("cpu") # Read result and reshape back - result = buf_out.view_as_torch()[:seq_len_actual * num_heads_or_groups, :].clone() - result = result.view(batch, seq_len_actual, num_heads_or_groups, head_dim) + result = buf_out.view_as_torch()[:seq_len * num_heads_or_groups, :].clone() + result = result.view(batch, seq_len, num_heads_or_groups, head_dim) return result @@ -384,6 +421,7 @@ def grouped_query_attention_forward( values_cache, W_query, W_key, W_value, W_out, angles, + layer_idx, mask=None, num_heads=32, num_kv_groups=8, @@ -397,6 +435,16 @@ def grouped_query_attention_forward( assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) + # Select prefill or decode operations and buffers + if seq_len > 1: + ops = aie_ops.prefill + bufs = aie_buffers.prefill + W_attn_query = aie_buffers.W_attn_query_prefill[layer_idx] + else: + ops = aie_ops.decode + bufs = aie_buffers.decode + W_attn_query = aie_buffers.W_attn_query_decode[layer_idx] + # Step 1: Linear projections # This multiplication produces queries, keys and values for all tokens in the sequence. # The weight matrix is such that multiple queries, keys and values are generated for each token. @@ -404,16 +452,42 @@ def grouped_query_attention_forward( # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. - queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, num_heads * head_dim) + + # Query projection using NPU - write directly to RoPE input buffer to avoid CPU round-trip + bufs.x_norm.to("npu") + bufs.rope_queries_in.to("npu") + if seq_len > 1: + # Project and write to rope buffer with appropriate view + rope_queries_in_view = bufs.rope_queries_in.view((bufs.rope_queries_in.shape[0] // num_heads, num_heads * head_dim)) + ops.attn_query(bufs.x_norm, W_attn_query, rope_queries_in_view) + else: + x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) + rope_queries_in_view = bufs.rope_queries_in.view(np.prod(bufs.rope_queries_in.shape)) + ops.attn_query(W_attn_query, x_norm_view, rope_queries_in_view) + keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) - queries = queries.view(batch, seq_len, num_heads, head_dim) # (batch, seq_len, num_heads, head_dim) keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) - # Step 2: Apply RoPE - queries = rope_forward(queries, angles, seq_len, num_preceding_tokens, is_query=True) - keys = rope_forward(keys, angles, seq_len, num_preceding_tokens, is_query=False) + # Step 2: Apply RoPE to queries (already in rope_queries_in buffer on NPU) + # Get the relevant angles slice + num_preceding_tokens = keys_cache.shape[2] + angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len] # (seq_len, head_dim) + bufs.rope_angles_queries.view_as_torch()[:seq_len, :] = angles_slice + bufs.rope_angles_queries.to("npu") + bufs.rope_queries_out.to("npu") + + # Execute RoPE on NPU (data already there from query projection) + ops.rope_queries(bufs.rope_queries_in, bufs.rope_angles_queries, bufs.rope_queries_out) + + # Read queries result from NPU + bufs.rope_queries_out.to("cpu") + queries = bufs.rope_queries_out.view_as_torch()[:seq_len * num_heads, :].clone() + queries = queries.view(batch, seq_len, num_heads, head_dim) + + # Apply RoPE to keys + keys = rope_forward(keys, angles, num_preceding_tokens, is_query=False) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. @@ -553,6 +627,7 @@ def transformer_block_forward( attn_values_cache, W_attn_query, W_attn_key, W_attn_value, W_attn_out, rope_angles, + layer_idx, attn_mask, num_heads, num_kv_groups, From d43eeea0cbbab7c82871dca0b0ab2c132735e641 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 16 Jan 2026 10:15:06 -0700 Subject: [PATCH 22/99] offload attention key projection linear layer -- slight decrease in output quality --- applications/llama_3.2_1b/llama_npu.py | 61 ++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index c1786be8..6c260911 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -237,6 +237,29 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_key = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + use_static_weight=True, + context=self.context + ).compile().get_callable() + + self.decode.attn_key = AIEGEMV( + M=config.n_kv_groups * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=self.context + ).compile().get_callable() + # Allocate buffers shared with NPU @@ -263,6 +286,7 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d self.rope_angles_keys = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) # Attention projection buffers self.attn_queries = AIEBuffer(shape=(prompt_len, n_heads * head_dim), dtype=ml_dtypes.bfloat16) + self.attn_keys = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) class AIEDecodeBuffers: def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): @@ -283,6 +307,7 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.rope_angles_keys = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) # Attention projection buffers self.attn_queries = AIEBuffer(shape=(1, n_heads * head_dim), dtype=ml_dtypes.bfloat16) + self.attn_keys = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): @@ -296,6 +321,8 @@ def __init__(self, config, prompt_len): # Attention projection weights self.W_attn_query_prefill = [] self.W_attn_query_decode = [] + self.W_attn_key_prefill = [] + self.W_attn_key_decode = [] # SwiGLU FFN weights self.W_ffn_gate_prefill = [] self.W_ffn_up_prefill = [] @@ -316,6 +343,12 @@ def __init__(self, config, prompt_len): self.W_attn_query_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].T).to("npu") ) + self.W_attn_key_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight']).to("npu") + ) + self.W_attn_key_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].T).to("npu") + ) self.W_ffn_gate_decode.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") ) @@ -440,10 +473,12 @@ def grouped_query_attention_forward( ops = aie_ops.prefill bufs = aie_buffers.prefill W_attn_query = aie_buffers.W_attn_query_prefill[layer_idx] + W_attn_key = aie_buffers.W_attn_key_prefill[layer_idx] else: ops = aie_ops.decode bufs = aie_buffers.decode W_attn_query = aie_buffers.W_attn_query_decode[layer_idx] + W_attn_key = aie_buffers.W_attn_key_decode[layer_idx] # Step 1: Linear projections # This multiplication produces queries, keys and values for all tokens in the sequence. @@ -465,9 +500,17 @@ def grouped_query_attention_forward( rope_queries_in_view = bufs.rope_queries_in.view(np.prod(bufs.rope_queries_in.shape)) ops.attn_query(W_attn_query, x_norm_view, rope_queries_in_view) - keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) + # Key projection using NPU - write directly to RoPE input buffer to avoid CPU round-trip + bufs.rope_keys_in.to("npu") + if seq_len > 1: + # Project and write to rope buffer with appropriate view + rope_keys_in_view = bufs.rope_keys_in.view((bufs.rope_keys_in.shape[0] // num_kv_groups, num_kv_groups * head_dim)) + ops.attn_key(bufs.x_norm, W_attn_key, rope_keys_in_view) + else: + rope_keys_in_view = bufs.rope_keys_in.view(np.prod(bufs.rope_keys_in.shape)) + ops.attn_key(W_attn_key, x_norm_view, rope_keys_in_view) + values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) - keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) # Step 2: Apply RoPE to queries (already in rope_queries_in buffer on NPU) @@ -486,8 +529,18 @@ def grouped_query_attention_forward( queries = bufs.rope_queries_out.view_as_torch()[:seq_len * num_heads, :].clone() queries = queries.view(batch, seq_len, num_heads, head_dim) - # Apply RoPE to keys - keys = rope_forward(keys, angles, num_preceding_tokens, is_query=False) + # Apply RoPE to keys (already in rope_keys_in buffer on NPU) + bufs.rope_angles_keys.view_as_torch()[:seq_len, :] = angles_slice + bufs.rope_angles_keys.to("npu") + bufs.rope_keys_out.to("npu") + + # Execute RoPE on NPU (data already there from key projection) + ops.rope_keys(bufs.rope_keys_in, bufs.rope_angles_keys, bufs.rope_keys_out) + + # Read keys result from NPU + bufs.rope_keys_out.to("cpu") + keys = bufs.rope_keys_out.view_as_torch()[:seq_len * num_kv_groups, :].clone() + keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. From de2995ce547cc96d5a666c008b6728d83cb28371 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 16 Jan 2026 14:08:54 -0700 Subject: [PATCH 23/99] add batching to GEMV, fix issue when K +#ifndef VEC_SIZE +#define VEC_SIZE 64 +#endif + void matvec_scalar(uint32_t m, uint32_t k, const bfloat16 *__restrict a, @@ -40,9 +44,8 @@ Matrix-vector multiplication kernel - c: Pointer to the output vector - r: Vector size; data from the matrix and vector will be loaded in and processed in chunks of this size */ -template +template void matvec_vectorized(uint32_t m, - uint32_t k, const bfloat16 *__restrict a, const bfloat16 *__restrict b, bfloat16 *__restrict c) @@ -52,10 +55,9 @@ void matvec_vectorized(uint32_t m, const bfloat16 *b_end = b + k; for (; c < c_end; c++) { aie::accum acc = aie::zeros(); - // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that k is at least - // two. This assumption should hold for any useful use of this function; if k were one, this would be a simple - // scalar multiplication of a vector. - AIE_LOOP_MIN_ITERATION_COUNT(2) + // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that there are at + // least two iterations of the loop, i.e. k >= 2*r. This pragma will break the code if that is not the case! + AIE_LOOP_MIN_ITERATION_COUNT(k / VEC_SIZE) for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) { aie::vector a_vec = aie::load_v(a); aie::vector b_vec = aie::load_v(b_cur); @@ -72,25 +74,23 @@ extern "C" { * `c`. */ void matvec_scalar_bf16_bf16(uint32_t m, - uint32_t k, uint32_t row_offset, const bfloat16 *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) { c_out += row_offset; - matvec_scalar(m, k, a_in, b_in, c_out); + matvec_scalar(m, DIM_K, a_in, b_in, c_out); } void matvec_vectorized_bf16_bf16(uint32_t m, - uint32_t k, uint32_t row_offset, const bfloat16 *__restrict a_in, const bfloat16 *__restrict b_in, bfloat16 *__restrict c_out) { c_out += row_offset; - matvec_vectorized<64>(m, k, a_in, b_in, c_out); + matvec_vectorized(m, a_in, b_in, c_out); } } // extern "C" \ No newline at end of file diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 6c260911..0b603280 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -8,6 +8,7 @@ import ml_dtypes import llama_inference_harness as harness import logging +import time repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) @@ -260,6 +261,56 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_value = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + use_static_weight=True, + context=self.context + ).compile().get_callable() + + self.decode.attn_value = AIEGEMV( + M=config.n_kv_groups * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=self.context + ).compile().get_callable() + + # Attention score computation: Q @ K^T per head + # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head + self.prefill.attn_scores = AIEGEMM( + M=prompt_len, + K=config.head_dim, + N=prompt_len, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + use_static_weight=False, + context=self.context + ).compile().get_callable() + + # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) + # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) + self.decode.attn_scores = AIEGEMV( + M=prompt_len, # max possible context length + K=config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=prompt_len // 8, + num_batches=config.n_heads, + context=self.context + ).compile().get_callable() + # Allocate buffers shared with NPU @@ -287,6 +338,26 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d # Attention projection buffers self.attn_queries = AIEBuffer(shape=(prompt_len, n_heads * head_dim), dtype=ml_dtypes.bfloat16) self.attn_keys = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + self.attn_values = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + # Attention score computation buffers (per-head) - parent buffer with subbuffers + self.attn_scores_queries_per_head = [ + AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + for h in range(n_heads) + ] + self.attn_scores_keys_per_head = [ + AIEBuffer(shape=(head_dim, prompt_len), dtype=ml_dtypes.bfloat16) + for h in range(n_heads) + ] + # Parent buffer for all heads' scores: (n_heads * prompt_len, prompt_len) + self.attn_scores = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scores_per_head = [ + self.attn_scores.subbuffer( + length=prompt_len * prompt_len, + offset=h * prompt_len * prompt_len, + shape=(prompt_len, prompt_len) + ) + for h in range(n_heads) + ] class AIEDecodeBuffers: def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): @@ -308,6 +379,12 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): # Attention projection buffers self.attn_queries = AIEBuffer(shape=(1, n_heads * head_dim), dtype=ml_dtypes.bfloat16) self.attn_keys = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + self.attn_values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + # Attention score computation buffers (batched) + # Batched GEMV expects: (num_batches, M, K) @ (num_batches, K, 1) = (num_batches, M, 1) + self.attn_scores_keys = AIEBuffer(shape=(n_heads, emb_dim, head_dim), dtype=ml_dtypes.bfloat16) # Max context length + self.attn_scores_queries = AIEBuffer(shape=(n_heads, head_dim, 1), dtype=ml_dtypes.bfloat16) + self.attn_scores = AIEBuffer(shape=(n_heads, emb_dim, 1), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): @@ -323,6 +400,8 @@ def __init__(self, config, prompt_len): self.W_attn_query_decode = [] self.W_attn_key_prefill = [] self.W_attn_key_decode = [] + self.W_attn_value_prefill = [] + self.W_attn_value_decode = [] # SwiGLU FFN weights self.W_ffn_gate_prefill = [] self.W_ffn_up_prefill = [] @@ -349,6 +428,12 @@ def __init__(self, config, prompt_len): self.W_attn_key_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].T).to("npu") ) + self.W_attn_value_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight']).to("npu") + ) + self.W_attn_value_prefill.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].T).to("npu") + ) self.W_ffn_gate_decode.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") ) @@ -474,11 +559,13 @@ def grouped_query_attention_forward( bufs = aie_buffers.prefill W_attn_query = aie_buffers.W_attn_query_prefill[layer_idx] W_attn_key = aie_buffers.W_attn_key_prefill[layer_idx] + W_attn_value = aie_buffers.W_attn_value_prefill[layer_idx] else: ops = aie_ops.decode bufs = aie_buffers.decode W_attn_query = aie_buffers.W_attn_query_decode[layer_idx] W_attn_key = aie_buffers.W_attn_key_decode[layer_idx] + W_attn_value = aie_buffers.W_attn_value_decode[layer_idx] # Step 1: Linear projections # This multiplication produces queries, keys and values for all tokens in the sequence. @@ -496,9 +583,12 @@ def grouped_query_attention_forward( rope_queries_in_view = bufs.rope_queries_in.view((bufs.rope_queries_in.shape[0] // num_heads, num_heads * head_dim)) ops.attn_query(bufs.x_norm, W_attn_query, rope_queries_in_view) else: - x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) - rope_queries_in_view = bufs.rope_queries_in.view(np.prod(bufs.rope_queries_in.shape)) - ops.attn_query(W_attn_query, x_norm_view, rope_queries_in_view) + # ropes_queries_in is (num_heads, head_dim) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + W_attn_query_view = W_attn_query.view((1, W_attn_query.shape[0], W_attn_query.shape[1])) + x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) + rope_queries_in_view = bufs.rope_queries_in.view((1, bufs.rope_queries_in.shape[0] * bufs.rope_queries_in.shape[1], 1)) + ops.attn_query(W_attn_query_view, x_norm_view, rope_queries_in_view) # Key projection using NPU - write directly to RoPE input buffer to avoid CPU round-trip bufs.rope_keys_in.to("npu") @@ -507,10 +597,28 @@ def grouped_query_attention_forward( rope_keys_in_view = bufs.rope_keys_in.view((bufs.rope_keys_in.shape[0] // num_kv_groups, num_kv_groups * head_dim)) ops.attn_key(bufs.x_norm, W_attn_key, rope_keys_in_view) else: - rope_keys_in_view = bufs.rope_keys_in.view(np.prod(bufs.rope_keys_in.shape)) - ops.attn_key(W_attn_key, x_norm_view, rope_keys_in_view) - - values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) + rope_keys_in_view = bufs.rope_keys_in.view((1, bufs.rope_keys_in.shape[0] * bufs.rope_keys_in.shape[1], 1)) + W_attn_key_view = W_attn_key.view((1, W_attn_key.shape[0], W_attn_key.shape[1])) + ops.attn_key(W_attn_key_view, x_norm_view, rope_keys_in_view) + + # Value projection using NPU + bufs.attn_values.to("npu") + if seq_len > 1: + # Project to values buffer with appropriate view + ops.attn_value(bufs.x_norm, W_attn_value, bufs.attn_values) + else: + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) + attn_values_view = bufs.attn_values.view((1, bufs.attn_values.shape[1], 1)) + W_attn_value_view = W_attn_value.view((1, W_attn_value.shape[0], W_attn_value.shape[1])) + ops.attn_value(W_attn_value_view, x_norm_view, attn_values_view) + + # Read values result from NPU + bufs.attn_values.to("cpu") + values = bufs.attn_values.view_as_torch()[:seq_len, :].clone() + values = values.unsqueeze(0) # (batch, seq_len, n_kv_groups * head_dim) values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) # Step 2: Apply RoPE to queries (already in rope_queries_in buffer on NPU) @@ -560,12 +668,69 @@ def grouped_query_attention_forward( group_size = num_heads // num_kv_groups keys = keys.repeat_interleave(group_size, dim=1) values = values.repeat_interleave(group_size, dim=1) + context_len = keys.shape[2] + + # Step 6: Compute attention scores using NPU (per-head) + # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, context_len) + # -> (batch, num_heads, seq_len, context_len) + queries_per_head = queries.squeeze(0) # (num_heads, seq_len, head_dim) + keys_per_head = keys.squeeze(0).transpose(-2, -1) # (num_heads, head_dim, context_len) + + if seq_len > 1: + # Prefill: use GEMM per head + for h in range(num_heads): + # Copy data for this head + bufs.attn_scores_queries_per_head[h].view_as_torch()[:context_len, :] = queries_per_head[h, :, :] + bufs.attn_scores_keys_per_head[h].view_as_torch()[:, :context_len] = keys_per_head[h, :, :context_len] + + # Transfer to NPU + bufs.attn_scores_queries_per_head[h].to("npu") + bufs.attn_scores_keys_per_head[h].to("npu") + bufs.attn_scores_per_head[h].to("npu") + + # Execute GEMM for this head + ops.attn_scores( + bufs.attn_scores_queries_per_head[h], + bufs.attn_scores_keys_per_head[h], + bufs.attn_scores_per_head[h] + ) + + # Read back all results at once from parent buffer + bufs.attn_scores.to("cpu") + # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice + max_seq_len = bufs.attn_scores.shape[0] // num_heads + scores = bufs.attn_scores.view_as_torch().view(num_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] + else: + # Decode: batched GEMV with all heads at once + keys_transposed = keys_per_head.transpose(-2, -1) # (num_heads, context_len, head_dim) + + # Copy all heads' data to batched buffers + # Keys: (num_heads, context_len, head_dim) + bufs.attn_scores_keys.view_as_torch()[:, :context_len, :] = keys_transposed[:, :context_len, :] + # Queries: (num_heads, head_dim, 1) - reshape from (num_heads, 1, head_dim) + bufs.attn_scores_queries.view_as_torch()[:, :, 0] = queries_per_head[:, 0, :] + + # Transfer to NPU + bufs.attn_scores_keys.to("npu") + bufs.attn_scores_queries.to("npu") + bufs.attn_scores.to("npu") + + # Execute batched GEMV: (num_heads, context_len, head_dim) @ (num_heads, head_dim, 1) = (num_heads, context_len, 1) + t_aie_start = time.perf_counter() + ops.attn_scores(bufs.attn_scores_keys, bufs.attn_scores_queries, bufs.attn_scores) + t_aie = time.perf_counter() - t_aie_start + # Reference: + t_cpu_start = time.perf_counter() + ref = bufs.attn_scores_keys.to("cpu").view_as_torch() @ bufs.attn_scores_queries.to("cpu").view_as_torch() + t_cpu = time.perf_counter() - t_cpu_start + + # Read back result + bufs.attn_scores.to("cpu") + # Result is (num_heads, max_context_len, 1), reshape to (batch, num_heads, 1, context_len) + scores = bufs.attn_scores.view_as_torch()[:, :context_len, 0].unsqueeze(0).unsqueeze(2) - # Step 6: Compute attention scores - # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) - # -> (batch, num_heads, seq_len, seq_len) - # Entry at row i, column j, indicates how much token i's query attends to token j's key. - scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) + # Apply scaling + scores = scores / math.sqrt(head_dim) # Step 7: Apply mask # This ensures causality, so that tokens in the future cannot attend to tokens in the past. @@ -610,18 +775,22 @@ def swiglu_ffn_forward(seq_len, layer_idx): if seq_len > 1: ops.ffn_up_gate(bufs.x_norm, W_ffn_gate, bufs.ffn_gate) else: - x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) - ffn_gate_view = bufs.ffn_gate.view(np.prod(bufs.ffn_gate.shape)) - ops.ffn_up_gate(W_ffn_gate, x_norm_view, ffn_gate_view) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) + ffn_gate_view = bufs.ffn_gate.view((1, bufs.ffn_gate.shape[1], 1)) + W_ffn_gate_view = W_ffn_gate.view((1, W_ffn_gate.shape[0], W_ffn_gate.shape[1])) + ops.ffn_up_gate(W_ffn_gate_view, x_norm_view, ffn_gate_view) # Step 2: Up projection: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) bufs.ffn_up.to("npu") if seq_len > 1: ops.ffn_up_gate(bufs.x_norm, W_ffn_up, bufs.ffn_up) else: - x_norm_view = bufs.x_norm.view(np.prod(bufs.x_norm.shape)) - ffn_up_view = bufs.ffn_up.view(np.prod(bufs.ffn_up.shape)) - ops.ffn_up_gate(W_ffn_up, x_norm_view, ffn_up_view) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) + ffn_up_view = bufs.ffn_up.view((1, bufs.ffn_up.shape[1], 1)) + W_ffn_up_view = W_ffn_up.view((1, W_ffn_up.shape[0], W_ffn_up.shape[1])) + ops.ffn_up_gate(W_ffn_up_view, x_norm_view, ffn_up_view) # Step 3: Apply SiLU activation to gate ffn_gate_view = bufs.ffn_gate.view(np.prod(bufs.ffn_gate.shape)) @@ -638,8 +807,11 @@ def swiglu_ffn_forward(seq_len, layer_idx): if seq_len > 1: ops.ffn_down(bufs.ffn_hidden, W_ffn_down, bufs.ffn_output) else: - ffn_output_view = bufs.ffn_output.view(np.prod(bufs.ffn_output.shape)) - ops.ffn_down(W_ffn_down, ffn_hidden_view, ffn_output_view) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + ffn_hidden_view = bufs.ffn_hidden.view((1, bufs.ffn_hidden.shape[1], 1)) + ffn_output_view = bufs.ffn_output.view((1, bufs.ffn_output.shape[1], 1)) + W_ffn_down_view = W_ffn_down.view((1, W_ffn_down.shape[0], W_ffn_down.shape[1])) + ops.ffn_down(W_ffn_down_view, ffn_hidden_view, ffn_output_view) def transformer_block_forward( @@ -780,7 +952,11 @@ def llama_forward_pass( # x = bufs.x.view_as_torch().unsqueeze(0) # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) bufs.logits.to("npu") - ops.out_head(aie_buffers.W_out_head, bufs.x.view((config.emb_dim,)), bufs.logits) + # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) + x_view = bufs.x.view((1, config.emb_dim, 1)) + logits_view = bufs.logits.view((1, config.vocab_size, 1)) + W_out_head_view = aie_buffers.W_out_head.view((1, aie_buffers.W_out_head.shape[0], aie_buffers.W_out_head.shape[1])) + ops.out_head(W_out_head_view, x_view, logits_view) bufs.logits.to("cpu") logits = bufs.logits.view_as_torch().view(1, 1, config.vocab_size) diff --git a/operators/gemv/design.py b/operators/gemv/design.py index c4354518..840d3c50 100644 --- a/operators/gemv/design.py +++ b/operators/gemv/design.py @@ -30,10 +30,11 @@ - K: number of columns in the matrix == number of rows in the vector - m_input: number of input rows stored on each AIE core == chunk size for data movement of input A - m_output: number of output rows stored on each AIE core == chunk size for data movement of output C + - num_batches: number of iterations of this mat-vec to perform on contiguous matrices and vectors in memory (results concatenated) """ -def my_matvec(dev, cols, M, K, m_input, m_output=None): +def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_archive="mv.o"): if m_output is None: m_output = m_input @@ -70,19 +71,20 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None): L1_C_ty = np.ndarray[(m_output,), dtype_out] L3_A_ty = np.ndarray[ ( + num_batches, M, K, ), dtype_in, ] - L3_B_ty = np.ndarray[(K,), dtype_in] - L3_C_ty = np.ndarray[(M,), dtype_out] + L3_B_ty = np.ndarray[(num_batches, K,), dtype_in] + L3_C_ty = np.ndarray[(num_batches, M,), dtype_out] func_type = "vectorized" if vectorized else "scalar" matvec = Kernel( f"matvec_{func_type}_{dtype_in_str}_{dtype_out_str}", - "mv.o", - [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], + kernel_archive, + [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], ) A_L3L1_fifos = [ @@ -97,7 +99,7 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None): def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): one_idx = index.constant(1) - for _ in range_(0xFFFFFFFF): + for _ in range_(0xFFFFFFFF): # batch dim handled as part of this loop b = B_L3L1_fifo.acquire(1) # The kernel function computes m output rows; each core is responsible for (M/cols) output rows, so we need to call the kernel (M/cols)/m times. for i_idx in range_(M // m_output // cols): @@ -107,7 +109,7 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): j_i32 = index.casts(T.i32(), j_idx) output_row_offset = j_i32 * m_input a = A_L3L1_fifo.acquire(1) - matvec(m_input, K, output_row_offset, a, b, c) + matvec(m_input, output_row_offset, a, b, c) A_L3L1_fifo.release(1) C_L1L3_fifo.release(1) B_L3L1_fifo.release(1) @@ -129,12 +131,15 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): # The input matrix in DDR is MxK-sized (row-major); each core processes (M/cols)xK-sized matrices in chunks of mxK-sized tiles. # The chunking into mxK-sized tiles happens in the ObjectFIFO; the shim puts all data on the stream in sequence. A_taps = [ - TensorAccessPattern( - tensor_dims=(M, K), - offset=col * (M // cols) * K, - sizes=[1, 1, 1, (M // cols) * K], - strides=[0, 0, 0, 1], - ) + [ + TensorAccessPattern( + tensor_dims=L3_A_ty.__args__[0], + offset=col * (M // cols) * K + batch * M * K, + sizes=[1, 1, 1, (M // cols) * K], + strides=[0, 0, 0, 1], + ) + for batch in range(num_batches) + ] for col in range(cols) ] @@ -143,52 +148,32 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): # Collection pattern for the output vector C: each AIE core writes back its contiguous chunk of rows. C_taps = [ - TensorAccessPattern( - tensor_dims=(1, M), - offset=col * (M // cols), - sizes=[1, 1, 1, (M // cols)], - strides=[0, 0, 0, 1], - ) + [ + TensorAccessPattern( + tensor_dims=L3_C_ty.__args__[0], + offset=col * (M // cols) + batch * M, + sizes=[1, 1, 1, (M // cols)], + strides=[0, 0, 0, 1], + ) + for batch in range(num_batches) + ] for col in range(cols) ] rt = Runtime() with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): rt.start(*workers) - tg = rt.task_group() - for i in range(cols): - rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg) - rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg) - for i in range(cols): - rt.drain(C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True) - rt.finish_task_group(tg) + tg_b = rt.task_group() + for col in range(cols): + # Simple linear transfer of B, includes all batches in sequence + rt.fill(B_L3L1_fifos[col].prod(), B, task_group=tg_b) + for batch in range(num_batches): + tg_ac = rt.task_group() + for col in range(cols): + rt.fill(A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac) + for col in range(cols): + rt.drain(C_L1L3_fifos[col].cons(), C, C_taps[col][batch], task_group=tg_ac, wait=True) + rt.finish_task_group(tg_ac) + rt.finish_task_group(tg_b) return Program(dev_ty, rt).resolve_program(SequentialPlacer()) - - -def main(): - argparser = argparse.ArgumentParser( - prog="AIE Matrix Vector Multiplication MLIR Design", - ) - argparser.add_argument("--dev", type=str, choices=["npu", "npu2"], default="npu") - argparser.add_argument("-M", type=int) - argparser.add_argument("-K", type=int) - argparser.add_argument("-m", type=int) - argparser.add_argument("--cols", type=int) - argparser.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - args = argparser.parse_args() - module = my_matvec(args.dev, args.cols, args.M, args.K, args.m) - - output_file_path = Path(args.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) - - -if __name__ == "__main__": - main() diff --git a/operators/gemv/op.py b/operators/gemv/op.py index fcae168e..9542891e 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -29,8 +29,9 @@ def __init__( num_aie_columns=1, tile_size_input=2, tile_size_output=None, - is_mv=True, + num_batches=1, use_static_weight=False, + kernel_vector_size=64, context=None, ): if tile_size_output is None: @@ -45,6 +46,9 @@ def __init__( self.num_aie_columns = num_aie_columns self.tile_size_input = tile_size_input self.tile_size_output = tile_size_output + self.num_batches = num_batches + self.kernel_vector_size = kernel_vector_size + assert K >= kernel_vector_size and K % kernel_vector_size == 0, "K must be multiple of kernel_vector_size" self.xclbin_artifact = None self.insts_artifact = None @@ -52,7 +56,7 @@ def __init__( SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): - return f"{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_aie_columns}col" + return f"{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_batches}batch_{self.num_aie_columns}col" def get_mlir_artifact(self): operator_dir = Path(__file__).parent @@ -68,24 +72,35 @@ def get_mlir_artifact(self): self.K, self.tile_size_input, self.tile_size_output, + self.num_batches, ], + callback_kwargs={ + "kernel_archive": self.get_kernel_archive_name(), + } ) + + def get_kernel_archive_name(self): + return f"mv_{self.K}k.a" def get_kernel_artifacts(self): return [ KernelObjectArtifact.new( - f"mv.o", + f"mv_{self.K}k.o", depends=[ SourceArtifact.new( self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" ) ], + extra_flags=[ + f"-DDIM_K={self.K}", + f"-DVEC_SIZE={self.kernel_vector_size}", + ] ), ] def get_arg_spec(self): return [ - AIERuntimeArgSpec("in", (self.M, self.K)), # matrix - AIERuntimeArgSpec("in", (self.K,)), # vector - AIERuntimeArgSpec("out", (self.M,)), # output + AIERuntimeArgSpec("in", (self.num_batches, self.M, self.K)), # matrix + AIERuntimeArgSpec("in", (self.num_batches, self.K, 1)), # vector + AIERuntimeArgSpec("out", (self.num_batches, self.M, 1)), # output ] From 74300bd84dc2b5617feaf80ec276bbfc5014c9f1 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 14:07:11 -0700 Subject: [PATCH 24/99] add strided_copy operator --- operators/strided_copy/design.py | 64 ++++++++++++++++++++++++ operators/strided_copy/op.py | 84 ++++++++++++++++++++++++++++++++ operators/strided_copy/test2.py | 68 ++++++++++++++++++++++++++ 3 files changed, 216 insertions(+) create mode 100644 operators/strided_copy/design.py create mode 100644 operators/strided_copy/op.py create mode 100755 operators/strided_copy/test2.py diff --git a/operators/strided_copy/design.py b/operators/strided_copy/design.py new file mode 100644 index 00000000..f528ea49 --- /dev/null +++ b/operators/strided_copy/design.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +#from aie.extras.context import mlir_mod_ctx +#from aie.ir import StridedLayoutAttr, ShapedType +#from aie.dialects.aie import * +#from aie.dialects.aiex import * +from aie.dialects.aiex import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + + +""" +Strided copy design + +This can be useful for data layout manipulation and data copying such as: +input[0, :, 0] -> output[:, 0, 0] +""" +def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, input_offset, output_buffer_size, output_sizes, output_strides, output_offset, transfer_size=None, num_aie_channels=1): + assert input_sizes[0] % num_aie_channels == 0, "Highest dimension of input_sizes must be divisible by num_aie_channels" + assert output_sizes[0] % num_aie_channels == 0, "Highest dimension of output_sizes must be divisible by num_aie_channels" + + if transfer_size is None: + transfer_size = int(np.prod(input_sizes)) + assert np.prod(input_sizes) % transfer_size == 0 + transfer_ty = np.ndarray[(transfer_size,), np.dtype[dtype],] + + inp_ty = np.ndarray[(int(input_buffer_size),), np.dtype[dtype],] + out_ty = np.ndarray[(int(output_buffer_size),), np.dtype[dtype],] + + input_taps = [ + TensorAccessPattern( + tensor_dims=(int(input_buffer_size),), + offset=input_offset + c * (input_sizes[0] // num_aie_channels) * input_strides[0], + sizes=[input_sizes[0] // num_aie_channels, *input_sizes[1:]], + strides=list(input_strides), + ) + for c in range(num_aie_channels) + ] + + output_taps = [ + TensorAccessPattern( + tensor_dims=(int(output_buffer_size),), + offset=output_offset + c * (output_sizes[0] // num_aie_channels) * output_strides[0], + sizes=[output_sizes[0] // num_aie_channels, *output_sizes[1:]], + strides=list(output_strides), + ) + for c in range(num_aie_channels) + ] + + # Use smaller FIFOs for the transfer amount + fifos_in = [ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=2) for c in range(num_aie_channels)] + fifos_out = [fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=2) for c in range(num_aie_channels)] + + rt = Runtime() + with rt.sequence(inp_ty, out_ty) as (inp, out): + tg = rt.task_group() + for c in range(num_aie_channels): + rt.fill(fifos_in[c].prod(), inp, input_taps[c], task_group=tg) + rt.drain(fifos_out[c].cons(), out, output_taps[c], task_group=tg, wait=True) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/operators/strided_copy/op.py b/operators/strided_copy/op.py new file mode 100644 index 00000000..f08b023e --- /dev/null +++ b/operators/strided_copy/op.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from operators.common import ( + SingleMLIRSourceOperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + + +class AIEStridedCopy(SingleMLIRSourceOperator): + """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" + + def __init__( + self, + input_sizes, + input_strides, + input_offset, + output_sizes, + output_strides, + output_offset, + input_buffer_size, + output_buffer_size, + dtype=bfloat16, + transfer_size=None, + num_aie_channels=1, + context=None, + ): + assert len(input_sizes) == len(input_strides) + assert len(output_sizes) == len(output_strides) + self.input_sizes = input_sizes + self.input_strides = input_strides + self.input_offset = input_offset + self.output_sizes = output_sizes + self.output_strides = output_strides + self.output_offset = output_offset + self.input_buffer_size = input_buffer_size + self.output_buffer_size = output_buffer_size + self.dtype = dtype + self.transfer_size = transfer_size + self.num_aie_channels = num_aie_channels + SingleMLIRSourceOperator.__init__(self, context=context) + + def get_operator_name(self): + return f"strided_copy_{'x'.join(map(str, self.input_sizes))}sz_{'x'.join(map(str, self.input_strides))}st_{self.input_offset}off_to_{'x'.join(map(str, self.output_sizes))}sz_{'x'.join(map(str, self.output_strides))}st_{self.output_offset}off_{self.transfer_size if self.transfer_size is not None else 'auto'}tr_{self.num_aie_channels}ch" + + def get_mlir_artifact(self): + operator_dir = Path(__file__).parent + + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="strided_copy", + callback_args=[ + self.context.device_manager.device_type, + self.dtype, + self.input_buffer_size, + self.input_sizes, + self.input_strides, + self.input_offset, + self.output_buffer_size, + self.output_sizes, + self.output_strides, + self.output_offset, + self.transfer_size, + self.num_aie_channels, + ] + ) + + def get_kernel_artifacts(self): + return [] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", self.input_buffer_size), # matrix + AIERuntimeArgSpec("out", self.output_buffer_size), # output + ] diff --git a/operators/strided_copy/test2.py b/operators/strided_copy/test2.py new file mode 100755 index 00000000..8b3ba310 --- /dev/null +++ b/operators/strided_copy/test2.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path +import time +import torch +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from operators.strided_copy.op import AIEStridedCopy +from operators.common import AIEBuffer + +max_prompt_len = 2048 +cached_prompt_len = 9 +prompt_len = 7 +head_dim = 64 +num_heads = 32 + +transpose_concat = AIEStridedCopy( + input_sizes=(num_heads, prompt_len, head_dim,), + input_strides=(head_dim, num_heads * head_dim, 1,), + input_offset=0, + output_sizes=(1, num_heads, prompt_len, head_dim,), + output_strides=(0, max_prompt_len * head_dim, head_dim, 1,), + output_offset=cached_prompt_len * head_dim, + input_buffer_size=prompt_len * num_heads * head_dim, + output_buffer_size=num_heads * max_prompt_len * head_dim, + num_aie_channels=4 +).compile().get_callable() + +value_cache = AIEBuffer((num_heads, max_prompt_len, head_dim)) +value = AIEBuffer((prompt_len, num_heads, head_dim)) + +value_cache.view_as_torch()[:, :cached_prompt_len, :] = torch.randn(num_heads, cached_prompt_len, head_dim) +value.view_as_torch()[:prompt_len, :, :] = torch.randn(prompt_len, num_heads, head_dim) + +t_cpu_start = time.perf_counter() +value_transposed = value.view_as_torch().transpose(0, 1) +out_ref = torch.cat([value_cache.view_as_torch()[:, :cached_prompt_len, :], value_transposed], dim=1) +t_cpu = time.perf_counter() - t_cpu_start + +value_cache.to("npu") +value.to("npu") +t_aie_start = time.perf_counter() +transpose_concat(value, value_cache) +t_aie = time.perf_counter() - t_aie_start +value_cache.to("cpu") + +print(out_ref) +print(t_cpu) +aie_out = value_cache.view_as_torch()[:, :(cached_prompt_len + prompt_len), :] +print(aie_out) +print(t_aie) + +# Check which elements differ +diff = torch.abs(out_ref - aie_out) +max_diff = diff.max() +print(f"Max diff: {max_diff}") +print(f"Number of mismatches (> 1e-2): {(diff > 1e-2).sum()}") + +# Find first mismatch +mismatches = torch.where(diff > 1e-2) +if len(mismatches[0]) > 0: + for i in range(min(10, len(mismatches[0]))): + h, s, d = mismatches[0][i], mismatches[1][i], mismatches[2][i] + print(f"Mismatch at head={h}, seq={s}, dim={d}: ref={out_ref[h,s,d]}, aie={aie_out[h,s,d]}, diff={diff[h,s,d]}") + +assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) + From b1eab7c9b9ac908156f4c2c9f6f868d6eaf64445 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 14:07:44 -0700 Subject: [PATCH 25/99] add patchable callable --- operators/common/aie_base.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index 89740497..c1b0409d 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -145,16 +145,16 @@ def get_kernel_archive_name(self): def get_artifacts(self): operator_name = self.get_operator_name() mlir_artifact = self.get_mlir_artifact() - kernel_deps = self.get_kernel_artifacts() + kernel_deps_inputs = self.get_kernel_artifacts() + kernel_deps = [ + KernelArchiveArtifact.new( + self.get_kernel_archive_name(), + depends=kernel_deps_inputs, + ) + ] if kernel_deps_inputs else [] xclbin_artifact = XclbinArtifact.new( f"{operator_name}.xclbin", - depends=[ - mlir_artifact, - KernelArchiveArtifact.new( - self.get_kernel_archive_name(), - depends=kernel_deps, - ), - ], + depends=[mlir_artifact] + kernel_deps, ) insts_artifact = InstsBinArtifact.new( f"{operator_name}.bin", depends=[mlir_artifact] @@ -167,7 +167,6 @@ def set_up_artifacts(self): self.insts_artifact = insts_artifact self.add_artifacts([xclbin_artifact, insts_artifact]) - def get_callable(self): return SingleXclbinCallable( xclbin_path=self.xclbin_artifact.path, @@ -289,13 +288,26 @@ def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_m def __call__(self, *buffers): assert len(buffers) == len(self.args_spec) assert all( - buffers[i].shape == self.args_spec[i].shape and buffers[i].dtype == self.args_spec[i].dtype + np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype for i in range(len(buffers)) ), "Input buffer shapes or dtypes do not match expected argument specification." self.insts_buffer.to("npu") - assert all(buffer.on == "npu" for buffer, spec in zip(buffers, self.args_spec)), "Not all buffers have been synced on the NPU; for some reason even output buffers must be synced!" + for buf in buffers: + buf.to("npu") opcode = 3 bos = [buffer.bo for buffer in buffers] run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) run.wait() +class PatchableSingleXclbinCallable(SingleXclbinCallable): + def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): + super().__init__(xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager) + self.baseline_instructions = self.insts_buffer.view_as_np().copy() + + def patch(self, patches): + """Apply patches with masking: dict of {position: (value, mask)}.""" + insts = self.insts_buffer.view_as_np() + insts[:] = self.baseline_instructions + for pos, (val, mask) in patches.items(): + insts[pos] = (insts[pos] & ~mask) | (val & mask) + self.insts_buffer.to("npu") From 119bb7ed167993624eff38b3697831da949b1bf7 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 14:33:14 -0700 Subject: [PATCH 26/99] simplify llama_npu.py, make GEMV operator input shapes simpler --- applications/llama_3.2_1b/llama_npu.py | 736 +++++++++++-------------- operators/gemv/op.py | 7 +- 2 files changed, 328 insertions(+), 415 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 0b603280..da4c74ed 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -83,7 +83,7 @@ def __init__(self, config, prompt_len): min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N config.vocab_partitions = 4 - self.prefill.out_head_compilable = AIEGEMM( + self.prefill.gemv_out_head_compilable = AIEGEMM( M=prompt_len, K=config.emb_dim, N=config.padded_vocab_size // config.vocab_partitions, @@ -96,8 +96,8 @@ def __init__(self, config, prompt_len): separate_c_tiles=True, context=self.context ).compile() - self.prefill.out_head = self.prefill.out_head_compilable.get_callable() - self.decode.out_head = AIEGEMV( + self.prefill.out_head = self.prefill.gemv_out_head_compilable.get_callable() + self.decode.gemv_out_head = AIEGEMV( M=config.vocab_size, K=config.emb_dim, num_aie_columns=8, @@ -142,7 +142,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.prefill.ffn_mul = AIEElementwiseMul( + self.prefill.eltwise_mul_ffn = AIEElementwiseMul( size=prompt_len * config.hidden_dim, tile_size=config.hidden_dim, num_aie_columns=8, @@ -150,7 +150,7 @@ def __init__(self, config, prompt_len): ).compile().get_callable() # Decode: GEMV for M=1 - self.decode.ffn_up_gate = AIEGEMV( + self.decode.gemv_ffn_up_gate = AIEGEMV( M=config.hidden_dim, K=config.emb_dim, num_aie_columns=8, @@ -159,7 +159,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.ffn_down = AIEGEMV( + self.decode.gemv_ffn_down = AIEGEMV( M=config.emb_dim, K=config.hidden_dim, num_aie_columns=8, @@ -175,7 +175,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.ffn_mul = AIEElementwiseMul( + self.decode.eltwise_mul_ffn = AIEElementwiseMul( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, @@ -229,7 +229,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.attn_query = AIEGEMV( + self.decode.gemv_attn_query = AIEGEMV( M=config.n_heads * config.head_dim, K=config.emb_dim, num_aie_columns=8, @@ -252,7 +252,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.attn_key = AIEGEMV( + self.decode.gemv_attn_key = AIEGEMV( M=config.n_kv_groups * config.head_dim, K=config.emb_dim, num_aie_columns=8, @@ -275,7 +275,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.attn_value = AIEGEMV( + self.decode.gemv_attn_value = AIEGEMV( M=config.n_kv_groups * config.head_dim, K=config.emb_dim, num_aie_columns=8, @@ -301,7 +301,7 @@ def __init__(self, config, prompt_len): # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) - self.decode.attn_scores = AIEGEMV( + self.decode.gemv_attn_scores = AIEGEMV( M=prompt_len, # max possible context length K=config.head_dim, num_aie_columns=8, @@ -328,25 +328,31 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d self.ffn_gate = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_up = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_hidden = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) - # RoPE buffers - self.rope_queries_in = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_queries_out = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_keys_in = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_keys_out = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles_queries = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles_keys = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) - # Attention projection buffers - self.attn_queries = AIEBuffer(shape=(prompt_len, n_heads * head_dim), dtype=ml_dtypes.bfloat16) - self.attn_keys = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) - self.attn_values = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) - # Attention score computation buffers (per-head) - parent buffer with subbuffers + # Attention buffers: queries and keys serve as both projection output and RoPE input/output + self.queries = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.keys = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.values = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + # Attention score computation buffers (per-head) - parent buffers with subbuffers + # Parent buffer for all heads' queries: (n_heads, prompt_len, head_dim) stored contiguously + self.attn_scores_queries_all = AIEBuffer(shape=(n_heads * prompt_len, head_dim), dtype=ml_dtypes.bfloat16) self.attn_scores_queries_per_head = [ - AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + self.attn_scores_queries_all.subbuffer( + length=prompt_len * head_dim, + offset=h * prompt_len * head_dim, + shape=(prompt_len, head_dim) + ) for h in range(n_heads) ] - self.attn_scores_keys_per_head = [ - AIEBuffer(shape=(head_dim, prompt_len), dtype=ml_dtypes.bfloat16) - for h in range(n_heads) + # Parent buffer for all KV groups' keys: (n_kv_groups, head_dim, prompt_len) stored contiguously + self.attn_scores_keys_all = AIEBuffer(shape=(n_kv_groups * head_dim, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scores_keys_per_kv_group = [ + self.attn_scores_keys_all.subbuffer( + length=head_dim * prompt_len, + offset=g * head_dim * prompt_len, + shape=(head_dim, prompt_len) + ) + for g in range(n_kv_groups) ] # Parent buffer for all heads' scores: (n_heads * prompt_len, prompt_len) self.attn_scores = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) @@ -369,19 +375,12 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.ffn_gate = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_up = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) self.ffn_hidden = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) - # RoPE buffers - self.rope_queries_in = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_queries_out = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_keys_in = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_keys_out = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles_queries = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles_keys = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) - # Attention projection buffers - self.attn_queries = AIEBuffer(shape=(1, n_heads * head_dim), dtype=ml_dtypes.bfloat16) - self.attn_keys = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) - self.attn_values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + # Attention buffers: queries and keys serve as both projection output and RoPE input/output + self.queries = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.keys = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) + self.values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) + self.rope_angles = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) # Attention score computation buffers (batched) - # Batched GEMV expects: (num_batches, M, K) @ (num_batches, K, 1) = (num_batches, M, 1) self.attn_scores_keys = AIEBuffer(shape=(n_heads, emb_dim, head_dim), dtype=ml_dtypes.bfloat16) # Max context length self.attn_scores_queries = AIEBuffer(shape=(n_heads, head_dim, 1), dtype=ml_dtypes.bfloat16) self.attn_scores = AIEBuffer(shape=(n_heads, emb_dim, 1), dtype=ml_dtypes.bfloat16) @@ -457,7 +456,7 @@ def __init__(self, config, prompt_len): self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']).to("npu") # Final linear layer self.W_out_head = AIEBuffer.from_torch(config.weights['model.embed_tokens.weight']).to("npu") # unpadded/unpartitioned, used by GEMV - W_out_head_parts = aie_ops.prefill.out_head_compilable.partition_B( + W_out_head_parts = aie_ops.prefill.gemv_out_head_compilable.partition_B( torch_to_numpy(config.weights['model.embed_tokens.weight']), config.vocab_partitions ) @@ -480,175 +479,33 @@ def __init__(self, config, prompt_len): # Operators # ########################################################################## -def rope_forward(x, angles, num_preceding_tokens, is_query): - """Rotary positional embedding using NPU""" - # x: (batch, seq_len, num_heads_or_groups, head_dim) - # angles: (context_length, head_dim) - full angle table - batch, seq_len, num_heads_or_groups, head_dim = x.shape - - # Select prefill or decode buffers - if seq_len > 1: - ops = aie_ops.prefill - bufs = aie_buffers.prefill - else: - ops = aie_ops.decode - bufs = aie_buffers.decode - - # Select appropriate buffers and operator based on query/key - if is_query: - rope_op = ops.rope_queries - buf_in = bufs.rope_queries_in - buf_out = bufs.rope_queries_out - buf_angles = bufs.rope_angles_queries - else: - rope_op = ops.rope_keys - buf_in = bufs.rope_keys_in - buf_out = bufs.rope_keys_out - buf_angles = bufs.rope_angles_keys - - # Reshape x to (seq_len * num_heads_or_groups, head_dim) for NPU - x_reshaped = x.view(batch * seq_len * num_heads_or_groups, head_dim) - - # Get the relevant angles slice and repeat for each head/group - angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len] # (seq_len, head_dim) - # Repeat angles for each head/group: (seq_len, head_dim) -> (seq_len * num_heads_or_groups, head_dim) - angles_repeated = angles_slice.repeat_interleave(num_heads_or_groups, dim=0) - - # Copy to NPU buffers - buf_in.view_as_torch()[:seq_len * num_heads_or_groups, :] = x_reshaped[:seq_len * num_heads_or_groups] - buf_angles.view_as_torch()[:seq_len, :] = angles_slice - - buf_in.to("npu") - buf_angles.to("npu") - buf_out.to("npu") - - # Execute RoPE on NPU - rope_op(buf_in, buf_angles, buf_out) - - buf_out.to("cpu") - - # Read result and reshape back - result = buf_out.view_as_torch()[:seq_len * num_heads_or_groups, :].clone() - result = result.view(batch, seq_len, num_heads_or_groups, head_dim) - - return result - -def grouped_query_attention_forward( +def grouped_query_attention_forward_prefill( + config, x, keys_cache, values_cache, - W_query, W_key, W_value, W_out, - angles, layer_idx, mask=None, - num_heads=32, - num_kv_groups=8, ): - batch, seq_len, d_in = x.shape - assert W_query.shape[0] >= num_heads and W_query.shape[0] % num_heads == 0 - head_dim = W_query.shape[0] // num_heads - assert W_key.shape[0] == num_kv_groups * head_dim - assert W_value.shape[0] == num_kv_groups * head_dim + batch, seq_len, emb_dim = x.shape num_preceding_tokens = keys_cache.shape[2] - assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) - assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) - - # Select prefill or decode operations and buffers - if seq_len > 1: - ops = aie_ops.prefill - bufs = aie_buffers.prefill - W_attn_query = aie_buffers.W_attn_query_prefill[layer_idx] - W_attn_key = aie_buffers.W_attn_key_prefill[layer_idx] - W_attn_value = aie_buffers.W_attn_value_prefill[layer_idx] - else: - ops = aie_ops.decode - bufs = aie_buffers.decode - W_attn_query = aie_buffers.W_attn_query_decode[layer_idx] - W_attn_key = aie_buffers.W_attn_key_decode[layer_idx] - W_attn_value = aie_buffers.W_attn_value_decode[layer_idx] # Step 1: Linear projections - # This multiplication produces queries, keys and values for all tokens in the sequence. - # The weight matrix is such that multiple queries, keys and values are generated for each token. - # For each token, each head corresponds to one query. - # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). - # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. - # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. - - # Query projection using NPU - write directly to RoPE input buffer to avoid CPU round-trip - bufs.x_norm.to("npu") - bufs.rope_queries_in.to("npu") - if seq_len > 1: - # Project and write to rope buffer with appropriate view - rope_queries_in_view = bufs.rope_queries_in.view((bufs.rope_queries_in.shape[0] // num_heads, num_heads * head_dim)) - ops.attn_query(bufs.x_norm, W_attn_query, rope_queries_in_view) - else: - # ropes_queries_in is (num_heads, head_dim) - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - W_attn_query_view = W_attn_query.view((1, W_attn_query.shape[0], W_attn_query.shape[1])) - x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) - rope_queries_in_view = bufs.rope_queries_in.view((1, bufs.rope_queries_in.shape[0] * bufs.rope_queries_in.shape[1], 1)) - ops.attn_query(W_attn_query_view, x_norm_view, rope_queries_in_view) - - # Key projection using NPU - write directly to RoPE input buffer to avoid CPU round-trip - bufs.rope_keys_in.to("npu") - if seq_len > 1: - # Project and write to rope buffer with appropriate view - rope_keys_in_view = bufs.rope_keys_in.view((bufs.rope_keys_in.shape[0] // num_kv_groups, num_kv_groups * head_dim)) - ops.attn_key(bufs.x_norm, W_attn_key, rope_keys_in_view) - else: - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) - rope_keys_in_view = bufs.rope_keys_in.view((1, bufs.rope_keys_in.shape[0] * bufs.rope_keys_in.shape[1], 1)) - W_attn_key_view = W_attn_key.view((1, W_attn_key.shape[0], W_attn_key.shape[1])) - ops.attn_key(W_attn_key_view, x_norm_view, rope_keys_in_view) - - # Value projection using NPU - bufs.attn_values.to("npu") - if seq_len > 1: - # Project to values buffer with appropriate view - ops.attn_value(bufs.x_norm, W_attn_value, bufs.attn_values) - else: - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) - attn_values_view = bufs.attn_values.view((1, bufs.attn_values.shape[1], 1)) - W_attn_value_view = W_attn_value.view((1, W_attn_value.shape[0], W_attn_value.shape[1])) - ops.attn_value(W_attn_value_view, x_norm_view, attn_values_view) - - # Read values result from NPU - bufs.attn_values.to("cpu") - values = bufs.attn_values.view_as_torch()[:seq_len, :].clone() - values = values.unsqueeze(0) # (batch, seq_len, n_kv_groups * head_dim) - values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) - - # Step 2: Apply RoPE to queries (already in rope_queries_in buffer on NPU) - # Get the relevant angles slice - num_preceding_tokens = keys_cache.shape[2] - angles_slice = angles[num_preceding_tokens : num_preceding_tokens + seq_len] # (seq_len, head_dim) - bufs.rope_angles_queries.view_as_torch()[:seq_len, :] = angles_slice - bufs.rope_angles_queries.to("npu") - bufs.rope_queries_out.to("npu") - - # Execute RoPE on NPU (data already there from query projection) - ops.rope_queries(bufs.rope_queries_in, bufs.rope_angles_queries, bufs.rope_queries_out) - - # Read queries result from NPU - bufs.rope_queries_out.to("cpu") - queries = bufs.rope_queries_out.view_as_torch()[:seq_len * num_heads, :].clone() - queries = queries.view(batch, seq_len, num_heads, head_dim) - - # Apply RoPE to keys (already in rope_keys_in buffer on NPU) - bufs.rope_angles_keys.view_as_torch()[:seq_len, :] = angles_slice - bufs.rope_angles_keys.to("npu") - bufs.rope_keys_out.to("npu") - - # Execute RoPE on NPU (data already there from key projection) - ops.rope_keys(bufs.rope_keys_in, bufs.rope_angles_keys, bufs.rope_keys_out) - - # Read keys result from NPU - bufs.rope_keys_out.to("cpu") - keys = bufs.rope_keys_out.view_as_torch()[:seq_len * num_kv_groups, :].clone() - keys = keys.view(batch, seq_len, num_kv_groups, head_dim) + aie_ops.prefill.attn_query(aie_buffers.prefill.x_norm, aie_buffers.W_attn_query_prefill[layer_idx], aie_buffers.prefill.queries) + aie_ops.prefill.attn_key(aie_buffers.prefill.x_norm, aie_buffers.W_attn_key_prefill[layer_idx], aie_buffers.prefill.keys) + aie_ops.prefill.attn_value(aie_buffers.prefill.x_norm, aie_buffers.W_attn_value_prefill[layer_idx], aie_buffers.prefill.values) + + # Step 2: Apply RoPE to queries and keys + aie_ops.prefill.rope_queries(aie_buffers.prefill.queries, aie_buffers.prefill.rope_angles, aie_buffers.prefill.queries) + aie_ops.prefill.rope_keys(aie_buffers.prefill.keys, aie_buffers.prefill.rope_angles, aie_buffers.prefill.keys) + + # Read results from NPU + queries = aie_buffers.prefill.queries.to("cpu").view_as_torch()[:seq_len * config.n_heads, :] + keys = aie_buffers.prefill.keys.to("cpu").view_as_torch()[:seq_len * config.n_kv_groups, :] + values = aie_buffers.prefill.values.to("cpu").view_as_torch()[:seq_len, :] # (seq_len, n_kv_groups * head_dim) + queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) + keys = keys.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) + values = values.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) # (batch, seq_len, num_kv_groups, head_dim) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. @@ -665,72 +522,45 @@ def grouped_query_attention_forward( values = values_cache # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value - group_size = num_heads // num_kv_groups - keys = keys.repeat_interleave(group_size, dim=1) + group_size = config.n_heads // config.n_kv_groups values = values.repeat_interleave(group_size, dim=1) context_len = keys.shape[2] # Step 6: Compute attention scores using NPU (per-head) # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, context_len) # -> (batch, num_heads, seq_len, context_len) - queries_per_head = queries.squeeze(0) # (num_heads, seq_len, head_dim) - keys_per_head = keys.squeeze(0).transpose(-2, -1) # (num_heads, head_dim, context_len) - if seq_len > 1: - # Prefill: use GEMM per head - for h in range(num_heads): - # Copy data for this head - bufs.attn_scores_queries_per_head[h].view_as_torch()[:context_len, :] = queries_per_head[h, :, :] - bufs.attn_scores_keys_per_head[h].view_as_torch()[:, :context_len] = keys_per_head[h, :, :context_len] - - # Transfer to NPU - bufs.attn_scores_queries_per_head[h].to("npu") - bufs.attn_scores_keys_per_head[h].to("npu") - bufs.attn_scores_per_head[h].to("npu") - - # Execute GEMM for this head - ops.attn_scores( - bufs.attn_scores_queries_per_head[h], - bufs.attn_scores_keys_per_head[h], - bufs.attn_scores_per_head[h] - ) - - # Read back all results at once from parent buffer - bufs.attn_scores.to("cpu") - # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice - max_seq_len = bufs.attn_scores.shape[0] // num_heads - scores = bufs.attn_scores.view_as_torch().view(num_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] - else: - # Decode: batched GEMV with all heads at once - keys_transposed = keys_per_head.transpose(-2, -1) # (num_heads, context_len, head_dim) - - # Copy all heads' data to batched buffers - # Keys: (num_heads, context_len, head_dim) - bufs.attn_scores_keys.view_as_torch()[:, :context_len, :] = keys_transposed[:, :context_len, :] - # Queries: (num_heads, head_dim, 1) - reshape from (num_heads, 1, head_dim) - bufs.attn_scores_queries.view_as_torch()[:, :, 0] = queries_per_head[:, 0, :] - - # Transfer to NPU - bufs.attn_scores_keys.to("npu") - bufs.attn_scores_queries.to("npu") - bufs.attn_scores.to("npu") - - # Execute batched GEMV: (num_heads, context_len, head_dim) @ (num_heads, head_dim, 1) = (num_heads, context_len, 1) - t_aie_start = time.perf_counter() - ops.attn_scores(bufs.attn_scores_keys, bufs.attn_scores_queries, bufs.attn_scores) - t_aie = time.perf_counter() - t_aie_start - # Reference: - t_cpu_start = time.perf_counter() - ref = bufs.attn_scores_keys.to("cpu").view_as_torch() @ bufs.attn_scores_queries.to("cpu").view_as_torch() - t_cpu = time.perf_counter() - t_cpu_start - - # Read back result - bufs.attn_scores.to("cpu") - # Result is (num_heads, max_context_len, 1), reshape to (batch, num_heads, 1, context_len) - scores = bufs.attn_scores.view_as_torch()[:, :context_len, 0].unsqueeze(0).unsqueeze(2) + queries_buf = aie_buffers.prefill.attn_scores_queries_all.view_as_torch().view( + config.n_heads, -1, config.head_dim + ) + queries_buf[:, :seq_len, :] = queries.squeeze(0)[:, :seq_len, :] # (num_heads, seq_len, head_dim) + keys_buf = aie_buffers.prefill.attn_scores_keys_all.view_as_torch().view( + config.n_kv_groups, config.head_dim, -1 + ) + keys_buf[:, :, :context_len] = keys.squeeze(0).transpose(-2, -1) # (num_kv_groups, head_dim, context_len) + + # Transfer parent buffers to NPU once + aie_buffers.prefill.attn_scores_queries_all.to("npu") + aie_buffers.prefill.attn_scores_keys_all.to("npu") + aie_buffers.prefill.attn_scores.to("npu") + + # Execute GEMM for each head using sub-buffers + for h in range(config.n_heads): + kv_group = h // group_size + aie_ops.prefill.attn_scores( + aie_buffers.prefill.attn_scores_queries_per_head[h], + aie_buffers.prefill.attn_scores_keys_per_kv_group[kv_group], + aie_buffers.prefill.attn_scores_per_head[h] + ) + + # Read back all results at once from parent buffer + aie_buffers.prefill.attn_scores.to("cpu") + # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice + max_seq_len = aie_buffers.prefill.attn_scores.shape[0] // config.n_heads + scores = aie_buffers.prefill.attn_scores.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] # Apply scaling - scores = scores / math.sqrt(head_dim) + scores = scores / math.sqrt(config.head_dim) # Step 7: Apply mask # This ensures causality, so that tokens in the future cannot attend to tokens in the past. @@ -749,220 +579,302 @@ def grouped_query_attention_forward( # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - output = torch.nn.functional.linear(context, W_out) + output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) return output, keys_cache, values_cache -def swiglu_ffn_forward(seq_len, layer_idx): - # Select prefill or decode operations and buffers - if seq_len > 1: - ops = aie_ops.prefill - bufs = aie_buffers.prefill - W_ffn_gate = aie_buffers.W_ffn_gate_prefill[layer_idx] - W_ffn_up = aie_buffers.W_ffn_up_prefill[layer_idx] - W_ffn_down = aie_buffers.W_ffn_down_prefill[layer_idx] - else: - ops = aie_ops.decode - bufs = aie_buffers.decode - W_ffn_gate = aie_buffers.W_ffn_gate_decode[layer_idx] - W_ffn_up = aie_buffers.W_ffn_up_decode[layer_idx] - W_ffn_down = aie_buffers.W_ffn_down_decode[layer_idx] - - # Step 1: Gate projection: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) - bufs.x_norm.to("npu") - bufs.ffn_gate.to("npu") - if seq_len > 1: - ops.ffn_up_gate(bufs.x_norm, W_ffn_gate, bufs.ffn_gate) - else: - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) - ffn_gate_view = bufs.ffn_gate.view((1, bufs.ffn_gate.shape[1], 1)) - W_ffn_gate_view = W_ffn_gate.view((1, W_ffn_gate.shape[0], W_ffn_gate.shape[1])) - ops.ffn_up_gate(W_ffn_gate_view, x_norm_view, ffn_gate_view) - - # Step 2: Up projection: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) - bufs.ffn_up.to("npu") - if seq_len > 1: - ops.ffn_up_gate(bufs.x_norm, W_ffn_up, bufs.ffn_up) - else: - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - x_norm_view = bufs.x_norm.view((1, bufs.x_norm.shape[1], 1)) - ffn_up_view = bufs.ffn_up.view((1, bufs.ffn_up.shape[1], 1)) - W_ffn_up_view = W_ffn_up.view((1, W_ffn_up.shape[0], W_ffn_up.shape[1])) - ops.ffn_up_gate(W_ffn_up_view, x_norm_view, ffn_up_view) - - # Step 3: Apply SiLU activation to gate - ffn_gate_view = bufs.ffn_gate.view(np.prod(bufs.ffn_gate.shape)) - ops.ffn_silu(ffn_gate_view, ffn_gate_view) - - # Step 4: Element-wise multiplication (apply the 'gating') - bufs.ffn_hidden.to("npu") - ffn_up_view = bufs.ffn_up.view(np.prod(bufs.ffn_up.shape)) - ffn_hidden_view = bufs.ffn_hidden.view(np.prod(bufs.ffn_hidden.shape)) - ops.ffn_mul(ffn_gate_view, ffn_up_view, ffn_hidden_view) - - # Step 5: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) - bufs.ffn_output.to("npu") - if seq_len > 1: - ops.ffn_down(bufs.ffn_hidden, W_ffn_down, bufs.ffn_output) - else: - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - ffn_hidden_view = bufs.ffn_hidden.view((1, bufs.ffn_hidden.shape[1], 1)) - ffn_output_view = bufs.ffn_output.view((1, bufs.ffn_output.shape[1], 1)) - W_ffn_down_view = W_ffn_down.view((1, W_ffn_down.shape[0], W_ffn_down.shape[1])) - ops.ffn_down(W_ffn_down_view, ffn_hidden_view, ffn_output_view) +def grouped_query_attention_forward_decode( + config, + x, + keys_cache, + values_cache, + layer_idx, + mask=None, +): + batch, seq_len, emb_dim = x.shape + + # Step 1: Linear projections - write directly to queries/keys/values buffers + aie_ops.decode.gemv_attn_query(aie_buffers.W_attn_query_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.queries) + aie_ops.decode.gemv_attn_key(aie_buffers.W_attn_key_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.keys) + aie_ops.decode.gemv_attn_value(aie_buffers.W_attn_value_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.values) + + # Step 2: Apply RoPE - use same buffers for input and output + aie_ops.decode.rope_queries(aie_buffers.decode.queries, aie_buffers.decode.rope_angles, aie_buffers.decode.queries) + aie_ops.decode.rope_keys(aie_buffers.decode.keys, aie_buffers.decode.rope_angles, aie_buffers.decode.keys) + + aie_buffers.decode.values.to("cpu") + values = aie_buffers.decode.values.view_as_torch()[:seq_len, :] + values = values.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) + aie_buffers.decode.queries.to("cpu") + queries = aie_buffers.decode.queries.view_as_torch()[:seq_len * config.n_heads, :] + queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) + aie_buffers.decode.keys.to("cpu") + keys = aie_buffers.decode.keys.view_as_torch()[:seq_len * config.n_kv_groups, :] + keys = keys.view(batch, seq_len, config.n_kv_groups, config.head_dim) + + # Step 3: Transpose + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + # Step 4: Update cache + keys_cache = torch.cat([keys_cache, keys], dim=2) + values_cache = torch.cat([values_cache, values], dim=2) + keys = keys_cache + values = values_cache + + # Step 5: Repeat keys and values for grouped attention + group_size = config.n_heads // config.n_kv_groups + keys = keys.repeat_interleave(group_size, dim=1) + values = values.repeat_interleave(group_size, dim=1) + context_len = keys.shape[2] + + # Step 6: Compute attention scores + aie_buffers.decode.attn_scores_keys.view_as_torch()[:, :context_len, :] = keys.squeeze(0)[:, :context_len, :] + aie_buffers.decode.attn_scores_queries.view_as_torch()[:, :, 0] = queries.squeeze(0)[:, 0, :] + aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.attn_scores_queries, aie_buffers.decode.attn_scores) + aie_buffers.decode.attn_scores.to("cpu") + scores = aie_buffers.decode.attn_scores.view_as_torch()[:, :context_len, 0].unsqueeze(0).unsqueeze(2) + + # Normalize + scores = scores / math.sqrt(config.head_dim) + + # Step 7: Apply mask + if mask is not None: + scores = scores.masked_fill(mask, float('-inf')) + + # Step 8: Softmax + attention_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Step 9: Compute attention output + context = torch.matmul(attention_weights, values) + + # Step 10: Concatenate heads and project + context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) + + return output, keys_cache, values_cache + +def swiglu_ffn_forward_prefill(layer_idx): + # Step 1: Gate projection + aie_ops.prefill.ffn_up_gate(aie_buffers.prefill.x_norm, aie_buffers.W_ffn_gate_prefill[layer_idx], aie_buffers.prefill.ffn_gate) + + # Step 2: Up projection + aie_ops.prefill.ffn_up_gate(aie_buffers.prefill.x_norm, aie_buffers.W_ffn_up_prefill[layer_idx], aie_buffers.prefill.ffn_up) + + # Step 3: Apply SiLU activation + aie_ops.prefill.ffn_silu(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_gate) + + # Step 4: Element-wise multiplication + aie_ops.prefill.eltwise_mul_ffn(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_up, aie_buffers.prefill.ffn_hidden) + + # Step 5: Down projection + aie_ops.prefill.ffn_down(aie_buffers.prefill.ffn_hidden, aie_buffers.W_ffn_down_prefill[layer_idx], aie_buffers.prefill.ffn_output) -def transformer_block_forward( + +def swiglu_ffn_forward_decode(layer_idx): + # Step 1: Gate projection + aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_gate_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_gate) + + # Step 2: Up projection + aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_up_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_up) + + # Step 3: Apply SiLU activation + aie_ops.decode.ffn_silu(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_gate) + + # Step 4: Element-wise multiplication + aie_ops.decode.eltwise_mul_ffn(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_up, aie_buffers.decode.ffn_hidden) + + # Step 5: Down projection + aie_ops.decode.gemv_ffn_down(aie_buffers.W_ffn_down_decode[layer_idx], aie_buffers.decode.ffn_hidden, aie_buffers.decode.ffn_output) + + +def transformer_block_forward_prefill( + config, seq_len, layer_idx, attn_keys_cache, attn_values_cache, - num_heads, - num_kv_groups, - W_norm1, - W_attn_query, - W_attn_key, - W_attn_value, - W_attn_out, - W_norm2, - rope_angles, attn_mask ): - # Select prefill or decode operations and buffers - if seq_len > 1: - ops = aie_ops.prefill - bufs = aie_buffers.prefill - else: - ops = aie_ops.decode - bufs = aie_buffers.decode + # Step 1: RMS normalization + aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_norm1[layer_idx], aie_buffers.prefill.x_norm) + aie_buffers.prefill.x_norm.to("cpu") + x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 2: Attention + attn_output, attn_keys, attn_values = grouped_query_attention_forward_prefill( + config, + x_norm, + attn_keys_cache, + attn_values_cache, + layer_idx, + attn_mask, + ) + + # Step 3: Residual + aie_buffers.prefill.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output + aie_ops.prefill.residual_add(aie_buffers.prefill.x, aie_buffers.prefill.attn_output, aie_buffers.prefill.x) + x = aie_buffers.prefill.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 4: Post-norm + aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_norm2[layer_idx], aie_buffers.prefill.x_norm) + aie_buffers.prefill.x_norm.to("cpu") + x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + + # Step 5: Feed-forward network + swiglu_ffn_forward_prefill(layer_idx) + + # Step 6: Residual + aie_ops.prefill.residual_add(aie_buffers.prefill.x, aie_buffers.prefill.ffn_output, aie_buffers.prefill.x) + return attn_keys, attn_values + + +def transformer_block_forward_decode( + config, + seq_len, + layer_idx, + attn_keys_cache, + attn_values_cache, + attn_mask +): # Step 1: RMS normalization - bufs.x.to("npu") - bufs.x_norm.to("npu") - ops.rms_norm(bufs.x, W_norm1, bufs.x_norm) - bufs.x_norm.to("cpu") - x_norm = bufs.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) + aie_buffers.decode.x_norm.to("cpu") + x_norm = aie_buffers.decode.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 2: Attention - attn_output, attn_keys, attn_values = grouped_query_attention_forward( + attn_output, attn_keys, attn_values = grouped_query_attention_forward_decode( + config, x_norm, attn_keys_cache, attn_values_cache, - W_attn_query, W_attn_key, W_attn_value, W_attn_out, - rope_angles, layer_idx, attn_mask, - num_heads, - num_kv_groups, ) # Step 3: Residual - bufs.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output - bufs.attn_output.to("npu") - x_view = bufs.x.view(np.prod(bufs.x.shape)) - attn_output_view = bufs.attn_output.view(np.prod(bufs.attn_output.shape)) - ops.residual_add(x_view, attn_output_view, x_view) - x = bufs.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] + aie_buffers.decode.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output + aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) + x = aie_buffers.decode.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] # Step 4: Post-norm - bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - bufs.x.to("npu") - bufs.x_norm.to("npu") - ops.rms_norm(bufs.x, W_norm2, bufs.x_norm) - bufs.x_norm.to("cpu") - x_norm = bufs.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] + aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm2[layer_idx], aie_buffers.decode.x_norm) + aie_buffers.decode.x_norm.to("cpu") + x_norm = aie_buffers.decode.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - # Step 5: fully-connected feed-forward network - swiglu_ffn_forward(seq_len, layer_idx) + # Step 5: Feed-forward network + swiglu_ffn_forward_decode(layer_idx) # Step 6: Residual - ffn_output_view = bufs.ffn_output.view(np.prod(bufs.ffn_output.shape)) - ops.residual_add(x_view, ffn_output_view, x_view) + aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.ffn_output, aie_buffers.decode.x) return attn_keys, attn_values -def llama_forward_pass( +def llama_forward_pass_prefill( config, state ): batch, seq_len = state.token_ids.shape + + # Step 1: RoPE angles + num_preceding_tokens = state.attn_keys_caches[0].shape[2] + angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] + aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice - # Select prefill or decode operations and buffers - if seq_len > 1: - ops = aie_ops.prefill - bufs = aie_buffers.prefill - else: - ops = aie_ops.decode - bufs = aie_buffers.decode - + # Step 2: Token embedding tok_emb_weight = config.weights['model.embed_tokens.weight'] - x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) # (batch, seq_len, emb_dim) + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) attn_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) - bufs.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - # Step 3: Apply transformer blocks + # Step 3: Transformer blocks for layer_idx in range(config.n_layers): - state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( + state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward_prefill( + config, seq_len, layer_idx, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx], - config.n_heads, - config.n_kv_groups, - W_norm1=aie_buffers.W_norm1[layer_idx], - W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], - W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], - W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], - W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], - W_norm2=aie_buffers.W_norm2[layer_idx], - rope_angles=config.angles, attn_mask=attn_mask, ) + # Step 4: Final normalization + aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_final_norm, aie_buffers.prefill.x) + # Step 5: Output projection + for i in range(config.vocab_partitions): + aie_ops.prefill.out_head(aie_buffers.prefill.x, aie_buffers.W_out_head_parts[i], aie_buffers.prefill.logits_parts[i]) + aie_buffers.prefill.logits.to("cpu") + logits_padded_partitioned = aie_buffers.prefill.logits.view_as_torch() + logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) + logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] + + return logits, state + + +def llama_forward_pass_decode( + config, + state +): + batch, seq_len = state.token_ids.shape + + # Step 1: RoPE angles + num_preceding_tokens = state.attn_keys_caches[0].shape[2] + angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] + aie_buffers.decode.rope_angles.view_as_torch()[:] = angles_slice + + # Step 2: Token embedding + tok_emb_weight = config.weights['model.embed_tokens.weight'] + x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), + diagonal=1 + ) + aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + + # Step 3: Transformer blocks + for layer_idx in range(config.n_layers): + state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward_decode( + config, + seq_len, + layer_idx, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + attn_mask=attn_mask, + ) + # Step 4: Final normalization - bufs.x.to("npu") - ops.rms_norm(bufs.x, aie_buffers.W_final_norm, bufs.x) + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) - # Step 5: Output projection (check for tied embeddings) - if seq_len > 1: - # Since vocab size is a very large dimension unsupported by the AIE GEMM, we have to execute the GEMM in multiple partitions and reassemble the output. - # Reference: - # bufs.x.to("cpu") - # x = bufs.x.view_as_torch().unsqueeze(0)[:, :seq_len, :] - # logits_ref = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - # assert (logits - logits_ref).max() < 0.5 - bufs.x.to("npu") - bufs.logits.to("npu") - for i in range(config.vocab_partitions): - ops.out_head(bufs.x, aie_buffers.W_out_head_parts[i], bufs.logits_parts[i]) - bufs.logits.to("cpu") - logits_padded_partitioned = bufs.logits.view_as_torch() # (vocab_partitions, padded_seq_len, padded_vocab_size // vocab_partitions) - logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) # (padded_seq_len, padded_vocab_size) - logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] # (batch, seq_len, vocab_size) - else: - # Step 5: Output projection - # Reference: - # x = bufs.x.view_as_torch().unsqueeze(0) - # logits = torch.nn.functional.linear(config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) - bufs.logits.to("npu") - # GEMV expects: (1, M, K) @ (1, K, 1) = (1, M, 1) - x_view = bufs.x.view((1, config.emb_dim, 1)) - logits_view = bufs.logits.view((1, config.vocab_size, 1)) - W_out_head_view = aie_buffers.W_out_head.view((1, aie_buffers.W_out_head.shape[0], aie_buffers.W_out_head.shape[1])) - ops.out_head(W_out_head_view, x_view, logits_view) - bufs.logits.to("cpu") - logits = bufs.logits.view_as_torch().view(1, 1, config.vocab_size) + # Step 5: Output projection + aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) + aie_buffers.decode.logits.to("cpu") + logits = aie_buffers.decode.logits.view_as_torch().view(1, 1, config.vocab_size) return logits, state +def llama_forward_pass( + config, + state +): + batch, seq_len = state.token_ids.shape + if seq_len > 1: + return llama_forward_pass_prefill(config, state) + else: + return llama_forward_pass_decode(config, state) + + # Main # ########################################################################## diff --git a/operators/gemv/op.py b/operators/gemv/op.py index 9542891e..95ac3fee 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -99,8 +99,9 @@ def get_kernel_artifacts(self): ] def get_arg_spec(self): + batch_dim = (self.num_batches,) if self.num_batches > 1 else () return [ - AIERuntimeArgSpec("in", (self.num_batches, self.M, self.K)), # matrix - AIERuntimeArgSpec("in", (self.num_batches, self.K, 1)), # vector - AIERuntimeArgSpec("out", (self.num_batches, self.M, 1)), # output + AIERuntimeArgSpec("in", batch_dim + (self.M, self.K)), # matrix + AIERuntimeArgSpec("in", batch_dim + (self.K,)), # vector + AIERuntimeArgSpec("out", batch_dim + (self.M,)), # output ] From 7fee60d8c11db00bae34d0eca8a07893e8a6d696 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 15:42:05 -0700 Subject: [PATCH 27/99] fix strided copy; offload KV cache concat + transpose to NPU --- applications/llama_3.2_1b/llama_npu.py | 108 ++++++++++++++++++------- operators/strided_copy/design.py | 35 ++++++-- 2 files changed, 108 insertions(+), 35 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index da4c74ed..6470c090 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -18,6 +18,7 @@ AIEBuffer ) from operators.common.utils import torch_to_numpy, numpy_to_torch +from operators.common.aie_base import PatchableSingleXclbinCallable from operators import ( AIERMSNorm, AIEGEMM, @@ -27,6 +28,7 @@ from operators.elementwise_mul.op import AIEElementwiseMul from operators.silu.op import AIESiLU from operators.rope.op import AIERope +from operators.strided_copy.op import AIEStridedCopy logging.basicConfig(level=logging.DEBUG) @@ -214,6 +216,32 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Strided copy operators for cache update (transpose and concatenate) + # Keys: transpose from (1, n_kv_groups, head_dim) to (n_kv_groups, 1, head_dim) and write to cache + self.decode.strided_copy_cache_compilable = AIEStridedCopy( + input_sizes=(config.n_kv_groups, 1, config.head_dim), + input_strides=(config.head_dim, config.n_kv_groups * config.head_dim, 1), + input_offset=0, + output_sizes=(1, config.n_kv_groups, 1, config.head_dim), + output_strides=(0, prompt_len * config.head_dim, config.head_dim, 1), + output_offset=0, # Will be patched at runtime based on cached_prompt_len + input_buffer_size=1 * config.n_kv_groups * config.head_dim, + output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, + num_aie_channels=1, + context=self.context + ).compile() + + # Create patchable callable for runtime offset updates + self.decode.strided_copy_cache = PatchableSingleXclbinCallable( + xclbin_path=self.decode.strided_copy_cache_compilable.xclbin_artifact.path, + kernel_name=self.decode.strided_copy_cache_compilable.xclbin_artifact.kernel_name, + insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.path, + args_spec=self.decode.strided_copy_cache_compilable.get_arg_spec() + ) + + # Store head_dim for patching + self.head_dim = config.head_dim + # Attention projection operators # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) self.prefill.attn_query = AIEGEMM( @@ -382,7 +410,6 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.rope_angles = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) # Attention score computation buffers (batched) self.attn_scores_keys = AIEBuffer(shape=(n_heads, emb_dim, head_dim), dtype=ml_dtypes.bfloat16) # Max context length - self.attn_scores_queries = AIEBuffer(shape=(n_heads, head_dim, 1), dtype=ml_dtypes.bfloat16) self.attn_scores = AIEBuffer(shape=(n_heads, emb_dim, 1), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: @@ -391,6 +418,16 @@ def __init__(self, config, prompt_len): self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) + # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) + self.keys_cache = [ + AIEBuffer(shape=(config.n_kv_groups, config.emb_dim, config.head_dim), dtype=ml_dtypes.bfloat16) + for _ in range(config.n_layers) + ] + self.values_cache = [ + AIEBuffer(shape=(config.n_kv_groups, config.emb_dim, config.head_dim), dtype=ml_dtypes.bfloat16) + for _ in range(config.n_layers) + ] + # Transformer block layer-wise RMS norm self.W_norm1 = [] self.W_norm2 = [] @@ -603,54 +640,53 @@ def grouped_query_attention_forward_decode( aie_ops.decode.rope_queries(aie_buffers.decode.queries, aie_buffers.decode.rope_angles, aie_buffers.decode.queries) aie_ops.decode.rope_keys(aie_buffers.decode.keys, aie_buffers.decode.rope_angles, aie_buffers.decode.keys) - aie_buffers.decode.values.to("cpu") - values = aie_buffers.decode.values.view_as_torch()[:seq_len, :] - values = values.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) + # Read results from NPU for CPU reference computation aie_buffers.decode.queries.to("cpu") queries = aie_buffers.decode.queries.view_as_torch()[:seq_len * config.n_heads, :] - queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) - aie_buffers.decode.keys.to("cpu") - keys = aie_buffers.decode.keys.view_as_torch()[:seq_len * config.n_kv_groups, :] - keys = keys.view(batch, seq_len, config.n_kv_groups, config.head_dim) - - # Step 3: Transpose - queries = queries.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - - # Step 4: Update cache - keys_cache = torch.cat([keys_cache, keys], dim=2) - values_cache = torch.cat([values_cache, values], dim=2) - keys = keys_cache - values = values_cache + # Since seq_len=1, the transpose is just a reinterpretation of the shape; no actual data movement needed + queries = queries.view(batch, config.n_heads, 1, config.head_dim) - # Step 5: Repeat keys and values for grouped attention + # Step 3: Update cache using strided copy on NPU (transpose and concatenate) + # Cache is already on NPU from prefill initialization or previous decode iteration + num_preceding_tokens = keys_cache.shape[2] + context_len = num_preceding_tokens + seq_len + + # Transpose and append new keys/values to this layer's cache on NPU + aie_ops.decode.strided_copy_cache(aie_buffers.decode.keys, aie_buffers.keys_cache[layer_idx]) + aie_ops.decode.strided_copy_cache(aie_buffers.decode.values, aie_buffers.values_cache[layer_idx]) + aie_buffers.keys_cache[layer_idx].to("cpu") + aie_buffers.values_cache[layer_idx].to("cpu") + keys = aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :context_len, :].unsqueeze(0) # (batch, n_kv_groups, context_len, head_dim) + values = aie_buffers.values_cache[layer_idx].view_as_torch()[:, :context_len, :].unsqueeze(0) # (batch, n_kv_groups, context_len, head_dim) + keys_cache = keys + values_cache = values + + # Step 4: Repeat keys and values for grouped attention group_size = config.n_heads // config.n_kv_groups keys = keys.repeat_interleave(group_size, dim=1) values = values.repeat_interleave(group_size, dim=1) context_len = keys.shape[2] - # Step 6: Compute attention scores + # Step 5: Compute attention scores aie_buffers.decode.attn_scores_keys.view_as_torch()[:, :context_len, :] = keys.squeeze(0)[:, :context_len, :] - aie_buffers.decode.attn_scores_queries.view_as_torch()[:, :, 0] = queries.squeeze(0)[:, 0, :] - aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.attn_scores_queries, aie_buffers.decode.attn_scores) + aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) aie_buffers.decode.attn_scores.to("cpu") scores = aie_buffers.decode.attn_scores.view_as_torch()[:, :context_len, 0].unsqueeze(0).unsqueeze(2) # Normalize scores = scores / math.sqrt(config.head_dim) - # Step 7: Apply mask + # Step 6: Apply mask if mask is not None: scores = scores.masked_fill(mask, float('-inf')) - # Step 8: Softmax + # Step 7: Softmax attention_weights = torch.nn.functional.softmax(scores, dim=-1) - # Step 9: Compute attention output + # Step 8: Compute attention output context = torch.matmul(attention_weights, values) - # Step 10: Concatenate heads and project + # Step 9: Concatenate heads and project context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) @@ -819,6 +855,14 @@ def llama_forward_pass_prefill( logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] + # Step 6: Initialize per-layer NPU cache buffers with current cache state for decode phase + for layer_idx in range(config.n_layers): + cache_len = state.attn_keys_caches[layer_idx].shape[2] + aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :cache_len, :] = state.attn_keys_caches[layer_idx].squeeze(0) + aie_buffers.values_cache[layer_idx].view_as_torch()[:, :cache_len, :] = state.attn_values_caches[layer_idx].squeeze(0) + aie_buffers.keys_cache[layer_idx].to("npu") + aie_buffers.values_cache[layer_idx].to("npu") + return logits, state @@ -828,6 +872,16 @@ def llama_forward_pass_decode( ): batch, seq_len = state.token_ids.shape + # Patch strided copy operators once for all layers with current cache offset + num_preceding_tokens = state.attn_keys_caches[0].shape[2] + output_offset = num_preceding_tokens * config.head_dim + offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset + patches = { + 39: (offset_val, 0xFFFFFFFF), + 56: (offset_val, 0xFFFFFFFF), + } + aie_ops.decode.strided_copy_cache.patch(patches) + # Step 1: RoPE angles num_preceding_tokens = state.attn_keys_caches[0].shape[2] angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] diff --git a/operators/strided_copy/design.py b/operators/strided_copy/design.py index f528ea49..5969dc95 100644 --- a/operators/strided_copy/design.py +++ b/operators/strided_copy/design.py @@ -18,8 +18,19 @@ input[0, :, 0] -> output[:, 0, 0] """ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, input_offset, output_buffer_size, output_sizes, output_strides, output_offset, transfer_size=None, num_aie_channels=1): - assert input_sizes[0] % num_aie_channels == 0, "Highest dimension of input_sizes must be divisible by num_aie_channels" - assert output_sizes[0] % num_aie_channels == 0, "Highest dimension of output_sizes must be divisible by num_aie_channels" + assert len(input_sizes) == len(input_strides) + assert len(output_sizes) == len(output_strides) + + # Pad out dimensions to 4D; dropping leading dimensions leads to compiler not initializing these registers, causing hard-to-debug errors + input_sizes = [1] * (4 - len(input_sizes)) + list(input_sizes) + input_strides = [0] * (4 - len(input_strides)) + list(input_strides) + output_sizes = [1] * (4 - len(output_sizes)) + list(output_sizes) + output_strides = [0] * (4 - len(output_strides)) + list(output_strides) + + input_highest_sz_idx = max(idx for idx, sz in enumerate(input_sizes) if sz >= 1) + output_highest_sz_idx = max(idx for idx, sz in enumerate(output_sizes) if sz >= 1) + assert input_sizes[input_highest_sz_idx] % num_aie_channels == 0, "Highest dimension of input_sizes must be divisible by num_aie_channels" + assert output_sizes[output_highest_sz_idx] % num_aie_channels == 0, "Highest dimension of output_sizes must be divisible by num_aie_channels" if transfer_size is None: transfer_size = int(np.prod(input_sizes)) @@ -32,8 +43,12 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu input_taps = [ TensorAccessPattern( tensor_dims=(int(input_buffer_size),), - offset=input_offset + c * (input_sizes[0] // num_aie_channels) * input_strides[0], - sizes=[input_sizes[0] // num_aie_channels, *input_sizes[1:]], + offset=input_offset + c * (input_sizes[input_highest_sz_idx] // num_aie_channels) * input_strides[input_highest_sz_idx], + sizes=( + input_sizes[:input_highest_sz_idx] + + [input_sizes[input_highest_sz_idx] // num_aie_channels] + + input_sizes[input_highest_sz_idx+1:] + ), strides=list(input_strides), ) for c in range(num_aie_channels) @@ -42,16 +57,20 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu output_taps = [ TensorAccessPattern( tensor_dims=(int(output_buffer_size),), - offset=output_offset + c * (output_sizes[0] // num_aie_channels) * output_strides[0], - sizes=[output_sizes[0] // num_aie_channels, *output_sizes[1:]], + offset=output_offset + c * (output_sizes[output_highest_sz_idx] // num_aie_channels) * output_strides[output_highest_sz_idx], + sizes=( + output_sizes[:output_highest_sz_idx] + + [output_sizes[output_highest_sz_idx] // num_aie_channels] + + output_sizes[output_highest_sz_idx+1:] + ), strides=list(output_strides), ) for c in range(num_aie_channels) ] # Use smaller FIFOs for the transfer amount - fifos_in = [ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=2) for c in range(num_aie_channels)] - fifos_out = [fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=2) for c in range(num_aie_channels)] + fifos_in = [ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=1) for c in range(num_aie_channels)] + fifos_out = [fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=1) for c in range(num_aie_channels)] rt = Runtime() with rt.sequence(inp_ty, out_ty) as (inp, out): From b72432b3413300108eae822c2c0dcfc69d56753e Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 19:26:17 -0700 Subject: [PATCH 28/99] offload repeat_interleave --- applications/llama_3.2_1b/llama_npu.py | 99 ++++++++++++-------------- operators/common/aie_base.py | 14 ++-- operators/strided_copy/test2.py | 12 +++- 3 files changed, 63 insertions(+), 62 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 6470c090..53bd6169 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -29,6 +29,7 @@ from operators.silu.op import AIESiLU from operators.rope.op import AIERope from operators.strided_copy.op import AIEStridedCopy +from operators.repeat.op import AIERepeat logging.basicConfig(level=logging.DEBUG) @@ -94,7 +95,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=True, - use_static_weight=True, separate_c_tiles=True, context=self.context ).compile() @@ -103,7 +103,6 @@ def __init__(self, config, prompt_len): M=config.vocab_size, K=config.emb_dim, num_aie_columns=8, - use_static_weight=True, tile_size_input=4, tile_size_output=32, context=self.context @@ -120,7 +119,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - use_static_weight=True, context=self.context ).compile().get_callable() @@ -133,7 +131,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - use_static_weight=True, context=self.context ).compile().get_callable() @@ -238,9 +235,16 @@ def __init__(self, config, prompt_len): insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.path, args_spec=self.decode.strided_copy_cache_compilable.get_arg_spec() ) - - # Store head_dim for patching - self.head_dim = config.head_dim + + # Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim) + # Compile with max context length, then patch at runtime for actual context_len + self.decode.attn_repeat_interleave = AIERepeat( + rows=config.n_kv_groups, + cols=prompt_len * config.head_dim, # Max context length + repeat=config.n_heads // config.n_kv_groups, + transfer_size=config.head_dim, + context=self.context + ).compile().get_callable() # Attention projection operators # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) @@ -253,7 +257,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, - use_static_weight=True, context=self.context ).compile().get_callable() @@ -276,7 +279,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, - use_static_weight=True, context=self.context ).compile().get_callable() @@ -299,7 +301,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, - use_static_weight=True, context=self.context ).compile().get_callable() @@ -323,7 +324,6 @@ def __init__(self, config, prompt_len): tile_k=64, tile_n=64, b_col_maj=False, - use_static_weight=False, context=self.context ).compile().get_callable() @@ -338,7 +338,6 @@ def __init__(self, config, prompt_len): num_batches=config.n_heads, context=self.context ).compile().get_callable() - # Allocate buffers shared with NPU @@ -394,7 +393,7 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d ] class AIEDecodeBuffers: - def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): + def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_context_len): self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) self.attn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) @@ -409,22 +408,23 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) # Attention score computation buffers (batched) - self.attn_scores_keys = AIEBuffer(shape=(n_heads, emb_dim, head_dim), dtype=ml_dtypes.bfloat16) # Max context length - self.attn_scores = AIEBuffer(shape=(n_heads, emb_dim, 1), dtype=ml_dtypes.bfloat16) + self.attn_scores_keys = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length + self.attn_scores_values = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length + self.attn_scores = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) - self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) + self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim, prompt_len) # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) self.keys_cache = [ - AIEBuffer(shape=(config.n_kv_groups, config.emb_dim, config.head_dim), dtype=ml_dtypes.bfloat16) + AIEBuffer(shape=(config.n_kv_groups, prompt_len, config.head_dim), dtype=ml_dtypes.bfloat16) for _ in range(config.n_layers) ] self.values_cache = [ - AIEBuffer(shape=(config.n_kv_groups, config.emb_dim, config.head_dim), dtype=ml_dtypes.bfloat16) + AIEBuffer(shape=(config.n_kv_groups, prompt_len, config.head_dim), dtype=ml_dtypes.bfloat16) for _ in range(config.n_layers) ] @@ -624,8 +624,7 @@ def grouped_query_attention_forward_prefill( def grouped_query_attention_forward_decode( config, x, - keys_cache, - values_cache, + num_preceding_tokens, layer_idx, mask=None, ): @@ -648,49 +647,38 @@ def grouped_query_attention_forward_decode( # Step 3: Update cache using strided copy on NPU (transpose and concatenate) # Cache is already on NPU from prefill initialization or previous decode iteration - num_preceding_tokens = keys_cache.shape[2] context_len = num_preceding_tokens + seq_len # Transpose and append new keys/values to this layer's cache on NPU aie_ops.decode.strided_copy_cache(aie_buffers.decode.keys, aie_buffers.keys_cache[layer_idx]) aie_ops.decode.strided_copy_cache(aie_buffers.decode.values, aie_buffers.values_cache[layer_idx]) - aie_buffers.keys_cache[layer_idx].to("cpu") - aie_buffers.values_cache[layer_idx].to("cpu") - keys = aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :context_len, :].unsqueeze(0) # (batch, n_kv_groups, context_len, head_dim) - values = aie_buffers.values_cache[layer_idx].view_as_torch()[:, :context_len, :].unsqueeze(0) # (batch, n_kv_groups, context_len, head_dim) - keys_cache = keys - values_cache = values - - # Step 4: Repeat keys and values for grouped attention + + # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU group_size = config.n_heads // config.n_kv_groups - keys = keys.repeat_interleave(group_size, dim=1) - values = values.repeat_interleave(group_size, dim=1) - context_len = keys.shape[2] + aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys) + aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values) # Step 5: Compute attention scores - aie_buffers.decode.attn_scores_keys.view_as_torch()[:, :context_len, :] = keys.squeeze(0)[:, :context_len, :] + # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) aie_buffers.decode.attn_scores.to("cpu") - scores = aie_buffers.decode.attn_scores.view_as_torch()[:, :context_len, 0].unsqueeze(0).unsqueeze(2) + scores = aie_buffers.decode.attn_scores.view_as_torch()[:, :context_len] # Normalize scores = scores / math.sqrt(config.head_dim) - # Step 6: Apply mask - if mask is not None: - scores = scores.masked_fill(mask, float('-inf')) - # Step 7: Softmax - attention_weights = torch.nn.functional.softmax(scores, dim=-1) + attention_weights = torch.nn.functional.softmax(scores, dim=-1).unsqueeze(-2) # (n_heads, 1, context_len) # Step 8: Compute attention output + values = aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch()[:, :context_len, :] # (n_heads, context_len, head_dim) context = torch.matmul(attention_weights, values) # Step 9: Concatenate heads and project context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) - return output, keys_cache, values_cache + return output, None, None def swiglu_ffn_forward_prefill(layer_idx): @@ -773,9 +761,8 @@ def transformer_block_forward_prefill( def transformer_block_forward_decode( config, seq_len, + num_preceding_tokens, layer_idx, - attn_keys_cache, - attn_values_cache, attn_mask ): # Step 1: RMS normalization @@ -787,8 +774,7 @@ def transformer_block_forward_decode( attn_output, attn_keys, attn_values = grouped_query_attention_forward_decode( config, x_norm, - attn_keys_cache, - attn_values_cache, + num_preceding_tokens, layer_idx, attn_mask, ) @@ -868,22 +854,24 @@ def llama_forward_pass_prefill( def llama_forward_pass_decode( config, - state + state, ): batch, seq_len = state.token_ids.shape - # Patch strided copy operators once for all layers with current cache offset - num_preceding_tokens = state.attn_keys_caches[0].shape[2] + # Patch operators once for all layers with current context length + num_preceding_tokens = state.num_preceding_tokens + context_len = num_preceding_tokens + seq_len + + # Patch strided copy operator for cache offset output_offset = num_preceding_tokens * config.head_dim offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset - patches = { + strided_copy_patches = { 39: (offset_val, 0xFFFFFFFF), 56: (offset_val, 0xFFFFFFFF), } - aie_ops.decode.strided_copy_cache.patch(patches) + aie_ops.decode.strided_copy_cache.patch(strided_copy_patches) # Step 1: RoPE angles - num_preceding_tokens = state.attn_keys_caches[0].shape[2] angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] aie_buffers.decode.rope_angles.view_as_torch()[:] = angles_slice @@ -901,9 +889,8 @@ def llama_forward_pass_decode( state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward_decode( config, seq_len, + num_preceding_tokens, layer_idx, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], attn_mask=attn_mask, ) @@ -924,9 +911,13 @@ def llama_forward_pass( ): batch, seq_len = state.token_ids.shape if seq_len > 1: - return llama_forward_pass_prefill(config, state) + ret = llama_forward_pass_prefill(config, state) + state.num_preceding_tokens = state.token_ids.shape[1] + return ret else: - return llama_forward_pass_decode(config, state) + ret = llama_forward_pass_decode(config, state) + state.num_preceding_tokens += 1 + return ret # Main diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index c1b0409d..a336bcb2 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -287,17 +287,21 @@ def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_m def __call__(self, *buffers): assert len(buffers) == len(self.args_spec) - assert all( - np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype - for i in range(len(buffers)) - ), "Input buffer shapes or dtypes do not match expected argument specification." + #assert all( + # np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype + # for i in range(len(buffers)) + #), "Input buffer shapes or dtypes do not match expected argument specification." self.insts_buffer.to("npu") for buf in buffers: buf.to("npu") opcode = 3 bos = [buffer.bo for buffer in buffers] run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) - run.wait() + ret_code = run.wait() + if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: + raise RuntimeError( + f"Kernel did not complete correctly: {ret_code}" + ) class PatchableSingleXclbinCallable(SingleXclbinCallable): def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): diff --git a/operators/strided_copy/test2.py b/operators/strided_copy/test2.py index 8b3ba310..7205de16 100755 --- a/operators/strided_copy/test2.py +++ b/operators/strided_copy/test2.py @@ -16,17 +16,22 @@ num_heads = 32 transpose_concat = AIEStridedCopy( - input_sizes=(num_heads, prompt_len, head_dim,), - input_strides=(head_dim, num_heads * head_dim, 1,), + input_sizes=(1, num_heads, prompt_len, head_dim,), + input_strides=(0, head_dim, num_heads * head_dim, 1,), input_offset=0, output_sizes=(1, num_heads, prompt_len, head_dim,), output_strides=(0, max_prompt_len * head_dim, head_dim, 1,), output_offset=cached_prompt_len * head_dim, input_buffer_size=prompt_len * num_heads * head_dim, output_buffer_size=num_heads * max_prompt_len * head_dim, - num_aie_channels=4 + num_aie_channels=1 ).compile().get_callable() +value_cache_1 = AIEBuffer((num_heads, max_prompt_len, head_dim)) +value_1 = AIEBuffer((prompt_len, num_heads, head_dim)) +value_cache_1.view_as_torch()[:, :cached_prompt_len, :] = torch.randn(num_heads, cached_prompt_len, head_dim) +value_1.view_as_torch()[:prompt_len, :, :] = torch.randn(prompt_len, num_heads, head_dim) + value_cache = AIEBuffer((num_heads, max_prompt_len, head_dim)) value = AIEBuffer((prompt_len, num_heads, head_dim)) @@ -38,6 +43,7 @@ out_ref = torch.cat([value_cache.view_as_torch()[:, :cached_prompt_len, :], value_transposed], dim=1) t_cpu = time.perf_counter() - t_cpu_start +transpose_concat(value_1, value_cache_1) value_cache.to("npu") value.to("npu") t_aie_start = time.perf_counter() From 8b4b2b3c822e828397ac0db36af672357660398c Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 19:59:29 -0700 Subject: [PATCH 29/99] offload normalization/scaling + softmax (with -inf masking on CPU for now) --- applications/llama_3.2_1b/llama_npu.py | 83 +++++++++++--- operators/softmax/op.py | 143 +++++++------------------ 2 files changed, 108 insertions(+), 118 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 53bd6169..937183a5 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -30,6 +30,7 @@ from operators.rope.op import AIERope from operators.strided_copy.op import AIEStridedCopy from operators.repeat.op import AIERepeat +from operators.softmax.op import AIESoftmax logging.basicConfig(level=logging.DEBUG) @@ -181,6 +182,39 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() + # Attention score scaling operators + # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector + self.prefill.attn_scale = AIEElementwiseMul( + size=config.n_heads * prompt_len * prompt_len, + tile_size=prompt_len, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + + self.decode.attn_scale = AIEElementwiseMul( + size=config.n_heads * prompt_len, + tile_size=prompt_len // 8, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + + # Softmax operators for attention weights + self.prefill.softmax = AIESoftmax( + rows=config.n_heads * prompt_len, + cols=prompt_len, + num_aie_columns=8, + num_channels=1, + context=self.context + ).compile().get_callable() + + self.decode.softmax = AIESoftmax( + rows=config.n_heads, + cols=prompt_len, + num_aie_columns=1, + num_channels=1, + context=self.context + ).compile().get_callable() + # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) @@ -391,6 +425,13 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d ) for h in range(n_heads) ] + # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) + scale_factor = 1.0 / math.sqrt(head_dim) + self.attn_scale_factor = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scale_factor.view_as_torch()[:] = scale_factor + self.attn_scale_factor.to("npu") + # Attention weights buffer (output of softmax) + self.attn_weights = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) class AIEDecodeBuffers: def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_context_len): @@ -411,7 +452,12 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_cont self.attn_scores_keys = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length self.attn_scores_values = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length self.attn_scores = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) - + # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) + scale_factor = 1.0 / math.sqrt(head_dim) + self.attn_scale_factor = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) + self.attn_scale_factor.view_as_torch().fill_(scale_factor) + self.attn_scale_factor.to("npu") # Attention weights buffer (output of softmax) + self.attn_weights = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline @@ -590,22 +636,29 @@ def grouped_query_attention_forward_prefill( aie_buffers.prefill.attn_scores_per_head[h] ) - # Read back all results at once from parent buffer + # Read back all results at once from parent buffer and apply scaling on NPU + aie_ops.prefill.attn_scale(aie_buffers.prefill.attn_scores, aie_buffers.prefill.attn_scale_factor, aie_buffers.prefill.attn_scores) aie_buffers.prefill.attn_scores.to("cpu") # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice max_seq_len = aie_buffers.prefill.attn_scores.shape[0] // config.n_heads scores = aie_buffers.prefill.attn_scores.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] - # Apply scaling - scores = scores / math.sqrt(config.head_dim) - # Step 7: Apply mask # This ensures causality, so that tokens in the future cannot attend to tokens in the past. if mask is not None: scores = scores.masked_fill(mask, float('-inf')) # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) - attention_weights = torch.nn.functional.softmax(scores, dim=-1) + # Write scores back to NPU buffer for softmax, handling variable seq_len and context_len + scores_buf = aie_buffers.prefill.attn_scores.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len) + scores_buf[:, :seq_len, :context_len] = scores.squeeze(0) + # Pad unused regions with -inf so they don't affect softmax + scores_buf[:, :seq_len, context_len:] = float('-inf') + scores_buf[:, seq_len:, :] = float('-inf') + aie_buffers.prefill.attn_scores.to("npu") + aie_ops.prefill.softmax(aie_buffers.prefill.attn_scores, aie_buffers.prefill.attn_weights) + aie_buffers.prefill.attn_weights.to("cpu") + attention_weights = aie_buffers.prefill.attn_weights.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] # Step 9: Compute attention output # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) @@ -661,14 +714,18 @@ def grouped_query_attention_forward_decode( # Step 5: Compute attention scores # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) + # Apply scaling on NPU + aie_ops.decode.attn_scale(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_scale_factor, aie_buffers.decode.attn_scores) aie_buffers.decode.attn_scores.to("cpu") - scores = aie_buffers.decode.attn_scores.view_as_torch()[:, :context_len] - - # Normalize - scores = scores / math.sqrt(config.head_dim) - - # Step 7: Softmax - attention_weights = torch.nn.functional.softmax(scores, dim=-1).unsqueeze(-2) # (n_heads, 1, context_len) + # Pad unused regions with -inf so they don't affect softmax FIXME: need to do this on NPU + scores_buf = aie_buffers.decode.attn_scores.view_as_torch() + scores_buf[:, context_len:] = float('-inf') + aie_buffers.decode.attn_scores.to("npu") + + # Step 7: Softmax on NPU + aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) + aie_buffers.decode.attn_weights.to("cpu") + attention_weights = aie_buffers.decode.attn_weights.view_as_torch()[:, :context_len].unsqueeze(-2) # (n_heads, 1, context_len) # Step 8: Compute attention output values = aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch()[:, :context_len, :] # (n_heads, context_len, head_dim) diff --git a/operators/softmax/op.py b/operators/softmax/op.py index 7c4cef71..21565753 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -2,133 +2,66 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, - KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIESoftmax(AIEOperatorBase): +class AIESoftmax(SingleMLIRSourceOperator): + """AIE-accelerated Softmax operation""" - def __init__( - self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None - ): - self.size = rows * cols + def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None): + assert rows % 16 == 0, "rows must be multiple of 16" + assert cols % 16 == 0, "cols must be multiple of 16" + assert (rows * cols) % (num_aie_columns * cols) == 0, "size must be multiple of num_aie_columns * tile_size" + self.rows = rows self.cols = cols - + self.size = rows * cols + self.num_aie_columns = num_aie_columns self.num_channels = num_channels - self.num_columns = num_aie_columns + + SingleMLIRSourceOperator.__init__(self, context=context) - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + def get_operator_name(self): + return f"softmax_{self.num_aie_columns}col_{self.num_channels}ch_{self.size}_{self.cols}t" - AIEOperatorBase.__init__(self, context=context) - - def set_up_artifacts(self): - # Compilation artifacts + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.cols}t" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="softmax", callback_args=[ self.context.device_manager.device_type, - self.rows * self.cols, - self.num_columns, + self.size, + self.num_aie_columns, self.num_channels, - 0, + 0, # trace_size self.cols, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"softmax.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "softmax.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"gemm_{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Runlist setup - self.add_buffer("in", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "softmax", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("softmax", "in", "output") - - def forward(self, x): - applicable = ( - x.shape[-1] * x.shape[-2] == self.size - and x.shape[-1] == self.cols - and x.shape[-1] % 16 == 0 - and x.shape[-2] % 16 == 0 - ) - if not applicable: - raise AIEOperatorConstraintError("AIESoftmax: incompatible tensor shape(s)") - - return self._execute_aie_operation(x) - - def _execute_aie_operation(self, x): - original_shape = x.shape - - # Reshape for processing - # Split x into a list of H tensors of size [S_q, S_kv] - heads = x.shape[1] - x_list = [x[0, h, :, :] for h in range(heads)] - results = [] - for i in range(heads): - x_iter = x_list[i] - input_size = x_iter.nbytes - self.write_buffer("in", x_iter) - test_pattern = np.zeros(len(x_iter), dtype=bfloat16) - self.write_buffer("output", test_pattern) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", shape=x_list[i].shape, dtype=bfloat16 - ) - results.append(result) - - result = torch.stack(results, dim=0).unsqueeze( - 0 - ) # Shape: (1, heads, S_q, S_kv) - - return result + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"softmax.o", + depends=[ + SourceArtifact.new( + self.context.base_dir / "aie_kernels" / "aie2p" / "softmax.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), + AIERuntimeArgSpec("out", (self.size,)), + ] From 5575d4fc2f027112b3cd28d187f02aca5f890a90 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 19 Jan 2026 20:40:52 -0700 Subject: [PATCH 30/99] make softmax run-time parametrizable --- aie_kernels/aie2p/softmax.cc | 7 +++ applications/llama_3.2_1b/llama_npu.py | 43 ++++++++++++++----- operators/softmax/design.py | 57 ++++++++++++++++++++----- operators/softmax/op.py | 9 +++- operators/softmax/test2.py | 59 ++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 24 deletions(-) create mode 100755 operators/softmax/test2.py diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 7d480354..6837f570 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -177,4 +177,11 @@ void partial_softmax_bf16(bfloat16 *restrict input, partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale); } +void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) { + // TODO: Optimize this to use vector code + for(int32 i = unmasked_size; i < total_size; i++) { + inout[i] = (bfloat16)(-INFINITY); + } +} + } // extern "C" \ No newline at end of file diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 937183a5..409a6a91 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -199,21 +199,37 @@ def __init__(self, config, prompt_len): ).compile().get_callable() # Softmax operators for attention weights - self.prefill.softmax = AIESoftmax( + self.prefill.softmax_compilable = AIESoftmax( rows=config.n_heads * prompt_len, cols=prompt_len, num_aie_columns=8, num_channels=1, + rtp_vector_size=prompt_len, # Compile with max size context=self.context - ).compile().get_callable() + ).compile() + + self.prefill.softmax = PatchableSingleXclbinCallable( + xclbin_path=self.prefill.softmax_compilable.xclbin_artifact.path, + kernel_name=self.prefill.softmax_compilable.xclbin_artifact.kernel_name, + insts_bin_path=self.prefill.softmax_compilable.insts_artifact.path, + args_spec=self.prefill.softmax_compilable.get_arg_spec() + ) - self.decode.softmax = AIESoftmax( + self.decode.softmax_compilable = AIESoftmax( rows=config.n_heads, cols=prompt_len, num_aie_columns=1, num_channels=1, + rtp_vector_size=prompt_len, # Compile with max size context=self.context - ).compile().get_callable() + ).compile() + + self.decode.softmax = PatchableSingleXclbinCallable( + xclbin_path=self.decode.softmax_compilable.xclbin_artifact.path, + kernel_name=self.decode.softmax_compilable.xclbin_artifact.kernel_name, + insts_bin_path=self.decode.softmax_compilable.insts_artifact.path, + args_spec=self.decode.softmax_compilable.get_arg_spec() + ) # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) @@ -456,8 +472,9 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_cont scale_factor = 1.0 / math.sqrt(head_dim) self.attn_scale_factor = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) self.attn_scale_factor.view_as_torch().fill_(scale_factor) - self.attn_scale_factor.to("npu") # Attention weights buffer (output of softmax) + self.attn_scale_factor.to("npu") # Attention weights buffer (output of softmax) self.attn_weights = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) + class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline @@ -716,13 +733,8 @@ def grouped_query_attention_forward_decode( aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) # Apply scaling on NPU aie_ops.decode.attn_scale(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_scale_factor, aie_buffers.decode.attn_scores) - aie_buffers.decode.attn_scores.to("cpu") - # Pad unused regions with -inf so they don't affect softmax FIXME: need to do this on NPU - scores_buf = aie_buffers.decode.attn_scores.view_as_torch() - scores_buf[:, context_len:] = float('-inf') - aie_buffers.decode.attn_scores.to("npu") - # Step 7: Softmax on NPU + # Step 7: Softmax on NPU (patched once at beginning of decode pass) aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) aie_buffers.decode.attn_weights.to("cpu") attention_weights = aie_buffers.decode.attn_weights.view_as_torch()[:, :context_len].unsqueeze(-2) # (n_heads, 1, context_len) @@ -866,6 +878,11 @@ def llama_forward_pass_prefill( num_preceding_tokens = state.attn_keys_caches[0].shape[2] angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice + + # Patch softmax operator once for this prefill pass with the context length + context_len = num_preceding_tokens + seq_len + softmax_patches = {8: (context_len, 0xFFFFFFFF)} + aie_ops.prefill.softmax.patch(softmax_patches) # Step 2: Token embedding tok_emb_weight = config.weights['model.embed_tokens.weight'] @@ -927,6 +944,10 @@ def llama_forward_pass_decode( 56: (offset_val, 0xFFFFFFFF), } aie_ops.decode.strided_copy_cache.patch(strided_copy_patches) + + # Patch softmax operator for actual context length + softmax_patches = {8: (context_len, 0xFFFFFFFF)} + aie_ops.decode.softmax.patch(softmax_patches) # Step 1: RoPE angles angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] diff --git a/operators/softmax/design.py b/operators/softmax/design.py index cdac226c..d54a10ff 100644 --- a/operators/softmax/design.py +++ b/operators/softmax/design.py @@ -7,7 +7,7 @@ import argparse import sys -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker, Buffer, WorkerRuntimeBarrier from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern @@ -15,15 +15,17 @@ from ml_dtypes import bfloat16 -def softmax(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None): per_tile_elements = tile_size - n = per_tile_elements * num_columns + if rtp_vector_size is None: + rtp_vector_size = per_tile_elements + n = per_tile_elements * num_aie_columns if num_elements % n != 0: raise ValueError( f"Number of elements ({num_elements}) must be a multiple of {n}." ) N_div_n = num_elements // n - chunk = num_elements // num_columns // num_channels # For offset calculation + chunk = num_elements // num_aie_columns // num_channels # For offset calculation dtype = bfloat16 # Define tensor types @@ -33,27 +35,47 @@ def softmax(dev, num_elements, num_columns, num_channels, trace_size, tile_size) # AIE-array data movement with object fifos of_in1s = [ ObjectFifo(tile_ty, name=f"in1_{i}_{j}") - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] of_outs = [ ObjectFifo(tile_ty, name=f"out_{i}_{j}") - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] # AIE Core Function declaration softmax_kernel = Kernel("softmax_bf16", "softmax.o", [tile_ty, tile_ty, np.int32]) + mask_kernel = Kernel("mask_bf16", "softmax.o", [tile_ty, np.int32, np.int32]) # Define a task that will run on a compute tile - def core_body(of_in1, of_out, softmax_kernel): + def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): # Number of sub-vector "tile" iterations + barrier.wait_for_value(1) + vector_size = rtp[0] for _ in range_(N_div_n): elem_in1 = of_in1.acquire(1) elem_out = of_out.acquire(1) + mask_kernel(elem_in1, vector_size, per_tile_elements) softmax_kernel(elem_in1, elem_out, per_tile_elements) of_in1.release(1) of_out.release(1) + + rtps = [ + Buffer( + np.ndarray[(1,), np.dtype[np.int32]], + name=f"rtp_{i}_{j}", + use_write_rtp=True, + ) + for i in range(num_aie_columns) + for j in range(num_channels) + ] + + barriers = [ + WorkerRuntimeBarrier() + for i in range(num_aie_columns) + for j in range(num_channels) + ] # Create a worker to run the task on a compute tile my_workers = [ @@ -63,9 +85,12 @@ def core_body(of_in1, of_out, softmax_kernel): of_in1s[i * num_channels + j].cons(), of_outs[i * num_channels + j].prod(), softmax_kernel, + mask_kernel, + rtps[i * num_channels + j], + barriers[i * num_channels + j] ], ) - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] @@ -81,7 +106,7 @@ def core_body(of_in1, of_out, softmax_kernel): [1, 1, 1, chunk], [0, 0, 0, 1], ) - for i in range(num_columns) + for i in range(num_aie_columns) for j in range(num_channels) ] @@ -90,11 +115,21 @@ def core_body(of_in1, of_out, softmax_kernel): with rt.sequence(tensor_ty, tensor_ty) as (A, C): rt.start(*my_workers) + # Set run-time parameter for actual vector size (remainder is considered padding and ignored by the computation) + def set_rtps(*args): + for rtp in args: + rtp[0] = rtp_vector_size + + rt.inline_ops(set_rtps, rtps) + + for i in range(num_aie_columns * num_channels): + rt.set_barrier(barriers[i], 1) + # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. tg = rt.task_group() # Fill the input objectFIFOs with data - for i in range(num_columns): + for i in range(num_aie_columns): for j in range(num_channels): rt.fill( of_in1s[i * num_channels + j].prod(), @@ -103,7 +138,7 @@ def core_body(of_in1, of_out, softmax_kernel): task_group=tg, ) # Drain the output objectFIFOs with data - for i in range(num_columns): + for i in range(num_aie_columns): for j in range(num_channels): rt.drain( of_outs[i * num_channels + j].cons(), diff --git a/operators/softmax/op.py b/operators/softmax/op.py index 21565753..e2443d58 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -16,7 +16,7 @@ class AIESoftmax(SingleMLIRSourceOperator): """AIE-accelerated Softmax operation""" - def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None): + def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_vector_size=None, context=None): assert rows % 16 == 0, "rows must be multiple of 16" assert cols % 16 == 0, "cols must be multiple of 16" assert (rows * cols) % (num_aie_columns * cols) == 0, "size must be multiple of num_aie_columns * tile_size" @@ -26,11 +26,15 @@ def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, cont self.size = rows * cols self.num_aie_columns = num_aie_columns self.num_channels = num_channels + self.rtp_vector_size = rtp_vector_size SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): - return f"softmax_{self.num_aie_columns}col_{self.num_channels}ch_{self.size}_{self.cols}t" + name = f"softmax_{self.num_aie_columns}col_{self.num_channels}ch_{self.size}_{self.cols}t" + if self.rtp_vector_size is not None: + name += f"_{self.rtp_vector_size}rtp" + return name def get_mlir_artifact(self): operator_dir = Path(__file__).parent @@ -45,6 +49,7 @@ def get_mlir_artifact(self): self.num_channels, 0, # trace_size self.cols, + self.rtp_vector_size, ], ) diff --git a/operators/softmax/test2.py b/operators/softmax/test2.py new file mode 100755 index 00000000..a3dfb149 --- /dev/null +++ b/operators/softmax/test2.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path +import time +import torch +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from operators.softmax.op import AIESoftmax +from operators.common import AIEBuffer + +max_context_len=2048 +prompt_len=8 +n_heads=32 + +softmax_op = AIESoftmax( + rows=n_heads, + cols=max_context_len, + rtp_vector_size=prompt_len +).compile().get_callable() + +inp = AIEBuffer((n_heads, max_context_len)) +out = AIEBuffer((n_heads, max_context_len)) + +inp.view_as_torch()[:] = torch.randn(n_heads, max_context_len) +out.view_as_torch()[:] = torch.zeros(n_heads, max_context_len) + +t_cpu_start = time.perf_counter() +out_ref = inp.view_as_torch()[:, :prompt_len].softmax(dim=-1) +t_cpu = time.perf_counter() - t_cpu_start + +inp.to("npu") +out.to("npu") +t_aie_start = time.perf_counter() +softmax_op(inp, out) +t_aie = time.perf_counter() - t_aie_start +out.to("cpu") + +print(out_ref) +print(t_cpu) +aie_out = out.view_as_torch()[:, :prompt_len] +print(aie_out) +print(t_aie) + +# Check which elements differ +diff = torch.abs(out_ref - aie_out) +max_diff = diff.max() +print(f"Max diff: {max_diff}") +print(f"Number of mismatches (> 1e-2): {(diff > 1e-2).sum()}") + +# Find first mismatch +mismatches = torch.where(diff > 1e-2) +if len(mismatches[0]) > 0: + for i in range(min(10, len(mismatches[0]))): + h, s = mismatches[0][i], mismatches[1][i] + print(f"Mismatch at head={h}, seq={s}: ref={out_ref[h,s]}, aie={aie_out[h,s]}, diff={diff[h,s]}") + +assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) + From 7ed760ac4155f3db626be972cfa3351a5cd89fab Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 10:03:13 -0700 Subject: [PATCH 31/99] Fix device manager singleton --- operators/common/aie_device_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/operators/common/aie_device_manager.py b/operators/common/aie_device_manager.py index a06d957a..e48c39cd 100644 --- a/operators/common/aie_device_manager.py +++ b/operators/common/aie_device_manager.py @@ -19,6 +19,7 @@ class AIEDeviceManager: """Singleton manager for AIE XRT resources""" _instance = None + _initialized = False def __new__(cls): if cls._instance is None: @@ -26,6 +27,11 @@ def __new__(cls): return cls._instance def __init__(self): + # Only initialize once + if AIEDeviceManager._initialized: + return + AIEDeviceManager._initialized = True + self.device = pyxrt.device(0) self.device_type = detect_npu_device() self.contexts = {} # xclbin_path -> (context, xclbin) From e2cdc1d2016d339a575ca17c430435323cbf55da Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 10:05:16 -0700 Subject: [PATCH 32/99] fix GEMV for large Bs --- operators/gemv/design.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/operators/gemv/design.py b/operators/gemv/design.py index 840d3c50..f988fc2e 100644 --- a/operators/gemv/design.py +++ b/operators/gemv/design.py @@ -145,6 +145,12 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): # Every column gets the entirety of the vector B, no TAP needed. # This design assumes that all of B fits on the cores. + B_tap = TensorAccessPattern( + tensor_dims=L3_B_ty.__args__[0], + offset=0, + sizes=[1, 1, 1, num_batches * K], + strides=[0, 0, 0, 1], + ) # Collection pattern for the output vector C: each AIE core writes back its contiguous chunk of rows. C_taps = [ @@ -166,7 +172,7 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): tg_b = rt.task_group() for col in range(cols): # Simple linear transfer of B, includes all batches in sequence - rt.fill(B_L3L1_fifos[col].prod(), B, task_group=tg_b) + rt.fill(B_L3L1_fifos[col].prod(), B, B_tap, task_group=tg_b) for batch in range(num_batches): tg_ac = rt.task_group() for col in range(cols): From 855bec346c533726256c341dda9a7f744098eb02 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 10:05:53 -0700 Subject: [PATCH 33/99] rework and offload transpose operator in llama, attention weight * values GEMV offloaded --- applications/llama_3.2_1b/llama_npu.py | 72 ++++++++++++-- operators/transpose/design.py | 119 +---------------------- operators/transpose/op.py | 129 ++++++++----------------- 3 files changed, 104 insertions(+), 216 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 409a6a91..c41e82a2 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -31,6 +31,7 @@ from operators.strided_copy.op import AIEStridedCopy from operators.repeat.op import AIERepeat from operators.softmax.op import AIESoftmax +from operators.transpose.op import AIETranspose logging.basicConfig(level=logging.DEBUG) @@ -388,6 +389,29 @@ def __init__(self, config, prompt_len): num_batches=config.n_heads, context=self.context ).compile().get_callable() + + # Transpose values from (max_context_len, head_dim) to (head_dim, max_context_len) per head + self.decode.transpose_values = AIETranspose( + M=prompt_len, + N=config.head_dim, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=self.context + ).compile().get_callable() + + # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head + self.decode.gemv_attn_context = AIEGEMV( + M=config.head_dim, + K=prompt_len, # max possible context length + num_aie_columns=8, + tile_size_input=4, + tile_size_output=4, + num_batches=config.n_heads, + context=self.context + ).compile().get_callable() # Allocate buffers shared with NPU @@ -465,14 +489,33 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_cont self.values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) self.rope_angles = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) # Attention score computation buffers (batched) - self.attn_scores_keys = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length - self.attn_scores_values = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) # Max context length + self.attn_scores_keys = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) + self.attn_scores_values = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) + self.attn_scores_values_transposed = AIEBuffer(shape=(n_heads, head_dim, max_context_len), dtype=ml_dtypes.bfloat16) + # Create per-head subbuffers for transpose operations (to avoid allocating in hot path) + self.attn_scores_values_per_head = [ + self.attn_scores_values.subbuffer( + length=max_context_len * head_dim, + offset=h * max_context_len * head_dim, + shape=(max_context_len, head_dim) + ) + for h in range(n_heads) + ] + self.attn_scores_values_transposed_per_head = [ + self.attn_scores_values_transposed.subbuffer( + length=head_dim * max_context_len, + offset=h * head_dim * max_context_len, + shape=(head_dim, max_context_len) + ) + for h in range(n_heads) + ] + self.attn_context = AIEBuffer(shape=(n_heads, head_dim), dtype=ml_dtypes.bfloat16) self.attn_scores = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) scale_factor = 1.0 / math.sqrt(head_dim) self.attn_scale_factor = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) - self.attn_scale_factor.view_as_torch().fill_(scale_factor) - self.attn_scale_factor.to("npu") # Attention weights buffer (output of softmax) + self.attn_scale_factor.view_as_torch()[:] = scale_factor + self.attn_scale_factor.to("npu") self.attn_weights = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) class AIELlamaBuffers: @@ -736,15 +779,26 @@ def grouped_query_attention_forward_decode( # Step 7: Softmax on NPU (patched once at beginning of decode pass) aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) - aie_buffers.decode.attn_weights.to("cpu") - attention_weights = aie_buffers.decode.attn_weights.view_as_torch()[:, :context_len].unsqueeze(-2) # (n_heads, 1, context_len) - # Step 8: Compute attention output - values = aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch()[:, :context_len, :] # (n_heads, context_len, head_dim) - context = torch.matmul(attention_weights, values) + # Step 8: Compute attention output on NPU + # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head + for h in range(config.n_heads): + aie_ops.decode.transpose_values( + aie_buffers.decode.attn_scores_values_per_head[h], + aie_buffers.decode.attn_scores_values_transposed_per_head[h] + ) + + # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) + aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) + + # Read context from NPU + aie_buffers.decode.attn_context.to("cpu") + context = aie_buffers.decode.attn_context.view_as_torch().unsqueeze(1) # (n_heads, 1, head_dim) # Step 9: Concatenate heads and project + # (n_heads, 1, head_dim) -> (n_heads, head_dim, 1) -> (1, 1, n_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) + # (1, 1, n_heads * head_dim) @ (emb_dim, n_heads * head_dim)^T -> (1, 1, emb_dim) output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) return output, None, None diff --git a/operators/transpose/design.py b/operators/transpose/design.py index 7a53365a..9a0c3b29 100644 --- a/operators/transpose/design.py +++ b/operators/transpose/design.py @@ -2,20 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from ml_dtypes import bfloat16 -from pathlib import Path import numpy as np -import argparse -import sys from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker from aie.iron.placers import SequentialPlacer -from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern from aie.iron.controlflow import range_ -from aie.helpers.util import np_ndarray_type_get_shape -def shuffle_transpose(dev, M, N, num_columns, num_channels, trace_size, m, n, s): +def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s): num_elements = M * N per_tile_elements = m * n dtype = bfloat16 @@ -163,115 +158,3 @@ def core_body(of_in1, of_out, transpose_kernel): # Place program components (assign them resources on the device) and generate an MLIR module return Program(dev, rt).resolve_program(SequentialPlacer()) - - -if __name__ == "__main__": - - def str_to_device(device: str): - if device == "npu": - return NPU1() - elif device == "npu2": - return NPU2() - else: - raise ValueError(f"Device name {device} is unknown.") - - p = argparse.ArgumentParser() - # Parse command line arguments - - # Device name is required to select the AIE device: npu or npu2 - p.add_argument( - "-d", - "--dev", - required=True, - dest="device", - help="AIE Device", - type=str_to_device, - ) - # Transfer size is required to define the size of the data to be transferred - p.add_argument( - "-M", "--workload-rows", required=True, dest="work_rows", help="Number of rows" - ) - p.add_argument( - "-N", - "--workload-columns", - required=True, - dest="work_cols", - help="Number of columns", - ) - # Number of columns is required to define the number of columns to be used - # It must be less than or equal to 4 for npu and 8 for npu2 - p.add_argument( - "-co", "--columns", required=True, dest="cols", help="Number of columns" - ) - # Number of channels is required to define the number of channels to be used - # It must be 1 or 2 - p.add_argument( - "-ch", "--channels", required=True, dest="chans", help="Number of channels" - ) - # Tile size - p.add_argument( - "-m", "--tile-rows", required=True, dest="tile_rows", help="Outer tile rows" - ) - p.add_argument( - "-n", - "--tile-columns", - required=True, - dest="tile_cols", - help="Outer tile columns", - ) - p.add_argument( - "-s", - "--kernel-dim", - required=True, - choices=["4", "8"], - dest="kernel_dim", - help="Inner tile dimension (square)", - ) - # Trace Size - p.add_argument( - "-tr", "--trace-size", required=True, dest="trace_size", help="Trace size" - ) - p.add_argument( - "--output-file-path", - "-o", - type=str, - help="Output file path for the generated MLIR module", - ) - - opts = p.parse_args(sys.argv[1:]) - - M = int(opts.work_rows) - N = int(opts.work_cols) - columns = int(opts.cols) - - dev = opts.device # Already a device object from str_to_device - - # Validate columns based on device type - if isinstance(dev, NPU1) and columns > 4: - raise ValueError("[ERROR] Device NPU cannot allocate more than 4 columns") - elif isinstance(dev, NPU2) and columns > 8: - raise ValueError("[ERROR] Device NPU2 cannot allocate more than 8 columns") - - channels = int(opts.chans) - if channels < 1 or channels > 2: - raise ValueError("Number of channels must be 1 or 2") - m = int(opts.tile_rows) - n = int(opts.tile_cols) - s = int(opts.kernel_dim) - if (((M * N) % (m * n)) % columns % channels) != 0: - print( - "transfer size (" - + str(M * N) - + ") must be a multiple of " - + str(m * n) - + f" and divisible by the number of columns ({columns}) and {channels} channels per column" - ) - raise ValueError - trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 - - module = shuffle_transpose(dev, M, N, columns, channels, trace_size, m, n, s) - - output_file_path = Path(opts.output_file_path) - - with open(output_file_path, "w") as f: - f.write(str(module)) diff --git a/operators/transpose/op.py b/operators/transpose/op.py index b71e14f8..92e21828 100644 --- a/operators/transpose/op.py +++ b/operators/transpose/op.py @@ -1,53 +1,44 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -import numpy as np -from ml_dtypes import bfloat16 from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIETranspose(AIEOperatorBase): +class AIETranspose(SingleMLIRSourceOperator): """AIE-accelerated transpose operator""" def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): + assert M % m == 0, f"Matrix rows ({M}) must be a multiple of {m}" + assert N % n == 0, f"Matrix columns ({N}) must be a multiple of {n}" + assert m % s == 0, f"AIE tile rows ({m}) must be a multiple of {s}" + assert n % s == 0, f"AIE tile columns ({n}) must be a multiple of {s}" + assert M * N % (m * n * num_aie_columns * num_channels) == 0, "Transfer size must be divisible by m*n*num_columns*num_channels" + self.M = M self.N = N self.m = m self.n = n self.s = s - self.size = M * N - self.tile_size = m * n - self.num_columns = num_aie_columns self.num_channels = num_channels + + SingleMLIRSourceOperator.__init__(self, context=context) - total_shimdma_channels = self.num_columns * self.num_channels - if 1 > 1: - total_shimdma_channels *= 1 - assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - - self.xclbin_artifact = None - self.insts_artifact = None - - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"transpose_{self.num_columns}c_{self.num_channels}ch_{self.M}x{self.N}_{self.m}x{self.n}_{self.s}s" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"transpose_{self.num_columns}c_{self.num_channels}ch_{self.M}x{self.N}_{self.m}x{self.n}_{self.s}s" - - mlir_artifact = PythonGeneratedMLIRArtifact.new( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="shuffle_transpose", callback_args=[ @@ -56,73 +47,33 @@ def set_up_artifacts(self): self.N, self.num_columns, self.num_channels, - 0, self.m, self.n, self.s, ], ) - xclbin_artifact = XclbinArtifact.new( - f"{file_name_base}.xclbin", - depends=[ - mlir_artifact, - KernelObjectArtifact.new( - f"transpose_{self.m}x{self.n}.o", - depends=[ - SourceArtifact.new( - self.context.base_dir - / "aie_kernels" - / "generic" - / "transpose.cc" - ) - ], - extra_flags=[ - f"-DDIM_m={self.m}", - f"-DDIM_n={self.n}", - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "transpose", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("transpose", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIETranspose: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact.new( + f"transpose_{self.m}x{self.n}.o", + depends=[ + SourceArtifact.new( + self.context.base_dir + / "aie_kernels" + / "generic" + / "transpose.cc" + ) + ], + extra_flags=[ + f"-DDIM_m={self.m}", + f"-DDIM_n={self.n}", + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.M * self.N,)), # input + AIERuntimeArgSpec("out", (self.M * self.N,)), # output (transposed) + ] From 00861689afa263e115828f21b454feb44e391aca Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 14:08:34 -0700 Subject: [PATCH 34/99] commit forgotten repeat operator --- operators/repeat/design.py | 63 +++++++++++++++++++++++++++++++++++ operators/repeat/op.py | 67 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 operators/repeat/design.py create mode 100644 operators/repeat/op.py diff --git a/operators/repeat/design.py b/operators/repeat/design.py new file mode 100644 index 00000000..71b98c5f --- /dev/null +++ b/operators/repeat/design.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +#from aie.extras.context import mlir_mod_ctx +#from aie.ir import StridedLayoutAttr, ShapedType +#from aie.dialects.aie import * +#from aie.dialects.aiex import * +from aie.dialects.aiex import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + + +""" +Repeat interleave +""" +def repeat(dev, dtype, rows, cols, repeat, transfer_size=None): + dtype = np.dtype[dtype] + + # Try to work around hardware size limitations by breaking transfers into smaller chunks + cols_split = 1 + if cols > 1023: + for divisor in range(2, cols + 1): + if cols % divisor == 0 and cols // divisor <= 1023: + cols_split = divisor + break + else: + raise ValueError(f"Cannot split cols={cols} into chunks <= 1023; hardware limits cols to not exceed 1023") + assert cols_split <= 1023, "cols is too large, can't split into smaller transfers" + + if transfer_size is None: + transfer_size = cols + + inp_ty = np.ndarray[(rows, cols), dtype,] + out_ty = np.ndarray[(rows * repeat, cols), dtype,] + transfer_ty = np.ndarray[(transfer_size,), dtype,] + + input_tap = TensorAccessPattern( + tensor_dims=(rows, cols), + offset=0, + sizes=[repeat, rows, cols // cols_split, cols_split], + strides=[0, cols, cols_split, 1], + ) + + output_tap = TensorAccessPattern( + tensor_dims=(rows * repeat, cols), + offset=0, + sizes=[repeat, rows, cols // cols_split, cols_split], + strides=[cols, cols * repeat, cols_split, 1], + ) + + # Use smaller FIFOs for the transfer amount + fifo_in = ObjectFifo(transfer_ty, name="fifo_in", depth=2) + fifo_out = fifo_in.cons().forward(name="fifo_out", depth=2) + + rt = Runtime() + with rt.sequence(inp_ty, out_ty) as (inp, out): + tg = rt.task_group() + rt.fill(fifo_in.prod(), inp, input_tap, task_group=tg) + rt.drain(fifo_out.cons(), out, output_tap, task_group=tg, wait=True) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/operators/repeat/op.py b/operators/repeat/op.py new file mode 100644 index 00000000..13f8da00 --- /dev/null +++ b/operators/repeat/op.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from operators.common import ( + SingleMLIRSourceOperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + + +class AIERepeat(SingleMLIRSourceOperator): + """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" + + def __init__( + self, + rows, + cols, + repeat, + transfer_size=None, + dtype=bfloat16, + context=None, + ): + self.rows = rows + self.cols = cols + self.repeat = repeat + self.transfer_size = transfer_size + self.dtype = dtype + SingleMLIRSourceOperator.__init__(self, context=context) + + def get_operator_name(self): + name = f"repeat_{self.rows}x{self.cols}_by_{self.repeat}" + if self.transfer_size is not None: + name += f"_{self.transfer_size}ts" + return name + + def get_mlir_artifact(self): + operator_dir = Path(__file__).parent + + return PythonGeneratedMLIRArtifact.new( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="repeat", + callback_args=[ + self.context.device_manager.device_type, + self.dtype, + self.rows, + self.cols, + self.repeat, + self.transfer_size, + ] + ) + + def get_kernel_artifacts(self): + return [] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.rows, self.cols)), + AIERuntimeArgSpec("out", (self.rows * self.repeat, self.cols)), + ] From ffca54121cf2f2df67beceb927fa1d4ff21af3d6 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 14:58:04 -0700 Subject: [PATCH 35/99] offload last GEMV in GQA, reorganize/simplify decode code --- applications/llama_3.2_1b/llama_npu.py | 302 +++++++++---------------- 1 file changed, 105 insertions(+), 197 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index c41e82a2..03a532ec 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -200,21 +200,7 @@ def __init__(self, config, prompt_len): ).compile().get_callable() # Softmax operators for attention weights - self.prefill.softmax_compilable = AIESoftmax( - rows=config.n_heads * prompt_len, - cols=prompt_len, - num_aie_columns=8, - num_channels=1, - rtp_vector_size=prompt_len, # Compile with max size - context=self.context - ).compile() - - self.prefill.softmax = PatchableSingleXclbinCallable( - xclbin_path=self.prefill.softmax_compilable.xclbin_artifact.path, - kernel_name=self.prefill.softmax_compilable.xclbin_artifact.kernel_name, - insts_bin_path=self.prefill.softmax_compilable.insts_artifact.path, - args_spec=self.prefill.softmax_compilable.get_arg_spec() - ) + # Prefill uses CPU softmax to reduce NPU operator count self.decode.softmax_compilable = AIESoftmax( rows=config.n_heads, @@ -333,7 +319,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.gemv_attn_key = AIEGEMV( + self.decode.gemv_attn_key_value = AIEGEMV( M=config.n_kv_groups * config.head_dim, K=config.emb_dim, num_aie_columns=8, @@ -355,15 +341,6 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - self.decode.gemv_attn_value = AIEGEMV( - M=config.n_kv_groups * config.head_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.head_dim // 2, - context=self.context - ).compile().get_callable() - # Attention score computation: Q @ K^T per head # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head self.prefill.attn_scores = AIEGEMM( @@ -412,6 +389,16 @@ def __init__(self, config, prompt_len): num_batches=config.n_heads, context=self.context ).compile().get_callable() + + # Output projection: (n_heads * head_dim,) @ (emb_dim, n_heads * head_dim)^T -> (emb_dim,) + self.decode.gemv_attn_output = AIEGEMV( + M=config.emb_dim, + K=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.emb_dim // 8, + context=self.context + ).compile().get_callable() # Allocate buffers shared with NPU @@ -510,6 +497,7 @@ def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_cont for h in range(n_heads) ] self.attn_context = AIEBuffer(shape=(n_heads, head_dim), dtype=ml_dtypes.bfloat16) + self.attn_context_concat = AIEBuffer(shape=(n_heads * head_dim,), dtype=ml_dtypes.bfloat16) self.attn_scores = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) scale_factor = 1.0 / math.sqrt(head_dim) @@ -544,6 +532,7 @@ def __init__(self, config, prompt_len): self.W_attn_key_decode = [] self.W_attn_value_prefill = [] self.W_attn_value_decode = [] + self.W_attn_output_decode = [] # SwiGLU FFN weights self.W_ffn_gate_prefill = [] self.W_ffn_up_prefill = [] @@ -576,6 +565,9 @@ def __init__(self, config, prompt_len): self.W_attn_value_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].T).to("npu") ) + self.W_attn_output_decode.append( + AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']).to("npu") + ) self.W_ffn_gate_decode.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") ) @@ -619,7 +611,7 @@ def __init__(self, config, prompt_len): self.decode.logits = AIEBuffer(shape=(config.vocab_size,)) -# Operators +# Prefill # ########################################################################## def grouped_query_attention_forward_prefill( @@ -708,17 +700,9 @@ def grouped_query_attention_forward_prefill( if mask is not None: scores = scores.masked_fill(mask, float('-inf')) - # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) - # Write scores back to NPU buffer for softmax, handling variable seq_len and context_len - scores_buf = aie_buffers.prefill.attn_scores.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len) - scores_buf[:, :seq_len, :context_len] = scores.squeeze(0) - # Pad unused regions with -inf so they don't affect softmax - scores_buf[:, :seq_len, context_len:] = float('-inf') - scores_buf[:, seq_len:, :] = float('-inf') - aie_buffers.prefill.attn_scores.to("npu") - aie_ops.prefill.softmax(aie_buffers.prefill.attn_scores, aie_buffers.prefill.attn_weights) - aie_buffers.prefill.attn_weights.to("cpu") - attention_weights = aie_buffers.prefill.attn_weights.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] + # Step 8: Apply softmax on CPU + scores = torch.softmax(scores.to(torch.float32), dim=-1).to(torch.bfloat16) + attention_weights = scores # Step 9: Compute attention output # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) @@ -734,76 +718,6 @@ def grouped_query_attention_forward_prefill( return output, keys_cache, values_cache -def grouped_query_attention_forward_decode( - config, - x, - num_preceding_tokens, - layer_idx, - mask=None, -): - batch, seq_len, emb_dim = x.shape - - # Step 1: Linear projections - write directly to queries/keys/values buffers - aie_ops.decode.gemv_attn_query(aie_buffers.W_attn_query_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.queries) - aie_ops.decode.gemv_attn_key(aie_buffers.W_attn_key_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.keys) - aie_ops.decode.gemv_attn_value(aie_buffers.W_attn_value_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.values) - - # Step 2: Apply RoPE - use same buffers for input and output - aie_ops.decode.rope_queries(aie_buffers.decode.queries, aie_buffers.decode.rope_angles, aie_buffers.decode.queries) - aie_ops.decode.rope_keys(aie_buffers.decode.keys, aie_buffers.decode.rope_angles, aie_buffers.decode.keys) - - # Read results from NPU for CPU reference computation - aie_buffers.decode.queries.to("cpu") - queries = aie_buffers.decode.queries.view_as_torch()[:seq_len * config.n_heads, :] - # Since seq_len=1, the transpose is just a reinterpretation of the shape; no actual data movement needed - queries = queries.view(batch, config.n_heads, 1, config.head_dim) - - # Step 3: Update cache using strided copy on NPU (transpose and concatenate) - # Cache is already on NPU from prefill initialization or previous decode iteration - context_len = num_preceding_tokens + seq_len - - # Transpose and append new keys/values to this layer's cache on NPU - aie_ops.decode.strided_copy_cache(aie_buffers.decode.keys, aie_buffers.keys_cache[layer_idx]) - aie_ops.decode.strided_copy_cache(aie_buffers.decode.values, aie_buffers.values_cache[layer_idx]) - - # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU - group_size = config.n_heads // config.n_kv_groups - aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys) - aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values) - - # Step 5: Compute attention scores - # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV - aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) - # Apply scaling on NPU - aie_ops.decode.attn_scale(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_scale_factor, aie_buffers.decode.attn_scores) - - # Step 7: Softmax on NPU (patched once at beginning of decode pass) - aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) - - # Step 8: Compute attention output on NPU - # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head - for h in range(config.n_heads): - aie_ops.decode.transpose_values( - aie_buffers.decode.attn_scores_values_per_head[h], - aie_buffers.decode.attn_scores_values_transposed_per_head[h] - ) - - # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) - aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) - - # Read context from NPU - aie_buffers.decode.attn_context.to("cpu") - context = aie_buffers.decode.attn_context.view_as_torch().unsqueeze(1) # (n_heads, 1, head_dim) - - # Step 9: Concatenate heads and project - # (n_heads, 1, head_dim) -> (n_heads, head_dim, 1) -> (1, 1, n_heads * head_dim) - context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - # (1, 1, n_heads * head_dim) @ (emb_dim, n_heads * head_dim)^T -> (1, 1, emb_dim) - output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) - - return output, None, None - - def swiglu_ffn_forward_prefill(layer_idx): # Step 1: Gate projection aie_ops.prefill.ffn_up_gate(aie_buffers.prefill.x_norm, aie_buffers.W_ffn_gate_prefill[layer_idx], aie_buffers.prefill.ffn_gate) @@ -821,23 +735,6 @@ def swiglu_ffn_forward_prefill(layer_idx): aie_ops.prefill.ffn_down(aie_buffers.prefill.ffn_hidden, aie_buffers.W_ffn_down_prefill[layer_idx], aie_buffers.prefill.ffn_output) -def swiglu_ffn_forward_decode(layer_idx): - # Step 1: Gate projection - aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_gate_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_gate) - - # Step 2: Up projection - aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_up_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_up) - - # Step 3: Apply SiLU activation - aie_ops.decode.ffn_silu(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_gate) - - # Step 4: Element-wise multiplication - aie_ops.decode.eltwise_mul_ffn(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_up, aie_buffers.decode.ffn_hidden) - - # Step 5: Down projection - aie_ops.decode.gemv_ffn_down(aie_buffers.W_ffn_down_decode[layer_idx], aie_buffers.decode.ffn_hidden, aie_buffers.decode.ffn_output) - - def transformer_block_forward_prefill( config, seq_len, @@ -881,47 +778,6 @@ def transformer_block_forward_prefill( return attn_keys, attn_values -def transformer_block_forward_decode( - config, - seq_len, - num_preceding_tokens, - layer_idx, - attn_mask -): - # Step 1: RMS normalization - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) - aie_buffers.decode.x_norm.to("cpu") - x_norm = aie_buffers.decode.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 2: Attention - attn_output, attn_keys, attn_values = grouped_query_attention_forward_decode( - config, - x_norm, - num_preceding_tokens, - layer_idx, - attn_mask, - ) - - # Step 3: Residual - aie_buffers.decode.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output - aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) - x = aie_buffers.decode.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 4: Post-norm - aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm2[layer_idx], aie_buffers.decode.x_norm) - aie_buffers.decode.x_norm.to("cpu") - x_norm = aie_buffers.decode.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 5: Feed-forward network - swiglu_ffn_forward_decode(layer_idx) - - # Step 6: Residual - aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.ffn_output, aie_buffers.decode.x) - - return attn_keys, attn_values - - def llama_forward_pass_prefill( config, state @@ -932,11 +788,6 @@ def llama_forward_pass_prefill( num_preceding_tokens = state.attn_keys_caches[0].shape[2] angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice - - # Patch softmax operator once for this prefill pass with the context length - context_len = num_preceding_tokens + seq_len - softmax_patches = {8: (context_len, 0xFFFFFFFF)} - aie_ops.prefill.softmax.patch(softmax_patches) # Step 2: Token embedding tok_emb_weight = config.weights['model.embed_tokens.weight'] @@ -980,15 +831,11 @@ def llama_forward_pass_prefill( return logits, state -def llama_forward_pass_decode( - config, - state, -): - batch, seq_len = state.token_ids.shape +# Decode +# ########################################################################## - # Patch operators once for all layers with current context length - num_preceding_tokens = state.num_preceding_tokens - context_len = num_preceding_tokens + seq_len +def patch_operators_for_decode(config, num_preceding_tokens): + context_len = num_preceding_tokens + 1 # Patch strided copy operator for cache offset output_offset = num_preceding_tokens * config.head_dim @@ -1003,40 +850,104 @@ def llama_forward_pass_decode( softmax_patches = {8: (context_len, 0xFFFFFFFF)} aie_ops.decode.softmax.patch(softmax_patches) - # Step 1: RoPE angles - angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] + +def llama_forward_pass_decode(config, state): + batch, seq_len = state.token_ids.shape + assert seq_len == 1 + + patch_operators_for_decode(config, state.num_preceding_tokens) + + # Step 1: Prefill RoPE angle look-up tables + angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] aie_buffers.decode.rope_angles.view_as_torch()[:] = angles_slice - # Step 2: Token embedding + # Step 2: Token embedding (on CPU) tok_emb_weight = config.weights['model.embed_tokens.weight'] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) - attn_mask = torch.triu( - torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), - diagonal=1 - ) aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x # Step 3: Transformer blocks for layer_idx in range(config.n_layers): - state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward_decode( + transformer_block_forward_decode( config, - seq_len, - num_preceding_tokens, + state.num_preceding_tokens, layer_idx, - attn_mask=attn_mask, ) + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) # Step 4: Final normalization + aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) # Step 5: Output projection - # Step 4: Final normalization - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) - - # Step 5: Output projection - aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) + # Read outputs from NPU to CPU aie_buffers.decode.logits.to("cpu") logits = aie_buffers.decode.logits.view_as_torch().view(1, 1, config.vocab_size) return logits, state +def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization + grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx) # Step 2: Attention; results stored in attn_output + aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) # Step 3: Residual + aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm2[layer_idx], aie_buffers.decode.x_norm) # Step 4: Post-norm + swiglu_ffn_forward_decode(layer_idx) # Step 5: Feed-forward network + aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.ffn_output, aie_buffers.decode.x) # Step 6: Residual + + +def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx): + context_len = num_preceding_tokens + 1 + group_size = config.n_heads // config.n_kv_groups + + # Step 1: Linear projections - write directly to queries/keys/values buffers + aie_ops.decode.gemv_attn_query(aie_buffers.W_attn_query_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.queries) + aie_ops.decode.gemv_attn_key_value(aie_buffers.W_attn_key_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.keys) + aie_ops.decode.gemv_attn_key_value(aie_buffers.W_attn_value_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.values) + + # Step 2: Apply RoPE - use same buffers for input and output + aie_ops.decode.rope_queries(aie_buffers.decode.queries, aie_buffers.decode.rope_angles, aie_buffers.decode.queries) + aie_ops.decode.rope_keys(aie_buffers.decode.keys, aie_buffers.decode.rope_angles, aie_buffers.decode.keys) + + # Step 3: Update cache using strided copy on NPU (transpose and concatenate) + # Cache is already on NPU from prefill initialization or previous decode iteration + # Transpose and append new keys/values to this layer's cache on NPU + aie_ops.decode.strided_copy_cache(aie_buffers.decode.keys, aie_buffers.keys_cache[layer_idx]) + aie_ops.decode.strided_copy_cache(aie_buffers.decode.values, aie_buffers.values_cache[layer_idx]) + + # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU + aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys) + aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values) + + # Step 5: Compute attention scores + # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV + aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) + aie_ops.decode.attn_scale(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_scale_factor, aie_buffers.decode.attn_scores) + + # Step 7: Softmax on NPU (patched once at beginning of decode pass) + aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) + + # Step 8: Compute attention output on NPU + # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head + for h in range(config.n_heads): + aie_ops.decode.transpose_values( + aie_buffers.decode.attn_scores_values_per_head[h], + aie_buffers.decode.attn_scores_values_transposed_per_head[h] + ) + # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) + aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) + + # Step 9: Project on NPU: (emb_dim, n_heads * head_dim) @ (n_heads * head_dim,) -> (emb_dim,) + aie_ops.decode.gemv_attn_output(aie_buffers.W_attn_output_decode[layer_idx], aie_buffers.decode.attn_context, aie_buffers.decode.attn_output) + + +def swiglu_ffn_forward_decode(layer_idx): + aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_gate_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_gate) # Gate projection + aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_up_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_up) # Up projection + aie_ops.decode.ffn_silu(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_gate) # SiLU activation + aie_ops.decode.eltwise_mul_ffn(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_up, aie_buffers.decode.ffn_hidden) # Gate application (eltwise mul) + aie_ops.decode.gemv_ffn_down(aie_buffers.W_ffn_down_decode[layer_idx], aie_buffers.decode.ffn_hidden, aie_buffers.decode.ffn_output) # Down projection + + +# Main +# ########################################################################## + def llama_forward_pass( config, state @@ -1052,9 +963,6 @@ def llama_forward_pass( return ret -# Main -# ########################################################################## - def main(): global aie_ops, aie_buffers max_seq_len = 2048 From c7926e8c6367abb9b905deb14f03f09c23b336c9 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 20 Jan 2026 20:09:26 -0700 Subject: [PATCH 36/99] initial steps for automatically fusing operators --- applications/llama_3.2_1b/autofuse.py | 264 ++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100755 applications/llama_3.2_1b/autofuse.py diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py new file mode 100755 index 00000000..85e757b6 --- /dev/null +++ b/applications/llama_3.2_1b/autofuse.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 + +import torch +import math +from pathlib import Path +import sys +import numpy as np +import ml_dtypes +import logging +import time +import importlib +from aie import ir +from aie.dialects import aie, aiex +from aie.extras.context import mlir_mod_ctx + +repo_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(repo_root)) + +from operators.common.aie_context import AIEContext +from operators.common import AIEOperatorBase, AIEBuffer, SingleMLIRSourceOperator +from operators.common.utils import torch_to_numpy, numpy_to_torch +from operators.common.compilation import PythonGeneratedMLIRArtifact +from operators import AIEGEMV +from operators.elementwise_mul.op import AIEElementwiseMul +from operators.silu.op import AIESiLU + + +emb_dim = 2048 +hidden_dim = 8192 + +# Operator definitions +# --- + +gemv_ffn_up_gate_op = AIEGEMV( + M=hidden_dim, + K=emb_dim, + num_aie_columns=1, + tile_size_input=4, + tile_size_output=hidden_dim // 8, +) + +gemv_ffn_down_op = AIEGEMV( + M=emb_dim, + K=hidden_dim, + num_aie_columns=1, + tile_size_input=1, + tile_size_output=emb_dim // 8, +) + +silu_ffn_op = AIESiLU( + size=hidden_dim, + tile_size=hidden_dim // 8, + num_aie_columns=1, +) + +eltwise_mul_ffn_op = AIEElementwiseMul( + size=hidden_dim, + tile_size=hidden_dim // 8, + num_aie_columns=1, +) + + +# Buffers +# --- + +buf_W_ffn_gate = AIEBuffer.from_torch(torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16)) +buf_W_ffn_up = AIEBuffer.from_torch(torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16)) +buf_W_ffn_down = AIEBuffer.from_torch(torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16)) +buf_x_norm = AIEBuffer.from_torch(torch.randn(emb_dim, dtype=torch.bfloat16)) +buf_ffn_gate = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) +buf_ffn_up = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) +buf_ffn_hidden = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) +buf_ffn_output = AIEBuffer.from_torch(torch.zeros(emb_dim, dtype=torch.bfloat16)) + + +# Separate xclbins +# --- + +gemv_ffn_up_gate = gemv_ffn_up_gate_op.compile().get_callable() +gemv_ffn_down = gemv_ffn_down_op.compile().get_callable() +silu_ffn = silu_ffn_op.compile().get_callable() +eltwise_mul_ffn = eltwise_mul_ffn_op.compile().get_callable() + +def run_separate_xclbins(): + gemv_ffn_up_gate(buf_W_ffn_gate, buf_x_norm, buf_ffn_gate) # Gate projection + gemv_ffn_up_gate(buf_W_ffn_up, buf_x_norm, buf_ffn_up) # Up projection + silu_ffn(buf_ffn_gate, buf_ffn_gate) # SiLU activation + eltwise_mul_ffn(buf_ffn_gate, buf_ffn_up, buf_ffn_hidden) # Gate application (eltwise mul) + gemv_ffn_down(buf_W_ffn_down, buf_ffn_hidden, buf_ffn_output) # Down projection + return buf_ffn_output.to("cpu").view_as_torch() + + +# Autofused +# --- + +class FusedMLIROperator(SingleMLIRSourceOperator): + def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): + assert all( + isinstance(op, SingleMLIRSourceOperator) and all(isinstance(buf, str) for buf in bufs) + for op, *bufs in runlist + ) + # Runlist is a list of operators and names for their buffer arguments. + # Shapes for the named buffer arguments are derived from the operator's argument specification. + # To pass data between operators, use the same buffer name in multiple operators. + # If the same buffer name is used in multiple operators, the required buffer shapes must match for each operator. + self.runlist = runlist + self.name = name + self.input_args = input_args + self.output_args = output_args + self.args = {} + self.populate_args() + AIEOperatorBase.__init__(self, *args, **kwargs) + + def populate_args(self): + for op, *bufs in self.runlist: + args_specs = op.get_arg_spec() + assert len(args_specs) == len(bufs), "Number of buffers must match operator argument specification" + for i, buf_name in enumerate(bufs): + args_spec = args_specs[i] + if buf_name not in self.args: + self.args[buf_name] = args_spec + else: + assert np.prod(self.args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" + for arg in self.input_args: + assert arg in self.args, f"Input argument {arg} not found in runlist buffers" + for arg in self.output_args: + assert arg in self.args, f"Output argument {arg} not found in runlist buffers" + + def get_operator_name(self): + return self.name + + def get_kernel_artifacts(self): + kernel_artifacts = [] + for op, *bufs in self.runlist: + kernel_artifacts.extend(op.get_kernel_artifacts()) + return kernel_artifacts + + @staticmethod + def get_child_mlir_module(artifact): + assert isinstance(artifact, PythonGeneratedMLIRArtifact) + # Import the Python source file + spec = importlib.util.spec_from_file_location( + Path(artifact.import_path).name, artifact.import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context + if artifact.requires_context: + raise NotImplementedError("Not handled, make your operator return a ctx.module") + callback_function = getattr(module, artifact.callback_fn) + mlir_module = callback_function( + *artifact.callback_args, **artifact.callback_kwargs + ) + return mlir_module + + def get_mlir_artifact(self): + device_mlir_strings = {} # op -> device str + device_ty = None + # FIXME: The proper way for this would be to create a new type of artifact (FusedMLIRArtifact) and a new compilation rule that does what this function steps _only if_ the fused MLIR file doesn't exist yet. + # As it stands, we're regenerating it on each run. + for runlist_op, *bufs in self.runlist: + if runlist_op in device_mlir_strings: + continue + artifact = runlist_op.get_mlir_artifact() + mlir_module = self.get_child_mlir_module(artifact) + for op in mlir_module.body.operations: + if not isinstance(op, aie.DeviceOp): + continue + if device_ty is None: + device_ty = op.device + # else: + # assert device_ty == op.device, "All operators in a fused operator must target the same type of AIE" + device_mlir_strings[runlist_op] = str(op) + + device_names = {} # op -> str + with mlir_mod_ctx() as ctx: + for i, (runlist_op, device_str) in enumerate(device_mlir_strings.items()): + dev_op = aie.DeviceOp.parse(device_str) + device_names[runlist_op] = f"dev{i}" + dev_op.sym_name = ir.StringAttr.get(device_names[runlist_op]) + ctx.module.body.append(dev_op) + @aie.device(device_ty) + def main(): + # Argument 0 is scratch space for intermediate values. + # All other arguments are defined by the input/output buffers. + @aiex.runtime_sequence( + np.ndarray[(1,), np.dtype[np.int8]], + np.ndarray[(1,), np.dtype[np.int8]], + np.ndarray[(1,), np.dtype[np.int8]], + ) + def sequence(input_buf, output_buf, scratch_buf): + for runlist_op, *bufs in self.runlist: + configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(device_names[runlist_op]) + configure_op = aiex.ConfigureOp(configure_sym_ref_attr) + configure_body = configure_op.body.blocks.append() + with ir.InsertionPoint(configure_body): + sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") + run_op = aiex.RunOp(sequence_sym_ref_attr, [input_buf]) + print(str(ctx.module)) + print(ctx.module.operation.verify()) + print("success") + sys.exit(0) + + def get_arg_spec(self): + pass + +class FusedFullELFCallable: + def __init__(self, op): + self.op = op + + def __call__(self, *kwargs): + assert all(kw in self.op.args for kw in kwargs), "at least one unknown argument passed" + assert all(kw in kwargs for kw in self.op.input_args), "not all input arguments passed" + assert all(kw in kwargs for kw in self.op.output_args), "not all output arguments passed" + + +swiglu_fused_op = FusedMLIROperator( + "swiglu", + [ + (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "inter_ffn_gate"), + (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "inter_ffn_up"), + (silu_ffn_op, "inter_ffn_gate", "inter_ffn_gate"), + (eltwise_mul_ffn_op, "inter_ffn_gate", "inter_ffn_up", "inter_ffn_hidden"), + (gemv_ffn_down_op, "W_ffn_down", "inter_ffn_hidden", "ffn_output"), + ], + input_args=[ + "x_norm", + "W_ffn_gate", + "W_ffn_up", + "W_ffn_down" + ], + output_args=[ + "ffn_output" + ] +) +swiglu_fused = swiglu_fused_op.compile().get_callable() + +def run_autofused(): + swiglu_fused() + +# CPU +# --- + +def run_cpu(): + x_norm = buf_x_norm.view_as_torch() + W_ffn_gate = buf_W_ffn_gate.view_as_torch() + W_ffn_up = buf_W_ffn_up.view_as_torch() + W_ffn_down = buf_W_ffn_down.view_as_torch() + + ffn_gate = torch.matmul(W_ffn_gate, x_norm) + ffn_up = torch.matmul(W_ffn_up, x_norm) + ffn_gate = torch.nn.functional.silu(ffn_gate) + ffn_hidden = ffn_gate * ffn_up + ffn_output = torch.matmul(W_ffn_down, ffn_hidden) + + return ffn_output + + +# Main +# --- + +print(run_autofused()) +print(run_separate_xclbins()) +print(run_cpu()) From f16695776584614f8f1b56d074c1b6e3993e7ee6 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 13:00:40 -0700 Subject: [PATCH 37/99] [WIP] compilation refactor --- operators/common/compilation.py | 582 ++++++++++++++------------------ 1 file changed, 253 insertions(+), 329 deletions(-) diff --git a/operators/common/compilation.py b/operators/common/compilation.py index f3f51344..46f77e87 100644 --- a/operators/common/compilation.py +++ b/operators/common/compilation.py @@ -40,102 +40,151 @@ from aie.extras.context import mlir_mod_ctx +# Global Functions +# ########################################################################## + + +def plan(rules, graph: CompilationArtifactGraph): + if all(artifact.is_available() for artifact in graph): + return [] # Everything has been compiled + for rule in rules: + if rule.matches(graph): + commands, new_graph = rule.compile(graph) + break + else: + raise RuntimeError( + f"No matching rule to compile target(s): {', '.join(artifact.filename for artifact in graph)}" + ) + return [(rule, commands, graph)] + plan(rules, new_graph) + + +def execute(plan): + for rule, commands, _ in plan: + logging.debug(f"Executing rule: {rule.__class__.__name__}") + for command in commands: + logging.debug(f" Executing command: {command}") + success = command.run() + if not success: + raise RuntimeError(f"Command failed: {command}") + + +def compile(rules, artifacts): + plan_steps = plan(rules, artifacts) + print(plan_steps) + execute(plan_steps) + + +# Compilation Artifact Graph +# ########################################################################## + + +class CompilationArtifactGraph: + def __init__(self, artifacts=None): + self.artifacts = artifacts if artifacts is not None else [] + + def __iter__(self): + return iter(self.artifacts) + + def dfs(self): + return self._traverse(True) + + def bfs(self): + return self._traverse(False) + + def _traverse(self, dfs): + visited = set() + todo = self.artifacts.copy() + while todo: + artifact = todo.pop() if dfs else todo.pop(0) + if artifact in visited: + continue + visited.add(artifact) + todo.extend(artifact.dependencies) + yield artifact + + def copy(self): + return CompilationArtifactGraph(artifacts=self.artifacts.copy()) + + def replace(self, old_artifact, new_artifact): + for i, artifact in enumerate(self.artifacts): + if artifact == old_artifact: + self.artifacts[i] = new_artifact + else: + artifact.dependencies.replace(old_artifact, new_artifact) + return self + + def populate_availability_from_filesystem(self): + for artifact in self.artifacts: + artifact.available = artifact.is_available_in_filesystem() + + def get_worklist(self, kind): + """Return a list of artifacts of the given kind that can be built in the next step (dependencies available).""" + return [ + artifact + for artifact in self.artifacts.bfs() + if isinstance(artifact, kind) + and not artifact.is_available() + and artifact.dependencies_available() + ] + + # Compilation Artifacts -# -------------------------------------------------------------------------- +# ########################################################################## class CompilationArtifact(ABC): - _instances = {} - - @classmethod - def new(cls, path, *args, **kwargs): - """Uniques artifacts based on absolute file path; any two artifacts with the same absolute path will be represented by the same object.""" - path = Path(path) - abs_path = path.absolute() - if abs_path not in cls._instances: - cls._instances[abs_path] = None - instance = cls(path, *args, **kwargs) - cls._instances[abs_path] = instance - else: - assert ( - type(cls._instances[abs_path]) == cls - ), f"Artifact with path {abs_path} is already registered with a different type" - return cls._instances[abs_path] - - def __init__(self, path, depends=None): - abs_path = path.absolute() - assert ( - abs_path in self._instances - ), "do not construct artifact objects directly; call the get() class method instead for uniquing" - self.path: Path = path - self.depends: list[CompilationArtifact] = depends if depends is not None else [] - self.users: list[CompilationArtifact] = ( - [] - ) # List of ancestor artifacts that depend on this artifact - for dependency in self.depends: - dependency.users.append(self) - self.fake_available = False + def __init__(self, filename, dependencies=None, available=False): + self.filename = filename + self.dependencies: CompilationArtifactGraph = CompilationArtifactGraph(artifacts=dependencies if dependencies is not None else []) + self.available = available + return self def __repr__(self): - return f"{self.__class__.__name__}(path={self.path}, depends={self.depends})" - - def set_path(self, new_path): - old_abs_path = self.path.absolute() - new_path = Path(new_path) - abs_path = new_path.absolute() - self.path = new_path - del CompilationArtifact._instances[old_abs_path] - CompilationArtifact._instances[abs_path] = self + return f"{self.__class__.__name__}({self.filename})" def is_available(self): - if self.fake_available: - return True - if not self.path.exists(): + """'Conceptual' availability: during a dry-run or in the planning stage, available may be True even if the underlying file does not exist yet.""" + # If any of our dependencies' dependencies are outdated, this artifact is also outdated + return self.available and self.dependencies_available() + + def dependencies_available(self): + return all(d.is_available() for d in self.dependencies) + + def is_available_in_filesystem(self): + """'Real' availability: checks if the underlying file exists and is up-to-date with respect to dependencies.""" + if not os.path.exists(self.filename): return False - for dependency in self.depends: - # If any of our dependencies' dependencies are outdated, this artifact is also outdated - if not dependency.is_available(): - return False - # If any of our direct dependencies are newer than this artifact, this artifact is invalid - if dependency.is_newer_than(os.path.getmtime(str(self.path))): + file_mtime = os.path.getmtime(self.filename) + for dependency in self.dependencies: + if not dependency.is_available_in_filesystem() or os.path.getmtime(dependency.filename) > file_mtime: return False return True - def is_newer_than(self, time): - if self.fake_available: - return True - return os.path.getmtime(str(self.path)) > time - - def delete(self): - for user in self.users: - user.depends.remove(self) - del self._instances[self.path.absolute()] - return self.users - class SourceArtifact(CompilationArtifact): + """Artifact representing a source file that does not need to be generated, is assumed to be there.""" pass class XclbinArtifact(CompilationArtifact): def __init__( - self, path, depends, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None + self, filename, dependencies, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None ): - super().__init__(path, depends) + super().__init__(filename, dependencies) self.kernel_name = kernel_name self.extra_flags = extra_flags if extra_flags is not None else [] self.xclbin_input = xclbin_input class InstsBinArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None): - super().__init__(path, depends) + def __init__(self, filename, dependencies, extra_flags=None): + super().__init__(filename, dependencies) self.extra_flags = extra_flags if extra_flags is not None else [] class KernelObjectArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None, rename_symbols=None): - super().__init__(path, depends) + def __init__(self, filename, dependencies, extra_flags=None, rename_symbols=None): + super().__init__(filename, dependencies) self.extra_flags = extra_flags if extra_flags is not None else [] self.rename_symbols = rename_symbols if rename_symbols is not None else {} @@ -147,212 +196,166 @@ class KernelArchiveArtifact(CompilationArtifact): class PythonGeneratedMLIRArtifact(CompilationArtifact): def __init__( self, - path, + filename, import_path, callback_fn, callback_args=None, callback_kwargs=None, requires_context=False, ): - self.import_path = import_path self.callback_fn = callback_fn self.callback_args = callback_args if callback_args is not None else [] self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} self.requires_context = requires_context - super().__init__(path) + super().__init__(filename, dependencies=[]) - def is_available(self): - if self.fake_available: - return True - is_available = super().is_available() - if is_available: - # Force regeneration if the Python source is changed - return os.path.getmtime(str(self.path)) >= os.path.getmtime( - self.import_path - ) - return is_available + +# Compilation Command +# ########################################################################## + + +class CompilationCommand(ABC): + """An abstraction for anything that can be executed to physically produce artifacts.""" + @abstractmethod + def run(self) -> bool: + pass + + @abstractmethod + def __repr__(self): + pass + + +class ShellCompilationCommand(CompilationCommand): + def __init__(self, command: list[str], cwd=None, env='copy'): + self.command = command + self.cwd = cwd + if env == 'copy': + env = os.environ.copy() + self.env = env + + def run(self) -> bool: + result = subprocess.run( + self.command, + capture_output=True, + text=True, + cwd=self.cwd, + env=self.env, + ) + return 0 == result.returncode + + def __repr__(self): + return f"Shell({self.command})" + + +class PythonCallbackCompilationCommand(CompilationCommand): + def __init__(self, callback): + self.callback = callback + + def run(self) -> bool: + return bool(self.callback()) + + def __repr__(self): + return f"PythonCallback({self.callback})" # Compilation Rules -# -------------------------------------------------------------------------- +# ########################################################################## class CompilationRule(ABC): - def __init__(self, dry_run=None): - self.dry_run = dry_run + """A compilation rule is applied to a artifact graph, producing compilation commands and a transformed artifact graph.""" @abstractmethod - def matches(self, artifact: list[CompilationArtifact]) -> bool: + def matches(self, artifact: CompilationArtifactGraph) -> bool: + """Return true if this rule can be applied to any artifact in the artifact graph.""" pass - + @abstractmethod def compile( - self, artifacts: list[CompilationArtifact] - ) -> list[CompilationArtifact]: + self, artifacts: CompilationArtifactGraph + ) -> list[CompilationCommand]: + """Apply this rule to the artifact graph, returning compilation commands. This should modify the artifact graph in-place to reflect the newly generated artifacts.""" pass class GenerateMLIRFromPythonCompilationRule(CompilationRule): - def matches(self, artifacts): + def matches(self, graph): return any( isinstance(artifact, PythonGeneratedMLIRArtifact) - and len(artifact.depends) == 0 - for artifact in artifacts + and len(artifact.dependencies) == 1 + and isinstance(artifact.dependencies[0], SourceArtifact) + and artifact.dependencies_available() + for artifact in graph.bfs(only_unavailable=True) ) - def compile(self, artifacts): + def compile(self, graph): """Generate MLIR from a Python callback that uses the MLIR bindings""" - for i, artifact in enumerate(artifacts): + commands = [] + for i, artifact in enumerate(graph.bfs(only_unavailable=True)): if not isinstance(artifact, PythonGeneratedMLIRArtifact): continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - - if self.dry_run is None: - # Import the Python source file - spec = importlib.util.spec_from_file_location( - Path(artifact.import_path).name, artifact.import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - ctx_callback = lambda: ( - mlir_mod_ctx() if artifact.requires_context else nullcontext() - ) - with ctx_callback() as ctx: - callback_function = getattr(module, artifact.callback_fn) - mlir_code = callback_function( - *artifact.callback_args, **artifact.callback_kwargs - ) - # Stringify the generated MLIR - if artifact.requires_context: - mlir_code = str(ctx.module) - else: - mlir_code = str(mlir_code) - - with open(artifact.path, "w") as f: - f.write(mlir_code) - - # Now that the artifact is generated, replace this artifact with the MLIR source code file - old_users = artifact.delete() - new_artifact = SourceArtifact.new(artifact.path) - for user in old_users: - user.depends.append(new_artifact) - if self.dry_run is not None: - python_cmd = "" - # Import the Python source file - python_cmd += ( - "import sys; sys.path.append(" - f'"{Path(artifact.import_path).parent}"' - "); " - ) - python_cmd += f"from {Path(artifact.import_path).stem} import {artifact.callback_fn}; " - - # Check if we need to import device classes - # Device classes have __module__ == 'abc' but need to be imported from aie.iron.device - device_classes = set() - for arg in artifact.callback_args: - obj_module = type(arg).__module__ - obj_class = type(arg).__name__ - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - device_classes.add(obj_class) - for v in artifact.callback_kwargs.values(): - obj_module = type(v).__module__ - obj_class = type(v).__name__ - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - device_classes.add(obj_class) - - if device_classes: - python_cmd += f"from aie.iron.device import {', '.join(sorted(device_classes))}; " - - if artifact.requires_context: - python_cmd += "from aie.extras.context import mlir_mod_ctx; " - python_cmd += "with mlir_mod_ctx() as ctx: " - python_cmd += f"mlir_code = {artifact.callback_fn}({', '.join(map(GenerateMLIRFromPythonCompilationRule._repr_for_codegen, artifact.callback_args))}, {', '.join(f'{k}={_repr_for_codegen(v)}' for k, v in artifact.callback_kwargs.items())}); " - if artifact.requires_context: - python_cmd += "print(str(ctx.module))" - else: - python_cmd += "print(str(mlir_code))" - self.dry_run.append(f"python3 -c '{python_cmd}' > {artifact.path}") - new_artifact.fake_available = True - artifacts[i] = new_artifact - logging.debug(f"Created MLIR source string for {artifact.path.name}") - - return artifacts - - @staticmethod - def _repr_for_codegen(obj): - """Convert an object to its string representation for code generation. - - Handles special cases like device classes that need to be instantiated - rather than using their default repr(). - """ - # Check if this is a device class from aie.iron.device - # These classes have __module__ == 'abc' but are imported from aie.iron.device - obj_module = type(obj).__module__ - obj_class = type(obj).__name__ - - # Check for known device class patterns (NPU1, NPU2, XCVC1902, etc.) - # These are imported from aie.iron.device but have __module__ == 'abc' - if obj_module == "abc" and ( - obj_class.startswith("NPU") or obj_class.startswith("XCVC") - ): - # For device classes, generate instantiation code - return f"{obj_class}()" + assert len(artifact.dependencies) == 1 and isinstance(artifact.dependencies[0], SourceArtifact), "PythonGeneratedMLIRArtifact must depend on exactly one SourceArtifact" + import_path = Path(artifact.dependencies[0].filename) + new_artifact = SourceArtifact.new(artifact.filename) + callback = lambda: self.generate_mlir(new_artifact, import_path, artifact.callback_fn, artifact.callback_args, artifact.callback_kwargs, artifact.requires_context) + commands.append(PythonCallbackCompilationCommand(callback)) + new_artifact.available = True + graph.replace(artifact, new_artifact) + return commands + + @staticmethod + def generate_mlir(output_artifact, import_path, callback_fn, callback_args=None, callback_kwargs=None, requires_context=False): + # Import the Python source file + spec = importlib.util.spec_from_file_location( + Path(import_path).name, import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context + ctx_callback = lambda: ( + mlir_mod_ctx() if requires_context else nullcontext() + ) + with ctx_callback() as ctx: + callback_function = getattr(module, callback_fn) + mlir_code = callback_function( + *callback_args, **callback_kwargs + ) + # Stringify the generated MLIR + if requires_context: + mlir_code = str(ctx.module) + else: + mlir_code = str(mlir_code) - # Default to repr() for other types - return repr(obj) + with open(output_artifact.filename, "w") as f: + f.write(mlir_code) -class AieccCompilationRule(CompilationRule): +class AieccXclbinInstsCompilationRule(CompilationRule): def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): self.build_dir = build_dir self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" self.peano_dir = peano_dir super().__init__(*args, **kwargs) - def matches(self, artifacts): - return any( - isinstance(artifact, (XclbinArtifact, InstsBinArtifact)) - and all(dependency.is_available() for dependency in artifact.depends) - for artifact in artifacts - ) + def matches(self, graph): + return any(graph.get_worklist((XclbinArtifact, InstsBinArtifact))) - def compile(self, artifacts): + def compile(self, graph): # If there are both xclbin and insts.bin targets based on the same source MLIR code, we can combine them into one single `aiecc.py` invocation. mlir_sources = set() mlir_sources_to_xclbins = {} - mlir_sources_to_insts_bins = {} - for artifact in artifacts: - if not isinstance(artifact, (XclbinArtifact, InstsBinArtifact)): - continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - mlir_dependencies = [ - d - for d in artifact.depends - if isinstance(d, (SourceArtifact, PythonGeneratedMLIRArtifact)) - ] - if len(mlir_dependencies) != 1: - raise RuntimeError( - f"Expected exactly one dependency of {artifact.path} to be SourceArtifact or PythonGeneratedMLIRArtifact, got: {', '.join(str(dep.path) for dep in artifact.depends)}" - ) - mlir_dependency = mlir_dependencies[0] - mlir_sources.add(mlir_dependency) + mlir_sources_to_insts = {} + worklist = graph.get_worklist((XclbinArtifact, InstsBinArtifact)) + for artifact in worklist: + mlir_dependency = artifact.mlir_input if isinstance(artifact, XclbinArtifact): mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) elif isinstance(artifact, InstsBinArtifact): - mlir_sources_to_insts_bins.setdefault(mlir_dependency, []).append( - artifact - ) + mlir_sources_to_insts.setdefault(mlir_dependency, []).append(artifact) + commands = [] # Now we know for each mlir source if we need to generate an xclbin, an insts.bin or both for it for mlir_source in mlir_sources: - # Build aiecc command using Peano compile_cmd = [ "python", str(self.aiecc_path), @@ -364,68 +367,45 @@ def compile(self, artifacts): "--dynamic-objFifos", ] do_compile_xclbin = mlir_source in mlir_sources_to_xclbins - do_compile_insts_bin = mlir_source in mlir_sources_to_insts_bins + do_compile_insts_bin = mlir_source in mlir_sources_to_insts if do_compile_xclbin: first_xclbin = mlir_sources_to_xclbins[mlir_source][ 0 - ] # FIXME: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR + ] # TODO: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR compile_cmd += first_xclbin.extra_flags + [ "--aie-generate-xclbin", - "--xclbin-name=" + str(first_xclbin.path), + "--xclbin-name=" + first_xclbin.filename, "--xclbin-kernel-name=" + first_xclbin.kernel_name, ] if first_xclbin.xclbin_input is not None: compile_cmd += [ - "--xclbin-input=" + str(first_xclbin.xclbin_input.path) + "--xclbin-input=" + first_xclbin.xclbin_input.filename ] if do_compile_insts_bin: - first_insts_bin = mlir_sources_to_insts_bins[mlir_source][ + first_insts_bin = mlir_sources_to_insts[mlir_source][ 0 - ] # FIXME: this does not handle the case of multiple insts.bins with different flags from the same MLIR + ] # TODO: this does not handle the case of multiple insts.bins with different flags from the same MLIR if not do_compile_xclbin: compile_cmd += ["--no-compile"] compile_cmd += first_insts_bin.extra_flags + [ "--aie-generate-npu", - "--npu-insts-name=" + str(first_insts_bin.path), + "--npu-insts-name=" + first_insts_bin.filename, ] - compile_cmd += [str(mlir_source.path)] + compile_cmd += [mlir_source.filename] - env = os.environ.copy() - logging.debug(f"Compiling MLIR with command: {' '.join(compile_cmd)}") - if not self.dry_run: - result = subprocess.run( - compile_cmd, - cwd=str(self.build_dir), - capture_output=True, - text=True, - timeout=300, - env=env, - ) - if result.returncode == 0: - logging.debug( - f"Successfully compiled {mlir_source.path} to {', '.join([str(first_xclbin.path)] if do_compile_xclbin else [] + [str(first_insts_bin.path)] if do_compile_insts_bin else [])}" - ) - else: - raise RuntimeError( - f"MLIR compilation for {mlir_source.path} failed: {result.stderr}" - ) + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) - # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts_bins]: - if sources_to.get(mlir_source, [])[1:]: - copy_src = sources_to[mlir_source][0] - for copy_dest in sources_to[mlir_source][1:]: - shutil.copy(copy_src.path, copy_dest.path) + # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them + for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts]: + if sources_to.get(mlir_source, [])[1:]: + copy_src = sources_to[mlir_source][0] + for copy_dest in sources_to[mlir_source][1:]: + shutil.copy(copy_src.filename, copy_dest.filename) + + # Update graph + for artifact in worklist: + artifact.available = True - else: - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts_bins]: - for artifact in sources_to.get(mlir_source, []): - self.dry_run.append( - f"pushd {str(self.build_dir)} && {' '.join(compile_cmd)} && popd" - ) - artifact.fake_available = True - - # With the newly generated files, is_available() should now return True on the Xclbin and InstsBin targets return artifacts @@ -440,7 +420,7 @@ def matches(self, artifacts): isinstance(artifact, KernelObjectArtifact) and all( isinstance(dependency, SourceArtifact) and dependency.is_available() - for dependency in artifact.depends + for dependency in artifact.dependencies ) for artifact in artifacts ) @@ -453,11 +433,11 @@ def compile(self, artifacts): if not isinstance(artifact, KernelObjectArtifact): continue - if len(artifact.depends) != 1: + if len(artifact.dependencies) != 1: raise RuntimeError( "Expected exactly one dependency (the C source code) for KernelObjectArtifact" ) - source_file = artifact.depends[0] + source_file = artifact.dependencies[0] if not isinstance(source_file, SourceArtifact): raise RuntimeError( "Expected KernelObject dependency to be a C source file" @@ -477,7 +457,7 @@ def compile(self, artifacts): f"-I{str(include_path)}", ] + artifact.extra_flags - + ["-c", str(source_file.path), "-o", str(artifact.path)] + + ["-c", source_file.filename, "-o", artifact.filename] ) logging.debug(f"Running compilation command: {' '.join(cmd)}") @@ -485,7 +465,7 @@ def compile(self, artifacts): result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"Compilation failed: {result.stderr}") - logging.debug(f"Successfully compiled: {artifact.path.name}") + logging.debug(f"Successfully compiled: {artifact.filename}") else: artifact.fake_available = True self.dry_run.append(" ".join(cmd)) @@ -505,13 +485,13 @@ def _rename_symbols(self, artifact): "--redefine-sym", f"{old_sym}={new_sym}", ] - cmd += [str(artifact.path)] + cmd += [artifact.filename] logging.debug(f"Running renaming command: {' '.join(cmd)}") if self.dry_run is None: result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode == 0: - logging.info(f"Successfully renamed symbols in: {artifact.path.name}") + logging.info(f"Successfully renamed symbols in: {artifact.filename}") else: raise RuntimeError(f"Symbol renaming failed: {result.stderr}") else: @@ -526,7 +506,7 @@ def __init__(self, peano_dir, *args, **kwargs): def matches(self, artifacts): return any( - isinstance(artifact, KernelArchiveArtifact) and len(artifact.depends) > 0 + isinstance(artifact, KernelArchiveArtifact) and len(artifact.dependencies) > 0 for artifact in artifacts ) @@ -537,10 +517,10 @@ def compile(self, artifacts): continue # Get archive filename from method - archive_path = str(artifact.path) + archive_path = artifact.filename object_files = [ - str(dep.path) - for dep in artifact.depends + dep.filename + for dep in artifact.dependencies if isinstance(dep, KernelObjectArtifact) ] @@ -573,59 +553,3 @@ def compile(self, artifacts): self.dry_run.append(" ".join(cmd)) return artifacts - - -# Global Functions -# -------------------------------------------------------------------------- - - -def apply_rules(rules, artifacts): - for rule in rules: - if rule.matches(artifacts): - logging.debug(f"Applying rule: {rule.__class__.__name__}") - artifacts = rule.compile(artifacts) - break - else: - # None of the rules matched - return False, artifacts - - return True, artifacts - - -def compile(rules, artifacts): - # While some artifacts remain to be compiled (not all are available) - while not all(artifact.is_available() for artifact in artifacts): - remaining = [artifact for artifact in artifacts if not artifact.is_available()] - success, artifacts = apply_rules(rules, remaining) - if not success: - raise RuntimeError( - f"No matching rule to compile target(s): {', '.join(str(artifact.path.name) for artifact in artifacts if not artifact.is_available())}" - ) - return artifacts - - -def get_work_list(artifacts): - """ - Return a flattened artifact creation worklist in reverse topological order from dependencies. - The returned list will start with leaf nodes (artifacts with no dependencies), and any following artifacts will only contain artifacts from earlier in the list. - """ - work_list = [] - todo = list(artifacts) - visited = set() - - def dfs_visit(artifact): - if artifact in visited: - # Thanks to uniquing of artifact objects, this avoids duplicate creation of the same artifacts - return - visited.add(artifact) - # First visit all dependencies, so put leaves first (post-order) ... - for dep in artifact.depends: - dfs_visit(dep) - # ... then put parent - if not artifact.is_available(): - work_list.append(artifact) - - for artifact in todo: - dfs_visit(artifact) - - return work_list From 11d5802b7032e54c02eae585af483f2191aaf92a Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 13:01:06 -0700 Subject: [PATCH 38/99] autofuse update --- applications/llama_3.2_1b/autofuse.py | 156 ++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 12 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index 85e757b6..5431bce2 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -10,8 +10,10 @@ import time import importlib from aie import ir -from aie.dialects import aie, aiex +from aie.dialects import aie, aiex, memref from aie.extras.context import mlir_mod_ctx +import logging +logging.basicConfig(level=logging.DEBUG) repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) @@ -19,7 +21,7 @@ from operators.common.aie_context import AIEContext from operators.common import AIEOperatorBase, AIEBuffer, SingleMLIRSourceOperator from operators.common.utils import torch_to_numpy, numpy_to_torch -from operators.common.compilation import PythonGeneratedMLIRArtifact +from operators.common.compilation import SourceArtifact, PythonGeneratedMLIRArtifact from operators import AIEGEMV from operators.elementwise_mul.op import AIEElementwiseMul from operators.silu.op import AIESiLU @@ -108,6 +110,10 @@ def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): self.input_args = input_args self.output_args = output_args self.args = {} + self.input_buffer_size = 0 + self.output_buffer_size = 0 + self.scratch_buffer_size = 0 + self.buffer_map = {} # Maps buffer name -> {'type': 'input'|'output'|'scratch', 'offset': int, 'size': int} self.populate_args() AIEOperatorBase.__init__(self, *args, **kwargs) @@ -125,6 +131,29 @@ def populate_args(self): assert arg in self.args, f"Input argument {arg} not found in runlist buffers" for arg in self.output_args: assert arg in self.args, f"Output argument {arg} not found in runlist buffers" + + # Calculate buffer sizes for input, output, and scratch buffers + # Scratch buffers are those that are neither input nor output + scratch_args = [arg for arg in self.args if arg not in self.input_args and arg not in self.output_args] + + # Build the buffer map with offsets and sizes for each buffer type + self.input_buffer_size = self._calculate_buffer_size('input', self.input_args) + self.output_buffer_size = self._calculate_buffer_size('output', self.output_args) + self.scratch_buffer_size = self._calculate_buffer_size('scratch', scratch_args) + + def _calculate_buffer_size(self, buffer_type, args_list): + """Calculate total buffer size and populate buffer_map for a given buffer type.""" + offset = 0 + for arg in args_list: + arg_spec = self.args[arg] + size_bytes = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) + self.buffer_map[arg] = { + 'type': buffer_type, + 'offset': offset, + 'size': size_bytes + } + offset += size_bytes + return offset def get_operator_name(self): return self.name @@ -173,33 +202,136 @@ def get_mlir_artifact(self): device_mlir_strings[runlist_op] = str(op) device_names = {} # op -> str + sequence_arg_types = {} # op -> list of expected arg types with mlir_mod_ctx() as ctx: for i, (runlist_op, device_str) in enumerate(device_mlir_strings.items()): dev_op = aie.DeviceOp.parse(device_str) device_names[runlist_op] = f"dev{i}" dev_op.sym_name = ir.StringAttr.get(device_names[runlist_op]) + + # Extract the runtime sequence argument types + # Look for aie.runtime_sequence operations + found_sequence = False + for nested_op in dev_op.body_region.blocks[0].operations: + op_name = nested_op.operation.name + if op_name == 'aie.runtime_sequence': + # Found the runtime sequence - need to extract argument types + # The runtime_sequence contains a region with the actual function + if hasattr(nested_op, 'body') and hasattr(nested_op.body, 'blocks'): + # Look for the entry block which has the arguments + if len(nested_op.body.blocks) > 0: + entry_block = nested_op.body.blocks[0] + # Extract argument types from the block arguments + arg_types = [entry_block.arguments[i].type for i in range(len(entry_block.arguments))] + sequence_arg_types[runlist_op] = arg_types + found_sequence = True + break + + if not found_sequence: + raise RuntimeError(f"Could not find runtime sequence or extract argument types for operator {runlist_op}") + ctx.module.body.append(dev_op) @aie.device(device_ty) def main(): - # Argument 0 is scratch space for intermediate values. - # All other arguments are defined by the input/output buffers. + # Argument 0 is the input buffer, argument 1 is the output buffer, argument 2 is scratch space for intermediate values. + # All buffers are bf16, so convert byte sizes to element counts + bf16_itemsize = 2 @aiex.runtime_sequence( - np.ndarray[(1,), np.dtype[np.int8]], - np.ndarray[(1,), np.dtype[np.int8]], - np.ndarray[(1,), np.dtype[np.int8]], + np.ndarray[(self.input_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], + np.ndarray[(self.output_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], + np.ndarray[(self.scratch_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], ) def sequence(input_buf, output_buf, scratch_buf): + # Map buffer type to the appropriate consolidated buffer argument + consolidated_buffers = { + 'input': input_buf, + 'output': output_buf, + 'scratch': scratch_buf + } + for runlist_op, *bufs in self.runlist: configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(device_names[runlist_op]) configure_op = aiex.ConfigureOp(configure_sym_ref_attr) configure_body = configure_op.body.blocks.append() with ir.InsertionPoint(configure_body): + # Get the expected argument types for this operator's sequence + expected_arg_types = sequence_arg_types.get(runlist_op, []) + + # Generate subviews and reinterpret_casts for each buffer argument + buffer_ssa_values = [] + bf16_itemsize = 2 + expected_arg_types = sequence_arg_types.get(runlist_op, None) + + if expected_arg_types is None: + raise RuntimeError(f"No runtime sequence argument types found for operator {runlist_op}") + + for idx, buf_name in enumerate(bufs): + buf_info = self.buffer_map[buf_name] + buf_spec = self.args[buf_name] + + # Get the consolidated buffer this belongs to (already bf16) + consolidated_buf = consolidated_buffers[buf_info['type']] + + # Convert byte offsets/sizes to bf16 element offsets/sizes + offset_elements = buf_info['offset'] // bf16_itemsize + size_elements = buf_info['size'] // bf16_itemsize + + # Create subview from the bf16 buffer + subview = memref.subview( + consolidated_buf, + [offset_elements], + [size_elements], + [1] + ) + + # Get target shape from the expected argument type + if idx >= len(expected_arg_types): + raise RuntimeError(f"No expected type for argument {idx} (buffer {buf_name}) of operator {runlist_op}") + + target_type = expected_arg_types[idx] + expected_memref = ir.MemRefType(target_type) + target_shape = [expected_memref.shape[i] for i in range(expected_memref.rank)] + + # Verify the size matches + expected_size = np.prod(target_shape) + if expected_size != size_elements: + raise ValueError(f"Size mismatch for buffer {buf_name}: expected {expected_size} elements, got {size_elements}") + + # Build strides (assuming row-major layout) + strides = [] + stride = 1 + for dim in reversed(target_shape): + strides.insert(0, stride) + stride *= dim + + # Build the result memref type with target shape (bf16) + result_type = ir.MemRefType.get(target_shape, ir.BF16Type.get()) + + # Reinterpret_cast to reset offset to 0 and reshape + reinterpreted = memref.reinterpret_cast( + result=result_type, + source=subview, + offsets=[], + sizes=[], + strides=[], + static_offsets=[0], + static_sizes=target_shape, + static_strides=strides + ) + + buffer_ssa_values.append(reinterpreted) + + # Run the sequence with the prepared buffers sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") - run_op = aiex.RunOp(sequence_sym_ref_attr, [input_buf]) - print(str(ctx.module)) - print(ctx.module.operation.verify()) - print("success") - sys.exit(0) + run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) + + filename = self.get_operator_name() + "_fused.mlir" + with open(filename, "w") as f: + f.write(str(ctx.module)) + src_artifact = SourceArtifact.new(Path(filename)) + src_artifact.fake_available = True + return src_artifact + def get_arg_spec(self): pass From f1a2ab9029f9f3c0853693191077296d73310121 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 14:07:15 -0700 Subject: [PATCH 39/99] refactor compilation --- applications/llama_3.2_1b/autofuse.py | 10 +- .../llama_3.2_1b/src/aie_device_manager.py | 100 ---- applications/llama_3.2_1b/src/compilation.py | 535 ------------------ operators/axpy/op.py | 16 +- operators/common/aie_base.py | 52 +- operators/common/compilation.py | 182 +++--- operators/dequant/op.py | 16 +- operators/elementwise_add/op.py | 8 +- operators/elementwise_mul/op.py | 8 +- operators/gelu/op.py | 16 +- operators/gemm/op.py | 12 +- operators/gemv/design.py | 2 +- operators/gemv/op.py | 8 +- operators/layer_norm/op.py | 16 +- operators/leaky_relu/op.py | 16 +- operators/mem_copy/op.py | 20 +- operators/mha/op.py | 34 +- operators/relu/op.py | 16 +- operators/repeat/op.py | 2 +- operators/rms_norm/op.py | 14 +- operators/rope/op.py | 8 +- operators/sigmoid/op.py | 16 +- operators/silu/op.py | 8 +- operators/softmax/op.py | 8 +- operators/strided_copy/op.py | 2 +- operators/swiglu_decode/op.py | 6 +- operators/swiglu_prefill/op.py | 6 +- operators/tanh/op.py | 16 +- operators/transpose/op.py | 8 +- 29 files changed, 260 insertions(+), 901 deletions(-) delete mode 100644 applications/llama_3.2_1b/src/aie_device_manager.py delete mode 100644 applications/llama_3.2_1b/src/compilation.py diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index 5431bce2..c6e56af4 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -328,7 +328,7 @@ def sequence(input_buf, output_buf, scratch_buf): filename = self.get_operator_name() + "_fused.mlir" with open(filename, "w") as f: f.write(str(ctx.module)) - src_artifact = SourceArtifact.new(Path(filename)) + src_artifact = SourceArtifact(Path(filename)) src_artifact.fake_available = True return src_artifact @@ -365,10 +365,10 @@ def __call__(self, *kwargs): "ffn_output" ] ) -swiglu_fused = swiglu_fused_op.compile().get_callable() +#swiglu_fused = swiglu_fused_op.compile().get_callable() -def run_autofused(): - swiglu_fused() +#def run_autofused(): +# swiglu_fused() # CPU # --- @@ -391,6 +391,6 @@ def run_cpu(): # Main # --- -print(run_autofused()) +#print(run_autofused()) print(run_separate_xclbins()) print(run_cpu()) diff --git a/applications/llama_3.2_1b/src/aie_device_manager.py b/applications/llama_3.2_1b/src/aie_device_manager.py deleted file mode 100644 index a06d957a..00000000 --- a/applications/llama_3.2_1b/src/aie_device_manager.py +++ /dev/null @@ -1,100 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Global AIE Device Manager for resource sharing and cleanup -""" - -import logging -import os -import sys -from pathlib import Path -from typing import Dict, Optional, Any -import pyxrt -from aie.iron.hostruntime.config import detect_npu_device -from aie.iron.device import NPU1, NPU2 - - -class AIEDeviceManager: - """Singleton manager for AIE XRT resources""" - - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - self.device = pyxrt.device(0) - self.device_type = detect_npu_device() - self.contexts = {} # xclbin_path -> (context, xclbin) - self.kernels = {} # (xclbin_path, kernel_name) -> kernel - - def get_context_and_kernel( - self, xclbin_path: str, kernel_name: str | None = None - ) -> (pyxrt.hw_context, pyxrt.kernel): - """Get or create hardware context and kernel for xclbin""" - # Check if we already have a context for this xclbin - - if xclbin_path not in self.contexts: - xclbin = pyxrt.xclbin(xclbin_path) - self.device.register_xclbin(xclbin) - xclbin_uuid = xclbin.get_uuid() - context = pyxrt.hw_context(self.device, xclbin_uuid) - self.contexts[xclbin_path] = (context, xclbin) - logging.debug(f"Created new context for {Path(xclbin_path).name}") - else: - context, xclbin = self.contexts[xclbin_path] - logging.debug(f"Reusing context for {Path(xclbin_path).name}") - - # Get kernel name if not provided - if kernel_name is None: - kernels = xclbin.get_kernels() - if not kernels: - raise RuntimeError("No kernels found in xclbin") - kernel_name = kernels[0].get_name() - - # Check if we already have the kernel - kernel_key = (xclbin_path, kernel_name) - if kernel_key not in self.kernels: - self.kernels[kernel_key] = pyxrt.kernel(context, kernel_name) - logging.debug( - f"Created new kernel {kernel_name} from xclbin {Path(xclbin_path).name}" - ) - else: - logging.debug( - f"Reusing kernel: {kernel_name} from xclbin {Path(xclbin_path).name}" - ) - - return context, self.kernels[kernel_key] - - def device_str(self) -> str: - return self.device_type.resolve().name - - def cleanup(self): - """Clean up all XRT resources""" - self.kernels.clear() - - # Clear contexts - for xclbin_path, (context, xclbin) in self.contexts.items(): - try: - del context - except: - pass - self.contexts.clear() - - # Clear device - if self.device is not None: - try: - del self.device - except: - pass - self.device = None - - logging.debug("Cleaned up AIE device manager") - - def reset(self): - """Reset the device manager (for debugging)""" - self.cleanup() - AIEDeviceManager._instance = None diff --git a/applications/llama_3.2_1b/src/compilation.py b/applications/llama_3.2_1b/src/compilation.py deleted file mode 100644 index c6716441..00000000 --- a/applications/llama_3.2_1b/src/compilation.py +++ /dev/null @@ -1,535 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -This file implements a simple Python-based build system. You specify what you -want to compile (*artifacts*) through subclasses of `CompilationArtifact`. -Each artifact can have a list of depenencies of other artifacts that it relies -on. Each artifact corresponds to exactly one file. If a file with a matching -name already exists, and all its dependencies are built and older than the file, -then the existing file will be reused. - -For each file name, artifacts are singletons. You create artifacts by calling -the `new` class method of the appropriate class. This ensures that artifact -objects are uniqued, i.e., calling `new` twice with the same file name will -return the same object. - -There is a special artifact for source files that do not need to get generated, -`SourceArtifact`. It is likely that in your compilation dependency graph, -the leaf nodes will be `SourceArtifact`s. - -You specify how to generate (compile) an artifact through *rules*, which are -expressed as subclasses of `CompilationRule`. This class requires you to -implement two methods: `matches` and `compile`. During compilation, we will -call `matches` on the set of remaining artifacts to see if the given rule is -able to produce any of the artifacts not available yet. If this function -returns `True`, we will call `compile` on the rule to generate the artifact. -`compile` returns a new list of artifacts, which may be the same one as -before; however, if `matches()==True`, at least one of the artifacts in the -list must be made available after calling `compile()`. -""" - -from abc import ABC, abstractmethod -from pathlib import Path -import os.path -import zlib -import logging -import subprocess -import importlib.util -from contextlib import nullcontext -from aie.extras.context import mlir_mod_ctx - - -# Compilation Artifacts -# -------------------------------------------------------------------------- - - -class CompilationArtifact(ABC): - _instances = {} - - @classmethod - def new(cls, path, *args, **kwargs): - """Uniques artifacts based on absolute file path; any two artifacts with the same absolute path will be represented by the same object.""" - path = Path(path) - abs_path = path.absolute() - if abs_path not in cls._instances: - cls._instances[abs_path] = None - instance = cls(path, *args, **kwargs) - cls._instances[abs_path] = instance - else: - assert ( - type(cls._instances[abs_path]) == cls - ), f"Artifact with path {abs_path} is already registered with a different type" - return cls._instances[abs_path] - - def __init__(self, path, depends=None): - abs_path = path.absolute() - assert ( - abs_path in self._instances - ), "do not construct artifact objects directly; call the get() class method instead for uniquing" - self.path: Path = path - self.depends: list[CompilationArtifact] = depends if depends is not None else [] - self.users: list[CompilationArtifact] = ( - [] - ) # List of ancestor artifacts that depend on this artifact - for dependency in self.depends: - dependency.users.append(self) - - def __repr__(self): - return f"{self.__class__.__name__}(path={self.path}, depends={self.depends})" - - def set_path(self, new_path): - old_abs_path = self.path.absolute() - new_path = Path(new_path) - abs_path = new_path.absolute() - self.path = new_path - del CompilationArtifact._instances[old_abs_path] - CompilationArtifact._instances[abs_path] = self - - def is_available(self): - if not self.path.exists(): - return False - for dependency in self.depends: - # If any of our dependencies' dependencies are outdated, this artifact is also outdated - if not dependency.is_available(): - return False - # If any of our direct dependencies are newer than this artifact, this artifact is invalid - if dependency.is_newer_than(os.path.getmtime(str(self.path))): - return False - return True - - def is_newer_than(self, time): - return os.path.getmtime(str(self.path)) > time - - def delete(self): - for user in self.users: - user.depends.remove(self) - del self._instances[self.path.absolute()] - return self.users - - -class SourceArtifact(CompilationArtifact): - pass - - -class XclbinArtifact(CompilationArtifact): - def __init__( - self, path, depends, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None - ): - super().__init__(path, depends) - self.kernel_name = kernel_name - self.extra_flags = extra_flags if extra_flags is not None else [] - self.xclbin_input = xclbin_input - - -class InstsBinArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None): - super().__init__(path, depends) - self.extra_flags = extra_flags if extra_flags is not None else [] - - -class KernelObjectArtifact(CompilationArtifact): - def __init__(self, path, depends, extra_flags=None, rename_symbols=None): - super().__init__(path, depends) - self.extra_flags = extra_flags if extra_flags is not None else [] - self.rename_symbols = rename_symbols if rename_symbols is not None else {} - - -class KernelArchiveArtifact(CompilationArtifact): - pass - - -class PythonGeneratedMLIRArtifact(CompilationArtifact): - def __init__( - self, - path, - import_path, - callback_fn, - callback_args=None, - callback_kwargs=None, - requires_context=False, - ): - self.import_path = import_path - self.callback_fn = callback_fn - self.callback_args = callback_args if callback_args is not None else [] - self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} - self.requires_context = requires_context - super().__init__(path) - - def is_available(self): - is_available = super().is_available() - if is_available: - # Force regeneration if the Python source is changed - return os.path.getmtime(str(self.path)) >= os.path.getmtime( - self.import_path - ) - return is_available - - -# Compilation Rules -# -------------------------------------------------------------------------- - - -class CompilationRule(ABC): - @abstractmethod - def matches(self, artifact: list[CompilationArtifact]) -> bool: - pass - - @abstractmethod - def compile( - self, artifacts: list[CompilationArtifact] - ) -> list[CompilationArtifact]: - pass - - -class GenerateMLIRFromPythonCompilationRule(CompilationRule): - def matches(self, artifacts): - return any( - isinstance(artifact, PythonGeneratedMLIRArtifact) - and len(artifact.depends) == 0 - for artifact in artifacts - ) - - def compile(self, artifacts): - """Generate MLIR from a Python callback that uses the MLIR bindings""" - for i, artifact in enumerate(artifacts): - if not isinstance(artifact, PythonGeneratedMLIRArtifact): - continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - - # Import the Python source file - spec = importlib.util.spec_from_file_location( - Path(artifact.import_path).name, artifact.import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - ctx_callback = lambda: ( - mlir_mod_ctx() if artifact.requires_context else nullcontext() - ) - with ctx_callback() as ctx: - callback_function = getattr(module, artifact.callback_fn) - mlir_code = callback_function( - *artifact.callback_args, **artifact.callback_kwargs - ) - # Stringify the generated MLIR - if artifact.requires_context: - mlir_code = str(ctx.module) - else: - mlir_code = str(mlir_code) - - with open(artifact.path, "w") as f: - f.write(mlir_code) - - # Now that the artifact is generated, replace this artifact with the MLIR source code file - old_users = artifact.delete() - new_artifact = SourceArtifact.new(artifact.path) - for user in old_users: - user.depends.append(new_artifact) - artifacts[i] = new_artifact - logging.debug(f"Created MLIR source string for {artifact.path.name}") - - return artifacts - - -class AieccCompilationRule(CompilationRule): - def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): - self.build_dir = build_dir - self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, (XclbinArtifact, InstsBinArtifact)) - and all(dependency.is_available() for dependency in artifact.depends) - for artifact in artifacts - ) - - def compile(self, artifacts): - # If there are both xclbin and insts.bin targets based on the same source MLIR code, we can combine them into one single `aiecc.py` invocation. - mlir_sources = set() - mlir_sources_to_xclbins = {} - mlir_sources_to_insts_bins = {} - for artifact in artifacts: - if not isinstance(artifact, (XclbinArtifact, InstsBinArtifact)): - continue - if not all(dependency.is_available() for dependency in artifact.depends): - continue - mlir_dependencies = [ - d - for d in artifact.depends - if isinstance(d, (SourceArtifact, PythonGeneratedMLIRArtifact)) - ] - if len(mlir_dependencies) != 1: - raise RuntimeError( - f"Expected exactly one dependency of {artifact.path} to be SourceArtifact or PythonGeneratedMLIRArtifact, got: {', '.join(str(dep.path) for dep in artifact.depends)}" - ) - mlir_dependency = mlir_dependencies[0] - mlir_sources.add(mlir_dependency) - if isinstance(artifact, XclbinArtifact): - mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) - elif isinstance(artifact, InstsBinArtifact): - mlir_sources_to_insts_bins.setdefault(mlir_dependency, []).append( - artifact - ) - - # Now we know for each mlir source if we need to generate an xclbin, an insts.bin or both for it - for mlir_source in mlir_sources: - # Build aiecc command using Peano - compile_cmd = [ - "python", - str(self.aiecc_path), - "--no-compile-host", - "--no-xchesscc", - "--no-xbridge", - "--peano", - str(self.peano_dir), - ] - do_compile_xclbin = mlir_source in mlir_sources_to_xclbins - do_compile_insts_bin = mlir_source in mlir_sources_to_insts_bins - if do_compile_xclbin: - first_xclbin = mlir_sources_to_xclbins[mlir_source][ - 0 - ] # FIXME: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR - compile_cmd += first_xclbin.extra_flags + [ - "--aie-generate-xclbin", - "--xclbin-name=" + str(first_xclbin.path), - "--xclbin-kernel-name=" + first_xclbin.kernel_name, - ] - if first_xclbin.xclbin_input is not None: - compile_cmd += [ - "--xclbin-input=" + str(first_xclbin.xclbin_input.path) - ] - if do_compile_insts_bin: - first_insts_bin = mlir_sources_to_insts_bins[mlir_source][ - 0 - ] # FIXME: this does not handle the case of multiple insts.bins with different flags from the same MLIR - if not do_compile_xclbin: - compile_cmd += ["--no-compile"] - compile_cmd += first_insts_bin.extra_flags + [ - "--aie-generate-npu", - "--npu-insts-name=" + str(first_insts_bin.path), - ] - compile_cmd += [str(mlir_source.path)] - - env = os.environ.copy() - logging.debug(f"Compiling MLIR with command: {' '.join(compile_cmd)}") - result = subprocess.run( - compile_cmd, - cwd=str(self.build_dir), - capture_output=True, - text=True, - timeout=300, - env=env, - ) - if result.returncode == 0: - logging.debug( - f"Successfully compiled {mlir_source.path} to {', '.join([str(first_xclbin.path)] if do_compile_xclbin else [] + [str(first_insts_bin.path)] if do_compile_insts_bin else [])}" - ) - else: - raise RuntimeError( - f"MLIR compilation for {mlir_source.path} failed: {result.stderr}" - ) - - # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts_bins]: - if sources_to.get(mlir_source, [])[1:]: - copy_src = sources_to[mlir_source][0] - for copy_dest in sources_to[mlir_source][1:]: - shutil.copy(copy_src.path, copy_dest.path) - - # With the newly generated files, is_available() should now return True on the Xclbin and InstsBin targets - return artifacts - - -class PeanoCompilationRule(CompilationRule): - def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): - self.peano_dir = peano_dir - self.mlir_aie_dir = mlir_aie_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, KernelObjectArtifact) - and all( - isinstance(dependency, SourceArtifact) and dependency.is_available() - for dependency in artifact.depends - ) - for artifact in artifacts - ) - - def compile(self, artifacts): - clang_path = Path(self.peano_dir) / "bin" / "clang++" - include_path = Path(self.mlir_aie_dir) / "include" - - for artifact in artifacts: - if not isinstance(artifact, KernelObjectArtifact): - continue - - if len(artifact.depends) != 1: - raise RuntimeError( - "Expected exactly one dependency (the C source code) for KernelObjectArtifact" - ) - source_file = artifact.depends[0] - if not isinstance(source_file, SourceArtifact): - raise RuntimeError( - "Expected KernelObject dependency to be a C source file" - ) - - cmd = ( - [ - str(clang_path), - "-O2", - "-std=c++20", - "--target=aie2p-none-unknown-elf", - "-Wno-parentheses", - "-Wno-attributes", - "-Wno-macro-redefined", - "-Wno-empty-body", - "-Wno-missing-template-arg-list-after-template-kw", - f"-I{str(include_path)}", - ] - + artifact.extra_flags - + ["-c", str(source_file.path), "-o", str(artifact.path)] - ) - logging.debug(f"Running compilation command: {' '.join(cmd)}") - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Compilation failed: {result.stderr}") - logging.debug(f"Successfully compiled: {artifact.path.name}") - - if artifact.rename_symbols: - self._rename_symbols(artifact) - - return artifacts - - def _rename_symbols(self, artifact): - objcopy_path = "llvm-objcopy-18" - cmd = [ - objcopy_path, - ] - for old_sym, new_sym in artifact.rename_symbols.items(): - cmd += [ - "--redefine-sym", - f"{old_sym}={new_sym}", - ] - cmd += [str(artifact.path)] - - logging.debug(f"Running renaming command: {' '.join(cmd)}") - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode == 0: - logging.debug(f"Successfully renamed symbols in: {artifact.path.name}") - else: - raise RuntimeError(f"Symbol renaming failed: {result.stderr}") - - -class ArchiveCompilationRule(CompilationRule): - def __init__(self, peano_dir, *args, **kwargs): - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any( - isinstance(artifact, KernelArchiveArtifact) and len(artifact.depends) > 0 - for artifact in artifacts - ) - - def compile(self, artifacts): - """Create an archive (.a) from compiled object files""" - for artifact in artifacts: - if not isinstance(artifact, KernelArchiveArtifact): - continue - - # Get archive filename from method - archive_path = str(artifact.path) - object_files = [ - str(dep.path) - for dep in artifact.depends - if isinstance(dep, KernelObjectArtifact) - ] - - # Try to find ar tool from PEANO, then system - ar_path = None - - if self.peano_dir: - # Peano has llvm-ar for archiving - peano_ar = Path(self.peano_dir) / "bin" / "llvm-ar" - if os.path.exists(peano_ar): - ar_path = peano_ar - - if ar_path is None: - raise RuntimeError( - "Could not find 'ar' tool in PEANO installation or system PATH" - ) - - cmd = [str(ar_path), "rcs", archive_path] + object_files - - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode == 0: - logging.debug( - f"Successfully created archive: {Path(archive_path).name}" - ) - else: - raise RuntimeError(f"Archive creation failed: {result.stderr}") - - return artifacts - - -# Global Functions -# -------------------------------------------------------------------------- - - -def apply_rules(rules, artifacts): - for rule in rules: - if rule.matches(artifacts): - logging.debug(f"Applying rule: {rule.__class__.__name__}") - artifacts = rule.compile(artifacts) - break - else: - # None of the rules matched - return False, artifacts - - return True, artifacts - - -def compile(rules, artifacts): - # While some artifacts remain to be compiled (not all are available) - while not all(artifact.is_available() for artifact in artifacts): - remaining = [artifact for artifact in artifacts if not artifact.is_available()] - success, artifacts = apply_rules(rules, remaining) - if not success: - raise RuntimeError( - f"No matching rule to compile target(s): {', '.join(str(artifact.path.name) for artifact in artifacts if not artifact.is_available())}" - ) - return artifacts - - -def get_work_list(artifacts): - """ - Return a flattened artifact creation worklist in reverse topological order from dependencies. - The returned list will start with leaf nodes (artifacts with no dependencies), and any following artifacts will only contain artifacts from earlier in the list. - """ - work_list = [] - todo = list(artifacts) - visited = set() - - def dfs_visit(artifact): - if artifact in visited: - # Thanks to uniquing of artifact objects, this avoids duplicate creation of the same artifacts - return - visited.add(artifact) - # First visit all dependencies, so put leaves first (post-order) ... - for dep in artifact.depends: - dfs_visit(dep) - # ... then put parent - if not artifact.is_available(): - work_list.append(artifact) - - for artifact in todo: - dfs_visit(artifact) - - return work_list diff --git a/operators/axpy/op.py b/operators/axpy/op.py index a64c8b85..8698741e 100644 --- a/operators/axpy/op.py +++ b/operators/axpy/op.py @@ -47,7 +47,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_axpy", @@ -62,14 +62,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"axpy.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" @@ -80,8 +80,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/common/aie_base.py b/operators/common/aie_base.py index a336bcb2..99ff2890 100644 --- a/operators/common/aie_base.py +++ b/operators/common/aie_base.py @@ -61,51 +61,32 @@ def get_default_context(cls): AIEOperatorBase._default_context = AIEContext() return AIEOperatorBase._default_context - def compile(self, dry_run=None): + def compile(self, dry_run=False): """ Set up the operator and compile any necessary artifacts. Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. """ context = self.context self.set_up_artifacts() - self._move_artifact_paths() - work_list = comp.get_work_list(self.artifacts) compilation_rules = [ - comp.GenerateMLIRFromPythonCompilationRule(dry_run=dry_run), + comp.GenerateMLIRFromPythonCompilationRule(), comp.PeanoCompilationRule( - context.peano_dir, context.mlir_aie_dir, dry_run=dry_run + context.peano_dir, context.mlir_aie_dir ), - comp.ArchiveCompilationRule(context.peano_dir, dry_run=dry_run), - comp.AieccCompilationRule( + comp.ArchiveCompilationRule(context.peano_dir), + comp.AieccXclbinInstsCompilationRule( context.build_dir, context.peano_dir, context.mlir_aie_dir, - dry_run=dry_run, ), ] - if work_list: - logging.info( - f"Compiling {len(work_list)} new artifacts for AIE operator {self.__class__.__name__}: {', '.join(str(artifact.path.name) for artifact in work_list)}" - ) - comp.compile(compilation_rules, work_list) + artifacts = comp.CompilationArtifactGraph(self.artifacts) + comp.compile(compilation_rules, artifacts, context.build_dir, dry_run=dry_run) return self def add_artifacts(self, artifacts): self.artifacts.extend(artifacts) - def _move_artifact_paths(self): - """Make all artifacts paths point into the build directory (source artifacts into the ironclad source directory). This doesn't phyisically move files; this function is called before artifact generation.""" - context = self.context - todo = self.artifacts.copy() - while todo: - artifact = todo[0] - todo.pop(0) - if isinstance(artifact, comp.SourceArtifact): - artifact.set_path(context.base_dir / artifact.path) - else: - artifact.set_path(context.build_dir / artifact.path) - todo.extend(artifact.depends) - def sync_to_device(bos): for bo in bos: @@ -147,17 +128,20 @@ def get_artifacts(self): mlir_artifact = self.get_mlir_artifact() kernel_deps_inputs = self.get_kernel_artifacts() kernel_deps = [ - KernelArchiveArtifact.new( + KernelArchiveArtifact( self.get_kernel_archive_name(), - depends=kernel_deps_inputs, + dependencies=kernel_deps_inputs, ) ] if kernel_deps_inputs else [] - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{operator_name}.xclbin", - depends=[mlir_artifact] + kernel_deps, + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, ) - insts_artifact = InstsBinArtifact.new( - f"{operator_name}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{operator_name}.bin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] ) return xclbin_artifact, insts_artifact @@ -169,9 +153,9 @@ def set_up_artifacts(self): def get_callable(self): return SingleXclbinCallable( - xclbin_path=self.xclbin_artifact.path, + xclbin_path=self.xclbin_artifact.filename, kernel_name=self.xclbin_artifact.kernel_name, - insts_bin_path=self.insts_artifact.path, + insts_bin_path=self.insts_artifact.filename, args_spec=self.get_arg_spec() ) diff --git a/operators/common/compilation.py b/operators/common/compilation.py index 46f77e87..a0565943 100644 --- a/operators/common/compilation.py +++ b/operators/common/compilation.py @@ -38,18 +38,20 @@ import importlib.util from contextlib import nullcontext from aie.extras.context import mlir_mod_ctx +import copy # Global Functions # ########################################################################## -def plan(rules, graph: CompilationArtifactGraph): +def plan(rules, graph): if all(artifact.is_available() for artifact in graph): return [] # Everything has been compiled for rule in rules: if rule.matches(graph): - commands, new_graph = rule.compile(graph) + new_graph = graph.copy() + commands = rule.compile(new_graph) break else: raise RuntimeError( @@ -60,7 +62,7 @@ def plan(rules, graph: CompilationArtifactGraph): def execute(plan): for rule, commands, _ in plan: - logging.debug(f"Executing rule: {rule.__class__.__name__}") + logging.debug(f"Applying rule: {rule.__class__.__name__}") for command in commands: logging.debug(f" Executing command: {command}") success = command.run() @@ -68,10 +70,17 @@ def execute(plan): raise RuntimeError(f"Command failed: {command}") -def compile(rules, artifacts): +def compile(rules, artifacts, build_dir="build", dry_run=False): + if not os.path.exists(build_dir) and not dry_run: + os.makedirs(build_dir) + artifacts.move_artifacts(build_dir) + artifacts.populate_availability_from_filesystem() plan_steps = plan(rules, artifacts) - print(plan_steps) - execute(plan_steps) + if not dry_run: + execute(plan_steps) + else: + print("\n".join("\n".join(map(str, cmds)) for _, cmds, _ in plan_steps)) + # Compilation Artifact Graph @@ -81,10 +90,31 @@ def compile(rules, artifacts): class CompilationArtifactGraph: def __init__(self, artifacts=None): self.artifacts = artifacts if artifacts is not None else [] + + def __repr__(self): + def format_artifact(artifact, indent=0): + prefix = " " * indent + avail = "[x] " if artifact.is_available() else "[ ] " + result = f"{prefix}{avail}{artifact.__class__.__name__}({Path(artifact.filename).name})\n" + for dep in artifact.dependencies: + result += format_artifact(dep, indent + 1) + return result + + result = "CompilationArtifactGraph(\n" + for artifact in self.artifacts: + result += format_artifact(artifact, indent=1) + result += ")" + return result def __iter__(self): return iter(self.artifacts) + def __len__(self): + return len(self.artifacts) + + def __getitem__(self, index): + return self.artifacts[index] + def dfs(self): return self._traverse(True) @@ -103,7 +133,19 @@ def _traverse(self, dfs): yield artifact def copy(self): - return CompilationArtifactGraph(artifacts=self.artifacts.copy()) + artifact_map = {} + + def copy_artifact(artifact): + if artifact in artifact_map: + return artifact_map[artifact] + new_artifact = copy.copy(artifact) + artifact_map[artifact] = new_artifact + new_deps = [copy_artifact(dep) for dep in artifact.dependencies] + new_artifact.dependencies = CompilationArtifactGraph(artifacts=new_deps) + return new_artifact + + new_artifacts = [copy_artifact(artifact) for artifact in self.artifacts] + return CompilationArtifactGraph(artifacts=new_artifacts) def replace(self, old_artifact, new_artifact): for i, artifact in enumerate(self.artifacts): @@ -115,18 +157,25 @@ def replace(self, old_artifact, new_artifact): def populate_availability_from_filesystem(self): for artifact in self.artifacts: + artifact.dependencies.populate_availability_from_filesystem() artifact.available = artifact.is_available_in_filesystem() def get_worklist(self, kind): """Return a list of artifacts of the given kind that can be built in the next step (dependencies available).""" return [ artifact - for artifact in self.artifacts.bfs() + for artifact in self.bfs() if isinstance(artifact, kind) and not artifact.is_available() and artifact.dependencies_available() ] + def move_artifacts(self, new_root): + """Make all artifacts paths point into a build directory""" + for artifact in self.bfs(): + if not os.path.isabs(artifact.filename): + artifact.filename = str(Path(new_root) / Path(artifact.filename).name) + # Compilation Artifacts # ########################################################################## @@ -134,10 +183,9 @@ def get_worklist(self, kind): class CompilationArtifact(ABC): def __init__(self, filename, dependencies=None, available=False): - self.filename = filename + self.filename = str(filename) self.dependencies: CompilationArtifactGraph = CompilationArtifactGraph(artifacts=dependencies if dependencies is not None else []) self.available = available - return self def __repr__(self): return f"{self.__class__.__name__}({self.filename})" @@ -168,16 +216,22 @@ class SourceArtifact(CompilationArtifact): class XclbinArtifact(CompilationArtifact): def __init__( - self, filename, dependencies, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None + self, filename, mlir_input, dependencies, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None ): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] super().__init__(filename, dependencies) + self.mlir_input = mlir_input self.kernel_name = kernel_name self.extra_flags = extra_flags if extra_flags is not None else [] self.xclbin_input = xclbin_input class InstsBinArtifact(CompilationArtifact): - def __init__(self, filename, dependencies, extra_flags=None): + def __init__(self, filename, mlir_input, dependencies, extra_flags=None): + self.mlir_input = mlir_input + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] super().__init__(filename, dependencies) self.extra_flags = extra_flags if extra_flags is not None else [] @@ -203,11 +257,13 @@ def __init__( callback_kwargs=None, requires_context=False, ): + self.import_path = import_path self.callback_fn = callback_fn self.callback_args = callback_args if callback_args is not None else [] self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} self.requires_context = requires_context - super().__init__(filename, dependencies=[]) + dependencies = [SourceArtifact(import_path)] + super().__init__(filename, dependencies=dependencies) # Compilation Command @@ -244,7 +300,7 @@ def run(self) -> bool: return 0 == result.returncode def __repr__(self): - return f"Shell({self.command})" + return f"Shell({' '.join(self.command)})" class PythonCallbackCompilationCommand(CompilationCommand): @@ -252,7 +308,8 @@ def __init__(self, callback): self.callback = callback def run(self) -> bool: - return bool(self.callback()) + result = self.callback() + return bool(result) if result is not None else True def __repr__(self): return f"PythonCallback({self.callback})" @@ -280,23 +337,16 @@ def compile( class GenerateMLIRFromPythonCompilationRule(CompilationRule): def matches(self, graph): - return any( - isinstance(artifact, PythonGeneratedMLIRArtifact) - and len(artifact.dependencies) == 1 - and isinstance(artifact.dependencies[0], SourceArtifact) - and artifact.dependencies_available() - for artifact in graph.bfs(only_unavailable=True) - ) + return any(graph.get_worklist(PythonGeneratedMLIRArtifact)) def compile(self, graph): """Generate MLIR from a Python callback that uses the MLIR bindings""" commands = [] - for i, artifact in enumerate(graph.bfs(only_unavailable=True)): - if not isinstance(artifact, PythonGeneratedMLIRArtifact): - continue + worklist = graph.get_worklist(PythonGeneratedMLIRArtifact) + for artifact in worklist: assert len(artifact.dependencies) == 1 and isinstance(artifact.dependencies[0], SourceArtifact), "PythonGeneratedMLIRArtifact must depend on exactly one SourceArtifact" import_path = Path(artifact.dependencies[0].filename) - new_artifact = SourceArtifact.new(artifact.filename) + new_artifact = SourceArtifact(artifact.filename) callback = lambda: self.generate_mlir(new_artifact, import_path, artifact.callback_fn, artifact.callback_args, artifact.callback_kwargs, artifact.requires_context) commands.append(PythonCallbackCompilationCommand(callback)) new_artifact.available = True @@ -348,6 +398,7 @@ def compile(self, graph): worklist = graph.get_worklist((XclbinArtifact, InstsBinArtifact)) for artifact in worklist: mlir_dependency = artifact.mlir_input + mlir_sources.add(mlir_dependency) if isinstance(artifact, XclbinArtifact): mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) elif isinstance(artifact, InstsBinArtifact): @@ -393,20 +444,20 @@ def compile(self, graph): ] compile_cmd += [mlir_source.filename] - ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts]: if sources_to.get(mlir_source, [])[1:]: copy_src = sources_to[mlir_source][0] for copy_dest in sources_to[mlir_source][1:]: - shutil.copy(copy_src.filename, copy_dest.filename) + commands.append(ShellCompilationCommand(['cp', copy_src.filename, copy_dest.filename])) # Update graph for artifact in worklist: artifact.available = True - return artifacts + return commands class PeanoCompilationRule(CompilationRule): @@ -416,23 +467,14 @@ def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): super().__init__(*args, **kwargs) def matches(self, artifacts): - return any( - isinstance(artifact, KernelObjectArtifact) - and all( - isinstance(dependency, SourceArtifact) and dependency.is_available() - for dependency in artifact.dependencies - ) - for artifact in artifacts - ) + return any(artifacts.get_worklist(KernelObjectArtifact)) def compile(self, artifacts): clang_path = Path(self.peano_dir) / "bin" / "clang++" include_path = Path(self.mlir_aie_dir) / "include" - - for artifact in artifacts: - if not isinstance(artifact, KernelObjectArtifact): - continue - + worklist = artifacts.get_worklist(KernelObjectArtifact) + commands = [] + for artifact in worklist: if len(artifact.dependencies) != 1: raise RuntimeError( "Expected exactly one dependency (the C source code) for KernelObjectArtifact" @@ -459,21 +501,13 @@ def compile(self, artifacts): + artifact.extra_flags + ["-c", source_file.filename, "-o", artifact.filename] ) - logging.debug(f"Running compilation command: {' '.join(cmd)}") - - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Compilation failed: {result.stderr}") - logging.debug(f"Successfully compiled: {artifact.filename}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) + commands.append(ShellCompilationCommand(cmd)) if artifact.rename_symbols: - self._rename_symbols(artifact) + commands.extend(self._rename_symbols(artifact)) + artifact.available = True - return artifacts + return commands def _rename_symbols(self, artifact): objcopy_path = "llvm-objcopy-18" @@ -486,17 +520,7 @@ def _rename_symbols(self, artifact): f"{old_sym}={new_sym}", ] cmd += [artifact.filename] - - logging.debug(f"Running renaming command: {' '.join(cmd)}") - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - logging.info(f"Successfully renamed symbols in: {artifact.filename}") - else: - raise RuntimeError(f"Symbol renaming failed: {result.stderr}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) + return [ShellCompilationCommand(cmd)] class ArchiveCompilationRule(CompilationRule): @@ -505,17 +529,13 @@ def __init__(self, peano_dir, *args, **kwargs): super().__init__(*args, **kwargs) def matches(self, artifacts): - return any( - isinstance(artifact, KernelArchiveArtifact) and len(artifact.dependencies) > 0 - for artifact in artifacts - ) + return any(artifacts.get_worklist(KernelArchiveArtifact)) def compile(self, artifacts): """Create an archive (.a) from compiled object files""" - for artifact in artifacts: - if not isinstance(artifact, KernelArchiveArtifact): - continue - + worklist = artifacts.get_worklist(KernelArchiveArtifact) + commands = [] + for artifact in worklist: # Get archive filename from method archive_path = artifact.filename object_files = [ @@ -539,17 +559,7 @@ def compile(self, artifacts): ) cmd = [str(ar_path), "rcs", archive_path] + object_files + commands.append(ShellCompilationCommand(cmd)) + artifact.available = True - if self.dry_run is None: - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode == 0: - logging.debug( - f"Successfully created archive: {Path(archive_path).name}" - ) - else: - raise RuntimeError(f"Archive creation failed: {result.stderr}") - else: - artifact.fake_available = True - self.dry_run.append(" ".join(cmd)) - - return artifacts + return commands diff --git a/operators/dequant/op.py b/operators/dequant/op.py index f235cbdd..9f19f0cd 100644 --- a/operators/dequant/op.py +++ b/operators/dequant/op.py @@ -55,7 +55,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_dequant_kernel", @@ -71,10 +71,10 @@ def set_up_artifacts(self): ) # Build the kernel object file with the appropriate tile size and group size - kernel_artifact = KernelObjectArtifact.new( + kernel_artifact = KernelObjectArtifact( f"expand_aie2_{self.tile_size}.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "expand.cc" ) ], @@ -84,13 +84,13 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[mlir_artifact, kernel_artifact], + dependencies=[mlir_artifact, kernel_artifact], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/elementwise_add/op.py b/operators/elementwise_add/op.py index bfe04f82..d98b7da5 100644 --- a/operators/elementwise_add/op.py +++ b/operators/elementwise_add/op.py @@ -44,7 +44,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_add", @@ -60,10 +60,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"add.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "add.cc" ) ], diff --git a/operators/elementwise_mul/op.py b/operators/elementwise_mul/op.py index 82cecd95..0afcd7fe 100644 --- a/operators/elementwise_mul/op.py +++ b/operators/elementwise_mul/op.py @@ -37,7 +37,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_eltwise_mul", @@ -53,10 +53,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"mul.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" ) ], diff --git a/operators/gelu/op.py b/operators/gelu/op.py index 8217017b..7733251b 100644 --- a/operators/gelu/op.py +++ b/operators/gelu/op.py @@ -41,7 +41,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_gelu", @@ -55,14 +55,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"gelu.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" / "gelu.cc" ) ], @@ -70,8 +70,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/gemm/op.py b/operators/gemm/op.py index bb7c3c79..db06b225 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -86,7 +86,7 @@ def get_mlir_artifact(self): use_scalar = self.gemm_args.get("use_scalar", False) round_conv_even = self.gemm_args.get("round_conv_even", True) separate_c_tiles = self.gemm_args.get("separate_c_tiles", False) - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{operator_name}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matmul", @@ -140,19 +140,19 @@ def get_kernel_artifacts(self): if self.c_col_maj: kernel_flags.append("-DC_COL_MAJ") return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.o", extra_flags=kernel_flags, - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( base_dir / "aie_kernels" / "aie2p" / "mm.cc" ) ], ), - KernelObjectArtifact.new( + KernelObjectArtifact( "convert_copy.o", [ - SourceArtifact.new( + SourceArtifact( base_dir / "aie_kernels" / "generic" diff --git a/operators/gemv/design.py b/operators/gemv/design.py index f988fc2e..b26fb648 100644 --- a/operators/gemv/design.py +++ b/operators/gemv/design.py @@ -12,7 +12,7 @@ import aie.dialects.memref as memref from aie.dialects.aie import * from aie.dialects.aiex import * -from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.dialects.scf import _for as range_ from aie.helpers.util import try_convert_np_type_to_mlir_type from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker from aie.iron.placers import SequentialPlacer diff --git a/operators/gemv/op.py b/operators/gemv/op.py index 95ac3fee..a938ca38 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -61,7 +61,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_matvec", @@ -84,10 +84,10 @@ def get_kernel_archive_name(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"mv_{self.K}k.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" ) ], diff --git a/operators/layer_norm/op.py b/operators/layer_norm/op.py index 36fb5256..0f83b99e 100644 --- a/operators/layer_norm/op.py +++ b/operators/layer_norm/op.py @@ -44,7 +44,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_layer_norm", @@ -58,14 +58,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"layer_norm.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" @@ -76,8 +76,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/leaky_relu/op.py b/operators/leaky_relu/op.py index e9cbc413..1cde0773 100644 --- a/operators/leaky_relu/op.py +++ b/operators/leaky_relu/op.py @@ -45,7 +45,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_leaky_relu", @@ -60,14 +60,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"leaky_relu.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" @@ -78,8 +78,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/mem_copy/op.py b/operators/mem_copy/op.py index 80bda8be..806248b1 100644 --- a/operators/mem_copy/op.py +++ b/operators/mem_copy/op.py @@ -43,7 +43,7 @@ def set_up_artifacts(self): xclbin_base_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" # Generate MLIR for xclbin (using dummy size) - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{xclbin_base_name}.mlir", import_path=operator_dir / "design.py", callback_fn="my_mem_copy", @@ -60,10 +60,10 @@ def set_up_artifacts(self): # Build kernel only if not bypass mode if not self.bypass: - kernel_artifact = KernelObjectArtifact.new( + kernel_artifact = KernelObjectArtifact( "mem_copy.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" @@ -71,20 +71,20 @@ def set_up_artifacts(self): ) ], ) - xclbin_depends = [mlir_artifact, kernel_artifact] + xclbin_dependencies = [mlir_artifact, kernel_artifact] else: - xclbin_depends = [mlir_artifact] + xclbin_dependencies = [mlir_artifact] - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{xclbin_base_name}.xclbin", - depends=xclbin_depends, + dependencies=xclbin_dependencies, extra_flags=["--dynamic-objFifos"], ) insts_file_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_{self.size}_tile_{self.tile_size}_{self.bypass_str}" - insts_artifact = InstsBinArtifact.new( + insts_artifact = InstsBinArtifact( f"{insts_file_name}.bin", - depends=[mlir_artifact], + dependencies=[mlir_artifact], extra_flags=["--dynamic-objFifos"], ) diff --git a/operators/mha/op.py b/operators/mha/op.py index 463a3062..40de3477 100644 --- a/operators/mha/op.py +++ b/operators/mha/op.py @@ -83,7 +83,7 @@ def set_up_artifacts(self): "zero_scalar_bf16": "zero_scalar_bf16_rowmaj", } - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="fused_mha", @@ -102,35 +102,35 @@ def set_up_artifacts(self): }, ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"mha.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelArchiveArtifact.new( + KernelArchiveArtifact( f"mha_kernels.a", - depends=[ - KernelObjectArtifact.new( + dependencies=[ + KernelObjectArtifact( f"mha_mm.o", extra_flags=mm_defines_colmaj, - depends=[SourceArtifact.new(mm_source)], + dependencies=[SourceArtifact(mm_source)], ), - KernelObjectArtifact.new( + KernelObjectArtifact( f"mha_mm_rowmaj.o", extra_flags=mm_defines_rowmaj, - depends=[SourceArtifact.new(mm_source)], + dependencies=[SourceArtifact(mm_source)], rename_symbols=mm_rename_symbols, ), - KernelObjectArtifact.new( + KernelObjectArtifact( "mha_softmax.o", - depends=[SourceArtifact.new(softmax_source)], + dependencies=[SourceArtifact(softmax_source)], ), - KernelObjectArtifact.new( - "mha_mha.o", depends=[SourceArtifact.new(mha_source)] + KernelObjectArtifact( + "mha_mha.o", dependencies=[SourceArtifact(mha_source)] ), - KernelObjectArtifact.new( + KernelObjectArtifact( "mha_passThrough.o", extra_flags=["-DBIT_WIDTH=16"], - depends=[SourceArtifact.new(passthrough_source)], + dependencies=[SourceArtifact(passthrough_source)], ), ], ), @@ -138,8 +138,8 @@ def set_up_artifacts(self): extra_flags=["--dynamic-objFifos"], ) - insts_artifact = InstsBinArtifact.new( - f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + insts_artifact = InstsBinArtifact( + f"mha.bin", dependencies=[mlir_artifact], extra_flags=["--dynamic-objFifos"] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/relu/op.py b/operators/relu/op.py index dfab584b..8ef73a78 100644 --- a/operators/relu/op.py +++ b/operators/relu/op.py @@ -41,7 +41,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_relu", @@ -55,14 +55,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"relu.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" / "relu.cc" ) ], @@ -70,8 +70,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/repeat/op.py b/operators/repeat/op.py index 13f8da00..88a1d1c9 100644 --- a/operators/repeat/op.py +++ b/operators/repeat/op.py @@ -43,7 +43,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="repeat", diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index 48442b9d..f961e907 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -57,7 +57,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design_weighted.py", callback_fn="my_weighted_rms_norm", @@ -76,10 +76,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"rms_norm.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" @@ -87,10 +87,10 @@ def get_kernel_artifacts(self): ) ], ), - KernelObjectArtifact.new( + KernelObjectArtifact( "mul.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" diff --git a/operators/rope/op.py b/operators/rope/op.py index b8d66487..2fdd9dca 100644 --- a/operators/rope/op.py +++ b/operators/rope/op.py @@ -45,7 +45,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="rope", @@ -63,10 +63,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"rope_{self.method_type}.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "rope.cc" ) ], diff --git a/operators/sigmoid/op.py b/operators/sigmoid/op.py index 33702eb4..2ca56bb5 100644 --- a/operators/sigmoid/op.py +++ b/operators/sigmoid/op.py @@ -42,7 +42,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_sigmoid", @@ -56,14 +56,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"sigmoid.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" @@ -74,8 +74,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/silu/op.py b/operators/silu/op.py index 42ee4fcc..456dead3 100644 --- a/operators/silu/op.py +++ b/operators/silu/op.py @@ -31,7 +31,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_silu", @@ -47,10 +47,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"silu.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" / "silu.cc" ) ], diff --git a/operators/softmax/op.py b/operators/softmax/op.py index e2443d58..1da691bd 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -38,7 +38,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="softmax", @@ -55,10 +55,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"softmax.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" / "softmax.cc" ) ], diff --git a/operators/strided_copy/op.py b/operators/strided_copy/op.py index f08b023e..8400a640 100644 --- a/operators/strided_copy/op.py +++ b/operators/strided_copy/op.py @@ -54,7 +54,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="strided_copy", diff --git a/operators/swiglu_decode/op.py b/operators/swiglu_decode/op.py index 7ddf8d6d..d035992a 100644 --- a/operators/swiglu_decode/op.py +++ b/operators/swiglu_decode/op.py @@ -84,7 +84,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemv_1_xclbin] + silu_xclbin.dependencies += [gemv_1_xclbin] artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( @@ -104,7 +104,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] + eltwise_mul_xclbin.dependencies += [silu_xclbin] artifacts.append(eltwise_mul_insts) gemv_2 = AIEGEMV( @@ -124,7 +124,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x904", ] gemv_2_xclbin.kernel_name = "swiglu_gemv_2" - gemv_2_xclbin.depends += [eltwise_mul_xclbin] + gemv_2_xclbin.dependencies += [eltwise_mul_xclbin] artifacts.append(gemv_2_xclbin) artifacts.append(gemv_2_insts) diff --git a/operators/swiglu_prefill/op.py b/operators/swiglu_prefill/op.py index 6a0d0c2b..cf6c0704 100644 --- a/operators/swiglu_prefill/op.py +++ b/operators/swiglu_prefill/op.py @@ -98,7 +98,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemm_1_xclbin] + silu_xclbin.dependencies += [gemm_1_xclbin] artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( @@ -119,7 +119,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] + eltwise_mul_xclbin.dependencies += [silu_xclbin] artifacts.append(eltwise_mul_insts) gemm_2 = AIEGEMM( @@ -137,7 +137,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x904", ] gemm_2_xclbin.kernel_name = "swiglu_gemm_2" - gemm_2_xclbin.depends += [eltwise_mul_xclbin] + gemm_2_xclbin.dependencies += [eltwise_mul_xclbin] artifacts.append(gemm_2_xclbin) artifacts.append(gemm_2_insts) diff --git a/operators/tanh/op.py b/operators/tanh/op.py index 8133ae17..5edd5bda 100644 --- a/operators/tanh/op.py +++ b/operators/tanh/op.py @@ -42,7 +42,7 @@ def set_up_artifacts(self): operator_dir = Path(__file__).parent file_name_base = f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - mlir_artifact = PythonGeneratedMLIRArtifact.new( + mlir_artifact = PythonGeneratedMLIRArtifact( f"{file_name_base}.mlir", import_path=operator_dir / "design.py", callback_fn="my_tanh", @@ -56,14 +56,14 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact.new( + xclbin_artifact = XclbinArtifact( f"{file_name_base}.xclbin", - depends=[ + dependencies=[ mlir_artifact, - KernelObjectArtifact.new( + KernelObjectArtifact( f"tanh.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "aie2p" / "tanh.cc" ) ], @@ -71,8 +71,8 @@ def set_up_artifacts(self): ], ) - insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact] + insts_artifact = InstsBinArtifact( + f"{file_name_base}.bin", dependencies=[mlir_artifact] ) self.xclbin_artifact = xclbin_artifact diff --git a/operators/transpose/op.py b/operators/transpose/op.py index 92e21828..d3db0cb5 100644 --- a/operators/transpose/op.py +++ b/operators/transpose/op.py @@ -37,7 +37,7 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact.new( + return PythonGeneratedMLIRArtifact( f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="shuffle_transpose", @@ -55,10 +55,10 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): return [ - KernelObjectArtifact.new( + KernelObjectArtifact( f"transpose_{self.m}x{self.n}.o", - depends=[ - SourceArtifact.new( + dependencies=[ + SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" From d80b726cddfa782942cdd7316a75ac29d3058856 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 17:13:18 -0700 Subject: [PATCH 40/99] towards full fused ELF + some more refactoring --- applications/llama_3.2_1b/autofuse.py | 332 +++--------------- operators/common/__init__.py | 6 +- operators/common/{aie_base.py => base.py} | 29 +- operators/common/compilation/__init__.py | 2 + .../{compilation.py => compilation/base.py} | 83 ++++- operators/common/compilation/fusion.py | 201 +++++++++++ .../common/{aie_context.py => context.py} | 11 +- ...ie_device_manager.py => device_manager.py} | 0 operators/common/fusion.py | 229 ++++++++++++ operators/elementwise_add/design.py | 4 +- operators/elementwise_mul/design.py | 4 +- operators/elementwise_mul/op.py | 3 +- operators/gemv/op.py | 12 +- operators/rms_norm/design.py | 4 +- operators/rms_norm/design_weighted.py | 6 +- operators/rope/design.py | 8 +- operators/silu/design.py | 4 +- operators/silu/op.py | 5 +- 18 files changed, 586 insertions(+), 357 deletions(-) rename operators/common/{aie_base.py => base.py} (92%) create mode 100644 operators/common/compilation/__init__.py rename operators/common/{compilation.py => compilation/base.py} (88%) create mode 100644 operators/common/compilation/fusion.py rename operators/common/{aie_context.py => context.py} (71%) rename operators/common/{aie_device_manager.py => device_manager.py} (100%) create mode 100644 operators/common/fusion.py diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index c6e56af4..067fbf79 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -8,20 +8,16 @@ import ml_dtypes import logging import time -import importlib -from aie import ir -from aie.dialects import aie, aiex, memref -from aie.extras.context import mlir_mod_ctx -import logging logging.basicConfig(level=logging.DEBUG) repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) -from operators.common.aie_context import AIEContext +from operators.common.context import AIEContext from operators.common import AIEOperatorBase, AIEBuffer, SingleMLIRSourceOperator from operators.common.utils import torch_to_numpy, numpy_to_torch from operators.common.compilation import SourceArtifact, PythonGeneratedMLIRArtifact +from operators.common.fusion import FusedMLIROperator, FusedFullELFCallable from operators import AIEGEMV from operators.elementwise_mul.op import AIEElementwiseMul from operators.silu.op import AIESiLU @@ -78,10 +74,17 @@ # Separate xclbins # --- -gemv_ffn_up_gate = gemv_ffn_up_gate_op.compile().get_callable() -gemv_ffn_down = gemv_ffn_down_op.compile().get_callable() -silu_ffn = silu_ffn_op.compile().get_callable() -eltwise_mul_ffn = eltwise_mul_ffn_op.compile().get_callable() +gemv_ffn_up_gate = None +gemv_ffn_down = None +silu_ffn = None +eltwise_mul_ffn = None + +def setup_separate_xclbins(): + global gemv_ffn_up_gate, gemv_ffn_down, silu_ffn, eltwise_mul_ffn + gemv_ffn_up_gate = gemv_ffn_up_gate_op.compile().get_callable() + gemv_ffn_down = gemv_ffn_down_op.compile().get_callable() + silu_ffn = silu_ffn_op.compile().get_callable() + eltwise_mul_ffn = eltwise_mul_ffn_op.compile().get_callable() def run_separate_xclbins(): gemv_ffn_up_gate(buf_W_ffn_gate, buf_x_norm, buf_ffn_gate) # Gate projection @@ -95,280 +98,37 @@ def run_separate_xclbins(): # Autofused # --- -class FusedMLIROperator(SingleMLIRSourceOperator): - def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): - assert all( - isinstance(op, SingleMLIRSourceOperator) and all(isinstance(buf, str) for buf in bufs) - for op, *bufs in runlist - ) - # Runlist is a list of operators and names for their buffer arguments. - # Shapes for the named buffer arguments are derived from the operator's argument specification. - # To pass data between operators, use the same buffer name in multiple operators. - # If the same buffer name is used in multiple operators, the required buffer shapes must match for each operator. - self.runlist = runlist - self.name = name - self.input_args = input_args - self.output_args = output_args - self.args = {} - self.input_buffer_size = 0 - self.output_buffer_size = 0 - self.scratch_buffer_size = 0 - self.buffer_map = {} # Maps buffer name -> {'type': 'input'|'output'|'scratch', 'offset': int, 'size': int} - self.populate_args() - AIEOperatorBase.__init__(self, *args, **kwargs) - - def populate_args(self): - for op, *bufs in self.runlist: - args_specs = op.get_arg_spec() - assert len(args_specs) == len(bufs), "Number of buffers must match operator argument specification" - for i, buf_name in enumerate(bufs): - args_spec = args_specs[i] - if buf_name not in self.args: - self.args[buf_name] = args_spec - else: - assert np.prod(self.args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" - for arg in self.input_args: - assert arg in self.args, f"Input argument {arg} not found in runlist buffers" - for arg in self.output_args: - assert arg in self.args, f"Output argument {arg} not found in runlist buffers" - - # Calculate buffer sizes for input, output, and scratch buffers - # Scratch buffers are those that are neither input nor output - scratch_args = [arg for arg in self.args if arg not in self.input_args and arg not in self.output_args] - - # Build the buffer map with offsets and sizes for each buffer type - self.input_buffer_size = self._calculate_buffer_size('input', self.input_args) - self.output_buffer_size = self._calculate_buffer_size('output', self.output_args) - self.scratch_buffer_size = self._calculate_buffer_size('scratch', scratch_args) - - def _calculate_buffer_size(self, buffer_type, args_list): - """Calculate total buffer size and populate buffer_map for a given buffer type.""" - offset = 0 - for arg in args_list: - arg_spec = self.args[arg] - size_bytes = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) - self.buffer_map[arg] = { - 'type': buffer_type, - 'offset': offset, - 'size': size_bytes - } - offset += size_bytes - return offset - - def get_operator_name(self): - return self.name - - def get_kernel_artifacts(self): - kernel_artifacts = [] - for op, *bufs in self.runlist: - kernel_artifacts.extend(op.get_kernel_artifacts()) - return kernel_artifacts - - @staticmethod - def get_child_mlir_module(artifact): - assert isinstance(artifact, PythonGeneratedMLIRArtifact) - # Import the Python source file - spec = importlib.util.spec_from_file_location( - Path(artifact.import_path).name, artifact.import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - if artifact.requires_context: - raise NotImplementedError("Not handled, make your operator return a ctx.module") - callback_function = getattr(module, artifact.callback_fn) - mlir_module = callback_function( - *artifact.callback_args, **artifact.callback_kwargs - ) - return mlir_module - - def get_mlir_artifact(self): - device_mlir_strings = {} # op -> device str - device_ty = None - # FIXME: The proper way for this would be to create a new type of artifact (FusedMLIRArtifact) and a new compilation rule that does what this function steps _only if_ the fused MLIR file doesn't exist yet. - # As it stands, we're regenerating it on each run. - for runlist_op, *bufs in self.runlist: - if runlist_op in device_mlir_strings: - continue - artifact = runlist_op.get_mlir_artifact() - mlir_module = self.get_child_mlir_module(artifact) - for op in mlir_module.body.operations: - if not isinstance(op, aie.DeviceOp): - continue - if device_ty is None: - device_ty = op.device - # else: - # assert device_ty == op.device, "All operators in a fused operator must target the same type of AIE" - device_mlir_strings[runlist_op] = str(op) - - device_names = {} # op -> str - sequence_arg_types = {} # op -> list of expected arg types - with mlir_mod_ctx() as ctx: - for i, (runlist_op, device_str) in enumerate(device_mlir_strings.items()): - dev_op = aie.DeviceOp.parse(device_str) - device_names[runlist_op] = f"dev{i}" - dev_op.sym_name = ir.StringAttr.get(device_names[runlist_op]) - - # Extract the runtime sequence argument types - # Look for aie.runtime_sequence operations - found_sequence = False - for nested_op in dev_op.body_region.blocks[0].operations: - op_name = nested_op.operation.name - if op_name == 'aie.runtime_sequence': - # Found the runtime sequence - need to extract argument types - # The runtime_sequence contains a region with the actual function - if hasattr(nested_op, 'body') and hasattr(nested_op.body, 'blocks'): - # Look for the entry block which has the arguments - if len(nested_op.body.blocks) > 0: - entry_block = nested_op.body.blocks[0] - # Extract argument types from the block arguments - arg_types = [entry_block.arguments[i].type for i in range(len(entry_block.arguments))] - sequence_arg_types[runlist_op] = arg_types - found_sequence = True - break - - if not found_sequence: - raise RuntimeError(f"Could not find runtime sequence or extract argument types for operator {runlist_op}") - - ctx.module.body.append(dev_op) - @aie.device(device_ty) - def main(): - # Argument 0 is the input buffer, argument 1 is the output buffer, argument 2 is scratch space for intermediate values. - # All buffers are bf16, so convert byte sizes to element counts - bf16_itemsize = 2 - @aiex.runtime_sequence( - np.ndarray[(self.input_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], - np.ndarray[(self.output_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], - np.ndarray[(self.scratch_buffer_size // bf16_itemsize,), np.dtype[ml_dtypes.bfloat16]], - ) - def sequence(input_buf, output_buf, scratch_buf): - # Map buffer type to the appropriate consolidated buffer argument - consolidated_buffers = { - 'input': input_buf, - 'output': output_buf, - 'scratch': scratch_buf - } - - for runlist_op, *bufs in self.runlist: - configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(device_names[runlist_op]) - configure_op = aiex.ConfigureOp(configure_sym_ref_attr) - configure_body = configure_op.body.blocks.append() - with ir.InsertionPoint(configure_body): - # Get the expected argument types for this operator's sequence - expected_arg_types = sequence_arg_types.get(runlist_op, []) - - # Generate subviews and reinterpret_casts for each buffer argument - buffer_ssa_values = [] - bf16_itemsize = 2 - expected_arg_types = sequence_arg_types.get(runlist_op, None) - - if expected_arg_types is None: - raise RuntimeError(f"No runtime sequence argument types found for operator {runlist_op}") - - for idx, buf_name in enumerate(bufs): - buf_info = self.buffer_map[buf_name] - buf_spec = self.args[buf_name] - - # Get the consolidated buffer this belongs to (already bf16) - consolidated_buf = consolidated_buffers[buf_info['type']] - - # Convert byte offsets/sizes to bf16 element offsets/sizes - offset_elements = buf_info['offset'] // bf16_itemsize - size_elements = buf_info['size'] // bf16_itemsize - - # Create subview from the bf16 buffer - subview = memref.subview( - consolidated_buf, - [offset_elements], - [size_elements], - [1] - ) - - # Get target shape from the expected argument type - if idx >= len(expected_arg_types): - raise RuntimeError(f"No expected type for argument {idx} (buffer {buf_name}) of operator {runlist_op}") - - target_type = expected_arg_types[idx] - expected_memref = ir.MemRefType(target_type) - target_shape = [expected_memref.shape[i] for i in range(expected_memref.rank)] - - # Verify the size matches - expected_size = np.prod(target_shape) - if expected_size != size_elements: - raise ValueError(f"Size mismatch for buffer {buf_name}: expected {expected_size} elements, got {size_elements}") - - # Build strides (assuming row-major layout) - strides = [] - stride = 1 - for dim in reversed(target_shape): - strides.insert(0, stride) - stride *= dim - - # Build the result memref type with target shape (bf16) - result_type = ir.MemRefType.get(target_shape, ir.BF16Type.get()) - - # Reinterpret_cast to reset offset to 0 and reshape - reinterpreted = memref.reinterpret_cast( - result=result_type, - source=subview, - offsets=[], - sizes=[], - strides=[], - static_offsets=[0], - static_sizes=target_shape, - static_strides=strides - ) - - buffer_ssa_values.append(reinterpreted) - - # Run the sequence with the prepared buffers - sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") - run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) - - filename = self.get_operator_name() + "_fused.mlir" - with open(filename, "w") as f: - f.write(str(ctx.module)) - src_artifact = SourceArtifact(Path(filename)) - src_artifact.fake_available = True - return src_artifact - - - def get_arg_spec(self): - pass - -class FusedFullELFCallable: - def __init__(self, op): - self.op = op - - def __call__(self, *kwargs): - assert all(kw in self.op.args for kw in kwargs), "at least one unknown argument passed" - assert all(kw in kwargs for kw in self.op.input_args), "not all input arguments passed" - assert all(kw in kwargs for kw in self.op.output_args), "not all output arguments passed" - - -swiglu_fused_op = FusedMLIROperator( - "swiglu", - [ - (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "inter_ffn_gate"), - (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "inter_ffn_up"), - (silu_ffn_op, "inter_ffn_gate", "inter_ffn_gate"), - (eltwise_mul_ffn_op, "inter_ffn_gate", "inter_ffn_up", "inter_ffn_hidden"), - (gemv_ffn_down_op, "W_ffn_down", "inter_ffn_hidden", "ffn_output"), - ], - input_args=[ - "x_norm", - "W_ffn_gate", - "W_ffn_up", - "W_ffn_down" - ], - output_args=[ - "ffn_output" - ] -) -#swiglu_fused = swiglu_fused_op.compile().get_callable() - -#def run_autofused(): -# swiglu_fused() +def setup_autofused(): + global swiglu_fused_op, swiglu_fused + swiglu_fused_op = FusedMLIROperator( + "swiglu", + [ + (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "inter_ffn_gate"), + (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "inter_ffn_up"), + (silu_ffn_op, "inter_ffn_gate", "inter_ffn_gate"), + (eltwise_mul_ffn_op, "inter_ffn_gate", "inter_ffn_up", "inter_ffn_hidden"), + (gemv_ffn_down_op, "W_ffn_down", "inter_ffn_hidden", "ffn_output"), + ], + input_args=[ + "x_norm", + "W_ffn_gate", + "W_ffn_up", + "W_ffn_down" + ], + output_args=[ + "ffn_output" + ] + ) + swiglu_fused = swiglu_fused_op.compile().get_callable() + +def run_autofused(): + swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = buf_x_norm.view_as_torch() + swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = buf_W_ffn_gate.view_as_torch() + swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = buf_W_ffn_up.view_as_torch() + swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = buf_W_ffn_down.view_as_torch() + swiglu_fused.get_buffer("ffn_output").view_as_torch()[:] = buf_ffn_output.view_as_torch() + swiglu_fused() + return swiglu_fused.get_buffer("ffn_output").view_as_torch() # CPU # --- @@ -391,6 +151,8 @@ def run_cpu(): # Main # --- -#print(run_autofused()) +setup_autofused() +print(run_autofused()) +setup_separate_xclbins() print(run_separate_xclbins()) print(run_cpu()) diff --git a/operators/common/__init__.py b/operators/common/__init__.py index b70aba2c..9351fd68 100644 --- a/operators/common/__init__.py +++ b/operators/common/__init__.py @@ -3,14 +3,14 @@ """Common utilities and base classes for IRON operators.""" -from .aie_base import ( +from .base import ( AIEOperatorBase, SingleMLIRSourceOperator, AIEBuffer, SingleXclbinCallable, AIERuntimeArgSpec, ) -from .aie_context import AIEContext +from .context import AIEContext from .compilation import ( XclbinArtifact, InstsBinArtifact, @@ -19,4 +19,4 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) -from .aie_device_manager import AIEDeviceManager +from .device_manager import AIEDeviceManager diff --git a/operators/common/aie_base.py b/operators/common/base.py similarity index 92% rename from operators/common/aie_base.py rename to operators/common/base.py index 99ff2890..05efad25 100644 --- a/operators/common/aie_base.py +++ b/operators/common/base.py @@ -12,8 +12,8 @@ import aie.utils.config from . import compilation as comp -from .aie_context import AIEContext -from .aie_device_manager import AIEDeviceManager, pyxrt +from .context import AIEContext +from .device_manager import AIEDeviceManager, pyxrt from .utils import numpy_to_torch, torch_to_numpy from .compilation import ( XclbinArtifact, @@ -66,22 +66,9 @@ def compile(self, dry_run=False): Set up the operator and compile any necessary artifacts. Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. """ - context = self.context self.set_up_artifacts() - compilation_rules = [ - comp.GenerateMLIRFromPythonCompilationRule(), - comp.PeanoCompilationRule( - context.peano_dir, context.mlir_aie_dir - ), - comp.ArchiveCompilationRule(context.peano_dir), - comp.AieccXclbinInstsCompilationRule( - context.build_dir, - context.peano_dir, - context.mlir_aie_dir, - ), - ] artifacts = comp.CompilationArtifactGraph(self.artifacts) - comp.compile(compilation_rules, artifacts, context.build_dir, dry_run=dry_run) + comp.compile(self.context.compilation_rules, artifacts, self.context.build_dir, dry_run=dry_run) return self def add_artifacts(self, artifacts): @@ -106,6 +93,7 @@ def execute_runlist(runlist): class SingleMLIRSourceOperator(AIEOperatorBase, ABC): """Base class for AIE-accelerated operations""" def __init__(self, *args, **kwargs): + self.kernel_archive = f"{self.get_operator_name()}_kernels.a" AIEOperatorBase.__init__(self, *args, **kwargs) @abstractmethod @@ -120,16 +108,17 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): pass - def get_kernel_archive_name(self): - return self.get_operator_name() + ".a" - def get_artifacts(self): operator_name = self.get_operator_name() mlir_artifact = self.get_mlir_artifact() kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels + # Also not handling name collisions of kernels with the same name + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive kernel_deps = [ KernelArchiveArtifact( - self.get_kernel_archive_name(), + self.kernel_archive, dependencies=kernel_deps_inputs, ) ] if kernel_deps_inputs else [] diff --git a/operators/common/compilation/__init__.py b/operators/common/compilation/__init__.py new file mode 100644 index 00000000..a0c2b126 --- /dev/null +++ b/operators/common/compilation/__init__.py @@ -0,0 +1,2 @@ +from .base import * +from .fusion import * \ No newline at end of file diff --git a/operators/common/compilation.py b/operators/common/compilation/base.py similarity index 88% rename from operators/common/compilation.py rename to operators/common/compilation/base.py index a0565943..9d02cccb 100644 --- a/operators/common/compilation.py +++ b/operators/common/compilation/base.py @@ -4,29 +4,31 @@ """ This file implements a simple Python-based build system. You specify what you want to compile (*artifacts*) through subclasses of `CompilationArtifact`. -Each artifact can have a list of depenencies of other artifacts that it relies -on. Each artifact corresponds to exactly one file. If a file with a matching -name already exists, and all its dependencies are built and older than the file, -then the existing file will be reused. - -For each file name, artifacts are singletons. You create artifacts by calling -the `new` class method of the appropriate class. This ensures that artifact -objects are uniqued, i.e., calling `new` twice with the same file name will -return the same object. +Multiple `CompilationArtifacts` form a `CompilationArtifactGraph`. Each artifact +can have a list (subgraph) of depenencies of other artifacts that it relies on. +Each artifact corresponds to exactly one file. There is a special artifact for source files that do not need to get generated, `SourceArtifact`. It is likely that in your compilation dependency graph, the leaf nodes will be `SourceArtifact`s. You specify how to generate (compile) an artifact through *rules*, which are -expressed as subclasses of `CompilationRule`. This class requires you to -implement two methods: `matches` and `compile`. During compilation, we will -call `matches` on the set of remaining artifacts to see if the given rule is -able to produce any of the artifacts not available yet. If this function -returns `True`, we will call `compile` on the rule to generate the artifact. -`compile` returns a new list of artifacts, which may be the same one as -before; however, if `matches()==True`, at least one of the artifacts in the -list must be made available after calling `compile()`. +expressed as subclasses of `CompilationRule`. Rules must implement two methods: +`matches` and `compile`. If a rule `matches` to an artifact graph, it can be +applied. Applying a rule is done by calling `compile`; this transforms the +artifact graph (in the simplest case, marks one of the artifacts as available) +and returns a list of compilation commands. + +At this point, we can print the compilation commands to the console (dry-run) +or actually run them to generate the artifacts. + +Before starting compilation, you may call +`populate_availability_from_filesystem()` -- this will check if any artifacts +are already available at the given file paths (and ensure that dependencies are +as old or older than the artifacts that depend on them). This way, you can avoid +recompiling artifacts that are already up-to-date on disk. If you wish to +regenerate everything, you can skip this step, but will at a minimum want to +mark the `SourceArtifact`s as available -- they cannot be generated. """ from abc import ABC, abstractmethod @@ -39,6 +41,7 @@ from contextlib import nullcontext from aie.extras.context import mlir_mod_ctx import copy +import sys # Global Functions @@ -214,6 +217,14 @@ class SourceArtifact(CompilationArtifact): pass +class FullElfArtifact(CompilationArtifact): + def __init__(self, filename, mlir_input, dependencies): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.mlir_input = mlir_input + + class XclbinArtifact(CompilationArtifact): def __init__( self, filename, mlir_input, dependencies, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None @@ -256,6 +267,8 @@ def __init__( callback_args=None, callback_kwargs=None, requires_context=False, + uses_kernel_archive=False, + kernel_archive=None ): self.import_path = import_path self.callback_fn = callback_fn @@ -297,6 +310,9 @@ def run(self) -> bool: cwd=self.cwd, env=self.env, ) + if 0 != result.returncode: + print(result.stdout) + print(result.stderr, file=sys.stderr) return 0 == result.returncode def __repr__(self): @@ -380,13 +396,44 @@ def generate_mlir(output_artifact, import_path, callback_fn, callback_args=None, f.write(mlir_code) -class AieccXclbinInstsCompilationRule(CompilationRule): +class AieccCompilationRule(CompilationRule, ABC): def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): self.build_dir = build_dir self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" self.peano_dir = peano_dir super().__init__(*args, **kwargs) +class AieccFullElfCompilationRule(AieccCompilationRule): + def matches(self, graph): + return any(graph.get_worklist(FullElfArtifact)) + + def compile(self, graph): + worklist = graph.get_worklist(FullElfArtifact) + commands = [] + + for artifact in worklist: + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + "--expand-load-pdis", + "--generate-full-elf", + "--full-elf-name", + artifact.filename, + artifact.mlir_input.filename, + ] + commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) + artifact.available = True + + return commands + + +class AieccXclbinInstsCompilationRule(AieccCompilationRule): def matches(self, graph): return any(graph.get_worklist((XclbinArtifact, InstsBinArtifact))) diff --git a/operators/common/compilation/fusion.py b/operators/common/compilation/fusion.py new file mode 100644 index 00000000..e60027c8 --- /dev/null +++ b/operators/common/compilation/fusion.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Temporal fusion of multiple MLIR modules into one module with multiple devices and a main runtime sequence that calls into them. +""" + +import numpy as np +import importlib.util +from pathlib import Path +from aie import ir +from aie.dialects import aie, aiex, memref +from aie.extras.context import mlir_mod_ctx +import ml_dtypes + +from . import ( + CompilationArtifact, + CompilationRule, + CompilationCommand, + PythonCallbackCompilationCommand, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + + +# Compilation Artifacts +# ########################################################################## + + +class FusedMLIRSource(CompilationArtifact): + def __init__(self, filename, operator_mlir_map, runlist, subbuffer_layout, buffer_sizes): + dependencies = list(operator_mlir_map.values()) + super().__init__(filename, dependencies) + self.operator_mlir_map = operator_mlir_map + self.runlist = runlist + self.subbuffer_layout = subbuffer_layout + self.buffer_sizes = buffer_sizes + + +# Helper Functions +# ########################################################################## + + +def extract_runtime_sequence_arg_types(dev_op): + """MLIR helper: Extract argument types from a device operation's runtime sequence.""" + for nested_op in dev_op.body_region.blocks[0].operations: + op_name = nested_op.operation.name + if op_name == 'aie.runtime_sequence': + if hasattr(nested_op, 'body') and hasattr(nested_op.body, 'blocks'): + if len(nested_op.body.blocks) > 0: + entry_block = nested_op.body.blocks[0] + arg_types = [entry_block.arguments[i].type for i in range(len(entry_block.arguments))] + return arg_types + raise RuntimeError("Could not find runtime sequence in device operation") + + +def get_child_mlir_module(mlir_artifact): + """Extract MLIR module from a PythonGeneratedMLIRArtifact.""" + assert isinstance(mlir_artifact, PythonGeneratedMLIRArtifact) + spec = importlib.util.spec_from_file_location( + Path(mlir_artifact.import_path).name, mlir_artifact.import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if mlir_artifact.requires_context: + raise NotImplementedError("Not handled, make your operator return a ctx.module") + + callback_function = getattr(module, mlir_artifact.callback_fn) + mlir_module = callback_function( + *mlir_artifact.callback_args, **mlir_artifact.callback_kwargs + ) + return mlir_module + + +def fuse_mlir(artifact): + """Fuse multiple MLIR modules by inlining their device operations and adding a new main device and runtime sequence that call into sequence of operations based on a runlist.""" + + input_buffer_size, output_buffer_size, scratch_buffer_size = artifact.buffer_sizes + + # Extract device operations from each operator's MLIR artifact + device_mlir_strings = {} + device_ty = None + sequence_arg_types = {} + for op_name, mlir_artifact in artifact.operator_mlir_map.items(): + mlir_module = get_child_mlir_module(mlir_artifact) + device_ops = [op for op in mlir_module.body.operations if isinstance(op, aie.DeviceOp)] + assert len(device_ops) == 1, f"Expected exactly one device operation in MLIR artifact for operator '{op_name}'" + device_op = device_ops[0] + if device_ty is None: + device_ty = device_op.device + device_mlir_strings[op_name] = str(device_op) + sequence_arg_types[op_name] = extract_runtime_sequence_arg_types(device_op) + + # Build fused MLIR module + with mlir_mod_ctx() as ctx: + + # Concatenate aie.device ops + for op_name, device_str in device_mlir_strings.items(): + dev_op = aie.DeviceOp.parse(device_str) + dev_op.sym_name = ir.StringAttr.get(op_name) + ctx.module.body.append(dev_op) + + # Create the main device -- this contains the runtime sequence calling into the other devices + @aie.device(device_ty) + def main(): + buf_dtype = np.dtype[ml_dtypes.bfloat16] # TODO: support for other data types + itemsize = 2 + + # RuntimeSequenceOp + @aiex.runtime_sequence( + np.ndarray[(input_buffer_size // itemsize,), buf_dtype], + np.ndarray[(output_buffer_size // itemsize,), buf_dtype], + np.ndarray[(scratch_buffer_size // itemsize,), buf_dtype], + ) + def sequence(input_buf, output_buf, scratch_buf): + consolidated_buffers = { + 'input': input_buf, + 'output': output_buf, + 'scratch': scratch_buf + } + + # Execute operations in runlist order + for op_name, *buffer_names in artifact.runlist: + expected_arg_types = sequence_arg_types[op_name] + + # Configure Op + configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) + configure_op = aiex.ConfigureOp(configure_sym_ref_attr) # TODO: optimization -- if previous op was in the same device, skip reconfiguration + configure_body = configure_op.body.blocks.append() + with ir.InsertionPoint(configure_body): + + # For each buffer, add subview and reinterpret_cast ops + buffer_ssa_values = [] + for idx, buf_name in enumerate(buffer_names): + buf_type, offset, length = artifact.subbuffer_layout[buf_name] + + # Subview Op + consolidated_buf = consolidated_buffers[buf_type] + offset_elements = offset // itemsize + size_elements = length // itemsize + subview = memref.subview( + consolidated_buf, + [offset_elements], + [size_elements], + [1] + ) + + # Reinterpret_cast Op + target_type = expected_arg_types[idx] + expected_memref = ir.MemRefType(target_type) + target_shape = [expected_memref.shape[i] for i in range(expected_memref.rank)] + expected_size = np.prod(target_shape) + assert expected_size == size_elements, f"Size mismatch for buffer '{buf_name}': MLIR runtime sequence expected {expected_size}, Python fused operator provided {size_elements}" + strides = [] + stride = 1 + for dim in reversed(target_shape): + strides.insert(0, stride) + stride *= dim + result_type = ir.MemRefType.get(target_shape, ir.BF16Type.get()) + reinterpreted = memref.reinterpret_cast( + result=result_type, + source=subview, + offsets=[], + sizes=[], + strides=[], + static_offsets=[0], + static_sizes=target_shape, + static_strides=strides + ) + buffer_ssa_values.append(reinterpreted) + + # Run Op + sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") + run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) + + # Write the fused MLIR to file + with open(artifact.filename, "w") as f: + f.write(str(ctx.module)) + + +# Compilation Rules +# ########################################################################## + + +class FuseMLIRCompilationRule(CompilationRule): + """Compilation rule that fuses multiple MLIR modules into one.""" + + def matches(self, graph): + return any(graph.get_worklist(FusedMLIRSource)) + + def compile(self, graph): + commands = [] + worklist = graph.get_worklist(FusedMLIRSource) + for artifact in worklist: + callback = lambda: fuse_mlir(artifact) + commands.append(PythonCallbackCompilationCommand(callback)) + new_artifact = SourceArtifact(artifact.filename) + new_artifact.available = True + graph.replace(artifact, new_artifact) + return commands diff --git a/operators/common/aie_context.py b/operators/common/context.py similarity index 71% rename from operators/common/aie_context.py rename to operators/common/context.py index a8f6e516..46987da4 100644 --- a/operators/common/aie_context.py +++ b/operators/common/context.py @@ -6,7 +6,7 @@ from pathlib import Path import os -from .aie_device_manager import AIEDeviceManager, pyxrt +from .device_manager import AIEDeviceManager, pyxrt from . import compilation as comp import aie.utils.config @@ -24,7 +24,14 @@ def __init__(self, use_runlist=True): self.peano_dir = Path(aie.utils.config.peano_install_dir()) # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) self.use_runlist = use_runlist - self._runtime_prepared = False + self.compilation_rules = [ + comp.FuseMLIRCompilationRule(), + comp.GenerateMLIRFromPythonCompilationRule(), + comp.PeanoCompilationRule(self.peano_dir, self.mlir_aie_dir), + comp.ArchiveCompilationRule(self.peano_dir), + comp.AieccXclbinInstsCompilationRule(self.build_dir, self.peano_dir, self.mlir_aie_dir), + comp.AieccFullElfCompilationRule(self.build_dir, self.peano_dir, self.mlir_aie_dir), + ] def register_operator(self, operator): """Register an operator with this context""" diff --git a/operators/common/aie_device_manager.py b/operators/common/device_manager.py similarity index 100% rename from operators/common/aie_device_manager.py rename to operators/common/device_manager.py diff --git a/operators/common/fusion.py b/operators/common/fusion.py new file mode 100644 index 00000000..53aef1d7 --- /dev/null +++ b/operators/common/fusion.py @@ -0,0 +1,229 @@ +import numpy as np +import ml_dtypes +import pyxrt +from . import compilation as comp +from .base import AIEOperatorBase, SingleMLIRSourceOperator, AIEBuffer +from .device_manager import AIEDeviceManager + +# Fused Operator +# ########################################################################## + + +class FusedMLIROperator(AIEOperatorBase): + """Operator that fuses multiple SingleMLIRSourceOperators into one.""" + + def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): + assert all( + isinstance(op, SingleMLIRSourceOperator) and all(isinstance(buf, str) for buf in bufs) + for op, *bufs in runlist + ) + self.runlist = runlist + self.name = name + self.input_args = input_args + self.output_args = output_args + self.kernel_archive = "kernels.a" + super().__init__(*args, **kwargs) + + def get_operator_name(self): + return self.name + + def get_kernel_artifacts(self): + """Collect all kernel artifacts from child operators.""" + kernel_artifacts = [] + for op, *bufs in self.runlist: + kernel_artifacts.extend(op.get_kernel_artifacts()) + return kernel_artifacts + + def get_mlir_artifact(self): + # Build operator_mlir_map: {op_name -> PythonGeneratedMLIRArtifact} + operator_mlir_map = {} + mlir_dependencies = [] + comp_runlist = [] + + for idx, (op, *bufs) in enumerate(self.runlist): + mlir_artifact = op.get_mlir_artifact() + if len(op.get_kernel_artifacts()) > 0: + # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels + # Also not handling name collisions of kernels with the same name + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + op_name = f"{op.get_operator_name()}_{idx}" + operator_mlir_map[op_name] = mlir_artifact + comp_runlist.append((op_name, *bufs)) + + # Calculate buffer layout: {buffer_name -> (type, offset, length)} + self.subbuffer_layout, self.buffer_sizes = self._calculate_buffer_layout() + + filename = self.get_operator_name() + "_fused.mlir" + fused_artifact = comp.FusedMLIRSource( + filename, + operator_mlir_map=operator_mlir_map, + runlist=comp_runlist, + subbuffer_layout=self.subbuffer_layout, + buffer_sizes=self.buffer_sizes + ) + + return fused_artifact + + def _calculate_buffer_layout(self): + args = {} + + # Collect all buffer specs from operators + for op, *bufs in self.runlist: + args_specs = op.get_arg_spec() + assert len(args_specs) == len(bufs), "Number of buffers must match operator argument specification" + for i, buf_name in enumerate(bufs): + args_spec = args_specs[i] + if buf_name not in args: + args[buf_name] = args_spec + else: + assert np.prod(args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" + + # Verify all input/output args are present + for arg in self.input_args: + assert arg in args, f"Input argument {arg} not found in runlist buffers" + for arg in self.output_args: + assert arg in args, f"Output argument {arg} not found in runlist buffers" + + # Determine buffer types + subbuffer_layout = {} + + def add_buffers(buffer_type, args_list): + offset = 0 + for arg in args_list: + arg_spec = args[arg] + length = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) + subbuffer_layout[arg] = (buffer_type, offset, length) + offset += length + return offset # == total length + + input_buffer_size = add_buffers('input', self.input_args) + output_buffer_size = add_buffers('output', self.output_args) + scratch_args = [arg for arg in args if arg not in self.input_args and arg not in self.output_args] + scratch_buffer_size = add_buffers('scratch', scratch_args) + + buffer_sizes = (input_buffer_size, output_buffer_size, scratch_buffer_size) + return subbuffer_layout, buffer_sizes + + def set_up_artifacts(self): + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_objects = self.get_kernel_artifacts() + kernel_dep = [comp.KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_objects, + )] if kernel_objects else [] + full_elf_artifact = comp.FullElfArtifact( + f"{operator_name}.elf", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_dep, + ) + self.add_artifacts([full_elf_artifact]) + + def get_arg_spec(self): + pass + + def get_callable(self): + """Return a callable for the fused operator (stub for now).""" + return FusedFullELFCallable(self) + + +class FullELFCallable: + def __init__(self, op, device_name="main", sequence_name="sequence", device_manager=None): + # std::string kernelName = "main:sequence"; + # xrt::elf ctx_elf{"aie.elf"}; + # xrt::hw_context context = xrt::hw_context(device, ctx_elf); + # auto kernel = xrt::ext::kernel(context, kernelName); + self.device_manager = device_manager or AIEDeviceManager() + self.xrt_elf = pyxrt.elf(op.artifacts[0].filename) + self.xrt_module = pyxrt.module(self.xrt_elf) + #self.xrt_context = self.xrt_module.get_hw_context() + self.xrt_context = pyxrt.hw_context(self.device_manager.device, self.xrt_module.get_cfg_uuid()) + self.xrt_kernel = pyxrt.kernel(self.xrt_context, f"{device_name}:{sequence_name}") + + def __call__(self, *args): + pass + #run = pyxrt.run(self.xrt_kernel) + #for i, arg in enumerate(args): + # assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" + # run.set_arg(i, arg) + #run.start() + #run.wait2() + +class FusedFullELFCallable(FullELFCallable): + def __init__(self, op, device_manager=None): + super().__init__(op, device_manager=device_manager) + + self.subbuffer_layout = op.subbuffer_layout + self.buffer_sizes = op.buffer_sizes + + input_buffer_size, output_buffer_size, scratch_buffer_size = self.buffer_sizes + bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + + self.input_buffer = AIEBuffer( + shape=(input_buffer_size // bf16_itemsize,), + dtype=ml_dtypes.bfloat16 + ) if input_buffer_size > 0 else None + + self.output_buffer = AIEBuffer( + shape=(output_buffer_size // bf16_itemsize,), + dtype=ml_dtypes.bfloat16 + ) if output_buffer_size > 0 else None + + self.scratch_buffer = AIEBuffer( + shape=(scratch_buffer_size // bf16_itemsize,), + dtype=ml_dtypes.bfloat16 + ) if scratch_buffer_size > 0 else None + + self._buffer_cache = {} + + def get_buffer(self, buffer_name): + # Return cached buffer if already allocated + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + + # Look up buffer information + if buffer_name not in self.subbuffer_layout: + raise KeyError(f"Buffer '{buffer_name}' not found in buffer layout") + + buf_type, offset, length = self.subbuffer_layout[buffer_name] + + # Select the appropriate main buffer + if buf_type == 'input': + main_buffer = self.input_buffer + elif buf_type == 'output': + main_buffer = self.output_buffer + elif buf_type == 'scratch': + main_buffer = self.scratch_buffer + else: + raise ValueError(f"Unknown buffer type '{buf_type}' for buffer '{buffer_name}'") + + if main_buffer is None: + raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") + + # Convert byte offset/length to element offset/length + bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + offset_elements = offset // bf16_itemsize + length_elements = length // bf16_itemsize + + # Create subbuffer with appropriate shape + # For now, use 1D shape; could be enhanced to use actual buffer shapes + sub_buffer = main_buffer.subbuffer( + length=length_elements, + offset=offset_elements, + shape=(length_elements,), + dtype=ml_dtypes.bfloat16 + ) + + # Cache and return + self._buffer_cache[buffer_name] = sub_buffer + return sub_buffer + + def __call__(self): + self.input_buffer.to("npu") + self.output_buffer.to("npu") + self.scratch_buffer.to("npu") + super().__call__( + self.input_buffer.xrt_bo if self.input_buffer else None, + self.output_buffer.xrt_bo if self.output_buffer else None, + self.scratch_buffer.xrt_bo if self.scratch_buffer else None, + ) diff --git a/operators/elementwise_add/design.py b/operators/elementwise_add/design.py index 0aa6d6b1..d8790dfa 100644 --- a/operators/elementwise_add/design.py +++ b/operators/elementwise_add/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, archive_name): +def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +37,7 @@ def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, archiv # AIE Core Function declaration eltwise_add_bf16_vector = Kernel( - "eltwise_add_bf16_vector", archive_name, [tile_ty, tile_ty, tile_ty, np.int32] + "eltwise_add_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_mul/design.py b/operators/elementwise_mul/design.py index f2a2a266..cd30bbef 100644 --- a/operators/elementwise_mul/design.py +++ b/operators/elementwise_mul/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, archive_name): +def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +37,7 @@ def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, archiv # AIE Core Function declaration eltwise_mul_bf16_vector = Kernel( - "eltwise_mul_bf16_vector", archive_name, [tile_ty, tile_ty, tile_ty, np.int32] + "eltwise_mul_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_mul/op.py b/operators/elementwise_mul/op.py index 0afcd7fe..d7557a2e 100644 --- a/operators/elementwise_mul/op.py +++ b/operators/elementwise_mul/op.py @@ -46,8 +46,7 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0, - self.get_kernel_archive_name(), + 0 ], ) diff --git a/operators/gemv/op.py b/operators/gemv/op.py index a938ca38..3771a3e3 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -56,7 +56,7 @@ def __init__( SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): - return f"{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_batches}batch_{self.num_aie_columns}col" + return f"gemv_{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_batches}batch_{self.num_aie_columns}col" def get_mlir_artifact(self): operator_dir = Path(__file__).parent @@ -73,19 +73,13 @@ def get_mlir_artifact(self): self.tile_size_input, self.tile_size_output, self.num_batches, - ], - callback_kwargs={ - "kernel_archive": self.get_kernel_archive_name(), - } + ] ) - def get_kernel_archive_name(self): - return f"mv_{self.K}k.a" - def get_kernel_artifacts(self): return [ KernelObjectArtifact( - f"mv_{self.K}k.o", + f"gemv_{self.K}k.o", dependencies=[ SourceArtifact( self.context.base_dir / "aie_kernels" / "generic" / "mv.cc" diff --git a/operators/rms_norm/design.py b/operators/rms_norm/design.py index 499c52d5..4b3a2da4 100644 --- a/operators/rms_norm/design.py +++ b/operators/rms_norm/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size, archive_name="rms_norm.a"): +def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size, kernel_archive="rms_norm.a"): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -46,7 +46,7 @@ def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_s # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", archive_name, [tile_ty, tile_ty, np.int32] + "rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/rms_norm/design_weighted.py b/operators/rms_norm/design_weighted.py index 6a2929f8..8f525abe 100644 --- a/operators/rms_norm/design_weighted.py +++ b/operators/rms_norm/design_weighted.py @@ -16,7 +16,7 @@ def my_weighted_rms_norm( - dev, num_elements, num_columns, num_channels, weight_length, trace_size, archive_name="rms_norm.a" + dev, num_elements, num_columns, num_channels, weight_length, trace_size, kernel_archive="rms_norm.a" ): per_tile_elements = weight_length total_cores = num_columns # For each core that does rms norm, another core will take its output to do eltwise mul @@ -53,11 +53,11 @@ def my_weighted_rms_norm( # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", archive_name, [tile_ty, tile_ty, np.int32] + "rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] ) eltwise_mul_kernel = Kernel( "eltwise_mul_bf16_vector", - archive_name, + kernel_archive, [tile_ty, weights_ty, tile_ty, np.int32], ) diff --git a/operators/rope/design.py b/operators/rope/design.py index 0e85e1ef..a2126783 100644 --- a/operators/rope/design.py +++ b/operators/rope/design.py @@ -38,14 +38,14 @@ def rope( num_aie_columns=1, trace_size=0, method_type=None, - archive_name=None, + kernel_archive=None, ): dtype = bfloat16 if angle_rows is None: angle_rows = rows - if archive_name is None: - archive_name = "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" + if kernel_archive is None: + kernel_archive = "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" assert cols % (16 * 2) == 0 and cols >= ( 16 * 2 @@ -78,7 +78,7 @@ def rope( # AIE Core Function declaration rope_kernel = Kernel( "rope", - archive_name, + kernel_archive, [tensor_tile_ty, angle_tile_ty, tensor_tile_ty, np.int32], ) diff --git a/operators/silu/design.py b/operators/silu/design.py index ce85517e..8e9532dd 100644 --- a/operators/silu/design.py +++ b/operators/silu/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_silu(dev, size, num_columns, tile_size, trace_size, archive_name): +def my_silu(dev, size, num_columns, tile_size, trace_size, kernel_archive): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] @@ -45,7 +45,7 @@ def my_silu(dev, size, num_columns, tile_size, trace_size, archive_name): # External, binary kernel definition silu_fcn = Kernel( "silu_bf16", - archive_name, + kernel_archive, [line_type, line_type, np.int32], ) diff --git a/operators/silu/op.py b/operators/silu/op.py index 456dead3..8adc6f4c 100644 --- a/operators/silu/op.py +++ b/operators/silu/op.py @@ -40,9 +40,8 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0, - self.get_kernel_archive_name(), - ], + 0 + ] ) def get_kernel_artifacts(self): From f0cda24b5a851250ed1c65c08852dbbf88e0d042 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 17:25:01 -0700 Subject: [PATCH 41/99] fixes; requires XRT PR 9560 to be merged --- applications/llama_3.2_1b/autofuse.py | 10 +++++----- operators/common/fusion.py | 28 +++++++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index 067fbf79..e7eb85a4 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -122,11 +122,11 @@ def setup_autofused(): swiglu_fused = swiglu_fused_op.compile().get_callable() def run_autofused(): - swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = buf_x_norm.view_as_torch() - swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = buf_W_ffn_gate.view_as_torch() - swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = buf_W_ffn_up.view_as_torch() - swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = buf_W_ffn_down.view_as_torch() - swiglu_fused.get_buffer("ffn_output").view_as_torch()[:] = buf_ffn_output.view_as_torch() + swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = buf_x_norm.view_as_torch().flatten() + swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = buf_W_ffn_gate.view_as_torch().flatten() + swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = buf_W_ffn_up.view_as_torch().flatten() + swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = buf_W_ffn_down.view_as_torch().flatten() + swiglu_fused.get_buffer("ffn_output").view_as_torch()[:] = buf_ffn_output.view_as_torch().flatten() swiglu_fused() return swiglu_fused.get_buffer("ffn_output").view_as_torch() diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 53aef1d7..f9eda2f4 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -129,25 +129,19 @@ def get_callable(self): class FullELFCallable: def __init__(self, op, device_name="main", sequence_name="sequence", device_manager=None): - # std::string kernelName = "main:sequence"; - # xrt::elf ctx_elf{"aie.elf"}; - # xrt::hw_context context = xrt::hw_context(device, ctx_elf); - # auto kernel = xrt::ext::kernel(context, kernelName); self.device_manager = device_manager or AIEDeviceManager() self.xrt_elf = pyxrt.elf(op.artifacts[0].filename) self.xrt_module = pyxrt.module(self.xrt_elf) - #self.xrt_context = self.xrt_module.get_hw_context() - self.xrt_context = pyxrt.hw_context(self.device_manager.device, self.xrt_module.get_cfg_uuid()) - self.xrt_kernel = pyxrt.kernel(self.xrt_context, f"{device_name}:{sequence_name}") + self.xrt_context = pyxrt.hw_context(self.device_manager.device, self.xrt_elf) + self.xrt_kernel = pyxrt.ext.kernel(self.xrt_context, f"{device_name}:{sequence_name}") def __call__(self, *args): - pass - #run = pyxrt.run(self.xrt_kernel) - #for i, arg in enumerate(args): - # assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" - # run.set_arg(i, arg) - #run.start() - #run.wait2() + run = pyxrt.run(self.xrt_kernel) + for i, arg in enumerate(args): + assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" + run.set_arg(i, arg) + run.start() + run.wait2() class FusedFullELFCallable(FullELFCallable): def __init__(self, op, device_manager=None): @@ -223,7 +217,7 @@ def __call__(self): self.output_buffer.to("npu") self.scratch_buffer.to("npu") super().__call__( - self.input_buffer.xrt_bo if self.input_buffer else None, - self.output_buffer.xrt_bo if self.output_buffer else None, - self.scratch_buffer.xrt_bo if self.scratch_buffer else None, + self.input_buffer.bo if self.input_buffer else None, + self.output_buffer.bo if self.output_buffer else None, + self.scratch_buffer.bo if self.scratch_buffer else None, ) From d715e48a6707e5771d038cacb5251c7eb9451545 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 18:19:48 -0700 Subject: [PATCH 42/99] fixes --- operators/common/compilation/base.py | 12 ++++++------ operators/common/context.py | 4 ++-- operators/common/fusion.py | 18 ++++++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index 9d02cccb..65f277e1 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -424,8 +424,8 @@ def compile(self, graph): "--expand-load-pdis", "--generate-full-elf", "--full-elf-name", - artifact.filename, - artifact.mlir_input.filename, + os.path.abspath(artifact.filename), + os.path.abspath(artifact.mlir_input.filename), ] commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) artifact.available = True @@ -472,12 +472,12 @@ def compile(self, graph): ] # TODO: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR compile_cmd += first_xclbin.extra_flags + [ "--aie-generate-xclbin", - "--xclbin-name=" + first_xclbin.filename, + "--xclbin-name=" + os.path.abspath(first_xclbin.filename), "--xclbin-kernel-name=" + first_xclbin.kernel_name, ] if first_xclbin.xclbin_input is not None: compile_cmd += [ - "--xclbin-input=" + first_xclbin.xclbin_input.filename + "--xclbin-input=" + os.path.abspath(first_xclbin.xclbin_input.filename) ] if do_compile_insts_bin: first_insts_bin = mlir_sources_to_insts[mlir_source][ @@ -487,9 +487,9 @@ def compile(self, graph): compile_cmd += ["--no-compile"] compile_cmd += first_insts_bin.extra_flags + [ "--aie-generate-npu", - "--npu-insts-name=" + first_insts_bin.filename, + "--npu-insts-name=" + os.path.abspath(first_insts_bin.filename), ] - compile_cmd += [mlir_source.filename] + compile_cmd += [os.path.abspath(mlir_source.filename)] commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) diff --git a/operators/common/context.py b/operators/common/context.py index 46987da4..06e224a7 100644 --- a/operators/common/context.py +++ b/operators/common/context.py @@ -14,12 +14,12 @@ class AIEContext: """Context for managing AIE operator compilation and runtime state""" - def __init__(self, use_runlist=True): + def __init__(self, use_runlist=True, build_dir=None): self.operators = [] self.static_data_pool = {} self.device_manager = AIEDeviceManager() self.base_dir = Path(__file__).parent.parent.parent - self.build_dir = Path(os.getcwd()) / "build" + self.build_dir = build_dir or Path(os.getcwd()) / "build" self.mlir_aie_dir = Path(aie.utils.config.root_path()) self.peano_dir = Path(aie.utils.config.peano_install_dir()) # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) diff --git a/operators/common/fusion.py b/operators/common/fusion.py index f9eda2f4..f1fbb2fd 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -141,7 +141,9 @@ def __call__(self, *args): assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" run.set_arg(i, arg) run.start() - run.wait2() + ret_code = run.wait() + if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: + raise RuntimeError(f"Kernel execution failed with return code {retcode}") class FusedFullELFCallable(FullELFCallable): def __init__(self, op, device_manager=None): @@ -151,22 +153,22 @@ def __init__(self, op, device_manager=None): self.buffer_sizes = op.buffer_sizes input_buffer_size, output_buffer_size, scratch_buffer_size = self.buffer_sizes - bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize self.input_buffer = AIEBuffer( - shape=(input_buffer_size // bf16_itemsize,), + shape=(max(input_buffer_size, itemsize) // itemsize,), dtype=ml_dtypes.bfloat16 - ) if input_buffer_size > 0 else None + ) self.output_buffer = AIEBuffer( - shape=(output_buffer_size // bf16_itemsize,), + shape=(max(output_buffer_size, itemsize) // itemsize,), dtype=ml_dtypes.bfloat16 - ) if output_buffer_size > 0 else None + ) self.scratch_buffer = AIEBuffer( - shape=(scratch_buffer_size // bf16_itemsize,), + shape=(max(scratch_buffer_size, itemsize) // itemsize,), dtype=ml_dtypes.bfloat16 - ) if scratch_buffer_size > 0 else None + ) self._buffer_cache = {} From 684a725698a43fb64e3a5b28dd30bb14a71503a7 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 21:59:50 -0700 Subject: [PATCH 43/99] finally working --- applications/llama_3.2_1b/autofuse.py | 98 +++++++++++++++++++-------- operators/common/compilation/base.py | 24 ++++++- operators/common/fusion.py | 22 ++++-- operators/elementwise_mul/design.py | 4 +- operators/gemv/design.py | 12 ++-- operators/silu/design.py | 4 +- 6 files changed, 118 insertions(+), 46 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index e7eb85a4..e33358ef 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -61,14 +61,30 @@ # Buffers # --- -buf_W_ffn_gate = AIEBuffer.from_torch(torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16)) -buf_W_ffn_up = AIEBuffer.from_torch(torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16)) -buf_W_ffn_down = AIEBuffer.from_torch(torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16)) -buf_x_norm = AIEBuffer.from_torch(torch.randn(emb_dim, dtype=torch.bfloat16)) +# Create identity matrix for W_ffn_gate (repeating pattern for hidden_dim x emb_dim) +# Each row i will pick element i % emb_dim from x_norm +W_ffn_gate = torch.zeros(hidden_dim, emb_dim, dtype=torch.bfloat16) +for i in range(hidden_dim): + W_ffn_gate[i, i % emb_dim] = 1.0 + +W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) + +W_ffn_down = torch.zeros(emb_dim, hidden_dim, dtype=torch.bfloat16) +for i in range(emb_dim): + W_ffn_down[i, i] = 1.0 + +buf_W_ffn_gate = AIEBuffer.from_torch(W_ffn_gate) +buf_W_ffn_up = AIEBuffer.from_torch(W_ffn_up) +buf_W_ffn_down = AIEBuffer.from_torch(W_ffn_down) + +# Create x_norm as sequential indices: [0, 1, 2, 3, ..., emb_dim-1] +x_norm = torch.arange(emb_dim, dtype=torch.bfloat16) +buf_x_norm = AIEBuffer.from_torch(x_norm) buf_ffn_gate = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) buf_ffn_up = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -buf_ffn_hidden = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -buf_ffn_output = AIEBuffer.from_torch(torch.zeros(emb_dim, dtype=torch.bfloat16)) +ffn_hidden = torch.arange(hidden_dim, dtype=torch.bfloat16) +buf_ffn_hidden = AIEBuffer.from_torch(ffn_hidden) +buf_ffn_output = AIEBuffer.from_torch(-1 * torch.arange(emb_dim, dtype=torch.bfloat16)) #torch.zeros(emb_dim, dtype=torch.bfloat16)) # Separate xclbins @@ -81,6 +97,11 @@ def setup_separate_xclbins(): global gemv_ffn_up_gate, gemv_ffn_down, silu_ffn, eltwise_mul_ffn + ctx = AIEContext(build_dir="build_separate") + gemv_ffn_up_gate_op.context = ctx + gemv_ffn_down_op.context = ctx + silu_ffn_op.context = ctx + eltwise_mul_ffn_op.context = ctx gemv_ffn_up_gate = gemv_ffn_up_gate_op.compile().get_callable() gemv_ffn_down = gemv_ffn_down_op.compile().get_callable() silu_ffn = silu_ffn_op.compile().get_callable() @@ -99,15 +120,20 @@ def run_separate_xclbins(): # --- def setup_autofused(): + ctx = AIEContext(build_dir="build_autofused") + gemv_ffn_up_gate_op.context = ctx + gemv_ffn_down_op.context = ctx + silu_ffn_op.context = ctx + eltwise_mul_ffn_op.context = ctx global swiglu_fused_op, swiglu_fused swiglu_fused_op = FusedMLIROperator( "swiglu", [ - (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "inter_ffn_gate"), - (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "inter_ffn_up"), - (silu_ffn_op, "inter_ffn_gate", "inter_ffn_gate"), - (eltwise_mul_ffn_op, "inter_ffn_gate", "inter_ffn_up", "inter_ffn_hidden"), - (gemv_ffn_down_op, "W_ffn_down", "inter_ffn_hidden", "ffn_output"), + (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), + (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), ], input_args=[ "x_norm", @@ -117,42 +143,60 @@ def setup_autofused(): ], output_args=[ "ffn_output" - ] + ], ) + swiglu_fused_op.context = ctx swiglu_fused = swiglu_fused_op.compile().get_callable() + swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() + swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() + swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = W_ffn_up.flatten() + swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = W_ffn_down.flatten() + def run_autofused(): - swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = buf_x_norm.view_as_torch().flatten() - swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = buf_W_ffn_gate.view_as_torch().flatten() - swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = buf_W_ffn_up.view_as_torch().flatten() - swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = buf_W_ffn_down.view_as_torch().flatten() - swiglu_fused.get_buffer("ffn_output").view_as_torch()[:] = buf_ffn_output.view_as_torch().flatten() swiglu_fused() - return swiglu_fused.get_buffer("ffn_output").view_as_torch() + return swiglu_fused.get_buffer("ffn_output").to("cpu").view_as_torch() # CPU # --- def run_cpu(): - x_norm = buf_x_norm.view_as_torch() - W_ffn_gate = buf_W_ffn_gate.view_as_torch() - W_ffn_up = buf_W_ffn_up.view_as_torch() - W_ffn_down = buf_W_ffn_down.view_as_torch() - ffn_gate = torch.matmul(W_ffn_gate, x_norm) ffn_up = torch.matmul(W_ffn_up, x_norm) ffn_gate = torch.nn.functional.silu(ffn_gate) ffn_hidden = ffn_gate * ffn_up ffn_output = torch.matmul(W_ffn_down, ffn_hidden) - return ffn_output # Main # --- +iters=100 + setup_autofused() -print(run_autofused()) +t_autofused_start = time.time() +for _ in range(iters): + res_npu = run_autofused() +t_autofused = time.time() - t_autofused_start + setup_separate_xclbins() -print(run_separate_xclbins()) -print(run_cpu()) +t_separate_start = time.time() +for _ in range(iters): + res_npu_s = run_separate_xclbins() +t_separate = time.time() - t_separate_start + +t_cpu_start = time.time() +for _ in range(iters): + res_cpu = run_cpu() +t_cpu = time.time() - t_cpu_start + +print(res_npu_s) +print(res_npu) +print(res_cpu) + + +print(f"Separate xclbins time: {t_separate/iters:.6f} seconds") +print(f"Autofused time: {t_autofused/iters:.6f} seconds") +print(f"CPU time: {t_cpu/iters:.6f} seconds") +assert(torch.allclose(res_npu[-1], res_cpu[-1], atol=0.7, rtol=0.07)) diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index 65f277e1..a2d05a16 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -248,11 +248,11 @@ def __init__(self, filename, mlir_input, dependencies, extra_flags=None): class KernelObjectArtifact(CompilationArtifact): - def __init__(self, filename, dependencies, extra_flags=None, rename_symbols=None): + def __init__(self, filename, dependencies, extra_flags=None, rename_symbols=None, prefix_symbols=None): super().__init__(filename, dependencies) self.extra_flags = extra_flags if extra_flags is not None else [] self.rename_symbols = rename_symbols if rename_symbols is not None else {} - + self.prefix_symbols = prefix_symbols class KernelArchiveArtifact(CompilationArtifact): pass @@ -552,6 +552,8 @@ def compile(self, artifacts): commands.append(ShellCompilationCommand(cmd)) if artifact.rename_symbols: commands.extend(self._rename_symbols(artifact)) + if artifact.prefix_symbols: + commands.extend(self._prefix_symbols(artifact, artifact.prefix_symbols)) artifact.available = True return commands @@ -568,6 +570,15 @@ def _rename_symbols(self, artifact): ] cmd += [artifact.filename] return [ShellCompilationCommand(cmd)] + + def _prefix_symbols(self, artifact, prefix): + objcopy_path = "llvm-objcopy-18" + cmd = [ + objcopy_path, + "--prefix-symbols=" + prefix, + artifact.filename, + ] + return [ShellCompilationCommand(cmd)] class ArchiveCompilationRule(CompilationRule): @@ -607,6 +618,15 @@ def compile(self, artifacts): cmd = [str(ar_path), "rcs", archive_path] + object_files commands.append(ShellCompilationCommand(cmd)) + + # Check for duplicate symbol definitions in the archive + check_cmd = [ + "sh", "-c", + f"nm {archive_path} | grep ' [TDR] ' | awk '{{print $3}}' | sort | uniq -d | " + f"if read sym; then echo \"Error: Duplicate symbol in archive: $sym\" >&2; exit 1; fi" + ] + commands.append(ShellCompilationCommand(check_cmd)) + artifact.available = True return commands diff --git a/operators/common/fusion.py b/operators/common/fusion.py index f1fbb2fd..4e5a083d 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -30,8 +30,12 @@ def get_operator_name(self): def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators.""" kernel_artifacts = [] - for op, *bufs in self.runlist: - kernel_artifacts.extend(op.get_kernel_artifacts()) + unique_ops = set(op for op, *_ in self.runlist) + for idx, op in enumerate(unique_ops): + objs = op.get_kernel_artifacts() + for obj in objs: + obj.prefix_symbols = f"op{idx}_" + kernel_artifacts.extend(objs) return kernel_artifacts def get_mlir_artifact(self): @@ -39,16 +43,22 @@ def get_mlir_artifact(self): operator_mlir_map = {} mlir_dependencies = [] comp_runlist = [] - - for idx, (op, *bufs) in enumerate(self.runlist): + op_names = {} # op -> op_name + + unique_operators = set(op for op, *_ in self.runlist) + for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() if len(op.get_kernel_artifacts()) > 0: # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels # Also not handling name collisions of kernels with the same name mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive - op_name = f"{op.get_operator_name()}_{idx}" + mlir_artifact.callback_kwargs["func_prefix"] = f"op{idx}_" + op_name = f"op{idx}_{op.__class__.__name__}" + op_names[op] = op_name operator_mlir_map[op_name] = mlir_artifact - comp_runlist.append((op_name, *bufs)) + + for op, *bufs in self.runlist: + comp_runlist.append((op_names[op], *bufs)) # Calculate buffer layout: {buffer_name -> (type, offset, length)} self.subbuffer_layout, self.buffer_sizes = self._calculate_buffer_layout() diff --git a/operators/elementwise_mul/design.py b/operators/elementwise_mul/design.py index cd30bbef..8cae5ac8 100644 --- a/operators/elementwise_mul/design.py +++ b/operators/elementwise_mul/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive): +def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +37,7 @@ def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel # AIE Core Function declaration eltwise_mul_bf16_vector = Kernel( - "eltwise_mul_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_mul_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/gemv/design.py b/operators/gemv/design.py index b26fb648..e8cf8e12 100644 --- a/operators/gemv/design.py +++ b/operators/gemv/design.py @@ -34,7 +34,7 @@ """ -def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_archive="mv.o"): +def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_archive="mv.o", func_prefix=""): if m_output is None: m_output = m_input @@ -71,18 +71,16 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_arc L1_C_ty = np.ndarray[(m_output,), dtype_out] L3_A_ty = np.ndarray[ ( - num_batches, - M, - K, + num_batches * M * K, ), dtype_in, ] - L3_B_ty = np.ndarray[(num_batches, K,), dtype_in] - L3_C_ty = np.ndarray[(num_batches, M,), dtype_out] + L3_B_ty = np.ndarray[(num_batches * K,), dtype_in] + L3_C_ty = np.ndarray[(num_batches * M,), dtype_out] func_type = "vectorized" if vectorized else "scalar" matvec = Kernel( - f"matvec_{func_type}_{dtype_in_str}_{dtype_out_str}", + f"{func_prefix}matvec_{func_type}_{dtype_in_str}_{dtype_out_str}", kernel_archive, [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], ) diff --git a/operators/silu/design.py b/operators/silu/design.py index 8e9532dd..0d0e6d74 100644 --- a/operators/silu/design.py +++ b/operators/silu/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_silu(dev, size, num_columns, tile_size, trace_size, kernel_archive): +def my_silu(dev, size, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] @@ -44,7 +44,7 @@ def my_silu(dev, size, num_columns, tile_size, trace_size, kernel_archive): # External, binary kernel definition silu_fcn = Kernel( - "silu_bf16", + f"{func_prefix}silu_bf16", kernel_archive, [line_type, line_type, np.int32], ) From 447983c040e401613b82aadd5b99f3f82b219508 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 22:09:14 -0700 Subject: [PATCH 44/99] fixes --- operators/common/fusion.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 4e5a083d..3663d967 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -30,8 +30,11 @@ def get_operator_name(self): def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators.""" kernel_artifacts = [] - unique_ops = set(op for op, *_ in self.runlist) - for idx, op in enumerate(unique_ops): + unique_operators = [] + for op, *_ in self.runlist: + if op not in unique_operators: + unique_operators.append(op) + for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() for obj in objs: obj.prefix_symbols = f"op{idx}_" @@ -45,7 +48,10 @@ def get_mlir_artifact(self): comp_runlist = [] op_names = {} # op -> op_name - unique_operators = set(op for op, *_ in self.runlist) + unique_operators = [] + for op, *_ in self.runlist: + if op not in unique_operators: + unique_operators.append(op) for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() if len(op.get_kernel_artifacts()) > 0: From e025ac7e3c9425a83ff9aca6afeccbb55e9bc399 Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 22:29:08 -0700 Subject: [PATCH 45/99] optimize out reconfiguration --- applications/llama_3.2_1b/autofuse.py | 8 ++++---- operators/common/compilation/fusion.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index e33358ef..e2bf482c 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -32,7 +32,7 @@ gemv_ffn_up_gate_op = AIEGEMV( M=hidden_dim, K=emb_dim, - num_aie_columns=1, + num_aie_columns=8, tile_size_input=4, tile_size_output=hidden_dim // 8, ) @@ -40,7 +40,7 @@ gemv_ffn_down_op = AIEGEMV( M=emb_dim, K=hidden_dim, - num_aie_columns=1, + num_aie_columns=8, tile_size_input=1, tile_size_output=emb_dim // 8, ) @@ -48,13 +48,13 @@ silu_ffn_op = AIESiLU( size=hidden_dim, tile_size=hidden_dim // 8, - num_aie_columns=1, + num_aie_columns=8, ) eltwise_mul_ffn_op = AIEElementwiseMul( size=hidden_dim, tile_size=hidden_dim // 8, - num_aie_columns=1, + num_aie_columns=8, ) diff --git a/operators/common/compilation/fusion.py b/operators/common/compilation/fusion.py index e60027c8..70d41f0e 100644 --- a/operators/common/compilation/fusion.py +++ b/operators/common/compilation/fusion.py @@ -121,13 +121,19 @@ def sequence(input_buf, output_buf, scratch_buf): } # Execute operations in runlist order + configure_op = None + last_op_name = None for op_name, *buffer_names in artifact.runlist: expected_arg_types = sequence_arg_types[op_name] - # Configure Op - configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) - configure_op = aiex.ConfigureOp(configure_sym_ref_attr) # TODO: optimization -- if previous op was in the same device, skip reconfiguration - configure_body = configure_op.body.blocks.append() + # Avoid reconfiguring altogether if the same op is called multiple times consecutively + if configure_op is None or op_name != last_op_name: + # Configure Op + configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) + configure_op = aiex.ConfigureOp(configure_sym_ref_attr) # TODO: optimization -- if previous op was in the same device, skip reconfiguration + configure_body = configure_op.body.blocks.append() + last_op_name = op_name + with ir.InsertionPoint(configure_body): # For each buffer, add subview and reinterpret_cast ops From e3c0e64525663c68560a63c53322b77f7dcb2afa Mon Sep 17 00:00:00 2001 From: andrej Date: Wed, 21 Jan 2026 23:40:08 -0700 Subject: [PATCH 46/99] fix some compilation issues --- applications/llama_3.2_1b/autofuse.py | 29 ++++++++------------- operators/common/base.py | 8 +++--- operators/common/compilation/base.py | 35 ++++++++------------------ operators/common/compilation/fusion.py | 4 +-- operators/common/context.py | 2 +- 5 files changed, 28 insertions(+), 50 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index e2bf482c..0632be7e 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -32,7 +32,7 @@ gemv_ffn_up_gate_op = AIEGEMV( M=hidden_dim, K=emb_dim, - num_aie_columns=8, + num_aie_columns=4, tile_size_input=4, tile_size_output=hidden_dim // 8, ) @@ -40,7 +40,7 @@ gemv_ffn_down_op = AIEGEMV( M=emb_dim, K=hidden_dim, - num_aie_columns=8, + num_aie_columns=4, tile_size_input=1, tile_size_output=emb_dim // 8, ) @@ -48,13 +48,13 @@ silu_ffn_op = AIESiLU( size=hidden_dim, tile_size=hidden_dim // 8, - num_aie_columns=8, + num_aie_columns=4, ) eltwise_mul_ffn_op = AIEElementwiseMul( size=hidden_dim, tile_size=hidden_dim // 8, - num_aie_columns=8, + num_aie_columns=4, ) @@ -63,28 +63,20 @@ # Create identity matrix for W_ffn_gate (repeating pattern for hidden_dim x emb_dim) # Each row i will pick element i % emb_dim from x_norm -W_ffn_gate = torch.zeros(hidden_dim, emb_dim, dtype=torch.bfloat16) -for i in range(hidden_dim): - W_ffn_gate[i, i % emb_dim] = 1.0 - +W_ffn_gate = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) - -W_ffn_down = torch.zeros(emb_dim, hidden_dim, dtype=torch.bfloat16) -for i in range(emb_dim): - W_ffn_down[i, i] = 1.0 - +W_ffn_down = torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16) buf_W_ffn_gate = AIEBuffer.from_torch(W_ffn_gate) buf_W_ffn_up = AIEBuffer.from_torch(W_ffn_up) buf_W_ffn_down = AIEBuffer.from_torch(W_ffn_down) # Create x_norm as sequential indices: [0, 1, 2, 3, ..., emb_dim-1] -x_norm = torch.arange(emb_dim, dtype=torch.bfloat16) +x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) buf_x_norm = AIEBuffer.from_torch(x_norm) buf_ffn_gate = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) buf_ffn_up = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -ffn_hidden = torch.arange(hidden_dim, dtype=torch.bfloat16) -buf_ffn_hidden = AIEBuffer.from_torch(ffn_hidden) -buf_ffn_output = AIEBuffer.from_torch(-1 * torch.arange(emb_dim, dtype=torch.bfloat16)) #torch.zeros(emb_dim, dtype=torch.bfloat16)) +buf_ffn_hidden = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) +buf_ffn_output = AIEBuffer.from_torch(torch.zeros(emb_dim, dtype=torch.bfloat16)) # Separate xclbins @@ -146,7 +138,8 @@ def setup_autofused(): ], ) swiglu_fused_op.context = ctx - swiglu_fused = swiglu_fused_op.compile().get_callable() + swiglu_fused_op = swiglu_fused_op.compile() + swiglu_fused = swiglu_fused_op.get_callable() swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() diff --git a/operators/common/base.py b/operators/common/base.py index 05efad25..395ee541 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -29,7 +29,7 @@ class AIEOperatorBase(ABC): """Base class for AIE-accelerated operations""" def __init__(self, context=None): - self.artifacts = ( + self.artifacts = comp.CompilationArtifactGraph( [] ) # CompilationArtifact objects are uniqued within the context if context is None: @@ -67,12 +67,12 @@ def compile(self, dry_run=False): Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. """ self.set_up_artifacts() - artifacts = comp.CompilationArtifactGraph(self.artifacts) - comp.compile(self.context.compilation_rules, artifacts, self.context.build_dir, dry_run=dry_run) + comp.compile(self.context.compilation_rules, self.artifacts, self.context.build_dir, dry_run=dry_run) return self def add_artifacts(self, artifacts): - self.artifacts.extend(artifacts) + for artifact in artifacts: + self.artifacts.add(artifact) def sync_to_device(bos): diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index a2d05a16..3d599951 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -40,7 +40,6 @@ import importlib.util from contextlib import nullcontext from aie.extras.context import mlir_mod_ctx -import copy import sys @@ -53,18 +52,17 @@ def plan(rules, graph): return [] # Everything has been compiled for rule in rules: if rule.matches(graph): - new_graph = graph.copy() - commands = rule.compile(new_graph) + commands = rule.compile(graph) break else: raise RuntimeError( f"No matching rule to compile target(s): {', '.join(artifact.filename for artifact in graph)}" ) - return [(rule, commands, graph)] + plan(rules, new_graph) + return [(rule, commands)] + plan(rules, graph) -def execute(plan): - for rule, commands, _ in plan: +def execute(plan_steps): + for rule, commands in plan_steps: logging.debug(f"Applying rule: {rule.__class__.__name__}") for command in commands: logging.debug(f" Executing command: {command}") @@ -82,7 +80,7 @@ def compile(rules, artifacts, build_dir="build", dry_run=False): if not dry_run: execute(plan_steps) else: - print("\n".join("\n".join(map(str, cmds)) for _, cmds, _ in plan_steps)) + print("\n".join("\n".join(map(str, cmds)) for _, cmds in plan_steps)) @@ -135,21 +133,6 @@ def _traverse(self, dfs): todo.extend(artifact.dependencies) yield artifact - def copy(self): - artifact_map = {} - - def copy_artifact(artifact): - if artifact in artifact_map: - return artifact_map[artifact] - new_artifact = copy.copy(artifact) - artifact_map[artifact] = new_artifact - new_deps = [copy_artifact(dep) for dep in artifact.dependencies] - new_artifact.dependencies = CompilationArtifactGraph(artifacts=new_deps) - return new_artifact - - new_artifacts = [copy_artifact(artifact) for artifact in self.artifacts] - return CompilationArtifactGraph(artifacts=new_artifacts) - def replace(self, old_artifact, new_artifact): for i, artifact in enumerate(self.artifacts): if artifact == old_artifact: @@ -178,6 +161,9 @@ def move_artifacts(self, new_root): for artifact in self.bfs(): if not os.path.isabs(artifact.filename): artifact.filename = str(Path(new_root) / Path(artifact.filename).name) + + def add(self, artifact): + self.artifacts.append(artifact) # Compilation Artifacts @@ -360,10 +346,9 @@ def compile(self, graph): commands = [] worklist = graph.get_worklist(PythonGeneratedMLIRArtifact) for artifact in worklist: - assert len(artifact.dependencies) == 1 and isinstance(artifact.dependencies[0], SourceArtifact), "PythonGeneratedMLIRArtifact must depend on exactly one SourceArtifact" - import_path = Path(artifact.dependencies[0].filename) new_artifact = SourceArtifact(artifact.filename) - callback = lambda: self.generate_mlir(new_artifact, import_path, artifact.callback_fn, artifact.callback_args, artifact.callback_kwargs, artifact.requires_context) + # To make Python capture variables in this closure by value, not by reference, use default arguments + callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir(new_artifact, import_path, callback_fn, callback_args, callback_kwargs, requires_context) commands.append(PythonCallbackCompilationCommand(callback)) new_artifact.available = True graph.replace(artifact, new_artifact) diff --git a/operators/common/compilation/fusion.py b/operators/common/compilation/fusion.py index 70d41f0e..ca241781 100644 --- a/operators/common/compilation/fusion.py +++ b/operators/common/compilation/fusion.py @@ -189,7 +189,7 @@ def sequence(input_buf, output_buf, scratch_buf): # ########################################################################## -class FuseMLIRCompilationRule(CompilationRule): +class FusePythonGeneratedMLIRCompilationRule(CompilationRule): """Compilation rule that fuses multiple MLIR modules into one.""" def matches(self, graph): @@ -199,7 +199,7 @@ def compile(self, graph): commands = [] worklist = graph.get_worklist(FusedMLIRSource) for artifact in worklist: - callback = lambda: fuse_mlir(artifact) + callback = lambda artifact=artifact: fuse_mlir(artifact) commands.append(PythonCallbackCompilationCommand(callback)) new_artifact = SourceArtifact(artifact.filename) new_artifact.available = True diff --git a/operators/common/context.py b/operators/common/context.py index 06e224a7..fc7880f6 100644 --- a/operators/common/context.py +++ b/operators/common/context.py @@ -25,7 +25,7 @@ def __init__(self, use_runlist=True, build_dir=None): # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) self.use_runlist = use_runlist self.compilation_rules = [ - comp.FuseMLIRCompilationRule(), + comp.FusePythonGeneratedMLIRCompilationRule(), comp.GenerateMLIRFromPythonCompilationRule(), comp.PeanoCompilationRule(self.peano_dir, self.mlir_aie_dir), comp.ArchiveCompilationRule(self.peano_dir), From 34fc4c5c8f1694d271944cd2465c40f66fcf1a9d Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 00:01:41 -0700 Subject: [PATCH 47/99] make all llama operators take a kernel archive and func prefix arg --- applications/llama_3.2_1b/llama_npu.py | 12 ++++++------ operators/common/base.py | 2 +- operators/elementwise_add/design.py | 4 ++-- operators/elementwise_add/op.py | 3 +-- operators/gemm/design.py | 15 ++++++++------- operators/gemm/op.py | 6 ------ operators/rms_norm/design_weighted.py | 6 +++--- operators/rms_norm/op.py | 2 +- operators/rope/design.py | 5 +++-- operators/rope/op.py | 1 - operators/softmax/design.py | 8 ++++---- operators/transpose/design.py | 6 ++++-- 12 files changed, 33 insertions(+), 37 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 03a532ec..aae1df32 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -13,12 +13,12 @@ repo_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(repo_root)) -from operators.common.aie_context import AIEContext +from operators.common.context import AIEContext from operators.common import ( AIEBuffer ) from operators.common.utils import torch_to_numpy, numpy_to_torch -from operators.common.aie_base import PatchableSingleXclbinCallable +from operators.common.base import PatchableSingleXclbinCallable from operators import ( AIERMSNorm, AIEGEMM, @@ -212,9 +212,9 @@ def __init__(self, config, prompt_len): ).compile() self.decode.softmax = PatchableSingleXclbinCallable( - xclbin_path=self.decode.softmax_compilable.xclbin_artifact.path, + xclbin_path=self.decode.softmax_compilable.xclbin_artifact.filename, kernel_name=self.decode.softmax_compilable.xclbin_artifact.kernel_name, - insts_bin_path=self.decode.softmax_compilable.insts_artifact.path, + insts_bin_path=self.decode.softmax_compilable.insts_artifact.filename, args_spec=self.decode.softmax_compilable.get_arg_spec() ) @@ -267,9 +267,9 @@ def __init__(self, config, prompt_len): # Create patchable callable for runtime offset updates self.decode.strided_copy_cache = PatchableSingleXclbinCallable( - xclbin_path=self.decode.strided_copy_cache_compilable.xclbin_artifact.path, + xclbin_path=self.decode.strided_copy_cache_compilable.xclbin_artifact.filename, kernel_name=self.decode.strided_copy_cache_compilable.xclbin_artifact.kernel_name, - insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.path, + insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.filename, args_spec=self.decode.strided_copy_cache_compilable.get_arg_spec() ) diff --git a/operators/common/base.py b/operators/common/base.py index 395ee541..0ef49770 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -286,5 +286,5 @@ def patch(self, patches): insts = self.insts_buffer.view_as_np() insts[:] = self.baseline_instructions for pos, (val, mask) in patches.items(): - insts[pos] = (insts[pos] & ~mask) | (val & mask) + insts[pos] = (np.int64(insts[pos]) & ~mask) | (val & mask) self.insts_buffer.to("npu") diff --git a/operators/elementwise_add/design.py b/operators/elementwise_add/design.py index d8790dfa..4a202e69 100644 --- a/operators/elementwise_add/design.py +++ b/operators/elementwise_add/design.py @@ -15,7 +15,7 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive): +def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +37,7 @@ def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel # AIE Core Function declaration eltwise_add_bf16_vector = Kernel( - "eltwise_add_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_add_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_add/op.py b/operators/elementwise_add/op.py index d98b7da5..ee6d1c5c 100644 --- a/operators/elementwise_add/op.py +++ b/operators/elementwise_add/op.py @@ -53,8 +53,7 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0, - self.get_kernel_archive_name(), + 0 ], ) diff --git a/operators/gemm/design.py b/operators/gemm/design.py index 432474cf..6216079e 100644 --- a/operators/gemm/design.py +++ b/operators/gemm/design.py @@ -141,7 +141,8 @@ def my_matmul( prio_accuracy, separate_c_tiles, trace_size, - archive=None, + kernel_archive=None, + func_prefix="", generate_taps=False, ): n_aie_rows = 4 @@ -274,7 +275,7 @@ def my_matmul( # AIE Core Function declarations scalar_suffix = "_scalar" if use_scalar else "" - archive_name = f"gemm_{m}x{k}x{n}_archive.a" if archive is None else archive + kernel_archive = f"{func_prefix}gemm_{m}x{k}x{n}_archive.a" if kernel_archive is None else kernel_archive if use_larger_internal_buffer: # Fix fifo depth for C objfifo to 1 since 1 buffer will be used for accumulation # and another for transfer to L2 @@ -284,19 +285,19 @@ def my_matmul( # A kernel to convert from the internal f32 accumulation to bf16 for transfer to L2 is needed convert_copy_kernel = Kernel( f"convert_copy_f32_to_bf16", - archive_name, + kernel_archive, [C_l1_ty_internal, C_l1_ty, np.int32], ) # Fix the kernels to use f32 outputs zero_kernel = Kernel( f"zero{scalar_suffix}_f32", - archive_name, + kernel_archive, [C_l1_ty_internal], ) matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, - archive_name, + kernel_archive, [A_l1_ty, B_l1_ty, C_l1_ty_internal], ) else: @@ -305,13 +306,13 @@ def my_matmul( fifo_depth_out = fifo_depth zero_kernel = Kernel( f"zero{scalar_suffix}_{dtype_out_str}", - archive_name, + kernel_archive, [C_l1_ty], ) matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" matmul_kernel = Kernel( matmul_func_name, - archive_name, + kernel_archive, [A_l1_ty, B_l1_ty, C_l1_ty], ) diff --git a/operators/gemm/op.py b/operators/gemm/op.py index db06b225..06fed39b 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -108,7 +108,6 @@ def get_mlir_artifact(self): "prio_accuracy": prio_accuracy, "separate_c_tiles": int(separate_c_tiles), "trace_size": 0, - "archive": self.get_kernel_archive_name(), "generate_taps": False, }, requires_context=False, @@ -162,11 +161,6 @@ def get_kernel_artifacts(self): ) ] - def get_kernel_archive_name(self): - return ( - f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.a" - ) - def get_arg_spec(self): return [ AIERuntimeArgSpec("in", (self.M, self.K)), # input A diff --git a/operators/rms_norm/design_weighted.py b/operators/rms_norm/design_weighted.py index 8f525abe..b6457ef3 100644 --- a/operators/rms_norm/design_weighted.py +++ b/operators/rms_norm/design_weighted.py @@ -16,7 +16,7 @@ def my_weighted_rms_norm( - dev, num_elements, num_columns, num_channels, weight_length, trace_size, kernel_archive="rms_norm.a" + dev, num_elements, num_columns, num_channels, weight_length, trace_size, kernel_archive="rms_norm.a", func_prefix="" ): per_tile_elements = weight_length total_cores = num_columns # For each core that does rms norm, another core will take its output to do eltwise mul @@ -53,10 +53,10 @@ def my_weighted_rms_norm( # AIE Core Function declaration rms_norm_kernel = Kernel( - "rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] + f"{func_prefix}rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] ) eltwise_mul_kernel = Kernel( - "eltwise_mul_bf16_vector", + f"{func_prefix}eltwise_mul_bf16_vector", kernel_archive, [tile_ty, weights_ty, tile_ty, np.int32], ) diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index f961e907..920b7b0e 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -70,7 +70,7 @@ def get_mlir_artifact(self): 0, ], callback_kwargs={ - "archive_name": f"{self.get_operator_name()}.a", + "kernel_archive": f"{self.get_operator_name()}.a", } ) diff --git a/operators/rope/design.py b/operators/rope/design.py index a2126783..2fa2c139 100644 --- a/operators/rope/design.py +++ b/operators/rope/design.py @@ -11,7 +11,7 @@ from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern -from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.dialects.scf import _for as range_ from ml_dtypes import bfloat16 @@ -39,6 +39,7 @@ def rope( trace_size=0, method_type=None, kernel_archive=None, + func_prefix="" ): dtype = bfloat16 @@ -77,7 +78,7 @@ def rope( # AIE Core Function declaration rope_kernel = Kernel( - "rope", + f"{func_prefix}rope", kernel_archive, [tensor_tile_ty, angle_tile_ty, tensor_tile_ty, np.int32], ) diff --git a/operators/rope/op.py b/operators/rope/op.py index 2fdd9dca..6ed264ff 100644 --- a/operators/rope/op.py +++ b/operators/rope/op.py @@ -57,7 +57,6 @@ def get_mlir_artifact(self): self.num_aie_columns, 0, self.method_type, - self.get_kernel_archive_name(), ], ) diff --git a/operators/softmax/design.py b/operators/softmax/design.py index d54a10ff..43ad46ad 100644 --- a/operators/softmax/design.py +++ b/operators/softmax/design.py @@ -11,11 +11,11 @@ from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern -from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.dialects.scf import _for as range_ from ml_dtypes import bfloat16 -def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None): +def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None, kernel_archive="softmax.a", func_prefix=""): per_tile_elements = tile_size if rtp_vector_size is None: rtp_vector_size = per_tile_elements @@ -45,8 +45,8 @@ def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_s ] # AIE Core Function declaration - softmax_kernel = Kernel("softmax_bf16", "softmax.o", [tile_ty, tile_ty, np.int32]) - mask_kernel = Kernel("mask_bf16", "softmax.o", [tile_ty, np.int32, np.int32]) + softmax_kernel = Kernel(f"{func_prefix}softmax_bf16", kernel_archive, [tile_ty, tile_ty, np.int32]) + mask_kernel = Kernel(f"{func_prefix}mask_bf16", kernel_archive, [tile_ty, np.int32, np.int32]) # Define a task that will run on a compute tile def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): diff --git a/operators/transpose/design.py b/operators/transpose/design.py index 9a0c3b29..dcf2636e 100644 --- a/operators/transpose/design.py +++ b/operators/transpose/design.py @@ -10,7 +10,7 @@ from aie.iron.controlflow import range_ -def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s): +def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, kernel_archive=None, func_prefix=""): num_elements = M * N per_tile_elements = m * n dtype = bfloat16 @@ -98,8 +98,10 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s): ] # AIE Core Function declaration + if kernel_archive is None: + kernel_archive = f"transpose_{s}x{s}.a" transpose_kernel = Kernel( - f"transpose_{s}x{s}", f"transpose_{m}x{n}.o", [tile_ty, tile_ty] + f"{func_prefix}transpose_{s}x{s}", kernel_archive, [tile_ty, tile_ty] ) # Define a task that will run on a compute tile From a197f2d7b8bba6902e26054af6c291ae196362e2 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 10:09:09 -0700 Subject: [PATCH 48/99] txn-fused swiglu --- applications/llama_3.2_1b/autofuse.py | 59 +++++++++++++-------- applications/llama_3.2_1b/llama_npu.py | 72 +++++++++++++++++++------- 2 files changed, 91 insertions(+), 40 deletions(-) diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py index 0632be7e..bddaf7a7 100755 --- a/applications/llama_3.2_1b/autofuse.py +++ b/applications/llama_3.2_1b/autofuse.py @@ -61,17 +61,22 @@ # Buffers # --- -# Create identity matrix for W_ffn_gate (repeating pattern for hidden_dim x emb_dim) -# Each row i will pick element i % emb_dim from x_norm +x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) W_ffn_gate = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) W_ffn_down = torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16) + +def init_random(): + global x_norm, W_ffn_gate, W_ffn_up, W_ffn_down + x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) + W_ffn_gate = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) + W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) + W_ffn_down = torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16) + buf_W_ffn_gate = AIEBuffer.from_torch(W_ffn_gate) buf_W_ffn_up = AIEBuffer.from_torch(W_ffn_up) buf_W_ffn_down = AIEBuffer.from_torch(W_ffn_down) -# Create x_norm as sequential indices: [0, 1, 2, 3, ..., emb_dim-1] -x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) buf_x_norm = AIEBuffer.from_torch(x_norm) buf_ffn_gate = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) buf_ffn_up = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) @@ -141,12 +146,20 @@ def setup_autofused(): swiglu_fused_op = swiglu_fused_op.compile() swiglu_fused = swiglu_fused_op.get_callable() + #swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() + #swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() + #swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = W_ffn_up.flatten() + #swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = W_ffn_down.flatten() + +def run_autofused(): + #swiglu_fused.output_buffer.view_as_torch()[:] = 0 + #swiglu_fused.scratch_buffer.view_as_torch()[:] = 0 + swiglu_fused.input_buffer.view_as_torch()[:] = 0 swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() + swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = W_ffn_up.flatten() swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = W_ffn_down.flatten() - -def run_autofused(): swiglu_fused() return swiglu_fused.get_buffer("ffn_output").to("cpu").view_as_torch() @@ -165,31 +178,35 @@ def run_cpu(): # Main # --- -iters=100 +iters=10 setup_autofused() -t_autofused_start = time.time() +#setup_separate_xclbins() for _ in range(iters): + init_random() + + t_autofused_start = time.time() res_npu = run_autofused() -t_autofused = time.time() - t_autofused_start + print("npu:") + print(res_npu) + t_autofused = time.time() - t_autofused_start -setup_separate_xclbins() -t_separate_start = time.time() -for _ in range(iters): - res_npu_s = run_separate_xclbins() -t_separate = time.time() - t_separate_start + #t_separate_start = time.time() + #for _ in range(iters): + # res_npu_s = run_separate_xclbins() + #t_separate = time.time() - t_separate_start -t_cpu_start = time.time() -for _ in range(iters): + t_cpu_start = time.time() res_cpu = run_cpu() -t_cpu = time.time() - t_cpu_start + print("cpu:") + print(res_cpu) + #assert(torch.allclose(res_npu[-1], res_cpu[-1], atol=0.7, rtol=0.07)) + t_cpu = time.time() - t_cpu_start -print(res_npu_s) -print(res_npu) -print(res_cpu) + #print(res_npu_s) -print(f"Separate xclbins time: {t_separate/iters:.6f} seconds") +#print(f"Separate xclbins time: {t_separate/iters:.6f} seconds") print(f"Autofused time: {t_autofused/iters:.6f} seconds") print(f"CPU time: {t_cpu/iters:.6f} seconds") assert(torch.allclose(res_npu[-1], res_cpu[-1], atol=0.7, rtol=0.07)) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index aae1df32..8c145e98 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -19,6 +19,7 @@ ) from operators.common.utils import torch_to_numpy, numpy_to_torch from operators.common.base import PatchableSingleXclbinCallable +from operators.common.fusion import FusedMLIROperator from operators import ( AIERMSNorm, AIEGEMM, @@ -150,38 +151,62 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - # Decode: GEMV for M=1 - self.decode.gemv_ffn_up_gate = AIEGEMV( + elf_ctx = AIEContext(build_dir="build_elf") + + # Fused SwiGLU operator for decode + gemv_ffn_up_gate_op = AIEGEMV( M=config.hidden_dim, K=config.emb_dim, num_aie_columns=8, tile_size_input=4, tile_size_output=config.hidden_dim // 8, - context=self.context - ).compile().get_callable() + context=elf_ctx + ) - self.decode.gemv_ffn_down = AIEGEMV( + gemv_ffn_down_op = AIEGEMV( M=config.emb_dim, K=config.hidden_dim, num_aie_columns=8, tile_size_input=1, tile_size_output=config.emb_dim // 8, - context=self.context - ).compile().get_callable() + context=elf_ctx + ) - self.decode.ffn_silu = AIESiLU( + silu_ffn_op = AIESiLU( size=config.hidden_dim, tile_size=config.hidden_dim // 8, - num_aie_columns=1, - context=self.context - ).compile().get_callable() + num_aie_columns=8, + context=elf_ctx + ) - self.decode.eltwise_mul_ffn = AIEElementwiseMul( + eltwise_mul_ffn_op = AIEElementwiseMul( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, - context=self.context - ).compile().get_callable() + context=elf_ctx + ) + + self.decode.swiglu_fused_op = FusedMLIROperator( + "swiglu_decode", + [ + (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), + (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), + ], + input_args=[ + "x_norm", + "W_ffn_gate", + "W_ffn_up", + "W_ffn_down" + ], + output_args=[ + "ffn_output" + ], + context=elf_ctx + ).compile() + self.decode.swiglu_fused = self.decode.swiglu_fused_op.get_callable() # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector @@ -938,11 +963,20 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i def swiglu_ffn_forward_decode(layer_idx): - aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_gate_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_gate) # Gate projection - aie_ops.decode.gemv_ffn_up_gate(aie_buffers.W_ffn_up_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.ffn_up) # Up projection - aie_ops.decode.ffn_silu(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_gate) # SiLU activation - aie_ops.decode.eltwise_mul_ffn(aie_buffers.decode.ffn_gate, aie_buffers.decode.ffn_up, aie_buffers.decode.ffn_hidden) # Gate application (eltwise mul) - aie_ops.decode.gemv_ffn_down(aie_buffers.W_ffn_down_decode[layer_idx], aie_buffers.decode.ffn_hidden, aie_buffers.decode.ffn_output) # Down projection + # fused + # Copy inputs to fused operator's internal buffers + fused_op = aie_ops.decode.swiglu_fused + fused_op.input_buffer.view_as_torch()[:] = 0 + fused_op.get_buffer("x_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x_norm.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() + + # Execute fused operator + fused_op() + + # Copy output from fused operator's internal buffer to output buffer + aie_buffers.decode.ffn_output.to("cpu").view_as_torch()[:] = fused_op.get_buffer("ffn_output").to("cpu").view_as_torch()[:] # Main From e35ed7b58a289c4848651f44e370f6222b127f57 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 12:11:09 -0700 Subject: [PATCH 49/99] bring up to speed after host runtime refactor --- operators/common/device_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/operators/common/device_manager.py b/operators/common/device_manager.py index e48c39cd..3ec6acf9 100644 --- a/operators/common/device_manager.py +++ b/operators/common/device_manager.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Dict, Optional, Any import pyxrt -from aie.iron.hostruntime.config import detect_npu_device +from aie.utils.hostruntime.xrtruntime.hostruntime import XRTHostRuntime from aie.iron.device import NPU1, NPU2 @@ -33,7 +33,7 @@ def __init__(self): AIEDeviceManager._initialized = True self.device = pyxrt.device(0) - self.device_type = detect_npu_device() + self.device_type = XRTHostRuntime().device() self.contexts = {} # xclbin_path -> (context, xclbin) self.kernels = {} # (xclbin_path, kernel_name) -> kernel From 2994ba28cab2dcf5d0b285a5a48f110069586023 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 12:40:15 -0700 Subject: [PATCH 50/99] refactor symbol renaming to not clash with externally defined library functions; fused-txn of more of the transformer block --- applications/llama_3.2_1b/llama_npu.py | 65 +++++++++++++++----------- operators/common/compilation/base.py | 18 +++++-- operators/common/fusion.py | 1 + 3 files changed, 55 insertions(+), 29 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 8c145e98..cbb10439 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -153,7 +153,16 @@ def __init__(self, config, prompt_len): elf_ctx = AIEContext(build_dir="build_elf") - # Fused SwiGLU operator for decode + # Fused operator for post-norm + SwiGLU + residual (decode) + rms_norm_op = AIERMSNorm( + size=config.emb_dim, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=config.emb_dim, + context=elf_ctx + ) + gemv_ffn_up_gate_op = AIEGEMV( M=config.hidden_dim, K=config.emb_dim, @@ -186,27 +195,36 @@ def __init__(self, config, prompt_len): context=elf_ctx ) - self.decode.swiglu_fused_op = FusedMLIROperator( - "swiglu_decode", + residual_add_op = AIEElementwiseAdd( + size=config.emb_dim, + tile_size=config.emb_dim // 8, + context=elf_ctx + ) + + self.decode.post_attn_fused_op = FusedMLIROperator( + "post_attn_decode", [ + (rms_norm_op, "x_pre_norm", "W_norm2", "x_norm"), (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), (silu_ffn_op, "ffn_gate", "ffn_gate"), (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), + (residual_add_op, "x_pre_norm", "ffn_output", "x_out"), ], input_args=[ - "x_norm", + "x_pre_norm", + "W_norm2", "W_ffn_gate", "W_ffn_up", "W_ffn_down" ], output_args=[ - "ffn_output" + "x_out" ], context=elf_ctx ).compile() - self.decode.swiglu_fused = self.decode.swiglu_fused_op.get_callable() + self.decode.post_attn_fused = self.decode.post_attn_fused_op.get_callable() # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector @@ -912,9 +930,21 @@ def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx) # Step 2: Attention; results stored in attn_output aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) # Step 3: Residual - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm2[layer_idx], aie_buffers.decode.x_norm) # Step 4: Post-norm - swiglu_ffn_forward_decode(layer_idx) # Step 5: Feed-forward network - aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.ffn_output, aie_buffers.decode.x) # Step 6: Residual + + # Step 4-6: Fused post-norm + SwiGLU + residual + fused_op = aie_ops.decode.post_attn_fused + fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.get_buffer("x_pre_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() + + fused_op() + + aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x_out").to("cpu").view_as_torch()[:] def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx): @@ -962,23 +992,6 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_ops.decode.gemv_attn_output(aie_buffers.W_attn_output_decode[layer_idx], aie_buffers.decode.attn_context, aie_buffers.decode.attn_output) -def swiglu_ffn_forward_decode(layer_idx): - # fused - # Copy inputs to fused operator's internal buffers - fused_op = aie_ops.decode.swiglu_fused - fused_op.input_buffer.view_as_torch()[:] = 0 - fused_op.get_buffer("x_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x_norm.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() - - # Execute fused operator - fused_op() - - # Copy output from fused operator's internal buffer to output buffer - aie_buffers.decode.ffn_output.to("cpu").view_as_torch()[:] = fused_op.get_buffer("ffn_output").to("cpu").view_as_torch()[:] - - # Main # ########################################################################## diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index 3d599951..e65d22b6 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -558,12 +558,24 @@ def _rename_symbols(self, artifact): def _prefix_symbols(self, artifact, prefix): objcopy_path = "llvm-objcopy-18" - cmd = [ + nm_path = "llvm-nm-18" + symbol_map_file = artifact.filename + ".symbol_map" + + # Extract defined symbols and create symbol map + nm_cmd = [ + "sh", "-c", + f"{nm_path} --defined-only --extern-only {artifact.filename} | " + f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}" + ] + + # Apply the renaming using the symbol map + objcopy_cmd = [ objcopy_path, - "--prefix-symbols=" + prefix, + "--redefine-syms=" + symbol_map_file, artifact.filename, ] - return [ShellCompilationCommand(cmd)] + + return [ShellCompilationCommand(nm_cmd), ShellCompilationCommand(objcopy_cmd)] class ArchiveCompilationRule(CompilationRule): diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 3663d967..432e86d7 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -37,6 +37,7 @@ def get_kernel_artifacts(self): for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() for obj in objs: + obj.filename = f"op{idx}_{obj.filename}" obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts From 25605be210b6be27a6aa7de0ba7828061c50cff5 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 12:48:48 -0700 Subject: [PATCH 51/99] fuse first part of attention --- applications/llama_3.2_1b/llama_npu.py | 83 +++++++++++++++++++++++--- operators/rope/design.py | 2 +- 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index cbb10439..a8779485 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -226,6 +226,64 @@ def __init__(self, config, prompt_len): ).compile() self.decode.post_attn_fused = self.decode.post_attn_fused_op.get_callable() + # Fused operator for attention projections + RoPE (decode) + gemv_attn_query_op = AIEGEMV( + M=config.n_heads * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=elf_ctx + ) + + gemv_attn_key_value_op = AIEGEMV( + M=config.n_kv_groups * config.head_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.head_dim // 2, + context=elf_ctx + ) + + rope_queries_op = AIERope( + rows=1 * config.n_heads, + cols=config.head_dim, + angle_rows=1, + context=elf_ctx + ) + + rope_keys_op = AIERope( + rows=1 * config.n_kv_groups, + cols=config.head_dim, + angle_rows=1, + context=elf_ctx + ) + + self.decode.attn_proj_rope_fused_op = FusedMLIROperator( + "attn_proj_rope_decode", + [ + (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), + (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), + (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), + (rope_queries_op, "queries", "rope_angles", "queries"), + (rope_keys_op, "keys", "rope_angles", "keys"), + ], + input_args=[ + "W_attn_query", + "W_attn_key", + "W_attn_value", + "x_norm", + "rope_angles" + ], + output_args=[ + "queries", + "keys", + "values" + ], + context=elf_ctx + ).compile() + self.decode.attn_proj_rope_fused = self.decode.attn_proj_rope_fused_op.get_callable() + # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector self.prefill.attn_scale = AIEElementwiseMul( @@ -951,14 +1009,25 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i context_len = num_preceding_tokens + 1 group_size = config.n_heads // config.n_kv_groups - # Step 1: Linear projections - write directly to queries/keys/values buffers - aie_ops.decode.gemv_attn_query(aie_buffers.W_attn_query_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.queries) - aie_ops.decode.gemv_attn_key_value(aie_buffers.W_attn_key_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.keys) - aie_ops.decode.gemv_attn_key_value(aie_buffers.W_attn_value_decode[layer_idx], aie_buffers.decode.x_norm, aie_buffers.decode.values) + # Step 1-2: Fused attention projections + RoPE + fused_op = aie_ops.decode.attn_proj_rope_fused + fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.get_buffer("W_attn_query").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_attn_key").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_attn_value").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("x_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x_norm.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() + + fused_op() - # Step 2: Apply RoPE - use same buffers for input and output - aie_ops.decode.rope_queries(aie_buffers.decode.queries, aie_buffers.decode.rope_angles, aie_buffers.decode.queries) - aie_ops.decode.rope_keys(aie_buffers.decode.keys, aie_buffers.decode.rope_angles, aie_buffers.decode.keys) + aie_buffers.decode.queries.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("queries").to("cpu").view_as_torch().flatten() + aie_buffers.decode.keys.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("keys").to("cpu").view_as_torch().flatten() + aie_buffers.decode.values.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("values").to("cpu").view_as_torch().flatten() + aie_buffers.decode.queries.to("npu") + aie_buffers.decode.keys.to("npu") + aie_buffers.decode.values.to("npu") # Step 3: Update cache using strided copy on NPU (transpose and concatenate) # Cache is already on NPU from prefill initialization or previous decode iteration diff --git a/operators/rope/design.py b/operators/rope/design.py index 2fa2c139..bbb5ed79 100644 --- a/operators/rope/design.py +++ b/operators/rope/design.py @@ -132,7 +132,7 @@ def core_body(of_in, of_lut, of_out, rope_kernel): # Runtime operations to move data to/from the AIE-array rt = Runtime() - with rt.sequence(tensor_ty, tensor_ty, tensor_ty) as (A, B, C): + with rt.sequence(tensor_ty, angle_ty, tensor_ty) as (A, B, C): rt.start(*my_workers) # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. From e01a6f0a13806e41f81946606989c57870edb6b3 Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 15:40:07 -0700 Subject: [PATCH 52/99] make it possible to slice buffers in fused txn specification; fused-txn transpose - 3 TPS --- applications/llama_3.2_1b/llama_npu.py | 59 ++++++++++++++-- operators/common/compilation/fusion.py | 15 +++- operators/common/fusion.py | 94 ++++++++++++++++++++------ 3 files changed, 140 insertions(+), 28 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index a8779485..813a1892 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -284,8 +284,46 @@ def __init__(self, config, prompt_len): ).compile() self.decode.attn_proj_rope_fused = self.decode.attn_proj_rope_fused_op.get_callable() + # Fused transpose for all attention heads (decode) + transpose_values_op = AIETranspose( + M=prompt_len, + N=config.head_dim, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx + ) + + # Calculate buffer size for all heads' values + values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 + values_buffer_size = config.n_heads * values_per_head_buffer_size + + # Build runlist with sliced buffers for each head + transpose_runlist = [ + (transpose_values_op, + f"values_all[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", + f"values_transposed_all[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]" + ) + for h in range(config.n_heads) + ] + + self.decode.transpose_values_fused_op = FusedMLIROperator( + "transpose_values_decode", + transpose_runlist, + input_args=["values_all"], + output_args=["values_transposed_all"], + buffer_sizes={ + "values_all": values_buffer_size, + "values_transposed_all": values_buffer_size + }, + context=elf_ctx + ).compile() + self.decode.transpose_values_fused = self.decode.transpose_values_fused_op.get_callable() + # Attention score scaling operators - # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector + # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY self.prefill.attn_scale = AIEElementwiseMul( size=config.n_heads * prompt_len * prompt_len, tile_size=prompt_len, @@ -1048,12 +1086,19 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) # Step 8: Compute attention output on NPU - # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head - for h in range(config.n_heads): - aie_ops.decode.transpose_values( - aie_buffers.decode.attn_scores_values_per_head[h], - aie_buffers.decode.attn_scores_values_transposed_per_head[h] - ) + # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head using fused operator + fused_transpose = aie_ops.decode.transpose_values_fused + fused_transpose.input_buffer.view_as_torch().to("cpu")[:] = 0 + fused_transpose.output_buffer.view_as_torch().to("cpu")[:] = 0 + fused_transpose.scratch_buffer.view_as_torch().to("cpu")[:] = 0 + fused_transpose.get_buffer("values_all").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch().flatten() + + fused_transpose() + + # Reshape flat output to match expected 3D shape (n_heads, head_dim, max_context_len) + aie_buffers.decode.attn_scores_values_transposed.to("cpu").view_as_torch().flatten()[:] = fused_transpose.get_buffer("values_transposed_all").to("cpu").view_as_torch().flatten() + aie_buffers.decode.attn_scores_values_transposed.to("npu") + # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) diff --git a/operators/common/compilation/fusion.py b/operators/common/compilation/fusion.py index ca241781..19b560d8 100644 --- a/operators/common/compilation/fusion.py +++ b/operators/common/compilation/fusion.py @@ -28,13 +28,14 @@ class FusedMLIRSource(CompilationArtifact): - def __init__(self, filename, operator_mlir_map, runlist, subbuffer_layout, buffer_sizes): + def __init__(self, filename, operator_mlir_map, runlist, subbuffer_layout, buffer_sizes, slice_info=None): dependencies = list(operator_mlir_map.values()) super().__init__(filename, dependencies) self.operator_mlir_map = operator_mlir_map self.runlist = runlist self.subbuffer_layout = subbuffer_layout self.buffer_sizes = buffer_sizes + self.slice_info = slice_info or {} # Helper Functions @@ -139,7 +140,17 @@ def sequence(input_buf, output_buf, scratch_buf): # For each buffer, add subview and reinterpret_cast ops buffer_ssa_values = [] for idx, buf_name in enumerate(buffer_names): - buf_type, offset, length = artifact.subbuffer_layout[buf_name] + # Check if this is a sliced buffer + if buf_name in artifact.slice_info: + base_name, start, end = artifact.slice_info[buf_name] + # Get parent buffer info + buf_type, parent_offset, parent_length = artifact.subbuffer_layout[base_name] + # Calculate actual offset and length for slice + offset = parent_offset + start + length = end - start + else: + # Regular buffer + buf_type, offset, length = artifact.subbuffer_layout[buf_name] # Subview Op consolidated_buf = consolidated_buffers[buf_type] diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 432e86d7..c418f8f7 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -12,7 +12,7 @@ class FusedMLIROperator(AIEOperatorBase): """Operator that fuses multiple SingleMLIRSourceOperators into one.""" - def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): + def __init__(self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs): assert all( isinstance(op, SingleMLIRSourceOperator) and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist @@ -21,6 +21,7 @@ def __init__(self, name, runlist, input_args, output_args, *args, **kwargs): self.name = name self.input_args = input_args self.output_args = output_args + self.explicit_buffer_sizes = buffer_sizes or {} # Optional dict: buffer_name -> size_in_bytes self.kernel_archive = "kernels.a" super().__init__(*args, **kwargs) @@ -68,7 +69,7 @@ def get_mlir_artifact(self): comp_runlist.append((op_names[op], *bufs)) # Calculate buffer layout: {buffer_name -> (type, offset, length)} - self.subbuffer_layout, self.buffer_sizes = self._calculate_buffer_layout() + self.subbuffer_layout, self.buffer_sizes, self.slice_info = self._calculate_buffer_layout() filename = self.get_operator_name() + "_fused.mlir" fused_artifact = comp.FusedMLIRSource( @@ -76,13 +77,15 @@ def get_mlir_artifact(self): operator_mlir_map=operator_mlir_map, runlist=comp_runlist, subbuffer_layout=self.subbuffer_layout, - buffer_sizes=self.buffer_sizes + buffer_sizes=self.buffer_sizes, + slice_info=self.slice_info ) return fused_artifact def _calculate_buffer_layout(self): - args = {} + args = {} # base_buffer_name -> args_spec + sliced_buffers = {} # full_buffer_name (with slice) -> (base_name, start, end, args_spec) # Collect all buffer specs from operators for op, *bufs in self.runlist: @@ -90,36 +93,69 @@ def _calculate_buffer_layout(self): assert len(args_specs) == len(bufs), "Number of buffers must match operator argument specification" for i, buf_name in enumerate(bufs): args_spec = args_specs[i] - if buf_name not in args: - args[buf_name] = args_spec + + # Parse slice notation: "buffer_name[start:end]" + if '[' in buf_name and buf_name.endswith(']'): + base_name = buf_name[:buf_name.index('[')] + slice_part = buf_name[buf_name.index('[')+1:-1] + start, end = map(int, slice_part.split(':')) + sliced_buffers[buf_name] = (base_name, start, end, args_spec) + # Track that base buffer exists (size will be set later) + if base_name not in args and base_name not in self.explicit_buffer_sizes: + raise ValueError(f"Sliced buffer '{buf_name}' requires explicit size for base buffer '{base_name}' in buffer_sizes parameter") else: - assert np.prod(args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" + # Regular buffer (no slice) + if buf_name not in args: + args[buf_name] = args_spec + else: + assert np.prod(args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" - # Verify all input/output args are present + # Verify all input/output args are present (either as regular or sliced buffers) + all_buffer_names = set(args.keys()) | set(sliced_buffers.keys()) for arg in self.input_args: - assert arg in args, f"Input argument {arg} not found in runlist buffers" + # Check if it's a base buffer name in explicit_buffer_sizes + if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: + raise AssertionError(f"Input argument {arg} not found in runlist buffers") for arg in self.output_args: - assert arg in args, f"Output argument {arg} not found in runlist buffers" + if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: + raise AssertionError(f"Output argument {arg} not found in runlist buffers") - # Determine buffer types + # Determine buffer types and create layout subbuffer_layout = {} + slice_info = {} # full_buffer_name -> (base_name, start, end) def add_buffers(buffer_type, args_list): offset = 0 for arg in args_list: - arg_spec = args[arg] - length = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) - subbuffer_layout[arg] = (buffer_type, offset, length) - offset += length - return offset # == total length + if arg in self.explicit_buffer_sizes: + # Explicit size specified - this is a parent buffer for slices + length = self.explicit_buffer_sizes[arg] + subbuffer_layout[arg] = (buffer_type, offset, length) + offset += length + elif arg in args: + # Regular buffer with inferred size + arg_spec = args[arg] + length = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) + subbuffer_layout[arg] = (buffer_type, offset, length) + offset += length + # Note: sliced buffers are handled separately, not in args_list + return offset # == total length + + # Add sliced buffer entries to layout (they reference parent buffers) + for buf_name, (base_name, start, end, args_spec) in sliced_buffers.items(): + slice_info[buf_name] = (base_name, start, end) input_buffer_size = add_buffers('input', self.input_args) output_buffer_size = add_buffers('output', self.output_args) scratch_args = [arg for arg in args if arg not in self.input_args and arg not in self.output_args] + # Also include explicit buffers that are only used for slicing + for explicit_buf in self.explicit_buffer_sizes: + if explicit_buf not in self.input_args and explicit_buf not in self.output_args and explicit_buf not in scratch_args: + scratch_args.append(explicit_buf) scratch_buffer_size = add_buffers('scratch', scratch_args) buffer_sizes = (input_buffer_size, output_buffer_size, scratch_buffer_size) - return subbuffer_layout, buffer_sizes + return subbuffer_layout, buffer_sizes, slice_info def set_up_artifacts(self): operator_name = self.get_operator_name() @@ -168,6 +204,7 @@ def __init__(self, op, device_manager=None): self.subbuffer_layout = op.subbuffer_layout self.buffer_sizes = op.buffer_sizes + self.slice_info = op.slice_info input_buffer_size, output_buffer_size, scratch_buffer_size = self.buffer_sizes itemsize = np.dtype(ml_dtypes.bfloat16).itemsize @@ -194,7 +231,27 @@ def get_buffer(self, buffer_name): if buffer_name in self._buffer_cache: return self._buffer_cache[buffer_name] - # Look up buffer information + bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + + # Check if this is a sliced buffer + if buffer_name in self.slice_info: + base_name, start, end = self.slice_info[buffer_name] + # Get the parent buffer + parent_buffer = self.get_buffer(base_name) + # Create subbuffer from parent + start_elements = start // bf16_itemsize + length_elements = (end - start) // bf16_itemsize + sub_buffer = parent_buffer.subbuffer( + length=length_elements, + offset=start_elements, + shape=(length_elements,), + dtype=ml_dtypes.bfloat16 + ) + # Cache and return + self._buffer_cache[buffer_name] = sub_buffer + return sub_buffer + + # Look up buffer information for regular buffers if buffer_name not in self.subbuffer_layout: raise KeyError(f"Buffer '{buffer_name}' not found in buffer layout") @@ -214,7 +271,6 @@ def get_buffer(self, buffer_name): raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") # Convert byte offset/length to element offset/length - bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize offset_elements = offset // bf16_itemsize length_elements = length // bf16_itemsize From e9cbc0048493f5657030db515ddde1552699912d Mon Sep 17 00:00:00 2001 From: andrej Date: Thu, 22 Jan 2026 16:08:44 -0700 Subject: [PATCH 53/99] discover patching locations automatically by use of magic values --- applications/llama_3.2_1b/llama_npu.py | 25 ++++++++++++++++++++----- operators/softmax/design.py | 6 ++++-- operators/softmax/op.py | 4 +++- operators/strided_copy/design.py | 16 +++++++++++----- operators/strided_copy/op.py | 5 ++++- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 813a1892..b1c2d8a8 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -347,6 +347,7 @@ def __init__(self, config, prompt_len): num_aie_columns=1, num_channels=1, rtp_vector_size=prompt_len, # Compile with max size + mask_patch_value=0xBA5EBA11, # Magic value for patching context=self.context ).compile() @@ -356,6 +357,11 @@ def __init__(self, config, prompt_len): insts_bin_path=self.decode.softmax_compilable.insts_artifact.filename, args_spec=self.decode.softmax_compilable.get_arg_spec() ) + + self.decode.softmax_patch_offsets = [ + i for i, x in enumerate(self.decode.softmax.insts_buffer.view_as_np()) + if 0xBA5EBA11 == x + ] # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) @@ -401,7 +407,8 @@ def __init__(self, config, prompt_len): input_buffer_size=1 * config.n_kv_groups * config.head_dim, output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, num_aie_channels=1, - context=self.context + context=self.context, + output_offset_patch_marker=0xDEADBEE0, ).compile() # Create patchable callable for runtime offset updates @@ -411,6 +418,11 @@ def __init__(self, config, prompt_len): insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.filename, args_spec=self.decode.strided_copy_cache_compilable.get_arg_spec() ) + self.decode.strided_copy_patch_offsets = [ + i for i, x in enumerate(self.decode.strided_copy_cache.insts_buffer.view_as_np()) + if (0xDEADBEE0 * 2) & 0xFFFFFFFF == x + ] + assert len(self.decode.strided_copy_patch_offsets) == 2, "Something else accidentally generated our magic offset" # Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim) # Compile with max context length, then patch at runtime for actual context_len @@ -979,14 +991,17 @@ def patch_operators_for_decode(config, num_preceding_tokens): # Patch strided copy operator for cache offset output_offset = num_preceding_tokens * config.head_dim offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset - strided_copy_patches = { - 39: (offset_val, 0xFFFFFFFF), - 56: (offset_val, 0xFFFFFFFF), + strided_copy_patches = { + i: (offset_val, 0xFFFFFFFF) + for i in aie_ops.decode.strided_copy_patch_offsets } aie_ops.decode.strided_copy_cache.patch(strided_copy_patches) # Patch softmax operator for actual context length - softmax_patches = {8: (context_len, 0xFFFFFFFF)} + softmax_patches = { + i: (context_len, 0xFFFFFFFF) + for i in aie_ops.decode.softmax_patch_offsets + } aie_ops.decode.softmax.patch(softmax_patches) diff --git a/operators/softmax/design.py b/operators/softmax/design.py index 43ad46ad..e7e12fc0 100644 --- a/operators/softmax/design.py +++ b/operators/softmax/design.py @@ -15,7 +15,7 @@ from ml_dtypes import bfloat16 -def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None, kernel_archive="softmax.a", func_prefix=""): +def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None, mask_patch_value=0, kernel_archive="softmax.a", func_prefix=""): per_tile_elements = tile_size if rtp_vector_size is None: rtp_vector_size = per_tile_elements @@ -118,7 +118,9 @@ def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): # Set run-time parameter for actual vector size (remainder is considered padding and ignored by the computation) def set_rtps(*args): for rtp in args: - rtp[0] = rtp_vector_size + rtp[0] = ( + rtp_vector_size if not mask_patch_value else mask_patch_value + ) rt.inline_ops(set_rtps, rtps) diff --git a/operators/softmax/op.py b/operators/softmax/op.py index 1da691bd..0f6c8b10 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -16,7 +16,7 @@ class AIESoftmax(SingleMLIRSourceOperator): """AIE-accelerated Softmax operation""" - def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_vector_size=None, context=None): + def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_vector_size=None, mask_patch_value=0, context=None): assert rows % 16 == 0, "rows must be multiple of 16" assert cols % 16 == 0, "cols must be multiple of 16" assert (rows * cols) % (num_aie_columns * cols) == 0, "size must be multiple of num_aie_columns * tile_size" @@ -27,6 +27,7 @@ def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_ self.num_aie_columns = num_aie_columns self.num_channels = num_channels self.rtp_vector_size = rtp_vector_size + self.mask_patch_value = mask_patch_value SingleMLIRSourceOperator.__init__(self, context=context) @@ -50,6 +51,7 @@ def get_mlir_artifact(self): 0, # trace_size self.cols, self.rtp_vector_size, + self.mask_patch_value ], ) diff --git a/operators/strided_copy/design.py b/operators/strided_copy/design.py index 5969dc95..57da2ec5 100644 --- a/operators/strided_copy/design.py +++ b/operators/strided_copy/design.py @@ -17,7 +17,7 @@ This can be useful for data layout manipulation and data copying such as: input[0, :, 0] -> output[:, 0, 0] """ -def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, input_offset, output_buffer_size, output_sizes, output_strides, output_offset, transfer_size=None, num_aie_channels=1): +def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, input_offset, output_buffer_size, output_sizes, output_strides, output_offset, transfer_size=None, num_aie_channels=1, input_offset_patch_marker=0, output_offset_patch_marker=0): assert len(input_sizes) == len(input_strides) assert len(output_sizes) == len(output_strides) @@ -42,8 +42,11 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu input_taps = [ TensorAccessPattern( - tensor_dims=(int(input_buffer_size),), - offset=input_offset + c * (input_sizes[input_highest_sz_idx] // num_aie_channels) * input_strides[input_highest_sz_idx], + tensor_dims=(int(input_buffer_size + input_offset_patch_marker),), + offset=( + input_offset_patch_marker if input_offset_patch_marker != 0 else + input_offset + c * (input_sizes[input_highest_sz_idx] // num_aie_channels) * input_strides[input_highest_sz_idx] + ), sizes=( input_sizes[:input_highest_sz_idx] + [input_sizes[input_highest_sz_idx] // num_aie_channels] @@ -56,8 +59,11 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu output_taps = [ TensorAccessPattern( - tensor_dims=(int(output_buffer_size),), - offset=output_offset + c * (output_sizes[output_highest_sz_idx] // num_aie_channels) * output_strides[output_highest_sz_idx], + tensor_dims=(int(output_buffer_size + output_offset_patch_marker),), + offset=( + output_offset_patch_marker if output_offset_patch_marker != 0 else + output_offset + c * (output_sizes[output_highest_sz_idx] // num_aie_channels) * output_strides[output_highest_sz_idx] + ), sizes=( output_sizes[:output_highest_sz_idx] + [output_sizes[output_highest_sz_idx] // num_aie_channels] diff --git a/operators/strided_copy/op.py b/operators/strided_copy/op.py index 8400a640..452743dd 100644 --- a/operators/strided_copy/op.py +++ b/operators/strided_copy/op.py @@ -32,6 +32,7 @@ def __init__( transfer_size=None, num_aie_channels=1, context=None, + **kwargs ): assert len(input_sizes) == len(input_strides) assert len(output_sizes) == len(output_strides) @@ -46,6 +47,7 @@ def __init__( self.dtype = dtype self.transfer_size = transfer_size self.num_aie_channels = num_aie_channels + self.kwargs = kwargs SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): @@ -71,7 +73,8 @@ def get_mlir_artifact(self): self.output_offset, self.transfer_size, self.num_aie_channels, - ] + ], + callback_kwargs=self.kwargs, ) def get_kernel_artifacts(self): From b7d2834717998030bd251e930c5f81379da27a5a Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 14:59:24 -0700 Subject: [PATCH 54/99] make ELFs patchable; offload strided-copy for KV cache --- applications/llama_3.2_1b/llama_npu.py | 117 +++++++++++++++---------- operators/common/base.py | 1 + operators/common/compilation/base.py | 1 + operators/common/fusion.py | 94 +++++++++++--------- 4 files changed, 124 insertions(+), 89 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index b1c2d8a8..be2d37dd 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -19,7 +19,7 @@ ) from operators.common.utils import torch_to_numpy, numpy_to_torch from operators.common.base import PatchableSingleXclbinCallable -from operators.common.fusion import FusedMLIROperator +from operators.common.fusion import FusedMLIROperator, FusedFullELFCallable, load_elf, patch_elf from operators import ( AIERMSNorm, AIEGEMM, @@ -259,30 +259,79 @@ def __init__(self, config, prompt_len): context=elf_ctx ) - self.decode.attn_proj_rope_fused_op = FusedMLIROperator( - "attn_proj_rope_decode", + strided_copy_cache_magic = 0xDEADBEE0 + strided_copy_cache_op = AIEStridedCopy( + input_sizes=(config.n_kv_groups, config.head_dim), + input_strides=(config.head_dim, 1), + input_offset=0, + output_sizes=(1, config.n_kv_groups, config.head_dim), + output_strides=(0, prompt_len * config.head_dim, 1), + output_offset=7 * config.head_dim * 2, # Will be patched at runtime + input_buffer_size=1 * config.n_kv_groups * config.head_dim, + output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, + num_aie_channels=1, + output_offset_patch_marker=strided_copy_cache_magic, + context=elf_ctx + ) + + # Calculate buffer size for keys/values cache (same size, used as input to strided copy) + cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 + + self.decode.attn_fused_op = FusedMLIROperator( + "attn_fused_op", [ (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), (rope_queries_op, "queries", "rope_angles", "queries"), (rope_keys_op, "keys", "rope_angles", "keys"), + (strided_copy_cache_op, "keys", "keys_cache"), + (strided_copy_cache_op, "values", "values_cache"), ], input_args=[ "W_attn_query", "W_attn_key", "W_attn_value", "x_norm", - "rope_angles" + "rope_angles", + "keys_cache", + "values_cache" ], output_args=[ "queries", "keys", "values" ], + buffer_sizes={ + "keys_cache": cache_buffer_size, + "values_cache": cache_buffer_size + }, context=elf_ctx ).compile() - self.decode.attn_proj_rope_fused = self.decode.attn_proj_rope_fused_op.get_callable() + + elf = load_elf(self.decode.attn_fused_op) + self.decode.attn_fused_elf_data = elf + + def get_patch_locs(elf_data, magic): + return [i for i, x in enumerate(elf_data) if magic & 0xFFFFFFFF == x ] + + # Extract patch offsets for strided_copy operations in fused operator + _, keys_cache_offs, _ = self.decode.attn_fused_op.get_layout_for_buffer("keys_cache") + _, values_cache_offs, _ = self.decode.attn_fused_op.get_layout_for_buffer("values_cache") + keys_patches = { + l: keys_cache_offs + for l in get_patch_locs(self.decode.attn_fused_elf_data, (keys_cache_offs + strided_copy_cache_magic * 2)) + } + values_patches = { + l: values_cache_offs + for l in get_patch_locs(self.decode.attn_fused_elf_data, (values_cache_offs + strided_copy_cache_magic * 2)) + } + no_offset_patches = { + l: 0 + for l in get_patch_locs(self.decode.attn_fused_elf_data, (strided_copy_cache_magic * 2)) + } + self.decode.attn_fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} + assert len(self.decode.attn_fused_patch_locations) == 6 # Fused transpose for all attention heads (decode) transpose_values_op = AIETranspose( @@ -362,6 +411,7 @@ def __init__(self, config, prompt_len): i for i, x in enumerate(self.decode.softmax.insts_buffer.view_as_np()) if 0xBA5EBA11 == x ] + assert len(self.decode.softmax_patch_offsets) == 1, "Something else accidentally generated our magic mask value" # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) @@ -395,35 +445,6 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - # Strided copy operators for cache update (transpose and concatenate) - # Keys: transpose from (1, n_kv_groups, head_dim) to (n_kv_groups, 1, head_dim) and write to cache - self.decode.strided_copy_cache_compilable = AIEStridedCopy( - input_sizes=(config.n_kv_groups, 1, config.head_dim), - input_strides=(config.head_dim, config.n_kv_groups * config.head_dim, 1), - input_offset=0, - output_sizes=(1, config.n_kv_groups, 1, config.head_dim), - output_strides=(0, prompt_len * config.head_dim, config.head_dim, 1), - output_offset=0, # Will be patched at runtime based on cached_prompt_len - input_buffer_size=1 * config.n_kv_groups * config.head_dim, - output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, - num_aie_channels=1, - context=self.context, - output_offset_patch_marker=0xDEADBEE0, - ).compile() - - # Create patchable callable for runtime offset updates - self.decode.strided_copy_cache = PatchableSingleXclbinCallable( - xclbin_path=self.decode.strided_copy_cache_compilable.xclbin_artifact.filename, - kernel_name=self.decode.strided_copy_cache_compilable.xclbin_artifact.kernel_name, - insts_bin_path=self.decode.strided_copy_cache_compilable.insts_artifact.filename, - args_spec=self.decode.strided_copy_cache_compilable.get_arg_spec() - ) - self.decode.strided_copy_patch_offsets = [ - i for i, x in enumerate(self.decode.strided_copy_cache.insts_buffer.view_as_np()) - if (0xDEADBEE0 * 2) & 0xFFFFFFFF == x - ] - assert len(self.decode.strided_copy_patch_offsets) == 2, "Something else accidentally generated our magic offset" - # Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim) # Compile with max context length, then patch at runtime for actual context_len self.decode.attn_repeat_interleave = AIERepeat( @@ -988,14 +1009,20 @@ def llama_forward_pass_prefill( def patch_operators_for_decode(config, num_preceding_tokens): context_len = num_preceding_tokens + 1 - # Patch strided copy operator for cache offset + # Patch fused operator for strided copy cache offset output_offset = num_preceding_tokens * config.head_dim offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset strided_copy_patches = { - i: (offset_val, 0xFFFFFFFF) - for i in aie_ops.decode.strided_copy_patch_offsets + i: (base + offset_val, 0xFFFFFFFF) + for i, base in aie_ops.decode.attn_fused_patch_locations.items() } - aie_ops.decode.strided_copy_cache.patch(strided_copy_patches) + patched_elf_data = aie_ops.decode.attn_fused_elf_data.copy() + patch_elf(patched_elf_data, strided_copy_patches) + + aie_ops.decode.attn_fused = FusedFullELFCallable( + aie_ops.decode.attn_fused_op, + elf_data=patched_elf_data + ) # Patch softmax operator for actual context length softmax_patches = { @@ -1062,8 +1089,8 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i context_len = num_preceding_tokens + 1 group_size = config.n_heads // config.n_kv_groups - # Step 1-2: Fused attention projections + RoPE - fused_op = aie_ops.decode.attn_proj_rope_fused + # Step 1-3: Fused attention projections + RoPE + cache update + fused_op = aie_ops.decode.attn_fused fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 @@ -1072,22 +1099,20 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i fused_op.get_buffer("W_attn_value").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("x_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x_norm.to("cpu").view_as_torch().flatten() fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("keys_cache").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op() aie_buffers.decode.queries.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("queries").to("cpu").view_as_torch().flatten() aie_buffers.decode.keys.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("keys").to("cpu").view_as_torch().flatten() aie_buffers.decode.values.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("values").to("cpu").view_as_torch().flatten() + aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() + aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() aie_buffers.decode.queries.to("npu") aie_buffers.decode.keys.to("npu") aie_buffers.decode.values.to("npu") - # Step 3: Update cache using strided copy on NPU (transpose and concatenate) - # Cache is already on NPU from prefill initialization or previous decode iteration - # Transpose and append new keys/values to this layer's cache on NPU - aie_ops.decode.strided_copy_cache(aie_buffers.decode.keys, aie_buffers.keys_cache[layer_idx]) - aie_ops.decode.strided_copy_cache(aie_buffers.decode.values, aie_buffers.values_cache[layer_idx]) - # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys) aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values) diff --git a/operators/common/base.py b/operators/common/base.py index 0ef49770..71734992 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -239,6 +239,7 @@ def from_np(buffer): def from_torch(tensor): return AIEBuffer.from_np(torch_to_numpy(tensor)) + class SingleXclbinCallable: def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): self.device_manager = device_manager or AIEDeviceManager() diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index e65d22b6..eb20318d 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -388,6 +388,7 @@ def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): self.peano_dir = peano_dir super().__init__(*args, **kwargs) + class AieccFullElfCompilationRule(AieccCompilationRule): def matches(self, graph): return any(graph.get_worklist(FullElfArtifact)) diff --git a/operators/common/fusion.py b/operators/common/fusion.py index c418f8f7..265ef420 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -1,6 +1,7 @@ import numpy as np import ml_dtypes import pyxrt +import ctypes from . import compilation as comp from .base import AIEOperatorBase, SingleMLIRSourceOperator, AIEBuffer from .device_manager import AIEDeviceManager @@ -176,18 +177,49 @@ def get_arg_spec(self): pass def get_callable(self): - """Return a callable for the fused operator (stub for now).""" return FusedFullELFCallable(self) + def get_layout_for_buffer(self, buffer_name): + if buffer_name in self.slice_info: + buf_name, start, end = self.slice_info[buffer_name] + buf_type, parent_start, parent_end = self.get_layout_for_buffer(buf_name) + return buf_type, parent_start + start, parent_start + end + + buf_type, offset, length = self.subbuffer_layout[buffer_name] + return buf_type, offset, length + + +def load_elf(op): + assert isinstance(op.artifacts[0], comp.FullElfArtifact) + elf_data = None + with open(op.artifacts[0].filename, 'rb') as f: + elf_data = np.frombuffer(f.read(), dtype=np.uint32) + return elf_data + + +def patch_elf(elf_data, patches): + for i, patch in patches.items(): + val, mask = patch + elf_data[i] = (elf_data[i] & ~mask) | (val & mask) + return elf_data + class FullELFCallable: - def __init__(self, op, device_name="main", sequence_name="sequence", device_manager=None): + def __init__(self, elf_data, device_name="main", sequence_name="sequence", device_manager=None): self.device_manager = device_manager or AIEDeviceManager() - self.xrt_elf = pyxrt.elf(op.artifacts[0].filename) - self.xrt_module = pyxrt.module(self.xrt_elf) - self.xrt_context = pyxrt.hw_context(self.device_manager.device, self.xrt_elf) - self.xrt_kernel = pyxrt.ext.kernel(self.xrt_context, f"{device_name}:{sequence_name}") - + # Create a PyCapsule from the numpy array pointer for pybind11 + elf_data_u8 = elf_data.view(dtype=np.uint8) + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] + capsule = ctypes.pythonapi.PyCapsule_New( + elf_data_u8.ctypes.data, + None, + None + ) + xrt_elf = pyxrt.elf(capsule, elf_data.nbytes) + xrt_context = pyxrt.hw_context(self.device_manager.device, xrt_elf) + self.xrt_kernel = pyxrt.ext.kernel(xrt_context, f"{device_name}:{sequence_name}") + def __call__(self, *args): run = pyxrt.run(self.xrt_kernel) for i, arg in enumerate(args): @@ -196,17 +228,17 @@ def __call__(self, *args): run.start() ret_code = run.wait() if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError(f"Kernel execution failed with return code {retcode}") + raise RuntimeError(f"Kernel execution failed with return code {ret_code}") + class FusedFullELFCallable(FullELFCallable): - def __init__(self, op, device_manager=None): - super().__init__(op, device_manager=device_manager) + def __init__(self, op, elf_data=None, device_manager=None): + if elf_data is None: + elf_data = load_elf(op) + super().__init__(elf_data, device_manager=device_manager) - self.subbuffer_layout = op.subbuffer_layout - self.buffer_sizes = op.buffer_sizes - self.slice_info = op.slice_info - - input_buffer_size, output_buffer_size, scratch_buffer_size = self.buffer_sizes + self.op = op + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes itemsize = np.dtype(ml_dtypes.bfloat16).itemsize self.input_buffer = AIEBuffer( @@ -231,31 +263,7 @@ def get_buffer(self, buffer_name): if buffer_name in self._buffer_cache: return self._buffer_cache[buffer_name] - bf16_itemsize = np.dtype(ml_dtypes.bfloat16).itemsize - - # Check if this is a sliced buffer - if buffer_name in self.slice_info: - base_name, start, end = self.slice_info[buffer_name] - # Get the parent buffer - parent_buffer = self.get_buffer(base_name) - # Create subbuffer from parent - start_elements = start // bf16_itemsize - length_elements = (end - start) // bf16_itemsize - sub_buffer = parent_buffer.subbuffer( - length=length_elements, - offset=start_elements, - shape=(length_elements,), - dtype=ml_dtypes.bfloat16 - ) - # Cache and return - self._buffer_cache[buffer_name] = sub_buffer - return sub_buffer - - # Look up buffer information for regular buffers - if buffer_name not in self.subbuffer_layout: - raise KeyError(f"Buffer '{buffer_name}' not found in buffer layout") - - buf_type, offset, length = self.subbuffer_layout[buffer_name] + buf_type, offset, length = self.op.get_layout_for_buffer(buffer_name) # Select the appropriate main buffer if buf_type == 'input': @@ -271,11 +279,11 @@ def get_buffer(self, buffer_name): raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") # Convert byte offset/length to element offset/length - offset_elements = offset // bf16_itemsize - length_elements = length // bf16_itemsize + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + offset_elements = offset // itemsize + length_elements = length // itemsize # Create subbuffer with appropriate shape - # For now, use 1D shape; could be enhanced to use actual buffer shapes sub_buffer = main_buffer.subbuffer( length=length_elements, offset=offset_elements, From 687cb2a3c292f493f0d9ce7977d9974bfbb80da0 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 15:18:12 -0700 Subject: [PATCH 55/99] fuse repeat_interleave and post attention residual onto other operators - 2.4 TPS --- applications/llama_3.2_1b/llama_npu.py | 41 ++++++++++++-------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index be2d37dd..c70360fe 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -201,26 +201,33 @@ def __init__(self, config, prompt_len): context=elf_ctx ) + repeat_interleave_op = AIERepeat( + rows=config.n_kv_groups, + cols=prompt_len * config.head_dim, # Max context length + repeat=config.n_heads // config.n_kv_groups, + transfer_size=config.head_dim, + context=self.context + ) + self.decode.post_attn_fused_op = FusedMLIROperator( "post_attn_decode", [ - (rms_norm_op, "x_pre_norm", "W_norm2", "x_norm"), + (residual_add_op, "x", "attn_output", "x"), + (rms_norm_op, "x", "W_norm2", "x_norm"), (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), (silu_ffn_op, "ffn_gate", "ffn_gate"), (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), - (residual_add_op, "x_pre_norm", "ffn_output", "x_out"), + (residual_add_op, "x", "ffn_output", "x"), ], input_args=[ - "x_pre_norm", "W_norm2", "W_ffn_gate", "W_ffn_up", "W_ffn_down" ], output_args=[ - "x_out" ], context=elf_ctx ).compile() @@ -287,6 +294,8 @@ def __init__(self, config, prompt_len): (rope_keys_op, "keys", "rope_angles", "keys"), (strided_copy_cache_op, "keys", "keys_cache"), (strided_copy_cache_op, "values", "values_cache"), + (repeat_interleave_op, "keys_cache", "attn_scores_keys"), + (repeat_interleave_op, "values_cache", "attn_scores_values"), ], input_args=[ "W_attn_query", @@ -304,7 +313,7 @@ def __init__(self, config, prompt_len): ], buffer_sizes={ "keys_cache": cache_buffer_size, - "values_cache": cache_buffer_size + "values_cache": cache_buffer_size, }, context=elf_ctx ).compile() @@ -445,16 +454,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim) - # Compile with max context length, then patch at runtime for actual context_len - self.decode.attn_repeat_interleave = AIERepeat( - rows=config.n_kv_groups, - cols=prompt_len * config.head_dim, # Max context length - repeat=config.n_heads // config.n_kv_groups, - transfer_size=config.head_dim, - context=self.context - ).compile().get_callable() - # Attention projection operators # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) self.prefill.attn_query = AIEGEMM( @@ -1067,14 +1066,14 @@ def llama_forward_pass_decode(config, state): def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx) # Step 2: Attention; results stored in attn_output - aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) # Step 3: Residual # Step 4-6: Fused post-norm + SwiGLU + residual fused_op = aie_ops.decode.post_attn_fused fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.get_buffer("x_pre_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("attn_output").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_output.to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() @@ -1082,7 +1081,7 @@ def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): fused_op() - aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x_out").to("cpu").view_as_torch()[:] + aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx): @@ -1109,14 +1108,12 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_buffers.decode.values.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("values").to("cpu").view_as_torch().flatten() aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() + aie_buffers.decode.attn_scores_keys.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_keys").to("cpu").view_as_torch().flatten() + aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_values").to("cpu").view_as_torch().flatten() aie_buffers.decode.queries.to("npu") aie_buffers.decode.keys.to("npu") aie_buffers.decode.values.to("npu") - # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU - aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys) - aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values) - # Step 5: Compute attention scores # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) From 8b0aaebe5d19f073fe408daa7baf380582be91f1 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 15:36:42 -0700 Subject: [PATCH 56/99] fused attn score and scaling onto end - 2.5 TPS --- applications/llama_3.2_1b/llama_npu.py | 51 ++++++++++++++------------ 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index c70360fe..f81260ad 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -281,6 +281,25 @@ def __init__(self, config, prompt_len): context=elf_ctx ) + # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) + # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) + gemv_attn_scores_op = AIEGEMV( + M=prompt_len, # max possible context length + K=config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=prompt_len // 8, + num_batches=config.n_heads, + context=self.context + ) + + attn_scale_op = AIEElementwiseMul( + size=config.n_heads * prompt_len, + tile_size=prompt_len // 8, + num_aie_columns=8, + context=self.context + ) + # Calculate buffer size for keys/values cache (same size, used as input to strided copy) cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 @@ -296,6 +315,8 @@ def __init__(self, config, prompt_len): (strided_copy_cache_op, "values", "values_cache"), (repeat_interleave_op, "keys_cache", "attn_scores_keys"), (repeat_interleave_op, "values_cache", "attn_scores_values"), + (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores") ], input_args=[ "W_attn_query", @@ -304,12 +325,14 @@ def __init__(self, config, prompt_len): "x_norm", "rope_angles", "keys_cache", - "values_cache" + "values_cache", + "attn_scale_factor" ], output_args=[ "queries", "keys", - "values" + "values", + "attn_scores" ], buffer_sizes={ "keys_cache": cache_buffer_size, @@ -389,13 +412,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - self.decode.attn_scale = AIEElementwiseMul( - size=config.n_heads * prompt_len, - tile_size=prompt_len // 8, - num_aie_columns=8, - context=self.context - ).compile().get_callable() - # Softmax operators for attention weights # Prefill uses CPU softmax to reduce NPU operator count @@ -526,18 +542,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) - # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) - self.decode.gemv_attn_scores = AIEGEMV( - M=prompt_len, # max possible context length - K=config.head_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=prompt_len // 8, - num_batches=config.n_heads, - context=self.context - ).compile().get_callable() - # Transpose values from (max_context_len, head_dim) to (head_dim, max_context_len) per head self.decode.transpose_values = AIETranspose( M=prompt_len, @@ -1100,6 +1104,7 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() fused_op.get_buffer("keys_cache").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scale_factor.to("cpu").view_as_torch().flatten() fused_op() @@ -1113,11 +1118,11 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_buffers.decode.queries.to("npu") aie_buffers.decode.keys.to("npu") aie_buffers.decode.values.to("npu") + aie_buffers.decode.attn_scores.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores").to("cpu").view_as_torch().flatten() + aie_buffers.decode.attn_scores.to("npu") # Step 5: Compute attention scores # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV - aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores) - aie_ops.decode.attn_scale(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_scale_factor, aie_buffers.decode.attn_scores) # Step 7: Softmax on NPU (patched once at beginning of decode pass) aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) From 361f10e96c1ff3c068ed9a896230c1413cca067d Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 16:53:43 -0700 Subject: [PATCH 57/99] fuse on softmax as well --- applications/llama_3.2_1b/llama_npu.py | 89 +++++++++++--------------- 1 file changed, 37 insertions(+), 52 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index f81260ad..674abb08 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -151,7 +151,7 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - elf_ctx = AIEContext(build_dir="build_elf") + postattn_ctx = AIEContext(build_dir="build_postattn") # Fused operator for post-norm + SwiGLU + residual (decode) rms_norm_op = AIERMSNorm( @@ -160,7 +160,7 @@ def __init__(self, config, prompt_len): num_aie_columns=1, num_channels=2, tile_size=config.emb_dim, - context=elf_ctx + context=postattn_ctx ) gemv_ffn_up_gate_op = AIEGEMV( @@ -169,7 +169,7 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=4, tile_size_output=config.hidden_dim // 8, - context=elf_ctx + context=postattn_ctx ) gemv_ffn_down_op = AIEGEMV( @@ -178,27 +178,27 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=1, tile_size_output=config.emb_dim // 8, - context=elf_ctx + context=postattn_ctx ) silu_ffn_op = AIESiLU( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, - context=elf_ctx + context=postattn_ctx ) eltwise_mul_ffn_op = AIEElementwiseMul( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, - context=elf_ctx + context=postattn_ctx ) residual_add_op = AIEElementwiseAdd( size=config.emb_dim, tile_size=config.emb_dim // 8, - context=elf_ctx + context=postattn_ctx ) repeat_interleave_op = AIERepeat( @@ -229,10 +229,12 @@ def __init__(self, config, prompt_len): ], output_args=[ ], - context=elf_ctx + context=postattn_ctx ).compile() self.decode.post_attn_fused = self.decode.post_attn_fused_op.get_callable() + elf_ctx = AIEContext(build_dir="build_elf") + # Fused operator for attention projections + RoPE (decode) gemv_attn_query_op = AIEGEMV( M=config.n_heads * config.head_dim, @@ -300,6 +302,18 @@ def __init__(self, config, prompt_len): context=self.context ) + # Softmax operators for attention weights + softmax_magic = 0xBA5EBA11 + softmax_op = AIESoftmax( + rows=config.n_heads, + cols=prompt_len, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=prompt_len, # Compile with max size + mask_patch_value=softmax_magic, # Magic value for patching + context=self.context + ) + # Calculate buffer size for keys/values cache (same size, used as input to strided copy) cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 @@ -316,7 +330,8 @@ def __init__(self, config, prompt_len): (repeat_interleave_op, "keys_cache", "attn_scores_keys"), (repeat_interleave_op, "values_cache", "attn_scores_values"), (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), - (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores") + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), + (softmax_op, "attn_scores", "attn_weights") ], input_args=[ "W_attn_query", @@ -341,11 +356,10 @@ def __init__(self, config, prompt_len): context=elf_ctx ).compile() - elf = load_elf(self.decode.attn_fused_op) - self.decode.attn_fused_elf_data = elf + self.decode.attn_fused_elf_data = load_elf(self.decode.attn_fused_op) def get_patch_locs(elf_data, magic): - return [i for i, x in enumerate(elf_data) if magic & 0xFFFFFFFF == x ] + return [i for i, x in enumerate(elf_data) if magic & 0xFFFFFFFF == x] # Extract patch offsets for strided_copy operations in fused operator _, keys_cache_offs, _ = self.decode.attn_fused_op.get_layout_for_buffer("keys_cache") @@ -364,6 +378,9 @@ def get_patch_locs(elf_data, magic): } self.decode.attn_fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} assert len(self.decode.attn_fused_patch_locations) == 6 + + self.decode.softmax_patch_offsets = get_patch_locs(self.decode.attn_fused_elf_data, softmax_magic) + assert len(self.decode.softmax_patch_offsets) == 2 # Fused transpose for all attention heads (decode) transpose_values_op = AIETranspose( @@ -412,32 +429,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # Softmax operators for attention weights - # Prefill uses CPU softmax to reduce NPU operator count - - self.decode.softmax_compilable = AIESoftmax( - rows=config.n_heads, - cols=prompt_len, - num_aie_columns=1, - num_channels=1, - rtp_vector_size=prompt_len, # Compile with max size - mask_patch_value=0xBA5EBA11, # Magic value for patching - context=self.context - ).compile() - - self.decode.softmax = PatchableSingleXclbinCallable( - xclbin_path=self.decode.softmax_compilable.xclbin_artifact.filename, - kernel_name=self.decode.softmax_compilable.xclbin_artifact.kernel_name, - insts_bin_path=self.decode.softmax_compilable.insts_artifact.filename, - args_spec=self.decode.softmax_compilable.get_arg_spec() - ) - - self.decode.softmax_patch_offsets = [ - i for i, x in enumerate(self.decode.softmax.insts_buffer.view_as_np()) - if 0xBA5EBA11 == x - ] - assert len(self.decode.softmax_patch_offsets) == 1, "Something else accidentally generated our magic mask value" - # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) @@ -1019,20 +1010,18 @@ def patch_operators_for_decode(config, num_preceding_tokens): i: (base + offset_val, 0xFFFFFFFF) for i, base in aie_ops.decode.attn_fused_patch_locations.items() } + softmax_patches = { + i: (context_len, 0xFFFFFFFF) + for i in aie_ops.decode.softmax_patch_offsets + } + patches = {**strided_copy_patches, **softmax_patches} patched_elf_data = aie_ops.decode.attn_fused_elf_data.copy() - patch_elf(patched_elf_data, strided_copy_patches) + patch_elf(patched_elf_data, patches) aie_ops.decode.attn_fused = FusedFullELFCallable( aie_ops.decode.attn_fused_op, elf_data=patched_elf_data ) - - # Patch softmax operator for actual context length - softmax_patches = { - i: (context_len, 0xFFFFFFFF) - for i in aie_ops.decode.softmax_patch_offsets - } - aie_ops.decode.softmax.patch(softmax_patches) def llama_forward_pass_decode(config, state): @@ -1120,12 +1109,8 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_buffers.decode.values.to("npu") aie_buffers.decode.attn_scores.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores").to("cpu").view_as_torch().flatten() aie_buffers.decode.attn_scores.to("npu") - - # Step 5: Compute attention scores - # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV - - # Step 7: Softmax on NPU (patched once at beginning of decode pass) - aie_ops.decode.softmax(aie_buffers.decode.attn_scores, aie_buffers.decode.attn_weights) + aie_buffers.decode.attn_weights.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_weights").to("cpu").view_as_torch().flatten() + aie_buffers.decode.attn_weights.to("npu") # Step 8: Compute attention output on NPU # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head using fused operator From 834f33b54d643954f54b94e9583b7bc394ec5c29 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 16:59:30 -0700 Subject: [PATCH 58/99] transpose fused onto the end; 2.6 TPS --- applications/llama_3.2_1b/llama_npu.py | 106 +++++++++---------------- 1 file changed, 39 insertions(+), 67 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 674abb08..f729b28f 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -314,25 +314,46 @@ def __init__(self, config, prompt_len): context=self.context ) - # Calculate buffer size for keys/values cache (same size, used as input to strided copy) + # Fused transpose for all attention heads (decode) + transpose_values_op = AIETranspose( + M=prompt_len, + N=config.head_dim, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx + ) + cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 + values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 + values_buffer_size = config.n_heads * values_per_head_buffer_size self.decode.attn_fused_op = FusedMLIROperator( "attn_fused_op", - [ - (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), - (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), - (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), - (rope_queries_op, "queries", "rope_angles", "queries"), - (rope_keys_op, "keys", "rope_angles", "keys"), - (strided_copy_cache_op, "keys", "keys_cache"), - (strided_copy_cache_op, "values", "values_cache"), - (repeat_interleave_op, "keys_cache", "attn_scores_keys"), - (repeat_interleave_op, "values_cache", "attn_scores_values"), - (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), - (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), - (softmax_op, "attn_scores", "attn_weights") - ], + ( + [ + (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), + (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), + (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), + (rope_queries_op, "queries", "rope_angles", "queries"), + (rope_keys_op, "keys", "rope_angles", "keys"), + (strided_copy_cache_op, "keys", "keys_cache"), + (strided_copy_cache_op, "values", "values_cache"), + (repeat_interleave_op, "keys_cache", "attn_scores_keys"), + (repeat_interleave_op, "values_cache", "attn_scores_values"), + (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), + (softmax_op, "attn_scores", "attn_weights") + ] + [ + (transpose_values_op, + f"attn_scores_values[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", + f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]" + ) + for h in range(config.n_heads) + ] + ), input_args=[ "W_attn_query", "W_attn_key", @@ -352,6 +373,8 @@ def __init__(self, config, prompt_len): buffer_sizes={ "keys_cache": cache_buffer_size, "values_cache": cache_buffer_size, + "attn_scores_values": values_buffer_size, + "attn_scores_values_transposed": values_buffer_size }, context=elf_ctx ).compile() @@ -382,44 +405,6 @@ def get_patch_locs(elf_data, magic): self.decode.softmax_patch_offsets = get_patch_locs(self.decode.attn_fused_elf_data, softmax_magic) assert len(self.decode.softmax_patch_offsets) == 2 - # Fused transpose for all attention heads (decode) - transpose_values_op = AIETranspose( - M=prompt_len, - N=config.head_dim, - num_aie_columns=2, - num_channels=1, - m=256, - n=32, - s=8, - context=elf_ctx - ) - - # Calculate buffer size for all heads' values - values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 - values_buffer_size = config.n_heads * values_per_head_buffer_size - - # Build runlist with sliced buffers for each head - transpose_runlist = [ - (transpose_values_op, - f"values_all[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", - f"values_transposed_all[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]" - ) - for h in range(config.n_heads) - ] - - self.decode.transpose_values_fused_op = FusedMLIROperator( - "transpose_values_decode", - transpose_runlist, - input_args=["values_all"], - output_args=["values_transposed_all"], - buffer_sizes={ - "values_all": values_buffer_size, - "values_transposed_all": values_buffer_size - }, - context=elf_ctx - ).compile() - self.decode.transpose_values_fused = self.decode.transpose_values_fused_op.get_callable() - # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY self.prefill.attn_scale = AIEElementwiseMul( @@ -1111,20 +1096,7 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i aie_buffers.decode.attn_scores.to("npu") aie_buffers.decode.attn_weights.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_weights").to("cpu").view_as_torch().flatten() aie_buffers.decode.attn_weights.to("npu") - - # Step 8: Compute attention output on NPU - # Transpose values: (max_context_len, head_dim) -> (head_dim, max_context_len) for each head using fused operator - fused_transpose = aie_ops.decode.transpose_values_fused - fused_transpose.input_buffer.view_as_torch().to("cpu")[:] = 0 - fused_transpose.output_buffer.view_as_torch().to("cpu")[:] = 0 - fused_transpose.scratch_buffer.view_as_torch().to("cpu")[:] = 0 - fused_transpose.get_buffer("values_all").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch().flatten() - - fused_transpose() - - # Reshape flat output to match expected 3D shape (n_heads, head_dim, max_context_len) - aie_buffers.decode.attn_scores_values_transposed.to("cpu").view_as_torch().flatten()[:] = fused_transpose.get_buffer("values_transposed_all").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_scores_values_transposed.to("npu") + aie_buffers.decode.attn_scores_values_transposed.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_values_transposed").to("cpu").view_as_torch().flatten() # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) From f345e8da74f28095e0ee1ad5b1a589d94db31d09 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 17:08:57 -0700 Subject: [PATCH 59/99] fuse attention context gemv - 2.7 TPS --- applications/llama_3.2_1b/llama_npu.py | 41 +++++++++----------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index f729b28f..580d5b32 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -326,6 +326,17 @@ def __init__(self, config, prompt_len): context=elf_ctx ) + # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head + gemv_attn_context_op = AIEGEMV( + M=config.head_dim, + K=prompt_len, # max possible context length + num_aie_columns=8, + tile_size_input=4, + tile_size_output=4, + num_batches=config.n_heads, + context=self.context + ) + cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_buffer_size = config.n_heads * values_per_head_buffer_size @@ -352,6 +363,8 @@ def __init__(self, config, prompt_len): f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]" ) for h in range(config.n_heads) + ] + [ + (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context") ] ), input_args=[ @@ -530,17 +543,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head - self.decode.gemv_attn_context = AIEGEMV( - M=config.head_dim, - K=prompt_len, # max possible context length - num_aie_columns=8, - tile_size_input=4, - tile_size_output=4, - num_batches=config.n_heads, - context=self.context - ).compile().get_callable() - # Output projection: (n_heads * head_dim,) @ (emb_dim, n_heads * head_dim)^T -> (emb_dim,) self.decode.gemv_attn_output = AIEGEMV( M=config.emb_dim, @@ -1082,24 +1084,9 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i fused_op() - aie_buffers.decode.queries.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("queries").to("cpu").view_as_torch().flatten() - aie_buffers.decode.keys.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("keys").to("cpu").view_as_torch().flatten() - aie_buffers.decode.values.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("values").to("cpu").view_as_torch().flatten() aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_scores_keys.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_keys").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_values").to("cpu").view_as_torch().flatten() - aie_buffers.decode.queries.to("npu") - aie_buffers.decode.keys.to("npu") - aie_buffers.decode.values.to("npu") - aie_buffers.decode.attn_scores.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_scores.to("npu") - aie_buffers.decode.attn_weights.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_weights").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_weights.to("npu") - aie_buffers.decode.attn_scores_values_transposed.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_values_transposed").to("cpu").view_as_torch().flatten() - - # GEMV: (n_heads, head_dim, max_context_len) @ (n_heads, max_context_len) -> (n_heads, head_dim) - aie_ops.decode.gemv_attn_context(aie_buffers.decode.attn_scores_values_transposed, aie_buffers.decode.attn_weights, aie_buffers.decode.attn_context) + aie_buffers.decode.attn_context.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_context").to("cpu").view_as_torch().flatten() # Step 9: Project on NPU: (emb_dim, n_heads * head_dim) @ (n_heads * head_dim,) -> (emb_dim,) aie_ops.decode.gemv_attn_output(aie_buffers.W_attn_output_decode[layer_idx], aie_buffers.decode.attn_context, aie_buffers.decode.attn_output) From b46838fb706ebedc0ec8f91cce810006f7d448a7 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 17:13:38 -0700 Subject: [PATCH 60/99] fuse attn output onto end - 2.7 TPS --- applications/llama_3.2_1b/llama_npu.py | 29 +++++++++++++------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 580d5b32..84bb3527 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -336,6 +336,15 @@ def __init__(self, config, prompt_len): num_batches=config.n_heads, context=self.context ) + + gemv_attn_output_op = AIEGEMV( + M=config.emb_dim, + K=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.emb_dim // 8, + context=self.context + ) cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 @@ -364,7 +373,8 @@ def __init__(self, config, prompt_len): ) for h in range(config.n_heads) ] + [ - (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context") + (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context"), + (gemv_attn_output_op, "W_attn_output_decode", "attn_context", "attn_output") ] ), input_args=[ @@ -543,15 +553,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # Output projection: (n_heads * head_dim,) @ (emb_dim, n_heads * head_dim)^T -> (emb_dim,) - self.decode.gemv_attn_output = AIEGEMV( - M=config.emb_dim, - K=config.n_heads * config.head_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.emb_dim // 8, - context=self.context - ).compile().get_callable() # Allocate buffers shared with NPU @@ -1081,15 +1082,13 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i fused_op.get_buffer("keys_cache").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scale_factor.to("cpu").view_as_torch().flatten() - + fused_op.get_buffer("W_attn_output_decode").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_output_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op() aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_context.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_context").to("cpu").view_as_torch().flatten() - - # Step 9: Project on NPU: (emb_dim, n_heads * head_dim) @ (n_heads * head_dim,) -> (emb_dim,) - aie_ops.decode.gemv_attn_output(aie_buffers.W_attn_output_decode[layer_idx], aie_buffers.decode.attn_context, aie_buffers.decode.attn_output) + aie_buffers.decode.attn_output.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_output").to("cpu").view_as_torch().flatten() # Main From 99ef9faf4c8e4aad4da51bf992e2ddadf27ba9a9 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 17:22:34 -0700 Subject: [PATCH 61/99] fuse GQA + post attention - 2.5 TPS --- applications/llama_3.2_1b/llama_npu.py | 195 +++++++++++-------------- 1 file changed, 84 insertions(+), 111 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 84bb3527..3df711a1 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -151,87 +151,6 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - postattn_ctx = AIEContext(build_dir="build_postattn") - - # Fused operator for post-norm + SwiGLU + residual (decode) - rms_norm_op = AIERMSNorm( - size=config.emb_dim, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=config.emb_dim, - context=postattn_ctx - ) - - gemv_ffn_up_gate_op = AIEGEMV( - M=config.hidden_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.hidden_dim // 8, - context=postattn_ctx - ) - - gemv_ffn_down_op = AIEGEMV( - M=config.emb_dim, - K=config.hidden_dim, - num_aie_columns=8, - tile_size_input=1, - tile_size_output=config.emb_dim // 8, - context=postattn_ctx - ) - - silu_ffn_op = AIESiLU( - size=config.hidden_dim, - tile_size=config.hidden_dim // 8, - num_aie_columns=8, - context=postattn_ctx - ) - - eltwise_mul_ffn_op = AIEElementwiseMul( - size=config.hidden_dim, - tile_size=config.hidden_dim // 8, - num_aie_columns=8, - context=postattn_ctx - ) - - residual_add_op = AIEElementwiseAdd( - size=config.emb_dim, - tile_size=config.emb_dim // 8, - context=postattn_ctx - ) - - repeat_interleave_op = AIERepeat( - rows=config.n_kv_groups, - cols=prompt_len * config.head_dim, # Max context length - repeat=config.n_heads // config.n_kv_groups, - transfer_size=config.head_dim, - context=self.context - ) - - self.decode.post_attn_fused_op = FusedMLIROperator( - "post_attn_decode", - [ - (residual_add_op, "x", "attn_output", "x"), - (rms_norm_op, "x", "W_norm2", "x_norm"), - (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), - (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), - (silu_ffn_op, "ffn_gate", "ffn_gate"), - (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), - (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), - (residual_add_op, "x", "ffn_output", "x"), - ], - input_args=[ - "W_norm2", - "W_ffn_gate", - "W_ffn_up", - "W_ffn_down" - ], - output_args=[ - ], - context=postattn_ctx - ).compile() - self.decode.post_attn_fused = self.decode.post_attn_fused_op.get_callable() elf_ctx = AIEContext(build_dir="build_elf") @@ -292,14 +211,14 @@ def __init__(self, config, prompt_len): tile_size_input=4, tile_size_output=prompt_len // 8, num_batches=config.n_heads, - context=self.context + context=elf_ctx ) attn_scale_op = AIEElementwiseMul( size=config.n_heads * prompt_len, tile_size=prompt_len // 8, num_aie_columns=8, - context=self.context + context=elf_ctx ) # Softmax operators for attention weights @@ -311,7 +230,7 @@ def __init__(self, config, prompt_len): num_channels=1, rtp_vector_size=prompt_len, # Compile with max size mask_patch_value=softmax_magic, # Magic value for patching - context=self.context + context=elf_ctx ) # Fused transpose for all attention heads (decode) @@ -334,7 +253,7 @@ def __init__(self, config, prompt_len): tile_size_input=4, tile_size_output=4, num_batches=config.n_heads, - context=self.context + context=elf_ctx ) gemv_attn_output_op = AIEGEMV( @@ -343,7 +262,62 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=4, tile_size_output=config.emb_dim // 8, - context=self.context + context=elf_ctx + ) + + rms_norm_op = AIERMSNorm( + size=config.emb_dim, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=config.emb_dim, + context=elf_ctx + ) + + gemv_ffn_up_gate_op = AIEGEMV( + M=config.hidden_dim, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=config.hidden_dim // 8, + context=elf_ctx + ) + + gemv_ffn_down_op = AIEGEMV( + M=config.emb_dim, + K=config.hidden_dim, + num_aie_columns=8, + tile_size_input=1, + tile_size_output=config.emb_dim // 8, + context=elf_ctx + ) + + silu_ffn_op = AIESiLU( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=8, + context=elf_ctx + ) + + eltwise_mul_ffn_op = AIEElementwiseMul( + size=config.hidden_dim, + tile_size=config.hidden_dim // 8, + num_aie_columns=8, + context=elf_ctx + ) + + residual_add_op = AIEElementwiseAdd( + size=config.emb_dim, + tile_size=config.emb_dim // 8, + context=elf_ctx + ) + + repeat_interleave_op = AIERepeat( + rows=config.n_kv_groups, + cols=prompt_len * config.head_dim, # Max context length + repeat=config.n_heads // config.n_kv_groups, + transfer_size=config.head_dim, + context=elf_ctx ) cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 @@ -354,6 +328,7 @@ def __init__(self, config, prompt_len): "attn_fused_op", ( [ + # (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), @@ -375,6 +350,18 @@ def __init__(self, config, prompt_len): ] + [ (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context"), (gemv_attn_output_op, "W_attn_output_decode", "attn_context", "attn_output") + # + ] + [ + # + (residual_add_op, "x", "attn_output", "x"), + (rms_norm_op, "x", "W_norm2", "x_norm"), + (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), + (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), + (residual_add_op, "x", "ffn_output", "x"), + # ] ), input_args=[ @@ -385,7 +372,11 @@ def __init__(self, config, prompt_len): "rope_angles", "keys_cache", "values_cache", - "attn_scale_factor" + "attn_scale_factor", + "W_norm2", + "W_ffn_gate", + "W_ffn_up", + "W_ffn_down" ], output_args=[ "queries", @@ -1046,34 +1037,12 @@ def llama_forward_pass_decode(config, state): def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization - grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx) # Step 2: Attention; results stored in attn_output - - # Step 4-6: Fused post-norm + SwiGLU + residual - fused_op = aie_ops.decode.post_attn_fused - fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("attn_output").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_output.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() - - fused_op() - - aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] - -def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx): - context_len = num_preceding_tokens + 1 - group_size = config.n_heads // config.n_kv_groups - - # Step 1-3: Fused attention projections + RoPE + cache update fused_op = aie_ops.decode.attn_fused fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 + fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_query").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_key").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_value").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() @@ -1083,12 +1052,16 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scale_factor.to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_output_decode").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_output_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op() aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() - aie_buffers.decode.attn_output.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_output").to("cpu").view_as_torch().flatten() + aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] # Main From 8eaa9bc199020d73bad01046a0885f2c5d1b1b55 Mon Sep 17 00:00:00 2001 From: andrej Date: Fri, 23 Jan 2026 17:30:45 -0700 Subject: [PATCH 62/99] fuse rms norm onto beginning of transformer block -- full transformer block fused -- 2.5 TPS --- applications/llama_3.2_1b/llama_npu.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 3df711a1..d776e550 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -66,14 +66,7 @@ def __init__(self, config, prompt_len): tile_size=config.emb_dim, context=self.context ).compile().get_callable() - self.decode.rms_norm = AIERMSNorm( - size=config.emb_dim, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=config.emb_dim, - context=self.context - ).compile().get_callable() + # Residual additions self.prefill.residual_add = AIEElementwiseAdd( @@ -150,7 +143,14 @@ def __init__(self, config, prompt_len): num_aie_columns=8, context=self.context ).compile().get_callable() - + self.decode.rms_norm = AIERMSNorm( + size=config.emb_dim, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=config.emb_dim, + context=self.context + ).compile().get_callable() elf_ctx = AIEContext(build_dir="build_elf") @@ -328,6 +328,8 @@ def __init__(self, config, prompt_len): "attn_fused_op", ( [ + (rms_norm_op, "x", "W_norm1", "x_norm") # Step 1: RMS normalization + ] + [ # (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), @@ -365,6 +367,7 @@ def __init__(self, config, prompt_len): ] ), input_args=[ + "W_norm1", "W_attn_query", "W_attn_key", "W_attn_value", @@ -1036,17 +1039,15 @@ def llama_forward_pass_decode(config, state): def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization - fused_op = aie_ops.decode.attn_fused fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() + fused_op.get_buffer("W_norm1").to("cpu").view_as_torch()[:] = aie_buffers.W_norm1[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_query").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_key").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("W_attn_value").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("x_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x_norm.to("cpu").view_as_torch().flatten() fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() fused_op.get_buffer("keys_cache").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() From 165a93bc02f74f2e716deab3d1502c7f9719011e Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 26 Jan 2026 11:35:22 -0700 Subject: [PATCH 63/99] [WRONG RESULTS] 16x-fused transformer block --- .../llama_3.2_1b/llama_inference_harness.py | 1 + applications/llama_3.2_1b/llama_npu.py | 176 ++++++++++-------- 2 files changed, 98 insertions(+), 79 deletions(-) diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 3d639627..47d2c085 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -194,6 +194,7 @@ def generate( n_tokens_generated = 0 t_prefill_start = time.perf_counter() first_token, state = generate_token(config, forward_pass, state) + print(len(state.token_ids[0])) token_text = config.tokenizer.decode([first_token]) n_tokens_generated += 1 print(token_text, end='', flush=True) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index d776e550..56ebbc5e 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -323,26 +323,26 @@ def __init__(self, config, prompt_len): cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_buffer_size = config.n_heads * values_per_head_buffer_size - - self.decode.attn_fused_op = FusedMLIROperator( - "attn_fused_op", - ( + + runlist = [] + for layer_idx in range(config.n_layers): + runlist.extend( [ - (rms_norm_op, "x", "W_norm1", "x_norm") # Step 1: RMS normalization + (rms_norm_op, "x", f"W_norm1_{layer_idx}", "x_norm") # Step 1: RMS normalization ] + [ # - (gemv_attn_query_op, "W_attn_query", "x_norm", "queries"), - (gemv_attn_key_value_op, "W_attn_key", "x_norm", "keys"), - (gemv_attn_key_value_op, "W_attn_value", "x_norm", "values"), - (rope_queries_op, "queries", "rope_angles", "queries"), - (rope_keys_op, "keys", "rope_angles", "keys"), - (strided_copy_cache_op, "keys", "keys_cache"), - (strided_copy_cache_op, "values", "values_cache"), - (repeat_interleave_op, "keys_cache", "attn_scores_keys"), - (repeat_interleave_op, "values_cache", "attn_scores_values"), - (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), - (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), - (softmax_op, "attn_scores", "attn_weights") + (gemv_attn_query_op, f"W_attn_query_{layer_idx}", "x_norm", "queries"), + (gemv_attn_key_value_op, f"W_attn_key_{layer_idx}", "x_norm", "keys"), + (gemv_attn_key_value_op, f"W_attn_value_{layer_idx}", "x_norm", "values"), + (rope_queries_op, "queries", "rope_angles", "queries"), + (rope_keys_op, "keys", "rope_angles", "keys"), + (strided_copy_cache_op, "keys", "keys_cache"), + (strided_copy_cache_op, "values", "values_cache"), + (repeat_interleave_op, f"keys_cache_{layer_idx}", "attn_scores_keys"), + (repeat_interleave_op, f"values_cache_{layer_idx}", "attn_scores_values"), + (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), + (softmax_op, "attn_scores", "attn_weights") ] + [ (transpose_values_op, f"attn_scores_values[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", @@ -350,48 +350,64 @@ def __init__(self, config, prompt_len): ) for h in range(config.n_heads) ] + [ - (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context"), - (gemv_attn_output_op, "W_attn_output_decode", "attn_context", "attn_output") + (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context"), + (gemv_attn_output_op, f"W_attn_output_decode_{layer_idx}", "attn_context", "attn_output") # ] + [ # - (residual_add_op, "x", "attn_output", "x"), - (rms_norm_op, "x", "W_norm2", "x_norm"), - (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), - (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), - (silu_ffn_op, "ffn_gate", "ffn_gate"), - (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), - (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), - (residual_add_op, "x", "ffn_output", "x"), + (residual_add_op, "x", "attn_output", "x"), + (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), + (gemv_ffn_up_gate_op, f"W_ffn_gate_{layer_idx}", "x_norm", "ffn_gate"), + (gemv_ffn_up_gate_op, f"W_ffn_up_{layer_idx}", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + (gemv_ffn_down_op, f"W_ffn_down_{layer_idx}", "ffn_hidden", "ffn_output"), + (residual_add_op, "x", "ffn_output", "x"), # ] - ), + ) + + self.decode.attn_fused_op = FusedMLIROperator( + "attn_fused_op", + runlist, input_args=[ - "W_norm1", - "W_attn_query", - "W_attn_key", - "W_attn_value", - "x_norm", + f"W_norm1_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_attn_query_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_attn_key_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_attn_value_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_norm2_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_ffn_gate_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_ffn_up_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"W_ffn_down_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"keys_cache_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ + f"values_cache_{layer_idx}" for layer_idx in range(config.n_layers) + ] + [ "rope_angles", - "keys_cache", - "values_cache", "attn_scale_factor", - "W_norm2", - "W_ffn_gate", - "W_ffn_up", - "W_ffn_down" - ], - output_args=[ - "queries", - "keys", - "values", - "attn_scores" ], + output_args=[], buffer_sizes={ - "keys_cache": cache_buffer_size, - "values_cache": cache_buffer_size, - "attn_scores_values": values_buffer_size, - "attn_scores_values_transposed": values_buffer_size + **{ + "keys_cache_{layer_idx}": cache_buffer_size + for layer_idx in range(config.n_layers) + }, + **{ + "values_cache_{layer_idx}": cache_buffer_size + for layer_idx in range(config.n_layers) + }, + **{ + "attn_scores_values": values_buffer_size, + "attn_scores_values_transposed": values_buffer_size + } }, context=elf_ctx ).compile() @@ -417,11 +433,11 @@ def get_patch_locs(elf_data, magic): for l in get_patch_locs(self.decode.attn_fused_elf_data, (strided_copy_cache_magic * 2)) } self.decode.attn_fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} - assert len(self.decode.attn_fused_patch_locations) == 6 + assert len(self.decode.attn_fused_patch_locations) == 4 * config.n_layers + 2 self.decode.softmax_patch_offsets = get_patch_locs(self.decode.attn_fused_elf_data, softmax_magic) - assert len(self.decode.softmax_patch_offsets) == 2 - + assert len(self.decode.softmax_patch_offsets) == config.n_layers + 1 + # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY self.prefill.attn_scale = AIEElementwiseMul( @@ -982,7 +998,7 @@ def llama_forward_pass_prefill( # Decode # ########################################################################## -def patch_operators_for_decode(config, num_preceding_tokens): +def patch_operators_for_decode(ops, config, num_preceding_tokens): context_len = num_preceding_tokens + 1 # Patch fused operator for strided copy cache offset @@ -990,18 +1006,18 @@ def patch_operators_for_decode(config, num_preceding_tokens): offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset strided_copy_patches = { i: (base + offset_val, 0xFFFFFFFF) - for i, base in aie_ops.decode.attn_fused_patch_locations.items() + for i, base in ops.attn_fused_patch_locations.items() } softmax_patches = { i: (context_len, 0xFFFFFFFF) - for i in aie_ops.decode.softmax_patch_offsets + for i in ops.softmax_patch_offsets } patches = {**strided_copy_patches, **softmax_patches} - patched_elf_data = aie_ops.decode.attn_fused_elf_data.copy() + patched_elf_data = ops.attn_fused_elf_data.copy() patch_elf(patched_elf_data, patches) - aie_ops.decode.attn_fused = FusedFullELFCallable( - aie_ops.decode.attn_fused_op, + ops.attn_fused = FusedFullELFCallable( + ops.attn_fused_op, elf_data=patched_elf_data ) @@ -1010,7 +1026,7 @@ def llama_forward_pass_decode(config, state): batch, seq_len = state.token_ids.shape assert seq_len == 1 - patch_operators_for_decode(config, state.num_preceding_tokens) + patch_operators_for_decode(aie_ops.decode, config, state.num_preceding_tokens) # Step 1: Prefill RoPE angle look-up tables angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] @@ -1022,12 +1038,10 @@ def llama_forward_pass_decode(config, state): aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x # Step 3: Transformer blocks - for layer_idx in range(config.n_layers): - transformer_block_forward_decode( - config, - state.num_preceding_tokens, - layer_idx, - ) + transformer_blocks_forward_decode( + config, + state.num_preceding_tokens, + ) aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) # Step 4: Final normalization aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) # Step 5: Output projection @@ -1038,30 +1052,34 @@ def llama_forward_pass_decode(config, state): return logits, state -def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx): +def transformer_blocks_forward_decode(config, num_preceding_tokens): fused_op = aie_ops.decode.attn_fused fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 + + for layer_idx in range(config.n_layers): + fused_op.get_buffer(f"W_norm1_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_norm1[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_attn_query_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_attn_key_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_attn_value_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_attn_output_decode_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_output_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_norm2_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_ffn_gate_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_ffn_up_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"W_ffn_down_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_norm1").to("cpu").view_as_torch()[:] = aie_buffers.W_norm1[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_attn_query").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_attn_key").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_attn_value").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("keys_cache").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("values_cache").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() fused_op.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scale_factor.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_attn_output_decode").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_output_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() fused_op() - aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten() - aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten() + for layer_idx in range(config.n_layers): + aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch().flatten() + aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch().flatten() aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] From 86d7de8027920bd8334501f21f3d8db78aa87334 Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 26 Jan 2026 13:28:17 -0700 Subject: [PATCH 64/99] remove unnecessary syncs, remove unused ops -- 4.4 TPS --- .../llama_3.2_1b/llama_inference_harness.py | 1 - applications/llama_3.2_1b/llama_npu.py | 261 +++++------------- operators/common/fusion.py | 28 +- 3 files changed, 86 insertions(+), 204 deletions(-) diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 47d2c085..3d639627 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -194,7 +194,6 @@ def generate( n_tokens_generated = 0 t_prefill_start = time.perf_counter() first_token, state = generate_token(config, forward_pass, state) - print(len(state.token_ids[0])) token_text = config.tokenizer.decode([first_token]) n_tokens_generated += 1 print(token_text, end='', flush=True) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 56ebbc5e..35f32388 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -336,8 +336,8 @@ def __init__(self, config, prompt_len): (gemv_attn_key_value_op, f"W_attn_value_{layer_idx}", "x_norm", "values"), (rope_queries_op, "queries", "rope_angles", "queries"), (rope_keys_op, "keys", "rope_angles", "keys"), - (strided_copy_cache_op, "keys", "keys_cache"), - (strided_copy_cache_op, "values", "values_cache"), + (strided_copy_cache_op, "keys", f"keys_cache_{layer_idx}"), + (strided_copy_cache_op, "values", f"values_cache_{layer_idx}"), (repeat_interleave_op, f"keys_cache_{layer_idx}", "attn_scores_keys"), (repeat_interleave_op, f"values_cache_{layer_idx}", "attn_scores_values"), (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), @@ -367,41 +367,21 @@ def __init__(self, config, prompt_len): ] ) - self.decode.attn_fused_op = FusedMLIROperator( - "attn_fused_op", + self.decode.fused_op = FusedMLIROperator( + "fused_op", runlist, - input_args=[ - f"W_norm1_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_attn_query_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_attn_key_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_attn_value_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_norm2_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_ffn_gate_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_ffn_up_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"W_ffn_down_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"keys_cache_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ - f"values_cache_{layer_idx}" for layer_idx in range(config.n_layers) - ] + [ + input_args=[ # arguments that change between invocations of the fused kernel and therefore need to be synced on each token + "x", # both input and output "rope_angles", - "attn_scale_factor", ], output_args=[], buffer_sizes={ **{ - "keys_cache_{layer_idx}": cache_buffer_size + f"keys_cache_{layer_idx}": cache_buffer_size for layer_idx in range(config.n_layers) }, **{ - "values_cache_{layer_idx}": cache_buffer_size + f"values_cache_{layer_idx}": cache_buffer_size for layer_idx in range(config.n_layers) }, **{ @@ -412,32 +392,57 @@ def __init__(self, config, prompt_len): context=elf_ctx ).compile() - self.decode.attn_fused_elf_data = load_elf(self.decode.attn_fused_op) + self.decode.fused_elf_data = load_elf(self.decode.fused_op) def get_patch_locs(elf_data, magic): - return [i for i, x in enumerate(elf_data) if magic & 0xFFFFFFFF == x] + magic = magic & 0xFFFFFFFF + return np.where(elf_data == magic)[0] # Extract patch offsets for strided_copy operations in fused operator - _, keys_cache_offs, _ = self.decode.attn_fused_op.get_layout_for_buffer("keys_cache") - _, values_cache_offs, _ = self.decode.attn_fused_op.get_layout_for_buffer("values_cache") - keys_patches = { - l: keys_cache_offs - for l in get_patch_locs(self.decode.attn_fused_elf_data, (keys_cache_offs + strided_copy_cache_magic * 2)) - } - values_patches = { - l: values_cache_offs - for l in get_patch_locs(self.decode.attn_fused_elf_data, (values_cache_offs + strided_copy_cache_magic * 2)) - } + keys_patches = {} + values_patches = {} + for layer_idx in range(config.n_layers): + _, keys_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer(f"keys_cache_{layer_idx}") + _, values_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer(f"values_cache_{layer_idx}") + keys_patches.update({ + int(l): keys_cache_offs + for l in get_patch_locs(self.decode.fused_elf_data, (keys_cache_offs + strided_copy_cache_magic * 2)) + }) + values_patches.update({ + int(l): values_cache_offs + for l in get_patch_locs(self.decode.fused_elf_data, (values_cache_offs + strided_copy_cache_magic * 2)) + }) no_offset_patches = { - l: 0 - for l in get_patch_locs(self.decode.attn_fused_elf_data, (strided_copy_cache_magic * 2)) + int(l): 0 + for l in get_patch_locs(self.decode.fused_elf_data, (strided_copy_cache_magic * 2)) } - self.decode.attn_fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} - assert len(self.decode.attn_fused_patch_locations) == 4 * config.n_layers + 2 + self.decode.fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} + assert len(self.decode.fused_patch_locations) == 4 * config.n_layers + 2 - self.decode.softmax_patch_offsets = get_patch_locs(self.decode.attn_fused_elf_data, softmax_magic) + self.decode.softmax_patch_offsets = get_patch_locs(self.decode.fused_elf_data, softmax_magic) assert len(self.decode.softmax_patch_offsets) == config.n_layers + 1 + self.decode.fused = FusedFullELFCallable( + self.decode.fused_op, + elf_data=self.decode.fused_elf_data + ) + + for layer_idx in range(config.n_layers): + self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'].flatten() + self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_attn_key_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_attn_value_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_attn_output_decode_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_norm2_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'].flatten() + self.decode.fused.get_buffer(f"W_ffn_gate_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'].flatten() + scale_factor = 1.0 / math.sqrt(config.head_dim) + self.decode.fused.get_buffer(f"attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor + self.decode.fused.input_buffer.to("npu") + self.decode.fused.scratch_buffer.to("npu") + self.decode.fused.output_buffer.to("npu") + # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY self.prefill.attn_scale = AIEElementwiseMul( @@ -465,20 +470,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - self.decode.rope_queries = AIERope( - rows=1 * config.n_heads, - cols=config.head_dim, - angle_rows=1, - context=self.context - ).compile().get_callable() - - self.decode.rope_keys = AIERope( - rows=1 * config.n_kv_groups, - cols=config.head_dim, - angle_rows=1, - context=self.context - ).compile().get_callable() - # Attention projection operators # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) self.prefill.attn_query = AIEGEMM( @@ -493,15 +484,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - self.decode.gemv_attn_query = AIEGEMV( - M=config.n_heads * config.head_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.head_dim // 2, - context=self.context - ).compile().get_callable() - # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) self.prefill.attn_key = AIEGEMM( M=prompt_len, @@ -515,15 +497,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - self.decode.gemv_attn_key_value = AIEGEMV( - M=config.n_kv_groups * config.head_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.head_dim // 2, - context=self.context - ).compile().get_callable() - # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) self.prefill.attn_value = AIEGEMM( M=prompt_len, @@ -551,18 +524,6 @@ def get_patch_locs(elf_data, magic): context=self.context ).compile().get_callable() - # Transpose values from (max_context_len, head_dim) to (head_dim, max_context_len) per head - self.decode.transpose_values = AIETranspose( - M=prompt_len, - N=config.head_dim, - num_aie_columns=2, - num_channels=1, - m=256, - n=32, - s=8, - context=self.context - ).compile().get_callable() - # Allocate buffers shared with NPU @@ -627,48 +588,7 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d class AIEDecodeBuffers: def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_context_len): self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) - self.x_norm = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) - self.attn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) - self.ffn_output = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) - # SwiGLU intermediate buffers - self.ffn_gate = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) - self.ffn_up = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) - self.ffn_hidden = AIEBuffer(shape=(1, hidden_dim), dtype=ml_dtypes.bfloat16) - # Attention buffers: queries and keys serve as both projection output and RoPE input/output - self.queries = AIEBuffer(shape=(1 * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.keys = AIEBuffer(shape=(1 * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.values = AIEBuffer(shape=(1, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles = AIEBuffer(shape=(1, head_dim), dtype=ml_dtypes.bfloat16) - # Attention score computation buffers (batched) - self.attn_scores_keys = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) - self.attn_scores_values = AIEBuffer(shape=(n_heads, max_context_len, head_dim), dtype=ml_dtypes.bfloat16) - self.attn_scores_values_transposed = AIEBuffer(shape=(n_heads, head_dim, max_context_len), dtype=ml_dtypes.bfloat16) - # Create per-head subbuffers for transpose operations (to avoid allocating in hot path) - self.attn_scores_values_per_head = [ - self.attn_scores_values.subbuffer( - length=max_context_len * head_dim, - offset=h * max_context_len * head_dim, - shape=(max_context_len, head_dim) - ) - for h in range(n_heads) - ] - self.attn_scores_values_transposed_per_head = [ - self.attn_scores_values_transposed.subbuffer( - length=head_dim * max_context_len, - offset=h * head_dim * max_context_len, - shape=(head_dim, max_context_len) - ) - for h in range(n_heads) - ] - self.attn_context = AIEBuffer(shape=(n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.attn_context_concat = AIEBuffer(shape=(n_heads * head_dim,), dtype=ml_dtypes.bfloat16) - self.attn_scores = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) - # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) - scale_factor = 1.0 / math.sqrt(head_dim) - self.attn_scale_factor = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) - self.attn_scale_factor.view_as_torch()[:] = scale_factor - self.attn_scale_factor.to("npu") - self.attn_weights = AIEBuffer(shape=(n_heads, max_context_len), dtype=ml_dtypes.bfloat16) + class AIELlamaBuffers: def __init__(self, config, prompt_len): @@ -711,36 +631,15 @@ def __init__(self, config, prompt_len): self.W_norm2.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight']).to("npu") ) - self.W_attn_query_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight']).to("npu") - ) self.W_attn_query_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].T).to("npu") ) - self.W_attn_key_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight']).to("npu") - ) self.W_attn_key_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].T).to("npu") ) - self.W_attn_value_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight']).to("npu") - ) self.W_attn_value_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].T).to("npu") ) - self.W_attn_output_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']).to("npu") - ) - self.W_ffn_gate_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight']).to("npu") - ) - self.W_ffn_up_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight']).to("npu") - ) - self.W_ffn_down_decode.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight']).to("npu") - ) self.W_ffn_gate_prefill.append( AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'].T).to("npu") ) @@ -998,7 +897,7 @@ def llama_forward_pass_prefill( # Decode # ########################################################################## -def patch_operators_for_decode(ops, config, num_preceding_tokens): +def patch_fused_decode_operator(ops, config, num_preceding_tokens): context_len = num_preceding_tokens + 1 # Patch fused operator for strided copy cache offset @@ -1006,42 +905,38 @@ def patch_operators_for_decode(ops, config, num_preceding_tokens): offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset strided_copy_patches = { i: (base + offset_val, 0xFFFFFFFF) - for i, base in ops.attn_fused_patch_locations.items() + for i, base in ops.fused_patch_locations.items() } softmax_patches = { i: (context_len, 0xFFFFFFFF) for i in ops.softmax_patch_offsets } patches = {**strided_copy_patches, **softmax_patches} - patched_elf_data = ops.attn_fused_elf_data.copy() + patched_elf_data = ops.fused_elf_data.copy() patch_elf(patched_elf_data, patches) - ops.attn_fused = FusedFullELFCallable( - ops.attn_fused_op, - elf_data=patched_elf_data - ) + ops.fused.reload_elf(patched_elf_data) def llama_forward_pass_decode(config, state): batch, seq_len = state.token_ids.shape assert seq_len == 1 - patch_operators_for_decode(aie_ops.decode, config, state.num_preceding_tokens) + patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) # Step 1: Prefill RoPE angle look-up tables angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] - aie_buffers.decode.rope_angles.view_as_torch()[:] = angles_slice + aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = angles_slice + #scale_factor = 1.0 / math.sqrt(config.head_dim) + #aie_ops.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor # Step 2: Token embedding (on CPU) tok_emb_weight = config.weights['model.embed_tokens.weight'] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) - aie_buffers.decode.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x + aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[:seq_len, :] = x # Step 3: Transformer blocks - transformer_blocks_forward_decode( - config, - state.num_preceding_tokens, - ) + transformer_blocks_forward_decode(config, state.num_preceding_tokens) aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) # Step 4: Final normalization aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) # Step 5: Output projection @@ -1053,33 +948,10 @@ def llama_forward_pass_decode(config, state): def transformer_blocks_forward_decode(config, num_preceding_tokens): - fused_op = aie_ops.decode.attn_fused - fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0 - fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0 - - for layer_idx in range(config.n_layers): - fused_op.get_buffer(f"W_norm1_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_norm1[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_attn_query_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_query_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_attn_key_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_key_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_attn_value_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_value_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_attn_output_decode_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_attn_output_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_norm2_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_ffn_gate_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_ffn_up_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"W_ffn_down_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() - fused_op.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() - - fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = aie_buffers.decode.rope_angles.to("cpu").view_as_torch().flatten() - fused_op.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_scale_factor.to("cpu").view_as_torch().flatten() + fused_op = aie_ops.decode.fused + fused_op.input_buffer.to("cpu") fused_op() - - for layer_idx in range(config.n_layers): - aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch().flatten() - aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch().flatten() aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] @@ -1090,10 +962,17 @@ def llama_forward_pass( config, state ): + global aie_ops, aie_buffers + batch, seq_len = state.token_ids.shape if seq_len > 1: ret = llama_forward_pass_prefill(config, state) state.num_preceding_tokens = state.token_ids.shape[1] + # Pass KV cache data onto fused decode operator + for layer_idx in range(config.n_layers): + aie_ops.decode.fused.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() + aie_ops.decode.fused.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + aie_ops.decode.fused.scratch_buffer.to("cpu") return ret else: ret = llama_forward_pass_decode(config, state) diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 265ef420..9f87076f 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -206,7 +206,22 @@ def patch_elf(elf_data, patches): class FullELFCallable: def __init__(self, elf_data, device_name="main", sequence_name="sequence", device_manager=None): + self.device_name = device_name + self.sequence_name = sequence_name self.device_manager = device_manager or AIEDeviceManager() + self.reload_elf(elf_data) + + def __call__(self, *args): + run = pyxrt.run(self.xrt_kernel) + for i, arg in enumerate(args): + assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" + run.set_arg(i, arg) + run.start() + ret_code = run.wait() + if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: + raise RuntimeError(f"Kernel execution failed with return code {ret_code}") + + def reload_elf(self, elf_data): # Create a PyCapsule from the numpy array pointer for pybind11 elf_data_u8 = elf_data.view(dtype=np.uint8) ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object @@ -218,17 +233,7 @@ def __init__(self, elf_data, device_name="main", sequence_name="sequence", devic ) xrt_elf = pyxrt.elf(capsule, elf_data.nbytes) xrt_context = pyxrt.hw_context(self.device_manager.device, xrt_elf) - self.xrt_kernel = pyxrt.ext.kernel(xrt_context, f"{device_name}:{sequence_name}") - - def __call__(self, *args): - run = pyxrt.run(self.xrt_kernel) - for i, arg in enumerate(args): - assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" - run.set_arg(i, arg) - run.start() - ret_code = run.wait() - if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError(f"Kernel execution failed with return code {ret_code}") + self.xrt_kernel = pyxrt.ext.kernel(xrt_context, f"{self.device_name}:{self.sequence_name}") class FusedFullELFCallable(FullELFCallable): @@ -298,7 +303,6 @@ def get_buffer(self, buffer_name): def __call__(self): self.input_buffer.to("npu") self.output_buffer.to("npu") - self.scratch_buffer.to("npu") super().__call__( self.input_buffer.bo if self.input_buffer else None, self.output_buffer.bo if self.output_buffer else None, From 77bac5ae10abc4f6c5c6606328fc90760f85bddc Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 26 Jan 2026 13:40:49 -0700 Subject: [PATCH 65/99] [decode end-to-end fused] offload last rms norm and last linear layer -- 4.4 TPS --- applications/llama_3.2_1b/llama_npu.py | 245 +++++++++++++------------ 1 file changed, 125 insertions(+), 120 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 35f32388..b46fb0db 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -57,6 +57,10 @@ def __init__(self, config, prompt_len): self.prefill = AIEPrefillOperations() self.decode = AIEDecodeOperations() + + # ########################################################################## + # Prefill operators + # RMS Norm self.prefill.rms_norm = AIERMSNorm( size=prompt_len * config.emb_dim, @@ -95,14 +99,6 @@ def __init__(self, config, prompt_len): context=self.context ).compile() self.prefill.out_head = self.prefill.gemv_out_head_compilable.get_callable() - self.decode.gemv_out_head = AIEGEMV( - M=config.vocab_size, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=32, - context=self.context - ).compile().get_callable() # SwiGLU FFN operators # Prefill: M=prompt_len, K=emb_dim, N=hidden_dim @@ -143,15 +139,92 @@ def __init__(self, config, prompt_len): num_aie_columns=8, context=self.context ).compile().get_callable() - self.decode.rms_norm = AIERMSNorm( - size=config.emb_dim, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=config.emb_dim, + + # Attention score scaling operators + # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY + self.prefill.attn_scale = AIEElementwiseMul( + size=config.n_heads * prompt_len * prompt_len, + tile_size=prompt_len, + num_aie_columns=8, + context=self.context + ).compile().get_callable() + + # RoPE operators + # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) + # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) + # angle_rows=1 because all rows use the same angle row (angles are per position) + self.prefill.rope_queries = AIERope( + rows=prompt_len * config.n_heads, + cols=config.head_dim, + angle_rows=prompt_len, context=self.context - ).compile().get_callable() + ).compile().get_callable() + self.prefill.rope_keys = AIERope( + rows=prompt_len * config.n_kv_groups, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context + ).compile().get_callable() + + # Attention projection operators + # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) + self.prefill.attn_query = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context + ).compile().get_callable() + + # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_key = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context + ).compile().get_callable() + + # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) + self.prefill.attn_value = AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context + ).compile().get_callable() + + # Attention score computation: Q @ K^T per head + # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head + self.prefill.attn_scores = AIEGEMM( + M=prompt_len, + K=config.head_dim, + N=prompt_len, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context + ).compile().get_callable() + + + # Fused prefill operator + # ########################################################################## + elf_ctx = AIEContext(build_dir="build_elf") # Fused operator for attention projections + RoPE (decode) @@ -319,6 +392,17 @@ def __init__(self, config, prompt_len): transfer_size=config.head_dim, context=elf_ctx ) + + gemv_out_head_op = AIEGEMV( + M=config.vocab_size, + K=config.emb_dim, + num_aie_columns=8, + tile_size_input=4, + tile_size_output=32, + context=self.context + ) + + # Create fused operator cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 @@ -326,6 +410,7 @@ def __init__(self, config, prompt_len): runlist = [] for layer_idx in range(config.n_layers): + # runlist.extend( [ (rms_norm_op, "x", f"W_norm1_{layer_idx}", "x_norm") # Step 1: RMS normalization @@ -354,7 +439,6 @@ def __init__(self, config, prompt_len): (gemv_attn_output_op, f"W_attn_output_decode_{layer_idx}", "attn_context", "attn_output") # ] + [ - # (residual_add_op, "x", "attn_output", "x"), (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), (gemv_ffn_up_gate_op, f"W_ffn_gate_{layer_idx}", "x_norm", "ffn_gate"), @@ -363,18 +447,24 @@ def __init__(self, config, prompt_len): (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), (gemv_ffn_down_op, f"W_ffn_down_{layer_idx}", "ffn_hidden", "ffn_output"), (residual_add_op, "x", "ffn_output", "x"), - # ] ) + # + runlist += [ + (rms_norm_op, "x", "W_final_norm", "x"), + (gemv_out_head_op, "W_out_head", "x", "logits") + ] self.decode.fused_op = FusedMLIROperator( "fused_op", runlist, input_args=[ # arguments that change between invocations of the fused kernel and therefore need to be synced on each token - "x", # both input and output + "x", "rope_angles", ], - output_args=[], + output_args=[ + "logits" + ], buffer_sizes={ **{ f"keys_cache_{layer_idx}": cache_buffer_size @@ -392,8 +482,10 @@ def __init__(self, config, prompt_len): context=elf_ctx ).compile() + # Operator patching + self.decode.fused_elf_data = load_elf(self.decode.fused_op) - + def get_patch_locs(elf_data, magic): magic = magic & 0xFFFFFFFF return np.where(elf_data == magic)[0] @@ -427,6 +519,8 @@ def get_patch_locs(elf_data, magic): elf_data=self.decode.fused_elf_data ) + # Operator static buffers (weights, LUTs) + for layer_idx in range(config.n_layers): self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'].flatten() self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].flatten() @@ -438,92 +532,12 @@ def get_patch_locs(elf_data, magic): self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'].flatten() self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'].flatten() scale_factor = 1.0 / math.sqrt(config.head_dim) - self.decode.fused.get_buffer(f"attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor + self.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor + self.decode.fused.get_buffer("W_final_norm").to("cpu").view_as_torch()[:] = config.weights['model.norm.weight'].flatten() + self.decode.fused.get_buffer("W_out_head").to("cpu").view_as_torch()[:] = config.weights['model.embed_tokens.weight'].flatten() self.decode.fused.input_buffer.to("npu") self.decode.fused.scratch_buffer.to("npu") - self.decode.fused.output_buffer.to("npu") - - # Attention score scaling operators - # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY - self.prefill.attn_scale = AIEElementwiseMul( - size=config.n_heads * prompt_len * prompt_len, - tile_size=prompt_len, - num_aie_columns=8, - context=self.context - ).compile().get_callable() - - # RoPE operators - # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) - # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) - # angle_rows=1 because all rows use the same angle row (angles are per position) - self.prefill.rope_queries = AIERope( - rows=prompt_len * config.n_heads, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context - ).compile().get_callable() - - self.prefill.rope_keys = AIERope( - rows=prompt_len * config.n_kv_groups, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context - ).compile().get_callable() - - # Attention projection operators - # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) - self.prefill.attn_query = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_heads * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - - # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_key = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - - # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_value = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - - # Attention score computation: Q @ K^T per head - # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head - self.prefill.attn_scores = AIEGEMM( - M=prompt_len, - K=config.head_dim, - N=prompt_len, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - + self.decode.fused.output_buffer.to("npu") # Allocate buffers shared with NPU @@ -924,37 +938,28 @@ def llama_forward_pass_decode(config, state): patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) - # Step 1: Prefill RoPE angle look-up tables + # Prefill RoPE angle look-up tables angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = angles_slice #scale_factor = 1.0 / math.sqrt(config.head_dim) #aie_ops.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor - # Step 2: Token embedding (on CPU) + # Token embedding (on CPU) tok_emb_weight = config.weights['model.embed_tokens.weight'] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[:seq_len, :] = x - # Step 3: Transformer blocks - transformer_blocks_forward_decode(config, state.num_preceding_tokens) - aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_final_norm, aie_buffers.decode.x) # Step 4: Final normalization - aie_ops.decode.gemv_out_head(aie_buffers.W_out_head, aie_buffers.decode.x, aie_buffers.decode.logits) # Step 5: Output projection + # Fused NPU operator for all of decode + aie_ops.decode.fused.input_buffer.to("cpu") + aie_ops.decode.fused() + aie_ops.decode.fused.output_buffer.to("cpu") # Read outputs from NPU to CPU - aie_buffers.decode.logits.to("cpu") - logits = aie_buffers.decode.logits.view_as_torch().view(1, 1, config.vocab_size) + logits = aie_ops.decode.fused.get_buffer("logits").view_as_torch().view(1, 1, config.vocab_size) return logits, state -def transformer_blocks_forward_decode(config, num_preceding_tokens): - fused_op = aie_ops.decode.fused - - fused_op.input_buffer.to("cpu") - fused_op() - aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:] - - # Main # ########################################################################## From 6211124f1dd611bd34e7a5e82a5b447bea5e3c4b Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 26 Jan 2026 13:52:42 -0700 Subject: [PATCH 66/99] cleanup --- applications/llama_3.2_1b/llama_npu.py | 61 ++++++++++++-------------- operators/__init__.py | 12 ++--- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index b46fb0db..2b865959 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -1,5 +1,13 @@ #!/usr/bin/env python3 +# Next steps for decode performance: +# [ ] All decode operators operate on 2048-padded buffers; instead, should bin into shorter sequence lengths and call smaller operators +# [ ] Opportunity to fuse data layout transformations (e.g., transpose ops) onto end of other operations (e.g., transpose after RoPE) +# [ ] Some kernels are not optimized; e.g., softmax masking is using scalar cores +# [ ] Fine-tune parameters of operators (e.g., num AIE columns, tile sizes) +# [ ] Patching of operators (instantiating new xrt::elf for each token) is slow; find quicker way of patching instruction sequence in-memory +# [ ] Spatial fusion of operators + import torch import math from pathlib import Path @@ -14,32 +22,33 @@ sys.path.insert(0, str(repo_root)) from operators.common.context import AIEContext -from operators.common import ( - AIEBuffer -) -from operators.common.utils import torch_to_numpy, numpy_to_torch +from operators.common import AIEBuffer +from operators.common.utils import torch_to_numpy from operators.common.base import PatchableSingleXclbinCallable from operators.common.fusion import FusedMLIROperator, FusedFullELFCallable, load_elf, patch_elf from operators import ( AIERMSNorm, AIEGEMM, AIEGEMV, - AIEElementwiseAdd + AIEElementwiseAdd, + AIEElementwiseMul, + AIESiLU, + AIERope, + AIEStridedCopy, + AIERepeat, + AIESoftmax, + AIETranspose ) -from operators.elementwise_mul.op import AIEElementwiseMul -from operators.silu.op import AIESiLU -from operators.rope.op import AIERope -from operators.strided_copy.op import AIEStridedCopy -from operators.repeat.op import AIERepeat -from operators.softmax.op import AIESoftmax -from operators.transpose.op import AIETranspose logging.basicConfig(level=logging.DEBUG) +max_seq_len = 2048 + # AIE Operator Configuration # ########################################################################## + aie_ops = None class AIEPrefillOperations: @@ -57,11 +66,9 @@ def __init__(self, config, prompt_len): self.prefill = AIEPrefillOperations() self.decode = AIEDecodeOperations() - - # ########################################################################## + # ################################################################## # Prefill operators - # RMS Norm self.prefill.rms_norm = AIERMSNorm( size=prompt_len * config.emb_dim, eps=1e-5, @@ -71,8 +78,6 @@ def __init__(self, config, prompt_len): context=self.context ).compile().get_callable() - - # Residual additions self.prefill.residual_add = AIEElementwiseAdd( size=prompt_len * config.emb_dim, tile_size=config.emb_dim @@ -82,7 +87,6 @@ def __init__(self, config, prompt_len): tile_size=config.emb_dim // 8 ).compile().get_callable() - # Final GEMM min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N config.vocab_partitions = 4 @@ -222,12 +226,11 @@ def __init__(self, config, prompt_len): ).compile().get_callable() - # Fused prefill operator - # ########################################################################## + # Decode operator (everything temporally fused) + # ################################################################## elf_ctx = AIEContext(build_dir="build_elf") - # Fused operator for attention projections + RoPE (decode) gemv_attn_query_op = AIEGEMV( M=config.n_heads * config.head_dim, K=config.emb_dim, @@ -490,7 +493,6 @@ def get_patch_locs(elf_data, magic): magic = magic & 0xFFFFFFFF return np.where(elf_data == magic)[0] - # Extract patch offsets for strided_copy operations in fused operator keys_patches = {} values_patches = {} for layer_idx in range(config.n_layers): @@ -599,16 +601,11 @@ def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_d # Attention weights buffer (output of softmax) self.attn_weights = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) -class AIEDecodeBuffers: - def __init__(self, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim, max_context_len): - self.x = AIEBuffer(shape=(1, emb_dim), dtype=ml_dtypes.bfloat16) - class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) - self.decode = AIEDecodeBuffers(config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim, prompt_len) # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) self.keys_cache = [ @@ -685,7 +682,6 @@ def __init__(self, config, prompt_len): ) for i in range(config.vocab_partitions) ] - self.decode.logits = AIEBuffer(shape=(config.vocab_size,)) # Prefill @@ -935,26 +931,23 @@ def patch_fused_decode_operator(ops, config, num_preceding_tokens): def llama_forward_pass_decode(config, state): batch, seq_len = state.token_ids.shape assert seq_len == 1 + assert state.num_preceding_tokens < max_seq_len patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) # Prefill RoPE angle look-up tables angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = angles_slice - #scale_factor = 1.0 / math.sqrt(config.head_dim) - #aie_ops.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor # Token embedding (on CPU) tok_emb_weight = config.weights['model.embed_tokens.weight'] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[:seq_len, :] = x - # Fused NPU operator for all of decode + # Fused NPU operator for all of decode (16 transformer blocks + final norm + final linear layer) aie_ops.decode.fused.input_buffer.to("cpu") aie_ops.decode.fused() aie_ops.decode.fused.output_buffer.to("cpu") - - # Read outputs from NPU to CPU logits = aie_ops.decode.fused.get_buffer("logits").view_as_torch().view(1, 1, config.vocab_size) return logits, state @@ -963,6 +956,7 @@ def llama_forward_pass_decode(config, state): # Main # ########################################################################## + def llama_forward_pass( config, state @@ -987,7 +981,6 @@ def llama_forward_pass( def main(): global aie_ops, aie_buffers - max_seq_len = 2048 prompt = "The capital of France is " #with open('prompt.txt', 'r') as f: # prompt = f.read() diff --git a/operators/__init__.py b/operators/__init__.py index 7da4b542..2696cbf1 100644 --- a/operators/__init__.py +++ b/operators/__init__.py @@ -4,7 +4,7 @@ #from .axpy.op import AIEAXPY #from .dequant.op import AIEDequant from .elementwise_add.op import AIEElementwiseAdd -#from .elementwise_mul.op import AIEElementwiseMul +from .elementwise_mul.op import AIEElementwiseMul #from .gelu.op import AIEGELU from .gemm.op import AIEGEMM from .gemv.op import AIEGEMV @@ -14,11 +14,13 @@ #from .mha.op import AIEMHA #from .relu.op import AIEReLU from .rms_norm.op import AIERMSNorm -#from .rope.op import AIERope +from .rope.op import AIERope #from .sigmoid.op import AIESigmoid -#from .silu.op import AIESiLU -#from .softmax.op import AIESoftmax +from .silu.op import AIESiLU +from .softmax.op import AIESoftmax #from .swiglu_decode.op import AIESwiGLUDecode #from .swiglu_prefill.op import AIESwiGLUPrefill #from .tanh.op import AIETanh -#from .transpose.op import AIETranspose +from .transpose.op import AIETranspose +from .strided_copy.op import AIEStridedCopy +from .repeat.op import AIERepeat From 2da438dfec353a9458c1e025163489990b48c1fb Mon Sep 17 00:00:00 2001 From: andrej Date: Mon, 26 Jan 2026 13:54:17 -0700 Subject: [PATCH 67/99] remove old llama implementation --- applications/llama_3.2_1b/analyze_profile.py | 358 ------------- applications/llama_3.2_1b/autofuse.py | 212 -------- applications/llama_3.2_1b/bar_plot_profile.py | 150 ------ .../llama_3.2_1b/configs/llama32_1b.json | 39 -- .../configs/llama32_1b.json.license | 7 - applications/llama_3.2_1b/custom_profile.py | 70 --- applications/llama_3.2_1b/inference.py | 352 ------------ .../llama_3.2_1b/profile_path_analyzer.py | 348 ------------ .../llama_3.2_1b/src/block/feed_forward.py | 250 --------- applications/llama_3.2_1b/src/block/gqa.py | 505 ------------------ .../llama_3.2_1b/src/block/transformer.py | 195 ------- .../llama_3.2_1b/src/model_with_json.py | 309 ----------- applications/llama_3.2_1b/src/tokenizer.py | 101 ---- applications/llama_3.2_1b/src/utils.py | 307 ----------- applications/llama_3.2_1b/test.py | 51 -- applications/llama_3.2_1b/torch_to_npy.py | 49 -- .../llama_3.2_1b/visualize_profile.py | 437 --------------- 17 files changed, 3740 deletions(-) delete mode 100644 applications/llama_3.2_1b/analyze_profile.py delete mode 100755 applications/llama_3.2_1b/autofuse.py delete mode 100755 applications/llama_3.2_1b/bar_plot_profile.py delete mode 100644 applications/llama_3.2_1b/configs/llama32_1b.json delete mode 100644 applications/llama_3.2_1b/configs/llama32_1b.json.license delete mode 100644 applications/llama_3.2_1b/custom_profile.py delete mode 100755 applications/llama_3.2_1b/inference.py delete mode 100755 applications/llama_3.2_1b/profile_path_analyzer.py delete mode 100644 applications/llama_3.2_1b/src/block/feed_forward.py delete mode 100644 applications/llama_3.2_1b/src/block/gqa.py delete mode 100644 applications/llama_3.2_1b/src/block/transformer.py delete mode 100644 applications/llama_3.2_1b/src/model_with_json.py delete mode 100644 applications/llama_3.2_1b/src/tokenizer.py delete mode 100644 applications/llama_3.2_1b/src/utils.py delete mode 100644 applications/llama_3.2_1b/test.py delete mode 100644 applications/llama_3.2_1b/torch_to_npy.py delete mode 100644 applications/llama_3.2_1b/visualize_profile.py diff --git a/applications/llama_3.2_1b/analyze_profile.py b/applications/llama_3.2_1b/analyze_profile.py deleted file mode 100644 index 7e2c76f1..00000000 --- a/applications/llama_3.2_1b/analyze_profile.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Analyze profiling logs generated by inference.py - -This script parses the profile logs and provides statistics about function execution times. -The total times reported in the analysis results are the cumulative times of the functions, including subcalls. -""" - -import argparse -import re -from collections import defaultdict -from pathlib import Path -import sys -import csv -import statistics -from collections import deque - - -class FunctionStats: - def __init__(self, name): - self.name = name - self.call_count = 0 - self.total_time = 0.0 - self.min_time = float("inf") - self.max_time = 0.0 - self.durations = [] - - def add_duration(self, duration): - self.call_count += 1 - self.total_time += duration - self.min_time = min(self.min_time, duration) - self.max_time = max(self.max_time, duration) - self.durations.append(duration) - - @property - def avg_time(self): - if not self.durations: - return 0.0 - return statistics.mean(self.durations) - - @property - def median_time(self): - if not self.durations: - return 0.0 - return statistics.median(self.durations) - - -def parse_profile_log(log_file): - """ - Parse a profile log file and extract function timing information. - - Args: - log_file: Path to the profile log file - - Returns: - dict: Dictionary mapping function names to FunctionStats objects - """ - stats = defaultdict(lambda: FunctionStats("")) - function_stack = deque() # Track ongoing calls by function identifier - - # Regex patterns for parsing log lines - call_pattern = re.compile(r"\[CALL\] (.+?) started at ([\d.]+)") - return_pattern = re.compile(r"\[RETURN\] (.+?) ended at ([\d.]+)") - - with open(log_file, "r") as f: - for line in f: - # Try to match CALL pattern - call_match = call_pattern.search(line) - if call_match: - func_id = call_match.group(1) - timestamp = float(call_match.group(2)) - function_stack.append((func_id, timestamp)) - continue - - # Try to match RETURN pattern - return_match = return_pattern.search(line) - if return_match: - func_id = return_match.group(1) - timestamp = float(return_match.group(2)) - - # Use the full function identifier (filepath:function_name:line_no) - if func_id not in stats: - stats[func_id].name = func_id - - if function_stack: - stats[func_id].add_duration(timestamp - function_stack.pop()[1]) - else: - raise RuntimeError( - f"Stack empty, found a log for the return but missing a log for the call of {func_id}" - ) - - return dict(stats) - - -def print_summary(stats, sort_by="total", top_n=20, min_calls=1): - """ - Print a summary of function statistics. - - Args: - stats: Dictionary of function statistics - sort_by: Sort criterion ('total', 'avg', 'max', 'calls') - top_n: Number of top functions to display - min_calls: Minimum number of calls to include in results - """ - # Filter by minimum calls - filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} - - if not filtered_stats: - print("No functions found matching the criteria.") - return - - # Sort functions - sort_keys = { - "total": lambda x: x[1].total_time, - "avg": lambda x: x[1].avg_time, - "max": lambda x: x[1].max_time, - "calls": lambda x: x[1].call_count, - } - - if sort_by not in sort_keys: - print(f"Invalid sort criterion: {sort_by}. Using 'total'.") - sort_by = "total" - - sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) - - # Print header - print("\n" + "=" * 160) - print(f"FUNCTION PROFILING SUMMARY (sorted by {sort_by}, top {top_n})") - print("=" * 160) - print( - f"{'Function Identifier':<80} {'Calls':>8} {'Total (s)':>12} {'Avg (s)':>12} {'Min (s)':>12} {'Max (s)':>12} {'Median (s)':>12}" - ) - print("-" * 160) - - # Print top N functions - for func_name, func_stats in sorted_stats[:top_n]: - # Truncate long function identifiers for display - display_name = func_name if len(func_name) <= 80 else func_name[:77] + "..." - print( - f"{display_name:<80} {func_stats.call_count:>8} " - f"{func_stats.total_time:>12.6f} {func_stats.avg_time:>12.6f} " - f"{func_stats.min_time:>12.6f} {func_stats.max_time:>12.6f} " - f"{func_stats.median_time:>12.6f}" - ) - - print("-" * 160) - - -def print_function_details(stats, function_name): - """ - Print detailed statistics for functions matching the given name. - - Args: - stats: Dictionary of function statistics - function_name: Name or substring to search for in function identifiers - """ - # Find all function identifiers containing the function_name string - matching_funcs = { - func_id: func_stats - for func_id, func_stats in stats.items() - if function_name in func_id - } - - if not matching_funcs: - print(f"No functions found containing '{function_name}' in profile data.") - print(f"\nAvailable functions (showing first 20):") - for i, name in enumerate(sorted(stats.keys())[:20]): - print(f" - {name}") - if len(stats) > 20: - print(f" ... and {len(stats) - 20} more") - return - - print("\n" + "=" * 120) - print(f"DETAILED STATISTICS FOR FUNCTIONS CONTAINING: '{function_name}'") - print(f"Found {len(matching_funcs)} matching function(s)") - print("=" * 120) - - for func_id, func_stats in sorted( - matching_funcs.items(), key=lambda x: x[1].total_time, reverse=True - ): - print(f"\nFunction: {func_id}") - print("-" * 120) - print(f" Total calls: {func_stats.call_count:,}") - print(f" Total time: {func_stats.total_time:.6f} seconds") - print(f" Average time: {func_stats.avg_time:.6f} seconds") - print(f" Median time: {func_stats.median_time:.6f} seconds") - print(f" Min time: {func_stats.min_time:.6f} seconds") - print(f" Max time: {func_stats.max_time:.6f} seconds") - - if func_stats.call_count > 1: - std_dev = statistics.stdev(func_stats.durations) - print(f" Std deviation: {std_dev:.6f} seconds") - - print("=" * 120 + "\n") - - -def export_to_csv(stats, output_file, sort_by="total", min_calls=1): - """ - Export function statistics to a CSV file. - - Args: - stats: Dictionary of function statistics - output_file: Path to output CSV file - sort_by: Sort criterion ('total', 'avg', 'max', 'calls') - min_calls: Minimum number of calls to include in results - """ - # Filter by minimum calls - filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} - - if not filtered_stats: - print("No functions found matching the criteria.") - return - - # Sort functions - sort_keys = { - "total": lambda x: x[1].total_time, - "avg": lambda x: x[1].avg_time, - "max": lambda x: x[1].max_time, - "calls": lambda x: x[1].call_count, - } - - if sort_by not in sort_keys: - print(f"Invalid sort criterion: {sort_by}. Using 'total'.") - sort_by = "total" - - sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) - - # Write to CSV - with open(output_file, "w", newline="") as csvfile: - fieldnames = [ - "function_name", - "call_count", - "total_time_seconds", - "avg_time_seconds", - "median_time_seconds", - "min_time_seconds", - "max_time_seconds", - "std_dev_seconds", - ] - - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - - for func_name, func_stats in sorted_stats: - std_dev = ( - statistics.stdev(func_stats.durations) - if func_stats.call_count > 1 - else 0.0 - ) - - writer.writerow( - { - "function_name": func_name, - "call_count": func_stats.call_count, - "total_time_seconds": f"{func_stats.total_time:.9f}", - "avg_time_seconds": f"{func_stats.avg_time:.9f}", - "median_time_seconds": f"{func_stats.median_time:.9f}", - "min_time_seconds": f"{func_stats.min_time:.9f}", - "max_time_seconds": f"{func_stats.max_time:.9f}", - "std_dev_seconds": f"{std_dev:.9f}", - } - ) - - print(f"\nCSV file saved to: {output_file}") - print(f"Total functions exported: {len(sorted_stats)}") - - -def main(): - parser = argparse.ArgumentParser( - description="Analyze profiling logs from inference.py", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Analyze the most recent profile log - python analyze_profile.py - - # Analyze a specific log file - python analyze_profile.py logs/profile_20250110_160000.log - - # Sort by average time and show top 30 - python analyze_profile.py --sort avg --top 30 - - # Show details for a specific function - python analyze_profile.py --function inference - - # Filter functions with at least 10 calls - python analyze_profile.py --min-calls 10 - - # Export to CSV file - python analyze_profile.py --csv profile_stats.csv - - # Export to CSV with custom sorting and filtering - python analyze_profile.py --csv results.csv --sort avg --min-calls 5 - """, - ) - - parser.add_argument( - "log_file", - type=str, - help="Path to profile log file", - ) - parser.add_argument( - "--sort", - choices=["total", "avg", "max", "calls"], - default="total", - help="Sort criterion (default: total)", - ) - parser.add_argument( - "--top", - type=int, - default=20, - help="Number of top functions to display (default: 20)", - ) - parser.add_argument( - "--min-calls", - type=int, - default=1, - help="Minimum number of calls to include (default: 1)", - ) - parser.add_argument( - "--function", type=str, help="Show detailed statistics for a specific function" - ) - parser.add_argument( - "--csv", - type=str, - help="Export results to CSV file instead of printing to console", - ) - - args = parser.parse_args() - - # Parse the log file - log_file = Path(args.log_file) - print(f"Parsing {log_file}...") - stats = parse_profile_log(log_file) - - if not stats: - print("No profiling data found in log file.") - else: - print(f"Found {len(stats)} unique functions") - - # Show results - if args.csv: - # Export to CSV - export_to_csv(stats, args.csv, sort_by=args.sort, min_calls=args.min_calls) - elif args.function: - # Show detailed function statistics - print_function_details(stats, args.function) - else: - # Print summary to console - print_summary( - stats, sort_by=args.sort, top_n=args.top, min_calls=args.min_calls - ) - - -if __name__ == "__main__": - main() diff --git a/applications/llama_3.2_1b/autofuse.py b/applications/llama_3.2_1b/autofuse.py deleted file mode 100755 index bddaf7a7..00000000 --- a/applications/llama_3.2_1b/autofuse.py +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env python3 - -import torch -import math -from pathlib import Path -import sys -import numpy as np -import ml_dtypes -import logging -import time -logging.basicConfig(level=logging.DEBUG) - -repo_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(repo_root)) - -from operators.common.context import AIEContext -from operators.common import AIEOperatorBase, AIEBuffer, SingleMLIRSourceOperator -from operators.common.utils import torch_to_numpy, numpy_to_torch -from operators.common.compilation import SourceArtifact, PythonGeneratedMLIRArtifact -from operators.common.fusion import FusedMLIROperator, FusedFullELFCallable -from operators import AIEGEMV -from operators.elementwise_mul.op import AIEElementwiseMul -from operators.silu.op import AIESiLU - - -emb_dim = 2048 -hidden_dim = 8192 - -# Operator definitions -# --- - -gemv_ffn_up_gate_op = AIEGEMV( - M=hidden_dim, - K=emb_dim, - num_aie_columns=4, - tile_size_input=4, - tile_size_output=hidden_dim // 8, -) - -gemv_ffn_down_op = AIEGEMV( - M=emb_dim, - K=hidden_dim, - num_aie_columns=4, - tile_size_input=1, - tile_size_output=emb_dim // 8, -) - -silu_ffn_op = AIESiLU( - size=hidden_dim, - tile_size=hidden_dim // 8, - num_aie_columns=4, -) - -eltwise_mul_ffn_op = AIEElementwiseMul( - size=hidden_dim, - tile_size=hidden_dim // 8, - num_aie_columns=4, -) - - -# Buffers -# --- - -x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) -W_ffn_gate = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) -W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) -W_ffn_down = torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16) - -def init_random(): - global x_norm, W_ffn_gate, W_ffn_up, W_ffn_down - x_norm = torch.randn(emb_dim, dtype=torch.bfloat16) - W_ffn_gate = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) - W_ffn_up = torch.randn(hidden_dim, emb_dim, dtype=torch.bfloat16) - W_ffn_down = torch.randn(emb_dim, hidden_dim, dtype=torch.bfloat16) - -buf_W_ffn_gate = AIEBuffer.from_torch(W_ffn_gate) -buf_W_ffn_up = AIEBuffer.from_torch(W_ffn_up) -buf_W_ffn_down = AIEBuffer.from_torch(W_ffn_down) - -buf_x_norm = AIEBuffer.from_torch(x_norm) -buf_ffn_gate = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -buf_ffn_up = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -buf_ffn_hidden = AIEBuffer.from_torch(torch.zeros(hidden_dim, dtype=torch.bfloat16)) -buf_ffn_output = AIEBuffer.from_torch(torch.zeros(emb_dim, dtype=torch.bfloat16)) - - -# Separate xclbins -# --- - -gemv_ffn_up_gate = None -gemv_ffn_down = None -silu_ffn = None -eltwise_mul_ffn = None - -def setup_separate_xclbins(): - global gemv_ffn_up_gate, gemv_ffn_down, silu_ffn, eltwise_mul_ffn - ctx = AIEContext(build_dir="build_separate") - gemv_ffn_up_gate_op.context = ctx - gemv_ffn_down_op.context = ctx - silu_ffn_op.context = ctx - eltwise_mul_ffn_op.context = ctx - gemv_ffn_up_gate = gemv_ffn_up_gate_op.compile().get_callable() - gemv_ffn_down = gemv_ffn_down_op.compile().get_callable() - silu_ffn = silu_ffn_op.compile().get_callable() - eltwise_mul_ffn = eltwise_mul_ffn_op.compile().get_callable() - -def run_separate_xclbins(): - gemv_ffn_up_gate(buf_W_ffn_gate, buf_x_norm, buf_ffn_gate) # Gate projection - gemv_ffn_up_gate(buf_W_ffn_up, buf_x_norm, buf_ffn_up) # Up projection - silu_ffn(buf_ffn_gate, buf_ffn_gate) # SiLU activation - eltwise_mul_ffn(buf_ffn_gate, buf_ffn_up, buf_ffn_hidden) # Gate application (eltwise mul) - gemv_ffn_down(buf_W_ffn_down, buf_ffn_hidden, buf_ffn_output) # Down projection - return buf_ffn_output.to("cpu").view_as_torch() - - -# Autofused -# --- - -def setup_autofused(): - ctx = AIEContext(build_dir="build_autofused") - gemv_ffn_up_gate_op.context = ctx - gemv_ffn_down_op.context = ctx - silu_ffn_op.context = ctx - eltwise_mul_ffn_op.context = ctx - global swiglu_fused_op, swiglu_fused - swiglu_fused_op = FusedMLIROperator( - "swiglu", - [ - (gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"), - (gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"), - (silu_ffn_op, "ffn_gate", "ffn_gate"), - (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), - (gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"), - ], - input_args=[ - "x_norm", - "W_ffn_gate", - "W_ffn_up", - "W_ffn_down" - ], - output_args=[ - "ffn_output" - ], - ) - swiglu_fused_op.context = ctx - swiglu_fused_op = swiglu_fused_op.compile() - swiglu_fused = swiglu_fused_op.get_callable() - - #swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() - #swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() - #swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = W_ffn_up.flatten() - #swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = W_ffn_down.flatten() - -def run_autofused(): - #swiglu_fused.output_buffer.view_as_torch()[:] = 0 - #swiglu_fused.scratch_buffer.view_as_torch()[:] = 0 - swiglu_fused.input_buffer.view_as_torch()[:] = 0 - swiglu_fused.get_buffer("x_norm").view_as_torch()[:] = x_norm.flatten() - swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() - swiglu_fused.get_buffer("W_ffn_gate").view_as_torch()[:] = W_ffn_gate.flatten() - swiglu_fused.get_buffer("W_ffn_up").view_as_torch()[:] = W_ffn_up.flatten() - swiglu_fused.get_buffer("W_ffn_down").view_as_torch()[:] = W_ffn_down.flatten() - swiglu_fused() - return swiglu_fused.get_buffer("ffn_output").to("cpu").view_as_torch() - -# CPU -# --- - -def run_cpu(): - ffn_gate = torch.matmul(W_ffn_gate, x_norm) - ffn_up = torch.matmul(W_ffn_up, x_norm) - ffn_gate = torch.nn.functional.silu(ffn_gate) - ffn_hidden = ffn_gate * ffn_up - ffn_output = torch.matmul(W_ffn_down, ffn_hidden) - return ffn_output - - -# Main -# --- - -iters=10 - -setup_autofused() -#setup_separate_xclbins() -for _ in range(iters): - init_random() - - t_autofused_start = time.time() - res_npu = run_autofused() - print("npu:") - print(res_npu) - t_autofused = time.time() - t_autofused_start - - #t_separate_start = time.time() - #for _ in range(iters): - # res_npu_s = run_separate_xclbins() - #t_separate = time.time() - t_separate_start - - t_cpu_start = time.time() - res_cpu = run_cpu() - print("cpu:") - print(res_cpu) - #assert(torch.allclose(res_npu[-1], res_cpu[-1], atol=0.7, rtol=0.07)) - t_cpu = time.time() - t_cpu_start - - #print(res_npu_s) - - -#print(f"Separate xclbins time: {t_separate/iters:.6f} seconds") -print(f"Autofused time: {t_autofused/iters:.6f} seconds") -print(f"CPU time: {t_cpu/iters:.6f} seconds") -assert(torch.allclose(res_npu[-1], res_cpu[-1], atol=0.7, rtol=0.07)) diff --git a/applications/llama_3.2_1b/bar_plot_profile.py b/applications/llama_3.2_1b/bar_plot_profile.py deleted file mode 100755 index 05f4ab82..00000000 --- a/applications/llama_3.2_1b/bar_plot_profile.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -""" -Generate a bar plot showing the top 15 most expensive functions by cumulative time. -""" -import json -import argparse -import matplotlib.pyplot as plt -from collections import defaultdict - -def load_profile_data(json_file): - """Load profile data from JSON file.""" - with open(json_file, 'r') as f: - return json.load(f) - -def extract_function_name(full_identifier): - """Extract just the function name from the full identifier.""" - # Remove parameters if present - if '(' in full_identifier: - full_identifier = full_identifier.split('(')[0] - - # Split by '/' to get path components - path_parts = full_identifier.split('/') - - # Get the last part which contains filename:line:function - last_part = path_parts[-1] - parts = last_part.split(':') - - if len(parts) >= 3: - # Format: filename:line:function - return parts[-1].strip() - elif len(parts) >= 2: - # Format: filename:function or similar - return parts[-1].strip() - else: - return full_identifier.strip() - -def aggregate_time_by_function(profile_data): - """ - Aggregate cumulative time for each function across all call sites. - - Args: - profile_data: Dict {func: [time, {children}]} - - Returns: - Dict mapping function name to total cumulative time - """ - time_by_function = defaultdict(float) - - def process_node(func_id, node_data): - """Recursively process nodes and accumulate time.""" - if not isinstance(node_data, list) or len(node_data) != 2: - return - - time, children = node_data - - # Extract function name and add time - func_name = extract_function_name(func_id) - time_by_function[func_name] += time - - # Recurse to children - for child_id, child_data in children.items(): - process_node(child_id, child_data) - - # Process all root functions - for func_id, func_data in profile_data.items(): - process_node(func_id, func_data) - - return time_by_function - -def create_bar_plot(time_by_function, output_file, top_n=15): - """ - Create a bar plot showing the top N most expensive functions. - - Args: - time_by_function: Dict mapping function name to cumulative time - output_file: Path to save the plot - top_n: Number of top functions to display - """ - # Sort by time and get top N - sorted_functions = sorted(time_by_function.items(), key=lambda x: x[1], reverse=True) - top_functions = sorted_functions[:top_n] - - # Prepare data for plotting - function_names = [func for func, _ in top_functions] - times = [time for _, time in top_functions] - - # Create the plot - fig, ax = plt.subplots(figsize=(12, 8)) - - # Create horizontal bars (easier to read function names) - bars = ax.barh(range(len(function_names)), times, color='steelblue') - - # Customize the plot - ax.set_yticks(range(len(function_names))) - ax.set_yticklabels(function_names) - ax.set_xlabel('Cumulative Time (seconds)', fontsize=12) - ax.set_ylabel('Function Name', fontsize=12) - ax.set_title(f'Top {top_n} Most Expensive Functions by Cumulative Time', fontsize=14, fontweight='bold') - - # Add value labels on the bars - for i, (bar, time) in enumerate(zip(bars, times)): - width = bar.get_width() - ax.text(width, bar.get_y() + bar.get_height()/2, - f' {time:.3f}s', - ha='left', va='center', fontsize=10) - - # Invert y-axis so the highest time is at the top - ax.invert_yaxis() - - # Add grid for better readability - ax.grid(axis='x', alpha=0.3, linestyle='--') - ax.set_axisbelow(True) - - # Tight layout - plt.tight_layout() - - # Save the plot - plt.savefig(output_file, dpi=150, bbox_inches='tight') - print(f"Bar plot saved to {output_file}") - - # Print summary statistics - total_time = sum(times) - print(f"\nTop {top_n} Functions Summary:") - print(f"Total cumulative time (top {top_n}): {total_time:.3f}s") - for i, (func_name, time) in enumerate(top_functions, 1): - print(f"{i:2d}. {func_name:40s} {time:8.3f}s") - -def main(): - parser = argparse.ArgumentParser( - description='Generate a bar plot of the top N most expensive functions by cumulative time' - ) - parser.add_argument('input', help='Input profile JSON file') - parser.add_argument('-o', '--output', default='bar_plot.png', - help='Output image file (default: bar_plot.png)') - parser.add_argument('-n', '--top-n', type=int, default=15, - help='Number of top functions to display (default: 15)') - - args = parser.parse_args() - - # Load profile data - profile_data = load_profile_data(args.input) - - # Aggregate time by function name - time_by_function = aggregate_time_by_function(profile_data) - - # Create the bar plot - create_bar_plot(time_by_function, args.output, args.top_n) - -if __name__ == '__main__': - main() diff --git a/applications/llama_3.2_1b/configs/llama32_1b.json b/applications/llama_3.2_1b/configs/llama32_1b.json deleted file mode 100644 index ed6bc4bf..00000000 --- a/applications/llama_3.2_1b/configs/llama32_1b.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "model_config": { - "vocab_size": 128256, - "context_length": 131072, - "emb_dim": 2048, - "n_heads": 32, - "n_layers": 16, - "hidden_dim": 8192, - "n_kv_groups": 8, - "use_kv_cache": true, - "rope_base": 500000.0, - "dtype": "bfloat16", - "use_aie_final_norm": true, - "use_aie_ffn_gemm": false, - "use_aie_ffn_silu": false, - "use_aie_ffn_mul": false, - "use_aie_ffn_swiglu": true, - "use_aie_ffn_gemv": true, - "use_aie_attn_projection_gemm": true, - "use_aie_gqa_gemv": true, - "use_aie_rope": true, - "use_aie_norm1": true, - "use_aie_norm2": true, - "use_aie_residual": true, - "use_aie_regular_mha": false, - "use_aie_fused_mha": true, - "use_aie_final_gemm": true, - "use_aie_final_gemv": true, - "rope_freq": { - "factor": 32.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_context_length": 8192 - } - }, - "aie_config": { - "device": "npu2" - } -} \ No newline at end of file diff --git a/applications/llama_3.2_1b/configs/llama32_1b.json.license b/applications/llama_3.2_1b/configs/llama32_1b.json.license deleted file mode 100644 index 50daea92..00000000 --- a/applications/llama_3.2_1b/configs/llama32_1b.json.license +++ /dev/null @@ -1,7 +0,0 @@ -Copyright (c) Sebastian Raschka under Apache License 2.0. -Source for "Build a Large Language Model From Scratch" - - https://www.manning.com/books/build-a-large-language-model-from-scratch -Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb - -SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -SPDX-License-Identifier: Apache-2.0 diff --git a/applications/llama_3.2_1b/custom_profile.py b/applications/llama_3.2_1b/custom_profile.py deleted file mode 100644 index 89ed4a51..00000000 --- a/applications/llama_3.2_1b/custom_profile.py +++ /dev/null @@ -1,70 +0,0 @@ -import sys, time, inspect, json - -# The current call stack; for each active function call, we store (func_identifier, start_time) -call_stack = [] - -# The cumulative time spent in each call stack path -# Map {function identifier: tuple (cumulative_time, {sub_call_function_identifier: cumulative_time, ...}) } -time_per_path = [0.0, {}] - - -def profile_call(frame, event, arg): - func_name = frame.f_code.co_name - filename = frame.f_code.co_filename - line_no = frame.f_lineno - func_identifier = f"{str(frame.f_code.co_filename)}:{frame.f_code.co_firstlineno}:{frame.f_code.co_name}" - - if event == "call": - log_call(func_identifier) - elif event == "return": - log_return(func_identifier) - - -def log_call(func_identifier): - global call_stack - if func_identifier.endswith(":log_call") or func_identifier.endswith(":log_return") or func_identifier.endswith(":__enter__") or func_identifier.endswith(":__exit__"): - return - timestamp = time.perf_counter() - call_stack.append((func_identifier, timestamp)) - - -def log_return(func_identifier): - global call_stack - if func_identifier.endswith(":log_call") or func_identifier.endswith(":log_return") or func_identifier.endswith(":__enter__") or func_identifier.endswith(":__exit__"): - return - if 0 == len(call_stack): - return - timestamp = time.perf_counter() - last_func_identifier, start_time = call_stack[-1] - if last_func_identifier != func_identifier: - print(call_stack) - raise RuntimeError(f"Function return mismatch: expected {last_func_identifier}, got {func_identifier}") - elapsed = timestamp - start_time - - this_path_time = time_per_path - for f, _ in call_stack: - this_path_time = this_path_time[1].setdefault(f, [0.0, {}]) - this_path_time[0] += elapsed - - call_stack.pop() - - -def enable_profiling(): - sys.setprofile(profile_call) - - -def store_profile(path): - sys.setprofile(None) - with open(path, "w") as f: - json.dump(time_per_path[1], f, indent=2) - - -class CustomProfileContext: - def __init__(self, label): - self.label = label - - def __enter__(self): - log_call(self.label) - - def __exit__(self, exc_type, exc_value, traceback): - log_return(self.label) diff --git a/applications/llama_3.2_1b/inference.py b/applications/llama_3.2_1b/inference.py deleted file mode 100755 index 63b2c18e..00000000 --- a/applications/llama_3.2_1b/inference.py +++ /dev/null @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import sys -from pathlib import Path - -# Add IRON repository root to Python path -repo_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(repo_root)) - -import argparse -import time -import torch -from src.model_with_json import Llama3ModelWithJSONConfig - -# from src.model import Llama3Model -from src.tokenizer import Tokenizer, ChatFormat -from safetensors.torch import load_file -import os -import shutil -import logging -from collections import deque - -from operators.common import AIEOperatorBase -from src.utils import ( - model_memory_size, - load_weights_into_llama, - text_to_token_ids, - token_ids_to_text, - clean_text, - generate, -) - -import custom_profile - - -_iron_chat = r""" - /$$$$$$ /$$$$$$$ /$$$$$$ /$$ /$$ - |_ $$_/| $$__ $$ /$$__ $$| $$$ | $$ - | $$ | $$ \ $$| $$ \ $$| $$$$| $$ - | $$ | $$$$$$$/| $$ | $$| $$ $$ $$ - | $$ | $$__ $$| $$ | $$| $$ $$$$ - | $$ | $$ \ $$| $$ | $$| $$\ $$$ - /$$$$$$| $$ | $$| $$$$$$/| $$ \ $$ - |______/|__/ |__/ \______/ |__/ \__/ - - - /$$ /$$ /$$$$$$ /$$ /$$ /$$$$$$ -| $$ | $$ /$$__ $$| $$$ /$$$ /$$__ $$ -| $$ | $$ | $$ \ $$| $$$$ /$$$$| $$ \ $$ -| $$ | $$ | $$$$$$$$| $$ $$/$$ $$| $$$$$$$$ -| $$ | $$ | $$__ $$| $$ $$$| $$| $$__ $$ -| $$ | $$ | $$ | $$| $$\ $ | $$| $$ | $$ -| $$$$$$$$| $$$$$$$$| $$ | $$| $$ \/ | $$| $$ | $$ -|________/|________/|__/ |__/|__/ |__/|__/ |__/ -""" - - -def setup_logging(verbosity): - """Set up logging based on verbosity level.""" - - # Ensure the logs directory is created in case of profiling - logs_dir_name = "logs" - if not os.path.exists(logs_dir_name): - os.makedirs(logs_dir_name) - - if verbosity != 0: - levels = { - 4: logging.DEBUG, - 3: logging.INFO, - 2: logging.WARNING, - # 1: log everything (DEBUG) to a file - } - - # Create log file - timestamp = time.strftime("%Y%m%d_%H%M%S") - log_file = f"logs/inference_{timestamp}.log" - - handlers = [logging.FileHandler(log_file)] - if verbosity > 0: - handlers.append(logging.StreamHandler(sys.stderr)) - handlers[-1].setLevel(levels[verbosity]) - - # Configure root logger - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=handlers, - force=True, # Override any existing configuration - ) - - return logs_dir_name - - -def save_layer_data(module, input, output, name, input_data_path, output_data_path): - for count, i in enumerate(input): - torch.save( - i.detach(), - f"{input_data_path}/{name}_input_{count}_{input[0].size()[1]}_toks.pt", - ) - torch.save( - output.detach(), f"{output_data_path}/{name}_output_{output.size()[1]}_toks.pt" - ) - - -def inference( - weights_file_path, - tokenizer_file_path, - num_tokens, - prompt, - use_prompt_template, - save_outputs, - chat: bool, - prompt_len: int = 64, -): - """ - Main function to load a Llama3 model, process input, and generate output text. - """ - logging.info("Weights file path: %s", weights_file_path) - logging.info("Tokenizer file path: %s", tokenizer_file_path) - logging.info("Number of tokens: %d", num_tokens) - logging.debug("Prompt: %s", prompt) - logging.info("Use prompt template: %s", use_prompt_template) - logging.info("Save outputs: %s", save_outputs) - torch.manual_seed(1608560892) - input_data_path = "results/inputs" - output_data_path = "results/outputs" - - tokenizer = Tokenizer(tokenizer_file_path) - - print(_iron_chat) - if chat: - prompt = input("Enter your prompt: ").strip() - print("") - - logging.info(f"Loading model and tokenizer...") - token_ids = text_to_token_ids(prompt, tokenizer)[:, :prompt_len] - truncated_prompt = token_ids_to_text(token_ids, tokenizer) - - script_dir = os.path.dirname(os.path.abspath(__file__)) - config_path = os.path.join(script_dir, "configs", "llama32_1b.json") - model = Llama3ModelWithJSONConfig( - config_path=config_path, - prompt_length=prompt_len, - num_tokens=num_tokens, - ) - logging.info("Model and tokenizer loaded.") - - # Important: Set the seed again after initialization of the model. Each - # call that initializes an nn.Linear layer updates the RNG state, because - # weights are initialized with random values. For different JSON - # configurations, we initialize a different number of linear layers, - # so different configurations result in a different RNG state here. Since - # we use random numbers to sample from the token distribution during - # inference, it is important to have the same RNG state between runs so we - # can have reproducible results across configurations. - torch.manual_seed(1608560892) - - hook_handles = [] - if save_outputs: - if os.path.exists(output_data_path): - shutil.rmtree(output_data_path) - os.makedirs(output_data_path) - if os.path.exists(input_data_path): - shutil.rmtree(input_data_path) - os.makedirs(input_data_path) - for name, module in model.named_modules(): - handle = module.register_forward_hook( - lambda module, input, output, name=name, input_data_path=input_data_path, output_data_path=output_data_path: ( - save_layer_data( - module, input, output, name, input_data_path, output_data_path - ) - ) - ) - hook_handles.append(handle) - - device = torch.device("cpu") - chat_tokenizer = ChatFormat(tokenizer) - - combined_weights = load_file(weights_file_path) - # Get parameters from model config - model_config = { - "n_layers": model.cfg["n_layers"], - "emb_dim": model.cfg["emb_dim"], - "n_heads": model.cfg["n_heads"], - "n_kv_groups": model.cfg["n_kv_groups"], - "vocab_size": model.cfg["vocab_size"], - "context_length": model.cfg["context_length"], - "hidden_dim": model.cfg["hidden_dim"], - "rope_base": model.cfg["rope_base"], - "dtype": model.cfg["dtype"], - "rope_freq": model.cfg["rope_freq"], - } - load_weights_into_llama(model, model_config, combined_weights) - del combined_weights - - logging.info("Preparing AIE operators...") - # At this point the model is fully described (operators and their dimensions and how to compile them) - AIEOperatorBase.get_default_context().compile_all() - AIEOperatorBase.get_default_context().prepare_runtime() - logging.info("AIE operator preparation completed.") - print(f"Starting text generation...") - print(f"Generating {num_tokens} tokens...") - print("=" * 55) - - prefill_end_time = None - - def set_prefill_time(): - nonlocal prefill_end_time - prefill_end_time = time.time() - - # Start total wall clock timing - start = time.time() - token_ids = generate( - model=model, - idx=token_ids.to(device), - max_new_tokens=num_tokens, - context_size=model.cfg["context_length"], - eos_id=tokenizer.special["<|end_of_text|>"], - hook_handles=hook_handles, - temperature=0.7, - top_k=50, - tokenizer=tokenizer, - prompt=truncated_prompt, - prefill_done_callback=set_prefill_time, - ) - end = time.time() - prefill_time = prefill_end_time - start - total_time = end - start - post_prefill_time = end - prefill_end_time if num_tokens > 0 else 0 - - tokens_per_second = (num_tokens - 1) / post_prefill_time if num_tokens > 1 else 0 - time_per_token = total_time / (num_tokens - 1) if num_tokens > 1 else prefill_time - - print("=" * 55) - print(" TIMING RESULTS:") - print(f" Total time: {total_time:.4f} seconds") - print(f" Prefill time: {prefill_time:.4f} seconds") - print(f" Tokens generated: {num_tokens}") - print(f" Tokens per second: {tokens_per_second:.2f}") - print( - f" Time per token: {time_per_token:.4f} seconds" - if num_tokens > 0 - else " Time per token: N/A" - ) - print("=" * 55) - - logging.info(f"Generation time: {total_time:.4f} sec") - logging.info(f"Total wall clock time: {total_time:.4f} sec") - logging.info(f"Tokens per second: {tokens_per_second:.2f}") - logging.info( - f"Time per token: {time_per_token:.4f} sec" - if num_tokens > 0 - else "Time per token: N/A" - ) - - output_text = token_ids_to_text(token_ids, tokenizer) - logging.info("Output text:\n %s", clean_text(output_text)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run Llama3 model inference.") - parser.add_argument( - "weights_file_path", - type=str, - help="Path to the weights file: model.safetensors", - ) - parser.add_argument( - "tokenizer_file_path", - type=str, - help="Path to the tokenizer file: tokenizer.model", - ) - parser.add_argument( - "--num_tokens", type=int, default=1, help="Number of tokens to predict." - ) - parser.add_argument( - "--prompt", - type=str, - default="", - help="Prompt for the model to generate text from.", - ) - parser.add_argument( - "--use_prompt_template", - action="store_true", - help="Use a prompt template for the model.", - ) - parser.add_argument( - "--save_outputs", - action="store_true", - help="Enable hooks to save outputs of the layers in the model", - ) - parser.add_argument( - "--chat", - action="store_true", - help="Enable interactive mode to enter your own prompt.", - ) - parser.add_argument( - "--prompt_len", - type=int, - default=2048, - help="Truncate prompt to this many tokens.", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use a custom profiler for performance measurements", - ) - parser.add_argument( - "-v", - action="count", - default=0, - help="Increase verbosity level (use -v (logs to file), -vv, -vvv, or -vvvv)", - ) - args = parser.parse_args() - - # Set up logging - logs_dir_name = setup_logging(args.v) - - # Enable function profiling - if args.profile: - custom_profile.enable_profiling() - - try: - prompt = args.prompt - if not prompt: - # Default prompt is text from Shakespeare's King Lear: https://shakespeare.mit.edu/lear/lear.1.1.html - prompt_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "prompt.txt" - ) - with open(prompt_path, "r", encoding="utf-8") as file: - prompt = file.read().strip() - - inference( - args.weights_file_path, - args.tokenizer_file_path, - args.num_tokens, - prompt, - args.use_prompt_template, - args.save_outputs, - args.chat, - args.prompt_len, - ) - finally: - if args.profile: - custom_profile.store_profile(Path(logs_dir_name) / "profile.json") - diff --git a/applications/llama_3.2_1b/profile_path_analyzer.py b/applications/llama_3.2_1b/profile_path_analyzer.py deleted file mode 100755 index 638d938d..00000000 --- a/applications/llama_3.2_1b/profile_path_analyzer.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -""" -Profile Path Analyzer - -Analyzes profile.json to aggregate time spent in selected functions across call paths. -Supports fuzzy matching and "zooming into" specific call paths. -""" - -import json -import argparse -import sys -from typing import Dict, List, Tuple, Set -import matplotlib.pyplot as plt -import matplotlib - - -def parse_function_name(full_name: str) -> Tuple[str, str, str]: - parts = full_name.rsplit(':', 2) - if len(parts) == 3: - return parts[0], parts[1], parts[2] - return full_name, '', '' - - -def fuzzy_match(pattern: str, full_name: str) -> bool: - return pattern.lower() in full_name.lower() - - -def get_all_paths(profile: Dict, current_path: List[str] = None) -> List[Tuple[List[str], float]]: - """ - Extract all call paths from the profile with their total times. - - Args: - profile: Profile dictionary - current_path: Current call path being explored - - Returns: - List of (call_path, total_time) tuples - """ - if current_path is None: - current_path = [] - - paths = [] - - for func_name, data in profile.items(): - time_spent = data[0] - children = data[1] if len(data) > 1 else {} - - new_path = current_path + [func_name] - paths.append((new_path, time_spent)) - - if children: - child_paths = get_all_paths(children, new_path) - paths.extend(child_paths) - - return paths - - -def is_subpath(path: List[str], zoom_path: List[str]) -> bool: - """ - Check if path is a subpath of zoom_path. - - Args: - path: The path to check - zoom_path: The zoom-in path - - Returns: - True if path starts with zoom_path - """ - if len(path) < len(zoom_path): - return False - - for i, func in enumerate(zoom_path): - if not fuzzy_match(func, path[i]): - return False - - return True - - -def find_matching_functions(profile: Dict, patterns: List[str]) -> Dict: - """ - Find all function names in the profile that match any of the patterns. - If multiple patterns match, uses the most specific (longest) pattern. - - Args: - profile: Profile dictionary - patterns: List of search patterns - - Returns: - Dict mapping function names to their matched pattern - """ - matching = {} - - def traverse(data: Dict): - for func_name, value in data.items(): - # Find all matching patterns and use the longest (most specific) - matched_patterns = [p for p in patterns if fuzzy_match(p, func_name)] - if matched_patterns: - # Use the longest pattern as it's the most specific - most_specific = max(matched_patterns, key=len) - matching[func_name] = most_specific - - if len(value) > 1 and value[1]: - traverse(value[1]) - - traverse(profile) - return matching - - -def aggregate_times(profile: Dict, selected_funcs: Dict, zoom_path: List[str]) -> Dict[str, float]: - aggregated = {func: 0.0 for func in selected_funcs.values()} - - def traverse(data: Dict, current_path: List[str], counted_parent: str | None): - for func_name, (time_spent, subprofile) in data.items(): - new_path = current_path + [func_name] - # Start with parent's pattern, will be overridden if this function is selected - counted_this_iter = counted_parent - - # Check if this path is within the zoom scope - if not zoom_path or is_subpath(new_path, zoom_path): - # If this function is selected, add its time - if func_name in selected_funcs: - pattern = selected_funcs[func_name] - if counted_parent is None: - aggregated[pattern] += time_spent - elif counted_parent != pattern: - raise RuntimeError(f"Double-counting detected: pattern '{counted_parent}' calls '{pattern}', which would lead to double-counting '{pattern}'.") - counted_this_iter = pattern - - # Recurse into children - if subprofile: - traverse(subprofile, new_path, counted_this_iter) - - traverse(profile, [], None) - return aggregated - - -def calculate_total_time(profile: Dict, zoom_path: List[str]) -> float: - """ - Calculate total time for the zoomed-in path. - - Args: - profile: Profile dictionary - zoom_path: Zoom-in path (empty list for entire profile) - - Returns: - Total time in seconds - """ - if not zoom_path: - # No zoom - return total of all top-level functions - return sum(value[0] for value in profile.values()) - - # Find the zoomed function and return its time - def find_zoom_time(data: Dict, path_remaining: List[str]) -> float: - if not path_remaining: - return 0.0 - - pattern = path_remaining[0] - - for func_name, value in data.items(): - if fuzzy_match(pattern, func_name): - if len(path_remaining) == 1: - # Found the zoom target - return value[0] - else: - # Continue searching deeper - if len(value) > 1 and value[1]: - result = find_zoom_time(value[1], path_remaining[1:]) - if result > 0: - return result - - return 0.0 - - return find_zoom_time(profile, zoom_path) - - -def print_results(total_time: float, aggregated: Dict[str, float], plot_file: str = None): - """ - Print the analysis results. - - Args: - total_time: Total time in the zoomed scope - aggregated: Dictionary of function times - plot_file: Optional path to save a matplotlib bar plot - """ - print("\n" + "="*80) - print("PROFILE ANALYSIS RESULTS") - print("="*80) - - print(f"\nTotal time (zoomed scope): {total_time:.6f} seconds") - print("-"*80) - - # Sort by time (descending) - sorted_funcs = sorted(aggregated.items(), key=lambda x: x[1], reverse=True) - - selected_total = 0.0 - for func_name, time in sorted_funcs: - selected_total += time - percentage = (time / total_time * 100) if total_time > 0 else 0 - print(f"{func_name}") - print(f" Time: {time:.6f}s ({percentage:.2f}%)") - print() - - # Calculate "other" time - other_time = total_time - selected_total - other_percentage = (other_time / total_time * 100) if total_time > 0 else 0 - - print("-"*80) - print(f"Selected functions total: {selected_total:.6f}s ({selected_total/total_time*100:.2f}%)") - print(f"Other (unselected): {other_time:.6f}s ({other_percentage:.2f}%)") - print("="*80) - - # Create plot if requested - if plot_file: - create_bar_plot(total_time, sorted_funcs, other_time, plot_file) - - -def create_bar_plot(total_time: float, sorted_funcs: List[Tuple[str, float]], other_time: float, output_file: str): - """ - Create a bar plot showing time spent in each selected pattern and "other". - - Args: - total_time: Total time in the zoomed scope - sorted_funcs: List of (pattern, time) tuples sorted by time - other_time: Time spent in unselected functions - output_file: Path to save the plot - """ - # Prepare data - labels = [func for func, _ in sorted_funcs] + ["Other"] - times = [time for _, time in sorted_funcs] + [other_time] - percentages = [(time / total_time * 100) if total_time > 0 else 0 for time in times] - - # Create figure - fig, ax = plt.subplots(figsize=(10, 6)) - - # Create bars - bars = ax.bar(range(len(labels)), times, color='steelblue', alpha=0.8) - - # Color the "Other" bar differently - bars[-1].set_color('lightgray') - - # Customize plot - ax.set_xlabel('Pattern', fontsize=12, fontweight='bold') - ax.set_ylabel('Time (seconds)', fontsize=12, fontweight='bold') - ax.set_title(f'Profile Analysis - Total Time: {total_time:.3f}s', fontsize=14, fontweight='bold') - ax.set_xticks(range(len(labels))) - ax.set_xticklabels(labels, rotation=45, ha='right') - - # Add value labels on bars - for i, (bar, time, pct) in enumerate(zip(bars, times, percentages)): - height = bar.get_height() - ax.text(bar.get_x() + bar.get_width()/2., height, - f'{time:.3f}s\n({pct:.1f}%)', - ha='center', va='bottom', fontsize=9) - - # Add grid for readability - ax.grid(axis='y', alpha=0.3, linestyle='--') - ax.set_axisbelow(True) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save plot - plt.savefig(output_file, dpi=300, bbox_inches='tight') - print(f"\nPlot saved to: {output_file}") - plt.close() - - -def main(): - parser = argparse.ArgumentParser( - description='Analyze profile.json to aggregate time spent in selected functions.', - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Analyze all torch calls - %(prog)s logs/profile.json --select torch - - # Analyze multiple patterns - %(prog)s logs/profile.json --select torch numpy inference - - # Zoom into a specific call path - %(prog)s logs/profile.json --select torch --zoom inference forward - - # List all unique functions - %(prog)s logs/profile.json --list-functions - """ - ) - - parser.add_argument('profile', help='Path to profile.json file') - parser.add_argument('--select', '-s', nargs='+', metavar='PATTERN', - help='Function patterns to select (fuzzy match)') - parser.add_argument('--zoom', '-z', nargs='+', metavar='PATTERN', - help='Call path to zoom into (fuzzy match sequence)') - parser.add_argument('--list-functions', '-l', action='store_true', - help='List all unique function names in the profile') - parser.add_argument('--plot', '-p', metavar='FILE', - help='Save a bar plot to the specified file (e.g., plot.png)') - - args = parser.parse_args() - - # Load profile - try: - with open(args.profile, 'r') as f: - profile = json.load(f) - except FileNotFoundError: - print(f"Error: Profile file '{args.profile}' not found.", file=sys.stderr) - sys.exit(1) - except json.JSONDecodeError as e: - print(f"Error: Invalid JSON in profile file: {e}", file=sys.stderr) - sys.exit(1) - - - # Require --select for analysis - if not args.select: - print("Error: --select is required for analysis", file=sys.stderr) - parser.print_help() - sys.exit(1) - - # Find matching functions - selected_funcs = find_matching_functions(profile, args.select) - - if not selected_funcs: - print(f"Error: No functions matched the patterns: {args.select}", file=sys.stderr) - sys.exit(1) - - print(f"\nMatched {len(selected_funcs)} function(s) from patterns: {args.select}") - - # Prepare zoom path - zoom_path = args.zoom if args.zoom else [] - - if zoom_path: - print(f"Zooming into path: {' -> '.join(zoom_path)}") - - # Calculate total time - total_time = calculate_total_time(profile, zoom_path) - - if total_time == 0: - print(f"Error: Could not find zoom path or it has zero time.", file=sys.stderr) - sys.exit(1) - - # Aggregate times - aggregated = aggregate_times(profile, selected_funcs, zoom_path) - - # Print results - print_results(total_time, aggregated, args.plot) - - -if __name__ == '__main__': - main() diff --git a/applications/llama_3.2_1b/src/block/feed_forward.py b/applications/llama_3.2_1b/src/block/feed_forward.py deleted file mode 100644 index 85305ad3..00000000 --- a/applications/llama_3.2_1b/src/block/feed_forward.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -from ..utils import assign -from operators import ( - AIEElementwiseMul, - AIEGEMM, - AIEGEMV, - AIESiLU, - AIESwiGLUPrefill, - AIESwiGLUDecode, -) -from ml_dtypes import bfloat16 - - -class FeedForward(nn.Module): - def __init__( - self, - cfg, - prompt_length=0, - num_tokens=1, - ): - super().__init__() - self.cfg = cfg.copy() - - assert ( - cfg["use_aie_ffn_swiglu"] - and not ( - cfg["use_aie_ffn_silu"] - or cfg["use_aie_ffn_gemm"] - or cfg["use_aie_ffn_mul"] - ) - or not cfg["use_aie_ffn_swiglu"] - ), "Cannot mix fused SwiGLU with individual AIE operators." - - self.emb_dim = cfg["emb_dim"] - self.hidden_dim = cfg["hidden_dim"] - - # Initialize SiLU activation - if self.cfg["use_aie_ffn_silu"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.hidden_dim - else: - max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim - self.aie_silu_prefill = AIESiLU( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim, - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.hidden_dim # 1 token * emb_dim - self.aie_silu_decode = AIESiLU( - size=decode_size, - num_aie_columns=1, - num_channels=1, - tile_size=self.hidden_dim, - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_silu_decode = self.silu_prefill - else: - self.silu = nn.SiLU() - - if self.cfg["use_aie_ffn_swiglu"]: - self.aie_swiglu_prefill = AIESwiGLUPrefill( - seq_len=prompt_length, - embedding_dim=self.emb_dim, - hidden_dim=self.hidden_dim, - ) - if self.cfg["use_kv_cache"]: - self.aie_swiglu_decode = AIESwiGLUDecode( - embedding_dim=self.emb_dim, hidden_dim=self.hidden_dim - ) - - if self.cfg["use_aie_ffn_gemm"]: - if self.cfg["use_kv_cache"]: - M_prefill = prompt_length - else: - M_prefill = prompt_length + num_tokens - - aie_config_prefill = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "use_static_weight": True, - } - - self.fc1 = AIEGEMM( - M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill - ) - self.fc2 = AIEGEMM( - M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill - ) - self.fc3 = AIEGEMM( - M=M_prefill, K=self.hidden_dim, N=self.emb_dim, **aie_config_prefill - ) - else: - self.fc1 = nn.Linear( - cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False - ) - self.fc2 = nn.Linear( - cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False - ) - self.fc3 = nn.Linear( - cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False - ) - - if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: - aie_gemv_config = {"num_aie_columns": 8, "is_mv": False} - # FC1 and FC2: emb_dim -> hidden_dim - self.aie_fc1_gemv = AIEGEMV( - M=self.hidden_dim, - K=self.emb_dim, - tile_size_input=1, - tile_size_output=self.hidden_dim // 16, - **aie_gemv_config, - ) - self.aie_fc2_gemv = AIEGEMV( - M=self.hidden_dim, - K=self.emb_dim, - tile_size_input=1, - tile_size_output=self.hidden_dim // 16, - **aie_gemv_config, - ) - # FC3: hidden_dim -> emb_dim - self.aie_fc3_gemv = AIEGEMV( - M=self.emb_dim, - K=self.hidden_dim, - tile_size_input=1, - tile_size_output=self.emb_dim // 16, - **aie_gemv_config, - ) - - # Initialize AIE elementwise multiply - if self.cfg["use_aie_ffn_mul"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.hidden_dim - else: - max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim - - self.aie_mul_prefill = AIEElementwiseMul( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim, - ) - - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.hidden_dim # 1 token * emb_dim - self.aie_mul_decode = AIEElementwiseMul( - size=decode_size, - num_aie_columns=1, - num_channels=2, - tile_size=self.hidden_dim, - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_mul_decode = self.aie_mul_prefill - - def forward(self, x): - original_shape = x.shape - - # Check if input is a vector (decode phase) or matrix (prefill phase) - # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) - is_vector = ( - len(x.shape) == 1 - or (len(x.shape) == 2 and x.shape[0] == 1) - or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) - ) - - is_prefill = not is_vector or not self.cfg["use_kv_cache"] - is_decode_with_kv = is_vector and self.cfg["use_kv_cache"] - - if self.cfg["use_aie_ffn_swiglu"]: - if is_prefill: - return self.aie_swiglu_prefill(x) - else: - return self.aie_swiglu_decode(x) - - if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: - x_fc1 = self.aie_fc1_gemv(x) - x_fc2 = self.aie_fc2_gemv(x) - else: - x_fc1 = self.fc1(x) - x_fc2 = self.fc2(x) - - if self.cfg["use_aie_ffn_silu"]: - if is_decode_with_kv: - x_fc1_silu = self.aie_silu_decode(x_fc1) - else: - x_fc1_silu = self.aie_silu_prefill(x_fc1) - else: - x_fc1_silu = self.silu(x_fc1) - - if self.cfg["use_aie_ffn_mul"]: - if is_decode_with_kv: - x = self.aie_mul_decode(x_fc1_silu, x_fc2) - else: - x = self.aie_mul_prefill(x_fc1_silu, x_fc2) - else: - x = x_fc1_silu * x_fc2 - - if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: - result = self.aie_fc3_gemv(x) - return result.view(original_shape) - else: - return self.fc3(x).view(original_shape) - - def assign_weights(self, l, fc1, fc2, fc3): - if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: - self.aie_fc1_gemv.weight = fc1 - self.aie_fc2_gemv.weight = fc2 - self.aie_fc3_gemv.weight = fc3 - - if self.cfg["use_aie_ffn_swiglu"]: - self.aie_swiglu_prefill.weights_1 = fc1 - self.aie_swiglu_prefill.weights_2 = fc2 - self.aie_swiglu_prefill.weights_3 = fc3 - if self.cfg["use_kv_cache"]: - self.aie_swiglu_decode.weights_1 = fc1 - self.aie_swiglu_decode.weights_2 = fc2 - self.aie_swiglu_decode.weights_3 = fc3 - return - - self.fc1.weight = assign( - self.fc1.weight, - fc1, - f"model.layers.{l}.mlp.gate_proj.weight", - ) - self.fc2.weight = assign( - self.fc2.weight, - fc2, - f"model.layers.{l}.mlp.up_proj.weight", - ) - self.fc3.weight = assign( - self.fc3.weight, - fc3, - f"model.layers.{l}.mlp.down_proj.weight", - ) diff --git a/applications/llama_3.2_1b/src/block/gqa.py b/applications/llama_3.2_1b/src/block/gqa.py deleted file mode 100644 index 05566814..00000000 --- a/applications/llama_3.2_1b/src/block/gqa.py +++ /dev/null @@ -1,505 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch -import torch.nn as nn - -from operators import AIERope, AIESoftmax, AIEMHA, AIEGEMM, AIEGEMV -from operators.rope.rope_utils import apply_rope - -from torchtune.modules import KVCache - -from ..utils import assign - - -class GroupedQueryAttention(nn.Module): - def __init__( - self, - d_in, - d_out, - num_heads, - num_kv_groups, - prompt_length=0, - num_tokens=1, - dtype=None, - max_batch_size=1, - max_seq_len=8192, - cfg=None, - ): - super().__init__() - assert d_out % num_heads == 0, "d_out must be divisible by num_heads" - assert ( - num_heads % num_kv_groups == 0 - ), "num_heads must be divisible by num_kv_groups" - - self.cfg = cfg.copy() if cfg is not None else {} - - self.d_out = d_out - self.num_heads = num_heads - self.head_dim = d_out // num_heads - - self.num_tokens = num_tokens - - # Weights for Attention layer - self.W_key = nn.Linear( - d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype - ) - self.W_value = nn.Linear( - d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype - ) - self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) - self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) - - self.num_kv_groups = num_kv_groups - self.group_size = num_heads // num_kv_groups - - self.prompt_length = prompt_length - - aie_gemm_config = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "use_static_weight": False, - } - - # Initialize KV Cache - if self.cfg["use_kv_cache"]: - self.kv_cache = KVCache( - batch_size=max_batch_size, - max_seq_len=max_seq_len, - num_kv_heads=self.num_kv_groups, - head_dim=self.head_dim, - dtype=torch.bfloat16, - ) - - # Initialize AIE Regular MHA operator - if self.cfg["use_aie_regular_mha"]: - self.aie_softmax = AIESoftmax( - num_aie_columns=1, - num_channels=1, - rows=prompt_length, - cols=prompt_length, - ) - M_for_gemm = prompt_length + num_tokens - self.aie_mha_gemm_qk = AIEGEMM( - M=M_for_gemm, K=self.head_dim, N=M_for_gemm, **aie_gemm_config - ) - self.aie_mha_gemm_pv = AIEGEMM( - M=M_for_gemm, K=M_for_gemm, N=self.head_dim, **aie_gemm_config - ) - - # Initialize AIE RoPE operator - if self.cfg["use_aie_rope"]: - self.aie_rope_prefill_k = AIERope( - rows=self.prompt_length * self.num_kv_groups, - cols=self.head_dim, - angle_rows=self.prompt_length, - ) - self.aie_rope_prefill_q = AIERope( - rows=self.prompt_length * self.num_heads, - cols=self.head_dim, - angle_rows=self.prompt_length, - ) - self.aie_rope_decode_k = AIERope( - rows=self.num_kv_groups, - cols=self.head_dim, - angle_rows=1, - ) - self.aie_rope_decode_q = AIERope( - rows=self.num_heads, - cols=self.head_dim, - angle_rows=1, - ) - - # Initialize fused AIE MHA operator - if self.cfg["use_aie_fused_mha"]: - self.aie_mha = AIEMHA( - num_heads=num_heads, - seq_len=prompt_length, - d=self.head_dim, - num_KV_heads=0, # Regular MHA since we feed repeated K/V - num_of_pipelines=8, - ) - - # Initialize AIE GEMV operators for decode phase (when using KV cache) - if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: - - aie_gemv_config = { - "num_aie_columns": 8, - "is_mv": False, - "use_static_weight": True, - } - self.aie_query_gemv = AIEGEMV( - M=d_out, - K=d_in, - tile_size_input=1, - tile_size_output=d_out // 16, - **aie_gemv_config, - ) - kv_out_dim = num_kv_groups * self.head_dim - self.aie_key_gemv = AIEGEMV( - M=kv_out_dim, - K=d_in, - tile_size_input=1, - tile_size_output=kv_out_dim // 16, - **aie_gemv_config, - ) - self.aie_value_gemv = AIEGEMV( - M=kv_out_dim, - K=d_in, - tile_size_input=1, - tile_size_output=kv_out_dim // 16, - **aie_gemv_config, - ) - self.aie_out_proj_gemv = AIEGEMV( - M=d_out, - K=d_out, - tile_size_input=1, - tile_size_output=d_out // 16, - **aie_gemv_config, - ) - - # Initialize AIE GEMM operators - if self.cfg["use_aie_attn_projection_gemm"]: - if self.cfg["use_kv_cache"]: - M_for_gemm = self.prompt_length - else: - M_for_gemm = self.prompt_length + self.num_tokens - - # GEMMs for projection use weights - aie_gemm_config["use_static_weight"] = True - # Query: (batch_size, d_in) @ (d_in, d_out) -> (batch_size, d_out) - self.aie_query = AIEGEMM(M=M_for_gemm, K=d_in, N=d_out, **aie_gemm_config) - # Key: (batch_size, d_in) @ (d_in, num_kv_groups * head_dim) -> (batch_size, num_kv_groups * head_dim) - kv_out_dim = num_kv_groups * self.head_dim - self.aie_key = AIEGEMM( - M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config - ) - # Value: same dimensions as key - self.aie_value = AIEGEMM( - M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config - ) - # Output projection: (batch_size, d_out) @ (d_out, d_out) -> (batch_size, d_out) - self.aie_out_proj = AIEGEMM( - M=M_for_gemm, K=d_out, N=d_out, **aie_gemm_config - ) - - def forward(self, x, mask, angles, input_pos=None): - b, num_tokens, d_in = x.shape - is_prefill = input_pos is None - is_decode = input_pos is not None - - # Step 1. - # --- - # Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices - - # Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage - if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: - # Decode phase with KV cache - use GEMV for single token - # weight.T @ input, which is vector-matrix multiplication (So, is_mv=False) - x_flat = x.reshape(1, -1) # Shape: (1, d_in) - - queries_flat = self.aie_query_gemv(x_flat) - queries = queries_flat.reshape(b, num_tokens, self.d_out) - - keys_flat = self.aie_key_gemv(x_flat) - keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) - - values_flat = self.aie_value_gemv(x_flat) - values = values_flat.reshape( - b, num_tokens, self.num_kv_groups * self.head_dim - ) - - elif self.cfg["use_aie_attn_projection_gemm"]: - # Prefill phase - use GEMM for multiple tokens - x_flat = x.reshape(-1, d_in) - input_dtype = x.dtype - - queries_flat = self.aie_query(x_flat) - queries = queries_flat.reshape(b, num_tokens, self.d_out) - - keys_flat = self.aie_key(x_flat) - keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) - - values_flat = self.aie_value(x_flat) - values = values_flat.reshape( - b, num_tokens, self.num_kv_groups * self.head_dim - ) - else: - queries = self.W_query(x) - keys = self.W_key(x) - values = self.W_value(x) - - # Each attention head gets its own slice of the embedding dimension. - # For each head, we have query, key and value. - # In grouped-query attention, the keys and values are shared across groups of heads. - # Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values. - # Each head can be applied independently to its subslice of the embedding dimension. - keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) - values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) - queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) - - # Step 2. - # --- - # Apply positional encoding to keys and queries. - # The positional embedding is applied independently to each head. - # It modifies the embedding vectors to encode where in the sequence each token is located. - - # Determine angle slice based on KV cache usage and phase - if self.cfg["use_kv_cache"] and is_decode: - # Decode phase with KV cache: use single position - current_pos = input_pos.item() - angle_slice = angles[current_pos : current_pos + 1, :] - else: - # Prefill phase or no KV cache: use all tokens - angle_slice = angles[:num_tokens, :] - - # Apply RoPE with AIE - def apply_rope_and_transpose(aie_op, tensor, num_heads_dim, angle_slice): - angle_slice = angle_slice.to(dtype=tensor.dtype) - if self.cfg["use_aie_rope"]: - result = aie_op( - tensor.view(num_tokens * num_heads_dim, self.head_dim), angle_slice - ) - result = result.view( - b, num_tokens, num_heads_dim, self.head_dim - ).transpose(1, 2) - else: - transposed = ( - tensor.view(num_tokens, num_heads_dim, self.head_dim) - .transpose(0, 1) - .contiguous() - ) - result = apply_rope( - transposed.view(1, num_heads_dim, num_tokens, self.head_dim), - angle_slice, - ) - # ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice) - # assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference" - return result - - keys = apply_rope_and_transpose( - ( - (self.aie_rope_prefill_k if is_prefill else self.aie_rope_decode_k) - if self.cfg["use_aie_rope"] - else None - ), - keys, - self.num_kv_groups, - angle_slice, - ) - queries = apply_rope_and_transpose( - ( - (self.aie_rope_prefill_q if is_prefill else self.aie_rope_decode_q) - if self.cfg["use_aie_rope"] - else None - ), - queries, - self.num_heads, - angle_slice, - ) - values = values.transpose(1, 2) - - if self.cfg["use_kv_cache"]: - if is_prefill: - self.kv_cache.reset() - self.kv_cache.update(keys, values) - cached_keys, cached_values = keys, values - else: - self.kv_cache.update(keys, values) - current_seq_len = input_pos.item() + 1 - cached_keys = self.kv_cache.k_cache[:, :, :current_seq_len, :] - cached_values = self.kv_cache.v_cache[:, :, :current_seq_len, :] - - keys = cached_keys - values = cached_values - - # Step 3. - # --- - # Since the keys and values are shared across groups of heads in grouped-query attention, - # we now expand (repeat) the same keys and values so that each head has its own keys and values. - keys = keys.repeat_interleave(self.group_size, dim=1) - values = values.repeat_interleave(self.group_size, dim=1) - - # Step 4. - # --- - # Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output. - # Attention scores are the dot-product of queries and keys. - - # Use fused AIE MHA if enabled and conditions are met - if is_prefill or not self.cfg["use_kv_cache"]: - if ( - self.cfg["use_aie_fused_mha"] - and b == 1 - and num_tokens == self.prompt_length - and self.head_dim == 64 - ): - # TODO: Doesn't give good output ven with num_kv_groups set to 8 with kv_cache - # TODO: Doesn't match the output of CPU only when used without kv_cache - context_vec = self.aie_mha( - queries, keys, values - ) # Shape: (num_heads, num_tokens, head_dim) - - # Reshape context_vec to prepare for output projection - context_vec = context_vec.transpose(0, 1) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - - elif self.cfg["use_aie_regular_mha"]: - # attn_scores = queries @ keys.transpose(2, 3) - # Compute attention scores for each head separately since AIE GEMM doesn't support batched operations - attn_scores_list = [] - for head in range(self.num_heads): - q_head = queries[:, head, :, :] # Shape: (b, num_tokens, head_dim) - k_head = keys[:, head, :, :] # Shape: (b, num_tokens, head_dim) - - # Use 2D tensors directly (remove batch dimension if b=1) - q_2d = q_head.squeeze(0) # Shape: (num_tokens, head_dim) - k_2d = k_head.squeeze(0) # Shape: (num_tokens, head_dim) - - # Compute Q @ K^T for this head - attn_head = self.aie_mha_gemm_qk( - q_2d, k_2d.T - ) # Shape: (num_tokens, num_tokens) - attn_head = attn_head.unsqueeze(0).unsqueeze( - 0 - ) # Add batch and head dimensions - attn_scores_list.append( - attn_head - ) # Shape: (1, 1, num_tokens, num_tokens) - - attn_scores = torch.cat( - attn_scores_list, dim=1 - ) # Shape: (b, num_heads, num_tokens, num_tokens) - attn_scores = attn_scores.masked_fill(mask, -torch.inf) - scaled_scores = attn_scores / (self.head_dim**0.5) - - # TODO: Make softmax more configurable to run in any scenario - if ( - scaled_scores.shape[-1] == self.prompt_length - and scaled_scores.shape[-1] % 16 == 0 - ): - attn_weights = self.aie_softmax(scaled_scores) - else: - attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) - - # Compute context vector for each head separately using AIE GEMM - context_vec_list = [] - for head in range(self.num_heads): - attn_head = attn_weights[ - :, head, :, : - ] # Shape: (b, num_tokens, num_tokens) - v_head = values[:, head, :, :] # Shape: (b, num_tokens, head_dim) - - # Use 2D tensors directly (remove batch dimension if b=1) - attn_2d = attn_head.squeeze(0) # Shape: (num_tokens, num_tokens) - v_2d = v_head.squeeze(0) # Shape: (num_tokens, head_dim) - - # Compute attn @ V for this head - context_head = self.aie_mha_gemm_pv( - attn_2d, v_2d - ) # Shape: (num_tokens, head_dim) - context_head = context_head.unsqueeze(0).unsqueeze( - 1 - ) # Add batch and head dimensions - context_vec_list.append( - context_head - ) # Shape: (1, 1, num_tokens, head_dim) - - context_vec = torch.cat( - context_vec_list, dim=1 - ) # Shape: (b, num_heads, num_tokens, head_dim) - context_vec = context_vec.transpose( - 1, 2 - ) # Shape: (b, num_tokens, num_heads, head_dim) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - else: - - def my_mha(queries, keys, values): - inv_scale = 1 / np.sqrt(values.shape[-1]) - context_vec = torch.nn.functional.scaled_dot_product_attention( - queries, - keys, - values, - dropout_p=0.0, - is_causal=True, - scale=inv_scale, - ) - return context_vec - - context_vec = my_mha(queries, keys, values) - context_vec = context_vec.transpose(1, 2) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - else: - attn_scores = queries @ keys.transpose(2, 3) - - if mask is not None: - attn_scores = attn_scores.masked_fill(mask, -torch.inf) - - scaled_scores = attn_scores / (self.head_dim**0.5) - - if ( - scaled_scores.shape[-1] == self.prompt_length - and self.cfg["use_aie_softmax"] - and scaled_scores.shape[-1] % 16 == 0 - ): - attn_weights = self.aie_softmax(scaled_scores) - else: - attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) - - context_vec = (attn_weights @ values).transpose(1, 2) - context_vec = context_vec.reshape(b, num_tokens, self.d_out) - - # Choose output projection based on phase - if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: - context_vec_flat = context_vec.reshape(1, -1) - output_flat = self.aie_out_proj_gemv(context_vec_flat) - context_vec = output_flat.reshape(b, num_tokens, self.d_out) - elif self.cfg["use_aie_attn_projection_gemm"]: - context_vec_flat = context_vec.reshape(-1, self.d_out) - output_flat = self.aie_out_proj(context_vec_flat) - context_vec = output_flat.reshape(b, num_tokens, self.d_out) - else: - context_vec = self.out_proj(context_vec) - - return context_vec - - def assign_weights(self, l, w_query, w_key, w_value, w_out_proj): - if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: - self.aie_query_gemv.weight = w_query - self.aie_key_gemv.weight = w_key - self.aie_value_gemv.weight = w_value - self.aie_out_proj_gemv.weight = w_out_proj - - if self.cfg["use_aie_attn_projection_gemm"]: - self.aie_query.weight = w_query - self.aie_key.weight = w_key - self.aie_value.weight = w_value - self.aie_out_proj.weight = w_out_proj - - self.W_query.weight = assign( - self.W_query.weight, - w_query, - f"model.layers.{l}.self_attn.q_proj.weight", - ) - self.W_key.weight = assign( - self.W_key.weight, - w_key, - f"model.layers.{l}.self_attn.k_proj.weight", - ) - self.W_value.weight = assign( - self.W_value.weight, - w_value, - f"model.layers.{l}.self_attn.v_proj.weight", - ) - self.out_proj.weight = assign( - self.out_proj.weight, - w_out_proj, - f"model.layers.{l}.self_attn.o_proj.weight", - ) diff --git a/applications/llama_3.2_1b/src/block/transformer.py b/applications/llama_3.2_1b/src/block/transformer.py deleted file mode 100644 index 42a48146..00000000 --- a/applications/llama_3.2_1b/src/block/transformer.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -from ..utils import assign -from src.block.gqa import GroupedQueryAttention -from src.block.feed_forward import FeedForward -from operators import AIERMSNorm, AIEElementwiseAdd - - -class TransformerBlock(nn.Module): - def __init__( - self, - cfg, - prompt_length=42, - num_tokens=1, - ): - super().__init__() - self.cfg = cfg.copy() - - self.att = GroupedQueryAttention( - d_in=cfg["emb_dim"], - d_out=cfg["emb_dim"], - num_heads=cfg["n_heads"], - num_kv_groups=cfg["n_kv_groups"], - dtype=cfg["dtype"], - prompt_length=prompt_length, - cfg=cfg, - ) - self.ff = FeedForward( - cfg, - prompt_length=prompt_length, - num_tokens=num_tokens, - ) - - if self.cfg["use_aie_norm1"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_norm1_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_norm1_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_norm1_decode = self.aie_norm1_prefill - else: - self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) - - if self.cfg["use_aie_norm2"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_norm2_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_norm2_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_norm2_decode = self.aie_norm2_prefill - else: - self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) - - if self.cfg["use_aie_residual"]: - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * cfg["emb_dim"] - - self.aie_residual_add_prefill = AIEElementwiseAdd( - size=max_prefill_size, - num_aie_columns=8, - num_channels=2, - tile_size=cfg["emb_dim"], - ) - - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = cfg["emb_dim"] # 1 token * emb_dim - self.aie_residual_add_decode = AIEElementwiseAdd( - size=decode_size, - num_aie_columns=1, - num_channels=2, - tile_size=cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_residual_add_decode = self.aie_residual_add_prefill - - def forward(self, x, mask, angles, input_pos): - original_shape = x.shape - - # (batch, sequence, embedding) where sequence=1 indicates decode - if len(x.shape) == 3: - is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] - elif len(x.shape) == 2: - is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] - else: - is_decode_with_kv = False - - shortcut = x - if self.cfg["use_aie_norm1"]: - if is_decode_with_kv: - x = self.aie_norm1_decode(x) - else: - x = self.aie_norm1_prefill(x) - else: - x = self.norm1(x) - - x = self.att(x, mask, angles, input_pos) - - if self.cfg["use_aie_residual"]: - if is_decode_with_kv: - x = self.aie_residual_add_decode(x, shortcut) - else: - x = self.aie_residual_add_prefill(x, shortcut) - else: - x = x + shortcut - - # Shortcut connection for feed-forward block - shortcut = x - if self.cfg["use_aie_norm2"]: - if is_decode_with_kv: - x = self.aie_norm2_decode(x) - else: - x = self.aie_norm2_prefill(x) - else: - x = self.norm2(x) - x = self.ff(x) - - if self.cfg["use_aie_residual"]: - if is_decode_with_kv: - x = self.aie_residual_add_decode(x, shortcut) - else: - x = self.aie_residual_add_prefill(x, shortcut) - else: - x = x + shortcut - - return x - - def assign_weights(self, l, norm1, norm2): - if self.cfg["use_aie_norm1"]: - self.aie_norm1_prefill.weight = norm1 - if self.cfg["use_kv_cache"]: - self.aie_norm1_decode.weight = norm1 - if self.cfg["use_aie_norm2"]: - self.aie_norm2_prefill.weight = norm2 - if self.cfg["use_kv_cache"]: - self.aie_norm2_decode.weight = norm2 - return - - self.norm1.weight = assign( - self.norm1.weight, - norm1, - f"model.layers.{l}.input_layernorm.weight", - ) - self.norm2.weight = assign( - self.norm2.weight, - norm2, - f"model.layers.{l}.post_attention_layernorm.weight", - ) diff --git a/applications/llama_3.2_1b/src/model_with_json.py b/applications/llama_3.2_1b/src/model_with_json.py deleted file mode 100644 index 9984e0a3..00000000 --- a/applications/llama_3.2_1b/src/model_with_json.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.nn as nn -import json -from pathlib import Path -from src.block.transformer import TransformerBlock -from operators.rope.rope_utils import compute_rope_params -from operators import AIERMSNorm, AIEGEMM, AIEGEMV -from rich.console import Console -from rich.text import Text - -from .utils import assign - - -def dtype_from_string(inp): - if isinstance(inp, torch.dtype): - return inp - return {"bfloat16": torch.bfloat16, "float16": torch.float16}.get( - inp, torch.float32 - ) - - -# fmt: off -# Configuration flag key -> (type function, default value, description) -config_options = { - "dtype": (dtype_from_string, torch.float32, "Data type"), - "use_kv_cache": (bool, False, "[Model] KV Cache"), - "use_aie_rope": (bool, False, "[Attention] Rope"), - "use_aie_attn_projection_gemm": (bool, False, "[Attention] QKV GEMM"), - "use_aie_regular_mha": (bool, False, "[Attention] Regular MHA"), - "use_aie_fused_mha": (bool, False, "[Attention] Fused MHA"), - "use_aie_gqa_gemv": (bool, False, "[Attention] GEMV (Decode)"), - "use_aie_ffn_gemm": (bool, False, "[FFN] GEMM"), - "use_aie_ffn_mul": (bool, False, "[FFN] Elementwise Mul"), - "use_aie_ffn_silu": (bool, False, "[FFN] SiLU"), - "use_aie_ffn_swiglu": (bool, False, "[FFN] Runlist-based SwiGLU"), - "use_aie_ffn_gemv": (bool, False, "[FFN] GEMV (Decode)"), - "use_aie_residual": (bool, False, "[Transformer] Residual Addition"), - "use_aie_norm1": (bool, False, "[Transformer] Pre Norm"), - "use_aie_norm2": (bool, False, "[Transformer] Post Norm"), - "use_aie_final_norm": (bool, False, "[Transformer] Final Norm"), - "use_aie_final_gemm": (bool, False, "[Transformer] Final GEMM"), - "use_aie_final_gemv": (bool, False, "[Transformer] Final GEMV"), -} -# fmt: on - - -def load_llama_config(config_path=None): - """Load Llama configuration from JSON file""" - if config_path is None: - # Default to config.json in the llama directory - config_path = Path(__file__).parent.parent / "llama32_1b.json" - - with open(config_path, "r") as f: - config = json.load(f) - - model_config = config["model_config"].copy() - for key, (type_fn, default_value, description) in config_options.items(): - if key in model_config: - model_config[key] = type_fn(model_config[key]) - else: - model_config[key] = default_value - - return model_config - - -def print_config(cfg, console=Console()): - def format_option(name, value): - if isinstance(value, bool): - checkmark = "[green]✔[/green]" if value else "[red]✘[/red]" - return f"{name} {checkmark}" - return f"{name}: {value}" - - dont_print = {"dtype"} - # The following options are mutually exclusive, e.g. regular and fused MHA - # cannot be enabled at the same time. But it looks bad to have red Xs, - # indicating things are running on the CPU when they are not. So, we only - # print one of these mutually exclusive options. - if cfg["use_aie_fused_mha"]: - dont_print |= {"use_aie_regular_mha"} - else: - dont_print |= {"use_aie_fused_mha"} - if cfg["use_aie_ffn_swiglu"]: - dont_print |= { - "use_aie_ffn_gemm", - "use_aie_ffn_mul", - "use_aie_ffn_silu", - } - else: - dont_print |= {"use_aie_ffn_swiglu"} - - console.print( - "AIE Configuration ([green]✔[/green] = AIE NPU / [red]✘[/red] = CPU):", - style="bold underline", - ) - for option_key, (option_ty, option_default, option_name) in config_options.items(): - if option_key in dont_print: - continue - console.print(format_option(option_name, cfg.get(option_key, option_default))) - console.print("") - - -class Llama3ModelWithJSONConfig(nn.Module): - """Llama3 model that loads configuration from JSON file""" - - def __init__( - self, - config_path=None, - prompt_length=0, - num_tokens=1, - ): - super().__init__() - - # Load configuration from JSON - self.cfg = load_llama_config(config_path) - self.prompt_length = prompt_length - self.num_tokens = num_tokens - print_config(self.cfg) - - # Main model parameters - self.tok_emb = nn.Embedding( - self.cfg["vocab_size"], self.cfg["emb_dim"], dtype=self.cfg["dtype"] - ) - - self.trf_blocks = nn.ModuleList( - [ - TransformerBlock( - self.cfg, - prompt_length=prompt_length, - num_tokens=num_tokens, - ) - for i in range(self.cfg["n_layers"]) - ] - ) - - # Create final norm - either AIE or PyTorch - if self.cfg.get("use_aie_final_norm", False): - if self.cfg["use_kv_cache"]: - max_prefill_size = prompt_length * self.cfg["emb_dim"] - else: - max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] - self.aie_final_norm_prefill = AIERMSNorm( - size=max_prefill_size, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - # For decode phase - single token (only when using KV cache) - if self.cfg["use_kv_cache"]: - decode_size = self.cfg["emb_dim"] # 1 token * emb_dim - self.aie_final_norm_decode = AIERMSNorm( - size=decode_size, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=self.cfg["emb_dim"], - ) - else: - # When not using KV cache, use same operator for both phases - self.aie_final_norm_decode = self.aie_final_norm_prefill - else: - self.final_norm = nn.RMSNorm( - self.cfg["emb_dim"], eps=1e-5, dtype=self.cfg["dtype"] - ) - - # Offload final linear layer if enabled - if self.cfg.get("use_aie_final_gemm", False): - # Since this GEMM has such a large N dimension, partition the N dimension by 4, - # and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C - aie_config_prefill = { - "num_aie_columns": 8, - "tile_m": 64, - "tile_k": 64, - "tile_n": 64, - "b_col_maj": True, - "use_static_weight": True, - "separate_c_tiles": True, - "partition_N": 4, - } - if self.cfg["use_kv_cache"]: - M_for_gemm = self.prompt_length - else: - M_for_gemm = self.prompt_length + self.num_tokens - self.out_head_prefill = AIEGEMM( - M=M_for_gemm, - K=self.cfg["emb_dim"], - N=self.cfg["vocab_size"], - **aie_config_prefill, - ) - aie_gemv_config = { - "num_aie_columns": 8, - "is_mv": True, - "use_static_weight": True, - "num_aie_columns": 8, - "tile_size_input": 4, - "tile_size_output": 32, - } - # FC1 and FC2: emb_dim -> hidden_dim - if self.cfg["use_aie_final_gemv"]: - self.out_head_decode = AIEGEMV( - M=self.cfg["vocab_size"], K=self.cfg["emb_dim"], **aie_gemv_config - ) - else: - self.out_head = nn.Linear( - self.cfg["emb_dim"], - self.cfg["vocab_size"], - bias=False, - dtype=self.cfg["dtype"], - ) - - # Reusable utilities - cos, sin = compute_rope_params( - head_dim=self.cfg["emb_dim"] // self.cfg["n_heads"], - theta_base=self.cfg["rope_base"], - context_length=self.cfg["context_length"], - freq_config=self.cfg["rope_freq"], - ) - angles = torch.cat([torch.empty_like(cos), torch.empty_like(cos)], dim=1) - angles[:, ::2] = cos - angles[:, 1::2] = sin - self.register_buffer("angles", angles, persistent=False) - - def forward(self, in_idx, input_pos=None, use_kv_cache=False): - # Forward pass - tok_embeds = self.tok_emb(in_idx) - x = tok_embeds - - # Check if input is a vector (decode phase) or matrix (prefill phase) - # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) - is_vector = ( - len(x.shape) == 1 - or (len(x.shape) == 2 and x.shape[0] == 1) - or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) - ) - - # (batch, sequence, embedding) where sequence=1 indicates decode - if len(x.shape) == 3: - is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] - elif len(x.shape) == 2: - is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] - else: - is_decode_with_kv = False - - num_tokens = x.shape[1] - - # During generation phase with KV cache, don't create a mask - # The attention layer will handle masking based on position - if use_kv_cache and input_pos is not None: - mask = None - else: - # During prefill, create standard causal mask - mask = torch.triu( - torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), - diagonal=1, - ) - - for block in self.trf_blocks: - x = block(x, mask, self.angles, input_pos) - - # Sequence length of 1 from input shape means we're in the decode stage, which can use KV cache - if self.cfg.get("use_aie_final_norm", False): - if (x.shape[-2] == 1) and self.cfg.get("use_kv_cache", False): - x = self.aie_final_norm_decode(x) - else: - x = self.aie_final_norm_prefill(x) - else: - x = self.final_norm(x) - - if self.cfg["use_aie_final_gemm"]: - if is_decode_with_kv and self.cfg["use_aie_final_gemv"]: - logits = self.out_head_decode(x) - else: - logits = self.out_head_prefill(x) - else: - logits = self.out_head(x) - - return logits - - def assign_weights(self, final_norm, out_head, out_head_name): - if self.cfg.get("use_aie_final_norm", False): - self.aie_final_norm_prefill.weight = final_norm - if self.cfg["use_kv_cache"]: - self.aie_final_norm_decode.weight = final_norm - else: - self.final_norm.weight = assign( - self.final_norm.weight, - final_norm, - f"model.norm.weight", - ) - - if self.cfg["use_aie_final_gemm"]: - # Want column-major for B - self.out_head_prefill.weight = out_head.T - if self.cfg["use_aie_final_gemv"]: - self.out_head_decode.weight = out_head.T - else: - self.out_head.weight = assign( - self.out_head.weight, - out_head, - out_head_name, - ) diff --git a/applications/llama_3.2_1b/src/tokenizer.py b/applications/llama_3.2_1b/src/tokenizer.py deleted file mode 100644 index 1a16cf57..00000000 --- a/applications/llama_3.2_1b/src/tokenizer.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import os -from pathlib import Path - -import tiktoken -from tiktoken.load import load_tiktoken_bpe - - -class Tokenizer: - """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" - - def __init__(self, model_path): - if not os.path.isfile(model_path): - raise FileNotFoundError(model_path) - - mergeable = load_tiktoken_bpe(model_path) - - # hard-coded from Meta's tokenizer.json - self.special = { - "<|begin_of_text|>": 128000, - "<|end_of_text|>": 128001, - "<|start_header_id|>": 128006, - "<|end_header_id|>": 128007, - "<|eot_id|>": 128009, - } - self.special.update( - { - f"<|reserved_{i}|>": 128002 + i - for i in range(256) - if 128002 + i not in self.special.values() - } - ) - - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" - r"|[^\r\n\p{L}\p{N}]?\p{L}+" - r"|\p{N}{1,3}" - r"| ?[^\s\p{L}\p{N}]+[\r\n]*" - r"|\s*[\r\n]+" - r"|\s+(?!\S)" - r"|\s+", - mergeable_ranks=mergeable, - special_tokens=self.special, - ) - - def encode(self, text, bos=False, eos=False): - ids = ([self.special["<|begin_of_text|>"]] if bos else []) + self.model.encode( - text - ) - if eos: - ids.append(self.special["<|end_of_text|>"]) - return ids - - def decode(self, ids): - return self.model.decode(ids) - - -class ChatFormat: - - def __init__( - self, tokenizer: Tokenizer, *, default_system="You are a helpful assistant." - ): - self.tok = tokenizer - self.default_system = default_system - - def _header(self, role): - """Encode <|start_header_id|>role<|end_header_id|>\n\n""" - return ( - [self.tok.special["<|start_header_id|>"]] - + self.tok.encode(role) - + [self.tok.special["<|end_header_id|>"]] - + self.tok.encode("\n\n") - ) - - def encode(self, user_message, system_message=None): - sys_msg = system_message if system_message is not None else self.default_system - - ids = [self.tok.special["<|begin_of_text|>"]] - - # system - ids += self._header("system") - ids += self.tok.encode(sys_msg) - ids += [self.tok.special["<|eot_id|>"]] - - # user - ids += self._header("user") - ids += self.tok.encode(user_message) - ids += [self.tok.special["<|eot_id|>"]] - - # assistant header (no content yet) - ids += self._header("assistant") - - return ids diff --git a/applications/llama_3.2_1b/src/utils.py b/applications/llama_3.2_1b/src/utils.py deleted file mode 100644 index 158b59df..00000000 --- a/applications/llama_3.2_1b/src/utils.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Sebastian Raschka under Apache License 2.0. -# Source for "Build a Large Language Model From Scratch" -# - https://www.manning.com/books/build-a-large-language-model-from-scratch -# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb -# -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import time -import torch -import numpy as np -from ml_dtypes import bfloat16 - - -def model_memory_size(model, input_dtype=torch.float32): - """ - Calculate the estimated memory size of a PyTorch model in gigabytes. - - This function computes the total memory required for the model's parameters, - gradients, and buffers based on the input data type. - - Args: - model (torch.nn.Module): The PyTorch model for which to calculate memory size. - input_dtype (torch.dtype, optional): The data type of the model's input. - Defaults to torch.float32. - - Returns: - float: The estimated memory size of the model in gigabytes. - """ - - total_params = 0 - total_grads = 0 - for param in model.parameters(): - # Calculate total number of elements per parameter - param_size = param.numel() - total_params += param_size - # Check if gradients are stored for this parameter - if param.requires_grad: - total_grads += param_size - - # Calculate buffer size (non-parameters that require memory) - total_buffers = sum(buf.numel() for buf in model.buffers()) - - # Size in bytes = (Number of elements) * (Size of each element in bytes) - # We assume parameters and gradients are stored in the same type as input dtype - element_size = torch.tensor(0, dtype=input_dtype).element_size() - total_memory_bytes = (total_params + total_grads + total_buffers) * element_size - - # Convert bytes to gigabytes - total_memory_gb = total_memory_bytes / (1024**3) - - return total_memory_gb - - -def assign(left, right, tensor_name="unknown"): - """ - Assigns the value of the right tensor to a new torch.nn.Parameter after validating shape compatibility. - - Parameters: - left (torch.Tensor or any): The tensor to compare shape with. - right (torch.Tensor or any): The tensor or value to be assigned. - tensor_name (str): The name of the tensor for error reporting (default is "unknown"). - - Returns: - torch.nn.Parameter: A new parameter containing the value of right. - - Raises: - ValueError: If the shapes of left and right do not match. - """ - - if left.shape != right.shape: - raise ValueError( - f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}" - ) - - if isinstance(right, torch.Tensor): - return torch.nn.Parameter(right.clone().detach()) - else: - return torch.nn.Parameter(torch.tensor(right)) - - -def load_weights_into_llama(model, param_config, params): - """ - Load weights into the LLaMA model from the provided parameters. - - This function assigns weights from the given parameters to the corresponding - layers of the LLaMA model. It handles the embedding layer, attention layers, - feedforward layers, and the output layer. The function also checks for weight - tying between the output head and the embedding layer. - - Args: - model: The LLaMA model instance into which weights will be loaded. - param_config (dict): A configuration dictionary containing model parameters, - including the number of layers (`n_layers`). - params (dict): A dictionary containing the weights to be loaded, with keys - corresponding to the model's architecture. - """ - model.tok_emb.weight = assign( - model.tok_emb.weight, - params["model.embed_tokens.weight"], - "model.embed_tokens.weight", - ) - - for l in range(param_config["n_layers"]): - - # Load attention weights - model.trf_blocks[l].att.assign_weights( - l, - params[f"model.layers.{l}.self_attn.q_proj.weight"], - params[f"model.layers.{l}.self_attn.k_proj.weight"], - params[f"model.layers.{l}.self_attn.v_proj.weight"], - params[f"model.layers.{l}.self_attn.o_proj.weight"], - ) - # Load FeedForward weights - model.trf_blocks[l].ff.assign_weights( - l, - fc1=params[f"model.layers.{l}.mlp.gate_proj.weight"], - fc2=params[f"model.layers.{l}.mlp.up_proj.weight"], - fc3=params[f"model.layers.{l}.mlp.down_proj.weight"], - ) - # Load RMS norm weights - model.trf_blocks[l].assign_weights( - l, - params[f"model.layers.{l}.input_layernorm.weight"], - params[f"model.layers.{l}.post_attention_layernorm.weight"], - ) - - # Load output layer weights - if "lm_head.weight" in params.keys(): - model.assign_weights( - params["model.norm.weight"], params["lm_head.weight"], "lm_head.weight" - ) - else: - model.assign_weights( - params["model.norm.weight"], - params["model.embed_tokens.weight"], - "model.embed_tokens.weight", - ) - - -def text_to_token_ids(text, tokenizer): - """ - Convert a given text into token IDs using the specified tokenizer. - - Args: - text (str): The input text to be tokenized. - tokenizer: An instance of a tokenizer that has an `encode` method. - - Returns: - torch.Tensor: A tensor containing the token IDs of the input text, - with an added batch dimension. - """ - encoded = tokenizer.encode(text) - encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension - return encoded_tensor - - -def token_ids_to_text(token_ids, tokenizer): - """ - Convert a tensor of token IDs to a human-readable text string. - - Args: - token_ids (torch.Tensor): A tensor containing token IDs, - typically with a batch dimension. - tokenizer (Tokenizer): An instance of a tokenizer that has a - decode method to convert token IDs to text. - - Returns: - str: The decoded text string corresponding to the input token IDs. - """ - flat = token_ids.squeeze(0) # remove batch dimension - return tokenizer.decode(flat.tolist()) - - -def generate( - model, - idx, - max_new_tokens, - context_size, - eos_id, - hook_handles, - temperature=0.0, - top_k=None, - tokenizer=None, - prompt=None, - do_print=True, - prefill_done_callback=None, -): - """ - Generate new tokens using the provided model based on the input sequence. - - Args: - model: The model used for generating tokens. It should accept input sequences and return logits. - idx (torch.Tensor): The input sequence of token indices (shape: (batch_size, sequence_length)). - max_new_tokens (int): The maximum number of new tokens to generate. - context_size (int): The number of tokens from the input sequence to consider for generation. - temperature (float, optional): The temperature for scaling logits. Higher values result in more random outputs. Default is 0.0 (no scaling). - top_k (int, optional): The number of top logits to consider for sampling. If None, all logits are used. Default is None. - eos_id (int, optional): The end-of-sequence token ID. If specified, generation will stop when this token is produced. Default is None. - - Returns: - torch.Tensor: The updated sequence of token indices after generation (shape: (batch_size, new_sequence_length)). - """ - # For-loop is the same as before: Get logits, and only focus on last time step - finished_prefill = False - - print(f"Starting prefill inference...") - - for i in range(max_new_tokens): - use_kv_cache = model.cfg["use_kv_cache"] - - if use_kv_cache: - if i == 0: - # Prefill phase - process entire sequence - idx_cond = idx[:, -context_size:] - input_pos = None - else: - # Generation phase with KV cache - single token, need to track position - # Extract only the last token - idx_cond = idx[:, -1:] - input_pos = torch.tensor([idx.shape[1] - 1], device=idx.device) - else: - # No KV cache - always process entire sequence (GEMM every time) - idx_cond = idx[:, -context_size:] - input_pos = None - with torch.no_grad(): - logits = model(idx_cond, input_pos=input_pos, use_kv_cache=use_kv_cache) - logits = logits[:, -1, :] - - # New: Filter logits with top_k sampling - if top_k is not None: - # Keep only top_k values - top_logits, _ = torch.topk(logits, top_k) - min_val = top_logits[:, -1] - logits = torch.where( - logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits - ) - - # New: Apply temperature scaling - if temperature > 0.0: - logits = logits / temperature - - # Apply softmax to get probabilities - probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) - - # Sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) - - # Otherwise same as before: get idx of the vocab entry with the highest logits value - else: - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) - - # Only run the forward hook for the prefill stage, remove it afterwards to speed up inference - if not finished_prefill: - if hook_handles: - for handle in hook_handles: - handle.remove() - finished_prefill = True - - if ( - idx_next == eos_id - ): # Stop generating early if end-of-sequence token is encountered and eos_id is specified - break - - # Same as before: append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) - - # End timing the first iteration - if i == 0: - if prefill_done_callback is not None: - prefill_done_callback() - if do_print: - print(prompt) - - # print(f"\rGenerating token {i + 1}/{max_new_tokens}...", end="") - generated_text = token_ids_to_text(idx_next, tokenizer) - if do_print: - print(f"{generated_text}", end="", flush=True) - - print("\n\n") - return idx - - -def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): - """ - Cleans the input text by removing the header portion defined by the header_end token. - - Parameters: - text (str): The input text to be cleaned. - header_end (str): The token that marks the end of the header. Defaults to "assistant<|end_header_id|>\n\n". - - Returns: - str: The cleaned text, which is the substring after the header_end token. - If the token is not found, the original text is returned. - """ - - # Find the index of the first occurrence of "<|end_header_id|>" - index = text.find(header_end) - - if index != -1: - # Return the substring starting after "<|end_header_id|>" - return text[ - index + len(header_end) : - ].strip() # Strip removes leading/trailing whitespace - else: - # If the token is not found, return the original text - return text diff --git a/applications/llama_3.2_1b/test.py b/applications/llama_3.2_1b/test.py deleted file mode 100644 index 933b7d5e..00000000 --- a/applications/llama_3.2_1b/test.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import subprocess -import pytest -from pathlib import Path - -test_dir = Path(__file__).parent -weights_dir = Path("/srv") - - -def generate_test_params(): - prompt_lengths = [2048, 13] - num_tokens_list = [40, 1] - - params = [] - names = [] - for prompt_len in prompt_lengths: - for num_tokens in num_tokens_list: - params.append((prompt_len, num_tokens)) - names.append(f"llama_3.2_1b_prompt_{prompt_len}_tokens_{num_tokens}") - return params, names - - -params, names = generate_test_params() - - -@pytest.mark.metrics( - TTFT=r"Prefill time: (?P[\d\.e\+-]+) seconds", - TPS=r"Tokens per second: (?P[\d\.e\+-]+)", - Num_Tokens=r"Tokens generated: (?P[\d\.e\+-]+)", -) -@pytest.mark.parametrize("prompt_len,num_tokens", params, ids=names) -def test_llama_3_2_1b(prompt_len, num_tokens): - command = f"python3 {test_dir}/inference.py {weights_dir}/llama3.2-1b/model.safetensors {weights_dir}/llama3.2-1b/tokenizer.model --prompt_len {prompt_len} --num_tokens {num_tokens}" - - result = subprocess.run( - command, - cwd=test_dir, - shell=True, - capture_output=True, - text=True, - timeout=300, - ) - - assert ( - result.returncode == 0 - ), f"Command failed with return code {result.returncode}\nStderr: {result.stderr}" - - print(result.stdout) diff --git a/applications/llama_3.2_1b/torch_to_npy.py b/applications/llama_3.2_1b/torch_to_npy.py deleted file mode 100644 index e7d06be0..00000000 --- a/applications/llama_3.2_1b/torch_to_npy.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import argparse -import numpy as np -import os -import shutil - - -def torch_to_npy(inp_file_path, outp_file_path): - # Load the torch file - data = torch.load(inp_file_path) - # Convert the tensor to a numpy array of floats - data_np = data.to(torch.float32).numpy() - # Compare the values between data and data_np - if not torch.equal(data, torch.from_numpy(data_np)): - raise ValueError("Mismatch between original data and converted numpy array.") - - # Save the array to a npy file - np.save(outp_file_path, data_np) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert torch files to npy files.") - parser.add_argument( - "file_path", - type=str, - help="Path to the torch file or directory containing torch files", - ) - args = parser.parse_args() - file_path = args.file_path - - output_dir = os.path.join("results", f"{os.path.basename(file_path)}_npy") - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir) - - # Check if the file path is a directory - if os.path.isdir(file_path): - for file_name in os.listdir(file_path): - if file_name.endswith(".pt") or file_name.endswith(".pth"): - full_path = os.path.join(file_path, file_name) - output_file_path = os.path.join( - output_dir, file_name.replace(".pt", ".npy").replace(".pth", ".npy") - ) - torch_to_npy(full_path, output_file_path) - else: - torch_to_npy(file_path) diff --git a/applications/llama_3.2_1b/visualize_profile.py b/applications/llama_3.2_1b/visualize_profile.py deleted file mode 100644 index 442d7115..00000000 --- a/applications/llama_3.2_1b/visualize_profile.py +++ /dev/null @@ -1,437 +0,0 @@ -import json -import argparse -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from collections import defaultdict - -# Try to import seaborn, fall back to matplotlib if not available -try: - import seaborn as sns - HAS_SEABORN = True -except ImportError: - HAS_SEABORN = False - -def load_profile_data(json_file): - """Load profile data from JSON file.""" - with open(json_file, 'r') as f: - return json.load(f) - -def extract_function_info(full_name): - """Extract function name, filename, path components, and full path from the identifier.""" - # Remove parameters if present - if '(' in full_name: - full_name = full_name.split('(')[0] - - # Split by '/' to get path components - path_parts = full_name.split('/') - - # Get the last part which contains filename:line:function - last_part = path_parts[-1] - parts = last_part.split(':') - - if len(parts) >= 3: - # Format: filename:line:function - filename = parts[0].strip() - func_name = parts[-1].strip() - elif len(parts) >= 2: - # Format: filename:function or similar - filename = parts[0].strip() - func_name = parts[-1].strip() - else: - filename = "" - func_name = full_name.strip() - - # Store directory components (excluding the last part which is filename:line:func) - dir_parts = path_parts[:-1] if len(path_parts) > 1 else [] - - return { - 'func_name': func_name, - 'filename': filename, - 'dir_parts': dir_parts, - 'full_path': full_name - } - -def collect_unique_identifiers(profile_data, time_threshold_pct=1.0): - """ - Collect all unique function identifiers from the profile data. - - Args: - profile_data: Dict {func: [time, {children}]} - time_threshold_pct: Minimum percentage to consider - - Returns: - Set of unique function identifiers - """ - # Calculate total time - total_time = sum(child[0] for child in profile_data.values() if isinstance(child, list)) - if total_time == 0: - return set() - - threshold = (time_threshold_pct / 100.0) * total_time - identifiers = set() - - def collect_from_node(func_id, node_data): - """Recursively collect identifiers.""" - if not isinstance(node_data, list) or len(node_data) != 2: - return - - time, children = node_data - - # Add this identifier (regardless of threshold, we want all unique functions) - identifiers.add(func_id) - - # Recurse to children - for child_id, child_data in children.items(): - collect_from_node(child_id, child_data) - - # Process all root functions - for func_id, func_data in profile_data.items(): - collect_from_node(func_id, func_data) - - return identifiers - -def build_disambiguation_map(identifiers): - """ - Build a map from full identifier to minimal disambiguated name. - - Args: - identifiers: Set of unique function identifiers - - Returns: - Dict mapping full_identifier -> disambiguated_name - """ - from collections import Counter, defaultdict - - # Extract info for all identifiers - full_info = {} - func_name_groups = defaultdict(list) - - for full_id in identifiers: - info = extract_function_info(full_id) - full_info[full_id] = info - func_name_groups[info['func_name']].append(full_id) - - result = {} - - # Process each group of same-named functions - for func_name, id_list in func_name_groups.items(): - if len(id_list) == 1: - # Unique function name, use as-is - result[id_list[0]] = func_name - else: - # Multiple functions with same name, need disambiguation - # Try progressively longer path suffixes until we find something unique - max_dirs = max(len(full_info[full_id]['dir_parts']) for full_id in id_list) - - disambiguated = False - for num_dirs in range(0, max_dirs + 1): - # Build candidates with this many directory components - candidates = {} - for full_id in id_list: - info = full_info[full_id] - dir_parts = info['dir_parts'] - filename = info['filename'] - - if num_dirs == 0: - # Just filename - candidate = f"{filename}:{func_name}" if filename else func_name - else: - # Take last num_dirs directories + filename - path_suffix = dir_parts[-num_dirs:] if len(dir_parts) >= num_dirs else dir_parts - if path_suffix and filename: - candidate = "/".join(path_suffix) + f"/{filename}:{func_name}" - elif filename: - candidate = f"{filename}:{func_name}" - else: - candidate = func_name - - candidates[full_id] = candidate - - # Check if all candidates are unique - if len(set(candidates.values())) == len(candidates): - # Apply the disambiguation to all functions in this group - result.update(candidates) - disambiguated = True - break - - # Fallback to full path if still not unique (shouldn't happen) - if not disambiguated: - for full_id in id_list: - result[full_id] = full_info[full_id]['full_path'] - - return result - -def build_hierarchical_layout(profile_data, time_threshold_pct=1.0, zoom_path=None): - """ - Build hierarchical layout for flame graph with proper parent-child positioning. - - Args: - profile_data: Either dict {func: [time, {children}]} or [time, {child_calls}] - time_threshold_pct: Minimum percentage of total time to display - zoom_path: Optional list of disambiguated function names to zoom into (e.g., ['inference', 'generate']) - - Returns: - List of rectangles with (depth, x_start, width, func_name, time, pct) - """ - # Handle both formats: dict or [time, {children}] - if isinstance(profile_data, dict): - root_children = profile_data - elif isinstance(profile_data, list) and len(profile_data) == 2: - _, root_children = profile_data - else: - return [], 0.0 - - if root_children: - # Calculate total time from root level - total_time = sum(child[0] for child in root_children.values() if isinstance(child, list)) - if total_time == 0: - return [], 0.0 - - # Build disambiguation map for all unique functions - unique_identifiers = collect_unique_identifiers(root_children, time_threshold_pct) - disambig_map = build_disambiguation_map(unique_identifiers) - - # If zoom_path is specified, find the subtree to zoom into - if zoom_path: - # Build reverse map: disambiguated_name -> [full_identifiers] - reverse_map = defaultdict(list) - for full_id, disambig_name in disambig_map.items(): - reverse_map[disambig_name].append(full_id) - - # Navigate to the zoomed node - current_data = root_children - current_depth = 0 - - for target_name in zoom_path: - # Find matching function in current level - found = False - for func_id, func_data in current_data.items(): - disambig_name = disambig_map.get(func_id, func_id) - if disambig_name == target_name: - if isinstance(func_data, list) and len(func_data) == 2: - _, current_data = func_data - current_depth += 1 - found = True - break - - if not found: - print(f"Warning: Could not find '{target_name}' in zoom path. Available at this level:") - for func_id in list(current_data.keys())[:10]: - print(f" - {disambig_map.get(func_id, func_id)}") - return [], 0.0 - - # Use the zoomed subtree as root - root_children = current_data - # Recalculate total time for the zoomed view - total_time = sum(child[0] for child in root_children.values() if isinstance(child, list)) - if total_time == 0: - return [], 0.0 - - threshold = (time_threshold_pct / 100.0) * total_time - rectangles = [] - - def process_node(func_name, node_data, depth, x_start, parent_time=None): - """Recursively process nodes and position them.""" - if not isinstance(node_data, list) or len(node_data) != 2: - return - - time, children = node_data - - # Calculate width as proportion of total time - width = time / total_time - pct_total = (time / total_time) * 100 - - # Calculate percentage relative to parent (if parent exists) - if parent_time is not None and parent_time > 0: - pct_parent = (time / parent_time) * 100 - else: - pct_parent = 100.0 # Root nodes are 100% of themselves - - # Get disambiguated name from the map - display_name = disambig_map.get(func_name, func_name) - - # Add rectangle for this function - # Mark whether it should be labeled based on threshold - rectangles.append({ - 'depth': depth, - 'x_start': x_start, - 'width': width, - 'func_name': display_name, - 'full_identifier': func_name, - 'time': time, - 'pct': pct_parent, # Use parent-relative percentage - 'pct_total': pct_total, # Keep total percentage for reference - 'show_label': time >= threshold - }) - - # Process children with proper positioning - # Children should be positioned within this function's span - child_x = x_start - for child_name, child_data in children.items(): - if isinstance(child_data, list) and len(child_data) == 2: - child_time = child_data[0] - process_node(child_name, child_data, depth + 1, child_x, parent_time=time) - # Move position for next child - child_x += child_time / total_time - - # Process all root-level functions - x_pos = 0.0 - for func_name, func_data in root_children.items(): - if isinstance(func_data, list) and len(func_data) == 2: - func_time = func_data[0] - process_node(func_name, func_data, 0, x_pos) - x_pos += func_time / total_time - - return rectangles, total_time - - return [], 0.0 - -def draw_flame_graph(rectangles, total_time, output_file='flame_graph.png'): - """Draw flame graph visualization.""" - if not rectangles: - print("No data to visualize") - return - - # Calculate layout - max_depth = max(rect['depth'] for rect in rectangles) - fig, ax = plt.subplots(figsize=(20, max_depth + 2)) - - # Define base colors: blue and green alternating by row - import colorsys - blue_hue = 0.58 # Blue in HSV - green_hue = 0.33 # Green in HSV - - # Track x-position at each depth to determine column parity - depth_positions = {} - - for rect in rectangles: - depth = rect['depth'] - x_start = rect['x_start'] - width = rect['width'] - func_name = rect['func_name'] - pct = rect['pct'] - time_abs = rect['time'] - - # Convert to absolute time coordinates - x_start_abs = x_start * total_time - width_abs = width * total_time - - # Alternate hue between blue and green by depth (row) - hue = blue_hue if depth % 2 == 0 else green_hue - - # Track column index at this depth - if depth not in depth_positions: - depth_positions[depth] = [] - - # Find column index (how many rectangles we've seen at this depth) - column_idx = len(depth_positions[depth]) - depth_positions[depth].append(x_start_abs) - - # Alternate brightness by column: odd columns are lighter, even are darker - if column_idx % 2 == 0: - saturation = 0.6 - value = 0.85 - else: - saturation = 0.5 - value = 0.95 - - # Convert HSV to RGB - rgb = colorsys.hsv_to_rgb(hue, saturation, value) - - # Create rectangle with no vertical spacing (height=1.0) and only left/right borders - patch = mpatches.Rectangle( - (x_start_abs, depth), width_abs, 1.0, - facecolor=rgb, - edgecolor='none', - linewidth=0 - ) - ax.add_patch(patch) - - # Add left and right borders only - ax.plot([x_start_abs, x_start_abs], [depth, depth + 1.0], 'k-', linewidth=0.2, zorder=10) - ax.plot([x_start_abs + width_abs, x_start_abs + width_abs], [depth, depth + 1.0], 'k-', linewidth=0.2, zorder=10) - - # Add text label if above threshold AND width is sufficient - # Use absolute width for threshold check - if rect.get('show_label', True) and width_abs > 0.015 * total_time: # Threshold in absolute time - # Create wrapped text that fits within the rectangle - import textwrap - - # Calculate approximate character width based on rectangle width - # Rough estimate: each character is about 0.06 inches at fontsize 7 - fig_width_inches = 20 # From figsize - chars_per_inch = 14 # Approximate at fontsize 7 - rect_width_inches = (width_abs / total_time) * fig_width_inches - max_chars = int(rect_width_inches * chars_per_inch) - max_chars = max(max_chars, 3) # At least 3 characters - - # Wrap the function name - wrapped_name = '\n'.join(textwrap.wrap(func_name, width=max_chars, break_long_words=True, break_on_hyphens=False)) - - # Build label with wrapped name - label = f"{wrapped_name}\n{pct:.1f}%\n{time_abs:.3f}s" - - # Limit number of lines to fit in rectangle height (0.8 units) - max_lines = 3 # Approximately 3 lines fit in 0.8 height - label_lines = label.split('\n') - if len(label_lines) > max_lines: - label = '\n'.join(label_lines[:max_lines]) - - ax.text( - x_start_abs + width_abs/2, depth + 0.5, - label, - ha='center', va='center', - fontsize=7, - clip_on=True - ) - - ax.set_xlim(0, total_time) - ax.set_ylim(0, max_depth + 1) - ax.set_xlabel('Cumulative Time (seconds)', fontsize=12) - ax.set_ylabel('Call Stack Depth', fontsize=12) - ax.set_title('Flame Graph - Profile Visualization', fontsize=14, weight='bold') - ax.set_yticks(range(max_depth + 1)) - ax.grid(axis='y', alpha=0.3) - - plt.tight_layout() - plt.savefig(output_file, dpi=150, bbox_inches='tight') - print(f"Flame graph saved to {output_file}") - plt.show() - -def main(): - parser = argparse.ArgumentParser(description='Generate flame graph from profile JSON data') - parser.add_argument('input', nargs='?', default='profile.json', - help='Input JSON profile file (default: profile.json)') - parser.add_argument('-o', '--output', default='flame_graph.png', - help='Output flame graph image file (default: flame_graph.png)') - parser.add_argument('-t', '--threshold', type=float, default=1.0, - help='Time threshold percentage for displaying functions (default: 1.0)') - parser.add_argument('-z', '--zoom', type=str, default=None, - help='Zoom into a specific call path using disambiguated names separated by ">" (e.g., "inference>generate")') - - args = parser.parse_args() - - # Parse zoom path if provided - zoom_path = None - if args.zoom: - zoom_path = [name.strip() for name in args.zoom.split('>')] - print(f"Zooming into path: {' > '.join(zoom_path)}") - - # Load profile JSON - profile_data = load_profile_data(args.input) - - # Build hierarchical layout with specified threshold and zoom - rectangles, total_time = build_hierarchical_layout(profile_data, time_threshold_pct=args.threshold, zoom_path=zoom_path) - - if not rectangles: - print("No data to visualize") - return - - print(f"Total profiled time: {total_time:.2f}s") - print(f"Displaying {len(rectangles)} function calls above {args.threshold}% threshold") - - # Draw flame graph - draw_flame_graph(rectangles, total_time, output_file=args.output) - -if __name__ == '__main__': - main() From 38e1b4c02910e531a5edd3ae9586979a9fbc85bd Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 15:04:34 -0700 Subject: [PATCH 68/99] naive attempt at porting operations --- operators/axpy/op.py | 114 +++++------------ operators/dequant/op.py | 104 ++++----------- operators/gelu/op.py | 110 +++++----------- operators/layer_norm/op.py | 110 +++++----------- operators/leaky_relu/op.py | 110 +++++----------- operators/mem_copy/op.py | 123 ++++++++---------- operators/mha/op.py | 253 +++++++++++++------------------------ operators/relu/op.py | 104 +++++---------- operators/sigmoid/op.py | 107 +++++----------- operators/tanh/op.py | 104 +++++---------- 10 files changed, 400 insertions(+), 839 deletions(-) diff --git a/operators/axpy/op.py b/operators/axpy/op.py index 8698741e..d1b1e269 100644 --- a/operators/axpy/op.py +++ b/operators/axpy/op.py @@ -7,17 +7,15 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEAXPY(AIEOperatorBase): +class AIEAXPY(SingleMLIRSourceOperator): """AIE-accelerated aX + Y operator""" def __init__( @@ -30,25 +28,26 @@ def __init__( context=None, ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels self.scalar_factor = scalar_factor - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_axpy", callback_args=[ @@ -62,68 +61,21 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"axpy.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "generic" - / "axpy.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("x", self.size) - self.add_buffer("y", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "axpy", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("axpy", "x", "y", "output") - - def forward(self, x, y): - if x.numel() > self.size or y.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEAXPY: input too large for configured size" - ) - if x.numel() != y.numel(): - raise AIEOperatorConstraintError("AIEAXPY: sizes of X and Y do not match") - - original_shape = x.shape - x_flat = x.reshape(-1) - y_flat = y.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) - - self.write_buffer("x", x_flat) - self.write_buffer("y", y_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"axpy.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "axpy.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # x + AIERuntimeArgSpec("in", (self.size,)), # y + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/dequant/op.py b/operators/dequant/op.py index 9f19f0cd..94e39fb7 100644 --- a/operators/dequant/op.py +++ b/operators/dequant/op.py @@ -7,17 +7,15 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEDequant(AIEOperatorBase): +class AIEDequant(SingleMLIRSourceOperator): def __init__( self, @@ -46,17 +44,15 @@ def __init__( assert self.size % total_cores == 0, "Size must be divisible by total cores" assert total_cores <= 16, "Total cores (columns * channels) must be <= 16" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_dequant_kernel", callback_args=[ @@ -70,68 +66,24 @@ def set_up_artifacts(self): ], ) - # Build the kernel object file with the appropriate tile size and group size - kernel_artifact = KernelObjectArtifact( - f"expand_aie2_{self.tile_size}.o", - dependencies=[ - SourceArtifact( - self.context.base_dir / "aie_kernels" / "generic" / "expand.cc" - ) - ], - extra_flags=[ - f"-DTILE_SIZE={self.tile_size}", - f"-DGROUP_SIZE={self.group_size}", - ], - ) - - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[mlir_artifact, kernel_artifact], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Input buffer uses uint8 dtype, output uses bfloat16 - self.add_buffer("input", self.input_size, dtype=np.uint8) - self.add_buffer("output", self.output_size, dtype=bfloat16) - self.add_kernel( - "dequant", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("dequant", "input", "output") - - def forward(self, x_packed): - """ - Forward pass for dequantization. - - Args: - x_packed: Packed uint8 numpy array containing int4 data + scale factors - - Returns: - Dequantized bfloat16 torch tensor - """ - if x_packed.size != self.input_size: - raise AIEOperatorConstraintError( - f"AIEDequant: input size {x_packed.size} does not match expected size {self.input_size}" + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"expand_aie2_{self.tile_size}.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "expand.cc" + ) + ], + extra_flags=[ + f"-DTILE_SIZE={self.tile_size}", + f"-DGROUP_SIZE={self.group_size}", + ], ) + ] - # Write input and execute - self.write_buffer("input", x_packed.flatten()) - self.write_buffer("output", np.zeros(self.output_size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", shape=(self.output_size,), dtype=bfloat16 - ) - - return result + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.input_size,), dtype=np.uint8), # input + AIERuntimeArgSpec("out", (self.output_size,), dtype=bfloat16), # output + ] diff --git a/operators/gelu/op.py b/operators/gelu/op.py index 7733251b..c59e141e 100644 --- a/operators/gelu/op.py +++ b/operators/gelu/op.py @@ -7,24 +7,25 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEGELU(AIEOperatorBase): +class AIEGELU(SingleMLIRSourceOperator): """AIE-accelerated GELU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels @@ -32,17 +33,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_gelu", callback_args=[ @@ -55,65 +54,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"gelu.o", - dependencies=[ - SourceArtifact( - self.context.base_dir / "aie_kernels" / "aie2p" / "gelu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "gelu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("gelu", "input", "output") - - def forward(self, x): - """Forward pass for GELU activation""" - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEGELU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - # Pad if necessary - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - # Execute on AIE - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - # Remove padding and restore shape - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"gelu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "gelu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/layer_norm/op.py b/operators/layer_norm/op.py index 0f83b99e..75771799 100644 --- a/operators/layer_norm/op.py +++ b/operators/layer_norm/op.py @@ -7,26 +7,27 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIELayerNorm(AIEOperatorBase): +class AIELayerNorm(SingleMLIRSourceOperator): """AIE-accelerated LAYER NORM operator""" def __init__( self, size, num_aie_columns, num_channels, tile_size, trace_size=0, context=None ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.trace_size = trace_size self.num_aie_columns = num_aie_columns @@ -35,17 +36,15 @@ def __init__( total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_layer_norm", callback_args=[ @@ -58,62 +57,23 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"layer_norm.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "layer_norm.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "layer_norm", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("layer_norm", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIELayerNorm: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"layer_norm.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "layer_norm.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/leaky_relu/op.py b/operators/leaky_relu/op.py index 1cde0773..461de61b 100644 --- a/operators/leaky_relu/op.py +++ b/operators/leaky_relu/op.py @@ -7,26 +7,27 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIELeakyReLU(AIEOperatorBase): +class AIELeakyReLU(SingleMLIRSourceOperator): """AIE-accelerated LEAKY RELU operator""" def __init__( self, size, num_aie_columns, num_channels, tile_size, alpha=0.01, context=None ): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -36,17 +37,15 @@ def __init__( total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_leaky_relu", callback_args=[ @@ -60,62 +59,23 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"leaky_relu.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "leaky_relu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "leaky_relu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("leaky_relu", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIELeakyReLU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"leaky_relu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "leaky_relu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/mem_copy/op.py b/operators/mem_copy/op.py index 806248b1..0481f2da 100644 --- a/operators/mem_copy/op.py +++ b/operators/mem_copy/op.py @@ -7,17 +7,18 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, + KernelArchiveArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEMemCopy(AIEOperatorBase): +class AIEMemCopy(SingleMLIRSourceOperator): def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=None): self.size = size @@ -29,22 +30,16 @@ def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=Non # For naming consistency with other operators self.bypass_str = "bypass" if bypass else "no_bypass" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - size = self.tile_size * self.num_cores - - # Xclbin base name (shared) - xclbin_base_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" - - # Generate MLIR for xclbin (using dummy size) - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{xclbin_base_name}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_mem_copy", callback_args=[ @@ -58,67 +53,57 @@ def set_up_artifacts(self): ], ) - # Build kernel only if not bypass mode + def get_kernel_artifacts(self): if not self.bypass: - kernel_artifact = KernelObjectArtifact( - "mem_copy.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "generic" - / "passThrough.cc" - ) - ], - ) - xclbin_dependencies = [mlir_artifact, kernel_artifact] + return [ + KernelObjectArtifact( + "mem_copy.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "generic" + / "passThrough.cc" + ) + ], + ) + ] else: - xclbin_dependencies = [mlir_artifact] - + return [] + + def get_artifacts(self): + # Override to add --dynamic-objFifos flag + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + kernel_deps = ( + [ + KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] + ) xclbin_artifact = XclbinArtifact( - f"{xclbin_base_name}.xclbin", - dependencies=xclbin_dependencies, + f"{operator_name}.xclbin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, extra_flags=["--dynamic-objFifos"], ) - - insts_file_name = f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_{self.size}_tile_{self.tile_size}_{self.bypass_str}" insts_artifact = InstsBinArtifact( - f"{insts_file_name}.bin", + f"{operator_name}.bin", + mlir_input=mlir_artifact, dependencies=[mlir_artifact], extra_flags=["--dynamic-objFifos"], ) + return xclbin_artifact, insts_artifact - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "mem_copy", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("mem_copy", "input", "output") - - def forward(self, x): - """Forward pass for memory copy""" - if x.numel() != self.size: - raise AIEOperatorConstraintError( - f"AIEMemCopy: input size {x.numel()} does not match expected size {self.size}" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - # Execute on AIE - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - return result.reshape(*original_shape) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/mha/op.py b/operators/mha/op.py index 40de3477..1040ccff 100644 --- a/operators/mha/op.py +++ b/operators/mha/op.py @@ -8,8 +8,8 @@ from typing import Dict, List from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -20,7 +20,7 @@ from operators.common.utils import torch_to_numpy, numpy_to_torch -class AIEMHA(AIEOperatorBase): +class AIEMHA(SingleMLIRSourceOperator): def __init__( self, @@ -40,20 +40,34 @@ def __init__( self.num_of_pipelines = num_of_pipelines assert d == 64, "Only d=64 is supported in this version" - # Artifacts created by set_up_artifacts() - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads + return f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d" - def set_up_artifacts(self): - # Set up compilation artifacts - # --- + def get_mlir_artifact(self): operator_dir = Path(__file__).parent + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=operator_dir / "design.py", + callback_fn="fused_mha", + callback_kwargs={ + "heads": self.num_heads, + "S_q": self.seq_len, + "S_kv": self.seq_len, + "d": self.d, + "B_q": self.B_q, + "B_kv": self.B_kv, + "num_KV_heads": self.num_KV_heads, + "number_of_pipelines": self.num_of_pipelines, + "emulate_bf16_mmul_with_bfp16": True, + "trace_size": 0, + "verbose": False, + }, + ) - kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads - file_name_base = f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d" - + def get_kernel_artifacts(self): # Define source files mm_source = str(self.context.base_dir / "aie_kernels" / "aie2p" / "mm.cc") softmax_source = str( @@ -83,105 +97,72 @@ def set_up_artifacts(self): "zero_scalar_bf16": "zero_scalar_bf16_rowmaj", } - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", - import_path=operator_dir / "design.py", - callback_fn="fused_mha", - callback_kwargs={ - "heads": self.num_heads, - "S_q": self.seq_len, - "S_kv": self.seq_len, - "d": self.d, - "B_q": self.B_q, - "B_kv": self.B_kv, - "num_KV_heads": self.num_KV_heads, - "number_of_pipelines": self.num_of_pipelines, - "emulate_bf16_mmul_with_bfp16": True, - "trace_size": 0, - "verbose": False, - }, - ) + return [ + KernelObjectArtifact( + f"mha_mm.o", + extra_flags=mm_defines_colmaj, + dependencies=[SourceArtifact(mm_source)], + ), + KernelObjectArtifact( + f"mha_mm_rowmaj.o", + extra_flags=mm_defines_rowmaj, + dependencies=[SourceArtifact(mm_source)], + rename_symbols=mm_rename_symbols, + ), + KernelObjectArtifact( + "mha_softmax.o", + dependencies=[SourceArtifact(softmax_source)], + ), + KernelObjectArtifact( + "mha_mha.o", dependencies=[SourceArtifact(mha_source)] + ), + KernelObjectArtifact( + "mha_passThrough.o", + extra_flags=["-DBIT_WIDTH=16"], + dependencies=[SourceArtifact(passthrough_source)], + ), + ] - xclbin_artifact = XclbinArtifact( - f"mha.xclbin", - dependencies=[ - mlir_artifact, + def get_artifacts(self): + # Override to add --dynamic-objFifos flag + operator_name = self.get_operator_name() + mlir_artifact = self.get_mlir_artifact() + kernel_deps_inputs = self.get_kernel_artifacts() + if len(kernel_deps_inputs) > 0: + mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive + kernel_deps = ( + [ KernelArchiveArtifact( - f"mha_kernels.a", - dependencies=[ - KernelObjectArtifact( - f"mha_mm.o", - extra_flags=mm_defines_colmaj, - dependencies=[SourceArtifact(mm_source)], - ), - KernelObjectArtifact( - f"mha_mm_rowmaj.o", - extra_flags=mm_defines_rowmaj, - dependencies=[SourceArtifact(mm_source)], - rename_symbols=mm_rename_symbols, - ), - KernelObjectArtifact( - "mha_softmax.o", - dependencies=[SourceArtifact(softmax_source)], - ), - KernelObjectArtifact( - "mha_mha.o", dependencies=[SourceArtifact(mha_source)] - ), - KernelObjectArtifact( - "mha_passThrough.o", - extra_flags=["-DBIT_WIDTH=16"], - dependencies=[SourceArtifact(passthrough_source)], - ), - ], - ), - ], + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] + ) + xclbin_artifact = XclbinArtifact( + f"{operator_name}.xclbin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact] + kernel_deps, extra_flags=["--dynamic-objFifos"], ) - insts_artifact = InstsBinArtifact( - f"mha.bin", dependencies=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + f"{operator_name}.bin", + mlir_input=mlir_artifact, + dependencies=[mlir_artifact], + extra_flags=["--dynamic-objFifos"], ) + return xclbin_artifact, insts_artifact - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - - artifacts = [xclbin_artifact, insts_artifact] - self.add_artifacts(artifacts) - - def set_up_runtime(self): - # Set up runtime - # --- - self.add_kernel( - "mha", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_buffer( - "Q", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "K", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "V", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_buffer( - "O", - self.num_heads - * self.d - * self._calculate_seq_padding(self.seq_len, self.num_of_pipelines), - ) - self.add_to_runlist("mha", "Q", "K", "V", "O") + def get_arg_spec(self): + seq_padding = self._calculate_seq_padding(self.seq_len, self.num_of_pipelines) + buffer_size = self.num_heads * self.d * seq_padding + return [ + AIERuntimeArgSpec("in", (buffer_size,)), # Q + AIERuntimeArgSpec("in", (buffer_size,)), # K + AIERuntimeArgSpec("in", (buffer_size,)), # V + AIERuntimeArgSpec("out", (buffer_size,)), # O + ] def _calculate_seq_padding(self, seq_len, num_pipeline=1): return ((seq_len + 63 * num_pipeline) // (64 * num_pipeline)) * ( @@ -190,7 +171,7 @@ def _calculate_seq_padding(self, seq_len, num_pipeline=1): def _pad_to_multiple_of_64(self, tensor, seq_dim, num_pipeline=1): seq_len = tensor.shape[seq_dim] - padded_seq_len = _calculate_seq_padding(seq_len, num_pipeline) + padded_seq_len = self._calculate_seq_padding(seq_len, num_pipeline) if padded_seq_len == seq_len: return tensor @@ -219,63 +200,3 @@ def _unpack_padded_to_compact( dst = np.zeros((H, S, D), dtype=src.dtype) dst = src[:H, :S, :D] return dst - - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - applicable = ( - q.shape[-1] == self.d - and k.shape[-1] == self.d - and v.shape[-1] == self.d - and q.shape[-2] == self.seq_len - and k.shape[-2] == self.seq_len - and v.shape[-2] == self.seq_len - and self.seq_len % 64 == 0, # Sequence length must be multiple of 64 - ) - if not applicable: - raise AIEOperatorConstraintError( - "AIEElementwiseAdd: incompatible tensor shape(s)" - ) - - ret = self._execute_aie_operation(q, k, v) - return ret - - def _execute_aie_operation(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - # Convert to numpy - q_np = torch_to_numpy(q) - k_np = torch_to_numpy(k) - v_np = torch_to_numpy(v) - - # Calculate padded sequence length - S_pad = self._calculate_seq_padding(self.seq_len, self.num_of_pipelines) - - # Pack compact inputs to padded format - q_padded = self._pack_compact_to_padded( - q_np, self.num_heads, self.seq_len, S_pad, self.d - ) - k_padded = self._pack_compact_to_padded( - k_np, self.num_heads, self.seq_len, S_pad, self.d - ) - v_padded = self._pack_compact_to_padded( - v_np, self.num_heads, self.seq_len, S_pad, self.d - ) - - # Write padded buffers - self.write_buffer("Q", q_padded) - self.write_buffer("K", k_padded) - self.write_buffer("V", v_padded) - - # Execute - self.run_runlist() - - # Read padded output - o_padded = self.read_buffer( - "O", shape=(self.num_heads, S_pad, self.d), dtype=bfloat16 - ) - - # Unpack padded output to compact format - o_compact = self._unpack_padded_to_compact( - o_padded, self.num_heads, self.seq_len, S_pad, self.d - ) - - # Convert back to torch with correct shape - result = numpy_to_torch(o_compact) - return result diff --git a/operators/relu/op.py b/operators/relu/op.py index 8ef73a78..457a923b 100644 --- a/operators/relu/op.py +++ b/operators/relu/op.py @@ -7,24 +7,25 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIEReLU(AIEOperatorBase): +class AIEReLU(SingleMLIRSourceOperator): """AIE-accelerated ReLU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns self.num_channels = num_channels @@ -32,17 +33,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_relu", callback_args=[ @@ -55,59 +54,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"relu.o", - dependencies=[ - SourceArtifact( - self.context.base_dir / "aie_kernels" / "aie2p" / "relu.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "relu", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("relu", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIEReLU: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"relu.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "relu.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/sigmoid/op.py b/operators/sigmoid/op.py index 2ca56bb5..6f1c8456 100644 --- a/operators/sigmoid/op.py +++ b/operators/sigmoid/op.py @@ -7,24 +7,25 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIESigmoid(AIEOperatorBase): +class AIESigmoid(SingleMLIRSourceOperator): """AIE-accelerated Sigmoid activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -33,17 +34,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_sigmoid", callback_args=[ @@ -56,62 +55,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"sigmoid.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "sigmoid.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "sigmoid", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("sigmoid", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIESigmoid: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"sigmoid.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "sigmoid.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] diff --git a/operators/tanh/op.py b/operators/tanh/op.py index 5edd5bda..6a71f559 100644 --- a/operators/tanh/op.py +++ b/operators/tanh/op.py @@ -7,24 +7,25 @@ from pathlib import Path from operators.common import ( - AIEOperatorBase, - AIEOperatorConstraintError, - XclbinArtifact, - InstsBinArtifact, + SingleMLIRSourceOperator, + AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, ) -class AIETanh(AIEOperatorBase): +class AIETanh(SingleMLIRSourceOperator): """AIE-accelerated Tanh activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): max_multiple = num_aie_columns * tile_size - padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple - self.orig_size = size - self.size = padded_size + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" + assert size % tile_size == 0, "size must be multiple of tile_size" + + self.size = size self.tile_size = tile_size self.num_columns = num_aie_columns @@ -33,17 +34,15 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - self.xclbin_artifact = None - self.insts_artifact = None + SingleMLIRSourceOperator.__init__(self, context=context) - AIEOperatorBase.__init__(self, context=context) + def get_operator_name(self): + return f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - def set_up_artifacts(self): + def get_mlir_artifact(self): operator_dir = Path(__file__).parent - file_name_base = f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" - - mlir_artifact = PythonGeneratedMLIRArtifact( - f"{file_name_base}.mlir", + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", import_path=operator_dir / "design.py", callback_fn="my_tanh", callback_args=[ @@ -56,59 +55,20 @@ def set_up_artifacts(self): ], ) - xclbin_artifact = XclbinArtifact( - f"{file_name_base}.xclbin", - dependencies=[ - mlir_artifact, - KernelObjectArtifact( - f"tanh.o", - dependencies=[ - SourceArtifact( - self.context.base_dir / "aie_kernels" / "aie2p" / "tanh.cc" - ) - ], - ), - ], - ) - - insts_artifact = InstsBinArtifact( - f"{file_name_base}.bin", dependencies=[mlir_artifact] - ) - - self.xclbin_artifact = xclbin_artifact - self.insts_artifact = insts_artifact - self.add_artifacts([xclbin_artifact, insts_artifact]) - - def set_up_runtime(self): - self.add_buffer("input", self.size) - self.add_buffer("output", self.size) - self.add_kernel( - "tanh", - self.xclbin_artifact, - self.xclbin_artifact.kernel_name, - self.insts_artifact, - ) - self.add_to_runlist("tanh", "input", "output") - - def forward(self, x): - if x.numel() > self.size: - raise AIEOperatorConstraintError( - "AIETanh: input too large for configured size" - ) - - original_shape = x.shape - x_flat = x.reshape(-1) - - pad_len = self.size - x_flat.numel() - if pad_len > 0: - x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) - - self.write_buffer("input", x_flat) - self.write_buffer("output", np.zeros(self.size, dtype=bfloat16)) - self.run_runlist() - result = self.read_buffer_as_torch("output", shape=(self.size,), dtype=bfloat16) - - if pad_len > 0: - result = result[: x_flat.numel() - pad_len] - - return result.reshape(*original_shape) + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"tanh.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "aie2p" / "tanh.cc" + ) + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.size,)), # input + AIERuntimeArgSpec("out", (self.size,)), # output + ] From ba15525d66bb15d4b97c711284ecfad5da520af1 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 16:26:35 -0700 Subject: [PATCH 69/99] most of the operators running in local tests --- aie_kernels/aie2p/softmax.cc | 5 +- aie_kernels/generic/mv.cc | 5 +- conftest.py | 4 +- operators/axpy/design.py | 9 ++- operators/common/device_manager.py | 3 +- operators/common/test_utils.py | 96 ++++++++++++++++++++---------- operators/dequant/design.py | 9 ++- operators/elementwise_add/test.py | 1 - operators/gelu/design.py | 4 +- operators/layer_norm/design.py | 10 +++- operators/leaky_relu/design.py | 11 +++- operators/mem_copy/design.py | 11 +++- operators/mha/design.py | 5 +- operators/relu/design.py | 4 +- operators/rms_norm/op.py | 78 ++++++++++++++---------- operators/rope/test.py | 7 +-- operators/sigmoid/design.py | 4 +- operators/silu/test.py | 1 - operators/softmax/test.py | 2 +- operators/tanh/design.py | 4 +- 20 files changed, 184 insertions(+), 89 deletions(-) diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 6837f570..64cca202 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -177,9 +177,10 @@ void partial_softmax_bf16(bfloat16 *restrict input, partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale); } -void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) { +void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size) +{ // TODO: Optimize this to use vector code - for(int32 i = unmasked_size; i < total_size; i++) { + for (int32 i = unmasked_size; i < total_size; i++) { inout[i] = (bfloat16)(-INFINITY); } } diff --git a/aie_kernels/generic/mv.cc b/aie_kernels/generic/mv.cc index ff54af58..f632e8f0 100644 --- a/aie_kernels/generic/mv.cc +++ b/aie_kernels/generic/mv.cc @@ -45,10 +45,7 @@ Matrix-vector multiplication kernel - r: Vector size; data from the matrix and vector will be loaded in and processed in chunks of this size */ template -void matvec_vectorized(uint32_t m, - const bfloat16 *__restrict a, - const bfloat16 *__restrict b, - bfloat16 *__restrict c) +void matvec_vectorized(uint32_t m, const bfloat16 *__restrict a, const bfloat16 *__restrict b, bfloat16 *__restrict c) { ::aie::set_rounding(aie::rounding_mode::conv_even); bfloat16 *c_end = c + m; diff --git a/conftest.py b/conftest.py index f3140421..db863da2 100644 --- a/conftest.py +++ b/conftest.py @@ -17,7 +17,9 @@ @pytest.fixture def aie_context(): """Create a fresh AIEContext for each test""" - return AIEContext() + ctx = AIEContext() + yield ctx + ctx.device_manager.reset() def pytest_addoption(parser): diff --git a/operators/axpy/design.py b/operators/axpy/design.py index 69468940..bfa676f8 100644 --- a/operators/axpy/design.py +++ b/operators/axpy/design.py @@ -16,7 +16,14 @@ def my_axpy( - dev, num_elements, num_columns, num_channels, tile_size, trace_size, scalar_factor + dev, + num_elements, + num_columns, + num_channels, + tile_size, + trace_size, + scalar_factor, + kernel_archive=None, ): factor = scalar_factor per_tile_elements = 4096 if tile_size > 4096 else tile_size diff --git a/operators/common/device_manager.py b/operators/common/device_manager.py index 3ec6acf9..2ae18bfe 100644 --- a/operators/common/device_manager.py +++ b/operators/common/device_manager.py @@ -31,7 +31,7 @@ def __init__(self): if AIEDeviceManager._initialized: return AIEDeviceManager._initialized = True - + self.device = pyxrt.device(0) self.device_type = XRTHostRuntime().device() self.contexts = {} # xclbin_path -> (context, xclbin) @@ -104,3 +104,4 @@ def reset(self): """Reset the device manager (for debugging)""" self.cleanup() AIEDeviceManager._instance = None + AIEDeviceManager._initialized = False diff --git a/operators/common/test_utils.py b/operators/common/test_utils.py index dc19df5d..066a7981 100644 --- a/operators/common/test_utils.py +++ b/operators/common/test_utils.py @@ -6,6 +6,7 @@ from ml_dtypes import bfloat16 from .utils import torch_to_numpy import logging +from .base import SingleMLIRSourceOperator, AIEBuffer def nearly_equal( @@ -29,11 +30,11 @@ def nearly_equal( return diff < max(abs_tol, rel_tol * norm) -def verify_buffer(operator, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): +def verify_buffer(output, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): errors = [] expected_np = torch_to_numpy(reference).reshape((-1,)) - buf_size = operator.buffers[buf_name] // 2 - output = operator.read_buffer(buf_name, (buf_size,)) + output = output.reshape((-1,)) + if len(output) < len(expected_np): # Allow larger buffers - binning may have allocated more space than needed print( @@ -65,7 +66,7 @@ def run_test( Run operator test with specified input/output/intermediate buffers. Args: - operator: AIE operator instance with registered buffers + operator: AIE operator instance input_buffers: Dict mapping buffer names to input data arrays output_buffers: Dict mapping buffer names to reference output arrays intermediate_buffers: Optional dict mapping buffer names to reference arrays for validation @@ -83,45 +84,78 @@ def run_test( level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) - operator.context.compile_all() - operator.context.prepare_runtime() - # Run warmup iterations before writing to buffers (warmup iters might corrupt the buffers) + if not isinstance(operator, SingleMLIRSourceOperator): + raise ValueError("run_test only supports SingleMLIRSourceOperator") + + operator.compile() + op_func = operator.get_callable() + + args = [] + arg_spec = operator.get_arg_spec() + + input_iter = iter(input_buffers.items()) + output_iter = iter(output_buffers.items()) + output_map = {} + + total_bytes = 0 + + for spec in arg_spec: + if spec.direction == "in": + try: + name, data = next(input_iter) + except StopIteration: + raise ValueError("Not enough input buffers provided for arg spec") + data_np = torch_to_numpy(data) + buf = AIEBuffer.from_np(data_np) + args.append(buf) + total_bytes += buf.bo.size() + elif spec.direction == "out": + try: + name, expected = next(output_iter) + except StopIteration: + raise ValueError("Not enough output buffers provided for arg spec") + buf = AIEBuffer(shape=spec.shape, dtype=spec.dtype) + args.append(buf) + output_map[name] = buf + total_bytes += buf.bo.size() + else: + # Handle other directions if needed, or raise error + raise ValueError(f"Unsupported direction: {spec.direction}") + + # Run warmup iterations for _ in range(warmup_iters): - operator.run_runlist() # warmup run to configure - - # Write input buffers and zero outputs - for buf_name in output_buffers: - buf_size = operator.buffers[buf_name] - operator.write_buffer(buf_name, np.zeros(buf_size, dtype=np.uint8)) - # Operator may share the same buffer object for inputs and outputs; hence, write input after outputs - for buf_name, data in input_buffers.items(): - data_np = torch_to_numpy(data) - operator.write_buffer(buf_name, data_np) + op_func(*args) # Run operator - elapsed_total = 0 + start_time = time.time() for _ in range(timed_iters): - elapsed_total += operator.run_runlist() - elapsed = elapsed_total / timed_iters + op_func(*args) + end_time = time.time() + + elapsed = (end_time - start_time) / timed_iters latency_us = elapsed * 1e6 # Verify outputs errors = {} for buf_name, expected in output_buffers.items(): - buf_errors = verify_buffer(operator, buf_name, expected, rel_tol, abs_tol) - if buf_errors: - errors[buf_name] = buf_errors - - for buf_name, expected in intermediate_buffers.items(): - buf_errors = verify_buffer(operator, buf_name, expected, rel_tol, abs_tol) - if buf_errors: - errors[buf_name] = buf_errors + if buf_name in output_map: + buf = output_map[buf_name] + output_np = buf.view_as_np() + buf_errors = verify_buffer(output_np, buf_name, expected, rel_tol, abs_tol) + if buf_errors: + errors[buf_name] = buf_errors + else: + print(f"Warning: Output buffer {buf_name} not found in operator arguments") + + # Intermediate buffers are not supported in this generic run_test for SingleMLIRSourceOperator + # unless we expose them somehow. For now, ignore or warn. + if intermediate_buffers: + print( + "Warning: intermediate_buffers verification is not supported for SingleMLIRSourceOperator in run_test" + ) # Calculate bandwidth - input_bytes = sum(operator.buffers[buf_name] for buf_name in input_buffers) - output_bytes = sum(operator.buffers[buf_name] for buf_name in output_buffers) - total_bytes = input_bytes + output_bytes bandwidth_gbps = total_bytes / (latency_us * 1e-6) / 1e9 return errors, latency_us, bandwidth_gbps diff --git a/operators/dequant/design.py b/operators/dequant/design.py index 05cf2ddd..07c3e3bf 100644 --- a/operators/dequant/design.py +++ b/operators/dequant/design.py @@ -16,7 +16,14 @@ def my_dequant_kernel( - dev, num_elements, num_columns, num_channels, trace_size, tile_size, group_size + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + group_size, + kernel_archive=None, ): per_tile_elements = ( 16384 if tile_size > 16384 else tile_size diff --git a/operators/elementwise_add/test.py b/operators/elementwise_add/test.py index 1fdec179..aeb16db0 100755 --- a/operators/elementwise_add/test.py +++ b/operators/elementwise_add/test.py @@ -61,7 +61,6 @@ def test_elementwise_add( operator = AIEElementwiseAdd( size=input_length, num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, context=aie_context, ) diff --git a/operators/gelu/design.py b/operators/gelu/design.py index 7a110286..3ecd85a5 100644 --- a/operators/gelu/design.py +++ b/operators/gelu/design.py @@ -15,7 +15,9 @@ from aie.iron.controlflow import range_ -def my_gelu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_gelu( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 8192 if tile_size > 8192 else tile_size fifodepth = 1 if line_size > 4096 else 2 diff --git a/operators/layer_norm/design.py b/operators/layer_norm/design.py index f48bb2d2..c5f088a4 100644 --- a/operators/layer_norm/design.py +++ b/operators/layer_norm/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_layer_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size): +def my_layer_norm( + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + kernel_archive=None, +): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: diff --git a/operators/leaky_relu/design.py b/operators/leaky_relu/design.py index 25cd580b..a5d5c534 100644 --- a/operators/leaky_relu/design.py +++ b/operators/leaky_relu/design.py @@ -14,7 +14,16 @@ from aie.iron.controlflow import range_ -def my_leaky_relu(dev, size, num_columns, num_channels, tile_size, trace_size, alpha): +def my_leaky_relu( + dev, + size, + num_columns, + num_channels, + tile_size, + trace_size, + alpha, + kernel_archive=None, +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/operators/mem_copy/design.py b/operators/mem_copy/design.py index ce807a48..73a0eca2 100644 --- a/operators/mem_copy/design.py +++ b/operators/mem_copy/design.py @@ -167,7 +167,16 @@ def create_partial_workload_config( # -def my_mem_copy(dev, size, num_cores, num_channels, bypass, tile_size, trace_size): +def my_mem_copy( + dev, + size, + num_cores, + num_channels, + bypass, + tile_size, + trace_size, + kernel_archive=None, +): # -------------------------------------------------------------------------- # Configuration # -------------------------------------------------------------------------- diff --git a/operators/mha/design.py b/operators/mha/design.py index 0f9fdf4e..9dc33b92 100644 --- a/operators/mha/design.py +++ b/operators/mha/design.py @@ -24,7 +24,7 @@ from aie.iron.device import NPU1Col1, NPU2, Tile from aie.iron.controlflow import range_ from aie.helpers.taplib import TensorTiler2D, TensorAccessSequence, TensorAccessPattern -from aie.helpers.dialects.ext.scf import if_, else_ +from aie.helpers.dialects.scf import if_, else_ base_dir = Path(__file__).parent @@ -115,6 +115,7 @@ def fused_mha( emulate_bf16_mmul_with_bfp16: bool, trace_size: int = 0, verbose: bool = False, + kernel_archive=None, ): of_depth = 2 @@ -205,7 +206,7 @@ def fused_mha( # AIE kernel declarations func_type = "" if vectorized else "_scalar" - bin_name = "mha_kernels.a" + bin_name = kernel_archive if kernel_archive else "mha_kernels.a" zero_kernel = Kernel(f"zero_{dtype_str}", bin_name, [qk_ty]) diff --git a/operators/relu/design.py b/operators/relu/design.py index 496bb443..5c46fbb9 100644 --- a/operators/relu/design.py +++ b/operators/relu/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_relu(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_relu( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index 920b7b0e..a512869c 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -34,7 +34,9 @@ def __init__( context=None, ): max_multiple = num_aie_columns * tile_size - assert size % max_multiple == 0, "size must be multiple of num_aie_columns * tile_size" + assert ( + size % max_multiple == 0 + ), "size must be multiple of num_aie_columns * tile_size" assert size % tile_size == 0, "size must be multiple of tile_size" self.size = size @@ -57,52 +59,68 @@ def get_operator_name(self): def get_mlir_artifact(self): operator_dir = Path(__file__).parent - return PythonGeneratedMLIRArtifact( - f"{self.get_operator_name()}.mlir", - import_path=operator_dir / "design_weighted.py", - callback_fn="my_weighted_rms_norm", - callback_args=[ + if self.weighted: + import_path = operator_dir / "design_weighted.py" + callback_fn = "my_weighted_rms_norm" + callback_args = [ self.context.device_manager.device_type, self.size, self.num_columns, self.num_channels, self.tile_size, 0, - ], + ] + else: + import_path = operator_dir / "design.py" + callback_fn = "my_rms_norm" + callback_args = [ + self.context.device_manager.device_type, + self.size, + self.num_columns, + self.num_channels, + 0, # trace_size + self.tile_size, + ] + + return PythonGeneratedMLIRArtifact( + f"{self.get_operator_name()}.mlir", + import_path=import_path, + callback_fn=callback_fn, + callback_args=callback_args, callback_kwargs={ - "kernel_archive": f"{self.get_operator_name()}.a", - } + "kernel_archive": self.kernel_archive, + }, ) def get_kernel_artifacts(self): - return [ + artifacts = [ KernelObjectArtifact( f"rms_norm.o", dependencies=[ SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "aie2p" - / "rms_norm.cc" - ) - ], - ), - KernelObjectArtifact( - "mul.o", - dependencies=[ - SourceArtifact( - self.context.base_dir - / "aie_kernels" - / "generic" - / "mul.cc" + self.context.base_dir / "aie_kernels" / "aie2p" / "rms_norm.cc" ) ], ), ] + if self.weighted: + artifacts.append( + KernelObjectArtifact( + "mul.o", + dependencies=[ + SourceArtifact( + self.context.base_dir / "aie_kernels" / "generic" / "mul.cc" + ) + ], + ) + ) + return artifacts def get_arg_spec(self): - return [ - AIERuntimeArgSpec("in", (self.size // self.tile_size, self.tile_size)), # input - AIERuntimeArgSpec("in", (self.tile_size,)), # weight - AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size)) # output - ] + specs = [AIERuntimeArgSpec("in", (self.size // self.tile_size, self.tile_size))] + if self.weighted: + specs.append(AIERuntimeArgSpec("in", (self.tile_size,))) + specs.append( + AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size)) + ) + return specs diff --git a/operators/rope/test.py b/operators/rope/test.py index 7399f78a..f033f58b 100755 --- a/operators/rope/test.py +++ b/operators/rope/test.py @@ -98,12 +98,7 @@ def test_rope(rows, cols, angle_rows, aie_columns, method_type, aie_context): operator, input_buffers, output_buffers, rel_tol=0.05, abs_tol=0.5 ) - print(golden_ref["C"]) - print( - operator.read_buffer_as_torch("output", (rows // angle_rows, angle_rows, cols)) - ) - print(f"\nLatency (us): {latency_us:.1f}") print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") - # assert not errors, f"Test failed with errors: {errors}" + assert not errors, f"Test failed with errors: {errors}" diff --git a/operators/sigmoid/design.py b/operators/sigmoid/design.py index 49d33502..927f9432 100644 --- a/operators/sigmoid/design.py +++ b/operators/sigmoid/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_sigmoid(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_sigmoid( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/operators/silu/test.py b/operators/silu/test.py index e9ac0fb4..79a65884 100755 --- a/operators/silu/test.py +++ b/operators/silu/test.py @@ -61,7 +61,6 @@ def test_silu(input_length, num_aie_columns, num_channels, tile_size, aie_contex operator = AIESiLU( size=input_length, num_aie_columns=num_aie_columns, - num_channels=num_channels, tile_size=tile_size, context=aie_context, ) diff --git a/operators/softmax/test.py b/operators/softmax/test.py index dd2c297e..818a5434 100755 --- a/operators/softmax/test.py +++ b/operators/softmax/test.py @@ -34,7 +34,7 @@ def get_optimal_columns_channels(input_length, tile_size): def generate_test_params(extensive=False): max_aie_columns = 8 num_channels = 2 - input_lengths = [4096] if not extensive else [] + input_lengths = [32768] if not extensive else [] tile_sizes = [1024, 512, 2048] params = [] diff --git a/operators/tanh/design.py b/operators/tanh/design.py index 0f78fc92..c3e0acad 100644 --- a/operators/tanh/design.py +++ b/operators/tanh/design.py @@ -14,7 +14,9 @@ from aie.iron.controlflow import range_ -def my_tanh(dev, size, num_columns, num_channels, tile_size, trace_size): +def my_tanh( + dev, size, num_columns, num_channels, tile_size, trace_size, kernel_archive=None +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] From d483d4fe7bad87aa23844ddd39a968c5453e7bda Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 16:33:19 -0700 Subject: [PATCH 70/99] the great reformatting --- applications/llama_3.2_1b/llama_cpu.py | 196 +-- .../llama_3.2_1b/llama_inference_harness.py | 114 +- applications/llama_3.2_1b/llama_npu.py | 1101 +++++++++++------ operators/__init__.py | 28 +- operators/common/__init__.py | 2 +- operators/common/base.py | 117 +- operators/common/compilation/base.py | 160 ++- operators/common/context.py | 8 +- operators/common/fusion.py | 208 ++-- operators/common/utils.py | 1 - operators/dequant/reference.py | 1 - operators/elementwise_add/design.py | 14 +- operators/elementwise_add/op.py | 6 +- operators/elementwise_mul/design.py | 14 +- operators/elementwise_mul/op.py | 6 +- operators/gemm/design.py | 7 +- operators/gemm/op.py | 23 +- operators/gemv/design.py | 29 +- operators/gemv/op.py | 10 +- operators/repeat/design.py | 35 +- operators/repeat/op.py | 2 +- operators/rms_norm/design.py | 10 +- operators/rms_norm/design_weighted.py | 13 +- operators/rope/design.py | 7 +- operators/rope/op.py | 40 +- operators/silu/design.py | 4 +- operators/silu/op.py | 8 +- operators/softmax/design.py | 41 +- operators/softmax/op.py | 21 +- operators/softmax/test2.py | 22 +- operators/strided_copy/design.py | 86 +- operators/strided_copy/op.py | 12 +- operators/strided_copy/test2.py | 70 +- operators/transpose/design.py | 4 +- operators/transpose/op.py | 8 +- 35 files changed, 1588 insertions(+), 840 deletions(-) diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py index 31849f08..7d5c8ed3 100755 --- a/applications/llama_3.2_1b/llama_cpu.py +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -4,35 +4,35 @@ import math import llama_inference_harness as harness - # Operators # ########################################################################## + def rope_forward(x, angles): """Rotary positional embedding using precomputed angles""" # x: (batch, seq_len, num_heads, head_dim) after view and before transpose # angles: (context_length, head_dim) _, seq_len, _, head_dim = x.shape angles_slice = angles[:seq_len] # (seq_len, head_dim) - + # Split into even and odd dimensions x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) - + # Get cos and sin from angles cos = angles_slice[:, ::2] # (seq_len, head_dim//2) sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) - + # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) # (The same cosine and sine values are used across batch and heads.) cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) - + # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] rotated = torch.empty_like(x) rotated[..., : head_dim // 2] = x1 * cos - x2 * sin rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos - + return rotated @@ -45,10 +45,13 @@ def rms_norm_forward(x, weight, eps=1e-5): def grouped_query_attention_forward( - x, + x, keys_cache, values_cache, - W_query, W_key, W_value, W_out, + W_query, + W_key, + W_value, + W_out, angles, mask=None, num_heads=32, @@ -70,77 +73,95 @@ def grouped_query_attention_forward( # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. - queries = torch.nn.functional.linear(x, W_query) # (batch, seq_len, num_heads * head_dim) - keys = torch.nn.functional.linear(x, W_key) # (batch, seq_len, num_kv_groups * head_dim) - values = torch.nn.functional.linear(x, W_value) # (batch, seq_len, num_kv_groups * head_dim) - queries = queries.view(batch, seq_len, num_heads, head_dim) # (batch, seq_len, num_heads, head_dim) - keys = keys.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) - values = values.view(batch, seq_len, num_kv_groups, head_dim) # (batch, seq_len, num_kv_groups, head_dim) - + queries = torch.nn.functional.linear( + x, W_query + ) # (batch, seq_len, num_heads * head_dim) + keys = torch.nn.functional.linear( + x, W_key + ) # (batch, seq_len, num_kv_groups * head_dim) + values = torch.nn.functional.linear( + x, W_value + ) # (batch, seq_len, num_kv_groups * head_dim) + queries = queries.view( + batch, seq_len, num_heads, head_dim + ) # (batch, seq_len, num_heads, head_dim) + keys = keys.view( + batch, seq_len, num_kv_groups, head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) + values = values.view( + batch, seq_len, num_kv_groups, head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) + # Step 2: Apply RoPE - queries = rope_forward(queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) - keys = rope_forward(keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len]) + queries = rope_forward( + queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len] + ) + keys = rope_forward( + keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len] + ) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) - keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. keys_cache = torch.cat([keys_cache, keys], dim=2) values_cache = torch.cat([values_cache, values], dim=2) keys = keys_cache values = values_cache - + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value group_size = num_heads // num_kv_groups keys = keys.repeat_interleave(group_size, dim=1) values = values.repeat_interleave(group_size, dim=1) - + # Step 6: Compute attention scores # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) # -> (batch, num_heads, seq_len, seq_len) # Entry at row i, column j, indicates how much token i's query attends to token j's key. scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) - + # Step 7: Apply mask # This ensures causality, so that tokens in the future cannot attend to tokens in the past. if mask is not None: - scores = scores.masked_fill(mask, float('-inf')) - + scores = scores.masked_fill(mask, float("-inf")) + # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) attention_weights = torch.nn.functional.softmax(scores, dim=-1) - + # Step 9: Compute attention output # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) # -> (batch, num_heads, seq_len, head_dim) context = torch.matmul(attention_weights, values) - + # Step 10: Concatenate heads and project # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - + output = torch.nn.functional.linear(context, W_out) - + return output, keys_cache, values_cache def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) gate = torch.nn.functional.linear(x, fc1_weight) # gate projection - up = torch.nn.functional.linear(x, fc2_weight) # up projection - + up = torch.nn.functional.linear(x, fc2_weight) # up projection + # Step 2: Apply SiLU activation - gate_activated = torch.nn.functional.silu(gate) # (batch, seq_len, swiglu_hidden_dim) - + gate_activated = torch.nn.functional.silu( + gate + ) # (batch, seq_len, swiglu_hidden_dim) + # Step 3: Element-wise multiplication (apply the 'gating') hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) - + # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) output = torch.nn.functional.linear(hidden, fc3_weight) - + return output @@ -160,81 +181,104 @@ def transformer_block_forward( W_ffn_fc2, W_ffn_fc3, rope_angles, - attn_mask + attn_mask, ): # Step 1: RMS normalization x_norm = rms_norm_forward(x, W_norm1) - + # Step 2: Attention attn_output, attn_keys, attn_values = grouped_query_attention_forward( x_norm, attn_keys_cache, attn_values_cache, - W_attn_query, W_attn_key, W_attn_value, W_attn_out, + W_attn_query, + W_attn_key, + W_attn_value, + W_attn_out, rope_angles, attn_mask, num_heads, num_kv_groups, ) - + # Step 3: Residual x = x + attn_output - + # Step 4: Post-norm x_norm = rms_norm_forward(x, W_norm2) - + # Step 5: fully-connected feed-forward network ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) - + # Step 6: Residual x = x + ffn_output - + return x, attn_keys, attn_values -def llama_forward_pass( - config, - state -): +def llama_forward_pass(config, state): batch, seq_len = state.token_ids.shape - + # Step 1: Token embedding - tok_emb_weight = config.weights['model.embed_tokens.weight'] - x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) # (batch, seq_len, emb_dim) - + tok_emb_weight = config.weights["model.embed_tokens.weight"] + x = torch.nn.functional.embedding( + state.token_ids, tok_emb_weight + ) # (batch, seq_len, emb_dim) + # Step 2: Create causal mask attn_mask = torch.triu( - torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), - diagonal=1 + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) - + # Step 3: Apply transformer blocks for layer_idx in range(config.n_layers): - x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward( - x, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], - config.n_heads, - config.n_kv_groups, - W_norm1=config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'], - W_attn_query=config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'], - W_attn_key=config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'], - W_attn_value=config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'], - W_attn_out=config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'], - W_ffn_fc1=config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'], - W_ffn_fc2=config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'], - W_ffn_fc3=config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'], - W_norm2=config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'], - rope_angles=config.angles, - attn_mask=attn_mask, + x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( + transformer_block_forward( + x, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + config.n_heads, + config.n_kv_groups, + W_norm1=config.weights[ + f"model.layers.{layer_idx}.input_layernorm.weight" + ], + W_attn_query=config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ], + W_attn_key=config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ], + W_attn_value=config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ], + W_attn_out=config.weights[ + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ], + W_ffn_fc1=config.weights[ + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ], + W_ffn_fc2=config.weights[ + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ], + W_ffn_fc3=config.weights[ + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ], + W_norm2=config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ], + rope_angles=config.angles, + attn_mask=attn_mask, + ) ) - + # Step 4: Final normalization - final_norm_weight = config.weights['model.norm.weight'] + final_norm_weight = config.weights["model.norm.weight"] x = rms_norm_forward(x, final_norm_weight) - + # Step 5: Output projection - logits = torch.nn.functional.linear(x, config.weights['model.embed_tokens.weight']) # (batch, seq_len, vocab_size) + logits = torch.nn.functional.linear( + x, config.weights["model.embed_tokens.weight"] + ) # (batch, seq_len, vocab_size) return logits, state @@ -242,11 +286,13 @@ def llama_forward_pass( # Main # ########################################################################## + def main(): prompt = "The capital of France is " config, state = harness.init(prompt=prompt) - print(prompt, end='', flush=True) + print(prompt, end="", flush=True) harness.generate(config, state, llama_forward_pass) + if __name__ == "__main__": main() diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 3d639627..34b5cf5f 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -16,10 +16,10 @@ import safetensors.torch import tiktoken, tiktoken.load - # Configuration # ########################################################################## + class LlamaConfig: def __init__(self, weights_path, tokenizer_path): # Model architecture @@ -30,16 +30,16 @@ def __init__(self, weights_path, tokenizer_path): self.n_kv_groups = 8 self.head_dim = self.emb_dim // self.n_heads # 64 self.hidden_dim = 8192 - + # RoPE self.rope_base = 500000.0 self.context_length = 131072 - + # Generation self.temperature = 0.7 self.top_k = 50 - # Tokenization + # Tokenization self.special_tokens = { "<|begin_of_text|>": 128000, "<|end_of_text|>": 128001, @@ -47,10 +47,12 @@ def __init__(self, weights_path, tokenizer_path): "<|end_header_id|>": 128007, "<|eot_id|>": 128009, } - self.special_tokens.update({ - f"<|reserved_{i}|>": i - for i in list(range(128002, 128006)) + list(range(128009, 128256)) - }) + self.special_tokens.update( + { + f"<|reserved_{i}|>": i + for i in list(range(128002, 128006)) + list(range(128009, 128256)) + } + ) # Load model weights and tokenizer self.weights = safetensors.torch.load_file(weights_path) @@ -59,9 +61,7 @@ def __init__(self, weights_path, tokenizer_path): # Compute RoPE angle look-up table self.angles = compute_rope_angles( - self.head_dim, - self.context_length, - self.rope_base + self.head_dim, self.context_length, self.rope_base ) @@ -75,11 +75,23 @@ def reset_kv_cache(self, config): # Set up KV cache -- initially empty # This is what passes information from previous tokens to the current token during generation self.attn_keys_caches = [ - torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=config.weights["model.layers.0.self_attn.k_proj.weight"].dtype) # (batch_size, n_kv_groups, seq_len, head_dim) + torch.empty( + 1, + config.n_kv_groups, + 0, + config.head_dim, + dtype=config.weights["model.layers.0.self_attn.k_proj.weight"].dtype, + ) # (batch_size, n_kv_groups, seq_len, head_dim) for _ in range(config.n_layers) ] self.attn_values_caches = [ - torch.empty(1, config.n_kv_groups, 0, config.head_dim, dtype=config.weights["model.layers.0.self_attn.v_proj.weight"].dtype) # (batch_size, n_kv_groups, seq_len, head_dim) + torch.empty( + 1, + config.n_kv_groups, + 0, + config.head_dim, + dtype=config.weights["model.layers.0.self_attn.v_proj.weight"].dtype, + ) # (batch_size, n_kv_groups, seq_len, head_dim) for _ in range(config.n_layers) ] @@ -87,16 +99,17 @@ def reset_kv_cache(self, config): # Utilities # ########################################################################## + def compute_rope_angles(head_dim, context_length, rope_base=500000.0): """Compute RoPE (Rotary Position Embedding) angles.""" # Precompute the frequency tensor inv_freq = 1.0 / (rope_base ** (torch.arange(0, head_dim, 2).float() / head_dim)) position = torch.arange(context_length).float() freqs = torch.outer(position, inv_freq) - + cos = torch.cos(freqs) sin = torch.sin(freqs) - + # Interleave cos and sin - create angles buffer angles = torch.empty(context_length, head_dim) angles[:, ::2] = cos @@ -123,36 +136,28 @@ def get_tokenizer(tokenizer_path, special_tokens): # Generation loop # ########################################################################## -def generate_token( - config, - forward_pass, - state -): + +def generate_token(config, forward_pass, state): generated_tokens = [] - + # Step 1: Forward pass - logits, state = forward_pass( - config, - state - ) - + logits, state = forward_pass(config, state) + # Step 2: Get logits for last token last_token_logits = logits[:, -1, :] # (batch, vocab_size) - + # Step 3: Temperature scaling if config.temperature > 0: last_token_logits = last_token_logits / config.temperature - + # Step 4: Top-k filtering if config.top_k is not None: top_logits, top_indices = torch.topk(last_token_logits, config.top_k) min_val = top_logits[:, -1:] last_token_logits = torch.where( - last_token_logits < min_val, - torch.tensor(float('-inf')), - last_token_logits + last_token_logits < min_val, torch.tensor(float("-inf")), last_token_logits ) - + # Step 5: Sample probs = torch.nn.functional.softmax(last_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) @@ -174,7 +179,9 @@ def init( # Tokenize prompt prompt_token_ids = [config.special_tokens["<|begin_of_text|>"]] prompt_token_ids += config.tokenizer.encode(prompt) - assert len(prompt_token_ids) <= config.context_length, "Prompt + new tokens to generate too long (exceed context)" + assert ( + len(prompt_token_ids) <= config.context_length + ), "Prompt + new tokens to generate too long (exceed context)" prompt_token_ids = torch.tensor([prompt_token_ids], dtype=torch.long) state.token_ids = prompt_token_ids @@ -182,21 +189,15 @@ def init( return config, state -def generate( - config, - state, - forward_pass, - num_tokens=100, - use_kv_cache=True -): +def generate(config, state, forward_pass, num_tokens=100, use_kv_cache=True): # Generate tokens # First token (prefill) n_tokens_generated = 0 t_prefill_start = time.perf_counter() first_token, state = generate_token(config, forward_pass, state) token_text = config.tokenizer.decode([first_token]) - n_tokens_generated += 1 - print(token_text, end='', flush=True) + n_tokens_generated += 1 + print(token_text, end="", flush=True) t_prefill_stop = time.perf_counter() # Remaining tokens (decode) @@ -204,30 +205,41 @@ def generate( state.token_ids = torch.tensor([[first_token]], dtype=torch.long) else: state.reset_kv_cache(config) - state.token_ids = torch.cat([state.token_ids, torch.tensor([[first_token]], dtype=torch.long)], dim=1) + state.token_ids = torch.cat( + [state.token_ids, torch.tensor([[first_token]], dtype=torch.long)], dim=1 + ) t_decode_start = time.perf_counter() - for _ in range(num_tokens-1): + for _ in range(num_tokens - 1): next_token, state = generate_token(config, forward_pass, state) token_text = config.tokenizer.decode([next_token]) n_tokens_generated += 1 - print(token_text, end='', flush=True) + print(token_text, end="", flush=True) if use_kv_cache: state.token_ids = torch.tensor([[next_token]], dtype=torch.long) else: state.reset_kv_cache(config) - state.token_ids = torch.cat([state.token_ids, torch.tensor([[next_token]], dtype=torch.long)], dim=1) + state.token_ids = torch.cat( + [state.token_ids, torch.tensor([[next_token]], dtype=torch.long)], dim=1 + ) t_decode_end = time.perf_counter() t_prefill = t_prefill_stop - t_prefill_start t_decode = t_decode_end - t_decode_start sys.stderr.write("\n\n=== Performance Statistics ===\n") sys.stderr.write(f"[Prefill] Time to first token: {t_prefill:7.3f} s\n") - sys.stderr.write(f"[Decode] Time per token (mean): {t_decode / (n_tokens_generated - 1):7.3f} s\n") - sys.stderr.write(f"[Decode] Tokens per second: {(n_tokens_generated - 1) / t_decode:7.3f}\n") - sys.stderr.write(f"[Total] Time per token (mean): {(t_prefill + t_decode) / n_tokens_generated:7.3f} s\n") - sys.stderr.write(f"[Total] Tokens per second: {n_tokens_generated / (t_prefill + t_decode):7.3f}\n") + sys.stderr.write( + f"[Decode] Time per token (mean): {t_decode / (n_tokens_generated - 1):7.3f} s\n" + ) + sys.stderr.write( + f"[Decode] Tokens per second: {(n_tokens_generated - 1) / t_decode:7.3f}\n" + ) + sys.stderr.write( + f"[Total] Time per token (mean): {(t_prefill + t_decode) / n_tokens_generated:7.3f} s\n" + ) + sys.stderr.write( + f"[Total] Tokens per second: {n_tokens_generated / (t_prefill + t_decode):7.3f}\n" + ) if __name__ == "__main__": main() - diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 2b865959..214ad45e 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -25,7 +25,12 @@ from operators.common import AIEBuffer from operators.common.utils import torch_to_numpy from operators.common.base import PatchableSingleXclbinCallable -from operators.common.fusion import FusedMLIROperator, FusedFullELFCallable, load_elf, patch_elf +from operators.common.fusion import ( + FusedMLIROperator, + FusedFullELFCallable, + load_elf, + patch_elf, +) from operators import ( AIERMSNorm, AIEGEMM, @@ -37,7 +42,7 @@ AIEStridedCopy, AIERepeat, AIESoftmax, - AIETranspose + AIETranspose, ) logging.basicConfig(level=logging.DEBUG) @@ -51,14 +56,17 @@ aie_ops = None + class AIEPrefillOperations: pass + class AIEDecodeOperations: pass + class AIELlamaOperators: - + def __init__(self, config, prompt_len): self.context = AIEContext() self.context.build_dir.mkdir(parents=True, exist_ok=True) @@ -69,23 +77,31 @@ def __init__(self, config, prompt_len): # ################################################################## # Prefill operators - self.prefill.rms_norm = AIERMSNorm( - size=prompt_len * config.emb_dim, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=config.emb_dim, - context=self.context - ).compile().get_callable() - - self.prefill.residual_add = AIEElementwiseAdd( - size=prompt_len * config.emb_dim, - tile_size=config.emb_dim - ).compile().get_callable() - self.decode.residual_add = AIEElementwiseAdd( - size=config.emb_dim, - tile_size=config.emb_dim // 8 - ).compile().get_callable() + self.prefill.rms_norm = ( + AIERMSNorm( + size=prompt_len * config.emb_dim, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=config.emb_dim, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.residual_add = ( + AIEElementwiseAdd( + size=prompt_len * config.emb_dim, tile_size=config.emb_dim + ) + .compile() + .get_callable() + ) + self.decode.residual_add = ( + AIEElementwiseAdd(size=config.emb_dim, tile_size=config.emb_dim // 8) + .compile() + .get_callable() + ) min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N @@ -100,131 +116,174 @@ def __init__(self, config, prompt_len): tile_n=64, b_col_maj=True, separate_c_tiles=True, - context=self.context + context=self.context, ).compile() self.prefill.out_head = self.prefill.gemv_out_head_compilable.get_callable() - + # SwiGLU FFN operators # Prefill: M=prompt_len, K=emb_dim, N=hidden_dim - self.prefill.ffn_up_gate = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.hidden_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - context=self.context - ).compile().get_callable() - - self.prefill.ffn_down = AIEGEMM( - M=prompt_len, - K=config.hidden_dim, - N=config.emb_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - context=self.context - ).compile().get_callable() - - self.prefill.ffn_silu = AIESiLU( - size=prompt_len * config.hidden_dim, - tile_size=config.hidden_dim, - num_aie_columns=8, - context=self.context - ).compile().get_callable() - - self.prefill.eltwise_mul_ffn = AIEElementwiseMul( - size=prompt_len * config.hidden_dim, - tile_size=config.hidden_dim, - num_aie_columns=8, - context=self.context - ).compile().get_callable() + self.prefill.ffn_up_gate = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.hidden_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.ffn_down = ( + AIEGEMM( + M=prompt_len, + K=config.hidden_dim, + N=config.emb_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.ffn_silu = ( + AIESiLU( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.eltwise_mul_ffn = ( + AIEElementwiseMul( + size=prompt_len * config.hidden_dim, + tile_size=config.hidden_dim, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) # Attention score scaling operators # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY - self.prefill.attn_scale = AIEElementwiseMul( - size=config.n_heads * prompt_len * prompt_len, - tile_size=prompt_len, - num_aie_columns=8, - context=self.context - ).compile().get_callable() - + self.prefill.attn_scale = ( + AIEElementwiseMul( + size=config.n_heads * prompt_len * prompt_len, + tile_size=prompt_len, + num_aie_columns=8, + context=self.context, + ) + .compile() + .get_callable() + ) + # RoPE operators # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) # angle_rows=1 because all rows use the same angle row (angles are per position) - self.prefill.rope_queries = AIERope( - rows=prompt_len * config.n_heads, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context - ).compile().get_callable() - - self.prefill.rope_keys = AIERope( - rows=prompt_len * config.n_kv_groups, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context - ).compile().get_callable() - + self.prefill.rope_queries = ( + AIERope( + rows=prompt_len * config.n_heads, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context, + ) + .compile() + .get_callable() + ) + + self.prefill.rope_keys = ( + AIERope( + rows=prompt_len * config.n_kv_groups, + cols=config.head_dim, + angle_rows=prompt_len, + context=self.context, + ) + .compile() + .get_callable() + ) + # Attention projection operators # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) - self.prefill.attn_query = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_heads * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - + self.prefill.attn_query = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_heads * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_key = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - + self.prefill.attn_key = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_value = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - + self.prefill.attn_value = ( + AIEGEMM( + M=prompt_len, + K=config.emb_dim, + N=config.n_kv_groups * config.head_dim, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) + # Attention score computation: Q @ K^T per head # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head - self.prefill.attn_scores = AIEGEMM( - M=prompt_len, - K=config.head_dim, - N=prompt_len, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context - ).compile().get_callable() - + self.prefill.attn_scores = ( + AIEGEMM( + M=prompt_len, + K=config.head_dim, + N=prompt_len, + num_aie_columns=8, + tile_m=64, + tile_k=64, + tile_n=64, + b_col_maj=False, + context=self.context, + ) + .compile() + .get_callable() + ) # Decode operator (everything temporally fused) # ################################################################## @@ -237,32 +296,29 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=4, tile_size_output=config.head_dim // 2, - context=elf_ctx + context=elf_ctx, ) - + gemv_attn_key_value_op = AIEGEMV( M=config.n_kv_groups * config.head_dim, K=config.emb_dim, num_aie_columns=8, tile_size_input=4, tile_size_output=config.head_dim // 2, - context=elf_ctx + context=elf_ctx, ) - + rope_queries_op = AIERope( - rows=1 * config.n_heads, - cols=config.head_dim, - angle_rows=1, - context=elf_ctx + rows=1 * config.n_heads, cols=config.head_dim, angle_rows=1, context=elf_ctx ) - + rope_keys_op = AIERope( rows=1 * config.n_kv_groups, cols=config.head_dim, angle_rows=1, - context=elf_ctx + context=elf_ctx, ) - + strided_copy_cache_magic = 0xDEADBEE0 strided_copy_cache_op = AIEStridedCopy( input_sizes=(config.n_kv_groups, config.head_dim), @@ -275,9 +331,9 @@ def __init__(self, config, prompt_len): output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, num_aie_channels=1, output_offset_patch_marker=strided_copy_cache_magic, - context=elf_ctx + context=elf_ctx, ) - + # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) gemv_attn_scores_op = AIEGEMV( @@ -287,16 +343,16 @@ def __init__(self, config, prompt_len): tile_size_input=4, tile_size_output=prompt_len // 8, num_batches=config.n_heads, - context=elf_ctx + context=elf_ctx, ) - + attn_scale_op = AIEElementwiseMul( size=config.n_heads * prompt_len, tile_size=prompt_len // 8, num_aie_columns=8, - context=elf_ctx + context=elf_ctx, ) - + # Softmax operators for attention weights softmax_magic = 0xBA5EBA11 softmax_op = AIESoftmax( @@ -306,9 +362,9 @@ def __init__(self, config, prompt_len): num_channels=1, rtp_vector_size=prompt_len, # Compile with max size mask_patch_value=softmax_magic, # Magic value for patching - context=elf_ctx + context=elf_ctx, ) - + # Fused transpose for all attention heads (decode) transpose_values_op = AIETranspose( M=prompt_len, @@ -318,9 +374,9 @@ def __init__(self, config, prompt_len): m=256, n=32, s=8, - context=elf_ctx + context=elf_ctx, ) - + # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head gemv_attn_context_op = AIEGEMV( M=config.head_dim, @@ -329,7 +385,7 @@ def __init__(self, config, prompt_len): tile_size_input=4, tile_size_output=4, num_batches=config.n_heads, - context=elf_ctx + context=elf_ctx, ) gemv_attn_output_op = AIEGEMV( @@ -338,62 +394,60 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=4, tile_size_output=config.emb_dim // 8, - context=elf_ctx + context=elf_ctx, ) - + rms_norm_op = AIERMSNorm( size=config.emb_dim, eps=1e-5, num_aie_columns=1, num_channels=2, tile_size=config.emb_dim, - context=elf_ctx + context=elf_ctx, ) - + gemv_ffn_up_gate_op = AIEGEMV( M=config.hidden_dim, K=config.emb_dim, num_aie_columns=8, tile_size_input=4, tile_size_output=config.hidden_dim // 8, - context=elf_ctx + context=elf_ctx, ) - + gemv_ffn_down_op = AIEGEMV( M=config.emb_dim, K=config.hidden_dim, num_aie_columns=8, tile_size_input=1, tile_size_output=config.emb_dim // 8, - context=elf_ctx + context=elf_ctx, ) - + silu_ffn_op = AIESiLU( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, - context=elf_ctx + context=elf_ctx, ) - + eltwise_mul_ffn_op = AIEElementwiseMul( size=config.hidden_dim, tile_size=config.hidden_dim // 8, num_aie_columns=8, - context=elf_ctx + context=elf_ctx, ) - + residual_add_op = AIEElementwiseAdd( - size=config.emb_dim, - tile_size=config.emb_dim // 8, - context=elf_ctx + size=config.emb_dim, tile_size=config.emb_dim // 8, context=elf_ctx ) - + repeat_interleave_op = AIERepeat( rows=config.n_kv_groups, cols=prompt_len * config.head_dim, # Max context length repeat=config.n_heads // config.n_kv_groups, transfer_size=config.head_dim, - context=elf_ctx + context=elf_ctx, ) gemv_out_head_op = AIEGEMV( @@ -402,62 +456,119 @@ def __init__(self, config, prompt_len): num_aie_columns=8, tile_size_input=4, tile_size_output=32, - context=self.context + context=self.context, ) # Create fused operator - - cache_buffer_size = config.n_kv_groups * prompt_len * config.head_dim * 2 # * 2 for bfloat16 - values_per_head_buffer_size = prompt_len * config.head_dim * 2 # * 2 for bfloat16 + + cache_buffer_size = ( + config.n_kv_groups * prompt_len * config.head_dim * 2 + ) # * 2 for bfloat16 + values_per_head_buffer_size = ( + prompt_len * config.head_dim * 2 + ) # * 2 for bfloat16 values_buffer_size = config.n_heads * values_per_head_buffer_size runlist = [] - for layer_idx in range(config.n_layers): + for layer_idx in range(config.n_layers): # runlist.extend( [ - (rms_norm_op, "x", f"W_norm1_{layer_idx}", "x_norm") # Step 1: RMS normalization - ] + [ + ( + rms_norm_op, + "x", + f"W_norm1_{layer_idx}", + "x_norm", + ) # Step 1: RMS normalization + ] + + [ # - (gemv_attn_query_op, f"W_attn_query_{layer_idx}", "x_norm", "queries"), - (gemv_attn_key_value_op, f"W_attn_key_{layer_idx}", "x_norm", "keys"), - (gemv_attn_key_value_op, f"W_attn_value_{layer_idx}", "x_norm", "values"), - (rope_queries_op, "queries", "rope_angles", "queries"), - (rope_keys_op, "keys", "rope_angles", "keys"), - (strided_copy_cache_op, "keys", f"keys_cache_{layer_idx}"), - (strided_copy_cache_op, "values", f"values_cache_{layer_idx}"), - (repeat_interleave_op, f"keys_cache_{layer_idx}", "attn_scores_keys"), - (repeat_interleave_op, f"values_cache_{layer_idx}", "attn_scores_values"), - (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), - (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), - (softmax_op, "attn_scores", "attn_weights") - ] + [ - (transpose_values_op, + ( + gemv_attn_query_op, + f"W_attn_query_{layer_idx}", + "x_norm", + "queries", + ), + ( + gemv_attn_key_value_op, + f"W_attn_key_{layer_idx}", + "x_norm", + "keys", + ), + ( + gemv_attn_key_value_op, + f"W_attn_value_{layer_idx}", + "x_norm", + "values", + ), + (rope_queries_op, "queries", "rope_angles", "queries"), + (rope_keys_op, "keys", "rope_angles", "keys"), + (strided_copy_cache_op, "keys", f"keys_cache_{layer_idx}"), + (strided_copy_cache_op, "values", f"values_cache_{layer_idx}"), + ( + repeat_interleave_op, + f"keys_cache_{layer_idx}", + "attn_scores_keys", + ), + ( + repeat_interleave_op, + f"values_cache_{layer_idx}", + "attn_scores_values", + ), + (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), + (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), + (softmax_op, "attn_scores", "attn_weights"), + ] + + [ + ( + transpose_values_op, f"attn_scores_values[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", - f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]" + f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", ) for h in range(config.n_heads) - ] + [ - (gemv_attn_context_op, "attn_scores_values_transposed", "attn_weights", "attn_context"), - (gemv_attn_output_op, f"W_attn_output_decode_{layer_idx}", "attn_context", "attn_output") + ] + + [ + ( + gemv_attn_context_op, + "attn_scores_values_transposed", + "attn_weights", + "attn_context", + ), + ( + gemv_attn_output_op, + f"W_attn_output_decode_{layer_idx}", + "attn_context", + "attn_output", + ), # - ] + [ - (residual_add_op, "x", "attn_output", "x"), - (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), - (gemv_ffn_up_gate_op, f"W_ffn_gate_{layer_idx}", "x_norm", "ffn_gate"), - (gemv_ffn_up_gate_op, f"W_ffn_up_{layer_idx}", "x_norm", "ffn_up"), - (silu_ffn_op, "ffn_gate", "ffn_gate"), - (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), - (gemv_ffn_down_op, f"W_ffn_down_{layer_idx}", "ffn_hidden", "ffn_output"), - (residual_add_op, "x", "ffn_output", "x"), + ] + + [ + (residual_add_op, "x", "attn_output", "x"), + (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), + ( + gemv_ffn_up_gate_op, + f"W_ffn_gate_{layer_idx}", + "x_norm", + "ffn_gate", + ), + (gemv_ffn_up_gate_op, f"W_ffn_up_{layer_idx}", "x_norm", "ffn_up"), + (silu_ffn_op, "ffn_gate", "ffn_gate"), + (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), + ( + gemv_ffn_down_op, + f"W_ffn_down_{layer_idx}", + "ffn_hidden", + "ffn_output", + ), + (residual_add_op, "x", "ffn_output", "x"), ] ) # runlist += [ - (rms_norm_op, "x", "W_final_norm", "x"), - (gemv_out_head_op, "W_out_head", "x", "logits") + (rms_norm_op, "x", "W_final_norm", "x"), + (gemv_out_head_op, "W_out_head", "x", "logits"), ] - + self.decode.fused_op = FusedMLIROperator( "fused_op", runlist, @@ -465,9 +576,7 @@ def __init__(self, config, prompt_len): "x", "rope_angles", ], - output_args=[ - "logits" - ], + output_args=["logits"], buffer_sizes={ **{ f"keys_cache_{layer_idx}": cache_buffer_size @@ -479,10 +588,10 @@ def __init__(self, config, prompt_len): }, **{ "attn_scores_values": values_buffer_size, - "attn_scores_values_transposed": values_buffer_size - } + "attn_scores_values_transposed": values_buffer_size, + }, }, - context=elf_ctx + context=elf_ctx, ).compile() # Operator patching @@ -496,50 +605,113 @@ def get_patch_locs(elf_data, magic): keys_patches = {} values_patches = {} for layer_idx in range(config.n_layers): - _, keys_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer(f"keys_cache_{layer_idx}") - _, values_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer(f"values_cache_{layer_idx}") - keys_patches.update({ - int(l): keys_cache_offs - for l in get_patch_locs(self.decode.fused_elf_data, (keys_cache_offs + strided_copy_cache_magic * 2)) - }) - values_patches.update({ - int(l): values_cache_offs - for l in get_patch_locs(self.decode.fused_elf_data, (values_cache_offs + strided_copy_cache_magic * 2)) - }) + _, keys_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( + f"keys_cache_{layer_idx}" + ) + _, values_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( + f"values_cache_{layer_idx}" + ) + keys_patches.update( + { + int(l): keys_cache_offs + for l in get_patch_locs( + self.decode.fused_elf_data, + (keys_cache_offs + strided_copy_cache_magic * 2), + ) + } + ) + values_patches.update( + { + int(l): values_cache_offs + for l in get_patch_locs( + self.decode.fused_elf_data, + (values_cache_offs + strided_copy_cache_magic * 2), + ) + } + ) no_offset_patches = { int(l): 0 - for l in get_patch_locs(self.decode.fused_elf_data, (strided_copy_cache_magic * 2)) + for l in get_patch_locs( + self.decode.fused_elf_data, (strided_copy_cache_magic * 2) + ) + } + self.decode.fused_patch_locations = { + **keys_patches, + **values_patches, + **no_offset_patches, } - self.decode.fused_patch_locations = {**keys_patches, **values_patches, **no_offset_patches} assert len(self.decode.fused_patch_locations) == 4 * config.n_layers + 2 - self.decode.softmax_patch_offsets = get_patch_locs(self.decode.fused_elf_data, softmax_magic) + self.decode.softmax_patch_offsets = get_patch_locs( + self.decode.fused_elf_data, softmax_magic + ) assert len(self.decode.softmax_patch_offsets) == config.n_layers + 1 self.decode.fused = FusedFullELFCallable( - self.decode.fused_op, - elf_data=self.decode.fused_elf_data + self.decode.fused_op, elf_data=self.decode.fused_elf_data ) # Operator static buffers (weights, LUTs) for layer_idx in range(config.n_layers): - self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.input_layernorm.weight'].flatten() - self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_attn_key_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_attn_value_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_attn_output_decode_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_norm2_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight'].flatten() - self.decode.fused.get_buffer(f"W_ffn_gate_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'].flatten() - self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to("cpu").view_as_torch()[:] = config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'].flatten() + self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.input_layernorm.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_key_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_value_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_attn_output_decode_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_norm2_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_gate_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ].flatten() + self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = config.weights[ + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ].flatten() scale_factor = 1.0 / math.sqrt(config.head_dim) - self.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[:] = scale_factor - self.decode.fused.get_buffer("W_final_norm").to("cpu").view_as_torch()[:] = config.weights['model.norm.weight'].flatten() - self.decode.fused.get_buffer("W_out_head").to("cpu").view_as_torch()[:] = config.weights['model.embed_tokens.weight'].flatten() + self.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[ + : + ] = scale_factor + self.decode.fused.get_buffer("W_final_norm").to("cpu").view_as_torch()[:] = ( + config.weights["model.norm.weight"].flatten() + ) + self.decode.fused.get_buffer("W_out_head").to("cpu").view_as_torch()[:] = ( + config.weights["model.embed_tokens.weight"].flatten() + ) self.decode.fused.input_buffer.to("npu") self.decode.fused.scratch_buffer.to("npu") - self.decode.fused.output_buffer.to("npu") + self.decode.fused.output_buffer.to("npu") # Allocate buffers shared with NPU @@ -547,73 +719,115 @@ def get_patch_locs(elf_data, magic): aie_buffers = None + class AIEPrefillBuffers: def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) - self.attn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) - self.ffn_output = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) + self.attn_output = AIEBuffer( + shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_output = AIEBuffer( + shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 + ) # SwiGLU intermediate buffers - self.ffn_gate = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) - self.ffn_up = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) - self.ffn_hidden = AIEBuffer(shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16) + self.ffn_gate = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_up = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) + self.ffn_hidden = AIEBuffer( + shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 + ) # Attention buffers: queries and keys serve as both projection output and RoPE input/output - self.queries = AIEBuffer(shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16) - self.keys = AIEBuffer(shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16) - self.values = AIEBuffer(shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16) - self.rope_angles = AIEBuffer(shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + self.queries = AIEBuffer( + shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16 + ) + self.keys = AIEBuffer( + shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16 + ) + self.values = AIEBuffer( + shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16 + ) + self.rope_angles = AIEBuffer( + shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16 + ) # Attention score computation buffers (per-head) - parent buffers with subbuffers # Parent buffer for all heads' queries: (n_heads, prompt_len, head_dim) stored contiguously - self.attn_scores_queries_all = AIEBuffer(shape=(n_heads * prompt_len, head_dim), dtype=ml_dtypes.bfloat16) + self.attn_scores_queries_all = AIEBuffer( + shape=(n_heads * prompt_len, head_dim), dtype=ml_dtypes.bfloat16 + ) self.attn_scores_queries_per_head = [ self.attn_scores_queries_all.subbuffer( length=prompt_len * head_dim, offset=h * prompt_len * head_dim, - shape=(prompt_len, head_dim) + shape=(prompt_len, head_dim), ) for h in range(n_heads) ] # Parent buffer for all KV groups' keys: (n_kv_groups, head_dim, prompt_len) stored contiguously - self.attn_scores_keys_all = AIEBuffer(shape=(n_kv_groups * head_dim, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scores_keys_all = AIEBuffer( + shape=(n_kv_groups * head_dim, prompt_len), dtype=ml_dtypes.bfloat16 + ) self.attn_scores_keys_per_kv_group = [ self.attn_scores_keys_all.subbuffer( length=head_dim * prompt_len, offset=g * head_dim * prompt_len, - shape=(head_dim, prompt_len) + shape=(head_dim, prompt_len), ) for g in range(n_kv_groups) ] # Parent buffer for all heads' scores: (n_heads * prompt_len, prompt_len) - self.attn_scores = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scores = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) self.attn_scores_per_head = [ self.attn_scores.subbuffer( length=prompt_len * prompt_len, offset=h * prompt_len * prompt_len, - shape=(prompt_len, prompt_len) + shape=(prompt_len, prompt_len), ) for h in range(n_heads) ] # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) scale_factor = 1.0 / math.sqrt(head_dim) - self.attn_scale_factor = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_scale_factor = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) self.attn_scale_factor.view_as_torch()[:] = scale_factor self.attn_scale_factor.to("npu") # Attention weights buffer (output of softmax) - self.attn_weights = AIEBuffer(shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16) + self.attn_weights = AIEBuffer( + shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 + ) class AIELlamaBuffers: def __init__(self, config, prompt_len): # Vector of the current token(s) being processed through the pipeline - self.prefill = AIEPrefillBuffers(prompt_len, config.emb_dim, config.hidden_dim, config.n_heads, config.n_kv_groups, config.head_dim) + self.prefill = AIEPrefillBuffers( + prompt_len, + config.emb_dim, + config.hidden_dim, + config.n_heads, + config.n_kv_groups, + config.head_dim, + ) # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) self.keys_cache = [ - AIEBuffer(shape=(config.n_kv_groups, prompt_len, config.head_dim), dtype=ml_dtypes.bfloat16) + AIEBuffer( + shape=(config.n_kv_groups, prompt_len, config.head_dim), + dtype=ml_dtypes.bfloat16, + ) for _ in range(config.n_layers) ] self.values_cache = [ - AIEBuffer(shape=(config.n_kv_groups, prompt_len, config.head_dim), dtype=ml_dtypes.bfloat16) + AIEBuffer( + shape=(config.n_kv_groups, prompt_len, config.head_dim), + dtype=ml_dtypes.bfloat16, + ) for _ in range(config.n_layers) ] @@ -637,47 +851,86 @@ def __init__(self, config, prompt_len): self.W_ffn_down_decode = [] for layer_idx in range(config.n_layers): self.W_norm1.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.input_layernorm.weight']).to("npu") + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.input_layernorm.weight"] + ).to("npu") ) self.W_norm2.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.post_attention_layernorm.weight']).to("npu") + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ] + ).to("npu") ) self.W_attn_query_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.q_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ].T + ).to("npu") ) self.W_attn_key_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.k_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ].T + ).to("npu") ) self.W_attn_value_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.self_attn.v_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[ + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ].T + ).to("npu") ) self.W_ffn_gate_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.gate_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"].T + ).to("npu") ) self.W_ffn_up_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.up_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"].T + ).to("npu") ) self.W_ffn_down_prefill.append( - AIEBuffer.from_torch(config.weights[f'model.layers.{layer_idx}.mlp.down_proj.weight'].T).to("npu") + AIEBuffer.from_torch( + config.weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"].T + ).to("npu") ) # Final RMS norm weights - self.W_final_norm = AIEBuffer.from_torch(config.weights['model.norm.weight']).to("npu") + self.W_final_norm = AIEBuffer.from_torch( + config.weights["model.norm.weight"] + ).to("npu") # Final linear layer - self.W_out_head = AIEBuffer.from_torch(config.weights['model.embed_tokens.weight']).to("npu") # unpadded/unpartitioned, used by GEMV + self.W_out_head = AIEBuffer.from_torch( + config.weights["model.embed_tokens.weight"] + ).to( + "npu" + ) # unpadded/unpartitioned, used by GEMV W_out_head_parts = aie_ops.prefill.gemv_out_head_compilable.partition_B( - torch_to_numpy(config.weights['model.embed_tokens.weight']), - config.vocab_partitions + torch_to_numpy(config.weights["model.embed_tokens.weight"]), + config.vocab_partitions, ) self.W_out_head_parts = [ - AIEBuffer.from_np(W_out_head_part).to("npu") + AIEBuffer.from_np(W_out_head_part).to("npu") for W_out_head_part in W_out_head_parts - ] # partitioned, padded parts of weight, used by GEMM - self.prefill.logits = AIEBuffer(shape=(config.vocab_partitions, prompt_len, config.padded_vocab_size // config.vocab_partitions)).to("npu") + ] # partitioned, padded parts of weight, used by GEMM + self.prefill.logits = AIEBuffer( + shape=( + config.vocab_partitions, + prompt_len, + config.padded_vocab_size // config.vocab_partitions, + ) + ).to("npu") self.prefill.logits_parts = [ self.prefill.logits.subbuffer( - length=prompt_len * (config.padded_vocab_size // config.vocab_partitions), - offset=i * prompt_len * (config.padded_vocab_size // config.vocab_partitions), + length=prompt_len + * (config.padded_vocab_size // config.vocab_partitions), + offset=i + * prompt_len + * (config.padded_vocab_size // config.vocab_partitions), shape=(prompt_len, config.padded_vocab_size // config.vocab_partitions), ) for i in range(config.vocab_partitions) @@ -687,9 +940,10 @@ def __init__(self, config, prompt_len): # Prefill # ########################################################################## + def grouped_query_attention_forward_prefill( config, - x, + x, keys_cache, values_cache, layer_idx, @@ -699,125 +953,182 @@ def grouped_query_attention_forward_prefill( num_preceding_tokens = keys_cache.shape[2] # Step 1: Linear projections - aie_ops.prefill.attn_query(aie_buffers.prefill.x_norm, aie_buffers.W_attn_query_prefill[layer_idx], aie_buffers.prefill.queries) - aie_ops.prefill.attn_key(aie_buffers.prefill.x_norm, aie_buffers.W_attn_key_prefill[layer_idx], aie_buffers.prefill.keys) - aie_ops.prefill.attn_value(aie_buffers.prefill.x_norm, aie_buffers.W_attn_value_prefill[layer_idx], aie_buffers.prefill.values) - + aie_ops.prefill.attn_query( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_query_prefill[layer_idx], + aie_buffers.prefill.queries, + ) + aie_ops.prefill.attn_key( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_key_prefill[layer_idx], + aie_buffers.prefill.keys, + ) + aie_ops.prefill.attn_value( + aie_buffers.prefill.x_norm, + aie_buffers.W_attn_value_prefill[layer_idx], + aie_buffers.prefill.values, + ) + # Step 2: Apply RoPE to queries and keys - aie_ops.prefill.rope_queries(aie_buffers.prefill.queries, aie_buffers.prefill.rope_angles, aie_buffers.prefill.queries) - aie_ops.prefill.rope_keys(aie_buffers.prefill.keys, aie_buffers.prefill.rope_angles, aie_buffers.prefill.keys) - + aie_ops.prefill.rope_queries( + aie_buffers.prefill.queries, + aie_buffers.prefill.rope_angles, + aie_buffers.prefill.queries, + ) + aie_ops.prefill.rope_keys( + aie_buffers.prefill.keys, + aie_buffers.prefill.rope_angles, + aie_buffers.prefill.keys, + ) + # Read results from NPU - queries = aie_buffers.prefill.queries.to("cpu").view_as_torch()[:seq_len * config.n_heads, :] - keys = aie_buffers.prefill.keys.to("cpu").view_as_torch()[:seq_len * config.n_kv_groups, :] - values = aie_buffers.prefill.values.to("cpu").view_as_torch()[:seq_len, :] # (seq_len, n_kv_groups * head_dim) + queries = aie_buffers.prefill.queries.to("cpu").view_as_torch()[ + : seq_len * config.n_heads, : + ] + keys = aie_buffers.prefill.keys.to("cpu").view_as_torch()[ + : seq_len * config.n_kv_groups, : + ] + values = aie_buffers.prefill.values.to("cpu").view_as_torch()[ + :seq_len, : + ] # (seq_len, n_kv_groups * head_dim) queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) keys = keys.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) - values = values.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) # (batch, seq_len, num_kv_groups, head_dim) + values = values.unsqueeze(0).view( + batch, seq_len, config.n_kv_groups, config.head_dim + ) # (batch, seq_len, num_kv_groups, head_dim) # Step 3: Transpose for attention computation # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. - # Transpose so that heads are consecutive for attention computation: + # Transpose so that heads are consecutive for attention computation: # (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) - keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) + values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. keys_cache = torch.cat([keys_cache, keys], dim=2) values_cache = torch.cat([values_cache, values], dim=2) keys = keys_cache values = values_cache - + # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value group_size = config.n_heads // config.n_kv_groups values = values.repeat_interleave(group_size, dim=1) context_len = keys.shape[2] - + # Step 6: Compute attention scores using NPU (per-head) # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, context_len) # -> (batch, num_heads, seq_len, context_len) - + queries_buf = aie_buffers.prefill.attn_scores_queries_all.view_as_torch().view( config.n_heads, -1, config.head_dim ) - queries_buf[:, :seq_len, :] = queries.squeeze(0)[:, :seq_len, :] # (num_heads, seq_len, head_dim) + queries_buf[:, :seq_len, :] = queries.squeeze(0)[ + :, :seq_len, : + ] # (num_heads, seq_len, head_dim) keys_buf = aie_buffers.prefill.attn_scores_keys_all.view_as_torch().view( config.n_kv_groups, config.head_dim, -1 ) - keys_buf[:, :, :context_len] = keys.squeeze(0).transpose(-2, -1) # (num_kv_groups, head_dim, context_len) - + keys_buf[:, :, :context_len] = keys.squeeze(0).transpose( + -2, -1 + ) # (num_kv_groups, head_dim, context_len) + # Transfer parent buffers to NPU once aie_buffers.prefill.attn_scores_queries_all.to("npu") aie_buffers.prefill.attn_scores_keys_all.to("npu") aie_buffers.prefill.attn_scores.to("npu") - + # Execute GEMM for each head using sub-buffers for h in range(config.n_heads): kv_group = h // group_size aie_ops.prefill.attn_scores( aie_buffers.prefill.attn_scores_queries_per_head[h], aie_buffers.prefill.attn_scores_keys_per_kv_group[kv_group], - aie_buffers.prefill.attn_scores_per_head[h] + aie_buffers.prefill.attn_scores_per_head[h], ) - + # Read back all results at once from parent buffer and apply scaling on NPU - aie_ops.prefill.attn_scale(aie_buffers.prefill.attn_scores, aie_buffers.prefill.attn_scale_factor, aie_buffers.prefill.attn_scores) + aie_ops.prefill.attn_scale( + aie_buffers.prefill.attn_scores, + aie_buffers.prefill.attn_scale_factor, + aie_buffers.prefill.attn_scores, + ) aie_buffers.prefill.attn_scores.to("cpu") # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice max_seq_len = aie_buffers.prefill.attn_scores.shape[0] // config.n_heads - scores = aie_buffers.prefill.attn_scores.view_as_torch().view(config.n_heads, max_seq_len, max_seq_len).unsqueeze(0)[:, :, :seq_len, :context_len] - + scores = ( + aie_buffers.prefill.attn_scores.view_as_torch() + .view(config.n_heads, max_seq_len, max_seq_len) + .unsqueeze(0)[:, :, :seq_len, :context_len] + ) + # Step 7: Apply mask # This ensures causality, so that tokens in the future cannot attend to tokens in the past. if mask is not None: - scores = scores.masked_fill(mask, float('-inf')) - + scores = scores.masked_fill(mask, float("-inf")) + # Step 8: Apply softmax on CPU scores = torch.softmax(scores.to(torch.float32), dim=-1).to(torch.bfloat16) attention_weights = scores - + # Step 9: Compute attention output # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) # -> (batch, num_heads, seq_len, head_dim) context = torch.matmul(attention_weights, values) - + # Step 10: Concatenate heads and project # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - - output = torch.nn.functional.linear(context, config.weights[f'model.layers.{layer_idx}.self_attn.o_proj.weight']) - + + output = torch.nn.functional.linear( + context, config.weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] + ) + return output, keys_cache, values_cache def swiglu_ffn_forward_prefill(layer_idx): # Step 1: Gate projection - aie_ops.prefill.ffn_up_gate(aie_buffers.prefill.x_norm, aie_buffers.W_ffn_gate_prefill[layer_idx], aie_buffers.prefill.ffn_gate) - + aie_ops.prefill.ffn_up_gate( + aie_buffers.prefill.x_norm, + aie_buffers.W_ffn_gate_prefill[layer_idx], + aie_buffers.prefill.ffn_gate, + ) + # Step 2: Up projection - aie_ops.prefill.ffn_up_gate(aie_buffers.prefill.x_norm, aie_buffers.W_ffn_up_prefill[layer_idx], aie_buffers.prefill.ffn_up) - + aie_ops.prefill.ffn_up_gate( + aie_buffers.prefill.x_norm, + aie_buffers.W_ffn_up_prefill[layer_idx], + aie_buffers.prefill.ffn_up, + ) + # Step 3: Apply SiLU activation aie_ops.prefill.ffn_silu(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_gate) - + # Step 4: Element-wise multiplication - aie_ops.prefill.eltwise_mul_ffn(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_up, aie_buffers.prefill.ffn_hidden) - + aie_ops.prefill.eltwise_mul_ffn( + aie_buffers.prefill.ffn_gate, + aie_buffers.prefill.ffn_up, + aie_buffers.prefill.ffn_hidden, + ) + # Step 5: Down projection - aie_ops.prefill.ffn_down(aie_buffers.prefill.ffn_hidden, aie_buffers.W_ffn_down_prefill[layer_idx], aie_buffers.prefill.ffn_output) + aie_ops.prefill.ffn_down( + aie_buffers.prefill.ffn_hidden, + aie_buffers.W_ffn_down_prefill[layer_idx], + aie_buffers.prefill.ffn_output, + ) def transformer_block_forward_prefill( - config, - seq_len, - layer_idx, - attn_keys_cache, - attn_values_cache, - attn_mask + config, seq_len, layer_idx, attn_keys_cache, attn_values_cache, attn_mask ): # Step 1: RMS normalization - aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_norm1[layer_idx], aie_buffers.prefill.x_norm) + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, + aie_buffers.W_norm1[layer_idx], + aie_buffers.prefill.x_norm, + ) aie_buffers.prefill.x_norm.to("cpu") x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] @@ -830,74 +1141,96 @@ def transformer_block_forward_prefill( layer_idx, attn_mask, ) - + # Step 3: Residual - aie_buffers.prefill.attn_output.view_as_torch().unsqueeze(0)[0, :seq_len, :] = attn_output - aie_ops.prefill.residual_add(aie_buffers.prefill.x, aie_buffers.prefill.attn_output, aie_buffers.prefill.x) + aie_buffers.prefill.attn_output.view_as_torch().unsqueeze(0)[ + 0, :seq_len, : + ] = attn_output + aie_ops.prefill.residual_add( + aie_buffers.prefill.x, aie_buffers.prefill.attn_output, aie_buffers.prefill.x + ) x = aie_buffers.prefill.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] - + # Step 4: Post-norm aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_norm2[layer_idx], aie_buffers.prefill.x_norm) + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, + aie_buffers.W_norm2[layer_idx], + aie_buffers.prefill.x_norm, + ) aie_buffers.prefill.x_norm.to("cpu") x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - + # Step 5: Feed-forward network swiglu_ffn_forward_prefill(layer_idx) - + # Step 6: Residual - aie_ops.prefill.residual_add(aie_buffers.prefill.x, aie_buffers.prefill.ffn_output, aie_buffers.prefill.x) - + aie_ops.prefill.residual_add( + aie_buffers.prefill.x, aie_buffers.prefill.ffn_output, aie_buffers.prefill.x + ) + return attn_keys, attn_values -def llama_forward_pass_prefill( - config, - state -): +def llama_forward_pass_prefill(config, state): batch, seq_len = state.token_ids.shape - + # Step 1: RoPE angles num_preceding_tokens = state.attn_keys_caches[0].shape[2] angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice # Step 2: Token embedding - tok_emb_weight = config.weights['model.embed_tokens.weight'] + tok_emb_weight = config.weights["model.embed_tokens.weight"] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) attn_mask = torch.triu( - torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), - diagonal=1 + torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x # Step 3: Transformer blocks for layer_idx in range(config.n_layers): - state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = transformer_block_forward_prefill( - config, - seq_len, - layer_idx, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], - attn_mask=attn_mask, + state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( + transformer_block_forward_prefill( + config, + seq_len, + layer_idx, + state.attn_keys_caches[layer_idx], + state.attn_values_caches[layer_idx], + attn_mask=attn_mask, + ) ) # Step 4: Final normalization - aie_ops.prefill.rms_norm(aie_buffers.prefill.x, aie_buffers.W_final_norm, aie_buffers.prefill.x) - + aie_ops.prefill.rms_norm( + aie_buffers.prefill.x, aie_buffers.W_final_norm, aie_buffers.prefill.x + ) + # Step 5: Output projection for i in range(config.vocab_partitions): - aie_ops.prefill.out_head(aie_buffers.prefill.x, aie_buffers.W_out_head_parts[i], aie_buffers.prefill.logits_parts[i]) + aie_ops.prefill.out_head( + aie_buffers.prefill.x, + aie_buffers.W_out_head_parts[i], + aie_buffers.prefill.logits_parts[i], + ) aie_buffers.prefill.logits.to("cpu") logits_padded_partitioned = aie_buffers.prefill.logits.view_as_torch() - logits_padded = logits_padded_partitioned.transpose(0, 1).contiguous().view(-1, config.padded_vocab_size) - logits = logits_padded.unsqueeze(0)[:,:seq_len,:config.vocab_size] + logits_padded = ( + logits_padded_partitioned.transpose(0, 1) + .contiguous() + .view(-1, config.padded_vocab_size) + ) + logits = logits_padded.unsqueeze(0)[:, :seq_len, : config.vocab_size] # Step 6: Initialize per-layer NPU cache buffers with current cache state for decode phase for layer_idx in range(config.n_layers): cache_len = state.attn_keys_caches[layer_idx].shape[2] - aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :cache_len, :] = state.attn_keys_caches[layer_idx].squeeze(0) - aie_buffers.values_cache[layer_idx].view_as_torch()[:, :cache_len, :] = state.attn_values_caches[layer_idx].squeeze(0) + aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( + state.attn_keys_caches[layer_idx].squeeze(0) + ) + aie_buffers.values_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( + state.attn_values_caches[layer_idx].squeeze(0) + ) aie_buffers.keys_cache[layer_idx].to("npu") aie_buffers.values_cache[layer_idx].to("npu") @@ -907,20 +1240,18 @@ def llama_forward_pass_prefill( # Decode # ########################################################################## + def patch_fused_decode_operator(ops, config, num_preceding_tokens): context_len = num_preceding_tokens + 1 - + # Patch fused operator for strided copy cache offset output_offset = num_preceding_tokens * config.head_dim offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset - strided_copy_patches = { + strided_copy_patches = { i: (base + offset_val, 0xFFFFFFFF) for i, base in ops.fused_patch_locations.items() } - softmax_patches = { - i: (context_len, 0xFFFFFFFF) - for i in ops.softmax_patch_offsets - } + softmax_patches = {i: (context_len, 0xFFFFFFFF) for i in ops.softmax_patch_offsets} patches = {**strided_copy_patches, **softmax_patches} patched_elf_data = ops.fused_elf_data.copy() patch_elf(patched_elf_data, patches) @@ -930,25 +1261,35 @@ def patch_fused_decode_operator(ops, config, num_preceding_tokens): def llama_forward_pass_decode(config, state): batch, seq_len = state.token_ids.shape - assert seq_len == 1 + assert seq_len == 1 assert state.num_preceding_tokens < max_seq_len patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) # Prefill RoPE angle look-up tables - angles_slice = config.angles[state.num_preceding_tokens : state.num_preceding_tokens + seq_len] - aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[:] = angles_slice + angles_slice = config.angles[ + state.num_preceding_tokens : state.num_preceding_tokens + seq_len + ] + aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[ + : + ] = angles_slice # Token embedding (on CPU) - tok_emb_weight = config.weights['model.embed_tokens.weight'] + tok_emb_weight = config.weights["model.embed_tokens.weight"] x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) - aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[:seq_len, :] = x + aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[ + :seq_len, : + ] = x # Fused NPU operator for all of decode (16 transformer blocks + final norm + final linear layer) aie_ops.decode.fused.input_buffer.to("cpu") aie_ops.decode.fused() aie_ops.decode.fused.output_buffer.to("cpu") - logits = aie_ops.decode.fused.get_buffer("logits").view_as_torch().view(1, 1, config.vocab_size) + logits = ( + aie_ops.decode.fused.get_buffer("logits") + .view_as_torch() + .view(1, 1, config.vocab_size) + ) return logits, state @@ -957,20 +1298,25 @@ def llama_forward_pass_decode(config, state): # ########################################################################## -def llama_forward_pass( - config, - state -): +def llama_forward_pass(config, state): global aie_ops, aie_buffers batch, seq_len = state.token_ids.shape if seq_len > 1: ret = llama_forward_pass_prefill(config, state) state.num_preceding_tokens = state.token_ids.shape[1] - # Pass KV cache data onto fused decode operator + # Pass KV cache data onto fused decode operator for layer_idx in range(config.n_layers): - aie_ops.decode.fused.get_buffer(f"keys_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() - aie_ops.decode.fused.get_buffer(f"values_cache_{layer_idx}").to("cpu").view_as_torch()[:] = aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + aie_ops.decode.fused.get_buffer(f"keys_cache_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = ( + aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() + ) + aie_ops.decode.fused.get_buffer(f"values_cache_{layer_idx}").to( + "cpu" + ).view_as_torch()[:] = ( + aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() + ) aie_ops.decode.fused.scratch_buffer.to("cpu") return ret else: @@ -982,17 +1328,18 @@ def llama_forward_pass( def main(): global aie_ops, aie_buffers prompt = "The capital of France is " - #with open('prompt.txt', 'r') as f: + # with open('prompt.txt', 'r') as f: # prompt = f.read() - #prompt = prompt[:max_seq_len] + # prompt = prompt[:max_seq_len] config, state = harness.init(prompt=prompt) aie_ops = AIELlamaOperators(config, max_seq_len) aie_buffers = AIELlamaBuffers(config, max_seq_len) - print(prompt, end='', flush=True) + print(prompt, end="", flush=True) harness.generate(config, state, llama_forward_pass, use_kv_cache=True) + if __name__ == "__main__": main() diff --git a/operators/__init__.py b/operators/__init__.py index 2696cbf1..5d55514b 100644 --- a/operators/__init__.py +++ b/operators/__init__.py @@ -1,26 +1,30 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -#from .axpy.op import AIEAXPY -#from .dequant.op import AIEDequant +# from .axpy.op import AIEAXPY +# from .dequant.op import AIEDequant from .elementwise_add.op import AIEElementwiseAdd from .elementwise_mul.op import AIEElementwiseMul -#from .gelu.op import AIEGELU + +# from .gelu.op import AIEGELU from .gemm.op import AIEGEMM from .gemv.op import AIEGEMV -#from .layer_norm.op import AIELayerNorm -#from .leaky_relu.op import AIELeakyReLU -#from .mem_copy.op import AIEMemCopy -#from .mha.op import AIEMHA -#from .relu.op import AIEReLU + +# from .layer_norm.op import AIELayerNorm +# from .leaky_relu.op import AIELeakyReLU +# from .mem_copy.op import AIEMemCopy +# from .mha.op import AIEMHA +# from .relu.op import AIEReLU from .rms_norm.op import AIERMSNorm from .rope.op import AIERope -#from .sigmoid.op import AIESigmoid + +# from .sigmoid.op import AIESigmoid from .silu.op import AIESiLU from .softmax.op import AIESoftmax -#from .swiglu_decode.op import AIESwiGLUDecode -#from .swiglu_prefill.op import AIESwiGLUPrefill -#from .tanh.op import AIETanh + +# from .swiglu_decode.op import AIESwiGLUDecode +# from .swiglu_prefill.op import AIESwiGLUPrefill +# from .tanh.op import AIETanh from .transpose.op import AIETranspose from .strided_copy.op import AIEStridedCopy from .repeat.op import AIERepeat diff --git a/operators/common/__init__.py b/operators/common/__init__.py index 9351fd68..89118abf 100644 --- a/operators/common/__init__.py +++ b/operators/common/__init__.py @@ -4,7 +4,7 @@ """Common utilities and base classes for IRON operators.""" from .base import ( - AIEOperatorBase, + AIEOperatorBase, SingleMLIRSourceOperator, AIEBuffer, SingleXclbinCallable, diff --git a/operators/common/base.py b/operators/common/base.py index 71734992..7688c249 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -45,7 +45,7 @@ def set_up_artifacts(self): Compilation will be handled automatically based on the provided description. """ pass - + @abstractmethod def get_arg_spec(self): pass @@ -67,7 +67,12 @@ def compile(self, dry_run=False): Subclasses are expected to overwrite set_up(); they may register any artifacts that they need to be compiled there. """ self.set_up_artifacts() - comp.compile(self.context.compilation_rules, self.artifacts, self.context.build_dir, dry_run=dry_run) + comp.compile( + self.context.compilation_rules, + self.artifacts, + self.context.build_dir, + dry_run=dry_run, + ) return self def add_artifacts(self, artifacts): @@ -92,6 +97,7 @@ def execute_runlist(runlist): class SingleMLIRSourceOperator(AIEOperatorBase, ABC): """Base class for AIE-accelerated operations""" + def __init__(self, *args, **kwargs): self.kernel_archive = f"{self.get_operator_name()}_kernels.a" AIEOperatorBase.__init__(self, *args, **kwargs) @@ -99,15 +105,15 @@ def __init__(self, *args, **kwargs): @abstractmethod def get_operator_name(self): pass - + @abstractmethod def get_mlir_artifact(self): pass - + @abstractmethod def get_kernel_artifacts(self): pass - + def get_artifacts(self): operator_name = self.get_operator_name() mlir_artifact = self.get_mlir_artifact() @@ -116,38 +122,43 @@ def get_artifacts(self): # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels # Also not handling name collisions of kernels with the same name mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive - kernel_deps = [ + kernel_deps = ( + [ KernelArchiveArtifact( - self.kernel_archive, - dependencies=kernel_deps_inputs, - ) - ] if kernel_deps_inputs else [] + self.kernel_archive, + dependencies=kernel_deps_inputs, + ) + ] + if kernel_deps_inputs + else [] + ) xclbin_artifact = XclbinArtifact( f"{operator_name}.xclbin", mlir_input=mlir_artifact, dependencies=[mlir_artifact] + kernel_deps, ) insts_artifact = InstsBinArtifact( - f"{operator_name}.bin", + f"{operator_name}.bin", mlir_input=mlir_artifact, - dependencies=[mlir_artifact] + dependencies=[mlir_artifact], ) return xclbin_artifact, insts_artifact - + def set_up_artifacts(self): xclbin_artifact, insts_artifact = self.get_artifacts() self.xclbin_artifact = xclbin_artifact self.insts_artifact = insts_artifact self.add_artifacts([xclbin_artifact, insts_artifact]) - + def get_callable(self): return SingleXclbinCallable( xclbin_path=self.xclbin_artifact.filename, kernel_name=self.xclbin_artifact.kernel_name, insts_bin_path=self.insts_artifact.filename, - args_spec=self.get_arg_spec() + args_spec=self.get_arg_spec(), ) - + + class AIERuntimeArgSpec: def __init__(self, direction, shape, dtype=bfloat16): self.shape = shape @@ -155,6 +166,7 @@ def __init__(self, direction, shape, dtype=bfloat16): assert direction in {"in", "out", "inout"} self.direction = direction + class AIEBuffer: def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): size = np.prod(shape) * np.dtype(dtype).itemsize @@ -172,28 +184,38 @@ def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): ) self.memory_view = self.bo.map() self.subviews = [] - + def subbuffer(self, length, offset, shape, dtype=None): if dtype is None: dtype = self.dtype assert np.prod(shape) == length itemsize = np.dtype(dtype).itemsize - assert offset >= 0 + assert offset >= 0 assert offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize - assert length * itemsize + offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + assert ( + length * itemsize + offset * itemsize + <= np.prod(self.shape) * np.dtype(self.dtype).itemsize + ) sub_bo = pyxrt.bo( - self.bo, # parent bo - length * itemsize, # size - offset * itemsize, # offset + self.bo, # parent bo + length * itemsize, # size + offset * itemsize, # offset + ) + sub_buffer = AIEBuffer( + shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager ) - sub_buffer = AIEBuffer(shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager) sub_buffer.on = self.on self.subviews.append(sub_buffer) return sub_buffer - + def view(self, shape): assert np.prod(shape) == np.prod(self.shape) - sub_buffer = AIEBuffer(shape=shape, dtype=self.dtype, bo=self.bo, device_manager=self.device_manager) + sub_buffer = AIEBuffer( + shape=shape, + dtype=self.dtype, + bo=self.bo, + device_manager=self.device_manager, + ) sub_buffer.on = self.on self.subviews.append(sub_buffer) return sub_buffer @@ -201,15 +223,17 @@ def view(self, shape): def view_as_np(self): self.to("cpu") # Interpret the buffer as a 1-dimensional array then change its view to the expected shape - return np.frombuffer(self.memory_view, dtype=self.dtype, count=np.prod(self.shape)).reshape(self.shape) + return np.frombuffer( + self.memory_view, dtype=self.dtype, count=np.prod(self.shape) + ).reshape(self.shape) def view_as_torch(self): return numpy_to_torch(self.view_as_np()) - + def to(self, dest): direction = { "npu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE, - "cpu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE + "cpu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE, } if dest not in direction: raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") @@ -224,7 +248,7 @@ def to(self, dest): sub_buffer.on = self.on todo.extend(sub_buffer.subviews) return self - + @staticmethod def from_np(buffer): shape = buffer.shape @@ -234,14 +258,16 @@ def from_np(buffer): aie_buffer.view_as_np()[:] = buffer aie_buffer.to("npu") return aie_buffer - + @staticmethod def from_torch(tensor): return AIEBuffer.from_np(torch_to_numpy(tensor)) class SingleXclbinCallable: - def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): + def __init__( + self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None + ): self.device_manager = device_manager or AIEDeviceManager() self.context, self.xrt_kernel = self.device_manager.get_context_and_kernel( str(xclbin_path), kernel_name @@ -255,33 +281,40 @@ def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_m self.xrt_kernel.group_id(1), ) insts_bo.write(instructions.view(np.uint8), 0) - self.insts_buffer = AIEBuffer(shape=(len(instructions),), dtype=np.uint32, bo=insts_bo) + self.insts_buffer = AIEBuffer( + shape=(len(instructions),), dtype=np.uint32, bo=insts_bo + ) self.insts_buffer.to("npu") self.args_spec = args_spec - + def __call__(self, *buffers): assert len(buffers) == len(self.args_spec) - #assert all( + # assert all( # np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype # for i in range(len(buffers)) - #), "Input buffer shapes or dtypes do not match expected argument specification." + # ), "Input buffer shapes or dtypes do not match expected argument specification." self.insts_buffer.to("npu") for buf in buffers: buf.to("npu") opcode = 3 bos = [buffer.bo for buffer in buffers] - run = self.xrt_kernel(opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos) + run = self.xrt_kernel( + opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos + ) ret_code = run.wait() if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError( - f"Kernel did not complete correctly: {ret_code}" - ) + raise RuntimeError(f"Kernel did not complete correctly: {ret_code}") + class PatchableSingleXclbinCallable(SingleXclbinCallable): - def __init__(self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None): - super().__init__(xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager) + def __init__( + self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None + ): + super().__init__( + xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager + ) self.baseline_instructions = self.insts_buffer.view_as_np().copy() - + def patch(self, patches): """Apply patches with masking: dict of {position: (value, mask)}.""" insts = self.insts_buffer.view_as_np() diff --git a/operators/common/compilation/base.py b/operators/common/compilation/base.py index eb20318d..8631ab4b 100644 --- a/operators/common/compilation/base.py +++ b/operators/common/compilation/base.py @@ -5,7 +5,7 @@ This file implements a simple Python-based build system. You specify what you want to compile (*artifacts*) through subclasses of `CompilationArtifact`. Multiple `CompilationArtifacts` form a `CompilationArtifactGraph`. Each artifact -can have a list (subgraph) of depenencies of other artifacts that it relies on. +can have a list (subgraph) of depenencies of other artifacts that it relies on. Each artifact corresponds to exactly one file. There is a special artifact for source files that do not need to get generated, @@ -14,15 +14,15 @@ You specify how to generate (compile) an artifact through *rules*, which are expressed as subclasses of `CompilationRule`. Rules must implement two methods: -`matches` and `compile`. If a rule `matches` to an artifact graph, it can be -applied. Applying a rule is done by calling `compile`; this transforms the -artifact graph (in the simplest case, marks one of the artifacts as available) +`matches` and `compile`. If a rule `matches` to an artifact graph, it can be +applied. Applying a rule is done by calling `compile`; this transforms the +artifact graph (in the simplest case, marks one of the artifacts as available) and returns a list of compilation commands. At this point, we can print the compilation commands to the console (dry-run) or actually run them to generate the artifacts. -Before starting compilation, you may call +Before starting compilation, you may call `populate_availability_from_filesystem()` -- this will check if any artifacts are already available at the given file paths (and ensure that dependencies are as old or older than the artifacts that depend on them). This way, you can avoid @@ -42,7 +42,6 @@ from aie.extras.context import mlir_mod_ctx import sys - # Global Functions # ########################################################################## @@ -83,7 +82,6 @@ def compile(rules, artifacts, build_dir="build", dry_run=False): print("\n".join("\n".join(map(str, cmds)) for _, cmds in plan_steps)) - # Compilation Artifact Graph # ########################################################################## @@ -100,28 +98,28 @@ def format_artifact(artifact, indent=0): for dep in artifact.dependencies: result += format_artifact(dep, indent + 1) return result - + result = "CompilationArtifactGraph(\n" for artifact in self.artifacts: result += format_artifact(artifact, indent=1) result += ")" return result - + def __iter__(self): return iter(self.artifacts) - + def __len__(self): return len(self.artifacts) - + def __getitem__(self, index): return self.artifacts[index] - + def dfs(self): return self._traverse(True) - + def bfs(self): return self._traverse(False) - + def _traverse(self, dfs): visited = set() todo = self.artifacts.copy() @@ -140,19 +138,19 @@ def replace(self, old_artifact, new_artifact): else: artifact.dependencies.replace(old_artifact, new_artifact) return self - + def populate_availability_from_filesystem(self): for artifact in self.artifacts: artifact.dependencies.populate_availability_from_filesystem() artifact.available = artifact.is_available_in_filesystem() - + def get_worklist(self, kind): """Return a list of artifacts of the given kind that can be built in the next step (dependencies available).""" return [ artifact for artifact in self.bfs() - if isinstance(artifact, kind) - and not artifact.is_available() + if isinstance(artifact, kind) + and not artifact.is_available() and artifact.dependencies_available() ] @@ -161,7 +159,7 @@ def move_artifacts(self, new_root): for artifact in self.bfs(): if not os.path.isabs(artifact.filename): artifact.filename = str(Path(new_root) / Path(artifact.filename).name) - + def add(self, artifact): self.artifacts.append(artifact) @@ -173,7 +171,9 @@ def add(self, artifact): class CompilationArtifact(ABC): def __init__(self, filename, dependencies=None, available=False): self.filename = str(filename) - self.dependencies: CompilationArtifactGraph = CompilationArtifactGraph(artifacts=dependencies if dependencies is not None else []) + self.dependencies: CompilationArtifactGraph = CompilationArtifactGraph( + artifacts=dependencies if dependencies is not None else [] + ) self.available = available def __repr__(self): @@ -183,23 +183,27 @@ def is_available(self): """'Conceptual' availability: during a dry-run or in the planning stage, available may be True even if the underlying file does not exist yet.""" # If any of our dependencies' dependencies are outdated, this artifact is also outdated return self.available and self.dependencies_available() - + def dependencies_available(self): return all(d.is_available() for d in self.dependencies) - + def is_available_in_filesystem(self): """'Real' availability: checks if the underlying file exists and is up-to-date with respect to dependencies.""" if not os.path.exists(self.filename): return False file_mtime = os.path.getmtime(self.filename) for dependency in self.dependencies: - if not dependency.is_available_in_filesystem() or os.path.getmtime(dependency.filename) > file_mtime: + if ( + not dependency.is_available_in_filesystem() + or os.path.getmtime(dependency.filename) > file_mtime + ): return False return True class SourceArtifact(CompilationArtifact): """Artifact representing a source file that does not need to be generated, is assumed to be there.""" + pass @@ -213,7 +217,13 @@ def __init__(self, filename, mlir_input, dependencies): class XclbinArtifact(CompilationArtifact): def __init__( - self, filename, mlir_input, dependencies, kernel_name="MLIR_AIE", extra_flags=None, xclbin_input=None + self, + filename, + mlir_input, + dependencies, + kernel_name="MLIR_AIE", + extra_flags=None, + xclbin_input=None, ): if mlir_input not in dependencies: dependencies = dependencies + [mlir_input] @@ -234,12 +244,20 @@ def __init__(self, filename, mlir_input, dependencies, extra_flags=None): class KernelObjectArtifact(CompilationArtifact): - def __init__(self, filename, dependencies, extra_flags=None, rename_symbols=None, prefix_symbols=None): + def __init__( + self, + filename, + dependencies, + extra_flags=None, + rename_symbols=None, + prefix_symbols=None, + ): super().__init__(filename, dependencies) self.extra_flags = extra_flags if extra_flags is not None else [] self.rename_symbols = rename_symbols if rename_symbols is not None else {} self.prefix_symbols = prefix_symbols + class KernelArchiveArtifact(CompilationArtifact): pass @@ -254,7 +272,7 @@ def __init__( callback_kwargs=None, requires_context=False, uses_kernel_archive=False, - kernel_archive=None + kernel_archive=None, ): self.import_path = import_path self.callback_fn = callback_fn @@ -271,20 +289,21 @@ def __init__( class CompilationCommand(ABC): """An abstraction for anything that can be executed to physically produce artifacts.""" + @abstractmethod def run(self) -> bool: pass - + @abstractmethod def __repr__(self): pass class ShellCompilationCommand(CompilationCommand): - def __init__(self, command: list[str], cwd=None, env='copy'): + def __init__(self, command: list[str], cwd=None, env="copy"): self.command = command self.cwd = cwd - if env == 'copy': + if env == "copy": env = os.environ.copy() self.env = env @@ -328,11 +347,9 @@ class CompilationRule(ABC): def matches(self, artifact: CompilationArtifactGraph) -> bool: """Return true if this rule can be applied to any artifact in the artifact graph.""" pass - + @abstractmethod - def compile( - self, artifacts: CompilationArtifactGraph - ) -> list[CompilationCommand]: + def compile(self, artifacts: CompilationArtifactGraph) -> list[CompilationCommand]: """Apply this rule to the artifact graph, returning compilation commands. This should modify the artifact graph in-place to reflect the newly generated artifacts.""" pass @@ -348,14 +365,28 @@ def compile(self, graph): for artifact in worklist: new_artifact = SourceArtifact(artifact.filename) # To make Python capture variables in this closure by value, not by reference, use default arguments - callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir(new_artifact, import_path, callback_fn, callback_args, callback_kwargs, requires_context) + callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir( + new_artifact, + import_path, + callback_fn, + callback_args, + callback_kwargs, + requires_context, + ) commands.append(PythonCallbackCompilationCommand(callback)) new_artifact.available = True graph.replace(artifact, new_artifact) return commands - - @staticmethod - def generate_mlir(output_artifact, import_path, callback_fn, callback_args=None, callback_kwargs=None, requires_context=False): + + @staticmethod + def generate_mlir( + output_artifact, + import_path, + callback_fn, + callback_args=None, + callback_kwargs=None, + requires_context=False, + ): # Import the Python source file spec = importlib.util.spec_from_file_location( Path(import_path).name, import_path @@ -363,14 +394,10 @@ def generate_mlir(output_artifact, import_path, callback_fn, callback_args=None, module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - ctx_callback = lambda: ( - mlir_mod_ctx() if requires_context else nullcontext() - ) + ctx_callback = lambda: (mlir_mod_ctx() if requires_context else nullcontext()) with ctx_callback() as ctx: callback_function = getattr(module, callback_fn) - mlir_code = callback_function( - *callback_args, **callback_kwargs - ) + mlir_code = callback_function(*callback_args, **callback_kwargs) # Stringify the generated MLIR if requires_context: mlir_code = str(ctx.module) @@ -392,11 +419,11 @@ def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): class AieccFullElfCompilationRule(AieccCompilationRule): def matches(self, graph): return any(graph.get_worklist(FullElfArtifact)) - + def compile(self, graph): worklist = graph.get_worklist(FullElfArtifact) commands = [] - + for artifact in worklist: compile_cmd = [ "python", @@ -413,9 +440,11 @@ def compile(self, graph): os.path.abspath(artifact.filename), os.path.abspath(artifact.mlir_input.filename), ] - commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) artifact.available = True - + return commands @@ -463,7 +492,8 @@ def compile(self, graph): ] if first_xclbin.xclbin_input is not None: compile_cmd += [ - "--xclbin-input=" + os.path.abspath(first_xclbin.xclbin_input.filename) + "--xclbin-input=" + + os.path.abspath(first_xclbin.xclbin_input.filename) ] if do_compile_insts_bin: first_insts_bin = mlir_sources_to_insts[mlir_source][ @@ -477,15 +507,21 @@ def compile(self, graph): ] compile_cmd += [os.path.abspath(mlir_source.filename)] - commands.append(ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir))) + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts]: if sources_to.get(mlir_source, [])[1:]: copy_src = sources_to[mlir_source][0] for copy_dest in sources_to[mlir_source][1:]: - commands.append(ShellCompilationCommand(['cp', copy_src.filename, copy_dest.filename])) - + commands.append( + ShellCompilationCommand( + ["cp", copy_src.filename, copy_dest.filename] + ) + ) + # Update graph for artifact in worklist: artifact.available = True @@ -556,26 +592,27 @@ def _rename_symbols(self, artifact): ] cmd += [artifact.filename] return [ShellCompilationCommand(cmd)] - + def _prefix_symbols(self, artifact, prefix): objcopy_path = "llvm-objcopy-18" nm_path = "llvm-nm-18" symbol_map_file = artifact.filename + ".symbol_map" - + # Extract defined symbols and create symbol map nm_cmd = [ - "sh", "-c", + "sh", + "-c", f"{nm_path} --defined-only --extern-only {artifact.filename} | " - f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}" + f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}", ] - + # Apply the renaming using the symbol map objcopy_cmd = [ objcopy_path, "--redefine-syms=" + symbol_map_file, artifact.filename, ] - + return [ShellCompilationCommand(nm_cmd), ShellCompilationCommand(objcopy_cmd)] @@ -616,15 +653,16 @@ def compile(self, artifacts): cmd = [str(ar_path), "rcs", archive_path] + object_files commands.append(ShellCompilationCommand(cmd)) - + # Check for duplicate symbol definitions in the archive check_cmd = [ - "sh", "-c", + "sh", + "-c", f"nm {archive_path} | grep ' [TDR] ' | awk '{{print $3}}' | sort | uniq -d | " - f"if read sym; then echo \"Error: Duplicate symbol in archive: $sym\" >&2; exit 1; fi" + f'if read sym; then echo "Error: Duplicate symbol in archive: $sym" >&2; exit 1; fi', ] commands.append(ShellCompilationCommand(check_cmd)) - + artifact.available = True return commands diff --git a/operators/common/context.py b/operators/common/context.py index fc7880f6..1cde7087 100644 --- a/operators/common/context.py +++ b/operators/common/context.py @@ -29,8 +29,12 @@ def __init__(self, use_runlist=True, build_dir=None): comp.GenerateMLIRFromPythonCompilationRule(), comp.PeanoCompilationRule(self.peano_dir, self.mlir_aie_dir), comp.ArchiveCompilationRule(self.peano_dir), - comp.AieccXclbinInstsCompilationRule(self.build_dir, self.peano_dir, self.mlir_aie_dir), - comp.AieccFullElfCompilationRule(self.build_dir, self.peano_dir, self.mlir_aie_dir), + comp.AieccXclbinInstsCompilationRule( + self.build_dir, self.peano_dir, self.mlir_aie_dir + ), + comp.AieccFullElfCompilationRule( + self.build_dir, self.peano_dir, self.mlir_aie_dir + ), ] def register_operator(self, operator): diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 9f87076f..42e7db61 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -12,23 +12,28 @@ class FusedMLIROperator(AIEOperatorBase): """Operator that fuses multiple SingleMLIRSourceOperators into one.""" - - def __init__(self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs): + + def __init__( + self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs + ): assert all( - isinstance(op, SingleMLIRSourceOperator) and all(isinstance(buf, str) for buf in bufs) + isinstance(op, SingleMLIRSourceOperator) + and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist ) self.runlist = runlist self.name = name self.input_args = input_args self.output_args = output_args - self.explicit_buffer_sizes = buffer_sizes or {} # Optional dict: buffer_name -> size_in_bytes + self.explicit_buffer_sizes = ( + buffer_sizes or {} + ) # Optional dict: buffer_name -> size_in_bytes self.kernel_archive = "kernels.a" super().__init__(*args, **kwargs) - + def get_operator_name(self): return self.name - + def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators.""" kernel_artifacts = [] @@ -43,13 +48,13 @@ def get_kernel_artifacts(self): obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts - + def get_mlir_artifact(self): # Build operator_mlir_map: {op_name -> PythonGeneratedMLIRArtifact} operator_mlir_map = {} mlir_dependencies = [] comp_runlist = [] - op_names = {} # op -> op_name + op_names = {} # op -> op_name unique_operators = [] for op, *_ in self.runlist: @@ -65,13 +70,15 @@ def get_mlir_artifact(self): op_name = f"op{idx}_{op.__class__.__name__}" op_names[op] = op_name operator_mlir_map[op_name] = mlir_artifact - + for op, *bufs in self.runlist: comp_runlist.append((op_names[op], *bufs)) - + # Calculate buffer layout: {buffer_name -> (type, offset, length)} - self.subbuffer_layout, self.buffer_sizes, self.slice_info = self._calculate_buffer_layout() - + self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( + self._calculate_buffer_layout() + ) + filename = self.get_operator_name() + "_fused.mlir" fused_artifact = comp.FusedMLIRSource( filename, @@ -79,52 +86,67 @@ def get_mlir_artifact(self): runlist=comp_runlist, subbuffer_layout=self.subbuffer_layout, buffer_sizes=self.buffer_sizes, - slice_info=self.slice_info + slice_info=self.slice_info, ) - + return fused_artifact - + def _calculate_buffer_layout(self): args = {} # base_buffer_name -> args_spec - sliced_buffers = {} # full_buffer_name (with slice) -> (base_name, start, end, args_spec) - + sliced_buffers = ( + {} + ) # full_buffer_name (with slice) -> (base_name, start, end, args_spec) + # Collect all buffer specs from operators for op, *bufs in self.runlist: args_specs = op.get_arg_spec() - assert len(args_specs) == len(bufs), "Number of buffers must match operator argument specification" + assert len(args_specs) == len( + bufs + ), "Number of buffers must match operator argument specification" for i, buf_name in enumerate(bufs): args_spec = args_specs[i] - + # Parse slice notation: "buffer_name[start:end]" - if '[' in buf_name and buf_name.endswith(']'): - base_name = buf_name[:buf_name.index('[')] - slice_part = buf_name[buf_name.index('[')+1:-1] - start, end = map(int, slice_part.split(':')) + if "[" in buf_name and buf_name.endswith("]"): + base_name = buf_name[: buf_name.index("[")] + slice_part = buf_name[buf_name.index("[") + 1 : -1] + start, end = map(int, slice_part.split(":")) sliced_buffers[buf_name] = (base_name, start, end, args_spec) # Track that base buffer exists (size will be set later) - if base_name not in args and base_name not in self.explicit_buffer_sizes: - raise ValueError(f"Sliced buffer '{buf_name}' requires explicit size for base buffer '{base_name}' in buffer_sizes parameter") + if ( + base_name not in args + and base_name not in self.explicit_buffer_sizes + ): + raise ValueError( + f"Sliced buffer '{buf_name}' requires explicit size for base buffer '{base_name}' in buffer_sizes parameter" + ) else: # Regular buffer (no slice) if buf_name not in args: args[buf_name] = args_spec else: - assert np.prod(args[buf_name].shape) == np.prod(args_spec.shape), f"Buffer {buf_name} has conflicting sizes between operators" - + assert np.prod(args[buf_name].shape) == np.prod( + args_spec.shape + ), f"Buffer {buf_name} has conflicting sizes between operators" + # Verify all input/output args are present (either as regular or sliced buffers) all_buffer_names = set(args.keys()) | set(sliced_buffers.keys()) for arg in self.input_args: # Check if it's a base buffer name in explicit_buffer_sizes if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: - raise AssertionError(f"Input argument {arg} not found in runlist buffers") + raise AssertionError( + f"Input argument {arg} not found in runlist buffers" + ) for arg in self.output_args: if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: - raise AssertionError(f"Output argument {arg} not found in runlist buffers") - + raise AssertionError( + f"Output argument {arg} not found in runlist buffers" + ) + # Determine buffer types and create layout subbuffer_layout = {} slice_info = {} # full_buffer_name -> (base_name, start, end) - + def add_buffers(buffer_type, args_list): offset = 0 for arg in args_list: @@ -136,46 +158,62 @@ def add_buffers(buffer_type, args_list): elif arg in args: # Regular buffer with inferred size arg_spec = args[arg] - length = int(np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize) + length = int( + np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize + ) subbuffer_layout[arg] = (buffer_type, offset, length) offset += length # Note: sliced buffers are handled separately, not in args_list return offset # == total length - + # Add sliced buffer entries to layout (they reference parent buffers) for buf_name, (base_name, start, end, args_spec) in sliced_buffers.items(): slice_info[buf_name] = (base_name, start, end) - - input_buffer_size = add_buffers('input', self.input_args) - output_buffer_size = add_buffers('output', self.output_args) - scratch_args = [arg for arg in args if arg not in self.input_args and arg not in self.output_args] + + input_buffer_size = add_buffers("input", self.input_args) + output_buffer_size = add_buffers("output", self.output_args) + scratch_args = [ + arg + for arg in args + if arg not in self.input_args and arg not in self.output_args + ] # Also include explicit buffers that are only used for slicing for explicit_buf in self.explicit_buffer_sizes: - if explicit_buf not in self.input_args and explicit_buf not in self.output_args and explicit_buf not in scratch_args: + if ( + explicit_buf not in self.input_args + and explicit_buf not in self.output_args + and explicit_buf not in scratch_args + ): scratch_args.append(explicit_buf) - scratch_buffer_size = add_buffers('scratch', scratch_args) - + scratch_buffer_size = add_buffers("scratch", scratch_args) + buffer_sizes = (input_buffer_size, output_buffer_size, scratch_buffer_size) return subbuffer_layout, buffer_sizes, slice_info - + def set_up_artifacts(self): operator_name = self.get_operator_name() mlir_artifact = self.get_mlir_artifact() kernel_objects = self.get_kernel_artifacts() - kernel_dep = [comp.KernelArchiveArtifact( - self.kernel_archive, - dependencies=kernel_objects, - )] if kernel_objects else [] + kernel_dep = ( + [ + comp.KernelArchiveArtifact( + self.kernel_archive, + dependencies=kernel_objects, + ) + ] + if kernel_objects + else [] + ) full_elf_artifact = comp.FullElfArtifact( f"{operator_name}.elf", mlir_input=mlir_artifact, dependencies=[mlir_artifact] + kernel_dep, ) self.add_artifacts([full_elf_artifact]) - + def get_arg_spec(self): pass - + def get_callable(self): return FusedFullELFCallable(self) @@ -192,7 +230,7 @@ def get_layout_for_buffer(self, buffer_name): def load_elf(op): assert isinstance(op.artifacts[0], comp.FullElfArtifact) elf_data = None - with open(op.artifacts[0].filename, 'rb') as f: + with open(op.artifacts[0].filename, "rb") as f: elf_data = np.frombuffer(f.read(), dtype=np.uint32) return elf_data @@ -205,7 +243,13 @@ def patch_elf(elf_data, patches): class FullELFCallable: - def __init__(self, elf_data, device_name="main", sequence_name="sequence", device_manager=None): + def __init__( + self, + elf_data, + device_name="main", + sequence_name="sequence", + device_manager=None, + ): self.device_name = device_name self.sequence_name = sequence_name self.device_manager = device_manager or AIEDeviceManager() @@ -220,20 +264,22 @@ def __call__(self, *args): ret_code = run.wait() if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: raise RuntimeError(f"Kernel execution failed with return code {ret_code}") - + def reload_elf(self, elf_data): # Create a PyCapsule from the numpy array pointer for pybind11 elf_data_u8 = elf_data.view(dtype=np.uint8) ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] - capsule = ctypes.pythonapi.PyCapsule_New( - elf_data_u8.ctypes.data, - None, - None - ) + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ] + capsule = ctypes.pythonapi.PyCapsule_New(elf_data_u8.ctypes.data, None, None) xrt_elf = pyxrt.elf(capsule, elf_data.nbytes) xrt_context = pyxrt.hw_context(self.device_manager.device, xrt_elf) - self.xrt_kernel = pyxrt.ext.kernel(xrt_context, f"{self.device_name}:{self.sequence_name}") + self.xrt_kernel = pyxrt.ext.kernel( + xrt_context, f"{self.device_name}:{self.sequence_name}" + ) class FusedFullELFCallable(FullELFCallable): @@ -245,61 +291,63 @@ def __init__(self, op, elf_data=None, device_manager=None): self.op = op input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes itemsize = np.dtype(ml_dtypes.bfloat16).itemsize - + self.input_buffer = AIEBuffer( shape=(max(input_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16 - ) - + dtype=ml_dtypes.bfloat16, + ) + self.output_buffer = AIEBuffer( shape=(max(output_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16 - ) - + dtype=ml_dtypes.bfloat16, + ) + self.scratch_buffer = AIEBuffer( shape=(max(scratch_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16 - ) - + dtype=ml_dtypes.bfloat16, + ) + self._buffer_cache = {} - + def get_buffer(self, buffer_name): # Return cached buffer if already allocated if buffer_name in self._buffer_cache: return self._buffer_cache[buffer_name] - + buf_type, offset, length = self.op.get_layout_for_buffer(buffer_name) - + # Select the appropriate main buffer - if buf_type == 'input': + if buf_type == "input": main_buffer = self.input_buffer - elif buf_type == 'output': + elif buf_type == "output": main_buffer = self.output_buffer - elif buf_type == 'scratch': + elif buf_type == "scratch": main_buffer = self.scratch_buffer else: - raise ValueError(f"Unknown buffer type '{buf_type}' for buffer '{buffer_name}'") - + raise ValueError( + f"Unknown buffer type '{buf_type}' for buffer '{buffer_name}'" + ) + if main_buffer is None: raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") - + # Convert byte offset/length to element offset/length itemsize = np.dtype(ml_dtypes.bfloat16).itemsize offset_elements = offset // itemsize length_elements = length // itemsize - + # Create subbuffer with appropriate shape sub_buffer = main_buffer.subbuffer( length=length_elements, offset=offset_elements, shape=(length_elements,), - dtype=ml_dtypes.bfloat16 + dtype=ml_dtypes.bfloat16, ) - + # Cache and return self._buffer_cache[buffer_name] = sub_buffer return sub_buffer - + def __call__(self): self.input_buffer.to("npu") self.output_buffer.to("npu") diff --git a/operators/common/utils.py b/operators/common/utils.py index f7834e7f..9037fbd8 100644 --- a/operators/common/utils.py +++ b/operators/common/utils.py @@ -11,7 +11,6 @@ import numpy as np from ml_dtypes import bfloat16 - torch_dtype_map = { "bf16": torch.bfloat16, "f32": torch.float32, diff --git a/operators/dequant/reference.py b/operators/dequant/reference.py index a0a853ce..69a736ef 100644 --- a/operators/dequant/reference.py +++ b/operators/dequant/reference.py @@ -5,7 +5,6 @@ import numpy as np from ml_dtypes import bfloat16 - tensor_type_to_quant = {torch.uint8: torch.quint8} diff --git a/operators/elementwise_add/design.py b/operators/elementwise_add/design.py index 4a202e69..246331b7 100644 --- a/operators/elementwise_add/design.py +++ b/operators/elementwise_add/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): +def my_eltwise_add( + dev, + num_elements, + num_columns, + tile_size, + trace_size, + kernel_archive, + func_prefix="", +): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +45,9 @@ def my_eltwise_add(dev, num_elements, num_columns, tile_size, trace_size, kernel # AIE Core Function declaration eltwise_add_bf16_vector = Kernel( - f"{func_prefix}eltwise_add_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_add_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, tile_ty, np.int32], ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_add/op.py b/operators/elementwise_add/op.py index ee6d1c5c..095d3d03 100644 --- a/operators/elementwise_add/op.py +++ b/operators/elementwise_add/op.py @@ -29,7 +29,9 @@ def __init__( num_aie_columns=8, context=None, ): - assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns @@ -53,7 +55,7 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0 + 0, ], ) diff --git a/operators/elementwise_mul/design.py b/operators/elementwise_mul/design.py index 8cae5ac8..51319004 100644 --- a/operators/elementwise_mul/design.py +++ b/operators/elementwise_mul/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): +def my_eltwise_mul( + dev, + num_elements, + num_columns, + tile_size, + trace_size, + kernel_archive, + func_prefix="", +): per_tile_elements = 4096 if tile_size > 4096 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: @@ -37,7 +45,9 @@ def my_eltwise_mul(dev, num_elements, num_columns, tile_size, trace_size, kernel # AIE Core Function declaration eltwise_mul_bf16_vector = Kernel( - f"{func_prefix}eltwise_mul_bf16_vector", kernel_archive, [tile_ty, tile_ty, tile_ty, np.int32] + f"{func_prefix}eltwise_mul_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, tile_ty, np.int32], ) # Define a task that will run on a compute tile diff --git a/operators/elementwise_mul/op.py b/operators/elementwise_mul/op.py index d7557a2e..c36dc4c0 100644 --- a/operators/elementwise_mul/op.py +++ b/operators/elementwise_mul/op.py @@ -22,7 +22,9 @@ def __init__( num_aie_columns=8, context=None, ): - assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns @@ -46,7 +48,7 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0 + 0, ], ) diff --git a/operators/gemm/design.py b/operators/gemm/design.py index 6216079e..230c7991 100644 --- a/operators/gemm/design.py +++ b/operators/gemm/design.py @@ -23,7 +23,6 @@ from aie.helpers.taplib import TensorAccessSequence, TensorTiler2D, TensorAccessPattern from aie.iron.controlflow import range_ - microkernel_mac_dim_map = { "npu": { "bf16": (4, 8, 4), @@ -275,7 +274,11 @@ def my_matmul( # AIE Core Function declarations scalar_suffix = "_scalar" if use_scalar else "" - kernel_archive = f"{func_prefix}gemm_{m}x{k}x{n}_archive.a" if kernel_archive is None else kernel_archive + kernel_archive = ( + f"{func_prefix}gemm_{m}x{k}x{n}_archive.a" + if kernel_archive is None + else kernel_archive + ) if use_larger_internal_buffer: # Fix fifo depth for C objfifo to 1 since 1 buffer will be used for accumulation # and another for transfer to L2 diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 06fed39b..0851d266 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -112,7 +112,7 @@ def get_mlir_artifact(self): }, requires_context=False, ) - + def get_kernel_artifacts(self): base_dir = self.context.base_dir emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( @@ -143,29 +143,28 @@ def get_kernel_artifacts(self): f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.o", extra_flags=kernel_flags, dependencies=[ - SourceArtifact( - base_dir / "aie_kernels" / "aie2p" / "mm.cc" - ) + SourceArtifact(base_dir / "aie_kernels" / "aie2p" / "mm.cc") ], ), KernelObjectArtifact( "convert_copy.o", [ SourceArtifact( - base_dir - / "aie_kernels" - / "generic" - / "convert_copy.cc" + base_dir / "aie_kernels" / "generic" / "convert_copy.cc" ) ], - ) + ), ] - + def get_arg_spec(self): return [ AIERuntimeArgSpec("in", (self.M, self.K)), # input A - AIERuntimeArgSpec("in", (self.K, self.N) if not self.b_col_maj else (self.N, self.K)), # input B (weights) - AIERuntimeArgSpec("out", (self.M, self.N) if not self.c_col_maj else (self.N, self.M)), # output C + AIERuntimeArgSpec( + "in", (self.K, self.N) if not self.b_col_maj else (self.N, self.K) + ), # input B (weights) + AIERuntimeArgSpec( + "out", (self.M, self.N) if not self.c_col_maj else (self.N, self.M) + ), # output C ] # def _get_B_dims(self, B_shape): diff --git a/operators/gemv/design.py b/operators/gemv/design.py index e8cf8e12..6d48aa6d 100644 --- a/operators/gemv/design.py +++ b/operators/gemv/design.py @@ -18,7 +18,6 @@ from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 - """ Matrix-vector design @@ -34,7 +33,17 @@ """ -def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_archive="mv.o", func_prefix=""): +def my_matvec( + dev, + cols, + M, + K, + m_input, + m_output=None, + num_batches=1, + kernel_archive="mv.o", + func_prefix="", +): if m_output is None: m_output = m_input @@ -70,9 +79,7 @@ def my_matvec(dev, cols, M, K, m_input, m_output=None, num_batches=1, kernel_arc L1_B_ty = np.ndarray[(K,), dtype_in] L1_C_ty = np.ndarray[(m_output,), dtype_out] L3_A_ty = np.ndarray[ - ( - num_batches * M * K, - ), + (num_batches * M * K,), dtype_in, ] L3_B_ty = np.ndarray[(num_batches * K,), dtype_in] @@ -174,9 +181,17 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): for batch in range(num_batches): tg_ac = rt.task_group() for col in range(cols): - rt.fill(A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac) + rt.fill( + A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac + ) for col in range(cols): - rt.drain(C_L1L3_fifos[col].cons(), C, C_taps[col][batch], task_group=tg_ac, wait=True) + rt.drain( + C_L1L3_fifos[col].cons(), + C, + C_taps[col][batch], + task_group=tg_ac, + wait=True, + ) rt.finish_task_group(tg_ac) rt.finish_task_group(tg_b) diff --git a/operators/gemv/op.py b/operators/gemv/op.py index 3771a3e3..268f7666 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -48,7 +48,9 @@ def __init__( self.tile_size_output = tile_size_output self.num_batches = num_batches self.kernel_vector_size = kernel_vector_size - assert K >= kernel_vector_size and K % kernel_vector_size == 0, "K must be multiple of kernel_vector_size" + assert ( + K >= kernel_vector_size and K % kernel_vector_size == 0 + ), "K must be multiple of kernel_vector_size" self.xclbin_artifact = None self.insts_artifact = None @@ -73,9 +75,9 @@ def get_mlir_artifact(self): self.tile_size_input, self.tile_size_output, self.num_batches, - ] + ], ) - + def get_kernel_artifacts(self): return [ KernelObjectArtifact( @@ -88,7 +90,7 @@ def get_kernel_artifacts(self): extra_flags=[ f"-DDIM_K={self.K}", f"-DVEC_SIZE={self.kernel_vector_size}", - ] + ], ), ] diff --git a/operators/repeat/design.py b/operators/repeat/design.py index 71b98c5f..a3539caa 100644 --- a/operators/repeat/design.py +++ b/operators/repeat/design.py @@ -2,18 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -#from aie.extras.context import mlir_mod_ctx -#from aie.ir import StridedLayoutAttr, ShapedType -#from aie.dialects.aie import * -#from aie.dialects.aiex import * + +# from aie.extras.context import mlir_mod_ctx +# from aie.ir import StridedLayoutAttr, ShapedType +# from aie.dialects.aie import * +# from aie.dialects.aiex import * from aie.dialects.aiex import TensorAccessPattern from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker from aie.iron.placers import SequentialPlacer - """ Repeat interleave """ + + def repeat(dev, dtype, rows, cols, repeat, transfer_size=None): dtype = np.dtype[dtype] @@ -25,22 +27,33 @@ def repeat(dev, dtype, rows, cols, repeat, transfer_size=None): cols_split = divisor break else: - raise ValueError(f"Cannot split cols={cols} into chunks <= 1023; hardware limits cols to not exceed 1023") + raise ValueError( + f"Cannot split cols={cols} into chunks <= 1023; hardware limits cols to not exceed 1023" + ) assert cols_split <= 1023, "cols is too large, can't split into smaller transfers" if transfer_size is None: transfer_size = cols - - inp_ty = np.ndarray[(rows, cols), dtype,] - out_ty = np.ndarray[(rows * repeat, cols), dtype,] - transfer_ty = np.ndarray[(transfer_size,), dtype,] + + inp_ty = np.ndarray[ + (rows, cols), + dtype, + ] + out_ty = np.ndarray[ + (rows * repeat, cols), + dtype, + ] + transfer_ty = np.ndarray[ + (transfer_size,), + dtype, + ] input_tap = TensorAccessPattern( tensor_dims=(rows, cols), offset=0, sizes=[repeat, rows, cols // cols_split, cols_split], strides=[0, cols, cols_split, 1], - ) + ) output_tap = TensorAccessPattern( tensor_dims=(rows * repeat, cols), diff --git a/operators/repeat/op.py b/operators/repeat/op.py index 88a1d1c9..66e30b20 100644 --- a/operators/repeat/op.py +++ b/operators/repeat/op.py @@ -54,7 +54,7 @@ def get_mlir_artifact(self): self.cols, self.repeat, self.transfer_size, - ] + ], ) def get_kernel_artifacts(self): diff --git a/operators/rms_norm/design.py b/operators/rms_norm/design.py index 4b3a2da4..583ca8f6 100644 --- a/operators/rms_norm/design.py +++ b/operators/rms_norm/design.py @@ -15,7 +15,15 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_rms_norm(dev, num_elements, num_columns, num_channels, trace_size, tile_size, kernel_archive="rms_norm.a"): +def my_rms_norm( + dev, + num_elements, + num_columns, + num_channels, + trace_size, + tile_size, + kernel_archive="rms_norm.a", +): per_tile_elements = 8192 if tile_size > 8192 else tile_size n = per_tile_elements * num_columns if num_elements % n != 0: diff --git a/operators/rms_norm/design_weighted.py b/operators/rms_norm/design_weighted.py index b6457ef3..fab3caac 100644 --- a/operators/rms_norm/design_weighted.py +++ b/operators/rms_norm/design_weighted.py @@ -16,7 +16,14 @@ def my_weighted_rms_norm( - dev, num_elements, num_columns, num_channels, weight_length, trace_size, kernel_archive="rms_norm.a", func_prefix="" + dev, + num_elements, + num_columns, + num_channels, + weight_length, + trace_size, + kernel_archive="rms_norm.a", + func_prefix="", ): per_tile_elements = weight_length total_cores = num_columns # For each core that does rms norm, another core will take its output to do eltwise mul @@ -53,7 +60,9 @@ def my_weighted_rms_norm( # AIE Core Function declaration rms_norm_kernel = Kernel( - f"{func_prefix}rms_norm_bf16_vector", kernel_archive, [tile_ty, tile_ty, np.int32] + f"{func_prefix}rms_norm_bf16_vector", + kernel_archive, + [tile_ty, tile_ty, np.int32], ) eltwise_mul_kernel = Kernel( f"{func_prefix}eltwise_mul_bf16_vector", diff --git a/operators/rope/design.py b/operators/rope/design.py index bbb5ed79..f486071d 100644 --- a/operators/rope/design.py +++ b/operators/rope/design.py @@ -14,7 +14,6 @@ from aie.helpers.dialects.scf import _for as range_ from ml_dtypes import bfloat16 - """ Rotary Positional Encoding (RoPE) design @@ -39,14 +38,16 @@ def rope( trace_size=0, method_type=None, kernel_archive=None, - func_prefix="" + func_prefix="", ): dtype = bfloat16 if angle_rows is None: angle_rows = rows if kernel_archive is None: - kernel_archive = "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" + kernel_archive = ( + "rope" + (f"_{method_type}" if method_type is not None else "") + ".o" + ) assert cols % (16 * 2) == 0 and cols >= ( 16 * 2 diff --git a/operators/rope/op.py b/operators/rope/op.py index 6ed264ff..a3535f75 100644 --- a/operators/rope/op.py +++ b/operators/rope/op.py @@ -25,12 +25,18 @@ def __init__( ): if angle_rows is None: angle_rows = rows - - assert cols % (16 * 2) == 0 and cols >= (16 * 2), "cols must be multiple of 32 and >= 32" + + assert cols % (16 * 2) == 0 and cols >= ( + 16 * 2 + ), "cols must be multiple of 32 and >= 32" assert rows % num_aie_columns == 0, "rows must be divisible by num_aie_columns" - assert angle_rows <= rows and rows % angle_rows == 0, "angle_rows must divide rows" - assert angle_rows >= num_aie_columns and angle_rows % num_aie_columns == 0, "angle_rows must be divisible by num_aie_columns" - + assert ( + angle_rows <= rows and rows % angle_rows == 0 + ), "angle_rows must divide rows" + assert ( + angle_rows >= num_aie_columns and angle_rows % num_aie_columns == 0 + ), "angle_rows must be divisible by num_aie_columns" + self.rows = rows self.cols = cols self.angle_rows = angle_rows @@ -77,7 +83,25 @@ def get_kernel_artifacts(self): def get_arg_spec(self): return [ - AIERuntimeArgSpec("in", (self.rows, self.cols,)), # input tensor - AIERuntimeArgSpec("in", (self.angle_rows, self.cols,)), # angles - AIERuntimeArgSpec("out", (self.rows, self.cols,)), # output + AIERuntimeArgSpec( + "in", + ( + self.rows, + self.cols, + ), + ), # input tensor + AIERuntimeArgSpec( + "in", + ( + self.angle_rows, + self.cols, + ), + ), # angles + AIERuntimeArgSpec( + "out", + ( + self.rows, + self.cols, + ), + ), # output ] diff --git a/operators/silu/design.py b/operators/silu/design.py index 0d0e6d74..4c041afb 100644 --- a/operators/silu/design.py +++ b/operators/silu/design.py @@ -15,7 +15,9 @@ from aie.helpers.util import np_ndarray_type_get_shape -def my_silu(dev, size, num_columns, tile_size, trace_size, kernel_archive, func_prefix=""): +def my_silu( + dev, size, num_columns, tile_size, trace_size, kernel_archive, func_prefix="" +): xfr_dtype = bfloat16 line_size = 4096 if tile_size > 4096 else tile_size line_type = np.ndarray[(line_size,), np.dtype[xfr_dtype]] diff --git a/operators/silu/op.py b/operators/silu/op.py index 8adc6f4c..942ebe12 100644 --- a/operators/silu/op.py +++ b/operators/silu/op.py @@ -16,7 +16,9 @@ class AIESiLU(SingleMLIRSourceOperator): """AIE-accelerated SiLU activation function""" def __init__(self, size, tile_size, num_aie_columns=8, context=None): - assert size % (num_aie_columns * tile_size) == 0, "size must be multiple of num_aie_columns * tile_size" + assert ( + size % (num_aie_columns * tile_size) == 0 + ), "size must be multiple of num_aie_columns * tile_size" self.size = size self.tile_size = tile_size self.num_aie_columns = num_aie_columns @@ -40,8 +42,8 @@ def get_mlir_artifact(self): self.size, self.num_aie_columns, self.tile_size, - 0 - ] + 0, + ], ) def get_kernel_artifacts(self): diff --git a/operators/softmax/design.py b/operators/softmax/design.py index e7e12fc0..567dbbc6 100644 --- a/operators/softmax/design.py +++ b/operators/softmax/design.py @@ -7,7 +7,15 @@ import argparse import sys -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker, Buffer, WorkerRuntimeBarrier +from aie.iron import ( + Kernel, + ObjectFifo, + Program, + Runtime, + Worker, + Buffer, + WorkerRuntimeBarrier, +) from aie.iron.placers import SequentialPlacer from aie.iron.device import NPU1, NPU2 from aie.helpers.taplib.tap import TensorAccessPattern @@ -15,7 +23,18 @@ from ml_dtypes import bfloat16 -def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_size, rtp_vector_size=None, mask_patch_value=0, kernel_archive="softmax.a", func_prefix=""): +def softmax( + dev, + num_elements, + num_aie_columns, + num_channels, + trace_size, + tile_size, + rtp_vector_size=None, + mask_patch_value=0, + kernel_archive="softmax.a", + func_prefix="", +): per_tile_elements = tile_size if rtp_vector_size is None: rtp_vector_size = per_tile_elements @@ -45,8 +64,12 @@ def softmax(dev, num_elements, num_aie_columns, num_channels, trace_size, tile_s ] # AIE Core Function declaration - softmax_kernel = Kernel(f"{func_prefix}softmax_bf16", kernel_archive, [tile_ty, tile_ty, np.int32]) - mask_kernel = Kernel(f"{func_prefix}mask_bf16", kernel_archive, [tile_ty, np.int32, np.int32]) + softmax_kernel = Kernel( + f"{func_prefix}softmax_bf16", kernel_archive, [tile_ty, tile_ty, np.int32] + ) + mask_kernel = Kernel( + f"{func_prefix}mask_bf16", kernel_archive, [tile_ty, np.int32, np.int32] + ) # Define a task that will run on a compute tile def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): @@ -60,7 +83,7 @@ def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): softmax_kernel(elem_in1, elem_out, per_tile_elements) of_in1.release(1) of_out.release(1) - + rtps = [ Buffer( np.ndarray[(1,), np.dtype[np.int32]], @@ -87,7 +110,7 @@ def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): softmax_kernel, mask_kernel, rtps[i * num_channels + j], - barriers[i * num_channels + j] + barriers[i * num_channels + j], ], ) for i in range(num_aie_columns) @@ -118,12 +141,10 @@ def core_body(of_in1, of_out, softmax_kernel, mask_kernel, rtp, barrier): # Set run-time parameter for actual vector size (remainder is considered padding and ignored by the computation) def set_rtps(*args): for rtp in args: - rtp[0] = ( - rtp_vector_size if not mask_patch_value else mask_patch_value - ) + rtp[0] = rtp_vector_size if not mask_patch_value else mask_patch_value rt.inline_ops(set_rtps, rtps) - + for i in range(num_aie_columns * num_channels): rt.set_barrier(barriers[i], 1) diff --git a/operators/softmax/op.py b/operators/softmax/op.py index 0f6c8b10..93d0c299 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -16,11 +16,22 @@ class AIESoftmax(SingleMLIRSourceOperator): """AIE-accelerated Softmax operation""" - def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_vector_size=None, mask_patch_value=0, context=None): + def __init__( + self, + rows: int, + cols: int, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=None, + mask_patch_value=0, + context=None, + ): assert rows % 16 == 0, "rows must be multiple of 16" assert cols % 16 == 0, "cols must be multiple of 16" - assert (rows * cols) % (num_aie_columns * cols) == 0, "size must be multiple of num_aie_columns * tile_size" - + assert (rows * cols) % ( + num_aie_columns * cols + ) == 0, "size must be multiple of num_aie_columns * tile_size" + self.rows = rows self.cols = cols self.size = rows * cols @@ -28,7 +39,7 @@ def __init__(self, rows: int, cols: int, num_aie_columns=1, num_channels=1, rtp_ self.num_channels = num_channels self.rtp_vector_size = rtp_vector_size self.mask_patch_value = mask_patch_value - + SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): @@ -51,7 +62,7 @@ def get_mlir_artifact(self): 0, # trace_size self.cols, self.rtp_vector_size, - self.mask_patch_value + self.mask_patch_value, ], ) diff --git a/operators/softmax/test2.py b/operators/softmax/test2.py index a3dfb149..864413f8 100755 --- a/operators/softmax/test2.py +++ b/operators/softmax/test2.py @@ -4,20 +4,21 @@ from pathlib import Path import time import torch + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from operators.softmax.op import AIESoftmax from operators.common import AIEBuffer -max_context_len=2048 -prompt_len=8 -n_heads=32 +max_context_len = 2048 +prompt_len = 8 +n_heads = 32 -softmax_op = AIESoftmax( - rows=n_heads, - cols=max_context_len, - rtp_vector_size=prompt_len -).compile().get_callable() +softmax_op = ( + AIESoftmax(rows=n_heads, cols=max_context_len, rtp_vector_size=prompt_len) + .compile() + .get_callable() +) inp = AIEBuffer((n_heads, max_context_len)) out = AIEBuffer((n_heads, max_context_len)) @@ -53,7 +54,8 @@ if len(mismatches[0]) > 0: for i in range(min(10, len(mismatches[0]))): h, s = mismatches[0][i], mismatches[1][i] - print(f"Mismatch at head={h}, seq={s}: ref={out_ref[h,s]}, aie={aie_out[h,s]}, diff={diff[h,s]}") + print( + f"Mismatch at head={h}, seq={s}: ref={out_ref[h,s]}, aie={aie_out[h,s]}, diff={diff[h,s]}" + ) assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) - diff --git a/operators/strided_copy/design.py b/operators/strided_copy/design.py index 57da2ec5..63b97e33 100644 --- a/operators/strided_copy/design.py +++ b/operators/strided_copy/design.py @@ -2,22 +2,39 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -#from aie.extras.context import mlir_mod_ctx -#from aie.ir import StridedLayoutAttr, ShapedType -#from aie.dialects.aie import * -#from aie.dialects.aiex import * + +# from aie.extras.context import mlir_mod_ctx +# from aie.ir import StridedLayoutAttr, ShapedType +# from aie.dialects.aie import * +# from aie.dialects.aiex import * from aie.dialects.aiex import TensorAccessPattern from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker from aie.iron.placers import SequentialPlacer - """ Strided copy design This can be useful for data layout manipulation and data copying such as: input[0, :, 0] -> output[:, 0, 0] """ -def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, input_offset, output_buffer_size, output_sizes, output_strides, output_offset, transfer_size=None, num_aie_channels=1, input_offset_patch_marker=0, output_offset_patch_marker=0): + + +def strided_copy( + dev, + dtype, + input_buffer_size, + input_sizes, + input_strides, + input_offset, + output_buffer_size, + output_sizes, + output_strides, + output_offset, + transfer_size=None, + num_aie_channels=1, + input_offset_patch_marker=0, + output_offset_patch_marker=0, +): assert len(input_sizes) == len(input_strides) assert len(output_sizes) == len(output_strides) @@ -29,31 +46,48 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu input_highest_sz_idx = max(idx for idx, sz in enumerate(input_sizes) if sz >= 1) output_highest_sz_idx = max(idx for idx, sz in enumerate(output_sizes) if sz >= 1) - assert input_sizes[input_highest_sz_idx] % num_aie_channels == 0, "Highest dimension of input_sizes must be divisible by num_aie_channels" - assert output_sizes[output_highest_sz_idx] % num_aie_channels == 0, "Highest dimension of output_sizes must be divisible by num_aie_channels" + assert ( + input_sizes[input_highest_sz_idx] % num_aie_channels == 0 + ), "Highest dimension of input_sizes must be divisible by num_aie_channels" + assert ( + output_sizes[output_highest_sz_idx] % num_aie_channels == 0 + ), "Highest dimension of output_sizes must be divisible by num_aie_channels" if transfer_size is None: transfer_size = int(np.prod(input_sizes)) assert np.prod(input_sizes) % transfer_size == 0 - transfer_ty = np.ndarray[(transfer_size,), np.dtype[dtype],] - - inp_ty = np.ndarray[(int(input_buffer_size),), np.dtype[dtype],] - out_ty = np.ndarray[(int(output_buffer_size),), np.dtype[dtype],] + transfer_ty = np.ndarray[ + (transfer_size,), + np.dtype[dtype], + ] + + inp_ty = np.ndarray[ + (int(input_buffer_size),), + np.dtype[dtype], + ] + out_ty = np.ndarray[ + (int(output_buffer_size),), + np.dtype[dtype], + ] input_taps = [ TensorAccessPattern( tensor_dims=(int(input_buffer_size + input_offset_patch_marker),), offset=( - input_offset_patch_marker if input_offset_patch_marker != 0 else - input_offset + c * (input_sizes[input_highest_sz_idx] // num_aie_channels) * input_strides[input_highest_sz_idx] + input_offset_patch_marker + if input_offset_patch_marker != 0 + else input_offset + + c + * (input_sizes[input_highest_sz_idx] // num_aie_channels) + * input_strides[input_highest_sz_idx] ), sizes=( input_sizes[:input_highest_sz_idx] + [input_sizes[input_highest_sz_idx] // num_aie_channels] - + input_sizes[input_highest_sz_idx+1:] + + input_sizes[input_highest_sz_idx + 1 :] ), strides=list(input_strides), - ) + ) for c in range(num_aie_channels) ] @@ -61,13 +95,17 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu TensorAccessPattern( tensor_dims=(int(output_buffer_size + output_offset_patch_marker),), offset=( - output_offset_patch_marker if output_offset_patch_marker != 0 else - output_offset + c * (output_sizes[output_highest_sz_idx] // num_aie_channels) * output_strides[output_highest_sz_idx] + output_offset_patch_marker + if output_offset_patch_marker != 0 + else output_offset + + c + * (output_sizes[output_highest_sz_idx] // num_aie_channels) + * output_strides[output_highest_sz_idx] ), sizes=( output_sizes[:output_highest_sz_idx] + [output_sizes[output_highest_sz_idx] // num_aie_channels] - + output_sizes[output_highest_sz_idx+1:] + + output_sizes[output_highest_sz_idx + 1 :] ), strides=list(output_strides), ) @@ -75,8 +113,14 @@ def strided_copy(dev, dtype, input_buffer_size, input_sizes, input_strides, inpu ] # Use smaller FIFOs for the transfer amount - fifos_in = [ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=1) for c in range(num_aie_channels)] - fifos_out = [fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=1) for c in range(num_aie_channels)] + fifos_in = [ + ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=1) + for c in range(num_aie_channels) + ] + fifos_out = [ + fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=1) + for c in range(num_aie_channels) + ] rt = Runtime() with rt.sequence(inp_ty, out_ty) as (inp, out): diff --git a/operators/strided_copy/op.py b/operators/strided_copy/op.py index 452743dd..c8d04bbf 100644 --- a/operators/strided_copy/op.py +++ b/operators/strided_copy/op.py @@ -20,11 +20,11 @@ class AIEStridedCopy(SingleMLIRSourceOperator): def __init__( self, - input_sizes, - input_strides, - input_offset, - output_sizes, - output_strides, + input_sizes, + input_strides, + input_offset, + output_sizes, + output_strides, output_offset, input_buffer_size, output_buffer_size, @@ -32,7 +32,7 @@ def __init__( transfer_size=None, num_aie_channels=1, context=None, - **kwargs + **kwargs, ): assert len(input_sizes) == len(input_strides) assert len(output_sizes) == len(output_strides) diff --git a/operators/strided_copy/test2.py b/operators/strided_copy/test2.py index 7205de16..46627022 100755 --- a/operators/strided_copy/test2.py +++ b/operators/strided_copy/test2.py @@ -4,6 +4,7 @@ from pathlib import Path import time import torch + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from operators.strided_copy.op import AIEStridedCopy @@ -15,32 +16,64 @@ head_dim = 64 num_heads = 32 -transpose_concat = AIEStridedCopy( - input_sizes=(1, num_heads, prompt_len, head_dim,), - input_strides=(0, head_dim, num_heads * head_dim, 1,), - input_offset=0, - output_sizes=(1, num_heads, prompt_len, head_dim,), - output_strides=(0, max_prompt_len * head_dim, head_dim, 1,), - output_offset=cached_prompt_len * head_dim, - input_buffer_size=prompt_len * num_heads * head_dim, - output_buffer_size=num_heads * max_prompt_len * head_dim, - num_aie_channels=1 -).compile().get_callable() +transpose_concat = ( + AIEStridedCopy( + input_sizes=( + 1, + num_heads, + prompt_len, + head_dim, + ), + input_strides=( + 0, + head_dim, + num_heads * head_dim, + 1, + ), + input_offset=0, + output_sizes=( + 1, + num_heads, + prompt_len, + head_dim, + ), + output_strides=( + 0, + max_prompt_len * head_dim, + head_dim, + 1, + ), + output_offset=cached_prompt_len * head_dim, + input_buffer_size=prompt_len * num_heads * head_dim, + output_buffer_size=num_heads * max_prompt_len * head_dim, + num_aie_channels=1, + ) + .compile() + .get_callable() +) value_cache_1 = AIEBuffer((num_heads, max_prompt_len, head_dim)) value_1 = AIEBuffer((prompt_len, num_heads, head_dim)) -value_cache_1.view_as_torch()[:, :cached_prompt_len, :] = torch.randn(num_heads, cached_prompt_len, head_dim) -value_1.view_as_torch()[:prompt_len, :, :] = torch.randn(prompt_len, num_heads, head_dim) +value_cache_1.view_as_torch()[:, :cached_prompt_len, :] = torch.randn( + num_heads, cached_prompt_len, head_dim +) +value_1.view_as_torch()[:prompt_len, :, :] = torch.randn( + prompt_len, num_heads, head_dim +) value_cache = AIEBuffer((num_heads, max_prompt_len, head_dim)) value = AIEBuffer((prompt_len, num_heads, head_dim)) -value_cache.view_as_torch()[:, :cached_prompt_len, :] = torch.randn(num_heads, cached_prompt_len, head_dim) +value_cache.view_as_torch()[:, :cached_prompt_len, :] = torch.randn( + num_heads, cached_prompt_len, head_dim +) value.view_as_torch()[:prompt_len, :, :] = torch.randn(prompt_len, num_heads, head_dim) t_cpu_start = time.perf_counter() value_transposed = value.view_as_torch().transpose(0, 1) -out_ref = torch.cat([value_cache.view_as_torch()[:, :cached_prompt_len, :], value_transposed], dim=1) +out_ref = torch.cat( + [value_cache.view_as_torch()[:, :cached_prompt_len, :], value_transposed], dim=1 +) t_cpu = time.perf_counter() - t_cpu_start transpose_concat(value_1, value_cache_1) @@ -53,7 +86,7 @@ print(out_ref) print(t_cpu) -aie_out = value_cache.view_as_torch()[:, :(cached_prompt_len + prompt_len), :] +aie_out = value_cache.view_as_torch()[:, : (cached_prompt_len + prompt_len), :] print(aie_out) print(t_aie) @@ -68,7 +101,8 @@ if len(mismatches[0]) > 0: for i in range(min(10, len(mismatches[0]))): h, s, d = mismatches[0][i], mismatches[1][i], mismatches[2][i] - print(f"Mismatch at head={h}, seq={s}, dim={d}: ref={out_ref[h,s,d]}, aie={aie_out[h,s,d]}, diff={diff[h,s,d]}") + print( + f"Mismatch at head={h}, seq={s}, dim={d}: ref={out_ref[h,s,d]}, aie={aie_out[h,s,d]}, diff={diff[h,s,d]}" + ) assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) - diff --git a/operators/transpose/design.py b/operators/transpose/design.py index dcf2636e..03fad5d3 100644 --- a/operators/transpose/design.py +++ b/operators/transpose/design.py @@ -10,7 +10,9 @@ from aie.iron.controlflow import range_ -def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, kernel_archive=None, func_prefix=""): +def shuffle_transpose( + dev, M, N, num_columns, num_channels, m, n, s, kernel_archive=None, func_prefix="" +): num_elements = M * N per_tile_elements = m * n dtype = bfloat16 diff --git a/operators/transpose/op.py b/operators/transpose/op.py index d3db0cb5..8a2e3adc 100644 --- a/operators/transpose/op.py +++ b/operators/transpose/op.py @@ -20,8 +20,10 @@ def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): assert N % n == 0, f"Matrix columns ({N}) must be a multiple of {n}" assert m % s == 0, f"AIE tile rows ({m}) must be a multiple of {s}" assert n % s == 0, f"AIE tile columns ({n}) must be a multiple of {s}" - assert M * N % (m * n * num_aie_columns * num_channels) == 0, "Transfer size must be divisible by m*n*num_columns*num_channels" - + assert ( + M * N % (m * n * num_aie_columns * num_channels) == 0 + ), "Transfer size must be divisible by m*n*num_columns*num_channels" + self.M = M self.N = N self.m = m @@ -29,7 +31,7 @@ def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): self.s = s self.num_columns = num_aie_columns self.num_channels = num_channels - + SingleMLIRSourceOperator.__init__(self, context=context) def get_operator_name(self): From 9f29a92757222c446616d1bbe5d32be6121a050f Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 16:35:21 -0700 Subject: [PATCH 71/99] a few more formatting fixes --- operators/common/compilation/__init__.py | 2 +- operators/common/compilation/fusion.py | 93 ++++++++++++++++-------- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/operators/common/compilation/__init__.py b/operators/common/compilation/__init__.py index a0c2b126..9e924f5a 100644 --- a/operators/common/compilation/__init__.py +++ b/operators/common/compilation/__init__.py @@ -1,2 +1,2 @@ from .base import * -from .fusion import * \ No newline at end of file +from .fusion import * diff --git a/operators/common/compilation/fusion.py b/operators/common/compilation/fusion.py index 19b560d8..ea1d47e2 100644 --- a/operators/common/compilation/fusion.py +++ b/operators/common/compilation/fusion.py @@ -22,13 +22,20 @@ PythonGeneratedMLIRArtifact, ) - # Compilation Artifacts # ########################################################################## class FusedMLIRSource(CompilationArtifact): - def __init__(self, filename, operator_mlir_map, runlist, subbuffer_layout, buffer_sizes, slice_info=None): + def __init__( + self, + filename, + operator_mlir_map, + runlist, + subbuffer_layout, + buffer_sizes, + slice_info=None, + ): dependencies = list(operator_mlir_map.values()) super().__init__(filename, dependencies) self.operator_mlir_map = operator_mlir_map @@ -46,11 +53,14 @@ def extract_runtime_sequence_arg_types(dev_op): """MLIR helper: Extract argument types from a device operation's runtime sequence.""" for nested_op in dev_op.body_region.blocks[0].operations: op_name = nested_op.operation.name - if op_name == 'aie.runtime_sequence': - if hasattr(nested_op, 'body') and hasattr(nested_op.body, 'blocks'): + if op_name == "aie.runtime_sequence": + if hasattr(nested_op, "body") and hasattr(nested_op.body, "blocks"): if len(nested_op.body.blocks) > 0: entry_block = nested_op.body.blocks[0] - arg_types = [entry_block.arguments[i].type for i in range(len(entry_block.arguments))] + arg_types = [ + entry_block.arguments[i].type + for i in range(len(entry_block.arguments)) + ] return arg_types raise RuntimeError("Could not find runtime sequence in device operation") @@ -63,10 +73,10 @@ def get_child_mlir_module(mlir_artifact): ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + if mlir_artifact.requires_context: raise NotImplementedError("Not handled, make your operator return a ctx.module") - + callback_function = getattr(module, mlir_artifact.callback_fn) mlir_module = callback_function( *mlir_artifact.callback_args, **mlir_artifact.callback_kwargs @@ -85,14 +95,18 @@ def fuse_mlir(artifact): sequence_arg_types = {} for op_name, mlir_artifact in artifact.operator_mlir_map.items(): mlir_module = get_child_mlir_module(mlir_artifact) - device_ops = [op for op in mlir_module.body.operations if isinstance(op, aie.DeviceOp)] - assert len(device_ops) == 1, f"Expected exactly one device operation in MLIR artifact for operator '{op_name}'" + device_ops = [ + op for op in mlir_module.body.operations if isinstance(op, aie.DeviceOp) + ] + assert ( + len(device_ops) == 1 + ), f"Expected exactly one device operation in MLIR artifact for operator '{op_name}'" device_op = device_ops[0] if device_ty is None: device_ty = device_op.device device_mlir_strings[op_name] = str(device_op) sequence_arg_types[op_name] = extract_runtime_sequence_arg_types(device_op) - + # Build fused MLIR module with mlir_mod_ctx() as ctx: @@ -101,11 +115,13 @@ def fuse_mlir(artifact): dev_op = aie.DeviceOp.parse(device_str) dev_op.sym_name = ir.StringAttr.get(op_name) ctx.module.body.append(dev_op) - + # Create the main device -- this contains the runtime sequence calling into the other devices @aie.device(device_ty) def main(): - buf_dtype = np.dtype[ml_dtypes.bfloat16] # TODO: support for other data types + buf_dtype = np.dtype[ + ml_dtypes.bfloat16 + ] # TODO: support for other data types itemsize = 2 # RuntimeSequenceOp @@ -116,11 +132,11 @@ def main(): ) def sequence(input_buf, output_buf, scratch_buf): consolidated_buffers = { - 'input': input_buf, - 'output': output_buf, - 'scratch': scratch_buf + "input": input_buf, + "output": output_buf, + "scratch": scratch_buf, } - + # Execute operations in runlist order configure_op = None last_op_name = None @@ -131,12 +147,14 @@ def sequence(input_buf, output_buf, scratch_buf): if configure_op is None or op_name != last_op_name: # Configure Op configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) - configure_op = aiex.ConfigureOp(configure_sym_ref_attr) # TODO: optimization -- if previous op was in the same device, skip reconfiguration + configure_op = aiex.ConfigureOp( + configure_sym_ref_attr + ) # TODO: optimization -- if previous op was in the same device, skip reconfiguration configure_body = configure_op.body.blocks.append() last_op_name = op_name - + with ir.InsertionPoint(configure_body): - + # For each buffer, add subview and reinterpret_cast ops buffer_ssa_values = [] for idx, buf_name in enumerate(buffer_names): @@ -144,14 +162,18 @@ def sequence(input_buf, output_buf, scratch_buf): if buf_name in artifact.slice_info: base_name, start, end = artifact.slice_info[buf_name] # Get parent buffer info - buf_type, parent_offset, parent_length = artifact.subbuffer_layout[base_name] + buf_type, parent_offset, parent_length = ( + artifact.subbuffer_layout[base_name] + ) # Calculate actual offset and length for slice offset = parent_offset + start length = end - start else: # Regular buffer - buf_type, offset, length = artifact.subbuffer_layout[buf_name] - + buf_type, offset, length = artifact.subbuffer_layout[ + buf_name + ] + # Subview Op consolidated_buf = consolidated_buffers[buf_type] offset_elements = offset // itemsize @@ -160,21 +182,28 @@ def sequence(input_buf, output_buf, scratch_buf): consolidated_buf, [offset_elements], [size_elements], - [1] + [1], ) - + # Reinterpret_cast Op target_type = expected_arg_types[idx] expected_memref = ir.MemRefType(target_type) - target_shape = [expected_memref.shape[i] for i in range(expected_memref.rank)] + target_shape = [ + expected_memref.shape[i] + for i in range(expected_memref.rank) + ] expected_size = np.prod(target_shape) - assert expected_size == size_elements, f"Size mismatch for buffer '{buf_name}': MLIR runtime sequence expected {expected_size}, Python fused operator provided {size_elements}" + assert ( + expected_size == size_elements + ), f"Size mismatch for buffer '{buf_name}': MLIR runtime sequence expected {expected_size}, Python fused operator provided {size_elements}" strides = [] stride = 1 for dim in reversed(target_shape): strides.insert(0, stride) stride *= dim - result_type = ir.MemRefType.get(target_shape, ir.BF16Type.get()) + result_type = ir.MemRefType.get( + target_shape, ir.BF16Type.get() + ) reinterpreted = memref.reinterpret_cast( result=result_type, source=subview, @@ -183,14 +212,14 @@ def sequence(input_buf, output_buf, scratch_buf): strides=[], static_offsets=[0], static_sizes=target_shape, - static_strides=strides + static_strides=strides, ) buffer_ssa_values.append(reinterpreted) - + # Run Op sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) - + # Write the fused MLIR to file with open(artifact.filename, "w") as f: f.write(str(ctx.module)) @@ -202,10 +231,10 @@ def sequence(input_buf, output_buf, scratch_buf): class FusePythonGeneratedMLIRCompilationRule(CompilationRule): """Compilation rule that fuses multiple MLIR modules into one.""" - + def matches(self, graph): return any(graph.get_worklist(FusedMLIRSource)) - + def compile(self, graph): commands = [] worklist = graph.get_worklist(FusedMLIRSource) From 404211ecb70f6329fae2237ca485c39c9f318311 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 17:09:20 -0700 Subject: [PATCH 72/99] some refactoring --- operators/axpy/op.py | 6 +- operators/common/__init__.py | 4 +- operators/common/base.py | 33 +++++- operators/common/fusion.py | 7 +- operators/common/test_utils.py | 12 +- operators/dequant/op.py | 6 +- operators/elementwise_add/op.py | 6 +- operators/elementwise_mul/op.py | 6 +- operators/gelu/op.py | 6 +- operators/gemm/op.py | 6 +- operators/gemv/op.py | 6 +- operators/layer_norm/op.py | 6 +- operators/leaky_relu/op.py | 6 +- operators/mem_copy/op.py | 6 +- operators/mha/op.py | 6 +- operators/relu/op.py | 6 +- operators/repeat/op.py | 6 +- operators/rms_norm/op.py | 6 +- operators/rope/op.py | 6 +- operators/sigmoid/op.py | 6 +- operators/silu/op.py | 6 +- operators/softmax/op.py | 6 +- operators/strided_copy/op.py | 6 +- operators/swiglu_decode/op.py | 155 ++++++++++++++------------ operators/swiglu_prefill/op.py | 191 +++++++++++++------------------- operators/tanh/op.py | 6 +- operators/transpose/op.py | 6 +- 27 files changed, 265 insertions(+), 263 deletions(-) diff --git a/operators/axpy/op.py b/operators/axpy/op.py index d1b1e269..91269ebc 100644 --- a/operators/axpy/op.py +++ b/operators/axpy/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIEAXPY(SingleMLIRSourceOperator): +class AIEAXPY(MLIROperator): """AIE-accelerated aX + Y operator""" def __init__( @@ -39,7 +39,7 @@ def __init__( self.num_channels = num_channels self.scalar_factor = scalar_factor - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"axpy_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.scalar_factor}s" diff --git a/operators/common/__init__.py b/operators/common/__init__.py index 89118abf..68cafac6 100644 --- a/operators/common/__init__.py +++ b/operators/common/__init__.py @@ -5,7 +5,9 @@ from .base import ( AIEOperatorBase, - SingleMLIRSourceOperator, + MLIROperator, + CompositeOperator, + CompositeCallable, AIEBuffer, SingleXclbinCallable, AIERuntimeArgSpec, diff --git a/operators/common/base.py b/operators/common/base.py index 7688c249..66c42c02 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -95,8 +95,8 @@ def execute_runlist(runlist): runlist.wait() -class SingleMLIRSourceOperator(AIEOperatorBase, ABC): - """Base class for AIE-accelerated operations""" +class MLIROperator(AIEOperatorBase, ABC): + """Base class for AIE-accelerated operations defined by a single MLIR source""" def __init__(self, *args, **kwargs): self.kernel_archive = f"{self.get_operator_name()}_kernels.a" @@ -159,6 +159,13 @@ def get_callable(self): ) +class CompositeOperator(AIEOperatorBase, ABC): + """Base class for composite operators that chain multiple sub-operators""" + + def __init__(self, context=None): + super().__init__(context) + + class AIERuntimeArgSpec: def __init__(self, direction, shape, dtype=bfloat16): self.shape = shape @@ -322,3 +329,25 @@ def patch(self, patches): for pos, (val, mask) in patches.items(): insts[pos] = (np.int64(insts[pos]) & ~mask) | (val & mask) self.insts_buffer.to("npu") + + +class CompositeCallable: + """Callable for executing a sequence of sub-operators""" + + def __init__(self, sequence, intermediate_buffers=None): + """ + Args: + sequence: List of (callable, args_indices) tuples. + args_indices is a list of indices into the combined list of [inputs, outputs, intermediates]. + intermediate_buffers: List of AIEBuffer objects for intermediate results. + """ + self.sequence = sequence + self.intermediate_buffers = intermediate_buffers or [] + + def __call__(self, *args): + # args contains inputs and outputs + all_buffers = list(args) + self.intermediate_buffers + + for op_callable, indices in self.sequence: + op_args = [all_buffers[i] for i in indices] + op_callable(*op_args) diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 42e7db61..748b55e9 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -3,7 +3,7 @@ import pyxrt import ctypes from . import compilation as comp -from .base import AIEOperatorBase, SingleMLIRSourceOperator, AIEBuffer +from .base import AIEOperatorBase, MLIROperator, AIEBuffer from .device_manager import AIEDeviceManager # Fused Operator @@ -11,14 +11,13 @@ class FusedMLIROperator(AIEOperatorBase): - """Operator that fuses multiple SingleMLIRSourceOperators into one.""" + """Operator that fuses multiple MLIROperators into one.""" def __init__( self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs ): assert all( - isinstance(op, SingleMLIRSourceOperator) - and all(isinstance(buf, str) for buf in bufs) + isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist ) self.runlist = runlist diff --git a/operators/common/test_utils.py b/operators/common/test_utils.py index 066a7981..cb33afa2 100644 --- a/operators/common/test_utils.py +++ b/operators/common/test_utils.py @@ -6,7 +6,7 @@ from ml_dtypes import bfloat16 from .utils import torch_to_numpy import logging -from .base import SingleMLIRSourceOperator, AIEBuffer +from .base import MLIROperator, CompositeOperator, AIEBuffer def nearly_equal( @@ -85,8 +85,8 @@ def run_test( ) logger = logging.getLogger(__name__) - if not isinstance(operator, SingleMLIRSourceOperator): - raise ValueError("run_test only supports SingleMLIRSourceOperator") + if not isinstance(operator, (MLIROperator, CompositeOperator)): + raise ValueError("run_test only supports MLIROperator or CompositeOperator") operator.compile() op_func = operator.get_callable() @@ -148,12 +148,10 @@ def run_test( else: print(f"Warning: Output buffer {buf_name} not found in operator arguments") - # Intermediate buffers are not supported in this generic run_test for SingleMLIRSourceOperator + # Intermediate buffers are not supported in this generic run_test # unless we expose them somehow. For now, ignore or warn. if intermediate_buffers: - print( - "Warning: intermediate_buffers verification is not supported for SingleMLIRSourceOperator in run_test" - ) + print("Warning: intermediate_buffers verification is not supported in run_test") # Calculate bandwidth bandwidth_gbps = total_bytes / (latency_us * 1e-6) / 1e9 diff --git a/operators/dequant/op.py b/operators/dequant/op.py index 94e39fb7..47ef7dc8 100644 --- a/operators/dequant/op.py +++ b/operators/dequant/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIEDequant(SingleMLIRSourceOperator): +class AIEDequant(MLIROperator): def __init__( self, @@ -44,7 +44,7 @@ def __init__( assert self.size % total_cores == 0, "Size must be divisible by total cores" assert total_cores <= 16, "Total cores (columns * channels) must be <= 16" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"dequant_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/elementwise_add/op.py b/operators/elementwise_add/op.py index 095d3d03..26886794 100644 --- a/operators/elementwise_add/op.py +++ b/operators/elementwise_add/op.py @@ -8,7 +8,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, @@ -19,7 +19,7 @@ ) -class AIEElementwiseAdd(SingleMLIRSourceOperator): +class AIEElementwiseAdd(MLIROperator): """AIE-accelerated element-wise addition""" def __init__( @@ -39,7 +39,7 @@ def __init__( # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"add_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" diff --git a/operators/elementwise_mul/op.py b/operators/elementwise_mul/op.py index c36dc4c0..7e107968 100644 --- a/operators/elementwise_mul/op.py +++ b/operators/elementwise_mul/op.py @@ -4,7 +4,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -12,7 +12,7 @@ ) -class AIEElementwiseMul(SingleMLIRSourceOperator): +class AIEElementwiseMul(MLIROperator): """AIE-accelerated element-wise multiplication""" def __init__( @@ -32,7 +32,7 @@ def __init__( # Maximum safe configuration: 8 columns × 2 channels = 16 ShimDMA channels total_shimdma_channels = self.num_aie_columns * 2 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"mul_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" diff --git a/operators/gelu/op.py b/operators/gelu/op.py index c59e141e..c06b9fc9 100644 --- a/operators/gelu/op.py +++ b/operators/gelu/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIEGELU(SingleMLIRSourceOperator): +class AIEGELU(MLIROperator): """AIE-accelerated GELU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): @@ -33,7 +33,7 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"gelu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 0851d266..c3d100c4 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -8,7 +8,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -18,7 +18,7 @@ from operators.common.utils import torch_to_numpy, numpy_to_torch -class AIEGEMM(SingleMLIRSourceOperator): +class AIEGEMM(MLIROperator): """AIE-accelerated General Matrix Multiplication (GEMM) layer""" def __init__( @@ -67,7 +67,7 @@ def __init__( assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"gemm_{self.M}x{self.K}x{self.N}_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}" diff --git a/operators/gemv/op.py b/operators/gemv/op.py index 268f7666..1a3788bd 100644 --- a/operators/gemv/op.py +++ b/operators/gemv/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, @@ -19,7 +19,7 @@ from operators.common.utils import torch_to_numpy -class AIEGEMV(SingleMLIRSourceOperator): +class AIEGEMV(MLIROperator): """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" def __init__( @@ -55,7 +55,7 @@ def __init__( self.xclbin_artifact = None self.insts_artifact = None - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"gemv_{self.M}x{self.K}_{self.tile_size_input}tsi_{self.tile_size_output}tso_{self.num_batches}batch_{self.num_aie_columns}col" diff --git a/operators/layer_norm/op.py b/operators/layer_norm/op.py index 75771799..383cb5cd 100644 --- a/operators/layer_norm/op.py +++ b/operators/layer_norm/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIELayerNorm(SingleMLIRSourceOperator): +class AIELayerNorm(MLIROperator): """AIE-accelerated LAYER NORM operator""" def __init__( @@ -36,7 +36,7 @@ def __init__( total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"layer_norm_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/leaky_relu/op.py b/operators/leaky_relu/op.py index 461de61b..1500479a 100644 --- a/operators/leaky_relu/op.py +++ b/operators/leaky_relu/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIELeakyReLU(SingleMLIRSourceOperator): +class AIELeakyReLU(MLIROperator): """AIE-accelerated LEAKY RELU operator""" def __init__( @@ -37,7 +37,7 @@ def __init__( total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"leaky_relu_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/mem_copy/op.py b/operators/mem_copy/op.py index 0481f2da..0234ae0e 100644 --- a/operators/mem_copy/op.py +++ b/operators/mem_copy/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, @@ -18,7 +18,7 @@ ) -class AIEMemCopy(SingleMLIRSourceOperator): +class AIEMemCopy(MLIROperator): def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=None): self.size = size @@ -30,7 +30,7 @@ def __init__(self, size, num_cores, num_channels, bypass, tile_size, context=Non # For naming consistency with other operators self.bypass_str = "bypass" if bypass else "no_bypass" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"mem_copy_{self.num_cores}_cores_{self.num_channels}_chans_tile_{self.tile_size}_{self.bypass_str}" diff --git a/operators/mha/op.py b/operators/mha/op.py index 1040ccff..9f38e42e 100644 --- a/operators/mha/op.py +++ b/operators/mha/op.py @@ -8,7 +8,7 @@ from typing import Dict, List from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, @@ -20,7 +20,7 @@ from operators.common.utils import torch_to_numpy, numpy_to_torch -class AIEMHA(SingleMLIRSourceOperator): +class AIEMHA(MLIROperator): def __init__( self, @@ -40,7 +40,7 @@ def __init__( self.num_of_pipelines = num_of_pipelines assert d == 64, "Only d=64 is supported in this version" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads diff --git a/operators/relu/op.py b/operators/relu/op.py index 457a923b..a6a6c115 100644 --- a/operators/relu/op.py +++ b/operators/relu/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIEReLU(SingleMLIRSourceOperator): +class AIEReLU(MLIROperator): """AIE-accelerated ReLU activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): @@ -33,7 +33,7 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_aie_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"relu_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/repeat/op.py b/operators/repeat/op.py index 66e30b20..a0317b85 100644 --- a/operators/repeat/op.py +++ b/operators/repeat/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIERepeat(SingleMLIRSourceOperator): +class AIERepeat(MLIROperator): """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" def __init__( @@ -32,7 +32,7 @@ def __init__( self.repeat = repeat self.transfer_size = transfer_size self.dtype = dtype - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): name = f"repeat_{self.rows}x{self.cols}_by_{self.repeat}" diff --git a/operators/rms_norm/op.py b/operators/rms_norm/op.py index a512869c..1fb48070 100644 --- a/operators/rms_norm/op.py +++ b/operators/rms_norm/op.py @@ -8,7 +8,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, XclbinArtifact, InstsBinArtifact, @@ -20,7 +20,7 @@ from operators.common.utils import torch_to_numpy -class AIERMSNorm(SingleMLIRSourceOperator): +class AIERMSNorm(MLIROperator): """AIE-accelerated RMS Normalization layer""" def __init__( @@ -52,7 +52,7 @@ def __init__( total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"weighted_rms_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/rope/op.py b/operators/rope/op.py index a3535f75..49488d1e 100644 --- a/operators/rope/op.py +++ b/operators/rope/op.py @@ -4,7 +4,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -12,7 +12,7 @@ ) -class AIERope(SingleMLIRSourceOperator): +class AIERope(MLIROperator): def __init__( self, @@ -44,7 +44,7 @@ def __init__( self.method_type = method_type assert method_type in {0, 1} - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"rope_{self.num_aie_columns}col_{self.rows}rows_{self.cols}cols_{self.angle_rows}arows_{self.method_type}m" diff --git a/operators/sigmoid/op.py b/operators/sigmoid/op.py index 6f1c8456..4a56c6ff 100644 --- a/operators/sigmoid/op.py +++ b/operators/sigmoid/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIESigmoid(SingleMLIRSourceOperator): +class AIESigmoid(MLIROperator): """AIE-accelerated Sigmoid activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): @@ -34,7 +34,7 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"sigmoid_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/silu/op.py b/operators/silu/op.py index 942ebe12..84d37005 100644 --- a/operators/silu/op.py +++ b/operators/silu/op.py @@ -4,7 +4,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -12,7 +12,7 @@ ) -class AIESiLU(SingleMLIRSourceOperator): +class AIESiLU(MLIROperator): """AIE-accelerated SiLU activation function""" def __init__(self, size, tile_size, num_aie_columns=8, context=None): @@ -26,7 +26,7 @@ def __init__(self, size, tile_size, num_aie_columns=8, context=None): # Maximum safe configuration: 8 columns × 1 channel = 8 ShimDMA channels total_shimdma_channels = self.num_aie_columns * 1 assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"silu_{self.num_aie_columns}col_{self.size}_{self.tile_size}t" diff --git a/operators/softmax/op.py b/operators/softmax/op.py index 93d0c299..202d0db6 100644 --- a/operators/softmax/op.py +++ b/operators/softmax/op.py @@ -5,7 +5,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -13,7 +13,7 @@ ) -class AIESoftmax(SingleMLIRSourceOperator): +class AIESoftmax(MLIROperator): """AIE-accelerated Softmax operation""" def __init__( @@ -40,7 +40,7 @@ def __init__( self.rtp_vector_size = rtp_vector_size self.mask_patch_value = mask_patch_value - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): name = f"softmax_{self.num_aie_columns}col_{self.num_channels}ch_{self.size}_{self.cols}t" diff --git a/operators/strided_copy/op.py b/operators/strided_copy/op.py index c8d04bbf..feda0e7b 100644 --- a/operators/strided_copy/op.py +++ b/operators/strided_copy/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIEStridedCopy(SingleMLIRSourceOperator): +class AIEStridedCopy(MLIROperator): """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" def __init__( @@ -48,7 +48,7 @@ def __init__( self.transfer_size = transfer_size self.num_aie_channels = num_aie_channels self.kwargs = kwargs - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"strided_copy_{'x'.join(map(str, self.input_sizes))}sz_{'x'.join(map(str, self.input_strides))}st_{self.input_offset}off_to_{'x'.join(map(str, self.output_sizes))}sz_{'x'.join(map(str, self.output_strides))}st_{self.output_offset}off_{self.transfer_size if self.transfer_size is not None else 'auto'}tr_{self.num_aie_channels}ch" diff --git a/operators/swiglu_decode/op.py b/operators/swiglu_decode/op.py index d035992a..c2694728 100644 --- a/operators/swiglu_decode/op.py +++ b/operators/swiglu_decode/op.py @@ -7,13 +7,10 @@ from ml_dtypes import bfloat16 from operators.common import ( - AIEOperatorBase, - XclbinArtifact, - InstsBinArtifact, - KernelObjectArtifact, - KernelArchiveArtifact, - SourceArtifact, - PythonGeneratedMLIRArtifact, + CompositeOperator, + AIERuntimeArgSpec, + AIEBuffer, + SingleXclbinCallable, ) from operators.gemv.op import AIEGEMV from operators.silu.op import AIESiLU @@ -21,7 +18,77 @@ from operators.common.utils import torch_to_numpy -class AIESwiGLUDecode(AIEOperatorBase): +class SwiGLUDecodeCallable: + def __init__(self, op): + self.op = op + # Create callables for sub-operators + # We need to manually construct SingleXclbinCallable because sub-operators weren't "compiled" in the standard way + + # Helper to create callable from operator and artifacts + def create_callable(sub_op, xclbin_artifact, insts_artifact): + return SingleXclbinCallable( + xclbin_path=xclbin_artifact.filename, + kernel_name=xclbin_artifact.kernel_name, + insts_bin_path=insts_artifact.filename, + args_spec=sub_op.get_arg_spec(), + ) + + self.gemv_1_callable = create_callable( + op.gemv_1, op.combined_xclbin, op.gemv_1_insts + ) + self.silu_callable = create_callable(op.silu, op.combined_xclbin, op.silu_insts) + self.eltwise_mul_callable = create_callable( + op.eltwise_mul, op.combined_xclbin, op.eltwise_mul_insts + ) + self.gemv_2_callable = create_callable( + op.gemv_2, op.combined_xclbin, op.gemv_2_insts + ) + + # Allocate and upload weights + self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1)) + self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2)) + self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3)) + + # Allocate intermediate buffers + # left: output of gemv_1 (hidden_dim_padded) + self.left = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # right: output of gemv_1 (hidden_dim_padded) + self.right = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # left_swished: output of silu (hidden_dim_padded) + self.left_swished = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + # intermediate: output of eltwise_mul (hidden_dim_padded) + self.intermediate = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + + def __call__(self, input_buf, output_buf): + # Ensure inputs are on device + input_buf.to("npu") + output_buf.to("npu") + self.weights_1.to("npu") + self.weights_2.to("npu") + self.weights_3.to("npu") + self.left.to("npu") + self.right.to("npu") + self.left_swished.to("npu") + self.intermediate.to("npu") + + # Sequence: + # 1. GEMV(weights_1, input, left) + self.gemv_1_callable(self.weights_1, input_buf, self.left) + + # 2. GEMV(weights_2, input, right) + self.gemv_1_callable(self.weights_2, input_buf, self.right) + + # 3. SiLU(left, left_swished) + self.silu_callable(self.left, self.left_swished) + + # 4. EltwiseMul(left_swished, right, intermediate) + self.eltwise_mul_callable(self.left_swished, self.right, self.intermediate) + + # 5. GEMV(weights_3, intermediate, output) + self.gemv_2_callable(self.weights_3, self.intermediate, output_buf) + + +class AIESwiGLUDecode(CompositeOperator): def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None): self.hidden_dim = hidden_dim @@ -140,69 +207,11 @@ def set_up_artifacts(self): self.add_artifacts(artifacts) - def set_up_runtime(self): - self.add_buffer("input", self.embedding_dim) - self.add_buffer( - "weights_1", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_1), - ) - self.add_buffer( - "weights_2", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_2), - ) - self.add_buffer( - "weights_3", - self.hidden_dim_padded * self.embedding_dim, - static_data=torch_to_numpy(self.weights_3), - ) - self.add_buffer("left", self.hidden_dim_padded) - self.add_buffer("left_swished", self.hidden_dim_padded) - self.add_buffer("right", self.hidden_dim_padded) - self.add_buffer("intermediate", self.hidden_dim_padded) - self.add_buffer("output", self.embedding_dim) - self.add_kernel( - "swiglu_gemv_1", - self.combined_xclbin, - self.gemv_1_xclbin.kernel_name, - self.gemv_1_insts, - ) - self.add_kernel( - "swiglu_silu", - self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, - ) - self.add_kernel( - "swiglu_gemv_2", - self.combined_xclbin, - self.gemv_2_xclbin.kernel_name, - self.gemv_2_insts, - ) - self.add_to_runlist("swiglu_gemv_1", "weights_1", "input", "left") - self.add_to_runlist("swiglu_gemv_1", "weights_2", "input", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) - self.add_to_runlist("swiglu_gemv_2", "weights_3", "intermediate", "output") - - def forward(self, x): - x_flat = x.reshape(x.shape[-1]) - assert x_flat.shape[0] == self.embedding_dim - - self.write_buffer("input", x_flat) - self.run_runlist() - result = self.read_buffer_as_torch( - "output", - (self.embedding_dim,), - ).view_as(x) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.embedding_dim,)), + AIERuntimeArgSpec("out", (self.embedding_dim,)), + ] - return result + def get_callable(self): + return SwiGLUDecodeCallable(self) diff --git a/operators/swiglu_prefill/op.py b/operators/swiglu_prefill/op.py index cf6c0704..b6231b2f 100644 --- a/operators/swiglu_prefill/op.py +++ b/operators/swiglu_prefill/op.py @@ -7,13 +7,10 @@ from ml_dtypes import bfloat16 from operators.common import ( - AIEOperatorBase, - XclbinArtifact, - InstsBinArtifact, - KernelObjectArtifact, - KernelArchiveArtifact, - SourceArtifact, - PythonGeneratedMLIRArtifact, + CompositeOperator, + AIERuntimeArgSpec, + AIEBuffer, + SingleXclbinCallable, ) from operators.gemm.op import AIEGEMM from operators.silu.op import AIESiLU @@ -21,7 +18,71 @@ from operators.common.utils import torch_to_numpy -class AIESwiGLUPrefill(AIEOperatorBase): +class SwiGLUPrefillCallable: + def __init__(self, op): + self.op = op + + def create_callable(sub_op, xclbin_artifact, insts_artifact): + return SingleXclbinCallable( + xclbin_path=xclbin_artifact.filename, + kernel_name=xclbin_artifact.kernel_name, + insts_bin_path=insts_artifact.filename, + args_spec=sub_op.get_arg_spec(), + ) + + self.gemm_1_callable = create_callable( + op.gemm_1, op.combined_xclbin, op.gemm_1_insts + ) + self.silu_callable = create_callable(op.silu, op.combined_xclbin, op.silu_insts) + self.eltwise_mul_callable = create_callable( + op.eltwise_mul, op.combined_xclbin, op.eltwise_mul_insts + ) + self.gemm_2_callable = create_callable( + op.gemm_2, op.combined_xclbin, op.gemm_2_insts + ) + + # Allocate and upload weights + self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1.T)) + self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2.T)) + self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3.T)) + + # Allocate intermediate buffers + # Sizes are padded + size_hidden = op.seq_len_padded * op.hidden_dim_padded + self.left = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.right = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.left_swished = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.intermediate = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + + def __call__(self, input_buf, output_buf): + input_buf.to("npu") + output_buf.to("npu") + self.weights_1.to("npu") + self.weights_2.to("npu") + self.weights_3.to("npu") + self.left.to("npu") + self.right.to("npu") + self.left_swished.to("npu") + self.intermediate.to("npu") + + # Sequence: + # 1. GEMM(input, weights_1, left) + self.gemm_1_callable(input_buf, self.weights_1, self.left) + + # 2. GEMM(input, weights_2, right) + self.gemm_1_callable(input_buf, self.weights_2, self.right) + + # 3. SiLU(left, left_swished) + self.silu_callable(self.left, self.left_swished) + + # 4. EltwiseMul(left_swished, right, intermediate) + self.eltwise_mul_callable(self.left_swished, self.right, self.intermediate) + + # 5. GEMM(intermediate, weights_3, output) + self.gemm_2_callable(self.intermediate, self.weights_3, output_buf) + + +class AIESwiGLUPrefill(CompositeOperator): def __init__( self, seq_len, embedding_dim, hidden_dim, prio_accuracy=False, context=None @@ -153,109 +214,13 @@ def set_up_artifacts(self): self.add_artifacts(artifacts) - def set_up_runtime(self): - # Runtime setup - # --- - self.add_buffer("input", self.seq_len_padded * self.embedding_dim_padded) - self.add_buffer( - "weights_1", - self.embedding_dim_padded * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_1.T), - ) - self.add_buffer( - "weights_2", - self.embedding_dim_padded * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_2.T), - ) - self.add_buffer( - "weights_3", - self.hidden_dim_padded * self.embedding_dim_padded, - static_data=torch_to_numpy(self.weights_3.T), - ) - self.add_buffer("left", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("left_swished", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("right", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("intermediate", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("output", self.seq_len_padded * self.embedding_dim_padded) - self.add_kernel( - "swiglu_gemm_1", - self.combined_xclbin, - self.gemm_1_xclbin.kernel_name, - self.gemm_1_insts, - ) - self.add_kernel( - "swiglu_silu", - self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, - ) - self.add_kernel( - "swiglu_gemm_2", - self.combined_xclbin, - self.gemm_2_xclbin.kernel_name, - self.gemm_2_insts, - ) - self.add_to_runlist("swiglu_gemm_1", "input", "weights_1", "left") - self.add_to_runlist("swiglu_gemm_1", "input", "weights_2", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) - self.add_to_runlist("swiglu_gemm_2", "intermediate", "weights_3", "output") - - def forward(self, x): - """Forward pass for SwiGLU operation""" - - # Always flatten to [batch, orig_size] - original_shape = x.shape - batch = x.shape[0] if x.dim() > 1 else 1 - x_flat = x.reshape(batch, -1) - - out = self._execute_aie_operation(x_flat) - - # Restore original shape - out = out.reshape(*original_shape) - - return out - - def _execute_aie_operation(self, x): - # x is [batch, size] - batch = x.shape[0] if x.dim() > 1 else 1 - - # Flatten inputs for AIE processing - x_flat = x.view(-1) - - # Verify input size matches expected dimensions - expected_size = batch * self.seq_len * self.embedding_dim - assert x_flat.shape[0] == expected_size - - # Pad input if necessary to match GEMM requirements - if self.seq_len_padded * self.embedding_dim_padded > x_flat.shape[0]: - x_padded = torch.zeros( - self.seq_len_padded * self.embedding_dim_padded, - dtype=x_flat.dtype, - device=x_flat.device, - ) - x_padded[: x_flat.shape[0]] = x_flat - x_flat = x_padded - - self.write_buffer("input", x_flat) - self.run_runlist() - - # Read padded output buffer - result_padded = self.read_buffer_as_torch( - "output", - shape=(self.seq_len_padded * self.embedding_dim_padded,), - dtype=bfloat16, - ) - - # Extract only the unpadded portion - result = result_padded[:expected_size].view(batch, -1) + def get_arg_spec(self): + return [ + AIERuntimeArgSpec("in", (self.seq_len_padded * self.embedding_dim_padded,)), + AIERuntimeArgSpec( + "out", (self.seq_len_padded * self.embedding_dim_padded,) + ), + ] - return result + def get_callable(self): + return SwiGLUPrefillCallable(self) diff --git a/operators/tanh/op.py b/operators/tanh/op.py index 6a71f559..d1c42820 100644 --- a/operators/tanh/op.py +++ b/operators/tanh/op.py @@ -7,7 +7,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -15,7 +15,7 @@ ) -class AIETanh(SingleMLIRSourceOperator): +class AIETanh(MLIROperator): """AIE-accelerated Tanh activation function""" def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None): @@ -34,7 +34,7 @@ def __init__(self, size, num_aie_columns, num_channels, tile_size, context=None) total_shimdma_channels = self.num_columns * self.num_channels assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"tanh_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" diff --git a/operators/transpose/op.py b/operators/transpose/op.py index 8a2e3adc..7a8a8c73 100644 --- a/operators/transpose/op.py +++ b/operators/transpose/op.py @@ -4,7 +4,7 @@ from pathlib import Path from operators.common import ( - SingleMLIRSourceOperator, + MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, @@ -12,7 +12,7 @@ ) -class AIETranspose(SingleMLIRSourceOperator): +class AIETranspose(MLIROperator): """AIE-accelerated transpose operator""" def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): @@ -32,7 +32,7 @@ def __init__(self, M, N, num_aie_columns, num_channels, m, n, s, context=None): self.num_columns = num_aie_columns self.num_channels = num_channels - SingleMLIRSourceOperator.__init__(self, context=context) + MLIROperator.__init__(self, context=context) def get_operator_name(self): return f"transpose_{self.num_columns}c_{self.num_channels}ch_{self.M}x{self.N}_{self.m}x{self.n}_{self.s}s" From 12b1dc63be2b9ed1d023b51a6144de9f0c97b346 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 17:42:53 -0700 Subject: [PATCH 73/99] some work with swiglu --- operators/common/base.py | 4 ++-- operators/common/test_utils.py | 2 ++ operators/gemm/op.py | 4 ++-- operators/swiglu_decode/op.py | 22 +++++++--------------- operators/swiglu_decode/test.py | 2 +- operators/swiglu_prefill/op.py | 8 +++----- operators/swiglu_prefill/test.py | 6 +++--- 7 files changed, 20 insertions(+), 28 deletions(-) diff --git a/operators/common/base.py b/operators/common/base.py index 66c42c02..6f9eb821 100644 --- a/operators/common/base.py +++ b/operators/common/base.py @@ -114,8 +114,8 @@ def get_mlir_artifact(self): def get_kernel_artifacts(self): pass - def get_artifacts(self): - operator_name = self.get_operator_name() + def get_artifacts(self, prefix=""): + operator_name = prefix + self.get_operator_name() mlir_artifact = self.get_mlir_artifact() kernel_deps_inputs = self.get_kernel_artifacts() if len(kernel_deps_inputs) > 0: diff --git a/operators/common/test_utils.py b/operators/common/test_utils.py index cb33afa2..53f40e9f 100644 --- a/operators/common/test_utils.py +++ b/operators/common/test_utils.py @@ -139,6 +139,8 @@ def run_test( # Verify outputs errors = {} for buf_name, expected in output_buffers.items(): + if expected is None: + continue if buf_name in output_map: buf = output_map[buf_name] output_np = buf.view_as_np() diff --git a/operators/gemm/op.py b/operators/gemm/op.py index c3d100c4..f49a3ed7 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -82,7 +82,7 @@ def get_mlir_artifact(self): emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( "emulate_bf16_mmul_with_bfp16", True ) - prio_accuracy = self.gemm_args.get("prio_accuracy", False) + prio_accuracy = False # Force False for debugging use_scalar = self.gemm_args.get("use_scalar", False) round_conv_even = self.gemm_args.get("round_conv_even", True) separate_c_tiles = self.gemm_args.get("separate_c_tiles", False) @@ -118,7 +118,7 @@ def get_kernel_artifacts(self): emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( "emulate_bf16_mmul_with_bfp16", True ) - prio_accuracy = self.gemm_args.get("prio_accuracy", False) + prio_accuracy = False # Force False for debugging round_conv_even = self.gemm_args.get("round_conv_even", True) kernel_flags = [ f"-DDIM_M={self.tile_m}", diff --git a/operators/swiglu_decode/op.py b/operators/swiglu_decode/op.py index c2694728..a16ac76a 100644 --- a/operators/swiglu_decode/op.py +++ b/operators/swiglu_decode/op.py @@ -124,9 +124,7 @@ def set_up_artifacts(self): tile_size_output=self.hidden_dim // 8, ) self.gemv_1 = gemv_1 - gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts( - prefix="swiglu_decode_gemv_1_" - ) + gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts() gemv_1_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_1", "--xclbin-kernel-id=0x901", @@ -139,39 +137,35 @@ def set_up_artifacts(self): silu = AIESiLU( size=self.hidden_dim, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim // 16, ) self.silu = silu self.hidden_dim_padded = silu.size - silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_decode_silu_") + silu_xclbin, silu_insts = silu.get_artifacts() silu_xclbin.xclbin_input = gemv_1_xclbin silu_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_silu", "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.dependencies += [gemv_1_xclbin] + silu_xclbin.dependencies.add(gemv_1_xclbin) artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( size=self.hidden_dim, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim // 8, ) self.eltwise_mul = eltwise_mul assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded - eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts( - prefix="swiglu_decode_eltwise_mul_" - ) + eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts() eltwise_mul_xclbin.xclbin_input = silu_xclbin eltwise_mul_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_eltwise_mul", "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.dependencies += [silu_xclbin] + eltwise_mul_xclbin.dependencies.add(silu_xclbin) artifacts.append(eltwise_mul_insts) gemv_2 = AIEGEMV( @@ -182,16 +176,14 @@ def set_up_artifacts(self): tile_size_output=self.embedding_dim // 8, ) self.gemv_2 = gemv_2 - gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts( - prefix="swiglu_decode_gemv_2_" - ) + gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts() gemv_2_xclbin.xclbin_input = eltwise_mul_xclbin gemv_2_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_2", "--xclbin-kernel-id=0x904", ] gemv_2_xclbin.kernel_name = "swiglu_gemv_2" - gemv_2_xclbin.dependencies += [eltwise_mul_xclbin] + gemv_2_xclbin.dependencies.add(eltwise_mul_xclbin) artifacts.append(gemv_2_xclbin) artifacts.append(gemv_2_insts) diff --git a/operators/swiglu_decode/test.py b/operators/swiglu_decode/test.py index 9cb372c1..1970f6f5 100755 --- a/operators/swiglu_decode/test.py +++ b/operators/swiglu_decode/test.py @@ -54,7 +54,7 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): # Note that the previous intermediate result _is_ still verified up to the given tolerance. input_buffers = {"input": golden_ref["input"]} - output_buffers = {} + output_buffers = {"output": None} intermediate_buffers = { "left": golden_ref["left"], "left_swished": golden_ref["left_swished"], diff --git a/operators/swiglu_prefill/op.py b/operators/swiglu_prefill/op.py index b6231b2f..04e67d2d 100644 --- a/operators/swiglu_prefill/op.py +++ b/operators/swiglu_prefill/op.py @@ -146,7 +146,6 @@ def set_up_artifacts(self): silu = AIESiLU( size=self.seq_len_padded * self.hidden_dim_padded, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim_padded // 8, ) self.silu = silu @@ -159,13 +158,12 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x902", ] silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.dependencies += [gemm_1_xclbin] + silu_xclbin.dependencies.add(gemm_1_xclbin) artifacts.append(silu_insts) eltwise_mul = AIEElementwiseMul( size=self.seq_len_padded * self.hidden_dim_padded, num_aie_columns=8, - num_channels=2, tile_size=self.hidden_dim_padded // 8, ) self.eltwise_mul = eltwise_mul @@ -180,7 +178,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x903", ] eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.dependencies += [silu_xclbin] + eltwise_mul_xclbin.dependencies.add(silu_xclbin) artifacts.append(eltwise_mul_insts) gemm_2 = AIEGEMM( @@ -198,7 +196,7 @@ def set_up_artifacts(self): "--xclbin-kernel-id=0x904", ] gemm_2_xclbin.kernel_name = "swiglu_gemm_2" - gemm_2_xclbin.dependencies += [eltwise_mul_xclbin] + gemm_2_xclbin.dependencies.add(eltwise_mul_xclbin) artifacts.append(gemm_2_xclbin) artifacts.append(gemm_2_insts) diff --git a/operators/swiglu_prefill/test.py b/operators/swiglu_prefill/test.py index 872c9442..90c0fb38 100755 --- a/operators/swiglu_prefill/test.py +++ b/operators/swiglu_prefill/test.py @@ -15,8 +15,8 @@ def generate_test_params(extensive=False): # This operation is currently untested except for the integrated llama application tests. - params = [] - names = [] + params = [(256, 2048, 2048, False)] + names = [f"swiglu_prefill_256x{emb}x{hid}" for _, emb, hid, _ in params] return params, names @@ -54,7 +54,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c input_buffers = {"input": golden_ref["input"]} # output_buffers = {'output': golden_ref['output']} - output_buffers = {} + output_buffers = {"output": None} intermediate_buffers = { "left": golden_ref["left"], "left_swished": golden_ref["left_swished"], From 1aa050e5bce4c827ee5a13b298ceb4bf90cffc97 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Thu, 5 Feb 2026 17:47:18 -0700 Subject: [PATCH 74/99] add licenses --- applications/llama_3.2_1b/llama_cpu.py | 3 +++ applications/llama_3.2_1b/llama_inference_harness.py | 4 ++++ applications/llama_3.2_1b/llama_npu.py | 3 +++ operators/common/compilation/__init__.py | 3 +++ operators/common/fusion.py | 3 +++ operators/softmax/test2.py | 3 +++ operators/strided_copy/test2.py | 3 +++ 7 files changed, 22 insertions(+) diff --git a/applications/llama_3.2_1b/llama_cpu.py b/applications/llama_3.2_1b/llama_cpu.py index 7d5c8ed3..f4db4257 100755 --- a/applications/llama_3.2_1b/llama_cpu.py +++ b/applications/llama_3.2_1b/llama_cpu.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import torch import math import llama_inference_harness as harness diff --git a/applications/llama_3.2_1b/llama_inference_harness.py b/applications/llama_3.2_1b/llama_inference_harness.py index 34b5cf5f..97cc5034 100644 --- a/applications/llama_3.2_1b/llama_inference_harness.py +++ b/applications/llama_3.2_1b/llama_inference_harness.py @@ -1,4 +1,8 @@ #!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """ Inference harness -- all the necessary code _other_ than the actual model (forward pass). Exposes a 'harness' function that can be called with a 'forward_pass' function that implements the model. diff --git a/applications/llama_3.2_1b/llama_npu.py b/applications/llama_3.2_1b/llama_npu.py index 214ad45e..565d94f9 100755 --- a/applications/llama_3.2_1b/llama_npu.py +++ b/applications/llama_3.2_1b/llama_npu.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + # Next steps for decode performance: # [ ] All decode operators operate on 2048-padded buffers; instead, should bin into shorter sequence lengths and call smaller operators # [ ] Opportunity to fuse data layout transformations (e.g., transpose ops) onto end of other operations (e.g., transpose after RoPE) diff --git a/operators/common/compilation/__init__.py b/operators/common/compilation/__init__.py index 9e924f5a..405df6b0 100644 --- a/operators/common/compilation/__init__.py +++ b/operators/common/compilation/__init__.py @@ -1,2 +1,5 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from .base import * from .fusion import * diff --git a/operators/common/fusion.py b/operators/common/fusion.py index 748b55e9..132a4b3c 100644 --- a/operators/common/fusion.py +++ b/operators/common/fusion.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import ml_dtypes import pyxrt diff --git a/operators/softmax/test2.py b/operators/softmax/test2.py index 864413f8..cbdfb93b 100755 --- a/operators/softmax/test2.py +++ b/operators/softmax/test2.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import sys from pathlib import Path import time diff --git a/operators/strided_copy/test2.py b/operators/strided_copy/test2.py index 46627022..994d8ab3 100755 --- a/operators/strided_copy/test2.py +++ b/operators/strided_copy/test2.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import sys from pathlib import Path import time From a8e07fcf2cb7cab7ef93c0578e6a67013cf6f892 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 06:38:26 -0700 Subject: [PATCH 75/99] Move remaining operators to iron/operators/ --- {operators => iron/operators}/repeat/design.py | 0 {operators => iron/operators}/repeat/op.py | 4 +++- {operators => iron/operators}/strided_copy/design.py | 0 {operators => iron/operators}/strided_copy/op.py | 4 +++- {operators => iron/operators}/strided_copy/test2.py | 0 5 files changed, 6 insertions(+), 2 deletions(-) rename {operators => iron/operators}/repeat/design.py (100%) rename {operators => iron/operators}/repeat/op.py (96%) rename {operators => iron/operators}/strided_copy/design.py (100%) rename {operators => iron/operators}/strided_copy/op.py (97%) rename {operators => iron/operators}/strided_copy/test2.py (100%) diff --git a/operators/repeat/design.py b/iron/operators/repeat/design.py similarity index 100% rename from operators/repeat/design.py rename to iron/operators/repeat/design.py diff --git a/operators/repeat/op.py b/iron/operators/repeat/op.py similarity index 96% rename from operators/repeat/op.py rename to iron/operators/repeat/op.py index a0317b85..b056f591 100644 --- a/operators/repeat/op.py +++ b/iron/operators/repeat/op.py @@ -6,12 +6,14 @@ from ml_dtypes import bfloat16 from pathlib import Path -from operators.common import ( +from iron.common import ( MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, + XclbinArtifact, + InstsBinArtifact, ) diff --git a/operators/strided_copy/design.py b/iron/operators/strided_copy/design.py similarity index 100% rename from operators/strided_copy/design.py rename to iron/operators/strided_copy/design.py diff --git a/operators/strided_copy/op.py b/iron/operators/strided_copy/op.py similarity index 97% rename from operators/strided_copy/op.py rename to iron/operators/strided_copy/op.py index feda0e7b..5996a90d 100644 --- a/operators/strided_copy/op.py +++ b/iron/operators/strided_copy/op.py @@ -6,12 +6,14 @@ from ml_dtypes import bfloat16 from pathlib import Path -from operators.common import ( +from iron.common import ( MLIROperator, AIERuntimeArgSpec, KernelObjectArtifact, SourceArtifact, PythonGeneratedMLIRArtifact, + XclbinArtifact, + InstsBinArtifact, ) diff --git a/operators/strided_copy/test2.py b/iron/operators/strided_copy/test2.py similarity index 100% rename from operators/strided_copy/test2.py rename to iron/operators/strided_copy/test2.py From 68a2e45cdf6a1196945ca62cee67211bd28e6061 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 09:21:01 -0700 Subject: [PATCH 76/99] try to minimize changes that are not central to the refactor --- .../llama_3.2_1b/analyze_profile.py | 358 +++++ .../llama_3.2_1b/configs/llama32_1b.json | 39 + .../configs/llama32_1b.json.license | 7 + iron/applications/llama_3.2_1b/inference.py | 445 ++++++ iron/applications/llama_3.2_1b/llama_cpu.py | 301 ---- .../llama_3.2_1b/llama_inference_harness.py | 249 --- iron/applications/llama_3.2_1b/llama_npu.py | 1348 ----------------- .../llama_3.2_1b/src/block/feed_forward.py | 250 +++ .../llama_3.2_1b/src/block/gqa.py | 505 ++++++ .../llama_3.2_1b/src/block/transformer.py | 195 +++ .../llama_3.2_1b/src/model_with_json.py | 309 ++++ .../llama_3.2_1b/src/tokenizer.py | 101 ++ iron/applications/llama_3.2_1b/src/utils.py | 307 ++++ iron/applications/llama_3.2_1b/test.py | 51 + .../applications/llama_3.2_1b/torch_to_npy.py | 49 + iron/operators/repeat/design.py | 76 - iron/operators/repeat/op.py | 69 - iron/operators/strided_copy/design.py | 133 -- iron/operators/strided_copy/op.py | 89 -- iron/operators/strided_copy/test2.py | 111 -- 20 files changed, 2616 insertions(+), 2376 deletions(-) create mode 100644 iron/applications/llama_3.2_1b/analyze_profile.py create mode 100644 iron/applications/llama_3.2_1b/configs/llama32_1b.json create mode 100644 iron/applications/llama_3.2_1b/configs/llama32_1b.json.license create mode 100755 iron/applications/llama_3.2_1b/inference.py delete mode 100755 iron/applications/llama_3.2_1b/llama_cpu.py delete mode 100644 iron/applications/llama_3.2_1b/llama_inference_harness.py delete mode 100755 iron/applications/llama_3.2_1b/llama_npu.py create mode 100644 iron/applications/llama_3.2_1b/src/block/feed_forward.py create mode 100644 iron/applications/llama_3.2_1b/src/block/gqa.py create mode 100644 iron/applications/llama_3.2_1b/src/block/transformer.py create mode 100644 iron/applications/llama_3.2_1b/src/model_with_json.py create mode 100644 iron/applications/llama_3.2_1b/src/tokenizer.py create mode 100644 iron/applications/llama_3.2_1b/src/utils.py create mode 100644 iron/applications/llama_3.2_1b/test.py create mode 100644 iron/applications/llama_3.2_1b/torch_to_npy.py delete mode 100644 iron/operators/repeat/design.py delete mode 100644 iron/operators/repeat/op.py delete mode 100644 iron/operators/strided_copy/design.py delete mode 100644 iron/operators/strided_copy/op.py delete mode 100755 iron/operators/strided_copy/test2.py diff --git a/iron/applications/llama_3.2_1b/analyze_profile.py b/iron/applications/llama_3.2_1b/analyze_profile.py new file mode 100644 index 00000000..7e2c76f1 --- /dev/null +++ b/iron/applications/llama_3.2_1b/analyze_profile.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Analyze profiling logs generated by inference.py + +This script parses the profile logs and provides statistics about function execution times. +The total times reported in the analysis results are the cumulative times of the functions, including subcalls. +""" + +import argparse +import re +from collections import defaultdict +from pathlib import Path +import sys +import csv +import statistics +from collections import deque + + +class FunctionStats: + def __init__(self, name): + self.name = name + self.call_count = 0 + self.total_time = 0.0 + self.min_time = float("inf") + self.max_time = 0.0 + self.durations = [] + + def add_duration(self, duration): + self.call_count += 1 + self.total_time += duration + self.min_time = min(self.min_time, duration) + self.max_time = max(self.max_time, duration) + self.durations.append(duration) + + @property + def avg_time(self): + if not self.durations: + return 0.0 + return statistics.mean(self.durations) + + @property + def median_time(self): + if not self.durations: + return 0.0 + return statistics.median(self.durations) + + +def parse_profile_log(log_file): + """ + Parse a profile log file and extract function timing information. + + Args: + log_file: Path to the profile log file + + Returns: + dict: Dictionary mapping function names to FunctionStats objects + """ + stats = defaultdict(lambda: FunctionStats("")) + function_stack = deque() # Track ongoing calls by function identifier + + # Regex patterns for parsing log lines + call_pattern = re.compile(r"\[CALL\] (.+?) started at ([\d.]+)") + return_pattern = re.compile(r"\[RETURN\] (.+?) ended at ([\d.]+)") + + with open(log_file, "r") as f: + for line in f: + # Try to match CALL pattern + call_match = call_pattern.search(line) + if call_match: + func_id = call_match.group(1) + timestamp = float(call_match.group(2)) + function_stack.append((func_id, timestamp)) + continue + + # Try to match RETURN pattern + return_match = return_pattern.search(line) + if return_match: + func_id = return_match.group(1) + timestamp = float(return_match.group(2)) + + # Use the full function identifier (filepath:function_name:line_no) + if func_id not in stats: + stats[func_id].name = func_id + + if function_stack: + stats[func_id].add_duration(timestamp - function_stack.pop()[1]) + else: + raise RuntimeError( + f"Stack empty, found a log for the return but missing a log for the call of {func_id}" + ) + + return dict(stats) + + +def print_summary(stats, sort_by="total", top_n=20, min_calls=1): + """ + Print a summary of function statistics. + + Args: + stats: Dictionary of function statistics + sort_by: Sort criterion ('total', 'avg', 'max', 'calls') + top_n: Number of top functions to display + min_calls: Minimum number of calls to include in results + """ + # Filter by minimum calls + filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} + + if not filtered_stats: + print("No functions found matching the criteria.") + return + + # Sort functions + sort_keys = { + "total": lambda x: x[1].total_time, + "avg": lambda x: x[1].avg_time, + "max": lambda x: x[1].max_time, + "calls": lambda x: x[1].call_count, + } + + if sort_by not in sort_keys: + print(f"Invalid sort criterion: {sort_by}. Using 'total'.") + sort_by = "total" + + sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) + + # Print header + print("\n" + "=" * 160) + print(f"FUNCTION PROFILING SUMMARY (sorted by {sort_by}, top {top_n})") + print("=" * 160) + print( + f"{'Function Identifier':<80} {'Calls':>8} {'Total (s)':>12} {'Avg (s)':>12} {'Min (s)':>12} {'Max (s)':>12} {'Median (s)':>12}" + ) + print("-" * 160) + + # Print top N functions + for func_name, func_stats in sorted_stats[:top_n]: + # Truncate long function identifiers for display + display_name = func_name if len(func_name) <= 80 else func_name[:77] + "..." + print( + f"{display_name:<80} {func_stats.call_count:>8} " + f"{func_stats.total_time:>12.6f} {func_stats.avg_time:>12.6f} " + f"{func_stats.min_time:>12.6f} {func_stats.max_time:>12.6f} " + f"{func_stats.median_time:>12.6f}" + ) + + print("-" * 160) + + +def print_function_details(stats, function_name): + """ + Print detailed statistics for functions matching the given name. + + Args: + stats: Dictionary of function statistics + function_name: Name or substring to search for in function identifiers + """ + # Find all function identifiers containing the function_name string + matching_funcs = { + func_id: func_stats + for func_id, func_stats in stats.items() + if function_name in func_id + } + + if not matching_funcs: + print(f"No functions found containing '{function_name}' in profile data.") + print(f"\nAvailable functions (showing first 20):") + for i, name in enumerate(sorted(stats.keys())[:20]): + print(f" - {name}") + if len(stats) > 20: + print(f" ... and {len(stats) - 20} more") + return + + print("\n" + "=" * 120) + print(f"DETAILED STATISTICS FOR FUNCTIONS CONTAINING: '{function_name}'") + print(f"Found {len(matching_funcs)} matching function(s)") + print("=" * 120) + + for func_id, func_stats in sorted( + matching_funcs.items(), key=lambda x: x[1].total_time, reverse=True + ): + print(f"\nFunction: {func_id}") + print("-" * 120) + print(f" Total calls: {func_stats.call_count:,}") + print(f" Total time: {func_stats.total_time:.6f} seconds") + print(f" Average time: {func_stats.avg_time:.6f} seconds") + print(f" Median time: {func_stats.median_time:.6f} seconds") + print(f" Min time: {func_stats.min_time:.6f} seconds") + print(f" Max time: {func_stats.max_time:.6f} seconds") + + if func_stats.call_count > 1: + std_dev = statistics.stdev(func_stats.durations) + print(f" Std deviation: {std_dev:.6f} seconds") + + print("=" * 120 + "\n") + + +def export_to_csv(stats, output_file, sort_by="total", min_calls=1): + """ + Export function statistics to a CSV file. + + Args: + stats: Dictionary of function statistics + output_file: Path to output CSV file + sort_by: Sort criterion ('total', 'avg', 'max', 'calls') + min_calls: Minimum number of calls to include in results + """ + # Filter by minimum calls + filtered_stats = {name: s for name, s in stats.items() if s.call_count >= min_calls} + + if not filtered_stats: + print("No functions found matching the criteria.") + return + + # Sort functions + sort_keys = { + "total": lambda x: x[1].total_time, + "avg": lambda x: x[1].avg_time, + "max": lambda x: x[1].max_time, + "calls": lambda x: x[1].call_count, + } + + if sort_by not in sort_keys: + print(f"Invalid sort criterion: {sort_by}. Using 'total'.") + sort_by = "total" + + sorted_stats = sorted(filtered_stats.items(), key=sort_keys[sort_by], reverse=True) + + # Write to CSV + with open(output_file, "w", newline="") as csvfile: + fieldnames = [ + "function_name", + "call_count", + "total_time_seconds", + "avg_time_seconds", + "median_time_seconds", + "min_time_seconds", + "max_time_seconds", + "std_dev_seconds", + ] + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + for func_name, func_stats in sorted_stats: + std_dev = ( + statistics.stdev(func_stats.durations) + if func_stats.call_count > 1 + else 0.0 + ) + + writer.writerow( + { + "function_name": func_name, + "call_count": func_stats.call_count, + "total_time_seconds": f"{func_stats.total_time:.9f}", + "avg_time_seconds": f"{func_stats.avg_time:.9f}", + "median_time_seconds": f"{func_stats.median_time:.9f}", + "min_time_seconds": f"{func_stats.min_time:.9f}", + "max_time_seconds": f"{func_stats.max_time:.9f}", + "std_dev_seconds": f"{std_dev:.9f}", + } + ) + + print(f"\nCSV file saved to: {output_file}") + print(f"Total functions exported: {len(sorted_stats)}") + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze profiling logs from inference.py", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Analyze the most recent profile log + python analyze_profile.py + + # Analyze a specific log file + python analyze_profile.py logs/profile_20250110_160000.log + + # Sort by average time and show top 30 + python analyze_profile.py --sort avg --top 30 + + # Show details for a specific function + python analyze_profile.py --function inference + + # Filter functions with at least 10 calls + python analyze_profile.py --min-calls 10 + + # Export to CSV file + python analyze_profile.py --csv profile_stats.csv + + # Export to CSV with custom sorting and filtering + python analyze_profile.py --csv results.csv --sort avg --min-calls 5 + """, + ) + + parser.add_argument( + "log_file", + type=str, + help="Path to profile log file", + ) + parser.add_argument( + "--sort", + choices=["total", "avg", "max", "calls"], + default="total", + help="Sort criterion (default: total)", + ) + parser.add_argument( + "--top", + type=int, + default=20, + help="Number of top functions to display (default: 20)", + ) + parser.add_argument( + "--min-calls", + type=int, + default=1, + help="Minimum number of calls to include (default: 1)", + ) + parser.add_argument( + "--function", type=str, help="Show detailed statistics for a specific function" + ) + parser.add_argument( + "--csv", + type=str, + help="Export results to CSV file instead of printing to console", + ) + + args = parser.parse_args() + + # Parse the log file + log_file = Path(args.log_file) + print(f"Parsing {log_file}...") + stats = parse_profile_log(log_file) + + if not stats: + print("No profiling data found in log file.") + else: + print(f"Found {len(stats)} unique functions") + + # Show results + if args.csv: + # Export to CSV + export_to_csv(stats, args.csv, sort_by=args.sort, min_calls=args.min_calls) + elif args.function: + # Show detailed function statistics + print_function_details(stats, args.function) + else: + # Print summary to console + print_summary( + stats, sort_by=args.sort, top_n=args.top, min_calls=args.min_calls + ) + + +if __name__ == "__main__": + main() diff --git a/iron/applications/llama_3.2_1b/configs/llama32_1b.json b/iron/applications/llama_3.2_1b/configs/llama32_1b.json new file mode 100644 index 00000000..ed6bc4bf --- /dev/null +++ b/iron/applications/llama_3.2_1b/configs/llama32_1b.json @@ -0,0 +1,39 @@ +{ + "model_config": { + "vocab_size": 128256, + "context_length": 131072, + "emb_dim": 2048, + "n_heads": 32, + "n_layers": 16, + "hidden_dim": 8192, + "n_kv_groups": 8, + "use_kv_cache": true, + "rope_base": 500000.0, + "dtype": "bfloat16", + "use_aie_final_norm": true, + "use_aie_ffn_gemm": false, + "use_aie_ffn_silu": false, + "use_aie_ffn_mul": false, + "use_aie_ffn_swiglu": true, + "use_aie_ffn_gemv": true, + "use_aie_attn_projection_gemm": true, + "use_aie_gqa_gemv": true, + "use_aie_rope": true, + "use_aie_norm1": true, + "use_aie_norm2": true, + "use_aie_residual": true, + "use_aie_regular_mha": false, + "use_aie_fused_mha": true, + "use_aie_final_gemm": true, + "use_aie_final_gemv": true, + "rope_freq": { + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_context_length": 8192 + } + }, + "aie_config": { + "device": "npu2" + } +} \ No newline at end of file diff --git a/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license b/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license new file mode 100644 index 00000000..50daea92 --- /dev/null +++ b/iron/applications/llama_3.2_1b/configs/llama32_1b.json.license @@ -0,0 +1,7 @@ +Copyright (c) Sebastian Raschka under Apache License 2.0. +Source for "Build a Large Language Model From Scratch" + - https://www.manning.com/books/build-a-large-language-model-from-scratch +Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb + +SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: Apache-2.0 diff --git a/iron/applications/llama_3.2_1b/inference.py b/iron/applications/llama_3.2_1b/inference.py new file mode 100755 index 00000000..8109c543 --- /dev/null +++ b/iron/applications/llama_3.2_1b/inference.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +from pathlib import Path + +import argparse +import time +import torch +from src.model_with_json import Llama3ModelWithJSONConfig + +# from src.model import Llama3Model +from src.tokenizer import Tokenizer, ChatFormat +from safetensors.torch import load_file +import os +import shutil +import logging +from collections import deque + +from iron.common import AIEOperatorBase +from src.utils import ( + model_memory_size, + load_weights_into_llama, + text_to_token_ids, + token_ids_to_text, + clean_text, + generate, +) + +# Global logger for profiling +_profile_logger = None + + +def profile_function_calls(frame, event, arg): + """ + Profile function that logs start and end times of every function call. + + Args: + frame: The current stack frame + event: The event type ('call', 'return', 'c_call', 'c_return', 'c_exception') + arg: Event-specific argument + """ + global _profile_logger + + if _profile_logger is None: + return + + func_name = frame.f_code.co_name + filename = frame.f_code.co_filename + line_no = frame.f_lineno + + # Create a readable function identifier + func_identifier = f"{filename}:{func_name}:{line_no}" + + if event == "call": + # Function is being called + timestamp = time.perf_counter() + _profile_logger.debug(f"[CALL] {func_identifier} started at {timestamp:.9f}") + + elif event == "return": + # Function is returning + timestamp = time.perf_counter() + _profile_logger.debug(f"[RETURN] {func_identifier} ended at {timestamp:.9f}") + + return profile_function_calls + + +def enable_profiling(logs_dir_name): + """Enable function call profiling using sys.setprofile.""" + global _profile_logger + + # Create a dedicated logger for profiling + _profile_logger = logging.getLogger("function_profiler") + _profile_logger.setLevel(logging.DEBUG) + # Prevent propagation to root logger to avoid console output + _profile_logger.propagate = False + + # Create log file for profiling data + timestamp = time.strftime("%Y%m%d_%H%M%S") + log_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + logs_dir_name, + f"profile_{timestamp}.log", + ) + + # Add file handler for profiling (only file, no console output) + profile_handler = logging.FileHandler(log_path) + profile_handler.setLevel(logging.DEBUG) + profile_formatter = logging.Formatter("%(asctime)s - %(message)s") + profile_handler.setFormatter(profile_formatter) + _profile_logger.addHandler(profile_handler) + + # Set the profile function + sys.setprofile(profile_function_calls) + _profile_logger.info("Function profiling enabled") + + # Explicitly call profile_function_calls to log this function's call + import inspect + + frame = inspect.currentframe() + profile_function_calls(frame, "call", None) + + +def disable_profiling(): + """Disable function call profiling.""" + global _profile_logger + + sys.setprofile(None) + if _profile_logger: + _profile_logger.info("Function profiling disabled") + # Close all handlers + for handler in _profile_logger.handlers[:]: + handler.close() + _profile_logger.removeHandler(handler) + + +_iron_chat = r""" + /$$$$$$ /$$$$$$$ /$$$$$$ /$$ /$$ + |_ $$_/| $$__ $$ /$$__ $$| $$$ | $$ + | $$ | $$ \ $$| $$ \ $$| $$$$| $$ + | $$ | $$$$$$$/| $$ | $$| $$ $$ $$ + | $$ | $$__ $$| $$ | $$| $$ $$$$ + | $$ | $$ \ $$| $$ | $$| $$\ $$$ + /$$$$$$| $$ | $$| $$$$$$/| $$ \ $$ + |______/|__/ |__/ \______/ |__/ \__/ + + + /$$ /$$ /$$$$$$ /$$ /$$ /$$$$$$ +| $$ | $$ /$$__ $$| $$$ /$$$ /$$__ $$ +| $$ | $$ | $$ \ $$| $$$$ /$$$$| $$ \ $$ +| $$ | $$ | $$$$$$$$| $$ $$/$$ $$| $$$$$$$$ +| $$ | $$ | $$__ $$| $$ $$$| $$| $$__ $$ +| $$ | $$ | $$ | $$| $$\ $ | $$| $$ | $$ +| $$$$$$$$| $$$$$$$$| $$ | $$| $$ \/ | $$| $$ | $$ +|________/|________/|__/ |__/|__/ |__/|__/ |__/ +""" + + +def setup_logging(verbosity): + """Set up logging based on verbosity level.""" + + # Ensure the logs directory is created in case of profiling + logs_dir_name = "logs" + if not os.path.exists(logs_dir_name): + os.makedirs(logs_dir_name) + + if verbosity != 0: + levels = { + 4: logging.DEBUG, + 3: logging.INFO, + 2: logging.WARNING, + # 1: log everything (DEBUG) to a file + } + + # Create log file + timestamp = time.strftime("%Y%m%d_%H%M%S") + log_file = f"logs/inference_{timestamp}.log" + + handlers = [logging.FileHandler(log_file)] + if verbosity > 0: + handlers.append(logging.StreamHandler(sys.stderr)) + handlers[-1].setLevel(levels[verbosity]) + + # Configure root logger + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=handlers, + force=True, # Override any existing configuration + ) + + return logs_dir_name + + +def save_layer_data(module, input, output, name, input_data_path, output_data_path): + for count, i in enumerate(input): + torch.save( + i.detach(), + f"{input_data_path}/{name}_input_{count}_{input[0].size()[1]}_toks.pt", + ) + torch.save( + output.detach(), f"{output_data_path}/{name}_output_{output.size()[1]}_toks.pt" + ) + + +def inference( + weights_file_path, + tokenizer_file_path, + num_tokens, + prompt, + use_prompt_template, + save_outputs, + chat: bool, + prompt_len: int = 64, +): + """ + Main function to load a Llama3 model, process input, and generate output text. + """ + logging.info("Weights file path: %s", weights_file_path) + logging.info("Tokenizer file path: %s", tokenizer_file_path) + logging.info("Number of tokens: %d", num_tokens) + logging.debug("Prompt: %s", prompt) + logging.info("Use prompt template: %s", use_prompt_template) + logging.info("Save outputs: %s", save_outputs) + torch.manual_seed(1608560892) + input_data_path = "results/inputs" + output_data_path = "results/outputs" + + tokenizer = Tokenizer(tokenizer_file_path) + + print(_iron_chat) + if chat: + prompt = input("Enter your prompt: ").strip() + print("") + + logging.info(f"Loading model and tokenizer...") + token_ids = text_to_token_ids(prompt, tokenizer)[:, :prompt_len] + truncated_prompt = token_ids_to_text(token_ids, tokenizer) + + script_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(script_dir, "configs", "llama32_1b.json") + model = Llama3ModelWithJSONConfig( + config_path=config_path, + prompt_length=prompt_len, + num_tokens=num_tokens, + ) + logging.info("Model and tokenizer loaded.") + + # Important: Set the seed again after initialization of the model. Each + # call that initializes an nn.Linear layer updates the RNG state, because + # weights are initialized with random values. For different JSON + # configurations, we initialize a different number of linear layers, + # so different configurations result in a different RNG state here. Since + # we use random numbers to sample from the token distribution during + # inference, it is important to have the same RNG state between runs so we + # can have reproducible results across configurations. + torch.manual_seed(1608560892) + + hook_handles = [] + if save_outputs: + if os.path.exists(output_data_path): + shutil.rmtree(output_data_path) + os.makedirs(output_data_path) + if os.path.exists(input_data_path): + shutil.rmtree(input_data_path) + os.makedirs(input_data_path) + for name, module in model.named_modules(): + handle = module.register_forward_hook( + lambda module, input, output, name=name, input_data_path=input_data_path, output_data_path=output_data_path: ( + save_layer_data( + module, input, output, name, input_data_path, output_data_path + ) + ) + ) + hook_handles.append(handle) + + device = torch.device("cpu") + model.to(device) + chat_tokenizer = ChatFormat(tokenizer) + + total_params = sum(p.numel() for p in model.parameters()) + total_params_normalized = total_params - model.tok_emb.weight.numel() + logging.info(f"Total number of parameters: {total_params:,}") + logging.info(f"Total number of unique parameters: {total_params_normalized:,}") + logging.info( + f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB" + ) + logging.info( + f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB" + ) + + combined_weights = load_file(weights_file_path) + # Get parameters from model config + model_config = { + "n_layers": model.cfg["n_layers"], + "emb_dim": model.cfg["emb_dim"], + "n_heads": model.cfg["n_heads"], + "n_kv_groups": model.cfg["n_kv_groups"], + "vocab_size": model.cfg["vocab_size"], + "context_length": model.cfg["context_length"], + "hidden_dim": model.cfg["hidden_dim"], + "rope_base": model.cfg["rope_base"], + "dtype": model.cfg["dtype"], + "rope_freq": model.cfg["rope_freq"], + } + load_weights_into_llama(model, model_config, combined_weights) + model.to(device) + del combined_weights + + logging.info("Preparing AIE operators...") + # At this point the model is fully described (operators and their dimensions and how to compile them) + AIEOperatorBase.get_default_context().compile_all() + AIEOperatorBase.get_default_context().prepare_runtime() + logging.info("AIE operator preparation completed.") + print(f"Starting text generation...") + print(f"Generating {num_tokens} tokens...") + print("=" * 55) + + prefill_end_time = None + + def set_prefill_time(): + nonlocal prefill_end_time + prefill_end_time = time.time() + + # Start total wall clock timing + start = time.time() + token_ids = generate( + model=model, + idx=token_ids.to(device), + max_new_tokens=num_tokens, + context_size=model.cfg["context_length"], + eos_id=tokenizer.special["<|end_of_text|>"], + hook_handles=hook_handles, + temperature=0.7, + top_k=50, + tokenizer=tokenizer, + prompt=truncated_prompt, + prefill_done_callback=set_prefill_time, + ) + end = time.time() + prefill_time = prefill_end_time - start + total_time = end - start + post_prefill_time = end - prefill_end_time if num_tokens > 0 else 0 + + tokens_per_second = (num_tokens - 1) / post_prefill_time if num_tokens > 1 else 0 + time_per_token = total_time / (num_tokens - 1) if num_tokens > 1 else prefill_time + + print("=" * 55) + print(" TIMING RESULTS:") + print(f" Total time: {total_time:.4f} seconds") + print(f" Prefill time: {prefill_time:.4f} seconds") + print(f" Tokens generated: {num_tokens}") + print(f" Tokens per second: {tokens_per_second:.2f}") + print( + f" Time per token: {time_per_token:.4f} seconds" + if num_tokens > 0 + else " Time per token: N/A" + ) + print("=" * 55) + + logging.info(f"Generation time: {total_time:.4f} sec") + logging.info(f"Total wall clock time: {total_time:.4f} sec") + logging.info(f"Tokens per second: {tokens_per_second:.2f}") + logging.info( + f"Time per token: {time_per_token:.4f} sec" + if num_tokens > 0 + else "Time per token: N/A" + ) + + output_text = token_ids_to_text(token_ids, tokenizer) + logging.info("Output text:\n %s", clean_text(output_text)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run Llama3 model inference.") + parser.add_argument( + "weights_file_path", + type=str, + help="Path to the weights file: model.safetensors", + ) + parser.add_argument( + "tokenizer_file_path", + type=str, + help="Path to the tokenizer file: tokenizer.model", + ) + parser.add_argument( + "--num_tokens", type=int, default=1, help="Number of tokens to predict." + ) + parser.add_argument( + "--prompt", + type=str, + default="", + help="Prompt for the model to generate text from.", + ) + parser.add_argument( + "--use_prompt_template", + action="store_true", + help="Use a prompt template for the model.", + ) + parser.add_argument( + "--save_outputs", + action="store_true", + help="Enable hooks to save outputs of the layers in the model", + ) + parser.add_argument( + "--chat", + action="store_true", + help="Enable interactive mode to enter your own prompt.", + ) + parser.add_argument( + "--prompt_len", + type=int, + default=2048, + help="Truncate prompt to this many tokens.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use a custom profiler for performance measurements", + ) + parser.add_argument( + "-v", + action="count", + default=0, + help="Increase verbosity level (use -v (logs to file), -vv, -vvv, or -vvvv)", + ) + args = parser.parse_args() + + # Set up logging + logs_dir_name = setup_logging(args.v) + + # Enable function profiling + if args.profile: + enable_profiling(logs_dir_name) + + try: + prompt = args.prompt + if not prompt: + # Default prompt is text from Shakespeare's King Lear: https://shakespeare.mit.edu/lear/lear.1.1.html + prompt_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "prompt.txt" + ) + with open(prompt_path, "r", encoding="utf-8") as file: + prompt = file.read().strip() + + inference( + args.weights_file_path, + args.tokenizer_file_path, + args.num_tokens, + prompt, + args.use_prompt_template, + args.save_outputs, + args.chat, + args.prompt_len, + ) + finally: + if args.profile: + # Disable profiling when done + disable_profiling() diff --git a/iron/applications/llama_3.2_1b/llama_cpu.py b/iron/applications/llama_3.2_1b/llama_cpu.py deleted file mode 100755 index f4db4257..00000000 --- a/iron/applications/llama_3.2_1b/llama_cpu.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import math -import llama_inference_harness as harness - -# Operators -# ########################################################################## - - -def rope_forward(x, angles): - """Rotary positional embedding using precomputed angles""" - # x: (batch, seq_len, num_heads, head_dim) after view and before transpose - # angles: (context_length, head_dim) - _, seq_len, _, head_dim = x.shape - angles_slice = angles[:seq_len] # (seq_len, head_dim) - - # Split into even and odd dimensions - x1 = x[..., : head_dim // 2] # (batch, seq_len, num_heads, head_dim//2) - x2 = x[..., head_dim // 2 :] # (batch, seq_len, num_heads, head_dim//2) - - # Get cos and sin from angles - cos = angles_slice[:, ::2] # (seq_len, head_dim//2) - sin = angles_slice[:, 1::2] # (seq_len, head_dim//2) - - # Reshape for broadcasting: (1, seq_len, 1, head_dim//2) - # (The same cosine and sine values are used across batch and heads.) - cos = cos.unsqueeze(0).unsqueeze(2) - sin = sin.unsqueeze(0).unsqueeze(2) - - # Rotate: [x1*cos - x2*sin, x1*sin + x2*cos] - rotated = torch.empty_like(x) - rotated[..., : head_dim // 2] = x1 * cos - x2 * sin - rotated[..., head_dim // 2 :] = x1 * sin + x2 * cos - - return rotated - - -def rms_norm_forward(x, weight, eps=1e-5): - """Root Mean Square Layer Normalization""" - # x: (batch, seq_len, dim) - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - return weight * x - - -def grouped_query_attention_forward( - x, - keys_cache, - values_cache, - W_query, - W_key, - W_value, - W_out, - angles, - mask=None, - num_heads=32, - num_kv_groups=8, -): - batch, seq_len, d_in = x.shape - assert W_query.shape[0] >= num_heads and W_query.shape[0] % num_heads == 0 - head_dim = W_query.shape[0] // num_heads - assert W_key.shape[0] == num_kv_groups * head_dim - assert W_value.shape[0] == num_kv_groups * head_dim - num_preceding_tokens = keys_cache.shape[2] - assert keys_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) - assert values_cache.shape == (batch, num_kv_groups, num_preceding_tokens, head_dim) - - # Step 1: Linear projections - # This multiplication produces queries, keys and values for all tokens in the sequence. - # The weight matrix is such that multiple queries, keys and values are generated for each token. - # For each token, each head corresponds to one query. - # In particular, each token gets `num_heads` queries and `num_kv_groups` keys/values (keys/values shared for multiple queries). - # Due to the structure of the matmul, all queries, keys and values are contiguous for each token. - # Note that during the decode phase, seq_len=1, and we are only calculating the projections for the most recent token -- the keys and values of previous tokens will be concatenated in step 4. - queries = torch.nn.functional.linear( - x, W_query - ) # (batch, seq_len, num_heads * head_dim) - keys = torch.nn.functional.linear( - x, W_key - ) # (batch, seq_len, num_kv_groups * head_dim) - values = torch.nn.functional.linear( - x, W_value - ) # (batch, seq_len, num_kv_groups * head_dim) - queries = queries.view( - batch, seq_len, num_heads, head_dim - ) # (batch, seq_len, num_heads, head_dim) - keys = keys.view( - batch, seq_len, num_kv_groups, head_dim - ) # (batch, seq_len, num_kv_groups, head_dim) - values = values.view( - batch, seq_len, num_kv_groups, head_dim - ) # (batch, seq_len, num_kv_groups, head_dim) - - # Step 2: Apply RoPE - queries = rope_forward( - queries, angles[num_preceding_tokens : num_preceding_tokens + seq_len] - ) - keys = rope_forward( - keys, angles[num_preceding_tokens : num_preceding_tokens + seq_len] - ) - - # Step 3: Transpose for attention computation - # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. - # Transpose so that heads are consecutive for attention computation: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) - queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) - keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - - # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. - keys_cache = torch.cat([keys_cache, keys], dim=2) - values_cache = torch.cat([values_cache, values], dim=2) - keys = keys_cache - values = values_cache - - # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value - group_size = num_heads // num_kv_groups - keys = keys.repeat_interleave(group_size, dim=1) - values = values.repeat_interleave(group_size, dim=1) - - # Step 6: Compute attention scores - # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) - # -> (batch, num_heads, seq_len, seq_len) - # Entry at row i, column j, indicates how much token i's query attends to token j's key. - scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(head_dim) - - # Step 7: Apply mask - # This ensures causality, so that tokens in the future cannot attend to tokens in the past. - if mask is not None: - scores = scores.masked_fill(mask, float("-inf")) - - # Step 8: Apply softmax to squeeze scores into probabilities (0, 1) - attention_weights = torch.nn.functional.softmax(scores, dim=-1) - - # Step 9: Compute attention output - # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) - # -> (batch, num_heads, seq_len, head_dim) - context = torch.matmul(attention_weights, values) - - # Step 10: Concatenate heads and project - # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) - context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - - output = torch.nn.functional.linear(context, W_out) - - return output, keys_cache, values_cache - - -def swiglu_ffn_forward(x, fc1_weight, fc2_weight, fc3_weight): - # Step 1: Parallel projections: (batch, seq_len, embedding_dim) -> (batch, seq_len, swiglu_hidden_dim) - gate = torch.nn.functional.linear(x, fc1_weight) # gate projection - up = torch.nn.functional.linear(x, fc2_weight) # up projection - - # Step 2: Apply SiLU activation - gate_activated = torch.nn.functional.silu( - gate - ) # (batch, seq_len, swiglu_hidden_dim) - - # Step 3: Element-wise multiplication (apply the 'gating') - hidden = gate_activated * up # (batch, seq_len, swiglu_hidden_dim) - - # Step 4: Down projection: (batch, seq_len, swiglu_hidden_dim) -> (batch, seq_len, embedding_dim) - output = torch.nn.functional.linear(hidden, fc3_weight) - - return output - - -def transformer_block_forward( - x, - attn_keys_cache, - attn_values_cache, - num_heads, - num_kv_groups, - W_norm1, - W_attn_query, - W_attn_key, - W_attn_value, - W_attn_out, - W_norm2, - W_ffn_fc1, - W_ffn_fc2, - W_ffn_fc3, - rope_angles, - attn_mask, -): - # Step 1: RMS normalization - x_norm = rms_norm_forward(x, W_norm1) - - # Step 2: Attention - attn_output, attn_keys, attn_values = grouped_query_attention_forward( - x_norm, - attn_keys_cache, - attn_values_cache, - W_attn_query, - W_attn_key, - W_attn_value, - W_attn_out, - rope_angles, - attn_mask, - num_heads, - num_kv_groups, - ) - - # Step 3: Residual - x = x + attn_output - - # Step 4: Post-norm - x_norm = rms_norm_forward(x, W_norm2) - - # Step 5: fully-connected feed-forward network - ffn_output = swiglu_ffn_forward(x_norm, W_ffn_fc1, W_ffn_fc2, W_ffn_fc3) - - # Step 6: Residual - x = x + ffn_output - - return x, attn_keys, attn_values - - -def llama_forward_pass(config, state): - batch, seq_len = state.token_ids.shape - - # Step 1: Token embedding - tok_emb_weight = config.weights["model.embed_tokens.weight"] - x = torch.nn.functional.embedding( - state.token_ids, tok_emb_weight - ) # (batch, seq_len, emb_dim) - - # Step 2: Create causal mask - attn_mask = torch.triu( - torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 - ) - - # Step 3: Apply transformer blocks - for layer_idx in range(config.n_layers): - x, state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( - transformer_block_forward( - x, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], - config.n_heads, - config.n_kv_groups, - W_norm1=config.weights[ - f"model.layers.{layer_idx}.input_layernorm.weight" - ], - W_attn_query=config.weights[ - f"model.layers.{layer_idx}.self_attn.q_proj.weight" - ], - W_attn_key=config.weights[ - f"model.layers.{layer_idx}.self_attn.k_proj.weight" - ], - W_attn_value=config.weights[ - f"model.layers.{layer_idx}.self_attn.v_proj.weight" - ], - W_attn_out=config.weights[ - f"model.layers.{layer_idx}.self_attn.o_proj.weight" - ], - W_ffn_fc1=config.weights[ - f"model.layers.{layer_idx}.mlp.gate_proj.weight" - ], - W_ffn_fc2=config.weights[ - f"model.layers.{layer_idx}.mlp.up_proj.weight" - ], - W_ffn_fc3=config.weights[ - f"model.layers.{layer_idx}.mlp.down_proj.weight" - ], - W_norm2=config.weights[ - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ], - rope_angles=config.angles, - attn_mask=attn_mask, - ) - ) - - # Step 4: Final normalization - final_norm_weight = config.weights["model.norm.weight"] - x = rms_norm_forward(x, final_norm_weight) - - # Step 5: Output projection - logits = torch.nn.functional.linear( - x, config.weights["model.embed_tokens.weight"] - ) # (batch, seq_len, vocab_size) - - return logits, state - - -# Main -# ########################################################################## - - -def main(): - prompt = "The capital of France is " - config, state = harness.init(prompt=prompt) - print(prompt, end="", flush=True) - harness.generate(config, state, llama_forward_pass) - - -if __name__ == "__main__": - main() diff --git a/iron/applications/llama_3.2_1b/llama_inference_harness.py b/iron/applications/llama_3.2_1b/llama_inference_harness.py deleted file mode 100644 index 97cc5034..00000000 --- a/iron/applications/llama_3.2_1b/llama_inference_harness.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Inference harness -- all the necessary code _other_ than the actual model (forward pass). -Exposes a 'harness' function that can be called with a 'forward_pass' function that implements the model. -The 'harness' function does the following: -1. Load and set up model weights, tokenizer, and RoPE angle look-up table. -2. Tokenize the provided input prompt. -3. Run the generation loop to produce new tokens; this calls the provided forward_pass function. Decode and print each generated token. -""" - -import torch -import math -import sys -import time - -import safetensors.torch -import tiktoken, tiktoken.load - -# Configuration -# ########################################################################## - - -class LlamaConfig: - def __init__(self, weights_path, tokenizer_path): - # Model architecture - self.vocab_size = 128256 - self.emb_dim = 2048 - self.n_layers = 16 - self.n_heads = 32 - self.n_kv_groups = 8 - self.head_dim = self.emb_dim // self.n_heads # 64 - self.hidden_dim = 8192 - - # RoPE - self.rope_base = 500000.0 - self.context_length = 131072 - - # Generation - self.temperature = 0.7 - self.top_k = 50 - - # Tokenization - self.special_tokens = { - "<|begin_of_text|>": 128000, - "<|end_of_text|>": 128001, - "<|start_header_id|>": 128006, - "<|end_header_id|>": 128007, - "<|eot_id|>": 128009, - } - self.special_tokens.update( - { - f"<|reserved_{i}|>": i - for i in list(range(128002, 128006)) + list(range(128009, 128256)) - } - ) - - # Load model weights and tokenizer - self.weights = safetensors.torch.load_file(weights_path) - self.tokenizer = get_tokenizer(tokenizer_path, self.special_tokens) - # TODO: Assert that weight dimensions match config - - # Compute RoPE angle look-up table - self.angles = compute_rope_angles( - self.head_dim, self.context_length, self.rope_base - ) - - -class LlamaModelState: - def __init__(self, config): - # Current IDs of tokens being processed (most recent token for decode; all prompt tokens for prefill) - self.token_ids = torch.empty(0, dtype=torch.long) - self.reset_kv_cache(config) - - def reset_kv_cache(self, config): - # Set up KV cache -- initially empty - # This is what passes information from previous tokens to the current token during generation - self.attn_keys_caches = [ - torch.empty( - 1, - config.n_kv_groups, - 0, - config.head_dim, - dtype=config.weights["model.layers.0.self_attn.k_proj.weight"].dtype, - ) # (batch_size, n_kv_groups, seq_len, head_dim) - for _ in range(config.n_layers) - ] - self.attn_values_caches = [ - torch.empty( - 1, - config.n_kv_groups, - 0, - config.head_dim, - dtype=config.weights["model.layers.0.self_attn.v_proj.weight"].dtype, - ) # (batch_size, n_kv_groups, seq_len, head_dim) - for _ in range(config.n_layers) - ] - - -# Utilities -# ########################################################################## - - -def compute_rope_angles(head_dim, context_length, rope_base=500000.0): - """Compute RoPE (Rotary Position Embedding) angles.""" - # Precompute the frequency tensor - inv_freq = 1.0 / (rope_base ** (torch.arange(0, head_dim, 2).float() / head_dim)) - position = torch.arange(context_length).float() - freqs = torch.outer(position, inv_freq) - - cos = torch.cos(freqs) - sin = torch.sin(freqs) - - # Interleave cos and sin - create angles buffer - angles = torch.empty(context_length, head_dim) - angles[:, ::2] = cos - angles[:, 1::2] = sin - return angles - - -def get_tokenizer(tokenizer_path, special_tokens): - mergeable = tiktoken.load.load_tiktoken_bpe(tokenizer_path) - return tiktoken.Encoding( - name="llama3.2-1b", - pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" - r"|[^\r\n\p{L}\p{N}]?\p{L}+" - r"|\p{N}{1,3}" - r"| ?[^\s\p{L}\p{N}]+[\r\n]*" - r"|\s*[\r\n]+" - r"|\s+(?!\S)" - r"|\s+", - mergeable_ranks=mergeable, - special_tokens=special_tokens, - ) - - -# Generation loop -# ########################################################################## - - -def generate_token(config, forward_pass, state): - generated_tokens = [] - - # Step 1: Forward pass - logits, state = forward_pass(config, state) - - # Step 2: Get logits for last token - last_token_logits = logits[:, -1, :] # (batch, vocab_size) - - # Step 3: Temperature scaling - if config.temperature > 0: - last_token_logits = last_token_logits / config.temperature - - # Step 4: Top-k filtering - if config.top_k is not None: - top_logits, top_indices = torch.topk(last_token_logits, config.top_k) - min_val = top_logits[:, -1:] - last_token_logits = torch.where( - last_token_logits < min_val, torch.tensor(float("-inf")), last_token_logits - ) - - # Step 5: Sample - probs = torch.nn.functional.softmax(last_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - return next_token.item(), state - - -def init( - weights_path="/scratch/roesti/models/llama3.2-1b/model.safetensors", - tokenizer_path="/scratch/roesti/models/llama3.2-1b/tokenizer.model", - prompt="The capital of France is ", -): - config = LlamaConfig(weights_path, tokenizer_path) - state = LlamaModelState(config) - - seed = 1608560892 - torch.manual_seed(seed) - - # Tokenize prompt - prompt_token_ids = [config.special_tokens["<|begin_of_text|>"]] - prompt_token_ids += config.tokenizer.encode(prompt) - assert ( - len(prompt_token_ids) <= config.context_length - ), "Prompt + new tokens to generate too long (exceed context)" - prompt_token_ids = torch.tensor([prompt_token_ids], dtype=torch.long) - - state.token_ids = prompt_token_ids - - return config, state - - -def generate(config, state, forward_pass, num_tokens=100, use_kv_cache=True): - # Generate tokens - # First token (prefill) - n_tokens_generated = 0 - t_prefill_start = time.perf_counter() - first_token, state = generate_token(config, forward_pass, state) - token_text = config.tokenizer.decode([first_token]) - n_tokens_generated += 1 - print(token_text, end="", flush=True) - t_prefill_stop = time.perf_counter() - - # Remaining tokens (decode) - if use_kv_cache: - state.token_ids = torch.tensor([[first_token]], dtype=torch.long) - else: - state.reset_kv_cache(config) - state.token_ids = torch.cat( - [state.token_ids, torch.tensor([[first_token]], dtype=torch.long)], dim=1 - ) - t_decode_start = time.perf_counter() - for _ in range(num_tokens - 1): - next_token, state = generate_token(config, forward_pass, state) - token_text = config.tokenizer.decode([next_token]) - n_tokens_generated += 1 - print(token_text, end="", flush=True) - if use_kv_cache: - state.token_ids = torch.tensor([[next_token]], dtype=torch.long) - else: - state.reset_kv_cache(config) - state.token_ids = torch.cat( - [state.token_ids, torch.tensor([[next_token]], dtype=torch.long)], dim=1 - ) - t_decode_end = time.perf_counter() - - t_prefill = t_prefill_stop - t_prefill_start - t_decode = t_decode_end - t_decode_start - sys.stderr.write("\n\n=== Performance Statistics ===\n") - sys.stderr.write(f"[Prefill] Time to first token: {t_prefill:7.3f} s\n") - sys.stderr.write( - f"[Decode] Time per token (mean): {t_decode / (n_tokens_generated - 1):7.3f} s\n" - ) - sys.stderr.write( - f"[Decode] Tokens per second: {(n_tokens_generated - 1) / t_decode:7.3f}\n" - ) - sys.stderr.write( - f"[Total] Time per token (mean): {(t_prefill + t_decode) / n_tokens_generated:7.3f} s\n" - ) - sys.stderr.write( - f"[Total] Tokens per second: {n_tokens_generated / (t_prefill + t_decode):7.3f}\n" - ) - - -if __name__ == "__main__": - main() diff --git a/iron/applications/llama_3.2_1b/llama_npu.py b/iron/applications/llama_3.2_1b/llama_npu.py deleted file mode 100755 index 565d94f9..00000000 --- a/iron/applications/llama_3.2_1b/llama_npu.py +++ /dev/null @@ -1,1348 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Next steps for decode performance: -# [ ] All decode operators operate on 2048-padded buffers; instead, should bin into shorter sequence lengths and call smaller operators -# [ ] Opportunity to fuse data layout transformations (e.g., transpose ops) onto end of other operations (e.g., transpose after RoPE) -# [ ] Some kernels are not optimized; e.g., softmax masking is using scalar cores -# [ ] Fine-tune parameters of operators (e.g., num AIE columns, tile sizes) -# [ ] Patching of operators (instantiating new xrt::elf for each token) is slow; find quicker way of patching instruction sequence in-memory -# [ ] Spatial fusion of operators - -import torch -import math -from pathlib import Path -import sys -import numpy as np -import ml_dtypes -import llama_inference_harness as harness -import logging -import time - -repo_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(repo_root)) - -from operators.common.context import AIEContext -from operators.common import AIEBuffer -from operators.common.utils import torch_to_numpy -from operators.common.base import PatchableSingleXclbinCallable -from operators.common.fusion import ( - FusedMLIROperator, - FusedFullELFCallable, - load_elf, - patch_elf, -) -from operators import ( - AIERMSNorm, - AIEGEMM, - AIEGEMV, - AIEElementwiseAdd, - AIEElementwiseMul, - AIESiLU, - AIERope, - AIEStridedCopy, - AIERepeat, - AIESoftmax, - AIETranspose, -) - -logging.basicConfig(level=logging.DEBUG) - -max_seq_len = 2048 - - -# AIE Operator Configuration -# ########################################################################## - - -aie_ops = None - - -class AIEPrefillOperations: - pass - - -class AIEDecodeOperations: - pass - - -class AIELlamaOperators: - - def __init__(self, config, prompt_len): - self.context = AIEContext() - self.context.build_dir.mkdir(parents=True, exist_ok=True) - - self.prefill = AIEPrefillOperations() - self.decode = AIEDecodeOperations() - - # ################################################################## - # Prefill operators - - self.prefill.rms_norm = ( - AIERMSNorm( - size=prompt_len * config.emb_dim, - eps=1e-5, - num_aie_columns=8, - num_channels=2, - tile_size=config.emb_dim, - context=self.context, - ) - .compile() - .get_callable() - ) - - self.prefill.residual_add = ( - AIEElementwiseAdd( - size=prompt_len * config.emb_dim, tile_size=config.emb_dim - ) - .compile() - .get_callable() - ) - self.decode.residual_add = ( - AIEElementwiseAdd(size=config.emb_dim, tile_size=config.emb_dim // 8) - .compile() - .get_callable() - ) - - min_N = 64 * 8 * 4 # tile_n * num_aie_columns * partition_N - config.padded_vocab_size = (config.vocab_size + min_N - 1) // min_N * min_N - config.vocab_partitions = 4 - self.prefill.gemv_out_head_compilable = AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.padded_vocab_size // config.vocab_partitions, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=True, - separate_c_tiles=True, - context=self.context, - ).compile() - self.prefill.out_head = self.prefill.gemv_out_head_compilable.get_callable() - - # SwiGLU FFN operators - # Prefill: M=prompt_len, K=emb_dim, N=hidden_dim - self.prefill.ffn_up_gate = ( - AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.hidden_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - context=self.context, - ) - .compile() - .get_callable() - ) - - self.prefill.ffn_down = ( - AIEGEMM( - M=prompt_len, - K=config.hidden_dim, - N=config.emb_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, # exceeds stride dimensions otherwise; just transpose weights - context=self.context, - ) - .compile() - .get_callable() - ) - - self.prefill.ffn_silu = ( - AIESiLU( - size=prompt_len * config.hidden_dim, - tile_size=config.hidden_dim, - num_aie_columns=8, - context=self.context, - ) - .compile() - .get_callable() - ) - - self.prefill.eltwise_mul_ffn = ( - AIEElementwiseMul( - size=prompt_len * config.hidden_dim, - tile_size=config.hidden_dim, - num_aie_columns=8, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Attention score scaling operators - # FIXME: Using elementwise mul is very wasteful (of bandwidth) here since it's the same scalar factor for all values; need a kernel that allows scalar multiplication of a vector; maybe use AXPY - self.prefill.attn_scale = ( - AIEElementwiseMul( - size=config.n_heads * prompt_len * prompt_len, - tile_size=prompt_len, - num_aie_columns=8, - context=self.context, - ) - .compile() - .get_callable() - ) - - # RoPE operators - # For queries: (seq_len, num_heads * head_dim) = (seq_len, 2048) - # For keys: (seq_len, num_kv_groups * head_dim) = (seq_len, 512) - # angle_rows=1 because all rows use the same angle row (angles are per position) - self.prefill.rope_queries = ( - AIERope( - rows=prompt_len * config.n_heads, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context, - ) - .compile() - .get_callable() - ) - - self.prefill.rope_keys = ( - AIERope( - rows=prompt_len * config.n_kv_groups, - cols=config.head_dim, - angle_rows=prompt_len, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Attention projection operators - # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim) - self.prefill.attn_query = ( - AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_heads * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Key projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_key = ( - AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Value projection: (seq_len, emb_dim) -> (seq_len, n_kv_groups * head_dim) - self.prefill.attn_value = ( - AIEGEMM( - M=prompt_len, - K=config.emb_dim, - N=config.n_kv_groups * config.head_dim, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Attention score computation: Q @ K^T per head - # For prefill: (seq_len, head_dim) @ (head_dim, seq_len) = (seq_len, seq_len) per head - self.prefill.attn_scores = ( - AIEGEMM( - M=prompt_len, - K=config.head_dim, - N=prompt_len, - num_aie_columns=8, - tile_m=64, - tile_k=64, - tile_n=64, - b_col_maj=False, - context=self.context, - ) - .compile() - .get_callable() - ) - - # Decode operator (everything temporally fused) - # ################################################################## - - elf_ctx = AIEContext(build_dir="build_elf") - - gemv_attn_query_op = AIEGEMV( - M=config.n_heads * config.head_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.head_dim // 2, - context=elf_ctx, - ) - - gemv_attn_key_value_op = AIEGEMV( - M=config.n_kv_groups * config.head_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.head_dim // 2, - context=elf_ctx, - ) - - rope_queries_op = AIERope( - rows=1 * config.n_heads, cols=config.head_dim, angle_rows=1, context=elf_ctx - ) - - rope_keys_op = AIERope( - rows=1 * config.n_kv_groups, - cols=config.head_dim, - angle_rows=1, - context=elf_ctx, - ) - - strided_copy_cache_magic = 0xDEADBEE0 - strided_copy_cache_op = AIEStridedCopy( - input_sizes=(config.n_kv_groups, config.head_dim), - input_strides=(config.head_dim, 1), - input_offset=0, - output_sizes=(1, config.n_kv_groups, config.head_dim), - output_strides=(0, prompt_len * config.head_dim, 1), - output_offset=7 * config.head_dim * 2, # Will be patched at runtime - input_buffer_size=1 * config.n_kv_groups * config.head_dim, - output_buffer_size=config.n_kv_groups * prompt_len * config.head_dim, - num_aie_channels=1, - output_offset_patch_marker=strided_copy_cache_magic, - context=elf_ctx, - ) - - # For decode: per head, (1, head_dim) @ (head_dim, max_context_len) - # Use GEMV: (max_context_len, head_dim) @ (head_dim,) = (max_context_len,) - gemv_attn_scores_op = AIEGEMV( - M=prompt_len, # max possible context length - K=config.head_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=prompt_len // 8, - num_batches=config.n_heads, - context=elf_ctx, - ) - - attn_scale_op = AIEElementwiseMul( - size=config.n_heads * prompt_len, - tile_size=prompt_len // 8, - num_aie_columns=8, - context=elf_ctx, - ) - - # Softmax operators for attention weights - softmax_magic = 0xBA5EBA11 - softmax_op = AIESoftmax( - rows=config.n_heads, - cols=prompt_len, - num_aie_columns=1, - num_channels=1, - rtp_vector_size=prompt_len, # Compile with max size - mask_patch_value=softmax_magic, # Magic value for patching - context=elf_ctx, - ) - - # Fused transpose for all attention heads (decode) - transpose_values_op = AIETranspose( - M=prompt_len, - N=config.head_dim, - num_aie_columns=2, - num_channels=1, - m=256, - n=32, - s=8, - context=elf_ctx, - ) - - # GEMV for attention context: (head_dim, max_context_len) @ (max_context_len,) = (head_dim,) per head - gemv_attn_context_op = AIEGEMV( - M=config.head_dim, - K=prompt_len, # max possible context length - num_aie_columns=8, - tile_size_input=4, - tile_size_output=4, - num_batches=config.n_heads, - context=elf_ctx, - ) - - gemv_attn_output_op = AIEGEMV( - M=config.emb_dim, - K=config.n_heads * config.head_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.emb_dim // 8, - context=elf_ctx, - ) - - rms_norm_op = AIERMSNorm( - size=config.emb_dim, - eps=1e-5, - num_aie_columns=1, - num_channels=2, - tile_size=config.emb_dim, - context=elf_ctx, - ) - - gemv_ffn_up_gate_op = AIEGEMV( - M=config.hidden_dim, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=config.hidden_dim // 8, - context=elf_ctx, - ) - - gemv_ffn_down_op = AIEGEMV( - M=config.emb_dim, - K=config.hidden_dim, - num_aie_columns=8, - tile_size_input=1, - tile_size_output=config.emb_dim // 8, - context=elf_ctx, - ) - - silu_ffn_op = AIESiLU( - size=config.hidden_dim, - tile_size=config.hidden_dim // 8, - num_aie_columns=8, - context=elf_ctx, - ) - - eltwise_mul_ffn_op = AIEElementwiseMul( - size=config.hidden_dim, - tile_size=config.hidden_dim // 8, - num_aie_columns=8, - context=elf_ctx, - ) - - residual_add_op = AIEElementwiseAdd( - size=config.emb_dim, tile_size=config.emb_dim // 8, context=elf_ctx - ) - - repeat_interleave_op = AIERepeat( - rows=config.n_kv_groups, - cols=prompt_len * config.head_dim, # Max context length - repeat=config.n_heads // config.n_kv_groups, - transfer_size=config.head_dim, - context=elf_ctx, - ) - - gemv_out_head_op = AIEGEMV( - M=config.vocab_size, - K=config.emb_dim, - num_aie_columns=8, - tile_size_input=4, - tile_size_output=32, - context=self.context, - ) - - # Create fused operator - - cache_buffer_size = ( - config.n_kv_groups * prompt_len * config.head_dim * 2 - ) # * 2 for bfloat16 - values_per_head_buffer_size = ( - prompt_len * config.head_dim * 2 - ) # * 2 for bfloat16 - values_buffer_size = config.n_heads * values_per_head_buffer_size - - runlist = [] - for layer_idx in range(config.n_layers): - # - runlist.extend( - [ - ( - rms_norm_op, - "x", - f"W_norm1_{layer_idx}", - "x_norm", - ) # Step 1: RMS normalization - ] - + [ - # - ( - gemv_attn_query_op, - f"W_attn_query_{layer_idx}", - "x_norm", - "queries", - ), - ( - gemv_attn_key_value_op, - f"W_attn_key_{layer_idx}", - "x_norm", - "keys", - ), - ( - gemv_attn_key_value_op, - f"W_attn_value_{layer_idx}", - "x_norm", - "values", - ), - (rope_queries_op, "queries", "rope_angles", "queries"), - (rope_keys_op, "keys", "rope_angles", "keys"), - (strided_copy_cache_op, "keys", f"keys_cache_{layer_idx}"), - (strided_copy_cache_op, "values", f"values_cache_{layer_idx}"), - ( - repeat_interleave_op, - f"keys_cache_{layer_idx}", - "attn_scores_keys", - ), - ( - repeat_interleave_op, - f"values_cache_{layer_idx}", - "attn_scores_values", - ), - (gemv_attn_scores_op, "attn_scores_keys", "queries", "attn_scores"), - (attn_scale_op, "attn_scores", "attn_scale_factor", "attn_scores"), - (softmax_op, "attn_scores", "attn_weights"), - ] - + [ - ( - transpose_values_op, - f"attn_scores_values[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", - f"attn_scores_values_transposed[{h * values_per_head_buffer_size}:{(h + 1) * values_per_head_buffer_size}]", - ) - for h in range(config.n_heads) - ] - + [ - ( - gemv_attn_context_op, - "attn_scores_values_transposed", - "attn_weights", - "attn_context", - ), - ( - gemv_attn_output_op, - f"W_attn_output_decode_{layer_idx}", - "attn_context", - "attn_output", - ), - # - ] - + [ - (residual_add_op, "x", "attn_output", "x"), - (rms_norm_op, "x", f"W_norm2_{layer_idx}", "x_norm"), - ( - gemv_ffn_up_gate_op, - f"W_ffn_gate_{layer_idx}", - "x_norm", - "ffn_gate", - ), - (gemv_ffn_up_gate_op, f"W_ffn_up_{layer_idx}", "x_norm", "ffn_up"), - (silu_ffn_op, "ffn_gate", "ffn_gate"), - (eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"), - ( - gemv_ffn_down_op, - f"W_ffn_down_{layer_idx}", - "ffn_hidden", - "ffn_output", - ), - (residual_add_op, "x", "ffn_output", "x"), - ] - ) - # - runlist += [ - (rms_norm_op, "x", "W_final_norm", "x"), - (gemv_out_head_op, "W_out_head", "x", "logits"), - ] - - self.decode.fused_op = FusedMLIROperator( - "fused_op", - runlist, - input_args=[ # arguments that change between invocations of the fused kernel and therefore need to be synced on each token - "x", - "rope_angles", - ], - output_args=["logits"], - buffer_sizes={ - **{ - f"keys_cache_{layer_idx}": cache_buffer_size - for layer_idx in range(config.n_layers) - }, - **{ - f"values_cache_{layer_idx}": cache_buffer_size - for layer_idx in range(config.n_layers) - }, - **{ - "attn_scores_values": values_buffer_size, - "attn_scores_values_transposed": values_buffer_size, - }, - }, - context=elf_ctx, - ).compile() - - # Operator patching - - self.decode.fused_elf_data = load_elf(self.decode.fused_op) - - def get_patch_locs(elf_data, magic): - magic = magic & 0xFFFFFFFF - return np.where(elf_data == magic)[0] - - keys_patches = {} - values_patches = {} - for layer_idx in range(config.n_layers): - _, keys_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( - f"keys_cache_{layer_idx}" - ) - _, values_cache_offs, _ = self.decode.fused_op.get_layout_for_buffer( - f"values_cache_{layer_idx}" - ) - keys_patches.update( - { - int(l): keys_cache_offs - for l in get_patch_locs( - self.decode.fused_elf_data, - (keys_cache_offs + strided_copy_cache_magic * 2), - ) - } - ) - values_patches.update( - { - int(l): values_cache_offs - for l in get_patch_locs( - self.decode.fused_elf_data, - (values_cache_offs + strided_copy_cache_magic * 2), - ) - } - ) - no_offset_patches = { - int(l): 0 - for l in get_patch_locs( - self.decode.fused_elf_data, (strided_copy_cache_magic * 2) - ) - } - self.decode.fused_patch_locations = { - **keys_patches, - **values_patches, - **no_offset_patches, - } - assert len(self.decode.fused_patch_locations) == 4 * config.n_layers + 2 - - self.decode.softmax_patch_offsets = get_patch_locs( - self.decode.fused_elf_data, softmax_magic - ) - assert len(self.decode.softmax_patch_offsets) == config.n_layers + 1 - - self.decode.fused = FusedFullELFCallable( - self.decode.fused_op, elf_data=self.decode.fused_elf_data - ) - - # Operator static buffers (weights, LUTs) - - for layer_idx in range(config.n_layers): - self.decode.fused.get_buffer(f"W_norm1_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.input_layernorm.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_attn_query_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.self_attn.q_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_attn_key_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.self_attn.k_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_attn_value_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.self_attn.v_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_attn_output_decode_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.self_attn.o_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_norm2_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_ffn_gate_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.mlp.gate_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_ffn_up_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.mlp.up_proj.weight" - ].flatten() - self.decode.fused.get_buffer(f"W_ffn_down_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = config.weights[ - f"model.layers.{layer_idx}.mlp.down_proj.weight" - ].flatten() - scale_factor = 1.0 / math.sqrt(config.head_dim) - self.decode.fused.get_buffer("attn_scale_factor").to("cpu").view_as_torch()[ - : - ] = scale_factor - self.decode.fused.get_buffer("W_final_norm").to("cpu").view_as_torch()[:] = ( - config.weights["model.norm.weight"].flatten() - ) - self.decode.fused.get_buffer("W_out_head").to("cpu").view_as_torch()[:] = ( - config.weights["model.embed_tokens.weight"].flatten() - ) - self.decode.fused.input_buffer.to("npu") - self.decode.fused.scratch_buffer.to("npu") - self.decode.fused.output_buffer.to("npu") - - -# Allocate buffers shared with NPU -# ########################################################################## - -aie_buffers = None - - -class AIEPrefillBuffers: - def __init__(self, prompt_len, emb_dim, hidden_dim, n_heads, n_kv_groups, head_dim): - self.x = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) - self.x_norm = AIEBuffer(shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16) - self.attn_output = AIEBuffer( - shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 - ) - self.ffn_output = AIEBuffer( - shape=(prompt_len, emb_dim), dtype=ml_dtypes.bfloat16 - ) - # SwiGLU intermediate buffers - self.ffn_gate = AIEBuffer( - shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 - ) - self.ffn_up = AIEBuffer( - shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 - ) - self.ffn_hidden = AIEBuffer( - shape=(prompt_len, hidden_dim), dtype=ml_dtypes.bfloat16 - ) - # Attention buffers: queries and keys serve as both projection output and RoPE input/output - self.queries = AIEBuffer( - shape=(prompt_len * n_heads, head_dim), dtype=ml_dtypes.bfloat16 - ) - self.keys = AIEBuffer( - shape=(prompt_len * n_kv_groups, head_dim), dtype=ml_dtypes.bfloat16 - ) - self.values = AIEBuffer( - shape=(prompt_len, n_kv_groups * head_dim), dtype=ml_dtypes.bfloat16 - ) - self.rope_angles = AIEBuffer( - shape=(prompt_len, head_dim), dtype=ml_dtypes.bfloat16 - ) - # Attention score computation buffers (per-head) - parent buffers with subbuffers - # Parent buffer for all heads' queries: (n_heads, prompt_len, head_dim) stored contiguously - self.attn_scores_queries_all = AIEBuffer( - shape=(n_heads * prompt_len, head_dim), dtype=ml_dtypes.bfloat16 - ) - self.attn_scores_queries_per_head = [ - self.attn_scores_queries_all.subbuffer( - length=prompt_len * head_dim, - offset=h * prompt_len * head_dim, - shape=(prompt_len, head_dim), - ) - for h in range(n_heads) - ] - # Parent buffer for all KV groups' keys: (n_kv_groups, head_dim, prompt_len) stored contiguously - self.attn_scores_keys_all = AIEBuffer( - shape=(n_kv_groups * head_dim, prompt_len), dtype=ml_dtypes.bfloat16 - ) - self.attn_scores_keys_per_kv_group = [ - self.attn_scores_keys_all.subbuffer( - length=head_dim * prompt_len, - offset=g * head_dim * prompt_len, - shape=(head_dim, prompt_len), - ) - for g in range(n_kv_groups) - ] - # Parent buffer for all heads' scores: (n_heads * prompt_len, prompt_len) - self.attn_scores = AIEBuffer( - shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 - ) - self.attn_scores_per_head = [ - self.attn_scores.subbuffer( - length=prompt_len * prompt_len, - offset=h * prompt_len * prompt_len, - shape=(prompt_len, prompt_len), - ) - for h in range(n_heads) - ] - # Attention score scaling buffer (pre-initialized with 1/sqrt(head_dim)) - scale_factor = 1.0 / math.sqrt(head_dim) - self.attn_scale_factor = AIEBuffer( - shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 - ) - self.attn_scale_factor.view_as_torch()[:] = scale_factor - self.attn_scale_factor.to("npu") - # Attention weights buffer (output of softmax) - self.attn_weights = AIEBuffer( - shape=(n_heads * prompt_len, prompt_len), dtype=ml_dtypes.bfloat16 - ) - - -class AIELlamaBuffers: - def __init__(self, config, prompt_len): - # Vector of the current token(s) being processed through the pipeline - self.prefill = AIEPrefillBuffers( - prompt_len, - config.emb_dim, - config.hidden_dim, - config.n_heads, - config.n_kv_groups, - config.head_dim, - ) - - # Per-layer KV cache buffers on NPU (used by strided copy for transpose and concatenate) - self.keys_cache = [ - AIEBuffer( - shape=(config.n_kv_groups, prompt_len, config.head_dim), - dtype=ml_dtypes.bfloat16, - ) - for _ in range(config.n_layers) - ] - self.values_cache = [ - AIEBuffer( - shape=(config.n_kv_groups, prompt_len, config.head_dim), - dtype=ml_dtypes.bfloat16, - ) - for _ in range(config.n_layers) - ] - - # Transformer block layer-wise RMS norm - self.W_norm1 = [] - self.W_norm2 = [] - # Attention projection weights - self.W_attn_query_prefill = [] - self.W_attn_query_decode = [] - self.W_attn_key_prefill = [] - self.W_attn_key_decode = [] - self.W_attn_value_prefill = [] - self.W_attn_value_decode = [] - self.W_attn_output_decode = [] - # SwiGLU FFN weights - self.W_ffn_gate_prefill = [] - self.W_ffn_up_prefill = [] - self.W_ffn_down_prefill = [] - self.W_ffn_gate_decode = [] - self.W_ffn_up_decode = [] - self.W_ffn_down_decode = [] - for layer_idx in range(config.n_layers): - self.W_norm1.append( - AIEBuffer.from_torch( - config.weights[f"model.layers.{layer_idx}.input_layernorm.weight"] - ).to("npu") - ) - self.W_norm2.append( - AIEBuffer.from_torch( - config.weights[ - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ] - ).to("npu") - ) - self.W_attn_query_prefill.append( - AIEBuffer.from_torch( - config.weights[ - f"model.layers.{layer_idx}.self_attn.q_proj.weight" - ].T - ).to("npu") - ) - self.W_attn_key_prefill.append( - AIEBuffer.from_torch( - config.weights[ - f"model.layers.{layer_idx}.self_attn.k_proj.weight" - ].T - ).to("npu") - ) - self.W_attn_value_prefill.append( - AIEBuffer.from_torch( - config.weights[ - f"model.layers.{layer_idx}.self_attn.v_proj.weight" - ].T - ).to("npu") - ) - self.W_ffn_gate_prefill.append( - AIEBuffer.from_torch( - config.weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"].T - ).to("npu") - ) - self.W_ffn_up_prefill.append( - AIEBuffer.from_torch( - config.weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"].T - ).to("npu") - ) - self.W_ffn_down_prefill.append( - AIEBuffer.from_torch( - config.weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"].T - ).to("npu") - ) - - # Final RMS norm weights - self.W_final_norm = AIEBuffer.from_torch( - config.weights["model.norm.weight"] - ).to("npu") - # Final linear layer - self.W_out_head = AIEBuffer.from_torch( - config.weights["model.embed_tokens.weight"] - ).to( - "npu" - ) # unpadded/unpartitioned, used by GEMV - W_out_head_parts = aie_ops.prefill.gemv_out_head_compilable.partition_B( - torch_to_numpy(config.weights["model.embed_tokens.weight"]), - config.vocab_partitions, - ) - self.W_out_head_parts = [ - AIEBuffer.from_np(W_out_head_part).to("npu") - for W_out_head_part in W_out_head_parts - ] # partitioned, padded parts of weight, used by GEMM - self.prefill.logits = AIEBuffer( - shape=( - config.vocab_partitions, - prompt_len, - config.padded_vocab_size // config.vocab_partitions, - ) - ).to("npu") - self.prefill.logits_parts = [ - self.prefill.logits.subbuffer( - length=prompt_len - * (config.padded_vocab_size // config.vocab_partitions), - offset=i - * prompt_len - * (config.padded_vocab_size // config.vocab_partitions), - shape=(prompt_len, config.padded_vocab_size // config.vocab_partitions), - ) - for i in range(config.vocab_partitions) - ] - - -# Prefill -# ########################################################################## - - -def grouped_query_attention_forward_prefill( - config, - x, - keys_cache, - values_cache, - layer_idx, - mask=None, -): - batch, seq_len, emb_dim = x.shape - num_preceding_tokens = keys_cache.shape[2] - - # Step 1: Linear projections - aie_ops.prefill.attn_query( - aie_buffers.prefill.x_norm, - aie_buffers.W_attn_query_prefill[layer_idx], - aie_buffers.prefill.queries, - ) - aie_ops.prefill.attn_key( - aie_buffers.prefill.x_norm, - aie_buffers.W_attn_key_prefill[layer_idx], - aie_buffers.prefill.keys, - ) - aie_ops.prefill.attn_value( - aie_buffers.prefill.x_norm, - aie_buffers.W_attn_value_prefill[layer_idx], - aie_buffers.prefill.values, - ) - - # Step 2: Apply RoPE to queries and keys - aie_ops.prefill.rope_queries( - aie_buffers.prefill.queries, - aie_buffers.prefill.rope_angles, - aie_buffers.prefill.queries, - ) - aie_ops.prefill.rope_keys( - aie_buffers.prefill.keys, - aie_buffers.prefill.rope_angles, - aie_buffers.prefill.keys, - ) - - # Read results from NPU - queries = aie_buffers.prefill.queries.to("cpu").view_as_torch()[ - : seq_len * config.n_heads, : - ] - keys = aie_buffers.prefill.keys.to("cpu").view_as_torch()[ - : seq_len * config.n_kv_groups, : - ] - values = aie_buffers.prefill.values.to("cpu").view_as_torch()[ - :seq_len, : - ] # (seq_len, n_kv_groups * head_dim) - queries = queries.view(batch, seq_len, config.n_heads, config.head_dim) - keys = keys.unsqueeze(0).view(batch, seq_len, config.n_kv_groups, config.head_dim) - values = values.unsqueeze(0).view( - batch, seq_len, config.n_kv_groups, config.head_dim - ) # (batch, seq_len, num_kv_groups, head_dim) - - # Step 3: Transpose for attention computation - # As a result of the attention projections, the queries, keys and values for each head are interspersed with each other. - # Transpose so that heads are consecutive for attention computation: - # (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) - queries = queries.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) - keys = keys.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - values = values.transpose(1, 2) # (batch, num_kv_groups, seq_len, head_dim) - - # Step 4: Combine newly computed keys/values for most recent token with cache; these values are used as the updated cache and will be returned to use in the next iteration. - keys_cache = torch.cat([keys_cache, keys], dim=2) - values_cache = torch.cat([values_cache, values], dim=2) - keys = keys_cache - values = values_cache - - # Step 5: Repeat keys and values for grouped attention -- multiple queries get the same key/value - group_size = config.n_heads // config.n_kv_groups - values = values.repeat_interleave(group_size, dim=1) - context_len = keys.shape[2] - - # Step 6: Compute attention scores using NPU (per-head) - # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, context_len) - # -> (batch, num_heads, seq_len, context_len) - - queries_buf = aie_buffers.prefill.attn_scores_queries_all.view_as_torch().view( - config.n_heads, -1, config.head_dim - ) - queries_buf[:, :seq_len, :] = queries.squeeze(0)[ - :, :seq_len, : - ] # (num_heads, seq_len, head_dim) - keys_buf = aie_buffers.prefill.attn_scores_keys_all.view_as_torch().view( - config.n_kv_groups, config.head_dim, -1 - ) - keys_buf[:, :, :context_len] = keys.squeeze(0).transpose( - -2, -1 - ) # (num_kv_groups, head_dim, context_len) - - # Transfer parent buffers to NPU once - aie_buffers.prefill.attn_scores_queries_all.to("npu") - aie_buffers.prefill.attn_scores_keys_all.to("npu") - aie_buffers.prefill.attn_scores.to("npu") - - # Execute GEMM for each head using sub-buffers - for h in range(config.n_heads): - kv_group = h // group_size - aie_ops.prefill.attn_scores( - aie_buffers.prefill.attn_scores_queries_per_head[h], - aie_buffers.prefill.attn_scores_keys_per_kv_group[kv_group], - aie_buffers.prefill.attn_scores_per_head[h], - ) - - # Read back all results at once from parent buffer and apply scaling on NPU - aie_ops.prefill.attn_scale( - aie_buffers.prefill.attn_scores, - aie_buffers.prefill.attn_scale_factor, - aie_buffers.prefill.attn_scores, - ) - aie_buffers.prefill.attn_scores.to("cpu") - # Buffer is (n_heads * max_seq_len, max_seq_len), view as (n_heads, max_seq_len, max_seq_len) then slice - max_seq_len = aie_buffers.prefill.attn_scores.shape[0] // config.n_heads - scores = ( - aie_buffers.prefill.attn_scores.view_as_torch() - .view(config.n_heads, max_seq_len, max_seq_len) - .unsqueeze(0)[:, :, :seq_len, :context_len] - ) - - # Step 7: Apply mask - # This ensures causality, so that tokens in the future cannot attend to tokens in the past. - if mask is not None: - scores = scores.masked_fill(mask, float("-inf")) - - # Step 8: Apply softmax on CPU - scores = torch.softmax(scores.to(torch.float32), dim=-1).to(torch.bfloat16) - attention_weights = scores - - # Step 9: Compute attention output - # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) - # -> (batch, num_heads, seq_len, head_dim) - context = torch.matmul(attention_weights, values) - - # Step 10: Concatenate heads and project - # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads * head_dim) - context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1) - - output = torch.nn.functional.linear( - context, config.weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] - ) - - return output, keys_cache, values_cache - - -def swiglu_ffn_forward_prefill(layer_idx): - # Step 1: Gate projection - aie_ops.prefill.ffn_up_gate( - aie_buffers.prefill.x_norm, - aie_buffers.W_ffn_gate_prefill[layer_idx], - aie_buffers.prefill.ffn_gate, - ) - - # Step 2: Up projection - aie_ops.prefill.ffn_up_gate( - aie_buffers.prefill.x_norm, - aie_buffers.W_ffn_up_prefill[layer_idx], - aie_buffers.prefill.ffn_up, - ) - - # Step 3: Apply SiLU activation - aie_ops.prefill.ffn_silu(aie_buffers.prefill.ffn_gate, aie_buffers.prefill.ffn_gate) - - # Step 4: Element-wise multiplication - aie_ops.prefill.eltwise_mul_ffn( - aie_buffers.prefill.ffn_gate, - aie_buffers.prefill.ffn_up, - aie_buffers.prefill.ffn_hidden, - ) - - # Step 5: Down projection - aie_ops.prefill.ffn_down( - aie_buffers.prefill.ffn_hidden, - aie_buffers.W_ffn_down_prefill[layer_idx], - aie_buffers.prefill.ffn_output, - ) - - -def transformer_block_forward_prefill( - config, seq_len, layer_idx, attn_keys_cache, attn_values_cache, attn_mask -): - # Step 1: RMS normalization - aie_ops.prefill.rms_norm( - aie_buffers.prefill.x, - aie_buffers.W_norm1[layer_idx], - aie_buffers.prefill.x_norm, - ) - aie_buffers.prefill.x_norm.to("cpu") - x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 2: Attention - attn_output, attn_keys, attn_values = grouped_query_attention_forward_prefill( - config, - x_norm, - attn_keys_cache, - attn_values_cache, - layer_idx, - attn_mask, - ) - - # Step 3: Residual - aie_buffers.prefill.attn_output.view_as_torch().unsqueeze(0)[ - 0, :seq_len, : - ] = attn_output - aie_ops.prefill.residual_add( - aie_buffers.prefill.x, aie_buffers.prefill.attn_output, aie_buffers.prefill.x - ) - x = aie_buffers.prefill.x.to("cpu").view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 4: Post-norm - aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - aie_ops.prefill.rms_norm( - aie_buffers.prefill.x, - aie_buffers.W_norm2[layer_idx], - aie_buffers.prefill.x_norm, - ) - aie_buffers.prefill.x_norm.to("cpu") - x_norm = aie_buffers.prefill.x_norm.view_as_torch().unsqueeze(0)[:, :seq_len, :] - - # Step 5: Feed-forward network - swiglu_ffn_forward_prefill(layer_idx) - - # Step 6: Residual - aie_ops.prefill.residual_add( - aie_buffers.prefill.x, aie_buffers.prefill.ffn_output, aie_buffers.prefill.x - ) - - return attn_keys, attn_values - - -def llama_forward_pass_prefill(config, state): - batch, seq_len = state.token_ids.shape - - # Step 1: RoPE angles - num_preceding_tokens = state.attn_keys_caches[0].shape[2] - angles_slice = config.angles[num_preceding_tokens : num_preceding_tokens + seq_len] - aie_buffers.prefill.rope_angles.view_as_torch()[:seq_len, :] = angles_slice - - # Step 2: Token embedding - tok_emb_weight = config.weights["model.embed_tokens.weight"] - x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) - attn_mask = torch.triu( - torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 - ) - aie_buffers.prefill.x.view_as_torch().unsqueeze(0)[0, :seq_len, :] = x - - # Step 3: Transformer blocks - for layer_idx in range(config.n_layers): - state.attn_keys_caches[layer_idx], state.attn_values_caches[layer_idx] = ( - transformer_block_forward_prefill( - config, - seq_len, - layer_idx, - state.attn_keys_caches[layer_idx], - state.attn_values_caches[layer_idx], - attn_mask=attn_mask, - ) - ) - - # Step 4: Final normalization - aie_ops.prefill.rms_norm( - aie_buffers.prefill.x, aie_buffers.W_final_norm, aie_buffers.prefill.x - ) - - # Step 5: Output projection - for i in range(config.vocab_partitions): - aie_ops.prefill.out_head( - aie_buffers.prefill.x, - aie_buffers.W_out_head_parts[i], - aie_buffers.prefill.logits_parts[i], - ) - aie_buffers.prefill.logits.to("cpu") - logits_padded_partitioned = aie_buffers.prefill.logits.view_as_torch() - logits_padded = ( - logits_padded_partitioned.transpose(0, 1) - .contiguous() - .view(-1, config.padded_vocab_size) - ) - logits = logits_padded.unsqueeze(0)[:, :seq_len, : config.vocab_size] - - # Step 6: Initialize per-layer NPU cache buffers with current cache state for decode phase - for layer_idx in range(config.n_layers): - cache_len = state.attn_keys_caches[layer_idx].shape[2] - aie_buffers.keys_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( - state.attn_keys_caches[layer_idx].squeeze(0) - ) - aie_buffers.values_cache[layer_idx].view_as_torch()[:, :cache_len, :] = ( - state.attn_values_caches[layer_idx].squeeze(0) - ) - aie_buffers.keys_cache[layer_idx].to("npu") - aie_buffers.values_cache[layer_idx].to("npu") - - return logits, state - - -# Decode -# ########################################################################## - - -def patch_fused_decode_operator(ops, config, num_preceding_tokens): - context_len = num_preceding_tokens + 1 - - # Patch fused operator for strided copy cache offset - output_offset = num_preceding_tokens * config.head_dim - offset_val = output_offset * 2 # Multiply by 2 for bfloat16 byte offset - strided_copy_patches = { - i: (base + offset_val, 0xFFFFFFFF) - for i, base in ops.fused_patch_locations.items() - } - softmax_patches = {i: (context_len, 0xFFFFFFFF) for i in ops.softmax_patch_offsets} - patches = {**strided_copy_patches, **softmax_patches} - patched_elf_data = ops.fused_elf_data.copy() - patch_elf(patched_elf_data, patches) - - ops.fused.reload_elf(patched_elf_data) - - -def llama_forward_pass_decode(config, state): - batch, seq_len = state.token_ids.shape - assert seq_len == 1 - assert state.num_preceding_tokens < max_seq_len - - patch_fused_decode_operator(aie_ops.decode, config, state.num_preceding_tokens) - - # Prefill RoPE angle look-up tables - angles_slice = config.angles[ - state.num_preceding_tokens : state.num_preceding_tokens + seq_len - ] - aie_ops.decode.fused.get_buffer("rope_angles").to("cpu").view_as_torch()[ - : - ] = angles_slice - - # Token embedding (on CPU) - tok_emb_weight = config.weights["model.embed_tokens.weight"] - x = torch.nn.functional.embedding(state.token_ids, tok_emb_weight) - aie_ops.decode.fused.get_buffer("x").view_as_torch().view(-1, config.emb_dim)[ - :seq_len, : - ] = x - - # Fused NPU operator for all of decode (16 transformer blocks + final norm + final linear layer) - aie_ops.decode.fused.input_buffer.to("cpu") - aie_ops.decode.fused() - aie_ops.decode.fused.output_buffer.to("cpu") - logits = ( - aie_ops.decode.fused.get_buffer("logits") - .view_as_torch() - .view(1, 1, config.vocab_size) - ) - - return logits, state - - -# Main -# ########################################################################## - - -def llama_forward_pass(config, state): - global aie_ops, aie_buffers - - batch, seq_len = state.token_ids.shape - if seq_len > 1: - ret = llama_forward_pass_prefill(config, state) - state.num_preceding_tokens = state.token_ids.shape[1] - # Pass KV cache data onto fused decode operator - for layer_idx in range(config.n_layers): - aie_ops.decode.fused.get_buffer(f"keys_cache_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = ( - aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten() - ) - aie_ops.decode.fused.get_buffer(f"values_cache_{layer_idx}").to( - "cpu" - ).view_as_torch()[:] = ( - aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten() - ) - aie_ops.decode.fused.scratch_buffer.to("cpu") - return ret - else: - ret = llama_forward_pass_decode(config, state) - state.num_preceding_tokens += 1 - return ret - - -def main(): - global aie_ops, aie_buffers - prompt = "The capital of France is " - # with open('prompt.txt', 'r') as f: - # prompt = f.read() - # prompt = prompt[:max_seq_len] - - config, state = harness.init(prompt=prompt) - - aie_ops = AIELlamaOperators(config, max_seq_len) - aie_buffers = AIELlamaBuffers(config, max_seq_len) - - print(prompt, end="", flush=True) - harness.generate(config, state, llama_forward_pass, use_kv_cache=True) - - -if __name__ == "__main__": - main() diff --git a/iron/applications/llama_3.2_1b/src/block/feed_forward.py b/iron/applications/llama_3.2_1b/src/block/feed_forward.py new file mode 100644 index 00000000..8bae36ec --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/block/feed_forward.py @@ -0,0 +1,250 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn +from ..utils import assign +from iron.operators import ( + AIEElementwiseMul, + AIEGEMM, + AIEGEMV, + AIESiLU, + AIESwiGLUPrefill, + AIESwiGLUDecode, +) +from ml_dtypes import bfloat16 + + +class FeedForward(nn.Module): + def __init__( + self, + cfg, + prompt_length=0, + num_tokens=1, + ): + super().__init__() + self.cfg = cfg.copy() + + assert ( + cfg["use_aie_ffn_swiglu"] + and not ( + cfg["use_aie_ffn_silu"] + or cfg["use_aie_ffn_gemm"] + or cfg["use_aie_ffn_mul"] + ) + or not cfg["use_aie_ffn_swiglu"] + ), "Cannot mix fused SwiGLU with individual AIE operators." + + self.emb_dim = cfg["emb_dim"] + self.hidden_dim = cfg["hidden_dim"] + + # Initialize SiLU activation + if self.cfg["use_aie_ffn_silu"]: + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * self.hidden_dim + else: + max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim + self.aie_silu_prefill = AIESiLU( + size=max_prefill_size, + num_aie_columns=8, + num_channels=2, + tile_size=self.hidden_dim, + ) + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = self.hidden_dim # 1 token * emb_dim + self.aie_silu_decode = AIESiLU( + size=decode_size, + num_aie_columns=1, + num_channels=1, + tile_size=self.hidden_dim, + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_silu_decode = self.silu_prefill + else: + self.silu = nn.SiLU() + + if self.cfg["use_aie_ffn_swiglu"]: + self.aie_swiglu_prefill = AIESwiGLUPrefill( + seq_len=prompt_length, + embedding_dim=self.emb_dim, + hidden_dim=self.hidden_dim, + ) + if self.cfg["use_kv_cache"]: + self.aie_swiglu_decode = AIESwiGLUDecode( + embedding_dim=self.emb_dim, hidden_dim=self.hidden_dim + ) + + if self.cfg["use_aie_ffn_gemm"]: + if self.cfg["use_kv_cache"]: + M_prefill = prompt_length + else: + M_prefill = prompt_length + num_tokens + + aie_config_prefill = { + "num_aie_columns": 8, + "tile_m": 64, + "tile_k": 64, + "tile_n": 64, + "use_static_weight": True, + } + + self.fc1 = AIEGEMM( + M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill + ) + self.fc2 = AIEGEMM( + M=M_prefill, K=self.emb_dim, N=self.hidden_dim, **aie_config_prefill + ) + self.fc3 = AIEGEMM( + M=M_prefill, K=self.hidden_dim, N=self.emb_dim, **aie_config_prefill + ) + else: + self.fc1 = nn.Linear( + cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False + ) + self.fc2 = nn.Linear( + cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False + ) + self.fc3 = nn.Linear( + cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False + ) + + if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: + aie_gemv_config = {"num_aie_columns": 8, "is_mv": False} + # FC1 and FC2: emb_dim -> hidden_dim + self.aie_fc1_gemv = AIEGEMV( + M=self.hidden_dim, + K=self.emb_dim, + tile_size_input=1, + tile_size_output=self.hidden_dim // 16, + **aie_gemv_config, + ) + self.aie_fc2_gemv = AIEGEMV( + M=self.hidden_dim, + K=self.emb_dim, + tile_size_input=1, + tile_size_output=self.hidden_dim // 16, + **aie_gemv_config, + ) + # FC3: hidden_dim -> emb_dim + self.aie_fc3_gemv = AIEGEMV( + M=self.emb_dim, + K=self.hidden_dim, + tile_size_input=1, + tile_size_output=self.emb_dim // 16, + **aie_gemv_config, + ) + + # Initialize AIE elementwise multiply + if self.cfg["use_aie_ffn_mul"]: + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * self.hidden_dim + else: + max_prefill_size = (prompt_length + num_tokens) * self.hidden_dim + + self.aie_mul_prefill = AIEElementwiseMul( + size=max_prefill_size, + num_aie_columns=8, + num_channels=2, + tile_size=self.hidden_dim, + ) + + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = self.hidden_dim # 1 token * emb_dim + self.aie_mul_decode = AIEElementwiseMul( + size=decode_size, + num_aie_columns=1, + num_channels=2, + tile_size=self.hidden_dim, + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_mul_decode = self.aie_mul_prefill + + def forward(self, x): + original_shape = x.shape + + # Check if input is a vector (decode phase) or matrix (prefill phase) + # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) + is_vector = ( + len(x.shape) == 1 + or (len(x.shape) == 2 and x.shape[0] == 1) + or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) + ) + + is_prefill = not is_vector or not self.cfg["use_kv_cache"] + is_decode_with_kv = is_vector and self.cfg["use_kv_cache"] + + if self.cfg["use_aie_ffn_swiglu"]: + if is_prefill: + return self.aie_swiglu_prefill(x) + else: + return self.aie_swiglu_decode(x) + + if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: + x_fc1 = self.aie_fc1_gemv(x) + x_fc2 = self.aie_fc2_gemv(x) + else: + x_fc1 = self.fc1(x) + x_fc2 = self.fc2(x) + + if self.cfg["use_aie_ffn_silu"]: + if is_decode_with_kv: + x_fc1_silu = self.aie_silu_decode(x_fc1) + else: + x_fc1_silu = self.aie_silu_prefill(x_fc1) + else: + x_fc1_silu = self.silu(x_fc1) + + if self.cfg["use_aie_ffn_mul"]: + if is_decode_with_kv: + x = self.aie_mul_decode(x_fc1_silu, x_fc2) + else: + x = self.aie_mul_prefill(x_fc1_silu, x_fc2) + else: + x = x_fc1_silu * x_fc2 + + if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]: + result = self.aie_fc3_gemv(x) + return result.view(original_shape) + else: + return self.fc3(x).view(original_shape) + + def assign_weights(self, l, fc1, fc2, fc3): + if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: + self.aie_fc1_gemv.weight = fc1 + self.aie_fc2_gemv.weight = fc2 + self.aie_fc3_gemv.weight = fc3 + + if self.cfg["use_aie_ffn_swiglu"]: + self.aie_swiglu_prefill.weights_1 = fc1 + self.aie_swiglu_prefill.weights_2 = fc2 + self.aie_swiglu_prefill.weights_3 = fc3 + if self.cfg["use_kv_cache"]: + self.aie_swiglu_decode.weights_1 = fc1 + self.aie_swiglu_decode.weights_2 = fc2 + self.aie_swiglu_decode.weights_3 = fc3 + return + + self.fc1.weight = assign( + self.fc1.weight, + fc1, + f"model.layers.{l}.mlp.gate_proj.weight", + ) + self.fc2.weight = assign( + self.fc2.weight, + fc2, + f"model.layers.{l}.mlp.up_proj.weight", + ) + self.fc3.weight = assign( + self.fc3.weight, + fc3, + f"model.layers.{l}.mlp.down_proj.weight", + ) diff --git a/iron/applications/llama_3.2_1b/src/block/gqa.py b/iron/applications/llama_3.2_1b/src/block/gqa.py new file mode 100644 index 00000000..1a712ff9 --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/block/gqa.py @@ -0,0 +1,505 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +import torch.nn as nn + +from iron.operators import AIERope, AIESoftmax, AIEMHA, AIEGEMM, AIEGEMV +from iron.operators.rope.rope_utils import apply_rope + +from torchtune.modules import KVCache + +from ..utils import assign + + +class GroupedQueryAttention(nn.Module): + def __init__( + self, + d_in, + d_out, + num_heads, + num_kv_groups, + prompt_length=0, + num_tokens=1, + dtype=None, + max_batch_size=1, + max_seq_len=8192, + cfg=None, + ): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + assert ( + num_heads % num_kv_groups == 0 + ), "num_heads must be divisible by num_kv_groups" + + self.cfg = cfg.copy() if cfg is not None else {} + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads + + self.num_tokens = num_tokens + + # Weights for Attention layer + self.W_key = nn.Linear( + d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype + ) + self.W_value = nn.Linear( + d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype + ) + self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) + self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) + + self.num_kv_groups = num_kv_groups + self.group_size = num_heads // num_kv_groups + + self.prompt_length = prompt_length + + aie_gemm_config = { + "num_aie_columns": 8, + "tile_m": 64, + "tile_k": 64, + "tile_n": 64, + "use_static_weight": False, + } + + # Initialize KV Cache + if self.cfg["use_kv_cache"]: + self.kv_cache = KVCache( + batch_size=max_batch_size, + max_seq_len=max_seq_len, + num_kv_heads=self.num_kv_groups, + head_dim=self.head_dim, + dtype=torch.bfloat16, + ) + + # Initialize AIE Regular MHA operator + if self.cfg["use_aie_regular_mha"]: + self.aie_softmax = AIESoftmax( + num_aie_columns=1, + num_channels=1, + rows=prompt_length, + cols=prompt_length, + ) + M_for_gemm = prompt_length + num_tokens + self.aie_mha_gemm_qk = AIEGEMM( + M=M_for_gemm, K=self.head_dim, N=M_for_gemm, **aie_gemm_config + ) + self.aie_mha_gemm_pv = AIEGEMM( + M=M_for_gemm, K=M_for_gemm, N=self.head_dim, **aie_gemm_config + ) + + # Initialize AIE RoPE operator + if self.cfg["use_aie_rope"]: + self.aie_rope_prefill_k = AIERope( + rows=self.prompt_length * self.num_kv_groups, + cols=self.head_dim, + angle_rows=self.prompt_length, + ) + self.aie_rope_prefill_q = AIERope( + rows=self.prompt_length * self.num_heads, + cols=self.head_dim, + angle_rows=self.prompt_length, + ) + self.aie_rope_decode_k = AIERope( + rows=self.num_kv_groups, + cols=self.head_dim, + angle_rows=1, + ) + self.aie_rope_decode_q = AIERope( + rows=self.num_heads, + cols=self.head_dim, + angle_rows=1, + ) + + # Initialize fused AIE MHA operator + if self.cfg["use_aie_fused_mha"]: + self.aie_mha = AIEMHA( + num_heads=num_heads, + seq_len=prompt_length, + d=self.head_dim, + num_KV_heads=0, # Regular MHA since we feed repeated K/V + num_of_pipelines=8, + ) + + # Initialize AIE GEMV operators for decode phase (when using KV cache) + if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: + + aie_gemv_config = { + "num_aie_columns": 8, + "is_mv": False, + "use_static_weight": True, + } + self.aie_query_gemv = AIEGEMV( + M=d_out, + K=d_in, + tile_size_input=1, + tile_size_output=d_out // 16, + **aie_gemv_config, + ) + kv_out_dim = num_kv_groups * self.head_dim + self.aie_key_gemv = AIEGEMV( + M=kv_out_dim, + K=d_in, + tile_size_input=1, + tile_size_output=kv_out_dim // 16, + **aie_gemv_config, + ) + self.aie_value_gemv = AIEGEMV( + M=kv_out_dim, + K=d_in, + tile_size_input=1, + tile_size_output=kv_out_dim // 16, + **aie_gemv_config, + ) + self.aie_out_proj_gemv = AIEGEMV( + M=d_out, + K=d_out, + tile_size_input=1, + tile_size_output=d_out // 16, + **aie_gemv_config, + ) + + # Initialize AIE GEMM operators + if self.cfg["use_aie_attn_projection_gemm"]: + if self.cfg["use_kv_cache"]: + M_for_gemm = self.prompt_length + else: + M_for_gemm = self.prompt_length + self.num_tokens + + # GEMMs for projection use weights + aie_gemm_config["use_static_weight"] = True + # Query: (batch_size, d_in) @ (d_in, d_out) -> (batch_size, d_out) + self.aie_query = AIEGEMM(M=M_for_gemm, K=d_in, N=d_out, **aie_gemm_config) + # Key: (batch_size, d_in) @ (d_in, num_kv_groups * head_dim) -> (batch_size, num_kv_groups * head_dim) + kv_out_dim = num_kv_groups * self.head_dim + self.aie_key = AIEGEMM( + M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config + ) + # Value: same dimensions as key + self.aie_value = AIEGEMM( + M=M_for_gemm, K=d_in, N=kv_out_dim, **aie_gemm_config + ) + # Output projection: (batch_size, d_out) @ (d_out, d_out) -> (batch_size, d_out) + self.aie_out_proj = AIEGEMM( + M=M_for_gemm, K=d_out, N=d_out, **aie_gemm_config + ) + + def forward(self, x, mask, angles, input_pos=None): + b, num_tokens, d_in = x.shape + is_prefill = input_pos is None + is_decode = input_pos is not None + + # Step 1. + # --- + # Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices + + # Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage + if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: + # Decode phase with KV cache - use GEMV for single token + # weight.T @ input, which is vector-matrix multiplication (So, is_mv=False) + x_flat = x.reshape(1, -1) # Shape: (1, d_in) + + queries_flat = self.aie_query_gemv(x_flat) + queries = queries_flat.reshape(b, num_tokens, self.d_out) + + keys_flat = self.aie_key_gemv(x_flat) + keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) + + values_flat = self.aie_value_gemv(x_flat) + values = values_flat.reshape( + b, num_tokens, self.num_kv_groups * self.head_dim + ) + + elif self.cfg["use_aie_attn_projection_gemm"]: + # Prefill phase - use GEMM for multiple tokens + x_flat = x.reshape(-1, d_in) + input_dtype = x.dtype + + queries_flat = self.aie_query(x_flat) + queries = queries_flat.reshape(b, num_tokens, self.d_out) + + keys_flat = self.aie_key(x_flat) + keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim) + + values_flat = self.aie_value(x_flat) + values = values_flat.reshape( + b, num_tokens, self.num_kv_groups * self.head_dim + ) + else: + queries = self.W_query(x) + keys = self.W_key(x) + values = self.W_value(x) + + # Each attention head gets its own slice of the embedding dimension. + # For each head, we have query, key and value. + # In grouped-query attention, the keys and values are shared across groups of heads. + # Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values. + # Each head can be applied independently to its subslice of the embedding dimension. + keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) + values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Step 2. + # --- + # Apply positional encoding to keys and queries. + # The positional embedding is applied independently to each head. + # It modifies the embedding vectors to encode where in the sequence each token is located. + + # Determine angle slice based on KV cache usage and phase + if self.cfg["use_kv_cache"] and is_decode: + # Decode phase with KV cache: use single position + current_pos = input_pos.item() + angle_slice = angles[current_pos : current_pos + 1, :] + else: + # Prefill phase or no KV cache: use all tokens + angle_slice = angles[:num_tokens, :] + + # Apply RoPE with AIE + def apply_rope_and_transpose(aie_op, tensor, num_heads_dim, angle_slice): + angle_slice = angle_slice.to(dtype=tensor.dtype) + if self.cfg["use_aie_rope"]: + result = aie_op( + tensor.view(num_tokens * num_heads_dim, self.head_dim), angle_slice + ) + result = result.view( + b, num_tokens, num_heads_dim, self.head_dim + ).transpose(1, 2) + else: + transposed = ( + tensor.view(num_tokens, num_heads_dim, self.head_dim) + .transpose(0, 1) + .contiguous() + ) + result = apply_rope( + transposed.view(1, num_heads_dim, num_tokens, self.head_dim), + angle_slice, + ) + # ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice) + # assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference" + return result + + keys = apply_rope_and_transpose( + ( + (self.aie_rope_prefill_k if is_prefill else self.aie_rope_decode_k) + if self.cfg["use_aie_rope"] + else None + ), + keys, + self.num_kv_groups, + angle_slice, + ) + queries = apply_rope_and_transpose( + ( + (self.aie_rope_prefill_q if is_prefill else self.aie_rope_decode_q) + if self.cfg["use_aie_rope"] + else None + ), + queries, + self.num_heads, + angle_slice, + ) + values = values.transpose(1, 2) + + if self.cfg["use_kv_cache"]: + if is_prefill: + self.kv_cache.reset() + self.kv_cache.update(keys, values) + cached_keys, cached_values = keys, values + else: + self.kv_cache.update(keys, values) + current_seq_len = input_pos.item() + 1 + cached_keys = self.kv_cache.k_cache[:, :, :current_seq_len, :] + cached_values = self.kv_cache.v_cache[:, :, :current_seq_len, :] + + keys = cached_keys + values = cached_values + + # Step 3. + # --- + # Since the keys and values are shared across groups of heads in grouped-query attention, + # we now expand (repeat) the same keys and values so that each head has its own keys and values. + keys = keys.repeat_interleave(self.group_size, dim=1) + values = values.repeat_interleave(self.group_size, dim=1) + + # Step 4. + # --- + # Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output. + # Attention scores are the dot-product of queries and keys. + + # Use fused AIE MHA if enabled and conditions are met + if is_prefill or not self.cfg["use_kv_cache"]: + if ( + self.cfg["use_aie_fused_mha"] + and b == 1 + and num_tokens == self.prompt_length + and self.head_dim == 64 + ): + # TODO: Doesn't give good output ven with num_kv_groups set to 8 with kv_cache + # TODO: Doesn't match the output of CPU only when used without kv_cache + context_vec = self.aie_mha( + queries, keys, values + ) # Shape: (num_heads, num_tokens, head_dim) + + # Reshape context_vec to prepare for output projection + context_vec = context_vec.transpose(0, 1) + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + + elif self.cfg["use_aie_regular_mha"]: + # attn_scores = queries @ keys.transpose(2, 3) + # Compute attention scores for each head separately since AIE GEMM doesn't support batched operations + attn_scores_list = [] + for head in range(self.num_heads): + q_head = queries[:, head, :, :] # Shape: (b, num_tokens, head_dim) + k_head = keys[:, head, :, :] # Shape: (b, num_tokens, head_dim) + + # Use 2D tensors directly (remove batch dimension if b=1) + q_2d = q_head.squeeze(0) # Shape: (num_tokens, head_dim) + k_2d = k_head.squeeze(0) # Shape: (num_tokens, head_dim) + + # Compute Q @ K^T for this head + attn_head = self.aie_mha_gemm_qk( + q_2d, k_2d.T + ) # Shape: (num_tokens, num_tokens) + attn_head = attn_head.unsqueeze(0).unsqueeze( + 0 + ) # Add batch and head dimensions + attn_scores_list.append( + attn_head + ) # Shape: (1, 1, num_tokens, num_tokens) + + attn_scores = torch.cat( + attn_scores_list, dim=1 + ) # Shape: (b, num_heads, num_tokens, num_tokens) + attn_scores = attn_scores.masked_fill(mask, -torch.inf) + scaled_scores = attn_scores / (self.head_dim**0.5) + + # TODO: Make softmax more configurable to run in any scenario + if ( + scaled_scores.shape[-1] == self.prompt_length + and scaled_scores.shape[-1] % 16 == 0 + ): + attn_weights = self.aie_softmax(scaled_scores) + else: + attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) + + # Compute context vector for each head separately using AIE GEMM + context_vec_list = [] + for head in range(self.num_heads): + attn_head = attn_weights[ + :, head, :, : + ] # Shape: (b, num_tokens, num_tokens) + v_head = values[:, head, :, :] # Shape: (b, num_tokens, head_dim) + + # Use 2D tensors directly (remove batch dimension if b=1) + attn_2d = attn_head.squeeze(0) # Shape: (num_tokens, num_tokens) + v_2d = v_head.squeeze(0) # Shape: (num_tokens, head_dim) + + # Compute attn @ V for this head + context_head = self.aie_mha_gemm_pv( + attn_2d, v_2d + ) # Shape: (num_tokens, head_dim) + context_head = context_head.unsqueeze(0).unsqueeze( + 1 + ) # Add batch and head dimensions + context_vec_list.append( + context_head + ) # Shape: (1, 1, num_tokens, head_dim) + + context_vec = torch.cat( + context_vec_list, dim=1 + ) # Shape: (b, num_heads, num_tokens, head_dim) + context_vec = context_vec.transpose( + 1, 2 + ) # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + else: + + def my_mha(queries, keys, values): + inv_scale = 1 / np.sqrt(values.shape[-1]) + context_vec = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + dropout_p=0.0, + is_causal=True, + scale=inv_scale, + ) + return context_vec + + context_vec = my_mha(queries, keys, values) + context_vec = context_vec.transpose(1, 2) + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + else: + attn_scores = queries @ keys.transpose(2, 3) + + if mask is not None: + attn_scores = attn_scores.masked_fill(mask, -torch.inf) + + scaled_scores = attn_scores / (self.head_dim**0.5) + + if ( + scaled_scores.shape[-1] == self.prompt_length + and self.cfg["use_aie_softmax"] + and scaled_scores.shape[-1] % 16 == 0 + ): + attn_weights = self.aie_softmax(scaled_scores) + else: + attn_weights = torch.nn.functional.softmax(scaled_scores, dim=-1) + + context_vec = (attn_weights @ values).transpose(1, 2) + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + + # Choose output projection based on phase + if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]: + context_vec_flat = context_vec.reshape(1, -1) + output_flat = self.aie_out_proj_gemv(context_vec_flat) + context_vec = output_flat.reshape(b, num_tokens, self.d_out) + elif self.cfg["use_aie_attn_projection_gemm"]: + context_vec_flat = context_vec.reshape(-1, self.d_out) + output_flat = self.aie_out_proj(context_vec_flat) + context_vec = output_flat.reshape(b, num_tokens, self.d_out) + else: + context_vec = self.out_proj(context_vec) + + return context_vec + + def assign_weights(self, l, w_query, w_key, w_value, w_out_proj): + if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]: + self.aie_query_gemv.weight = w_query + self.aie_key_gemv.weight = w_key + self.aie_value_gemv.weight = w_value + self.aie_out_proj_gemv.weight = w_out_proj + + if self.cfg["use_aie_attn_projection_gemm"]: + self.aie_query.weight = w_query + self.aie_key.weight = w_key + self.aie_value.weight = w_value + self.aie_out_proj.weight = w_out_proj + + self.W_query.weight = assign( + self.W_query.weight, + w_query, + f"model.layers.{l}.self_attn.q_proj.weight", + ) + self.W_key.weight = assign( + self.W_key.weight, + w_key, + f"model.layers.{l}.self_attn.k_proj.weight", + ) + self.W_value.weight = assign( + self.W_value.weight, + w_value, + f"model.layers.{l}.self_attn.v_proj.weight", + ) + self.out_proj.weight = assign( + self.out_proj.weight, + w_out_proj, + f"model.layers.{l}.self_attn.o_proj.weight", + ) diff --git a/iron/applications/llama_3.2_1b/src/block/transformer.py b/iron/applications/llama_3.2_1b/src/block/transformer.py new file mode 100644 index 00000000..f2b46cdf --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/block/transformer.py @@ -0,0 +1,195 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn +from ..utils import assign +from src.block.gqa import GroupedQueryAttention +from src.block.feed_forward import FeedForward +from iron.operators import AIERMSNorm, AIEElementwiseAdd + + +class TransformerBlock(nn.Module): + def __init__( + self, + cfg, + prompt_length=42, + num_tokens=1, + ): + super().__init__() + self.cfg = cfg.copy() + + self.att = GroupedQueryAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + num_kv_groups=cfg["n_kv_groups"], + dtype=cfg["dtype"], + prompt_length=prompt_length, + cfg=cfg, + ) + self.ff = FeedForward( + cfg, + prompt_length=prompt_length, + num_tokens=num_tokens, + ) + + if self.cfg["use_aie_norm1"]: + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * self.cfg["emb_dim"] + else: + max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] + self.aie_norm1_prefill = AIERMSNorm( + size=max_prefill_size, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = self.cfg["emb_dim"] # 1 token * emb_dim + self.aie_norm1_decode = AIERMSNorm( + size=decode_size, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_norm1_decode = self.aie_norm1_prefill + else: + self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) + + if self.cfg["use_aie_norm2"]: + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * self.cfg["emb_dim"] + else: + max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] + self.aie_norm2_prefill = AIERMSNorm( + size=max_prefill_size, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = self.cfg["emb_dim"] # 1 token * emb_dim + self.aie_norm2_decode = AIERMSNorm( + size=decode_size, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_norm2_decode = self.aie_norm2_prefill + else: + self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) + + if self.cfg["use_aie_residual"]: + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * cfg["emb_dim"] + else: + max_prefill_size = (prompt_length + num_tokens) * cfg["emb_dim"] + + self.aie_residual_add_prefill = AIEElementwiseAdd( + size=max_prefill_size, + num_aie_columns=8, + num_channels=2, + tile_size=cfg["emb_dim"], + ) + + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = cfg["emb_dim"] # 1 token * emb_dim + self.aie_residual_add_decode = AIEElementwiseAdd( + size=decode_size, + num_aie_columns=1, + num_channels=2, + tile_size=cfg["emb_dim"], + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_residual_add_decode = self.aie_residual_add_prefill + + def forward(self, x, mask, angles, input_pos): + original_shape = x.shape + + # (batch, sequence, embedding) where sequence=1 indicates decode + if len(x.shape) == 3: + is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] + elif len(x.shape) == 2: + is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] + else: + is_decode_with_kv = False + + shortcut = x + if self.cfg["use_aie_norm1"]: + if is_decode_with_kv: + x = self.aie_norm1_decode(x) + else: + x = self.aie_norm1_prefill(x) + else: + x = self.norm1(x) + + x = self.att(x, mask, angles, input_pos) + + if self.cfg["use_aie_residual"]: + if is_decode_with_kv: + x = self.aie_residual_add_decode(x, shortcut) + else: + x = self.aie_residual_add_prefill(x, shortcut) + else: + x = x + shortcut + + # Shortcut connection for feed-forward block + shortcut = x + if self.cfg["use_aie_norm2"]: + if is_decode_with_kv: + x = self.aie_norm2_decode(x) + else: + x = self.aie_norm2_prefill(x) + else: + x = self.norm2(x) + x = self.ff(x) + + if self.cfg["use_aie_residual"]: + if is_decode_with_kv: + x = self.aie_residual_add_decode(x, shortcut) + else: + x = self.aie_residual_add_prefill(x, shortcut) + else: + x = x + shortcut + + return x + + def assign_weights(self, l, norm1, norm2): + if self.cfg["use_aie_norm1"]: + self.aie_norm1_prefill.weight = norm1 + if self.cfg["use_kv_cache"]: + self.aie_norm1_decode.weight = norm1 + if self.cfg["use_aie_norm2"]: + self.aie_norm2_prefill.weight = norm2 + if self.cfg["use_kv_cache"]: + self.aie_norm2_decode.weight = norm2 + return + + self.norm1.weight = assign( + self.norm1.weight, + norm1, + f"model.layers.{l}.input_layernorm.weight", + ) + self.norm2.weight = assign( + self.norm2.weight, + norm2, + f"model.layers.{l}.post_attention_layernorm.weight", + ) diff --git a/iron/applications/llama_3.2_1b/src/model_with_json.py b/iron/applications/llama_3.2_1b/src/model_with_json.py new file mode 100644 index 00000000..856fb048 --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/model_with_json.py @@ -0,0 +1,309 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn +import json +from pathlib import Path +from src.block.transformer import TransformerBlock +from iron.operators.rope.rope_utils import compute_rope_params +from iron.operators import AIERMSNorm, AIEGEMM, AIEGEMV +from rich.console import Console +from rich.text import Text + +from .utils import assign + + +def dtype_from_string(inp): + if isinstance(inp, torch.dtype): + return inp + return {"bfloat16": torch.bfloat16, "float16": torch.float16}.get( + inp, torch.float32 + ) + + +# fmt: off +# Configuration flag key -> (type function, default value, description) +config_options = { + "dtype": (dtype_from_string, torch.float32, "Data type"), + "use_kv_cache": (bool, False, "[Model] KV Cache"), + "use_aie_rope": (bool, False, "[Attention] Rope"), + "use_aie_attn_projection_gemm": (bool, False, "[Attention] QKV GEMM"), + "use_aie_regular_mha": (bool, False, "[Attention] Regular MHA"), + "use_aie_fused_mha": (bool, False, "[Attention] Fused MHA"), + "use_aie_gqa_gemv": (bool, False, "[Attention] GEMV (Decode)"), + "use_aie_ffn_gemm": (bool, False, "[FFN] GEMM"), + "use_aie_ffn_mul": (bool, False, "[FFN] Elementwise Mul"), + "use_aie_ffn_silu": (bool, False, "[FFN] SiLU"), + "use_aie_ffn_swiglu": (bool, False, "[FFN] Runlist-based SwiGLU"), + "use_aie_ffn_gemv": (bool, False, "[FFN] GEMV (Decode)"), + "use_aie_residual": (bool, False, "[Transformer] Residual Addition"), + "use_aie_norm1": (bool, False, "[Transformer] Pre Norm"), + "use_aie_norm2": (bool, False, "[Transformer] Post Norm"), + "use_aie_final_norm": (bool, False, "[Transformer] Final Norm"), + "use_aie_final_gemm": (bool, False, "[Transformer] Final GEMM"), + "use_aie_final_gemv": (bool, False, "[Transformer] Final GEMV"), +} +# fmt: on + + +def load_llama_config(config_path=None): + """Load Llama configuration from JSON file""" + if config_path is None: + # Default to config.json in the llama directory + config_path = Path(__file__).parent.parent / "llama32_1b.json" + + with open(config_path, "r") as f: + config = json.load(f) + + model_config = config["model_config"].copy() + for key, (type_fn, default_value, description) in config_options.items(): + if key in model_config: + model_config[key] = type_fn(model_config[key]) + else: + model_config[key] = default_value + + return model_config + + +def print_config(cfg, console=Console()): + def format_option(name, value): + if isinstance(value, bool): + checkmark = "[green]✔[/green]" if value else "[red]✘[/red]" + return f"{name} {checkmark}" + return f"{name}: {value}" + + dont_print = {"dtype"} + # The following options are mutually exclusive, e.g. regular and fused MHA + # cannot be enabled at the same time. But it looks bad to have red Xs, + # indicating things are running on the CPU when they are not. So, we only + # print one of these mutually exclusive options. + if cfg["use_aie_fused_mha"]: + dont_print |= {"use_aie_regular_mha"} + else: + dont_print |= {"use_aie_fused_mha"} + if cfg["use_aie_ffn_swiglu"]: + dont_print |= { + "use_aie_ffn_gemm", + "use_aie_ffn_mul", + "use_aie_ffn_silu", + } + else: + dont_print |= {"use_aie_ffn_swiglu"} + + console.print( + "AIE Configuration ([green]✔[/green] = AIE NPU / [red]✘[/red] = CPU):", + style="bold underline", + ) + for option_key, (option_ty, option_default, option_name) in config_options.items(): + if option_key in dont_print: + continue + console.print(format_option(option_name, cfg.get(option_key, option_default))) + console.print("") + + +class Llama3ModelWithJSONConfig(nn.Module): + """Llama3 model that loads configuration from JSON file""" + + def __init__( + self, + config_path=None, + prompt_length=0, + num_tokens=1, + ): + super().__init__() + + # Load configuration from JSON + self.cfg = load_llama_config(config_path) + self.prompt_length = prompt_length + self.num_tokens = num_tokens + print_config(self.cfg) + + # Main model parameters + self.tok_emb = nn.Embedding( + self.cfg["vocab_size"], self.cfg["emb_dim"], dtype=self.cfg["dtype"] + ) + + self.trf_blocks = nn.ModuleList( + [ + TransformerBlock( + self.cfg, + prompt_length=prompt_length, + num_tokens=num_tokens, + ) + for i in range(self.cfg["n_layers"]) + ] + ) + + # Create final norm - either AIE or PyTorch + if self.cfg.get("use_aie_final_norm", False): + if self.cfg["use_kv_cache"]: + max_prefill_size = prompt_length * self.cfg["emb_dim"] + else: + max_prefill_size = (prompt_length + num_tokens) * self.cfg["emb_dim"] + self.aie_final_norm_prefill = AIERMSNorm( + size=max_prefill_size, + eps=1e-5, + num_aie_columns=8, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + # For decode phase - single token (only when using KV cache) + if self.cfg["use_kv_cache"]: + decode_size = self.cfg["emb_dim"] # 1 token * emb_dim + self.aie_final_norm_decode = AIERMSNorm( + size=decode_size, + eps=1e-5, + num_aie_columns=1, + num_channels=2, + tile_size=self.cfg["emb_dim"], + ) + else: + # When not using KV cache, use same operator for both phases + self.aie_final_norm_decode = self.aie_final_norm_prefill + else: + self.final_norm = nn.RMSNorm( + self.cfg["emb_dim"], eps=1e-5, dtype=self.cfg["dtype"] + ) + + # Offload final linear layer if enabled + if self.cfg.get("use_aie_final_gemm", False): + # Since this GEMM has such a large N dimension, partition the N dimension by 4, + # and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C + aie_config_prefill = { + "num_aie_columns": 8, + "tile_m": 64, + "tile_k": 64, + "tile_n": 64, + "b_col_maj": True, + "use_static_weight": True, + "separate_c_tiles": True, + "partition_N": 4, + } + if self.cfg["use_kv_cache"]: + M_for_gemm = self.prompt_length + else: + M_for_gemm = self.prompt_length + self.num_tokens + self.out_head_prefill = AIEGEMM( + M=M_for_gemm, + K=self.cfg["emb_dim"], + N=self.cfg["vocab_size"], + **aie_config_prefill, + ) + aie_gemv_config = { + "num_aie_columns": 8, + "is_mv": True, + "use_static_weight": True, + "num_aie_columns": 8, + "tile_size_input": 4, + "tile_size_output": 32, + } + # FC1 and FC2: emb_dim -> hidden_dim + if self.cfg["use_aie_final_gemv"]: + self.out_head_decode = AIEGEMV( + M=self.cfg["vocab_size"], K=self.cfg["emb_dim"], **aie_gemv_config + ) + else: + self.out_head = nn.Linear( + self.cfg["emb_dim"], + self.cfg["vocab_size"], + bias=False, + dtype=self.cfg["dtype"], + ) + + # Reusable utilities + cos, sin = compute_rope_params( + head_dim=self.cfg["emb_dim"] // self.cfg["n_heads"], + theta_base=self.cfg["rope_base"], + context_length=self.cfg["context_length"], + freq_config=self.cfg["rope_freq"], + ) + angles = torch.cat([torch.empty_like(cos), torch.empty_like(cos)], dim=1) + angles[:, ::2] = cos + angles[:, 1::2] = sin + self.register_buffer("angles", angles, persistent=False) + + def forward(self, in_idx, input_pos=None, use_kv_cache=False): + # Forward pass + tok_embeds = self.tok_emb(in_idx) + x = tok_embeds + + # Check if input is a vector (decode phase) or matrix (prefill phase) + # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim) + is_vector = ( + len(x.shape) == 1 + or (len(x.shape) == 2 and x.shape[0] == 1) + or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1) + ) + + # (batch, sequence, embedding) where sequence=1 indicates decode + if len(x.shape) == 3: + is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"] + elif len(x.shape) == 2: + is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"] + else: + is_decode_with_kv = False + + num_tokens = x.shape[1] + + # During generation phase with KV cache, don't create a mask + # The attention layer will handle masking based on position + if use_kv_cache and input_pos is not None: + mask = None + else: + # During prefill, create standard causal mask + mask = torch.triu( + torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), + diagonal=1, + ) + + for block in self.trf_blocks: + x = block(x, mask, self.angles, input_pos) + + # Sequence length of 1 from input shape means we're in the decode stage, which can use KV cache + if self.cfg.get("use_aie_final_norm", False): + if (x.shape[-2] == 1) and self.cfg.get("use_kv_cache", False): + x = self.aie_final_norm_decode(x) + else: + x = self.aie_final_norm_prefill(x) + else: + x = self.final_norm(x) + + if self.cfg["use_aie_final_gemm"]: + if is_decode_with_kv and self.cfg["use_aie_final_gemv"]: + logits = self.out_head_decode(x) + else: + logits = self.out_head_prefill(x) + else: + logits = self.out_head(x) + + return logits + + def assign_weights(self, final_norm, out_head, out_head_name): + if self.cfg.get("use_aie_final_norm", False): + self.aie_final_norm_prefill.weight = final_norm + if self.cfg["use_kv_cache"]: + self.aie_final_norm_decode.weight = final_norm + else: + self.final_norm.weight = assign( + self.final_norm.weight, + final_norm, + f"model.norm.weight", + ) + + if self.cfg["use_aie_final_gemm"]: + # Want column-major for B + self.out_head_prefill.weight = out_head.T + if self.cfg["use_aie_final_gemv"]: + self.out_head_decode.weight = out_head.T + else: + self.out_head.weight = assign( + self.out_head.weight, + out_head, + out_head_name, + ) diff --git a/iron/applications/llama_3.2_1b/src/tokenizer.py b/iron/applications/llama_3.2_1b/src/tokenizer.py new file mode 100644 index 00000000..1a16cf57 --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/tokenizer.py @@ -0,0 +1,101 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + + +class Tokenizer: + """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" + + def __init__(self, model_path): + if not os.path.isfile(model_path): + raise FileNotFoundError(model_path) + + mergeable = load_tiktoken_bpe(model_path) + + # hard-coded from Meta's tokenizer.json + self.special = { + "<|begin_of_text|>": 128000, + "<|end_of_text|>": 128001, + "<|start_header_id|>": 128006, + "<|end_header_id|>": 128007, + "<|eot_id|>": 128009, + } + self.special.update( + { + f"<|reserved_{i}|>": 128002 + i + for i in range(256) + if 128002 + i not in self.special.values() + } + ) + + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" + r"|[^\r\n\p{L}\p{N}]?\p{L}+" + r"|\p{N}{1,3}" + r"| ?[^\s\p{L}\p{N}]+[\r\n]*" + r"|\s*[\r\n]+" + r"|\s+(?!\S)" + r"|\s+", + mergeable_ranks=mergeable, + special_tokens=self.special, + ) + + def encode(self, text, bos=False, eos=False): + ids = ([self.special["<|begin_of_text|>"]] if bos else []) + self.model.encode( + text + ) + if eos: + ids.append(self.special["<|end_of_text|>"]) + return ids + + def decode(self, ids): + return self.model.decode(ids) + + +class ChatFormat: + + def __init__( + self, tokenizer: Tokenizer, *, default_system="You are a helpful assistant." + ): + self.tok = tokenizer + self.default_system = default_system + + def _header(self, role): + """Encode <|start_header_id|>role<|end_header_id|>\n\n""" + return ( + [self.tok.special["<|start_header_id|>"]] + + self.tok.encode(role) + + [self.tok.special["<|end_header_id|>"]] + + self.tok.encode("\n\n") + ) + + def encode(self, user_message, system_message=None): + sys_msg = system_message if system_message is not None else self.default_system + + ids = [self.tok.special["<|begin_of_text|>"]] + + # system + ids += self._header("system") + ids += self.tok.encode(sys_msg) + ids += [self.tok.special["<|eot_id|>"]] + + # user + ids += self._header("user") + ids += self.tok.encode(user_message) + ids += [self.tok.special["<|eot_id|>"]] + + # assistant header (no content yet) + ids += self._header("assistant") + + return ids diff --git a/iron/applications/llama_3.2_1b/src/utils.py b/iron/applications/llama_3.2_1b/src/utils.py new file mode 100644 index 00000000..158b59df --- /dev/null +++ b/iron/applications/llama_3.2_1b/src/utils.py @@ -0,0 +1,307 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0. +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb +# +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time +import torch +import numpy as np +from ml_dtypes import bfloat16 + + +def model_memory_size(model, input_dtype=torch.float32): + """ + Calculate the estimated memory size of a PyTorch model in gigabytes. + + This function computes the total memory required for the model's parameters, + gradients, and buffers based on the input data type. + + Args: + model (torch.nn.Module): The PyTorch model for which to calculate memory size. + input_dtype (torch.dtype, optional): The data type of the model's input. + Defaults to torch.float32. + + Returns: + float: The estimated memory size of the model in gigabytes. + """ + + total_params = 0 + total_grads = 0 + for param in model.parameters(): + # Calculate total number of elements per parameter + param_size = param.numel() + total_params += param_size + # Check if gradients are stored for this parameter + if param.requires_grad: + total_grads += param_size + + # Calculate buffer size (non-parameters that require memory) + total_buffers = sum(buf.numel() for buf in model.buffers()) + + # Size in bytes = (Number of elements) * (Size of each element in bytes) + # We assume parameters and gradients are stored in the same type as input dtype + element_size = torch.tensor(0, dtype=input_dtype).element_size() + total_memory_bytes = (total_params + total_grads + total_buffers) * element_size + + # Convert bytes to gigabytes + total_memory_gb = total_memory_bytes / (1024**3) + + return total_memory_gb + + +def assign(left, right, tensor_name="unknown"): + """ + Assigns the value of the right tensor to a new torch.nn.Parameter after validating shape compatibility. + + Parameters: + left (torch.Tensor or any): The tensor to compare shape with. + right (torch.Tensor or any): The tensor or value to be assigned. + tensor_name (str): The name of the tensor for error reporting (default is "unknown"). + + Returns: + torch.nn.Parameter: A new parameter containing the value of right. + + Raises: + ValueError: If the shapes of left and right do not match. + """ + + if left.shape != right.shape: + raise ValueError( + f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}" + ) + + if isinstance(right, torch.Tensor): + return torch.nn.Parameter(right.clone().detach()) + else: + return torch.nn.Parameter(torch.tensor(right)) + + +def load_weights_into_llama(model, param_config, params): + """ + Load weights into the LLaMA model from the provided parameters. + + This function assigns weights from the given parameters to the corresponding + layers of the LLaMA model. It handles the embedding layer, attention layers, + feedforward layers, and the output layer. The function also checks for weight + tying between the output head and the embedding layer. + + Args: + model: The LLaMA model instance into which weights will be loaded. + param_config (dict): A configuration dictionary containing model parameters, + including the number of layers (`n_layers`). + params (dict): A dictionary containing the weights to be loaded, with keys + corresponding to the model's architecture. + """ + model.tok_emb.weight = assign( + model.tok_emb.weight, + params["model.embed_tokens.weight"], + "model.embed_tokens.weight", + ) + + for l in range(param_config["n_layers"]): + + # Load attention weights + model.trf_blocks[l].att.assign_weights( + l, + params[f"model.layers.{l}.self_attn.q_proj.weight"], + params[f"model.layers.{l}.self_attn.k_proj.weight"], + params[f"model.layers.{l}.self_attn.v_proj.weight"], + params[f"model.layers.{l}.self_attn.o_proj.weight"], + ) + # Load FeedForward weights + model.trf_blocks[l].ff.assign_weights( + l, + fc1=params[f"model.layers.{l}.mlp.gate_proj.weight"], + fc2=params[f"model.layers.{l}.mlp.up_proj.weight"], + fc3=params[f"model.layers.{l}.mlp.down_proj.weight"], + ) + # Load RMS norm weights + model.trf_blocks[l].assign_weights( + l, + params[f"model.layers.{l}.input_layernorm.weight"], + params[f"model.layers.{l}.post_attention_layernorm.weight"], + ) + + # Load output layer weights + if "lm_head.weight" in params.keys(): + model.assign_weights( + params["model.norm.weight"], params["lm_head.weight"], "lm_head.weight" + ) + else: + model.assign_weights( + params["model.norm.weight"], + params["model.embed_tokens.weight"], + "model.embed_tokens.weight", + ) + + +def text_to_token_ids(text, tokenizer): + """ + Convert a given text into token IDs using the specified tokenizer. + + Args: + text (str): The input text to be tokenized. + tokenizer: An instance of a tokenizer that has an `encode` method. + + Returns: + torch.Tensor: A tensor containing the token IDs of the input text, + with an added batch dimension. + """ + encoded = tokenizer.encode(text) + encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension + return encoded_tensor + + +def token_ids_to_text(token_ids, tokenizer): + """ + Convert a tensor of token IDs to a human-readable text string. + + Args: + token_ids (torch.Tensor): A tensor containing token IDs, + typically with a batch dimension. + tokenizer (Tokenizer): An instance of a tokenizer that has a + decode method to convert token IDs to text. + + Returns: + str: The decoded text string corresponding to the input token IDs. + """ + flat = token_ids.squeeze(0) # remove batch dimension + return tokenizer.decode(flat.tolist()) + + +def generate( + model, + idx, + max_new_tokens, + context_size, + eos_id, + hook_handles, + temperature=0.0, + top_k=None, + tokenizer=None, + prompt=None, + do_print=True, + prefill_done_callback=None, +): + """ + Generate new tokens using the provided model based on the input sequence. + + Args: + model: The model used for generating tokens. It should accept input sequences and return logits. + idx (torch.Tensor): The input sequence of token indices (shape: (batch_size, sequence_length)). + max_new_tokens (int): The maximum number of new tokens to generate. + context_size (int): The number of tokens from the input sequence to consider for generation. + temperature (float, optional): The temperature for scaling logits. Higher values result in more random outputs. Default is 0.0 (no scaling). + top_k (int, optional): The number of top logits to consider for sampling. If None, all logits are used. Default is None. + eos_id (int, optional): The end-of-sequence token ID. If specified, generation will stop when this token is produced. Default is None. + + Returns: + torch.Tensor: The updated sequence of token indices after generation (shape: (batch_size, new_sequence_length)). + """ + # For-loop is the same as before: Get logits, and only focus on last time step + finished_prefill = False + + print(f"Starting prefill inference...") + + for i in range(max_new_tokens): + use_kv_cache = model.cfg["use_kv_cache"] + + if use_kv_cache: + if i == 0: + # Prefill phase - process entire sequence + idx_cond = idx[:, -context_size:] + input_pos = None + else: + # Generation phase with KV cache - single token, need to track position + # Extract only the last token + idx_cond = idx[:, -1:] + input_pos = torch.tensor([idx.shape[1] - 1], device=idx.device) + else: + # No KV cache - always process entire sequence (GEMM every time) + idx_cond = idx[:, -context_size:] + input_pos = None + with torch.no_grad(): + logits = model(idx_cond, input_pos=input_pos, use_kv_cache=use_kv_cache) + logits = logits[:, -1, :] + + # New: Filter logits with top_k sampling + if top_k is not None: + # Keep only top_k values + top_logits, _ = torch.topk(logits, top_k) + min_val = top_logits[:, -1] + logits = torch.where( + logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits + ) + + # New: Apply temperature scaling + if temperature > 0.0: + logits = logits / temperature + + # Apply softmax to get probabilities + probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) + + # Sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) + + # Otherwise same as before: get idx of the vocab entry with the highest logits value + else: + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + + # Only run the forward hook for the prefill stage, remove it afterwards to speed up inference + if not finished_prefill: + if hook_handles: + for handle in hook_handles: + handle.remove() + finished_prefill = True + + if ( + idx_next == eos_id + ): # Stop generating early if end-of-sequence token is encountered and eos_id is specified + break + + # Same as before: append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) + + # End timing the first iteration + if i == 0: + if prefill_done_callback is not None: + prefill_done_callback() + if do_print: + print(prompt) + + # print(f"\rGenerating token {i + 1}/{max_new_tokens}...", end="") + generated_text = token_ids_to_text(idx_next, tokenizer) + if do_print: + print(f"{generated_text}", end="", flush=True) + + print("\n\n") + return idx + + +def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): + """ + Cleans the input text by removing the header portion defined by the header_end token. + + Parameters: + text (str): The input text to be cleaned. + header_end (str): The token that marks the end of the header. Defaults to "assistant<|end_header_id|>\n\n". + + Returns: + str: The cleaned text, which is the substring after the header_end token. + If the token is not found, the original text is returned. + """ + + # Find the index of the first occurrence of "<|end_header_id|>" + index = text.find(header_end) + + if index != -1: + # Return the substring starting after "<|end_header_id|>" + return text[ + index + len(header_end) : + ].strip() # Strip removes leading/trailing whitespace + else: + # If the token is not found, return the original text + return text diff --git a/iron/applications/llama_3.2_1b/test.py b/iron/applications/llama_3.2_1b/test.py new file mode 100644 index 00000000..933b7d5e --- /dev/null +++ b/iron/applications/llama_3.2_1b/test.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import subprocess +import pytest +from pathlib import Path + +test_dir = Path(__file__).parent +weights_dir = Path("/srv") + + +def generate_test_params(): + prompt_lengths = [2048, 13] + num_tokens_list = [40, 1] + + params = [] + names = [] + for prompt_len in prompt_lengths: + for num_tokens in num_tokens_list: + params.append((prompt_len, num_tokens)) + names.append(f"llama_3.2_1b_prompt_{prompt_len}_tokens_{num_tokens}") + return params, names + + +params, names = generate_test_params() + + +@pytest.mark.metrics( + TTFT=r"Prefill time: (?P[\d\.e\+-]+) seconds", + TPS=r"Tokens per second: (?P[\d\.e\+-]+)", + Num_Tokens=r"Tokens generated: (?P[\d\.e\+-]+)", +) +@pytest.mark.parametrize("prompt_len,num_tokens", params, ids=names) +def test_llama_3_2_1b(prompt_len, num_tokens): + command = f"python3 {test_dir}/inference.py {weights_dir}/llama3.2-1b/model.safetensors {weights_dir}/llama3.2-1b/tokenizer.model --prompt_len {prompt_len} --num_tokens {num_tokens}" + + result = subprocess.run( + command, + cwd=test_dir, + shell=True, + capture_output=True, + text=True, + timeout=300, + ) + + assert ( + result.returncode == 0 + ), f"Command failed with return code {result.returncode}\nStderr: {result.stderr}" + + print(result.stdout) diff --git a/iron/applications/llama_3.2_1b/torch_to_npy.py b/iron/applications/llama_3.2_1b/torch_to_npy.py new file mode 100644 index 00000000..e7d06be0 --- /dev/null +++ b/iron/applications/llama_3.2_1b/torch_to_npy.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import argparse +import numpy as np +import os +import shutil + + +def torch_to_npy(inp_file_path, outp_file_path): + # Load the torch file + data = torch.load(inp_file_path) + # Convert the tensor to a numpy array of floats + data_np = data.to(torch.float32).numpy() + # Compare the values between data and data_np + if not torch.equal(data, torch.from_numpy(data_np)): + raise ValueError("Mismatch between original data and converted numpy array.") + + # Save the array to a npy file + np.save(outp_file_path, data_np) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert torch files to npy files.") + parser.add_argument( + "file_path", + type=str, + help="Path to the torch file or directory containing torch files", + ) + args = parser.parse_args() + file_path = args.file_path + + output_dir = os.path.join("results", f"{os.path.basename(file_path)}_npy") + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) + + # Check if the file path is a directory + if os.path.isdir(file_path): + for file_name in os.listdir(file_path): + if file_name.endswith(".pt") or file_name.endswith(".pth"): + full_path = os.path.join(file_path, file_name) + output_file_path = os.path.join( + output_dir, file_name.replace(".pt", ".npy").replace(".pth", ".npy") + ) + torch_to_npy(full_path, output_file_path) + else: + torch_to_npy(file_path) diff --git a/iron/operators/repeat/design.py b/iron/operators/repeat/design.py deleted file mode 100644 index a3539caa..00000000 --- a/iron/operators/repeat/design.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np - -# from aie.extras.context import mlir_mod_ctx -# from aie.ir import StridedLayoutAttr, ShapedType -# from aie.dialects.aie import * -# from aie.dialects.aiex import * -from aie.dialects.aiex import TensorAccessPattern -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker -from aie.iron.placers import SequentialPlacer - -""" -Repeat interleave -""" - - -def repeat(dev, dtype, rows, cols, repeat, transfer_size=None): - dtype = np.dtype[dtype] - - # Try to work around hardware size limitations by breaking transfers into smaller chunks - cols_split = 1 - if cols > 1023: - for divisor in range(2, cols + 1): - if cols % divisor == 0 and cols // divisor <= 1023: - cols_split = divisor - break - else: - raise ValueError( - f"Cannot split cols={cols} into chunks <= 1023; hardware limits cols to not exceed 1023" - ) - assert cols_split <= 1023, "cols is too large, can't split into smaller transfers" - - if transfer_size is None: - transfer_size = cols - - inp_ty = np.ndarray[ - (rows, cols), - dtype, - ] - out_ty = np.ndarray[ - (rows * repeat, cols), - dtype, - ] - transfer_ty = np.ndarray[ - (transfer_size,), - dtype, - ] - - input_tap = TensorAccessPattern( - tensor_dims=(rows, cols), - offset=0, - sizes=[repeat, rows, cols // cols_split, cols_split], - strides=[0, cols, cols_split, 1], - ) - - output_tap = TensorAccessPattern( - tensor_dims=(rows * repeat, cols), - offset=0, - sizes=[repeat, rows, cols // cols_split, cols_split], - strides=[cols, cols * repeat, cols_split, 1], - ) - - # Use smaller FIFOs for the transfer amount - fifo_in = ObjectFifo(transfer_ty, name="fifo_in", depth=2) - fifo_out = fifo_in.cons().forward(name="fifo_out", depth=2) - - rt = Runtime() - with rt.sequence(inp_ty, out_ty) as (inp, out): - tg = rt.task_group() - rt.fill(fifo_in.prod(), inp, input_tap, task_group=tg) - rt.drain(fifo_out.cons(), out, output_tap, task_group=tg, wait=True) - rt.finish_task_group(tg) - - return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/repeat/op.py b/iron/operators/repeat/op.py deleted file mode 100644 index b056f591..00000000 --- a/iron/operators/repeat/op.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import numpy as np -from ml_dtypes import bfloat16 -from pathlib import Path - -from iron.common import ( - MLIROperator, - AIERuntimeArgSpec, - KernelObjectArtifact, - SourceArtifact, - PythonGeneratedMLIRArtifact, - XclbinArtifact, - InstsBinArtifact, -) - - -class AIERepeat(MLIROperator): - """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" - - def __init__( - self, - rows, - cols, - repeat, - transfer_size=None, - dtype=bfloat16, - context=None, - ): - self.rows = rows - self.cols = cols - self.repeat = repeat - self.transfer_size = transfer_size - self.dtype = dtype - MLIROperator.__init__(self, context=context) - - def get_operator_name(self): - name = f"repeat_{self.rows}x{self.cols}_by_{self.repeat}" - if self.transfer_size is not None: - name += f"_{self.transfer_size}ts" - return name - - def get_mlir_artifact(self): - operator_dir = Path(__file__).parent - - return PythonGeneratedMLIRArtifact( - f"{self.get_operator_name()}.mlir", - import_path=operator_dir / "design.py", - callback_fn="repeat", - callback_args=[ - self.context.device_manager.device_type, - self.dtype, - self.rows, - self.cols, - self.repeat, - self.transfer_size, - ], - ) - - def get_kernel_artifacts(self): - return [] - - def get_arg_spec(self): - return [ - AIERuntimeArgSpec("in", (self.rows, self.cols)), - AIERuntimeArgSpec("out", (self.rows * self.repeat, self.cols)), - ] diff --git a/iron/operators/strided_copy/design.py b/iron/operators/strided_copy/design.py deleted file mode 100644 index 63b97e33..00000000 --- a/iron/operators/strided_copy/design.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np - -# from aie.extras.context import mlir_mod_ctx -# from aie.ir import StridedLayoutAttr, ShapedType -# from aie.dialects.aie import * -# from aie.dialects.aiex import * -from aie.dialects.aiex import TensorAccessPattern -from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker -from aie.iron.placers import SequentialPlacer - -""" -Strided copy design - -This can be useful for data layout manipulation and data copying such as: -input[0, :, 0] -> output[:, 0, 0] -""" - - -def strided_copy( - dev, - dtype, - input_buffer_size, - input_sizes, - input_strides, - input_offset, - output_buffer_size, - output_sizes, - output_strides, - output_offset, - transfer_size=None, - num_aie_channels=1, - input_offset_patch_marker=0, - output_offset_patch_marker=0, -): - assert len(input_sizes) == len(input_strides) - assert len(output_sizes) == len(output_strides) - - # Pad out dimensions to 4D; dropping leading dimensions leads to compiler not initializing these registers, causing hard-to-debug errors - input_sizes = [1] * (4 - len(input_sizes)) + list(input_sizes) - input_strides = [0] * (4 - len(input_strides)) + list(input_strides) - output_sizes = [1] * (4 - len(output_sizes)) + list(output_sizes) - output_strides = [0] * (4 - len(output_strides)) + list(output_strides) - - input_highest_sz_idx = max(idx for idx, sz in enumerate(input_sizes) if sz >= 1) - output_highest_sz_idx = max(idx for idx, sz in enumerate(output_sizes) if sz >= 1) - assert ( - input_sizes[input_highest_sz_idx] % num_aie_channels == 0 - ), "Highest dimension of input_sizes must be divisible by num_aie_channels" - assert ( - output_sizes[output_highest_sz_idx] % num_aie_channels == 0 - ), "Highest dimension of output_sizes must be divisible by num_aie_channels" - - if transfer_size is None: - transfer_size = int(np.prod(input_sizes)) - assert np.prod(input_sizes) % transfer_size == 0 - transfer_ty = np.ndarray[ - (transfer_size,), - np.dtype[dtype], - ] - - inp_ty = np.ndarray[ - (int(input_buffer_size),), - np.dtype[dtype], - ] - out_ty = np.ndarray[ - (int(output_buffer_size),), - np.dtype[dtype], - ] - - input_taps = [ - TensorAccessPattern( - tensor_dims=(int(input_buffer_size + input_offset_patch_marker),), - offset=( - input_offset_patch_marker - if input_offset_patch_marker != 0 - else input_offset - + c - * (input_sizes[input_highest_sz_idx] // num_aie_channels) - * input_strides[input_highest_sz_idx] - ), - sizes=( - input_sizes[:input_highest_sz_idx] - + [input_sizes[input_highest_sz_idx] // num_aie_channels] - + input_sizes[input_highest_sz_idx + 1 :] - ), - strides=list(input_strides), - ) - for c in range(num_aie_channels) - ] - - output_taps = [ - TensorAccessPattern( - tensor_dims=(int(output_buffer_size + output_offset_patch_marker),), - offset=( - output_offset_patch_marker - if output_offset_patch_marker != 0 - else output_offset - + c - * (output_sizes[output_highest_sz_idx] // num_aie_channels) - * output_strides[output_highest_sz_idx] - ), - sizes=( - output_sizes[:output_highest_sz_idx] - + [output_sizes[output_highest_sz_idx] // num_aie_channels] - + output_sizes[output_highest_sz_idx + 1 :] - ), - strides=list(output_strides), - ) - for c in range(num_aie_channels) - ] - - # Use smaller FIFOs for the transfer amount - fifos_in = [ - ObjectFifo(transfer_ty, name=f"fifo_in_{c}", depth=1) - for c in range(num_aie_channels) - ] - fifos_out = [ - fifos_in[c].cons().forward(name=f"fifo_out_{c}", depth=1) - for c in range(num_aie_channels) - ] - - rt = Runtime() - with rt.sequence(inp_ty, out_ty) as (inp, out): - tg = rt.task_group() - for c in range(num_aie_channels): - rt.fill(fifos_in[c].prod(), inp, input_taps[c], task_group=tg) - rt.drain(fifos_out[c].cons(), out, output_taps[c], task_group=tg, wait=True) - rt.finish_task_group(tg) - - return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/strided_copy/op.py b/iron/operators/strided_copy/op.py deleted file mode 100644 index 5996a90d..00000000 --- a/iron/operators/strided_copy/op.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import torch -import numpy as np -from ml_dtypes import bfloat16 -from pathlib import Path - -from iron.common import ( - MLIROperator, - AIERuntimeArgSpec, - KernelObjectArtifact, - SourceArtifact, - PythonGeneratedMLIRArtifact, - XclbinArtifact, - InstsBinArtifact, -) - - -class AIEStridedCopy(MLIROperator): - """AIE-accelerated General Matrix-Vector/Vector-Matrix Multiplication layer""" - - def __init__( - self, - input_sizes, - input_strides, - input_offset, - output_sizes, - output_strides, - output_offset, - input_buffer_size, - output_buffer_size, - dtype=bfloat16, - transfer_size=None, - num_aie_channels=1, - context=None, - **kwargs, - ): - assert len(input_sizes) == len(input_strides) - assert len(output_sizes) == len(output_strides) - self.input_sizes = input_sizes - self.input_strides = input_strides - self.input_offset = input_offset - self.output_sizes = output_sizes - self.output_strides = output_strides - self.output_offset = output_offset - self.input_buffer_size = input_buffer_size - self.output_buffer_size = output_buffer_size - self.dtype = dtype - self.transfer_size = transfer_size - self.num_aie_channels = num_aie_channels - self.kwargs = kwargs - MLIROperator.__init__(self, context=context) - - def get_operator_name(self): - return f"strided_copy_{'x'.join(map(str, self.input_sizes))}sz_{'x'.join(map(str, self.input_strides))}st_{self.input_offset}off_to_{'x'.join(map(str, self.output_sizes))}sz_{'x'.join(map(str, self.output_strides))}st_{self.output_offset}off_{self.transfer_size if self.transfer_size is not None else 'auto'}tr_{self.num_aie_channels}ch" - - def get_mlir_artifact(self): - operator_dir = Path(__file__).parent - - return PythonGeneratedMLIRArtifact( - f"{self.get_operator_name()}.mlir", - import_path=operator_dir / "design.py", - callback_fn="strided_copy", - callback_args=[ - self.context.device_manager.device_type, - self.dtype, - self.input_buffer_size, - self.input_sizes, - self.input_strides, - self.input_offset, - self.output_buffer_size, - self.output_sizes, - self.output_strides, - self.output_offset, - self.transfer_size, - self.num_aie_channels, - ], - callback_kwargs=self.kwargs, - ) - - def get_kernel_artifacts(self): - return [] - - def get_arg_spec(self): - return [ - AIERuntimeArgSpec("in", self.input_buffer_size), # matrix - AIERuntimeArgSpec("out", self.output_buffer_size), # output - ] diff --git a/iron/operators/strided_copy/test2.py b/iron/operators/strided_copy/test2.py deleted file mode 100755 index 994d8ab3..00000000 --- a/iron/operators/strided_copy/test2.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import sys -from pathlib import Path -import time -import torch - -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from operators.strided_copy.op import AIEStridedCopy -from operators.common import AIEBuffer - -max_prompt_len = 2048 -cached_prompt_len = 9 -prompt_len = 7 -head_dim = 64 -num_heads = 32 - -transpose_concat = ( - AIEStridedCopy( - input_sizes=( - 1, - num_heads, - prompt_len, - head_dim, - ), - input_strides=( - 0, - head_dim, - num_heads * head_dim, - 1, - ), - input_offset=0, - output_sizes=( - 1, - num_heads, - prompt_len, - head_dim, - ), - output_strides=( - 0, - max_prompt_len * head_dim, - head_dim, - 1, - ), - output_offset=cached_prompt_len * head_dim, - input_buffer_size=prompt_len * num_heads * head_dim, - output_buffer_size=num_heads * max_prompt_len * head_dim, - num_aie_channels=1, - ) - .compile() - .get_callable() -) - -value_cache_1 = AIEBuffer((num_heads, max_prompt_len, head_dim)) -value_1 = AIEBuffer((prompt_len, num_heads, head_dim)) -value_cache_1.view_as_torch()[:, :cached_prompt_len, :] = torch.randn( - num_heads, cached_prompt_len, head_dim -) -value_1.view_as_torch()[:prompt_len, :, :] = torch.randn( - prompt_len, num_heads, head_dim -) - -value_cache = AIEBuffer((num_heads, max_prompt_len, head_dim)) -value = AIEBuffer((prompt_len, num_heads, head_dim)) - -value_cache.view_as_torch()[:, :cached_prompt_len, :] = torch.randn( - num_heads, cached_prompt_len, head_dim -) -value.view_as_torch()[:prompt_len, :, :] = torch.randn(prompt_len, num_heads, head_dim) - -t_cpu_start = time.perf_counter() -value_transposed = value.view_as_torch().transpose(0, 1) -out_ref = torch.cat( - [value_cache.view_as_torch()[:, :cached_prompt_len, :], value_transposed], dim=1 -) -t_cpu = time.perf_counter() - t_cpu_start - -transpose_concat(value_1, value_cache_1) -value_cache.to("npu") -value.to("npu") -t_aie_start = time.perf_counter() -transpose_concat(value, value_cache) -t_aie = time.perf_counter() - t_aie_start -value_cache.to("cpu") - -print(out_ref) -print(t_cpu) -aie_out = value_cache.view_as_torch()[:, : (cached_prompt_len + prompt_len), :] -print(aie_out) -print(t_aie) - -# Check which elements differ -diff = torch.abs(out_ref - aie_out) -max_diff = diff.max() -print(f"Max diff: {max_diff}") -print(f"Number of mismatches (> 1e-2): {(diff > 1e-2).sum()}") - -# Find first mismatch -mismatches = torch.where(diff > 1e-2) -if len(mismatches[0]) > 0: - for i in range(min(10, len(mismatches[0]))): - h, s, d = mismatches[0][i], mismatches[1][i], mismatches[2][i] - print( - f"Mismatch at head={h}, seq={s}, dim={d}: ref={out_ref[h,s,d]}, aie={aie_out[h,s,d]}, diff={diff[h,s,d]}" - ) - -assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) From b9cbcba5008ab169952c1c0659e74c1248b8d71a Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 10:22:52 -0700 Subject: [PATCH 77/99] remove extra test file --- iron/operators/softmax/test2.py | 64 --------------------------------- 1 file changed, 64 deletions(-) delete mode 100755 iron/operators/softmax/test2.py diff --git a/iron/operators/softmax/test2.py b/iron/operators/softmax/test2.py deleted file mode 100755 index cbdfb93b..00000000 --- a/iron/operators/softmax/test2.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import sys -from pathlib import Path -import time -import torch - -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from operators.softmax.op import AIESoftmax -from operators.common import AIEBuffer - -max_context_len = 2048 -prompt_len = 8 -n_heads = 32 - -softmax_op = ( - AIESoftmax(rows=n_heads, cols=max_context_len, rtp_vector_size=prompt_len) - .compile() - .get_callable() -) - -inp = AIEBuffer((n_heads, max_context_len)) -out = AIEBuffer((n_heads, max_context_len)) - -inp.view_as_torch()[:] = torch.randn(n_heads, max_context_len) -out.view_as_torch()[:] = torch.zeros(n_heads, max_context_len) - -t_cpu_start = time.perf_counter() -out_ref = inp.view_as_torch()[:, :prompt_len].softmax(dim=-1) -t_cpu = time.perf_counter() - t_cpu_start - -inp.to("npu") -out.to("npu") -t_aie_start = time.perf_counter() -softmax_op(inp, out) -t_aie = time.perf_counter() - t_aie_start -out.to("cpu") - -print(out_ref) -print(t_cpu) -aie_out = out.view_as_torch()[:, :prompt_len] -print(aie_out) -print(t_aie) - -# Check which elements differ -diff = torch.abs(out_ref - aie_out) -max_diff = diff.max() -print(f"Max diff: {max_diff}") -print(f"Number of mismatches (> 1e-2): {(diff > 1e-2).sum()}") - -# Find first mismatch -mismatches = torch.where(diff > 1e-2) -if len(mismatches[0]) > 0: - for i in range(min(10, len(mismatches[0]))): - h, s = mismatches[0][i], mismatches[1][i] - print( - f"Mismatch at head={h}, seq={s}: ref={out_ref[h,s]}, aie={aie_out[h,s]}, diff={diff[h,s]}" - ) - -assert torch.allclose(out_ref, aie_out, atol=1e-2, rtol=1e-2) From b8815abd64f3fb82ab8aef297a2af680ba328ba6 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 10:25:16 -0700 Subject: [PATCH 78/99] fixup imports --- iron/operators/__init__.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index 5d55514b..79ffec8d 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -1,30 +1,13 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# from .axpy.op import AIEAXPY -# from .dequant.op import AIEDequant from .elementwise_add.op import AIEElementwiseAdd from .elementwise_mul.op import AIEElementwiseMul - -# from .gelu.op import AIEGELU from .gemm.op import AIEGEMM from .gemv.op import AIEGEMV - -# from .layer_norm.op import AIELayerNorm -# from .leaky_relu.op import AIELeakyReLU -# from .mem_copy.op import AIEMemCopy -# from .mha.op import AIEMHA -# from .relu.op import AIEReLU from .rms_norm.op import AIERMSNorm from .rope.op import AIERope - -# from .sigmoid.op import AIESigmoid from .silu.op import AIESiLU from .softmax.op import AIESoftmax - -# from .swiglu_decode.op import AIESwiGLUDecode -# from .swiglu_prefill.op import AIESwiGLUPrefill -# from .tanh.op import AIETanh from .transpose.op import AIETranspose -from .strided_copy.op import AIEStridedCopy from .repeat.op import AIERepeat From 93e04a58f86d2fe3adcd7918cb85348adb902de8 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 10:30:44 -0700 Subject: [PATCH 79/99] fix another import path --- iron/operators/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index 79ffec8d..66435428 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -10,4 +10,3 @@ from .silu.op import AIESiLU from .softmax.op import AIESoftmax from .transpose.op import AIETranspose -from .repeat.op import AIERepeat From d950db18f0c4e2ac21ec5570ccda7015a115db3d Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 15:03:49 -0700 Subject: [PATCH 80/99] Fix import issue --- iron/operators/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index 66435428..9da2febf 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -5,6 +5,7 @@ from .elementwise_mul.op import AIEElementwiseMul from .gemm.op import AIEGEMM from .gemv.op import AIEGEMV +from .mha.op import AIEMHA from .rms_norm.op import AIERMSNorm from .rope.op import AIERope from .silu.op import AIESiLU From c4a948e078db6881b00df3a95996fb9558240406 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 15:20:52 -0700 Subject: [PATCH 81/99] GEMM fixup --- iron/operators/gemm/design.py | 1 + iron/operators/gemm/op.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index 230c7991..e5b4d748 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -106,6 +106,7 @@ def main(): args.separate_c_tiles, args.trace_size, args.archive, + "", args.generate_taps, ) diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index 380f83aa..d5f0d88c 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -84,7 +84,7 @@ def get_mlir_artifact(self): emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( "emulate_bf16_mmul_with_bfp16", True ) - prio_accuracy = False # Force False for debugging + prio_accuracy = self.gemm_args.get("prio_accuracy", False) use_scalar = self.gemm_args.get("use_scalar", False) round_conv_even = self.gemm_args.get("round_conv_even", True) separate_c_tiles = self.gemm_args.get("separate_c_tiles", False) @@ -120,7 +120,7 @@ def get_kernel_artifacts(self): emulate_bf16_mmul_with_bfp16 = self.gemm_args.get( "emulate_bf16_mmul_with_bfp16", True ) - prio_accuracy = False # Force False for debugging + prio_accuracy = self.gemm_args.get("prio_accuracy", False) round_conv_even = self.gemm_args.get("round_conv_even", True) kernel_flags = [ f"-DDIM_M={self.tile_m}", From 9c926b9e434fea8ecdbb30425cf585713696b2ba Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 15:28:31 -0700 Subject: [PATCH 82/99] fix rms norm w/ weights --- iron/operators/rms_norm/test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/iron/operators/rms_norm/test.py b/iron/operators/rms_norm/test.py index e6dd012d..be18b202 100755 --- a/iron/operators/rms_norm/test.py +++ b/iron/operators/rms_norm/test.py @@ -97,6 +97,7 @@ def test_rms_norm( input_buffers = {"input1": golden_ref["input"]} if weighted: operator.weight = golden_ref["weight"] + input_buffers["weight"] = golden_ref["weight"] output_buffers = {"output": golden_ref["output"]} errors, latency_us, bandwidth_gbps = run_test( From 5f51d2dd0a18d870240a1c023d364f2c55fa728a Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 15:45:55 -0700 Subject: [PATCH 83/99] fixup swiglu decode --- iron/operators/swiglu_decode/op.py | 28 +++++++++++---- iron/operators/swiglu_decode/test.py | 54 ++++++++++++++-------------- 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 613ba1a2..9293d972 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -31,23 +31,37 @@ def __init__(self, op): # We need to manually construct SingleXclbinCallable because sub-operators weren't "compiled" in the standard way # Helper to create callable from operator and artifacts - def create_callable(sub_op, xclbin_artifact, insts_artifact): + def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): return SingleXclbinCallable( - xclbin_path=xclbin_artifact.filename, - kernel_name=xclbin_artifact.kernel_name, + xclbin_path=xclbin_path, + kernel_name=kernel_name, insts_bin_path=insts_artifact.filename, args_spec=sub_op.get_arg_spec(), ) self.gemv_1_callable = create_callable( - op.gemv_1, op.combined_xclbin, op.gemv_1_insts + op.gemv_1, + op.combined_xclbin.filename, + op.gemv_1_xclbin.kernel_name, + op.gemv_1_insts, + ) + self.silu_callable = create_callable( + op.silu, + op.combined_xclbin.filename, + op.silu_xclbin.kernel_name, + op.silu_insts, ) - self.silu_callable = create_callable(op.silu, op.combined_xclbin, op.silu_insts) self.eltwise_mul_callable = create_callable( - op.eltwise_mul, op.combined_xclbin, op.eltwise_mul_insts + op.eltwise_mul, + op.combined_xclbin.filename, + op.eltwise_mul_xclbin.kernel_name, + op.eltwise_mul_insts, ) self.gemv_2_callable = create_callable( - op.gemv_2, op.combined_xclbin, op.gemv_2_insts + op.gemv_2, + op.combined_xclbin.filename, + op.gemv_2_xclbin.kernel_name, + op.gemv_2_insts, ) # Allocate and upload weights diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 5e726e5f..19f20a66 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -7,9 +7,12 @@ from pathlib import Path +from ml_dtypes import bfloat16 +from iron.common.base import AIEBuffer +from iron.common.utils import torch_to_numpy from iron.operators.swiglu_decode.op import AIESwiGLUDecode from iron.operators.swiglu_decode.reference import generate_golden_reference -from iron.common.test_utils import run_test, verify_buffer +from iron.common.test_utils import verify_buffer def generate_test_params(extensive=False): @@ -52,33 +55,32 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): # This reference is based on the previous intermediate result read back from the AIE operator, "resetting" the accumulated error to zero. # Note that the previous intermediate result _is_ still verified up to the given tolerance. - input_buffers = {"input": golden_ref["input"]} - output_buffers = {"output": None} - intermediate_buffers = { - "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], - "right": golden_ref["right"], - "intermediate": golden_ref["intermediate"], - } - - errors, latency_us, bandwidth_gbps = run_test( - operator, - input_buffers, - output_buffers, - intermediate_buffers, + operator.compile() + op_func = operator.get_callable() + + input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) + output_buf = AIEBuffer(shape=(1, embedding_dim), dtype=bfloat16) + + op_func(input_buf, output_buf) + + errors = {} + # Verify intermediate result + intermediate = op_func.intermediate.view_as_torch().reshape((1, hidden_dim)) + errors_intermediate = verify_buffer( + intermediate, + "intermediate", + golden_ref["intermediate"], rel_tol=0.07, abs_tol=0.7, ) - - ref_2 = ( - operator.read_buffer_as_torch("intermediate", (1, hidden_dim)) - @ golden_ref["w_down"] - ) - errors_2 = verify_buffer(operator, "output", ref_2, rel_tol=0.04, abs_tol=0.4) - if errors_2: - errors["output"] = errors_2 - - print(f"\nLatency (us): {latency_us:.1f}") - print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + if errors_intermediate: + errors["intermediate"] = errors_intermediate + + # Verify output using intermediate result + ref_2 = intermediate @ golden_ref["w_down"] + output = output_buf.view_as_torch().reshape((1, embedding_dim)) + errors_output = verify_buffer(output, "output", ref_2, rel_tol=0.04, abs_tol=0.4) + if errors_output: + errors["output"] = errors_output assert not errors, f"Test failed with errors: {errors}" From ba338d33d110628323fa499a965f3a890ee9206a Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 16:01:28 -0700 Subject: [PATCH 84/99] small steps --- iron/common/compilation/base.py | 14 ++++++ iron/operators/swiglu_prefill/op.py | 30 ++++++++++--- iron/operators/swiglu_prefill/test.py | 62 +++++++++++++-------------- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index 8631ab4b..fb4b2c4d 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -507,6 +507,20 @@ def compile(self, graph): ] compile_cmd += [os.path.abspath(mlir_source.filename)] + # If the MLIR source depends on a kernel archive, pass it to aiecc.py so it can be linked + if ( + isinstance(mlir_source, PythonGeneratedMLIRArtifact) + and "kernel_archive" in mlir_source.callback_kwargs + ): + compile_cmd.append( + os.path.abspath( + os.path.join( + self.build_dir, + mlir_source.callback_kwargs["kernel_archive"], + ) + ) + ) + commands.append( ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) ) diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index be8f7548..d572c21f 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -28,23 +28,37 @@ class SwiGLUPrefillCallable: def __init__(self, op): self.op = op - def create_callable(sub_op, xclbin_artifact, insts_artifact): + def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): return SingleXclbinCallable( - xclbin_path=xclbin_artifact.filename, - kernel_name=xclbin_artifact.kernel_name, + xclbin_path=xclbin_path, + kernel_name=kernel_name, insts_bin_path=insts_artifact.filename, args_spec=sub_op.get_arg_spec(), ) self.gemm_1_callable = create_callable( - op.gemm_1, op.combined_xclbin, op.gemm_1_insts + op.gemm_1, + op.combined_xclbin.filename, + op.gemm_1_xclbin.kernel_name, + op.gemm_1_insts, + ) + self.silu_callable = create_callable( + op.silu, + op.combined_xclbin.filename, + op.silu_xclbin.kernel_name, + op.silu_insts, ) - self.silu_callable = create_callable(op.silu, op.combined_xclbin, op.silu_insts) self.eltwise_mul_callable = create_callable( - op.eltwise_mul, op.combined_xclbin, op.eltwise_mul_insts + op.eltwise_mul, + op.combined_xclbin.filename, + op.eltwise_mul_xclbin.kernel_name, + op.eltwise_mul_insts, ) self.gemm_2_callable = create_callable( - op.gemm_2, op.combined_xclbin, op.gemm_2_insts + op.gemm_2, + op.combined_xclbin.filename, + op.gemm_2_xclbin.kernel_name, + op.gemm_2_insts, ) # Allocate and upload weights @@ -59,8 +73,10 @@ def create_callable(sub_op, xclbin_artifact, insts_artifact): self.right = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) self.left_swished = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) self.intermediate = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.last_output_buf = None def __call__(self, input_buf, output_buf): + self.last_output_buf = output_buf input_buf.to("npu") output_buf.to("npu") self.weights_1.to("npu") diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 25204cf5..cab789b1 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -7,9 +7,12 @@ from pathlib import Path +from ml_dtypes import bfloat16 +from iron.common.base import AIEBuffer +from iron.common.utils import torch_to_numpy from iron.operators.swiglu_prefill.op import AIESwiGLUPrefill from iron.operators.swiglu_decode.reference import generate_golden_reference -from iron.common.test_utils import run_test, verify_buffer +from iron.common.test_utils import verify_buffer def generate_test_params(extensive=False): @@ -51,41 +54,36 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c operator.weights_2 = golden_ref["w_up"].T operator.weights_3 = golden_ref["w_down"].T - input_buffers = {"input": golden_ref["input"]} - # output_buffers = {'output': golden_ref['output']} - output_buffers = {"output": None} - intermediate_buffers = { - "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], - "right": golden_ref["right"], - # 'intermediate': golden_ref['intermediate'] - } - - errors, latency_us, bandwidth_gbps = run_test( - operator, - input_buffers, - output_buffers, - intermediate_buffers, - rel_tol=0.07, - abs_tol=0.7, - ) + operator.compile() + op_func = operator.get_callable() + + input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) + output_buf = AIEBuffer( + shape=(seq_len * embedding_dim,), dtype=bfloat16 + ) # Output is flattened + + op_func(input_buf, output_buf) + + errors = {} - ref_2 = operator.read_buffer_as_torch( - "left_swished", (seq_len, hidden_dim) - ) * operator.read_buffer_as_torch("right", (seq_len, hidden_dim)) - errors_2 = verify_buffer(operator, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4) + # Verify intermediate result (left_swished * right) + left_swished = op_func.left_swished.view_as_torch().reshape((seq_len, hidden_dim)) + right = op_func.right.view_as_torch().reshape((seq_len, hidden_dim)) + ref_2 = left_swished * right + + # Note: intermediate buffer in op_func stores the result of eltwise_mul + intermediate = op_func.intermediate.view_as_torch().reshape((seq_len, hidden_dim)) + errors_2 = verify_buffer( + intermediate, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4 + ) if errors_2: errors["intermediate"] = errors_2 - ref_3 = ( - operator.read_buffer_as_torch("intermediate", (seq_len, hidden_dim)) - @ golden_ref["w_down"] - ) - errors_3 = verify_buffer(operator, "output", ref_3, rel_tol=0.04, abs_tol=0.4) + # Verify output using intermediate result + ref_3 = intermediate @ golden_ref["w_down"] + output = output_buf.view_as_torch().reshape((seq_len, embedding_dim)) + errors_3 = verify_buffer(output, "output", ref_3, rel_tol=0.04, abs_tol=0.4) if errors_3: - errors["output"] = errors_2 - - print(f"\nLatency (us): {latency_us:.1f}") - print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + errors["output"] = errors_3 assert not errors, f"Test failed with errors: {errors}" From 515556d5642666d6f9586e0f2bd851c833292c4f Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 16:09:02 -0700 Subject: [PATCH 85/99] Fixup paths a bit. --- iron/operators/__init__.py | 2 ++ iron/operators/gemm/op.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index 9da2febf..1ad3044e 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -10,4 +10,6 @@ from .rope.op import AIERope from .silu.op import AIESiLU from .softmax.op import AIESoftmax +from .swiglu_decode.op import AIESwiGLUDecode +from .swiglu_prefill.op import AIESwiGLUPrefill from .transpose.op import AIETranspose diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index d5f0d88c..841fef03 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -140,9 +140,13 @@ def get_kernel_artifacts(self): kernel_flags.append("-DB_COL_MAJ") if self.c_col_maj: kernel_flags.append("-DC_COL_MAJ") + + # Include flags in the filename to avoid stale builds when flags change + flags_suffix = f"_{int(prio_accuracy)}_{int(emulate_bf16_mmul_with_bfp16)}_{int(round_conv_even)}" + return [ KernelObjectArtifact( - f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}.o", + f"gemm_{self.tile_m}x{self.tile_k}x{self.tile_n}_{int(self.b_col_maj)}_{int(self.c_col_maj)}{flags_suffix}.o", extra_flags=kernel_flags, dependencies=[ SourceArtifact(base_dir / "aie_kernels" / "aie2p" / "mm.cc") From 21bbe8e7a8a87b56c8793ac1cb899af0165d69bf Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 6 Feb 2026 17:07:13 -0700 Subject: [PATCH 86/99] swiglu decode working locally --- iron/operators/swiglu_decode/op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 9293d972..05496634 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -144,7 +144,7 @@ def set_up_artifacts(self): tile_size_output=self.hidden_dim // 8, ) self.gemv_1 = gemv_1 - gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts() + gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts(prefix="swiglu_gemv_1_") gemv_1_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_1", "--xclbin-kernel-id=0x901", @@ -161,7 +161,7 @@ def set_up_artifacts(self): ) self.silu = silu self.hidden_dim_padded = silu.size - silu_xclbin, silu_insts = silu.get_artifacts() + silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_silu_") silu_xclbin.xclbin_input = gemv_1_xclbin silu_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_silu", @@ -178,7 +178,9 @@ def set_up_artifacts(self): ) self.eltwise_mul = eltwise_mul assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded - eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts() + eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts( + prefix="swiglu_eltwise_mul_" + ) eltwise_mul_xclbin.xclbin_input = silu_xclbin eltwise_mul_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_eltwise_mul", @@ -196,7 +198,7 @@ def set_up_artifacts(self): tile_size_output=self.embedding_dim // 8, ) self.gemv_2 = gemv_2 - gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts() + gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts(prefix="swiglu_gemv_2_") gemv_2_xclbin.xclbin_input = eltwise_mul_xclbin gemv_2_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_2", From 2523442ce3b187eb33a085f3b2c1b384cc167565 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 07:54:31 -0700 Subject: [PATCH 87/99] try to integrate tensor more --- iron/common/base.py | 68 +++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/iron/common/base.py b/iron/common/base.py index 6f9eb821..8da910f8 100644 --- a/iron/common/base.py +++ b/iron/common/base.py @@ -11,6 +11,8 @@ from ml_dtypes import bfloat16 import aie.utils.config +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.hostruntime.tensor_class import Tensor from . import compilation as comp from .context import AIEContext from .device_manager import AIEDeviceManager, pyxrt @@ -174,24 +176,33 @@ def __init__(self, direction, shape, dtype=bfloat16): self.direction = direction -class AIEBuffer: +class AIEBuffer(XRTTensor): def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): - size = np.prod(shape) * np.dtype(dtype).itemsize - self.shape = shape - self.dtype = dtype - self.bo = bo - self.on = "cpu" self.device_manager = device_manager or AIEDeviceManager() - if not self.bo: - self.bo = pyxrt.bo( - self.device_manager.device, - size, - pyxrt.bo.host_only, - 0x10000, - ) - self.memory_view = self.bo.map() self.subviews = [] + if bo is not None: + Tensor.__init__(self, shape, dtype=dtype, device="cpu") + self._shape = shape + self.xrt_device = self.device_manager.device + self._bo = bo + ptr = self._bo.map() + self._data = np.frombuffer(ptr, dtype=self.dtype).reshape(self._shape) + else: + super().__init__(shape, dtype=dtype, device="cpu") + + @property + def bo(self): + return self._bo + + @property + def on(self): + return self.device + + @on.setter + def on(self, value): + self.device = value + def subbuffer(self, length, offset, shape, dtype=None): if dtype is None: dtype = self.dtype @@ -228,41 +239,24 @@ def view(self, shape): return sub_buffer def view_as_np(self): - self.to("cpu") - # Interpret the buffer as a 1-dimensional array then change its view to the expected shape - return np.frombuffer( - self.memory_view, dtype=self.dtype, count=np.prod(self.shape) - ).reshape(self.shape) + return self.numpy() def view_as_torch(self): - return numpy_to_torch(self.view_as_np()) + return numpy_to_torch(self.numpy()) def to(self, dest): - direction = { - "npu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE, - "cpu": pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE, - } - if dest not in direction: - raise RuntimeError(f"Unknown destination for AIEBuffer.to(): {dest}") - if self.on == dest: - return self - direction = direction[dest] - self.bo.sync(direction) - self.on = dest + super().to(dest) todo = self.subviews.copy() while todo: sub_buffer = todo.pop() - sub_buffer.on = self.on + sub_buffer.device = dest todo.extend(sub_buffer.subviews) return self @staticmethod def from_np(buffer): - shape = buffer.shape - dtype = buffer.dtype - size = np.prod(shape) * np.dtype(dtype).itemsize - aie_buffer = AIEBuffer(shape=shape, dtype=dtype) - aie_buffer.view_as_np()[:] = buffer + aie_buffer = AIEBuffer(buffer.shape, dtype=buffer.dtype) + aie_buffer.data[:] = buffer aie_buffer.to("npu") return aie_buffer From 27aadcfe44f460a2011dcf0e0f488e2d30bad8c6 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 08:03:54 -0700 Subject: [PATCH 88/99] try to fix llama keywords --- iron/applications/llama_3.2_1b/src/block/feed_forward.py | 2 +- iron/applications/llama_3.2_1b/src/block/gqa.py | 1 - iron/applications/llama_3.2_1b/src/model_with_json.py | 2 -- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/iron/applications/llama_3.2_1b/src/block/feed_forward.py b/iron/applications/llama_3.2_1b/src/block/feed_forward.py index 8bae36ec..b7dc8cf2 100644 --- a/iron/applications/llama_3.2_1b/src/block/feed_forward.py +++ b/iron/applications/llama_3.2_1b/src/block/feed_forward.py @@ -116,7 +116,7 @@ def __init__( ) if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]: - aie_gemv_config = {"num_aie_columns": 8, "is_mv": False} + aie_gemv_config = {"num_aie_columns": 8} # FC1 and FC2: emb_dim -> hidden_dim self.aie_fc1_gemv = AIEGEMV( M=self.hidden_dim, diff --git a/iron/applications/llama_3.2_1b/src/block/gqa.py b/iron/applications/llama_3.2_1b/src/block/gqa.py index 1a712ff9..2267cd28 100644 --- a/iron/applications/llama_3.2_1b/src/block/gqa.py +++ b/iron/applications/llama_3.2_1b/src/block/gqa.py @@ -133,7 +133,6 @@ def __init__( aie_gemv_config = { "num_aie_columns": 8, - "is_mv": False, "use_static_weight": True, } self.aie_query_gemv = AIEGEMV( diff --git a/iron/applications/llama_3.2_1b/src/model_with_json.py b/iron/applications/llama_3.2_1b/src/model_with_json.py index 856fb048..ba240ffc 100644 --- a/iron/applications/llama_3.2_1b/src/model_with_json.py +++ b/iron/applications/llama_3.2_1b/src/model_with_json.py @@ -197,9 +197,7 @@ def __init__( ) aie_gemv_config = { "num_aie_columns": 8, - "is_mv": True, "use_static_weight": True, - "num_aie_columns": 8, "tile_size_input": 4, "tile_size_output": 32, } From 1872ed28ea44b97740e8f8265b2408c3a206038c Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 08:28:07 -0700 Subject: [PATCH 89/99] fix another arg --- iron/applications/llama_3.2_1b/src/block/transformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/iron/applications/llama_3.2_1b/src/block/transformer.py b/iron/applications/llama_3.2_1b/src/block/transformer.py index f2b46cdf..fd6f9e58 100644 --- a/iron/applications/llama_3.2_1b/src/block/transformer.py +++ b/iron/applications/llama_3.2_1b/src/block/transformer.py @@ -104,7 +104,6 @@ def __init__( self.aie_residual_add_prefill = AIEElementwiseAdd( size=max_prefill_size, num_aie_columns=8, - num_channels=2, tile_size=cfg["emb_dim"], ) @@ -114,7 +113,6 @@ def __init__( self.aie_residual_add_decode = AIEElementwiseAdd( size=decode_size, num_aie_columns=1, - num_channels=2, tile_size=cfg["emb_dim"], ) else: From c7ff382f1f34efab62fb9a9c7e35e62712d82efc Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 08:32:35 -0700 Subject: [PATCH 90/99] Increment IRON --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0072a859..c849253f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -6,11 +6,11 @@ # version of torch (don't need CUDA), so we give this index precedence over the # main PyPI. These indices are consulted in order of precedence by pip. --index-url https://download.pytorch.org/whl/cpu ---extra-index-url https://github.com/Xilinx/mlir-aie/releases/expanded_assets/v1.2.0 +--extra-index-url https://github.com/Xilinx/mlir-aie/releases/expanded_assets/v1.2.1 --extra-index-url https://github.com/Xilinx/llvm-aie/releases/expanded_assets/nightly --extra-index-url https://pypi.org/simple -mlir_aie==v1.2.0 +mlir_aie==v1.2.1 llvm-aie black From 884da88bc961e68a7cc28c5ea46124c8e816ae46 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 09:44:10 -0700 Subject: [PATCH 91/99] cleanup pytest config a bit --- iron/operators/axpy/test.py | 43 ++++++------- iron/operators/dequant/test.py | 40 ++++++------ iron/operators/elementwise_add/test.py | 37 ++++++----- iron/operators/elementwise_mul/test.py | 36 +++++------ iron/operators/gelu/test.py | 39 +++++------- iron/operators/gemm/test.py | 87 ++++++++++++-------------- iron/operators/gemv/test.py | 30 ++++----- iron/operators/layer_norm/test.py | 39 +++++------- iron/operators/leaky_relu/test.py | 20 +----- iron/operators/mem_copy/test.py | 36 ++++------- iron/operators/mha/test.py | 23 +++---- iron/operators/relu/test.py | 39 ++++++------ iron/operators/rms_norm/test.py | 38 ++++------- iron/operators/rope/test.py | 60 ++++++++---------- iron/operators/sigmoid/test.py | 39 ++++++------ iron/operators/silu/test.py | 39 ++++++------ iron/operators/softmax/test.py | 32 ++++------ iron/operators/swiglu_decode/test.py | 32 +++------- iron/operators/swiglu_prefill/test.py | 26 +++----- iron/operators/tanh/test.py | 39 ++++++------ iron/operators/transpose/test.py | 46 +++++++------- 21 files changed, 346 insertions(+), 474 deletions(-) diff --git a/iron/operators/axpy/test.py b/iron/operators/axpy/test.py index b91e802f..b37fabe2 100755 --- a/iron/operators/axpy/test.py +++ b/iron/operators/axpy/test.py @@ -12,40 +12,37 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] - scalar_factors = [3.0] if not extensive else [3.0, 10.0] + input_lengths = [1024, 2048, 4096, 8192] + scalar_factors = [3.0, 10.0] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns if tile_size * num_aie_columns != input_length: continue for scalar in scalar_factors: - names.append( - f"axpy_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}_{scalar}" - ) - params.append( - (input_length, num_aie_columns, num_channels, tile_size, scalar) - ) - return params, names + name = f"axpy_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}_{scalar}" + # Determine if this is a regular test case + is_regular = input_length == 2048 and scalar == 3.0 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + scalar, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -54,7 +51,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,scalar_factor", - all_params, + get_params(), ) def test_axpy( input_length, num_aie_columns, num_channels, tile_size, scalar_factor, aie_context diff --git a/iron/operators/dequant/test.py b/iron/operators/dequant/test.py index 03b037f4..4ab904c0 100644 --- a/iron/operators/dequant/test.py +++ b/iron/operators/dequant/test.py @@ -12,12 +12,11 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] +def get_params(): + input_lengths = [1024, 2048, 4096, 8192] group_size = 32 params = [] - names = [] for input_length in input_lengths: for num_columns in range(1, 9): # 1 to 8 columns for num_channels in range(1, 3): # 1 or 2 channels @@ -30,26 +29,23 @@ def generate_test_params(extensive=False): # Only proceed if tile_size * total_cores == input_length (exact division) if tile_size * total_cores == input_length: - names.append( - f"dequant_{num_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_columns, num_channels, tile_size, group_size) - ) - return params, names + name = f"dequant_{num_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_columns, + num_channels, + tile_size, + group_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -58,7 +54,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,group_size", - all_params, + get_params(), ) def test_dequant( input_length, num_aie_columns, num_channels, tile_size, group_size, aie_context diff --git a/iron/operators/elementwise_add/test.py b/iron/operators/elementwise_add/test.py index 9309ca9c..5794a2c4 100755 --- a/iron/operators/elementwise_add/test.py +++ b/iron/operators/elementwise_add/test.py @@ -12,36 +12,35 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + # Combine all lengths + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns if tile_size * num_aie_columns != input_length: continue - names.append( - f"eltwise_add_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names + name = f"eltwise_add_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -50,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_elementwise_add( input_length, num_aie_columns, num_channels, tile_size, aie_context diff --git a/iron/operators/elementwise_mul/test.py b/iron/operators/elementwise_mul/test.py index 4f7956ba..82b34a9b 100755 --- a/iron/operators/elementwise_mul/test.py +++ b/iron/operators/elementwise_mul/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,23 @@ def generate_test_params(extensive=False): tile_size = 4096 if tile_size * num_aie_columns != input_length: continue - names.append( - f"eltwise_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names + name = f"eltwise_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -52,7 +50,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_elementwise_mul( input_length, num_aie_columns, num_channels, tile_size, aie_context diff --git a/iron/operators/gelu/test.py b/iron/operators/gelu/test.py index d91a9e7a..69b4519d 100755 --- a/iron/operators/gelu/test.py +++ b/iron/operators/gelu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels_choices = [1, 2] - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): for num_channels in num_channels_choices: @@ -28,26 +27,22 @@ def generate_test_params(extensive=False): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - names.append( - f"gelu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_aie_columns, num_channels, tile_size) - ) - return params, names + name = f"gelu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -56,7 +51,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_gelu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/gemm/test.py b/iron/operators/gemm/test.py index 6480aeff..b1dc8194 100755 --- a/iron/operators/gemm/test.py +++ b/iron/operators/gemm/test.py @@ -12,10 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): # fmt: off - params = [ - # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + regular_params = [ (2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1), (2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1), (2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1), @@ -44,48 +44,43 @@ def generate_test_params(extensive=False): ] # fmt: on - if extensive: - params = extensive_params - - names = [] - for ( - M, - K, - N, - num_aie_columns, - b_col_maj, - c_col_maj, - m, - k, - n, - trace_size, - partition_N, - ) in params: - name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" - if b_col_maj: - name += "_bcolmaj" - if c_col_maj: - name += "_ccolmaj" - if partition_N > 1: - name += f"_{partition_N}npart" - if trace_size > 0: - name += f"_{trace_size}trace" - names.append(name) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + + # Helper to generate name and append param + def add_params(param_list, is_extensive): + for p in param_list: + ( + M, + K, + N, + num_aie_columns, + b_col_maj, + c_col_maj, + m, + k, + n, + trace_size, + partition_N, + ) = p + + name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" + if b_col_maj: + name += "_bcolmaj" + if c_col_maj: + name += "_ccolmaj" + if partition_N > 1: + name += f"_{partition_N}npart" + if trace_size > 0: + name += f"_{trace_size}trace" + + marks = [pytest.mark.extensive] if is_extensive else [] + + params.append(pytest.param(*p, id=name, marks=marks)) + + add_params(regular_params, is_extensive=False) + add_params(extensive_params, is_extensive=True) + + return params @pytest.mark.metrics( @@ -95,7 +90,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "M,K,N,num_aie_columns,b_col_maj,c_col_maj,m,k,n,trace_size,partition_N", - all_params, + get_params(), ) def test_gemm( M, diff --git a/iron/operators/gemv/test.py b/iron/operators/gemv/test.py index 2dd4a8e6..c26fb1f4 100755 --- a/iron/operators/gemv/test.py +++ b/iron/operators/gemv/test.py @@ -12,8 +12,8 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [ +def get_params(): + params_list = [ (128, 128, 1, 32, 128), (2048, 8192, 1, 1, 2048), (8192, 2048, 1, 4, 1024), @@ -24,24 +24,16 @@ def generate_test_params(extensive=False): (2048, 8192, 8, 1, 256), (8192, 2048, 8, 4, 1024), ] - names = [ - f"matrix_vector_mul_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col" - for M, K, num_aie_columns, tile_size_input, tile_size_output in params - ] - return params, names - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) + params = [] + for p in params_list: + M, K, num_aie_columns, tile_size_input, tile_size_output = p + name = f"matrix_vector_mul_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col" -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + # All tests are considered regular here as per original code structure + # (original code returned same list for both regular and extensive) + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( @@ -50,7 +42,7 @@ def generate_test_params(extensive=False): Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", ) @pytest.mark.parametrize( - "M,K,num_aie_columns,tile_size_input,tile_size_output", all_params + "M,K,num_aie_columns,tile_size_input,tile_size_output", get_params() ) def test_gemv(M, K, num_aie_columns, tile_size_input, tile_size_output, aie_context): golden_ref = generate_golden_reference(M=M, K=K) diff --git a/iron/operators/layer_norm/test.py b/iron/operators/layer_norm/test.py index 2b14641c..360da0a1 100755 --- a/iron/operators/layer_norm/test.py +++ b/iron/operators/layer_norm/test.py @@ -12,11 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): for num_channels_layer in range(1, 3): # 1 or 2 @@ -26,26 +25,22 @@ def generate_test_params(extensive=False): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - names.append( - f"layer_norm_{num_aie_columns}_cols_{num_channels_layer}_channels_{input_length}_tile_{tile_size}" - ) - params.append( - (input_length, num_aie_columns, num_channels_layer, tile_size) - ) - return params, names + name = f"layer_norm_{num_aie_columns}_cols_{num_channels_layer}_channels_{input_length}_tile_{tile_size}" + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels_layer, + tile_size, + id=name, + marks=marks, + ) + ) + return params @pytest.mark.metrics( @@ -54,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_layer_norm( input_length, num_aie_columns, num_channels, tile_size, aie_context diff --git a/iron/operators/leaky_relu/test.py b/iron/operators/leaky_relu/test.py index cac577ad..6adb8d4d 100755 --- a/iron/operators/leaky_relu/test.py +++ b/iron/operators/leaky_relu/test.py @@ -12,24 +12,10 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): # Leaky ReLU is currently broken (#36); leave it untested params = [] - names = [] - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -38,7 +24,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,alpha", - all_params, + get_params(), ) def test_leaky_relu( input_length, num_aie_columns, num_channels, tile_size, alpha, aie_context diff --git a/iron/operators/mem_copy/test.py b/iron/operators/mem_copy/test.py index afd7f540..f6314e5b 100644 --- a/iron/operators/mem_copy/test.py +++ b/iron/operators/mem_copy/test.py @@ -12,12 +12,11 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - input_lengths = [2048] if not extensive else [1024, 2048, 4096, 8192] - bypass_modes = [False] if not extensive else [False, True] +def get_params(): + input_lengths = [1024, 2048, 4096, 8192] + bypass_modes = [False, True] params = [] - names = [] for input_length in input_lengths: for num_cores in range(1, 17): # 1 to 16 cores @@ -35,33 +34,24 @@ def generate_test_params(extensive=False): # Only proceed if tile_size * num_cores == input_length (exact division) if tile_size * num_cores == input_length: - names.append( - f"mem_copy_{num_cores}_cores_{num_channels}_chans_{input_length}_tile_{tile_size}_{str(bypass)}" - ) + name = f"mem_copy_{num_cores}_cores_{num_channels}_chans_{input_length}_tile_{tile_size}_{str(bypass)}" + + is_regular = input_length == 2048 and bypass == False + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( input_length, num_cores, num_channels, bypass, tile_size, + id=name, + marks=marks, ) ) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -70,7 +60,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_cores,num_channels,bypass,tile_size", - all_params, + get_params(), ) def test_mem_copy( input_length, num_cores, num_channels, bypass, tile_size, aie_context diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index 35c5087f..b1871e42 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -12,30 +12,21 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [(16384, 64, 1, 8)] +def get_params(): + params_list = [(16384, 64, 1, 8)] names = ["mha"] - return params, names - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p, name in zip(params_list, names): + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params) +@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", get_params()) def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): golden_ref = generate_golden_reference( S_q=seq_len, diff --git a/iron/operators/relu/test.py b/iron/operators/relu/test.py index 3194c8c0..4bea4584 100755 --- a/iron/operators/relu/test.py +++ b/iron/operators/relu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"relu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"relu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_relu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/rms_norm/test.py b/iron/operators/rms_norm/test.py index be18b202..f7c183f9 100755 --- a/iron/operators/rms_norm/test.py +++ b/iron/operators/rms_norm/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for weighted in [False, True]: for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): @@ -37,37 +36,26 @@ def generate_test_params(extensive=False): check_length = tile_size * num_aie_columns if check_length == input_length: if not weighted: - names.append( - f"rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_tile_{tile_size}" - ) + name = f"rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_tile_{tile_size}" else: - names.append( - f"weighted_rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_weights_{tile_size}" - ) + name = f"weighted_rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_weights_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( input_length, num_aie_columns, num_channels_rms, tile_size, weighted, + id=name, + marks=marks, ) ) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -76,7 +64,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size,weighted", - all_params, + get_params(), ) def test_rms_norm( input_length, num_aie_columns, num_channels, tile_size, weighted, aie_context diff --git a/iron/operators/rope/test.py b/iron/operators/rope/test.py index 5a02edbc..d7156a7b 100755 --- a/iron/operators/rope/test.py +++ b/iron/operators/rope/test.py @@ -12,55 +12,49 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [] - names = [] - +def get_params(): num_aie_columns_options = [1, 2, 8] - if not extensive: - input_rows = [32] - input_cols = [512] - input_angle_rows = [8, 32] - method_types = [0] # 0: Two-halves method - else: - input_rows = [32, 64] - input_cols = [128] - input_angle_rows = [8, 16, 32] - method_types = [0, 1] # 0: Two-halves method, 1: interleaved method + # Combine all options + input_rows = [32, 64] + input_cols = [128, 512] + input_angle_rows = [8, 16, 32] + method_types = [0, 1] # 0: Two-halves method, 1: interleaved method + params = [] for num_aie_columns in num_aie_columns_options: for n_rows in input_rows: for n_angle_rows in input_angle_rows: for n_cols in input_cols: for method_type in method_types: - names.append( - f"rope_{num_aie_columns}c_{n_rows}rows_{n_cols}cols_{n_angle_rows}arows_{method_type}m" + name = f"rope_{num_aie_columns}c_{n_rows}rows_{n_cols}cols_{n_angle_rows}arows_{method_type}m" + + is_regular = ( + n_rows == 32 + and n_cols == 512 + and n_angle_rows in [8, 32] + and method_type == 0 ) + + is_extensive_valid = n_cols == 128 + + if not is_regular and not is_extensive_valid: + continue + + marks = [] if is_regular else [pytest.mark.extensive] + params.append( - ( + pytest.param( n_rows, n_cols, n_angle_rows, num_aie_columns, method_type, + id=name, + marks=marks, ) ) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -69,7 +63,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "rows,cols,angle_rows,aie_columns,method_type", - all_params, + get_params(), ) def test_rope(rows, cols, angle_rows, aie_columns, method_type, aie_context): golden_ref = generate_golden_reference( diff --git a/iron/operators/sigmoid/test.py b/iron/operators/sigmoid/test.py index 1dc5b99d..641fca96 100755 --- a/iron/operators/sigmoid/test.py +++ b/iron/operators/sigmoid/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"sigmoid_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"sigmoid_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_sigmoid(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/silu/test.py b/iron/operators/silu/test.py index e211d55b..6eb22f20 100755 --- a/iron/operators/silu/test.py +++ b/iron/operators/silu/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"silu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"silu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_silu(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/softmax/test.py b/iron/operators/softmax/test.py index d965ab93..093610a7 100755 --- a/iron/operators/softmax/test.py +++ b/iron/operators/softmax/test.py @@ -30,37 +30,27 @@ def get_optimal_columns_channels(input_length, tile_size): return 2, 2 # Default fallback -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 2 - input_lengths = [32768] if not extensive else [] + input_lengths = [32768] tile_sizes = [1024, 512, 2048] params = [] - names = [] for input_length in input_lengths: for tile_size in tile_sizes: optimal_columns, optimal_channels = get_optimal_columns_channels( input_length, tile_size ) - names.append( - f"softmax_{optimal_columns}_cols_{optimal_channels}_channels_{input_length}_tile_{tile_size}" - ) - params.append((input_length, optimal_columns, optimal_channels, tile_size)) - return params, names - + name = f"softmax_{optimal_columns}_cols_{optimal_channels}_channels_{input_length}_tile_{tile_size}" -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + # All tests are regular as extensive list was empty in original code + params.append( + pytest.param( + input_length, optimal_columns, optimal_channels, tile_size, id=name + ) + ) + return params @pytest.mark.metrics( @@ -69,7 +59,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_context): diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 19f20a66..8d4a51d2 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -15,30 +15,22 @@ from iron.common.test_utils import verify_buffer -def generate_test_params(extensive=False): - params = [(2048, 2048)] - names = [f"swiglu_decode_1x{emb}x{hid}" for emb, hid in params] - return params, names +def get_params(): + params_list = [(2048, 2048)] - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p in params_list: + emb, hid = p + name = f"swiglu_decode_1x{emb}x{hid}" + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("embedding_dim,hidden_dim", all_params) +@pytest.mark.parametrize("embedding_dim,hidden_dim", get_params()) def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): golden_ref = generate_golden_reference(M=1, K=embedding_dim, N=hidden_dim) @@ -49,12 +41,6 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): operator.weights_2 = golden_ref["w_up"].T operator.weights_3 = golden_ref["w_down"].T - # In the following, some buffers are commented out. - # Because this operator calls multiple kernels in sequence, rounding errors due to the smaller bf16 data type accumulate, which can cause it to fail verification. - # So, instead of verifying the final output buffers against the float32-calculated reference, we calculate another reference for the final output: - # This reference is based on the previous intermediate result read back from the AIE operator, "resetting" the accumulated error to zero. - # Note that the previous intermediate result _is_ still verified up to the given tolerance. - operator.compile() op_func = operator.get_callable() diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index cab789b1..10df9243 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -15,31 +15,23 @@ from iron.common.test_utils import verify_buffer -def generate_test_params(extensive=False): +def get_params(): # This operation is currently untested except for the integrated llama application tests. - params = [(256, 2048, 2048, False)] - names = [f"swiglu_prefill_256x{emb}x{hid}" for _, emb, hid, _ in params] - return params, names + params_list = [(256, 2048, 2048, False)] - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + params = [] + for p in params_list: + _, emb, hid, _ = p + name = f"swiglu_prefill_256x{emb}x{hid}" + params.append(pytest.param(*p, id=name)) + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", all_params) +@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", get_params()) def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_context): golden_ref = generate_golden_reference(M=seq_len, K=embedding_dim, N=hidden_dim) diff --git a/iron/operators/tanh/test.py b/iron/operators/tanh/test.py index f9986bb3..0a50b183 100755 --- a/iron/operators/tanh/test.py +++ b/iron/operators/tanh/test.py @@ -12,13 +12,12 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): +def get_params(): max_aie_columns = 8 num_channels = 1 # 1 channel for 1 input - input_lengths = [2048] if not extensive else [1024, 4096, 8192] + input_lengths = [1024, 2048, 4096, 8192] params = [] - names = [] for input_length in input_lengths: for num_aie_columns in range(1, max_aie_columns + 1): tile_size = input_length // num_aie_columns @@ -26,24 +25,22 @@ def generate_test_params(extensive=False): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - names.append( - f"tanh_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + name = f"tanh_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + + is_regular = input_length == 2048 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + input_length, + num_aie_columns, + num_channels, + tile_size, + id=name, + marks=marks, + ) ) - params.append((input_length, num_aie_columns, num_channels, tile_size)) - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) - -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( @@ -52,7 +49,7 @@ def generate_test_params(extensive=False): ) @pytest.mark.parametrize( "input_length,num_aie_columns,num_channels,tile_size", - all_params, + get_params(), ) def test_tanh(input_length, num_aie_columns, num_channels, tile_size, aie_context): golden_ref = generate_golden_reference(input_length=input_length) diff --git a/iron/operators/transpose/test.py b/iron/operators/transpose/test.py index 8f0d9981..00cf562b 100755 --- a/iron/operators/transpose/test.py +++ b/iron/operators/transpose/test.py @@ -12,16 +12,15 @@ from iron.common.test_utils import run_test -def generate_test_params(extensive=False): - params = [] - names = [] +def get_params(): max_aie_columns = 8 - input_lengths = [2048] if not extensive else [64, 2048] - n_list = [64] if not extensive else [64, 128, 256, 512] + input_lengths = [64, 2048] + n_list = [64, 128, 256, 512] s_list = [8] m = 64 n = 64 + params = [] for M in input_lengths: for N in n_list: for s in s_list: @@ -37,32 +36,33 @@ def generate_test_params(extensive=False): length = M * N if check_length != length: continue - names.append( - f"transpose_{M}_M_{N}_N_{num_aie_columns}_cols_{num_channels}_channels_{m}_m_{n}_n_{s}_s" + name = f"transpose_{M}_M_{N}_N_{num_aie_columns}_cols_{num_channels}_channels_{m}_m_{n}_n_{s}_s" + + is_regular = M == 2048 and N == 64 + marks = [] if is_regular else [pytest.mark.extensive] + + params.append( + pytest.param( + M, + N, + num_aie_columns, + num_channels, + m, + n, + s, + id=name, + marks=marks, + ) ) - params.append((M, N, num_aie_columns, num_channels, m, n, s)) - - return params, names - - -regular_params, regular_names = generate_test_params(extensive=False) -extensive_params, extensive_names = generate_test_params(extensive=True) -# Combine params with marks - extensive params get pytest.mark.extensive -all_params = [ - pytest.param(*params, id=name) - for params, name in zip(regular_params, regular_names) -] + [ - pytest.param(*params, marks=pytest.mark.extensive, id=name) - for params, name in zip(extensive_params, extensive_names) -] + return params @pytest.mark.metrics( Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", all_params) +@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", get_params()) def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context): golden_ref = generate_golden_reference(rows=M, cols=N) From a8caa1c82de0e2a55ed71fc3f09f762176e8caa1 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 13:37:28 -0700 Subject: [PATCH 92/99] try to cleanup pytest code --- conftest.py | 5 +++++ iron/operators/axpy/test.py | 3 --- iron/operators/dequant/test.py | 3 --- iron/operators/elementwise_add/test.py | 3 --- iron/operators/elementwise_mul/test.py | 3 --- iron/operators/gelu/test.py | 3 --- iron/operators/gemm/test.py | 27 +------------------------- iron/operators/gemv/test.py | 5 +---- iron/operators/layer_norm/test.py | 3 --- iron/operators/mem_copy/test.py | 3 --- iron/operators/mha/test.py | 5 ++--- iron/operators/relu/test.py | 3 --- iron/operators/rms_norm/test.py | 6 ------ iron/operators/rope/test.py | 3 --- iron/operators/sigmoid/test.py | 3 --- iron/operators/silu/test.py | 3 --- iron/operators/softmax/test.py | 5 +---- iron/operators/swiglu_decode/test.py | 4 +--- iron/operators/swiglu_prefill/test.py | 4 +--- iron/operators/tanh/test.py | 3 --- iron/operators/transpose/test.py | 2 -- 21 files changed, 12 insertions(+), 87 deletions(-) diff --git a/conftest.py b/conftest.py index 1a3c0e89..e9269f09 100644 --- a/conftest.py +++ b/conftest.py @@ -168,3 +168,8 @@ def pytest_generate_tests(metafunc): if iterations > 1: metafunc.fixturenames.append("_iteration") metafunc.parametrize("_iteration", range(iterations), ids=lambda i: f"iter{i}") + + +def pytest_make_parametrize_id(config, val, argname): + """Automatically generate test IDs with parameter names""" + return f"{argname}_{val}" diff --git a/iron/operators/axpy/test.py b/iron/operators/axpy/test.py index b37fabe2..8fc84ef0 100755 --- a/iron/operators/axpy/test.py +++ b/iron/operators/axpy/test.py @@ -25,8 +25,6 @@ def get_params(): if tile_size * num_aie_columns != input_length: continue for scalar in scalar_factors: - name = f"axpy_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}_{scalar}" - # Determine if this is a regular test case is_regular = input_length == 2048 and scalar == 3.0 marks = [] if is_regular else [pytest.mark.extensive] @@ -38,7 +36,6 @@ def get_params(): num_channels, tile_size, scalar, - id=name, marks=marks, ) ) diff --git a/iron/operators/dequant/test.py b/iron/operators/dequant/test.py index 4ab904c0..a4678199 100644 --- a/iron/operators/dequant/test.py +++ b/iron/operators/dequant/test.py @@ -29,8 +29,6 @@ def get_params(): # Only proceed if tile_size * total_cores == input_length (exact division) if tile_size * total_cores == input_length: - name = f"dequant_{num_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -41,7 +39,6 @@ def get_params(): num_channels, tile_size, group_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/elementwise_add/test.py b/iron/operators/elementwise_add/test.py index 5794a2c4..87cb5c1f 100755 --- a/iron/operators/elementwise_add/test.py +++ b/iron/operators/elementwise_add/test.py @@ -25,8 +25,6 @@ def get_params(): if tile_size * num_aie_columns != input_length: continue - name = f"eltwise_add_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/elementwise_mul/test.py b/iron/operators/elementwise_mul/test.py index 82b34a9b..163ff0e4 100755 --- a/iron/operators/elementwise_mul/test.py +++ b/iron/operators/elementwise_mul/test.py @@ -26,8 +26,6 @@ def get_params(): if tile_size * num_aie_columns != input_length: continue - name = f"eltwise_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -37,7 +35,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/gelu/test.py b/iron/operators/gelu/test.py index 69b4519d..f74a2e73 100755 --- a/iron/operators/gelu/test.py +++ b/iron/operators/gelu/test.py @@ -27,8 +27,6 @@ def get_params(): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - name = f"gelu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -38,7 +36,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/gemm/test.py b/iron/operators/gemm/test.py index b1dc8194..b9da6f10 100755 --- a/iron/operators/gemm/test.py +++ b/iron/operators/gemm/test.py @@ -49,33 +49,8 @@ def get_params(): # Helper to generate name and append param def add_params(param_list, is_extensive): for p in param_list: - ( - M, - K, - N, - num_aie_columns, - b_col_maj, - c_col_maj, - m, - k, - n, - trace_size, - partition_N, - ) = p - - name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" - if b_col_maj: - name += "_bcolmaj" - if c_col_maj: - name += "_ccolmaj" - if partition_N > 1: - name += f"_{partition_N}npart" - if trace_size > 0: - name += f"_{trace_size}trace" - marks = [pytest.mark.extensive] if is_extensive else [] - - params.append(pytest.param(*p, id=name, marks=marks)) + params.append(pytest.param(*p, marks=marks)) add_params(regular_params, is_extensive=False) add_params(extensive_params, is_extensive=True) diff --git a/iron/operators/gemv/test.py b/iron/operators/gemv/test.py index c26fb1f4..493d51c0 100755 --- a/iron/operators/gemv/test.py +++ b/iron/operators/gemv/test.py @@ -27,12 +27,9 @@ def get_params(): params = [] for p in params_list: - M, K, num_aie_columns, tile_size_input, tile_size_output = p - name = f"matrix_vector_mul_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col" - # All tests are considered regular here as per original code structure # (original code returned same list for both regular and extensive) - params.append(pytest.param(*p, id=name)) + params.append(pytest.param(*p)) return params diff --git a/iron/operators/layer_norm/test.py b/iron/operators/layer_norm/test.py index 360da0a1..57fffea4 100755 --- a/iron/operators/layer_norm/test.py +++ b/iron/operators/layer_norm/test.py @@ -25,8 +25,6 @@ def get_params(): tile_size = 8192 check_length = tile_size * total_cores if check_length == input_length: - name = f"layer_norm_{num_aie_columns}_cols_{num_channels_layer}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels_layer, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/mem_copy/test.py b/iron/operators/mem_copy/test.py index f6314e5b..d65028c2 100644 --- a/iron/operators/mem_copy/test.py +++ b/iron/operators/mem_copy/test.py @@ -34,8 +34,6 @@ def get_params(): # Only proceed if tile_size * num_cores == input_length (exact division) if tile_size * num_cores == input_length: - name = f"mem_copy_{num_cores}_cores_{num_channels}_chans_{input_length}_tile_{tile_size}_{str(bypass)}" - is_regular = input_length == 2048 and bypass == False marks = [] if is_regular else [pytest.mark.extensive] @@ -46,7 +44,6 @@ def get_params(): num_channels, bypass, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index b1871e42..ae0275cf 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -14,11 +14,10 @@ def get_params(): params_list = [(16384, 64, 1, 8)] - names = ["mha"] params = [] - for p, name in zip(params_list, names): - params.append(pytest.param(*p, id=name)) + for p in params_list: + params.append(pytest.param(*p)) return params diff --git a/iron/operators/relu/test.py b/iron/operators/relu/test.py index 4bea4584..f3236c26 100755 --- a/iron/operators/relu/test.py +++ b/iron/operators/relu/test.py @@ -25,8 +25,6 @@ def get_params(): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - name = f"relu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/rms_norm/test.py b/iron/operators/rms_norm/test.py index f7c183f9..7f736021 100755 --- a/iron/operators/rms_norm/test.py +++ b/iron/operators/rms_norm/test.py @@ -35,11 +35,6 @@ def get_params(): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - if not weighted: - name = f"rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_tile_{tile_size}" - else: - name = f"weighted_rms_norm_{num_aie_columns}_cols_{num_channels_rms}_channels_{input_length}_weights_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -50,7 +45,6 @@ def get_params(): num_channels_rms, tile_size, weighted, - id=name, marks=marks, ) ) diff --git a/iron/operators/rope/test.py b/iron/operators/rope/test.py index d7156a7b..6459e28a 100755 --- a/iron/operators/rope/test.py +++ b/iron/operators/rope/test.py @@ -27,8 +27,6 @@ def get_params(): for n_angle_rows in input_angle_rows: for n_cols in input_cols: for method_type in method_types: - name = f"rope_{num_aie_columns}c_{n_rows}rows_{n_cols}cols_{n_angle_rows}arows_{method_type}m" - is_regular = ( n_rows == 32 and n_cols == 512 @@ -50,7 +48,6 @@ def get_params(): n_angle_rows, num_aie_columns, method_type, - id=name, marks=marks, ) ) diff --git a/iron/operators/sigmoid/test.py b/iron/operators/sigmoid/test.py index 641fca96..a9b6b596 100755 --- a/iron/operators/sigmoid/test.py +++ b/iron/operators/sigmoid/test.py @@ -25,8 +25,6 @@ def get_params(): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - name = f"sigmoid_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/silu/test.py b/iron/operators/silu/test.py index 6eb22f20..267f7669 100755 --- a/iron/operators/silu/test.py +++ b/iron/operators/silu/test.py @@ -25,8 +25,6 @@ def get_params(): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - name = f"silu_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/softmax/test.py b/iron/operators/softmax/test.py index 093610a7..a75c37af 100755 --- a/iron/operators/softmax/test.py +++ b/iron/operators/softmax/test.py @@ -42,13 +42,10 @@ def get_params(): optimal_columns, optimal_channels = get_optimal_columns_channels( input_length, tile_size ) - name = f"softmax_{optimal_columns}_cols_{optimal_channels}_channels_{input_length}_tile_{tile_size}" # All tests are regular as extensive list was empty in original code params.append( - pytest.param( - input_length, optimal_columns, optimal_channels, tile_size, id=name - ) + pytest.param(input_length, optimal_columns, optimal_channels, tile_size) ) return params diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 8d4a51d2..65af7cf2 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -20,9 +20,7 @@ def get_params(): params = [] for p in params_list: - emb, hid = p - name = f"swiglu_decode_1x{emb}x{hid}" - params.append(pytest.param(*p, id=name)) + params.append(pytest.param(*p)) return params diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 10df9243..8bb8d14e 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -21,9 +21,7 @@ def get_params(): params = [] for p in params_list: - _, emb, hid, _ = p - name = f"swiglu_prefill_256x{emb}x{hid}" - params.append(pytest.param(*p, id=name)) + params.append(pytest.param(*p)) return params diff --git a/iron/operators/tanh/test.py b/iron/operators/tanh/test.py index 0a50b183..6888474d 100755 --- a/iron/operators/tanh/test.py +++ b/iron/operators/tanh/test.py @@ -25,8 +25,6 @@ def get_params(): tile_size = 4096 check_length = tile_size * num_aie_columns if check_length == input_length: - name = f"tanh_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" - is_regular = input_length == 2048 marks = [] if is_regular else [pytest.mark.extensive] @@ -36,7 +34,6 @@ def get_params(): num_aie_columns, num_channels, tile_size, - id=name, marks=marks, ) ) diff --git a/iron/operators/transpose/test.py b/iron/operators/transpose/test.py index 00cf562b..f151f8df 100755 --- a/iron/operators/transpose/test.py +++ b/iron/operators/transpose/test.py @@ -36,7 +36,6 @@ def get_params(): length = M * N if check_length != length: continue - name = f"transpose_{M}_M_{N}_N_{num_aie_columns}_cols_{num_channels}_channels_{m}_m_{n}_n_{s}_s" is_regular = M == 2048 and N == 64 marks = [] if is_regular else [pytest.mark.extensive] @@ -50,7 +49,6 @@ def get_params(): m, n, s, - id=name, marks=marks, ) ) From 7aca3dc52ed09d163a4f881a531fdc2dce7a047f Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 13:46:49 -0700 Subject: [PATCH 93/99] remove unused code --- iron/common/aie_device_manager.py | 53 ----- iron/common/fusion.py | 360 ------------------------------ 2 files changed, 413 deletions(-) delete mode 100644 iron/common/aie_device_manager.py delete mode 100644 iron/common/fusion.py diff --git a/iron/common/aie_device_manager.py b/iron/common/aie_device_manager.py deleted file mode 100644 index fda4d0cb..00000000 --- a/iron/common/aie_device_manager.py +++ /dev/null @@ -1,53 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Global AIE Device Manager for resource sharing and cleanup -""" - -import logging -import os -import sys -from pathlib import Path -from typing import Dict, Optional, Any -import pyxrt -from aie.utils import DefaultNPURuntime -from aie.utils.npukernel import NPUKernel -from aie.iron.device import NPU1, NPU2 - - -class AIEDeviceManager: - """Singleton manager for AIE XRT resources""" - - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - self.runtime = DefaultNPURuntime - # Expose device for AIEContext buffer allocation - # Accessing protected member _device as AIEContext needs pyxrt.device - self.device = self.runtime._device - self.device_type = self.runtime.device() - - def get_kernel_handle(self, xclbin_path: str, kernel_name: str, insts_path: str): - """Get kernel handle using HostRuntime""" - npu_kernel = NPUKernel( - xclbin_path=xclbin_path, insts_path=insts_path, kernel_name=kernel_name - ) - return self.runtime.load(npu_kernel) - - def device_str(self) -> str: - return self.device_type.resolve().name - - def cleanup(self): - """Clean up all XRT resources""" - # HostRuntime handles cleanup - pass - - def reset(self): - """Reset the device manager (for debugging)""" - pass diff --git a/iron/common/fusion.py b/iron/common/fusion.py deleted file mode 100644 index 132a4b3c..00000000 --- a/iron/common/fusion.py +++ /dev/null @@ -1,360 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import ml_dtypes -import pyxrt -import ctypes -from . import compilation as comp -from .base import AIEOperatorBase, MLIROperator, AIEBuffer -from .device_manager import AIEDeviceManager - -# Fused Operator -# ########################################################################## - - -class FusedMLIROperator(AIEOperatorBase): - """Operator that fuses multiple MLIROperators into one.""" - - def __init__( - self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs - ): - assert all( - isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) - for op, *bufs in runlist - ) - self.runlist = runlist - self.name = name - self.input_args = input_args - self.output_args = output_args - self.explicit_buffer_sizes = ( - buffer_sizes or {} - ) # Optional dict: buffer_name -> size_in_bytes - self.kernel_archive = "kernels.a" - super().__init__(*args, **kwargs) - - def get_operator_name(self): - return self.name - - def get_kernel_artifacts(self): - """Collect all kernel artifacts from child operators.""" - kernel_artifacts = [] - unique_operators = [] - for op, *_ in self.runlist: - if op not in unique_operators: - unique_operators.append(op) - for idx, op in enumerate(unique_operators): - objs = op.get_kernel_artifacts() - for obj in objs: - obj.filename = f"op{idx}_{obj.filename}" - obj.prefix_symbols = f"op{idx}_" - kernel_artifacts.extend(objs) - return kernel_artifacts - - def get_mlir_artifact(self): - # Build operator_mlir_map: {op_name -> PythonGeneratedMLIRArtifact} - operator_mlir_map = {} - mlir_dependencies = [] - comp_runlist = [] - op_names = {} # op -> op_name - - unique_operators = [] - for op, *_ in self.runlist: - if op not in unique_operators: - unique_operators.append(op) - for idx, op in enumerate(unique_operators): - mlir_artifact = op.get_mlir_artifact() - if len(op.get_kernel_artifacts()) > 0: - # FIXME: currently hard-coding that the design will accept this argument as an input if it uses kernels - # Also not handling name collisions of kernels with the same name - mlir_artifact.callback_kwargs["kernel_archive"] = self.kernel_archive - mlir_artifact.callback_kwargs["func_prefix"] = f"op{idx}_" - op_name = f"op{idx}_{op.__class__.__name__}" - op_names[op] = op_name - operator_mlir_map[op_name] = mlir_artifact - - for op, *bufs in self.runlist: - comp_runlist.append((op_names[op], *bufs)) - - # Calculate buffer layout: {buffer_name -> (type, offset, length)} - self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( - self._calculate_buffer_layout() - ) - - filename = self.get_operator_name() + "_fused.mlir" - fused_artifact = comp.FusedMLIRSource( - filename, - operator_mlir_map=operator_mlir_map, - runlist=comp_runlist, - subbuffer_layout=self.subbuffer_layout, - buffer_sizes=self.buffer_sizes, - slice_info=self.slice_info, - ) - - return fused_artifact - - def _calculate_buffer_layout(self): - args = {} # base_buffer_name -> args_spec - sliced_buffers = ( - {} - ) # full_buffer_name (with slice) -> (base_name, start, end, args_spec) - - # Collect all buffer specs from operators - for op, *bufs in self.runlist: - args_specs = op.get_arg_spec() - assert len(args_specs) == len( - bufs - ), "Number of buffers must match operator argument specification" - for i, buf_name in enumerate(bufs): - args_spec = args_specs[i] - - # Parse slice notation: "buffer_name[start:end]" - if "[" in buf_name and buf_name.endswith("]"): - base_name = buf_name[: buf_name.index("[")] - slice_part = buf_name[buf_name.index("[") + 1 : -1] - start, end = map(int, slice_part.split(":")) - sliced_buffers[buf_name] = (base_name, start, end, args_spec) - # Track that base buffer exists (size will be set later) - if ( - base_name not in args - and base_name not in self.explicit_buffer_sizes - ): - raise ValueError( - f"Sliced buffer '{buf_name}' requires explicit size for base buffer '{base_name}' in buffer_sizes parameter" - ) - else: - # Regular buffer (no slice) - if buf_name not in args: - args[buf_name] = args_spec - else: - assert np.prod(args[buf_name].shape) == np.prod( - args_spec.shape - ), f"Buffer {buf_name} has conflicting sizes between operators" - - # Verify all input/output args are present (either as regular or sliced buffers) - all_buffer_names = set(args.keys()) | set(sliced_buffers.keys()) - for arg in self.input_args: - # Check if it's a base buffer name in explicit_buffer_sizes - if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: - raise AssertionError( - f"Input argument {arg} not found in runlist buffers" - ) - for arg in self.output_args: - if arg not in all_buffer_names and arg not in self.explicit_buffer_sizes: - raise AssertionError( - f"Output argument {arg} not found in runlist buffers" - ) - - # Determine buffer types and create layout - subbuffer_layout = {} - slice_info = {} # full_buffer_name -> (base_name, start, end) - - def add_buffers(buffer_type, args_list): - offset = 0 - for arg in args_list: - if arg in self.explicit_buffer_sizes: - # Explicit size specified - this is a parent buffer for slices - length = self.explicit_buffer_sizes[arg] - subbuffer_layout[arg] = (buffer_type, offset, length) - offset += length - elif arg in args: - # Regular buffer with inferred size - arg_spec = args[arg] - length = int( - np.prod(arg_spec.shape) * np.dtype(arg_spec.dtype).itemsize - ) - subbuffer_layout[arg] = (buffer_type, offset, length) - offset += length - # Note: sliced buffers are handled separately, not in args_list - return offset # == total length - - # Add sliced buffer entries to layout (they reference parent buffers) - for buf_name, (base_name, start, end, args_spec) in sliced_buffers.items(): - slice_info[buf_name] = (base_name, start, end) - - input_buffer_size = add_buffers("input", self.input_args) - output_buffer_size = add_buffers("output", self.output_args) - scratch_args = [ - arg - for arg in args - if arg not in self.input_args and arg not in self.output_args - ] - # Also include explicit buffers that are only used for slicing - for explicit_buf in self.explicit_buffer_sizes: - if ( - explicit_buf not in self.input_args - and explicit_buf not in self.output_args - and explicit_buf not in scratch_args - ): - scratch_args.append(explicit_buf) - scratch_buffer_size = add_buffers("scratch", scratch_args) - - buffer_sizes = (input_buffer_size, output_buffer_size, scratch_buffer_size) - return subbuffer_layout, buffer_sizes, slice_info - - def set_up_artifacts(self): - operator_name = self.get_operator_name() - mlir_artifact = self.get_mlir_artifact() - kernel_objects = self.get_kernel_artifacts() - kernel_dep = ( - [ - comp.KernelArchiveArtifact( - self.kernel_archive, - dependencies=kernel_objects, - ) - ] - if kernel_objects - else [] - ) - full_elf_artifact = comp.FullElfArtifact( - f"{operator_name}.elf", - mlir_input=mlir_artifact, - dependencies=[mlir_artifact] + kernel_dep, - ) - self.add_artifacts([full_elf_artifact]) - - def get_arg_spec(self): - pass - - def get_callable(self): - return FusedFullELFCallable(self) - - def get_layout_for_buffer(self, buffer_name): - if buffer_name in self.slice_info: - buf_name, start, end = self.slice_info[buffer_name] - buf_type, parent_start, parent_end = self.get_layout_for_buffer(buf_name) - return buf_type, parent_start + start, parent_start + end - - buf_type, offset, length = self.subbuffer_layout[buffer_name] - return buf_type, offset, length - - -def load_elf(op): - assert isinstance(op.artifacts[0], comp.FullElfArtifact) - elf_data = None - with open(op.artifacts[0].filename, "rb") as f: - elf_data = np.frombuffer(f.read(), dtype=np.uint32) - return elf_data - - -def patch_elf(elf_data, patches): - for i, patch in patches.items(): - val, mask = patch - elf_data[i] = (elf_data[i] & ~mask) | (val & mask) - return elf_data - - -class FullELFCallable: - def __init__( - self, - elf_data, - device_name="main", - sequence_name="sequence", - device_manager=None, - ): - self.device_name = device_name - self.sequence_name = sequence_name - self.device_manager = device_manager or AIEDeviceManager() - self.reload_elf(elf_data) - - def __call__(self, *args): - run = pyxrt.run(self.xrt_kernel) - for i, arg in enumerate(args): - assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" - run.set_arg(i, arg) - run.start() - ret_code = run.wait() - if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError(f"Kernel execution failed with return code {ret_code}") - - def reload_elf(self, elf_data): - # Create a PyCapsule from the numpy array pointer for pybind11 - elf_data_u8 = elf_data.view(dtype=np.uint8) - ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_void_p, - ] - capsule = ctypes.pythonapi.PyCapsule_New(elf_data_u8.ctypes.data, None, None) - xrt_elf = pyxrt.elf(capsule, elf_data.nbytes) - xrt_context = pyxrt.hw_context(self.device_manager.device, xrt_elf) - self.xrt_kernel = pyxrt.ext.kernel( - xrt_context, f"{self.device_name}:{self.sequence_name}" - ) - - -class FusedFullELFCallable(FullELFCallable): - def __init__(self, op, elf_data=None, device_manager=None): - if elf_data is None: - elf_data = load_elf(op) - super().__init__(elf_data, device_manager=device_manager) - - self.op = op - input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes - itemsize = np.dtype(ml_dtypes.bfloat16).itemsize - - self.input_buffer = AIEBuffer( - shape=(max(input_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16, - ) - - self.output_buffer = AIEBuffer( - shape=(max(output_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16, - ) - - self.scratch_buffer = AIEBuffer( - shape=(max(scratch_buffer_size, itemsize) // itemsize,), - dtype=ml_dtypes.bfloat16, - ) - - self._buffer_cache = {} - - def get_buffer(self, buffer_name): - # Return cached buffer if already allocated - if buffer_name in self._buffer_cache: - return self._buffer_cache[buffer_name] - - buf_type, offset, length = self.op.get_layout_for_buffer(buffer_name) - - # Select the appropriate main buffer - if buf_type == "input": - main_buffer = self.input_buffer - elif buf_type == "output": - main_buffer = self.output_buffer - elif buf_type == "scratch": - main_buffer = self.scratch_buffer - else: - raise ValueError( - f"Unknown buffer type '{buf_type}' for buffer '{buffer_name}'" - ) - - if main_buffer is None: - raise RuntimeError(f"Main buffer for type '{buf_type}' is not allocated") - - # Convert byte offset/length to element offset/length - itemsize = np.dtype(ml_dtypes.bfloat16).itemsize - offset_elements = offset // itemsize - length_elements = length // itemsize - - # Create subbuffer with appropriate shape - sub_buffer = main_buffer.subbuffer( - length=length_elements, - offset=offset_elements, - shape=(length_elements,), - dtype=ml_dtypes.bfloat16, - ) - - # Cache and return - self._buffer_cache[buffer_name] = sub_buffer - return sub_buffer - - def __call__(self): - self.input_buffer.to("npu") - self.output_buffer.to("npu") - super().__call__( - self.input_buffer.bo if self.input_buffer else None, - self.output_buffer.bo if self.output_buffer else None, - self.scratch_buffer.bo if self.scratch_buffer else None, - ) From 782576913b7544f8086a9b2c82e6d460770ba95b Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 14:00:09 -0700 Subject: [PATCH 94/99] remove more unused things --- iron/common/base.py | 18 -- iron/common/compilation/__init__.py | 1 - iron/common/compilation/fusion.py | 247 ---------------------------- iron/common/context.py | 1 - 4 files changed, 267 deletions(-) delete mode 100644 iron/common/compilation/fusion.py diff --git a/iron/common/base.py b/iron/common/base.py index 8da910f8..746cb43f 100644 --- a/iron/common/base.py +++ b/iron/common/base.py @@ -307,24 +307,6 @@ def __call__(self, *buffers): raise RuntimeError(f"Kernel did not complete correctly: {ret_code}") -class PatchableSingleXclbinCallable(SingleXclbinCallable): - def __init__( - self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None - ): - super().__init__( - xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager - ) - self.baseline_instructions = self.insts_buffer.view_as_np().copy() - - def patch(self, patches): - """Apply patches with masking: dict of {position: (value, mask)}.""" - insts = self.insts_buffer.view_as_np() - insts[:] = self.baseline_instructions - for pos, (val, mask) in patches.items(): - insts[pos] = (np.int64(insts[pos]) & ~mask) | (val & mask) - self.insts_buffer.to("npu") - - class CompositeCallable: """Callable for executing a sequence of sub-operators""" diff --git a/iron/common/compilation/__init__.py b/iron/common/compilation/__init__.py index 405df6b0..20823d02 100644 --- a/iron/common/compilation/__init__.py +++ b/iron/common/compilation/__init__.py @@ -2,4 +2,3 @@ # SPDX-License-Identifier: Apache-2.0 from .base import * -from .fusion import * diff --git a/iron/common/compilation/fusion.py b/iron/common/compilation/fusion.py deleted file mode 100644 index ea1d47e2..00000000 --- a/iron/common/compilation/fusion.py +++ /dev/null @@ -1,247 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Temporal fusion of multiple MLIR modules into one module with multiple devices and a main runtime sequence that calls into them. -""" - -import numpy as np -import importlib.util -from pathlib import Path -from aie import ir -from aie.dialects import aie, aiex, memref -from aie.extras.context import mlir_mod_ctx -import ml_dtypes - -from . import ( - CompilationArtifact, - CompilationRule, - CompilationCommand, - PythonCallbackCompilationCommand, - SourceArtifact, - PythonGeneratedMLIRArtifact, -) - -# Compilation Artifacts -# ########################################################################## - - -class FusedMLIRSource(CompilationArtifact): - def __init__( - self, - filename, - operator_mlir_map, - runlist, - subbuffer_layout, - buffer_sizes, - slice_info=None, - ): - dependencies = list(operator_mlir_map.values()) - super().__init__(filename, dependencies) - self.operator_mlir_map = operator_mlir_map - self.runlist = runlist - self.subbuffer_layout = subbuffer_layout - self.buffer_sizes = buffer_sizes - self.slice_info = slice_info or {} - - -# Helper Functions -# ########################################################################## - - -def extract_runtime_sequence_arg_types(dev_op): - """MLIR helper: Extract argument types from a device operation's runtime sequence.""" - for nested_op in dev_op.body_region.blocks[0].operations: - op_name = nested_op.operation.name - if op_name == "aie.runtime_sequence": - if hasattr(nested_op, "body") and hasattr(nested_op.body, "blocks"): - if len(nested_op.body.blocks) > 0: - entry_block = nested_op.body.blocks[0] - arg_types = [ - entry_block.arguments[i].type - for i in range(len(entry_block.arguments)) - ] - return arg_types - raise RuntimeError("Could not find runtime sequence in device operation") - - -def get_child_mlir_module(mlir_artifact): - """Extract MLIR module from a PythonGeneratedMLIRArtifact.""" - assert isinstance(mlir_artifact, PythonGeneratedMLIRArtifact) - spec = importlib.util.spec_from_file_location( - Path(mlir_artifact.import_path).name, mlir_artifact.import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - if mlir_artifact.requires_context: - raise NotImplementedError("Not handled, make your operator return a ctx.module") - - callback_function = getattr(module, mlir_artifact.callback_fn) - mlir_module = callback_function( - *mlir_artifact.callback_args, **mlir_artifact.callback_kwargs - ) - return mlir_module - - -def fuse_mlir(artifact): - """Fuse multiple MLIR modules by inlining their device operations and adding a new main device and runtime sequence that call into sequence of operations based on a runlist.""" - - input_buffer_size, output_buffer_size, scratch_buffer_size = artifact.buffer_sizes - - # Extract device operations from each operator's MLIR artifact - device_mlir_strings = {} - device_ty = None - sequence_arg_types = {} - for op_name, mlir_artifact in artifact.operator_mlir_map.items(): - mlir_module = get_child_mlir_module(mlir_artifact) - device_ops = [ - op for op in mlir_module.body.operations if isinstance(op, aie.DeviceOp) - ] - assert ( - len(device_ops) == 1 - ), f"Expected exactly one device operation in MLIR artifact for operator '{op_name}'" - device_op = device_ops[0] - if device_ty is None: - device_ty = device_op.device - device_mlir_strings[op_name] = str(device_op) - sequence_arg_types[op_name] = extract_runtime_sequence_arg_types(device_op) - - # Build fused MLIR module - with mlir_mod_ctx() as ctx: - - # Concatenate aie.device ops - for op_name, device_str in device_mlir_strings.items(): - dev_op = aie.DeviceOp.parse(device_str) - dev_op.sym_name = ir.StringAttr.get(op_name) - ctx.module.body.append(dev_op) - - # Create the main device -- this contains the runtime sequence calling into the other devices - @aie.device(device_ty) - def main(): - buf_dtype = np.dtype[ - ml_dtypes.bfloat16 - ] # TODO: support for other data types - itemsize = 2 - - # RuntimeSequenceOp - @aiex.runtime_sequence( - np.ndarray[(input_buffer_size // itemsize,), buf_dtype], - np.ndarray[(output_buffer_size // itemsize,), buf_dtype], - np.ndarray[(scratch_buffer_size // itemsize,), buf_dtype], - ) - def sequence(input_buf, output_buf, scratch_buf): - consolidated_buffers = { - "input": input_buf, - "output": output_buf, - "scratch": scratch_buf, - } - - # Execute operations in runlist order - configure_op = None - last_op_name = None - for op_name, *buffer_names in artifact.runlist: - expected_arg_types = sequence_arg_types[op_name] - - # Avoid reconfiguring altogether if the same op is called multiple times consecutively - if configure_op is None or op_name != last_op_name: - # Configure Op - configure_sym_ref_attr = ir.FlatSymbolRefAttr.get(op_name) - configure_op = aiex.ConfigureOp( - configure_sym_ref_attr - ) # TODO: optimization -- if previous op was in the same device, skip reconfiguration - configure_body = configure_op.body.blocks.append() - last_op_name = op_name - - with ir.InsertionPoint(configure_body): - - # For each buffer, add subview and reinterpret_cast ops - buffer_ssa_values = [] - for idx, buf_name in enumerate(buffer_names): - # Check if this is a sliced buffer - if buf_name in artifact.slice_info: - base_name, start, end = artifact.slice_info[buf_name] - # Get parent buffer info - buf_type, parent_offset, parent_length = ( - artifact.subbuffer_layout[base_name] - ) - # Calculate actual offset and length for slice - offset = parent_offset + start - length = end - start - else: - # Regular buffer - buf_type, offset, length = artifact.subbuffer_layout[ - buf_name - ] - - # Subview Op - consolidated_buf = consolidated_buffers[buf_type] - offset_elements = offset // itemsize - size_elements = length // itemsize - subview = memref.subview( - consolidated_buf, - [offset_elements], - [size_elements], - [1], - ) - - # Reinterpret_cast Op - target_type = expected_arg_types[idx] - expected_memref = ir.MemRefType(target_type) - target_shape = [ - expected_memref.shape[i] - for i in range(expected_memref.rank) - ] - expected_size = np.prod(target_shape) - assert ( - expected_size == size_elements - ), f"Size mismatch for buffer '{buf_name}': MLIR runtime sequence expected {expected_size}, Python fused operator provided {size_elements}" - strides = [] - stride = 1 - for dim in reversed(target_shape): - strides.insert(0, stride) - stride *= dim - result_type = ir.MemRefType.get( - target_shape, ir.BF16Type.get() - ) - reinterpreted = memref.reinterpret_cast( - result=result_type, - source=subview, - offsets=[], - sizes=[], - strides=[], - static_offsets=[0], - static_sizes=target_shape, - static_strides=strides, - ) - buffer_ssa_values.append(reinterpreted) - - # Run Op - sequence_sym_ref_attr = ir.FlatSymbolRefAttr.get("sequence") - run_op = aiex.RunOp(sequence_sym_ref_attr, buffer_ssa_values) - - # Write the fused MLIR to file - with open(artifact.filename, "w") as f: - f.write(str(ctx.module)) - - -# Compilation Rules -# ########################################################################## - - -class FusePythonGeneratedMLIRCompilationRule(CompilationRule): - """Compilation rule that fuses multiple MLIR modules into one.""" - - def matches(self, graph): - return any(graph.get_worklist(FusedMLIRSource)) - - def compile(self, graph): - commands = [] - worklist = graph.get_worklist(FusedMLIRSource) - for artifact in worklist: - callback = lambda artifact=artifact: fuse_mlir(artifact) - commands.append(PythonCallbackCompilationCommand(callback)) - new_artifact = SourceArtifact(artifact.filename) - new_artifact.available = True - graph.replace(artifact, new_artifact) - return commands diff --git a/iron/common/context.py b/iron/common/context.py index 1cde7087..d33c6888 100644 --- a/iron/common/context.py +++ b/iron/common/context.py @@ -25,7 +25,6 @@ def __init__(self, use_runlist=True, build_dir=None): # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) self.use_runlist = use_runlist self.compilation_rules = [ - comp.FusePythonGeneratedMLIRCompilationRule(), comp.GenerateMLIRFromPythonCompilationRule(), comp.PeanoCompilationRule(self.peano_dir, self.mlir_aie_dir), comp.ArchiveCompilationRule(self.peano_dir), From f4651ef0839408c8e937277a3f8084be73cec2c8 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 16:57:58 -0700 Subject: [PATCH 95/99] use XRTTensor instead of AIEBuffer --- iron/common/__init__.py | 2 +- iron/common/base.py | 113 +++----------------------- iron/common/test_utils.py | 13 +-- iron/operators/swiglu_decode/op.py | 19 +++-- iron/operators/swiglu_decode/test.py | 13 +-- iron/operators/swiglu_prefill/op.py | 19 +++-- iron/operators/swiglu_prefill/test.py | 23 ++++-- 7 files changed, 62 insertions(+), 140 deletions(-) diff --git a/iron/common/__init__.py b/iron/common/__init__.py index 68cafac6..d51b5c65 100644 --- a/iron/common/__init__.py +++ b/iron/common/__init__.py @@ -3,12 +3,12 @@ """Common utilities and base classes for IRON operators.""" +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor from .base import ( AIEOperatorBase, MLIROperator, CompositeOperator, CompositeCallable, - AIEBuffer, SingleXclbinCallable, AIERuntimeArgSpec, ) diff --git a/iron/common/base.py b/iron/common/base.py index 746cb43f..48d083cd 100644 --- a/iron/common/base.py +++ b/iron/common/base.py @@ -176,95 +176,6 @@ def __init__(self, direction, shape, dtype=bfloat16): self.direction = direction -class AIEBuffer(XRTTensor): - def __init__(self, shape, dtype=bfloat16, bo=None, device_manager=None): - self.device_manager = device_manager or AIEDeviceManager() - self.subviews = [] - - if bo is not None: - Tensor.__init__(self, shape, dtype=dtype, device="cpu") - self._shape = shape - self.xrt_device = self.device_manager.device - self._bo = bo - ptr = self._bo.map() - self._data = np.frombuffer(ptr, dtype=self.dtype).reshape(self._shape) - else: - super().__init__(shape, dtype=dtype, device="cpu") - - @property - def bo(self): - return self._bo - - @property - def on(self): - return self.device - - @on.setter - def on(self, value): - self.device = value - - def subbuffer(self, length, offset, shape, dtype=None): - if dtype is None: - dtype = self.dtype - assert np.prod(shape) == length - itemsize = np.dtype(dtype).itemsize - assert offset >= 0 - assert offset * itemsize <= np.prod(self.shape) * np.dtype(self.dtype).itemsize - assert ( - length * itemsize + offset * itemsize - <= np.prod(self.shape) * np.dtype(self.dtype).itemsize - ) - sub_bo = pyxrt.bo( - self.bo, # parent bo - length * itemsize, # size - offset * itemsize, # offset - ) - sub_buffer = AIEBuffer( - shape=shape, dtype=dtype, bo=sub_bo, device_manager=self.device_manager - ) - sub_buffer.on = self.on - self.subviews.append(sub_buffer) - return sub_buffer - - def view(self, shape): - assert np.prod(shape) == np.prod(self.shape) - sub_buffer = AIEBuffer( - shape=shape, - dtype=self.dtype, - bo=self.bo, - device_manager=self.device_manager, - ) - sub_buffer.on = self.on - self.subviews.append(sub_buffer) - return sub_buffer - - def view_as_np(self): - return self.numpy() - - def view_as_torch(self): - return numpy_to_torch(self.numpy()) - - def to(self, dest): - super().to(dest) - todo = self.subviews.copy() - while todo: - sub_buffer = todo.pop() - sub_buffer.device = dest - todo.extend(sub_buffer.subviews) - return self - - @staticmethod - def from_np(buffer): - aie_buffer = AIEBuffer(buffer.shape, dtype=buffer.dtype) - aie_buffer.data[:] = buffer - aie_buffer.to("npu") - return aie_buffer - - @staticmethod - def from_torch(tensor): - return AIEBuffer.from_np(torch_to_numpy(tensor)) - - class SingleXclbinCallable: def __init__( self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None @@ -275,17 +186,12 @@ def __init__( ) with open(str(insts_bin_path), "rb") as f: instructions = np.frombuffer(f.read(), dtype=np.uint32) - insts_bo = pyxrt.bo( - self.device_manager.device, - instructions.nbytes, - pyxrt.bo.cacheable, - self.xrt_kernel.group_id(1), - ) - insts_bo.write(instructions.view(np.uint8), 0) - self.insts_buffer = AIEBuffer( - shape=(len(instructions),), dtype=np.uint32, bo=insts_bo + self.insts_buffer = XRTTensor( + instructions, + dtype=np.uint32, + flags=pyxrt.bo.cacheable, + group_id=self.xrt_kernel.group_id(1), ) - self.insts_buffer.to("npu") self.args_spec = args_spec def __call__(self, *buffers): @@ -298,9 +204,12 @@ def __call__(self, *buffers): for buf in buffers: buf.to("npu") opcode = 3 - bos = [buffer.bo for buffer in buffers] + bos = [buffer.buffer_object() for buffer in buffers] run = self.xrt_kernel( - opcode, self.insts_buffer.bo, self.insts_buffer.shape[0], *bos + opcode, + self.insts_buffer.buffer_object(), + self.insts_buffer.shape[0], + *bos, ) ret_code = run.wait() if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: @@ -315,7 +224,7 @@ def __init__(self, sequence, intermediate_buffers=None): Args: sequence: List of (callable, args_indices) tuples. args_indices is a list of indices into the combined list of [inputs, outputs, intermediates]. - intermediate_buffers: List of AIEBuffer objects for intermediate results. + intermediate_buffers: List of XRTTensor objects for intermediate results. """ self.sequence = sequence self.intermediate_buffers = intermediate_buffers or [] diff --git a/iron/common/test_utils.py b/iron/common/test_utils.py index 53f40e9f..d04aa14f 100644 --- a/iron/common/test_utils.py +++ b/iron/common/test_utils.py @@ -6,7 +6,8 @@ from ml_dtypes import bfloat16 from .utils import torch_to_numpy import logging -from .base import MLIROperator, CompositeOperator, AIEBuffer +from .base import MLIROperator, CompositeOperator +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor def nearly_equal( @@ -107,18 +108,18 @@ def run_test( except StopIteration: raise ValueError("Not enough input buffers provided for arg spec") data_np = torch_to_numpy(data) - buf = AIEBuffer.from_np(data_np) + buf = XRTTensor(data_np, dtype=data_np.dtype) args.append(buf) - total_bytes += buf.bo.size() + total_bytes += buf.buffer_object().size() elif spec.direction == "out": try: name, expected = next(output_iter) except StopIteration: raise ValueError("Not enough output buffers provided for arg spec") - buf = AIEBuffer(shape=spec.shape, dtype=spec.dtype) + buf = XRTTensor(spec.shape, dtype=spec.dtype) args.append(buf) output_map[name] = buf - total_bytes += buf.bo.size() + total_bytes += buf.buffer_object().size() else: # Handle other directions if needed, or raise error raise ValueError(f"Unsupported direction: {spec.direction}") @@ -143,7 +144,7 @@ def run_test( continue if buf_name in output_map: buf = output_map[buf_name] - output_np = buf.view_as_np() + output_np = buf.numpy() buf_errors = verify_buffer(output_np, buf_name, expected, rel_tol, abs_tol) if buf_errors: errors[buf_name] = buf_errors diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 05496634..bc405c08 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -6,10 +6,10 @@ import numpy as np from ml_dtypes import bfloat16 +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor from iron.common import ( CompositeOperator, AIERuntimeArgSpec, - AIEBuffer, SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, @@ -65,19 +65,22 @@ def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): ) # Allocate and upload weights - self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1)) - self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2)) - self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3)) + w1 = torch_to_numpy(op.weights_1) + self.weights_1 = XRTTensor(w1, dtype=w1.dtype) + w2 = torch_to_numpy(op.weights_2) + self.weights_2 = XRTTensor(w2, dtype=w2.dtype) + w3 = torch_to_numpy(op.weights_3) + self.weights_3 = XRTTensor(w3, dtype=w3.dtype) # Allocate intermediate buffers # left: output of gemv_1 (hidden_dim_padded) - self.left = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + self.left = XRTTensor((op.hidden_dim_padded,), dtype=bfloat16) # right: output of gemv_1 (hidden_dim_padded) - self.right = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + self.right = XRTTensor((op.hidden_dim_padded,), dtype=bfloat16) # left_swished: output of silu (hidden_dim_padded) - self.left_swished = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + self.left_swished = XRTTensor((op.hidden_dim_padded,), dtype=bfloat16) # intermediate: output of eltwise_mul (hidden_dim_padded) - self.intermediate = AIEBuffer(shape=(op.hidden_dim_padded,), dtype=bfloat16) + self.intermediate = XRTTensor((op.hidden_dim_padded,), dtype=bfloat16) def __call__(self, input_buf, output_buf): # Ensure inputs are on device diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 65af7cf2..191653b4 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -8,8 +8,8 @@ from ml_dtypes import bfloat16 -from iron.common.base import AIEBuffer -from iron.common.utils import torch_to_numpy +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from iron.common.utils import torch_to_numpy, numpy_to_torch from iron.operators.swiglu_decode.op import AIESwiGLUDecode from iron.operators.swiglu_decode.reference import generate_golden_reference from iron.common.test_utils import verify_buffer @@ -42,14 +42,15 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): operator.compile() op_func = operator.get_callable() - input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) - output_buf = AIEBuffer(shape=(1, embedding_dim), dtype=bfloat16) + input_np = torch_to_numpy(golden_ref["input"]) + input_buf = XRTTensor(input_np, dtype=input_np.dtype) + output_buf = XRTTensor((1, embedding_dim), dtype=bfloat16) op_func(input_buf, output_buf) errors = {} # Verify intermediate result - intermediate = op_func.intermediate.view_as_torch().reshape((1, hidden_dim)) + intermediate = numpy_to_torch(op_func.intermediate.numpy()).reshape((1, hidden_dim)) errors_intermediate = verify_buffer( intermediate, "intermediate", @@ -62,7 +63,7 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): # Verify output using intermediate result ref_2 = intermediate @ golden_ref["w_down"] - output = output_buf.view_as_torch().reshape((1, embedding_dim)) + output = numpy_to_torch(output_buf.numpy()).reshape((1, embedding_dim)) errors_output = verify_buffer(output, "output", ref_2, rel_tol=0.04, abs_tol=0.4) if errors_output: errors["output"] = errors_output diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index d572c21f..e5a72213 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -6,10 +6,10 @@ import numpy as np from ml_dtypes import bfloat16 +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor from iron.common import ( CompositeOperator, AIERuntimeArgSpec, - AIEBuffer, SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, @@ -62,17 +62,20 @@ def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): ) # Allocate and upload weights - self.weights_1 = AIEBuffer.from_np(torch_to_numpy(op.weights_1.T)) - self.weights_2 = AIEBuffer.from_np(torch_to_numpy(op.weights_2.T)) - self.weights_3 = AIEBuffer.from_np(torch_to_numpy(op.weights_3.T)) + w1 = torch_to_numpy(op.weights_1.T) + self.weights_1 = XRTTensor(w1, dtype=w1.dtype) + w2 = torch_to_numpy(op.weights_2.T) + self.weights_2 = XRTTensor(w2, dtype=w2.dtype) + w3 = torch_to_numpy(op.weights_3.T) + self.weights_3 = XRTTensor(w3, dtype=w3.dtype) # Allocate intermediate buffers # Sizes are padded size_hidden = op.seq_len_padded * op.hidden_dim_padded - self.left = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) - self.right = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) - self.left_swished = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) - self.intermediate = AIEBuffer(shape=(size_hidden,), dtype=bfloat16) + self.left = XRTTensor((size_hidden,), dtype=bfloat16) + self.right = XRTTensor((size_hidden,), dtype=bfloat16) + self.left_swished = XRTTensor((size_hidden,), dtype=bfloat16) + self.intermediate = XRTTensor((size_hidden,), dtype=bfloat16) self.last_output_buf = None def __call__(self, input_buf, output_buf): diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 8bb8d14e..605d30fe 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -8,8 +8,8 @@ from ml_dtypes import bfloat16 -from iron.common.base import AIEBuffer -from iron.common.utils import torch_to_numpy +from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from iron.common.utils import torch_to_numpy, numpy_to_torch from iron.operators.swiglu_prefill.op import AIESwiGLUPrefill from iron.operators.swiglu_decode.reference import generate_golden_reference from iron.common.test_utils import verify_buffer @@ -47,9 +47,10 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c operator.compile() op_func = operator.get_callable() - input_buf = AIEBuffer.from_np(torch_to_numpy(golden_ref["input"])) - output_buf = AIEBuffer( - shape=(seq_len * embedding_dim,), dtype=bfloat16 + input_np = torch_to_numpy(golden_ref["input"]) + input_buf = XRTTensor(input_np, dtype=input_np.dtype) + output_buf = XRTTensor( + (seq_len * embedding_dim,), dtype=bfloat16 ) # Output is flattened op_func(input_buf, output_buf) @@ -57,12 +58,16 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c errors = {} # Verify intermediate result (left_swished * right) - left_swished = op_func.left_swished.view_as_torch().reshape((seq_len, hidden_dim)) - right = op_func.right.view_as_torch().reshape((seq_len, hidden_dim)) + left_swished = numpy_to_torch(op_func.left_swished.numpy()).reshape( + (seq_len, hidden_dim) + ) + right = numpy_to_torch(op_func.right.numpy()).reshape((seq_len, hidden_dim)) ref_2 = left_swished * right # Note: intermediate buffer in op_func stores the result of eltwise_mul - intermediate = op_func.intermediate.view_as_torch().reshape((seq_len, hidden_dim)) + intermediate = numpy_to_torch(op_func.intermediate.numpy()).reshape( + (seq_len, hidden_dim) + ) errors_2 = verify_buffer( intermediate, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4 ) @@ -71,7 +76,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c # Verify output using intermediate result ref_3 = intermediate @ golden_ref["w_down"] - output = output_buf.view_as_torch().reshape((seq_len, embedding_dim)) + output = numpy_to_torch(output_buf.numpy()).reshape((seq_len, embedding_dim)) errors_3 = verify_buffer(output, "output", ref_3, rel_tol=0.04, abs_tol=0.4) if errors_3: errors["output"] = errors_3 From cb32ff784c5e82b1ade393aa267e4f1f9913ed5b Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 17:08:37 -0700 Subject: [PATCH 96/99] more simplifications --- iron/common/__init__.py | 1 - iron/common/base.py | 48 +++-------------------------- iron/operators/swiglu_decode/op.py | 7 ++--- iron/operators/swiglu_prefill/op.py | 7 ++--- 4 files changed, 10 insertions(+), 53 deletions(-) diff --git a/iron/common/__init__.py b/iron/common/__init__.py index d51b5c65..45013acf 100644 --- a/iron/common/__init__.py +++ b/iron/common/__init__.py @@ -9,7 +9,6 @@ MLIROperator, CompositeOperator, CompositeCallable, - SingleXclbinCallable, AIERuntimeArgSpec, ) from .context import AIEContext diff --git a/iron/common/base.py b/iron/common/base.py index 48d083cd..78574d3b 100644 --- a/iron/common/base.py +++ b/iron/common/base.py @@ -13,9 +13,10 @@ import aie.utils.config from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor from aie.utils.hostruntime.tensor_class import Tensor +from aie.utils.npukernel import NPUKernel from . import compilation as comp from .context import AIEContext -from .device_manager import AIEDeviceManager, pyxrt +from .device_manager import pyxrt from .utils import numpy_to_torch, torch_to_numpy from .compilation import ( XclbinArtifact, @@ -153,11 +154,10 @@ def set_up_artifacts(self): self.add_artifacts([xclbin_artifact, insts_artifact]) def get_callable(self): - return SingleXclbinCallable( + return NPUKernel( xclbin_path=self.xclbin_artifact.filename, kernel_name=self.xclbin_artifact.kernel_name, - insts_bin_path=self.insts_artifact.filename, - args_spec=self.get_arg_spec(), + insts_path=self.insts_artifact.filename, ) @@ -176,46 +176,6 @@ def __init__(self, direction, shape, dtype=bfloat16): self.direction = direction -class SingleXclbinCallable: - def __init__( - self, xclbin_path, kernel_name, insts_bin_path, args_spec, device_manager=None - ): - self.device_manager = device_manager or AIEDeviceManager() - self.context, self.xrt_kernel = self.device_manager.get_context_and_kernel( - str(xclbin_path), kernel_name - ) - with open(str(insts_bin_path), "rb") as f: - instructions = np.frombuffer(f.read(), dtype=np.uint32) - self.insts_buffer = XRTTensor( - instructions, - dtype=np.uint32, - flags=pyxrt.bo.cacheable, - group_id=self.xrt_kernel.group_id(1), - ) - self.args_spec = args_spec - - def __call__(self, *buffers): - assert len(buffers) == len(self.args_spec) - # assert all( - # np.prod(buffers[i].shape) >= np.prod(self.args_spec[i].shape) and buffers[i].dtype == self.args_spec[i].dtype - # for i in range(len(buffers)) - # ), "Input buffer shapes or dtypes do not match expected argument specification." - self.insts_buffer.to("npu") - for buf in buffers: - buf.to("npu") - opcode = 3 - bos = [buffer.buffer_object() for buffer in buffers] - run = self.xrt_kernel( - opcode, - self.insts_buffer.buffer_object(), - self.insts_buffer.shape[0], - *bos, - ) - ret_code = run.wait() - if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: - raise RuntimeError(f"Kernel did not complete correctly: {ret_code}") - - class CompositeCallable: """Callable for executing a sequence of sub-operators""" diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index bc405c08..2b454ac6 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -7,10 +7,10 @@ from ml_dtypes import bfloat16 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.npukernel import NPUKernel from iron.common import ( CompositeOperator, AIERuntimeArgSpec, - SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -32,11 +32,10 @@ def __init__(self, op): # Helper to create callable from operator and artifacts def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): - return SingleXclbinCallable( + return NPUKernel( xclbin_path=xclbin_path, kernel_name=kernel_name, - insts_bin_path=insts_artifact.filename, - args_spec=sub_op.get_arg_spec(), + insts_path=insts_artifact.filename, ) self.gemv_1_callable = create_callable( diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index e5a72213..cf936c31 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -7,10 +7,10 @@ from ml_dtypes import bfloat16 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.npukernel import NPUKernel from iron.common import ( CompositeOperator, AIERuntimeArgSpec, - SingleXclbinCallable, XclbinArtifact, InstsBinArtifact, KernelObjectArtifact, @@ -29,11 +29,10 @@ def __init__(self, op): self.op = op def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): - return SingleXclbinCallable( + return NPUKernel( xclbin_path=xclbin_path, kernel_name=kernel_name, - insts_bin_path=insts_artifact.filename, - args_spec=sub_op.get_arg_spec(), + insts_path=insts_artifact.filename, ) self.gemm_1_callable = create_callable( From a5c0c2f18c6c7bc5bb45bac0c71d1a1354dc77c4 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 17:09:07 -0700 Subject: [PATCH 97/99] update comment --- iron/operators/swiglu_decode/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 2b454ac6..45e8bd4b 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -28,7 +28,7 @@ class SwiGLUDecodeCallable: def __init__(self, op): self.op = op # Create callables for sub-operators - # We need to manually construct SingleXclbinCallable because sub-operators weren't "compiled" in the standard way + # We need to manually construct NPUKernel because sub-operators weren't "compiled" in the standard way # Helper to create callable from operator and artifacts def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): From da3a0bfa0ea1dc9c3dc718d45afe4d420220cd71 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 17:19:48 -0700 Subject: [PATCH 98/99] simplify tensor conversion --- iron/common/test_utils.py | 3 +-- iron/operators/swiglu_decode/op.py | 9 +++------ iron/operators/swiglu_decode/test.py | 3 +-- iron/operators/swiglu_prefill/op.py | 9 +++------ iron/operators/swiglu_prefill/test.py | 3 +-- 5 files changed, 9 insertions(+), 18 deletions(-) diff --git a/iron/common/test_utils.py b/iron/common/test_utils.py index d04aa14f..07308e03 100644 --- a/iron/common/test_utils.py +++ b/iron/common/test_utils.py @@ -107,8 +107,7 @@ def run_test( name, data = next(input_iter) except StopIteration: raise ValueError("Not enough input buffers provided for arg spec") - data_np = torch_to_numpy(data) - buf = XRTTensor(data_np, dtype=data_np.dtype) + buf = XRTTensor.from_torch(data) args.append(buf) total_bytes += buf.buffer_object().size() elif spec.direction == "out": diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 45e8bd4b..ec990ede 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -64,12 +64,9 @@ def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): ) # Allocate and upload weights - w1 = torch_to_numpy(op.weights_1) - self.weights_1 = XRTTensor(w1, dtype=w1.dtype) - w2 = torch_to_numpy(op.weights_2) - self.weights_2 = XRTTensor(w2, dtype=w2.dtype) - w3 = torch_to_numpy(op.weights_3) - self.weights_3 = XRTTensor(w3, dtype=w3.dtype) + self.weights_1 = XRTTensor.from_torch(op.weights_1) + self.weights_2 = XRTTensor.from_torch(op.weights_2) + self.weights_3 = XRTTensor.from_torch(op.weights_3) # Allocate intermediate buffers # left: output of gemv_1 (hidden_dim_padded) diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 191653b4..f6565af0 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -42,8 +42,7 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): operator.compile() op_func = operator.get_callable() - input_np = torch_to_numpy(golden_ref["input"]) - input_buf = XRTTensor(input_np, dtype=input_np.dtype) + input_buf = XRTTensor.from_torch(golden_ref["input"]) output_buf = XRTTensor((1, embedding_dim), dtype=bfloat16) op_func(input_buf, output_buf) diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index cf936c31..9e7eb387 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -61,12 +61,9 @@ def create_callable(sub_op, xclbin_path, kernel_name, insts_artifact): ) # Allocate and upload weights - w1 = torch_to_numpy(op.weights_1.T) - self.weights_1 = XRTTensor(w1, dtype=w1.dtype) - w2 = torch_to_numpy(op.weights_2.T) - self.weights_2 = XRTTensor(w2, dtype=w2.dtype) - w3 = torch_to_numpy(op.weights_3.T) - self.weights_3 = XRTTensor(w3, dtype=w3.dtype) + self.weights_1 = XRTTensor.from_torch(op.weights_1.T) + self.weights_2 = XRTTensor.from_torch(op.weights_2.T) + self.weights_3 = XRTTensor.from_torch(op.weights_3.T) # Allocate intermediate buffers # Sizes are padded diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 605d30fe..3ae0804a 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -47,8 +47,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c operator.compile() op_func = operator.get_callable() - input_np = torch_to_numpy(golden_ref["input"]) - input_buf = XRTTensor(input_np, dtype=input_np.dtype) + input_buf = XRTTensor.from_torch(golden_ref["input"]) output_buf = XRTTensor( (seq_len * embedding_dim,), dtype=bfloat16 ) # Output is flattened From fa00154f500e8b350472b234e9b14480f56b2273 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Mon, 9 Feb 2026 17:55:35 -0700 Subject: [PATCH 99/99] fix compilation todos, split up compilation code, continue to clean up tensor logic --- iron/common/base.py | 16 - iron/common/compilation/__init__.py | 1 + iron/common/compilation/aie.py | 507 ++++++++++++++++++++++++++ iron/common/compilation/base.py | 414 +-------------------- iron/common/test_utils.py | 22 +- iron/common/utils.py | 54 ++- iron/operators/gemm/op.py | 2 - iron/operators/gemv/op.py | 1 - iron/operators/mha/op.py | 1 - iron/operators/rms_norm/op.py | 1 - iron/operators/swiglu_decode/op.py | 1 - iron/operators/swiglu_decode/test.py | 6 +- iron/operators/swiglu_prefill/op.py | 1 - iron/operators/swiglu_prefill/test.py | 14 +- 14 files changed, 571 insertions(+), 470 deletions(-) create mode 100644 iron/common/compilation/aie.py diff --git a/iron/common/base.py b/iron/common/base.py index 78574d3b..49d9bc25 100644 --- a/iron/common/base.py +++ b/iron/common/base.py @@ -17,7 +17,6 @@ from . import compilation as comp from .context import AIEContext from .device_manager import pyxrt -from .utils import numpy_to_torch, torch_to_numpy from .compilation import ( XclbinArtifact, InstsBinArtifact, @@ -83,21 +82,6 @@ def add_artifacts(self, artifacts): self.artifacts.add(artifact) -def sync_to_device(bos): - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - - -def sync_from_device(bos): - for bo in bos: - bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - - -def execute_runlist(runlist): - runlist.execute() - runlist.wait() - - class MLIROperator(AIEOperatorBase, ABC): """Base class for AIE-accelerated operations defined by a single MLIR source""" diff --git a/iron/common/compilation/__init__.py b/iron/common/compilation/__init__.py index 20823d02..9cca95c6 100644 --- a/iron/common/compilation/__init__.py +++ b/iron/common/compilation/__init__.py @@ -2,3 +2,4 @@ # SPDX-License-Identifier: Apache-2.0 from .base import * +from .aie import * diff --git a/iron/common/compilation/aie.py b/iron/common/compilation/aie.py new file mode 100644 index 00000000..7e784b1c --- /dev/null +++ b/iron/common/compilation/aie.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +import os.path +import importlib.util +from contextlib import nullcontext +from aie.extras.context import mlir_mod_ctx +from .base import ( + CompilationArtifact, + SourceArtifact, + CompilationRule, + ShellCompilationCommand, + PythonCallbackCompilationCommand, +) + +# AIE Artifacts +# ########################################################################## + + +class FullElfArtifact(CompilationArtifact): + def __init__(self, filename, mlir_input, dependencies): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.mlir_input = mlir_input + + +class XclbinArtifact(CompilationArtifact): + def __init__( + self, + filename, + mlir_input, + dependencies, + kernel_name="MLIR_AIE", + extra_flags=None, + xclbin_input=None, + ): + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.mlir_input = mlir_input + self.kernel_name = kernel_name + self.extra_flags = extra_flags if extra_flags is not None else [] + self.xclbin_input = xclbin_input + + +class InstsBinArtifact(CompilationArtifact): + def __init__(self, filename, mlir_input, dependencies, extra_flags=None): + self.mlir_input = mlir_input + if mlir_input not in dependencies: + dependencies = dependencies + [mlir_input] + super().__init__(filename, dependencies) + self.extra_flags = extra_flags if extra_flags is not None else [] + + +class KernelObjectArtifact(CompilationArtifact): + def __init__( + self, + filename, + dependencies, + extra_flags=None, + rename_symbols=None, + prefix_symbols=None, + ): + super().__init__(filename, dependencies) + self.extra_flags = extra_flags if extra_flags is not None else [] + self.rename_symbols = rename_symbols if rename_symbols is not None else {} + self.prefix_symbols = prefix_symbols + + +class KernelArchiveArtifact(CompilationArtifact): + pass + + +class PythonGeneratedMLIRArtifact(CompilationArtifact): + def __init__( + self, + filename, + import_path, + callback_fn, + callback_args=None, + callback_kwargs=None, + requires_context=False, + uses_kernel_archive=False, + kernel_archive=None, + ): + self.import_path = import_path + self.callback_fn = callback_fn + self.callback_args = callback_args if callback_args is not None else [] + self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} + self.requires_context = requires_context + dependencies = [SourceArtifact(import_path)] + super().__init__(filename, dependencies=dependencies) + + +# AIE Rules +# ########################################################################## + + +class GenerateMLIRFromPythonCompilationRule(CompilationRule): + def matches(self, graph): + return any(graph.get_worklist(PythonGeneratedMLIRArtifact)) + + def compile(self, graph): + """Generate MLIR from a Python callback that uses the MLIR bindings""" + commands = [] + worklist = graph.get_worklist(PythonGeneratedMLIRArtifact) + for artifact in worklist: + new_artifact = SourceArtifact(artifact.filename) + # To make Python capture variables in this closure by value, not by reference, use default arguments + callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir( + new_artifact, + import_path, + callback_fn, + callback_args, + callback_kwargs, + requires_context, + ) + commands.append(PythonCallbackCompilationCommand(callback)) + new_artifact.available = True + graph.replace(artifact, new_artifact) + return commands + + @staticmethod + def generate_mlir( + output_artifact, + import_path, + callback_fn, + callback_args=None, + callback_kwargs=None, + requires_context=False, + ): + # Import the Python source file + spec = importlib.util.spec_from_file_location( + Path(import_path).name, import_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context + ctx_callback = lambda: (mlir_mod_ctx() if requires_context else nullcontext()) + with ctx_callback() as ctx: + callback_function = getattr(module, callback_fn) + mlir_code = callback_function(*callback_args, **callback_kwargs) + # Stringify the generated MLIR + if requires_context: + mlir_code = str(ctx.module) + else: + mlir_code = str(mlir_code) + + with open(output_artifact.filename, "w") as f: + f.write(mlir_code) + + +class AieccCompilationRule(CompilationRule): + def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): + self.build_dir = build_dir + self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" + self.peano_dir = peano_dir + super().__init__(*args, **kwargs) + + +class AieccFullElfCompilationRule(AieccCompilationRule): + def matches(self, graph): + return any(graph.get_worklist(FullElfArtifact)) + + def compile(self, graph): + worklist = graph.get_worklist(FullElfArtifact) + commands = [] + + for artifact in worklist: + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + "--expand-load-pdis", + "--generate-full-elf", + "--full-elf-name", + os.path.abspath(artifact.filename), + os.path.abspath(artifact.mlir_input.filename), + ] + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) + artifact.available = True + + return commands + + +class AieccXclbinInstsCompilationRule(AieccCompilationRule): + def matches(self, graph): + return any(graph.get_worklist((XclbinArtifact, InstsBinArtifact))) + + def compile(self, graph): + # Group artifacts by their unique compilation configuration + xclbin_configs = {} + insts_configs = {} + worklist = graph.get_worklist((XclbinArtifact, InstsBinArtifact)) + + for artifact in worklist: + mlir_dependency = artifact.mlir_input + if isinstance(artifact, XclbinArtifact): + key = ( + mlir_dependency, + artifact.kernel_name, + tuple(artifact.extra_flags), + artifact.xclbin_input, + ) + xclbin_configs.setdefault(key, []).append(artifact) + elif isinstance(artifact, InstsBinArtifact): + key = (mlir_dependency, tuple(artifact.extra_flags)) + insts_configs.setdefault(key, []).append(artifact) + + commands = [] + handled_insts_configs = set() + + # Iterate through XCLBIN configurations + for xclbin_key, xclbin_artifacts in xclbin_configs.items(): + mlir_source, kernel_name, xclbin_flags, xclbin_input = xclbin_key + + # Try to find a matching InstsBin configuration (same MLIR source) + matching_insts_key = None + for insts_key in insts_configs: + if ( + insts_key not in handled_insts_configs + and insts_key[0] == mlir_source + ): + matching_insts_key = insts_key + break + + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + ] + + # Add XCLBIN flags + first_xclbin = xclbin_artifacts[0] + compile_cmd += list(xclbin_flags) + [ + "--aie-generate-xclbin", + "--xclbin-name=" + os.path.abspath(first_xclbin.filename), + "--xclbin-kernel-name=" + kernel_name, + ] + if xclbin_input is not None: + compile_cmd += [ + "--xclbin-input=" + os.path.abspath(xclbin_input.filename) + ] + + # Add InstsBin flags if matching config found + if matching_insts_key: + handled_insts_configs.add(matching_insts_key) + insts_artifacts = insts_configs[matching_insts_key] + first_insts = insts_artifacts[0] + compile_cmd += list(matching_insts_key[1]) + [ + "--aie-generate-npu", + "--npu-insts-name=" + os.path.abspath(first_insts.filename), + ] + + compile_cmd += [os.path.abspath(mlir_source.filename)] + + # If the MLIR source depends on a kernel archive, pass it to aiecc.py so it can be linked + if ( + isinstance(mlir_source, PythonGeneratedMLIRArtifact) + and "kernel_archive" in mlir_source.callback_kwargs + ): + compile_cmd.append( + os.path.abspath( + os.path.join( + self.build_dir, + mlir_source.callback_kwargs["kernel_archive"], + ) + ) + ) + + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) + + # Copy for other XCLBIN artifacts with same config + if len(xclbin_artifacts) > 1: + for copy_dest in xclbin_artifacts[1:]: + commands.append( + ShellCompilationCommand( + ["cp", first_xclbin.filename, copy_dest.filename] + ) + ) + + # Copy for other InstsBin artifacts with same config (if matched) + if matching_insts_key: + insts_artifacts = insts_configs[matching_insts_key] + if len(insts_artifacts) > 1: + first_insts = insts_artifacts[0] + for copy_dest in insts_artifacts[1:]: + commands.append( + ShellCompilationCommand( + ["cp", first_insts.filename, copy_dest.filename] + ) + ) + + # Handle remaining InstsBin configurations + for insts_key, insts_artifacts in insts_configs.items(): + if insts_key in handled_insts_configs: + continue + + mlir_source, insts_flags = insts_key + first_insts = insts_artifacts[0] + + compile_cmd = [ + "python", + str(self.aiecc_path), + "--no-compile-host", + "--no-xchesscc", + "--no-xbridge", + "--peano", + str(self.peano_dir), + "--dynamic-objFifos", + "--no-compile", + ] + + compile_cmd += list(insts_flags) + [ + "--aie-generate-npu", + "--npu-insts-name=" + os.path.abspath(first_insts.filename), + ] + + compile_cmd += [os.path.abspath(mlir_source.filename)] + + # If the MLIR source depends on a kernel archive, pass it to aiecc.py so it can be linked + if ( + isinstance(mlir_source, PythonGeneratedMLIRArtifact) + and "kernel_archive" in mlir_source.callback_kwargs + ): + compile_cmd.append( + os.path.abspath( + os.path.join( + self.build_dir, + mlir_source.callback_kwargs["kernel_archive"], + ) + ) + ) + + commands.append( + ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + ) + + # Copy for other InstsBin artifacts with same config + if len(insts_artifacts) > 1: + for copy_dest in insts_artifacts[1:]: + commands.append( + ShellCompilationCommand( + ["cp", first_insts.filename, copy_dest.filename] + ) + ) + + # Update graph + for artifact in worklist: + artifact.available = True + + return commands + + +class PeanoCompilationRule(CompilationRule): + def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): + self.peano_dir = peano_dir + self.mlir_aie_dir = mlir_aie_dir + super().__init__(*args, **kwargs) + + def matches(self, artifacts): + return any(artifacts.get_worklist(KernelObjectArtifact)) + + def compile(self, artifacts): + clang_path = Path(self.peano_dir) / "bin" / "clang++" + include_path = Path(self.mlir_aie_dir) / "include" + worklist = artifacts.get_worklist(KernelObjectArtifact) + commands = [] + for artifact in worklist: + if len(artifact.dependencies) != 1: + raise RuntimeError( + "Expected exactly one dependency (the C source code) for KernelObjectArtifact" + ) + source_file = artifact.dependencies[0] + if not isinstance(source_file, SourceArtifact): + raise RuntimeError( + "Expected KernelObject dependency to be a C source file" + ) + + cmd = ( + [ + str(clang_path), + "-O2", + "-std=c++20", + "--target=aie2p-none-unknown-elf", + "-Wno-parentheses", + "-Wno-attributes", + "-Wno-macro-redefined", + "-Wno-empty-body", + "-Wno-missing-template-arg-list-after-template-kw", + f"-I{str(include_path)}", + ] + + artifact.extra_flags + + ["-c", source_file.filename, "-o", artifact.filename] + ) + + commands.append(ShellCompilationCommand(cmd)) + if artifact.rename_symbols: + commands.extend(self._rename_symbols(artifact)) + if artifact.prefix_symbols: + commands.extend(self._prefix_symbols(artifact, artifact.prefix_symbols)) + artifact.available = True + + return commands + + def _rename_symbols(self, artifact): + objcopy_path = "llvm-objcopy-18" + cmd = [ + objcopy_path, + ] + for old_sym, new_sym in artifact.rename_symbols.items(): + cmd += [ + "--redefine-sym", + f"{old_sym}={new_sym}", + ] + cmd += [artifact.filename] + return [ShellCompilationCommand(cmd)] + + def _prefix_symbols(self, artifact, prefix): + objcopy_path = "llvm-objcopy-18" + nm_path = "llvm-nm-18" + symbol_map_file = artifact.filename + ".symbol_map" + + # Extract defined symbols and create symbol map + nm_cmd = [ + "sh", + "-c", + f"{nm_path} --defined-only --extern-only {artifact.filename} | " + f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}", + ] + + # Apply the renaming using the symbol map + objcopy_cmd = [ + objcopy_path, + "--redefine-syms=" + symbol_map_file, + artifact.filename, + ] + + return [ShellCompilationCommand(nm_cmd), ShellCompilationCommand(objcopy_cmd)] + + +class ArchiveCompilationRule(CompilationRule): + def __init__(self, peano_dir, *args, **kwargs): + self.peano_dir = peano_dir + super().__init__(*args, **kwargs) + + def matches(self, artifacts): + return any(artifacts.get_worklist(KernelArchiveArtifact)) + + def compile(self, artifacts): + """Create an archive (.a) from compiled object files""" + worklist = artifacts.get_worklist(KernelArchiveArtifact) + commands = [] + for artifact in worklist: + # Get archive filename from method + archive_path = artifact.filename + object_files = [ + dep.filename + for dep in artifact.dependencies + if isinstance(dep, KernelObjectArtifact) + ] + + # Try to find ar tool from PEANO, then system + ar_path = None + + if self.peano_dir: + # Peano has llvm-ar for archiving + peano_ar = Path(self.peano_dir) / "bin" / "llvm-ar" + if os.path.exists(peano_ar): + ar_path = peano_ar + + if ar_path is None: + raise RuntimeError( + "Could not find 'ar' tool in PEANO installation or system PATH" + ) + + cmd = [str(ar_path), "rcs", archive_path] + object_files + commands.append(ShellCompilationCommand(cmd)) + + # Check for duplicate symbol definitions in the archive + check_cmd = [ + "sh", + "-c", + f"nm {archive_path} | grep ' [TDR] ' | awk '{{print $3}}' | sort | uniq -d | " + f'if read sym; then echo "Error: Duplicate symbol in archive: $sym" >&2; exit 1; fi', + ] + commands.append(ShellCompilationCommand(check_cmd)) + + artifact.available = True + + return commands diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index fb4b2c4d..84db532b 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -34,12 +34,8 @@ from abc import ABC, abstractmethod from pathlib import Path import os.path -import zlib import logging import subprocess -import importlib.util -from contextlib import nullcontext -from aie.extras.context import mlir_mod_ctx import sys # Global Functions @@ -207,82 +203,6 @@ class SourceArtifact(CompilationArtifact): pass -class FullElfArtifact(CompilationArtifact): - def __init__(self, filename, mlir_input, dependencies): - if mlir_input not in dependencies: - dependencies = dependencies + [mlir_input] - super().__init__(filename, dependencies) - self.mlir_input = mlir_input - - -class XclbinArtifact(CompilationArtifact): - def __init__( - self, - filename, - mlir_input, - dependencies, - kernel_name="MLIR_AIE", - extra_flags=None, - xclbin_input=None, - ): - if mlir_input not in dependencies: - dependencies = dependencies + [mlir_input] - super().__init__(filename, dependencies) - self.mlir_input = mlir_input - self.kernel_name = kernel_name - self.extra_flags = extra_flags if extra_flags is not None else [] - self.xclbin_input = xclbin_input - - -class InstsBinArtifact(CompilationArtifact): - def __init__(self, filename, mlir_input, dependencies, extra_flags=None): - self.mlir_input = mlir_input - if mlir_input not in dependencies: - dependencies = dependencies + [mlir_input] - super().__init__(filename, dependencies) - self.extra_flags = extra_flags if extra_flags is not None else [] - - -class KernelObjectArtifact(CompilationArtifact): - def __init__( - self, - filename, - dependencies, - extra_flags=None, - rename_symbols=None, - prefix_symbols=None, - ): - super().__init__(filename, dependencies) - self.extra_flags = extra_flags if extra_flags is not None else [] - self.rename_symbols = rename_symbols if rename_symbols is not None else {} - self.prefix_symbols = prefix_symbols - - -class KernelArchiveArtifact(CompilationArtifact): - pass - - -class PythonGeneratedMLIRArtifact(CompilationArtifact): - def __init__( - self, - filename, - import_path, - callback_fn, - callback_args=None, - callback_kwargs=None, - requires_context=False, - uses_kernel_archive=False, - kernel_archive=None, - ): - self.import_path = import_path - self.callback_fn = callback_fn - self.callback_args = callback_args if callback_args is not None else [] - self.callback_kwargs = callback_kwargs if callback_kwargs is not None else {} - self.requires_context = requires_context - dependencies = [SourceArtifact(import_path)] - super().__init__(filename, dependencies=dependencies) - - # Compilation Command # ########################################################################## @@ -354,329 +274,31 @@ def compile(self, artifacts: CompilationArtifactGraph) -> list[CompilationComman pass -class GenerateMLIRFromPythonCompilationRule(CompilationRule): - def matches(self, graph): - return any(graph.get_worklist(PythonGeneratedMLIRArtifact)) +class BatchRule(CompilationRule): + """ + A helper class for rules that process all available artifacts of a certain type in one go. + Subclasses should define `artifact_type` and implement `create_commands`. + """ - def compile(self, graph): - """Generate MLIR from a Python callback that uses the MLIR bindings""" - commands = [] - worklist = graph.get_worklist(PythonGeneratedMLIRArtifact) - for artifact in worklist: - new_artifact = SourceArtifact(artifact.filename) - # To make Python capture variables in this closure by value, not by reference, use default arguments - callback = lambda new_artifact=new_artifact, import_path=artifact.import_path, callback_fn=artifact.callback_fn, callback_args=artifact.callback_args, callback_kwargs=artifact.callback_kwargs, requires_context=artifact.requires_context: self.generate_mlir( - new_artifact, - import_path, - callback_fn, - callback_args, - callback_kwargs, - requires_context, - ) - commands.append(PythonCallbackCompilationCommand(callback)) - new_artifact.available = True - graph.replace(artifact, new_artifact) - return commands + artifact_type = None - @staticmethod - def generate_mlir( - output_artifact, - import_path, - callback_fn, - callback_args=None, - callback_kwargs=None, - requires_context=False, - ): - # Import the Python source file - spec = importlib.util.spec_from_file_location( - Path(import_path).name, import_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # We only initiate an MLIR context if requested; otherwise, it is expected that the callback creates the context - ctx_callback = lambda: (mlir_mod_ctx() if requires_context else nullcontext()) - with ctx_callback() as ctx: - callback_function = getattr(module, callback_fn) - mlir_code = callback_function(*callback_args, **callback_kwargs) - # Stringify the generated MLIR - if requires_context: - mlir_code = str(ctx.module) - else: - mlir_code = str(mlir_code) - - with open(output_artifact.filename, "w") as f: - f.write(mlir_code) - - -class AieccCompilationRule(CompilationRule, ABC): - def __init__(self, build_dir, peano_dir, mlir_aie_dir, *args, **kwargs): - self.build_dir = build_dir - self.aiecc_path = Path(mlir_aie_dir) / "bin" / "aiecc.py" - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - -class AieccFullElfCompilationRule(AieccCompilationRule): def matches(self, graph): - return any(graph.get_worklist(FullElfArtifact)) - - def compile(self, graph): - worklist = graph.get_worklist(FullElfArtifact) - commands = [] - - for artifact in worklist: - compile_cmd = [ - "python", - str(self.aiecc_path), - "--no-compile-host", - "--no-xchesscc", - "--no-xbridge", - "--peano", - str(self.peano_dir), - "--dynamic-objFifos", - "--expand-load-pdis", - "--generate-full-elf", - "--full-elf-name", - os.path.abspath(artifact.filename), - os.path.abspath(artifact.mlir_input.filename), - ] - commands.append( - ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) + if self.artifact_type is None: + raise NotImplementedError( + "Subclasses of BatchRule must define artifact_type" ) - artifact.available = True - - return commands - - -class AieccXclbinInstsCompilationRule(AieccCompilationRule): - def matches(self, graph): - return any(graph.get_worklist((XclbinArtifact, InstsBinArtifact))) + return any(graph.get_worklist(self.artifact_type)) def compile(self, graph): - # If there are both xclbin and insts.bin targets based on the same source MLIR code, we can combine them into one single `aiecc.py` invocation. - mlir_sources = set() - mlir_sources_to_xclbins = {} - mlir_sources_to_insts = {} - worklist = graph.get_worklist((XclbinArtifact, InstsBinArtifact)) - for artifact in worklist: - mlir_dependency = artifact.mlir_input - mlir_sources.add(mlir_dependency) - if isinstance(artifact, XclbinArtifact): - mlir_sources_to_xclbins.setdefault(mlir_dependency, []).append(artifact) - elif isinstance(artifact, InstsBinArtifact): - mlir_sources_to_insts.setdefault(mlir_dependency, []).append(artifact) - - commands = [] - # Now we know for each mlir source if we need to generate an xclbin, an insts.bin or both for it - for mlir_source in mlir_sources: - compile_cmd = [ - "python", - str(self.aiecc_path), - "--no-compile-host", - "--no-xchesscc", - "--no-xbridge", - "--peano", - str(self.peano_dir), - "--dynamic-objFifos", - ] - do_compile_xclbin = mlir_source in mlir_sources_to_xclbins - do_compile_insts_bin = mlir_source in mlir_sources_to_insts - if do_compile_xclbin: - first_xclbin = mlir_sources_to_xclbins[mlir_source][ - 0 - ] # TODO: this does not handle the case of multiple xclbins with different kernel names or flags from the same MLIR - compile_cmd += first_xclbin.extra_flags + [ - "--aie-generate-xclbin", - "--xclbin-name=" + os.path.abspath(first_xclbin.filename), - "--xclbin-kernel-name=" + first_xclbin.kernel_name, - ] - if first_xclbin.xclbin_input is not None: - compile_cmd += [ - "--xclbin-input=" - + os.path.abspath(first_xclbin.xclbin_input.filename) - ] - if do_compile_insts_bin: - first_insts_bin = mlir_sources_to_insts[mlir_source][ - 0 - ] # TODO: this does not handle the case of multiple insts.bins with different flags from the same MLIR - if not do_compile_xclbin: - compile_cmd += ["--no-compile"] - compile_cmd += first_insts_bin.extra_flags + [ - "--aie-generate-npu", - "--npu-insts-name=" + os.path.abspath(first_insts_bin.filename), - ] - compile_cmd += [os.path.abspath(mlir_source.filename)] - - # If the MLIR source depends on a kernel archive, pass it to aiecc.py so it can be linked - if ( - isinstance(mlir_source, PythonGeneratedMLIRArtifact) - and "kernel_archive" in mlir_source.callback_kwargs - ): - compile_cmd.append( - os.path.abspath( - os.path.join( - self.build_dir, - mlir_source.callback_kwargs["kernel_archive"], - ) - ) - ) - - commands.append( - ShellCompilationCommand(compile_cmd, cwd=str(self.build_dir)) - ) - - # There may be multiple targets that require an xclbin/insts.bin from the same MLIR with different names; copy them - for sources_to in [mlir_sources_to_xclbins, mlir_sources_to_insts]: - if sources_to.get(mlir_source, [])[1:]: - copy_src = sources_to[mlir_source][0] - for copy_dest in sources_to[mlir_source][1:]: - commands.append( - ShellCompilationCommand( - ["cp", copy_src.filename, copy_dest.filename] - ) - ) - - # Update graph + worklist = graph.get_worklist(self.artifact_type) + commands = self.create_commands(worklist) for artifact in worklist: artifact.available = True - return commands - -class PeanoCompilationRule(CompilationRule): - def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs): - self.peano_dir = peano_dir - self.mlir_aie_dir = mlir_aie_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any(artifacts.get_worklist(KernelObjectArtifact)) - - def compile(self, artifacts): - clang_path = Path(self.peano_dir) / "bin" / "clang++" - include_path = Path(self.mlir_aie_dir) / "include" - worklist = artifacts.get_worklist(KernelObjectArtifact) - commands = [] - for artifact in worklist: - if len(artifact.dependencies) != 1: - raise RuntimeError( - "Expected exactly one dependency (the C source code) for KernelObjectArtifact" - ) - source_file = artifact.dependencies[0] - if not isinstance(source_file, SourceArtifact): - raise RuntimeError( - "Expected KernelObject dependency to be a C source file" - ) - - cmd = ( - [ - str(clang_path), - "-O2", - "-std=c++20", - "--target=aie2p-none-unknown-elf", - "-Wno-parentheses", - "-Wno-attributes", - "-Wno-macro-redefined", - "-Wno-empty-body", - "-Wno-missing-template-arg-list-after-template-kw", - f"-I{str(include_path)}", - ] - + artifact.extra_flags - + ["-c", source_file.filename, "-o", artifact.filename] - ) - - commands.append(ShellCompilationCommand(cmd)) - if artifact.rename_symbols: - commands.extend(self._rename_symbols(artifact)) - if artifact.prefix_symbols: - commands.extend(self._prefix_symbols(artifact, artifact.prefix_symbols)) - artifact.available = True - - return commands - - def _rename_symbols(self, artifact): - objcopy_path = "llvm-objcopy-18" - cmd = [ - objcopy_path, - ] - for old_sym, new_sym in artifact.rename_symbols.items(): - cmd += [ - "--redefine-sym", - f"{old_sym}={new_sym}", - ] - cmd += [artifact.filename] - return [ShellCompilationCommand(cmd)] - - def _prefix_symbols(self, artifact, prefix): - objcopy_path = "llvm-objcopy-18" - nm_path = "llvm-nm-18" - symbol_map_file = artifact.filename + ".symbol_map" - - # Extract defined symbols and create symbol map - nm_cmd = [ - "sh", - "-c", - f"{nm_path} --defined-only --extern-only {artifact.filename} | " - f"awk '{{print $3 \" {prefix}\" $3}}' > {symbol_map_file}", - ] - - # Apply the renaming using the symbol map - objcopy_cmd = [ - objcopy_path, - "--redefine-syms=" + symbol_map_file, - artifact.filename, - ] - - return [ShellCompilationCommand(nm_cmd), ShellCompilationCommand(objcopy_cmd)] - - -class ArchiveCompilationRule(CompilationRule): - def __init__(self, peano_dir, *args, **kwargs): - self.peano_dir = peano_dir - super().__init__(*args, **kwargs) - - def matches(self, artifacts): - return any(artifacts.get_worklist(KernelArchiveArtifact)) - - def compile(self, artifacts): - """Create an archive (.a) from compiled object files""" - worklist = artifacts.get_worklist(KernelArchiveArtifact) - commands = [] - for artifact in worklist: - # Get archive filename from method - archive_path = artifact.filename - object_files = [ - dep.filename - for dep in artifact.dependencies - if isinstance(dep, KernelObjectArtifact) - ] - - # Try to find ar tool from PEANO, then system - ar_path = None - - if self.peano_dir: - # Peano has llvm-ar for archiving - peano_ar = Path(self.peano_dir) / "bin" / "llvm-ar" - if os.path.exists(peano_ar): - ar_path = peano_ar - - if ar_path is None: - raise RuntimeError( - "Could not find 'ar' tool in PEANO installation or system PATH" - ) - - cmd = [str(ar_path), "rcs", archive_path] + object_files - commands.append(ShellCompilationCommand(cmd)) - - # Check for duplicate symbol definitions in the archive - check_cmd = [ - "sh", - "-c", - f"nm {archive_path} | grep ' [TDR] ' | awk '{{print $3}}' | sort | uniq -d | " - f'if read sym; then echo "Error: Duplicate symbol in archive: $sym" >&2; exit 1; fi', - ] - commands.append(ShellCompilationCommand(check_cmd)) - - artifact.available = True - - return commands + def create_commands(self, artifacts): + """ + Create compilation commands for the given list of artifacts. + Must be implemented by subclasses. + """ + raise NotImplementedError diff --git a/iron/common/test_utils.py b/iron/common/test_utils.py index 07308e03..0b8c8da2 100644 --- a/iron/common/test_utils.py +++ b/iron/common/test_utils.py @@ -4,7 +4,7 @@ import time import numpy as np from ml_dtypes import bfloat16 -from .utils import torch_to_numpy +from .utils import xrt_to_torch import logging from .base import MLIROperator, CompositeOperator from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor @@ -33,22 +33,22 @@ def nearly_equal( def verify_buffer(output, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): errors = [] - expected_np = torch_to_numpy(reference).reshape((-1,)) + expected = reference.reshape((-1,)) output = output.reshape((-1,)) - if len(output) < len(expected_np): + if len(output) < len(expected): # Allow larger buffers - binning may have allocated more space than needed print( - f"Buffer size mismatch for {buf_name}: expected {len(expected_np)}, got {len(output)}" + f"Buffer size mismatch for {buf_name}: expected {len(expected)}, got {len(output)}" ) - errors.extend(i for i in range(abs(len(output) - len(expected_np)))) - compare_len = min(len(output), len(expected_np)) + errors.extend(i for i in range(abs(len(output) - len(expected)))) + compare_len = min(len(output), len(expected)) for i in range(compare_len): - if not nearly_equal(float(output[i]), float(expected_np[i]), rel_tol, abs_tol): + if not nearly_equal(float(output[i]), float(expected[i]), rel_tol, abs_tol): errors.append(i) if len(errors) <= 10: print( - f"Mismatch in {buf_name}[{i}]: expected {float(expected_np[i]):.6f}, got {float(output[i]):.6f}" + f"Mismatch in {buf_name}[{i}]: expected {float(expected[i]):.6f}, got {float(output[i]):.6f}" ) return errors @@ -143,8 +143,10 @@ def run_test( continue if buf_name in output_map: buf = output_map[buf_name] - output_np = buf.numpy() - buf_errors = verify_buffer(output_np, buf_name, expected, rel_tol, abs_tol) + output_torch = xrt_to_torch(buf) + buf_errors = verify_buffer( + output_torch, buf_name, expected, rel_tol, abs_tol + ) if buf_errors: errors[buf_name] = buf_errors else: diff --git a/iron/common/utils.py b/iron/common/utils.py index 9037fbd8..9966b1dd 100644 --- a/iron/common/utils.py +++ b/iron/common/utils.py @@ -21,32 +21,28 @@ } -def torch_to_numpy(tensor: torch.Tensor) -> np.ndarray: - # Detach (to drop grad) and ensure on CPU - t = tensor.detach() - if t.device.type != "cpu": - t = t.cpu() - # Ensure contiguous for safe view operations - if not t.is_contiguous(): - t = t.contiguous() - - if t.dtype == torch.bfloat16: - # View the same memory as uint16, then as NumPy bfloat16 - # This avoids numeric conversion and extra passes over memory. - u16_np = t.view(torch.uint16).numpy() # shares memory - return u16_np.view(np.dtype("bfloat16")) # reinterpret - - return t.numpy() - - -def numpy_to_torch(array: np.ndarray) -> torch.Tensor: - # Ensure contiguous to let from_numpy create a view - if not array.flags["C_CONTIGUOUS"]: - array = np.ascontiguousarray(array) - - if array.dtype == np.dtype("bfloat16"): - # reinterpret the same memory as uint16, then view as torch.bfloat16 - t_u16 = torch.from_numpy(array.view(np.uint16)) - return t_u16.view(torch.bfloat16) # view - - return torch.from_numpy(array) +def xrt_to_torch(xrttensor) -> torch.Tensor: + """ + Convert an XRTTensor (or compatible object with buffer_object()) to a Torch tensor + without intermediate numpy array creation, supporting bfloat16. + """ + dtype_map = { + np.dtype("float32"): torch.float32, + np.dtype("int32"): torch.int32, + np.dtype("int16"): torch.int16, + np.dtype("int8"): torch.int8, + np.dtype("uint8"): torch.uint8, + np.dtype("float16"): torch.float16, + np.dtype(bfloat16): torch.bfloat16, + bfloat16: torch.bfloat16, + } + + torch_dtype = dtype_map.get(xrttensor.dtype) + if torch_dtype is None: + raise ValueError(f"Unsupported dtype: {xrttensor.dtype}") + + xrttensor.to("cpu") + bo = xrttensor.buffer_object() + mem = bo.map() + t = torch.frombuffer(mem, dtype=torch_dtype) + return t.reshape(xrttensor.shape) diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index 841fef03..0c087ad8 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -17,8 +17,6 @@ PythonGeneratedMLIRArtifact, ) -from iron.common.utils import torch_to_numpy, numpy_to_torch - class AIEGEMM(MLIROperator): """AIE-accelerated General Matrix Multiplication (GEMM) layer""" diff --git a/iron/operators/gemv/op.py b/iron/operators/gemv/op.py index 45c691f0..a96910ac 100644 --- a/iron/operators/gemv/op.py +++ b/iron/operators/gemv/op.py @@ -16,7 +16,6 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) -from iron.common.utils import torch_to_numpy class AIEGEMV(MLIROperator): diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index 950614a8..9fa8c0ba 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -17,7 +17,6 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) -from iron.common.utils import torch_to_numpy, numpy_to_torch class AIEMHA(MLIROperator): diff --git a/iron/operators/rms_norm/op.py b/iron/operators/rms_norm/op.py index 5ca06e88..fe8c8aa9 100644 --- a/iron/operators/rms_norm/op.py +++ b/iron/operators/rms_norm/op.py @@ -17,7 +17,6 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) -from iron.common.utils import torch_to_numpy class AIERMSNorm(MLIROperator): diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index ec990ede..0fb0969b 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -21,7 +21,6 @@ from iron.operators.gemv.op import AIEGEMV from iron.operators.silu.op import AIESiLU from iron.operators.elementwise_mul.op import AIEElementwiseMul -from iron.common.utils import torch_to_numpy class SwiGLUDecodeCallable: diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index f6565af0..606c082f 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -9,7 +9,7 @@ from ml_dtypes import bfloat16 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor -from iron.common.utils import torch_to_numpy, numpy_to_torch +from iron.common.utils import xrt_to_torch from iron.operators.swiglu_decode.op import AIESwiGLUDecode from iron.operators.swiglu_decode.reference import generate_golden_reference from iron.common.test_utils import verify_buffer @@ -49,7 +49,7 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): errors = {} # Verify intermediate result - intermediate = numpy_to_torch(op_func.intermediate.numpy()).reshape((1, hidden_dim)) + intermediate = xrt_to_torch(op_func.intermediate).reshape((1, hidden_dim)) errors_intermediate = verify_buffer( intermediate, "intermediate", @@ -62,7 +62,7 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): # Verify output using intermediate result ref_2 = intermediate @ golden_ref["w_down"] - output = numpy_to_torch(output_buf.numpy()).reshape((1, embedding_dim)) + output = xrt_to_torch(output_buf).reshape((1, embedding_dim)) errors_output = verify_buffer(output, "output", ref_2, rel_tol=0.04, abs_tol=0.4) if errors_output: errors["output"] = errors_output diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index 9e7eb387..9a19bb20 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -21,7 +21,6 @@ from iron.operators.gemm.op import AIEGEMM from iron.operators.silu.op import AIESiLU from iron.operators.elementwise_mul.op import AIEElementwiseMul -from iron.common.utils import torch_to_numpy class SwiGLUPrefillCallable: diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 3ae0804a..583dbdcf 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -9,7 +9,7 @@ from ml_dtypes import bfloat16 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor -from iron.common.utils import torch_to_numpy, numpy_to_torch +from iron.common.utils import xrt_to_torch from iron.operators.swiglu_prefill.op import AIESwiGLUPrefill from iron.operators.swiglu_decode.reference import generate_golden_reference from iron.common.test_utils import verify_buffer @@ -57,16 +57,12 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c errors = {} # Verify intermediate result (left_swished * right) - left_swished = numpy_to_torch(op_func.left_swished.numpy()).reshape( - (seq_len, hidden_dim) - ) - right = numpy_to_torch(op_func.right.numpy()).reshape((seq_len, hidden_dim)) + left_swished = xrt_to_torch(op_func.left_swished).reshape((seq_len, hidden_dim)) + right = xrt_to_torch(op_func.right).reshape((seq_len, hidden_dim)) ref_2 = left_swished * right # Note: intermediate buffer in op_func stores the result of eltwise_mul - intermediate = numpy_to_torch(op_func.intermediate.numpy()).reshape( - (seq_len, hidden_dim) - ) + intermediate = xrt_to_torch(op_func.intermediate).reshape((seq_len, hidden_dim)) errors_2 = verify_buffer( intermediate, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4 ) @@ -75,7 +71,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c # Verify output using intermediate result ref_3 = intermediate @ golden_ref["w_down"] - output = numpy_to_torch(output_buf.numpy()).reshape((seq_len, embedding_dim)) + output = xrt_to_torch(output_buf).reshape((seq_len, embedding_dim)) errors_3 = verify_buffer(output, "output", ref_3, rel_tol=0.04, abs_tol=0.4) if errors_3: errors["output"] = errors_3