Browse Source

HDFS-9228. libhdfs++ should respect NN retry configuration settings. Contributed by Bob Hansen

James 9 years ago
parent
commit
325ad8a0c1
16 changed files with 1169 additions and 260 deletions
  1. 1 1
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/include/libhdfspp/options.h
  2. 1 1
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/CMakeLists.txt
  3. 1 1
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/options.cc
  4. 47 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/retry_policy.cc
  5. 91 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/retry_policy.h
  6. 14 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/util.h
  7. 0 3
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/fs/filesystem.cc
  8. 158 75
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_connection.cc
  9. 155 64
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_connection.h
  10. 124 23
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_engine.cc
  11. 186 65
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_engine.h
  12. 4 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/CMakeLists.txt
  13. 11 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/mock_connection.cc
  14. 67 1
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/mock_connection.h
  15. 63 0
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/retry_policy_test.cc
  16. 246 26
      hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/rpc_engine_test.cc

+ 1 - 1
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/include/libhdfspp/options.h

@@ -33,7 +33,7 @@ struct Options {
   /**
    * Maximum number of retries for RPC operations
    **/
-  const static int NO_RPC_RETRY = -1;
+  const static int kNoRetry = -1;
   int max_rpc_retries;
 
   /**

+ 1 - 1
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/CMakeLists.txt

@@ -15,4 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-add_library(common base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc util.cc)
+add_library(common base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc util.cc retry_policy.cc)

+ 1 - 1
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/options.cc

@@ -20,6 +20,6 @@
 
 namespace hdfs {
 
-Options::Options() : rpc_timeout(30000), max_rpc_retries(0),
+Options::Options() : rpc_timeout(30000), max_rpc_retries(kNoRetry),
                      rpc_retry_delay_ms(10000), host_exclusion_duration(600000) {}
 }

+ 47 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/retry_policy.cc

@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "common/retry_policy.h"
+
+namespace hdfs {
+
+RetryAction FixedDelayRetryPolicy::ShouldRetry(
+    const Status &s, uint64_t retries, uint64_t failovers,
+    bool isIdempotentOrAtMostOnce) const {
+  (void)s;
+  (void)isIdempotentOrAtMostOnce;
+  if (retries + failovers >= max_retries_) {
+    return RetryAction::fail(
+        "Failovers (" + std::to_string(retries + failovers) +
+        ") exceeded maximum retries (" + std::to_string(max_retries_) + ")");
+  } else {
+    return RetryAction::retry(delay_);
+  }
+}
+
+RetryAction NoRetryPolicy::ShouldRetry(
+    const Status &s, uint64_t retries, uint64_t failovers,
+    bool isIdempotentOrAtMostOnce) const {
+  (void)s;
+  (void)retries;
+  (void)failovers;
+  (void)isIdempotentOrAtMostOnce;
+  return RetryAction::fail("No retry");
+}
+
+}

+ 91 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/retry_policy.h

@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef LIB_COMMON_RETRY_POLICY_H_
+#define LIB_COMMON_RETRY_POLICY_H_
+
+#include "common/util.h"
+
+#include <string>
+#include <stdint.h>
+
+namespace hdfs {
+
+class RetryAction {
+ public:
+  enum RetryDecision { FAIL, RETRY, FAILOVER_AND_RETRY };
+
+  RetryDecision action;
+  uint64_t delayMillis;
+  std::string reason;
+
+  RetryAction(RetryDecision in_action, uint64_t in_delayMillis,
+              const std::string &in_reason)
+      : action(in_action), delayMillis(in_delayMillis), reason(in_reason) {}
+
+  static RetryAction fail(const std::string &reason) {
+    return RetryAction(FAIL, 0, reason);
+  }
+  static RetryAction retry(uint64_t delay) {
+    return RetryAction(RETRY, delay, "");
+  }
+  static RetryAction failover() {
+    return RetryAction(FAILOVER_AND_RETRY, 0, "");
+  }
+};
+
+class RetryPolicy {
+ public:
+  /*
+   * If there was an error in communications, responds with the configured
+   * action to take.
+   */
+  virtual RetryAction ShouldRetry(const Status &s, uint64_t retries,
+                                            uint64_t failovers,
+                                            bool isIdempotentOrAtMostOnce) const = 0;
+
+  virtual ~RetryPolicy() {}
+};
+
+/*
+ * Returns a fixed delay up to a certain number of retries
+ */
+class FixedDelayRetryPolicy : public RetryPolicy {
+ public:
+  FixedDelayRetryPolicy(uint64_t delay, uint64_t max_retries)
+      : delay_(delay), max_retries_(max_retries) {}
+
+  RetryAction ShouldRetry(const Status &s, uint64_t retries,
+                          uint64_t failovers,
+                          bool isIdempotentOrAtMostOnce) const override;
+ private:
+  uint64_t delay_;
+  uint64_t max_retries_;
+};
+
+/*
+ * Never retries
+ */
+class NoRetryPolicy : public RetryPolicy {
+ public:
+  RetryAction ShouldRetry(const Status &s, uint64_t retries,
+                          uint64_t failovers,
+                          bool isIdempotentOrAtMostOnce) const override;
+};
+}
+
+#endif

+ 14 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/common/util.h

@@ -60,6 +60,20 @@ std::string Base64Encode(const std::string &src);
  * Returns a new high-entropy client name
  */
 std::string GetRandomClientName();
+
+/* Returns true if _someone_ is holding the lock (not necessarily this thread,
+ * but a std::mutex doesn't track which thread is holding the lock)
+ */
+template<class T>
+bool lock_held(T & mutex) {
+  bool result = !mutex.try_lock();
+  if (!result)
+    mutex.unlock();
+  return result;
+}
+
+
+
 }
 
 #endif

+ 0 - 3
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/fs/filesystem.cc

@@ -51,9 +51,6 @@ void NameNodeOperations::Connect(const std::string &server,
         engine_.Connect(m->state().front(), next);
       }));
   m->Run([this, handler](const Status &status, const State &) {
-    if (status.ok()) {
-      engine_.Start();
-    }
     handler(status);
   });
 }

+ 158 - 75
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_connection.cc

@@ -26,9 +26,6 @@
 
 #include <asio/read.hpp>
 
