1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <string> 9 #include <vector> 10 11 #include <torch/csrc/jit/ir/ir.h> 12 #include <torch/csrc/jit/runtime/static/impl.h> 13 14 namespace c10 { 15 struct IValue; 16 } 17 18 namespace torch { 19 namespace jit { 20 21 struct Node; 22 class StaticModule; 23 24 namespace test { 25 26 // Given a model/function in jit or IR script, run the model/function 27 // with the jit interpreter and static runtime, and compare the results 28 void testStaticRuntime( 29 const std::string& source, 30 const std::vector<c10::IValue>& args, 31 const std::vector<c10::IValue>& args2 = {}, 32 const bool use_allclose = false, 33 const bool use_equalnan = false, 34 const bool check_resize = true); 35 36 std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script); 37 38 std::shared_ptr<Graph> getGraphFromIR(const std::string& ir); 39 40 bool hasProcessedNodeWithName( 41 torch::jit::StaticModule& smodule, 42 const char* name); 43 44 at::Tensor getTensor(const at::IValue& ival); 45 46 Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind); 47 Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind); 48 49 bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind); 50 bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind); 51 52 void compareResultsWithJIT( 53 StaticRuntime& runtime, 54 const std::shared_ptr<Graph>& graph, 55 const std::vector<c10::IValue>& args, 56 const bool use_allclose = false, 57 const bool use_equalnan = false); 58 59 void compareResults( 60 const IValue& expect, 61 const IValue& actual, 62 const bool use_allclose = false, 63 const bool use_equalnan = false); 64 65 } // namespace test 66 } // namespace jit 67 } // namespace torch 68