Browse Source

AMBARI-6608. Ambari server should inform the agent whether to use two-way ssl when registering (dlysnichenko)

Lisnichenko Dmitro 11 years ago
parent
commit
92c4d45f6e
22 changed files with 347 additions and 230 deletions
  1. 82 24
      ambari-agent/src/main/python/ambari_agent/AmbariConfig.py
  2. 41 42
      ambari-agent/src/main/python/ambari_agent/Controller.py
  3. 3 3
      ambari-agent/src/main/python/ambari_agent/CustomServiceOrchestrator.py
  4. 10 10
      ambari-agent/src/main/python/ambari_agent/Heartbeat.py
  5. 4 5
      ambari-agent/src/main/python/ambari_agent/HostCleanup.py
  6. 2 1
      ambari-agent/src/main/python/ambari_agent/HostInfo.py
  7. 30 25
      ambari-agent/src/main/python/ambari_agent/NetUtil.py
  8. 5 6
      ambari-agent/src/main/python/ambari_agent/Register.py
  9. 8 8
      ambari-agent/src/main/python/ambari_agent/hostname.py
  10. 20 15
      ambari-agent/src/main/python/ambari_agent/main.py
  11. 46 42
      ambari-agent/src/main/python/ambari_agent/security.py
  12. 4 2
      ambari-agent/src/test/python/ambari_agent/TestActionQueue.py
  13. 5 5
      ambari-agent/src/test/python/ambari_agent/TestCertGeneration.py
  14. 2 1
      ambari-agent/src/test/python/ambari_agent/TestController.py
  15. 11 11
      ambari-agent/src/test/python/ambari_agent/TestHostname.py
  16. 2 2
      ambari-agent/src/test/python/ambari_agent/TestMain.py
  17. 7 9
      ambari-agent/src/test/python/ambari_agent/TestNetUtil.py
  18. 1 1
      ambari-agent/src/test/python/ambari_agent/TestSecurity.py
  19. 5 17
      ambari-server/src/main/java/org/apache/ambari/server/controller/AmbariServer.java
  20. 4 0
      ambari-server/src/main/java/org/apache/ambari/server/security/SecurityFilter.java
  21. 54 0
      ambari-server/src/main/java/org/apache/ambari/server/security/unsecured/rest/ConnectionInfo.java
  22. 1 1
      ambari-server/src/test/java/org/apache/ambari/server/security/CertGenerationTest.java

+ 82 - 24
ambari-agent/src/main/python/ambari_agent/AmbariConfig.py

@@ -20,8 +20,9 @@ limitations under the License.
 
 
 import ConfigParser
 import ConfigParser
 import StringIO
 import StringIO
+import json
+from NetUtil import NetUtil
 
 
-config = ConfigParser.RawConfigParser()
 content = """
 content = """
 
 
 [server]
 [server]
@@ -58,8 +59,6 @@ rpms=glusterfs,openssl,wget,net-snmp,ntpd,ganglia,nagios,glusterfs
 log_lines_count=300
 log_lines_count=300
 
 
 """
 """
