From e33cb01b79e5b50c28ed267c29a88163cae4d777 Mon Sep 17 00:00:00 2001
From: Min RK <benjaminrk@gmail.com>
Date: Wed, 20 Apr 2016 11:37:07 +0200
Subject: thread auth fixes

fix timeouts in zmq.tests.test_auth

Bug: https://github.com/zeromq/pyzmq/pull/839
Patch-Name: thread_auth_fixes
---
 zmq/auth/base.py       |  5 +++++
 zmq/auth/thread.py     |  9 +++++++--
 zmq/tests/test_auth.py | 17 ++++++++++-------
 3 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/zmq/auth/base.py b/zmq/auth/base.py
index 72fccb1..06952a0 100644
--- a/zmq/auth/base.py
+++ b/zmq/auth/base.py
@@ -49,6 +49,7 @@ class Authenticator(object):
         self.zap_socket = self.context.socket(zmq.REP)
         self.zap_socket.linger = 1
         self.zap_socket.bind("inproc://zeromq.zap.01")
+        self.log.debug("Starting")
 
     def stop(self):
         """Close the ZAP socket"""
@@ -68,6 +69,7 @@ class Authenticator(object):
         """
         if self.blacklist:
             raise ValueError("Only use a whitelist or a blacklist, not both")
+        self.log.debug("Allowing %s", ','.join(addresses))
         self.whitelist.update(addresses)
 
     def deny(self, *addresses):
@@ -79,6 +81,7 @@ class Authenticator(object):
         """
         if self.whitelist:
             raise ValueError("Only use a whitelist or a blacklist, not both")
+        self.log.debug("Denying %s", ','.join(addresses))
         self.blacklist.update(addresses)
 
     def configure_plain(self, domain='*', passwords=None):
@@ -90,6 +93,7 @@ class Authenticator(object):
         """
         if passwords:
             self.passwords[domain] = passwords
+        self.log.debug("Configure plain: %s", domain)
 
     def configure_curve(self, domain='*', location=None):
         """Configure CURVE authentication for a given domain.
@@ -105,6 +109,7 @@ class Authenticator(object):
         """
         # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
         # treat location as a directory that holds the certificates.
+        self.log.debug("Configure curve: %s[%s]", domain, location)
         if location == CURVE_ALLOW_ANY:
             self.allow_any = True
         else:
diff --git a/zmq/auth/thread.py b/zmq/auth/thread.py
index 24457d6..e26916c 100644
--- a/zmq/auth/thread.py
+++ b/zmq/auth/thread.py
@@ -6,8 +6,9 @@
 # Copyright (C) PyZMQ Developers
 # Distributed under the terms of the Modified BSD License.
 
+import time
 import logging
-from threading import Thread
+from threading import Thread, Event
 
 import zmq
 from zmq.utils import jsonapi
@@ -26,6 +27,7 @@ class AuthenticationThread(Thread):
         self.context = context or zmq.Context.instance()
         self.encoding = encoding
         self.log = log = log or logging.getLogger('zmq.auth')
+        self.started = Event()
         self.authenticator = authenticator or Authenticator(context, encoding=encoding, log=log)
 
         # create a socket to communicate back to main thread.
@@ -34,8 +36,9 @@ class AuthenticationThread(Thread):
         self.pipe.connect(endpoint)
 
     def run(self):
-        """ Start the Authentication Agent thread task """
+        """Start the Authentication Agent thread task"""
         self.authenticator.start()
+        self.started.set()
         zap = self.authenticator.zap_socket
         poller = zmq.Poller()
         poller.register(self.pipe, zmq.POLLIN)
@@ -161,6 +164,8 @@ class ThreadAuthenticator(object):
         self.pipe.bind(self.pipe_endpoint)
         self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log)
         self.thread.start()
+        if not self.thread.started.wait(timeout=10):
+            raise RuntimeError("Authenticator thread failed to start")
 
     def stop(self):
         """Stop the authentication thread"""
diff --git a/zmq/tests/test_auth.py b/zmq/tests/test_auth.py
index beff925..af521ea 100644
--- a/zmq/tests/test_auth.py
+++ b/zmq/tests/test_auth.py
@@ -101,7 +101,8 @@ class TestThreadAuthentication(BaseAuthTestCase):
         port = server.bind_to_random_port(iface)
         client.connect("%s:%i" % (iface, port))
         msg = [b"Hello World"]
-        server.send_multipart(msg)
+        if server.poll(1000, zmq.POLLOUT):
+            server.send_multipart(msg)
         if client.poll(1000):
             rcvd_msg = client.recv_multipart()
             self.assertEqual(rcvd_msg, msg)
@@ -114,6 +115,8 @@ class TestThreadAuthentication(BaseAuthTestCase):
         # go through our authentication infrastructure at all.
         self.auth.stop()
         self.auth = None
+        # use a new context, so ZAP isn't inherited
+        self.context = self.Context()
         
         server = self.socket(zmq.PUSH)
         client = self.socket(zmq.PULL)
@@ -242,20 +245,20 @@ def with_ioloop(method, expect_success=True):
     """decorator for running tests with an IOLoop"""
     def test_method(self):
         r = method(self)
-        
+
         loop = self.io_loop
         if expect_success:
             self.pullstream.on_recv(self.on_message_succeed)
         else:
             self.pullstream.on_recv(self.on_message_fail)
         
-        t = loop.time()
-        loop.add_callback(self.attempt_connection)
-        loop.add_callback(self.send_msg)
+        loop.call_later(1, self.attempt_connection)
+        loop.call_later(1.2, self.send_msg)
+        
         if expect_success:
-            loop.add_timeout(t + 1, self.on_test_timeout_fail)
+            loop.call_later(2, self.on_test_timeout_fail)
         else:
-            loop.add_timeout(t + 1, self.on_test_timeout_succeed)
+            loop.call_later(2, self.on_test_timeout_succeed)
         
         loop.start()
         if self.fail_msg:
