added pointer check to prevent the router from crashing when it receives the
[OpenSRF.git] / src / router / osrf_router.c
1 #include "osrf_router.h"
2
3 #define ROUTER_SOCKFD connection->session->sock_id
4 #define ROUTER_REGISTER "register"
5 #define ROUTER_UNREGISTER "unregister"
6
7
8 #define ROUTER_REQUEST_CLASS_LIST "opensrf.router.info.class.list"
9
10 osrfRouter* osrfNewRouter( 
11                 char* domain, char* name, 
12                 char* resource, char* password, int port, 
13                 osrfStringArray* trustedClients, osrfStringArray* trustedServers ) {
14
15         if(!( domain && name && resource && password && port && trustedClients && trustedServers )) return NULL;
16
17         osrfRouter* router      = safe_malloc(sizeof(osrfRouter));
18         router->domain                  = strdup(domain);
19         router->name                    = strdup(name);
20         router->password                = strdup(password);
21         router->resource                = strdup(resource);
22         router->port                    = port;
23
24         router->trustedClients = trustedClients;
25         router->trustedServers = trustedServers;
26
27         
28         router->classes = osrfNewHash(); 
29         router->classes->freeItem = &osrfRouterClassFree;
30
31         router->connection = client_init( domain, port, NULL, 0 );
32
33         return router;
34 }
35
36
37
38 int osrfRouterConnect( osrfRouter* router ) {
39         if(!router) return -1;
40         int ret = client_connect( router->connection, router->name, 
41                         router->password, router->resource, 10, AUTH_DIGEST );
42         if( ret == 0 ) return -1;
43         return 0;
44 }
45
46
47 void osrfRouterRun( osrfRouter* router ) {
48         if(!(router && router->classes)) return;
49
50         int routerfd = router->ROUTER_SOCKFD;
51         int selectret = 0;
52
53         while(1) {
54
55                 fd_set set;
56                 int maxfd = __osrfRouterFillFDSet( router, &set );
57                 int numhandled = 0;
58
59                 if( (selectret = select(maxfd + 1, &set, NULL, NULL, NULL)) < 0 ) {
60                         osrfLogWarning( OSRF_LOG_MARK, "Top level select call failed with errno %d", errno);
61                         continue;
62                 }
63
64                 /* see if there is a top level router message */
65
66                 if( FD_ISSET(routerfd, &set) ) {
67                         osrfLogDebug( OSRF_LOG_MARK, "Top router socket is active: %d", routerfd );
68                         numhandled++;
69                         osrfRouterHandleIncoming( router );
70                 }
71
72
73                 /* now check each of the connected classes and see if they have data to route */
74                 while( numhandled < selectret ) {
75
76                         osrfRouterClass* class;
77                         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
78
79                         while( (class = osrfHashIteratorNext(itr)) ) {
80
81                                 char* classname = itr->current;
82
83                                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
84
85                                         osrfLogDebug( OSRF_LOG_MARK, "Checking %s for activity...", classname );
86
87                                         int sockfd = class->ROUTER_SOCKFD;
88                                         if(FD_ISSET( sockfd, &set )) {
89                                                 osrfLogDebug( OSRF_LOG_MARK, "Socket is active: %d", sockfd );
90                                                 numhandled++;
91                                                 osrfRouterClassHandleIncoming( router, classname, class );
92                                         }
93                                 }
94                         }
95
96                         osrfHashIteratorFree(itr);
97                 }
98         }
99 }
100
101
102 void osrfRouterHandleIncoming( osrfRouter* router ) {
103         if(!router) return;
104
105         transport_message* msg = NULL;
106
107         if( (msg = client_recv( router->connection, 0 )) ) { 
108
109                 if( msg->sender ) {
110
111                         /* if the sender is not a trusted server, drop the message */
112                         int len = strlen(msg->sender) + 1;
113                         char domain[len];
114                         bzero(domain, len);
115                         jid_get_domain( msg->sender, domain, len - 1 );
116
117                         if(osrfStringArrayContains( router->trustedServers, domain)) 
118                                 osrfRouterHandleMessage( router, msg );
119                          else 
120                                 osrfLogWarning( OSRF_LOG_MARK, "Received message from un-trusted server domain %s", msg->sender);
121                 }
122
123                 message_free(msg);
124         }
125 }
126
127 int osrfRouterClassHandleIncoming( osrfRouter* router, char* classname, osrfRouterClass* class ) {
128         if(!(router && class)) return -1;
129
130         transport_message* msg;
131         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleIncoming()");
132
133         if( (msg = client_recv( class->connection, 0 )) ) {
134
135                 if( msg->sender ) {
136
137                         /* if the client is not from a trusted domain, drop the message */
138                         int len = strlen(msg->sender) + 1;
139                         char domain[len];
140                         bzero(domain, len);
141                         jid_get_domain( msg->sender, domain, len - 1 );
142
143                         if(osrfStringArrayContains( router->trustedClients, domain)) {
144
145                                 transport_message* bouncedMessage = NULL;
146                                 if( msg->is_error )  {
147
148                                         /* handle bounced message */
149                                         if( !(bouncedMessage = osrfRouterClassHandleBounce( router, classname, class, msg )) ) 
150                                                 return -1; /* we have no one to send the requested message to */
151
152                                         message_free( msg );
153                                         msg = bouncedMessage;
154                                 }
155                                 osrfRouterClassHandleMessage( router, class, msg );
156
157                         } else {
158                                 osrfLogWarning( OSRF_LOG_MARK, "Received client message from untrusted client domain %s", domain );
159                         }
160                 }
161
162                 message_free( msg );
163         }
164
165         return 0;
166 }
167
168
169
170
171 int osrfRouterHandleMessage( osrfRouter* router, transport_message* msg ) {
172         if(!(router && msg)) return -1;
173
174         if( !msg->router_command || !strcmp(msg->router_command,"")) 
175                 return osrfRouterHandleAppRequest( router, msg ); /* assume it's an app session level request */
176
177         if(!msg->router_class) return -1;
178
179         osrfRouterClass* class = NULL;
180         if(!strcmp(msg->router_command, ROUTER_REGISTER)) {
181                 class = osrfRouterFindClass( router, msg->router_class );
182
183                 osrfLogInfo( OSRF_LOG_MARK, "Registering class %s", msg->router_class );
184
185                 if(!class) class = osrfRouterAddClass( router, msg->router_class );
186
187                 if(class) { 
188
189                         if( osrfRouterClassFindNode( class, msg->sender ) )
190                                 return 0;
191                         else 
192                                 osrfRouterClassAddNode( class, msg->sender );
193
194                 } 
195
196         } else if( !strcmp( msg->router_command, ROUTER_UNREGISTER ) ) {
197
198                 if( msg->router_class && strcmp( msg->router_class, "") ) {
199                         osrfLogInfo( OSRF_LOG_MARK, "Unregistering router class %s", msg->router_class );
200                         osrfRouterClassRemoveNode( router, msg->router_class, msg->sender );
201                 }
202         }
203
204         return 0;
205 }
206
207
208
209 osrfRouterClass* osrfRouterAddClass( osrfRouter* router, char* classname ) {
210         if(!(router && router->classes && classname)) return NULL;
211
212         osrfRouterClass* class = safe_malloc(sizeof(osrfRouterClass));
213         class->nodes = osrfNewHash();
214         class->itr = osrfNewHashIterator(class->nodes);
215         class->nodes->freeItem = &osrfRouterNodeFree;
216         class->router   = router;
217
218         class->connection = client_init( router->domain, router->port, NULL, 0 );
219
220         if(!client_connect( class->connection, router->name, 
221                         router->password, classname, 10, AUTH_DIGEST ) ) {
222                 osrfRouterClassFree( classname, class );
223                 return NULL;
224         }
225         
226         osrfHashSet( router->classes, class, classname );
227         return class;
228 }
229
230
231 int osrfRouterClassAddNode( osrfRouterClass* rclass, char* remoteId ) {
232         if(!(rclass && rclass->nodes && remoteId)) return -1;
233
234         osrfLogInfo( OSRF_LOG_MARK, "Adding router node for remote id %s", remoteId );
235
236         osrfRouterNode* node = safe_malloc(sizeof(osrfRouterNode));
237         node->count = 0;
238         node->lastMessage = NULL;
239         node->remoteId = strdup(remoteId);
240
241         osrfHashSet( rclass->nodes, node, remoteId );
242         return 0;
243 }
244
245 /* copy off the lastMessage, remove the offending node, send error if it's tht last node 
246         ? return NULL if it's the last node ?
247  */
248
249 transport_message* osrfRouterClassHandleBounce( 
250                 osrfRouter* router, char* classname, osrfRouterClass* rclass, transport_message* msg ) {
251
252         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleBounce()");
253
254         osrfLogInfo( OSRF_LOG_MARK, "Received network layer error message from %s", msg->sender );
255         osrfRouterNode* node = osrfRouterClassFindNode( rclass, msg->sender );
256         transport_message* lastSent = NULL;
257
258         if( node && osrfHashGetCount(rclass->nodes) == 1 ) { /* the last node is dead */
259
260                 if( node->lastMessage ) {
261                         osrfLogWarning( OSRF_LOG_MARK, "We lost the last node in the class, responding with error and removing...");
262         
263                         transport_message* error = message_init( 
264                                 node->lastMessage->body, node->lastMessage->subject, 
265                                 node->lastMessage->thread, node->lastMessage->router_from, node->lastMessage->recipient );
266                         set_msg_error( error, "cancel", 501 );
267         
268                         /* send the error message back to the original sender */
269                         client_send_message( rclass->connection, error );
270                         message_free( error );
271                 }
272         
273                 return NULL;
274         
275         } else { 
276
277                 if( node ) {
278                         if( node->lastMessage ) {
279                                 osrfLogDebug( OSRF_LOG_MARK, "Cloning lastMessage so next node can send it");
280                                 lastSent = message_init( node->lastMessage->body,
281                                         node->lastMessage->subject, node->lastMessage->thread, "", node->lastMessage->router_from );
282                                 message_set_router_info( lastSent, node->lastMessage->router_from, NULL, NULL, NULL, 0 );
283                         }
284                 } else {
285
286                         osrfLogInfo(OSRF_LOG_MARK, "network error occurred after we removed the class.. ignoring");
287                         return NULL;
288                 }
289         }
290
291         /* remove the dead node */
292         osrfRouterClassRemoveNode( router, classname, msg->sender);
293         return lastSent;
294 }
295
296
297 /**
298   If we get a regular message, we send it to the next node in the list of nodes
299   if we get an error, it's a bounce back from a previous attempt.  We take the
300   body and thread from the last sent on the node that had the bounced message
301   and propogate them on to the new message being sent
302   */
303 int osrfRouterClassHandleMessage( 
304                 osrfRouter* router, osrfRouterClass* rclass, transport_message* msg ) {
305         if(!(router && rclass && msg)) return -1;
306
307         osrfLogDebug( OSRF_LOG_MARK, "osrfRouterClassHandleMessage()");
308
309         osrfRouterNode* node = osrfHashIteratorNext( rclass->itr );
310         if(!node) {
311                 osrfHashIteratorReset(rclass->itr);
312                 node = osrfHashIteratorNext( rclass->itr );
313         }
314
315         if(node) {
316
317                 transport_message* new_msg= message_init(       msg->body, 
318                                 msg->subject, msg->thread, node->remoteId, msg->sender );
319                 message_set_router_info( new_msg, msg->sender, NULL, NULL, NULL, 0 );
320
321                 osrfLogInfo( OSRF_LOG_MARK,  "Routing message:\nfrom: [%s]\nto: [%s]", 
322                                 new_msg->router_from, new_msg->recipient );
323
324                 message_free( node->lastMessage );
325                 node->lastMessage = new_msg;
326
327                 if ( client_send_message( rclass->connection, new_msg ) == 0 ) 
328                         node->count++;
329
330                 else {
331                         message_prepare_xml(new_msg);
332                         osrfLogWarning( OSRF_LOG_MARK, "Error sending message from %s to %s\n%s", 
333                                         new_msg->sender, new_msg->recipient, new_msg->msg_xml );
334                 }
335
336         } 
337
338         return 0;
339 }
340
341
342 int osrfRouterRemoveClass( osrfRouter* router, char* classname ) {
343         if(!(router && router->classes && classname)) return -1;
344         osrfLogInfo( OSRF_LOG_MARK, "Removing router class %s", classname );
345         osrfHashRemove( router->classes, classname );
346         return 0;
347 }
348
349
350 int osrfRouterClassRemoveNode( 
351                 osrfRouter* router, char* classname, char* remoteId ) {
352
353         if(!(router && router->classes && classname && remoteId)) return 0;
354
355         osrfLogInfo( OSRF_LOG_MARK, "Removing router node %s", remoteId );
356
357         osrfRouterClass* class = osrfRouterFindClass( router, classname );
358
359         if( class ) {
360
361                 osrfHashRemove( class->nodes, remoteId );
362                 if( osrfHashGetCount(class->nodes) == 0 ) {
363                         osrfRouterRemoveClass( router, classname );
364                         return 1;
365                 }
366
367                 return 0;
368         }
369
370         return -1;
371 }
372
373
374 void osrfRouterClassFree( char* classname, void* c ) {
375         if(!(classname && c)) return;
376         osrfRouterClass* rclass = (osrfRouterClass*) c;
377         client_disconnect( rclass->connection );        
378         client_free( rclass->connection );      
379
380         osrfHashIteratorReset( rclass->itr );
381         osrfRouterNode* node;
382
383         while( (node = osrfHashIteratorNext(rclass->itr)) ) 
384                 osrfRouterClassRemoveNode( rclass->router, classname, node->remoteId );
385
386         free(rclass);
387 }
388
389
390 void osrfRouterNodeFree( char* remoteId, void* n ) {
391         if(!n) return;
392         osrfRouterNode* node = (osrfRouterNode*) n;
393         free(node->remoteId);
394         message_free(node->lastMessage);
395         free(node);
396 }
397
398
399 void osrfRouterFree( osrfRouter* router ) {
400         if(!router) return;
401
402         free(router->domain);           
403         free(router->name);
404         free(router->resource);
405         free(router->password);
406
407         osrfStringArrayFree( router->trustedClients );
408         osrfStringArrayFree( router->trustedServers );
409
410         client_free( router->connection );
411         free(router);
412 }
413
414
415
416 osrfRouterClass* osrfRouterFindClass( osrfRouter* router, char* classname ) {
417         if(!( router && router->classes && classname )) return NULL;
418         return (osrfRouterClass*) osrfHashGet( router->classes, classname );
419 }
420
421
422 osrfRouterNode* osrfRouterClassFindNode( osrfRouterClass* rclass, char* remoteId ) {
423         if(!(rclass && remoteId))  return NULL;
424         return (osrfRouterNode*) osrfHashGet( rclass->nodes, remoteId );
425 }
426
427
428 int __osrfRouterFillFDSet( osrfRouter* router, fd_set* set ) {
429         if(!(router && router->classes && set)) return -1;
430
431         FD_ZERO(set);
432         int maxfd = router->ROUTER_SOCKFD;
433         FD_SET(maxfd, set);
434
435         int sockid;
436
437         osrfRouterClass* class = NULL;
438         osrfHashIterator* itr = osrfNewHashIterator(router->classes);
439
440         while( (class = osrfHashIteratorNext(itr)) ) {
441                 char* classname = itr->current;
442
443                 if( classname && (class = osrfRouterFindClass( router, classname )) ) {
444                         sockid = class->ROUTER_SOCKFD;
445         
446                         if( osrfUtilsCheckFileDescriptor( sockid ) ) {
447                                 osrfRouterRemoveClass( router, classname );
448         
449                         } else {
450                                 if( sockid > maxfd ) maxfd = sockid;
451                                 FD_SET(sockid, set);
452                         }
453                 }
454         }
455
456         osrfHashIteratorFree(itr);
457         return maxfd;
458 }
459
460
461
462 int osrfRouterHandleAppRequest( osrfRouter* router, transport_message* msg ) {
463
464         int T = 32;
465         osrfMessage* arr[T];
466         memset(arr, 0, T );
467
468         int num_msgs = osrf_message_deserialize( msg->body, arr, T );
469         osrfMessage* omsg = NULL;
470
471         int i;
472         for( i = 0; i != num_msgs; i++ ) {
473
474                 if( !(omsg = arr[i]) ) continue;
475
476                 switch( omsg->m_type ) {
477
478                         case CONNECT:
479                                 osrfRouterRespondConnect( router, msg, omsg );
480                                 break;
481
482                         case REQUEST:
483                                 osrfRouterProcessAppRequest( router, msg, omsg );
484                                 break;
485
486                         default: break;
487                 }
488
489                 osrfMessageFree( omsg );
490         }
491
492         return 0;
493 }
494
495 int osrfRouterRespondConnect( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
496         if(!(router && msg && omsg)) return -1;
497
498         osrfMessage* success = osrf_message_init( STATUS, omsg->thread_trace, omsg->protocol );
499
500         osrfLogDebug( OSRF_LOG_MARK, "router recevied a CONNECT message from %s", msg->sender );
501
502         osrf_message_set_status_info( 
503                 success, "osrfConnectStatus", "Connection Successful", OSRF_STATUS_OK );
504
505         char* data      = osrf_message_serialize(success);
506
507         transport_message* return_m = message_init( 
508                 data, "", msg->thread, msg->sender, "" );
509
510         client_send_message(router->connection, return_m);
511
512         free(data);
513         osrf_message_free(success);
514         message_free(return_m);
515
516         return 0;
517 }
518
519
520
521 int osrfRouterProcessAppRequest( osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
522
523         if(!(router && msg && omsg && omsg->method_name)) return -1;
524
525         osrfLogInfo( OSRF_LOG_MARK, "Router received app request: %s", omsg->method_name );
526
527         jsonObject* jresponse = NULL;
528         if(!strcmp( omsg->method_name, ROUTER_REQUEST_CLASS_LIST )) {
529
530                 int i;
531                 jresponse = jsonParseString("[]");
532
533                 osrfStringArray* keys = osrfHashKeys( router->classes );
534                 for( i = 0; i != keys->size; i++ )
535                         jsonObjectPush( jresponse, jsonNewObject(osrfStringArrayGetString( keys, i )) );
536                 osrfStringArrayFree(keys);
537
538
539         } else {
540
541                 return osrfRouterHandleMethodNFound( router, msg, omsg );
542         }
543
544
545         osrfRouterHandleAppResponse( router, msg, omsg, jresponse );
546         jsonObjectFree(jresponse); 
547
548         return 0;
549
550 }
551
552
553
554 int osrfRouterHandleMethodNFound( 
555                 osrfRouter* router, transport_message* msg, osrfMessage* omsg ) {
556
557         osrf_message* err = osrf_message_init( STATUS, omsg->thread_trace, 1);
558                 osrf_message_set_status_info( err, 
559                                 "osrfMethodException", "Router method not found", OSRF_STATUS_NOTFOUND );
560
561                 char* data =  osrf_message_serialize(err);
562
563                 transport_message* tresponse = message_init(
564                                 data, "", msg->thread, msg->sender, msg->recipient );
565
566                 client_send_message(router->connection, tresponse );
567
568                 free(data);
569                 osrf_message_free( err );
570                 message_free(tresponse);
571                 return 0;
572 }
573
574
575
576 int osrfRouterHandleAppResponse( osrfRouter* router, 
577         transport_message* msg, osrfMessage* omsg, jsonObject* response ) {
578
579         if( response ) { /* send the response message */
580
581                 osrfMessage* oresponse = osrf_message_init(
582                                 RESULT, omsg->thread_trace, omsg->protocol );
583         
584                 char* json = jsonObjectToJSON(response);
585                 osrf_message_set_result_content( oresponse, json);
586         
587                 char* data =  osrf_message_serialize(oresponse);
588                 osrfLogDebug( OSRF_LOG_MARK,  "Responding to client app request with data: \n%s\n", data );
589
590                 transport_message* tresponse = message_init(
591                                 data, "", msg->thread, msg->sender, msg->recipient );
592         
593                 client_send_message(router->connection, tresponse );
594
595                 osrfMessageFree(oresponse); 
596                 message_free(tresponse);
597                 free(json);
598                 free(data);
599         }
600
601
602         /* now send the 'request complete' message */
603         osrf_message* status = osrf_message_init( STATUS, omsg->thread_trace, 1);
604         osrf_message_set_status_info( status, "osrfConnectStatus", "Request Complete", OSRF_STATUS_COMPLETE );
605
606         char* statusdata = osrf_message_serialize(status);
607
608         transport_message* sresponse = message_init(
609                         statusdata, "", msg->thread, msg->sender, msg->recipient );
610         client_send_message(router->connection, sresponse );
611
612
613         free(statusdata);
614         osrfMessageFree(status);
615         message_free(sresponse);
616
617         return 0;
618 }
619
620
621
622