LP#1066131: srfsh.py should not require opensrf.settings
[OpenSRF.git] / src / python / srfsh.py
1 #!/usr/bin/python
2 # vim:et:ts=4
3 """
4 srfsh.py - provides a basic shell for issuing OpenSRF requests
5
6   help
7     - show this menu
8
9   math_bench <count>
10     - runs <count> opensrf.math requests and prints the average time
11
12   request <service> <method> [<param1>, <param2>, ...]
13     - performs an opensrf request
14     - parameters are JSON strings
15
16   router <query>
17     - Queries the router.  Query options: services service-stats service-nodes
18
19   introspect <service> [<api_name_prefix>]
20     - List API calls for a service.  
21     - api_name_prefix is a bare string or JSON string.
22
23   set VAR=<value>
24     - sets an environment variable
25
26   get VAR
27     - Returns the value for the environment variable
28
29   Environment variables:
30     SRFSH_OUTPUT_NET_OBJ_KEYS  = true - If a network object is array-encoded and key registry exists for the object type, annotate the object with field names
31                                = false - Print JSON
32
33     SRFSH_OUTPUT_FORMAT_JSON = true - Use JSON pretty printer
34                              = false - Print raw JSON
35
36     SRFSH_OUTPUT_PAGED = true - Paged output.  Uses "less -EX"
37                        = false - Output is not paged
38
39     SRFSH_LOCALE = <locale> - request responses to be returned in locale <locale> if available
40
41 """
42 import os, sys, time, readline, atexit, re, pydoc, traceback
43 import osrf.json, osrf.system, osrf.ses, osrf.conf, osrf.log, osrf.net
44
45 class Srfsh(object):
46
47
48     def __init__(self, script_file=None):
49
50         # used for paging
51         self.output_buffer = '' 
52
53         # true if invoked with a script file 
54         self.reading_script = False
55
56         # multi-request sessions
57         self.active_session = None 
58
59         # default opensrf request timeout
60         self.timeout = 120
61
62         # map of command name to handler
63         self.command_map = {}
64
65         if script_file:
66             self.open_script(script_file)
67             self.reading_script = True
68
69         # map of router sub-commands to router API calls
70         self.router_command_map = {
71             'services'      : 'opensrf.router.info.class.list',
72             'service-stats' : 'opensrf.router.info.stats.class.node.all',
73             'service-nodes' : 'opensrf.router.info.stats.class.all'
74         }
75
76         # seed the tab completion word bank
77         self.tab_complete_words = self.router_command_map.keys() + [
78             'exit', 
79             'quit', 
80             'opensrf.settings', 
81             'opensrf.math',
82             'opensrf.dbmath',
83             'opensrf.py-example'
84         ]
85
86         # add the default commands
87         for command in ['request', 'router', 'help', 'set', 
88                 'get', 'math_bench', 'introspect', 'connect', 'disconnect' ]:
89
90             self.add_command(command = command, handler = getattr(Srfsh, 'handle_' + command))
91
92         # for compat w/ srfsh.c
93         self.add_command(command = 'open', handler = Srfsh.handle_connect)
94         self.add_command(command = 'close', handler = Srfsh.handle_disconnect)
95
96     def open_script(self, script_file):
97         ''' Opens the script file and redirects the contents to STDIN for reading. '''
98
99         try:
100             script = open(script_file, 'r')
101             os.dup2(script.fileno(), sys.stdin.fileno())
102             script.close()
103         except Exception, e:
104             self.report_error("Error opening script file '%s': %s" % (script_file, str(e)))
105             raise e
106
107
108     def main_loop(self):
109         ''' Main listen loop. '''
110
111         self.set_vars()
112         self.do_connect()
113         self.load_plugins()
114         self.setup_readline()
115
116         while True:
117
118             try:
119                 self.report("", True)
120                 line = raw_input("srfsh# ")
121
122                 if not len(line): 
123                     continue
124
125                 if re.search('^\s*#', line): # ignore lines starting with #
126                     continue
127
128                 if str.lower(line) == 'exit' or str.lower(line) == 'quit': 
129                     break
130
131                 parts = str.split(line)
132                 command = parts.pop(0)
133
134                 if command not in self.command_map:
135                     self.report("unknown command: '%s'\n" % command)
136                     continue
137
138                 self.command_map[command](self, parts)
139
140             except EOFError: # ctrl-d
141                 break
142
143             except KeyboardInterrupt: # ctrl-c
144                 self.report("\n")
145
146             except Exception, e:
147                 self.report("%s\n" % traceback.format_exc())
148
149         self.cleanup()
150
151     def handle_connect(self, parts):
152         ''' Opens a connected session to an opensrf service '''
153
154         if len(parts) == 0:
155             self.report("usage: connect <service>")
156             return
157
158         service = parts.pop(0)
159
160         if self.active_session:
161             if self.active_session['service'] == service:
162                 return # use the existing active session
163             else:
164                 # currently, we only support one active session at a time
165                 self.handle_disconnect([self.active_session['service']])
166
167         self.active_session = {
168             'ses' : osrf.ses.ClientSession(service, locale = self.__get_locale()),
169             'service' : service
170         }
171
172         self.active_session['ses'].connect()
173
174     def handle_disconnect(self, parts):
175         ''' Disconnects the currently active session. '''
176
177         if len(parts) == 0:
178             self.report("usage: disconnect <service>")
179             return
180
181         service = parts.pop(0)
182
183         if self.active_session:
184             if self.active_session['service'] == service:
185                 self.active_session['ses'].disconnect()
186                 self.active_session['ses'].cleanup()
187                 self.active_session = None
188             else:
189                 self.report_error("There is no open connection for service '%s'" % service)
190
191     def handle_introspect(self, parts):
192         ''' Introspect an opensrf service. '''
193
194         if len(parts) == 0:
195             self.report("usage: introspect <service> [api_prefix]\n")
196             return
197
198         service = parts.pop(0)
199         args = [service, 'opensrf.system.method']
200
201         if len(parts) > 0:
202             api_pfx = parts[0]
203             if api_pfx[0] != '"': # json-encode if necessary
204                 api_pfx = '"%s"' % api_pfx
205             args.append(api_pfx)
206         else:
207             args[1] += '.all'
208
209         return self.handle_request(args)
210
211
212     def handle_router(self, parts):
213         ''' Send requests to the router. '''
214
215         if len(parts) == 0:
216             self.report("usage: router <query>\n")
217             return
218
219         query = parts[0]
220
221         if query not in self.router_command_map:
222             self.report("router query options: %s\n" % ','.join(self.router_command_map.keys()))
223             return
224
225         return self.handle_request(['router', self.router_command_map[query]])
226
227     def handle_set(self, parts):
228         ''' Set env variables to control srfsh behavior. '''
229
230         cmd = "".join(parts)
231         pattern = re.compile('(.*)=(.*)').match(cmd)
232         key = pattern.group(1)
233         val = pattern.group(2)
234         self.set_var(key, val)
235         self.report("%s = %s\n" % (key, val))
236
237     def handle_get(self, parts):
238         ''' Returns environment variable value '''
239         try:
240             self.report("%s=%s\n" % (parts[0], self.get_var(parts[0])))
241         except:
242             self.report("\n")
243
244
245     def handle_help(self, foo):
246         ''' Prints help info '''
247         self.report(__doc__)
248
249     def handle_request(self, parts):
250         ''' Performs an OpenSRF request and reports the results. '''
251
252         if len(parts) < 2:
253             self.report("usage: request <service> <api_name> [<param1>, <param2>, ...]\n")
254             return
255
256         self.report("\n")
257
258         service = parts.pop(0)
259         method = parts.pop(0)
260         locale = self.__get_locale()
261         jstr = '[%s]' % "".join(parts)
262         params = None
263
264         try:
265             params = osrf.json.to_object(jstr)
266         except:
267             self.report("Error parsing JSON: %s\n" % jstr)
268             return
269
270         using_active = False
271         if self.active_session and self.active_session['service'] == service:
272             # if we have an open connection to the same service, use it
273             ses = self.active_session['ses']
274             using_active = True
275         else:
276             ses = osrf.ses.ClientSession(service, locale=locale)
277
278         start = time.time()
279
280         req = ses.request2(method, tuple(params))
281
282         last_content = None
283         total = 0
284         while True:
285             resp = None
286
287             try:
288                 resp = req.recv(timeout=self.timeout)
289             except osrf.net.XMPPNoRecipient:
290                 self.report("Unable to communicate with %s\n" % service)
291                 break
292             except osrf.ex.OSRFServiceException, e:
293                 self.report("Server exception occurred: %s" % e)
294                 break
295
296             total = time.time() - start
297
298             if not resp: break
299
300             content = resp.content()
301
302             if content is not None:
303                 last_content = content
304                 if self.get_var('SRFSH_OUTPUT_NET_OBJ_KEYS') == 'true':
305                     self.report("Received Data: %s\n" % osrf.json.debug_net_object(content))
306                 else:
307                     if self.get_var('SRFSH_OUTPUT_FORMAT_JSON') == 'true':
308                         self.report("Received Data: %s\n" % osrf.json.pprint(osrf.json.to_json(content)))
309                     else:
310                         self.report("Received Data: %s\n" % osrf.json.to_json(content))
311
312         req.cleanup()
313         if not using_active:
314             ses.cleanup()
315
316         self.report("\n" + '-'*60 + "\n")
317         self.report("Total request time: %f\n" % total)
318         self.report('-'*60 + "\n")
319
320         return last_content
321
322
323     def handle_math_bench(self, parts):
324         ''' Sends a series of request to the opensrf.math service and collects timing stats. '''
325
326         count = int(parts.pop(0))
327         ses = osrf.ses.ClientSession('opensrf.math')
328         times = []
329
330         for cnt in range(100):
331             if cnt % 10:
332                 sys.stdout.write('.')
333             else:
334                 sys.stdout.write( str( cnt / 10 ) )
335         print ""
336
337         for cnt in range(count):
338         
339             starttime = time.time()
340             req = ses.request('add', 1, 2)
341             resp = req.recv(timeout=2)
342             endtime = time.time()
343         
344             if resp.content() == 3:
345                 sys.stdout.write("+")
346                 sys.stdout.flush()
347                 times.append( endtime - starttime )
348             else:
349                 print "What happened? %s" % str(resp.content())
350         
351             req.cleanup()
352             if not ( (cnt + 1) % 100):
353                 print ' [%d]' % (cnt + 1)
354         
355         ses.cleanup()
356         total = 0
357         for cnt in times:
358             total += cnt 
359         print "\naverage time %f" % (total / len(times))
360
361
362
363
364     def setup_readline(self):
365         ''' Initialize readline history and tab completion. '''
366
367         class SrfshCompleter(object):
368
369             def __init__(self, words):
370                 self.words = words
371                 self.prefix = None
372         
373             def complete(self, prefix, index):
374
375                 if prefix != self.prefix:
376
377                     self.prefix = prefix
378
379                     # find all words that start with this prefix
380                     self.matching_words = [
381                         w for w in self.words if w.startswith(prefix)
382                     ]
383
384                     if len(self.matching_words) == 0:
385                         return None
386
387                     if len(self.matching_words) == 1:
388                         return self.matching_words[0]
389
390                     # re-print the prompt w/ all of the possible word completions
391                     sys.stdout.write('\n%s\nsrfsh# %s' % 
392                         (' '.join(self.matching_words), readline.get_line_buffer()))
393
394                     return None
395
396         completer = SrfshCompleter(tuple(self.tab_complete_words))
397         readline.parse_and_bind("tab: complete")
398         readline.set_completer(completer.complete)
399
400         histfile = os.path.join(self.get_var('HOME'), ".srfsh_history")
401         try:
402             readline.read_history_file(histfile)
403         except IOError:
404             pass
405         atexit.register(readline.write_history_file, histfile)
406
407         readline.set_completer_delims(readline.get_completer_delims().replace('-',''))
408
409
410     def do_connect(self):
411         ''' Connects this instance to the OpenSRF network. '''
412
413         osrf.ses.Session.ingress('srfsh')
414         file = os.path.join(self.get_var('HOME'), ".srfsh.xml")
415         osrf.system.System.net_connect(config_file=file, config_context='srfsh')
416
417     def add_command(self, **kwargs):
418         ''' Adds a new command to the supported srfsh commands.
419
420         Command is also added to the tab-completion word bank.
421
422         kwargs :
423             command : the command name
424             handler : reference to a two-argument function.  
425                 Arguments are Srfsh instance and command arguments.
426         '''
427
428         command = kwargs['command']
429         self.command_map[command] = kwargs['handler']
430         self.tab_complete_words.append(command)
431
432
433     def load_plugins(self):
434         ''' Load plugin modules from the srfsh configuration file '''
435
436         try:
437             plugins = osrf.conf.get('plugins.plugin')
438         except:
439             return
440
441         if not isinstance(plugins, list):
442             plugins = [plugins]
443
444         for plugin in plugins:
445             module = plugin['module']
446             init = plugin.get('init', 'load')
447             self.report("Loading module %s..." % module, True, True)
448
449             try:
450                 mod = __import__(module, fromlist=' ')
451                 getattr(mod, init)(self, plugin)
452                 self.report("OK.\n", True, True)
453
454             except Exception, e:
455                 self.report_error("Error importing plugin '%s' : %s\n" % (module, traceback.format_exc()))
456
457     def cleanup(self):
458         ''' Disconnects from opensrf. '''
459         osrf.system.System.net_disconnect()
460
461     def report_error(self, msg):
462         ''' Log to stderr. '''
463         sys.stderr.write("%s\n" % msg)
464         sys.stderr.flush()
465         
466     def report(self, text, flush=False, no_page=False):
467         ''' Logs to the pager or stdout, depending on env vars and context '''
468
469         if self.reading_script or no_page or self.get_var('SRFSH_OUTPUT_PAGED') != 'true':
470             sys.stdout.write(text)
471             if flush:
472                 sys.stdout.flush()
473         else:
474             self.output_buffer += text
475
476             if flush and self.output_buffer != '':
477                 pipe = os.popen('less -EX', 'w') 
478                 pipe.write(self.output_buffer)
479                 pipe.close()
480                 self.output_buffer = ''
481
482     def set_vars(self):
483         ''' Set defaults for environment variables. '''
484
485         if not self.get_var('SRFSH_OUTPUT_NET_OBJ_KEYS'):
486             self.set_var('SRFSH_OUTPUT_NET_OBJ_KEYS', 'false')
487
488         if not self.get_var('SRFSH_OUTPUT_FORMAT_JSON'):
489             self.set_var('SRFSH_OUTPUT_FORMAT_JSON', 'true')
490
491         if not self.get_var('SRFSH_OUTPUT_PAGED'):
492             self.set_var('SRFSH_OUTPUT_PAGED', 'true')
493
494         # XXX Do we need to differ between LANG and LC_MESSAGES?
495         if not self.get_var('SRFSH_LOCALE'):
496             self.set_var('SRFSH_LOCALE', self.get_var('LC_ALL'))
497
498     def set_var(self, key, val):
499         ''' Sets an environment variable's value. '''
500         os.environ[key] = val
501
502     def get_var(self, key):
503         ''' Returns an environment variable's value. '''
504         return os.environ.get(key, '')
505         
506     def __get_locale(self):
507         """
508         Return the defined locale for this srfsh session.
509
510         A locale in OpenSRF is currently defined as a [a-z]{2}-[A-Z]{2} pattern.
511         This function munges the LC_ALL setting to conform to that pattern; for
512         example, trimming en_CA.UTF-8 to en-CA.
513
514         >>> import srfsh
515         >>> shell = srfsh.Srfsh()
516         >>> shell.set_var('SRFSH_LOCALE', 'zz-ZZ')
517         >>> print shell.__get_locale()
518         zz-ZZ
519         >>> shell.set_var('SRFSH_LOCALE', 'en_CA.UTF-8')
520         >>> print shell.__get_locale()
521         en-CA
522         """
523
524         env_locale = self.get_var('SRFSH_LOCALE')
525         if env_locale:
526             pattern = re.compile(r'^\s*([a-z]+)[^a-zA-Z]([A-Z]+)').search(env_locale)
527             lang = pattern.group(1)
528             region = pattern.group(2)
529             locale = "%s-%s" % (lang, region)
530         else:
531             locale = 'en-US'
532
533         return locale
534     
535 if __name__ == '__main__':
536     script = sys.argv[1] if len(sys.argv) > 1 else None
537     Srfsh(script).main_loop()
538