Support CAST expressions, taking care to avoid SQL injection.
[working/Evergreen.git] / Open-ILS / src / c-apps / oils_storedq.c
index 8aab5ee..c4f512a 100644 (file)
@@ -5,6 +5,7 @@
 
 #include <stdlib.h>
 #include <string.h>
+#include <ctype.h>
 #include <dbi/dbi.h>
 #include "opensrf/utils.h"
 #include "opensrf/log.h"
@@ -45,6 +46,10 @@ static CaseBranch* getCaseBranchList( BuildSQLState* state, int parent_id );
 static CaseBranch* constructCaseBranch( BuildSQLState* state, dbi_result result );
 static void freeBranchList( CaseBranch* branch );
 
+static Datatype* getDatatype( BuildSQLState* state, int id );
+static Datatype* constructDatatype( BuildSQLState* state, dbi_result result );
+static void datatypeFree( Datatype* datatype );
+
 static Expression* getExpression( BuildSQLState* state, int id );
 static Expression* constructExpression( BuildSQLState* state, dbi_result result );
 static void expressionListFree( Expression* exp );
@@ -65,6 +70,7 @@ static FromRelation* free_from_relation_list = NULL;
 static SelectItem* free_select_item_list = NULL;
 static BindVar* free_bindvar_list = NULL;
 static CaseBranch* free_branch_list = NULL;
+static Datatype* free_datatype_list = NULL;
 static Expression* free_expression_list = NULL;
 static IdNode* free_id_node_list = NULL;
 static QSeq* free_qseq_list = NULL;
@@ -189,7 +195,9 @@ StoredQ* getStoredQuery( BuildSQLState* state, int query_id ) {
                                osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
                                        "Unable to build a query for id = %d", query_id ));
                } else {
-                       sqlAddMsg( state, "Stored query not found for id %d", query_id );
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "Stored query not found for id %d", query_id ));
+                       state->error = 1;
                }
 
                dbi_result_free( result );
@@ -585,6 +593,7 @@ static FromRelation* getFromRelation( BuildSQLState* state, int id ) {
                } else {
                        osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
                                "FROM relation not found for id = %d", id ));
+                       state->error = 1;
                }
                dbi_result_free( result );
        } else {
@@ -910,6 +919,10 @@ static SelectItem* getSelectList( BuildSQLState* state, int query_id ) {
                                if( !dbi_result_next_row( result ) )
                                        break;
                        };
+               } else {
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No SELECT list found for query # %d", query_id ));
+                       state->error = 1;
                }
        } else {
                const char* msg;
@@ -1040,6 +1053,7 @@ static BindVar* getBindVar( BuildSQLState* state, const char* name ) {
                } else {
                        osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
                                "No bind variable found with name \"%s\"", name ));
+                       state->error = 1;
                }
        } else {
                const char* msg;
@@ -1209,6 +1223,10 @@ static CaseBranch* getCaseBranchList( BuildSQLState* state, int parent_id ) {
                                if( !dbi_result_next_row( result ) )
                                        break;
                        };  // end while
+               } else {
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No branches found for CASE expression %d", parent_id ));
+                       state->error = 1;
                }
 
                // Make sure that at least one branch includes a condition
