1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static java.nio.charset.StandardCharsets.UTF_8; 20 21 import com.google.common.base.Preconditions; 22 import io.grpc.ChannelLogger; 23 import io.grpc.alts.internal.TsiPeer.Property; 24 import io.grpc.internal.TestUtils.NoopChannelLogger; 25 import io.netty.buffer.ByteBufAllocator; 26 import java.nio.ByteBuffer; 27 import java.security.GeneralSecurityException; 28 import java.util.Collections; 29 import java.util.logging.Level; 30 import java.util.logging.Logger; 31 32 /** 33 * A fake handshaker compatible with security/transport_security/fake_transport_security.h See 34 * {@link TsiHandshaker} for documentation. 35 */ 36 public class FakeTsiHandshaker implements TsiHandshaker { 37 private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName()); 38 39 private static final TsiHandshakerFactory clientHandshakerFactory = 40 new TsiHandshakerFactory() { 41 @Override 42 public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) { 43 return new FakeTsiHandshaker(true); 44 } 45 }; 46 47 private static final TsiHandshakerFactory serverHandshakerFactory = 48 new TsiHandshakerFactory() { 49 @Override 50 public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) { 51 return new FakeTsiHandshaker(false); 52 } 53 }; 54 55 private boolean isClient; 56 private ByteBuffer sendBuffer = null; 57 private AltsFraming.Parser frameParser = new AltsFraming.Parser(); 58 59 private State sendState; 60 private State receiveState; 61 62 enum State { 63 CLIENT_NONE, 64 SERVER_NONE, 65 CLIENT_INIT, 66 SERVER_INIT, 67 CLIENT_FINISHED, 68 SERVER_FINISHED; 69 70 // Returns the next State. In order to advance to sendState=N, receiveState must be N-1. next()71 public State next() { 72 if (ordinal() + 1 < values().length) { 73 return values()[ordinal() + 1]; 74 } 75 throw new UnsupportedOperationException("Can't call next() on last element: " + this); 76 } 77 } 78 clientHandshakerFactory()79 public static TsiHandshakerFactory clientHandshakerFactory() { 80 return clientHandshakerFactory; 81 } 82 serverHandshakerFactory()83 public static TsiHandshakerFactory serverHandshakerFactory() { 84 return serverHandshakerFactory; 85 } 86 newFakeHandshakerClient()87 public static TsiHandshaker newFakeHandshakerClient() { 88 NoopChannelLogger channelLogger = new NoopChannelLogger(); 89 return clientHandshakerFactory.newHandshaker(null, channelLogger); 90 } 91 newFakeHandshakerServer()92 public static TsiHandshaker newFakeHandshakerServer() { 93 NoopChannelLogger channelLogger = new NoopChannelLogger(); 94 return serverHandshakerFactory.newHandshaker(null, channelLogger); 95 } 96 FakeTsiHandshaker(boolean isClient)97 protected FakeTsiHandshaker(boolean isClient) { 98 this.isClient = isClient; 99 if (isClient) { 100 sendState = State.CLIENT_NONE; 101 receiveState = State.SERVER_NONE; 102 } else { 103 sendState = State.SERVER_NONE; 104 receiveState = State.CLIENT_NONE; 105 } 106 } 107 getNextState(State state)108 private State getNextState(State state) { 109 switch (state) { 110 case CLIENT_NONE: 111 return State.CLIENT_INIT; 112 case SERVER_NONE: 113 return State.SERVER_INIT; 114 case CLIENT_INIT: 115 return State.CLIENT_FINISHED; 116 case SERVER_INIT: 117 return State.SERVER_FINISHED; 118 default: 119 return null; 120 } 121 } 122 getNextMessage()123 private String getNextMessage() { 124 State result = getNextState(sendState); 125 return result == null ? "BAD STATE" : result.toString(); 126 } 127 getExpectedMessage()128 private String getExpectedMessage() { 129 State result = getNextState(receiveState); 130 return result == null ? "BAD STATE" : result.toString(); 131 } 132 incrementSendState()133 private void incrementSendState() { 134 sendState = getNextState(sendState); 135 } 136 incrementReceiveState()137 private void incrementReceiveState() { 138 receiveState = getNextState(receiveState); 139 } 140 141 @Override getBytesToSendToPeer(ByteBuffer bytes)142 public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException { 143 Preconditions.checkNotNull(bytes); 144 145 // If we're done, return nothing. 146 if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) { 147 return; 148 } 149 150 // Prepare the next message, if neeeded. 151 if (sendBuffer == null) { 152 if (sendState.next() != receiveState) { 153 // We're still waiting for bytes from the peer, so bail. 154 return; 155 } 156 ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8)); 157 sendBuffer = AltsFraming.toFrame(payload, payload.remaining()); 158 logger.log(Level.FINE, "Buffered message: {0}", getNextMessage()); 159 } 160 while (bytes.hasRemaining() && sendBuffer.hasRemaining()) { 161 bytes.put(sendBuffer.get()); 162 } 163 if (!sendBuffer.hasRemaining()) { 164 // Get ready to send the next message. 165 sendBuffer = null; 166 incrementSendState(); 167 } 168 } 169 170 @Override processBytesFromPeer(ByteBuffer bytes)171 public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException { 172 Preconditions.checkNotNull(bytes); 173 174 frameParser.readBytes(bytes); 175 if (frameParser.isComplete()) { 176 ByteBuffer messageBytes = frameParser.getRawFrame(); 177 int offset = AltsFraming.getFramingOverhead(); 178 int length = messageBytes.limit() - offset; 179 @SuppressWarnings("ByteBufferBackingArray") // ByteBuffer is created using allocate() 180 String message = new String(messageBytes.array(), offset, length, UTF_8); 181 logger.log(Level.FINE, "Read message: {0}", message); 182 183 if (!message.equals(getExpectedMessage())) { 184 throw new IllegalArgumentException( 185 "Bad handshake message. Got " 186 + message 187 + " (length = " 188 + message.length() 189 + ") expected " 190 + getExpectedMessage() 191 + " (length = " 192 + getExpectedMessage().length() 193 + ")"); 194 } 195 incrementReceiveState(); 196 return true; 197 } 198 return false; 199 } 200 201 @Override isInProgress()202 public boolean isInProgress() { 203 boolean finishedReceiving = 204 receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED; 205 boolean finishedSending = 206 sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED; 207 return !finishedSending || !finishedReceiving; 208 } 209 210 @Override extractPeer()211 public TsiPeer extractPeer() { 212 return new TsiPeer(Collections.<Property<?>>emptyList()); 213 } 214 215 @Override extractPeerObject()216 public Object extractPeerObject() { 217 return AltsInternalContext.getDefaultInstance(); 218 } 219 220 @Override createFrameProtector(int maxFrameSize, ByteBufAllocator alloc)221 public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) { 222 Preconditions.checkState(!isInProgress(), "Handshake is not complete."); 223 224 // We use an all-zero key, since this is the fake handshaker. 225 byte[] key = new byte[AltsChannelCrypter.getKeyLength()]; 226 return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc); 227 } 228 229 @Override createFrameProtector(ByteBufAllocator alloc)230 public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { 231 return createFrameProtector(AltsTsiFrameProtector.getMinFrameSize(), alloc); 232 } 233 234 @Override close()235 public void close() { 236 // No-op 237 } 238 } 239