Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Netty Connection Read Timeout #477

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 59 additions & 2 deletions src/main/java/com/basho/riak/client/core/RiakNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public enum State

private volatile Bootstrap bootstrap;
private volatile boolean ownsBootstrap;
private volatile RiakChannelInitializer riakChannelInitializer;
private volatile ScheduledExecutorService executor;
private volatile boolean ownsExecutor;
private volatile State state;
Expand All @@ -84,6 +85,7 @@ public enum State
private volatile int minConnections;
private volatile long idleTimeoutInNanos;
private volatile int connectionTimeout;
private volatile int readTimeout;
private volatile boolean blockOnMaxConnections;

private HealthCheckFactory healthCheckFactory;
Expand Down Expand Up @@ -175,6 +177,7 @@ private RiakNode(Builder builder) throws UnknownHostException
this.executor = builder.executor;
this.connectionTimeout = builder.connectionTimeout;
this.idleTimeoutInNanos = TimeUnit.NANOSECONDS.convert(builder.idleTimeout, TimeUnit.MILLISECONDS);
this.readTimeout = builder.readTimeout;
this.minConnections = builder.minConnections;
this.port = builder.port;
this.remoteAddress = builder.remoteAddress;
Expand Down Expand Up @@ -242,7 +245,9 @@ public synchronized RiakNode start()
ownsBootstrap = true;
}

bootstrap.handler(new RiakChannelInitializer(this))

riakChannelInitializer = new RiakChannelInitializer(this, readTimeout);
bootstrap.handler(riakChannelInitializer)
.remoteAddress(new InetSocketAddress(remoteAddress, port));

if (connectionTimeout > 0)
Expand Down Expand Up @@ -516,6 +521,33 @@ public int getConnectionTimeout()
return connectionTimeout;
}

/**
* Sets the read timeout in milliseconds.
*
* @param readTimeoutInMillis the read timeout to set
* @return a reference to this RiakNode
* @see Builder#withReadTimeout(int)
*/
public RiakNode setReadTimeout(int readTimeoutInMillis)
{
stateCheck(State.CREATED, State.RUNNING, State.HEALTH_CHECKING);
this.readTimeout = readTimeoutInMillis;
riakChannelInitializer.setReadTimeout(readTimeout);
return this;
}

/**
* Returns the read timeout in milliseconds.
*
* @return the readTimeout
* @see Builder#withReadTimeout(int)
*/
public int getReadTimeout()
{
stateCheck(State.CREATED, State.RUNNING, State.HEALTH_CHECKING);
return readTimeout;
}

/**
* Returns the number of permits currently available.
* The number of available permits indicates how many additional
Expand Down Expand Up @@ -663,6 +695,7 @@ private Channel doGetConnection() throws ConnectionFailedException

try
{
logger.debug("Waiting for new connection from channel future to {}:{}", remoteAddress, port);
f.await();
}
catch (InterruptedException ex)
Expand All @@ -680,12 +713,15 @@ private Channel doGetConnection() throws ConnectionFailedException
consecutiveFailedConnectionAttempts.incrementAndGet();
throw new ConnectionFailedException(f.cause());
}

logger.debug("Connection to {}:{} successful", remoteAddress, port);

consecutiveFailedConnectionAttempts.set(0);
Channel c = f.channel();

if (trustStore != null)
{
logger.debug("trustStore set starting TLS");
SSLContext context;
try
{
Expand Down Expand Up @@ -720,11 +756,12 @@ else if (protocols.contains("TLSv1.1"))
}

engine.setUseClientMode(true);
RiakSecurityDecoder decoder = new RiakSecurityDecoder(engine, username, password);
RiakSecurityDecoder decoder = new RiakSecurityDecoder(remoteAddress, port, engine, username, password);
c.pipeline().addFirst(decoder);

try
{
logger.debug("Waiting for authentication to complete with {}:{}", remoteAddress, port);
DefaultPromise<Void> promise = decoder.getPromise();
promise.await();

Expand Down Expand Up @@ -1199,6 +1236,12 @@ public static class Builder
* @see #withConnectionTimeout(int)
*/
public final static int DEFAULT_CONNECTION_TIMEOUT = 0;
/**
* The default so timeout in milliseconds if not specified: {@value #DEFAULT_READ_TIMEOUT}
*
* @see #withReadTimeout(int)
*/
public final static int DEFAULT_READ_TIMEOUT = 0;

