LP#1409055 Support specific protocols for OpenSRF gateway requests
[OpenSRF.git] / src / python / osrf / gateway.py
index ee2fd36..8796ce2 100644 (file)
@@ -1,8 +1,8 @@
 from xml.dom import minidom
 from xml.sax import handler, make_parser, saxutils
-from json import *
-from net_obj import *
-from log import *
+from osrf.json import to_object
+from osrf.net_obj import NetworkObject, new_object_from_hint
+import osrf.log
 import urllib, urllib2, sys, re
 
 defaultHost = None
@@ -13,6 +13,7 @@ class GatewayRequest:
         self.method = method
         self.params = params
         self.path = 'gateway'
+        self.bytes_read = 0 # for now this, this is really characters read
 
     def setPath(self, path):
         self.path = path
@@ -25,7 +26,7 @@ class GatewayRequest:
             response =urllib2.urlopen(request)
         except urllib2.HTTPError, e:
             # log this?
-            sys.stderr.write('%s => %s?%s\n' % (str(e), self.buildURL(), params))
+            sys.stderr.write('%s => %s?%s\n' % (unicode(e), self.buildURL(), params))
             raise e
             
         return self.handleResponse(response)
@@ -49,6 +50,18 @@ class GatewayRequest:
     setDefaultHost = staticmethod(setDefaultHost)
 
     def buildURL(self):
+        """
+        Builds the URL for the OpenSRF gateway based on the host and path
+
+        Previous versions of the code assumed that the host would be a bare
+        hostname or IP address, and prepended the http:// protocol. However,
+        to enable more secure communications, now we check for the existence
+        of the HTTP or HTTPS prefix and use that if it has been supplied.
+        """
+
+        if defaultHost.lower().startswith(('http://', 'https://')):
+            return '%s/%s' % (defaultHost, self.path)
+
         return 'http://%s/%s' % (defaultHost, self.path)
 
 class JSONGatewayRequest(GatewayRequest):
@@ -62,20 +75,24 @@ class JSONGatewayRequest(GatewayRequest):
         return self.getFormat()
 
     def handleResponse(self, response):
-        s = response.read()
-        obj = osrfJSONToObject(s)
+
+        data = response.read()
+        self.bytes_read = len(str(response.headers)) + len(data)
+        obj = to_object(data)
+
         if obj['status'] != 200:
-            sys.stderr.write('JSON gateway returned status %d:\n%s\n' % (obj['status'], s))
+            sys.stderr.write('JSON gateway returned status %d:\n' % (obj['status']))
             return None
 
         # the gateway wraps responses in an array to handle streaming data
         # if there is only one item in the array, it (probably) wasn't a streaming request
         p = obj['payload']
         if len(p) > 1: return p
-        return p[0]
+        if len(p): return p[0]
+        return None
 
     def encodeParam(self, param):
-        return osrfObjectToJSON(param)
+        return osrf.json.to_json(param)
 
 class XMLGatewayRequest(GatewayRequest):
 
@@ -95,13 +112,13 @@ class XMLGatewayRequest(GatewayRequest):
         try:
             parser.parse(response)
         except Exception, e:
-            osrfLogErr('Error parsing gateway XML: %s' % str(e))
+            osrf.log.log_error('Error parsing gateway XML: %s' % unicode(e))
             return None
 
         return handler.getResult()
 
     def encodeParam(self, param):
-        return osrfObjectToXML(param);
+        return osrf.net_obj.to_xml(param);
 
 class XMLGatewayParser(handler.ContentHandler):
 
@@ -144,7 +161,7 @@ class XMLGatewayParser(handler.ContentHandler):
 
         hint = self.__getAttr(attrs, 'class_hint')
         if hint:
-            obj = osrfNewObjectFromHint(hint)
+            obj = new_object_from_hint(hint)
             self.appendChild(obj)
             self.objStack.append(obj)
             if name == 'array':
@@ -183,10 +200,10 @@ class XMLGatewayParser(handler.ContentHandler):
             if isinstance(parent, dict):
                 parent[self.keyStack.pop()] = child
             else:
-                if isinstance(parent, osrfNetworkObject):
+                if isinstance(parent, NetworkObject):
                     key = None
-                    if parent.getRegistry().wireProtocol == 'array':
-                        keys = parent.getRegistry().keys
+                    if parent.get_registry().protocol == 'array':
+                        keys = parent.get_registry().keys
                         i = self.posStack.pop()
                         key = keys[i]
                         if i+1 < len(keys):
@@ -194,7 +211,7 @@ class XMLGatewayParser(handler.ContentHandler):
                     else:
                         key = self.keyStack.pop()
 
-                    parent.setField(key, child)
+                    parent.set_field(key, child)
 
     def endElement(self, name):
         if name == 'array' or name == 'object':