Home
last modified time | relevance | path

Searched full:output_differentiability (Results 1 – 8 of 8) sorted by relevance

/aosp_15_r20/external/pytorch/tools/autograd/
H A Dderivatives.yaml30 # - Optional entry with key 'output_differentiability' and value a list of the
80 # 'output_differentiability' entry (see above).
293 output_differentiability: [False]
296 output_differentiability: [False]
299 output_differentiability: [False]
308 output_differentiability: [False]
311 output_differentiability: [False]
314 output_differentiability: [False]
454 output_differentiability: ["!dtype || isDifferentiableType(*dtype)"]
502 output_differentiability: [False]
[all …]
H A Dload_derivatives.py153 output_differentiability = defn_dict.pop(
154 "output_differentiability", None
157 if output_differentiability:
158 defn_dict["output_differentiability"] = output_differentiability
541 # output_differentiability is captured from the enclosed
550 differentiability = output_differentiability or [True] * len(f.func.returns)
623 # NB: Removes 'output_differentiability' from defn dictionary
625 output_differentiability = defn_dict.pop("output_differentiability", None)
627 if output_differentiability and any(
628 isinstance(diff, str) for diff in output_differentiability
[all …]
/aosp_15_r20/external/pytorch/torchgen/api/
H A Dautograd.py154 output_differentiability: list[bool] | None
156 # output_differentiability in derivatives.yaml can be a list of
160 # output_differentiability gets populated with True for each condition,
201 output_differentiability=self.output_differentiability,
244 # `output_differentiability` field defined in derivatives.yaml (if specified),
575 output_differentiability=None,
848 output_differentiability = info.output_differentiability if info else None
849 if output_differentiability is not None:
850 if len(output_differentiability) != len(outputs):
852 f"The length of output_differentiability ({len(output_differentiability)}), "
[all …]
/aosp_15_r20/external/pytorch/torch/_custom_op/
H A Dautograd.py60 def mark_non_differentiable(ctx, output, output_differentiability): argument
66 if output_differentiability is not None:
71 assert len(output_differentiability) == len(tuple_output)
73 for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
84 f"With output_differentiability={output_differentiability}. "
87 f"output_differentiability.")
94 output_differentiability, argument
118 mark_non_differentiable(ctx, output, output_differentiability)
H A Dimpl.py403 def impl_backward(self, output_differentiability=None, _stacklevel=2): argument
407 if output_differentiability is not None:
410 f"impl_backward(output_differentiability): expected "
411 f"output_differentiability to be a list of bools with "
413 f"got: {output_differentiability}")
415 if not isinstance(output_differentiability, list):
417 for diff in output_differentiability:
420 if len(self._schema.returns) != len(output_differentiability):
429 self._output_differentiability = output_differentiability
/aosp_15_r20/external/pytorch/torch/
H A D_custom_ops.py274 def impl_backward(qualname, output_differentiability=None, *, func=None): argument
313 custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
/aosp_15_r20/external/pytorch/test/
H A Dtest_custom_ops.py1385 with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1388 f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1398 with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1401 f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1420 f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1446 f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
/aosp_15_r20/external/pytorch/tools/test/
H A Dtest_codegen.py160 "output_differentiability": [True, False, True],