-#include <google/protobuf/io/coded_stream.h>
-#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
-
 namespace hdfs {
 
 namespace pb = ::google::protobuf;
@@ -37,17 +34,18 @@ namespace pbio = ::google::protobuf::io;
 using namespace ::hadoop::common;
 using namespace ::std::placeholders;
 
-static void
-ConstructPacket(std::string *res,
-                std::initializer_list<const pb::MessageLite *> headers,
-                const std::string *request) {
+static const int kNoRetry = -1;
+
+static void AddHeadersToPacket(
+    std::string *res, std::initializer_list<const pb::MessageLite *> headers,
+    const std::string *payload) {
   int len = 0;
   std::for_each(
       headers.begin(), headers.end(),
       [&len](const pb::MessageLite *v) { len += DelimitedPBMessageSize(v); });
-  if (request) {
-    len += pbio::CodedOutputStream::VarintSize32(request->size()) +
-           request->size();
+
+  if (payload) {
+    len += payload->size();
   }
 
   int net_len = htonl(len);
@@ -58,6 +56,7 @@ ConstructPacket(std::string *res,
   os.WriteRaw(reinterpret_cast<const char *>(&net_len), sizeof(net_len));
 
   uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
+  assert(buf);
 
   std::for_each(
       headers.begin(), headers.end(), [&buf](const pb::MessageLite *v) {
@@ -65,19 +64,43 @@ ConstructPacket(std::string *res,
         buf = v->SerializeWithCachedSizesToArray(buf);
       });
 
-  if (request) {
-    buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
-    buf = os.WriteStringToArray(*request, buf);
+  if (payload) {
+    buf = os.WriteStringToArray(*payload, buf);
   }
 }
 
-static void SetRequestHeader(RpcEngine *engine, int call_id,
-                             const std::string &method_name,
+static void ConstructPayload(std::string *res, const pb::MessageLite *header) {
+  int len = DelimitedPBMessageSize(header);
+  res->reserve(len);
+  pbio::StringOutputStream ss(res);
+  pbio::CodedOutputStream os(&ss);
+  uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
+  assert(buf);
+  buf = pbio::CodedOutputStream::WriteVarint32ToArray(header->ByteSize(), buf);
+  buf = header->SerializeWithCachedSizesToArray(buf);
+}
+
+static void ConstructPayload(std::string *res, const std::string *request) {
+  int len =
+      pbio::CodedOutputStream::VarintSize32(request->size()) + request->size();
+  res->reserve(len);
+  pbio::StringOutputStream ss(res);
+  pbio::CodedOutputStream os(&ss);
+  uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
+  assert(buf);
+  buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
+  buf = os.WriteStringToArray(*request, buf);
+}
+
+static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
+                             const std::string &method_name, int retry_count,
                              RpcRequestHeaderProto *rpc_header,
                              RequestHeaderProto *req_header) {
   rpc_header->set_rpckind(RPC_PROTOCOL_BUFFER);
   rpc_header->set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
   rpc_header->set_callid(call_id);
+  if (retry_count != kNoRetry)
+    rpc_header->set_retrycount(retry_count);
   rpc_header->set_clientid(engine->client_name());
 
   req_header->set_methodname(method_name);
@@ -87,64 +110,84 @@ static void SetRequestHeader(RpcEngine *engine, int call_id,
 
 RpcConnection::~RpcConnection() {}
 
-RpcConnection::Request::Request(RpcConnection *parent,
-                                const std::string &method_name,
-                                const std::string &request, Handler &&handler)
-    : call_id_(parent->engine_->NextCallId()), timer_(parent->io_service()),
-      handler_(std::move(handler)) {
-  RpcRequestHeaderProto rpc_header;
-  RequestHeaderProto req_header;
-  SetRequestHeader(parent->engine_, call_id_, method_name, &rpc_header,
-                   &req_header);
-  ConstructPacket(&payload_, {&rpc_header, &req_header}, &request);
+Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
+                 const std::string &request, Handler &&handler)
+    : engine_(engine),
+      method_name_(method_name),
+      call_id_(engine->NextCallId()),
+      timer_(engine->io_service()),
+      handler_(std::move(handler)),
+      retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
+  ConstructPayload(&payload_, &request);
+}
+
+Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
+                 const pb::MessageLite *request, Handler &&handler)
+    : engine_(engine),
+      method_name_(method_name),
+      call_id_(engine->NextCallId()),
+      timer_(engine->io_service()),
+      handler_(std::move(handler)),
+      retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
+  ConstructPayload(&payload_, request);
 }
 
-RpcConnection::Request::Request(RpcConnection *parent,
-                                const std::string &method_name,
-                                const pb::MessageLite *request,
-                                Handler &&handler)
-    : call_id_(parent->engine_->NextCallId()), timer_(parent->io_service()),
-      handler_(std::move(handler)) {
+Request::Request(LockFreeRpcEngine *engine, Handler &&handler)
+    : engine_(engine),
+      call_id_(-1),
+      timer_(engine->io_service()),
+      handler_(std::move(handler)),
+      retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
+}
+
+void Request::GetPacket(std::string *res) const {
+  if (payload_.empty())
+    return;
+
   RpcRequestHeaderProto rpc_header;
   RequestHeaderProto req_header;
-  SetRequestHeader(parent->engine_, call_id_, method_name, &rpc_header,
+  SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
                    &req_header);
-  ConstructPacket(&payload_, {&rpc_header, &req_header, request}, nullptr);
+  AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
 }
 
-void RpcConnection::Request::OnResponseArrived(pbio::CodedInputStream *is,
-                                               const Status &status) {
+void Request::OnResponseArrived(pbio::CodedInputStream *is,
+                                const Status &status) {
   handler_(is, status);
 }
 
-RpcConnection::RpcConnection(RpcEngine *engine)
-    : engine_(engine), resp_state_(kReadLength), resp_length_(0) {}
+RpcConnection::RpcConnection(LockFreeRpcEngine *engine)
+    : engine_(engine),
+      connected_(false) {}
 
 ::asio::io_service &RpcConnection::io_service() {
   return engine_->io_service();
 }
 
-void RpcConnection::Start() {
+void RpcConnection::StartReading() {
   io_service().post(std::bind(&RpcConnection::OnRecvCompleted, this,
                               ::asio::error_code(), 0));
 }
 
-void RpcConnection::FlushPendingRequests() {
-  io_service().post([this]() {
+void RpcConnection::AsyncFlushPendingRequests() {
+  std::shared_ptr<RpcConnection> shared_this = shared_from_this();
+  io_service().post([shared_this, this]() {
+    std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+
     if (!request_over_the_wire_) {
-      OnSendCompleted(::asio::error_code(), 0);
+      FlushPendingRequests();
     }
   });
 }
 
-void RpcConnection::HandleRpcResponse(const std::vector<char> &data) {
-  /* assumed to be called from a context that has already acquired the
-   * engine_state_lock */
-  pbio::ArrayInputStream ar(&data[0], data.size());
-  pbio::CodedInputStream in(&ar);
-  in.PushLimit(data.size());
+void RpcConnection::HandleRpcResponse(std::shared_ptr<Response> response) {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
+
+  response->ar.reset(new pbio::ArrayInputStream(&response->data_[0], response->data_.size()));
+  response->in.reset(new pbio::CodedInputStream(response->ar.get()));
+  response->in->PushLimit(response->data_.size());
   RpcResponseHeaderProto h;
-  ReadDelimitedPBMessage(&in, &h);
+  ReadDelimitedPBMessage(response->in.get(), &h);
 
   auto req = RemoveFromRunningQueue(h.callid());
   if (!req) {
@@ -152,12 +195,15 @@ void RpcConnection::HandleRpcResponse(const std::vector<char> &data) {
     return;
   }
 
-  Status stat;
+  Status status;
   if (h.has_exceptionclassname()) {
-    stat =
+    status =
         Status::Exception(h.exceptionclassname().c_str(), h.errormsg().c_str());
   }
-  req->OnResponseArrived(&in, stat);
+
+  io_service().post([req, response, status]() {
+    req->OnResponseArrived(response->in.get(), status);  // Never call back while holding a lock
+  });
 }
 
 void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
@@ -166,7 +212,7 @@ void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
     return;
   }
 
-  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
   auto r = RemoveFromRunningQueue(req->call_id());
   if (!r) {
     // The RPC might have been finished and removed from the queue
@@ -179,6 +225,8 @@ void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
 }
 
 std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
+
   static const char kHandshakeHeader[] = {'h', 'r', 'p', 'c',
                                           RpcEngine::kRpcVersion, 0, 0};
   auto res =
@@ -192,25 +240,27 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
 
   IpcConnectionContextProto handshake;
   handshake.set_protocol(engine_->protocol_name());
-  ConstructPacket(res.get(), {&h, &handshake}, nullptr);
+  AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
   return res;
 }
 
 void RpcConnection::AsyncRpc(
     const std::string &method_name, const ::google::protobuf::MessageLite *req,
     std::shared_ptr<::google::protobuf::MessageLite> resp,
-    const Callback &handler) {
-  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+    const RpcCallback &handler) {
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
 
   auto wrapped_handler =
       [resp, handler](pbio::CodedInputStream *is, const Status &status) {
         if (status.ok()) {
-          ReadDelimitedPBMessage(is, resp.get());
+          if (is) {  // Connect messages will not have an is
+            ReadDelimitedPBMessage(is, resp.get());
+          }
         }
         handler(status);
       };
 
-  auto r = std::make_shared<Request>(this, method_name, req,
+  auto r = std::make_shared<Request>(engine_, method_name, req,
                                      std::move(wrapped_handler));
   pending_requests_.push_back(r);
   FlushPendingRequests();
@@ -219,29 +269,62 @@ void RpcConnection::AsyncRpc(
 void RpcConnection::AsyncRawRpc(const std::string &method_name,
                                 const std::string &req,
                                 std::shared_ptr<std::string> resp,
-                                Callback &&handler) {
-  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
-
-  auto wrapped_handler =
-      [this, resp, handler](pbio::CodedInputStream *is, const Status &status) {
-        if (status.ok()) {
-          uint32_t size = 0;
-          is->ReadVarint32(&size);
-          auto limit = is->PushLimit(size);
-          is->ReadString(resp.get(), limit);
-          is->PopLimit(limit);
-        }
-        handler(status);
-      };
+                                RpcCallback &&handler) {
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+
+  std::shared_ptr<RpcConnection> shared_this = shared_from_this();
+  auto wrapped_handler = [shared_this, this, resp, handler](
+      pbio::CodedInputStream *is, const Status &status) {
+    if (status.ok()) {
+      uint32_t size = 0;
+      is->ReadVarint32(&size);
+      auto limit = is->PushLimit(size);
+      is->ReadString(resp.get(), limit);
+      is->PopLimit(limit);
+    }
+    handler(status);
+  };
 
-  auto r = std::make_shared<Request>(this, method_name, req,
+  auto r = std::make_shared<Request>(engine_, method_name, req,
                                      std::move(wrapped_handler));
   pending_requests_.push_back(r);
   FlushPendingRequests();
 }
 
+void RpcConnection::PreEnqueueRequests(
+    std::vector<std::shared_ptr<Request>> requests) {
+  // Public method - acquire lock
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+  assert(!connected_);
+
+  pending_requests_.insert(pending_requests_.end(), requests.begin(),
+                           requests.end());
+  // Don't start sending yet; will flush when connected
+}
+
+void RpcConnection::CommsError(const Status &status) {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
+
+  Disconnect();
+
+  // Anything that has been queued to the connection (on the fly or pending)
+  //    will get dinged for a retry
+  std::vector<std::shared_ptr<Request>> requestsToReturn;
+  std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
+                 std::back_inserter(requestsToReturn),
+                 std::bind(&RequestOnFlyMap::value_type::second, _1));
+  requests_on_fly_.clear();
+
+  requestsToReturn.insert(requestsToReturn.end(),
+                         std::make_move_iterator(pending_requests_.begin()),
+                         std::make_move_iterator(pending_requests_.end()));
+  pending_requests_.clear();
+
+  engine_->AsyncRpcCommsError(status, requestsToReturn);
+}
+
 void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
-  Shutdown();
+  Disconnect();
   std::vector<std::shared_ptr<Request>> requests;
   std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
                  std::back_inserter(requests),
@@ -256,8 +339,8 @@ void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
   }
 }
 
-std::shared_ptr<RpcConnection::Request>
-RpcConnection::RemoveFromRunningQueue(int call_id) {
+std::shared_ptr<Request> RpcConnection::RemoveFromRunningQueue(int call_id) {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
   auto it = requests_on_fly_.find(call_id);
   if (it == requests_on_fly_.end()) {
     return std::shared_ptr<Request>();

+ 155 - 64
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_connection.h

@@ -29,46 +29,90 @@
 
 namespace hdfs {
 
-template <class NextLayer> class RpcConnectionImpl : public RpcConnection {
+template <class NextLayer>
+class RpcConnectionImpl : public RpcConnection {
 public:
   RpcConnectionImpl(RpcEngine *engine);
   virtual void Connect(const ::asio::ip::tcp::endpoint &server,
-                       Callback &&handler) override;
-  virtual void Handshake(Callback &&handler) override;
-  virtual void Shutdown() override;
+                       RpcCallback &handler);
+  virtual void ConnectAndFlush(
+      const ::asio::ip::tcp::endpoint &server) override;
+  virtual void Handshake(RpcCallback &handler) override;
+  virtual void Disconnect() override;
   virtual void OnSendCompleted(const ::asio::error_code &ec,
                                size_t transferred) override;
   virtual void OnRecvCompleted(const ::asio::error_code &ec,
                                size_t transferred) override;
+  virtual void FlushPendingRequests() override;
+
 
   NextLayer &next_layer() { return next_layer_; }
-private:
+
+  void TEST_set_connected(bool new_value) { connected_ = new_value; }
+
+ private:
   const Options options_;
   NextLayer next_layer_;
 };
 
 template <class NextLayer>
 RpcConnectionImpl<NextLayer>::RpcConnectionImpl(RpcEngine *engine)
-    : RpcConnection(engine), options_(engine->options()),
+    : RpcConnection(engine),
+      options_(engine->options()),
       next_layer_(engine->io_service()) {}
 
 template <class NextLayer>
 void RpcConnectionImpl<NextLayer>::Connect(
-    const ::asio::ip::tcp::endpoint &server, Callback &&handler) {
-  next_layer_.async_connect(server,
-      [handler](const ::asio::error_code &ec) {
-        handler(ToStatus(ec));
+    const ::asio::ip::tcp::endpoint &server, RpcCallback &handler) {
+  auto connectionSuccessfulReq = std::make_shared<Request>(
+      engine_, [handler](::google::protobuf::io::CodedInputStream *is,
+                         const Status &status) {
+        (void)is;
+        handler(status);
       });
+  pending_requests_.push_back(connectionSuccessfulReq);
+  this->ConnectAndFlush(server);  // need "this" so compiler can infer type of CAF
+}
+
+template <class NextLayer>
+void RpcConnectionImpl<NextLayer>::ConnectAndFlush(
+    const ::asio::ip::tcp::endpoint &server) {
+  std::shared_ptr<RpcConnection> shared_this = shared_from_this();
+  next_layer_.async_connect(server,
+                            [shared_this, this](const ::asio::error_code &ec) {
+                              std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+                              Status status = ToStatus(ec);
+                              if (status.ok()) {
+                                StartReading();
+                                Handshake([shared_this, this](const Status &s) {
+                                  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+                                  if (s.ok()) {
+                                    FlushPendingRequests();
+                                  } else {
+                                    CommsError(s);
+                                  };
+                                });
+                              } else {
+                                CommsError(status);
+                              }
+                            });
 }
 
 template <class NextLayer>
-void RpcConnectionImpl<NextLayer>::Handshake(Callback &&handler) {
+void RpcConnectionImpl<NextLayer>::Handshake(RpcCallback &handler) {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
+
+  auto shared_this = shared_from_this();
   auto handshake_packet = PrepareHandshakePacket();
-  ::asio::async_write(
-      next_layer_, asio::buffer(*handshake_packet),
-      [handshake_packet, handler](const ::asio::error_code &ec, size_t) {
-        handler(ToStatus(ec));
-      });
+  ::asio::async_write(next_layer_, asio::buffer(*handshake_packet),
+                      [handshake_packet, handler, shared_this, this](
+                          const ::asio::error_code &ec, size_t) {
+                        Status status = ToStatus(ec);
+                        if (status.ok()) {
+                          connected_ = true;
+                        }
+                        handler(status);
+                      });
 }
 
 template <class NextLayer>
@@ -76,82 +120,129 @@ void RpcConnectionImpl<NextLayer>::OnSendCompleted(const ::asio::error_code &ec,
                                                    size_t) {
   using std::placeholders::_1;
   using std::placeholders::_2;
-  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
 
   request_over_the_wire_.reset();
   if (ec) {
-    // Current RPC has failed -- abandon the
-    // connection and do proper clean up
-    ClearAndDisconnect(ec);
+    LOG_WARN() << "Network error during RPC write: " << ec.message();
+    CommsError(ToStatus(ec));
     return;
   }
 
-  if (!pending_requests_.size()) {
+  FlushPendingRequests();
+}
+
+template <class NextLayer>
+void RpcConnectionImpl<NextLayer>::FlushPendingRequests() {
+  using namespace ::std::placeholders;
+
+  // Lock should be held
+  assert(lock_held(connection_state_lock_));
+
+  if (pending_requests_.empty()) {
+    return;
+  }
+
+  if (!connected_) {
+    return;
+  }
+
+  // Don't send if we don't need to
+  if (request_over_the_wire_) {
     return;
   }
 
   std::shared_ptr<Request> req = pending_requests_.front();
   pending_requests_.erase(pending_requests_.begin());
-  requests_on_fly_[req->call_id()] = req;
-  request_over_the_wire_ = req;
-
-  req->timer().expires_from_now(
-      std::chrono::milliseconds(options_.rpc_timeout));
-  req->timer().async_wait(std::bind(
-      &RpcConnectionImpl<NextLayer>::HandleRpcTimeout, this, req, _1));
 
-  asio::async_write(
-      next_layer_, asio::buffer(req->payload()),
-      std::bind(&RpcConnectionImpl<NextLayer>::OnSendCompleted, this, _1, _2));
+  std::shared_ptr<RpcConnection> shared_this = shared_from_this();
+  std::shared_ptr<std::string> payload = std::make_shared<std::string>();
+  req->GetPacket(payload.get());
+  if (!payload->empty()) {
+    requests_on_fly_[req->call_id()] = req;
+    request_over_the_wire_ = req;
+
+    req->timer().expires_from_now(
+        std::chrono::milliseconds(options_.rpc_timeout));
+    req->timer().async_wait(std::bind(
+      &RpcConnection::HandleRpcTimeout, this, req, _1));
+
+    asio::async_write(next_layer_, asio::buffer(*payload),
+                      [shared_this, this, payload](const ::asio::error_code &ec,
+                                                   size_t size) {
+                        OnSendCompleted(ec, size);
+                      });
+  } else {  // Nothing to send for this request, inform the handler immediately
+    io_service().post(
+        // Never hold locks when calling a callback
+        [req]() { req->OnResponseArrived(nullptr, Status::OK()); }
+    );
+
+    // Reschedule to flush the next one
+    AsyncFlushPendingRequests();
+  }
 }
 
+
 template <class NextLayer>
 void RpcConnectionImpl<NextLayer>::OnRecvCompleted(const ::asio::error_code &ec,
                                                    size_t) {
   using std::placeholders::_1;
   using std::placeholders::_2;
-  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+  std::lock_guard<std::mutex> state_lock(connection_state_lock_);
+
+  std::shared_ptr<RpcConnection> shared_this = shared_from_this();
 
   switch (ec.value()) {
-  case 0:
-    // No errors
-    break;
-  case asio::error::operation_aborted:
-    // The event loop has been shut down. Ignore the error.
-    return;
-  default:
-    LOG_WARN() << "Network error during RPC: " << ec.message();
-    ClearAndDisconnect(ec);
-    return;
+    case 0:
+      // No errors
+      break;
+    case asio::error::operation_aborted:
+      // The event loop has been shut down. Ignore the error.
+      return;
+    default:
+      LOG_WARN() << "Network error during RPC read: " << ec.message();
+      CommsError(ToStatus(ec));
+      return;
+  }
+
+  if (!response_) { /* start a new one */
+    response_ = std::make_shared<Response>();
   }
 
-  if (resp_state_ == kReadLength) {
-    resp_state_ = kReadContent;
-    auto buf = ::asio::buffer(reinterpret_cast<char *>(&resp_length_),
-                              sizeof(resp_length_));
-    asio::async_read(next_layer_, buf,
-                     std::bind(&RpcConnectionImpl<NextLayer>::OnRecvCompleted,
-                               this, _1, _2));
-
-  } else if (resp_state_ == kReadContent) {
-    resp_state_ = kParseResponse;
-    resp_length_ = ntohl(resp_length_);
-    resp_data_.resize(resp_length_);
-    asio::async_read(next_layer_, ::asio::buffer(resp_data_),
-                     std::bind(&RpcConnectionImpl<NextLayer>::OnRecvCompleted,
-                               this, _1, _2));
-
-  } else if (resp_state_ == kParseResponse) {
-    resp_state_ = kReadLength;
-    HandleRpcResponse(resp_data_);
-    resp_data_.clear();
-    Start();
+  if (response_->state_ == Response::kReadLength) {
+    response_->state_ = Response::kReadContent;
+    auto buf = ::asio::buffer(reinterpret_cast<char *>(&response_->length_),
+                              sizeof(response_->length_));
+    asio::async_read(
+        next_layer_, buf,
+        [shared_this, this](const ::asio::error_code &ec, size_t size) {
+          OnRecvCompleted(ec, size);
+        });
+  } else if (response_->state_ == Response::kReadContent) {
+    response_->state_ = Response::kParseResponse;
+    response_->length_ = ntohl(response_->length_);
+    response_->data_.resize(response_->length_);
+    asio::async_read(
+        next_layer_, ::asio::buffer(response_->data_),
+        [shared_this, this](const ::asio::error_code &ec, size_t size) {
+          OnRecvCompleted(ec, size);
+        });
+  } else if (response_->state_ == Response::kParseResponse) {
+    HandleRpcResponse(response_);
+    response_ = nullptr;
+    StartReading();
   }
 }
 
-template <class NextLayer> void RpcConnectionImpl<NextLayer>::Shutdown() {
+template <class NextLayer>
+void RpcConnectionImpl<NextLayer>::Disconnect() {
+  assert(lock_held(connection_state_lock_));  // Must be holding lock before calling
+
+  request_over_the_wire_.reset();
   next_layer_.cancel();
   next_layer_.close();
+  connected_ = false;
 }
 }
 

+ 124 - 23
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_engine.cc

@@ -18,52 +18,71 @@
 #include "rpc_engine.h"
 #include "rpc_connection.h"
 #include "common/util.h"
+#include "optional.hpp"
 
 #include <future>
 
 namespace hdfs {
 
+template <class T>
+using optional = std::experimental::optional<T>;
+
 RpcEngine::RpcEngine(::asio::io_service *io_service, const Options &options,
                      const std::string &client_name, const char *protocol_name,
                      int protocol_version)
-    : io_service_(io_service), options_(options), client_name_(client_name),
-      protocol_name_(protocol_name), protocol_version_(protocol_version),
-      call_id_(0) {
-}
+    : io_service_(io_service),
+      options_(options),
+      client_name_(client_name),
+      protocol_name_(protocol_name),
+      protocol_version_(protocol_version),
+      retry_policy_(std::move(MakeRetryPolicy(options))),
+      call_id_(0),
+      retry_timer(*io_service) {}
 
 void RpcEngine::Connect(const ::asio::ip::tcp::endpoint &server,
-                        const std::function<void(const Status &)> &handler) {
-  conn_.reset(new RpcConnectionImpl<::asio::ip::tcp::socket>(this));
-  conn_->Connect(server, [this, handler](const Status &stat) {
-    if (!stat.ok()) {
-      handler(stat);
-    } else {
-      conn_->Handshake([handler](const Status &s) { handler(s); });
-    }
-  });
-}
+                        RpcCallback &handler) {
+  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+  last_endpoint_ = server;
 
-void RpcEngine::Start() { conn_->Start(); }
+  conn_ = NewConnection();
+  conn_->Connect(server, handler);
+}
 
 void RpcEngine::Shutdown() {
-  io_service_->post([this]() { conn_->Shutdown(); });
+  io_service_->post([this]() {
+    std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+    conn_->Disconnect();
+    conn_.reset();
+  });
 }
 
-void RpcEngine::TEST_SetRpcConnection(std::unique_ptr<RpcConnection> *conn) {
-  conn_.reset(conn->release());
+std::unique_ptr<const RetryPolicy> RpcEngine::MakeRetryPolicy(const Options &options) {
+  if (options.max_rpc_retries > 0) {
+    return std::unique_ptr<RetryPolicy>(new FixedDelayRetryPolicy(options.rpc_retry_delay_ms, options.max_rpc_retries));
+  } else {
+    return nullptr;
+  }
+}
+
+void RpcEngine::TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn) {
+  conn_ = conn;
 }
 
 void RpcEngine::AsyncRpc(
     const std::string &method_name, const ::google::protobuf::MessageLite *req,
     const std::shared_ptr<::google::protobuf::MessageLite> &resp,
     const std::function<void(const Status &)> &handler) {
+  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+  if (!conn_) {
+    conn_ = NewConnection();
+    conn_->ConnectAndFlush(last_endpoint_);
+  }
   conn_->AsyncRpc(method_name, req, resp, handler);
 }
 
-Status
-RpcEngine::Rpc(const std::string &method_name,
-               const ::google::protobuf::MessageLite *req,
-               const std::shared_ptr<::google::protobuf::MessageLite> &resp) {
+Status RpcEngine::Rpc(
+    const std::string &method_name, const ::google::protobuf::MessageLite *req,
+    const std::shared_ptr<::google::protobuf::MessageLite> &resp) {
   auto stat = std::make_shared<std::promise<Status>>();
   std::future<Status> future(stat->get_future());
   AsyncRpc(method_name, req, resp,
@@ -71,13 +90,95 @@ RpcEngine::Rpc(const std::string &method_name,
   return future.get();
 }
 
+std::shared_ptr<RpcConnection> RpcEngine::NewConnection()
+{
+  return std::make_shared<RpcConnectionImpl<::asio::ip::tcp::socket>>(this);
+}
+
+
 Status RpcEngine::RawRpc(const std::string &method_name, const std::string &req,
                          std::shared_ptr<std::string> resp) {
+  std::shared_ptr<RpcConnection> conn;
+  {
+    std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+    if (!conn_) {
+        conn_ = NewConnection();
+        conn_->ConnectAndFlush(last_endpoint_);
+      }
+    conn = conn_;
+  }
+
   auto stat = std::make_shared<std::promise<Status>>();
   std::future<Status> future(stat->get_future());
-  conn_->AsyncRawRpc(method_name, req, resp,
+  conn->AsyncRawRpc(method_name, req, resp,
                      [stat](const Status &status) { stat->set_value(status); });
   return future.get();
 }
 
+void RpcEngine::AsyncRpcCommsError(
+    const Status &status,
+    std::vector<std::shared_ptr<Request>> pendingRequests) {
+  io_service().post([this, status, pendingRequests]() {
+    RpcCommsError(status, pendingRequests);
+  });
+}
+
+void RpcEngine::RpcCommsError(
+    const Status &status,
+    std::vector<std::shared_ptr<Request>> pendingRequests) {
+  (void)status;
+
+  std::lock_guard<std::mutex> state_lock(engine_state_lock_);
+
+  auto head_action = optional<RetryAction>();
+
+  // Filter out anything with too many retries already
+  for (auto it = pendingRequests.begin(); it < pendingRequests.end();) {
+    auto req = *it;
+
+    RetryAction retry = RetryAction::fail(""); // Default to fail
+    if (retry_policy()) {
+      retry = retry_policy()->ShouldRetry(status, req->IncrementRetryCount(), 0, true);
+    }
+
+    if (retry.action == RetryAction::FAIL) {
+      // If we've exceeded the maximum retry, take the latest error and pass it
+      //    on.  There might be a good argument for caching the first error
+      //    rather than the last one, that gets messy
+
+      io_service().post([req, status]() {
+        req->OnResponseArrived(nullptr, status);  // Never call back while holding a lock
+      });
+      it = pendingRequests.erase(it);
+    } else {
+      if (!head_action) {
+        head_action = retry;
+      }
+
+      ++it;
+    }
+  }
+
+  // Close the connection and retry and requests that might have been sent to
+  //    the NN
+  if (!pendingRequests.empty() &&
+          head_action && head_action->action != RetryAction::FAIL) {
+    conn_ = NewConnection();
+
+    conn_->PreEnqueueRequests(pendingRequests);
+    if (head_action->delayMillis > 0) {
+      retry_timer.expires_from_now(
+          std::chrono::milliseconds(options_.rpc_retry_delay_ms));
+      retry_timer.async_wait([this](asio::error_code ec) {
+        if (!ec) conn_->ConnectAndFlush(last_endpoint_);
+      });
+    } else {
+      conn_->ConnectAndFlush(last_endpoint_);
+    }
+  } else {
+    // Connection will try again if someone calls AsyncRpc
+    conn_.reset();
+  }
+}
+
 }

+ 186 - 65
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/lib/rpc/rpc_engine.h

@@ -21,7 +21,11 @@
 #include "libhdfspp/options.h"
 #include "libhdfspp/status.h"
 
+#include "common/retry_policy.h"
+
 #include <google/protobuf/message_lite.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
 
 #include <asio/ip/tcp.hpp>
 #include <asio/deadline_timer.hpp>
@@ -34,77 +38,146 @@
 
 namespace hdfs {
 
-class RpcEngine;
-class RpcConnection {
-public:
-  typedef std::function<void(const Status &)> Callback;
+  /*
+   *        NOTE ABOUT LOCKING MODELS
+   *
+   * To prevent deadlocks, anything that might acquire multiple locks must
+   * acquire the lock on the RpcEngine first, then the RpcConnection.  Callbacks
+   * will never be called while holding any locks, so the components are free
+   * to take locks when servicing a callback.
+   *
+   * An RpcRequest or RpcConnection should never call any methods on the RpcEngine
+   * except for those that are exposed through the LockFreeRpcEngine interface.
+   */
+
+typedef const std::function<void(const Status &)> RpcCallback;
+
+class LockFreeRpcEngine;
+class RpcConnection;
+
+/*
+ * Internal bookkeeping for an outstanding request from the consumer.
+ *
+ * Threading model: not thread-safe; should only be accessed from a single
+ *   thread at a time
+ */
+class Request {
+ public:
+  typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
+                             const Status &status)> Handler;
+
+  Request(LockFreeRpcEngine *engine, const std::string &method_name,
+          const std::string &request, Handler &&callback);
+  Request(LockFreeRpcEngine *engine, const std::string &method_name,
+          const ::google::protobuf::MessageLite *request, Handler &&callback);
+
+  // Null request (with no actual message) used to track the state of an
+  //    initial Connect call
+  Request(LockFreeRpcEngine *engine, Handler &&handler);
+
+  int call_id() const { return call_id_; }
+  ::asio::deadline_timer &timer() { return timer_; }
+  int IncrementRetryCount() { return retry_count_++; }
+  void GetPacket(std::string *res) const;
+  void OnResponseArrived(::google::protobuf::io::CodedInputStream *is,
+                         const Status &status);
+
+ private:
+  LockFreeRpcEngine *const engine_;
+  const std::string method_name_;
+  const int call_id_;
+
+  ::asio::deadline_timer timer_;
+  std::string payload_;
+  const Handler handler_;
+
+  int retry_count_;
+};
+
+/*
+ * Encapsulates a persistent connection to the NameNode, and the sending of
+ * RPC requests and evaluating their responses.
+ *
+ * Can have multiple RPC requests in-flight simultaneously, but they are
+ * evaluated in-order on the server side in a blocking manner.
+ *
+ * Threading model: public interface is thread-safe
+ * All handlers passed in to method calls will be called from an asio thread,
+ *   and will not be holding any internal RpcConnection locks.
+ */
+class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
+ public:
+  RpcConnection(LockFreeRpcEngine *engine);
   virtual ~RpcConnection();
-  RpcConnection(RpcEngine *engine);
+
   virtual void Connect(const ::asio::ip::tcp::endpoint &server,
-                       Callback &&handler) = 0;
-  virtual void Handshake(Callback &&handler) = 0;
-  virtual void Shutdown() = 0;
+                       RpcCallback &handler) = 0;
+  virtual void ConnectAndFlush(const ::asio::ip::tcp::endpoint &server) = 0;
+  virtual void Handshake(RpcCallback &handler) = 0;
+  virtual void Disconnect() = 0;
 
-  void Start();
+  void StartReading();
   void AsyncRpc(const std::string &method_name,
                 const ::google::protobuf::MessageLite *req,
                 std::shared_ptr<::google::protobuf::MessageLite> resp,
-                const Callback &handler);
+                const RpcCallback &handler);
 
   void AsyncRawRpc(const std::string &method_name, const std::string &request,
-                   std::shared_ptr<std::string> resp, Callback &&handler);
+                   std::shared_ptr<std::string> resp, RpcCallback &&handler);
+
+  // Enqueue requests before the connection is connected.  Will be flushed
+  //   on connect
+  void PreEnqueueRequests(std::vector<std::shared_ptr<Request>> requests);
+
+  LockFreeRpcEngine *engine() { return engine_; }
+  ::asio::io_service &io_service();
+
+ protected:
+  struct Response {
+    enum ResponseState {
+      kReadLength,
+      kReadContent,
+      kParseResponse,
+    } state_;
+    unsigned length_;
+    std::vector<char> data_;
+
+    std::unique_ptr<::google::protobuf::io::ArrayInputStream> ar;
+    std::unique_ptr<::google::protobuf::io::CodedInputStream> in;
+
+    Response() : state_(kReadLength), length_(0) {}
+  };
 
-protected:
-  class Request;
-  RpcEngine *const engine_;
+
+  LockFreeRpcEngine *const engine_;
   virtual void OnSendCompleted(const ::asio::error_code &ec,
                                size_t transferred) = 0;
   virtual void OnRecvCompleted(const ::asio::error_code &ec,
                                size_t transferred) = 0;
+  virtual void FlushPendingRequests()=0;      // Synchronously write the next request
+
+  void AsyncFlushPendingRequests(); // Queue requests to be flushed at a later time
+
+
 
-  ::asio::io_service &io_service();
   std::shared_ptr<std::string> PrepareHandshakePacket();
-  static std::string
-  SerializeRpcRequest(const std::string &method_name,
-                      const ::google::protobuf::MessageLite *req);
-  void HandleRpcResponse(const std::vector<char> &data);
+  static std::string SerializeRpcRequest(
+      const std::string &method_name,
+      const ::google::protobuf::MessageLite *req);
+  void HandleRpcResponse(std::shared_ptr<Response> response);
   void HandleRpcTimeout(std::shared_ptr<Request> req,
                         const ::asio::error_code &ec);
-  void FlushPendingRequests();
+  void CommsError(const Status &status);
+
   void ClearAndDisconnect(const ::asio::error_code &ec);
   std::shared_ptr<Request> RemoveFromRunningQueue(int call_id);
 
-  enum ResponseState {
-    kReadLength,
-    kReadContent,
-    kParseResponse,
-  } resp_state_;
-  unsigned resp_length_;
-  std::vector<char> resp_data_;
-
-  class Request {
-  public:
-    typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
-                               const Status &status)> Handler;
-    Request(RpcConnection *parent, const std::string &method_name,
-            const std::string &request, Handler &&callback);
-    Request(RpcConnection *parent, const std::string &method_name,
-            const ::google::protobuf::MessageLite *request, Handler &&callback);
-
-    int call_id() const { return call_id_; }
-    ::asio::deadline_timer &timer() { return timer_; }
-    const std::string &payload() const { return payload_; }
-    void OnResponseArrived(::google::protobuf::io::CodedInputStream *is,
-                           const Status &status);
-
-  private:
-    const int call_id_;
-    ::asio::deadline_timer timer_;
-    std::string payload_;
-    Handler handler_;
-  };
+  std::shared_ptr<Response> response_;
 
-  // The request being sent over the wire
+  // Connection can have deferred connection, especially when we're pausing
+  //   during retry
+  bool connected_;
+  // The request being sent over the wire; will also be in requests_on_fly_
   std::shared_ptr<Request> request_over_the_wire_;
   // Requests to be sent over the wire
   std::vector<std::shared_ptr<Request>> pending_requests_;
@@ -112,11 +185,40 @@ protected:
   typedef std::unordered_map<int, std::shared_ptr<Request>> RequestOnFlyMap;
   RequestOnFlyMap requests_on_fly_;
   // Lock for mutable parts of this class that need to be thread safe
-  std::mutex engine_state_lock_;
+  std::mutex connection_state_lock_;
 };
 
-class RpcEngine {
+
+/*
+ * These methods of the RpcEngine will never acquire locks, and are safe for
+ * RpcConnections to call while holding a ConnectionLock.
+ */
+class LockFreeRpcEngine {
 public:
+  /* Enqueues a CommsError without acquiring a lock*/
+  virtual void AsyncRpcCommsError(const Status &status,
+                      std::vector<std::shared_ptr<Request>> pendingRequests) = 0;
+
+
+  virtual const RetryPolicy * retry_policy() const = 0;
+  virtual int NextCallId() = 0;
+
+  virtual const std::string &client_name() const = 0;
+  virtual const std::string &protocol_name() const = 0;
+  virtual int protocol_version() const = 0;
+  virtual ::asio::io_service &io_service() = 0;
+  virtual const Options &options() const = 0;
+};
+
+/*
+ * An engine for reliable communication with a NameNode.  Handles connection,
+ * retry, and (someday) failover of the requested messages.
+ *
+ * Threading model: thread-safe.  All callbacks will be called back from
+ *   an asio pool and will not hold any internal locks
+ */
+class RpcEngine : public LockFreeRpcEngine {
+ public:
   enum { kRpcVersion = 9 };
   enum {
     kCallIdAuthorizationFailed = -1,
@@ -129,6 +231,8 @@ public:
             const std::string &client_name, const char *protocol_name,
             int protocol_version);
 
+  void Connect(const ::asio::ip::tcp::endpoint &server, RpcCallback &handler);
+
   void AsyncRpc(const std::string &method_name,
                 const ::google::protobuf::MessageLite *req,
                 const std::shared_ptr<::google::protobuf::MessageLite> &resp,
@@ -143,29 +247,46 @@ public:
    **/
   Status RawRpc(const std::string &method_name, const std::string &req,
                 std::shared_ptr<std::string> resp);
-  void Connect(const ::asio::ip::tcp::endpoint &server,
-               const std::function<void(const Status &)> &handler);
   void Start();
   void Shutdown();
-  void TEST_SetRpcConnection(std::unique_ptr<RpcConnection> *conn);
 
-  int NextCallId() { return ++call_id_; }
+  /* Enqueues a CommsError without acquiring a lock*/
+  void AsyncRpcCommsError(const Status &status,
+                     std::vector<std::shared_ptr<Request>> pendingRequests) override;
+  void RpcCommsError(const Status &status,
+                     std::vector<std::shared_ptr<Request>> pendingRequests);
 
-  const std::string &client_name() const { return client_name_; }
-  const std::string &protocol_name() const { return protocol_name_; }
-  int protocol_version() const { return protocol_version_; }
-  ::asio::io_service &io_service() { return *io_service_; }
-  const Options &options() { return options_; }
-  static std::string GetRandomClientName();
 
+  const RetryPolicy * retry_policy() const override { return retry_policy_.get(); }
+  int NextCallId() override { return ++call_id_; }
+
+  void TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn);
+
+  const std::string &client_name() const override { return client_name_; }
+  const std::string &protocol_name() const override { return protocol_name_; }
+  int protocol_version() const override { return protocol_version_; }
+  ::asio::io_service &io_service() override { return *io_service_; }
+  const Options &options() const override { return options_; }
+  static std::string GetRandomClientName();
+ protected:
+  std::shared_ptr<RpcConnection> conn_;
+  virtual std::shared_ptr<RpcConnection> NewConnection();
+  virtual std::unique_ptr<const RetryPolicy> MakeRetryPolicy(const Options &options);
 private:
-  ::asio::io_service *io_service_;
-  Options options_;
+  ::asio::io_service * const io_service_;
+  const Options options_;
   const std::string client_name_;
   const std::string protocol_name_;
   const int protocol_version_;
+  const std::unique_ptr<const RetryPolicy> retry_policy_; //null --> no retry
   std::atomic_int call_id_;
-  std::unique_ptr<RpcConnection> conn_;
+  ::asio::deadline_timer retry_timer;
+
+  // Remember the last endpoint in case we need to reconnect to retry
+  ::asio::ip::tcp::endpoint last_endpoint_;
+
+  std::mutex engine_state_lock_;
+
 };
 }
 

+ 4 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/CMakeLists.txt

@@ -51,6 +51,10 @@ add_executable(sasl_digest_md5_test sasl_digest_md5_test.cc)
 target_link_libraries(sasl_digest_md5_test common ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
 add_test(sasl_digest_md5 sasl_digest_md5_test)
 
+add_executable(retry_policy_test retry_policy_test.cc)
+target_link_libraries(retry_policy_test common gmock_main ${CMAKE_THREAD_LIBS_INIT})
+add_test(retry_policy retry_policy_test)
+
 include_directories(${CMAKE_CURRENT_BINARY_DIR})
 add_executable(rpc_engine_test rpc_engine_test.cc ${PROTO_TEST_SRCS} ${PROTO_TEST_HDRS} $<TARGET_OBJECTS:test_common>)
 target_link_libraries(rpc_engine_test rpc proto common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})

+ 11 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/mock_connection.cc

@@ -26,4 +26,15 @@ MockConnectionBase::MockConnectionBase(::asio::io_service *io_service)
 
 MockConnectionBase::~MockConnectionBase() {}
 
+ProducerResult SharedMockConnection::Produce() {
+  if (auto shared_prducer = shared_connection_data_.lock()) {
+    return shared_prducer->Produce();
+  } else {
+    assert(false && "No producer registered");
+    return std::make_pair(asio::error_code(), "");
+  }
+}
+
+std::weak_ptr<SharedConnectionData> SharedMockConnection::shared_connection_data_;
+
 }

+ 67 - 1
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/mock_connection.h

@@ -29,7 +29,21 @@
 
 namespace hdfs {
 
-class MockConnectionBase : public AsyncStream{
+typedef std::pair<asio::error_code, std::string> ProducerResult;
+class AsioProducer {
+public:
+  /*
+   *  Return either:
+   *     (::asio::error_code(), <some data>) for a good result
+   *     (<an ::asio::error instance>, <anything>) to pass an error to the caller
+   *     (::asio::error::would_block, <anything>) to block the next call forever
+   */
+
+  virtual ProducerResult Produce() = 0;
+};
+
+
+class MockConnectionBase : public AsioProducer, public AsyncStream {
 public:
   MockConnectionBase(::asio::io_service *io_service);
   virtual ~MockConnectionBase();
@@ -40,6 +54,9 @@ public:
                                  std::size_t bytes_transferred) > handler) override {
     if (produced_.size() == 0) {
       ProducerResult r = Produce();
+      if (r.first == asio::error::would_block) {
+        return; // No more reads to do
+      }
       if (r.first) {
         io_service_->post(std::bind(handler, r.first, 0));
         return;
@@ -62,6 +79,13 @@ public:
     io_service_->post(std::bind(handler, asio::error_code(), asio::buffer_size(buf)));
   }
 
+  template <class Endpoint, class Callback>
+  void async_connect(const Endpoint &, Callback &&handler) {
+    io_service_->post([handler]() { handler(::asio::error_code()); });
+  }
+
+  virtual void cancel() {}
+  virtual void close() {}
 protected:
   virtual ProducerResult Produce() = 0;
   ::asio::io_service *io_service_;
@@ -69,6 +93,48 @@ protected:
 private:
   asio::streambuf produced_;
 };
+
+
+
+
+class SharedConnectionData : public AsioProducer {
+ public:
+  bool checkProducerForConnect = false;
+
+  MOCK_METHOD0(Produce, ProducerResult());
+};
+
+class SharedMockConnection : public MockConnectionBase {
+public:
+  using MockConnectionBase::MockConnectionBase;
+
+  template <class Endpoint, class Callback>
+  void async_connect(const Endpoint &, Callback &&handler) {
+    auto data = shared_connection_data_.lock();
+    assert(data);
+
+    if (!data->checkProducerForConnect) {
+      io_service_->post([handler]() { handler(::asio::error_code()); });
+    } else {
+      ProducerResult result = Produce();
+      if (result.first == asio::error::would_block) {
+        return; // Connect will hang
+      } else {
+        io_service_->post([handler, result]() { handler( result.first); });
+      }
+    }
+  }
+
+  static void SetSharedConnectionData(std::shared_ptr<SharedConnectionData> new_producer) {
+    shared_connection_data_ = new_producer; // get a weak reference to it
+  }
+
+protected:
+  ProducerResult Produce() override;
+
+  static std::weak_ptr<SharedConnectionData> shared_connection_data_;
+};
+
 }
 
 #endif

+ 63 - 0
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/retry_policy_test.cc

@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "common/retry_policy.h"
+
+#include <gmock/gmock.h>
+
+using namespace hdfs;
+
+TEST(RetryPolicyTest, TestNoRetry) {
+  NoRetryPolicy policy;
+  EXPECT_EQ(RetryAction::FAIL, policy.ShouldRetry(Status::Unimplemented(), 0, 0, true).action);
+}
+
+TEST(RetryPolicyTest, TestFixedDelay) {
+  static const uint64_t DELAY = 100;
+  FixedDelayRetryPolicy policy(DELAY, 10);
+
+  // No error
+  RetryAction result = policy.ShouldRetry(Status::Unimplemented(), 0, 0, true);
+  EXPECT_EQ(RetryAction::RETRY, result.action);
+  EXPECT_EQ(DELAY, result.delayMillis);
+
+  // Few errors
+  result = policy.ShouldRetry(Status::Unimplemented(), 2, 2, true);
+  EXPECT_EQ(RetryAction::RETRY, result.action);
+  EXPECT_EQ(DELAY, result.delayMillis);
+
+  result = policy.ShouldRetry(Status::Unimplemented(), 9, 0, true);
+  EXPECT_EQ(RetryAction::RETRY, result.action);
+  EXPECT_EQ(DELAY, result.delayMillis);
+
+  // Too many errors
+  result = policy.ShouldRetry(Status::Unimplemented(), 10, 0, true);
+  EXPECT_EQ(RetryAction::FAIL, result.action);
+  EXPECT_TRUE(result.reason.size() > 0);  // some error message
+
+  result = policy.ShouldRetry(Status::Unimplemented(), 0, 10, true);
+  EXPECT_EQ(RetryAction::FAIL, result.action);
+  EXPECT_TRUE(result.reason.size() > 0);  // some error message
+}
+
+int main(int argc, char *argv[]) {
+  // The following line must be executed to initialize Google Mock
+  // (and Google Test) before running the tests.
+  ::testing::InitGoogleMock(&argc, argv);
+  return RUN_ALL_TESTS();
+}

+ 246 - 26
hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfspp/tests/rpc_engine_test.cc

@@ -44,21 +44,33 @@ namespace pbio = ::google::protobuf::io;
 namespace hdfs {
 
 class MockRPCConnection : public MockConnectionBase {
-public:
+ public:
   MockRPCConnection(::asio::io_service &io_service)
       : MockConnectionBase(&io_service) {}
   MOCK_METHOD0(Produce, ProducerResult());
-  template <class Endpoint, class Callback>
-  void async_connect(const Endpoint &, Callback &&handler) {
-    handler(::asio::error_code());
+};
+
+class SharedMockRPCConnection : public SharedMockConnection {
+ public:
+  SharedMockRPCConnection(::asio::io_service &io_service)
+      : SharedMockConnection(&io_service) {}
+};
+
+class SharedConnectionEngine : public RpcEngine {
+  using RpcEngine::RpcEngine;
+
+protected:
+  std::shared_ptr<RpcConnection> NewConnection() override {
+    return std::make_shared<RpcConnectionImpl<SharedMockRPCConnection>>(this);
   }
-  void cancel() {}
-  void close() {}
+
 };
 
-static inline std::pair<error_code, string>
-RpcResponse(const RpcResponseHeaderProto &h, const std::string &data,
-            const ::asio::error_code &ec = error_code()) {
+}
+
+static inline std::pair<error_code, string> RpcResponse(
+    const RpcResponseHeaderProto &h, const std::string &data,
+    const ::asio::error_code &ec = error_code()) {
   uint32_t payload_length =
       pbio::CodedOutputStream::VarintSize32(h.ByteSize()) +
       pbio::CodedOutputStream::VarintSize32(data.size()) + h.ByteSize() +
@@ -77,7 +89,7 @@ RpcResponse(const RpcResponseHeaderProto &h, const std::string &data,
 
   return std::make_pair(ec, std::move(res));
 }
-}
+
 
 using namespace hdfs;
 
@@ -87,6 +99,9 @@ TEST(RpcEngineTest, TestRoundTrip) {
   RpcEngine engine(&io_service, options, "foo", "protocol", 1);
   RpcConnectionImpl<MockRPCConnection> *conn =
       new RpcConnectionImpl<MockRPCConnection>(&engine);
+  conn->TEST_set_connected(true);
+  conn->StartReading();
+
   EchoResponseProto server_resp;
   server_resp.set_message("foo");
 
@@ -96,27 +111,34 @@ TEST(RpcEngineTest, TestRoundTrip) {
   EXPECT_CALL(conn->next_layer(), Produce())
       .WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
 
-  std::unique_ptr<RpcConnection> conn_ptr(conn);
-  engine.TEST_SetRpcConnection(&conn_ptr);
+  std::shared_ptr<RpcConnection> conn_ptr(conn);
+  engine.TEST_SetRpcConnection(conn_ptr);
+
+  bool complete = false;
 
   EchoRequestProto req;
   req.set_message("foo");
   std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
-  engine.AsyncRpc("test", &req, resp, [resp, &io_service](const Status &stat) {
+  engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
     ASSERT_TRUE(stat.ok());
     ASSERT_EQ("foo", resp->message());
+    complete = true;
     io_service.stop();
   });
-  conn->Start();
   io_service.run();
+  ASSERT_TRUE(complete);
 }
 
-TEST(RpcEngineTest, TestConnectionReset) {
+TEST(RpcEngineTest, TestConnectionResetAndFail) {
   ::asio::io_service io_service;
   Options options;
   RpcEngine engine(&io_service, options, "foo", "protocol", 1);
   RpcConnectionImpl<MockRPCConnection> *conn =
       new RpcConnectionImpl<MockRPCConnection>(&engine);
+  conn->TEST_set_connected(true);
+  conn->StartReading();
+
+  bool complete = false;
 
   RpcResponseHeaderProto h;
   h.set_callid(1);
@@ -125,23 +147,213 @@ TEST(RpcEngineTest, TestConnectionReset) {
       .WillOnce(Return(RpcResponse(
           h, "", make_error_code(::asio::error::connection_reset))));
 
-  std::unique_ptr<RpcConnection> conn_ptr(conn);
-  engine.TEST_SetRpcConnection(&conn_ptr);
+  std::shared_ptr<RpcConnection> conn_ptr(conn);
+  engine.TEST_SetRpcConnection(conn_ptr);
 
   EchoRequestProto req;
   req.set_message("foo");
   std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
 
-  engine.AsyncRpc("test", &req, resp, [&io_service](const Status &stat) {
+  engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
     ASSERT_FALSE(stat.ok());
   });
+  io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+
+TEST(RpcEngineTest, TestConnectionResetAndRecover) {
+  ::asio::io_service io_service;
+  Options options;
+  options.max_rpc_retries = 1;
+  options.rpc_retry_delay_ms = 0;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+
+  EchoResponseProto server_resp;
+  server_resp.set_message("foo");
+
+  bool complete = false;
+
+  auto producer = std::make_shared<SharedConnectionData>();
+  RpcResponseHeaderProto h;
+  h.set_callid(1);
+  h.set_status(RpcResponseHeaderProto::SUCCESS);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(RpcResponse(
+          h, "", make_error_code(::asio::error::connection_reset))))
+      .WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  EchoRequestProto req;
+  req.set_message("foo");
+  std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
+
+  engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
+    ASSERT_TRUE(stat.ok());
+  });
+  io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+TEST(RpcEngineTest, TestConnectionResetAndRecoverWithDelay) {
+  ::asio::io_service io_service;
+  Options options;
+  options.max_rpc_retries = 1;
+  options.rpc_retry_delay_ms = 1;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+
+  EchoResponseProto server_resp;
+  server_resp.set_message("foo");
+
+  bool complete = false;
+
+  auto producer = std::make_shared<SharedConnectionData>();
+  RpcResponseHeaderProto h;
+  h.set_callid(1);
+  h.set_status(RpcResponseHeaderProto::SUCCESS);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(RpcResponse(
+          h, "", make_error_code(::asio::error::connection_reset))))
+      .WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  EchoRequestProto req;
+  req.set_message("foo");
+  std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
 
-  engine.AsyncRpc("test", &req, resp, [&io_service](const Status &stat) {
+  engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
+    ASSERT_TRUE(stat.ok());
+  });
+
+  ::asio::deadline_timer timer(io_service);
+  timer.expires_from_now(std::chrono::hours(100));
+  timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
+
+  io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+TEST(RpcEngineTest, TestConnectionFailure)
+{
+  auto producer = std::make_shared<SharedConnectionData>();
+  producer->checkProducerForConnect = true;
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  // Error and no retry
+  ::asio::io_service io_service;
+
+  bool complete = false;
+
+  Options options;
+  options.max_rpc_retries = 0;
+  options.rpc_retry_delay_ms = 0;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
+
+  engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
+    ASSERT_FALSE(stat.ok());
+  });
+  io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+TEST(RpcEngineTest, TestConnectionFailureRetryAndFailure)
+{
+  auto producer = std::make_shared<SharedConnectionData>();
+  producer->checkProducerForConnect = true;
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  ::asio::io_service io_service;
+
+  bool complete = false;
+
+  Options options;
+  options.max_rpc_retries = 2;
+  options.rpc_retry_delay_ms = 0;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
+
+  engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
+    complete = true;
     io_service.stop();
     ASSERT_FALSE(stat.ok());
   });
-  conn->Start();
   io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+TEST(RpcEngineTest, TestConnectionFailureAndRecover)
+{
+  auto producer = std::make_shared<SharedConnectionData>();
+  producer->checkProducerForConnect = true;
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  ::asio::io_service io_service;
+
+  bool complete = false;
+
+  Options options;
+  options.max_rpc_retries = 1;
+  options.rpc_retry_delay_ms = 0;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
+      .WillOnce(Return(std::make_pair(::asio::error_code(), "")))
+      .WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
+
+  engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
+    ASSERT_TRUE(stat.ok());
+  });
+  io_service.run();
+  ASSERT_TRUE(complete);
+}
+
+TEST(RpcEngineTest, TestConnectionFailureAndAsyncRecover)
+{
+  // Error and async recover
+  auto producer = std::make_shared<SharedConnectionData>();
+  producer->checkProducerForConnect = true;
+  SharedMockConnection::SetSharedConnectionData(producer);
+
+  ::asio::io_service io_service;
+
+  bool complete = false;
+
+  Options options;
+  options.max_rpc_retries = 1;
+  options.rpc_retry_delay_ms = 1;
+  SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
+  EXPECT_CALL(*producer, Produce())
+      .WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
+      .WillOnce(Return(std::make_pair(::asio::error_code(), "")))
+      .WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
+
+  engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
+    complete = true;
+    io_service.stop();
+    ASSERT_TRUE(stat.ok());
+  });
+
+  ::asio::deadline_timer timer(io_service);
+  timer.expires_from_now(std::chrono::hours(100));
+  timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
+
+  io_service.run();
+  ASSERT_TRUE(complete);
 }
 
 TEST(RpcEngineTest, TestTimeout) {
@@ -151,24 +363,32 @@ TEST(RpcEngineTest, TestTimeout) {
   RpcEngine engine(&io_service, options, "foo", "protocol", 1);
   RpcConnectionImpl<MockRPCConnection> *conn =
       new RpcConnectionImpl<MockRPCConnection>(&engine);
+  conn->TEST_set_connected(true);
+  conn->StartReading();
+
+    EXPECT_CALL(conn->next_layer(), Produce())
+        .WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
 
-  EXPECT_CALL(conn->next_layer(), Produce()).Times(0);
+  std::shared_ptr<RpcConnection> conn_ptr(conn);
+  engine.TEST_SetRpcConnection(conn_ptr);
 
-  std::unique_ptr<RpcConnection> conn_ptr(conn);
-  engine.TEST_SetRpcConnection(&conn_ptr);
+  bool complete = false;
 
   EchoRequestProto req;
   req.set_message("foo");
   std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
-  engine.AsyncRpc("test", &req, resp, [resp, &io_service](const Status &stat) {
+  engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
+    complete = true;
     io_service.stop();
     ASSERT_FALSE(stat.ok());
   });
 
   ::asio::deadline_timer timer(io_service);
-  timer.expires_from_now(std::chrono::milliseconds(options.rpc_timeout * 2));
-  timer.async_wait(std::bind(&RpcConnection::Start, conn));
+  timer.expires_from_now(std::chrono::hours(100));
+  timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
+
   io_service.run();
+  ASSERT_TRUE(complete);
 }
 
 int main(int argc, char *argv[]) {