@@ -1318,6 +1336,116 @@ static void freeBranchList( CaseBranch* branch ) {
 }
 
 /**
+       @brief Given an id for a row in query.datatype, build an Datatype struct.
+       @param Pointer to the query-building context.
+       @param id ID of a row in query.datatype.
+       @return Pointer to a newly-created Datatype if successful, or NULL if not.
+*/
+static Datatype* getDatatype( BuildSQLState* state, int id ) {
+       Datatype* datatype = NULL;
+       dbi_result result = dbi_conn_queryf( state->dbhandle,
+               "SELECT id, datatype_name, is_numeric, is_composite "
+               "FROM query.datatype WHERE id = %d", id );
+       if( result ) {
+               if( dbi_result_first_row( result ) ) {
+                       datatype = constructDatatype( state, result );
+                       if( datatype ) {
+                               PRINT( "Got a datatype\n" );
+                               PRINT( "\tid = %d\n", id );
+                               PRINT( "\tdatatype = %s\n", datatype->datatype_name );
+                       } else
+                               osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                                       "Unable to construct a Datatype for id = %d", id ));
+               } else {
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No datatype found for id = %d", id ));
+                       state->error = 1;
+               }
+       } else {
+               const char* msg;
+               int errnum = dbi_conn_error( state->dbhandle, &msg );
+               osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                       "Unable to query query.datatype table: #%d %s",
+                       errnum, msg ? msg : "No description available" ));
+               state->error = 1;
+       }
+       return datatype;
+}
+
+/**
+       @brief Construct a Datatype.
+       @param Pointer to the query-building context.
+       @param result Database cursor positioned at a row in query.datatype.
+       @return Pointer to a newly constructed Datatype, if successful, or NULL if not.
+
+       The calling code is responsible for freeing the Datatype by calling datatypeFree().
+*/
+static Datatype* constructDatatype( BuildSQLState* state, dbi_result result ) {
+       int id           = dbi_result_get_int_idx( result, 1 );
+       const char* datatype_name = dbi_result_get_string_idx( result, 2 );
+       int is_numeric   = oils_result_get_bool_idx( result, 3 );
+       int is_composite = oils_result_get_bool_idx( result, 4 );
+
+       if( !datatype_name || !*datatype_name ) {
+               osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                       "No datatype name provided for CAST expression # %d", id ));
+               state->error = 1;
+               return NULL;
+       }
+
+       // Make sure that the datatype name is composed entirely of certain approved
+       // characters.  This check is not an attempt to validate the datatype name, but
+       // only to prevent certain types of SQL injection.
+       const char* p = datatype_name;
+       while( *p ) {
+               unsigned char c = *p;
+               if( isalnum( c )
+                       || isspace( c )
+                       || ',' == c
+                       || '(' == c
+                       || ')' == c
+                       || '[' == c
+                       || ']' == c
+                       || '.' == c
+               )
+                       ++p;
+               else {
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "Invalid datatype name \"%s\" for datatype # %d; "
+                               "contains unexpected character \"%c\"", datatype_name, id, (char) c ));
+                       state->error = 1;
+                       return NULL;
+               }
+       }
+
+       // Allocate a Datatype: from the free list if possible, from the heap if necessary
+       Datatype* datatype = NULL;
+       if( free_datatype_list ) {
+               datatype = free_datatype_list;
+               free_datatype_list = free_datatype_list->next;
+       } else
+               datatype = safe_malloc( sizeof( Datatype ) );
+
+       datatype->id            = id;
+       datatype->datatype_name = strdup( datatype_name );
+       datatype->is_numeric    = is_numeric;
+       datatype->is_composite  = is_composite;
+
+       return datatype;
+}
+
+/**
+       @brief Free a Datatype.
+       @param datatype Pointer to the Datatype to be freed.
+*/
+static void datatypeFree( Datatype* datatype ) {
+       if( datatype ) {
+               free( datatype->datatype_name );
+               free( datatype );
+       }
+}
+
+/**
        @brief Given an id for a row in query.expression, build an Expression struct.
        @param Pointer to the query-building context.
        @param id ID of a row in query.expression.
@@ -1357,6 +1485,10 @@ static Expression* getExpression( BuildSQLState* state, int id ) {
                        } else
                                osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
                                        "Unable to construct an Expression for id = %d", id ));
+               } else {
+                       osrfLogError( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No expression found for id = %d", id ));
+                       state->error = 1;
                }
        } else {
                const char* msg;
@@ -1468,6 +1600,7 @@ static Expression* constructExpression( BuildSQLState* state, dbi_result result
        Expression* left_operand = NULL;
        Expression* right_operand = NULL;
        StoredQ* subquery = NULL;
+       Datatype* cast_type = NULL;
        BindVar* bind = NULL;
        CaseBranch* branch_list = NULL;
        Expression* subexp_list = NULL;
@@ -1573,6 +1706,38 @@ static Expression* constructExpression( BuildSQLState* state, dbi_result result
                        return NULL;
                }
 
+       } else if( EXP_CAST == type ) {
+               // Get the left operand
+               if( -1 == left_operand_id ) {
+                       osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No left operand defined for CAST expression # %d", id ));
+                       state->error = 1;
+                       return NULL;
+               } else {
+                       left_operand = getExpression( state, left_operand_id );
+                       if( !left_operand ) {
+                               osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
+                                       "Unable to get left operand for CAST expression # %d", id ));
+                               state->error = 1;
+                               return NULL;
+                       }
+               }
+
+               if( -1 == cast_type_id ) {
+                       osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
+                               "No datatype specified for CAST expression # %d", id ));
+                       state->error = 1;
+                       return NULL;
+               } else {
+                       cast_type = getDatatype( state, cast_type_id );
+                       if( !cast_type ) {
+                               osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
+                                       "Unable to get datatype for CAST expression # %d", id ));
+                               state->error = 1;
+                               return NULL;
+                       }
+               }
+
        } else if( EXP_EXIST == type ) {
                if( -1 == subquery_id ) {
                        osrfLogWarning( OSRF_LOG_MARK, sqlAddMsg( state,
@@ -1769,7 +1934,7 @@ static Expression* constructExpression( BuildSQLState* state, dbi_result result
        exp->right_operand = right_operand;
        exp->subquery_id = subquery_id;
        exp->subquery = subquery;
-       exp->cast_type_id = subquery_id;
+       exp->cast_type = cast_type;
        exp->negate = negate;
        exp->bind = bind;
        exp->branch_list = branch_list;
@@ -1822,6 +1987,10 @@ static void expressionFree( Expression* exp ) {
                        storedQFree( exp->subquery );
                        exp->subquery = NULL;
                }
+               if( exp->cast_type ) {
+                       datatypeFree( exp->cast_type );
+                       exp->cast_type = NULL;
+               }
 
                // We don't free the bind member here because the Expression doesn't own it;
                // the bindvar_list hash owns it, so that multiple Expressions can reference it.
@@ -2167,6 +2336,14 @@ void storedQCleanup( void ) {
                free( branch );
                branch = free_branch_list;
        }
+
+       // Free all the nodes in the datatype free list
+       Datatype* datatype = free_datatype_list;
+       while( datatype ) {
+               free_datatype_list = datatype->next;
+               free( datatype );
+               datatype = free_datatype_list;
+       }
 }
 
 /**