created a shallow clone method
[OpenSRF.git] / src / python / osrf / net_obj.py
1 from osrf.const import OSRF_JSON_PAYLOAD_KEY, OSRF_JSON_CLASS_KEY
2 import re
3 from xml.sax import saxutils
4
5
6 # -----------------------------------------------------------
7 # Define the global network-class registry
8 # -----------------------------------------------------------
9
10
11 class NetworkRegistry(object):
12     ''' Network-serializable objects must be registered.  The class
13         hint maps to a set (ordered in the case of array-base objects)
14         of field names (keys).  
15         '''
16
17     # Global object registry 
18     registry = {}
19
20     def __init__(self, hint, keys, protocol):
21         self.hint = hint
22         self.keys = keys
23         self.protocol = protocol
24         NetworkRegistry.registry[hint] = self
25     
26     @staticmethod
27     def get_registry(hint):
28         return NetworkRegistry.registry.get(hint)
29
30 # -----------------------------------------------------------
31 # Define the base class for all network-serializable objects
32 # -----------------------------------------------------------
33
34 class NetworkObject(object):
35     ''' Base class for all network serializable objects '''
36
37     # link to our registry object for this registered class
38     registry = None
39
40     def __init__(self, data=None):
41         ''' If this is an array, we pull data out of the data array
42             (if there is any) and translate that into a hash internally '''
43
44         self._data = data
45         if not data: self._data = {}
46         if isinstance(data, list):
47             self.import_array_data(list)
48
49     def import_array_data(self, data):
50         ''' If an array-based object is created with an array
51             of data, cycle through and load the data '''
52
53         self._data = {}
54         if len(data) == 0:
55             return
56
57         reg = self.get_registry()
58         if reg.protocol == 'array':
59             for entry in range(len(reg.keys)):
60                 if len(data) > entry:
61                     break
62                 self.set_field(reg.keys[entry], data[entry])
63
64     def get_data(self):
65         ''' Returns the full dataset for this object as a dict '''
66         return self._data
67
68     def set_field(self, field, value):
69         self._data[field] = value
70
71     def get_field(self, field):
72         return self._data.get(field)
73
74     def get_registry(self):
75         ''' Returns the registry object for this registered class '''
76         return self.__class__.registry
77
78     def shallow_clone(self):
79         ''' Makes a shallow copy '''
80         reg = self.get_registry()
81         obj = new_object_from_hint(reg.hint)
82         for field in reg.keys:
83             obj.set_field(field, self.get_field(field))
84         return obj
85             
86
87
88 def new_object_from_hint(hint):
89     ''' Given a hint, this will create a new object of that 
90         type and return it.  If this hint is not registered,
91         an object of type NetworkObject.__unknown is returned'''
92     try:
93         obj = None
94         exec('obj = NetworkObject.%s()' % hint)
95         return obj
96     except AttributeError:
97         return NetworkObject.__unknown()
98
99 def __make_network_accessor(cls, key):
100     ''' Creates and accessor/mutator method for the given class.  
101         'key' is the name the method will have and represents
102         the field on the object whose data we are accessing ''' 
103     def accessor(self, *args):
104         if len(args) != 0:
105             self.set_field(key, args[0])
106         return self.get_field(key)
107     setattr(cls, key, accessor)
108
109
110 def register_hint(hint, keys, type='hash'):
111     ''' Registers a new network-serializable object class.
112
113         'hint' is the class hint
114         'keys' is the list of field names on the object
115             If this is an array-based object, the field names
116             must be sorted to reflect the encoding order of the fields
117         'type' is the wire-protocol of the object.  hash or array.
118         '''
119
120     # register the class with the global registry
121     registry = NetworkRegistry(hint, keys, type)
122
123     # create the new class locally with the given hint name
124     exec('class %s(NetworkObject):\n\tpass' % hint)
125
126     # give the new registered class a local handle
127     cls = None
128     exec('cls = %s' % hint)
129
130     # assign an accessor/mutator for each field on the object
131     for k in keys:
132         __make_network_accessor(cls, k)
133
134     # attach our new class to the NetworkObject 
135     # class so others can access it
136     setattr(NetworkObject, hint , cls)
137     cls.registry = registry
138
139
140
141
142 # create a unknown object to handle unregistred types
143 register_hint('__unknown', [], 'hash')
144
145 # -------------------------------------------------------------------
146 # Define the custom object parsing behavior 
147 # -------------------------------------------------------------------
148 def parse_net_object(obj):
149     
150     try:
151         hint = obj[OSRF_JSON_CLASS_KEY]
152         sub_object = obj[OSRF_JSON_PAYLOAD_KEY]
153         reg = NetworkRegistry.get_registry(hint)
154
155         obj = {}
156
157         if reg.protocol == 'array':
158             for entry in range(len(reg.keys)):
159                 if len(sub_object) > entry:
160                     obj[reg.keys[entry]] = parse_net_object(sub_object[entry])
161                 else:
162                     obj[reg.keys[entry]] = None
163         else:
164             for key in reg.keys:
165                 obj[key] = parse_net_object(sub_object.get(key))
166
167         estr = 'obj = NetworkObject.%s(obj)' % hint
168         try:
169             exec(estr)
170         except:
171             # this object has not been registered, shove it into the default container
172             obj = NetworkObject.__unknown(obj)
173
174         return obj
175
176     except:
177         pass
178
179     # the current object does not have a class hint
180     if isinstance(obj, list):
181         for entry in range(len(obj)):
182             obj[entry] = parse_net_object(obj[entry])
183
184     else:
185         if isinstance(obj, dict):
186             for key, value in obj.iteritems():
187                 obj[key] = parse_net_object(value)
188
189     return obj
190
191
192 def to_xml(obj):
193     """ Returns the XML representation of an internal object."""
194     chars = []
195     __to_xml(obj, chars)
196     return ''.join(chars)
197
198 def __to_xml(obj, chars):
199     """ Turns an internal object into OpenSRF XML """
200
201     if obj is None:
202         chars.append('<null/>')
203         return
204
205     if isinstance(obj, unicode) or isinstance(obj, str):
206         chars.append('<string>%s</string>' % saxutils.escape(obj))
207         return
208
209     if isinstance(obj, int)  or isinstance(obj, long):
210         chars.append('<number>%d</number>' % obj)
211         return
212
213     if isinstance(obj, float):
214         chars.append('<number>%f</number>' % obj)
215         return
216
217     if isinstance(obj, NetworkObject): 
218
219         registry = obj.get_registry()
220         data = obj.get_data()
221         hint = saxutils.escape(registry.hint)
222
223         if registry.protocol == 'array':
224             chars.append("<array class_hint='%s'>" % hint)
225             for key in registry.keys:
226                 __to_xml(data.get(key), chars)
227             chars.append('</array>')
228
229         else:
230             if registry.protocol == 'hash':
231                 chars.append("<object class_hint='%s'>" % hint)
232                 for key, value in data.items():
233                     chars.append("<element key='%s'>" % saxutils.escape(key))
234                     __to_xml(value, chars)
235                     chars.append('</element>')
236                 chars.append('</object>')
237                 
238
239     if isinstance(obj, list):
240         chars.append('<array>')
241         for entry in obj:
242             __to_xml(entry, chars)
243         chars.append('</array>')
244         return
245
246     if isinstance(obj, dict):
247         chars.append('<object>')
248         for key, value in obj.items():
249             chars.append("<element key='%s'>" % saxutils.escape(key))
250             __to_xml(value, chars)
251             chars.append('</element>')
252         chars.append('</object>')
253         return
254
255     if isinstance(obj, bool):
256         val = 'false'
257         if obj:
258             val = 'true'
259         chars.append("<boolean value='%s'/>" % val)
260         return
261
262 def find_object_path(obj, path, idx=None):
263     """Searches an object along the given path for a value to return.
264
265     Path separators can be '/' or '.', '/' is tried first."""
266
267     parts = []
268
269     if re.search('/', path):
270         parts = path.split('/')
271     else:
272         parts = path.split('.')
273
274     for part in parts:
275         try:
276             val = obj[part]
277         except:
278             return None
279         if isinstance(val, str): 
280             return val
281         if isinstance(val, list):
282             if idx != None:
283                 return val[idx]
284             return val
285         if isinstance(val, dict):
286             obj = val
287         else:
288             return val
289
290     return obj