Pārlūkot izejas kodu

AMBARI-6978. Uncatched exception at ambari agent - it may die on connection error (dlysnichenko)

Lisnichenko Dmitro 11 gadi atpakaļ
vecāks
revīzija
8f02714e5b

+ 26 - 29
ambari-agent/src/main/python/ambari_agent/Controller.py

@@ -78,17 +78,20 @@ class Controller(threading.Thread):
   def __del__(self):
   def __del__(self):
     logger.info("Server connection disconnected.")
     logger.info("Server connection disconnected.")
     pass
     pass
-  
+
   def registerWithServer(self):
   def registerWithServer(self):
+    """
+    :return: returning from current method without setting self.isRegistered
+    to True will lead to agent termination.
+    """
     LiveStatus.SERVICES = []
     LiveStatus.SERVICES = []
     LiveStatus.CLIENT_COMPONENTS = []
     LiveStatus.CLIENT_COMPONENTS = []
     LiveStatus.COMPONENTS = []
     LiveStatus.COMPONENTS = []
-    id = -1
     ret = {}
     ret = {}
 
 
     while not self.isRegistered:
     while not self.isRegistered:
       try:
       try:
-        data = json.dumps(self.register.build(id))
+        data = json.dumps(self.register.build())
         prettyData = pprint.pformat(data)
         prettyData = pprint.pformat(data)
 
 
         try:
         try:
@@ -111,8 +114,7 @@ class Controller(threading.Thread):
           # log - message, which will be printed to agents log
           # log - message, which will be printed to agents log
           if 'log' in ret.keys():
           if 'log' in ret.keys():
             log = ret['log']
             log = ret['log']
-
-          logger.error(log)
+            logger.error(log)
           self.isRegistered = False
           self.isRegistered = False
           self.repeatRegistration = False
           self.repeatRegistration = False
           return ret
           return ret
@@ -122,23 +124,22 @@ class Controller(threading.Thread):
         self.responseId = int(ret['responseId'])
         self.responseId = int(ret['responseId'])
         self.isRegistered = True
         self.isRegistered = True
         if 'statusCommands' in ret.keys():
         if 'statusCommands' in ret.keys():
-          logger.info("Got status commands on registration " + pprint.pformat(ret['statusCommands']) )
+          logger.info("Got status commands on registration " + pprint.pformat(ret['statusCommands']))
           self.addToStatusQueue(ret['statusCommands'])
           self.addToStatusQueue(ret['statusCommands'])
           pass
           pass
         else:
         else:
           self.hasMappedComponents = False
           self.hasMappedComponents = False
         pass
         pass
       except ssl.SSLError:
       except ssl.SSLError:
-        self.repeatRegistration=False
+        self.repeatRegistration = False
         self.isRegistered = False
         self.isRegistered = False
         return
         return
       except Exception:
       except Exception:
         # try a reconnect only after a certain amount of random time
         # try a reconnect only after a certain amount of random time
         delay = randint(0, self.range)
         delay = randint(0, self.range)
-        logger.error("Unable to connect to: " + self.registerUrl, exc_info = True)
+        logger.error("Unable to connect to: " + self.registerUrl, exc_info=True)
         """ Sleeping for {0} seconds and then retrying again """.format(delay)
         """ Sleeping for {0} seconds and then retrying again """.format(delay)
         time.sleep(delay)
         time.sleep(delay)
-        pass
       pass
       pass
     return ret
     return ret
 
 
@@ -147,7 +148,7 @@ class Controller(threading.Thread):
     if commands:
     if commands:
       self.actionQueue.cancel(commands)
       self.actionQueue.cancel(commands)
     pass
     pass
-  
+
   def addToQueue(self, commands):
   def addToQueue(self, commands):
     """Add to the queue for running the commands """
     """Add to the queue for running the commands """
     """ Put the required actions into the Queue """
     """ Put the required actions into the Queue """
