# # Copyright (c) 2023 Apple Inc. All rights reserved. # Provided subject to the LICENSE file in the top level directory. # import logging from typing import ClassVar, Dict, final, List, Tuple import torch from executorch.backends.apple.mps.operators.node_visitor import ( get_node_visitors, NodeVisitor, process_output_node, process_placeholder_nodes, ) from executorch.backends.apple.mps.serialization.mps_graph_schema import ( Buffer, DataSegment, MPSGraph, MPSTensor, OpType, ) from executorch.backends.apple.mps.serialization.mps_graph_serialize import ( convert_to_flatbuffer, ) from executorch.exir._serialize._program import Cord from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, PreprocessResult, ) from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @final class MPSBackend(BackendDetails): @staticmethod def slice_len_max(s): assert s.start is not None assert s.stop is not None step = 1 if s.step is not None: step = s.step return max((s.stop - s.start) // step, 1) MAGIC_IX: ClassVar[slice] = slice(4, 8) DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16) DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24) # magic bytes that should be at the beginning of the header EXPECTED_MAGIC: ClassVar[bytes] = b"MP00" # The length of the header in bytes EXPECTED_LENGTH: ClassVar[int] = ( 4 + slice_len_max(MAGIC_IX) + slice_len_max(DATA_SEGMENT_OFFSET_IX) + slice_len_max(DATA_SEGMENT_SIZE_IX) ) @staticmethod def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: # The EdgeIR nodes are processed in the following order: # 1. Process first the input feeds to the graph (in the same # order as args from forward(*args)), and generate a unique # id for each input placeholder. Each input id is appended to # `input_ids` array from the FlatBuffer schema. # 2. Process the nodes the graph (e.g `call_function`). For each # EdgeIR node, create an equivalent MPS node in the FlatBuffer, # based on which the MPSGraph is constructed at runtime. During # this process, any visited constant in the EdgeIR is added to the # final MPS FlatBuffer schema. Each constant id is appended to the # `constant_ids` FlatBuffer schema. # 3. After all the inputs, nodes and constants are added to the # FlatBuffer graph, process the `output` nodes and add their id to # the `output_ids` array in the schema. mps_graph = MPSGraph( version="0", mps_nodes=[], mps_values=[], input_ids=[], output_ids=[], constant_ids=[], graph_type=OpType.mps_graph, constant_segment=DataSegment(0, 0), ) convert_model_to_fp16 = True for spec in compile_specs: if spec.key == "use_fp16": convert_model_to_fp16 = bool(list(bytes(spec.value))[0]) logging.debug(f"Convert model to FP16: {convert_model_to_fp16}") node_visitors = get_node_visitors(edge_program, convert_model_to_fp16) if logging.DEBUG >= logging.root.level: edge_program.graph.print_tabular() process_placeholder_nodes( edge_program, edge_program.graph_module, mps_graph, node_visitors["placeholder"], ) op_handler = { "call_function": MPSBackend.handle_call_function, "placeholder": MPSBackend.handle_placeholder, "output": MPSBackend.handle_output, "get_attr": MPSBackend.handle_get_attr, } for node in edge_program.graph_module.graph.nodes: if node.op not in op_handler: raise RuntimeError(f"{node.op} is not supported in MPS") else: op_handler[node.op](edge_program, node_visitors, node, mps_graph) segment_data, mps_graph = _extract_constant_segment(mps_graph) if logging.DEBUG >= logging.root.level: pretty_print(mps_graph) # Add to aggregate segments cord with padding. padding_length = _padding_required(len(segment_data), 16) if padding_length > 0: segment_data.append(b"\x00" * padding_length) # Combine mps_graph with segment data combined = Cord() graph_bytes = convert_to_flatbuffer(mps_graph) data_segment_offset: int = MPSBackend.EXPECTED_LENGTH data_segment_offset = data_segment_offset + len(graph_bytes) graph_padding_length = _padding_required(data_segment_offset, 16) data_segment_offset = data_segment_offset + graph_padding_length data_segment_size = len(segment_data) data: bytes = ( b"\x00\x00\x00\x00" + MPSBackend.EXPECTED_MAGIC + data_segment_offset.to_bytes(8, byteorder="little") + data_segment_size.to_bytes(8, byteorder="little") ) assert len(data) == MPSBackend.EXPECTED_LENGTH combined.append(data) combined.append(graph_bytes) if graph_padding_length > 0: combined.append(b"\x00" * graph_padding_length) # Append the segment data to the end of the mps graph combined.append(segment_data) return PreprocessResult(processed_bytes=bytes(combined)) @staticmethod def handle_call_function( _: ExportedProgram, node_visitors: Dict[str, NodeVisitor], node: torch.fx.Node, mps_graph: MPSGraph, ) -> None: logging.info(f"Visiting: {node}, {node.target.__name__}") if ( "delegation_tag" in node.meta and "metal_kernel" in node.meta["delegation_tag"] ): logging.info( f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!" ) mps_graph.graph_type = OpType.metal_kernel if node.target.__name__ in node_visitors: node_visitors[node.target.__name__].define_node(node, mps_graph) else: pretty_print(mps_graph) raise RuntimeError( f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate" ) @staticmethod def handle_placeholder( edge_program: ExportedProgram, node_visitors: Dict[str, NodeVisitor], node: torch.fx.Node, mps_graph: MPSGraph, ) -> None: # Constants are handled directly when visiting the nodes. pass @staticmethod def handle_output( edge_program: ExportedProgram, node_visitors: Dict[str, NodeVisitor], node: torch.fx.Node, mps_graph: MPSGraph, ) -> None: for output_nodes in node.args: for output_node in output_nodes: process_output_node(output_node, mps_graph, node_visitors[node.op]) @staticmethod def handle_get_attr( edge_program: ExportedProgram, node_visitors: Dict[str, NodeVisitor], node: torch.fx.Node, mps_graph: MPSGraph, ) -> None: pass def _padding_required(offset: int, alignment: int) -> int: """Returns the padding required to align `offset` to `alignment`.""" remainder: int = offset % alignment if remainder != 0: return alignment - remainder return 0 def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]: """Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data.""" # Note that the beginning of the segment data is not aligned. Need to handle out of this call. segment_data = Cord() offset = 0 for i in range(len(mps_graph.mps_values)): tensor = mps_graph.mps_values[i] if tensor.constant_buffer_size > 0: # Notice that buffer is already force aligned so we don't need to pad it segment_data.append(tensor.constant_buffer.storage) # Reset buffer to empty tensor.constant_buffer = Buffer(storage=b"") # Update segment offset tensor.segment_offset = offset offset += tensor.constant_buffer_size return segment_data, mps_graph def tensor_to_str(mps_tensor: MPSTensor): tensor_str = "MPSTensor(" tensor_str += "datatype=" + str(mps_tensor.datatype) + ", " tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", " tensor_str += "dims=" + str(mps_tensor.dims) + ", " tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", " tensor_str += "segment_offset=" + str(mps_tensor.segment_offset) tensor_str += ")" return tensor_str def pretty_print(mps_graph: MPSGraph): logging.info("Serialized MPSGraph:") logging.info(f" Version: {mps_graph.version}") logging.info(" MPS nodes: ") for i in range(len(mps_graph.mps_nodes)): logging.info(f" [{i}]: {mps_graph.mps_nodes[i]}") logging.info(" MPS values: ") for i in range(len(mps_graph.mps_values)): logging.info(f" [{i}]: {tensor_to_str(mps_graph.mps_values[i])}") logging.info(" Input ids:") for in_id in mps_graph.input_ids: logging.info(f" {in_id}") logging.info(" Constant ids:") for constant_id in mps_graph.constant_ids: logging.info(f" {constant_id}") logging.info(" Output ids:") for out_id in mps_graph.output_ids: logging.info(f" {out_id}") logging.info(f" Constant segment: {mps_graph.constant_segment}")