xref: /aosp_15_r20/external/executorch/examples/models/mobilebert/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerfrom transformers import AutoTokenizer, MobileBertModel  # @manual
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerclass MobileBertModelExample(EagerModelBase):
17*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
18*523fa7a6SAndroid Build Coastguard Worker        pass
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self) -> torch.nn.Module:
21*523fa7a6SAndroid Build Coastguard Worker        logging.info("loading mobilebert model")
22*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
23*523fa7a6SAndroid Build Coastguard Worker        model = MobileBertModel.from_pretrained(
24*523fa7a6SAndroid Build Coastguard Worker            "google/mobilebert-uncased", return_dict=False
25*523fa7a6SAndroid Build Coastguard Worker        )
26*523fa7a6SAndroid Build Coastguard Worker        logging.info("loaded mobilebert model")
27*523fa7a6SAndroid Build Coastguard Worker        return model
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
30*523fa7a6SAndroid Build Coastguard Worker        tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
31*523fa7a6SAndroid Build Coastguard Worker        return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)
32