1load("//tools/build_defs:expect.bzl", "expect") 2load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") 3load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") 4load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") 5 6IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build 7 8USED_PT_BACKENDS = [ 9 "CPU", 10 "QuantizedCPU", 11 "SparseCPU", # brings ~20 kb size regression 12] 13 14def pt_operator_library( 15 name, 16 ops = [], 17 exported_deps = [], 18 check_decl = True, 19 train = False, 20 model = None, 21 include_all_operators = False, 22 include_base_operators = True, 23 **kwargs): 24 (model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information( 25 name, 26 model, 27 ) 28 29 ops = [op.strip() for op in ops] 30 31 # If ops are specified, then we are in static selective build mode, so we append 32 # base ops to this list to avoid additional special case logic in subsequent code, 33 # unless include_base_operators is explicitly set to False (the default is True) 34 if len(ops) > 0 and include_base_operators: 35 ops.extend(PT_BASE_OPS) 36 37 labels = kwargs.pop("labels", []) 38 visibility = kwargs.pop("visibility", ["PUBLIC"]) 39 40 # Sanity check the model name and versions. While the input to both is an array, the 41 # codegen script only ever outputs a single item in the array so we can just assume that 42 # here. If you ever need to depends on more than one assets, just break it up into a separate 43 # BUCK targets. 44 if model_assets or model_versions: 45 if len(model_assets) != 1: 46 fail("Model assets must be of size 1") 47 if len(model_versions) != 1: 48 fail("Model versions must be of size 1") 49 50 # Is this a traced operator therefore has a YAML file with ops? 51 yaml_option = "" 52 if model_assets and len(model_assets) > 0: 53 # We know these lists are only of length 1 via earlier assert. 54 model_asset = model_assets[0] 55 model_version = model_versions[0] 56 57 # Pass the YAML file from this asset to the genrule below. 58 yaml_dep = "{}_v{}_yaml".format(model_asset, model_version) 59 fb_native.filegroup( 60 name = yaml_dep, 61 srcs = [ 62 model_asset + ".yaml", 63 ], 64 # The visibility is not set to PUBLIC as this an internal detail. If you see this error 65 # in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that 66 # with parameters not supported outside of codegen targets! 67 ) 68 69 # Since all selective traced ops are created by automation, we can assume they 70 # have a YAML file at this very location. If it doesn't exist, it means the targets 71 # was hand-crafted which is not a support workflow for traced ops. 72 yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset) 73 74 not_include_all_overloads_static_root_ops = kwargs.pop( 75 "not_include_all_overloads_static_root_ops", 76 False, 77 ) 78 79 not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False) 80 81 if False: 82 # TODO(nga): `yaml_option` is never `None`, but it is checked against `None` below. 83 # Typechecker (`--unstable-typecheck`) catches it. 84 yaml_option = None 85 86 fb_xplat_genrule( 87 name = name, 88 out = "model_operators.yaml", 89 cmd = ( 90 "$(exe {exe}) " + 91 "{optionally_root_ops} " + 92 "{optionally_training_root_ops} " + 93 "--rule_name {rule_name} " + 94 "--output_path \"${{OUT}}\" " + 95 "--model_name {model_name} " + 96 "--dep_graph_yaml_path {dep_graph_yaml} " + 97 "{optionally_model_yamls} " + 98 "{optionally_model_versions} " + 99 "{optionally_model_assets} " + 100 "{optionally_model_traced_backends} " + 101 "{optionally_include_all_operators}" + 102 "{not_include_all_overloads_static_root_ops}" + 103 "{not_include_all_overloads_closure_ops}" 104 ).format( 105 exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml", 106 rule_name = name, 107 model_name = model_name, 108 dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", 109 optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option, 110 optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", 111 optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", 112 optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", 113 optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", 114 optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", 115 optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", 116 not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "", 117 not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "", 118 ), 119 labels = labels + [ 120 "pt_operator_library", 121 "supermodule:android/default/pytorch", 122 "supermodule:ios/default/public.pytorch", 123 ] + (["pt_train_operator_library"] if train else []), 124 visibility = visibility, 125 **kwargs 126 ) 127 128def validate_and_extract_model_information(name, model): 129 model_name = name 130 model_versions = None 131 model_assets = None 132 model_traced_backends = None 133 134 if model != None: 135 model_name = model.get("name") 136 expect(model_name != None, "Expected Model Name to be present") 137 model_versions = model.get("versions") 138 expect(is_list(model_versions), "Expected model versions to be a list of string") 139 for ver in model_versions or []: 140 expect(is_string(ver), "Expected version '{}' to be string".format(str(ver))) 141 model_assets = model.get("assets") 142 expect( 143 model_assets == None or is_list(model_assets), 144 "Expected model assets to be a list of string if specified", 145 ) 146 for asset_name in model_assets or []: 147 expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name))) 148 model_traced_backends = model.get("traced_backends") 149 expect( 150 model_traced_backends == None or is_list(model_traced_backends), 151 "Expected model traced backends to be a list of string if specified", 152 ) 153 154 if model_traced_backends != None: 155 for backend in model_traced_backends: 156 expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend))) 157 expect( 158 backend in USED_PT_BACKENDS, 159 "Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)), 160 ) 161 162 return (model_name, model_versions, model_assets, model_traced_backends) 163 164# This file keeps a list of PyTorch operators used by any targets in 165# @fbsource//xplat/... 166# The purpose of the list is to avoid generating large number of unused 167# operator registration code / BUCK rules at build time. 168# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv 169 170PT_OPS_PRIM = [ 171 "aten::str", 172 "aten::list", 173 "aten::__range_length", 174 "aten::__derive_index", 175 "prim::TupleUnpack", 176 "prim::unchecked_cast", 177 "aten::IntImplicit", 178 "aten::FloatImplicit", 179 "aten::ScalarImplicit", 180 "aten::Bool.Tensor", 181 "aten::Bool.int", 182 "aten::Bool.float", 183 "aten::Int.Tensor", 184 "aten::Int.Scalar", 185 "aten::Int.int", 186 "aten::Int.bool", 187 "aten::Int.str", 188 "aten::Float.Tensor", 189 "aten::Float.Scalar", 190 "aten::Float.int", 191 "aten::Float.bool", 192 "aten::Float.str", 193 "aten::format", 194 "prim::NumToTensor.Scalar", 195 "prim::RaiseException", 196 "aten::Size", 197 "aten::size", 198 "prim::EnumName", 199 "prim::EnumValue.int", 200 "prim::EnumValue.float", 201 "prim::EnumValue.str", 202 "prim::TupleIndex", 203 "aten::ne.int_list", 204 "prim::unchecked_unwrap_optional", 205 "prim::device", 206 "prim::dtype", 207 "aten::__not__", 208 "aten::__is__", 209 "aten::__isnot__", 210 "aten::element_size", 211 "aten::numel", 212 "aten::dim", 213 "aten::get_device", 214 "aten::storage_offset", 215 "aten::is_contiguous", 216 "aten::select.t", 217 "aten::__getitem__.t", 218 "aten::append.t", 219 "aten::reverse.t", 220 "aten::extend.t", 221 "aten::copy.t", 222 "aten::_set_item.t", 223 "aten::clear.t", 224 "aten::Delete.t", 225 "aten::insert.t", 226 "aten::pop.t", 227 "aten::add.t", 228 "aten::add_.t", 229 "aten::slice.t", 230 "aten::list.t", 231 "aten::mul.left_t", 232 "aten::mul.right_", 233 "aten::mul_.t", 234 "aten::len.t", 235 "aten::eq.int_list", 236 "prim::Uninitialized", 237 "prim::Print", 238 "aten::eq.enum", 239 "aten::ne.enum", 240 "aten::dequantize.tensor", 241 "aten::dequantize.any", 242 "aten::add.str", 243 "aten::eq.int", 244 "aten::eq.float", 245 "aten::eq.int_float", 246 "aten::eq.float_int", 247 "aten::eq", 248 "aten::eq.str", 249 "aten::ne.int", 250 "aten::ne.float", 251 "aten::ne.int_float", 252 "aten::ne.float_int", 253 "aten::ne", 254 "aten::ne.str", 255 "aten::lt.int", 256 "aten::lt.float", 257 "aten::lt.int_float", 258 "aten::lt.float_int", 259 "aten::lt", 260 "aten::lt.str", 261 "aten::gt.int", 262 "aten::gt.float", 263 "aten::gt.int_float", 264 "aten::gt.float_int", 265 "aten::gt", 266 "aten::gt.str", 267 "aten::le.int", 268 "aten::le.float", 269 "aten::le.int_float", 270 "aten::le.float_int", 271 "aten::le", 272 "aten::le.str", 273 "aten::ge.int", 274 "aten::ge.float", 275 "aten::ge.int_float", 276 "aten::ge.float_int", 277 "aten::ge", 278 "aten::ge.str", 279 "aten::add.int", 280 "aten::add.float", 281 "aten::add.int_float", 282 "aten::add.float_int", 283 "aten::add", 284 "aten::sub.int", 285 "aten::sub.float", 286 "aten::sub.int_float", 287 "aten::sub.float_int", 288 "aten::sub", 289 "aten::mul.int", 290 "aten::mul.float", 291 "aten::mul.int_float", 292 "aten::mul.float_int", 293 "aten::mul", 294 "aten::__and__.bool", 295 "aten::__or__.bool", 296 "aten::__xor__.bool", 297 "aten::floor.int", 298 "aten::floor.float", 299 "aten::floor.Scalar", 300 "aten::ceil.int", 301 "aten::ceil.float", 302 "aten::ceil.Scalar", 303 "aten::neg.int", 304 "aten::neg.float", 305 "aten::neg.Scalar", 306 "aten::exp.int", 307 "aten::exp.float", 308 "aten::exp.Scalar", 309 "aten::remainder.int", 310 "aten::remainder.float", 311 "aten::remainder.int_float", 312 "aten::remainder.float_int", 313 "aten::remainder", 314 "aten::div.int", 315 "aten::div.float", 316 "aten::div", 317 "aten::floordiv.int", 318 "aten::floordiv.float", 319 "aten::floordiv.int_float", 320 "aten::floordiv.float_int", 321 "aten::floordiv", 322 "aten::pow.int", 323 "aten::pow.float", 324 "aten::pow.int_float", 325 "aten::pow.float_int", 326 "aten::pow.Scalar_Scalar", 327 "aten::pow.int_to_int", 328 "prim::min.int", 329 "prim::min.float", 330 "prim::min.int_float", 331 "prim::min.float_int", 332 "prim::min", 333 "prim::max.int", 334 "prim::max.float", 335 "prim::max.int_float", 336 "prim::max.float_int", 337 "prim::max", 338 "prim::type", 339 "aten::len.Tensor", 340 "aten::ord", 341 "aten::lower", 342 "aten::__contains__.str_list", 343 "aten::len.str", 344 "aten::__getitem__.str", 345 "aten::copy_.Tensor", 346 "aten::copy_.int", 347 "aten::copy_.float", 348 "aten::backward", 349 "aten::index.Tensor_hacked_twin", 350 "aten::_unsafe_index.Tensor_hacked_twin", 351 "aten::_index_put_impl_.hacked_twin", 352 "aten::index_put_.hacked_twin", 353 "aten::index_put.hacked_twin", 354 "aten::_unsafe_index_put.hacked_twin", 355 "aten::to.prim_Device", 356 "aten::to.prim_dtype", 357 "prim::is_cuda", 358 "prim::data", 359 "prim::min.int_list", 360 "prim::max.int_list", 361 "prim::min.self_int", 362 "prim::max.self_int", 363 "prim::min.float_list", 364 "prim::max.float_list", 365 "prim::min.self_float", 366 "prim::max.self_float", 367 "prim::min.bool_list", 368 "prim::max.bool_list", 369 "prim::min.self_bool", 370 "prim::max.self_bool", 371 "aten::len.Dict_str", 372 "aten::keys.str", 373 "aten::values.str", 374 "aten::__getitem__.Dict_str", 375 "aten::get.str", 376 "aten::get.default_str", 377 "aten::setdefault.str", 378 "aten::Delete.Dict_str", 379 "aten::pop.Dict_str", 380 "aten::pop.Dict_default_str", 381 "aten::popitem.str", 382 "aten::clear.str", 383 "aten::update.str", 384 "aten::items.str", 385 "aten::copy.Dict_str", 386 "aten::__contains__.str", 387 "aten::_set_item.str", 388 "aten::dict.str", 389 "aten::len.Dict_int", 390 "aten::keys.int", 391 "aten::values.int", 392 "aten::__getitem__.Dict_int", 393 "aten::get.int", 394 "aten::get.default_int", 395 "aten::setdefault.int", 396 "aten::Delete.Dict_int", 397 "aten::pop.Dict_int", 398 "aten::pop.Dict_default_int", 399 "aten::popitem.int", 400 "aten::clear.int", 401 "aten::update.int", 402 "aten::items.int", 403 "aten::copy.Dict_int", 404 "aten::__contains__.int", 405 "aten::_set_item.int", 406 "aten::dict.int", 407 "aten::len.Dict_bool", 408 "aten::keys.bool", 409 "aten::values.bool", 410 "aten::__getitem__.Dict_bool", 411 "aten::get.bool", 412 "aten::get.default_bool", 413 "aten::setdefault.bool", 414 "aten::Delete.Dict_bool", 415 "aten::pop.Dict_bool", 416 "aten::pop.Dict_default_bool", 417 "aten::popitem.bool", 418 "aten::clear.bool", 419 "aten::update.bool", 420 "aten::items.bool", 421 "aten::copy.Dict_bool", 422 "aten::__contains__.bool", 423 "aten::_set_item.bool", 424 "aten::dict.bool", 425 "aten::len.Dict_float", 426 "aten::keys.float", 427 "aten::values.float", 428 "aten::__getitem__.Dict_float", 429 "aten::get.float", 430 "aten::get.default_float", 431 "aten::setdefault.float", 432 "aten::Delete.Dict_float", 433 "aten::pop.Dict_float", 434 "aten::pop.Dict_default_float", 435 "aten::popitem.float", 436 "aten::clear.float", 437 "aten::update.float", 438 "aten::items.float", 439 "aten::copy.Dict_float", 440 "aten::__contains__.float", 441 "aten::_set_item.float", 442 "aten::dict.float", 443 "aten::len.Dict_Tensor", 444 "aten::keys.Tensor", 445 "aten::values.Tensor", 446 "aten::__getitem__.Dict_Tensor", 447 "aten::get.Tensor", 448 "aten::get.default_Tensor", 449 "aten::setdefault.Tensor", 450 "aten::Delete.Dict_Tensor", 451 "aten::pop.Dict_Tensor", 452 "aten::pop.Dict_default_Tensor", 453 "aten::popitem.Tensor", 454 "aten::clear.Tensor", 455 "aten::update.Tensor", 456 "aten::items.Tensor", 457 "aten::copy.Dict_Tensor", 458 "aten::__contains__.Tensor", 459 "aten::_set_item.Tensor", 460 "aten::dict.Tensor", 461 "aten::__round_to_zero_floordiv.int", 462 "aten::mathremainder.int", 463 "aten::mathremainder.float", 464 "aten::mathremainder.int_float", 465 "aten::mathremainder.float_int", 466 "aten::mathremainder", 467 "aten::__and__.int", 468 "aten::__or__.int", 469 "aten::__xor__.int", 470 "aten::__lshift__.int", 471 "aten::__rshift__.int", 472 "aten::round.int", 473 "aten::round.float", 474 "aten::round.Scalar", 475 "aten::log.int", 476 "aten::log.float", 477 "aten::log.Scalar", 478 "aten::log.int_int", 479 "aten::log.float_float", 480 "aten::log.int_float", 481 "aten::log.float_int", 482 "aten::log.Scalar_Scalar", 483 "aten::log1p.int", 484 "aten::log1p.float", 485 "aten::log1p.Scalar", 486 "aten::log10.int", 487 "aten::log10.float", 488 "aten::log10.Scalar", 489 "aten::sqrt.int", 490 "aten::sqrt.float", 491 "aten::sqrt.Scalar", 492 "aten::acos.int", 493 "aten::acos.float", 494 "aten::acos.Scalar", 495 "aten::asin.int", 496 "aten::asin.float", 497 "aten::asin.Scalar", 498 "aten::atan.int", 499 "aten::atan.float", 500 "aten::atan.Scalar", 501 "aten::atan2.int", 502 "aten::atan2.float", 503 "aten::atan2.int_float", 504 "aten::atan2.float_int", 505 "aten::atan2.Scalar_Scalar", 506 "aten::cos.int", 507 "aten::cos.float", 508 "aten::cos.Scalar", 509 "aten::sin.int", 510 "aten::sin.float", 511 "aten::sin.Scalar", 512 "aten::tan.int", 513 "aten::tan.float", 514 "aten::tan.Scalar", 515 "aten::asinh.int", 516 "aten::asinh.float", 517 "aten::asinh.Scalar", 518 "aten::atanh.int", 519 "aten::atanh.float", 520 "aten::atanh.Scalar", 521 "aten::acosh.int", 522 "aten::acosh.float", 523 "aten::acosh.Scalar", 524 "aten::sinh.int", 525 "aten::sinh.float", 526 "aten::sinh.Scalar", 527 "aten::cosh.int", 528 "aten::cosh.float", 529 "aten::cosh.Scalar", 530 "aten::tanh.int", 531 "aten::tanh.float", 532 "aten::tanh.Scalar", 533 "aten::degrees.int", 534 "aten::degrees.float", 535 "aten::degrees.Scalar", 536 "aten::radians.int", 537 "aten::radians.float", 538 "aten::radians.Scalar", 539 "aten::fmod.int", 540 "aten::fmod.float", 541 "aten::fmod.int_float", 542 "aten::fmod.float_int", 543 "aten::fmod", 544 "aten::factorial.int", 545 "aten::isnan.float", 546 "aten::isfinite.float", 547 "aten::isinf.float", 548 "aten::gamma.int", 549 "aten::gamma.float", 550 "aten::gamma.Scalar", 551 "aten::erf.int", 552 "aten::erf.float", 553 "aten::erf.Scalar", 554 "aten::erfc.int", 555 "aten::erfc.float", 556 "aten::erfc.Scalar", 557 "aten::expm1.int", 558 "aten::expm1.float", 559 "aten::expm1.Scalar", 560 "aten::fabs.int", 561 "aten::fabs.float", 562 "aten::fabs.Scalar", 563 "aten::lgamma.int", 564 "aten::lgamma.float", 565 "aten::lgamma.Scalar", 566 "prim::abs.int", 567 "prim::abs.float", 568 "prim::abs.Scalar", 569 "aten::gcd.int", 570 "aten::copysign.int", 571 "aten::copysign.float", 572 "aten::copysign.int_float", 573 "aten::copysign.float_int", 574 "aten::copysign", 575 "aten::split", 576 "aten::tensor.float", 577 "aten::as_tensor.float", 578 "aten::tensor.int", 579 "aten::as_tensor.int", 580 "aten::tensor.bool", 581 "aten::as_tensor.bool", 582 "aten::_infer_size", 583 "aten::_no_grad_embedding_renorm_", 584 "aten::tensor", 585 "aten::as_tensor", 586 "aten::as_tensor.list", 587 "aten::_pack_sequence", 588 "aten::_get_tracing_state", 589 "aten::is_scripting", 590 "aten::_no_grad_uniform_", 591 "aten::_no_grad_normal_", 592 "aten::_no_grad_fill_", 593 "aten::_no_grad_zero_", 594] 595 596PT_BASE_OPS = [ 597 "aten::_coalesced_", 598 "aten::_copy_from", 599 "aten::_empty_affine_quantized", 600 "aten::_empty_per_channel_affine_quantized", 601 "aten::_indices", 602 "aten::_nnz", 603 "aten::_values", 604 "aten::add", 605 "aten::add_", 606 "aten::arange", 607 "aten::as_strided", 608 "aten::as_strided_", 609 "aten::cat", 610 "aten::clone", 611 "aten::coalesce", 612 "aten::contiguous", 613 "aten::copy_", 614 "aten::copy_sparse_to_sparse_", 615 "aten::dense_dim", 616 "aten::dequantize", 617 "aten::div", 618 "aten::div_", 619 "aten::empty", 620 "aten::empty_like", 621 "aten::empty_strided", 622 "aten::eq", 623 "aten::equal", 624 "aten::expand", 625 "aten::fill_", 626 "aten::is_coalesced", 627 "aten::is_complex", 628 "aten::is_floating_point", 629 "aten::is_leaf", 630 "aten::is_nonzero", 631 "aten::item", 632 "aten::max", 633 "aten::min", 634 "aten::mul", 635 "aten::mul_", 636 "aten::narrow", 637 "aten::ne", 638 "aten::permute", 639 "aten::q_per_channel_axis", 640 "aten::q_per_channel_scales", 641 "aten::q_per_channel_zero_points", 642 "aten::q_scale", 643 "aten::q_zero_point", 644 "aten::qscheme", 645 "aten::quantize_per_tensor", 646 "aten::reshape", 647 "aten::_reshape_alias", 648 "aten::resize_", 649 "aten::resize_as_", 650 "aten::scalar_tensor", 651 "aten::select", 652 "aten::set_", 653 "aten::size", 654 "aten::slice", 655 "aten::sparse_dim", 656 "aten::sparse_resize_and_clear_", 657 "aten::squeeze", 658 "aten::squeeze_", 659 "aten::stride", 660 "aten::sub", 661 "aten::sub_", 662 "aten::sum", 663 "aten::t", 664 "aten::to", 665 "aten::_to_copy", 666 "aten::unsqueeze", 667 "aten::view", 668 "aten::zero_", 669 "aten::zeros", 670 "aten::zeros_like", 671] 672