1 #pragma once 2 3 #include <ATen/CPUFunctions.h> 4 #include <ATen/NativeFunctions.h> 5 #include <torch/torch.h> 6 7 struct DeepAndWide : torch::nn::Module { 8 DeepAndWide(int num_features = 50) { 9 mu_ = register_parameter("mu_", torch::randn({1, num_features})); 10 sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); 11 fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); 12 fc_b_ = register_parameter("fc_b_", torch::randn({1})); 13 } 14 forwardDeepAndWide15 torch::Tensor forward( 16 torch::Tensor ad_emb_packed, 17 torch::Tensor user_emb, 18 torch::Tensor wide) { 19 auto wide_offset = wide + mu_; 20 auto wide_normalized = wide_offset * sigma_; 21 auto wide_noNaN = wide_normalized; 22 // Placeholder for ReplaceNaN 23 auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0); 24 25 auto user_emb_t = torch::transpose(user_emb, 1, 2); 26 auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t); 27 auto dp = torch::flatten(dp_unflatten, 1); 28 auto input = torch::cat({dp, wide_preproc}, 1); 29 auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); 30 auto pred = torch::sigmoid(fc1); 31 return pred; 32 } 33 torch::Tensor mu_, sigma_, fc_w_, fc_b_; 34 }; 35 36 // Implementation using native functions and pre-allocated tensors. 37 // It could be used as a "speed of light" for static runtime. 38 struct DeepAndWideFast : torch::nn::Module { 39 DeepAndWideFast(int num_features = 50) { 40 mu_ = register_parameter("mu_", torch::randn({1, num_features})); 41 sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); 42 fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); 43 fc_b_ = register_parameter("fc_b_", torch::randn({1})); 44 allocated = false; 45 prealloc_tensors = {}; 46 } 47 forwardDeepAndWideFast48 torch::Tensor forward( 49 torch::Tensor ad_emb_packed, 50 torch::Tensor user_emb, 51 torch::Tensor wide) { 52 torch::NoGradGuard no_grad; 53 if (!allocated) { 54 auto wide_offset = at::add(wide, mu_); 55 auto wide_normalized = at::mul(wide_offset, sigma_); 56 // Placeholder for ReplaceNaN 57 auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0); 58 59 auto user_emb_t = at::native::transpose(user_emb, 1, 2); 60 auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t); 61 // auto dp = at::native::flatten(dp_unflatten, 1); 62 auto dp = dp_unflatten.view({dp_unflatten.size(0), 1}); 63 auto input = at::cpu::cat({dp, wide_preproc}, 1); 64 65 // fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); 66 fc_w_t_ = torch::t(fc_w_); 67 auto fc1 = torch::addmm(fc_b_, input, fc_w_t_); 68 69 auto pred = at::cpu::sigmoid(fc1); 70 71 prealloc_tensors = { 72 wide_offset, 73 wide_normalized, 74 wide_preproc, 75 user_emb_t, 76 dp_unflatten, 77 dp, 78 input, 79 fc1, 80 pred}; 81 allocated = true; 82 83 return pred; 84 } else { 85 // Potential optimization: add and mul could be fused together (e.g. with 86 // Eigen). 87 at::add_out(prealloc_tensors[0], wide, mu_); 88 at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_); 89 90 at::native::clip_out( 91 prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]); 92 93 // Potential optimization: original tensor could be pre-transposed. 94 // prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2); 95 if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) { 96 auto sizes = user_emb.sizes(); 97 auto strides = user_emb.strides(); 98 prealloc_tensors[3].set_( 99 user_emb.storage(), 100 0, 101 {sizes[0], sizes[2], sizes[1]}, 102 {strides[0], strides[2], strides[1]}); 103 } 104 105 // Potential optimization: call MKLDNN directly. 106 at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]); 107 108 if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) { 109 // in unlikely case that the input tensor changed we need to 110 // reinitialize the view 111 prealloc_tensors[5] = 112 prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1}); 113 } 114 115 // Potential optimization: we can replace cat with carefully constructed 116 // tensor views on the output that are passed to the _out ops above. 117 at::cpu::cat_outf( 118 {prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]); 119 at::cpu::addmm_out( 120 prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1); 121 at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]); 122 123 return prealloc_tensors[8]; 124 } 125 } 126 torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_; 127 std::vector<torch::Tensor> prealloc_tensors; 128 bool allocated = false; 129 }; 130 131 torch::jit::Module getDeepAndWideSciptModel(int num_features = 50); 132 133 torch::jit::Module getTrivialScriptModel(); 134 135 torch::jit::Module getLeakyReLUScriptModel(); 136 137 torch::jit::Module getLeakyReLUConstScriptModel(); 138 139 torch::jit::Module getLongScriptModel(); 140 141 torch::jit::Module getSignedLog1pModel(); 142