Subversion Repositories SmartDukaan

Rev

Rev 30 | Details | Compare with Previous | Last modification | View Log | RSS feed

Rev Author Line No. Line
30 ashish 1
#!/usr/bin/env python
2
 
3
#
4
# Licensed to the Apache Software Foundation (ASF) under one
5
# or more contributor license agreements. See the NOTICE file
6
# distributed with this work for additional information
7
# regarding copyright ownership. The ASF licenses this file
8
# to you under the Apache License, Version 2.0 (the
9
# "License"); you may not use this file except in compliance
10
# with the License. You may obtain a copy of the License at
11
#
12
#   http://www.apache.org/licenses/LICENSE-2.0
13
#
14
# Unless required by applicable law or agreed to in writing,
15
# software distributed under the License is distributed on an
16
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17
# KIND, either express or implied. See the License for the
18
# specific language governing permissions and limitations
19
# under the License.
20
#
21
 
22
import sys, glob
23
sys.path.insert(0, './gen-py')
24
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
25
 
26
from ThriftTest.ttypes import *
27
from thrift.transport import TTransport
28
from thrift.transport import TSocket
29
from thrift.protocol import TBinaryProtocol
30
import unittest
31
import time
32
 
33
class AbstractTest(unittest.TestCase):
34
 
35
  def setUp(self):
36
      self.v1obj = VersioningTestV1(
37
          begin_in_both=12345,
38
          old_string='aaa',
39
          end_in_both=54321,
40
          )
41
 
42
      self.v2obj = VersioningTestV2(
43
          begin_in_both=12345,
44
          newint=1,
45
          newbyte=2,
46
          newshort=3,
47
          newlong=4,
48
          newdouble=5.0,
49
          newstruct=Bonk(message="Hello!", type=123),
50
          newlist=[7,8,9],
51
          newset=[42,1,8],
52
          newmap={1:2,2:3},
53
          newstring="Hola!",
54
          end_in_both=54321,
55
          )
56
 
57
  def _serialize(self, obj):
58
      trans = TTransport.TMemoryBuffer()
59
      prot = self.protocol_factory.getProtocol(trans)
60
      obj.write(prot)
61
      return trans.getvalue()
62
 
63
  def _deserialize(self, objtype, data):
64
      prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
65
      ret = objtype()
66
      ret.read(prot)
67
      return ret
68
 
69
  def testForwards(self):
70
      obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
71
      self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
72
      self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
73
 
74
  def testBackwards(self):
75
      obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
76
      self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
77
      self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
78
 
79
 
80
class NormalBinaryTest(AbstractTest):
81
  protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
82
 
83
class AcceleratedBinaryTest(AbstractTest):
84
  protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
85
 
86
 
87
class AcceleratedFramedTest(unittest.TestCase):
88
  def testSplit(self):
89
    """Test FramedTransport and BinaryProtocolAccelerated
90
 
91
    Tests that TBinaryProtocolAccelerated and TFramedTransport
92
    play nicely together when a read spans a frame"""
93
 
94
    protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
95
    bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1))
96
 
97
    databuf = TTransport.TMemoryBuffer()
98
    prot = protocol_factory.getProtocol(databuf)
99
    prot.writeI32(42)
100
    prot.writeString(bigstring)
101
    prot.writeI16(24)
102
    data = databuf.getvalue()
103
    cutpoint = len(data)/2
104
    parts = [ data[:cutpoint], data[cutpoint:] ]
105
 
106
    framed_buffer = TTransport.TMemoryBuffer()
107
    framed_writer = TTransport.TFramedTransport(framed_buffer)
108
    for part in parts:
109
      framed_writer.write(part)
110
      framed_writer.flush()
111
    self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
112
 
113
    # Recreate framed_buffer so we can read from it.
114
    framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
115
    framed_reader = TTransport.TFramedTransport(framed_buffer)
116
    prot = protocol_factory.getProtocol(framed_reader)
117
    self.assertEqual(prot.readI32(), 42)
118
    self.assertEqual(prot.readString(), bigstring)
119
    self.assertEqual(prot.readI16(), 24)
120
 
121
 
122
 
123
def suite():
124
  suite = unittest.TestSuite()
125
  loader = unittest.TestLoader()
126
 
127
  suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
128
  suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
129
  suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
130
  return suite
131
 
132
if __name__ == "__main__":
133
  unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))