1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/llm/sampler/sampler.h>
10
11 #include <gtest/gtest.h>
12 #include <torch/torch.h>
13
14 using namespace ::testing;
15 using ::executorch::extension::llm::Sampler;
16
TEST(SamplerTest,TestArgMax)17 TEST(SamplerTest, TestArgMax) {
18 Sampler sampler{
19 /*vocab_size*/ 32000,
20 /*temperature*/ 0.0f,
21 /*topp*/ 0.9f,
22 /*rng_seed*/ 0};
23 // tensor([[[-12.9832, -7.4133, -0.4327, ..., -6.8297, -8.0880,
24 // -7.5863]]])
25 torch::Tensor input = torch::rand({1, 1, 32000}, at::kFloat);
26 input[0][0][396] = 1.0f;
27 EXPECT_EQ(sampler.sample(input.data_ptr<float>()), 396);
28 }
29
TEST(SamplerTest,TestArgMaxWithFP16)30 TEST(SamplerTest, TestArgMaxWithFP16) {
31 Sampler sampler{
32 /*vocab_size*/ 32000,
33 /*temperature*/ 0.0f,
34 /*topp*/ 0.9f,
35 /*rng_seed*/ 0};
36 // tensor([[[-12.9832, -7.4133, -0.4327, ..., -6.8297, -8.0880,
37 // -7.5863]]])
38 torch::Tensor input = torch::rand({1, 1, 32000}, at::kHalf);
39 input[0][0][396] = 1.0f;
40 EXPECT_EQ(sampler.sample(input.data_ptr<c10::Half>()), 396);
41 }
42