From df4fc0a97e4c77c1c9ac3a2e3b45b6d4db97b531 Mon Sep 17 00:00:00 2001 From: Prabhjot Singh Sethi Date: Mon, 26 Jan 2015 22:07:45 -0800 Subject: [PATCH] Support for SSL server/session infrastructure Following changes are done to extend existing tcp infra to support SSL server. - code movement to make TCP message writer to be independent of socket, and all operations to be trigered on session ptr - introduced virtual methods to override socket operations to use relavant socket structure. - hooks installed in SSL server to trigger SSL handshake before triggering tcp session connected and accepted state machine Added basic connect, send and recv test code for ssl server infra. Change-Id: I8e50a400e6b80cef42e852f5da2038f44ce4b082 --- src/io/SConscript | 2 + src/io/ssl_server.cc | 54 ++++++ src/io/ssl_server.h | 43 +++++ src/io/ssl_session.cc | 95 ++++++++++ src/io/ssl_session.h | 46 +++++ src/io/tcp_message_write.cc | 66 +------ src/io/tcp_message_write.h | 13 +- src/io/tcp_server.cc | 47 +++-- src/io/tcp_server.h | 7 + src/io/tcp_session.cc | 127 ++++++++++---- src/io/tcp_session.h | 34 ++-- src/io/test/SConscript | 9 +- src/io/test/ssl_server_test.cc | 311 +++++++++++++++++++++++++++++++++ 13 files changed, 727 insertions(+), 127 deletions(-) create mode 100644 src/io/ssl_server.cc create mode 100644 src/io/ssl_server.h create mode 100644 src/io/ssl_session.cc create mode 100644 src/io/ssl_session.h create mode 100644 src/io/test/ssl_server_test.cc diff --git a/src/io/SConscript b/src/io/SConscript index 1ba9453239c..4e409b2afed 100644 --- a/src/io/SConscript +++ b/src/io/SConscript @@ -23,6 +23,8 @@ libio = env.Library('io', EventManagerSrc + [ 'io_utils.cc', + 'ssl_server.cc', + 'ssl_session.cc', 'tcp_message_write.cc', 'tcp_server.cc', 'tcp_session.cc', diff --git a/src/io/ssl_server.cc b/src/io/ssl_server.cc new file mode 100644 index 00000000000..055893177b0 --- /dev/null +++ b/src/io/ssl_server.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include "ssl_server.h" +#include "ssl_session.h" + +#include "io/event_manager.h" + +SslServer::SslServer(EventManager *evm, boost::asio::ssl::context::method m) + : TcpServer(evm), context_(*evm->io_service(), m) { + boost::system::error_code ec; + // By default set verify mode to none, to be set by derived class later. + context_.set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + context_.set_options(boost::asio::ssl::context::default_workarounds, ec); + assert(ec.value() == 0); +} + +SslServer::~SslServer() { +} + +boost::asio::ssl::context *SslServer::context() { + return &context_; +} + +TcpSession *SslServer::AllocSession(bool server_session) { + SslSession *session; + if (server_session) { + session = AllocSession(so_ssl_accept_.get()); + + // if session allocate succeeds release ownership to so_accept. + if (session != NULL) { + so_ssl_accept_.release(); + } + } else { + SslSocket *socket = new SslSocket(*event_manager()->io_service(), + context_); + session = AllocSession(socket); + } + + return session; +} + +TcpServer::Socket *SslServer::accept_socket() const { + // return tcp socket + return &(so_ssl_accept_->next_layer()); +} + +void SslServer::set_accept_socket() { + so_ssl_accept_.reset(new SslSocket(*event_manager()->io_service(), + context_)); +} + diff --git a/src/io/ssl_server.h b/src/io/ssl_server.h new file mode 100644 index 00000000000..6da053c1626 --- /dev/null +++ b/src/io/ssl_server.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#ifndef __src_io_ssl_server_h__ +#define __src_io_ssl_server_h__ + +#include + +#include "io/tcp_server.h" + +class SslSession; + +class SslServer : public TcpServer { +public: + typedef boost::asio::ssl::stream SslSocket; + + explicit SslServer(EventManager *evm, boost::asio::ssl::context::method m); + virtual ~SslServer(); + +protected: + // given SSL socket, Create a session object. + virtual SslSession *AllocSession(SslSocket *socket) = 0; + + // boost ssl context accessor to setup ssl context variables. + boost::asio::ssl::context *context(); + +private: + // suppress AllocSession method using tcp socket, not valid for + // ssl server. + TcpSession *AllocSession(Socket *socket) { return NULL; } + + TcpSession *AllocSession(bool server_session); + + Socket *accept_socket() const; + void set_accept_socket(); + + boost::asio::ssl::context context_; + std::auto_ptr so_ssl_accept_; // SSL socket used in async_accept + DISALLOW_COPY_AND_ASSIGN(SslServer); +}; + +#endif //__src_io_ssl_server_h__ diff --git a/src/io/ssl_session.cc b/src/io/ssl_session.cc new file mode 100644 index 00000000000..8c945ade59f --- /dev/null +++ b/src/io/ssl_session.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include +#include + +#include "ssl_session.h" + +using namespace boost::asio; + +SslSession::SslSession(SslServer *server, SslSocket *socket, + bool async_read_ready) : + TcpSession(server, NULL, async_read_ready), + ssl_socket_(socket) { +} + +SslSession::~SslSession() { +} + +TcpSession::Socket *SslSession::socket() const { + // return tcp socket + return &ssl_socket_->next_layer(); +} + +bool SslSession::Connected(Endpoint remote) { + if (IsClosed()) { + return false; + } + + // trigger ssl client handshake + std::srand(std::time(0)); + ssl_socket_->async_handshake + (boost::asio::ssl::stream_base::client, + boost::bind(&SslSession::ConnectHandShakeHandler, TcpSessionPtr(this), + remote, boost::asio::placeholders::error)); + return true; +} + +void SslSession::Accepted() { + // trigger ssl server handshake + std::srand(std::time(0)); + ssl_socket_->async_handshake + (boost::asio::ssl::stream_base::server, + boost::bind(&SslSession::AcceptHandShakeHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + +void SslSession::AcceptHandShakeHandler(TcpSessionPtr session, + const boost::system::error_code& error) { + SslSession *ssl_session = static_cast(session.get()); + if (!error) { + // on successful handshake continue with tcp session state machine. + ssl_session->TcpSession::Accepted(); + } else { + // close session on failure + ssl_session->CloseInternal(false); + } +} + +void SslSession::ConnectHandShakeHandler(TcpSessionPtr session, Endpoint remote, + const boost::system::error_code& error) { + SslSession *ssl_session = static_cast(session.get()); + bool ret = false; + if (!error) { + // on successful handshake continue with tcp session state machine. + ret = ssl_session->TcpSession::Connected(remote); + } + if (ret == false) { + // report connect failure and close the session + ssl_session->ConnectFailed(); + ssl_session->CloseInternal(false); + } +} + + +void SslSession::AsyncReadSome(boost::asio::mutable_buffer buffer) { + ssl_socket_->async_read_some(mutable_buffers_1(buffer), + boost::bind(&TcpSession::AsyncReadHandler, TcpSessionPtr(this), buffer, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred)); +} + +std::size_t SslSession::WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error) { + return ssl_socket_->write_some(boost::asio::buffer(data, len), error); +} + +void SslSession::AsyncWrite(const u_int8_t *data, std::size_t size) { + boost::asio::async_write( + *ssl_socket_.get(), buffer(data, size), + boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + diff --git a/src/io/ssl_session.h b/src/io/ssl_session.h new file mode 100644 index 00000000000..81dbc8b1e54 --- /dev/null +++ b/src/io/ssl_session.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#ifndef __src_io_ssl_session_h__ +#define __src_io_ssl_session_h__ + +#include "io/tcp_session.h" +#include "io/ssl_server.h" + +class SslSession : public TcpSession { +public: + typedef boost::asio::ssl::stream SslSocket; + + // SslSession constructor takes ownership of socket. + SslSession(SslServer *server, SslSocket *socket, + bool async_read_ready = true); + + virtual Socket *socket() const; + + // Override to trigger handshake + virtual bool Connected(Endpoint remote); + + // Override to trigger handshake + virtual void Accepted(); + + +protected: + virtual ~SslSession(); + +private: + static void AcceptHandShakeHandler(TcpSessionPtr session, + const boost::system::error_code& error); + static void ConnectHandShakeHandler(TcpSessionPtr session, Endpoint remote, + const boost::system::error_code& error); + + void AsyncReadSome(boost::asio::mutable_buffer buffer); + std::size_t WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error); + void AsyncWrite(const u_int8_t *data, std::size_t size); + + boost::scoped_ptr ssl_socket_; + DISALLOW_COPY_AND_ASSIGN(SslSession); +}; + +#endif // __src_io_ssl_session_h__ diff --git a/src/io/tcp_message_write.cc b/src/io/tcp_message_write.cc index 1c267e8ff9a..b28bdb00ed2 100644 --- a/src/io/tcp_message_write.cc +++ b/src/io/tcp_message_write.cc @@ -13,8 +13,8 @@ using namespace boost::asio; using namespace boost::system; using tbb::mutex; -TcpMessageWriter::TcpMessageWriter(Socket *socket, TcpSession *session) : - socket_(socket), offset_(0), session_(session) { +TcpMessageWriter::TcpMessageWriter(TcpSession *session) : + offset_(0), session_(session) { } TcpMessageWriter::~TcpMessageWriter() { @@ -36,7 +36,7 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { session_->server_->stats_.write_bytes += len; if (buffer_queue_.empty()) { - wrote = socket_->write_some(boost::asio::buffer(data, len), ec); + wrote = session_->WriteSome(data, len, ec); if (TcpSession::IsSocketErrorHard(ec)) return -1; assert(wrote >= 0); @@ -45,7 +45,7 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { "Encountered partial send of " << wrote << " bytes when " "sending " << len << " bytes, Error: " << ec); BufferAppend(data + wrote, len - wrote); - DeferWrite(); + session_->DeferWriter(); } } else { TCP_SESSION_LOG_UT_DEBUG(session_, TCP_DIR_OUT, @@ -55,55 +55,20 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { return wrote; } -void TcpMessageWriter::DeferWrite() { - - // Update socket write block count. - session_->stats_.write_blocked++; - session_->server_->stats_.write_blocked++; - socket_->async_write_some( - boost::asio::null_buffers(), - boost::bind(&TcpMessageWriter::HandleWriteReady, this, - TcpSessionPtr(session_), - placeholders::error, UTCTimestampUsec())); - return; -} - -// Socket is ready for write. Flush any pending data and notify -// clients aboout it. -void TcpMessageWriter::HandleWriteReady(TcpSessionPtr session_ptr, - const error_code &error, - uint64_t block_start_time) { - mutex::scoped_lock lock(session_->mutex()); - - // Update socket write block time. - uint64_t blocked_usecs = UTCTimestampUsec() - block_start_time; - session_->stats_.write_blocked_duration_usecs += blocked_usecs; - session_->server_->stats_.write_blocked_duration_usecs += blocked_usecs; - - if (TcpSession::IsSocketErrorHard(error)) { - goto done; - } - - // - // Ignore if connection is already closed. - // - if (session_->IsClosedLocked()) return; - +// Socket is ready for write. Flush any pending data +void TcpMessageWriter::HandleWriteReady(error_code &error) { while (!buffer_queue_.empty()) { boost::asio::mutable_buffer head = buffer_queue_.front(); const uint8_t *data = buffer_cast(head) + offset_; int remaining = buffer_size(head) - offset_; - error_code ec; - int wrote = socket_->write_some(buffer(data, remaining), ec); - if (TcpSession::IsSocketErrorHard(ec)) { - lock.release(); - if (!cb_.empty()) cb_(ec); + int wrote = session_->WriteSome(data, remaining, error); + if (TcpSession::IsSocketErrorHard(error)) { return; } assert(wrote >= 0); if (wrote != remaining) { offset_ += wrote; - DeferWrite(); + session_->DeferWriter(); return; } else { offset_ = 0; @@ -112,16 +77,6 @@ void TcpMessageWriter::HandleWriteReady(TcpSessionPtr session_ptr, } } buffer_queue_.clear(); - -done: - lock.release(); - // The session object is implicitly accessed in by cb_. This is - // safe because this function currently holds a refcount on the session - // via TcpSessionPtr. - if (!cb_.empty()) { - cb_(error); - } - return; } void TcpMessageWriter::BufferAppend(const uint8_t *src, int bytes) { @@ -137,6 +92,3 @@ void TcpMessageWriter::DeleteBuffer(mutable_buffer buffer) { return; } -void TcpMessageWriter::RegisterNotification(SendReadyCb cb) { - cb_ = cb; -} diff --git a/src/io/tcp_message_write.h b/src/io/tcp_message_write.h index 385cbf101db..914d26d1678 100644 --- a/src/io/tcp_message_write.h +++ b/src/io/tcp_message_write.h @@ -20,29 +20,22 @@ class TcpSession; class TcpMessageWriter { public: - typedef boost::asio::ip::tcp::socket Socket; static const int kDefaultBufferSize = 4 * 1024; - explicit TcpMessageWriter(Socket *, TcpSession *session); + explicit TcpMessageWriter(TcpSession *session); ~TcpMessageWriter(); // return false for send int Send(const uint8_t *msg, size_t len, error_code &ec); - typedef boost::function SendReadyCb; - void RegisterNotification(SendReadyCb); - private: + friend class TcpSession; typedef boost::intrusive_ptr TcpSessionPtr; typedef std::list BufferQueue; void BufferAppend(const uint8_t *data, int len); void DeleteBuffer(boost::asio::mutable_buffer buffer); - void DeferWrite(); - void HandleWriteReady(TcpSessionPtr session_ref, const error_code &ec, - uint64_t block_start_time); + void HandleWriteReady(boost::system::error_code &ec); BufferQueue buffer_queue_; - SendReadyCb cb_; - Socket *socket_; int offset_; TcpSession *session_; }; diff --git a/src/io/tcp_server.cc b/src/io/tcp_server.cc index 127a4a2dbff..0b60e84c93d 100644 --- a/src/io/tcp_server.cc +++ b/src/io/tcp_server.cc @@ -138,8 +138,7 @@ void TcpServer::ClearSessions() { } TcpSession *TcpServer::CreateSession() { - Socket *socket = new Socket(*evm_->io_service()); - TcpSession *session = AllocSession(socket); + TcpSession *session = AllocSession(false); { tbb::mutex::scoped_lock lock(mutex_); session_ref_.insert(TcpSessionPtr(session)); @@ -213,8 +212,8 @@ void TcpServer::AsyncAccept() { if (acceptor_ == NULL) { return; } - so_accept_.reset(new Socket(*evm_->io_service())); - acceptor_->async_accept(*so_accept_.get(), + set_accept_socket(); + acceptor_->async_accept(*accept_socket(), boost::bind(&TcpServer::AcceptHandlerInternal, this, TcpServerPtr(this), boost::asio::placeholders::error)); } @@ -240,7 +239,7 @@ bool TcpServer::HasSessions() const { bool TcpServer::HasSessionReadAvailable() const { tbb::mutex::scoped_lock lock(mutex_); boost::system::error_code error; - if (so_accept_->available(error) > 0) { + if (accept_socket()->available(error) > 0) { return true; } for (SessionMap::const_iterator iter = session_map_.begin(); @@ -266,6 +265,31 @@ TcpServer::Endpoint TcpServer::LocalEndpoint() const { return local; } +TcpSession *TcpServer::AllocSession(bool server_session) { + TcpSession *session; + if (server_session) { + session = AllocSession(so_accept_.get()); + + // if session allocate succeeds release ownership to so_accept. + if (session != NULL) { + so_accept_.release(); + } + } else { + Socket *socket = new Socket(*evm_->io_service()); + session = AllocSession(socket); + } + + return session; +} + +TcpServer::Socket *TcpServer::accept_socket() const { + return so_accept_.get(); +} + +void TcpServer::set_accept_socket() { + so_accept_.reset(new Socket(*evm_->io_service())); +} + bool TcpServer::AcceptSession(TcpSession *session) { return true; } @@ -281,15 +305,13 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, tcp::endpoint remote; boost::system::error_code ec; TcpSessionPtr session; - auto_ptr socket; bool need_close = false; if (error) { goto done; } - socket.reset(so_accept_.release()); - remote = socket->remote_endpoint(ec); + remote = accept_socket()->remote_endpoint(ec); if (ec) { TCP_SERVER_LOG_ERROR(this, TCP_DIR_IN, "Accept: No remote endpoint: " << ec.message()); @@ -301,16 +323,15 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, "Session accepted after server shutdown: " << remote.address().to_string() << ":" << remote.port()); - socket->close(ec); + accept_socket()->close(ec); goto done; } - session.reset(AllocSession(socket.get())); + session.reset(AllocSession(true)); if (session == NULL) { TCP_SERVER_LOG_DEBUG(this, TCP_DIR_IN, "Session not created"); goto done; } - socket.release(); ec = session->SetSocketOptions(); if (ec) { @@ -340,9 +361,7 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, } } - if (session->read_on_connect_) { - session->AsyncReadStart(); - } + session->Accepted(); done: if (need_close) { diff --git a/src/io/tcp_server.h b/src/io/tcp_server.h index 08280c87b34..4b8dbe9d92e 100644 --- a/src/io/tcp_server.h +++ b/src/io/tcp_server.h @@ -83,6 +83,13 @@ class TcpServer { // Create a session object. virtual TcpSession *AllocSession(Socket *socket) = 0; + // Only SslServer overrides this method, to manage server with SSL + // socket instead of TCP socket + virtual TcpSession *AllocSession(bool server_session); + + virtual Socket *accept_socket() const; + virtual void set_accept_socket(); + // // Passively accepted a new session. Returns true if the session is // accepted, false otherwise. diff --git a/src/io/tcp_session.cc b/src/io/tcp_session.cc index 44db6b87268..da132c3e1fb 100644 --- a/src/io/tcp_session.cc +++ b/src/io/tcp_session.cc @@ -54,10 +54,8 @@ TcpSession::TcpSession( established_(false), closed_(false), direction_(ACTIVE), - writer_(new TcpMessageWriter(socket, this)) { + writer_(new TcpMessageWriter(this)) { refcount_ = 0; - writer_->RegisterNotification( - boost::bind(&TcpSession::WriteReadyInternal, this, _1)); if (reader_task_id_ == -1) { TaskScheduler *scheduler = TaskScheduler::GetInstance(); reader_task_id_ = scheduler->GetTaskId("io::ReaderTask"); @@ -125,18 +123,43 @@ void TcpSession::AsyncReadStart() { ReleaseBufferLocked(buffer); return; } + AsyncReadSome(buffer); +} + +void TcpSession::DeferWriter() { + // Update socket write block count. + stats_.write_blocked++; + server_->stats_.write_blocked++; + socket()->async_write_some(boost::asio::null_buffers(), + boost::bind(&TcpSession::WriteReadyInternal, TcpSessionPtr(this), + placeholders::error, UTCTimestampUsec())); +} + +void TcpSession::AsyncReadSome(boost::asio::mutable_buffer buffer) { socket_->async_read_some(mutable_buffers_1(buffer), boost::bind(&TcpSession::AsyncReadHandler, TcpSessionPtr(this), buffer, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } +std::size_t TcpSession::WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error) { + return socket_->write_some(boost::asio::buffer(data, len), error); +} + +void TcpSession::AsyncWrite(const u_int8_t *data, std::size_t size) { + boost::asio::async_write( + *socket_.get(), buffer(data, size), + boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + TcpSession::Endpoint TcpSession::local_endpoint() const { tbb::mutex::scoped_lock lock(mutex_); if (!established_) { return Endpoint(); } boost::system::error_code err; - Endpoint local = socket_->local_endpoint(err); + Endpoint local = socket()->local_endpoint(err); if (err) { return Endpoint(); } @@ -153,7 +176,7 @@ void TcpSession::SetName() { boost::system::error_code err; Endpoint local; - local = socket_->local_endpoint(err); + local = socket()->local_endpoint(err); out << local.address().to_string() << ":" << local.port() << "::"; out << remote_.address().to_string() << ":" << remote_.port(); @@ -168,6 +191,21 @@ void TcpSession::SessionEstablished(Endpoint remote, SetName(); } +void TcpSession::Accepted() { + TCP_SESSION_LOG_DEBUG(this, TCP_DIR_OUT, + "Passive session Accept complete"); + { + tbb::mutex::scoped_lock obs_lock(obs_mutex_); + if (observer_) { + observer_(this, ACCEPT); + } + } + + if (read_on_connect_) { + AsyncReadStart(); + } +} + bool TcpSession::Connected(Endpoint remote) { assert(refcount_); { @@ -205,9 +243,9 @@ void TcpSession::ConnectFailed() { void TcpSession::CloseInternal(bool call_observer, bool notify_server) { tbb::mutex::scoped_lock lock(mutex_); - if (socket_.get() != NULL && !closed_) { + if (socket() != NULL && !closed_) { boost::system::error_code err; - socket_->close(err); + socket()->close(err); } closed_ = true; if (!established_) { @@ -240,16 +278,42 @@ void TcpSession::Close() { void TcpSession::WriteReady(const boost::system::error_code &error) { } -void TcpSession::WriteReadyInternal(const boost::system::error_code &error) { - if (IsSocketErrorHard(error)) { - TCP_SESSION_LOG_INFO(this, TCP_DIR_OUT, "Write failed due to error: " - << error.value() - << " category: " << error.category().name() - << " message: " << error.message()); - CloseInternal(true); - return; +void TcpSession::WriteReadyInternal(TcpSessionPtr session, + const boost::system::error_code &error, + uint64_t block_start_time) { + boost::system::error_code ec = error; + tbb::mutex::scoped_lock lock(session->mutex_); + + // Update socket write block time. + uint64_t blocked_usecs = UTCTimestampUsec() - block_start_time; + session->stats_.write_blocked_duration_usecs += blocked_usecs; + session->server_->stats_.write_blocked_duration_usecs += blocked_usecs; + + if (session->IsSocketErrorHard(ec)) { + goto session_error; + } + + // + // Ignore if connection is already closed. + // + if (session->IsClosedLocked()) return; + + session->writer_->HandleWriteReady(ec); + if (session->IsSocketErrorHard(ec)) { + goto session_error; } - WriteReady(error); + + lock.release(); + session->WriteReady(ec); + return; + +session_error: + lock.release(); + TCP_SESSION_LOG_INFO(session.get(), TCP_DIR_OUT, + "Write failed due to error: " << ec.value() + << " category: " << ec.category().name() + << " message: " << ec.message()); + session->CloseInternal(true); } void TcpSession::AsyncWriteHandler(TcpSessionPtr session, @@ -274,7 +338,7 @@ bool TcpSession::Send(const u_int8_t *data, size_t size, size_t *sent) { // if (!established_) return false; - if (socket_->non_blocking()) { + if (socket()->non_blocking()) { boost::system::error_code error; int len = writer_->Send(data, size, error); lock.release(); @@ -288,10 +352,7 @@ bool TcpSession::Send(const u_int8_t *data, size_t size, size_t *sent) { if (len < 0 || (size_t)len != size) ret = false; if (sent) *sent = (len > 0) ? len : 0; } else { - boost::asio::async_write( - *socket_.get(), buffer(data, size), - boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), - boost::asio::placeholders::error)); + AsyncWrite(data, size); if (sent) *sent = size; } return ret; @@ -337,11 +398,11 @@ int TcpSession::GetSessionInstance() const { int32_t TcpSession::local_port() const { - if (socket_.get() == NULL) { + if (socket() == NULL) { return -1; } boost::system::error_code error; - Endpoint local = socket_->local_endpoint(error); + Endpoint local = socket()->local_endpoint(error); if (IsSocketErrorHard(error)) { return -1; } @@ -349,11 +410,11 @@ int32_t TcpSession::local_port() const { } int32_t TcpSession::remote_port() const { - if (socket_.get() == NULL) { + if (socket() == NULL) { return -1; } boost::system::error_code error; - Endpoint remote = socket_->remote_endpoint(error); + Endpoint remote = socket()->remote_endpoint(error); if (IsSocketErrorHard(error)) { return -1; } @@ -550,7 +611,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti int keepalive_intvl, int keepalive_probes) { boost::system::error_code ec; socket_base::keep_alive keep_alive_option(true); - socket_->set_option(keep_alive_option, ec); + socket()->set_option(keep_alive_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keep_alive set error: " << ec); @@ -559,7 +620,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPIDLE typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPIDLE > keepalive_idle_time; keepalive_idle_time keepalive_idle_time_option(keepalive_time); - socket_->set_option(keepalive_idle_time_option, ec); + socket()->set_option(keepalive_idle_time_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_idle_time: " << keepalive_time << " set error: " << ec); @@ -569,7 +630,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPALIVE typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPALIVE > keepalive_idle_time; keepalive_idle_time keepalive_idle_time_option(keepalive_time); - socket_->set_option(keepalive_idle_time_option, ec); + socket()->set_option(keepalive_idle_time_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_idle_time: " << keepalive_time << " set error: " << ec); @@ -579,7 +640,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPINTVL typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPINTVL > keepalive_interval; keepalive_interval keepalive_interval_option(keepalive_intvl); - socket_->set_option(keepalive_interval_option, ec); + socket()->set_option(keepalive_interval_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_interval: " << keepalive_intvl << " set error: " << ec); @@ -589,7 +650,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPCNT typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPCNT > keepalive_count; keepalive_count keepalive_count_option(keepalive_probes); - socket_->set_option(keepalive_count_option, ec); + socket()->set_option(keepalive_count_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_probes: " << keepalive_probes << " set error: " << ec); @@ -605,7 +666,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { // // Make socket write non-blocking // - socket_->non_blocking(true, ec); + socket()->non_blocking(true, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_NA, "Cannot set socket non blocking: " << ec); @@ -625,7 +686,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { // sends more deterministically // socket_base::send_buffer_size send_buffer_size_option(sz); - socket_->set_option(send_buffer_size_option, ec); + socket()->set_option(send_buffer_size_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "send_buffer_size set error: " << ec); @@ -633,7 +694,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { } socket_base::receive_buffer_size receive_buffer_size_option(sz); - socket_->set_option(receive_buffer_size_option, ec); + socket()->set_option(receive_buffer_size_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_IN, "receive_buffer_size set error: " << ec); diff --git a/src/io/tcp_session.h b/src/io/tcp_session.h index 04f65ac71b6..454e7fb434e 100644 --- a/src/io/tcp_session.h +++ b/src/io/tcp_session.h @@ -62,6 +62,9 @@ class TcpSession { // Called by TcpServer to trigger async read. virtual bool Connected(Endpoint remote); + // Called by TcpServer to trigger async read. + virtual void Accepted(); + void ConnectFailed(); void Close(); @@ -71,7 +74,7 @@ class TcpSession { void SetBufferSize(int buffer_size); // Getters and setters - Socket *socket() { return socket_.get(); } + virtual Socket *socket() const { return socket_.get(); } TcpServer *server() { return server_.get(); } int32_t local_port() const; int32_t remote_port() const; @@ -132,6 +135,13 @@ class TcpSession { void GetTxSocketStats(SocketIOStats &socket_stats) const; protected: + typedef boost::intrusive_ptr TcpSessionPtr; + static void AsyncReadHandler(TcpSessionPtr session, + boost::asio::mutable_buffer buffer, + const boost::system::error_code &error, + size_t size); + static void AsyncWriteHandler(TcpSessionPtr session, + const boost::system::error_code &error); virtual ~TcpSession(); // Read handler. Called from a TBB task. @@ -139,6 +149,11 @@ class TcpSession { // Callback after socket is ready for write. virtual void WriteReady(const boost::system::error_code &error); + virtual void AsyncReadSome(boost::asio::mutable_buffer buffer); + virtual std::size_t WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error); + virtual void AsyncWrite(const u_int8_t *data, std::size_t size); + virtual int reader_task_id() const { return reader_task_id_; } @@ -147,28 +162,24 @@ class TcpSession { boost::system::error_code SetSocketKeepaliveOptions(int keepalive_time, int keepalive_intvl, int keepalive_probes); + void CloseInternal(bool call_observer, bool notify_server = true); + private: class Reader; friend class TcpServer; friend class TcpMessageWriter; friend void intrusive_ptr_add_ref(TcpSession *session); friend void intrusive_ptr_release(TcpSession *session); - typedef boost::intrusive_ptr TcpSessionPtr; typedef std::list BufferQueue; - static void AsyncReadHandler(TcpSessionPtr session, - boost::asio::mutable_buffer buffer, - const boost::system::error_code &error, - size_t size); - static void AsyncWriteHandler(TcpSessionPtr session, - const boost::system::error_code &error); + static void WriteReadyInternal(TcpSessionPtr session, + const boost::system::error_code &error, + uint64_t block_start_time); + void DeferWriter(); void ReleaseBufferLocked(Buffer buffer); - void CloseInternal(bool call_observer, bool notify_server = true); void SetEstablished(Endpoint remote, Direction dir); - tbb::mutex &mutex() { return mutex_; } - bool IsClosedLocked() const { return closed_; } @@ -176,7 +187,6 @@ class TcpSession { boost::asio::mutable_buffer AllocateBuffer(); void DeleteBuffer(boost::asio::mutable_buffer buffer); - void WriteReadyInternal(const boost::system::error_code &); static int reader_task_id_; diff --git a/src/io/test/SConscript b/src/io/test/SConscript index 0d68181114a..e8a4ba3540f 100644 --- a/src/io/test/SConscript +++ b/src/io/test/SConscript @@ -19,7 +19,7 @@ env.Append(LIBPATH = env['TOP'] + '/base/test') env.Prepend(LIBS = ['gunit', 'task_test', 'io', 'sandesh', 'http', 'sandeshvns', 'process_info', 'io', 'base', 'http_parser', 'curl', - 'boost_program_options', 'pugixml']) + 'boost_program_options', 'pugixml', 'ssl', 'crypto']) if sys.platform != 'darwin': env.Append(LIBS = ['rt']) @@ -36,6 +36,12 @@ tcp_server_test = env.UnitTest('tcp_server_test', env.Alias('src/io:tcp_server_test', tcp_server_test) +ssl_server_test = env.UnitTest('ssl_server_test', + ['ssl_server_test.cc'], + ) + +env.Alias('src/io:ssl_server_test', ssl_server_test) + tcp_io_test = env.UnitTest('tcp_io_test', ['tcp_io_test.cc'], ) @@ -64,6 +70,7 @@ test_suite = [ ] flaky_test_suite = [ + ssl_server_test, tcp_io_test, tcp_server_test, tcp_stress_test, diff --git a/src/io/test/ssl_server_test.cc b/src/io/test/ssl_server_test.cc new file mode 100644 index 00000000000..cb7febb53ac --- /dev/null +++ b/src/io/test/ssl_server_test.cc @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include "testing/gunit.h" + +#include "base/logging.h" +#include "base/parse_object.h" +#include "base/test/task_test_util.h" +#include "io/event_manager.h" +#include "io/ssl_server.h" +#include "io/ssl_session.h" +#include "io/test/event_manager_test.h" +#include "io/io_log.h" + +using namespace std; + +namespace { + +class EchoServer; + +class EchoSession : public SslSession { + public: + EchoSession(EchoServer *server, SslSocket *socket); + + protected: + virtual void OnRead(Buffer buffer) { + const u_int8_t *data = BufferData(buffer); + const size_t len = BufferSize(buffer); + TCP_UT_LOG_DEBUG("Received " << BufferData(buffer) << " " << len << " bytes"); + Send(data, len, NULL); + } + private: + void OnEvent(TcpSession *session, Event event) { + if (event == ACCEPT) { + TCP_UT_LOG_DEBUG("Accept"); + } + if (event == CLOSE) { + TCP_UT_LOG_DEBUG("Close"); + } + } +}; + +class EchoServer : public SslServer { +public: + explicit EchoServer(EventManager *evm) : + SslServer(evm, boost::asio::ssl::context::tlsv1_server), session_(NULL) { + boost::asio::ssl::context *ctx = context(); + boost::system::error_code ec; + ctx->set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + ctx->use_certificate_chain_file + ("controller/src/ifmap/client/test/newcert.pem", ec); + assert(ec.value() == 0); + ctx->use_private_key_file("controller/src/ifmap/client/test/server.pem", + boost::asio::ssl::context::pem, ec); + assert(ec.value() == 0); + ctx->add_verify_path("controller/src/ifmap/client/test/", ec); + assert(ec.value() == 0); + ctx->load_verify_file("controller/src/ifmap/client/test/newcert.pem", + ec); + assert(ec.value() == 0); + } + ~EchoServer() { + } + virtual SslSession *AllocSession(SslSocket *socket) { + session_ = new EchoSession(this, socket); + return session_; + } + + TcpSession *CreateSession() { + TcpSession *session = SslServer::CreateSession(); + Socket *socket = session->socket(); + + boost::system::error_code err; + socket->open(boost::asio::ip::tcp::v4(), err); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "open failed: " << err.message()); + } + err = session->SetSocketOptions(); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "sockopt: " << err.message()); + } + return session; + } + + EchoSession *GetSession() const { return session_; } + +private: + EchoSession *session_; +}; + +EchoSession::EchoSession(EchoServer *server, SslSocket *socket) + : SslSession(server, socket) { + set_observer(boost::bind(&EchoSession::OnEvent, this, _1, _2)); +} + +class SslClient; + +class ClientSession : public SslSession { + public: + ClientSession(SslClient *server, SslSocket *socket); + + void OnEvent(TcpSession *session, Event event) { + if (event == CONNECT_COMPLETE) { + const u_int8_t *data = (const u_int8_t *)"Hello there !"; + size_t len = 14; + Send(data, len, NULL); + } + } + + std::size_t &len() { return len_; } + + protected: + virtual void OnRead(Buffer buffer) { + const u_int8_t *data = BufferData(buffer); + const size_t len = BufferSize(buffer); + TCP_UT_LOG_DEBUG("Received " << BufferData(buffer) << " " << len << " bytes"); + len_ += len; + } + + private: + std::size_t len_; +}; + +class SslClient : public SslServer { +public: + explicit SslClient(EventManager *evm) : + SslServer(evm, boost::asio::ssl::context::tlsv1), session_(NULL) { + boost::asio::ssl::context *ctx = context(); + boost::system::error_code ec; + ctx->set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + ctx->use_certificate_chain_file + ("controller/src/ifmap/client/test/newcert.pem", ec); + assert(ec.value() == 0); + ctx->use_private_key_file("controller/src/ifmap/client/test/server.pem", + boost::asio::ssl::context::pem, ec); + assert(ec.value() == 0); + ctx->add_verify_path("controller/src/ifmap/client/test/", ec); + assert(ec.value() == 0); + ctx->load_verify_file("controller/src/ifmap/client/test/newcert.pem", + ec); + assert(ec.value() == 0); + } + ~SslClient() { + } + virtual SslSession *AllocSession(SslSocket *socket) { + session_ = new ClientSession(this, socket); + return session_; + } + + TcpSession *CreateSession() { + TcpSession *session = SslServer::CreateSession(); + Socket *socket = session->socket(); + + boost::system::error_code err; + socket->open(boost::asio::ip::tcp::v4(), err); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "open failed: " << err.message()); + } + err = session->SetSocketOptions(); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "sockopt: " << err.message()); + } + return session; + } + + ClientSession *GetSession() const { return session_; } + +private: + ClientSession *session_; +}; + +ClientSession::ClientSession(SslClient *server, SslSocket *socket) + : SslSession(server, socket) { + set_observer(boost::bind(&ClientSession::OnEvent, this, _1, _2)); +} + +class SslEchoServerTest : public ::testing::Test { +public: + void OnEvent(TcpSession *session, SslSession::Event event) { + boost::system::error_code ec; + timer_.cancel(ec); + ClientSession *client_session = static_cast(session); + client_session->OnEvent(session, event); + if (event == SslSession::CONNECT_FAILED) { + connect_fail_++; + session->Close(); + } + if (event == SslSession::CONNECT_COMPLETE) { + connect_success_++; + } + } + +protected: + SslEchoServerTest() + : evm_(new EventManager()), timer_(*evm_->io_service()), + connect_success_(0), connect_fail_(0), connect_abort_(0) { + } + virtual void SetUp() { + server_ = new EchoServer(evm_.get()); + thread_.reset(new ServerThread(evm_.get())); + session_ = NULL; + } + + virtual void TearDown() { + if (server_->GetSession()) { + server_->GetSession()->Close(); + } + if (session_) session_->Close(); + task_util::WaitForIdle(); + server_->Shutdown(); + server_->ClearSessions(); + task_util::WaitForIdle(); + TcpServerManager::DeleteServer(server_); + server_ = NULL; + evm_->Shutdown(); + if (thread_.get() != NULL) { + thread_->Join(); + } + task_util::WaitForIdle(); + } + + + void DummyTimerHandler(TcpSession *session, + const boost::system::error_code &error) { + if (error) { + return; + } + if (!session->IsClosed()) { + connect_abort_++; + } + session->Close(); + } + + void StartConnectTimer(TcpSession *session, int sec) { + boost::system::error_code ec; + timer_.expires_from_now(boost::posix_time::seconds(sec), ec); + timer_.async_wait( + boost::bind(&SslEchoServerTest::DummyTimerHandler, this, session, + boost::asio::placeholders::error)); + } + auto_ptr thread_; + auto_ptr evm_; + EchoServer *server_; + boost::asio::deadline_timer timer_; + EchoSession *session_; + int connect_success_; + int connect_fail_; + int connect_abort_; +}; + +TEST_F(SslEchoServerTest, msg_send_recv) { + SslClient *client = new SslClient(evm_.get()); + + task_util::WaitForIdle(); + server_->Initialize(0); + task_util::WaitForIdle(); + thread_->Start(); // Must be called after initialization + + connect_success_ = connect_fail_ = connect_abort_ = 0; + ClientSession *session = static_cast(client->CreateSession()); + session->set_observer(boost::bind(&SslEchoServerTest::OnEvent, this, _1, _2)); + boost::asio::ip::tcp::endpoint endpoint; + boost::system::error_code ec; + endpoint.address(boost::asio::ip::address::from_string("127.0.0.1", ec)); + endpoint.port(server_->GetPort()); + client->Connect(session, endpoint); + task_util::WaitForIdle(); + StartConnectTimer(session, 10); + TASK_UTIL_EXPECT_TRUE(session->IsEstablished()); + TASK_UTIL_EXPECT_FALSE(session->IsClosed()); + TASK_UTIL_EXPECT_EQ(1, connect_success_); + TASK_UTIL_EXPECT_EQ(connect_fail_, 0); + TASK_UTIL_EXPECT_EQ(connect_abort_, 0); + // wait for on connect message to come back from echo server. + TASK_UTIL_EXPECT_EQ(session->len(), 14); + + session->Close(); + client->DeleteSession(session); + + client->Shutdown(); + task_util::WaitForIdle(); + TcpServerManager::DeleteServer(client); + client = NULL; +} + +} // namespace + +int main(int argc, char **argv) { + LoggingInit(); + Sandesh::SetLoggingParams(true, "", SandeshLevel::UT_DEBUG); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}