Skip to content

Commit

Permalink
Execute socket->read() inside TcpSession client code protected by mutex
Browse files Browse the repository at this point in the history
Instead of letting read happen in boost asio library in the thread which runs
the even, defer actual socket read call to TcpSsession (Or SslSession) code
which is (already) protected by mutex. This ensures that socket read and write
never happens in parallel. This is a requirement as boost::asio library is not
thread safe per-se for parallel reads/writes over the same asio socket

Change-Id: I25cc6b149d26579c1eb1f75965227135c26834e7
Closes-Bug: 1608731
  • Loading branch information
ananth-at-camphor-networks committed Aug 4, 2016
1 parent c8b3333 commit 7853e1c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 100 deletions.
46 changes: 11 additions & 35 deletions src/io/ssl_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,18 @@ TcpSession::Socket *SslSession::socket() const {
return &ssl_socket_->next_layer();
}

bool SslSession::AsyncReadHandlerProcess(mutable_buffer buffer,
size_t *bytes_transferred,
error_code &error) {
// no processing needed if ssl handshake is not complete.
if (!IsSslHandShakeSuccessLocked()) {
return false;
}
size_t SslSession::ReadSome(mutable_buffer buffer, error_code &error) {
// Read data from the tcp socket or from the ssl socket, as appropriate.
assert(!ssl_handshake_in_progress_);
if (!IsSslHandShakeSuccessLocked())
return TcpSession::ReadSome(buffer, error);

// do ssl read here in IO context, ignore errors
*bytes_transferred = ssl_socket_->read_some(mutable_buffers_1(buffer),
error);

return true;
}

void SslSession::AsyncReadSome(mutable_buffer buffer) {
if (IsSslHandShakeSuccessLocked()) {
// trigger read with null buffer to get indication for data available
// on the socket and then do the actuall socket read in
// AsyncReadHandlerProcess
socket()->async_read_some(null_buffers(),
bind(&TcpSession::AsyncReadHandler, SslSessionPtr(this), buffer,
placeholders::error, placeholders::bytes_transferred));
} else {
// No tcp socket read/write while ssl handshake is ongoing
if (!ssl_handshake_in_progress_) {
socket()->async_read_some(mutable_buffers_1(buffer),
bind(&TcpSession::AsyncReadHandler, SslSessionPtr(this),
buffer, placeholders::error,
placeholders::bytes_transferred));
}
}
return ssl_socket_->read_some(mutable_buffers_1(buffer), error);
}

size_t SslSession::WriteSome(const uint8_t *data, size_t len,
error_code &error) {
error_code &error) {

if (IsSslHandShakeSuccessLocked()) {
return ssl_socket_->write_some(buffer(data, len), error);
Expand All @@ -130,16 +106,16 @@ size_t SslSession::WriteSome(const uint8_t *data, size_t len,
void SslSession::AsyncWrite(const u_int8_t *data, size_t size) {
if (IsSslHandShakeSuccessLocked()) {
async_write(*ssl_socket_.get(), buffer(data, size),
bind(&TcpSession::AsyncWriteHandler,
TcpSessionPtr(this), placeholders::error));
bind(&TcpSession::AsyncWriteHandler,
TcpSessionPtr(this), placeholders::error));
} else {
return (TcpSession::AsyncWrite(data, size));
}
}

void SslSession::SslHandShakeCallback(SslHandShakeCallbackHandler cb,
SslSessionPtr session,
const error_code &error) {
SslSessionPtr session,
const error_code &error) {

session->ssl_handshake_in_progress_ = false;
if (!error) {
Expand Down
7 changes: 2 additions & 5 deletions src/io/ssl_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ class SslSession : public TcpSession {
// SslSession do actual ssl socket read for data in this context with
// session mutex held, to avoid concurrent read and write operations
// on same socket.
bool AsyncReadHandlerProcess(boost::asio::mutable_buffer buffer,
size_t *bytes_transferred,
boost::system::error_code &error);

void AsyncReadSome(boost::asio::mutable_buffer buffer);
size_t ReadSome(boost::asio::mutable_buffer buffer,
boost::system::error_code &error);
std::size_t WriteSome(const uint8_t *data, std::size_t len,
boost::system::error_code &error);
void AsyncWrite(const u_int8_t *data, std::size_t size);
Expand Down
76 changes: 27 additions & 49 deletions src/io/tcp_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,6 @@ void TcpSession::ReleaseBufferLocked(Buffer buffer) {
assert(false);
}

bool TcpSession::AsyncReadHandlerProcess(mutable_buffer buffer,
size_t *bytes_transferred,
error_code &error) {
return false;
}

void TcpSession::AsyncReadStartInternal(TcpSessionPtr session) {
// Update socket read block time.
if (stats_.read_block_start_time) {
Expand All @@ -157,19 +151,13 @@ void TcpSession::AsyncReadStartInternal(TcpSessionPtr session) {
stats_.read_blocked_duration_usecs += blocked_usecs;
server_->stats_.read_blocked_duration_usecs += blocked_usecs;
}
mutable_buffer buffer = AllocateBuffer();
tbb::mutex::scoped_lock lock(mutex_);
if (!established_) {
ReleaseBufferLocked(buffer);
return;
}
AsyncReadSome(buffer);
AsyncReadSome();
}

void TcpSession::AsyncReadStart() {
if (io_strand_) {
io_strand_->post(bind(&TcpSession::AsyncReadStartInternal, this,
TcpSessionPtr(this)));
TcpSessionPtr(this)));
}
}

Expand All @@ -193,12 +181,12 @@ void TcpSession::DeferWriter() {
placeholders::error, UTCTimestampUsec()));
}

void TcpSession::AsyncReadSome(mutable_buffer buffer) {
socket()->async_read_some(mutable_buffers_1(buffer),
bind(&TcpSession::AsyncReadHandler,
TcpSessionPtr(this), buffer,
placeholders::error,
placeholders::bytes_transferred));
void TcpSession::AsyncReadSome() {
tbb::mutex::scoped_lock lock(mutex_);
if (established_) {
socket()->async_read_some(null_buffers(),
bind(&TcpSession::AsyncReadHandler, TcpSessionPtr(this)));
}
}

size_t TcpSession::WriteSome(const uint8_t *data, size_t len,
Expand All @@ -208,8 +196,8 @@ size_t TcpSession::WriteSome(const uint8_t *data, size_t len,

void TcpSession::AsyncWrite(const u_int8_t *data, size_t size) {
async_write(*socket(), buffer(data, size),
bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this),
placeholders::error));
bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this),
placeholders::error));
}

TcpSession::Endpoint TcpSession::local_endpoint() const {
Expand Down Expand Up @@ -416,7 +404,8 @@ bool TcpSession::Send(const u_int8_t *data, size_t size, size_t *sent) {
CloseInternal(error, true);
return false;
}
if (len < 0 || (size_t)len != size) ret = false;
if ((size_t) len != size)
ret = false;
if (sent) *sent = (len > 0) ? len : 0;
} else {
AsyncWrite(data, size);
Expand All @@ -434,46 +423,35 @@ Task* TcpSession::CreateReaderTask(mutable_buffer buffer,
return (task);
}

void TcpSession::AsyncReadHandler(
TcpSessionPtr session, mutable_buffer buffer,
const error_code &error, size_t bytes_transferred) {
size_t TcpSession::ReadSome(mutable_buffer buffer, error_code &error) {
return socket()->read_some(mutable_buffers_1(buffer), error);
}

void TcpSession::AsyncReadHandler(TcpSessionPtr session) {
mutable_buffer buffer = session->AllocateBuffer();

tbb::mutex::scoped_lock lock(session->mutex_);
if (session->closed_) {
session->ReleaseBufferLocked(buffer);
return;
}

error_code error;
size_t bytes_transferred = session->ReadSome(buffer, error);
if (IsSocketErrorHard(error)) {
session->ReleaseBufferLocked(buffer);
// eof is returned when the peer closed the socket, no need to log err
if (error != error::eof) {
TCP_SESSION_LOG_ERROR(session, TCP_DIR_IN,
"Read failed due to error " << error.value()
<< " : " << error.message());
}
// eof is returned when the peer closed the socket, no need to log err
if (error != error::eof) {
TCP_SESSION_LOG_ERROR(session, TCP_DIR_IN,
"Read failed due to error " << error.value()
<< " : " << error.message());
}

lock.release();
session->CloseInternal(error, true);
return;
}

error_code err;
if (session->AsyncReadHandlerProcess(buffer, &bytes_transferred, err)) {
// check error code if session needs to be closed
if (IsSocketErrorHard(err)) {
session->ReleaseBufferLocked(buffer);
// eof is returned when the peer has closed the socket
if (err != error::eof) {
TCP_SESSION_LOG_ERROR(session, TCP_DIR_IN,
"Read failed due to error " << err.value()
<< " : " << err.message());
}
lock.release();
session->CloseInternal(err, true);
return;
}
}

// Update read statistics.
session->stats_.read_calls++;
session->stats_.read_bytes += bytes_transferred;
Expand Down
15 changes: 4 additions & 11 deletions src/io/tcp_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,10 @@ class TcpSession {

protected:
typedef boost::intrusive_ptr<TcpSession> TcpSessionPtr;
static void AsyncReadHandler(TcpSessionPtr session,
boost::asio::mutable_buffer buffer,
const boost::system::error_code &error,
size_t size);
static void AsyncReadHandler(TcpSessionPtr session);
static void AsyncWriteHandler(TcpSessionPtr session,
const boost::system::error_code &error);

// returns true if Processing done, used by SslSession to do actual
// synchronous read for data.
virtual bool AsyncReadHandlerProcess(boost::asio::mutable_buffer buffer,
size_t *bytes_transferred,
boost::system::error_code &error);

void AsyncReadStartInternal(TcpSessionPtr session);
virtual Task* CreateReaderTask(boost::asio::mutable_buffer, size_t);

Expand All @@ -179,7 +170,9 @@ class TcpSession {
// Callback after socket is ready for write.
virtual void WriteReady(const boost::system::error_code &error);

virtual void AsyncReadSome(boost::asio::mutable_buffer buffer);
void AsyncReadSome();
virtual size_t ReadSome(boost::asio::mutable_buffer buffer,
boost::system::error_code &error);
virtual std::size_t WriteSome(const uint8_t *data, std::size_t len,
boost::system::error_code &error);
virtual void AsyncWrite(const u_int8_t *data, std::size_t size);
Expand Down

0 comments on commit 7853e1c

Please sign in to comment.