|
@@ -16,6 +16,7 @@
|
|
|
* limitations under the License.
|
|
|
*/
|
|
|
#include "rpc_engine.h"
|
|
|
+#include "sasl_protocol.h"
|
|
|
|
|
|
#include "RpcHeader.pb.h"
|
|
|
#include "ProtobufRpcEngine.pb.h"
|
|
@@ -110,22 +111,22 @@ static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
|
|
|
|
|
|
RpcConnection::~RpcConnection() {}
|
|
|
|
|
|
-Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
|
|
+Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
|
|
const std::string &request, Handler &&handler)
|
|
|
: engine_(engine),
|
|
|
method_name_(method_name),
|
|
|
- call_id_(engine->NextCallId()),
|
|
|
+ call_id_(call_id),
|
|
|
timer_(engine->io_service()),
|
|
|
handler_(std::move(handler)),
|
|
|
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
|
|
|
ConstructPayload(&payload_, &request);
|
|
|
}
|
|
|
|
|
|
-Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
|
|
+Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
|
|
const pb::MessageLite *request, Handler &&handler)
|
|
|
: engine_(engine),
|
|
|
method_name_(method_name),
|
|
|
- call_id_(engine->NextCallId()),
|
|
|
+ call_id_(call_id),
|
|
|
timer_(engine->io_service()),
|
|
|
handler_(std::move(handler)),
|
|
|
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
|
|
@@ -148,7 +149,12 @@ void Request::GetPacket(std::string *res) const {
|
|
|
RequestHeaderProto req_header;
|
|
|
SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
|
|
|
&req_header);
|
|
|
- AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
|
|
|
+
|
|
|
+ // SASL messages don't have a request header
|
|
|
+ if (method_name_ != SASL_METHOD_NAME)
|
|
|
+ AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
|
|
|
+ else
|
|
|
+ AddHeadersToPacket(res, {&rpc_header}, &payload_);
|
|
|
}
|
|
|
|
|
|
void Request::OnResponseArrived(pbio::CodedInputStream *is,
|
|
@@ -171,12 +177,80 @@ void RpcConnection::StartReading() {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+void RpcConnection::HandshakeComplete(const Status &s) {
|
|
|
+ std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
+
|
|
|
+ LOG_TRACE(kRPC, << "RpcConnectionImpl::HandshakeComplete called");
|
|
|
+
|
|
|
+ if (s.ok()) {
|
|
|
+ if (connected_ == kConnecting) {
|
|
|
+ auto shared_this = shared_from_this();
|
|
|
+
|
|
|
+ connected_ = kAuthenticating;
|
|
|
+ if (auth_info_.useSASL()) {
|
|
|
+#ifdef USE_SASL
|
|
|
+ sasl_protocol_ = std::make_shared<SaslProtocol>(cluster_name_, auth_info_, shared_from_this());
|
|
|
+ sasl_protocol_->SetEventHandlers(event_handlers_);
|
|
|
+ sasl_protocol_->authenticate([shared_this, this](
|
|
|
+ const Status & status, const AuthInfo & new_auth_info) {
|
|
|
+ AuthComplete(status, new_auth_info); } );
|
|
|
+#else
|
|
|
+ AuthComplete_locked(Status::Error("SASL is required, but no SASL library was found"), auth_info_);
|
|
|
+#endif
|
|
|
+ } else {
|
|
|
+ AuthComplete_locked(Status::OK(), auth_info_);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ CommsError(s);
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+void RpcConnection::AuthComplete(const Status &s, const AuthInfo & new_auth_info) {
|
|
|
+ std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
+ AuthComplete_locked(s, new_auth_info);
|
|
|
+}
|
|
|
+
|
|
|
+void RpcConnection::AuthComplete_locked(const Status &s, const AuthInfo & new_auth_info) {
|
|
|
+ assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
|
|
+ LOG_TRACE(kRPC, << "RpcConnectionImpl::AuthComplete called");
|
|
|
+
|
|
|
+ // Free the sasl_protocol object
|
|
|
+ sasl_protocol_.reset();
|
|
|
+
|
|
|
+ if (s.ok()) {
|
|
|
+ auth_info_ = new_auth_info;
|
|
|
+
|
|
|
+ auto shared_this = shared_from_this();
|
|
|
+ SendContext([shared_this, this](const Status & s) {
|
|
|
+ ContextComplete(s);
|
|
|
+ });
|
|
|
+ } else {
|
|
|
+ CommsError(s);
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+void RpcConnection::ContextComplete(const Status &s) {
|
|
|
+ std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
+
|
|
|
+ LOG_TRACE(kRPC, << "RpcConnectionImpl::ContextComplete called");
|
|
|
+
|
|
|
+ if (s.ok()) {
|
|
|
+ if (connected_ == kAuthenticating) {
|
|
|
+ connected_ = kConnected;
|
|
|
+ }
|
|
|
+ FlushPendingRequests();
|
|
|
+ } else {
|
|
|
+ CommsError(s);
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
void RpcConnection::AsyncFlushPendingRequests() {
|
|
|
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
|
|
|
io_service().post([shared_this, this]() {
|
|
|
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
|
|
|
- LOG_TRACE(kRPC, << "RpcConnection::AsyncRpc called (connected=" << ToString(connected_) << ")");
|
|
|
+ LOG_TRACE(kRPC, << "RpcConnection::AsyncFlushPendingRequests called (connected=" << ToString(connected_) << ")");
|
|
|
|
|
|
if (!request_over_the_wire_) {
|
|
|
FlushPendingRequests();
|
|
@@ -246,10 +320,22 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
|
|
|
*
|
|
|
* AuthProtocol: 0->none, -33->SASL
|
|
|
*/
|
|
|
- static const char kHandshakeHeader[] = {'h', 'r', 'p', 'c',
|
|
|
- RpcEngine::kRpcVersion, 0, 0};
|
|
|
+
|
|
|
+ char auth_protocol = auth_info_.useSASL() ? -33 : 0;
|
|
|
+ const char handshake_header[] = {'h', 'r', 'p', 'c',
|
|
|
+ RpcEngine::kRpcVersion, 0, auth_protocol};
|
|
|
auto res =
|
|
|
- std::make_shared<std::string>(kHandshakeHeader, sizeof(kHandshakeHeader));
|
|
|
+ std::make_shared<std::string>(handshake_header, sizeof(handshake_header));
|
|
|
+
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
+std::shared_ptr<std::string> RpcConnection::PrepareContextPacket() {
|
|
|
+ // This needs to be send after the SASL handshake, and
|
|
|
+ // after the SASL handshake (if any)
|
|
|
+ assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
|
|
+
|
|
|
+ auto res = std::make_shared<std::string>();
|
|
|
|
|
|
RpcRequestHeaderProto h;
|
|
|
h.set_rpckind(RPC_PROTOCOL_BUFFER);
|
|
@@ -259,11 +345,12 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
|
|
|
|
|
|
IpcConnectionContextProto handshake;
|
|
|
handshake.set_protocol(engine_->protocol_name());
|
|
|
- const std::string & user_name = engine()->user_name();
|
|
|
+ const std::string & user_name = auth_info_.getUser();
|
|
|
if (!user_name.empty()) {
|
|
|
*handshake.mutable_userinfo()->mutable_effectiveuser() = user_name;
|
|
|
}
|
|
|
AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
|
|
|
+
|
|
|
return res;
|
|
|
}
|
|
|
|
|
@@ -272,6 +359,14 @@ void RpcConnection::AsyncRpc(
|
|
|
std::shared_ptr<::google::protobuf::MessageLite> resp,
|
|
|
const RpcCallback &handler) {
|
|
|
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
+ AsyncRpc_locked(method_name, req, resp, handler);
|
|
|
+}
|
|
|
+
|
|
|
+void RpcConnection::AsyncRpc_locked(
|
|
|
+ const std::string &method_name, const ::google::protobuf::MessageLite *req,
|
|
|
+ std::shared_ptr<::google::protobuf::MessageLite> resp,
|
|
|
+ const RpcCallback &handler) {
|
|
|
+ assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
|
|
|
|
|
auto wrapped_handler =
|
|
|
[resp, handler](pbio::CodedInputStream *is, const Status &status) {
|
|
@@ -283,29 +378,21 @@ void RpcConnection::AsyncRpc(
|
|
|
handler(status);
|
|
|
};
|
|
|
|
|
|
- auto r = std::make_shared<Request>(engine_, method_name, req,
|
|
|
+ int call_id = (method_name != SASL_METHOD_NAME ? engine_->NextCallId() : RpcEngine::kCallIdSasl);
|
|
|
+ auto r = std::make_shared<Request>(engine_, method_name, call_id, req,
|
|
|
std::move(wrapped_handler));
|
|
|
-
|
|
|
- if (connected_ == kDisconnected) {
|
|
|
- // Oops. The connection failed _just_ before the engine got a chance
|
|
|
- // to send it. Register it as a failure
|
|
|
- Status status = Status::ResourceUnavailable("RpcConnection closed before send.");
|
|
|
- auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
|
|
|
- assert(r_vector[0].get() != nullptr);
|
|
|
-
|
|
|
- engine_->AsyncRpcCommsError(status, shared_from_this(), r_vector);
|
|
|
- } else {
|
|
|
- pending_requests_.push_back(r);
|
|
|
-
|
|
|
- if (connected_ == kConnected) { // Dont flush if we're waiting or handshaking
|
|
|
- FlushPendingRequests();
|
|
|
- }
|
|
|
- }
|
|
|
+ auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
|
|
|
+ SendRpcRequests(r_vector);
|
|
|
}
|
|
|
|
|
|
void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requests) {
|
|
|
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
- LOG_TRACE(kRPC, << "RpcConnection::AsyncRpc[] called; connected=" << ToString(connected_));
|
|
|
+ SendRpcRequests(requests);
|
|
|
+}
|
|
|
+
|
|
|
+void RpcConnection::SendRpcRequests(const std::vector<std::shared_ptr<Request> > & requests) {
|
|
|
+ LOG_TRACE(kRPC, << "RpcConnection::SendRpcRequests[] called; connected=" << ToString(connected_));
|
|
|
+ assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
|
|
|
|
|
if (connected_ == kDisconnected) {
|
|
|
// Oops. The connection failed _just_ before the engine got a chance
|
|
@@ -315,9 +402,12 @@ void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requ
|
|
|
} else {
|
|
|
pending_requests_.reserve(pending_requests_.size() + requests.size());
|
|
|
for (auto r: requests) {
|
|
|
- pending_requests_.push_back(r);
|
|
|
+ if (r->method_name() != SASL_METHOD_NAME)
|
|
|
+ pending_requests_.push_back(r);
|
|
|
+ else
|
|
|
+ auth_requests_.push_back(r);
|
|
|
}
|
|
|
- if (connected_ == kConnected) { // Dont flush if we're waiting or handshaking
|
|
|
+ if (connected_ == kConnected || connected_ == kAuthenticating) { // Dont flush if we're waiting or handshaking
|
|
|
FlushPendingRequests();
|
|
|
}
|
|
|
}
|
|
@@ -341,6 +431,9 @@ void RpcConnection::PreEnqueueRequests(
|
|
|
void RpcConnection::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
|
|
|
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
|
|
event_handlers_ = event_handlers;
|
|
|
+ if (sasl_protocol_) {
|
|
|
+ sasl_protocol_->SetEventHandlers(event_handlers);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void RpcConnection::SetClusterName(std::string cluster_name) {
|
|
@@ -401,6 +494,7 @@ std::string RpcConnection::ToString(ConnectedState connected) {
|
|
|
switch(connected) {
|
|
|
case kNotYetConnected: return "NotYetConnected";
|
|
|
case kConnecting: return "Connecting";
|
|
|
+ case kAuthenticating: return "Authenticating";
|
|
|
case kConnected: return "Connected";
|
|
|
case kDisconnected: return "Disconnected";
|
|
|
default: return "Invalid ConnectedState";
|