/**
* The default HealthCheckFactory.
Expand All @@ -1216,6 +1259,7 @@ public static class Builder
private int maxConnections = DEFAULT_MAX_CONNECTIONS;
private int idleTimeout = DEFAULT_IDLE_TIMEOUT;
private int connectionTimeout = DEFAULT_CONNECTION_TIMEOUT;
private int readTimeout = DEFAULT_READ_TIMEOUT;
private HealthCheckFactory healthCheckFactory = DEFAULT_HEALTHCHECK_FACTORY;
private Bootstrap bootstrap;
private ScheduledExecutorService executor;
Expand Down Expand Up @@ -1331,6 +1375,19 @@ public Builder withConnectionTimeout(int connectionTimeoutInMillis)
return this;
}

/**
* Set the read timeout used when waiting for a response on the underlying sockets
*
* @param readTimeoutMillis
* @return this
* @see #DEFAULT_READ_TIMEOUT
*/
public Builder withReadTimeout(int readTimeoutMillis)
{
this.readTimeout = readTimeoutMillis;
return this;
}

/**
* Provides an executor for this node to use for internal maintenance tasks.
* If not provided one will be created via
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;

import java.util.concurrent.TimeUnit;

/**
*
Expand All @@ -29,10 +33,13 @@
public class RiakChannelInitializer extends ChannelInitializer<SocketChannel>
{
private final RiakResponseListener listener;
public RiakChannelInitializer(RiakResponseListener listener)
private volatile int readTimeout;

public RiakChannelInitializer(RiakResponseListener listener, int readTimeoutMillis)
{
super();
this.listener = listener;
this.readTimeout = readTimeoutMillis;
}

@Override
Expand All @@ -42,6 +49,16 @@ public void initChannel(SocketChannel ch) throws Exception
p.addLast(Constants.MESSAGE_CODEC, new RiakMessageCodec());
p.addLast(Constants.OPERATION_ENCODER, new RiakOperationEncoder());
p.addLast(Constants.RESPONSE_HANDLER, new RiakResponseHandler(listener));
p.addLast(Constants.READ_TIMEOUT_HANDLER, new ReadTimeoutHandler(readTimeout, TimeUnit.MILLISECONDS));
}

public int getReadTimeout()
{
return readTimeout;
}

public void setReadTimeout(int readTimeoutMillis)
{
readTimeout = readTimeoutMillis;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ public class RiakSecurityDecoder extends ByteToMessageDecoder
private final String username;
private final String password;
private final Logger logger = LoggerFactory.getLogger(RiakSecurityDecoder.class);
private volatile DefaultPromise<Void> promise;
private final String remoteAddr;
private final int remotePort;
private volatile DefaultPromise<Void> promise;

private enum State { TLS_START, TLS_WAIT, SSL_WAIT, AUTH_WAIT }

private volatile State state = State.TLS_START;

public RiakSecurityDecoder(SSLEngine engine, String username, String password)
public RiakSecurityDecoder(String remoteAddress, int port, SSLEngine engine, String username, String password)
{
this.remoteAddr = remoteAddress;
this.remotePort = port;
this.sslEngine = engine;
this.username = username;
this.password = password;
Expand Down Expand Up @@ -88,7 +92,7 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List<Object> out) t
switch(code)
{
case RiakMessageCodes.MSG_StartTls:
logger.debug("Received MSG_RpbStartTls reply");
logger.debug("Received MSG_RpbStartTls reply from {}:{}", remoteAddr, remotePort);
// change state
this.state = State.SSL_WAIT;
// insert SSLHandler
Expand All @@ -101,10 +105,11 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List<Object> out) t
chc.channel().pipeline().addFirst(Constants.SSL_HANDLER, sslHandler);
break;
case RiakMessageCodes.MSG_ErrorResp:
logger.debug("Received MSG_ErrorResp reply to startTls");
logger.debug("Received MSG_ErrorResp reply to startTls from {}:{}", remoteAddr, remotePort);
promise.tryFailure((riakErrorToException(protobuf)));
break;
default:
logger.debug("Invalid return code during StartTLS from {}:{} code", remoteAddr, remotePort, code);
promise.tryFailure(new RiakResponseException(0,
"Invalid return code during StartTLS; " + code));
}
Expand All @@ -114,21 +119,22 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List<Object> out) t
switch(code)
{
case RiakMessageCodes.MSG_AuthResp:
logger.debug("Received MSG_RpbAuthResp reply");
logger.debug("Received MSG_RpbAuthResp reply from {}:{}", remoteAddr, remotePort);
promise.trySuccess(null);
break;
case RiakMessageCodes.MSG_ErrorResp:
logger.debug("Received MSG_ErrorResp reply to auth");
logger.debug("Received MSG_ErrorResp reply to Auth from {}:{}", remoteAddr, remotePort);
promise.tryFailure(riakErrorToException(protobuf));
break;
default:
logger.debug("Invalid return code during Auth from {}:{}", remoteAddr, remotePort);
promise.tryFailure(new RiakResponseException(0,
"Invalid return code during Auth; " + code));
}
break;
default:
// WTF?
logger.error("Received message while not in TLS_WAIT or AUTH_WAIT");
logger.error("Received message while not in TLS_WAIT or AUTH_WAIT from {}:{}", remoteAddr, remotePort);
promise.tryFailure(new IllegalStateException("Received message while not in TLS_WAIT or AUTH_WAIT"));
}
}
Expand Down Expand Up @@ -208,6 +214,7 @@ public void operationComplete(Future<Channel> future) throws Exception
{
if (future.isSuccess())
{
logger.debug("SSLHandshake Completed with {}:{}. Authenticating.", remoteAddr, remotePort);
Channel c = future.getNow();
state = State.AUTH_WAIT;
RiakPB.RpbAuthReq authReq =
Expand All @@ -221,6 +228,7 @@ public void operationComplete(Future<Channel> future) throws Exception
}
else
{
logger.warn("SSLHandshake Failed with {}:{}.", remoteAddr, remotePort, future.cause());
promise.tryFailure(future.cause());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public interface Constants {
public static final String MESSAGE_CODEC = "codec";
public static final String OPERATION_ENCODER = "operationEncoder";
public static final String RESPONSE_HANDLER = "responseHandler";
public static final String READ_TIMEOUT_HANDLER = "readTimeoutHandler";
public static final String SSL_HANDLER = "sslHandler";
public static final String HEALTHCHECK_CODEC = "healthCheckCodec";

Expand Down
5 changes: 4 additions & 1 deletion src/test/java/com/basho/riak/client/core/RiakNodeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public void builderProducesDefaultNode() throws UnknownHostException
assertEquals(node.getMaxConnections(), Integer.MAX_VALUE);
assertEquals(node.getConnectionTimeout(), RiakNode.Builder.DEFAULT_CONNECTION_TIMEOUT);
assertEquals(node.getIdleTimeout(), RiakNode.Builder.DEFAULT_IDLE_TIMEOUT);
assertEquals(node.getReadTimeout(), RiakNode.Builder.DEFAULT_READ_TIMEOUT);
assertEquals(node.getMinConnections(), RiakNode.Builder.DEFAULT_MIN_CONNECTIONS);
assertEquals(node.availablePermits(), Integer.MAX_VALUE);
}
Expand All @@ -75,7 +76,7 @@ public void builderProducesCorrectNode() throws UnknownHostException
final int MIN_CONNECTIONS = 2002;
final int MAX_CONNECTIONS = 2003;
final int PORT = 2004;
final int READ_TIMEOUT = 2005;
final int READ_TIMEOUT = 2006;
final String REMOTE_ADDRESS = "localhost";
final ScheduledExecutorService EXECUTOR = Executors.newSingleThreadScheduledExecutor();
final Bootstrap BOOTSTRAP = PowerMockito.spy(new Bootstrap());
Expand All @@ -91,6 +92,7 @@ public void builderProducesCorrectNode() throws UnknownHostException
.withRemoteAddress(REMOTE_ADDRESS)
.withExecutor(EXECUTOR)
.withBootstrap(BOOTSTRAP)
.withReadTimeout(READ_TIMEOUT)
.build();

assertEquals(node.getRemoteAddress(), REMOTE_ADDRESS);
Expand All @@ -103,6 +105,7 @@ public void builderProducesCorrectNode() throws UnknownHostException
assertEquals(node.getRemoteAddress(), REMOTE_ADDRESS);
assertEquals(node.availablePermits(), MAX_CONNECTIONS);
assertEquals(node.getPort(), PORT);
assertEquals(node.getReadTimeout(), READ_TIMEOUT);

}

Expand Down