xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ops/to_copy.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/ts_backend/ts_node.h>
4 
5 namespace torch {
6 namespace lazy {
7 
8 // This IR was copied from code-generated output, but the entire _to_copy
9 // operator cannot be trivially code genereated since it is only desirable to
10 // capture IR for certain permutaions of _to_copy (e.g. dtype), and for the
11 // others it is difficult to even invoke the aten/eager fallback necessitating
12 // directly implementing the right to(device) behavior
13 class ToCopy : public torch::lazy::TsNode {
14  public:
ClassOpKind()15   static OpKind ClassOpKind() {
16     return OpKind(at::aten::_to_copy);
17   }
18 
ToCopy(const torch::lazy::Value & self,const std::optional<at::ScalarType> & dtype,const std::optional<at::Layout> & layout,const std::optional<at::Device> & device,const std::optional<bool> & pin_memory,const bool & non_blocking,const std::optional<at::MemoryFormat> & memory_format,std::vector<torch::lazy::Shape> && shapes)19   ToCopy(
20       const torch::lazy::Value& self,
21       const std::optional<at::ScalarType>& dtype,
22       const std::optional<at::Layout>& layout,
23       const std::optional<at::Device>& device,
24       const std::optional<bool>& pin_memory,
25       const bool& non_blocking,
26       const std::optional<at::MemoryFormat>& memory_format,
27       std::vector<torch::lazy::Shape>&& shapes)
28       : torch::lazy::TsNode(
29             ClassOpKind(),
30             {self},
31             std::move(shapes),
32             /* num_outputs */ 1,
33             torch::lazy::MHash(
34                 dtype,
35                 layout,
36                 device,
37                 pin_memory,
38                 non_blocking,
39                 memory_format)),
40 
41         dtype(dtype),
42         layout(layout),
43         device(device),
44         pin_memory(pin_memory),
45         non_blocking(non_blocking),
46         memory_format(memory_format) {}
47 
CanBeReused(const torch::lazy::Value & self,const std::optional<at::ScalarType> & dtype,const std::optional<at::Layout> & layout,const std::optional<at::Device> & device,const std::optional<bool> & pin_memory,const bool & non_blocking,const std::optional<at::MemoryFormat> & memory_format)48   bool CanBeReused(
49       const torch::lazy::Value& self,
50       const std::optional<at::ScalarType>& dtype,
51       const std::optional<at::Layout>& layout,
52       const std::optional<at::Device>& device,
53       const std::optional<bool>& pin_memory,
54       const bool& non_blocking,
55       const std::optional<at::MemoryFormat>& memory_format) const {
56     size_t i = 0;
57     return (
58         operand(i++) == self && this->dtype == dtype &&
59         this->layout == layout && this->device == device &&
60         this->pin_memory == pin_memory && this->non_blocking == non_blocking &&
61         this->memory_format == memory_format);
62   }
63 
ToString()64   std::string ToString() const override {
65     std::stringstream ss;
66     ss << torch::lazy::TsNode::ToString();
67     if (dtype.has_value()) {
68       ss << ", dtype=" << dtype.value();
69     } else {
70       ss << ", dtype=null";
71     }
72     if (layout.has_value()) {
73       ss << ", layout=" << layout.value();
74     } else {
75       ss << ", layout=null";
76     }
77     if (device.has_value()) {
78       ss << ", device=" << device.value();
79     } else {
80       ss << ", device=null";
81     }
82     if (pin_memory.has_value()) {
83       ss << ", pin_memory=" << pin_memory.value();
84     } else {
85       ss << ", pin_memory=null";
86     }
87     ss << ", non_blocking=" << non_blocking;
88     if (memory_format.has_value()) {
89       ss << ", memory_format=" << memory_format.value();
90     } else {
91       ss << ", memory_format=null";
92     }
93     return ss.str();
94   }
95 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx)96   torch::lazy::TSOpVector Lower(
97       std::shared_ptr<torch::jit::GraphFunction> function,
98       torch::lazy::TSLoweringContext* loctx) const override {
99     std::vector<torch::jit::NamedValue> arguments;
100     std::vector<torch::jit::NamedValue> kwarguments;
101     arguments.reserve(1);
102     kwarguments.reserve(6);
103     size_t i = 0;
104     arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
105     kwarguments.emplace_back("dtype", dtype);
106     kwarguments.emplace_back("layout", layout);
107     kwarguments.emplace_back("device", device);
108     kwarguments.emplace_back("pin_memory", pin_memory);
109     kwarguments.emplace_back("non_blocking", non_blocking);
110     kwarguments.emplace_back("memory_format", memory_format);
111     torch::lazy::TSOpVector _to_copy_out =
112         torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
113     TORCH_CHECK_EQ(_to_copy_out.size(), 1);
114 
115     return _to_copy_out;
116   }
117 
118   std::optional<at::ScalarType> dtype;
119   std::optional<at::Layout> layout;
120   std::optional<at::Device> device;
121   std::optional<bool> pin_memory;
122   bool non_blocking;
123   std::optional<at::MemoryFormat> memory_format;
124 };
125 
126 } // namespace lazy
127 } // namespace torch
128