1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport dataclasses 4*da0073e9SAndroid Build Coastguard Workerimport typing 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport yaml 9*da0073e9SAndroid Build Coastguard Workerfrom tools.autograd import gen_autograd_functions, load_derivatives 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom torchgen import dest 12*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import CppSignatureGroup, DispatcherSignature 13*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import native_function_manager 14*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import ( 15*da0073e9SAndroid Build Coastguard Worker get_native_function_declarations, 16*da0073e9SAndroid Build Coastguard Worker get_native_function_schema_registrations, 17*da0073e9SAndroid Build Coastguard Worker LineLoader, 18*da0073e9SAndroid Build Coastguard Worker static_dispatch, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 21*da0073e9SAndroid Build Coastguard Worker BackendIndex, 22*da0073e9SAndroid Build Coastguard Worker BackendMetadata, 23*da0073e9SAndroid Build Coastguard Worker DispatchKey, 24*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 25*da0073e9SAndroid Build Coastguard Worker Location, 26*da0073e9SAndroid Build Coastguard Worker NativeFunction, 27*da0073e9SAndroid Build Coastguard Worker OperatorName, 28*da0073e9SAndroid Build Coastguard Worker) 29*da0073e9SAndroid Build Coastguard Workerfrom torchgen.native_function_generation import add_generated_native_functions 30*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerclass TestCreateDerivative(unittest.TestCase): 34*da0073e9SAndroid Build Coastguard Worker def test_named_grads(self) -> None: 35*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse( 36*da0073e9SAndroid Build Coastguard Worker "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 37*da0073e9SAndroid Build Coastguard Worker ) 38*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker derivative = load_derivatives.create_derivative( 41*da0073e9SAndroid Build Coastguard Worker native_function, 42*da0073e9SAndroid Build Coastguard Worker formula="func_backward(grad_x, grad_y)", 43*da0073e9SAndroid Build Coastguard Worker var_names=(), 44*da0073e9SAndroid Build Coastguard Worker available_named_gradients=["grad_x", "grad_y"], 45*da0073e9SAndroid Build Coastguard Worker ) 46*da0073e9SAndroid Build Coastguard Worker self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"}) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def test_non_differentiable_output(self) -> None: 49*da0073e9SAndroid Build Coastguard Worker specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 50*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(specification) 51*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker _, differentiability_info = load_derivatives.create_differentiability_info( 54*da0073e9SAndroid Build Coastguard Worker defn_dict={ 55*da0073e9SAndroid Build Coastguard Worker "name": specification, 56*da0073e9SAndroid Build Coastguard Worker "dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}}, 57*da0073e9SAndroid Build Coastguard Worker }, 58*da0073e9SAndroid Build Coastguard Worker functions_by_signature={schema.signature(): [native_function]}, 59*da0073e9SAndroid Build Coastguard Worker functions_by_schema={specification: native_function}, 60*da0073e9SAndroid Build Coastguard Worker op_counter=typing.Counter[str](), 61*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys=set(), 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual( 65*da0073e9SAndroid Build Coastguard Worker differentiability_info["Default"].available_named_gradients, 66*da0073e9SAndroid Build Coastguard Worker # grad_y is not present because y is a 67*da0073e9SAndroid Build Coastguard Worker # bool and thus not differentiable. 68*da0073e9SAndroid Build Coastguard Worker ["grad_x", "grad_z"], 69*da0073e9SAndroid Build Coastguard Worker ) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def test_indexed_grads(self) -> None: 72*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse( 73*da0073e9SAndroid Build Coastguard Worker "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker derivative = load_derivatives.create_derivative( 78*da0073e9SAndroid Build Coastguard Worker native_function, 79*da0073e9SAndroid Build Coastguard Worker formula="func_backward(grads[0], grads[1])", 80*da0073e9SAndroid Build Coastguard Worker var_names=(), 81*da0073e9SAndroid Build Coastguard Worker available_named_gradients=["grad_x", "grad_y"], 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker self.assertSetEqual(derivative.named_gradients, set()) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def test_named_grads_and_indexed_grads(self) -> None: 86*da0073e9SAndroid Build Coastguard Worker specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" 87*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(specification) 88*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 91*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"' 92*da0073e9SAndroid Build Coastguard Worker ): 93*da0073e9SAndroid Build Coastguard Worker load_derivatives.create_differentiability_info( 94*da0073e9SAndroid Build Coastguard Worker defn_dict={ 95*da0073e9SAndroid Build Coastguard Worker "name": specification, 96*da0073e9SAndroid Build Coastguard Worker # Uh-oh, the derivatives reference gradients by 97*da0073e9SAndroid Build Coastguard Worker # name and by index. 98*da0073e9SAndroid Build Coastguard Worker "dispatch": { 99*da0073e9SAndroid Build Coastguard Worker "Default": { 100*da0073e9SAndroid Build Coastguard Worker "a": "grad_x", 101*da0073e9SAndroid Build Coastguard Worker "b": "grads[1]", 102*da0073e9SAndroid Build Coastguard Worker } 103*da0073e9SAndroid Build Coastguard Worker }, 104*da0073e9SAndroid Build Coastguard Worker }, 105*da0073e9SAndroid Build Coastguard Worker functions_by_signature={schema.signature(): [native_function]}, 106*da0073e9SAndroid Build Coastguard Worker functions_by_schema={specification: native_function}, 107*da0073e9SAndroid Build Coastguard Worker op_counter=typing.Counter[str](), 108*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys=set(), 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerclass TestGenAutogradFunctions(unittest.TestCase): 113*da0073e9SAndroid Build Coastguard Worker def test_non_differentiable_output_invalid_type(self) -> None: 114*da0073e9SAndroid Build Coastguard Worker specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 115*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(specification) 116*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker _, differentiability_info = load_derivatives.create_differentiability_info( 119*da0073e9SAndroid Build Coastguard Worker defn_dict={ 120*da0073e9SAndroid Build Coastguard Worker "name": specification, 121*da0073e9SAndroid Build Coastguard Worker "dispatch": { 122*da0073e9SAndroid Build Coastguard Worker "Default": { 123*da0073e9SAndroid Build Coastguard Worker "a": "grad_x", 124*da0073e9SAndroid Build Coastguard Worker "b": "grad_z", 125*da0073e9SAndroid Build Coastguard Worker } 126*da0073e9SAndroid Build Coastguard Worker }, 127*da0073e9SAndroid Build Coastguard Worker }, 128*da0073e9SAndroid Build Coastguard Worker functions_by_signature={schema.signature(): [native_function]}, 129*da0073e9SAndroid Build Coastguard Worker functions_by_schema={specification: native_function}, 130*da0073e9SAndroid Build Coastguard Worker op_counter=typing.Counter[str](), 131*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys=set(), 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker definition = gen_autograd_functions.process_function( 134*da0073e9SAndroid Build Coastguard Worker differentiability_info["Default"], 135*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions.FUNCTION_DEFINITION, 136*da0073e9SAndroid Build Coastguard Worker ) 137*da0073e9SAndroid Build Coastguard Worker # grad_z should map to grads[1], not grads[2] because output 1 138*da0073e9SAndroid Build Coastguard Worker # (y) is not differentiable. 139*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[2]" not in definition 140*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[1]" in definition 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_non_differentiable_output_output_differentiability(self) -> None: 143*da0073e9SAndroid Build Coastguard Worker specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" 144*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(specification) 145*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker _, differentiability_info = load_derivatives.create_differentiability_info( 148*da0073e9SAndroid Build Coastguard Worker defn_dict={ 149*da0073e9SAndroid Build Coastguard Worker "name": specification, 150*da0073e9SAndroid Build Coastguard Worker "dispatch": { 151*da0073e9SAndroid Build Coastguard Worker "Default": { 152*da0073e9SAndroid Build Coastguard Worker "a": "grad_x", 153*da0073e9SAndroid Build Coastguard Worker "b": "grad_z", 154*da0073e9SAndroid Build Coastguard Worker }, 155*da0073e9SAndroid Build Coastguard Worker "AutogradNestedTensor": { 156*da0073e9SAndroid Build Coastguard Worker "a": "grad_z", 157*da0073e9SAndroid Build Coastguard Worker "b": "grad_x", 158*da0073e9SAndroid Build Coastguard Worker }, 159*da0073e9SAndroid Build Coastguard Worker }, 160*da0073e9SAndroid Build Coastguard Worker "output_differentiability": [True, False, True], 161*da0073e9SAndroid Build Coastguard Worker }, 162*da0073e9SAndroid Build Coastguard Worker functions_by_signature={schema.signature(): [native_function]}, 163*da0073e9SAndroid Build Coastguard Worker functions_by_schema={specification: native_function}, 164*da0073e9SAndroid Build Coastguard Worker op_counter=typing.Counter[str](), 165*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys=set(), 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker default_definition = gen_autograd_functions.process_function( 168*da0073e9SAndroid Build Coastguard Worker differentiability_info["Default"], 169*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions.FUNCTION_DEFINITION, 170*da0073e9SAndroid Build Coastguard Worker ) 171*da0073e9SAndroid Build Coastguard Worker # grad_z should map to grads[1], not grads[2] because output 1 172*da0073e9SAndroid Build Coastguard Worker # (y) is not differentiable. 173*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[2]" not in default_definition 174*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[1]" in default_definition 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker nested_tensor_definition = gen_autograd_functions.process_function( 177*da0073e9SAndroid Build Coastguard Worker differentiability_info["AutogradNestedTensor"], 178*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions.FUNCTION_DEFINITION, 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[2]" not in nested_tensor_definition 181*da0073e9SAndroid Build Coastguard Worker assert "grad_z = grads[1]" in nested_tensor_definition 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def test_register_bogus_dispatch_key(self) -> None: 184*da0073e9SAndroid Build Coastguard Worker specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" 185*da0073e9SAndroid Build Coastguard Worker schema = FunctionSchema.parse(specification) 186*da0073e9SAndroid Build Coastguard Worker native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 189*da0073e9SAndroid Build Coastguard Worker RuntimeError, 190*da0073e9SAndroid Build Coastguard Worker "Invalid dispatch key AutogradRandomTensor in derivatives.yaml for", 191*da0073e9SAndroid Build Coastguard Worker ): 192*da0073e9SAndroid Build Coastguard Worker load_derivatives.create_differentiability_info( 193*da0073e9SAndroid Build Coastguard Worker defn_dict={ 194*da0073e9SAndroid Build Coastguard Worker "name": specification, 195*da0073e9SAndroid Build Coastguard Worker "dispatch": { 196*da0073e9SAndroid Build Coastguard Worker "Default": { 197*da0073e9SAndroid Build Coastguard Worker "a": "grad_x", 198*da0073e9SAndroid Build Coastguard Worker "b": "grad_z", 199*da0073e9SAndroid Build Coastguard Worker }, 200*da0073e9SAndroid Build Coastguard Worker "AutogradRandomTensor": { 201*da0073e9SAndroid Build Coastguard Worker "a": "grad_x", 202*da0073e9SAndroid Build Coastguard Worker "b": "grad_z", 203*da0073e9SAndroid Build Coastguard Worker }, 204*da0073e9SAndroid Build Coastguard Worker }, 205*da0073e9SAndroid Build Coastguard Worker }, 206*da0073e9SAndroid Build Coastguard Worker functions_by_signature={schema.signature(): [native_function]}, 207*da0073e9SAndroid Build Coastguard Worker functions_by_schema={specification: native_function}, 208*da0073e9SAndroid Build Coastguard Worker op_counter=typing.Counter[str](), 209*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys=set(), 210*da0073e9SAndroid Build Coastguard Worker ) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Workerclass TestGenSchemaRegistration(unittest.TestCase): 214*da0073e9SAndroid Build Coastguard Worker def setUp(self) -> None: 215*da0073e9SAndroid Build Coastguard Worker self.selector = SelectiveBuilder.get_nop_selector() 216*da0073e9SAndroid Build Coastguard Worker self.custom_native_function, _ = NativeFunction.from_yaml( 217*da0073e9SAndroid Build Coastguard Worker {"func": "custom::func() -> bool"}, 218*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 219*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker ( 222*da0073e9SAndroid Build Coastguard Worker self.fragment_custom_native_function, 223*da0073e9SAndroid Build Coastguard Worker _, 224*da0073e9SAndroid Build Coastguard Worker ) = NativeFunction.from_yaml( 225*da0073e9SAndroid Build Coastguard Worker {"func": "quantized_decomposed::func() -> bool"}, 226*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 227*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def test_default_namespace_schema_registration_code_valid(self) -> None: 231*da0073e9SAndroid Build Coastguard Worker native_functions = [DEFAULT_NATIVE_FUNCTION] 232*da0073e9SAndroid Build Coastguard Worker registrations, _ = get_native_function_schema_registrations( 233*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, 234*da0073e9SAndroid Build Coastguard Worker schema_selector=self.selector, 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(registrations, ['m.def("func() -> bool", {});\n']) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def test_custom_namespace_schema_registration_code_valid(self) -> None: 239*da0073e9SAndroid Build Coastguard Worker _, registrations = get_native_function_schema_registrations( 240*da0073e9SAndroid Build Coastguard Worker native_functions=[self.custom_native_function], 241*da0073e9SAndroid Build Coastguard Worker schema_selector=self.selector, 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 244*da0073e9SAndroid Build Coastguard Worker registrations, 245*da0073e9SAndroid Build Coastguard Worker """ 246*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) { 247*da0073e9SAndroid Build Coastguard Worker m.def("func() -> bool", {}); 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker};""", 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None: 253*da0073e9SAndroid Build Coastguard Worker """Sometimes we want to extend an existing namespace, for example quantized 254*da0073e9SAndroid Build Coastguard Worker namespace, which is already defined in native/quantized/library.cpp 255*da0073e9SAndroid Build Coastguard Worker """ 256*da0073e9SAndroid Build Coastguard Worker _, registrations = get_native_function_schema_registrations( 257*da0073e9SAndroid Build Coastguard Worker native_functions=[self.fragment_custom_native_function], 258*da0073e9SAndroid Build Coastguard Worker schema_selector=self.selector, 259*da0073e9SAndroid Build Coastguard Worker ) 260*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 261*da0073e9SAndroid Build Coastguard Worker registrations, 262*da0073e9SAndroid Build Coastguard Worker """ 263*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) { 264*da0073e9SAndroid Build Coastguard Worker m.def("func() -> bool", {}); 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker};""", 267*da0073e9SAndroid Build Coastguard Worker ) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker def test_mixed_namespace_schema_registration_code_valid(self) -> None: 270*da0073e9SAndroid Build Coastguard Worker ( 271*da0073e9SAndroid Build Coastguard Worker aten_registrations, 272*da0073e9SAndroid Build Coastguard Worker custom_registrations, 273*da0073e9SAndroid Build Coastguard Worker ) = get_native_function_schema_registrations( 274*da0073e9SAndroid Build Coastguard Worker native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function], 275*da0073e9SAndroid Build Coastguard Worker schema_selector=self.selector, 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) 278*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 279*da0073e9SAndroid Build Coastguard Worker custom_registrations, 280*da0073e9SAndroid Build Coastguard Worker """ 281*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) { 282*da0073e9SAndroid Build Coastguard Worker m.def("func() -> bool", {}); 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker};""", 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def test_3_namespaces_schema_registration_code_valid(self) -> None: 288*da0073e9SAndroid Build Coastguard Worker custom2_native_function, _ = NativeFunction.from_yaml( 289*da0073e9SAndroid Build Coastguard Worker {"func": "custom2::func() -> bool"}, 290*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 291*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 292*da0073e9SAndroid Build Coastguard Worker ) 293*da0073e9SAndroid Build Coastguard Worker ( 294*da0073e9SAndroid Build Coastguard Worker aten_registrations, 295*da0073e9SAndroid Build Coastguard Worker custom_registrations, 296*da0073e9SAndroid Build Coastguard Worker ) = get_native_function_schema_registrations( 297*da0073e9SAndroid Build Coastguard Worker native_functions=[ 298*da0073e9SAndroid Build Coastguard Worker DEFAULT_NATIVE_FUNCTION, 299*da0073e9SAndroid Build Coastguard Worker self.custom_native_function, 300*da0073e9SAndroid Build Coastguard Worker custom2_native_function, 301*da0073e9SAndroid Build Coastguard Worker ], 302*da0073e9SAndroid Build Coastguard Worker schema_selector=self.selector, 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) 305*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 306*da0073e9SAndroid Build Coastguard Worker custom_registrations, 307*da0073e9SAndroid Build Coastguard Worker """ 308*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) { 309*da0073e9SAndroid Build Coastguard Worker m.def("func() -> bool", {}); 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker}; 312*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom2, m) { 313*da0073e9SAndroid Build Coastguard Worker m.def("func() -> bool", {}); 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker};""", 316*da0073e9SAndroid Build Coastguard Worker ) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Workerclass TestGenNativeFunctionDeclaration(unittest.TestCase): 320*da0073e9SAndroid Build Coastguard Worker def setUp(self) -> None: 321*da0073e9SAndroid Build Coastguard Worker self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( 322*da0073e9SAndroid Build Coastguard Worker {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, 323*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 324*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( 327*da0073e9SAndroid Build Coastguard Worker { 328*da0073e9SAndroid Build Coastguard Worker "func": "op_2() -> bool", 329*da0073e9SAndroid Build Coastguard Worker "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, 330*da0073e9SAndroid Build Coastguard Worker }, 331*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 332*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 333*da0073e9SAndroid Build Coastguard Worker ) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { 336*da0073e9SAndroid Build Coastguard Worker DispatchKey.CPU: {}, 337*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedCPU: {}, 338*da0073e9SAndroid Build Coastguard Worker } 339*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(backend_indices, op_1_backend_index) 340*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(backend_indices, op_2_backend_index) 341*da0073e9SAndroid Build Coastguard Worker self.backend_indices = { 342*da0073e9SAndroid Build Coastguard Worker k: BackendIndex( 343*da0073e9SAndroid Build Coastguard Worker dispatch_key=k, 344*da0073e9SAndroid Build Coastguard Worker use_out_as_primary=True, 345*da0073e9SAndroid Build Coastguard Worker external=False, 346*da0073e9SAndroid Build Coastguard Worker device_guard=False, 347*da0073e9SAndroid Build Coastguard Worker index=backend_indices[k], 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker for k in backend_indices 350*da0073e9SAndroid Build Coastguard Worker } 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def test_native_function_declaration_1_op_2_ns_error(self) -> None: 353*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 354*da0073e9SAndroid Build Coastguard Worker get_native_function_declarations( 355*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=[ 356*da0073e9SAndroid Build Coastguard Worker self.op_1_native_function, 357*da0073e9SAndroid Build Coastguard Worker self.op_2_native_function, 358*da0073e9SAndroid Build Coastguard Worker ], 359*da0073e9SAndroid Build Coastguard Worker backend_indices=self.backend_indices, 360*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen=dest.compute_native_function_declaration, 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def test_native_function_declaration_1_op_1_ns_valid(self) -> None: 364*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(self.op_1_native_function, NativeFunction) 365*da0073e9SAndroid Build Coastguard Worker declaration = get_native_function_declarations( 366*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=[ 367*da0073e9SAndroid Build Coastguard Worker self.op_1_native_function, 368*da0073e9SAndroid Build Coastguard Worker ], 369*da0073e9SAndroid Build Coastguard Worker backend_indices=self.backend_indices, 370*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen=dest.compute_native_function_declaration, 371*da0073e9SAndroid Build Coastguard Worker ) 372*da0073e9SAndroid Build Coastguard Worker target = """ 373*da0073e9SAndroid Build Coastguard Workernamespace at { 374*da0073e9SAndroid Build Coastguard Workernamespace native { 375*da0073e9SAndroid Build Coastguard WorkerTORCH_API bool kernel_1(); 376*da0073e9SAndroid Build Coastguard Worker} // namespace native 377*da0073e9SAndroid Build Coastguard Worker} // namespace at 378*da0073e9SAndroid Build Coastguard Worker """ 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual("\n".join(declaration), target) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker# Test for native_function_generation 383*da0073e9SAndroid Build Coastguard Workerclass TestNativeFunctionGeneratrion(unittest.TestCase): 384*da0073e9SAndroid Build Coastguard Worker def setUp(self) -> None: 385*da0073e9SAndroid Build Coastguard Worker self.native_functions: list[NativeFunction] = [] 386*da0073e9SAndroid Build Coastguard Worker self.backend_indices: dict[ 387*da0073e9SAndroid Build Coastguard Worker DispatchKey, dict[OperatorName, BackendMetadata] 388*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(dict) 389*da0073e9SAndroid Build Coastguard Worker yaml_entry = """ 390*da0073e9SAndroid Build Coastguard Worker- func: op(Tensor self) -> Tensor 391*da0073e9SAndroid Build Coastguard Worker dispatch: 392*da0073e9SAndroid Build Coastguard Worker CompositeExplicitAutograd: op 393*da0073e9SAndroid Build Coastguard Worker autogen: op.out 394*da0073e9SAndroid Build Coastguard Worker """ 395*da0073e9SAndroid Build Coastguard Worker es = yaml.load(yaml_entry, Loader=LineLoader) 396*da0073e9SAndroid Build Coastguard Worker self.one_return_func, m = NativeFunction.from_yaml( 397*da0073e9SAndroid Build Coastguard Worker es[0], loc=Location(__file__, 1), valid_tags=set() 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(self.backend_indices, m) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml( 403*da0073e9SAndroid Build Coastguard Worker { 404*da0073e9SAndroid Build Coastguard Worker "func": "op_2() -> (Tensor, Tensor)", 405*da0073e9SAndroid Build Coastguard Worker "dispatch": {"CPU": "kernel_1"}, 406*da0073e9SAndroid Build Coastguard Worker "autogen": "op_2.out", 407*da0073e9SAndroid Build Coastguard Worker }, 408*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 409*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(self.backend_indices, two_returns_backend_index) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def test_functional_variant_autogen_out_variant(self) -> None: 414*da0073e9SAndroid Build Coastguard Worker native_functions = [self.one_return_func] 415*da0073e9SAndroid Build Coastguard Worker add_generated_native_functions(native_functions, self.backend_indices) 416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(native_functions), 2) 417*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 418*da0073e9SAndroid Build Coastguard Worker str(native_functions[1].func), 419*da0073e9SAndroid Build Coastguard Worker "op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker op_name = native_functions[1].func.name 422*da0073e9SAndroid Build Coastguard Worker backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ 423*da0073e9SAndroid Build Coastguard Worker op_name 424*da0073e9SAndroid Build Coastguard Worker ] 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(backend_metadata.kernel, "op_out") 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker def test_functional_variant_autogen_out_variant_two_returns(self) -> None: 428*da0073e9SAndroid Build Coastguard Worker native_functions = [self.two_returns_func] 429*da0073e9SAndroid Build Coastguard Worker add_generated_native_functions(native_functions, self.backend_indices) 430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(native_functions), 2) 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 432*da0073e9SAndroid Build Coastguard Worker str(native_functions[1].func), 433*da0073e9SAndroid Build Coastguard Worker "op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker op_name = native_functions[1].func.name 436*da0073e9SAndroid Build Coastguard Worker backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ 437*da0073e9SAndroid Build Coastguard Worker op_name 438*da0073e9SAndroid Build Coastguard Worker ] 439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(backend_metadata.kernel, "op_2_out") 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker# Test for static_dispatch 443*da0073e9SAndroid Build Coastguard Workerclass TestStaticDispatchGeneratrion(unittest.TestCase): 444*da0073e9SAndroid Build Coastguard Worker def setUp(self) -> None: 445*da0073e9SAndroid Build Coastguard Worker self.backend_indices: dict[ 446*da0073e9SAndroid Build Coastguard Worker DispatchKey, dict[OperatorName, BackendMetadata] 447*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(dict) 448*da0073e9SAndroid Build Coastguard Worker yaml_entry = """ 449*da0073e9SAndroid Build Coastguard Worker- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 450*da0073e9SAndroid Build Coastguard Worker dispatch: 451*da0073e9SAndroid Build Coastguard Worker CompositeExplicitAutograd: op 452*da0073e9SAndroid Build Coastguard Worker """ 453*da0073e9SAndroid Build Coastguard Worker es = yaml.load(yaml_entry, Loader=LineLoader) 454*da0073e9SAndroid Build Coastguard Worker self.one_return_func, m = NativeFunction.from_yaml( 455*da0073e9SAndroid Build Coastguard Worker es[0], loc=Location(__file__, 1), valid_tags=set() 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(self.backend_indices, m) 459*da0073e9SAndroid Build Coastguard Worker dispatch_key = DispatchKey.CompositeExplicitAutograd 460*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dispatch_key in self.backend_indices) 461*da0073e9SAndroid Build Coastguard Worker self.indices = [ 462*da0073e9SAndroid Build Coastguard Worker BackendIndex( 463*da0073e9SAndroid Build Coastguard Worker dispatch_key=dispatch_key, 464*da0073e9SAndroid Build Coastguard Worker use_out_as_primary=True, 465*da0073e9SAndroid Build Coastguard Worker external=False, 466*da0073e9SAndroid Build Coastguard Worker device_guard=False, 467*da0073e9SAndroid Build Coastguard Worker index=self.backend_indices[dispatch_key], 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker ] 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker def test_op_with_1_backend_generates_static_dispatch(self) -> None: 472*da0073e9SAndroid Build Coastguard Worker disp_sig = DispatcherSignature.from_schema(self.one_return_func.func) 473*da0073e9SAndroid Build Coastguard Worker with native_function_manager(self.one_return_func): 474*da0073e9SAndroid Build Coastguard Worker out = static_dispatch( 475*da0073e9SAndroid Build Coastguard Worker sig=disp_sig, 476*da0073e9SAndroid Build Coastguard Worker f=self.one_return_func, 477*da0073e9SAndroid Build Coastguard Worker backend_indices=self.indices, 478*da0073e9SAndroid Build Coastguard Worker ) 479*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 480*da0073e9SAndroid Build Coastguard Worker out, "return at::compositeexplicitautograd::op_out(out, self);" 481*da0073e9SAndroid Build Coastguard Worker ) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker def test_op_with_cpp_sig_generates_static_dispatch(self) -> None: 484*da0073e9SAndroid Build Coastguard Worker sig_group = CppSignatureGroup.from_native_function( 485*da0073e9SAndroid Build Coastguard Worker self.one_return_func, 486*da0073e9SAndroid Build Coastguard Worker method=False, 487*da0073e9SAndroid Build Coastguard Worker fallback_binding=self.one_return_func.manual_cpp_binding, 488*da0073e9SAndroid Build Coastguard Worker ) 489*da0073e9SAndroid Build Coastguard Worker # cpp signature puts out at the front 490*da0073e9SAndroid Build Coastguard Worker with native_function_manager(self.one_return_func): 491*da0073e9SAndroid Build Coastguard Worker out = static_dispatch( 492*da0073e9SAndroid Build Coastguard Worker sig=sig_group.signature, 493*da0073e9SAndroid Build Coastguard Worker f=self.one_return_func, 494*da0073e9SAndroid Build Coastguard Worker backend_indices=self.indices, 495*da0073e9SAndroid Build Coastguard Worker ) 496*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 497*da0073e9SAndroid Build Coastguard Worker out, "return at::compositeexplicitautograd::op_out(out, self);" 498*da0073e9SAndroid Build Coastguard Worker ) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker# Represents the most basic NativeFunction. Use dataclasses.replace() 502*da0073e9SAndroid Build Coastguard Worker# to edit for use. 503*da0073e9SAndroid Build Coastguard WorkerDEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( 504*da0073e9SAndroid Build Coastguard Worker {"func": "func() -> bool"}, 505*da0073e9SAndroid Build Coastguard Worker loc=Location(__file__, 1), 506*da0073e9SAndroid Build Coastguard Worker valid_tags=set(), 507*da0073e9SAndroid Build Coastguard Worker) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 511*da0073e9SAndroid Build Coastguard Worker unittest.main() 512