1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerfrom transformers import AutoTokenizer, MobileBertModel # @manual 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerclass MobileBertModelExample(EagerModelBase): 17*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 18*523fa7a6SAndroid Build Coastguard Worker pass 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 21*523fa7a6SAndroid Build Coastguard Worker logging.info("loading mobilebert model") 22*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 23*523fa7a6SAndroid Build Coastguard Worker model = MobileBertModel.from_pretrained( 24*523fa7a6SAndroid Build Coastguard Worker "google/mobilebert-uncased", return_dict=False 25*523fa7a6SAndroid Build Coastguard Worker ) 26*523fa7a6SAndroid Build Coastguard Worker logging.info("loaded mobilebert model") 27*523fa7a6SAndroid Build Coastguard Worker return model 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 30*523fa7a6SAndroid Build Coastguard Worker tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased") 31*523fa7a6SAndroid Build Coastguard Worker return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],) 32