xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/unique.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Unique in this file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/master/third_party/xla_client/unique.h
4  */
5 
6 #pragma once
7 
8 #include <optional>
9 
10 #include <functional>
11 #include <set>
12 
13 namespace torch {
14 namespace lazy {
15 
16 // Helper class to allow tracking zero or more things, which should be forcibly
17 // be one only thing.
18 template <typename T, typename C = std::equal_to<T>>
19 class Unique {
20  public:
set(const T & value)21   std::pair<bool, const T&> set(const T& value) {
22     if (value_) {
23       TORCH_CHECK(C()(*value_, value), "'", *value_, "' vs '", value);
24       return std::pair<bool, const T&>(false, *value_);
25     }
26     value_ = value;
27     return std::pair<bool, const T&>(true, *value_);
28   }
29 
30   operator bool() const {
31     return value_.has_value();
32   }
33   operator const T&() const {
34     return *value_;
35   }
36   const T& operator*() const {
37     return *value_;
38   }
39   const T* operator->() const {
40     return value_.operator->();
41   }
42 
AsSet()43   std::set<T> AsSet() const {
44     std::set<T> vset;
45     if (value_.has_value()) {
46       vset.insert(*value_);
47     }
48     return vset;
49   }
50 
51  private:
52   std::optional<T> value_;
53 };
54 
55 } // namespace lazy
56 } // namespace torch
57