@@ -178,11 +179,8 @@ class Controller(threading.Thread):
     self.DEBUG_SUCCESSFULL_HEARTBEATS = 0
     self.DEBUG_SUCCESSFULL_HEARTBEATS = 0
     retry = False
     retry = False
     certVerifFailed = False
     certVerifFailed = False
-
     hb_interval = self.config.get('heartbeat', 'state_interval')
     hb_interval = self.config.get('heartbeat', 'state_interval')
 
 
-    #TODO make sure the response id is monotonically increasing
-    id = 0
     while not self.DEBUG_STOP_HEARTBEATING:
     while not self.DEBUG_STOP_HEARTBEATING:
       try:
       try:
         if not retry:
         if not retry:
@@ -212,7 +210,7 @@ class Controller(threading.Thread):
           logger.info('Heartbeat response received (id = %s)', serverId)
           logger.info('Heartbeat response received (id = %s)', serverId)
 
 
         if 'hasMappedComponents' in response.keys():
         if 'hasMappedComponents' in response.keys():
-          self.hasMappedComponents = response['hasMappedComponents'] != False
+          self.hasMappedComponents = response['hasMappedComponents'] is not False
 
 
         if 'registrationCommand' in response.keys():
         if 'registrationCommand' in response.keys():
           # check if the registration command is None. If none skip
           # check if the registration command is None. If none skip
@@ -226,7 +224,7 @@ class Controller(threading.Thread):
           logger.error("Error in responseId sequence - restarting")
           logger.error("Error in responseId sequence - restarting")
           self.restartAgent()
           self.restartAgent()
         else:
         else:
-          self.responseId=serverId
+          self.responseId = serverId
 
 
         if 'cancelCommands' in response.keys():
         if 'cancelCommands' in response.keys():
           self.cancelCommandInQueue(response['cancelCommands'])
           self.cancelCommandInQueue(response['cancelCommands'])
@@ -250,7 +248,7 @@ class Controller(threading.Thread):
         if retry:
         if retry:
           logger.info("Reconnected to %s", self.heartbeatUrl)
           logger.info("Reconnected to %s", self.heartbeatUrl)
 
 
-        retry=False
+        retry = False
         certVerifFailed = False
         certVerifFailed = False
         self.DEBUG_SUCCESSFULL_HEARTBEATS += 1
         self.DEBUG_SUCCESSFULL_HEARTBEATS += 1
         self.DEBUG_HEARTBEAT_RETRIES = 0
         self.DEBUG_HEARTBEAT_RETRIES = 0
@@ -260,10 +258,6 @@ class Controller(threading.Thread):
         self.isRegistered = False
         self.isRegistered = False
         return
         return
       except Exception, err:
       except Exception, err:
-        #randomize the heartbeat
-        delay = randint(0, self.range)
-        time.sleep(delay)
-
         if "code" in err:
         if "code" in err:
           logger.error(err.code)
           logger.error(err.code)
         else:
         else:
@@ -283,13 +277,17 @@ class Controller(threading.Thread):
             logger.warn("Server certificate verify failed. Did you regenerate server certificate?")
             logger.warn("Server certificate verify failed. Did you regenerate server certificate?")
             certVerifFailed = True
             certVerifFailed = True
 
 
-        self.cachedconnect = None # Previous connection is broken now
-        retry=True
+        self.cachedconnect = None  # Previous connection is broken now
+        retry = True
+
+        #randomize the heartbeat
+        delay = randint(0, self.range)
+        time.sleep(delay)
 
 
       # Sleep for some time
       # Sleep for some time
       timeout = self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC \
       timeout = self.netutil.HEARTBEAT_IDDLE_INTERVAL_SEC \
                 - self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS
                 - self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS
-      self.heartbeat_wait_event.wait(timeout = timeout)
+      self.heartbeat_wait_event.wait(timeout=timeout)
       # Sleep a bit more to allow STATUS_COMMAND results to be collected
       # Sleep a bit more to allow STATUS_COMMAND results to be collected
       # and sent in one heartbeat. Also avoid server overload with heartbeats
       # and sent in one heartbeat. Also avoid server overload with heartbeats
       time.sleep(self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS)
       time.sleep(self.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS)
@@ -345,17 +343,16 @@ class Controller(threading.Thread):
       return json.loads(response)
       return json.loads(response)
     except Exception, exception:
     except Exception, exception:
       if response is None:
       if response is None:
-        err_msg = 'Request to {0} failed due to {1}'.format(url, str(exception))
-        return {'exitstatus': 1, 'log': err_msg}
+        raise IOError('Request to {0} failed due to {1}'.format(url, str(exception)))
       else:
       else:
-        err_msg = ('Response parsing failed! Request data: ' + str(data)
-            + '; Response: ' + str(response))
-        logger.warn(err_msg)
-        return {'exitstatus': 1, 'log': err_msg}
+        raise IOError('Response parsing failed! Request data: ' + str(data)
+                      + '; Response: ' + str(response))
+
 
 
   def updateComponents(self, cluster_name):
   def updateComponents(self, cluster_name):
     logger.info("Updating components map of cluster " + cluster_name)
     logger.info("Updating components map of cluster " + cluster_name)
 
 
+    # May throw IOError on server connection error
     response = self.sendRequest(self.componentsUrl + cluster_name, None)
     response = self.sendRequest(self.componentsUrl + cluster_name, None)
     logger.debug("Response from %s was %s", self.serverHostname, str(response))
     logger.debug("Response from %s was %s", self.serverHostname, str(response))
 
 

+ 64 - 12
ambari-agent/src/test/python/ambari_agent/TestController.py

@@ -30,7 +30,7 @@ from threading import Event
 import json
 import json
 
 
 with patch("platform.linux_distribution", return_value = ('Suse','11','Final')):
 with patch("platform.linux_distribution", return_value = ('Suse','11','Final')):
-  from ambari_agent import Controller, ActionQueue
+  from ambari_agent import Controller, ActionQueue, Register
   from ambari_agent import hostname
   from ambari_agent import hostname
   from ambari_agent.Controller import AGENT_AUTO_RESTART_EXIT_CODE
   from ambari_agent.Controller import AGENT_AUTO_RESTART_EXIT_CODE
   from ambari_commons import OSCheck
   from ambari_commons import OSCheck
@@ -247,9 +247,9 @@ class TestController(unittest.TestCase):
     heartbeatWithServer.assert_called_once_with()
     heartbeatWithServer.assert_called_once_with()
 
 
     self.controller.registerWithServer =\
     self.controller.registerWithServer =\
-    Controller.Controller.registerWithServer
+      Controller.Controller.registerWithServer
     self.controller.heartbeatWithServer =\
     self.controller.heartbeatWithServer =\
-    Controller.Controller.registerWithServer
+      Controller.Controller.registerWithServer
 
 
   @patch("time.sleep")
   @patch("time.sleep")
   def test_registerAndHeartbeat(self, sleepMock):
   def test_registerAndHeartbeat(self, sleepMock):
@@ -300,6 +300,33 @@ class TestController(unittest.TestCase):
       Controller.Controller.registerWithServer
       Controller.Controller.registerWithServer
 
 
 
 
