]> git.evergreen-ils.org Git - OpenSRF.git/blob - src/python/srfsh.py
Python srfsh enhancements
[OpenSRF.git] / src / python / srfsh.py
1 #!/usr/bin/python
2 # vim:et:ts=4
3 """
4 srfsh.py - provides a basic shell for issuing OpenSRF requests
5
6   help
7     - show this menu
8
9   math_bench <count>
10     - runs <count> opensrf.math requests and prints the average time
11
12   request <service> <method> [<param1>, <param2>, ...]
13     - performs an opensrf request
14     - parameters are JSON strings
15
16   router <query>
17     - Queries the router.  Query options: services service-stats service-nodes
18
19   introspect <service> [<api_name_prefix>]
20     - List API calls for a service.  
21     - api_name_prefix is a bare string or JSON string.
22
23   set VAR=<value>
24     - sets an environment variable
25
26   get VAR
27     - Returns the value for the environment variable
28
29   Environment variables:
30     SRFSH_OUTPUT_NET_OBJ_KEYS  = true - If a network object is array-encoded and key registry exists for the object type, annotate the object with field names
31                                = false - Print JSON
32
33     SRFSH_OUTPUT_FORMAT_JSON = true - Use JSON pretty printer
34                              = false - Print raw JSON
35
36     SRFSH_OUTPUT_PAGED = true - Paged output.  Uses "less -EX"
37                        = false - Output is not paged
38
39     SRFSH_LOCALE = <locale> - request responses to be returned in locale <locale> if available
40
41 """
42 import os, sys, time, readline, atexit, re, pydoc, traceback
43 import osrf.json, osrf.system, osrf.ses, osrf.conf, osrf.log, osrf.net
44
45 class Srfsh(object):
46
47
48     def __init__(self, script_file=None):
49
50         # used for paging
51         self.output_buffer = '' 
52
53         # true if invoked with a script file 
54         self.reading_script = False
55
56         # multi-request sessions
57         self.active_session = None 
58
59         # default opensrf request timeout
60         self.timeout = 120
61
62         # map of command name to handler
63         self.command_map = {}
64
65         if script_file:
66             self.open_script(script_file)
67             self.reading_script = True
68
69         # map of router sub-commands to router API calls
70         self.router_command_map = {
71             'services'      : 'opensrf.router.info.class.list',
72             'service-stats' : 'opensrf.router.info.stats.class.node.all',
73             'service-nodes' : 'opensrf.router.info.stats.class.all'
74         }
75
76         # seed the tab completion word bank
77         self.tab_complete_words = self.router_command_map.keys() + [
78             'exit', 
79             'quit', 
80             'opensrf.settings', 
81             'opensrf.math',
82             'opensrf.dbmath',
83             'opensrf.py-example'
84         ]
85
86         # add the default commands
87         for command in ['request', 'router', 'help', 'set', 
88                 'get', 'math_bench', 'introspect', 'connect', 'disconnect' ]:
89
90             self.add_command(command = command, handler = getattr(Srfsh, 'handle_' + command))
91
92         # for compat w/ srfsh.c
93         self.add_command(command = 'open', handler = Srfsh.handle_connect)
94         self.add_command(command = 'close', handler = Srfsh.handle_disconnect)
95
96     def open_script(self, script_file):
97         ''' Opens the script file and redirects the contents to STDIN for reading. '''
98
99         try:
100             script = open(script_file, 'r')
101             os.dup2(script.fileno(), sys.stdin.fileno())
102             script.close()
103         except Exception, e:
104             self.report_error("Error opening script file '%s': %s" % (script_file, str(e)))
105             raise e
106
107
108     def main_loop(self):
109         ''' Main listen loop. '''
110
111         self.set_vars()
112         self.do_connect()
113         self.load_plugins()
114         self.setup_readline()
115
116         while True:
117
118             try:
119                 self.report("", True)
120                 line = raw_input("srfsh# ")
121
122                 if not len(line): 
123                     continue
124
125                 if re.search('^\s*#', line): # ignore lines starting with #
126                     continue
127
128                 if str.lower(line) == 'exit' or str.lower(line) == 'quit': 
129                     break
130
131                 parts = str.split(line)
132                 command = parts.pop(0)
133
134                 if command not in self.command_map:
135                     self.report("unknown command: '%s'\n" % command)
136                     continue
137
138                 self.command_map[command](self, parts)
139
140             except EOFError: # ctrl-d
141                 break
142
143             except KeyboardInterrupt: # ctrl-c
144                 self.report("\n")
145
146             except Exception, e:
147                 self.report("%s\n" % traceback.format_exc())
148
149         self.cleanup()
150
151     def handle_connect(self, parts):
152         ''' Opens a connected session to an opensrf service '''
153
154         if len(parts) == 0:
155             self.report("usage: connect <service>")
156             return
157
158         service = parts.pop(0)
159
160         if self.active_session:
161             if self.active_session['service'] == service:
162                 return # use the existing active session
163             else:
164                 # currently, we only support one active session at a time
165                 self.handle_disconnect([self.active_session['service']])
166
167         self.active_session = {
168             'ses' : osrf.ses.ClientSession(service, locale = self.__get_locale()),
169             'service' : service
170         }
171
172         self.active_session['ses'].connect()
173
174     def handle_disconnect(self, parts):
175         ''' Disconnects the currently active session. '''
176
177         if len(parts) == 0:
178             self.report("usage: disconnect <service>")
179             return
180
181         service = parts.pop(0)
182
183         if self.active_session:
184             if self.active_session['service'] == service:
185                 self.active_session['ses'].disconnect()
186                 self.active_session['ses'].cleanup()
187                 self.active_session = None
188             else:
189                 self.report_error("There is no open connection for service '%s'" % service)
190
191     def handle_introspect(self, parts):
192         ''' Introspect an opensrf service. '''
193
194         if len(parts) == 0:
195             self.report("usage: introspect <service> [api_prefix]\n")
196             return
197
198         service = parts.pop(0)
199         args = [service, 'opensrf.system.method']
200
201         if len(parts) > 0:
202             api_pfx = parts[0]
203             if api_pfx[0] != '"': # json-encode if necessary
204                 api_pfx = '"%s"' % api_pfx
205             args.append(api_pfx)
206         else:
207             args[1] += '.all'
208
209         return handle_request(args)
210
211
212     def handle_router(self, parts):
213         ''' Send requests to the router. '''
214
215         if len(parts) == 0:
216             self.report("usage: router <query>\n")
217             return
218
219         query = parts[0]
220
221         if query not in self.router_command_map:
222             self.report("router query options: %s\n" % ','.join(self.router_command_map.keys()))
223             return
224
225         return handle_request(['router', self.router_command_map[query]])
226
227     def handle_set(self, parts):
228         ''' Set env variables to control srfsh behavior. '''
229
230         cmd = "".join(parts)
231         pattern = re.compile('(.*)=(.*)').match(cmd)
232         key = pattern.group(1)
233         val = pattern.group(2)
234         self.set_var(key, val)
235         self.report("%s = %s\n" % (key, val))
236
237     def handle_get(self, parts):
238         ''' Returns environment variable value '''
239         try:
240             self.report("%s=%s\n" % (parts[0], self.get_var(parts[0])))
241         except:
242             self.report("\n")
243
244
245     def handle_help(self, foo):
246         ''' Prints help info '''
247         self.report(__doc__)
248
249     def handle_request(self, parts):
250         ''' Performs an OpenSRF request and reports the results. '''
251
252         if len(parts) < 2:
253             self.report("usage: request <service> <api_name> [<param1>, <param2>, ...]\n")
254             return
255
256         self.report("\n")
257
258         service = parts.pop(0)
259         method = parts.pop(0)
260         locale = self.__get_locale()
261         jstr = '[%s]' % "".join(parts)
262         params = None
263
264         try:
265             params = osrf.json.to_object(jstr)
266         except:
267             self.report("Error parsing JSON: %s\n" % jstr)
268             return
269
270         using_active = False
271         if self.active_session and self.active_session['service'] == service:
272             # if we have an open connection to the same service, use it
273             ses = self.active_session['ses']
274             using_active = True
275         else:
276             ses = osrf.ses.ClientSession(service, locale=locale)
277
278         start = time.time()
279
280         req = ses.request2(method, tuple(params))
281
282         last_content = None
283         while True:
284             resp = None
285
286             try:
287                 resp = req.recv(timeout=self.timeout)
288             except osrf.net.XMPPNoRecipient:
289                 self.report("Unable to communicate with %s\n" % service)
290                 total = 0
291                 break
292
293             if not resp: break
294
295             total = time.time() - start
296             content = resp.content()
297
298             if content is not None:
299                 last_content = content
300                 if self.get_var('SRFSH_OUTPUT_NET_OBJ_KEYS') == 'true':
301                     self.report("Received Data: %s\n" % osrf.json.debug_net_object(content))
302                 else:
303                     if self.get_var('SRFSH_OUTPUT_FORMAT_JSON') == 'true':
304                         self.report("Received Data: %s\n" % osrf.json.pprint(osrf.json.to_json(content)))
305                     else:
306                         self.report("Received Data: %s\n" % osrf.json.to_json(content))
307
308         req.cleanup()
309         if not using_active:
310             ses.cleanup()
311
312         self.report("\n" + '-'*60 + "\n")
313         self.report("Total request time: %f\n" % total)
314         self.report('-'*60 + "\n")
315
316         return last_content
317
318
319     def handle_math_bench(self, parts):
320         ''' Sends a series of request to the opensrf.math service and collects timing stats. '''
321
322         count = int(parts.pop(0))
323         ses = osrf.ses.ClientSession('opensrf.math')
324         times = []
325
326         for cnt in range(100):
327             if cnt % 10:
328                 sys.stdout.write('.')
329             else:
330                 sys.stdout.write( str( cnt / 10 ) )
331         print ""
332
333         for cnt in range(count):
334         
335             starttime = time.time()
336             req = ses.request('add', 1, 2)
337             resp = req.recv(timeout=2)
338             endtime = time.time()
339         
340             if resp.content() == 3:
341                 sys.stdout.write("+")
342                 sys.stdout.flush()
343                 times.append( endtime - starttime )
344             else:
345                 print "What happened? %s" % str(resp.content())
346         
347             req.cleanup()
348             if not ( (cnt + 1) % 100):
349                 print ' [%d]' % (cnt + 1)
350         
351         ses.cleanup()
352         total = 0
353         for cnt in times:
354             total += cnt 
355         print "\naverage time %f" % (total / len(times))
356
357
358
359
360     def setup_readline(self):
361         ''' Initialize readline history and tab completion. '''
362
363         class SrfshCompleter(object):
364
365             def __init__(self, words):
366                 self.words = words
367                 self.prefix = None
368         
369             def complete(self, prefix, index):
370
371                 if prefix != self.prefix:
372
373                     self.prefix = prefix
374
375                     # find all words that start with this prefix
376                     self.matching_words = [
377                         w for w in self.words if w.startswith(prefix)
378                     ]
379
380                     if len(self.matching_words) == 0:
381                         return None
382
383                     if len(self.matching_words) == 1:
384                         return self.matching_words[0]
385
386                     # re-print the prompt w/ all of the possible word completions
387                     sys.stdout.write('\n%s\nsrfsh# %s' % 
388                         (' '.join(self.matching_words), readline.get_line_buffer()))
389
390                     return None
391
392         completer = SrfshCompleter(tuple(self.tab_complete_words))
393         readline.parse_and_bind("tab: complete")
394         readline.set_completer(completer.complete)
395
396         histfile = os.path.join(self.get_var('HOME'), ".srfsh_history")
397         try:
398             readline.read_history_file(histfile)
399         except IOError:
400             pass
401         atexit.register(readline.write_history_file, histfile)
402
403         readline.set_completer_delims(readline.get_completer_delims().replace('-',''))
404
405
406     def do_connect(self):
407         ''' Connects this instance to the OpenSRF network. '''
408
409         file = os.path.join(self.get_var('HOME'), ".srfsh.xml")
410         osrf.system.System.connect(config_file=file, config_context='srfsh')
411
412     def add_command(self, **kwargs):
413         ''' Adds a new command to the supported srfsh commands.
414
415         Command is also added to the tab-completion word bank.
416
417         kwargs :
418             command : the command name
419             handler : reference to a two-argument function.  
420                 Arguments are Srfsh instance and command arguments.
421         '''
422
423         command = kwargs['command']
424         self.command_map[command] = kwargs['handler']
425         self.tab_complete_words.append(command)
426
427
428     def load_plugins(self):
429         ''' Load plugin modules from the srfsh configuration file '''
430
431         try:
432             plugins = osrf.conf.get('plugins.plugin')
433         except:
434             return
435
436         if not isinstance(plugins, list):
437             plugins = [plugins]
438
439         for plugin in plugins:
440             module = plugin['module']
441             init = plugin.get('init', 'load')
442             self.report("Loading module %s..." % module, True, True)
443
444             try:
445                 mod = __import__(module, fromlist=' ')
446                 getattr(mod, init)(self, plugin)
447                 self.report("OK.\n", True, True)
448
449             except Exception, e:
450                 self.report_error("Error importing plugin '%s' : %s\n" % (module, traceback.format_exc()))
451
452     def cleanup(self):
453         ''' Disconnects from opensrf. '''
454         osrf.system.System.net_disconnect()
455
456     def report_error(self, msg):
457         ''' Log to stderr. '''
458         sys.stderr.write("%s\n" % msg)
459         sys.stderr.flush()
460         
461     def report(self, text, flush=False, no_page=False):
462         ''' Logs to the pager or stdout, depending on env vars and context '''
463
464         if self.reading_script or no_page or self.get_var('SRFSH_OUTPUT_PAGED') != 'true':
465             sys.stdout.write(text)
466             if flush:
467                 sys.stdout.flush()
468         else:
469             self.output_buffer += text
470
471             if flush and self.output_buffer != '':
472                 pipe = os.popen('less -EX', 'w') 
473                 pipe.write(self.output_buffer)
474                 pipe.close()
475                 self.output_buffer = ''
476
477     def set_vars(self):
478         ''' Set defaults for environment variables. '''
479
480         if not self.get_var('SRFSH_OUTPUT_NET_OBJ_KEYS'):
481             self.set_var('SRFSH_OUTPUT_NET_OBJ_KEYS', 'false')
482
483         if not self.get_var('SRFSH_OUTPUT_FORMAT_JSON'):
484             self.set_var('SRFSH_OUTPUT_FORMAT_JSON', 'true')
485
486         if not self.get_var('SRFSH_OUTPUT_PAGED'):
487             self.set_var('SRFSH_OUTPUT_PAGED', 'true')
488
489         # XXX Do we need to differ between LANG and LC_MESSAGES?
490         if not self.get_var('SRFSH_LOCALE'):
491             self.set_var('SRFSH_LOCALE', self.get_var('LC_ALL'))
492
493     def set_var(self, key, val):
494         ''' Sets an environment variable's value. '''
495         os.environ[key] = val
496
497     def get_var(self, key):
498         ''' Returns an environment variable's value. '''
499         return os.environ.get(key, '')
500         
501     def __get_locale(self):
502         """
503         Return the defined locale for this srfsh session.
504
505         A locale in OpenSRF is currently defined as a [a-z]{2}-[A-Z]{2} pattern.
506         This function munges the LC_ALL setting to conform to that pattern; for
507         example, trimming en_CA.UTF-8 to en-CA.
508
509         >>> import srfsh
510         >>> shell = srfsh.Srfsh()
511         >>> shell.set_var('SRFSH_LOCALE', 'zz-ZZ')
512         >>> print shell.__get_locale()
513         zz-ZZ
514         >>> shell.set_var('SRFSH_LOCALE', 'en_CA.UTF-8')
515         >>> print shell.__get_locale()
516         en-CA
517         """
518
519         env_locale = self.get_var('SRFSH_LOCALE')
520         if env_locale:
521             pattern = re.compile(r'^\s*([a-z]+)[^a-zA-Z]([A-Z]+)').search(env_locale)
522             lang = pattern.group(1)
523             region = pattern.group(2)
524             locale = "%s-%s" % (lang, region)
525         else:
526             locale = 'en-US'
527
528         return locale
529     
530 if __name__ == '__main__':
531     script = sys.argv[1] if len(sys.argv) > 1 else None
532     Srfsh(script).main_loop()
533