xref: /aosp_15_r20/external/grpc-grpc-java/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java (revision e07d83d3ffcef9ecfc9f7f475418ec639ff0e5fe)
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static java.nio.charset.StandardCharsets.UTF_8;
20 
21 import com.google.common.base.Preconditions;
22 import io.grpc.ChannelLogger;
23 import io.grpc.alts.internal.TsiPeer.Property;
24 import io.grpc.internal.TestUtils.NoopChannelLogger;
25 import io.netty.buffer.ByteBufAllocator;
26 import java.nio.ByteBuffer;
27 import java.security.GeneralSecurityException;
28 import java.util.Collections;
29 import java.util.logging.Level;
30 import java.util.logging.Logger;
31 
32 /**
33  * A fake handshaker compatible with security/transport_security/fake_transport_security.h See
34  * {@link TsiHandshaker} for documentation.
35  */
36 public class FakeTsiHandshaker implements TsiHandshaker {
37   private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName());
38 
39   private static final TsiHandshakerFactory clientHandshakerFactory =
40       new TsiHandshakerFactory() {
41         @Override
42         public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
43           return new FakeTsiHandshaker(true);
44         }
45       };
46 
47   private static final TsiHandshakerFactory serverHandshakerFactory =
48       new TsiHandshakerFactory() {
49         @Override
50         public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
51           return new FakeTsiHandshaker(false);
52         }
53       };
54 
55   private boolean isClient;
56   private ByteBuffer sendBuffer = null;
57   private AltsFraming.Parser frameParser = new AltsFraming.Parser();
58 
59   private State sendState;
60   private State receiveState;
61 
62   enum State {
63     CLIENT_NONE,
64     SERVER_NONE,
65     CLIENT_INIT,
66     SERVER_INIT,
67     CLIENT_FINISHED,
68     SERVER_FINISHED;
69 
70     // Returns the next State. In order to advance to sendState=N, receiveState must be N-1.
next()71     public State next() {
72       if (ordinal() + 1 < values().length) {
73         return values()[ordinal() + 1];
74       }
75       throw new UnsupportedOperationException("Can't call next() on last element: " + this);
76     }
77   }
78 
clientHandshakerFactory()79   public static TsiHandshakerFactory clientHandshakerFactory() {
80     return clientHandshakerFactory;
81   }
82 
serverHandshakerFactory()83   public static TsiHandshakerFactory serverHandshakerFactory() {
84     return serverHandshakerFactory;
85   }
86 
newFakeHandshakerClient()87   public static TsiHandshaker newFakeHandshakerClient() {
88     NoopChannelLogger channelLogger = new NoopChannelLogger();
89     return clientHandshakerFactory.newHandshaker(null, channelLogger);
90   }
91 
newFakeHandshakerServer()92   public static TsiHandshaker newFakeHandshakerServer() {
93     NoopChannelLogger channelLogger = new NoopChannelLogger();
94     return serverHandshakerFactory.newHandshaker(null, channelLogger);
95   }
96 
FakeTsiHandshaker(boolean isClient)97   protected FakeTsiHandshaker(boolean isClient) {
98     this.isClient = isClient;
99     if (isClient) {
100       sendState = State.CLIENT_NONE;
101       receiveState = State.SERVER_NONE;
102     } else {
103       sendState = State.SERVER_NONE;
104       receiveState = State.CLIENT_NONE;
105     }
106   }
107 
getNextState(State state)108   private State getNextState(State state) {
109     switch (state) {
110       case CLIENT_NONE:
111         return State.CLIENT_INIT;
112       case SERVER_NONE:
113         return State.SERVER_INIT;
114       case CLIENT_INIT:
115         return State.CLIENT_FINISHED;
116       case SERVER_INIT:
117         return State.SERVER_FINISHED;
118       default:
119         return null;
120     }
121   }
122 
getNextMessage()123   private String getNextMessage() {
124     State result = getNextState(sendState);
125     return result == null ? "BAD STATE" : result.toString();
126   }
127 
getExpectedMessage()128   private String getExpectedMessage() {
129     State result = getNextState(receiveState);
130     return result == null ? "BAD STATE" : result.toString();
131   }
132 
incrementSendState()133   private void incrementSendState() {
134     sendState = getNextState(sendState);
135   }
136 
incrementReceiveState()137   private void incrementReceiveState() {
138     receiveState = getNextState(receiveState);
139   }
140 
141   @Override
getBytesToSendToPeer(ByteBuffer bytes)142   public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
143     Preconditions.checkNotNull(bytes);
144 
145     // If we're done, return nothing.
146     if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) {
147       return;
148     }
149 
150     // Prepare the next message, if neeeded.
151     if (sendBuffer == null) {
152       if (sendState.next() != receiveState) {
153         // We're still waiting for bytes from the peer, so bail.
154         return;
155       }
156       ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8));
157       sendBuffer = AltsFraming.toFrame(payload, payload.remaining());
158       logger.log(Level.FINE, "Buffered message: {0}", getNextMessage());
159     }
160     while (bytes.hasRemaining() && sendBuffer.hasRemaining()) {
161       bytes.put(sendBuffer.get());
162     }
163     if (!sendBuffer.hasRemaining()) {
164       // Get ready to send the next message.
165       sendBuffer = null;
166       incrementSendState();
167     }
168   }
169 
170   @Override
processBytesFromPeer(ByteBuffer bytes)171   public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
172     Preconditions.checkNotNull(bytes);
173 
174     frameParser.readBytes(bytes);
175     if (frameParser.isComplete()) {
176       ByteBuffer messageBytes = frameParser.getRawFrame();
177       int offset = AltsFraming.getFramingOverhead();
178       int length = messageBytes.limit() - offset;
179       @SuppressWarnings("ByteBufferBackingArray") // ByteBuffer is created using allocate()
180       String message = new String(messageBytes.array(), offset, length, UTF_8);
181       logger.log(Level.FINE, "Read message: {0}", message);
182 
183       if (!message.equals(getExpectedMessage())) {
184         throw new IllegalArgumentException(
185             "Bad handshake message. Got "
186                 + message
187                 + " (length = "
188                 + message.length()
189                 + ") expected "
190                 + getExpectedMessage()
191                 + " (length = "
192                 + getExpectedMessage().length()
193                 + ")");
194       }
195       incrementReceiveState();
196       return true;
197     }
198     return false;
199   }
200 
201   @Override
isInProgress()202   public boolean isInProgress() {
203     boolean finishedReceiving =
204         receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED;
205     boolean finishedSending =
206         sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED;
207     return !finishedSending || !finishedReceiving;
208   }
209 
210   @Override
extractPeer()211   public TsiPeer extractPeer() {
212     return new TsiPeer(Collections.<Property<?>>emptyList());
213   }
214 
215   @Override
extractPeerObject()216   public Object extractPeerObject() {
217     return AltsInternalContext.getDefaultInstance();
218   }
219 
220   @Override
createFrameProtector(int maxFrameSize, ByteBufAllocator alloc)221   public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
222     Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
223 
224     // We use an all-zero key, since this is the fake handshaker.
225     byte[] key = new byte[AltsChannelCrypter.getKeyLength()];
226     return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
227   }
228 
229   @Override
createFrameProtector(ByteBufAllocator alloc)230   public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
231     return createFrameProtector(AltsTsiFrameProtector.getMinFrameSize(), alloc);
232   }
233 
234   @Override
close()235   public void close() {
236     // No-op
237   }
238 }
239