Rev 30 | Blame | Compare with Previous | Last modification | View Log | RSS feed
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.#from zope.interface import implements, Interface, Attributefrom twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \connectionDonefrom twisted.internet import deferfrom twisted.protocols import basicfrom twisted.python import logfrom thrift.transport import TTransportfrom cStringIO import StringIOclass TMessageSenderTransport(TTransport.TTransportBase):def __init__(self):self.__wbuf = StringIO()def write(self, buf):self.__wbuf.write(buf)def flush(self):msg = self.__wbuf.getvalue()self.__wbuf = StringIO()self.sendMessage(msg)def sendMessage(self, message):raise NotImplementedErrorclass TCallbackTransport(TMessageSenderTransport):def __init__(self, func):TMessageSenderTransport.__init__(self)self.func = funcdef sendMessage(self, message):self.func(message)class ThriftClientProtocol(basic.Int32StringReceiver):MAX_LENGTH = 2 ** 31 - 1def __init__(self, client_class, iprot_factory, oprot_factory=None):self._client_class = client_classself._iprot_factory = iprot_factoryif oprot_factory is None:self._oprot_factory = iprot_factoryelse:self._oprot_factory = oprot_factoryself.recv_map = {}self.started = defer.Deferred()def dispatch(self, msg):self.sendString(msg)def connectionMade(self):tmo = TCallbackTransport(self.dispatch)self.client = self._client_class(tmo, self._oprot_factory)self.started.callback(self.client)def connectionLost(self, reason=connectionDone):for k,v in self.client._reqs.iteritems():tex = TTransport.TTransportException(type=TTransport.TTransportException.END_OF_FILE,message='Connection closed')v.errback(tex)def stringReceived(self, frame):tr = TTransport.TMemoryBuffer(frame)iprot = self._iprot_factory.getProtocol(tr)(fname, mtype, rseqid) = iprot.readMessageBegin()try:method = self.recv_map[fname]except KeyError:method = getattr(self.client, 'recv_' + fname)self.recv_map[fname] = methodmethod(iprot, mtype, rseqid)class ThriftServerProtocol(basic.Int32StringReceiver):MAX_LENGTH = 2 ** 31 - 1def dispatch(self, msg):self.sendString(msg)def processError(self, error):self.transport.loseConnection()def processOk(self, _, tmo):msg = tmo.getvalue()if len(msg) > 0:self.dispatch(msg)def stringReceived(self, frame):tmi = TTransport.TMemoryBuffer(frame)tmo = TTransport.TMemoryBuffer()iprot = self.factory.iprot_factory.getProtocol(tmi)oprot = self.factory.oprot_factory.getProtocol(tmo)d = self.factory.processor.process(iprot, oprot)d.addCallbacks(self.processOk, self.processError,callbackArgs=(tmo,))class IThriftServerFactory(Interface):processor = Attribute("Thrift processor")iprot_factory = Attribute("Input protocol factory")oprot_factory = Attribute("Output protocol factory")class IThriftClientFactory(Interface):client_class = Attribute("Thrift client class")iprot_factory = Attribute("Input protocol factory")oprot_factory = Attribute("Output protocol factory")class ThriftServerFactory(ServerFactory):implements(IThriftServerFactory)protocol = ThriftServerProtocoldef __init__(self, processor, iprot_factory, oprot_factory=None):self.processor = processorself.iprot_factory = iprot_factoryif oprot_factory is None:self.oprot_factory = iprot_factoryelse:self.oprot_factory = oprot_factoryclass ThriftClientFactory(ClientFactory):implements(IThriftClientFactory)protocol = ThriftClientProtocoldef __init__(self, client_class, iprot_factory, oprot_factory=None):self.client_class = client_classself.iprot_factory = iprot_factoryif oprot_factory is None:self.oprot_factory = iprot_factoryelse:self.oprot_factory = oprot_factorydef buildProtocol(self, addr):p = self.protocol(self.client_class, self.iprot_factory,self.oprot_factory)p.factory = selfreturn p