Subversion Repositories SmartDukaan

Rev

Go to most recent revision | Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
30 ashish 1
#
2
# Licensed to the Apache Software Foundation (ASF) under one
3
# or more contributor license agreements. See the NOTICE file
4
# distributed with this work for additional information
5
# regarding copyright ownership. The ASF licenses this file
6
# to you under the Apache License, Version 2.0 (the
7
# "License"); you may not use this file except in compliance
8
# with the License. You may obtain a copy of the License at
9
#
10
#   http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing,
13
# software distributed under the License is distributed on an
14
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
# KIND, either express or implied. See the License for the
16
# specific language governing permissions and limitations
17
# under the License.
18
#
19
from zope.interface import implements, Interface, Attribute
20
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
21
    connectionDone
22
from twisted.internet import defer
23
from twisted.protocols import basic
24
from twisted.python import log
25
 
26
 
27
from thrift.transport import TTransport
28
from cStringIO import StringIO
29
 
30
 
31
class TMessageSenderTransport(TTransport.TTransportBase):
32
 
33
    def __init__(self):
34
        self.__wbuf = StringIO()
35
 
36
    def write(self, buf):
37
        self.__wbuf.write(buf)
38
 
39
    def flush(self):
40
        msg = self.__wbuf.getvalue()
41
        self.__wbuf = StringIO()
42
        self.sendMessage(msg)
43
 
44
    def sendMessage(self, message):
45
        raise NotImplementedError
46
 
47
 
48
class TCallbackTransport(TMessageSenderTransport):
49
 
50
    def __init__(self, func):
51
        TMessageSenderTransport.__init__(self)
52
        self.func = func
53
 
54
    def sendMessage(self, message):
55
        self.func(message)
56
 
57
 
58
class ThriftClientProtocol(basic.Int32StringReceiver):
59
 
60
    MAX_LENGTH = 2 ** 31 - 1
61
 
62
    def __init__(self, client_class, iprot_factory, oprot_factory=None):
63
        self._client_class = client_class
64
        self._iprot_factory = iprot_factory
65
        if oprot_factory is None:
66
            self._oprot_factory = iprot_factory
67
        else:
68
            self._oprot_factory = oprot_factory
69
 
70
        self.recv_map = {}
71
        self.started = defer.Deferred()
72
 
73
    def dispatch(self, msg):
74
        self.sendString(msg)
75
 
76
    def connectionMade(self):
77
        tmo = TCallbackTransport(self.dispatch)
78
        self.client = self._client_class(tmo, self._oprot_factory)
79
        self.started.callback(self.client)
80
 
81
    def connectionLost(self, reason=connectionDone):
82
        for k,v in self.client._reqs.iteritems():
83
            tex = TTransport.TTransportException(
84
                type=TTransport.TTransportException.END_OF_FILE,
85
                message='Connection closed')
86
            v.errback(tex)
87
 
88
    def stringReceived(self, frame):
89
        tr = TTransport.TMemoryBuffer(frame)
90
        iprot = self._iprot_factory.getProtocol(tr)
91
        (fname, mtype, rseqid) = iprot.readMessageBegin()
92
 
93
        try:
94
            method = self.recv_map[fname]
95
        except KeyError:
96
            method = getattr(self.client, 'recv_' + fname)
97
            self.recv_map[fname] = method
98
 
99
        method(iprot, mtype, rseqid)
100
 
101
 
102
class ThriftServerProtocol(basic.Int32StringReceiver):
103
 
104
    MAX_LENGTH = 2 ** 31 - 1
105
 
106
    def dispatch(self, msg):
107
        self.sendString(msg)
108
 
109
    def processError(self, error):
110
        self.transport.loseConnection()
111
 
112
    def processOk(self, _, tmo):
113
        msg = tmo.getvalue()
114
 
115
        if len(msg) > 0:
116
            self.dispatch(msg)
117
 
118
    def stringReceived(self, frame):
119
        tmi = TTransport.TMemoryBuffer(frame)
120
        tmo = TTransport.TMemoryBuffer()
121
 
122
        iprot = self.factory.iprot_factory.getProtocol(tmi)
123
        oprot = self.factory.oprot_factory.getProtocol(tmo)
124
 
125
        d = self.factory.processor.process(iprot, oprot)
126
        d.addCallbacks(self.processOk, self.processError,
127
            callbackArgs=(tmo,))
128
 
129
 
130
class IThriftServerFactory(Interface):
131
 
132
    processor = Attribute("Thrift processor")
133
 
134
    iprot_factory = Attribute("Input protocol factory")
135
 
136
    oprot_factory = Attribute("Output protocol factory")
137
 
138
 
139
class IThriftClientFactory(Interface):
140
 
141
    client_class = Attribute("Thrift client class")
142
 
143
    iprot_factory = Attribute("Input protocol factory")
144
 
145
    oprot_factory = Attribute("Output protocol factory")
146
 
147
 
148
class ThriftServerFactory(ServerFactory):
149
 
150
    implements(IThriftServerFactory)
151
 
152
    protocol = ThriftServerProtocol
153
 
154
    def __init__(self, processor, iprot_factory, oprot_factory=None):
155
        self.processor = processor
156
        self.iprot_factory = iprot_factory
157
        if oprot_factory is None:
158
            self.oprot_factory = iprot_factory
159
        else:
160
            self.oprot_factory = oprot_factory
161
 
162
 
163
class ThriftClientFactory(ClientFactory):
164
 
165
    implements(IThriftClientFactory)
166
 
167
    protocol = ThriftClientProtocol
168
 
169
    def __init__(self, client_class, iprot_factory, oprot_factory=None):
170
        self.client_class = client_class
171
        self.iprot_factory = iprot_factory
172
        if oprot_factory is None:
173
            self.oprot_factory = iprot_factory
174
        else:
175
            self.oprot_factory = oprot_factory
176
 
177
    def buildProtocol(self, addr):
178
        p = self.protocol(self.client_class, self.iprot_factory,
179
            self.oprot_factory)
180
        p.factory = self
181
        return p