// // Copyright (c) 2023 Apple Inc. All rights reserved. // Provided subject to the LICENSE file in the top level directory. // namespace mpsgraph; // Update after any BC breaking changes file_identifier "MP00"; // datatype for mps-values enum MPSDataType : short { mps_data_type_invalid = 0, mps_data_type_float16 = 1, mps_data_type_float32 = 2, mps_data_type_float64 = 3, mps_data_type_bfloat16 = 4, // Signed integers. mps_data_type_int4 = 5, mps_data_type_int8 = 6, mps_data_type_int16 = 7, mps_data_type_int32 = 8, mps_data_type_int64 = 9, // Unsigned integers. range: [0, UTYPE_MAX] mps_data_type_uint4 = 10, mps_data_type_uint8 = 11, mps_data_type_uint16 = 12, mps_data_type_uint32 = 13, mps_data_type_uint64 = 14, mps_data_type_bool = 15, mps_data_type_complex_float16 = 16, mps_data_type_complex_float32 = 17, } // ops like index.Tensor and index.put are currentely implemented as // Metal kernels for unsupported MPSGraph cases. enum OpType : short { mps_graph, metal_kernel } // Helper classes to define the number of input and output tensors for a node. // Not meant to be used directly. // A node with one input and one output. table _MPSNode1x1 { input1_id:int; output_id:int; } // A node with two inputs and one output. table _MPSNode2x1 { input1_id:int; input2_id:int; output_id:int; } table _MPSDivNode2x1 { input1_id:int; input2_id:int; output_id:int; rounding_mode:string; } table _MPSNodeWithAlpha2x1 { input1_id:int; input2_id:int; output_id:int; alpha:float; } // A node with three inputs and one output. table _MPSNode3x1 { input1_id:int; input2_id:int; input3_id:int; output_id:int; } table MPSMinMax { min_value:float; max_value:float; } table MPSPooling2D { input1_id:int; kernel_height:int; kernel_width:int; stride_height:int; stride_width:int; padding_left:int; padding_right:int; padding_top:int; padding_bottom:int; dilation_height:int; dilation_width:int; ceil_mode:bool; count_include_pad:bool; divisor_override:int; output1_id:int; output2_id:int; } // Activation ops. table MPSHardTanh { input1_id:int; output_id:int; min_value:float; max_value:float; } table MPSGELU { input1_id:int; output_id:int; approximate:string; } table MPSLeakyReLU { input1_id:int; output_id:int; negative_slope:float; } table MPSSoftmax { input1_id:int; output_id:int; dim:int; half_to_float:bool; } // Clamp ops table MPSClamp { input1_id:int; output_id:int; } // Reduce ops table MPSMean { input1_id:int; output_id:int; num_dims:int; dims:[int]; keep_dims:bool; } // Indexing ops table MPSIndexSelect { input1_id:int; output_id:int; dim:int; index_id:int; } table MPSEmbedding { input1_id:int; input2_id:int; output_id:int; padding_idx:int; scale_grad_by_freq:bool; sparse:bool; } table MPSIndexTensor { input1_id:int; indices_id:[int]; output_id:int; } table MPSIndexPut { input1_id:int; indices_id:[int]; values_shape:[int]; values_id:int; output_id:int; } table MPSScatter { input1_id:int; output_id:int; dim:long; idx_id:int; src_id:int; } // Shape ops. table MPSPermute { input1_id:int; output_id:int; num_dims:int; perm:[int]; } table MPSView { input1_id:int; output_id:int; num_dims:int; shape:[int]; } table MPSCat { input_ids:[int]; output_id:int; dim:int; } table MPSSqueeze { input1_id:int; output_id:int; dims:[int]; } table MPSUnsqueeze { input1_id:int; output_id:int; dim:int; } table MPSSelect { input1_id:int; output_id:int; dim:int; index:int; } table MPSSlice { input1_id:int; output_id:int; dim:long; start:long; end:long; step:long; } table MPSPixelShuffle { input1_id:int; output_id:int; upscale_factor:int; } table MPSSplitWithSizes { input1_id:int; output_ids:[int]; split_sizes:[int]; dim:int; } table MPSCast { input1_id:int; output_id:int; dtype:MPSDataType; } // Linear algebra ops. table MPSAddmm { input1_id:int; input2_id:int; input3_id:int; output_id:int; beta:float; alpha:float; } // Constant ops table _MPSFull { input1_id:int; output_id:int; shape:[int]; fill_value: float; dtype:MPSDataType; } // Convolution ops. table MPSConv { input1_id:int; input2_id:int; input3_id:int; output_id:int; stride_x:int; stride_y:int; dilation_x:int; dilation_y:int; groups:int; padding_left:int; padding_right:int; padding_top:int; padding_bottom:int; } // Normalization ops. table MPSBatchNorm { input_id:int; mean_id:int; var_id:int; weight_id:int; bias_id:int; momentum:float; epsilon:float; output2_id:int; output1_id:int; output3_id:int; } table MPSLayerNorm { input1_id:int; normalized_shape:[int]; weight_id:int; bias_id:int; eps:float; output2_id:int; output1_id:int; output3_id:int; } // Pooling ops // Pad ops table MPSConstantPadND { input1_id:int; output_id:int; pad:[int]; value:float; } // Range ops table MPSArange { output_id:int; start:float; end:float; step:float; dtype:MPSDataType; } // Quant - Dequant ops table MPSDequantizePerChannelGroup { input1_id:int; output_id:int; scales_id:int; zero_points_id:int; quant_min:int; quant_max:int; dtype:MPSDataType; group_size:int; output_dtype:MPSDataType; } union MPSNodeUnion { // Activation ops MPSHardTanh, MPSReLU: _MPSNode2x1, MPSGELU, MPSLeakyReLU, MPSSoftmax, MPSLogSoftmax: MPSSoftmax, // Binary ops MPSAdd: _MPSNodeWithAlpha2x1, MPSSub: _MPSNodeWithAlpha2x1, MPSMul: _MPSNode2x1, MPSDiv: _MPSDivNode2x1, MPSFmod: _MPSDivNode2x1, MPSRemainder: _MPSDivNode2x1, MPSMin: _MPSNode2x1, MPSMax: _MPSNode2x1, MPSPow: _MPSNode2x1, MPSAtan2: _MPSNode2x1, MPSBitwiseAnd: _MPSNode2x1, MPSBitwiseOr: _MPSNode2x1, MPSBitwiseXor: _MPSNode2x1, MPSMinimum: _MPSNode2x1, // Unary ops MPSExp: _MPSNode1x1, MPSExp2: _MPSNode1x1, MPSReciprocal: _MPSNode1x1, MPSSqrt: _MPSNode1x1, MPSNeg: _MPSNode1x1, MPSLog: _MPSNode1x1, MPSLog10: _MPSNode1x1, MPSLog2: _MPSNode1x1, MPSErf: _MPSNode1x1, MPSFloor: _MPSNode1x1, MPSCeil: _MPSNode1x1, MPSRsqrt: _MPSNode1x1, MPSSigmoid: _MPSNode1x1, MPSSin: _MPSNode1x1, MPSSign: _MPSNode1x1, MPSCos: _MPSNode1x1, MPSTan: _MPSNode1x1, MPSAbs: _MPSNode1x1, MPSAsin: _MPSNode1x1, MPSAcos: _MPSNode1x1, MPSAtan: _MPSNode1x1, MPSSinh: _MPSNode1x1, MPSCosh: _MPSNode1x1, MPSTanh: _MPSNode1x1, MPSAsinh: _MPSNode1x1, MPSAcosh: _MPSNode1x1, MPSAtanh: _MPSNode1x1, MPSBitwiseNot: _MPSNode1x1, MPSIsnan: _MPSNode1x1, MPSIsinf: _MPSNode1x1, MPSRound: _MPSNode1x1, MPSLogicalNot: _MPSNode1x1, // Linear algebra ops MPSMatMul: _MPSNode2x1, MPSAddmm, // Constant ops MPSFull: _MPSFull, MPSFullLike: _MPSFull, // Clamp ops, MPSClamp, MPSWhere: _MPSNode3x1, // Indexing ops MPSIndexSelect, MPSEmbedding, MPSIndexTensor, MPSIndexPut, MPSScatter, // Reduce ops MPSMean, // Shape ops MPSPermute, MPSView, MPSExpand: MPSView, MPSCat, MPSSqueeze, MPSUnsqueeze, MPSSelect, MPSSlice, MPSPixelShuffle, MPSSplitWithSizes, MPSCast, // Convolution ops MPSConv2D: MPSConv, MPSDepthwiseConv2D: MPSConv, // Comparasion ops MPSEq: _MPSNode2x1, MPSNe: _MPSNode2x1, MPSGe: _MPSNode2x1, MPSGt: _MPSNode2x1, MPSLe: _MPSNode2x1, MPSLt: _MPSNode2x1, // Normalization ops MPSBatchNorm, MPSLayerNorm, // Pooling ops MPSMaxPool2DWithIndices: MPSPooling2D, MPSAvgPool2D: MPSPooling2D, // Pad ops MPSConstantPadND, // Range ops MPSArange, // Quant-Dequant ops MPSDequantizePerChannelGroup, } table MPSNode { mpsnode_union:MPSNodeUnion; min_max:MPSMinMax; } // taken from executorch // Data buffer abstraction. // Deprecated table Buffer { storage:[ubyte] (force_align: 16); } table MPSTensor { datatype:MPSDataType; num_dims:int; dims:[int]; constant_buffer_size:uint64; constant_buffer:Buffer; // deprecated segment_offset:uint64; } table DataSegment { // Segment offsets are relative to the segment base offset provided in // the extended file header. Segments will typically be aligned in a // way to make it possible to use mmap() to load them. offset: uint64; // The size in bytes of valid data starting at the offset. The segment // data may be followed by padding before the segment that follows it, // to make it easier to use mmap(). size: uint64; } table MPSGraph { // Schema version. version:string; mps_nodes:[MPSNode]; mps_values:[MPSTensor]; input_ids:[int]; output_ids:[int]; constant_ids:[int]; graph_type:OpType; constant_segment:DataSegment; } root_type MPSGraph;