xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/named_value.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <torch/csrc/jit/frontend/source_range.h>
4 #include <torch/csrc/jit/ir/constants.h>
5 #include <torch/csrc/utils/variadic.h>
6 
7 namespace torch::jit {
8 
9 struct Value;
10 
11 /**
12  * A value with optional extra name and location information. Used during
13  * schema matching to provide extra error information and resolve kwargs.
14  */
15 struct NamedValue {
NamedValueNamedValue16   NamedValue(const SourceRange& loc, const std::string& name, Value* value)
17       : loc_(loc), name_(name), value_(value) {}
NamedValueNamedValue18   NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
19 
NamedValueNamedValue20   /* implicit */ NamedValue(Value* value) : value_(value) {}
NamedValueNamedValue21   NamedValue(const std::string& name, Value* value)
22       : name_(name), value_(value) {}
23 
NamedValueNamedValue24   /* implicit */ NamedValue(IValue value) : ivalue_(std::move(value)) {}
25 
NamedValueNamedValue26   NamedValue(const std::string& name, IValue value)
27       : name_(name), ivalue_(std::move(value)) {}
28 
29   template <
30       typename T,
31       typename = std::enable_if_t<
32           (!std::is_same_v<std::decay_t<T>, NamedValue> &&
33            !std::is_same_v<std::decay_t<T>, Value*> &&
34            !std::is_same_v<std::decay_t<T>, IValue>)>>
35   // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
NamedValueNamedValue36   NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
37 
38   template <
39       typename T,
40       typename = std::enable_if_t<
41           (!std::is_same_v<std::decay_t<T>, Value*> &&
42            !std::is_same_v<std::decay_t<T>, IValue>)>>
NamedValueNamedValue43   NamedValue(const std::string& name, T&& t)
44       : NamedValue(name, IValue(std::forward<T>(t))) {}
45 
locOrNamedValue46   SourceRange locOr(const SourceRange& backup_location) const {
47     if (!loc_)
48       return backup_location;
49     return loc();
50   }
51 
52   // note: this will insert a constant node into the graph at the current
53   // insert point if this NamedValue is actually a constant
valueNamedValue54   Value* value(Graph& g) const {
55     if (!value_)
56       return insertConstant(
57           g, ivalue_); // use insertConstant to remove need to include ir.h here
58     return value_;
59   }
60 
nameNamedValue61   const std::string& name() const {
62     AT_ASSERT(name_);
63     return *name_;
64   }
65 
locNamedValue66   const SourceRange& loc() const {
67     AT_ASSERT(loc_);
68     return *loc_;
69   }
70 
71   at::TypePtr type() const;
72 
73  private:
74   std::optional<SourceRange> loc_;
75   std::optional<std::string> name_;
76   Value* value_{nullptr};
77   // only valid if value_ == nullptr;
78   IValue ivalue_;
79 };
80 
81 } // namespace torch::jit
82