1 #pragma once 2 3 #include <cstdint> 4 #include <functional> 5 #include <memory> 6 7 #include <c10/util/hash.h> 8 9 namespace torch::autograd { 10 11 struct Node; 12 13 /// Represents a particular input of a function. 14 struct Edge { EdgeEdge15 Edge() noexcept : function(nullptr), input_nr(0) {} 16 EdgeEdge17 Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept 18 : function(std::move(function_)), input_nr(input_nr_) {} 19 20 /// Convenience method to test if an edge is valid. is_validEdge21 bool is_valid() const noexcept { 22 return function != nullptr; 23 } 24 25 // Required for use in associative containers. 26 bool operator==(const Edge& other) const noexcept { 27 return this->function == other.function && this->input_nr == other.input_nr; 28 } 29 30 bool operator!=(const Edge& other) const noexcept { 31 return !(*this == other); 32 } 33 34 /// The function this `Edge` points to. 35 std::shared_ptr<Node> function; 36 37 /// The identifier of a particular input to the function. 38 uint32_t input_nr; 39 }; 40 } // namespace torch::autograd 41 42 // The idiomatic way of enabling use of a custom type as the key of hash 43 // containers in C++11. This method removes the requirement of having to pass 44 // a custom hasher to std::unordered_{map, set}. 45 // See http://en.cppreference.com/w/cpp/utility/hash for more information. 46 namespace std { 47 template <> 48 struct hash<torch::autograd::Edge> { 49 // These type aliases are required by the standard. 50 using argument_type = torch::autograd::Edge; 51 using return_type = size_t; 52 return_type operator()(const argument_type& edge) const noexcept { 53 return c10::get_hash(edge.function, edge.input_nr); 54 } 55 }; 56 } // namespace std 57