1 #include <torch/csrc/jit/passes/remove_dropout.h> 2 3 namespace torch::jit { 4 5 namespace { isDropoutRemovable(const Node * node)6bool isDropoutRemovable(const Node* node) { 7 const auto inputs = node->inputs(); 8 TORCH_INTERNAL_ASSERT(inputs.size() == 3); 9 const Value* training_input = inputs[2]; 10 auto optional_ivalue = toIValue(training_input); 11 if (!optional_ivalue) { 12 return false; 13 } 14 const IValue& val = optional_ivalue.value(); 15 TORCH_INTERNAL_ASSERT(val.isBool()); 16 const bool is_training = val.toBool(); 17 return !is_training; 18 } 19 removeDropoutImpl(Block * block)20void removeDropoutImpl(Block* block) { 21 std::vector<Node*> deleted_nodes; 22 23 for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) { 24 Node* node = *it; 25 for (auto block : node->blocks()) { 26 removeDropoutImpl(block); 27 } 28 if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") || 29 node->kind() == c10::Symbol::fromQualString("aten::dropout_") || 30 node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") || 31 node->kind() == 32 c10::Symbol::fromQualString("aten::feature_dropout_")) && 33 isDropoutRemovable(*it)) { 34 // Input tensor of dropout. 35 Value* input_value = node->inputs()[0]; 36 // Output tensor. 37 Value* output_value = node->outputs()[0]; 38 output_value->replaceAllUsesWith(input_value); 39 deleted_nodes.push_back(node); 40 } 41 } 42 for (auto del_node : deleted_nodes) { 43 del_node->destroy(); 44 } 45 } 46 } // namespace 47 removeDropout(std::shared_ptr<Graph> & graph)48void removeDropout(std::shared_ptr<Graph>& graph) { 49 removeDropoutImpl(graph->block()); 50 } 51 removeDropout(script::Module & module)52void removeDropout(script::Module& module) { 53 TORCH_CHECK( 54 !module.hasattr("training") || !module.is_training(), 55 "Dropout removal module in training mode is not yet supported"); 56 auto graph = module.get_method("forward").graph(); 57 removeDropout(graph); 58 } 59 60 } // namespace torch::jit 61