Home
last modified time | relevance | path

Searched refs:axis_tensor (Results 1 – 12 of 12) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/core/kernels/
H A Dunique_op.cc100 const Tensor& axis_tensor = context->input(1); in Compute() local
101 OP_REQUIRES(context, TensorShapeUtils::IsVector(axis_tensor.shape()), in Compute()
104 context, axis_tensor.NumElements() <= 1, in Compute()
107 if (axis_tensor.NumElements() == 0) { in Compute()
112 (axis_tensor.dtype() == DT_INT32 || in Compute()
113 axis_tensor.dtype() == DT_INT64), in Compute()
116 DataTypeString(axis_tensor.dtype()))); in Compute()
117 if (axis_tensor.dtype() == DT_INT32) { in Compute()
118 axis = internal::SubtleMustCopy(axis_tensor.scalar<int32>()()); in Compute()
120 axis = internal::SubtleMustCopy(axis_tensor.scalar<int64_t>()()); in Compute()
H A Dgather_op.cc67 const Tensor& axis_tensor = c->input(2); in Compute() local
68 OP_REQUIRES(c, TensorShapeUtils::IsScalar(axis_tensor.shape()), in Compute()
71 if (axis_tensor.dtype() == DT_INT32) { in Compute()
72 axis = axis_tensor.scalar<int32>()(); in Compute()
73 } else if (axis_tensor.dtype() == DT_INT64) { in Compute()
74 axis = axis_tensor.scalar<int64_t>()(); in Compute()
/aosp_15_r20/external/tensorflow/tensorflow/lite/testing/op_tests/
H A Droll.py112 axis_tensor = tf.compat.v1.placeholder(
114 outs = tf.roll(input_tensor, shift_tensor, axis_tensor)
115 return [input_tensor, shift_tensor, axis_tensor], [outs]
/aosp_15_r20/external/tensorflow/tensorflow/lite/experimental/mlir/testing/op_tests/
H A Droll.py116 axis_tensor = tf.compat.v1.placeholder(
118 outs = tf.roll(input_tensor, shift_tensor, axis_tensor)
119 return [input_tensor, shift_tensor, axis_tensor], [outs]
/aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/
H A Dreverse.cc77 const TfLiteTensor* axis_tensor; in Eval() local
79 GetInputSafe(context, node, kAxisTensor, &axis_tensor)); in Eval()
80 int axis = GetTensorData<int32_t>(axis_tensor)[0]; in Eval()
H A Dcumsum.cc57 const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor); in Eval() local
63 int axis = *GetTensorData<int>(axis_tensor); in Eval()
/aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/nnapi/
H A Dnnapi_delegate.cc617 const auto& axis_tensor = context->tensors[node->inputs->data[2]]; in ComputeSplitVUnknownSplitSize() local
631 int axis = axis_tensor.data.i32[0]; in ComputeSplitVUnknownSplitSize()
2328 const auto& axis_tensor = context->tensors[node->inputs->data[1]]; in Validate() local
2329 if (axis_tensor.type == kTfLiteInt64) { in Validate()
2331 axis_tensor.allocation_type == kTfLiteMmapRo && in Validate()
2332 *axis_tensor.data.i64 <= std::numeric_limits<int32_t>::max() && in Validate()
2333 *axis_tensor.data.i64 >= std::numeric_limits<int32_t>::min(), in Validate()
2340 Expect(axis_tensor.type == kTfLiteInt32, in Validate()
5894 const TfLiteTensor& axis_tensor = context->tensors[axis_id]; in AddOpsAndTensors() local
5895 switch (axis_tensor.type) { in AddOpsAndTensors()
[all …]
/aosp_15_r20/external/tensorflow/tensorflow/core/ops/
H A Darray_ops.cc1021 const Tensor* axis_tensor = c->input_tensor(1); in __anon38bbb0e81402() local
1022 if (axis_tensor != nullptr && c->RankKnown(input)) { in __anon38bbb0e81402()
1025 if (axis_tensor->dtype() == DT_INT32) { in __anon38bbb0e81402()
1026 axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements()); in __anon38bbb0e81402()
1029 AsInt64<int64_t>(axis_tensor, axis_tensor->NumElements()); in __anon38bbb0e81402()
/aosp_15_r20/external/pytorch/torch/ao/nn/quantized/reference/modules/
H A Drnn.py125 axis_tensor = (
130 self.register_buffer(key + "_axis", axis_tensor)
/aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/
H A Darithmetic_optimizer.cc4214 Tensor axis_tensor; in IsAxis0() local
4215 if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor)) in IsAxis0()
4217 if (axis_tensor.NumElements() != 1) return false; in IsAxis0()
4218 if (axis_tensor.dtype() == DT_INT32) { in IsAxis0()
4219 return axis_tensor.flat<int32>()(0) == 0; in IsAxis0()
4220 } else if (axis_tensor.dtype() == DT_INT64) { in IsAxis0()
4221 return axis_tensor.flat<int64_t>()(0) == 0; in IsAxis0()
H A Dconstant_folding.cc3832 Tensor axis_tensor; in GetConcatAxis() local
3833 if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) { in GetConcatAxis()
3836 *axis = axis_tensor.dtype() == DT_INT64 in GetConcatAxis()
3837 ? static_cast<int>(axis_tensor.scalar<int64_t>()()) in GetConcatAxis()
3838 : axis_tensor.scalar<int32>()(); in GetConcatAxis()
/aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/
H A Dmodel_builder.cc781 const TfLiteTensor* axis_tensor = reader->GetInputTensor(1); in Parse() local
783 const int tflite_axis = GetTensorData<int32_t>(axis_tensor)[0]; in Parse()
2275 const TfLiteTensor* axis_tensor = reader->GetInputTensor(0); in Parse() local
2278 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis)); in Parse()
2317 const TfLiteTensor* axis_tensor = reader->GetInputTensor(2); in Parse() local
2320 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis)); in Parse()