Created
April 4, 2012 06:27
-
-
Save blackwithwhite666/2298905 to your computer and use it in GitHub Desktop.
Base primitives for txAMQP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABCMeta, abstractmethod | |
| from google.protobuf.message import Message | |
| from twisted.application.internet import _AbstractClient, _maybeGlobalReactor | |
| from twisted.internet import protocol | |
| from twisted.internet.defer import inlineCallbacks, Deferred, returnValue, \ | |
| maybeDeferred, DeferredLock | |
| from twisted.python import log | |
| from txamqp.client import Closed as ClientClosed, TwistedDelegate | |
| from txamqp.content import Content | |
| from txamqp.protocol import AMQClient | |
| from txamqp.queue import Closed as QueueClosed | |
| from txamqp.spec import load | |
| from uuid import uuid4 | |
| import os | |
| class ProtobufContent(Content): | |
| def __init__(self, body="", children=None, | |
| properties=None, message=None): | |
| # Set defaults value | |
| self._body = self._message = None | |
| # Set current message | |
| self.message = message | |
| # Initialize class | |
| super(ProtobufContent, self).__init__(body, children, properties) | |
| def getMessage(self): | |
| return self._message | |
| def setMessage(self, message): | |
| assert(isinstance(message, Message)) | |
| self._message = message | |
| message = property(getMessage, setMessage) | |
| def getBody(self): | |
| return self.message.SerializeToString() | |
| def setBody(self, body): | |
| self._body = body | |
| if body: | |
| self._message.ParseFromString(body) | |
| body = property(getBody, setBody) | |
| @classmethod | |
| def create(cls, content, message): | |
| return cls(body=content.body, | |
| children=content.children, | |
| properties=content.properties, | |
| message=message) | |
| class AmqpProtocol(AMQClient): | |
| def __init__(self, reactor, *args, **kwargs): | |
| self.factory = None | |
| self.connected = False | |
| reactor.addSystemEventTrigger("before", "shutdown", self.disconnect) | |
| AMQClient.__init__(self, *args, **kwargs) | |
| @inlineCallbacks | |
| def connectionMade(self): | |
| AMQClient.connectionMade(self) | |
| yield self.start({"LOGIN": self.factory.user, | |
| "PASSWORD": self.factory.password}) | |
| self.connected = True | |
| self.factory.deferred.callback(self) | |
| def connectionLost(self, reason): | |
| self.connected = False | |
| AMQClient.connectionLost(self, reason) | |
| del self.factory | |
| @inlineCallbacks | |
| def disconnect(self): | |
| if self.connected: | |
| chan0 = yield self.channel(0) | |
| try: | |
| yield chan0.connection_close() | |
| except ClientClosed, e: | |
| pass | |
| self.connected = False | |
| class Client(object): | |
| __metaclass__ = ABCMeta | |
| def __init__(self, client): | |
| # Wait for initialization | |
| self.deferred = Deferred() | |
| # Our client | |
| self.client = client | |
| # Our channel | |
| self._channel = None | |
| # Setup client | |
| self.setup() | |
| @inlineCallbacks | |
| def createChannel(self): | |
| channel = yield self.client.channel(len(self.client.channels) + 1) | |
| yield channel.channel_open() | |
| returnValue(channel) | |
| @inlineCallbacks | |
| def createExchange(self, name, options=None): | |
| ch = yield self.channel() | |
| if name != '': | |
| _options = {"type": "direct", | |
| "durable": True, | |
| "auto_delete": False} | |
| _options.update(options or {}) | |
| yield ch.exchange_declare(exchange=name, **_options) | |
| @inlineCallbacks | |
| def createQueue(self, name, exchange_name, routing_key, options=None): | |
| ch = yield self.channel() | |
| _options = {"durable": True, | |
| "exclusive": False, | |
| "auto_delete": False} | |
| _options.update(options or {}) | |
| method = yield ch.queue_declare(queue=name, **_options) | |
| name = method[0] | |
| if exchange_name != "": | |
| yield ch.queue_bind(queue=name, | |
| exchange=exchange_name, | |
| routing_key=routing_key) | |
| returnValue(name) | |
| @inlineCallbacks | |
| def channel(self): | |
| if self._channel is None: | |
| _channel = self._channel = yield self.createChannel() | |
| else: | |
| _channel = yield self._channel | |
| returnValue(_channel) | |
| @inlineCallbacks | |
| def queue(self, name): | |
| queue = yield self.client.queue(name) | |
| returnValue(queue) | |
| def error(self, failure): | |
| if failure.check(QueueClosed) is not None: | |
| log.msg("Queue closed") | |
| else: | |
| log.msg("Error reading item: ", failure) | |
| def setup(self): | |
| log.msg('Register %r' % self) | |
| maybeDeferred(self.initialize).chainDeferred(self.deferred) | |
| @abstractmethod | |
| def initialize(self): | |
| raise NotImplementedError() | |
| @inlineCallbacks | |
| def wait(self): | |
| if self.deferred.called: | |
| yield | |
| else: | |
| yield self.deferred | |
| class Consumer(Client): | |
| def __init__(self, client, exchange, routing_key, callback, | |
| queue=None, consumer_tag=None, | |
| exchange_options=None, queue_options=None): | |
| # Set exchange name | |
| self.exchange_name = exchange | |
| self.exchange_options = exchange_options | |
| # For now use the exchange name as the queue name. | |
| self.queue_name = queue or exchange | |
| self.queue_options = queue_options | |
| # Use the exchange name for the consumer tag for now. | |
| self.consumer_tag = consumer_tag or exchange | |
| self.routing_key = routing_key | |
| # Set callback | |
| self.callback = callback | |
| # Start consumer | |
| super(Consumer, self).__init__(client) | |
| @inlineCallbacks | |
| def initialize(self): | |
| # Declare the exchange in case it doesn't exist. | |
| yield self.createExchange(self.exchange_name, self.exchange_options) | |
| # Declare the queue and bind to it. | |
| self.queue_name = yield self.createQueue(self.queue_name, | |
| self.exchange_name, | |
| self.routing_key, | |
| self.queue_options) | |
| # Update consumer tag | |
| self.consumer_tag = (self.consumer_tag | |
| if self.consumer_tag | |
| else self.queue_name) | |
| # Get channel | |
| ch = yield self.channel() | |
| # Consume. | |
| yield ch.basic_consume(queue=self.queue_name, | |
| no_ack=True, | |
| consumer_tag=self.consumer_tag) | |
| # Get queue | |
| queue = yield self.queue(self.consumer_tag) | |
| # Start receive events | |
| self.receive(queue) | |
| @inlineCallbacks | |
| def receive(self, queue): | |
| yield self.wait() | |
| def _process(item): | |
| if item is not None: | |
| return self.process(item.content) | |
| def _get(item=None): | |
| d = queue.get() | |
| d.addCallback(_get) | |
| d.addErrback(self.error) | |
| d.addCallback(_process) | |
| return item | |
| _get() | |
| @inlineCallbacks | |
| def process(self, item): | |
| assert(item is not None) | |
| try: | |
| yield maybeDeferred(self.callback, item) | |
| except: | |
| log.err() | |
| def __repr__(self): | |
| return ('<%s: exchange "%s", routing_key "%s">' % | |
| (self.__class__.__name__, | |
| self.exchange_name, | |
| self.routing_key)) | |
| class ProtobufConsumer(Consumer): | |
| def __init__(self, *args, **kwargs): | |
| assert('message_cls' in kwargs) | |
| self.message_cls = kwargs.pop('message_cls') | |
| super(ProtobufConsumer, self).__init__(*args, **kwargs) | |
| def process(self, item): | |
| item = ProtobufContent.create(item, self.message_cls()) | |
| return super(ProtobufConsumer, self).process(item) | |
| class Producer(Client): | |
| def __init__(self, client, exchange, exchange_options=None): | |
| # Set exchange name | |
| self.exchange_name = exchange | |
| self.exchange_options = exchange_options | |
| # Register producer | |
| super(Producer, self).__init__(client) | |
| @inlineCallbacks | |
| def initialize(self): | |
| # First declare the exchange just in case it doesn't exist. | |
| yield self.createExchange(self.exchange_name, self.exchange_options) | |
| @inlineCallbacks | |
| def send(self, content, routing_key=''): | |
| assert(isinstance(content, Content)) | |
| yield self.wait() | |
| ch = yield self.channel() | |
| yield ch.basic_publish(exchange=self.exchange_name, | |
| routing_key=routing_key, | |
| content=content) | |
| def __repr__(self): | |
| return ('<%s: exchange "%s">' % | |
| (self.__class__.__name__, | |
| self.exchange_name)) | |
| class ProtobufProducer(Producer): | |
| def send(self, message, routing_key='', **kwargs): | |
| assert(isinstance(message, Message)) | |
| content = ProtobufContent(message=message, **kwargs) | |
| return super(ProtobufProducer, self).send(content, routing_key) | |
| class RPCServer(object): | |
| def __init__(self, factory): | |
| self.factory = factory | |
| @inlineCallbacks | |
| def register(self, exchange, name, callback, request_cls, response_cls): | |
| log.msg("Register method '%s' in exchange '%s'" % (name, exchange)) | |
| # Create producer | |
| producer = yield self.factory.producer(exchange='') | |
| # Callback for consumer | |
| def _cb(content): | |
| return self.process(content=content, | |
| name=name, | |
| producer=producer, | |
| callback=callback, | |
| response=response_cls()) | |
| # Create consumer | |
| consumer = yield self.factory.consumer(exchange=exchange, | |
| routing_key=name, | |
| callback=_cb, | |
| message_cls=request_cls) | |
| # Wait for callback registred | |
| yield consumer.wait() | |
| @inlineCallbacks | |
| def process(self, content, name, producer, callback, response): | |
| # check input message | |
| properties = content.properties | |
| assert("correlation id" in properties and "reply to" in properties) | |
| # log input request | |
| log.msg("Handle request '%s' with id '%s'" % | |
| (name, properties['correlation id'])) | |
| # handle request | |
| try: | |
| yield maybeDeferred(callback, | |
| content.message, | |
| response) | |
| except: | |
| log.err() | |
| # log response | |
| log.msg("Send response with id '%s' to '%s'" % | |
| (properties['correlation id'], properties['reply to'])) | |
| # send reply | |
| yield producer.send(response, | |
| routing_key=properties['reply to'], | |
| properties={'correlation id': | |
| properties['correlation id']}) | |
| class RPCClient(object): | |
| def __init__(self, exchange, factory): | |
| self.exchange = exchange | |
| self.factory = factory | |
| self.producer = self.consumer = None | |
| self.requests = {} | |
| @inlineCallbacks | |
| def initialize(self): | |
| self.producer = yield self.factory.producer(exchange=self.exchange) | |
| yield self.producer.wait() | |
| self.consumer = yield self.factory.consumer(exchange='', queue='', | |
| routing_key='', | |
| queue_options={'exclusive': True}, | |
| callback=self.receive, | |
| kls=Consumer) | |
| yield self.consumer.wait() | |
| @inlineCallbacks | |
| def invoke(self, name, request, response): | |
| uuid = str(uuid4()) | |
| yield self.producer.send(request, routing_key=name, | |
| properties={'correlation id': uuid, | |
| 'reply to': self.consumer.queue_name}) | |
| body = yield self.register(uuid) | |
| response.ParseFromString(body) | |
| returnValue(response) | |
| def register(self, uuid): | |
| d = Deferred() | |
| self.requests[uuid] = d | |
| return d | |
| def receive(self, content): | |
| properties = content.properties | |
| assert("correlation id" in properties) | |
| d = self.requests.get(properties["correlation id"], None) | |
| if d is None: | |
| log.msg("Stale response with id '%s'" % | |
| properties["correlation id"]) | |
| return | |
| d.callback(content.body) | |
| del self.requests[properties["correlation id"]] | |
| class AmqpFactory(protocol.ReconnectingClientFactory): | |
| protocol = AmqpProtocol | |
| def __init__(self, reactor=None, vhost=None, host=None, port=None, | |
| user=None, password=None): | |
| self.spec = load(os.path.join(os.path.dirname(__file__), | |
| '..', 'contrib', 'amqp0-8.stripped.rabbitmq.xml')) | |
| self.user = user or 'guest' | |
| self.password = password or 'guest' | |
| self.vhost = vhost or '/' | |
| self.host = host or 'localhost' | |
| self.port = port or 5672 | |
| self.delegate = TwistedDelegate() | |
| self.deferred = Deferred() | |
| self.reactor = _maybeGlobalReactor(reactor) | |
| self._client = None | |
| def buildProtocol(self, addr): | |
| self._client = p = self.protocol(self.reactor, self.delegate, | |
| self.vhost, self.spec) | |
| p.factory = self | |
| self.resetDelay() | |
| return p | |
| def doStop(self): | |
| self._client = None | |
| self.deferred = Deferred() | |
| protocol.ReconnectingClientFactory.doStop(self) | |
| @inlineCallbacks | |
| def client(self): | |
| if self.deferred.called: | |
| client = yield self._client | |
| else: | |
| client = yield self.deferred | |
| returnValue(client) | |
| @inlineCallbacks | |
| def consumer(self, kls=None, **kwargs): | |
| kls = kls or Consumer | |
| returnValue(kls((yield self.client()), **kwargs)) | |
| @inlineCallbacks | |
| def producer(self, kls=None, **kwargs): | |
| kls = kls or Producer | |
| returnValue(kls((yield self.client()), **kwargs)) | |
| class ProtobufAmqpFactory(AmqpFactory): | |
| def __init__(self, *args, **kwargs): | |
| AmqpFactory.__init__(self, *args, **kwargs) | |
| self._rpc_server = None | |
| def consumer(self, kls=None, **kwargs): | |
| kls = kls or ProtobufConsumer | |
| return AmqpFactory.consumer(self, kls=kls, **kwargs) | |
| def producer(self, kls=None, **kwargs): | |
| kls = kls or ProtobufProducer | |
| return AmqpFactory.consumer(self, kls=kls, **kwargs) | |
| @property | |
| def rpc_server(self): | |
| if self._rpc_server is None: | |
| self._rpc_server = RPCServer(self) | |
| return self._rpc_server | |
| @inlineCallbacks | |
| def rpc_client(self, exchange): | |
| client = RPCClient(exchange=exchange, | |
| factory=self) | |
| yield client.initialize() | |
| returnValue(client) | |
| class AmqpService(_AbstractClient): | |
| method = 'TCP' | |
| def __init__(self, factory): | |
| self.factory = factory | |
| _AbstractClient.__init__(self, factory.host, factory.port, factory) | |
| def stopService(self): | |
| self.factory.stopTrying() | |
| self.factory.doStop() | |
| _AbstractClient.stopService(self) | |
| class ProducerMixin(object): | |
| def __init__(self, *args, **kwargs): | |
| self.__producers = {} | |
| self.__lock = DeferredLock() | |
| super(ProducerMixin, self).__init__(*args, **kwargs) | |
| @inlineCallbacks | |
| def producer(self, exchange_name): | |
| self.__lock.acquire() | |
| try: | |
| producer = self.__producers.get(exchange_name, None) | |
| if producer is None: | |
| producer = self.amqp.producer(exchange=exchange_name) | |
| self.__producers[exchange_name] = producer = yield producer | |
| returnValue(producer) | |
| finally: | |
| self.__lock.release() | |
| def rpc_handler(amqp, exchange, name, | |
| request_cls, response_cls, | |
| deferred=None): | |
| def decorator(f): | |
| d = amqp.rpc_server.register(exchange=exchange, | |
| name=name, | |
| callback=f, | |
| request_cls=request_cls, | |
| response_cls=response_cls) | |
| if deferred is not None: | |
| d.chainDeferred(deferred) | |
| return f | |
| return decorator | |
| def consume(amqp, exchange, routing_key, | |
| message_cls, deferred=None, | |
| **kwargs): | |
| def decorator(f): | |
| d = amqp.consumer(exchange=exchange, | |
| routing_key=routing_key, | |
| callback=f, | |
| message_cls=message_cls, | |
| **kwargs) | |
| if deferred is not None: | |
| d.chainDeferred(deferred) | |
| return f | |
| return decorator |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from amqp import ProtobufContent, AmqpFactory, \ | |
| Producer, Consumer, ProtobufAmqpFactory, ProducerMixin, AmqpService | |
| from missed import ReaderRequest, ReaderResponse | |
| from twisted.internet import reactor, defer | |
| from twisted.trial import unittest | |
| from txamqp.content import Content | |
| class TestProducerContainer(ProducerMixin): | |
| def __init__(self, amqp): | |
| self.amqp = amqp | |
| super(TestProducerContainer, self).__init__() | |
| class ProtobufContentTest(unittest.TestCase): | |
| def test_init(self): | |
| message = ReaderRequest(uri="test") | |
| content = ProtobufContent(message=message) | |
| self.assertFalse(content is None) | |
| self.assertRaises(AssertionError, lambda: ProtobufContent()) | |
| def test_getMessage(self): | |
| message = ReaderRequest(uri="test") | |
| content = ProtobufContent(message=message) | |
| self.assertIdentical(message, content.message) | |
| def test_setMessage(self): | |
| message1 = ReaderRequest(uri="test1") | |
| message2 = ReaderRequest(uri="test2") | |
| content = ProtobufContent(message=message1) | |
| content.message = message2 | |
| self.assertIdentical(message2, content.message) | |
| def test_getBody(self): | |
| source_message = ReaderRequest(uri="test1") | |
| content = ProtobufContent(body=source_message.SerializeToString(), | |
| message=ReaderRequest()) | |
| self.assertEqual(source_message.uri, content.message.uri) | |
| def test_setBody(self): | |
| message1 = ReaderRequest(uri="test1") | |
| message2 = ReaderRequest(uri="test2") | |
| content = ProtobufContent(message=message1) | |
| content.body = message2.SerializeToString() | |
| self.assertEqual(message2.uri, content.message.uri) | |
| def test_create(self): | |
| source_message = ReaderRequest(uri="test") | |
| source_content = Content(body=source_message.SerializeToString()) | |
| content = ProtobufContent.create(source_content, ReaderRequest()) | |
| self.assertEqual(source_message.uri, content.message.uri) | |
| class AmqpFactoryTest(unittest.TestCase): | |
| def getFactory(self, kls=None): | |
| kls = kls or AmqpFactory | |
| factory = kls(host="192.168.1.17") | |
| connector = reactor.connectTCP(factory.host, factory.port, factory) | |
| @defer.inlineCallbacks | |
| def _cb(): | |
| client = yield factory.client() | |
| yield client.disconnect() | |
| connector.disconnect() | |
| self.addCleanup(_cb) | |
| return factory | |
| @defer.inlineCallbacks | |
| def test_client(self): | |
| factory = self.getFactory() | |
| client = yield factory.client() | |
| self.assertTrue(client.connected) | |
| yield client.disconnect() | |
| self.assertFalse(client.connected) | |
| @defer.inlineCallbacks | |
| def test_consumer(self): | |
| factory = self.getFactory() | |
| message = "test" | |
| d = defer.Deferred() | |
| @defer.inlineCallbacks | |
| def _cb(content): | |
| self.assertEqual(content.body, message) | |
| yield | |
| d.addCallback(_cb) | |
| consumer = yield factory.consumer(exchange="test.exchange", | |
| queue="test.queue", | |
| routing_key="test", | |
| callback=d.callback) | |
| self.assertIsInstance(consumer, Consumer) | |
| yield consumer.wait() | |
| producer = yield factory.producer(exchange="test.exchange") | |
| self.assertIsInstance(producer, Producer) | |
| content = Content(body=message) | |
| yield producer.send(content, routing_key="test") | |
| yield d | |
| @defer.inlineCallbacks | |
| def test_protobuf(self): | |
| factory = self.getFactory(ProtobufAmqpFactory) | |
| message = ReaderRequest(uri="test") | |
| d = defer.Deferred() | |
| def _cb(content): | |
| self.assertEqual(content.message.uri, message.uri) | |
| d.addCallback(_cb) | |
| consumer = yield factory.consumer(exchange="test.proto_exchange", | |
| queue="test.proto_queue", | |
| routing_key="test", | |
| callback=d.callback, | |
| message_cls=ReaderRequest) | |
| self.assertIsInstance(consumer, Consumer) | |
| yield consumer.wait() | |
| producer = yield factory.producer(exchange="test.proto_exchange") | |
| self.assertIsInstance(producer, Producer) | |
| yield producer.send(message, routing_key="test") | |
| yield d | |
| @defer.inlineCallbacks | |
| def test_mixin(self): | |
| factory = self.getFactory() | |
| container = TestProducerContainer(factory) | |
| producer1 = yield container.producer("test.exchange") | |
| yield producer1.wait() | |
| producer2 = yield container.producer("test.exchange") | |
| yield producer2.wait() | |
| self.assertIdentical(producer1, producer2) | |
| producer3 = yield container.producer("test.proto_exchange") | |
| yield producer3.wait() | |
| self.assertNotIdentical(producer1, producer3) | |
| @defer.inlineCallbacks | |
| def test_rpc(self): | |
| factory = self.getFactory(ProtobufAmqpFactory) | |
| uri = 'test' | |
| def _cb(request, response): | |
| response.uri = request.uri | |
| yield factory.rpc_server.register(exchange='test.rpc', | |
| name='test', | |
| callback=_cb, | |
| request_cls=ReaderRequest, | |
| response_cls=ReaderResponse) | |
| rpc_client = yield factory.rpc_client(exchange='test.rpc') | |
| response = ReaderResponse() | |
| yield rpc_client.invoke(name='test', | |
| request=ReaderRequest(uri=uri), | |
| response=response) | |
| self.assertEqual(response.uri, uri) | |
| def test_service(self): | |
| factory = self.getFactory() | |
| service = AmqpService(factory) | |
| service.startService() | |
| service.stopService()% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment