xref: /aosp_15_r20/external/pytorch/pt_ops.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1load("//tools/build_defs:expect.bzl", "expect")
2load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
3load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
4load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
5
6IS_OSS = read_config("pt", "is_oss", "0") == "1"  # True for OSS BUCK build, and False for internal BUCK build
7
8USED_PT_BACKENDS = [
9    "CPU",
10    "QuantizedCPU",
11    "SparseCPU",  # brings ~20 kb size regression
12]
13
14def pt_operator_library(
15        name,
16        ops = [],
17        exported_deps = [],
18        check_decl = True,
19        train = False,
20        model = None,
21        include_all_operators = False,
22        include_base_operators = True,
23        **kwargs):
24    (model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
25        name,
26        model,
27    )
28
29    ops = [op.strip() for op in ops]
30
31    # If ops are specified, then we are in static selective build mode, so we append
32    # base ops to this list to avoid additional special case logic in subsequent code,
33    # unless include_base_operators is explicitly set to False (the default is True)
34    if len(ops) > 0 and include_base_operators:
35        ops.extend(PT_BASE_OPS)
36
37    labels = kwargs.pop("labels", [])
38    visibility = kwargs.pop("visibility", ["PUBLIC"])
39
40    # Sanity check the model name and versions.  While the input to both is an array, the
41    # codegen script only ever outputs a single item in the array so we can just assume that
42    # here. If you ever need to depends on more than one assets, just break it up into a separate
43    # BUCK targets.
44    if model_assets or model_versions:
45        if len(model_assets) != 1:
46            fail("Model assets must be of size 1")
47        if len(model_versions) != 1:
48            fail("Model versions must be of size 1")
49
50    # Is this a traced operator therefore has a YAML file with ops?
51    yaml_option = ""
52    if model_assets and len(model_assets) > 0:
53        # We know these lists are only of length 1 via earlier assert.
54        model_asset = model_assets[0]
55        model_version = model_versions[0]
56
57        # Pass the YAML file from this asset to the genrule below.
58        yaml_dep = "{}_v{}_yaml".format(model_asset, model_version)
59        fb_native.filegroup(
60            name = yaml_dep,
61            srcs = [
62                model_asset + ".yaml",
63            ],
64            # The visibility is not set to PUBLIC as this an internal detail.  If you see this error
65            # in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that
66            # with parameters not supported outside of codegen targets!
67        )
68
69        # Since all selective traced ops are created by automation, we can assume they
70        # have a YAML file at this very location.  If it doesn't exist, it means the targets
71        # was hand-crafted which is not a support workflow for traced ops.
72        yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
73
74    not_include_all_overloads_static_root_ops = kwargs.pop(
75        "not_include_all_overloads_static_root_ops",
76        False,
77    )
78
79    not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
80
81    if False:
82        # TODO(nga): `yaml_option` is never `None`, but it is checked against `None` below.
83        #   Typechecker (`--unstable-typecheck`) catches it.
84        yaml_option = None
85
86    fb_xplat_genrule(
87        name = name,
88        out = "model_operators.yaml",
89        cmd = (
90            "$(exe {exe}) " +
91            "{optionally_root_ops} " +
92            "{optionally_training_root_ops} " +
93            "--rule_name {rule_name} " +
94            "--output_path \"${{OUT}}\" " +
95            "--model_name {model_name} " +
96            "--dep_graph_yaml_path {dep_graph_yaml} " +
97            "{optionally_model_yamls} " +
98            "{optionally_model_versions} " +
99            "{optionally_model_assets} " +
100            "{optionally_model_traced_backends} " +
101            "{optionally_include_all_operators}" +
102            "{not_include_all_overloads_static_root_ops}" +
103            "{not_include_all_overloads_closure_ops}"
104        ).format(
105            exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
106            rule_name = name,
107            model_name = model_name,
108            dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
109            optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option,
110            optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
111            optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
112            optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
113            optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
114            optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
115            optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
116            not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
117            not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
118        ),
119        labels = labels + [
120            "pt_operator_library",
121            "supermodule:android/default/pytorch",
122            "supermodule:ios/default/public.pytorch",
123        ] + (["pt_train_operator_library"] if train else []),
124        visibility = visibility,
125        **kwargs
126    )
127
128def validate_and_extract_model_information(name, model):
129    model_name = name
130    model_versions = None
131    model_assets = None
132    model_traced_backends = None
133
134    if model != None:
135        model_name = model.get("name")
136        expect(model_name != None, "Expected Model Name to be present")
137        model_versions = model.get("versions")
138        expect(is_list(model_versions), "Expected model versions to be a list of string")
139        for ver in model_versions or []:
140            expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
141        model_assets = model.get("assets")
142        expect(
143            model_assets == None or is_list(model_assets),
144            "Expected model assets to be a list of string if specified",
145        )
146        for asset_name in model_assets or []:
147            expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
148        model_traced_backends = model.get("traced_backends")
149        expect(
150            model_traced_backends == None or is_list(model_traced_backends),
151            "Expected model traced backends to be a list of string if specified",
152        )
153
154        if model_traced_backends != None:
155            for backend in model_traced_backends:
156                expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
157                expect(
158                    backend in USED_PT_BACKENDS,
159                    "Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
160                )
161
162    return (model_name, model_versions, model_assets, model_traced_backends)
163
164# This file keeps a list of PyTorch operators used by any targets in
165# @fbsource//xplat/...
166# The purpose of the list is to avoid generating large number of unused
167# operator registration code / BUCK rules at build time.
168# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv
169
170PT_OPS_PRIM = [
171    "aten::str",
172    "aten::list",
173    "aten::__range_length",
174    "aten::__derive_index",
175    "prim::TupleUnpack",
176    "prim::unchecked_cast",
177    "aten::IntImplicit",
178    "aten::FloatImplicit",
179    "aten::ScalarImplicit",
180    "aten::Bool.Tensor",
181    "aten::Bool.int",
182    "aten::Bool.float",
183    "aten::Int.Tensor",
184    "aten::Int.Scalar",
185    "aten::Int.int",
186    "aten::Int.bool",
187    "aten::Int.str",
188    "aten::Float.Tensor",
189    "aten::Float.Scalar",
190    "aten::Float.int",
191    "aten::Float.bool",
192    "aten::Float.str",
193    "aten::format",
194    "prim::NumToTensor.Scalar",
195    "prim::RaiseException",
196    "aten::Size",
197    "aten::size",
198    "prim::EnumName",
199    "prim::EnumValue.int",
200    "prim::EnumValue.float",
201    "prim::EnumValue.str",
202    "prim::TupleIndex",
203    "aten::ne.int_list",
204    "prim::unchecked_unwrap_optional",
205    "prim::device",
206    "prim::dtype",
207    "aten::__not__",
208    "aten::__is__",
209    "aten::__isnot__",
210    "aten::element_size",
211    "aten::numel",
212    "aten::dim",
213    "aten::get_device",
214    "aten::storage_offset",
215    "aten::is_contiguous",
216    "aten::select.t",
217    "aten::__getitem__.t",
218    "aten::append.t",
219    "aten::reverse.t",
220    "aten::extend.t",
221    "aten::copy.t",
222    "aten::_set_item.t",
223    "aten::clear.t",
224    "aten::Delete.t",
225    "aten::insert.t",
226    "aten::pop.t",
227    "aten::add.t",
228    "aten::add_.t",
229    "aten::slice.t",
230    "aten::list.t",
231    "aten::mul.left_t",
232    "aten::mul.right_",
233    "aten::mul_.t",
234    "aten::len.t",
235    "aten::eq.int_list",
236    "prim::Uninitialized",
237    "prim::Print",
238    "aten::eq.enum",
239    "aten::ne.enum",
240    "aten::dequantize.tensor",
241    "aten::dequantize.any",
242    "aten::add.str",
243    "aten::eq.int",
244    "aten::eq.float",
245    "aten::eq.int_float",
246    "aten::eq.float_int",
247    "aten::eq",
248    "aten::eq.str",
249    "aten::ne.int",
250    "aten::ne.float",
251    "aten::ne.int_float",
252    "aten::ne.float_int",
253    "aten::ne",
254    "aten::ne.str",
255    "aten::lt.int",
256    "aten::lt.float",
257    "aten::lt.int_float",
258    "aten::lt.float_int",
259    "aten::lt",
260    "aten::lt.str",
261    "aten::gt.int",
262    "aten::gt.float",
263    "aten::gt.int_float",
264    "aten::gt.float_int",
265    "aten::gt",
266    "aten::gt.str",
267    "aten::le.int",
268    "aten::le.float",
269    "aten::le.int_float",
270    "aten::le.float_int",
271    "aten::le",
272    "aten::le.str",
273    "aten::ge.int",
274    "aten::ge.float",
275    "aten::ge.int_float",
276    "aten::ge.float_int",
277    "aten::ge",
278    "aten::ge.str",
279    "aten::add.int",
280    "aten::add.float",
281    "aten::add.int_float",
282    "aten::add.float_int",
283    "aten::add",
284    "aten::sub.int",
285    "aten::sub.float",
286    "aten::sub.int_float",
287    "aten::sub.float_int",
288    "aten::sub",
289    "aten::mul.int",
290    "aten::mul.float",
291    "aten::mul.int_float",
292    "aten::mul.float_int",
293    "aten::mul",
294    "aten::__and__.bool",
295    "aten::__or__.bool",
296    "aten::__xor__.bool",
297    "aten::floor.int",
298    "aten::floor.float",
299    "aten::floor.Scalar",
300    "aten::ceil.int",
301    "aten::ceil.float",
302    "aten::ceil.Scalar",
303    "aten::neg.int",
304    "aten::neg.float",
305    "aten::neg.Scalar",
306    "aten::exp.int",
307    "aten::exp.float",
308    "aten::exp.Scalar",
309    "aten::remainder.int",
310    "aten::remainder.float",
311    "aten::remainder.int_float",
312    "aten::remainder.float_int",
313    "aten::remainder",
314    "aten::div.int",
315    "aten::div.float",
316    "aten::div",
317    "aten::floordiv.int",
318    "aten::floordiv.float",
319    "aten::floordiv.int_float",
320    "aten::floordiv.float_int",
321    "aten::floordiv",
322    "aten::pow.int",
323    "aten::pow.float",
324    "aten::pow.int_float",
325    "aten::pow.float_int",
326    "aten::pow.Scalar_Scalar",
327    "aten::pow.int_to_int",
328    "prim::min.int",
329    "prim::min.float",
330    "prim::min.int_float",
331    "prim::min.float_int",
332    "prim::min",
333    "prim::max.int",
334    "prim::max.float",
335    "prim::max.int_float",
336    "prim::max.float_int",
337    "prim::max",
338    "prim::type",
339    "aten::len.Tensor",
340    "aten::ord",
341    "aten::lower",
342    "aten::__contains__.str_list",
343    "aten::len.str",
344    "aten::__getitem__.str",
345    "aten::copy_.Tensor",
346    "aten::copy_.int",
347    "aten::copy_.float",
348    "aten::backward",
349    "aten::index.Tensor_hacked_twin",
350    "aten::_unsafe_index.Tensor_hacked_twin",
351    "aten::_index_put_impl_.hacked_twin",
352    "aten::index_put_.hacked_twin",
353    "aten::index_put.hacked_twin",
354    "aten::_unsafe_index_put.hacked_twin",
355    "aten::to.prim_Device",
356    "aten::to.prim_dtype",
357    "prim::is_cuda",
358    "prim::data",
359    "prim::min.int_list",
360    "prim::max.int_list",
361    "prim::min.self_int",
362    "prim::max.self_int",
363    "prim::min.float_list",
364    "prim::max.float_list",
365    "prim::min.self_float",
366    "prim::max.self_float",
367    "prim::min.bool_list",
368    "prim::max.bool_list",
369    "prim::min.self_bool",
370    "prim::max.self_bool",
371    "aten::len.Dict_str",
372    "aten::keys.str",
373    "aten::values.str",
374    "aten::__getitem__.Dict_str",
375    "aten::get.str",
376    "aten::get.default_str",
377    "aten::setdefault.str",
378    "aten::Delete.Dict_str",
379    "aten::pop.Dict_str",
380    "aten::pop.Dict_default_str",
381    "aten::popitem.str",
382    "aten::clear.str",
383    "aten::update.str",
384    "aten::items.str",
385    "aten::copy.Dict_str",
386    "aten::__contains__.str",
387    "aten::_set_item.str",
388    "aten::dict.str",
389    "aten::len.Dict_int",
390    "aten::keys.int",
391    "aten::values.int",
392    "aten::__getitem__.Dict_int",
393    "aten::get.int",
394    "aten::get.default_int",
395    "aten::setdefault.int",
396    "aten::Delete.Dict_int",
397    "aten::pop.Dict_int",
398    "aten::pop.Dict_default_int",
399    "aten::popitem.int",
400    "aten::clear.int",
401    "aten::update.int",
402    "aten::items.int",
403    "aten::copy.Dict_int",
404    "aten::__contains__.int",
405    "aten::_set_item.int",
406    "aten::dict.int",
407    "aten::len.Dict_bool",
408    "aten::keys.bool",
409    "aten::values.bool",
410    "aten::__getitem__.Dict_bool",
411    "aten::get.bool",
412    "aten::get.default_bool",
413    "aten::setdefault.bool",
414    "aten::Delete.Dict_bool",
415    "aten::pop.Dict_bool",
416    "aten::pop.Dict_default_bool",
417    "aten::popitem.bool",
418    "aten::clear.bool",
419    "aten::update.bool",
420    "aten::items.bool",
421    "aten::copy.Dict_bool",
422    "aten::__contains__.bool",
423    "aten::_set_item.bool",
424    "aten::dict.bool",
425    "aten::len.Dict_float",
426    "aten::keys.float",
427    "aten::values.float",
428    "aten::__getitem__.Dict_float",
429    "aten::get.float",
430    "aten::get.default_float",
431    "aten::setdefault.float",
432    "aten::Delete.Dict_float",
433    "aten::pop.Dict_float",
434    "aten::pop.Dict_default_float",
435    "aten::popitem.float",
436    "aten::clear.float",
437    "aten::update.float",
438    "aten::items.float",
439    "aten::copy.Dict_float",
440    "aten::__contains__.float",
441    "aten::_set_item.float",
442    "aten::dict.float",
443    "aten::len.Dict_Tensor",
444    "aten::keys.Tensor",
445    "aten::values.Tensor",
446    "aten::__getitem__.Dict_Tensor",
447    "aten::get.Tensor",
448    "aten::get.default_Tensor",
449    "aten::setdefault.Tensor",
450    "aten::Delete.Dict_Tensor",
451    "aten::pop.Dict_Tensor",
452    "aten::pop.Dict_default_Tensor",
453    "aten::popitem.Tensor",
454    "aten::clear.Tensor",
455    "aten::update.Tensor",
456    "aten::items.Tensor",
457    "aten::copy.Dict_Tensor",
458    "aten::__contains__.Tensor",
459    "aten::_set_item.Tensor",
460    "aten::dict.Tensor",
461    "aten::__round_to_zero_floordiv.int",
462    "aten::mathremainder.int",
463    "aten::mathremainder.float",
464    "aten::mathremainder.int_float",
465    "aten::mathremainder.float_int",
466    "aten::mathremainder",
467    "aten::__and__.int",
468    "aten::__or__.int",
469    "aten::__xor__.int",
470    "aten::__lshift__.int",
471    "aten::__rshift__.int",
472    "aten::round.int",
473    "aten::round.float",
474    "aten::round.Scalar",
475    "aten::log.int",
476    "aten::log.float",
477    "aten::log.Scalar",
478    "aten::log.int_int",
479    "aten::log.float_float",
480    "aten::log.int_float",
481    "aten::log.float_int",
482    "aten::log.Scalar_Scalar",
483    "aten::log1p.int",
484    "aten::log1p.float",
485    "aten::log1p.Scalar",
486    "aten::log10.int",
487    "aten::log10.float",
488    "aten::log10.Scalar",
489    "aten::sqrt.int",
490    "aten::sqrt.float",
491    "aten::sqrt.Scalar",
492    "aten::acos.int",
493    "aten::acos.float",
494    "aten::acos.Scalar",
495    "aten::asin.int",
496    "aten::asin.float",
497    "aten::asin.Scalar",
498    "aten::atan.int",
499    "aten::atan.float",
500    "aten::atan.Scalar",
501    "aten::atan2.int",
502    "aten::atan2.float",
503    "aten::atan2.int_float",
504    "aten::atan2.float_int",
505    "aten::atan2.Scalar_Scalar",
506    "aten::cos.int",
507    "aten::cos.float",
508    "aten::cos.Scalar",
509    "aten::sin.int",
510    "aten::sin.float",
511    "aten::sin.Scalar",
512    "aten::tan.int",
513    "aten::tan.float",
514    "aten::tan.Scalar",
515    "aten::asinh.int",
516    "aten::asinh.float",
517    "aten::asinh.Scalar",
518    "aten::atanh.int",
519    "aten::atanh.float",
520    "aten::atanh.Scalar",
521    "aten::acosh.int",
522    "aten::acosh.float",
523    "aten::acosh.Scalar",
524    "aten::sinh.int",
525    "aten::sinh.float",
526    "aten::sinh.Scalar",
527    "aten::cosh.int",
528    "aten::cosh.float",
529    "aten::cosh.Scalar",
530    "aten::tanh.int",
531    "aten::tanh.float",
532    "aten::tanh.Scalar",
533    "aten::degrees.int",
534    "aten::degrees.float",
535    "aten::degrees.Scalar",
536    "aten::radians.int",
537    "aten::radians.float",
538    "aten::radians.Scalar",
539    "aten::fmod.int",
540    "aten::fmod.float",
541    "aten::fmod.int_float",
542    "aten::fmod.float_int",
543    "aten::fmod",
544    "aten::factorial.int",
545    "aten::isnan.float",
546    "aten::isfinite.float",
547    "aten::isinf.float",
548    "aten::gamma.int",
549    "aten::gamma.float",
550    "aten::gamma.Scalar",
551    "aten::erf.int",
552    "aten::erf.float",
553    "aten::erf.Scalar",
554    "aten::erfc.int",
555    "aten::erfc.float",
556    "aten::erfc.Scalar",
557    "aten::expm1.int",
558    "aten::expm1.float",
559    "aten::expm1.Scalar",
560    "aten::fabs.int",
561    "aten::fabs.float",
562    "aten::fabs.Scalar",
563    "aten::lgamma.int",
564    "aten::lgamma.float",
565    "aten::lgamma.Scalar",
566    "prim::abs.int",
567    "prim::abs.float",
568    "prim::abs.Scalar",
569    "aten::gcd.int",
570    "aten::copysign.int",
571    "aten::copysign.float",
572    "aten::copysign.int_float",
573    "aten::copysign.float_int",
574    "aten::copysign",
575    "aten::split",
576    "aten::tensor.float",
577    "aten::as_tensor.float",
578    "aten::tensor.int",
579    "aten::as_tensor.int",
580    "aten::tensor.bool",
581    "aten::as_tensor.bool",
582    "aten::_infer_size",
583    "aten::_no_grad_embedding_renorm_",
584    "aten::tensor",
585    "aten::as_tensor",
586    "aten::as_tensor.list",
587    "aten::_pack_sequence",
588    "aten::_get_tracing_state",
589    "aten::is_scripting",
590    "aten::_no_grad_uniform_",
591    "aten::_no_grad_normal_",
592    "aten::_no_grad_fill_",
593    "aten::_no_grad_zero_",
594]
595
596PT_BASE_OPS = [
597    "aten::_coalesced_",
598    "aten::_copy_from",
599    "aten::_empty_affine_quantized",
600    "aten::_empty_per_channel_affine_quantized",
601    "aten::_indices",
602    "aten::_nnz",
603    "aten::_values",
604    "aten::add",
605    "aten::add_",
606    "aten::arange",
607    "aten::as_strided",
608    "aten::as_strided_",
609    "aten::cat",
610    "aten::clone",
611    "aten::coalesce",
612    "aten::contiguous",
613    "aten::copy_",
614    "aten::copy_sparse_to_sparse_",
615    "aten::dense_dim",
616    "aten::dequantize",
617    "aten::div",
618    "aten::div_",
619    "aten::empty",
620    "aten::empty_like",
621    "aten::empty_strided",
622    "aten::eq",
623    "aten::equal",
624    "aten::expand",
625    "aten::fill_",
626    "aten::is_coalesced",
627    "aten::is_complex",
628    "aten::is_floating_point",
629    "aten::is_leaf",
630    "aten::is_nonzero",
631    "aten::item",
632    "aten::max",
633    "aten::min",
634    "aten::mul",
635    "aten::mul_",
636    "aten::narrow",
637    "aten::ne",
638    "aten::permute",
639    "aten::q_per_channel_axis",
640    "aten::q_per_channel_scales",
641    "aten::q_per_channel_zero_points",
642    "aten::q_scale",
643    "aten::q_zero_point",
644    "aten::qscheme",
645    "aten::quantize_per_tensor",
646    "aten::reshape",
647    "aten::_reshape_alias",
648    "aten::resize_",
649    "aten::resize_as_",
650    "aten::scalar_tensor",
651    "aten::select",
652    "aten::set_",
653    "aten::size",
654    "aten::slice",
655    "aten::sparse_dim",
656    "aten::sparse_resize_and_clear_",
657    "aten::squeeze",
658    "aten::squeeze_",
659    "aten::stride",
660    "aten::sub",
661    "aten::sub_",
662    "aten::sum",
663    "aten::t",
664    "aten::to",
665    "aten::_to_copy",
666    "aten::unsqueeze",
667    "aten::view",
668    "aten::zero_",
669    "aten::zeros",
670    "aten::zeros_like",
671]
672