-s = StringIO.StringIO(content)
-config.readfp(s)
 
 
 imports = [
 imports = [
   "hdp/manifests/*.pp",
   "hdp/manifests/*.pp",
@@ -145,7 +144,7 @@ serviceStates = {
 }
 }
 
 
 servicesToPidNames = {
 servicesToPidNames = {
-  'GLUSTERFS' : 'glusterd.pid$',    
+  'GLUSTERFS' : 'glusterd.pid$',
   'NAMENODE': 'hadoop-{USER}-namenode.pid$',
   'NAMENODE': 'hadoop-{USER}-namenode.pid$',
   'SECONDARY_NAMENODE': 'hadoop-{USER}-secondarynamenode.pid$',
   'SECONDARY_NAMENODE': 'hadoop-{USER}-secondarynamenode.pid$',
   'DATANODE': 'hadoop-{USER}-datanode.pid$',
   'DATANODE': 'hadoop-{USER}-datanode.pid$',
@@ -192,43 +191,65 @@ servicesToLinuxUser = {
 
 
 pidPathesVars = [
 pidPathesVars = [
   {'var' : 'glusterfs_pid_dir_prefix',
   {'var' : 'glusterfs_pid_dir_prefix',
-   'defaultValue' : '/var/run'},      
+   'defaultValue' : '/var/run'},
   {'var' : 'hadoop_pid_dir_prefix',
   {'var' : 'hadoop_pid_dir_prefix',
    'defaultValue' : '/var/run/hadoop'},
    'defaultValue' : '/var/run/hadoop'},
   {'var' : 'hadoop_pid_dir_prefix',
   {'var' : 'hadoop_pid_dir_prefix',
-   'defaultValue' : '/var/run/hadoop'},                 
+   'defaultValue' : '/var/run/hadoop'},
   {'var' : 'ganglia_runtime_dir',
   {'var' : 'ganglia_runtime_dir',
-   'defaultValue' : '/var/run/ganglia/hdp'},                 
+   'defaultValue' : '/var/run/ganglia/hdp'},
   {'var' : 'hbase_pid_dir',
   {'var' : 'hbase_pid_dir',
-   'defaultValue' : '/var/run/hbase'},                
+   'defaultValue' : '/var/run/hbase'},
   {'var' : '',
   {'var' : '',
-   'defaultValue' : '/var/run/nagios'},                    
+   'defaultValue' : '/var/run/nagios'},
   {'var' : 'zk_pid_dir',
   {'var' : 'zk_pid_dir',
-   'defaultValue' : '/var/run/zookeeper'},             
+   'defaultValue' : '/var/run/zookeeper'},
   {'var' : 'oozie_pid_dir',
   {'var' : 'oozie_pid_dir',
-   'defaultValue' : '/var/run/oozie'},             
+   'defaultValue' : '/var/run/oozie'},
   {'var' : 'hcat_pid_dir',
   {'var' : 'hcat_pid_dir',
-   'defaultValue' : '/var/run/webhcat'},                       
+   'defaultValue' : '/var/run/webhcat'},
   {'var' : 'hive_pid_dir',
   {'var' : 'hive_pid_dir',
-   'defaultValue' : '/var/run/hive'},                      
+   'defaultValue' : '/var/run/hive'},
   {'var' : 'mysqld_pid_dir',
   {'var' : 'mysqld_pid_dir',
    'defaultValue' : '/var/run/mysqld'},
    'defaultValue' : '/var/run/mysqld'},
   {'var' : 'hcat_pid_dir',
   {'var' : 'hcat_pid_dir',
-   'defaultValue' : '/var/run/webhcat'},                      
+   'defaultValue' : '/var/run/webhcat'},
   {'var' : 'yarn_pid_dir_prefix',
   {'var' : 'yarn_pid_dir_prefix',
    'defaultValue' : '/var/run/hadoop-yarn'},
    'defaultValue' : '/var/run/hadoop-yarn'},
   {'var' : 'mapred_pid_dir_prefix',
   {'var' : 'mapred_pid_dir_prefix',
    'defaultValue' : '/var/run/hadoop-mapreduce'},
    'defaultValue' : '/var/run/hadoop-mapreduce'},
 ]
 ]
 
 
+
 class AmbariConfig:
 class AmbariConfig:
-  def getConfig(self):
-    global config
-    return config
+  TWO_WAY_SSL_PROPERTY = "security.server.two_way_ssl"
+  CONFIG_FILE = "/etc/ambari-agent/conf/ambari-agent.ini"
+  SERVER_CONNECTION_INFO = "{0}/connection_info"
+  CONNECTION_PROTOCOL = "https"
 
 
-  def getImports(self):
-    global imports
-    return imports
+  config = None
+  net = None
+
+  def __init__(self):
+    global content
+    self.config = ConfigParser.RawConfigParser()
+    self.net = NetUtil()
+    self.config.readfp(StringIO.StringIO(content))
+
+  def get(self, section, value):
+    return self.config.get(section, value)
+
+  def set(self, section, option, value):
+    self.config.set(section, option, value)
+
+  def add_section(self, section):
+    self.config.add_section(section)
+
+  def setConfig(self, customConfig):
+    self.config = customConfig
+
+  def getConfig(self):
+    return self.config
 
 
   def getRolesToClass(self):
   def getRolesToClass(self):
     global rolesToClass
     global rolesToClass
@@ -242,18 +263,55 @@ class AmbariConfig:
     global servicesToPidNames
     global servicesToPidNames
     return servicesToPidNames
     return servicesToPidNames
 
 
+  def getImports(self):
+    global imports
+    return imports
+
   def getPidPathesVars(self):
   def getPidPathesVars(self):
     global pidPathesVars
     global pidPathesVars
     return pidPathesVars
     return pidPathesVars
 
 
+  def has_option(self, section, option):
+    return self.config.has_option(section, option)
+
+  def remove_option(self, section, option):
+    return self.config.remove_option(section, option)
+
+  def load(self, data):
+    self.config = ConfigParser.RawConfigParser(data)
+
+  def read(self, filename):
+    self.config.read(filename)
+
+  def getServerOption(self, url, name, default=None):
+    status, response = self.net.checkURL(url)
+    if status is True:
+      try:
+        data = json.loads(response)
+        if name in data:
+          return data[name]
+      except:
+        pass
+    return default
+
+  def get_api_url(self):
+    return "%s://%s:%s" % (self.CONNECTION_PROTOCOL,
+                           self.get('server', 'hostname'),
+                           self.get('server', 'url_port'))
 
 
-def setConfig(customConfig):
-  global config
-  config = customConfig
+  def isTwoWaySSLConnection(self):
+    req_url = self.get_api_url()
+    response = self.getServerOption(self.SERVER_CONNECTION_INFO.format(req_url), self.TWO_WAY_SSL_PROPERTY, 'false')
+    if response is None:
+      return False
+    elif response.lower() == "true":
+      return True
+    else:
+      return False
 
 
 
 
 def main():
 def main():
-  print config
+  print AmbariConfig().config
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
   main()
   main()

+ 41 - 42
ambari-agent/src/main/python/ambari_agent/Controller.py

@@ -54,7 +54,7 @@ class Controller(threading.Thread):
     self.safeMode = True
     self.safeMode = True
     self.credential = None
     self.credential = None
     self.config = config
     self.config = config
-    self.hostname = hostname.hostname()
+    self.hostname = hostname.hostname(config)
     self.serverHostname = config.get('server', 'hostname')
     self.serverHostname = config.get('server', 'hostname')
     server_secured_url = 'https://' + self.serverHostname + \
     server_secured_url = 'https://' + self.serverHostname + \
                          ':' + config.get('server', 'secured_url_port')
                          ':' + config.get('server', 'secured_url_port')
@@ -78,7 +78,7 @@ class Controller(threading.Thread):
   def __del__(self):
   def __del__(self):
     logger.info("Server connection disconnected.")
     logger.info("Server connection disconnected.")
     pass
     pass
-  
+
   def registerWithServer(self):
   def registerWithServer(self):
     LiveStatus.SERVICES = []
     LiveStatus.SERVICES = []
     LiveStatus.CLIENT_COMPONENTS = []
     LiveStatus.CLIENT_COMPONENTS = []
@@ -87,36 +87,36 @@ class Controller(threading.Thread):
     ret = {}
     ret = {}
 
 
     while not self.isRegistered:
     while not self.isRegistered:
-      try:                
+      try:
         data = json.dumps(self.register.build(id))
         data = json.dumps(self.register.build(id))
         prettyData = pprint.pformat(data)
         prettyData = pprint.pformat(data)
-        
+
         try:
         try:
           server_ip = socket.gethostbyname(self.hostname)
           server_ip = socket.gethostbyname(self.hostname)
           logger.info("Registering with %s (%s) (agent=%s)", self.hostname, server_ip, prettyData)
           logger.info("Registering with %s (%s) (agent=%s)", self.hostname, server_ip, prettyData)
-        except socket.error:          
-          logger.warn("Unable to determine the IP address of '%s', agent registration may fail (agent=%s)", 
+        except socket.error:
+          logger.warn("Unable to determine the IP address of '%s', agent registration may fail (agent=%s)",
                       self.hostname, prettyData)
                       self.hostname, prettyData)
-                
+
         ret = self.sendRequest(self.registerUrl, data)
         ret = self.sendRequest(self.registerUrl, data)
-        
+
         # exitstatus is a code of error which was rised on server side.
         # exitstatus is a code of error which was rised on server side.
         # exitstatus = 0 (OK - Default)
         # exitstatus = 0 (OK - Default)
         # exitstatus = 1 (Registration failed because different version of agent and server)
         # exitstatus = 1 (Registration failed because different version of agent and server)
         exitstatus = 0
         exitstatus = 0
         if 'exitstatus' in ret.keys():
         if 'exitstatus' in ret.keys():
           exitstatus = int(ret['exitstatus'])
           exitstatus = int(ret['exitstatus'])
-                
+
         if exitstatus == 1:
         if exitstatus == 1:
-          # log - message, which will be printed to agents log  
+          # log - message, which will be printed to agents log
           if 'log' in ret.keys():
           if 'log' in ret.keys():
-            log = ret['log']          
-          
+            log = ret['log']
+
           logger.error(log)
           logger.error(log)
           self.isRegistered = False
           self.isRegistered = False
-          self.repeatRegistration=False
+          self.repeatRegistration = False
           return ret
           return ret
-        
+
         logger.info("Registration Successful (response=%s)", pprint.pformat(ret))
         logger.info("Registration Successful (response=%s)", pprint.pformat(ret))
 
 
         self.responseId = int(ret['responseId'])
         self.responseId = int(ret['responseId'])
@@ -139,10 +139,10 @@ class Controller(threading.Thread):
         """ Sleeping for {0} seconds and then retrying again """.format(delay)
         """ Sleeping for {0} seconds and then retrying again """.format(delay)
         time.sleep(delay)
         time.sleep(delay)
         pass
         pass
-      pass  
+      pass
     return ret
     return ret
-  
-  
+
+
   def addToQueue(self, commands):
   def addToQueue(self, commands):
     """Add to the queue for running the commands """
     """Add to the queue for running the commands """
     """ Put the required actions into the Queue """
     """ Put the required actions into the Queue """
@@ -174,8 +174,7 @@ class Controller(threading.Thread):
     retry = False
     retry = False
     certVerifFailed = False
     certVerifFailed = False
 
 
-    config = AmbariConfig.config
-    hb_interval = config.get('heartbeat', 'state_interval')
+    hb_interval = self.config.get('heartbeat', 'state_interval')
 
 
     #TODO make sure the response id is monotonically increasing
     #TODO make sure the response id is monotonically increasing
     id = 0
     id = 0
@@ -190,22 +189,22 @@ class Controller(threading.Thread):
 
 
         if logger.isEnabledFor(logging.DEBUG):
         if logger.isEnabledFor(logging.DEBUG):
           logger.debug("Sending Heartbeat (id = %s): %s", self.responseId, data)
           logger.debug("Sending Heartbeat (id = %s): %s", self.responseId, data)
-        
+
         response = self.sendRequest(self.heartbeatUrl, data)
         response = self.sendRequest(self.heartbeatUrl, data)
-        
+
         exitStatus = 0
         exitStatus = 0
         if 'exitstatus' in response.keys():
         if 'exitstatus' in response.keys():
-          exitStatus = int(response['exitstatus'])   
-        
+          exitStatus = int(response['exitstatus'])
+
         if exitStatus != 0:
         if exitStatus != 0:
           raise Exception(response)
           raise Exception(response)
-        
+
         serverId = int(response['responseId'])
         serverId = int(response['responseId'])
 
 
         if logger.isEnabledFor(logging.DEBUG):
         if logger.isEnabledFor(logging.DEBUG):
           logger.debug('Heartbeat response (id = %s): %s', serverId, pprint.pformat(response))
           logger.debug('Heartbeat response (id = %s): %s', serverId, pprint.pformat(response))
         else:
         else:
-          logger.info('Heartbeat response received (id = %s)', serverId)                
+          logger.info('Heartbeat response received (id = %s)', serverId)
 
 
         if 'hasMappedComponents' in response.keys():
         if 'hasMappedComponents' in response.keys():
           self.hasMappedComponents = response['hasMappedComponents'] != False
           self.hasMappedComponents = response['hasMappedComponents'] != False
@@ -227,11 +226,11 @@ class Controller(threading.Thread):
         if 'executionCommands' in response.keys():
         if 'executionCommands' in response.keys():
           self.addToQueue(response['executionCommands'])
           self.addToQueue(response['executionCommands'])
           pass
           pass
-        
+
         if 'statusCommands' in response.keys():
         if 'statusCommands' in response.keys():
           self.addToStatusQueue(response['statusCommands'])
           self.addToStatusQueue(response['statusCommands'])
           pass
           pass
-        
+
         if "true" == response['restartAgent']:
         if "true" == response['restartAgent']:
           logger.error("Received the restartAgent command")
           logger.error("Received the restartAgent command")
           self.restartAgent()
           self.restartAgent()
@@ -241,7 +240,7 @@ class Controller(threading.Thread):
 
 
         if retry:
         if retry:
           logger.info("Reconnected to %s", self.heartbeatUrl)
           logger.info("Reconnected to %s", self.heartbeatUrl)
-          
+
         retry=False
         retry=False
         certVerifFailed = False
         certVerifFailed = False
         self.DEBUG_SUCCESSFULL_HEARTBEATS += 1
         self.DEBUG_SUCCESSFULL_HEARTBEATS += 1
@@ -255,29 +254,29 @@ class Controller(threading.Thread):
         #randomize the heartbeat
         #randomize the heartbeat
         delay = randint(0, self.range)
         delay = randint(0, self.range)
         time.sleep(delay)
         time.sleep(delay)
-        
+
         if "code" in err:
         if "code" in err:
           logger.error(err.code)
           logger.error(err.code)
         else:
         else:
           logException = False
           logException = False
           if logger.isEnabledFor(logging.DEBUG):
           if logger.isEnabledFor(logging.DEBUG):
             logException = True
             logException = True
-          
+
           exceptionMessage = str(err)
           exceptionMessage = str(err)
           errorMessage = "Unable to reconnect to {0} (attempts={1}, details={2})".format(self.heartbeatUrl, self.DEBUG_HEARTBEAT_RETRIES, exceptionMessage)
           errorMessage = "Unable to reconnect to {0} (attempts={1}, details={2})".format(self.heartbeatUrl, self.DEBUG_HEARTBEAT_RETRIES, exceptionMessage)
-          
+
           if not retry:
           if not retry:
             errorMessage = "Connection to {0} was lost (details={1})".format(self.serverHostname, exceptionMessage)
             errorMessage = "Connection to {0} was lost (details={1})".format(self.serverHostname, exceptionMessage)
-          
+
           logger.error(errorMessage, exc_info=logException)
           logger.error(errorMessage, exc_info=logException)
-            
+
           if 'certificate verify failed' in str(err) and not certVerifFailed:
           if 'certificate verify failed' in str(err) and not certVerifFailed:
             logger.warn("Server certificate verify failed. Did you regenerate server certificate?")
             logger.warn("Server certificate verify failed. Did you regenerate server certificate?")
             certVerifFailed = True
             certVerifFailed = True
-            
+
         self.cachedconnect = None # Previous connection is broken now
         self.cachedconnect = None # Previous connection is broken now
         retry=True
         retry=True
-        
+
       # Sleep for some time
       # Sleep for some time
       timeout = self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC \
       timeout = self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC \
                 - self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS
                 - self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS
@@ -308,12 +307,12 @@ class Controller(threading.Thread):
     registerResponse = self.registerWithServer()
     registerResponse = self.registerWithServer()
     message = registerResponse['response']
     message = registerResponse['response']
     logger.info("Registration response from %s was %s", self.serverHostname, message)
     logger.info("Registration response from %s was %s", self.serverHostname, message)
-    
+
     if self.isRegistered:
     if self.isRegistered:
       # Process callbacks
       # Process callbacks
       for callback in self.registration_listeners:
       for callback in self.registration_listeners:
         callback()
         callback()
-        
+
       time.sleep(self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC)
       time.sleep(self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC)
       self.heartbeatWithServer()
       self.heartbeatWithServer()
 
 
@@ -323,11 +322,11 @@ class Controller(threading.Thread):
 
 
   def sendRequest(self, url, data):
   def sendRequest(self, url, data):
     response = None
     response = None
-    
+
     try:
     try:
       if self.cachedconnect is None: # Lazy initialization
       if self.cachedconnect is None: # Lazy initialization
         self.cachedconnect = security.CachedHTTPSConnection(self.config)
         self.cachedconnect = security.CachedHTTPSConnection(self.config)
-      req = urllib2.Request(url, data, {'Content-Type': 'application/json'})      
+      req = urllib2.Request(url, data, {'Content-Type': 'application/json'})
       response = self.cachedconnect.request(req)
       response = self.cachedconnect.request(req)
       return json.loads(response)
       return json.loads(response)
     except Exception, exception:
     except Exception, exception:
@@ -342,10 +341,10 @@ class Controller(threading.Thread):
 
 
   def updateComponents(self, cluster_name):
   def updateComponents(self, cluster_name):
     logger.info("Updating components map of cluster " + cluster_name)
     logger.info("Updating components map of cluster " + cluster_name)
-    
-    response = self.sendRequest(self.componentsUrl + cluster_name, None)    
+
+    response = self.sendRequest(self.componentsUrl + cluster_name, None)
     logger.debug("Response from %s was %s", self.serverHostname, str(response))
     logger.debug("Response from %s was %s", self.serverHostname, str(response))
-    
+
     for service, components in response['components'].items():
     for service, components in response['components'].items():
       LiveStatus.SERVICES.append(service)
       LiveStatus.SERVICES.append(service)
       for component, category in components.items():
       for component, category in components.items():

+ 3 - 3
ambari-agent/src/main/python/ambari_agent/CustomServiceOrchestrator.py

@@ -59,7 +59,7 @@ class CustomServiceOrchestrator():
                                                'status_command_stdout.txt')
                                                'status_command_stdout.txt')
     self.status_commands_stderr = os.path.join(self.tmp_dir,
     self.status_commands_stderr = os.path.join(self.tmp_dir,
                                                'status_command_stderr.txt')
                                                'status_command_stderr.txt')
-    self.public_fqdn = hostname.public_hostname()
+    self.public_fqdn = hostname.public_hostname(config)
     # cache reset will be called on every agent registration
     # cache reset will be called on every agent registration
     controller.registration_listeners.append(self.file_cache.reset)
     controller.registration_listeners.append(self.file_cache.reset)
     # Clean up old status command files if any
     # Clean up old status command files if any
@@ -122,12 +122,12 @@ class CustomServiceOrchestrator():
       py_file_list = [pre_hook_tuple, script_tuple, post_hook_tuple]
       py_file_list = [pre_hook_tuple, script_tuple, post_hook_tuple]
       # filter None values
       # filter None values
       filtered_py_file_list = [i for i in py_file_list if i]
       filtered_py_file_list = [i for i in py_file_list if i]
-      
+
       logger_level = logging.getLevelName(logger.level)
       logger_level = logging.getLevelName(logger.level)
 
 
       # Executing hooks and script
       # Executing hooks and script
       ret = None
       ret = None
-      
+
       for py_file, current_base_dir in filtered_py_file_list:
       for py_file, current_base_dir in filtered_py_file_list:
         script_params = [command_name, json_path, current_base_dir]
         script_params = [command_name, json_path, current_base_dir]
         ret = self.python_executor.run_file(py_file, script_params,
         ret = self.python_executor.run_file(py_file, script_params,

+ 10 - 10
ambari-agent/src/main/python/ambari_agent/Heartbeat.py

@@ -45,23 +45,23 @@ class Heartbeat:
     timestamp = int(time.time()*1000)
     timestamp = int(time.time()*1000)
     queueResult = self.actionQueue.result()
     queueResult = self.actionQueue.result()
 
 
-    
+
     nodeStatus = { "status" : "HEALTHY",
     nodeStatus = { "status" : "HEALTHY",
                    "cause" : "NONE" }
                    "cause" : "NONE" }
     nodeStatus["alerts"] = []
     nodeStatus["alerts"] = []
-    
-    
-    
+
+
+
     heartbeat = { 'responseId'        : int(id),
     heartbeat = { 'responseId'        : int(id),
                   'timestamp'         : timestamp,
                   'timestamp'         : timestamp,
-                  'hostname'          : hostname.hostname(),
+                  'hostname'          : hostname.hostname(self.config),
                   'nodeStatus'        : nodeStatus
                   'nodeStatus'        : nodeStatus
                 }
                 }
 
 
     commandsInProgress = False
     commandsInProgress = False
     if not self.actionQueue.commandQueue.empty():
     if not self.actionQueue.commandQueue.empty():
       commandsInProgress = True
       commandsInProgress = True
-      
+
     if len(queueResult) != 0:
     if len(queueResult) != 0:
       heartbeat['reports'] = queueResult['reports']
       heartbeat['reports'] = queueResult['reports']
       heartbeat['componentStatus'] = queueResult['componentStatus']
       heartbeat['componentStatus'] = queueResult['componentStatus']
@@ -74,9 +74,9 @@ class Heartbeat:
     if int(id) == 0:
     if int(id) == 0:
       componentsMapped = False
       componentsMapped = False
 
 
-    logger.info("Building Heartbeat: {responseId = %s, timestamp = %s, commandsInProgress = %s, componentsMapped = %s}", 
+    logger.info("Building Heartbeat: {responseId = %s, timestamp = %s, commandsInProgress = %s, componentsMapped = %s}",
         str(id), str(timestamp), repr(commandsInProgress), repr(componentsMapped))
         str(id), str(timestamp), repr(commandsInProgress), repr(componentsMapped))
-    
+
     if logger.isEnabledFor(logging.DEBUG):
     if logger.isEnabledFor(logging.DEBUG):
       logger.debug("Heartbeat: %s", pformat(heartbeat))
       logger.debug("Heartbeat: %s", pformat(heartbeat))
 
 
@@ -85,11 +85,11 @@ class Heartbeat:
       nodeInfo = { }
       nodeInfo = { }
       # for now, just do the same work as registration
       # for now, just do the same work as registration
       # this must be the last step before returning heartbeat
       # this must be the last step before returning heartbeat
-      hostInfo.register(nodeInfo, componentsMapped, commandsInProgress)      
+      hostInfo.register(nodeInfo, componentsMapped, commandsInProgress)
       heartbeat['agentEnv'] = nodeInfo
       heartbeat['agentEnv'] = nodeInfo
       mounts = Hardware.osdisks()
       mounts = Hardware.osdisks()
       heartbeat['mounts'] = mounts
       heartbeat['mounts'] = mounts
-            
+
       if logger.isEnabledFor(logging.DEBUG):
       if logger.isEnabledFor(logging.DEBUG):
         logger.debug("agentEnv: %s", str(nodeInfo))
         logger.debug("agentEnv: %s", str(nodeInfo))
         logger.debug("mounts: %s", str(mounts))
         logger.debug("mounts: %s", str(mounts))

+ 4 - 5
ambari-agent/src/main/python/ambari_agent/HostCleanup.py

@@ -82,10 +82,9 @@ PACKAGES_BLACK_LIST = ["ambari-server", "ambari-agent"]
 class HostCleanup:
 class HostCleanup:
   def resolve_ambari_config(self):
   def resolve_ambari_config(self):
     try:
     try:
-      config = AmbariConfig.config
+      config = AmbariConfig.AmbariConfig()
       if os.path.exists(configFile):
       if os.path.exists(configFile):
         config.read(configFile)
         config.read(configFile)
-        AmbariConfig.setConfig(config)
       else:
       else:
         raise Exception("No config found, use default")
         raise Exception("No config found, use default")
 
 
@@ -99,13 +98,13 @@ class HostCleanup:
     for patern in DIRNAME_PATTERNS:
     for patern in DIRNAME_PATTERNS:
       dirList.add(os.path.dirname(patern))
       dirList.add(os.path.dirname(patern))
 
 
-    for folder in dirList:  
+    for folder in dirList:
       for dirs in os.walk(folder):
       for dirs in os.walk(folder):
         for dir in dirs:
         for dir in dirs:
           for patern in DIRNAME_PATTERNS:
           for patern in DIRNAME_PATTERNS:
             if patern in dir:
             if patern in dir:
              resultList.append(dir)
              resultList.append(dir)
-    return resultList         
+    return resultList
 
 
   def do_cleanup(self, argMap=None):
   def do_cleanup(self, argMap=None):
     if argMap:
     if argMap:
@@ -136,7 +135,7 @@ class HostCleanup:
         self.do_erase_dir_silent(dirList)
         self.do_erase_dir_silent(dirList)
       if additionalDirList and not ADDITIONAL_DIRS in SKIP_LIST:
       if additionalDirList and not ADDITIONAL_DIRS in SKIP_LIST:
         logger.info("\n" + "Deleting additional directories: " + str(dirList))
         logger.info("\n" + "Deleting additional directories: " + str(dirList))
-        self.do_erase_dir_silent(additionalDirList)        
+        self.do_erase_dir_silent(additionalDirList)
       if repoList and not REPO_SECTION in SKIP_LIST:
       if repoList and not REPO_SECTION in SKIP_LIST:
         repoFiles = self.find_repo_files_for_repos(repoList)
         repoFiles = self.find_repo_files_for_repos(repoList)
         logger.info("\n" + "Deleting repo files: " + str(repoFiles))
         logger.info("\n" + "Deleting repo files: " + str(repoFiles))

+ 2 - 1
ambari-agent/src/main/python/ambari_agent/HostInfo.py

@@ -128,6 +128,7 @@ class HostInfo:
 
 
   def __init__(self, config=None):
   def __init__(self, config=None):
     self.packages = PackagesAnalyzer()
     self.packages = PackagesAnalyzer()
+    self.config = config
     self.reportFileHandler = HostCheckReportFileHandler(config)
     self.reportFileHandler = HostCheckReportFileHandler(config)
 
 
   def dirType(self, path):
   def dirType(self, path):
@@ -232,7 +233,7 @@ class HostInfo:
       'instance': None,
       'instance': None,
       'service': 'AMBARI',
       'service': 'AMBARI',
       'component': 'host',
       'component': 'host',
-      'host': hostname.hostname(),
+      'host': hostname.hostname(self.config),
       'state': 'OK',
       'state': 'OK',
       'label': 'Disk space',
       'label': 'Disk space',
       'text': 'Used disk space less than 80%'}
       'text': 'Used disk space less than 80%'}

+ 30 - 25
ambari-agent/src/main/python/ambari_agent/NetUtil.py

@@ -20,8 +20,13 @@ import logging
 import httplib
 import httplib
 from ssl import SSLError
 from ssl import SSLError
 
 
+ERROR_SSL_WRONG_VERSION = "SSLError: Failed to connect. Please check openssl library versions. \n" +\
+              "Refer to: https://bugzilla.redhat.com/show_bug.cgi?id=1022468 for more details."
+LOG_REQUEST_MESSAGE = "GET %s -> %s, body: %s"
+
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
+
 class NetUtil:
 class NetUtil:
 
 
   CONNECT_SERVER_RETRY_INTERVAL_SEC = 10
   CONNECT_SERVER_RETRY_INTERVAL_SEC = 10
@@ -30,54 +35,55 @@ class NetUtil:
 
 
   # Url within server to request during status check. This url
   # Url within server to request during status check. This url
   # should return HTTP code 200
   # should return HTTP code 200
-  SERVER_STATUS_REQUEST = "{0}/cert/ca"
-
+  SERVER_STATUS_REQUEST = "{0}/ca"
   # For testing purposes
   # For testing purposes
   DEBUG_STOP_RETRIES_FLAG = False
   DEBUG_STOP_RETRIES_FLAG = False
 
 
   def checkURL(self, url):
   def checkURL(self, url):
     """Try to connect to a given url. Result is True if url returns HTTP code 200, in any other case
     """Try to connect to a given url. Result is True if url returns HTTP code 200, in any other case
-    (like unreachable server or wrong HTTP code) result will be False
+    (like unreachable server or wrong HTTP code) result will be False.
+
+       Additionally returns body of request, if available
     """
     """
-    logger.info("Connecting to " + url);
-    
+    logger.info("Connecting to " + url)
+    responseBody = ""
+
     try:
     try:
       parsedurl = urlparse(url)
       parsedurl = urlparse(url)
       ca_connection = httplib.HTTPSConnection(parsedurl[1])
       ca_connection = httplib.HTTPSConnection(parsedurl[1])
-      ca_connection.request("HEAD", parsedurl[2])
-      response = ca_connection.getresponse()  
-      status = response.status    
-      
-      requestLogMessage = "HEAD %s -> %s"
-      
+      ca_connection.request("GET", parsedurl[2])
+      response = ca_connection.getresponse()
+      status = response.status
+
       if status == 200:
       if status == 200:
-        logger.debug(requestLogMessage, url, str(status) ) 
-        return True
-      else: 
-        logger.warning(requestLogMessage, url, str(status) )
-        return False
+        responseBody = response.read()
+        logger.debug(LOG_REQUEST_MESSAGE, url, str(status), responseBody)
+        return True, responseBody
+      else:
+        logger.warning(LOG_REQUEST_MESSAGE, url, str(status), responseBody)
+        return False, responseBody
     except SSLError as slerror:
     except SSLError as slerror:
       logger.error(str(slerror))
       logger.error(str(slerror))
-      logger.error("SSLError: Failed to connect. Please check openssl library versions. \n" +
-                   "Refer to: https://bugzilla.redhat.com/show_bug.cgi?id=1022468 for more details.")
-      return False
-    
+      logger.error(ERROR_SSL_WRONG_VERSION)
+      return False, responseBody
+
     except Exception, e:
     except Exception, e:
       logger.warning("Failed to connect to " + str(url) + " due to " + str(e) + "  ")
       logger.warning("Failed to connect to " + str(url) + " due to " + str(e) + "  ")
-      return False
+      return False, responseBody
 
 
-  def try_to_connect(self, server_url, max_retries, logger = None):
+  def try_to_connect(self, server_url, max_retries, logger=None):
     """Try to connect to a given url, sleeping for CONNECT_SERVER_RETRY_INTERVAL_SEC seconds
     """Try to connect to a given url, sleeping for CONNECT_SERVER_RETRY_INTERVAL_SEC seconds
     between retries. No more than max_retries is performed. If max_retries is -1, connection
     between retries. No more than max_retries is performed. If max_retries is -1, connection
     attempts will be repeated forever until server is not reachable
     attempts will be repeated forever until server is not reachable
+
     Returns count of retries
     Returns count of retries
     """
     """
     if logger is not None:
     if logger is not None:
       logger.debug("Trying to connect to %s", server_url)
       logger.debug("Trying to connect to %s", server_url)
-      
+
     retries = 0
     retries = 0
     while (max_retries == -1 or retries < max_retries) and not self.DEBUG_STOP_RETRIES_FLAG:
     while (max_retries == -1 or retries < max_retries) and not self.DEBUG_STOP_RETRIES_FLAG:
-      server_is_up = self.checkURL(self.SERVER_STATUS_REQUEST.format(server_url))
+      server_is_up, responseBody = self.checkURL(self.SERVER_STATUS_REQUEST.format(server_url))
       if server_is_up:
       if server_is_up:
         break
         break
       else:
       else:
@@ -87,4 +93,3 @@ class NetUtil:
         retries += 1
         retries += 1
         time.sleep(self.CONNECT_SERVER_RETRY_INTERVAL_SEC)
         time.sleep(self.CONNECT_SERVER_RETRY_INTERVAL_SEC)
     return retries
     return retries
-

+ 5 - 6
ambari-agent/src/main/python/ambari_agent/Register.py

@@ -28,7 +28,7 @@ from HostInfo import HostInfo
 
 
 firstContact = True
 firstContact = True
 class Register:
 class Register:
-  """ Registering with the server. Get the hardware profile and 
+  """ Registering with the server. Get the hardware profile and
   declare success for now """
   declare success for now """
   def __init__(self, config):
   def __init__(self, config):
     self.hardware = Hardware()
     self.hardware = Hardware()
@@ -37,19 +37,19 @@ class Register:
   def build(self, id='-1'):
   def build(self, id='-1'):
     global clusterId, clusterDefinitionRevision, firstContact
     global clusterId, clusterDefinitionRevision, firstContact
     timestamp = int(time.time()*1000)
     timestamp = int(time.time()*1000)
-   
+
     hostInfo = HostInfo(self.config)
     hostInfo = HostInfo(self.config)
     agentEnv = { }
     agentEnv = { }
     hostInfo.register(agentEnv, False, False)
     hostInfo.register(agentEnv, False, False)
 
 
     version = self.read_agent_version()
     version = self.read_agent_version()
     current_ping_port = self.config.get('agent','current_ping_port')
     current_ping_port = self.config.get('agent','current_ping_port')
-    
+
     register = { 'responseId'        : int(id),
     register = { 'responseId'        : int(id),
                  'timestamp'         : timestamp,
                  'timestamp'         : timestamp,
-                 'hostname'          : hostname.hostname(),
+                 'hostname'          : hostname.hostname(self.config),
                  'currentPingPort'   : int(current_ping_port),
                  'currentPingPort'   : int(current_ping_port),
-                 'publicHostname'    : hostname.public_hostname(),
+                 'publicHostname'    : hostname.public_hostname(self.config),
                  'hardwareProfile'   : self.hardware.get(),
                  'hardwareProfile'   : self.hardware.get(),
                  'agentEnv'          : agentEnv,
                  'agentEnv'          : agentEnv,
                  'agentVersion'      : version,
                  'agentVersion'      : version,
@@ -64,4 +64,3 @@ class Register:
     version = f.read().strip()
     version = f.read().strip()
     f.close()
     f.close()
     return version
     return version
-  

+ 8 - 8
ambari-agent/src/main/python/ambari_agent/hostname.py

@@ -30,15 +30,15 @@ logger = logging.getLogger()
 cached_hostname = None
 cached_hostname = None
 cached_public_hostname = None
 cached_public_hostname = None
 
 
-def hostname():
+
+def hostname(config):
   global cached_hostname
   global cached_hostname
   if cached_hostname is not None:
   if cached_hostname is not None:
     return cached_hostname
     return cached_hostname
 
 
-  config = AmbariConfig.config
   try:
   try:
     scriptname = config.get('agent', 'hostname_script')
     scriptname = config.get('agent', 'hostname_script')
-    try: 
+    try:
       osStat = subprocess.Popen([scriptname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
       osStat = subprocess.Popen([scriptname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
       out, err = osStat.communicate()
       out, err = osStat.communicate()
       if (0 == osStat.returncode and 0 != len(out.strip())):
       if (0 == osStat.returncode and 0 != len(out.strip())):
@@ -51,12 +51,12 @@ def hostname():
     cached_hostname = socket.getfqdn()
     cached_hostname = socket.getfqdn()
   return cached_hostname
   return cached_hostname
 
 
-def public_hostname():
+
+def public_hostname(config):
   global cached_public_hostname
   global cached_public_hostname
   if cached_public_hostname is not None:
   if cached_public_hostname is not None:
     return cached_public_hostname
     return cached_public_hostname
 
 
-  config = AmbariConfig.config
   out = ''
   out = ''
   err = ''
   err = ''
   try:
   try:
@@ -68,12 +68,12 @@ def public_hostname():
         cached_public_hostname = out.strip()
         cached_public_hostname = out.strip()
         return cached_public_hostname
         return cached_public_hostname
   except:
   except:
-    #ignore for now. 
+    #ignore for now.
     trace_info = traceback.format_exc()
     trace_info = traceback.format_exc()
-    logger.info("Error using the scriptname:" +  trace_info 
+    logger.info("Error using the scriptname:" +  trace_info
                 + " :out " + out + " :err " + err)
                 + " :out " + out + " :err " + err)
     logger.info("Defaulting to fqdn.")
     logger.info("Defaulting to fqdn.")
-    
+
   # future - do an agent entry for this too
   # future - do an agent entry for this too
   try:
   try:
     handle = urllib2.urlopen('http://169.254.169.254/latest/meta-data/public-hostname', '', 2)
     handle = urllib2.urlopen('http://169.254.169.254/latest/meta-data/public-hostname', '', 2)

+ 20 - 15
ambari-agent/src/main/python/ambari_agent/main.py

@@ -28,7 +28,7 @@ import time
 import ConfigParser
 import ConfigParser
 import ProcessHelper
 import ProcessHelper
 from Controller import Controller
 from Controller import Controller
-import AmbariConfig
+from AmbariConfig import AmbariConfig
 from NetUtil import NetUtil
 from NetUtil import NetUtil
 from PingPortListener import PingPortListener
 from PingPortListener import PingPortListener
 import hostname
 import hostname
@@ -38,7 +38,9 @@ import socket
 logger = logging.getLogger()
 logger = logging.getLogger()
 formatstr = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d - %(message)s"
 formatstr = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d - %(message)s"
 agentPid = os.getpid()
 agentPid = os.getpid()
-configFile = "/etc/ambari-agent/conf/ambari-agent.ini"
+config = AmbariConfig()
+configFile = config.CONFIG_FILE
+two_way_ssl_property = config.TWO_WAY_SSL_PROPERTY
 
 
 if 'AMBARI_LOG_DIR' in os.environ:
 if 'AMBARI_LOG_DIR' in os.environ:
   logfile = os.environ['AMBARI_LOG_DIR'] + "/ambari-agent.log"
   logfile = os.environ['AMBARI_LOG_DIR'] + "/ambari-agent.log"
@@ -104,12 +106,12 @@ def bind_signal_handlers():
   signal.signal(signal.SIGUSR1, debug)
   signal.signal(signal.SIGUSR1, debug)
 
 
 
 
+#  ToDo: move that function inside AmbariConfig
 def resolve_ambari_config():
 def resolve_ambari_config():
+  global config
   try:
   try:
-    config = AmbariConfig.config
     if os.path.exists(configFile):
     if os.path.exists(configFile):
-      config.read(configFile)
-      AmbariConfig.setConfig(config)
+        config.read(configFile)
     else:
     else:
       raise Exception("No config found, use default")
       raise Exception("No config found, use default")
 
 
@@ -121,8 +123,10 @@ def resolve_ambari_config():
 def perform_prestart_checks(expected_hostname):
 def perform_prestart_checks(expected_hostname):
   # Check if current hostname is equal to expected one (got from the server
   # Check if current hostname is equal to expected one (got from the server
   # during bootstrap.
   # during bootstrap.
+  global config
+
   if expected_hostname is not None:
   if expected_hostname is not None:
-    current_hostname = hostname.hostname()
+    current_hostname = hostname.hostname(config)
     if current_hostname != expected_hostname:
     if current_hostname != expected_hostname:
       print("Determined hostname does not match expected. Please check agent "
       print("Determined hostname does not match expected. Please check agent "
             "log for details")
             "log for details")
@@ -151,7 +155,7 @@ def daemonize():
   # and agent only dumps self pid to file
   # and agent only dumps self pid to file
   if not os.path.exists(ProcessHelper.piddir):
   if not os.path.exists(ProcessHelper.piddir):
     os.makedirs(ProcessHelper.piddir, 0755)
     os.makedirs(ProcessHelper.piddir, 0755)
-  
+
   pid = str(os.getpid())
   pid = str(os.getpid())
   file(ProcessHelper.pidfile, 'w').write(pid)
   file(ProcessHelper.pidfile, 'w').write(pid)
 
 
@@ -189,11 +193,12 @@ def main():
 
 
   setup_logging(options.verbose)
   setup_logging(options.verbose)
 
 
-  default_cfg = { 'agent' : { 'prefix' : '/home/ambari' } }
-  config = ConfigParser.RawConfigParser(default_cfg)
+  default_cfg = {'agent': {'prefix': '/home/ambari'}}
+  config.load(default_cfg)
+
   bind_signal_handlers()
   bind_signal_handlers()
 
 
-  if (len(sys.argv) >1) and sys.argv[1]=='stop':
+  if (len(sys.argv) > 1) and sys.argv[1] == 'stop':
     stop_agent()
     stop_agent()
 
 
   # Check for ambari configuration file.
   # Check for ambari configuration file.
@@ -201,7 +206,7 @@ def main():
 
 
   # Starting data cleanup daemon
   # Starting data cleanup daemon
   data_cleaner = None
   data_cleaner = None
-  if int(config.get('agent','data_cleanup_interval')) > 0:
+  if int(config.get('agent', 'data_cleanup_interval')) > 0:
     data_cleaner = DataCleaner(config)
     data_cleaner = DataCleaner(config)
     data_cleaner.start()
     data_cleaner.start()
 
 
@@ -213,7 +218,7 @@ def main():
     ping_port_listener = PingPortListener(config)
     ping_port_listener = PingPortListener(config)
   except Exception as ex:
   except Exception as ex:
     err_message = "Failed to start ping port listener of: " + str(ex)
     err_message = "Failed to start ping port listener of: " + str(ex)
-    logger.error(err_message);
+    logger.error(err_message)
     sys.stderr.write(err_message)
     sys.stderr.write(err_message)
     sys.exit(1)
     sys.exit(1)
   ping_port_listener.start()
   ping_port_listener.start()
@@ -221,13 +226,13 @@ def main():
   update_log_level(config)
   update_log_level(config)
 
 
   server_hostname = config.get('server', 'hostname')
   server_hostname = config.get('server', 'hostname')
-  server_url = 'https://' + server_hostname + ':' + config.get('server', 'url_port') 
-  
+  server_url = config.get_api_url()
+
   try:
   try:
     server_ip = socket.gethostbyname(server_hostname)
     server_ip = socket.gethostbyname(server_hostname)
     logger.info('Connecting to Ambari server at %s (%s)', server_url, server_ip)
     logger.info('Connecting to Ambari server at %s (%s)', server_url, server_ip)
   except socket.error:
   except socket.error:
-    logger.warn("Unable to determine the IP address of the Ambari server '%s'", server_hostname)  
+    logger.warn("Unable to determine the IP address of the Ambari server '%s'", server_hostname)
 
 
   # Wait until server is reachable
   # Wait until server is reachable
   netutil = NetUtil()
   netutil = NetUtil()

+ 46 - 42
ambari-agent/src/main/python/ambari_agent/security.py

@@ -30,8 +30,8 @@ import hostname
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
-GEN_AGENT_KEY="openssl req -new -newkey rsa:1024 -nodes -keyout %(keysdir)s/%(hostname)s.key\
-	-subj /OU=%(hostname)s/\
+GEN_AGENT_KEY = "openssl req -new -newkey rsa:1024 -nodes -keyout %(keysdir)s/%(hostname)s.key\
+  -subj /OU=%(hostname)s/\
         -out %(keysdir)s/%(hostname)s.csr"
         -out %(keysdir)s/%(hostname)s.csr"
 
 
 
 
@@ -39,30 +39,34 @@ class VerifiedHTTPSConnection(httplib.HTTPSConnection):
   """ Connecting using ssl wrapped sockets """
   """ Connecting using ssl wrapped sockets """
   def __init__(self, host, port=None, config=None):
   def __init__(self, host, port=None, config=None):
     httplib.HTTPSConnection.__init__(self, host, port=port)
     httplib.HTTPSConnection.__init__(self, host, port=port)
-    self.config=config
-    self.two_way_ssl_required=False
+    self.two_way_ssl_required = False
+    self.config = config
 
 
   def connect(self):
   def connect(self):
+    self.two_way_ssl_required = self.config.isTwoWaySSLConnection()
+    logger.debug("Server two-way SSL authentication required: %s" % str(self.two_way_ssl_required))
+    if self.two_way_ssl_required is True:
+      logger.info('Server require two-way SSL authentication. Use it instead of one-way...')
 
 
     if not self.two_way_ssl_required:
     if not self.two_way_ssl_required:
       try:
       try:
-        sock=self.create_connection()
+        sock = self.create_connection()
         self.sock = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_NONE)
         self.sock = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_NONE)
         logger.info('SSL connection established. Two-way SSL authentication is '
         logger.info('SSL connection established. Two-way SSL authentication is '
                     'turned off on the server.')
                     'turned off on the server.')
       except (ssl.SSLError, AttributeError):
       except (ssl.SSLError, AttributeError):
-        self.two_way_ssl_required=True
+        self.two_way_ssl_required = True
         logger.info('Insecure connection to https://' + self.host + ':' + self.port +
         logger.info('Insecure connection to https://' + self.host + ':' + self.port +
                     '/ failed. Reconnecting using two-way SSL authentication..')
                     '/ failed. Reconnecting using two-way SSL authentication..')
 
 
     if self.two_way_ssl_required:
     if self.two_way_ssl_required:
-      self.certMan=CertificateManager(self.config)
+      self.certMan = CertificateManager(self.config)
       self.certMan.initSecurity()
       self.certMan.initSecurity()
       agent_key = self.certMan.getAgentKeyName()
       agent_key = self.certMan.getAgentKeyName()
       agent_crt = self.certMan.getAgentCrtName()
       agent_crt = self.certMan.getAgentCrtName()
       server_crt = self.certMan.getSrvrCrtName()
       server_crt = self.certMan.getSrvrCrtName()
 
 
-      sock=self.create_connection()
+      sock = self.create_connection()
 
 
       try:
       try:
         self.sock = ssl.wrap_socket(sock,
         self.sock = ssl.wrap_socket(sock,
@@ -88,41 +92,40 @@ class VerifiedHTTPSConnection(httplib.HTTPSConnection):
       self.sock.close()
       self.sock.close()
     logger.info("SSL Connect being called.. connecting to the server")
     logger.info("SSL Connect being called.. connecting to the server")
     sock = socket.create_connection((self.host, self.port), 60)
     sock = socket.create_connection((self.host, self.port), 60)
-    sock.setsockopt( socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
     if self._tunnel_host:
     if self._tunnel_host:
       self.sock = sock
       self.sock = sock
       self._tunnel()
       self._tunnel()
 
 
     return sock
     return sock
 
 
+
 class CachedHTTPSConnection:
 class CachedHTTPSConnection:
   """ Caches a ssl socket and uses a single https connection to the server. """
   """ Caches a ssl socket and uses a single https connection to the server. """
-  
+
   def __init__(self, config):
   def __init__(self, config):
-    self.connected = False;
+    self.connected = False
     self.config = config
     self.config = config
     self.server = config.get('server', 'hostname')
     self.server = config.get('server', 'hostname')
     self.port = config.get('server', 'secured_url_port')
     self.port = config.get('server', 'secured_url_port')
     self.connect()
     self.connect()
-  
+
   def connect(self):
   def connect(self):
-    if  not self.connected:
+    if not self.connected:
       self.httpsconn = VerifiedHTTPSConnection(self.server, self.port, self.config)
       self.httpsconn = VerifiedHTTPSConnection(self.server, self.port, self.config)
       self.httpsconn.connect()
       self.httpsconn.connect()
       self.connected = True
       self.connected = True
     # possible exceptions are caught and processed in Controller
     # possible exceptions are caught and processed in Controller
 
 
-
-  
   def forceClear(self):
   def forceClear(self):
     self.httpsconn = VerifiedHTTPSConnection(self.server, self.port, self.config)
     self.httpsconn = VerifiedHTTPSConnection(self.server, self.port, self.config)
     self.connect()
     self.connect()
-    
-  def request(self, req): 
+
+  def request(self, req):
     self.connect()
     self.connect()
     try:
     try:
-      self.httpsconn.request(req.get_method(), req.get_full_url(), 
-                                  req.get_data(), req.headers)
+      self.httpsconn.request(req.get_method(), req.get_full_url(),
+                             req.get_data(), req.headers)
       response = self.httpsconn.getresponse()
       response = self.httpsconn.getresponse()
       readResponse = response.read()
       readResponse = response.read()
     except Exception as ex:
     except Exception as ex:
@@ -133,59 +136,60 @@ class CachedHTTPSConnection:
       self.connected = False
       self.connected = False
       raise IOError("Error occured during connecting to the server: " + str(ex))
       raise IOError("Error occured during connecting to the server: " + str(ex))
     return readResponse
     return readResponse
-  
+
+
 class CertificateManager():
 class CertificateManager():
   def __init__(self, config):
   def __init__(self, config):
     self.config = config
     self.config = config
     self.keysdir = self.config.get('security', 'keysdir')
     self.keysdir = self.config.get('security', 'keysdir')
-    self.server_crt=self.config.get('security', 'server_crt')
+    self.server_crt = self.config.get('security', 'server_crt')
     self.server_url = 'https://' + self.config.get('server', 'hostname') + ':' \
     self.server_url = 'https://' + self.config.get('server', 'hostname') + ':' \
        + self.config.get('server', 'url_port')
        + self.config.get('server', 'url_port')
-    
+
   def getAgentKeyName(self):
   def getAgentKeyName(self):
     keysdir = self.config.get('security', 'keysdir')
     keysdir = self.config.get('security', 'keysdir')
-    return keysdir + os.sep + hostname.hostname() + ".key"
+    return keysdir + os.sep + hostname.hostname(self.config) + ".key"
 
 
   def getAgentCrtName(self):
   def getAgentCrtName(self):
     keysdir = self.config.get('security', 'keysdir')
     keysdir = self.config.get('security', 'keysdir')
-    return keysdir + os.sep + hostname.hostname() + ".crt"
+    return keysdir + os.sep + hostname.hostname(self.config) + ".crt"
 
 
   def getAgentCrtReqName(self):
   def getAgentCrtReqName(self):
     keysdir = self.config.get('security', 'keysdir')
     keysdir = self.config.get('security', 'keysdir')
-    return keysdir + os.sep + hostname.hostname() + ".csr"
+    return keysdir + os.sep + hostname.hostname(self.config) + ".csr"
 
 
   def getSrvrCrtName(self):
   def getSrvrCrtName(self):
     keysdir = self.config.get('security', 'keysdir')
     keysdir = self.config.get('security', 'keysdir')
     return keysdir + os.sep + "ca.crt"
     return keysdir + os.sep + "ca.crt"
-    
+
   def checkCertExists(self):
   def checkCertExists(self):
-    
+
     s = self.config.get('security', 'keysdir') + os.sep + "ca.crt"
     s = self.config.get('security', 'keysdir') + os.sep + "ca.crt"
 
 
     server_crt_exists = os.path.exists(s)
     server_crt_exists = os.path.exists(s)
-    
+
     if not server_crt_exists:
     if not server_crt_exists:
       logger.info("Server certicate not exists, downloading")
       logger.info("Server certicate not exists, downloading")
       self.loadSrvrCrt()
       self.loadSrvrCrt()
     else:
     else:
       logger.info("Server certicate exists, ok")
       logger.info("Server certicate exists, ok")
-      
+
     agent_key_exists = os.path.exists(self.getAgentKeyName())
     agent_key_exists = os.path.exists(self.getAgentKeyName())
-    
+
     if not agent_key_exists:
     if not agent_key_exists:
       logger.info("Agent key not exists, generating request")
       logger.info("Agent key not exists, generating request")
       self.genAgentCrtReq()
       self.genAgentCrtReq()
     else:
     else:
       logger.info("Agent key exists, ok")
       logger.info("Agent key exists, ok")
-      
+
     agent_crt_exists = os.path.exists(self.getAgentCrtName())
     agent_crt_exists = os.path.exists(self.getAgentCrtName())
-    
+
     if not agent_crt_exists:
     if not agent_crt_exists:
       logger.info("Agent certificate not exists, sending sign request")
       logger.info("Agent certificate not exists, sending sign request")
       self.reqSignCrt()
       self.reqSignCrt()
     else:
     else:
       logger.info("Agent certificate exists, ok")
       logger.info("Agent certificate exists, ok")
-            
+
   def loadSrvrCrt(self):
   def loadSrvrCrt(self):
     get_ca_url = self.server_url + '/cert/ca/'
     get_ca_url = self.server_url + '/cert/ca/'
     logger.info("Downloading server cert from " + get_ca_url)
     logger.info("Downloading server cert from " + get_ca_url)
@@ -196,15 +200,15 @@ class CertificateManager():
     stream.close()
     stream.close()
     srvr_crt_f = open(self.getSrvrCrtName(), 'w+')
     srvr_crt_f = open(self.getSrvrCrtName(), 'w+')
     srvr_crt_f.write(response)
     srvr_crt_f.write(response)
-      
+
   def reqSignCrt(self):
   def reqSignCrt(self):
-    sign_crt_req_url = self.server_url + '/certs/' + hostname.hostname()
+    sign_crt_req_url = self.server_url + '/certs/' + hostname.hostname(self.config)
     agent_crt_req_f = open(self.getAgentCrtReqName())
     agent_crt_req_f = open(self.getAgentCrtReqName())
     agent_crt_req_content = agent_crt_req_f.read()
     agent_crt_req_content = agent_crt_req_f.read()
     passphrase_env_var = self.config.get('security', 'passphrase_env_var_name')
     passphrase_env_var = self.config.get('security', 'passphrase_env_var_name')
     passphrase = os.environ[passphrase_env_var]
     passphrase = os.environ[passphrase_env_var]
-    register_data = {'csr'       : agent_crt_req_content,
-                    'passphrase' : passphrase}
+    register_data = {'csr': agent_crt_req_content,
+                    'passphrase': passphrase}
     data = json.dumps(register_data)
     data = json.dumps(register_data)
     proxy_handler = urllib2.ProxyHandler({})
     proxy_handler = urllib2.ProxyHandler({})
     opener = urllib2.build_opener(proxy_handler)
     opener = urllib2.build_opener(proxy_handler)
@@ -219,9 +223,9 @@ class CertificateManager():
     except Exception:
     except Exception:
       logger.warn("Malformed response! data: " + str(data))
       logger.warn("Malformed response! data: " + str(data))
       data = {'result': 'ERROR'}
       data = {'result': 'ERROR'}
-    result=data['result']
+    result = data['result']
     if result == 'OK':
     if result == 'OK':
-      agentCrtContent=data['signedCa']
+      agentCrtContent = data['signedCa']
       agentCrtF = open(self.getAgentCrtName(), "w")
       agentCrtF = open(self.getAgentCrtName(), "w")
       agentCrtF.write(agentCrtContent)
       agentCrtF.write(agentCrtContent)
     else:
     else:
@@ -235,11 +239,11 @@ class CertificateManager():
       raise ssl.SSLError
       raise ssl.SSLError
 
 
   def genAgentCrtReq(self):
   def genAgentCrtReq(self):
-    generate_script = GEN_AGENT_KEY % {'hostname': hostname.hostname(),
-                                     'keysdir' : self.config.get('security', 'keysdir')}
+    generate_script = GEN_AGENT_KEY % {'hostname': hostname.hostname(self.config),
+                                     'keysdir': self.config.get('security', 'keysdir')}
     logger.info(generate_script)
     logger.info(generate_script)
     p = subprocess.Popen([generate_script], shell=True, stdout=subprocess.PIPE)
     p = subprocess.Popen([generate_script], shell=True, stdout=subprocess.PIPE)
     p.communicate()
     p.communicate()
-      
+
   def initSecurity(self):
   def initSecurity(self):
     self.checkCertExists()
     self.checkCertExists()

+ 4 - 2
ambari-agent/src/test/python/ambari_agent/TestActionQueue.py

@@ -178,7 +178,9 @@ class TestActionQueue(TestCase):
   def test_process_command(self, execute_status_command_mock,
   def test_process_command(self, execute_status_command_mock,
                            execute_command_mock, print_exc_mock):
                            execute_command_mock, print_exc_mock):
     dummy_controller = MagicMock()
     dummy_controller = MagicMock()
-    actionQueue = ActionQueue(AmbariConfig().getConfig(), dummy_controller)
+    config = AmbariConfig()
+    config.set('agent', 'tolerate_download_failures', "true")
+    actionQueue = ActionQueue(config, dummy_controller)
     execution_command = {
     execution_command = {
       'commandType' : ActionQueue.EXECUTION_COMMAND,
       'commandType' : ActionQueue.EXECUTION_COMMAND,
     }
     }
@@ -243,7 +245,7 @@ class TestActionQueue(TestCase):
         return self.original_open(file, mode)
         return self.original_open(file, mode)
     open_mock.side_effect = open_side_effect
     open_mock.side_effect = open_side_effect
 
 
-    config = AmbariConfig().getConfig()
+    config = AmbariConfig()
     tempdir = tempfile.gettempdir()
     tempdir = tempfile.gettempdir()
     config.set('agent', 'prefix', tempdir)
     config.set('agent', 'prefix', tempdir)
     config.set('agent', 'cache_dir', "/var/lib/ambari-agent/cache")
     config.set('agent', 'cache_dir', "/var/lib/ambari-agent/cache")

+ 5 - 5
ambari-agent/src/test/python/ambari_agent/TestCertGeneration.py

@@ -29,20 +29,20 @@ from ambari_agent import AmbariConfig
 class TestCertGeneration(TestCase):
 class TestCertGeneration(TestCase):
   def setUp(self):
   def setUp(self):
     self.tmpdir = tempfile.mkdtemp()
     self.tmpdir = tempfile.mkdtemp()
-    config = ConfigParser.RawConfigParser()
-    config.add_section('server')
+    config = AmbariConfig.AmbariConfig()
+    #config.add_section('server')
     config.set('server', 'hostname', 'example.com')
     config.set('server', 'hostname', 'example.com')
     config.set('server', 'url_port', '777')
     config.set('server', 'url_port', '777')
-    config.add_section('security')
+    #config.add_section('security')
     config.set('security', 'keysdir', self.tmpdir)
     config.set('security', 'keysdir', self.tmpdir)
     config.set('security', 'server_crt', 'ca.crt')
     config.set('security', 'server_crt', 'ca.crt')
     self.certMan = CertificateManager(config)
     self.certMan = CertificateManager(config)
-    
+
   def test_generation(self):
   def test_generation(self):
     self.certMan.genAgentCrtReq()
     self.certMan.genAgentCrtReq()
     self.assertTrue(os.path.exists(self.certMan.getAgentKeyName()))
     self.assertTrue(os.path.exists(self.certMan.getAgentKeyName()))
     self.assertTrue(os.path.exists(self.certMan.getAgentCrtReqName()))
     self.assertTrue(os.path.exists(self.certMan.getAgentCrtReqName()))
   def tearDown(self):
   def tearDown(self):
     shutil.rmtree(self.tmpdir)
     shutil.rmtree(self.tmpdir)
-    
+
 
 

+ 2 - 1
ambari-agent/src/test/python/ambari_agent/TestController.py

@@ -53,7 +53,8 @@ class TestController(unittest.TestCase):
 
 
 
 
     config = MagicMock()
     config = MagicMock()
-    config.get.return_value = "something"
+    #config.get.return_value = "something"
+    config.get.return_value = "5"
 
 
     self.controller = Controller.Controller(config)
     self.controller = Controller.Controller(config)
     self.controller.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS = 0.1
     self.controller.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS = 0.1

+ 11 - 11
ambari-agent/src/test/python/ambari_agent/TestHostname.py

@@ -21,7 +21,7 @@ limitations under the License.
 from unittest import TestCase
 from unittest import TestCase
 import unittest
 import unittest
 import ambari_agent.hostname as hostname
 import ambari_agent.hostname as hostname
-import ambari_agent.AmbariConfig as AmbariConfig
+from ambari_agent.AmbariConfig import AmbariConfig
 import socket
 import socket
 import tempfile
 import tempfile
 import shutil
 import shutil
@@ -33,7 +33,8 @@ class TestHostname(TestCase):
   def test_hostname(self):
   def test_hostname(self):
     hostname.cached_hostname = None
     hostname.cached_hostname = None
     hostname.cached_public_hostname = None
     hostname.cached_public_hostname = None
-    self.assertEquals(hostname.hostname(), socket.getfqdn(), 
+    config = AmbariConfig()
+    self.assertEquals(hostname.hostname(config), socket.getfqdn(),
                       "hostname should equal the socket-based hostname")
                       "hostname should equal the socket-based hostname")
     pass
     pass
 
 
@@ -46,14 +47,14 @@ class TestHostname(TestCase):
     os.chmod(tmpname, os.stat(tmpname).st_mode | stat.S_IXUSR)
     os.chmod(tmpname, os.stat(tmpname).st_mode | stat.S_IXUSR)
 
 
     tmpfile = file(tmpname, "w+")
     tmpfile = file(tmpname, "w+")
-    config = AmbariConfig.config
+    config = AmbariConfig()
     try:
     try:
       tmpfile.write("#!/bin/sh\n\necho 'test.example.com'")
       tmpfile.write("#!/bin/sh\n\necho 'test.example.com'")
       tmpfile.close()
       tmpfile.close()
 
 
       config.set('agent', 'hostname_script', tmpname)
       config.set('agent', 'hostname_script', tmpname)
 
 
-      self.assertEquals(hostname.hostname(), 'test.example.com', "expected hostname 'test.example.com'")
+      self.assertEquals(hostname.hostname(config), 'test.example.com', "expected hostname 'test.example.com'")
     finally:
     finally:
       os.remove(tmpname)
       os.remove(tmpname)
       config.remove_option('agent', 'hostname_script')
       config.remove_option('agent', 'hostname_script')
@@ -66,17 +67,17 @@ class TestHostname(TestCase):
     tmpname = fd[1]
     tmpname = fd[1]
     os.close(fd[0])
     os.close(fd[0])
     os.chmod(tmpname, os.stat(tmpname).st_mode | stat.S_IXUSR)
     os.chmod(tmpname, os.stat(tmpname).st_mode | stat.S_IXUSR)
-   
+
     tmpfile = file(tmpname, "w+")
     tmpfile = file(tmpname, "w+")
 
 
-    config = AmbariConfig.config
+    config = AmbariConfig()
     try:
     try:
       tmpfile.write("#!/bin/sh\n\necho 'test.example.com'")
       tmpfile.write("#!/bin/sh\n\necho 'test.example.com'")
       tmpfile.close()
       tmpfile.close()
 
 
       config.set('agent', 'public_hostname_script', tmpname)
       config.set('agent', 'public_hostname_script', tmpname)
 
 
-      self.assertEquals(hostname.public_hostname(), 'test.example.com', 
+      self.assertEquals(hostname.public_hostname(config), 'test.example.com',
                         "expected hostname 'test.example.com'")
                         "expected hostname 'test.example.com'")
     finally:
     finally:
       os.remove(tmpname)
       os.remove(tmpname)
@@ -87,9 +88,10 @@ class TestHostname(TestCase):
   def test_caching(self, getfqdn_mock):
   def test_caching(self, getfqdn_mock):
     hostname.cached_hostname = None
     hostname.cached_hostname = None
     hostname.cached_public_hostname = None
     hostname.cached_public_hostname = None
+    config = AmbariConfig()
     getfqdn_mock.side_effect = ["test.example.com", "test2.example.com'"]
     getfqdn_mock.side_effect = ["test.example.com", "test2.example.com'"]
-    self.assertEquals(hostname.hostname(), "test.example.com")
-    self.assertEquals(hostname.hostname(), "test.example.com")
+    self.assertEquals(hostname.hostname(config), "test.example.com")
+    self.assertEquals(hostname.hostname(config), "test.example.com")
     self.assertEqual(getfqdn_mock.call_count, 1)
     self.assertEqual(getfqdn_mock.call_count, 1)
     pass
     pass
 
 
@@ -97,5 +99,3 @@ if __name__ == "__main__":
   unittest.main(verbosity=2)
   unittest.main(verbosity=2)
 
 
 
 
-
-

+ 2 - 2
ambari-agent/src/test/python/ambari_agent/TestMain.py

@@ -242,11 +242,11 @@ class TestMain(unittest.TestCase):
   @patch.object(DataCleaner,"start")
   @patch.object(DataCleaner,"start")
   @patch.object(DataCleaner,"__init__")
   @patch.object(DataCleaner,"__init__")
   @patch.object(PingPortListener,"start")
   @patch.object(PingPortListener,"start")
-  @patch.object(PingPortListener,"__init__")  
+  @patch.object(PingPortListener,"__init__")
   def test_main(self, ping_port_init_mock, ping_port_start_mock, data_clean_init_mock,data_clean_start_mock,
   def test_main(self, ping_port_init_mock, ping_port_start_mock, data_clean_init_mock,data_clean_start_mock,
                 parse_args_mock, join_mock, start_mock, Controller_init_mock, try_to_connect_mock,
                 parse_args_mock, join_mock, start_mock, Controller_init_mock, try_to_connect_mock,
                 update_log_level_mock, daemonize_mock, perform_prestart_checks_mock,
                 update_log_level_mock, daemonize_mock, perform_prestart_checks_mock,
-                resolve_ambari_config_mock, stop_mock, bind_signal_handlers_mock, 
+                resolve_ambari_config_mock, stop_mock, bind_signal_handlers_mock,
                 setup_logging_mock, socket_mock):
                 setup_logging_mock, socket_mock):
     data_clean_init_mock.return_value = None
     data_clean_init_mock.return_value = None
     Controller_init_mock.return_value = None
     Controller_init_mock.return_value = None

+ 7 - 9
ambari-agent/src/test/python/ambari_agent/TestNetUtil.py

@@ -38,16 +38,16 @@ class TestNetUtil(unittest.TestCase):
 
 
     # test 200
     # test 200
     netutil = NetUtil.NetUtil()
     netutil = NetUtil.NetUtil()
-    self.assertTrue(netutil.checkURL("url"))
+    self.assertTrue(netutil.checkURL("url")[0])
 
 
     # test fail
     # test fail
     response.status = 404
     response.status = 404
-    self.assertFalse(netutil.checkURL("url"))
+    self.assertFalse(netutil.checkURL("url")[0])
 
 
     # test Exception
     # test Exception
     response.status = 200
     response.status = 200
     httpsConMock.side_effect = Exception("test")
     httpsConMock.side_effect = Exception("test")
-    self.assertFalse(netutil.checkURL("url"))
+    self.assertFalse(netutil.checkURL("url")[0])
 
 
 
 
   @patch("time.sleep")
   @patch("time.sleep")
@@ -55,15 +55,15 @@ class TestNetUtil(unittest.TestCase):
 
 
     netutil = NetUtil.NetUtil()
     netutil = NetUtil.NetUtil()
     checkURL = MagicMock(name="checkURL")
     checkURL = MagicMock(name="checkURL")
-    checkURL.return_value = True
+    checkURL.return_value = True, "test"
     netutil.checkURL = checkURL
     netutil.checkURL = checkURL
-    l = MagicMock()
 
 
     # one successful get
     # one successful get
     self.assertEqual(0, netutil.try_to_connect("url", 10))
     self.assertEqual(0, netutil.try_to_connect("url", 10))
 
 
     # got successful after N retries
     # got successful after N retries
-    gets = [True, False, False]
+    gets = [[True, ""], [False, ""], [False, ""]]
+
     def side_effect(*args):
     def side_effect(*args):
       return gets.pop()
       return gets.pop()
     checkURL.side_effect = side_effect
     checkURL.side_effect = side_effect
@@ -71,7 +71,5 @@ class TestNetUtil(unittest.TestCase):
 
 
     # max retries
     # max retries
     checkURL.side_effect = None
     checkURL.side_effect = None
-    checkURL.return_value = False
+    checkURL.return_value = False, "test"
     self.assertEqual(5, netutil.try_to_connect("url", 5))
     self.assertEqual(5, netutil.try_to_connect("url", 5))
-
-

+ 1 - 1
ambari-agent/src/test/python/ambari_agent/TestSecurity.py

@@ -45,7 +45,7 @@ class TestSecurity(unittest.TestCase):
     out = StringIO.StringIO()
     out = StringIO.StringIO()
     sys.stdout = out
     sys.stdout = out
     # Create config
     # Create config
-    self.config = AmbariConfig().getConfig()
+    self.config = AmbariConfig()
     # Instantiate CachedHTTPSConnection (skip connect() call)
     # Instantiate CachedHTTPSConnection (skip connect() call)
     with patch.object(security.VerifiedHTTPSConnection, "connect"):
     with patch.object(security.VerifiedHTTPSConnection, "connect"):
       self.cachedHTTPSConnection = security.CachedHTTPSConnection(self.config)
       self.cachedHTTPSConnection = security.CachedHTTPSConnection(self.config)

+ 5 - 17
ambari-server/src/main/java/org/apache/ambari/server/controller/AmbariServer.java

@@ -87,6 +87,7 @@ import org.apache.ambari.server.security.authorization.internal.AmbariInternalAu
 import org.apache.ambari.server.security.authorization.internal.InternalTokenAuthenticationFilter;
 import org.apache.ambari.server.security.authorization.internal.InternalTokenAuthenticationFilter;
 import org.apache.ambari.server.security.unsecured.rest.CertificateDownload;
 import org.apache.ambari.server.security.unsecured.rest.CertificateDownload;
 import org.apache.ambari.server.security.unsecured.rest.CertificateSign;
 import org.apache.ambari.server.security.unsecured.rest.CertificateSign;
+import org.apache.ambari.server.security.unsecured.rest.ConnectionInfo;
 import org.apache.ambari.server.state.Clusters;
 import org.apache.ambari.server.state.Clusters;
 import org.apache.ambari.server.state.ConfigHelper;
 import org.apache.ambari.server.state.ConfigHelper;
 import org.apache.ambari.server.utils.StageUtils;
 import org.apache.ambari.server.utils.StageUtils;
@@ -267,32 +268,18 @@ public class AmbariServer {
       sslConnectorTwoWay.setTruststoreType("PKCS12");
       sslConnectorTwoWay.setTruststoreType("PKCS12");
       sslConnectorTwoWay.setNeedClientAuth(configs.getTwoWaySsl());
       sslConnectorTwoWay.setNeedClientAuth(configs.getTwoWaySsl());
 
 
-      //Secured connector for 1-way auth
-      //SslSelectChannelConnector sslConnectorOneWay = new SslSelectChannelConnector();
+      //SSL Context Factory
       SslContextFactory contextFactory = new SslContextFactory(true);
       SslContextFactory contextFactory = new SslContextFactory(true);
-      //sslConnectorOneWay.setPort(AGENT_ONE_WAY_AUTH);
       contextFactory.setKeyStorePath(keystore);
       contextFactory.setKeyStorePath(keystore);
-      // sslConnectorOneWay.setKeystore(keystore);
       contextFactory.setTrustStore(keystore);
       contextFactory.setTrustStore(keystore);
-      // sslConnectorOneWay.setTruststore(keystore);
       contextFactory.setKeyStorePassword(srvrCrtPass);
       contextFactory.setKeyStorePassword(srvrCrtPass);
-      // sslConnectorOneWay.setPassword(srvrCrtPass);
-
       contextFactory.setKeyManagerPassword(srvrCrtPass);
       contextFactory.setKeyManagerPassword(srvrCrtPass);
-
-      // sslConnectorOneWay.setKeyPassword(srvrCrtPass);
-
       contextFactory.setTrustStorePassword(srvrCrtPass);
       contextFactory.setTrustStorePassword(srvrCrtPass);
-      //sslConnectorOneWay.setTrustPassword(srvrCrtPass);
-
       contextFactory.setKeyStoreType("PKCS12");
       contextFactory.setKeyStoreType("PKCS12");
-      //sslConnectorOneWay.setKeystoreType("PKCS12");
       contextFactory.setTrustStoreType("PKCS12");
       contextFactory.setTrustStoreType("PKCS12");
-
-      //sslConnectorOneWay.setTruststoreType("PKCS12");
       contextFactory.setNeedClientAuth(false);
       contextFactory.setNeedClientAuth(false);
-      // sslConnectorOneWay.setWantClientAuth(false);
-      // sslConnectorOneWay.setNeedClientAuth(false);
+
+      //Secured connector for 1-way auth
       SslSelectChannelConnector sslConnectorOneWay = new SslSelectChannelConnector(contextFactory);
       SslSelectChannelConnector sslConnectorOneWay = new SslSelectChannelConnector(contextFactory);
       sslConnectorOneWay.setPort(configs.getOneWayAuthPort());
       sslConnectorOneWay.setPort(configs.getOneWayAuthPort());
       sslConnectorOneWay.setAcceptors(2);
       sslConnectorOneWay.setAcceptors(2);
@@ -530,6 +517,7 @@ public class AmbariServer {
   public void performStaticInjection() {
   public void performStaticInjection() {
     AgentResource.init(injector.getInstance(HeartBeatHandler.class));
     AgentResource.init(injector.getInstance(HeartBeatHandler.class));
     CertificateDownload.init(injector.getInstance(CertificateManager.class));
     CertificateDownload.init(injector.getInstance(CertificateManager.class));
+    ConnectionInfo.init(injector.getInstance(Configuration.class));
     CertificateSign.init(injector.getInstance(CertificateManager.class));
     CertificateSign.init(injector.getInstance(CertificateManager.class));
     GetResource.init(injector.getInstance(ResourceManager.class));
     GetResource.init(injector.getInstance(ResourceManager.class));
     PersistKeyValueService.init(injector.getInstance(PersistKeyValueImpl.class));
     PersistKeyValueService.init(injector.getInstance(PersistKeyValueImpl.class));

+ 4 - 0
ambari-server/src/main/java/org/apache/ambari/server/security/SecurityFilter.java

@@ -84,6 +84,10 @@ public class SecurityFilter implements Filter {
         return true;
         return true;
       }
       }
 
 
+      if (Pattern.matches("/connection_info", url.getPath())) {
+          return true;
+      }
+
       if (Pattern.matches("/certs/[^/0-9][^/]*", url.getPath())) {
       if (Pattern.matches("/certs/[^/0-9][^/]*", url.getPath())) {
         return true;
         return true;
       }
       }

+ 54 - 0
ambari-server/src/main/java/org/apache/ambari/server/security/unsecured/rest/ConnectionInfo.java

@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ambari.server.security.unsecured.rest;
+
+import javax.ws.rs.GET;
+import javax.ws.rs.Path;
+import javax.ws.rs.Produces;
+import javax.ws.rs.core.MediaType;
+
+import org.apache.ambari.server.configuration.Configuration;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import com.google.inject.Inject;
+
+import java.util.HashMap;
+import java.util.Map;
+
+
+@Path("/connection_info")
+public class ConnectionInfo {
+    private static Log LOG = LogFactory.getLog(ConnectionInfo.class);
+    private static HashMap<String,String> response=new HashMap<String,String>();
+    private static Configuration conf;
+
+
+    @Inject
+    public static void init(Configuration instance){
+        conf = instance;
+        response.put(Configuration.SRVR_TWO_WAY_SSL_KEY,String.valueOf(conf.getTwoWaySsl()));
+    }
+
+    @GET
+    @Produces({MediaType.APPLICATION_JSON})
+    public Map<String,String> connectionType() {
+        return response;
+    }
+}

+ 1 - 1
ambari-server/src/test/java/org/apache/ambari/server/security/CertGenerationTest.java

@@ -162,7 +162,7 @@ public class CertGenerationTest {
     Map<String,String> config = certMan.configs.getConfigsMap();
     Map<String,String> config = certMan.configs.getConfigsMap();
     config.put(Configuration.PASSPHRASE_KEY,"passphrase");
     config.put(Configuration.PASSPHRASE_KEY,"passphrase");
 
 
-    String agentHostname = "agent_hostname1";
+    String agentHostname = "agent_hostname";
     SignCertResponse scr = certMan.signAgentCrt(agentHostname,
     SignCertResponse scr = certMan.signAgentCrt(agentHostname,
       "incorrect_agentCrtReqContent", "passphrase");
       "incorrect_agentCrtReqContent", "passphrase");
     //Revoke command wasn't executed
     //Revoke command wasn't executed