xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/QuantDequant.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2024 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
7
8namespace executorch {
9namespace backends {
10namespace mps {
11namespace delegate {
12
13Error
14MPSGraphBuilder::mpsDequantizePerChannelGroupOp(NodePtr nodePtr) {
15  auto graphNode = nodePtr->mpsnode_union_as_MPSDequantizePerChannelGroup();
16  ET_LOG(
17    Debug, "%s: (%d, %d, %d) -> %d",
18    __FUNCTION__,
19    graphNode->input1_id(),
20    graphNode->scales_id(),
21    graphNode->zero_points_id(),
22    graphNode->output_id()
23  );
24
25  ET_CHECK_OR_RETURN_ERROR(
26    is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS),
27    NotImplemented,
28    "[ERROR] Operation %s is supported starting with macOS 15.0+ | iOS 18.0 + | iPadOS 18+ | tvOS 18+ | visionOS 2.0+ !",
29    mpsgraph::EnumNameMPSNodeUnion(nodePtr->mpsnode_union_type()));
30
31  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
32  MPSGraphTensor* scalesTensor = getMPSGraphTensor(graphNode->scales_id());
33  if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, *)) {
34    MPSGraphTensor *zpTensor = [_mpsGraph constantWithScalar:0
35                                                  dataType:MPSDataTypeInt4];
36    MPSGraphTensor *wDqTensor = [_mpsGraph dequantizeTensor:inputTensor
37                                                scaleTensor:scalesTensor
38                                            zeroPointTensor:zpTensor
39                                                  dataType:MPSDataTypeFloat16
40                                                      name:nil];
41    _idToMPSGraphTensor[graphNode->output_id()] = wDqTensor;
42  } else {
43    _idToMPSGraphTensor[graphNode->output_id()] = nil;
44  }
45
46  return Error::Ok;
47}
48
49} // namespace delegate
50} // namespace mps
51} // namespace backends
52} // namespace executorch
53