sasl_protocol.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "sasl_protocol.h"
  19. #include "sasl_engine.h"
  20. #include "rpc_engine.h"
  21. #include "common/logging.h"
  22. #include <optional.hpp>
  23. namespace hdfs {
  24. using namespace hadoop::common;
  25. using namespace google::protobuf;
  26. template <class T>
  27. using optional = std::experimental::optional<T>;
  28. /*****
  29. * Threading model: all entry points need to acquire the sasl_lock before accessing
  30. * members of the class
  31. *
  32. * Lifecycle model: asio may have outstanding callbacks into this class for arbitrary
  33. * amounts of time, so any references to the class must be shared_ptr's. The
  34. * SASLProtocol keeps a weak_ptr to the owning RpcConnection, which might go away,
  35. * so the weak_ptr should be locked only long enough to make callbacks into the
  36. * RpcConnection.
  37. */
  38. SaslProtocol::SaslProtocol(const std::string & cluster_name,
  39. const AuthInfo & auth_info,
  40. std::shared_ptr<RpcConnection> connection) :
  41. state_(kUnstarted),
  42. cluster_name_(cluster_name),
  43. auth_info_(auth_info),
  44. connection_(connection)
  45. {
  46. }
  47. SaslProtocol::~SaslProtocol()
  48. {
  49. std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
  50. event_handlers_->call("SASL End", cluster_name_.c_str(), 0);
  51. }
  52. void SaslProtocol::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
  53. std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
  54. event_handlers_ = event_handlers;
  55. }
  56. void SaslProtocol::authenticate(std::function<void(const Status & status, const AuthInfo new_auth_info)> callback)
  57. {
  58. std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
  59. LOG_TRACE(kRPC, << "Authenticating as " << auth_info_.getUser());
  60. assert(state_ == kUnstarted);
  61. event_handlers_->call("SASL Start", cluster_name_.c_str(), 0);
  62. callback_ = callback;
  63. state_ = kNegotiate;
  64. std::shared_ptr<RpcSaslProto> req_msg = std::make_shared<RpcSaslProto>();
  65. req_msg->set_state(RpcSaslProto_SaslState_NEGOTIATE);
  66. // We cheat here since this is always called while holding the RpcConnection's lock
  67. std::shared_ptr<RpcConnection> connection = connection_.lock();
  68. if (!connection) {
  69. return;
  70. }
  71. std::shared_ptr<RpcSaslProto> resp_msg = std::make_shared<RpcSaslProto>();
  72. auto self(shared_from_this());
  73. connection->AsyncRpc_locked(SASL_METHOD_NAME, req_msg.get(), resp_msg,
  74. [self, req_msg, resp_msg] (const Status & status) { self->OnServerResponse(status, resp_msg.get()); } );
  75. }
  76. AuthInfo::AuthMethod ParseMethod(const std::string & method)
  77. {
  78. if (0 == strcasecmp(method.c_str(), "SIMPLE")) {
  79. return AuthInfo::kSimple;
  80. }
  81. else if (0 == strcasecmp(method.c_str(), "KERBEROS")) {
  82. return AuthInfo::kKerberos;
  83. }
  84. else if (0 == strcasecmp(method.c_str(), "TOKEN")) {
  85. return AuthInfo::kToken;
  86. }
  87. else {
  88. return AuthInfo::kUnknownAuth;
  89. }
  90. }
  91. void SaslProtocol::Negotiate(const hadoop::common::RpcSaslProto * response)
  92. {
  93. std::vector<SaslMethod> protocols;
  94. bool simple_available = false;
  95. #if defined USE_SASL
  96. #if defined USE_CYRUS_SASL
  97. sasl_engine_.reset(new CyrusSaslEngine());
  98. #elif defined USE_GSASL
  99. sasl_engine_.reset(new GSaslEngine());
  100. #else
  101. #error USE_SASL defined but no engine (USE_GSASL) defined
  102. #endif
  103. #endif
  104. if (auth_info_.getToken()) {
  105. sasl_engine_->setPasswordInfo(auth_info_.getToken().value().identifier,
  106. auth_info_.getToken().value().password);
  107. }
  108. sasl_engine_->setKerberosInfo(auth_info_.getUser()); // HDFS-10451 will look up principal by username
  109. auto auths = response->auths();
  110. for (int i = 0; i < auths.size(); ++i) {
  111. auto auth = auths.Get(i);
  112. AuthInfo::AuthMethod method = ParseMethod(auth.method());
  113. switch(method) {
  114. case AuthInfo::kToken:
  115. case AuthInfo::kKerberos: {
  116. SaslMethod new_method;
  117. new_method.mechanism = auth.mechanism();
  118. new_method.protocol = auth.protocol();
  119. new_method.serverid = auth.serverid();
  120. new_method.data = const_cast<RpcSaslProto_SaslAuth *>(&response->auths().Get(i));
  121. protocols.push_back(new_method);
  122. }
  123. break;
  124. case AuthInfo::kSimple:
  125. simple_available = true;
  126. break;
  127. case AuthInfo::kUnknownAuth:
  128. LOG_WARN(kRPC, << "Unknown auth method " << auth.method() << "; ignoring");
  129. break;
  130. default:
  131. LOG_WARN(kRPC, << "Invalid auth type: " << method << "; ignoring");
  132. break;
  133. }
  134. }
  135. if (!protocols.empty()) {
  136. auto init = sasl_engine_->start(protocols);
  137. if (init.first.ok()) {
  138. auto chosen_auth = reinterpret_cast<RpcSaslProto_SaslAuth *>(init.second.data);
  139. // Prepare initiate message
  140. RpcSaslProto initiate;
  141. initiate.set_state(RpcSaslProto_SaslState_INITIATE);
  142. RpcSaslProto_SaslAuth * respAuth = initiate.add_auths();
  143. respAuth->CopyFrom(*chosen_auth);
  144. LOG_TRACE(kRPC, << "Using auth: " << chosen_auth->protocol() << "/" <<
  145. chosen_auth->mechanism() << "/" << chosen_auth->serverid());
  146. std::string challenge = chosen_auth->has_challenge() ? chosen_auth->challenge() : "";
  147. auto sasl_challenge = sasl_engine_->step(challenge);
  148. if (sasl_challenge.first.ok()) {
  149. if (!sasl_challenge.second.empty()) {
  150. initiate.set_token(sasl_challenge.second);
  151. }
  152. std::shared_ptr<RpcSaslProto> return_msg = std::make_shared<RpcSaslProto>();
  153. SendSaslMessage(initiate);
  154. return;
  155. } else {
  156. AuthComplete(sasl_challenge.first, auth_info_);
  157. return;
  158. }
  159. } else if (!simple_available) {
  160. // If simple IS available, fall through to below
  161. AuthComplete(init.first, auth_info_);
  162. return;
  163. }
  164. }
  165. // There were no protocols, or the SaslEngine couldn't make one work
  166. if (simple_available) {
  167. // Simple was the only one we could use. That's OK.
  168. AuthComplete(Status::OK(), auth_info_);
  169. return;
  170. } else {
  171. // We didn't understand any of the protocols; give back some information
  172. std::stringstream ss;
  173. ss << "Client cannot authenticate via: ";
  174. for (int i = 0; i < auths.size(); ++i) {
  175. auto auth = auths.Get(i);
  176. ss << auth.mechanism() << ", ";
  177. }
  178. AuthComplete(Status::Error(ss.str().c_str()), auth_info_);
  179. return;
  180. }
  181. }
  182. void SaslProtocol::Challenge(const hadoop::common::RpcSaslProto * challenge)
  183. {
  184. if (!sasl_engine_) {
  185. AuthComplete(Status::Error("Received challenge before negotiate"), auth_info_);
  186. return;
  187. }
  188. RpcSaslProto response;
  189. response.CopyFrom(*challenge);
  190. response.set_state(RpcSaslProto_SaslState_RESPONSE);
  191. std::string challenge_token = challenge->has_token() ? challenge->token() : "";
  192. auto sasl_response = sasl_engine_->step(challenge_token);
  193. if (sasl_response.first.ok()) {
  194. response.set_token(sasl_response.second);
  195. std::shared_ptr<RpcSaslProto> return_msg = std::make_shared<RpcSaslProto>();
  196. SendSaslMessage(response);
  197. } else {
  198. AuthComplete(sasl_response.first, auth_info_);
  199. return;
  200. }
  201. }
  202. bool SaslProtocol::SendSaslMessage(RpcSaslProto & message)
  203. {
  204. assert(lock_held(sasl_state_lock_)); // Must be holding lock before calling
  205. // RpcConnection might have been freed when we weren't looking. Lock it
  206. // to make sure it's there long enough for us
  207. std::shared_ptr<RpcConnection> connection = connection_.lock();
  208. if (!connection) {
  209. LOG_DEBUG(kRPC, << "Tried sending a SASL Message but the RPC connection was gone");
  210. return false;
  211. }
  212. std::shared_ptr<RpcSaslProto> resp_msg = std::make_shared<RpcSaslProto>();
  213. auto self(shared_from_this());
  214. connection->AsyncRpc(SASL_METHOD_NAME, &message, resp_msg,
  215. [self, resp_msg] (const Status & status) {
  216. self->OnServerResponse(status, resp_msg.get());
  217. } );
  218. return true;
  219. }
  220. bool SaslProtocol::AuthComplete(const Status & status, const AuthInfo & auth_info)
  221. {
  222. assert(lock_held(sasl_state_lock_)); // Must be holding lock before calling
  223. // RpcConnection might have been freed when we weren't looking. Lock it
  224. // to make sure it's there long enough for us
  225. std::shared_ptr<RpcConnection> connection = connection_.lock();
  226. if (!connection) {
  227. LOG_DEBUG(kRPC, << "Tried sending an AuthComplete but the RPC connection was gone: " << status.ToString());
  228. return false;
  229. }
  230. if (!status.ok()) {
  231. auth_info_.setMethod(AuthInfo::kAuthFailed);
  232. }
  233. LOG_TRACE(kRPC, << "AuthComplete: " << status.ToString());
  234. connection->AuthComplete(status, auth_info);
  235. return true;
  236. }
  237. void SaslProtocol::OnServerResponse(const Status & status, const hadoop::common::RpcSaslProto * response)
  238. {
  239. std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
  240. LOG_TRACE(kRPC, << "Received SASL response: " << status.ToString());
  241. if (status.ok()) {
  242. switch(response->state()) {
  243. case RpcSaslProto_SaslState_NEGOTIATE:
  244. Negotiate(response);
  245. break;
  246. case RpcSaslProto_SaslState_CHALLENGE:
  247. Challenge(response);
  248. break;
  249. case RpcSaslProto_SaslState_SUCCESS:
  250. if (sasl_engine_) {
  251. sasl_engine_->finish();
  252. }
  253. AuthComplete(Status::OK(), auth_info_);
  254. break;
  255. case RpcSaslProto_SaslState_INITIATE: // Server side only
  256. case RpcSaslProto_SaslState_RESPONSE: // Server side only
  257. case RpcSaslProto_SaslState_WRAP:
  258. LOG_ERROR(kRPC, << "Invalid client-side SASL state: " << response->state());
  259. AuthComplete(Status::Error("Invalid client-side state"), auth_info_);
  260. break;
  261. default:
  262. LOG_ERROR(kRPC, << "Unknown client-side SASL state: " << response->state());
  263. AuthComplete(Status::Error("Unknown client-side state"), auth_info_);
  264. break;
  265. }
  266. } else {
  267. AuthComplete(status, auth_info_);
  268. }
  269. }
  270. }