1 #pragma once 2 3 #include <torch/csrc/lazy/ts_backend/ts_node.h> 4 5 namespace torch { 6 namespace lazy { 7 8 // This IR was copied from code-generated output, but the entire _to_copy 9 // operator cannot be trivially code genereated since it is only desirable to 10 // capture IR for certain permutaions of _to_copy (e.g. dtype), and for the 11 // others it is difficult to even invoke the aten/eager fallback necessitating 12 // directly implementing the right to(device) behavior 13 class ToCopy : public torch::lazy::TsNode { 14 public: ClassOpKind()15 static OpKind ClassOpKind() { 16 return OpKind(at::aten::_to_copy); 17 } 18 ToCopy(const torch::lazy::Value & self,const std::optional<at::ScalarType> & dtype,const std::optional<at::Layout> & layout,const std::optional<at::Device> & device,const std::optional<bool> & pin_memory,const bool & non_blocking,const std::optional<at::MemoryFormat> & memory_format,std::vector<torch::lazy::Shape> && shapes)19 ToCopy( 20 const torch::lazy::Value& self, 21 const std::optional<at::ScalarType>& dtype, 22 const std::optional<at::Layout>& layout, 23 const std::optional<at::Device>& device, 24 const std::optional<bool>& pin_memory, 25 const bool& non_blocking, 26 const std::optional<at::MemoryFormat>& memory_format, 27 std::vector<torch::lazy::Shape>&& shapes) 28 : torch::lazy::TsNode( 29 ClassOpKind(), 30 {self}, 31 std::move(shapes), 32 /* num_outputs */ 1, 33 torch::lazy::MHash( 34 dtype, 35 layout, 36 device, 37 pin_memory, 38 non_blocking, 39 memory_format)), 40 41 dtype(dtype), 42 layout(layout), 43 device(device), 44 pin_memory(pin_memory), 45 non_blocking(non_blocking), 46 memory_format(memory_format) {} 47 CanBeReused(const torch::lazy::Value & self,const std::optional<at::ScalarType> & dtype,const std::optional<at::Layout> & layout,const std::optional<at::Device> & device,const std::optional<bool> & pin_memory,const bool & non_blocking,const std::optional<at::MemoryFormat> & memory_format)48 bool CanBeReused( 49 const torch::lazy::Value& self, 50 const std::optional<at::ScalarType>& dtype, 51 const std::optional<at::Layout>& layout, 52 const std::optional<at::Device>& device, 53 const std::optional<bool>& pin_memory, 54 const bool& non_blocking, 55 const std::optional<at::MemoryFormat>& memory_format) const { 56 size_t i = 0; 57 return ( 58 operand(i++) == self && this->dtype == dtype && 59 this->layout == layout && this->device == device && 60 this->pin_memory == pin_memory && this->non_blocking == non_blocking && 61 this->memory_format == memory_format); 62 } 63 ToString()64 std::string ToString() const override { 65 std::stringstream ss; 66 ss << torch::lazy::TsNode::ToString(); 67 if (dtype.has_value()) { 68 ss << ", dtype=" << dtype.value(); 69 } else { 70 ss << ", dtype=null"; 71 } 72 if (layout.has_value()) { 73 ss << ", layout=" << layout.value(); 74 } else { 75 ss << ", layout=null"; 76 } 77 if (device.has_value()) { 78 ss << ", device=" << device.value(); 79 } else { 80 ss << ", device=null"; 81 } 82 if (pin_memory.has_value()) { 83 ss << ", pin_memory=" << pin_memory.value(); 84 } else { 85 ss << ", pin_memory=null"; 86 } 87 ss << ", non_blocking=" << non_blocking; 88 if (memory_format.has_value()) { 89 ss << ", memory_format=" << memory_format.value(); 90 } else { 91 ss << ", memory_format=null"; 92 } 93 return ss.str(); 94 } 95 Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx)96 torch::lazy::TSOpVector Lower( 97 std::shared_ptr<torch::jit::GraphFunction> function, 98 torch::lazy::TSLoweringContext* loctx) const override { 99 std::vector<torch::jit::NamedValue> arguments; 100 std::vector<torch::jit::NamedValue> kwarguments; 101 arguments.reserve(1); 102 kwarguments.reserve(6); 103 size_t i = 0; 104 arguments.emplace_back(loctx->GetOutputOp(operand(i++))); 105 kwarguments.emplace_back("dtype", dtype); 106 kwarguments.emplace_back("layout", layout); 107 kwarguments.emplace_back("device", device); 108 kwarguments.emplace_back("pin_memory", pin_memory); 109 kwarguments.emplace_back("non_blocking", non_blocking); 110 kwarguments.emplace_back("memory_format", memory_format); 111 torch::lazy::TSOpVector _to_copy_out = 112 torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); 113 TORCH_CHECK_EQ(_to_copy_out.size(), 1); 114 115 return _to_copy_out; 116 } 117 118 std::optional<at::ScalarType> dtype; 119 std::optional<at::Layout> layout; 120 std::optional<at::Device> device; 121 std::optional<bool> pin_memory; 122 bool non_blocking; 123 std::optional<at::MemoryFormat> memory_format; 124 }; 125 126 } // namespace lazy 127 } // namespace torch 128