xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_dropout.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/remove_dropout.h>
2 
3 namespace torch::jit {
4 
5 namespace {
isDropoutRemovable(const Node * node)6 bool isDropoutRemovable(const Node* node) {
7   const auto inputs = node->inputs();
8   TORCH_INTERNAL_ASSERT(inputs.size() == 3);
9   const Value* training_input = inputs[2];
10   auto optional_ivalue = toIValue(training_input);
11   if (!optional_ivalue) {
12     return false;
13   }
14   const IValue& val = optional_ivalue.value();
15   TORCH_INTERNAL_ASSERT(val.isBool());
16   const bool is_training = val.toBool();
17   return !is_training;
18 }
19 
removeDropoutImpl(Block * block)20 void removeDropoutImpl(Block* block) {
21   std::vector<Node*> deleted_nodes;
22 
23   for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
24     Node* node = *it;
25     for (auto block : node->blocks()) {
26       removeDropoutImpl(block);
27     }
28     if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") ||
29          node->kind() == c10::Symbol::fromQualString("aten::dropout_") ||
30          node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") ||
31          node->kind() ==
32              c10::Symbol::fromQualString("aten::feature_dropout_")) &&
33         isDropoutRemovable(*it)) {
34       // Input tensor of dropout.
35       Value* input_value = node->inputs()[0];
36       // Output tensor.
37       Value* output_value = node->outputs()[0];
38       output_value->replaceAllUsesWith(input_value);
39       deleted_nodes.push_back(node);
40     }
41   }
42   for (auto del_node : deleted_nodes) {
43     del_node->destroy();
44   }
45 }
46 } // namespace
47 
removeDropout(std::shared_ptr<Graph> & graph)48 void removeDropout(std::shared_ptr<Graph>& graph) {
49   removeDropoutImpl(graph->block());
50 }
51 
removeDropout(script::Module & module)52 void removeDropout(script::Module& module) {
53   TORCH_CHECK(
54       !module.hasattr("training") || !module.is_training(),
55       "Dropout removal module in training mode is not yet supported");
56   auto graph = module.get_method("forward").graph();
57   removeDropout(graph);
58 }
59 
60 } // namespace torch::jit
61