diff --git a/src/io/SConscript b/src/io/SConscript index 1ba9453239c..4e409b2afed 100644 --- a/src/io/SConscript +++ b/src/io/SConscript @@ -23,6 +23,8 @@ libio = env.Library('io', EventManagerSrc + [ 'io_utils.cc', + 'ssl_server.cc', + 'ssl_session.cc', 'tcp_message_write.cc', 'tcp_server.cc', 'tcp_session.cc', diff --git a/src/io/ssl_server.cc b/src/io/ssl_server.cc new file mode 100644 index 00000000000..055893177b0 --- /dev/null +++ b/src/io/ssl_server.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include "ssl_server.h" +#include "ssl_session.h" + +#include "io/event_manager.h" + +SslServer::SslServer(EventManager *evm, boost::asio::ssl::context::method m) + : TcpServer(evm), context_(*evm->io_service(), m) { + boost::system::error_code ec; + // By default set verify mode to none, to be set by derived class later. + context_.set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + context_.set_options(boost::asio::ssl::context::default_workarounds, ec); + assert(ec.value() == 0); +} + +SslServer::~SslServer() { +} + +boost::asio::ssl::context *SslServer::context() { + return &context_; +} + +TcpSession *SslServer::AllocSession(bool server_session) { + SslSession *session; + if (server_session) { + session = AllocSession(so_ssl_accept_.get()); + + // if session allocate succeeds release ownership to so_accept. + if (session != NULL) { + so_ssl_accept_.release(); + } + } else { + SslSocket *socket = new SslSocket(*event_manager()->io_service(), + context_); + session = AllocSession(socket); + } + + return session; +} + +TcpServer::Socket *SslServer::accept_socket() const { + // return tcp socket + return &(so_ssl_accept_->next_layer()); +} + +void SslServer::set_accept_socket() { + so_ssl_accept_.reset(new SslSocket(*event_manager()->io_service(), + context_)); +} + diff --git a/src/io/ssl_server.h b/src/io/ssl_server.h new file mode 100644 index 00000000000..6da053c1626 --- /dev/null +++ b/src/io/ssl_server.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#ifndef __src_io_ssl_server_h__ +#define __src_io_ssl_server_h__ + +#include + +#include "io/tcp_server.h" + +class SslSession; + +class SslServer : public TcpServer { +public: + typedef boost::asio::ssl::stream SslSocket; + + explicit SslServer(EventManager *evm, boost::asio::ssl::context::method m); + virtual ~SslServer(); + +protected: + // given SSL socket, Create a session object. + virtual SslSession *AllocSession(SslSocket *socket) = 0; + + // boost ssl context accessor to setup ssl context variables. + boost::asio::ssl::context *context(); + +private: + // suppress AllocSession method using tcp socket, not valid for + // ssl server. + TcpSession *AllocSession(Socket *socket) { return NULL; } + + TcpSession *AllocSession(bool server_session); + + Socket *accept_socket() const; + void set_accept_socket(); + + boost::asio::ssl::context context_; + std::auto_ptr so_ssl_accept_; // SSL socket used in async_accept + DISALLOW_COPY_AND_ASSIGN(SslServer); +}; + +#endif //__src_io_ssl_server_h__ diff --git a/src/io/ssl_session.cc b/src/io/ssl_session.cc new file mode 100644 index 00000000000..8c945ade59f --- /dev/null +++ b/src/io/ssl_session.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include +#include + +#include "ssl_session.h" + +using namespace boost::asio; + +SslSession::SslSession(SslServer *server, SslSocket *socket, + bool async_read_ready) : + TcpSession(server, NULL, async_read_ready), + ssl_socket_(socket) { +} + +SslSession::~SslSession() { +} + +TcpSession::Socket *SslSession::socket() const { + // return tcp socket + return &ssl_socket_->next_layer(); +} + +bool SslSession::Connected(Endpoint remote) { + if (IsClosed()) { + return false; + } + + // trigger ssl client handshake + std::srand(std::time(0)); + ssl_socket_->async_handshake + (boost::asio::ssl::stream_base::client, + boost::bind(&SslSession::ConnectHandShakeHandler, TcpSessionPtr(this), + remote, boost::asio::placeholders::error)); + return true; +} + +void SslSession::Accepted() { + // trigger ssl server handshake + std::srand(std::time(0)); + ssl_socket_->async_handshake + (boost::asio::ssl::stream_base::server, + boost::bind(&SslSession::AcceptHandShakeHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + +void SslSession::AcceptHandShakeHandler(TcpSessionPtr session, + const boost::system::error_code& error) { + SslSession *ssl_session = static_cast(session.get()); + if (!error) { + // on successful handshake continue with tcp session state machine. + ssl_session->TcpSession::Accepted(); + } else { + // close session on failure + ssl_session->CloseInternal(false); + } +} + +void SslSession::ConnectHandShakeHandler(TcpSessionPtr session, Endpoint remote, + const boost::system::error_code& error) { + SslSession *ssl_session = static_cast(session.get()); + bool ret = false; + if (!error) { + // on successful handshake continue with tcp session state machine. + ret = ssl_session->TcpSession::Connected(remote); + } + if (ret == false) { + // report connect failure and close the session + ssl_session->ConnectFailed(); + ssl_session->CloseInternal(false); + } +} + + +void SslSession::AsyncReadSome(boost::asio::mutable_buffer buffer) { + ssl_socket_->async_read_some(mutable_buffers_1(buffer), + boost::bind(&TcpSession::AsyncReadHandler, TcpSessionPtr(this), buffer, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred)); +} + +std::size_t SslSession::WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error) { + return ssl_socket_->write_some(boost::asio::buffer(data, len), error); +} + +void SslSession::AsyncWrite(const u_int8_t *data, std::size_t size) { + boost::asio::async_write( + *ssl_socket_.get(), buffer(data, size), + boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + diff --git a/src/io/ssl_session.h b/src/io/ssl_session.h new file mode 100644 index 00000000000..81dbc8b1e54 --- /dev/null +++ b/src/io/ssl_session.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#ifndef __src_io_ssl_session_h__ +#define __src_io_ssl_session_h__ + +#include "io/tcp_session.h" +#include "io/ssl_server.h" + +class SslSession : public TcpSession { +public: + typedef boost::asio::ssl::stream SslSocket; + + // SslSession constructor takes ownership of socket. + SslSession(SslServer *server, SslSocket *socket, + bool async_read_ready = true); + + virtual Socket *socket() const; + + // Override to trigger handshake + virtual bool Connected(Endpoint remote); + + // Override to trigger handshake + virtual void Accepted(); + + +protected: + virtual ~SslSession(); + +private: + static void AcceptHandShakeHandler(TcpSessionPtr session, + const boost::system::error_code& error); + static void ConnectHandShakeHandler(TcpSessionPtr session, Endpoint remote, + const boost::system::error_code& error); + + void AsyncReadSome(boost::asio::mutable_buffer buffer); + std::size_t WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error); + void AsyncWrite(const u_int8_t *data, std::size_t size); + + boost::scoped_ptr ssl_socket_; + DISALLOW_COPY_AND_ASSIGN(SslSession); +}; + +#endif // __src_io_ssl_session_h__ diff --git a/src/io/tcp_message_write.cc b/src/io/tcp_message_write.cc index 1c267e8ff9a..b28bdb00ed2 100644 --- a/src/io/tcp_message_write.cc +++ b/src/io/tcp_message_write.cc @@ -13,8 +13,8 @@ using namespace boost::asio; using namespace boost::system; using tbb::mutex; -TcpMessageWriter::TcpMessageWriter(Socket *socket, TcpSession *session) : - socket_(socket), offset_(0), session_(session) { +TcpMessageWriter::TcpMessageWriter(TcpSession *session) : + offset_(0), session_(session) { } TcpMessageWriter::~TcpMessageWriter() { @@ -36,7 +36,7 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { session_->server_->stats_.write_bytes += len; if (buffer_queue_.empty()) { - wrote = socket_->write_some(boost::asio::buffer(data, len), ec); + wrote = session_->WriteSome(data, len, ec); if (TcpSession::IsSocketErrorHard(ec)) return -1; assert(wrote >= 0); @@ -45,7 +45,7 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { "Encountered partial send of " << wrote << " bytes when " "sending " << len << " bytes, Error: " << ec); BufferAppend(data + wrote, len - wrote); - DeferWrite(); + session_->DeferWriter(); } } else { TCP_SESSION_LOG_UT_DEBUG(session_, TCP_DIR_OUT, @@ -55,55 +55,20 @@ int TcpMessageWriter::Send(const uint8_t *data, size_t len, error_code &ec) { return wrote; } -void TcpMessageWriter::DeferWrite() { - - // Update socket write block count. - session_->stats_.write_blocked++; - session_->server_->stats_.write_blocked++; - socket_->async_write_some( - boost::asio::null_buffers(), - boost::bind(&TcpMessageWriter::HandleWriteReady, this, - TcpSessionPtr(session_), - placeholders::error, UTCTimestampUsec())); - return; -} - -// Socket is ready for write. Flush any pending data and notify -// clients aboout it. -void TcpMessageWriter::HandleWriteReady(TcpSessionPtr session_ptr, - const error_code &error, - uint64_t block_start_time) { - mutex::scoped_lock lock(session_->mutex()); - - // Update socket write block time. - uint64_t blocked_usecs = UTCTimestampUsec() - block_start_time; - session_->stats_.write_blocked_duration_usecs += blocked_usecs; - session_->server_->stats_.write_blocked_duration_usecs += blocked_usecs; - - if (TcpSession::IsSocketErrorHard(error)) { - goto done; - } - - // - // Ignore if connection is already closed. - // - if (session_->IsClosedLocked()) return; - +// Socket is ready for write. Flush any pending data +void TcpMessageWriter::HandleWriteReady(error_code &error) { while (!buffer_queue_.empty()) { boost::asio::mutable_buffer head = buffer_queue_.front(); const uint8_t *data = buffer_cast(head) + offset_; int remaining = buffer_size(head) - offset_; - error_code ec; - int wrote = socket_->write_some(buffer(data, remaining), ec); - if (TcpSession::IsSocketErrorHard(ec)) { - lock.release(); - if (!cb_.empty()) cb_(ec); + int wrote = session_->WriteSome(data, remaining, error); + if (TcpSession::IsSocketErrorHard(error)) { return; } assert(wrote >= 0); if (wrote != remaining) { offset_ += wrote; - DeferWrite(); + session_->DeferWriter(); return; } else { offset_ = 0; @@ -112,16 +77,6 @@ void TcpMessageWriter::HandleWriteReady(TcpSessionPtr session_ptr, } } buffer_queue_.clear(); - -done: - lock.release(); - // The session object is implicitly accessed in by cb_. This is - // safe because this function currently holds a refcount on the session - // via TcpSessionPtr. - if (!cb_.empty()) { - cb_(error); - } - return; } void TcpMessageWriter::BufferAppend(const uint8_t *src, int bytes) { @@ -137,6 +92,3 @@ void TcpMessageWriter::DeleteBuffer(mutable_buffer buffer) { return; } -void TcpMessageWriter::RegisterNotification(SendReadyCb cb) { - cb_ = cb; -} diff --git a/src/io/tcp_message_write.h b/src/io/tcp_message_write.h index 385cbf101db..914d26d1678 100644 --- a/src/io/tcp_message_write.h +++ b/src/io/tcp_message_write.h @@ -20,29 +20,22 @@ class TcpSession; class TcpMessageWriter { public: - typedef boost::asio::ip::tcp::socket Socket; static const int kDefaultBufferSize = 4 * 1024; - explicit TcpMessageWriter(Socket *, TcpSession *session); + explicit TcpMessageWriter(TcpSession *session); ~TcpMessageWriter(); // return false for send int Send(const uint8_t *msg, size_t len, error_code &ec); - typedef boost::function SendReadyCb; - void RegisterNotification(SendReadyCb); - private: + friend class TcpSession; typedef boost::intrusive_ptr TcpSessionPtr; typedef std::list BufferQueue; void BufferAppend(const uint8_t *data, int len); void DeleteBuffer(boost::asio::mutable_buffer buffer); - void DeferWrite(); - void HandleWriteReady(TcpSessionPtr session_ref, const error_code &ec, - uint64_t block_start_time); + void HandleWriteReady(boost::system::error_code &ec); BufferQueue buffer_queue_; - SendReadyCb cb_; - Socket *socket_; int offset_; TcpSession *session_; }; diff --git a/src/io/tcp_server.cc b/src/io/tcp_server.cc index f7f94d00310..dc2c42c9386 100644 --- a/src/io/tcp_server.cc +++ b/src/io/tcp_server.cc @@ -138,8 +138,7 @@ void TcpServer::ClearSessions() { } TcpSession *TcpServer::CreateSession() { - Socket *socket = new Socket(*evm_->io_service()); - TcpSession *session = AllocSession(socket); + TcpSession *session = AllocSession(false); { tbb::mutex::scoped_lock lock(mutex_); session_ref_.insert(TcpSessionPtr(session)); @@ -213,8 +212,8 @@ void TcpServer::AsyncAccept() { if (acceptor_ == NULL) { return; } - so_accept_.reset(new Socket(*evm_->io_service())); - acceptor_->async_accept(*so_accept_.get(), + set_accept_socket(); + acceptor_->async_accept(*accept_socket(), boost::bind(&TcpServer::AcceptHandlerInternal, this, TcpServerPtr(this), boost::asio::placeholders::error)); } @@ -240,7 +239,7 @@ bool TcpServer::HasSessions() const { bool TcpServer::HasSessionReadAvailable() const { tbb::mutex::scoped_lock lock(mutex_); boost::system::error_code error; - if (so_accept_->available(error) > 0) { + if (accept_socket()->available(error) > 0) { return true; } for (SessionMap::const_iterator iter = session_map_.begin(); @@ -266,6 +265,31 @@ TcpServer::Endpoint TcpServer::LocalEndpoint() const { return local; } +TcpSession *TcpServer::AllocSession(bool server_session) { + TcpSession *session; + if (server_session) { + session = AllocSession(so_accept_.get()); + + // if session allocate succeeds release ownership to so_accept. + if (session != NULL) { + so_accept_.release(); + } + } else { + Socket *socket = new Socket(*evm_->io_service()); + session = AllocSession(socket); + } + + return session; +} + +TcpServer::Socket *TcpServer::accept_socket() const { + return so_accept_.get(); +} + +void TcpServer::set_accept_socket() { + so_accept_.reset(new Socket(*evm_->io_service())); +} + bool TcpServer::AcceptSession(TcpSession *session) { return true; } @@ -281,15 +305,13 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, tcp::endpoint remote; boost::system::error_code ec; TcpSessionPtr session; - auto_ptr socket; bool need_close = false; if (error) { goto done; } - socket.reset(so_accept_.release()); - remote = socket->remote_endpoint(ec); + remote = accept_socket()->remote_endpoint(ec); if (ec) { TCP_SERVER_LOG_ERROR(this, TCP_DIR_IN, "Accept: No remote endpoint: " << ec.message()); @@ -301,16 +323,15 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, "Session accepted after server shutdown: " << remote.address().to_string() << ":" << remote.port()); - socket->close(ec); + accept_socket()->close(ec); goto done; } - session.reset(AllocSession(socket.get())); + session.reset(AllocSession(true)); if (session == NULL) { TCP_SERVER_LOG_DEBUG(this, TCP_DIR_IN, "Session not created"); goto done; } - socket.release(); ec = session->SetSocketOptions(); if (ec) { @@ -340,9 +361,7 @@ void TcpServer::AcceptHandlerInternal(TcpServerPtr server, } } - if (session->read_on_connect_) { - session->AsyncReadStart(); - } + session->Accepted(); done: if (need_close) { diff --git a/src/io/tcp_server.h b/src/io/tcp_server.h index 49261319364..534acef8c3c 100644 --- a/src/io/tcp_server.h +++ b/src/io/tcp_server.h @@ -83,6 +83,13 @@ class TcpServer { // Create a session object. virtual TcpSession *AllocSession(Socket *socket) = 0; + // Only SslServer overrides this method, to manage server with SSL + // socket instead of TCP socket + virtual TcpSession *AllocSession(bool server_session); + + virtual Socket *accept_socket() const; + virtual void set_accept_socket(); + // // Passively accepted a new session. Returns true if the session is // accepted, false otherwise. diff --git a/src/io/tcp_session.cc b/src/io/tcp_session.cc index d6988e1f4aa..d0d45d551db 100644 --- a/src/io/tcp_session.cc +++ b/src/io/tcp_session.cc @@ -54,10 +54,8 @@ TcpSession::TcpSession( established_(false), closed_(false), direction_(ACTIVE), - writer_(new TcpMessageWriter(socket, this)) { + writer_(new TcpMessageWriter(this)) { refcount_ = 0; - writer_->RegisterNotification( - boost::bind(&TcpSession::WriteReadyInternal, this, _1)); if (reader_task_id_ == -1) { TaskScheduler *scheduler = TaskScheduler::GetInstance(); reader_task_id_ = scheduler->GetTaskId("io::ReaderTask"); @@ -125,18 +123,43 @@ void TcpSession::AsyncReadStart() { ReleaseBufferLocked(buffer); return; } + AsyncReadSome(buffer); +} + +void TcpSession::DeferWriter() { + // Update socket write block count. + stats_.write_blocked++; + server_->stats_.write_blocked++; + socket()->async_write_some(boost::asio::null_buffers(), + boost::bind(&TcpSession::WriteReadyInternal, TcpSessionPtr(this), + placeholders::error, UTCTimestampUsec())); +} + +void TcpSession::AsyncReadSome(boost::asio::mutable_buffer buffer) { socket_->async_read_some(mutable_buffers_1(buffer), boost::bind(&TcpSession::AsyncReadHandler, TcpSessionPtr(this), buffer, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } +std::size_t TcpSession::WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error) { + return socket_->write_some(boost::asio::buffer(data, len), error); +} + +void TcpSession::AsyncWrite(const u_int8_t *data, std::size_t size) { + boost::asio::async_write( + *socket_.get(), buffer(data, size), + boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), + boost::asio::placeholders::error)); +} + TcpSession::Endpoint TcpSession::local_endpoint() const { tbb::mutex::scoped_lock lock(mutex_); if (!established_) { return Endpoint(); } boost::system::error_code err; - Endpoint local = socket_->local_endpoint(err); + Endpoint local = socket()->local_endpoint(err); if (err) { return Endpoint(); } @@ -153,7 +176,7 @@ void TcpSession::SetName() { boost::system::error_code err; Endpoint local; - local = socket_->local_endpoint(err); + local = socket()->local_endpoint(err); out << local.address().to_string() << ":" << local.port() << "::"; out << remote_.address().to_string() << ":" << remote_.port(); @@ -168,6 +191,21 @@ void TcpSession::SessionEstablished(Endpoint remote, SetName(); } +void TcpSession::Accepted() { + TCP_SESSION_LOG_DEBUG(this, TCP_DIR_OUT, + "Passive session Accept complete"); + { + tbb::mutex::scoped_lock obs_lock(obs_mutex_); + if (observer_) { + observer_(this, ACCEPT); + } + } + + if (read_on_connect_) { + AsyncReadStart(); + } +} + bool TcpSession::Connected(Endpoint remote) { assert(refcount_); { @@ -205,9 +243,9 @@ void TcpSession::ConnectFailed() { void TcpSession::CloseInternal(bool call_observer, bool notify_server) { tbb::mutex::scoped_lock lock(mutex_); - if (socket_.get() != NULL && !closed_) { + if (socket() != NULL && !closed_) { boost::system::error_code err; - socket_->close(err); + socket()->close(err); } closed_ = true; if (!established_) { @@ -240,16 +278,42 @@ void TcpSession::Close() { void TcpSession::WriteReady(const boost::system::error_code &error) { } -void TcpSession::WriteReadyInternal(const boost::system::error_code &error) { - if (IsSocketErrorHard(error)) { - TCP_SESSION_LOG_INFO(this, TCP_DIR_OUT, "Write failed due to error: " - << error.value() - << " category: " << error.category().name() - << " message: " << error.message()); - CloseInternal(true); - return; +void TcpSession::WriteReadyInternal(TcpSessionPtr session, + const boost::system::error_code &error, + uint64_t block_start_time) { + boost::system::error_code ec = error; + tbb::mutex::scoped_lock lock(session->mutex_); + + // Update socket write block time. + uint64_t blocked_usecs = UTCTimestampUsec() - block_start_time; + session->stats_.write_blocked_duration_usecs += blocked_usecs; + session->server_->stats_.write_blocked_duration_usecs += blocked_usecs; + + if (session->IsSocketErrorHard(ec)) { + goto session_error; + } + + // + // Ignore if connection is already closed. + // + if (session->IsClosedLocked()) return; + + session->writer_->HandleWriteReady(ec); + if (session->IsSocketErrorHard(ec)) { + goto session_error; } - WriteReady(error); + + lock.release(); + session->WriteReady(ec); + return; + +session_error: + lock.release(); + TCP_SESSION_LOG_INFO(session.get(), TCP_DIR_OUT, + "Write failed due to error: " << ec.value() + << " category: " << ec.category().name() + << " message: " << ec.message()); + session->CloseInternal(true); } void TcpSession::AsyncWriteHandler(TcpSessionPtr session, @@ -274,7 +338,7 @@ bool TcpSession::Send(const u_int8_t *data, size_t size, size_t *sent) { // if (!established_) return false; - if (socket_->non_blocking()) { + if (socket()->non_blocking()) { boost::system::error_code error; int len = writer_->Send(data, size, error); lock.release(); @@ -288,10 +352,7 @@ bool TcpSession::Send(const u_int8_t *data, size_t size, size_t *sent) { if (len < 0 || (size_t)len != size) ret = false; if (sent) *sent = (len > 0) ? len : 0; } else { - boost::asio::async_write( - *socket_.get(), buffer(data, size), - boost::bind(&TcpSession::AsyncWriteHandler, TcpSessionPtr(this), - boost::asio::placeholders::error)); + AsyncWrite(data, size); if (sent) *sent = size; } return ret; @@ -337,11 +398,11 @@ int TcpSession::GetSessionInstance() const { int32_t TcpSession::local_port() const { - if (socket_.get() == NULL) { + if (socket() == NULL) { return -1; } boost::system::error_code error; - Endpoint local = socket_->local_endpoint(error); + Endpoint local = socket()->local_endpoint(error); if (IsSocketErrorHard(error)) { return -1; } @@ -349,11 +410,11 @@ int32_t TcpSession::local_port() const { } int32_t TcpSession::remote_port() const { - if (socket_.get() == NULL) { + if (socket() == NULL) { return -1; } boost::system::error_code error; - Endpoint remote = socket_->remote_endpoint(error); + Endpoint remote = socket()->remote_endpoint(error); if (IsSocketErrorHard(error)) { return -1; } @@ -546,7 +607,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti int keepalive_intvl, int keepalive_probes) { boost::system::error_code ec; socket_base::keep_alive keep_alive_option(true); - socket_->set_option(keep_alive_option, ec); + socket()->set_option(keep_alive_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keep_alive set error: " << ec); @@ -555,7 +616,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPIDLE typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPIDLE > keepalive_idle_time; keepalive_idle_time keepalive_idle_time_option(keepalive_time); - socket_->set_option(keepalive_idle_time_option, ec); + socket()->set_option(keepalive_idle_time_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_idle_time: " << keepalive_time << " set error: " << ec); @@ -565,7 +626,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPALIVE typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPALIVE > keepalive_idle_time; keepalive_idle_time keepalive_idle_time_option(keepalive_time); - socket_->set_option(keepalive_idle_time_option, ec); + socket()->set_option(keepalive_idle_time_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_idle_time: " << keepalive_time << " set error: " << ec); @@ -575,7 +636,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPINTVL typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPINTVL > keepalive_interval; keepalive_interval keepalive_interval_option(keepalive_intvl); - socket_->set_option(keepalive_interval_option, ec); + socket()->set_option(keepalive_interval_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_interval: " << keepalive_intvl << " set error: " << ec); @@ -585,7 +646,7 @@ boost::system::error_code TcpSession::SetSocketKeepaliveOptions(int keepalive_ti #ifdef TCP_KEEPCNT typedef boost::asio::detail::socket_option::integer< IPPROTO_TCP, TCP_KEEPCNT > keepalive_count; keepalive_count keepalive_count_option(keepalive_probes); - socket_->set_option(keepalive_count_option, ec); + socket()->set_option(keepalive_count_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "keepalive_probes: " << keepalive_probes << " set error: " << ec); @@ -601,7 +662,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { // // Make socket write non-blocking // - socket_->non_blocking(true, ec); + socket()->non_blocking(true, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_NA, "Cannot set socket non blocking: " << ec); @@ -621,7 +682,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { // sends more deterministically // socket_base::send_buffer_size send_buffer_size_option(sz); - socket_->set_option(send_buffer_size_option, ec); + socket()->set_option(send_buffer_size_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_OUT, "send_buffer_size set error: " << ec); @@ -629,7 +690,7 @@ boost::system::error_code TcpSession::SetSocketOptions() { } socket_base::receive_buffer_size receive_buffer_size_option(sz); - socket_->set_option(receive_buffer_size_option, ec); + socket()->set_option(receive_buffer_size_option, ec); if (ec) { TCP_SESSION_LOG_ERROR(this, TCP_DIR_IN, "receive_buffer_size set error: " << ec); diff --git a/src/io/tcp_session.h b/src/io/tcp_session.h index ceb0c87fd7d..9ae138ff840 100644 --- a/src/io/tcp_session.h +++ b/src/io/tcp_session.h @@ -62,6 +62,9 @@ class TcpSession { // Called by TcpServer to trigger async read. virtual bool Connected(Endpoint remote); + // Called by TcpServer to trigger async read. + virtual void Accepted(); + void ConnectFailed(); void Close(); @@ -71,7 +74,7 @@ class TcpSession { void SetBufferSize(int buffer_size); // Getters and setters - Socket *socket() { return socket_.get(); } + virtual Socket *socket() const { return socket_.get(); } TcpServer *server() { return server_.get(); } int32_t local_port() const; int32_t remote_port() const; @@ -132,6 +135,13 @@ class TcpSession { void GetTxSocketStats(SocketIOStats &socket_stats) const; protected: + typedef boost::intrusive_ptr TcpSessionPtr; + static void AsyncReadHandler(TcpSessionPtr session, + boost::asio::mutable_buffer buffer, + const boost::system::error_code &error, + size_t size); + static void AsyncWriteHandler(TcpSessionPtr session, + const boost::system::error_code &error); virtual ~TcpSession(); // Read handler. Called from a TBB task. @@ -139,6 +149,11 @@ class TcpSession { // Callback after socket is ready for write. virtual void WriteReady(const boost::system::error_code &error); + virtual void AsyncReadSome(boost::asio::mutable_buffer buffer); + virtual std::size_t WriteSome(const uint8_t *data, std::size_t len, + boost::system::error_code &error); + virtual void AsyncWrite(const u_int8_t *data, std::size_t size); + virtual int reader_task_id() const { return reader_task_id_; } @@ -147,28 +162,24 @@ class TcpSession { boost::system::error_code SetSocketKeepaliveOptions(int keepalive_time, int keepalive_intvl, int keepalive_probes); + void CloseInternal(bool call_observer, bool notify_server = true); + private: class Reader; friend class TcpServer; friend class TcpMessageWriter; friend void intrusive_ptr_add_ref(TcpSession *session); friend void intrusive_ptr_release(TcpSession *session); - typedef boost::intrusive_ptr TcpSessionPtr; typedef std::list BufferQueue; - static void AsyncReadHandler(TcpSessionPtr session, - boost::asio::mutable_buffer buffer, - const boost::system::error_code &error, - size_t size); - static void AsyncWriteHandler(TcpSessionPtr session, - const boost::system::error_code &error); + static void WriteReadyInternal(TcpSessionPtr session, + const boost::system::error_code &error, + uint64_t block_start_time); + void DeferWriter(); void ReleaseBufferLocked(Buffer buffer); - void CloseInternal(bool call_observer, bool notify_server = true); void SetEstablished(Endpoint remote, Direction dir); - tbb::mutex &mutex() { return mutex_; } - bool IsClosedLocked() const { return closed_; } @@ -176,7 +187,6 @@ class TcpSession { boost::asio::mutable_buffer AllocateBuffer(); void DeleteBuffer(boost::asio::mutable_buffer buffer); - void WriteReadyInternal(const boost::system::error_code &); static int reader_task_id_; diff --git a/src/io/test/SConscript b/src/io/test/SConscript index 0d68181114a..e8a4ba3540f 100644 --- a/src/io/test/SConscript +++ b/src/io/test/SConscript @@ -19,7 +19,7 @@ env.Append(LIBPATH = env['TOP'] + '/base/test') env.Prepend(LIBS = ['gunit', 'task_test', 'io', 'sandesh', 'http', 'sandeshvns', 'process_info', 'io', 'base', 'http_parser', 'curl', - 'boost_program_options', 'pugixml']) + 'boost_program_options', 'pugixml', 'ssl', 'crypto']) if sys.platform != 'darwin': env.Append(LIBS = ['rt']) @@ -36,6 +36,12 @@ tcp_server_test = env.UnitTest('tcp_server_test', env.Alias('src/io:tcp_server_test', tcp_server_test) +ssl_server_test = env.UnitTest('ssl_server_test', + ['ssl_server_test.cc'], + ) + +env.Alias('src/io:ssl_server_test', ssl_server_test) + tcp_io_test = env.UnitTest('tcp_io_test', ['tcp_io_test.cc'], ) @@ -64,6 +70,7 @@ test_suite = [ ] flaky_test_suite = [ + ssl_server_test, tcp_io_test, tcp_server_test, tcp_stress_test, diff --git a/src/io/test/ssl_server_test.cc b/src/io/test/ssl_server_test.cc new file mode 100644 index 00000000000..cb7febb53ac --- /dev/null +++ b/src/io/test/ssl_server_test.cc @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2015 Juniper Networks, Inc. All rights reserved. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include "testing/gunit.h" + +#include "base/logging.h" +#include "base/parse_object.h" +#include "base/test/task_test_util.h" +#include "io/event_manager.h" +#include "io/ssl_server.h" +#include "io/ssl_session.h" +#include "io/test/event_manager_test.h" +#include "io/io_log.h" + +using namespace std; + +namespace { + +class EchoServer; + +class EchoSession : public SslSession { + public: + EchoSession(EchoServer *server, SslSocket *socket); + + protected: + virtual void OnRead(Buffer buffer) { + const u_int8_t *data = BufferData(buffer); + const size_t len = BufferSize(buffer); + TCP_UT_LOG_DEBUG("Received " << BufferData(buffer) << " " << len << " bytes"); + Send(data, len, NULL); + } + private: + void OnEvent(TcpSession *session, Event event) { + if (event == ACCEPT) { + TCP_UT_LOG_DEBUG("Accept"); + } + if (event == CLOSE) { + TCP_UT_LOG_DEBUG("Close"); + } + } +}; + +class EchoServer : public SslServer { +public: + explicit EchoServer(EventManager *evm) : + SslServer(evm, boost::asio::ssl::context::tlsv1_server), session_(NULL) { + boost::asio::ssl::context *ctx = context(); + boost::system::error_code ec; + ctx->set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + ctx->use_certificate_chain_file + ("controller/src/ifmap/client/test/newcert.pem", ec); + assert(ec.value() == 0); + ctx->use_private_key_file("controller/src/ifmap/client/test/server.pem", + boost::asio::ssl::context::pem, ec); + assert(ec.value() == 0); + ctx->add_verify_path("controller/src/ifmap/client/test/", ec); + assert(ec.value() == 0); + ctx->load_verify_file("controller/src/ifmap/client/test/newcert.pem", + ec); + assert(ec.value() == 0); + } + ~EchoServer() { + } + virtual SslSession *AllocSession(SslSocket *socket) { + session_ = new EchoSession(this, socket); + return session_; + } + + TcpSession *CreateSession() { + TcpSession *session = SslServer::CreateSession(); + Socket *socket = session->socket(); + + boost::system::error_code err; + socket->open(boost::asio::ip::tcp::v4(), err); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "open failed: " << err.message()); + } + err = session->SetSocketOptions(); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "sockopt: " << err.message()); + } + return session; + } + + EchoSession *GetSession() const { return session_; } + +private: + EchoSession *session_; +}; + +EchoSession::EchoSession(EchoServer *server, SslSocket *socket) + : SslSession(server, socket) { + set_observer(boost::bind(&EchoSession::OnEvent, this, _1, _2)); +} + +class SslClient; + +class ClientSession : public SslSession { + public: + ClientSession(SslClient *server, SslSocket *socket); + + void OnEvent(TcpSession *session, Event event) { + if (event == CONNECT_COMPLETE) { + const u_int8_t *data = (const u_int8_t *)"Hello there !"; + size_t len = 14; + Send(data, len, NULL); + } + } + + std::size_t &len() { return len_; } + + protected: + virtual void OnRead(Buffer buffer) { + const u_int8_t *data = BufferData(buffer); + const size_t len = BufferSize(buffer); + TCP_UT_LOG_DEBUG("Received " << BufferData(buffer) << " " << len << " bytes"); + len_ += len; + } + + private: + std::size_t len_; +}; + +class SslClient : public SslServer { +public: + explicit SslClient(EventManager *evm) : + SslServer(evm, boost::asio::ssl::context::tlsv1), session_(NULL) { + boost::asio::ssl::context *ctx = context(); + boost::system::error_code ec; + ctx->set_verify_mode(boost::asio::ssl::context::verify_none, ec); + assert(ec.value() == 0); + ctx->use_certificate_chain_file + ("controller/src/ifmap/client/test/newcert.pem", ec); + assert(ec.value() == 0); + ctx->use_private_key_file("controller/src/ifmap/client/test/server.pem", + boost::asio::ssl::context::pem, ec); + assert(ec.value() == 0); + ctx->add_verify_path("controller/src/ifmap/client/test/", ec); + assert(ec.value() == 0); + ctx->load_verify_file("controller/src/ifmap/client/test/newcert.pem", + ec); + assert(ec.value() == 0); + } + ~SslClient() { + } + virtual SslSession *AllocSession(SslSocket *socket) { + session_ = new ClientSession(this, socket); + return session_; + } + + TcpSession *CreateSession() { + TcpSession *session = SslServer::CreateSession(); + Socket *socket = session->socket(); + + boost::system::error_code err; + socket->open(boost::asio::ip::tcp::v4(), err); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "open failed: " << err.message()); + } + err = session->SetSocketOptions(); + if (err) { + TCP_SESSION_LOG_ERROR(session, TCP_DIR_OUT, + "sockopt: " << err.message()); + } + return session; + } + + ClientSession *GetSession() const { return session_; } + +private: + ClientSession *session_; +}; + +ClientSession::ClientSession(SslClient *server, SslSocket *socket) + : SslSession(server, socket) { + set_observer(boost::bind(&ClientSession::OnEvent, this, _1, _2)); +} + +class SslEchoServerTest : public ::testing::Test { +public: + void OnEvent(TcpSession *session, SslSession::Event event) { + boost::system::error_code ec; + timer_.cancel(ec); + ClientSession *client_session = static_cast(session); + client_session->OnEvent(session, event); + if (event == SslSession::CONNECT_FAILED) { + connect_fail_++; + session->Close(); + } + if (event == SslSession::CONNECT_COMPLETE) { + connect_success_++; + } + } + +protected: + SslEchoServerTest() + : evm_(new EventManager()), timer_(*evm_->io_service()), + connect_success_(0), connect_fail_(0), connect_abort_(0) { + } + virtual void SetUp() { + server_ = new EchoServer(evm_.get()); + thread_.reset(new ServerThread(evm_.get())); + session_ = NULL; + } + + virtual void TearDown() { + if (server_->GetSession()) { + server_->GetSession()->Close(); + } + if (session_) session_->Close(); + task_util::WaitForIdle(); + server_->Shutdown(); + server_->ClearSessions(); + task_util::WaitForIdle(); + TcpServerManager::DeleteServer(server_); + server_ = NULL; + evm_->Shutdown(); + if (thread_.get() != NULL) { + thread_->Join(); + } + task_util::WaitForIdle(); + } + + + void DummyTimerHandler(TcpSession *session, + const boost::system::error_code &error) { + if (error) { + return; + } + if (!session->IsClosed()) { + connect_abort_++; + } + session->Close(); + } + + void StartConnectTimer(TcpSession *session, int sec) { + boost::system::error_code ec; + timer_.expires_from_now(boost::posix_time::seconds(sec), ec); + timer_.async_wait( + boost::bind(&SslEchoServerTest::DummyTimerHandler, this, session, + boost::asio::placeholders::error)); + } + auto_ptr thread_; + auto_ptr evm_; + EchoServer *server_; + boost::asio::deadline_timer timer_; + EchoSession *session_; + int connect_success_; + int connect_fail_; + int connect_abort_; +}; + +TEST_F(SslEchoServerTest, msg_send_recv) { + SslClient *client = new SslClient(evm_.get()); + + task_util::WaitForIdle(); + server_->Initialize(0); + task_util::WaitForIdle(); + thread_->Start(); // Must be called after initialization + + connect_success_ = connect_fail_ = connect_abort_ = 0; + ClientSession *session = static_cast(client->CreateSession()); + session->set_observer(boost::bind(&SslEchoServerTest::OnEvent, this, _1, _2)); + boost::asio::ip::tcp::endpoint endpoint; + boost::system::error_code ec; + endpoint.address(boost::asio::ip::address::from_string("127.0.0.1", ec)); + endpoint.port(server_->GetPort()); + client->Connect(session, endpoint); + task_util::WaitForIdle(); + StartConnectTimer(session, 10); + TASK_UTIL_EXPECT_TRUE(session->IsEstablished()); + TASK_UTIL_EXPECT_FALSE(session->IsClosed()); + TASK_UTIL_EXPECT_EQ(1, connect_success_); + TASK_UTIL_EXPECT_EQ(connect_fail_, 0); + TASK_UTIL_EXPECT_EQ(connect_abort_, 0); + // wait for on connect message to come back from echo server. + TASK_UTIL_EXPECT_EQ(session->len(), 14); + + session->Close(); + client->DeleteSession(session); + + client->Shutdown(); + task_util::WaitForIdle(); + TcpServerManager::DeleteServer(client); + client = NULL; +} + +} // namespace + +int main(int argc, char **argv) { + LoggingInit(); + Sandesh::SetLoggingParams(true, "", SandeshLevel::UT_DEBUG); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}