rpc_connection.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "rpc_engine.h"
  19. #include "sasl_protocol.h"
  20. #include "RpcHeader.pb.h"
  21. #include "ProtobufRpcEngine.pb.h"
  22. #include "IpcConnectionContext.pb.h"
  23. #include "common/logging.h"
  24. #include "common/util.h"
  25. #include <asio/read.hpp>
  26. namespace hdfs {
  27. namespace pb = ::google::protobuf;
  28. namespace pbio = ::google::protobuf::io;
  29. using namespace ::hadoop::common;
  30. using namespace ::std::placeholders;
  31. static const int kNoRetry = -1;
  32. static void AddHeadersToPacket(
  33. std::string *res, std::initializer_list<const pb::MessageLite *> headers,
  34. const std::string *payload) {
  35. int len = 0;
  36. std::for_each(
  37. headers.begin(), headers.end(),
  38. [&len](const pb::MessageLite *v) { len += DelimitedPBMessageSize(v); });
  39. if (payload) {
  40. len += payload->size();
  41. }
  42. int net_len = htonl(len);
  43. res->reserve(res->size() + sizeof(net_len) + len);
  44. pbio::StringOutputStream ss(res);
  45. pbio::CodedOutputStream os(&ss);
  46. os.WriteRaw(reinterpret_cast<const char *>(&net_len), sizeof(net_len));
  47. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  48. assert(buf);
  49. std::for_each(
  50. headers.begin(), headers.end(), [&buf](const pb::MessageLite *v) {
  51. buf = pbio::CodedOutputStream::WriteVarint32ToArray(v->ByteSize(), buf);
  52. buf = v->SerializeWithCachedSizesToArray(buf);
  53. });
  54. if (payload) {
  55. buf = os.WriteStringToArray(*payload, buf);
  56. }
  57. }
  58. static void ConstructPayload(std::string *res, const pb::MessageLite *header) {
  59. int len = DelimitedPBMessageSize(header);
  60. res->reserve(len);
  61. pbio::StringOutputStream ss(res);
  62. pbio::CodedOutputStream os(&ss);
  63. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  64. assert(buf);
  65. buf = pbio::CodedOutputStream::WriteVarint32ToArray(header->ByteSize(), buf);
  66. buf = header->SerializeWithCachedSizesToArray(buf);
  67. }
  68. static void ConstructPayload(std::string *res, const std::string *request) {
  69. int len =
  70. pbio::CodedOutputStream::VarintSize32(request->size()) + request->size();
  71. res->reserve(len);
  72. pbio::StringOutputStream ss(res);
  73. pbio::CodedOutputStream os(&ss);
  74. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  75. assert(buf);
  76. buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
  77. buf = os.WriteStringToArray(*request, buf);
  78. }
  79. static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
  80. const std::string &method_name, int retry_count,
  81. RpcRequestHeaderProto *rpc_header,
  82. RequestHeaderProto *req_header) {
  83. rpc_header->set_rpckind(RPC_PROTOCOL_BUFFER);
  84. rpc_header->set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
  85. rpc_header->set_callid(call_id);
  86. if (retry_count != kNoRetry)
  87. rpc_header->set_retrycount(retry_count);
  88. rpc_header->set_clientid(engine->client_id());
  89. req_header->set_methodname(method_name);
  90. req_header->set_declaringclassprotocolname(engine->protocol_name());
  91. req_header->set_clientprotocolversion(engine->protocol_version());
  92. }
  93. RpcConnection::~RpcConnection() {}
  94. Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
  95. const std::string &request, Handler &&handler)
  96. : engine_(engine),
  97. method_name_(method_name),
  98. call_id_(call_id),
  99. timer_(engine->io_service()),
  100. handler_(std::move(handler)),
  101. retry_count_(engine->retry_policy() ? 0 : kNoRetry),
  102. failover_count_(0) {
  103. ConstructPayload(&payload_, &request);
  104. }
  105. Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
  106. const pb::MessageLite *request, Handler &&handler)
  107. : engine_(engine),
  108. method_name_(method_name),
  109. call_id_(call_id),
  110. timer_(engine->io_service()),
  111. handler_(std::move(handler)),
  112. retry_count_(engine->retry_policy() ? 0 : kNoRetry),
  113. failover_count_(0) {
  114. ConstructPayload(&payload_, request);
  115. }
  116. Request::Request(LockFreeRpcEngine *engine, Handler &&handler)
  117. : engine_(engine),
  118. call_id_(-1),
  119. timer_(engine->io_service()),
  120. handler_(std::move(handler)),
  121. retry_count_(engine->retry_policy() ? 0 : kNoRetry),
  122. failover_count_(0) {
  123. }
  124. void Request::GetPacket(std::string *res) const {
  125. LOG_TRACE(kRPC, << "Request::GetPacket called");
  126. if (payload_.empty())
  127. return;
  128. RpcRequestHeaderProto rpc_header;
  129. RequestHeaderProto req_header;
  130. SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
  131. &req_header);
  132. // SASL messages don't have a request header
  133. if (method_name_ != SASL_METHOD_NAME)
  134. AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
  135. else
  136. AddHeadersToPacket(res, {&rpc_header}, &payload_);
  137. }
  138. void Request::OnResponseArrived(pbio::CodedInputStream *is,
  139. const Status &status) {
  140. LOG_TRACE(kRPC, << "Request::OnResponseArrived called");
  141. handler_(is, status);
  142. }
  143. std::string Request::GetDebugString() const {
  144. // Basic description of this object, aimed at debugging
  145. std::stringstream ss;
  146. ss << "\nRequest Object:\n";
  147. ss << "\tMethod name = \"" << method_name_ << "\"\n";
  148. ss << "\tCall id = " << call_id_ << "\n";
  149. ss << "\tRetry Count = " << retry_count_ << "\n";
  150. ss << "\tFailover count = " << failover_count_ << "\n";
  151. return ss.str();
  152. }
  153. int Request::IncrementFailoverCount() {
  154. // reset retry count when failing over
  155. retry_count_ = 0;
  156. return failover_count_++;
  157. }
  158. RpcConnection::RpcConnection(LockFreeRpcEngine *engine)
  159. : engine_(engine),
  160. connected_(kNotYetConnected) {}
  161. ::asio::io_service &RpcConnection::io_service() {
  162. return engine_->io_service();
  163. }
  164. void RpcConnection::StartReading() {
  165. auto shared_this = shared_from_this();
  166. io_service().post([shared_this, this] () {
  167. OnRecvCompleted(::asio::error_code(), 0);
  168. });
  169. }
  170. void RpcConnection::HandshakeComplete(const Status &s) {
  171. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  172. LOG_TRACE(kRPC, << "RpcConnectionImpl::HandshakeComplete called");
  173. if (s.ok()) {
  174. if (connected_ == kHandshaking) {
  175. auto shared_this = shared_from_this();
  176. connected_ = kAuthenticating;
  177. if (auth_info_.useSASL()) {
  178. #ifdef USE_SASL
  179. sasl_protocol_ = std::make_shared<SaslProtocol>(cluster_name_, auth_info_, shared_from_this());
  180. sasl_protocol_->SetEventHandlers(event_handlers_);
  181. sasl_protocol_->Authenticate([shared_this, this](
  182. const Status & status, const AuthInfo & new_auth_info) {
  183. AuthComplete(status, new_auth_info); } );
  184. #else
  185. AuthComplete_locked(Status::Error("SASL is required, but no SASL library was found"), auth_info_);
  186. #endif
  187. } else {
  188. AuthComplete_locked(Status::OK(), auth_info_);
  189. }
  190. }
  191. } else {
  192. CommsError(s);
  193. };
  194. }
  195. void RpcConnection::AuthComplete(const Status &s, const AuthInfo & new_auth_info) {
  196. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  197. AuthComplete_locked(s, new_auth_info);
  198. }
  199. void RpcConnection::AuthComplete_locked(const Status &s, const AuthInfo & new_auth_info) {
  200. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  201. LOG_TRACE(kRPC, << "RpcConnectionImpl::AuthComplete called");
  202. // Free the sasl_protocol object
  203. sasl_protocol_.reset();
  204. if (s.ok()) {
  205. auth_info_ = new_auth_info;
  206. auto shared_this = shared_from_this();
  207. SendContext([shared_this, this](const Status & s) {
  208. ContextComplete(s);
  209. });
  210. } else {
  211. CommsError(s);
  212. };
  213. }
  214. void RpcConnection::ContextComplete(const Status &s) {
  215. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  216. LOG_TRACE(kRPC, << "RpcConnectionImpl::ContextComplete called");
  217. if (s.ok()) {
  218. if (connected_ == kAuthenticating) {
  219. connected_ = kConnected;
  220. }
  221. FlushPendingRequests();
  222. } else {
  223. CommsError(s);
  224. };
  225. }
  226. void RpcConnection::AsyncFlushPendingRequests() {
  227. std::shared_ptr<RpcConnection> shared_this = shared_from_this();
  228. io_service().post([shared_this, this]() {
  229. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  230. LOG_TRACE(kRPC, << "RpcConnection::AsyncFlushPendingRequests called (connected=" << ToString(connected_) << ")");
  231. if (!request_over_the_wire_) {
  232. FlushPendingRequests();
  233. }
  234. });
  235. }
  236. Status RpcConnection::HandleRpcResponse(std::shared_ptr<Response> response) {
  237. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  238. response->ar.reset(new pbio::ArrayInputStream(&response->data_[0], response->data_.size()));
  239. response->in.reset(new pbio::CodedInputStream(response->ar.get()));
  240. response->in->PushLimit(response->data_.size());
  241. RpcResponseHeaderProto h;
  242. ReadDelimitedPBMessage(response->in.get(), &h);
  243. auto req = RemoveFromRunningQueue(h.callid());
  244. if (!req) {
  245. LOG_WARN(kRPC, << "RPC response with Unknown call id " << h.callid());
  246. if((int32_t)h.callid() == RpcEngine::kCallIdSasl) {
  247. return Status::AuthenticationFailed("You have an unsecured client connecting to a secured server");
  248. } else {
  249. return Status::Error("Rpc response with unknown call id");
  250. }
  251. }
  252. Status status;
  253. if(event_handlers_) {
  254. event_response event_resp = event_handlers_->call(FS_NN_READ_EVENT, cluster_name_.c_str(), 0);
  255. #ifndef LIBHDFSPP_SIMULATE_ERROR_DISABLED
  256. if (event_resp.response() == event_response::kTest_Error) {
  257. status = event_resp.status();
  258. }
  259. #endif
  260. }
  261. if (status.ok() && h.has_exceptionclassname()) {
  262. status =
  263. Status::Exception(h.exceptionclassname().c_str(), h.errormsg().c_str());
  264. }
  265. if(status.get_server_exception_type() == Status::kStandbyException) {
  266. LOG_WARN(kRPC, << "Tried to connect to standby. status = " << status.ToString());
  267. // We got the request back, but it needs to be resent to the other NN
  268. std::vector<std::shared_ptr<Request>> reqs_to_redirect = {req};
  269. PrependRequests_locked(reqs_to_redirect);
  270. CommsError(status);
  271. return status;
  272. }
  273. io_service().post([req, response, status]() {
  274. req->OnResponseArrived(response->in.get(), status); // Never call back while holding a lock
  275. });
  276. return Status::OK();
  277. }
  278. void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
  279. const ::asio::error_code &ec) {
  280. if (ec.value() == asio::error::operation_aborted) {
  281. return;
  282. }
  283. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  284. auto r = RemoveFromRunningQueue(req->call_id());
  285. if (!r) {
  286. // The RPC might have been finished and removed from the queue
  287. return;
  288. }
  289. Status stat = ToStatus(ec ? ec : make_error_code(::asio::error::timed_out));
  290. r->OnResponseArrived(nullptr, stat);
  291. }
  292. std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
  293. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  294. /** From Client.java:
  295. *
  296. * Write the connection header - this is sent when connection is established
  297. * +----------------------------------+
  298. * | "hrpc" 4 bytes |
  299. * +----------------------------------+
  300. * | Version (1 byte) |
  301. * +----------------------------------+
  302. * | Service Class (1 byte) |
  303. * +----------------------------------+
  304. * | AuthProtocol (1 byte) |
  305. * +----------------------------------+
  306. *
  307. * AuthProtocol: 0->none, -33->SASL
  308. */
  309. char auth_protocol = auth_info_.useSASL() ? -33 : 0;
  310. const char handshake_header[] = {'h', 'r', 'p', 'c',
  311. RpcEngine::kRpcVersion, 0, auth_protocol};
  312. auto res =
  313. std::make_shared<std::string>(handshake_header, sizeof(handshake_header));
  314. return res;
  315. }
  316. std::shared_ptr<std::string> RpcConnection::PrepareContextPacket() {
  317. // This needs to be send after the SASL handshake, and
  318. // after the SASL handshake (if any)
  319. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  320. auto res = std::make_shared<std::string>();
  321. RpcRequestHeaderProto h;
  322. h.set_rpckind(RPC_PROTOCOL_BUFFER);
  323. h.set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
  324. h.set_callid(RpcEngine::kCallIdConnectionContext);
  325. h.set_clientid(engine_->client_name());
  326. IpcConnectionContextProto handshake;
  327. handshake.set_protocol(engine_->protocol_name());
  328. const std::string & user_name = auth_info_.getUser();
  329. if (!user_name.empty()) {
  330. *handshake.mutable_userinfo()->mutable_effectiveuser() = user_name;
  331. }
  332. AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
  333. return res;
  334. }
  335. void RpcConnection::AsyncRpc(
  336. const std::string &method_name, const ::google::protobuf::MessageLite *req,
  337. std::shared_ptr<::google::protobuf::MessageLite> resp,
  338. const RpcCallback &handler) {
  339. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  340. AsyncRpc_locked(method_name, req, resp, handler);
  341. }
  342. void RpcConnection::AsyncRpc_locked(
  343. const std::string &method_name, const ::google::protobuf::MessageLite *req,
  344. std::shared_ptr<::google::protobuf::MessageLite> resp,
  345. const RpcCallback &handler) {
  346. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  347. auto wrapped_handler =
  348. [resp, handler](pbio::CodedInputStream *is, const Status &status) {
  349. if (status.ok()) {
  350. if (is) { // Connect messages will not have an is
  351. ReadDelimitedPBMessage(is, resp.get());
  352. }
  353. }
  354. handler(status);
  355. };
  356. int call_id = (method_name != SASL_METHOD_NAME ? engine_->NextCallId() : RpcEngine::kCallIdSasl);
  357. auto r = std::make_shared<Request>(engine_, method_name, call_id, req,
  358. std::move(wrapped_handler));
  359. auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
  360. SendRpcRequests(r_vector);
  361. }
  362. void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requests) {
  363. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  364. SendRpcRequests(requests);
  365. }
  366. void RpcConnection::SendRpcRequests(const std::vector<std::shared_ptr<Request> > & requests) {
  367. LOG_TRACE(kRPC, << "RpcConnection::SendRpcRequests[] called; connected=" << ToString(connected_));
  368. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  369. if (connected_ == kDisconnected) {
  370. // Oops. The connection failed _just_ before the engine got a chance
  371. // to send it. Register it as a failure
  372. Status status = Status::ResourceUnavailable("RpcConnection closed before send.");
  373. engine_->AsyncRpcCommsError(status, shared_from_this(), requests);
  374. } else {
  375. for (auto r: requests) {
  376. if (r->method_name() != SASL_METHOD_NAME)
  377. pending_requests_.push_back(r);
  378. else
  379. auth_requests_.push_back(r);
  380. }
  381. if (connected_ == kConnected || connected_ == kHandshaking || connected_ == kAuthenticating) { // Dont flush if we're waiting or handshaking
  382. FlushPendingRequests();
  383. }
  384. }
  385. }
  386. void RpcConnection::PreEnqueueRequests(
  387. std::vector<std::shared_ptr<Request>> requests) {
  388. // Public method - acquire lock
  389. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  390. LOG_DEBUG(kRPC, << "RpcConnection::PreEnqueueRequests called");
  391. assert(connected_ == kNotYetConnected);
  392. pending_requests_.insert(pending_requests_.end(), requests.begin(),
  393. requests.end());
  394. // Don't start sending yet; will flush when connected
  395. }
  396. // Only call when already holding conn state lock
  397. void RpcConnection::PrependRequests_locked( std::vector<std::shared_ptr<Request>> requests) {
  398. LOG_DEBUG(kRPC, << "RpcConnection::PrependRequests called");
  399. pending_requests_.insert(pending_requests_.begin(), requests.begin(),
  400. requests.end());
  401. // Don't start sending yet; will flush when connected
  402. }
  403. void RpcConnection::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
  404. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  405. event_handlers_ = event_handlers;
  406. if (sasl_protocol_) {
  407. sasl_protocol_->SetEventHandlers(event_handlers);
  408. }
  409. }
  410. void RpcConnection::SetClusterName(std::string cluster_name) {
  411. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  412. cluster_name_ = cluster_name;
  413. }
  414. void RpcConnection::CommsError(const Status &status) {
  415. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  416. LOG_DEBUG(kRPC, << "RpcConnection::CommsError called");
  417. Disconnect();
  418. // Anything that has been queued to the connection (on the fly or pending)
  419. // will get dinged for a retry
  420. std::vector<std::shared_ptr<Request>> requestsToReturn;
  421. std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
  422. std::back_inserter(requestsToReturn),
  423. std::bind(&RequestOnFlyMap::value_type::second, _1));
  424. requests_on_fly_.clear();
  425. requestsToReturn.insert(requestsToReturn.end(),
  426. std::make_move_iterator(pending_requests_.begin()),
  427. std::make_move_iterator(pending_requests_.end()));
  428. pending_requests_.clear();
  429. engine_->AsyncRpcCommsError(status, shared_from_this(), requestsToReturn);
  430. }
  431. void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
  432. Disconnect();
  433. std::vector<std::shared_ptr<Request>> requests;
  434. std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
  435. std::back_inserter(requests),
  436. std::bind(&RequestOnFlyMap::value_type::second, _1));
  437. requests_on_fly_.clear();
  438. requests.insert(requests.end(),
  439. std::make_move_iterator(pending_requests_.begin()),
  440. std::make_move_iterator(pending_requests_.end()));
  441. pending_requests_.clear();
  442. for (const auto &req : requests) {
  443. req->OnResponseArrived(nullptr, ToStatus(ec));
  444. }
  445. }
  446. std::shared_ptr<Request> RpcConnection::RemoveFromRunningQueue(int call_id) {
  447. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  448. auto it = requests_on_fly_.find(call_id);
  449. if (it == requests_on_fly_.end()) {
  450. return std::shared_ptr<Request>();
  451. }
  452. auto req = it->second;
  453. requests_on_fly_.erase(it);
  454. return req;
  455. }
  456. std::string RpcConnection::ToString(ConnectedState connected) {
  457. switch(connected) {
  458. case kNotYetConnected: return "NotYetConnected";
  459. case kConnecting: return "Connecting";
  460. case kHandshaking: return "Handshaking";
  461. case kAuthenticating: return "Authenticating";
  462. case kConnected: return "Connected";
  463. case kDisconnected: return "Disconnected";
  464. default: return "Invalid ConnectedState";
  465. }
  466. }
  467. }