xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/deep_wide_pt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/CPUFunctions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <torch/torch.h>
6 
7 struct DeepAndWide : torch::nn::Module {
8   DeepAndWide(int num_features = 50) {
9     mu_ = register_parameter("mu_", torch::randn({1, num_features}));
10     sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
11     fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
12     fc_b_ = register_parameter("fc_b_", torch::randn({1}));
13   }
14 
forwardDeepAndWide15   torch::Tensor forward(
16       torch::Tensor ad_emb_packed,
17       torch::Tensor user_emb,
18       torch::Tensor wide) {
19     auto wide_offset = wide + mu_;
20     auto wide_normalized = wide_offset * sigma_;
21     auto wide_noNaN = wide_normalized;
22     // Placeholder for ReplaceNaN
23     auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0);
24 
25     auto user_emb_t = torch::transpose(user_emb, 1, 2);
26     auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t);
27     auto dp = torch::flatten(dp_unflatten, 1);
28     auto input = torch::cat({dp, wide_preproc}, 1);
29     auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
30     auto pred = torch::sigmoid(fc1);
31     return pred;
32   }
33   torch::Tensor mu_, sigma_, fc_w_, fc_b_;
34 };
35 
36 // Implementation using native functions and pre-allocated tensors.
37 // It could be used as a "speed of light" for static runtime.
38 struct DeepAndWideFast : torch::nn::Module {
39   DeepAndWideFast(int num_features = 50) {
40     mu_ = register_parameter("mu_", torch::randn({1, num_features}));
41     sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
42     fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
43     fc_b_ = register_parameter("fc_b_", torch::randn({1}));
44     allocated = false;
45     prealloc_tensors = {};
46   }
47 
forwardDeepAndWideFast48   torch::Tensor forward(
49       torch::Tensor ad_emb_packed,
50       torch::Tensor user_emb,
51       torch::Tensor wide) {
52     torch::NoGradGuard no_grad;
53     if (!allocated) {
54       auto wide_offset = at::add(wide, mu_);
55       auto wide_normalized = at::mul(wide_offset, sigma_);
56       // Placeholder for ReplaceNaN
57       auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0);
58 
59       auto user_emb_t = at::native::transpose(user_emb, 1, 2);
60       auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t);
61       // auto dp = at::native::flatten(dp_unflatten, 1);
62       auto dp = dp_unflatten.view({dp_unflatten.size(0), 1});
63       auto input = at::cpu::cat({dp, wide_preproc}, 1);
64 
65       // fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
66       fc_w_t_ = torch::t(fc_w_);
67       auto fc1 = torch::addmm(fc_b_, input, fc_w_t_);
68 
69       auto pred = at::cpu::sigmoid(fc1);
70 
71       prealloc_tensors = {
72           wide_offset,
73           wide_normalized,
74           wide_preproc,
75           user_emb_t,
76           dp_unflatten,
77           dp,
78           input,
79           fc1,
80           pred};
81       allocated = true;
82 
83       return pred;
84     } else {
85       // Potential optimization: add and mul could be fused together (e.g. with
86       // Eigen).
87       at::add_out(prealloc_tensors[0], wide, mu_);
88       at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_);
89 
90       at::native::clip_out(
91           prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]);
92 
93       // Potential optimization: original tensor could be pre-transposed.
94       // prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2);
95       if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) {
96         auto sizes = user_emb.sizes();
97         auto strides = user_emb.strides();
98         prealloc_tensors[3].set_(
99             user_emb.storage(),
100             0,
101             {sizes[0], sizes[2], sizes[1]},
102             {strides[0], strides[2], strides[1]});
103       }
104 
105       // Potential optimization: call MKLDNN directly.
106       at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]);
107 
108       if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) {
109         // in unlikely case that the input tensor changed we need to
110         // reinitialize the view
111         prealloc_tensors[5] =
112             prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1});
113       }
114 
115       // Potential optimization: we can replace cat with carefully constructed
116       // tensor views on the output that are passed to the _out ops above.
117       at::cpu::cat_outf(
118           {prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]);
119       at::cpu::addmm_out(
120           prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1);
121       at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]);
122 
123       return prealloc_tensors[8];
124     }
125   }
126   torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_;
127   std::vector<torch::Tensor> prealloc_tensors;
128   bool allocated = false;
129 };
130 
131 torch::jit::Module getDeepAndWideSciptModel(int num_features = 50);
132 
133 torch::jit::Module getTrivialScriptModel();
134 
135 torch::jit::Module getLeakyReLUScriptModel();
136 
137 torch::jit::Module getLeakyReLUConstScriptModel();
138 
139 torch::jit::Module getLongScriptModel();
140 
141 torch::jit::Module getSignedLog1pModel();
142