1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""tff.Computation subclass for the demo Federated Computation platform.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport functools 17*14675a02SAndroid Build Coastguard Workerimport re 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard WorkerCOMPUTATION_NAME_REGEX = re.compile(r'\w+(/\w+)*') 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Workerclass FederatedComputation(tff.Computation): 25*14675a02SAndroid Build Coastguard Worker """A tff.Computation that should be run in a tff.program.FederatedContext.""" 26*14675a02SAndroid Build Coastguard Worker 27*14675a02SAndroid Build Coastguard Worker def __init__(self, comp: tff.Computation, *, name: str): 28*14675a02SAndroid Build Coastguard Worker """Constructs a new FederatedComputation object. 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker Args: 31*14675a02SAndroid Build Coastguard Worker comp: The MapReduceForm- and DistributeAggregateForm- compatible 32*14675a02SAndroid Build Coastguard Worker computation that will be run. 33*14675a02SAndroid Build Coastguard Worker name: A unique name for the computation. 34*14675a02SAndroid Build Coastguard Worker """ 35*14675a02SAndroid Build Coastguard Worker tff.backends.mapreduce.check_computation_compatible_with_map_reduce_form( 36*14675a02SAndroid Build Coastguard Worker comp 37*14675a02SAndroid Build Coastguard Worker ) # pytype: disable=wrong-arg-types 38*14675a02SAndroid Build Coastguard Worker if not COMPUTATION_NAME_REGEX.fullmatch(name): 39*14675a02SAndroid Build Coastguard Worker raise ValueError(f'name must match "{COMPUTATION_NAME_REGEX.pattern}".') 40*14675a02SAndroid Build Coastguard Worker self._comp = comp 41*14675a02SAndroid Build Coastguard Worker self._name = name 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Worker @functools.cached_property 44*14675a02SAndroid Build Coastguard Worker def map_reduce_form(self) -> tff.backends.mapreduce.MapReduceForm: 45*14675a02SAndroid Build Coastguard Worker """The underlying MapReduceForm representation.""" 46*14675a02SAndroid Build Coastguard Worker return tff.backends.mapreduce.get_map_reduce_form_for_computation( # pytype: disable=wrong-arg-types 47*14675a02SAndroid Build Coastguard Worker self._comp 48*14675a02SAndroid Build Coastguard Worker ) 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker @functools.cached_property 51*14675a02SAndroid Build Coastguard Worker def distribute_aggregate_form( 52*14675a02SAndroid Build Coastguard Worker self, 53*14675a02SAndroid Build Coastguard Worker ) -> tff.backends.mapreduce.DistributeAggregateForm: 54*14675a02SAndroid Build Coastguard Worker """The underlying DistributeAggregateForm representation.""" 55*14675a02SAndroid Build Coastguard Worker return tff.backends.mapreduce.get_distribute_aggregate_form_for_computation( # pytype: disable=wrong-arg-types 56*14675a02SAndroid Build Coastguard Worker self._comp 57*14675a02SAndroid Build Coastguard Worker ) 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker @property 60*14675a02SAndroid Build Coastguard Worker def wrapped_computation(self) -> tff.Computation: 61*14675a02SAndroid Build Coastguard Worker """The underlying tff.Computation.""" 62*14675a02SAndroid Build Coastguard Worker return self._comp 63*14675a02SAndroid Build Coastguard Worker 64*14675a02SAndroid Build Coastguard Worker @property 65*14675a02SAndroid Build Coastguard Worker def name(self) -> str: 66*14675a02SAndroid Build Coastguard Worker """The name of the computation.""" 67*14675a02SAndroid Build Coastguard Worker return self._name 68*14675a02SAndroid Build Coastguard Worker 69*14675a02SAndroid Build Coastguard Worker @property 70*14675a02SAndroid Build Coastguard Worker def type_signature(self) -> tff.Type: 71*14675a02SAndroid Build Coastguard Worker return self._comp.type_signature 72*14675a02SAndroid Build Coastguard Worker 73*14675a02SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs) ->...: 74*14675a02SAndroid Build Coastguard Worker arg = tff.structure.Struct([(None, arg) for arg in args] + 75*14675a02SAndroid Build Coastguard Worker list(kwargs.items())) 76*14675a02SAndroid Build Coastguard Worker return tff.framework.get_context_stack().current.invoke(self, arg) 77*14675a02SAndroid Build Coastguard Worker 78*14675a02SAndroid Build Coastguard Worker def __hash__(self) -> int: 79*14675a02SAndroid Build Coastguard Worker return hash((self._comp, self._name)) 80