1
0

rpc_connection.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "rpc_engine.h"
  19. #include "RpcHeader.pb.h"
  20. #include "ProtobufRpcEngine.pb.h"
  21. #include "IpcConnectionContext.pb.h"
  22. #include "common/logging.h"
  23. #include "common/util.h"
  24. #include <asio/read.hpp>
  25. namespace hdfs {
  26. namespace pb = ::google::protobuf;
  27. namespace pbio = ::google::protobuf::io;
  28. using namespace ::hadoop::common;
  29. using namespace ::std::placeholders;
  30. static const int kNoRetry = -1;
  31. static void AddHeadersToPacket(
  32. std::string *res, std::initializer_list<const pb::MessageLite *> headers,
  33. const std::string *payload) {
  34. int len = 0;
  35. std::for_each(
  36. headers.begin(), headers.end(),
  37. [&len](const pb::MessageLite *v) { len += DelimitedPBMessageSize(v); });
  38. if (payload) {
  39. len += payload->size();
  40. }
  41. int net_len = htonl(len);
  42. res->reserve(res->size() + sizeof(net_len) + len);
  43. pbio::StringOutputStream ss(res);
  44. pbio::CodedOutputStream os(&ss);
  45. os.WriteRaw(reinterpret_cast<const char *>(&net_len), sizeof(net_len));
  46. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  47. assert(buf);
  48. std::for_each(
  49. headers.begin(), headers.end(), [&buf](const pb::MessageLite *v) {
  50. buf = pbio::CodedOutputStream::WriteVarint32ToArray(v->ByteSize(), buf);
  51. buf = v->SerializeWithCachedSizesToArray(buf);
  52. });
  53. if (payload) {
  54. buf = os.WriteStringToArray(*payload, buf);
  55. }
  56. }
  57. static void ConstructPayload(std::string *res, const pb::MessageLite *header) {
  58. int len = DelimitedPBMessageSize(header);
  59. res->reserve(len);
  60. pbio::StringOutputStream ss(res);
  61. pbio::CodedOutputStream os(&ss);
  62. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  63. assert(buf);
  64. buf = pbio::CodedOutputStream::WriteVarint32ToArray(header->ByteSize(), buf);
  65. buf = header->SerializeWithCachedSizesToArray(buf);
  66. }
  67. static void ConstructPayload(std::string *res, const std::string *request) {
  68. int len =
  69. pbio::CodedOutputStream::VarintSize32(request->size()) + request->size();
  70. res->reserve(len);
  71. pbio::StringOutputStream ss(res);
  72. pbio::CodedOutputStream os(&ss);
  73. uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
  74. assert(buf);
  75. buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
  76. buf = os.WriteStringToArray(*request, buf);
  77. }
  78. static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
  79. const std::string &method_name, int retry_count,
  80. RpcRequestHeaderProto *rpc_header,
  81. RequestHeaderProto *req_header) {
  82. rpc_header->set_rpckind(RPC_PROTOCOL_BUFFER);
  83. rpc_header->set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
  84. rpc_header->set_callid(call_id);
  85. if (retry_count != kNoRetry)
  86. rpc_header->set_retrycount(retry_count);
  87. rpc_header->set_clientid(engine->client_name());
  88. req_header->set_methodname(method_name);
  89. req_header->set_declaringclassprotocolname(engine->protocol_name());
  90. req_header->set_clientprotocolversion(engine->protocol_version());
  91. }
  92. RpcConnection::~RpcConnection() {}
  93. Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
  94. const std::string &request, Handler &&handler)
  95. : engine_(engine),
  96. method_name_(method_name),
  97. call_id_(engine->NextCallId()),
  98. timer_(engine->io_service()),
  99. handler_(std::move(handler)),
  100. retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
  101. ConstructPayload(&payload_, &request);
  102. }
  103. Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
  104. const pb::MessageLite *request, Handler &&handler)
  105. : engine_(engine),
  106. method_name_(method_name),
  107. call_id_(engine->NextCallId()),
  108. timer_(engine->io_service()),
  109. handler_(std::move(handler)),
  110. retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
  111. ConstructPayload(&payload_, request);
  112. }
  113. Request::Request(LockFreeRpcEngine *engine, Handler &&handler)
  114. : engine_(engine),
  115. call_id_(-1),
  116. timer_(engine->io_service()),
  117. handler_(std::move(handler)),
  118. retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
  119. }
  120. void Request::GetPacket(std::string *res) const {
  121. if (payload_.empty())
  122. return;
  123. RpcRequestHeaderProto rpc_header;
  124. RequestHeaderProto req_header;
  125. SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
  126. &req_header);
  127. AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
  128. }
  129. void Request::OnResponseArrived(pbio::CodedInputStream *is,
  130. const Status &status) {
  131. handler_(is, status);
  132. }
  133. RpcConnection::RpcConnection(LockFreeRpcEngine *engine)
  134. : engine_(engine),
  135. connected_(false) {}
  136. ::asio::io_service &RpcConnection::io_service() {
  137. return engine_->io_service();
  138. }
  139. void RpcConnection::StartReading() {
  140. io_service().post(std::bind(&RpcConnection::OnRecvCompleted, this,
  141. ::asio::error_code(), 0));
  142. }
  143. void RpcConnection::AsyncFlushPendingRequests() {
  144. std::shared_ptr<RpcConnection> shared_this = shared_from_this();
  145. io_service().post([shared_this, this]() {
  146. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  147. if (!request_over_the_wire_) {
  148. FlushPendingRequests();
  149. }
  150. });
  151. }
  152. void RpcConnection::HandleRpcResponse(std::shared_ptr<Response> response) {
  153. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  154. response->ar.reset(new pbio::ArrayInputStream(&response->data_[0], response->data_.size()));
  155. response->in.reset(new pbio::CodedInputStream(response->ar.get()));
  156. response->in->PushLimit(response->data_.size());
  157. RpcResponseHeaderProto h;
  158. ReadDelimitedPBMessage(response->in.get(), &h);
  159. auto req = RemoveFromRunningQueue(h.callid());
  160. if (!req) {
  161. LOG_WARN(kRPC, << "RPC response with Unknown call id " << h.callid());
  162. return;
  163. }
  164. Status status;
  165. if (h.has_exceptionclassname()) {
  166. status =
  167. Status::Exception(h.exceptionclassname().c_str(), h.errormsg().c_str());
  168. }
  169. io_service().post([req, response, status]() {
  170. req->OnResponseArrived(response->in.get(), status); // Never call back while holding a lock
  171. });
  172. }
  173. void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
  174. const ::asio::error_code &ec) {
  175. if (ec.value() == asio::error::operation_aborted) {
  176. return;
  177. }
  178. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  179. auto r = RemoveFromRunningQueue(req->call_id());
  180. if (!r) {
  181. // The RPC might have been finished and removed from the queue
  182. return;
  183. }
  184. Status stat = ToStatus(ec ? ec : make_error_code(::asio::error::timed_out));
  185. r->OnResponseArrived(nullptr, stat);
  186. }
  187. std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
  188. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  189. /** From Client.java:
  190. *
  191. * Write the connection header - this is sent when connection is established
  192. * +----------------------------------+
  193. * | "hrpc" 4 bytes |
  194. * +----------------------------------+
  195. * | Version (1 byte) |
  196. * +----------------------------------+
  197. * | Service Class (1 byte) |
  198. * +----------------------------------+
  199. * | AuthProtocol (1 byte) |
  200. * +----------------------------------+
  201. *
  202. * AuthProtocol: 0->none, -33->SASL
  203. */
  204. static const char kHandshakeHeader[] = {'h', 'r', 'p', 'c',
  205. RpcEngine::kRpcVersion, 0, 0};
  206. auto res =
  207. std::make_shared<std::string>(kHandshakeHeader, sizeof(kHandshakeHeader));
  208. RpcRequestHeaderProto h;
  209. h.set_rpckind(RPC_PROTOCOL_BUFFER);
  210. h.set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
  211. h.set_callid(RpcEngine::kCallIdConnectionContext);
  212. h.set_clientid(engine_->client_name());
  213. IpcConnectionContextProto handshake;
  214. handshake.set_protocol(engine_->protocol_name());
  215. const std::string & user_name = engine()->user_name();
  216. if (!user_name.empty()) {
  217. *handshake.mutable_userinfo()->mutable_effectiveuser() = user_name;
  218. }
  219. AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
  220. return res;
  221. }
  222. void RpcConnection::AsyncRpc(
  223. const std::string &method_name, const ::google::protobuf::MessageLite *req,
  224. std::shared_ptr<::google::protobuf::MessageLite> resp,
  225. const RpcCallback &handler) {
  226. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  227. auto wrapped_handler =
  228. [resp, handler](pbio::CodedInputStream *is, const Status &status) {
  229. if (status.ok()) {
  230. if (is) { // Connect messages will not have an is
  231. ReadDelimitedPBMessage(is, resp.get());
  232. }
  233. }
  234. handler(status);
  235. };
  236. auto r = std::make_shared<Request>(engine_, method_name, req,
  237. std::move(wrapped_handler));
  238. pending_requests_.push_back(r);
  239. FlushPendingRequests();
  240. }
  241. void RpcConnection::AsyncRawRpc(const std::string &method_name,
  242. const std::string &req,
  243. std::shared_ptr<std::string> resp,
  244. RpcCallback &&handler) {
  245. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  246. std::shared_ptr<RpcConnection> shared_this = shared_from_this();
  247. auto wrapped_handler = [shared_this, this, resp, handler](
  248. pbio::CodedInputStream *is, const Status &status) {
  249. if (status.ok()) {
  250. uint32_t size = 0;
  251. is->ReadVarint32(&size);
  252. auto limit = is->PushLimit(size);
  253. is->ReadString(resp.get(), limit);
  254. is->PopLimit(limit);
  255. }
  256. handler(status);
  257. };
  258. auto r = std::make_shared<Request>(engine_, method_name, req,
  259. std::move(wrapped_handler));
  260. pending_requests_.push_back(r);
  261. FlushPendingRequests();
  262. }
  263. void RpcConnection::PreEnqueueRequests(
  264. std::vector<std::shared_ptr<Request>> requests) {
  265. // Public method - acquire lock
  266. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  267. assert(!connected_);
  268. pending_requests_.insert(pending_requests_.end(), requests.begin(),
  269. requests.end());
  270. // Don't start sending yet; will flush when connected
  271. }
  272. void RpcConnection::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
  273. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  274. event_handlers_ = event_handlers;
  275. }
  276. void RpcConnection::SetClusterName(std::string cluster_name) {
  277. std::lock_guard<std::mutex> state_lock(connection_state_lock_);
  278. cluster_name_ = cluster_name;
  279. }
  280. void RpcConnection::CommsError(const Status &status) {
  281. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  282. Disconnect();
  283. // Anything that has been queued to the connection (on the fly or pending)
  284. // will get dinged for a retry
  285. std::vector<std::shared_ptr<Request>> requestsToReturn;
  286. std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
  287. std::back_inserter(requestsToReturn),
  288. std::bind(&RequestOnFlyMap::value_type::second, _1));
  289. requests_on_fly_.clear();
  290. requestsToReturn.insert(requestsToReturn.end(),
  291. std::make_move_iterator(pending_requests_.begin()),
  292. std::make_move_iterator(pending_requests_.end()));
  293. pending_requests_.clear();
  294. engine_->AsyncRpcCommsError(status, requestsToReturn);
  295. }
  296. void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
  297. Disconnect();
  298. std::vector<std::shared_ptr<Request>> requests;
  299. std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
  300. std::back_inserter(requests),
  301. std::bind(&RequestOnFlyMap::value_type::second, _1));
  302. requests_on_fly_.clear();
  303. requests.insert(requests.end(),
  304. std::make_move_iterator(pending_requests_.begin()),
  305. std::make_move_iterator(pending_requests_.end()));
  306. pending_requests_.clear();
  307. for (const auto &req : requests) {
  308. req->OnResponseArrived(nullptr, ToStatus(ec));
  309. }
  310. }
  311. std::shared_ptr<Request> RpcConnection::RemoveFromRunningQueue(int call_id) {
  312. assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
  313. auto it = requests_on_fly_.find(call_id);
  314. if (it == requests_on_fly_.end()) {
  315. return std::shared_ptr<Request>();
  316. }
  317. auto req = it->second;
  318. requests_on_fly_.erase(it);
  319. return req;
  320. }
  321. }