Require CORS Origin header to use https:// and match the entire hostname.

Also require the port number to match if specified in the accepted origins
list.
This commit is contained in:
Wesley Miaw
2020-03-27 15:45:23 -07:00
parent e7ccaec8ae
commit df63f0e6af
4 changed files with 80 additions and 49 deletions

View File

@@ -491,69 +491,100 @@ static void handle_dial_data(struct mg_connection *conn,
ds_unlock(ds);
}
static int ends_with(const char *str, const char *suffix) {
if (!str || !suffix)
static int host_matches(const char *origin_host, const char *host) {
if (!origin_host || !host)
return 0;
size_t lenstr = strlen(str);
size_t lensuffix = strlen(suffix);
if (lensuffix > lenstr)
return 0;
return strncmp(str + lenstr - lensuffix, suffix, lensuffix) == 0;
// If the host contains a port number (indicated by a colon) require an
// exact match.
const char * origin_colon = strrchr(origin_host, ':');
const char * host_colon = strrchr(host, ':');
if (host_colon != NULL) {
// If the host port number is 443, then accept the origin host if it
// does not have any port number under the assumption we already
// verified https://
if (strlen(host_colon) == 4 &&
strncmp(host_colon, ":443", 4) == 0 &&
origin_colon == NULL)
{
// Strip off the host port number and require an exact match.
size_t host_len = host_colon - host;
if (host_len == 0)
return 0;
return strncmp(origin_host, host, host_len) == 0;
}
// Other port numbers must match exactly.
if (strlen(origin_host) != strlen(host))
return 0;
return strncmp(origin_host, host, strlen(host)) == 0;
}
// Otherwise strip off any port number in the origin host, and require an
// exact match.
size_t origin_len = (origin_colon == NULL) ? strlen(origin_host) : origin_colon - origin_host;
if (origin_len == 0)
return 0;
if (origin_len != strlen(host))
return 0;
return strncmp(origin_host, host, origin_len) == 0;
}
// str contains a white space separated list of strings (only supports SPACE characters for now)
static int ends_with_in_list (const char *str, const char *list) {
if (!str || !list)
static int is_hostname_in_list(const char *origin_host, const char *list) {
if (!origin_host || !list)
return 0;
const char * scanPointer=list;
const char * scanPointer = list;
const char * spacePointer;
unsigned int substringSize = 257;
char *substring = (char *)malloc(substringSize);
if (!substring){
char *host = (char *)malloc(substringSize);
if (!host) {
return 0;
}
while ( (spacePointer =strchr(scanPointer, ' ')) != NULL) {
while ((spacePointer = strchr(scanPointer, ' ')) != NULL) {
int copyLength = spacePointer - scanPointer;
// protect against buffer overflow
if (copyLength>=substringSize){
substringSize=copyLength+1;
free(substring);
substring=(char *)malloc(substringSize);
if (!substring){
return 0;
}
}
// protect against buffer overflow
if (copyLength >= substringSize){
substringSize = copyLength + 1;
free(host);
host = (char *)malloc(substringSize);
if (!host) {
return 0;
}
}
memcpy(substring, scanPointer, copyLength);
substring[copyLength] = '\0';
//printf("found %s \n", substring);
if (ends_with(str, substring)) {
free(substring); substring = NULL;
memcpy(host, scanPointer, copyLength);
host[copyLength] = '\0';
//printf("found %s \n", host);
if (host_matches(origin_host, host)) {
free(host); host = NULL;
return 1;
}
scanPointer = scanPointer + copyLength + 1; // assumption: only 1 character
}
free(substring); substring = NULL;
return ends_with(str, scanPointer);
}
static int should_check_for_origin( char * origin ) {
const char * const CHECK_PROTOS[] = { "http:", "https:", "file:" };
for (int i = 0; i < 3; ++i) {
if (!strncmp(origin, CHECK_PROTOS[i], strlen(CHECK_PROTOS[i]) - 1)) {
return 1;
}
}
return 0;
free(host); host = NULL;
return host_matches(origin_host, scanPointer);
}
static int is_allowed_origin(DIALServer* ds, char * origin, const char * app_name) {
if (!origin || strlen(origin)==0 || !should_check_for_origin(origin)) {
const char * const HTTPS_PROTO = "https://";
fprintf(stderr, "checking %s for %s\n", origin, app_name);
if (!origin || strlen(origin)==0) {
return 1;
}
// Make sure the origin begins with HTTPS.
if (strncmp(origin, HTTPS_PROTO, strlen(HTTPS_PROTO)) != 0) {
return 0;
}
// For the rest of the check, we only care about the hostname and optional
// port number.
const char *origin_host = origin + strlen(HTTPS_PROTO);
if (!ds_lock(ds)) {
// If we can't check, fail in favor of safety.
@@ -564,7 +595,7 @@ static int is_allowed_origin(DIALServer* ds, char * origin, const char * app_nam
for (app = ds->apps; app != NULL; app = app->next) {
if (!strcmp(app->name, app_name)) {
if (!app->corsAllowedOrigin[0] ||
ends_with_in_list(origin, app->corsAllowedOrigin)) {
is_hostname_in_list(origin_host, app->corsAllowedOrigin)) {
result = 1;
break;
}