Fixed issue: when the router received two messages on a single socket that both fit...
[OpenSRF.git] / src / router / osrf_router.c
1 #include "osrf_router.h"
2
3 #define ROUTER_SOCKFD connection->session->sock_id
4 #define ROUTER_REGISTER "register"
5 #define ROUTER_UNREGISTER "unregister"
6
7
8 #define ROUTER_REQUEST_CLASS_LIST "opensrf.router.info.class.list"
9
10 osrfRouter* osrfNewRouter( 
11                 char* domain, char* name, 
12                 char* resource, char* password, int port, 
13                 osrfStringArray* trustedClients, osrfStringArray* trustedServers ) {
14
15         if(!( domain && name && resource && password && port && trustedClients && trustedServers )) return NULL;
16
17         osrfRouter* router      = safe_malloc(sizeof(osrfRouter));
18         router->domain                  = strdup(domain);
19         router->name                    = strdup(name);
20         router->password                = strdup(password);
21         router->resource                = strdup(resource);
22         router->port                    = port;
23
24         router->trustedClients = trustedClients;
25         router->trustedServers = trustedServers;
26
27         
28         router->classes = osrfNewHash(); 
29         router->classes->freeItem = &osrfRouterClassFree;
30
31         router->connection = client_init( domain, port, NULL, 0 );
32
33         return router;
34 }
35
36
37
38 int osrfRouterConnect( osrfRouter* router ) {
39         if(!router) return -1;
40         int ret = client_connect( router->connection, router->name, 
41                         router->password, router->resource, 10, AUTH_DIGEST );
42         if( ret == 0 ) return -1;
43         return 0;
44 }
45
46
47 void osrfRouterRun( osrfRouter* router ) {
48         if(!(router && router->classes)) return;
49
50         int routerfd = router->ROUTER_SOCKFD;
51         int selectret = 0;
52
53         while(1) {
54
55                 fd_set set;
56                 int maxfd = __osrfRouterFillFDSet( router, &set );
57                 int numhandled = 0;
58
59                 if( (selectret = select(maxfd + 1, &set, NULL, NULL, NULL)) < 0 ) {
60                         osrfLogWarning( OSRF_LOG_MARK, "Top level select call failed with errno %d", errno);
61                         continue;
62                 }
63
64                 /* see if there is a top level router message */
65
66                 if( FD_ISSET(routerfd, &set) ) {
67                         osrfLogDebug( OSRF_LOG_MARK, "Top router socket is active: %d", routerfd );
68                         numhandled++;
69                         osrfRouterHandleIncoming( router );
70                 }
71
72
73                 /* now check each of the connected classes and see if they have data to route */
74                 while( numhandled < selectret ) {
75
76                         osrfRouterClass* class;
77                         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
78
79                         while( (class = osrfHashIteratorNext(itr)) ) {
80
81                                 char* classname = itr->current;
82
83                                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
84
85                                         osrfLogDebug( OSRF_LOG_MARK, "Checking %s for activity...", classname );
86
87                                         int sockfd = class->ROUTER_SOCKFD;
88                                         if(FD_ISSET( sockfd, &set )) {
89                                                 osrfLogDebug( OSRF_LOG_MARK, "Socket is active: %d", sockfd );
90                                                 numhandled++;
91                                                 osrfRouterClassHandleIncoming( router, classname, class );
92                                         }
93                                 }
94                         }
95
96                         osrfHashIteratorFree(itr);
97                 }
98         }
99 }
100
101
102 void osrfRouterHandleIncoming( osrfRouter* router ) {
103         if(!router) return;
104
105         transport_message* msg = NULL;
106
107         //if( (msg = client_recv( router->connection, 0 )) ) { 
108         while( (msg = client_recv( router->connection, 0 )) ) { 
109
110                 if( msg->sender ) {
111
112                         osrfLogDebug(OSRF_LOG_MARK, 
113                                 "osrfRouterHandleIncoming(): investigating message from %s", msg->sender);
114
115                         /* if the sender is not a trusted server, drop the message */
116                         int len = strlen(msg->sender) + 1;
117                         char domain[len];
118                         bzero(domain, len);
119                         jid_get_domain( msg->sender, domain, len - 1 );
120
121                         if(osrfStringArrayContains( router->trustedServers, domain)) 
122                                 osrfRouterHandleMessage( router, msg );
123                          else 
124                                 osrfLogWarning( OSRF_LOG_MARK, "Received message from un-trusted server domain %s", msg->sender);
125                 }
126
127                 message_free(msg);
128         }
129 }
130
131 int osrfRouterClassHandleIncoming( osrfRouter* router, char* classname, osrfRouterClass* class ) {
132         if(!(router && class)) return -1;
133
134         transport_message* msg;
135         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleIncoming()");
136
137         while( (msg = client_recv( class->connection, 0 )) ) {
138
139                 if( msg->sender ) {
140
141                         osrfLogDebug(OSRF_LOG_MARK, 
142                                 "osrfRouterClassHandleIncoming(): investigating message from %s", msg->sender);
143
144                         /* if the client is not from a trusted domain, drop the message */
145                         int len = strlen(msg->sender) + 1;
146                         char domain[len];
147                         bzero(domain, len);
148                         jid_get_domain( msg->sender, domain, len - 1 );
149
150                         if(osrfStringArrayContains( router->trustedClients, domain)) {
151
152                                 transport_message* bouncedMessage = NULL;
153                                 if( msg->is_error )  {
154
155                                         /* handle bounced message */
156                                         if( !(bouncedMessage = osrfRouterClassHandleBounce( router, classname, class, msg )) ) 
157                                                 return -1; /* we have no one to send the requested message to */
158
159                                         message_free( msg );
160                                         msg = bouncedMessage;
161                                 }
162                                 osrfRouterClassHandleMessage( router, class, msg );
163
164                         } else {
165                                 osrfLogWarning( OSRF_LOG_MARK, "Received client message from untrusted client domain %s", domain );
166                         }
167                 }
168
169                 message_free( msg );
170         }
171
172         return 0;
173 }
174
175
176
177
178 int osrfRouterHandleMessage( osrfRouter* router, transport_message* msg ) {
179         if(!(router && msg)) return -1;
180
181         if( !msg->router_command || !strcmp(msg->router_command,"")) 
182                 return osrfRouterHandleAppRequest( router, msg ); /* assume it's an app session level request */
183
184         if(!msg->router_class) return -1;
185
186         osrfRouterClass* class = NULL;
187         if(!strcmp(msg->router_command, ROUTER_REGISTER)) {
188                 class = osrfRouterFindClass( router, msg->router_class );
189
190                 osrfLogInfo( OSRF_LOG_MARK, "Registering class %s", msg->router_class );
191
192                 if(!class) class = osrfRouterAddClass( router, msg->router_class );
193
194                 if(class) { 
195
196                         if( osrfRouterClassFindNode( class, msg->sender ) )
197                                 return 0;
198                         else 
199                                 osrfRouterClassAddNode( class, msg->sender );
200
201                 } 
202
203         } else if( !strcmp( msg->router_command, ROUTER_UNREGISTER ) ) {
204
205                 if( msg->router_class && strcmp( msg->router_class, "") ) {
206                         osrfLogInfo( OSRF_LOG_MARK, "Unregistering router class %s", msg->router_class );
207                         osrfRouterClassRemoveNode( router, msg->router_class, msg->sender );
208                 }
209         }
210
211         return 0;
212 }
213
214
215
216 osrfRouterClass* osrfRouterAddClass( osrfRouter* router, char* classname ) {
217         if(!(router && router->classes && classname)) return NULL;
218
219         osrfRouterClass* class = safe_malloc(sizeof(osrfRouterClass));
220         class->nodes = osrfNewHash();
221         class->itr = osrfNewHashIterator(class->nodes);
222         class->nodes->freeItem = &osrfRouterNodeFree;
223         class->router   = router;
224
225         class->connection = client_init( router->domain, router->port, NULL, 0 );
226
227         if(!client_connect( class->connection, router->name, 
228                         router->password, classname, 10, AUTH_DIGEST ) ) {
229                 osrfRouterClassFree( classname, class );
230                 return NULL;
231         }
232         
233         osrfHashSet( router->classes, class, classname );
234         return class;
235 }
236
237
238 int osrfRouterClassAddNode( osrfRouterClass* rclass, char* remoteId ) {
239         if(!(rclass && rclass->nodes && remoteId)) return -1;
240
241         osrfLogInfo( OSRF_LOG_MARK, "Adding router node for remote id %s", remoteId );
242
243         osrfRouterNode* node = safe_malloc(sizeof(osrfRouterNode));
244         node->count = 0;
245         node->lastMessage = NULL;
246         node->remoteId = strdup(remoteId);
247
248         osrfHashSet( rclass->nodes, node, remoteId );
249         return 0;
250 }
251
252 /* copy off the lastMessage, remove the offending node, send error if it's tht last node 
253         ? return NULL if it's the last node ?
254  */
255
256 transport_message* osrfRouterClassHandleBounce( 
257                 osrfRouter* router, char* classname, osrfRouterClass* rclass, transport_message* msg ) {
258
259         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleBounce()");
260
261         osrfLogInfo( OSRF_LOG_MARK, "Received network layer error message from %s", msg->sender );
262         osrfRouterNode* node = osrfRouterClassFindNode( rclass, msg->sender );
263         transport_message* lastSent = NULL;
264
265         if( node && osrfHashGetCount(rclass->nodes) == 1 ) { /* the last node is dead */
266
267                 if( node->lastMessage ) {
268                         osrfLogWarning( OSRF_LOG_MARK, "We lost the last node in the class, responding with error and removing...");
269         
270                         transport_message* error = message_init( 
271                                 node->lastMessage->body, node->lastMessage->subject, 
272                                 node->lastMessage->thread, node->lastMessage->router_from, node->lastMessage->recipient );
273                         set_msg_error( error, "cancel", 501 );
274         
275                         /* send the error message back to the original sender */
276                         client_send_message( rclass->connection, error );
277                         message_free( error );
278                 }
279         
280                 return NULL;
281         
282         } else { 
283
284                 if( node ) {
285                         if( node->lastMessage ) {
286                                 osrfLogDebug( OSRF_LOG_MARK, "Cloning lastMessage so next node can send it");
287                                 lastSent = message_init( node->lastMessage->body,
288                                         node->lastMessage->subject, node->lastMessage->thread, "", node->lastMessage->router_from );
289                                 message_set_router_info( lastSent, node->lastMessage->router_from, NULL, NULL, NULL, 0 );
290                         }
291                 } else {
292
293                         osrfLogInfo(OSRF_LOG_MARK, "network error occurred after we removed the class.. ignoring");
294                         return NULL;
295                 }
296         }
297
298         /* remove the dead node */
299         osrfRouterClassRemoveNode( router, classname, msg->sender);
300         return lastSent;
301 }
302
303
304 /**
305   If we get a regular message, we send it to the next node in the list of nodes
306   if we get an error, it's a bounce back from a previous attempt.  We take the
307   body and thread from the last sent on the node that had the bounced message
308   and propogate them on to the new message being sent
309   */
310 int osrfRouterClassHandleMessage( 
311                 osrfRouter* router, osrfRouterClass* rclass, transport_message* msg ) {
312         if(!(router && rclass && msg)) return -1;
313
314         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleMessage()");
315
316         osrfRouterNode* node = osrfHashIteratorNext( rclass->itr );
317         if(!node) {
318                 osrfHashIteratorReset(rclass->itr);
319                 node = osrfHashIteratorNext( rclass->itr );
320         }
321
322         if(node) {
323
324                 transport_message* new_msg= message_init(       msg->body, 
325                                 msg->subject, msg->thread, node->remoteId, msg->sender );
326                 message_set_router_info( new_msg, msg->sender, NULL, NULL, NULL, 0 );
327
328                 osrfLogInfo( OSRF_LOG_MARK,  "Routing message:\nfrom: [%s]\nto: [%s]", 
329                                 new_msg->router_from, new_msg->recipient );
330
331                 message_free( node->lastMessage );
332                 node->lastMessage = new_msg;
333
334                 if ( client_send_message( rclass->connection, new_msg ) == 0 ) 
335                         node->count++;
336
337                 else {
338                         message_prepare_xml(new_msg);
339                         osrfLogWarning( OSRF_LOG_MARK, "Error sending message from %s to %s\n%s", 
340                                         new_msg->sender, new_msg->recipient, new_msg->msg_xml );
341                 }
342
343         } 
344
345         return 0;
346 }
347
348
349 int osrfRouterRemoveClass( osrfRouter* router, char* classname ) {
350         if(!(router && router->classes && classname)) return -1;
351         osrfLogInfo( OSRF_LOG_MARK, "Removing router class %s", classname );
352         osrfHashRemove( router->classes, classname );
353         return 0;
354 }
355
356
357 int osrfRouterClassRemoveNode( 
358                 osrfRouter* router, char* classname, char* remoteId ) {
359
360         if(!(router && router->classes && classname && remoteId)) return 0;
361
362         osrfLogInfo( OSRF_LOG_MARK, "Removing router node %s", remoteId );
363
364         osrfRouterClass* class = osrfRouterFindClass( router, classname );
365
366         if( class ) {
367
368                 osrfHashRemove( class->nodes, remoteId );
369                 if( osrfHashGetCount(class->nodes) == 0 ) {
370                         osrfRouterRemoveClass( router, classname );
371                         return 1;
372                 }
373
374                 return 0;
375         }
376
377         return -1;
378 }
379
380
381 void osrfRouterClassFree( char* classname, void* c ) {
382         if(!(classname && c)) return;
383         osrfRouterClass* rclass = (osrfRouterClass*) c;
384         client_disconnect( rclass->connection );        
385         client_free( rclass->connection );      
386
387         osrfHashIteratorReset( rclass->itr );
388         osrfRouterNode* node;
389
390         while( (node = osrfHashIteratorNext(rclass->itr)) ) 
391                 osrfRouterClassRemoveNode( rclass->router, classname, node->remoteId );
392
393         free(rclass);
394 }
395
396
397 void osrfRouterNodeFree( char* remoteId, void* n ) {
398         if(!n) return;
399         osrfRouterNode* node = (osrfRouterNode*) n;
400         free(node->remoteId);
401         message_free(node->lastMessage);
402         free(node);
403 }
404
405
406 void osrfRouterFree( osrfRouter* router ) {
407         if(!router) return;
408
409         free(router->domain);           
410         free(router->name);
411         free(router->resource);
412         free(router->password);
413
414         osrfStringArrayFree( router->trustedClients );
415         osrfStringArrayFree( router->trustedServers );
416
417         client_free( router->connection );
418         free(router);
419 }
420
421
422
423 osrfRouterClass* osrfRouterFindClass( osrfRouter* router, char* classname ) {
424         if(!( router && router->classes && classname )) return NULL;
425         return (osrfRouterClass*) osrfHashGet( router->classes, classname );
426 }
427
428
429 osrfRouterNode* osrfRouterClassFindNode( osrfRouterClass* rclass, char* remoteId ) {
430         if(!(rclass && remoteId))  return NULL;
431         return (osrfRouterNode*) osrfHashGet( rclass->nodes, remoteId );
432 }
433
434
435 int __osrfRouterFillFDSet( osrfRouter* router, fd_set* set ) {
436         if(!(router && router->classes && set)) return -1;
437
438         FD_ZERO(set);
439         int maxfd = router->ROUTER_SOCKFD;
440         FD_SET(maxfd, set);
441
442         int sockid;
443
444         osrfRouterClass* class = NULL;
445         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
446
447         while( (class = osrfHashIteratorNext(itr)) ) {
448                 char* classname = itr->current;
449
450                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
451                         sockid = class->ROUTER_SOCKFD;
452         
453                         if( osrfUtilsCheckFileDescriptor( sockid ) ) {
454                                 osrfRouterRemoveClass( router, classname );
455         
456                         } else {
457                                 if( sockid > maxfd ) maxfd = sockid;
458                                 FD_SET(sockid, set);
459                         }
460                 }
461         }
462
463         osrfHashIteratorFree(itr);
464         return maxfd;
465 }
466
467
468
469 int osrfRouterHandleAppRequest( osrfRouter* router, transport_message* msg ) {
470
471         int T = 32;
472         osrfMessage* arr[T];
473         memset(arr, 0, T );
474
475         int num_msgs = osrf_message_deserialize( msg->body, arr, T );
476         osrfMessage* omsg = NULL;
477
478         int i;
479         for( i = 0; i != num_msgs; i++ ) {
480
481                 if( !(omsg = arr[i]) ) continue;
482
483                 switch( omsg->m_type ) {
484
485                         case CONNECT:
486                                 osrfRouterRespondConnect( router, msg, omsg );
487                                 break;
488
489                         case REQUEST:
490                                 osrfRouterProcessAppRequest( router, msg, omsg );
491                                 break;
492
493                         default: break;
494                 }
495
496                 osrfMessageFree( omsg );
497         }
498
499         return 0;
500 }
501
502 int osrfRouterRespondConnect( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
503         if(!(router && msg && omsg)) return -1;
504
505         osrfMessage* success = osrf_message_init( STATUS, omsg->thread_trace, omsg->protocol );
506
507         osrfLogDebug( OSRF_LOG_MARK, "router recevied a CONNECT message from %s", msg->sender );
508
509         osrf_message_set_status_info( 
510                 success, "osrfConnectStatus", "Connection Successful", OSRF_STATUS_OK );
511
512         char* data      = osrf_message_serialize(success);
513
514         transport_message* return_m = message_init( 
515                 data, "", msg->thread, msg->sender, "" );
516
517         client_send_message(router->connection, return_m);
518
519         free(data);
520         osrf_message_free(success);
521         message_free(return_m);
522
523         return 0;
524 }
525
526
527
528 int osrfRouterProcessAppRequest( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
529
530         if(!(router && msg && omsg && omsg->method_name)) return -1;
531
532         osrfLogInfo( OSRF_LOG_MARK, "Router received app request: %s", omsg->method_name );
533
534         jsonObject* jresponse = NULL;
535         if(!strcmp( omsg->method_name, ROUTER_REQUEST_CLASS_LIST )) {
536
537                 int i;
538                 jresponse = jsonParseString("[]");
539
540                 osrfStringArray* keys = osrfHashKeys( router->classes );
541                 for( i = 0; i != keys->size; i++ )
542                         jsonObjectPush( jresponse, jsonNewObject(osrfStringArrayGetString( keys, i )) );
543                 osrfStringArrayFree(keys);
544
545
546         } else {
547
548                 return osrfRouterHandleMethodNFound( router, msg, omsg );
549         }
550
551
552         osrfRouterHandleAppResponse( router, msg, omsg, jresponse );
553         jsonObjectFree(jresponse); 
554
555         return 0;
556
557 }
558
559
560
561 int osrfRouterHandleMethodNFound( 
562                 osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
563
564         osrf_message* err = osrf_message_init( STATUS, omsg->thread_trace, 1);
565                 osrf_message_set_status_info( err, 
566                                 "osrfMethodException", "Router method not found", OSRF_STATUS_NOTFOUND );
567
568                 char* data =  osrf_message_serialize(err);
569
570                 transport_message* tresponse = message_init(
571                                 data, "", msg->thread, msg->sender, msg->recipient );
572
573                 client_send_message(router->connection, tresponse );
574
575                 free(data);
576                 osrf_message_free( err );
577                 message_free(tresponse);
578                 return 0;
579 }
580
581
582
583 int osrfRouterHandleAppResponse( osrfRouter* router, 
584         transport_message* msg, osrfMessage* omsg, jsonObject* response ) {
585
586         if( response ) { /* send the response message */
587
588                 osrfMessage* oresponse = osrf_message_init(
589                                 RESULT, omsg->thread_trace, omsg->protocol );
590         
591                 char* json = jsonObjectToJSON(response);
592                 osrf_message_set_result_content( oresponse, json);
593         
594                 char* data =  osrf_message_serialize(oresponse);
595                 osrfLogDebug( OSRF_LOG_MARK,  "Responding to client app request with data: \n%s\n", data );
596
597                 transport_message* tresponse = message_init(
598                                 data, "", msg->thread, msg->sender, msg->recipient );
599         
600                 client_send_message(router->connection, tresponse );
601
602                 osrfMessageFree(oresponse); 
603                 message_free(tresponse);
604                 free(json);
605                 free(data);
606         }
607
608
609         /* now send the 'request complete' message */
610         osrf_message* status = osrf_message_init( STATUS, omsg->thread_trace, 1);
611         osrf_message_set_status_info( status, "osrfConnectStatus", "Request Complete", OSRF_STATUS_COMPLETE );
612
613         char* statusdata = osrf_message_serialize(status);
614
615         transport_message* sresponse = message_init(
616                         statusdata, "", msg->thread, msg->sender, msg->recipient );
617         client_send_message(router->connection, sresponse );
618
619
620         free(statusdata);
621         osrfMessageFree(status);
622         message_free(sresponse);
623
624         return 0;
625 }
626
627
628
629