xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/ppl_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from utils import GetterReturnType
2
3import torch
4import torch.distributions as dist
5from torch import Tensor
6
7
8def get_simple_regression(device: torch.device) -> GetterReturnType:
9    N = 10
10    K = 10
11
12    loc_beta = 0.0
13    scale_beta = 1.0
14
15    beta_prior = dist.Normal(loc_beta, scale_beta)
16
17    X = torch.rand(N, K + 1, device=device)
18    Y = torch.rand(N, 1, device=device)
19
20    # X.shape: (N, K + 1), Y.shape: (N, 1), beta_value.shape: (K + 1, 1)
21    beta_value = beta_prior.sample((K + 1, 1))
22    beta_value.requires_grad_(True)
23
24    def forward(beta_value: Tensor) -> Tensor:
25        mu = X.mm(beta_value)
26
27        # We need to compute the first and second gradient of this score with respect
28        # to beta_value. We disable Bernoulli validation because Y is a relaxed value.
29        score = (
30            dist.Bernoulli(logits=mu, validate_args=False).log_prob(Y).sum()
31            + beta_prior.log_prob(beta_value).sum()
32        )
33        return score
34
35    return forward, (beta_value.to(device),)
36
37
38def get_robust_regression(device: torch.device) -> GetterReturnType:
39    N = 10
40    K = 10
41
42    # X.shape: (N, K + 1), Y.shape: (N, 1)
43    X = torch.rand(N, K + 1, device=device)
44    Y = torch.rand(N, 1, device=device)
45
46    # Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1)
47    nu_alpha = torch.rand(1, 1, device=device)
48    nu_beta = torch.rand(1, 1, device=device)
49    nu = dist.Gamma(nu_alpha, nu_beta)
50
51    # Predefined sigma_rate: sigma_rate.shape: (N, 1)
52    sigma_rate = torch.rand(N, 1, device=device)
53    sigma = dist.Exponential(sigma_rate)
54
55    # Predefined beta_mean and beta_sigma: beta_mean.shape: (K + 1, 1), beta_sigma.shape: (K + 1, 1)
56    beta_mean = torch.rand(K + 1, 1, device=device)
57    beta_sigma = torch.rand(K + 1, 1, device=device)
58    beta = dist.Normal(beta_mean, beta_sigma)
59
60    nu_value = nu.sample()
61    nu_value.requires_grad_(True)
62
63    sigma_value = sigma.sample()
64    sigma_unconstrained_value = sigma_value.log()
65    sigma_unconstrained_value.requires_grad_(True)
66
67    beta_value = beta.sample()
68    beta_value.requires_grad_(True)
69
70    def forward(
71        nu_value: Tensor, sigma_unconstrained_value: Tensor, beta_value: Tensor
72    ) -> Tensor:
73        sigma_constrained_value = sigma_unconstrained_value.exp()
74        mu = X.mm(beta_value)
75
76        # For this model, we need to compute the following three scores:
77        # We need to compute the first and second gradient of this score with respect
78        # to nu_value.
79        nu_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(
80            Y
81        ).sum() + nu.log_prob(nu_value)
82
83        # We need to compute the first and second gradient of this score with respect
84        # to sigma_unconstrained_value.
85        sigma_score = (
86            dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum()
87            + sigma.log_prob(sigma_constrained_value)
88            + sigma_unconstrained_value
89        )
90
91        # We need to compute the first and second gradient of this score with respect
92        # to beta_value.
93        beta_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(
94            Y
95        ).sum() + beta.log_prob(beta_value)
96
97        return nu_score.sum() + sigma_score.sum() + beta_score.sum()
98
99    return forward, (
100        nu_value.to(device),
101        sigma_unconstrained_value.to(device),
102        beta_value.to(device),
103    )
104