xref: /aosp_15_r20/external/executorch/extension/llm/sampler/test/test_sampler.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/sampler/sampler.h>
10 
11 #include <gtest/gtest.h>
12 #include <torch/torch.h>
13 
14 using namespace ::testing;
15 using ::executorch::extension::llm::Sampler;
16 
TEST(SamplerTest,TestArgMax)17 TEST(SamplerTest, TestArgMax) {
18   Sampler sampler{
19       /*vocab_size*/ 32000,
20       /*temperature*/ 0.0f,
21       /*topp*/ 0.9f,
22       /*rng_seed*/ 0};
23   // tensor([[[-12.9832,  -7.4133,  -0.4327,  ...,  -6.8297,  -8.0880,
24   // -7.5863]]])
25   torch::Tensor input = torch::rand({1, 1, 32000}, at::kFloat);
26   input[0][0][396] = 1.0f;
27   EXPECT_EQ(sampler.sample(input.data_ptr<float>()), 396);
28 }
29 
TEST(SamplerTest,TestArgMaxWithFP16)30 TEST(SamplerTest, TestArgMaxWithFP16) {
31   Sampler sampler{
32       /*vocab_size*/ 32000,
33       /*temperature*/ 0.0f,
34       /*topp*/ 0.9f,
35       /*rng_seed*/ 0};
36   // tensor([[[-12.9832,  -7.4133,  -0.4327,  ...,  -6.8297,  -8.0880,
37   // -7.5863]]])
38   torch::Tensor input = torch::rand({1, 1, 32000}, at::kHalf);
39   input[0][0][396] = 1.0f;
40   EXPECT_EQ(sampler.sample(input.data_ptr<c10::Half>()), 396);
41 }
42