Added support for the opensrf XID, which is a transaction string that's passed
[OpenSRF.git] / src / router / osrf_router.c
1 #include "osrf_router.h"
2
3 #define ROUTER_SOCKFD connection->session->sock_id
4 #define ROUTER_REGISTER "register"
5 #define ROUTER_UNREGISTER "unregister"
6
7
8 #define ROUTER_REQUEST_CLASS_LIST "opensrf.router.info.class.list"
9
10 osrfRouter* osrfNewRouter( 
11                 char* domain, char* name, 
12                 char* resource, char* password, int port, 
13                 osrfStringArray* trustedClients, osrfStringArray* trustedServers ) {
14
15         if(!( domain && name && resource && password && port && trustedClients && trustedServers )) return NULL;
16
17         osrfRouter* router      = safe_malloc(sizeof(osrfRouter));
18         router->domain                  = strdup(domain);
19         router->name                    = strdup(name);
20         router->password                = strdup(password);
21         router->resource                = strdup(resource);
22         router->port                    = port;
23
24         router->trustedClients = trustedClients;
25         router->trustedServers = trustedServers;
26
27         
28         router->classes = osrfNewHash(); 
29         router->classes->freeItem = &osrfRouterClassFree;
30
31         router->connection = client_init( domain, port, NULL, 0 );
32
33         return router;
34 }
35
36
37
38 int osrfRouterConnect( osrfRouter* router ) {
39         if(!router) return -1;
40         int ret = client_connect( router->connection, router->name, 
41                         router->password, router->resource, 10, AUTH_DIGEST );
42         if( ret == 0 ) return -1;
43         return 0;
44 }
45
46
47 void osrfRouterRun( osrfRouter* router ) {
48         if(!(router && router->classes)) return;
49
50         int routerfd = router->ROUTER_SOCKFD;
51         int selectret = 0;
52
53         while(1) {
54
55                 fd_set set;
56                 int maxfd = __osrfRouterFillFDSet( router, &set );
57                 int numhandled = 0;
58
59                 if( (selectret = select(maxfd + 1, &set, NULL, NULL, NULL)) < 0 ) {
60                         osrfLogWarning( OSRF_LOG_MARK, "Top level select call failed with errno %d", errno);
61                         continue;
62                 }
63
64                 /* see if there is a top level router message */
65
66                 if( FD_ISSET(routerfd, &set) ) {
67                         osrfLogDebug( OSRF_LOG_MARK, "Top router socket is active: %d", routerfd );
68                         numhandled++;
69                         osrfRouterHandleIncoming( router );
70                 }
71
72
73                 /* now check each of the connected classes and see if they have data to route */
74                 while( numhandled < selectret ) {
75
76                         osrfRouterClass* class;
77                         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
78
79                         while( (class = osrfHashIteratorNext(itr)) ) {
80
81                                 char* classname = itr->current;
82
83                                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
84
85                                         osrfLogDebug( OSRF_LOG_MARK, "Checking %s for activity...", classname );
86
87                                         int sockfd = class->ROUTER_SOCKFD;
88                                         if(FD_ISSET( sockfd, &set )) {
89                                                 osrfLogDebug( OSRF_LOG_MARK, "Socket is active: %d", sockfd );
90                                                 numhandled++;
91                                                 osrfRouterClassHandleIncoming( router, classname, class );
92                                         }
93                                 }
94                         }
95
96                         osrfHashIteratorFree(itr);
97                 }
98         }
99 }
100
101
102 void osrfRouterHandleIncoming( osrfRouter* router ) {
103         if(!router) return;
104
105         transport_message* msg = NULL;
106
107         //if( (msg = client_recv( router->connection, 0 )) ) { 
108         while( (msg = client_recv( router->connection, 0 )) ) { 
109
110                 if( msg->sender ) {
111
112                         osrfLogDebug(OSRF_LOG_MARK, 
113                                 "osrfRouterHandleIncoming(): investigating message from %s", msg->sender);
114
115                         /* if the sender is not a trusted server, drop the message */
116                         int len = strlen(msg->sender) + 1;
117                         char domain[len];
118                         bzero(domain, len);
119                         jid_get_domain( msg->sender, domain, len - 1 );
120
121                         if(osrfStringArrayContains( router->trustedServers, domain)) 
122                                 osrfRouterHandleMessage( router, msg );
123                          else 
124                                 osrfLogWarning( OSRF_LOG_MARK, "Received message from un-trusted server domain %s", msg->sender);
125                 }
126
127                 message_free(msg);
128         }
129 }
130
131 int osrfRouterClassHandleIncoming( osrfRouter* router, char* classname, osrfRouterClass* class ) {
132         if(!(router && class)) return -1;
133
134         transport_message* msg;
135         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleIncoming()");
136
137         while( (msg = client_recv( class->connection, 0 )) ) {
138
139       osrfLogSetXid(msg->osrf_xid);
140
141                 if( msg->sender ) {
142
143                         osrfLogDebug(OSRF_LOG_MARK, 
144                                 "osrfRouterClassHandleIncoming(): investigating message from %s", msg->sender);
145
146                         /* if the client is not from a trusted domain, drop the message */
147                         int len = strlen(msg->sender) + 1;
148                         char domain[len];
149                         bzero(domain, len);
150                         jid_get_domain( msg->sender, domain, len - 1 );
151
152                         if(osrfStringArrayContains( router->trustedClients, domain)) {
153
154                                 transport_message* bouncedMessage = NULL;
155                                 if( msg->is_error )  {
156
157                                         /* handle bounced message */
158                                         if( !(bouncedMessage = osrfRouterClassHandleBounce( router, classname, class, msg )) ) 
159                                                 return -1; /* we have no one to send the requested message to */
160
161                                         message_free( msg );
162                                         msg = bouncedMessage;
163                                 }
164                                 osrfRouterClassHandleMessage( router, class, msg );
165
166                         } else {
167                                 osrfLogWarning( OSRF_LOG_MARK, "Received client message from untrusted client domain %s", domain );
168                         }
169                 }
170
171       osrfLogClearXid();
172                 message_free( msg );
173         }
174
175         return 0;
176 }
177
178
179
180
181 int osrfRouterHandleMessage( osrfRouter* router, transport_message* msg ) {
182         if(!(router && msg)) return -1;
183
184         if( !msg->router_command || !strcmp(msg->router_command,"")) 
185                 return osrfRouterHandleAppRequest( router, msg ); /* assume it's an app session level request */
186
187         if(!msg->router_class) return -1;
188
189         osrfRouterClass* class = NULL;
190         if(!strcmp(msg->router_command, ROUTER_REGISTER)) {
191                 class = osrfRouterFindClass( router, msg->router_class );
192
193                 osrfLogInfo( OSRF_LOG_MARK, "Registering class %s", msg->router_class );
194
195                 if(!class) class = osrfRouterAddClass( router, msg->router_class );
196
197                 if(class) { 
198
199                         if( osrfRouterClassFindNode( class, msg->sender ) )
200                                 return 0;
201                         else 
202                                 osrfRouterClassAddNode( class, msg->sender );
203
204                 } 
205
206         } else if( !strcmp( msg->router_command, ROUTER_UNREGISTER ) ) {
207
208                 if( msg->router_class && strcmp( msg->router_class, "") ) {
209                         osrfLogInfo( OSRF_LOG_MARK, "Unregistering router class %s", msg->router_class );
210                         osrfRouterClassRemoveNode( router, msg->router_class, msg->sender );
211                 }
212         }
213
214         return 0;
215 }
216
217
218
219 osrfRouterClass* osrfRouterAddClass( osrfRouter* router, char* classname ) {
220         if(!(router && router->classes && classname)) return NULL;
221
222         osrfRouterClass* class = safe_malloc(sizeof(osrfRouterClass));
223         class->nodes = osrfNewHash();
224         class->itr = osrfNewHashIterator(class->nodes);
225         class->nodes->freeItem = &osrfRouterNodeFree;
226         class->router   = router;
227
228         class->connection = client_init( router->domain, router->port, NULL, 0 );
229
230         if(!client_connect( class->connection, router->name, 
231                         router->password, classname, 10, AUTH_DIGEST ) ) {
232                 osrfRouterClassFree( classname, class );
233                 return NULL;
234         }
235         
236         osrfHashSet( router->classes, class, classname );
237         return class;
238 }
239
240
241 int osrfRouterClassAddNode( osrfRouterClass* rclass, char* remoteId ) {
242         if(!(rclass && rclass->nodes && remoteId)) return -1;
243
244         osrfLogInfo( OSRF_LOG_MARK, "Adding router node for remote id %s", remoteId );
245
246         osrfRouterNode* node = safe_malloc(sizeof(osrfRouterNode));
247         node->count = 0;
248         node->lastMessage = NULL;
249         node->remoteId = strdup(remoteId);
250
251         osrfHashSet( rclass->nodes, node, remoteId );
252         return 0;
253 }
254
255 /* copy off the lastMessage, remove the offending node, send error if it's tht last node 
256         ? return NULL if it's the last node ?
257  */
258
259 transport_message* osrfRouterClassHandleBounce( 
260                 osrfRouter* router, char* classname, osrfRouterClass* rclass, transport_message* msg ) {
261
262         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleBounce()");
263
264         osrfLogInfo( OSRF_LOG_MARK, "Received network layer error message from %s", msg->sender );
265         osrfRouterNode* node = osrfRouterClassFindNode( rclass, msg->sender );
266         transport_message* lastSent = NULL;
267
268         if( node && osrfHashGetCount(rclass->nodes) == 1 ) { /* the last node is dead */
269
270                 if( node->lastMessage ) {
271                         osrfLogWarning( OSRF_LOG_MARK, "We lost the last node in the class, responding with error and removing...");
272         
273                         transport_message* error = message_init( 
274                                 node->lastMessage->body, node->lastMessage->subject, 
275                                 node->lastMessage->thread, node->lastMessage->router_from, node->lastMessage->recipient );
276          message_set_osrf_xid(error, node->lastMessage->osrf_xid);
277                         set_msg_error( error, "cancel", 501 );
278         
279                         /* send the error message back to the original sender */
280                         client_send_message( rclass->connection, error );
281                         message_free( error );
282                 }
283         
284                 return NULL;
285         
286         } else { 
287
288                 if( node ) {
289                         if( node->lastMessage ) {
290                                 osrfLogDebug( OSRF_LOG_MARK, "Cloning lastMessage so next node can send it");
291                                 lastSent = message_init( node->lastMessage->body,
292                                         node->lastMessage->subject, node->lastMessage->thread, "", node->lastMessage->router_from );
293                                 message_set_router_info( lastSent, node->lastMessage->router_from, NULL, NULL, NULL, 0 );
294             message_set_osrf_xid( lastSent, node->lastMessage->osrf_xid );
295                         }
296                 } else {
297
298                         osrfLogInfo(OSRF_LOG_MARK, "network error occurred after we removed the class.. ignoring");
299                         return NULL;
300                 }
301         }
302
303         /* remove the dead node */
304         osrfRouterClassRemoveNode( router, classname, msg->sender);
305         return lastSent;
306 }
307
308
309 /**
310   If we get a regular message, we send it to the next node in the list of nodes
311   if we get an error, it's a bounce back from a previous attempt.  We take the
312   body and thread from the last sent on the node that had the bounced message
313   and propogate them on to the new message being sent
314   */
315 int osrfRouterClassHandleMessage( 
316                 osrfRouter* router, osrfRouterClass* rclass, transport_message* msg ) {
317         if(!(router && rclass && msg)) return -1;
318
319         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleMessage()");
320
321         osrfRouterNode* node = osrfHashIteratorNext( rclass->itr );
322         if(!node) {
323                 osrfHashIteratorReset(rclass->itr);
324                 node = osrfHashIteratorNext( rclass->itr );
325         }
326
327         if(node) {
328
329                 transport_message* new_msg= message_init(       msg->body, 
330                                 msg->subject, msg->thread, node->remoteId, msg->sender );
331                 message_set_router_info( new_msg, msg->sender, NULL, NULL, NULL, 0 );
332       message_set_osrf_xid( new_msg, msg->osrf_xid );
333
334                 osrfLogInfo( OSRF_LOG_MARK,  "Routing message:\nfrom: [%s]\nto: [%s]", 
335                                 new_msg->router_from, new_msg->recipient );
336
337                 message_free( node->lastMessage );
338                 node->lastMessage = new_msg;
339
340                 if ( client_send_message( rclass->connection, new_msg ) == 0 ) 
341                         node->count++;
342
343                 else {
344                         message_prepare_xml(new_msg);
345                         osrfLogWarning( OSRF_LOG_MARK, "Error sending message from %s to %s\n%s", 
346                                         new_msg->sender, new_msg->recipient, new_msg->msg_xml );
347                 }
348
349         } 
350
351         return 0;
352 }
353
354
355 int osrfRouterRemoveClass( osrfRouter* router, char* classname ) {
356         if(!(router && router->classes && classname)) return -1;
357         osrfLogInfo( OSRF_LOG_MARK, "Removing router class %s", classname );
358         osrfHashRemove( router->classes, classname );
359         return 0;
360 }
361
362
363 int osrfRouterClassRemoveNode( 
364                 osrfRouter* router, char* classname, char* remoteId ) {
365
366         if(!(router && router->classes && classname && remoteId)) return 0;
367
368         osrfLogInfo( OSRF_LOG_MARK, "Removing router node %s", remoteId );
369
370         osrfRouterClass* class = osrfRouterFindClass( router, classname );
371
372         if( class ) {
373
374                 osrfHashRemove( class->nodes, remoteId );
375                 if( osrfHashGetCount(class->nodes) == 0 ) {
376                         osrfRouterRemoveClass( router, classname );
377                         return 1;
378                 }
379
380                 return 0;
381         }
382
383         return -1;
384 }
385
386
387 void osrfRouterClassFree( char* classname, void* c ) {
388         if(!(classname && c)) return;
389         osrfRouterClass* rclass = (osrfRouterClass*) c;
390         client_disconnect( rclass->connection );        
391         client_free( rclass->connection );      
392
393         osrfHashIteratorReset( rclass->itr );
394         osrfRouterNode* node;
395
396         while( (node = osrfHashIteratorNext(rclass->itr)) ) 
397                 osrfRouterClassRemoveNode( rclass->router, classname, node->remoteId );
398
399    osrfHashIteratorFree(rclass->itr);
400    osrfHashFree(rclass->nodes);
401
402         free(rclass);
403 }
404
405
406 void osrfRouterNodeFree( char* remoteId, void* n ) {
407         if(!n) return;
408         osrfRouterNode* node = (osrfRouterNode*) n;
409         free(node->remoteId);
410         message_free(node->lastMessage);
411         free(node);
412 }
413
414
415 void osrfRouterFree( osrfRouter* router ) {
416         if(!router) return;
417
418         free(router->domain);           
419         free(router->name);
420         free(router->resource);
421         free(router->password);
422
423         osrfStringArrayFree( router->trustedClients );
424         osrfStringArrayFree( router->trustedServers );
425
426         client_free( router->connection );
427         free(router);
428 }
429
430
431
432 osrfRouterClass* osrfRouterFindClass( osrfRouter* router, char* classname ) {
433         if(!( router && router->classes && classname )) return NULL;
434         return (osrfRouterClass*) osrfHashGet( router->classes, classname );
435 }
436
437
438 osrfRouterNode* osrfRouterClassFindNode( osrfRouterClass* rclass, char* remoteId ) {
439         if(!(rclass && remoteId))  return NULL;
440         return (osrfRouterNode*) osrfHashGet( rclass->nodes, remoteId );
441 }
442
443
444 int __osrfRouterFillFDSet( osrfRouter* router, fd_set* set ) {
445         if(!(router && router->classes && set)) return -1;
446
447         FD_ZERO(set);
448         int maxfd = router->ROUTER_SOCKFD;
449         FD_SET(maxfd, set);
450
451         int sockid;
452
453         osrfRouterClass* class = NULL;
454         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
455
456         while( (class = osrfHashIteratorNext(itr)) ) {
457                 char* classname = itr->current;
458
459                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
460                         sockid = class->ROUTER_SOCKFD;
461         
462                         if( osrfUtilsCheckFileDescriptor( sockid ) ) {
463
464                                 osrfLogWarning(OSRF_LOG_MARK, 
465                                         "Removing router class '%s' because of a bad top-level file descriptor [%d]", classname, sockid);
466                                 osrfRouterRemoveClass( router, classname );
467         
468                         } else {
469                                 if( sockid > maxfd ) maxfd = sockid;
470                                 FD_SET(sockid, set);
471                         }
472                 }
473         }
474
475         osrfHashIteratorFree(itr);
476         return maxfd;
477 }
478
479
480
481 int osrfRouterHandleAppRequest( osrfRouter* router, transport_message* msg ) {
482
483         int T = 32;
484         osrfMessage* arr[T];
485         memset(arr, 0, T );
486
487         int num_msgs = osrf_message_deserialize( msg->body, arr, T );
488         osrfMessage* omsg = NULL;
489
490         int i;
491         for( i = 0; i != num_msgs; i++ ) {
492
493                 if( !(omsg = arr[i]) ) continue;
494
495                 switch( omsg->m_type ) {
496
497                         case CONNECT:
498                                 osrfRouterRespondConnect( router, msg, omsg );
499                                 break;
500
501                         case REQUEST:
502                                 osrfRouterProcessAppRequest( router, msg, omsg );
503                                 break;
504
505                         default: break;
506                 }
507
508                 osrfMessageFree( omsg );
509         }
510
511         return 0;
512 }
513
514 int osrfRouterRespondConnect( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
515         if(!(router && msg && omsg)) return -1;
516
517         osrfMessage* success = osrf_message_init( STATUS, omsg->thread_trace, omsg->protocol );
518
519         osrfLogDebug( OSRF_LOG_MARK, "router recevied a CONNECT message from %s", msg->sender );
520
521         osrf_message_set_status_info( 
522                 success, "osrfConnectStatus", "Connection Successful", OSRF_STATUS_OK );
523
524         char* data      = osrf_message_serialize(success);
525
526         transport_message* return_m = message_init( 
527                 data, "", msg->thread, msg->sender, "" );
528
529         client_send_message(router->connection, return_m);
530
531         free(data);
532         osrf_message_free(success);
533         message_free(return_m);
534
535         return 0;
536 }
537
538
539
540 int osrfRouterProcessAppRequest( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
541
542         if(!(router && msg && omsg && omsg->method_name)) return -1;
543
544         osrfLogInfo( OSRF_LOG_MARK, "Router received app request: %s", omsg->method_name );
545
546         jsonObject* jresponse = NULL;
547         if(!strcmp( omsg->method_name, ROUTER_REQUEST_CLASS_LIST )) {
548
549                 int i;
550                 jresponse = jsonParseString("[]");
551
552                 osrfStringArray* keys = osrfHashKeys( router->classes );
553                 for( i = 0; i != keys->size; i++ )
554                         jsonObjectPush( jresponse, jsonNewObject(osrfStringArrayGetString( keys, i )) );
555                 osrfStringArrayFree(keys);
556
557
558         } else {
559
560                 return osrfRouterHandleMethodNFound( router, msg, omsg );
561         }
562
563
564         osrfRouterHandleAppResponse( router, msg, omsg, jresponse );
565         jsonObjectFree(jresponse); 
566
567         return 0;
568
569 }
570
571
572
573 int osrfRouterHandleMethodNFound( 
574                 osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
575
576         osrf_message* err = osrf_message_init( STATUS, omsg->thread_trace, 1);
577                 osrf_message_set_status_info( err, 
578                                 "osrfMethodException", "Router method not found", OSRF_STATUS_NOTFOUND );
579
580                 char* data =  osrf_message_serialize(err);
581
582                 transport_message* tresponse = message_init(
583                                 data, "", msg->thread, msg->sender, msg->recipient );
584
585                 client_send_message(router->connection, tresponse );
586
587                 free(data);
588                 osrf_message_free( err );
589                 message_free(tresponse);
590                 return 0;
591 }
592
593
594
595 int osrfRouterHandleAppResponse( osrfRouter* router, 
596         transport_message* msg, osrfMessage* omsg, jsonObject* response ) {
597
598         if( response ) { /* send the response message */
599
600                 osrfMessage* oresponse = osrf_message_init(
601                                 RESULT, omsg->thread_trace, omsg->protocol );
602         
603                 char* json = jsonObjectToJSON(response);
604                 osrf_message_set_result_content( oresponse, json);
605         
606                 char* data =  osrf_message_serialize(oresponse);
607                 osrfLogDebug( OSRF_LOG_MARK,  "Responding to client app request with data: \n%s\n", data );
608
609                 transport_message* tresponse = message_init(
610                                 data, "", msg->thread, msg->sender, msg->recipient );
611         
612                 client_send_message(router->connection, tresponse );
613
614                 osrfMessageFree(oresponse); 
615                 message_free(tresponse);
616                 free(json);
617                 free(data);
618         }
619
620
621         /* now send the 'request complete' message */
622         osrf_message* status = osrf_message_init( STATUS, omsg->thread_trace, 1);
623         osrf_message_set_status_info( status, "osrfConnectStatus", "Request Complete", OSRF_STATUS_COMPLETE );
624
625         char* statusdata = osrf_message_serialize(status);
626
627         transport_message* sresponse = message_init(
628                         statusdata, "", msg->thread, msg->sender, msg->recipient );
629         client_send_message(router->connection, sresponse );
630
631
632         free(statusdata);
633         osrfMessageFree(status);
634         message_free(sresponse);
635
636         return 0;
637 }
638
639
640
641