+  @patch("time.sleep")
+  @patch.object(Controller.Controller, "sendRequest")
+  def test_registerWithIOErrors(self, sendRequestMock, sleepMock):
+    # Check that server continues to heartbeat after connection errors
+    registerMock = MagicMock(name="Register")
+    registerMock.build.return_value = {}
+    actionQueue = MagicMock()
+    actionQueue.isIdle.return_value = True
+    self.controller.actionQueue = actionQueue
+    self.controller.register = registerMock
+    self.controller.responseId = 1
+    self.controller.TEST_IOERROR_COUNTER = 1
+    self.controller.isRegistered = False
+    def util_throw_IOErrors(*args, **kwargs):
+      """
+      Throws IOErrors 10 times and then stops heartbeats/registrations
+      """
+      if self.controller.TEST_IOERROR_COUNTER == 10:
+        self.controller.isRegistered = True
+      self.controller.TEST_IOERROR_COUNTER += 1
+      raise IOError("Sample error")
+    actionQueue.isIdle.return_value = False
+    sendRequestMock.side_effect = util_throw_IOErrors
+    self.controller.registerWithServer()
+    self.assertTrue(sendRequestMock.call_count > 5)
+
+
   @patch("os._exit")
   @patch("os._exit")
   def test_restartAgent(self, os_exit_mock):
   def test_restartAgent(self, os_exit_mock):
 
 
@@ -331,18 +358,22 @@ class TestController(unittest.TestCase):
       {'Content-Type': 'application/json'})
       {'Content-Type': 'application/json'})
 
 
     conMock.request.return_value = '{invalid_object}'
     conMock.request.return_value = '{invalid_object}'
-    actual = self.controller.sendRequest(url, data)
-    expected = {'exitstatus': 1, 'log': ('Response parsing failed! Request data: ' + data
-                                         + '; Response: {invalid_object}')}
-    self.assertEqual(actual, expected)
+
+    try:
+      self.controller.sendRequest(url, data)
+      self.fail("Should throw exception!")
+    except IOError, e: # Expected
+      self.assertEquals('Response parsing failed! Request data: ' + data +
+                        '; Response: {invalid_object}', e.message)
 
 
     exceptionMessage = "Connection Refused"
     exceptionMessage = "Connection Refused"
     conMock.request.side_effect = Exception(exceptionMessage)
     conMock.request.side_effect = Exception(exceptionMessage)
-    actual = self.controller.sendRequest(url, data)
-    expected = {'exitstatus': 1, 'log': 'Request to ' + url + ' failed due to ' + exceptionMessage}
-
-    self.assertEqual(actual, expected)
-
+    try:
+      self.controller.sendRequest(url, data)
+      self.fail("Should throw exception!")
+    except IOError, e: # Expected
+      self.assertEquals('Request to ' + url + ' failed due to ' +
+                        exceptionMessage, e.message)
 
 
 
 
   @patch.object(threading._Event, "wait")
   @patch.object(threading._Event, "wait")
@@ -480,6 +511,27 @@ class TestController(unittest.TestCase):
     response["restartAgent"] = "false"
     response["restartAgent"] = "false"
     self.controller.heartbeatWithServer()
     self.controller.heartbeatWithServer()
 
 
+    sleepMock.assert_called_with(
+      self.controller.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS)
+
+    # Check that server continues to heartbeat after connection errors
+    self.controller.responseId = 1
+    self.controller.TEST_IOERROR_COUNTER = 1
+    sendRequest.reset()
+    def util_throw_IOErrors(*args, **kwargs):
+      """
+      Throws IOErrors 100 times and then stops heartbeats/registrations
+      """
+      if self.controller.TEST_IOERROR_COUNTER == 10:
+        self.controller.DEBUG_STOP_HEARTBEATING = True
+      self.controller.TEST_IOERROR_COUNTER += 1
+      raise IOError("Sample error")
+    self.controller.DEBUG_STOP_HEARTBEATING = False
+    actionQueue.isIdle.return_value = False
+    sendRequest.side_effect = util_throw_IOErrors
+    self.controller.heartbeatWithServer()
+    self.assertTrue(sendRequest.call_count > 5)
+
     sleepMock.assert_called_with(
     sleepMock.assert_called_with(
       self.controller.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS)
       self.controller.netutil.MINIMUM_INTERVAL_BETWEEN_HEARTBEATS)