Procházet zdrojové kódy

Improve SSL error handling and reformat code

bel před 9 roky
rodič
revize
202341e3cd
1 změnil soubory, kde provedl 94 přidání a 65 odebrání
  1. 94 65
      src/civetweb.c

+ 94 - 65
src/civetweb.c

@@ -309,7 +309,7 @@ typedef long off_t;
 #define close(x) (_close(x))
 #define dlsym(x, y) (GetProcAddress((HINSTANCE)(x), (y)))
 #define RTLD_LAZY (0)
-#define fseeko(x, y, z) (_lseeki64(_fileno(x), (y), (z)) == -1 ? -1 : 0)
+#define fseeko(x, y, z) ((_lseeki64(_fileno(x), (y), (z)) == -1) ? -1 : 0)
 #define fdopen(x, y) (_fdopen((x), (y)))
 #define write(x, y, z) (_write((x), (y), (unsigned)z))
 #define read(x, y, z) (_read((x), (y), (unsigned)z))
@@ -623,7 +623,7 @@ localtime_s(const time_t *ptime, struct tm *ptm)
 	ptm->tm_sec = st.wSecond;
 	ptm->tm_yday = 0; /* hope nobody uses this */
 	ptm->tm_isdst =
-	    GetTimeZoneInformation(&tzinfo) == TIME_ZONE_ID_DAYLIGHT ? 1 : 0;
+	    (GetTimeZoneInformation(&tzinfo) == TIME_ZONE_ID_DAYLIGHT) ? 1 : 0;
 
 	return ptm;
 }
@@ -1042,6 +1042,17 @@ typedef struct x509_store_ctx_st X509_STORE_CTX;
 #define SSL_OP_NO_TLSv1_1 (0x10000000L)
 #define SSL_OP_SINGLE_DH_USE (0x00100000L)
 
+#define SSL_ERROR_NONE (0)
+#define SSL_ERROR_SSL (1)
+#define SSL_ERROR_WANT_READ (2)
+#define SSL_ERROR_WANT_WRITE (3)
+#define SSL_ERROR_WANT_X509_LOOKUP (4)
+#define SSL_ERROR_SYSCALL (5) /* see errno */
+#define SSL_ERROR_ZERO_RETURN (6)
+#define SSL_ERROR_WANT_CONNECT (7)
+#define SSL_ERROR_WANT_ACCEPT (8)
+
+
 struct ssl_func {
 	const char *name;  /* SSL function name */
 	void (*ptr)(void); /* Function pointer */
@@ -2328,17 +2339,17 @@ match_prefix(const char *pattern, size_t pattern_len, const char *str)
 
 	if ((or_str = (const char *)memchr(pattern, '|', pattern_len)) != NULL) {
 		res = match_prefix(pattern, (size_t)(or_str - pattern), str);
-		return res > 0 ? res : match_prefix(or_str + 1,
-		                                    (size_t)((pattern + pattern_len)
-		                                             - (or_str + 1)),
-		                                    str);
+		return (res > 0) ? res : match_prefix(or_str + 1,
+		                                      (size_t)((pattern + pattern_len)
+		                                               - (or_str + 1)),
+		                                      str);
 	}
 
 	for (i = 0, j = 0; i < pattern_len; i++, j++) {
 		if (pattern[i] == '?' && str[j] != '\0') {
 			continue;
 		} else if (pattern[i] == '$') {
-			return str[j] == '\0' ? j : -1;
+			return (str[j] == '\0') ? j : -1;
 		} else if (pattern[i] == '*') {
 			i++;
 			if (pattern[i] == '*') {
@@ -2353,7 +2364,7 @@ match_prefix(const char *pattern, size_t pattern_len, const char *str)
 			do {
 				res = match_prefix(pattern + i, pattern_len - i, str + j + len);
 			} while (res == -1 && len-- > 0);
-			return res == -1 ? -1 : j + res + len;
+			return (res == -1) ? -1 : j + res + len;
 		} else if (lowercase(&pattern[i]) != lowercase(&str[j])) {
 			return -1;
 		}
@@ -2786,21 +2797,21 @@ pthread_mutex_init(pthread_mutex_t *mutex, void *unused)
 {
 	(void)unused;
 	*mutex = CreateMutex(NULL, FALSE, NULL);
-	return *mutex == NULL ? -1 : 0;
+	return (*mutex == NULL) ? -1 : 0;
 }
 
 
 static int
 pthread_mutex_destroy(pthread_mutex_t *mutex)
 {
-	return CloseHandle(*mutex) == 0 ? -1 : 0;
+	return (CloseHandle(*mutex) == 0) ? -1 : 0;
 }
 
 
 static int
 pthread_mutex_lock(pthread_mutex_t *mutex)
 {
-	return WaitForSingleObject(*mutex, INFINITE) == WAIT_OBJECT_0 ? 0 : -1;
+	return (WaitForSingleObject(*mutex, INFINITE) == WAIT_OBJECT_0) ? 0 : -1;
 }
 
 
@@ -2822,7 +2833,7 @@ pthread_mutex_trylock(pthread_mutex_t *mutex)
 static int
 pthread_mutex_unlock(pthread_mutex_t *mutex)
 {
-	return ReleaseMutex(*mutex) == 0 ? -1 : 0;
+	return (ReleaseMutex(*mutex) == 0) ? -1 : 0;
 }
 
 
@@ -3903,10 +3914,9 @@ push(struct mg_context *ctx,
 			n = SSL_write(ssl, buf, len);
 			if (n <= 0) {
 				err = SSL_get_error(ssl, n);
-				if ((err == 5 /* SSL_ERROR_SYSCALL */) && (n == -1)) {
+				if ((err == SSL_ERROR_SYSCALL) && (n == -1)) {
 					err = ERRNO;
-				} else if ((err == 2 /* SSL_ERROR_WANT_READ */)
-				           || (err == 3 /* SSL_ERROR_WANT_READ */)) {
+				} else if (err == SSL_ERROR_WANT_WRITE) {
 					n = 0;
 				} else {
 					DEBUG_TRACE("SSL_write() failed, error %d", err);
@@ -4045,10 +4055,9 @@ pull(FILE *fp, struct mg_connection *conn, char *buf, int len, double timeout)
 			nread = SSL_read(conn->ssl, buf, len);
 			if (nread <= 0) {
 				err = SSL_get_error(conn->ssl, nread);
-				if ((err == 5 /* SSL_ERROR_SYSCALL */) && (nread == -1)) {
+				if ((err == SSL_ERROR_SYSCALL) && (nread == -1)) {
 					err = ERRNO;
-				} else if ((err == 2 /* SSL_ERROR_WANT_READ */)
-				           || (err == 3 /* SSL_ERROR_WANT_READ */)) {
+				} else if (err == SSL_ERROR_WANT_READ) {
 					nread = 0;
 				} else {
 					DEBUG_TRACE("SSL_read() failed, error %d", err);
@@ -4197,7 +4206,7 @@ mg_read_inner(struct mg_connection *conn, void *buf, size_t len)
 {
 	int64_t n, buffered_len, nread;
 	int64_t len64 =
-	    (int64_t)(len > INT_MAX ? INT_MAX : len); /* since the return value is
+	    (int64_t)((len > INT_MAX) ? INT_MAX : len); /* since the return value is
 	                                               * int, we may not read more
 	                                               * bytes */
 	const char *body;
@@ -4244,7 +4253,7 @@ mg_read_inner(struct mg_connection *conn, void *buf, size_t len)
 		if ((n = pull_all(NULL, conn, (char *)buf, (int)len64)) >= 0) {
 			nread += n;
 		} else {
-			nread = (nread > 0 ? nread : n);
+			nread = ((nread > 0) ? nread : n);
 		}
 	}
 	return (int)nread;
@@ -4383,7 +4392,7 @@ mg_write(struct mg_connection *conn, const void *buf, size_t len)
 			buf = (const char *)buf + total;
 			conn->last_throttle_bytes += total;
 			while (total < (int64_t)len && conn->ctx->stop_flag == 0) {
-				allowed = conn->throttle > (int64_t)len - total
+				allowed = (conn->throttle > ((int64_t)len - total))
 				              ? (int64_t)len - total
 				              : conn->throttle;
 				if ((n = push_all(conn->ctx,
@@ -4542,9 +4551,9 @@ mg_url_decode(const char *src,
               int is_form_url_encoded)
 {
 	int i, j, a, b;
-#define HEXTOI(x) (isdigit(x) ? x - '0' : x - 'W')
+#define HEXTOI(x) (isdigit(x) ? (x - '0') : (x - 'W'))
 
-	for (i = j = 0; i < src_len && j < dst_len - 1; i++, j++) {
+	for (i = j = 0; (i < src_len) && (j < (dst_len - 1)); i++, j++) {
 		if (i < src_len - 2 && src[i] == '%'
 		    && isxdigit(*(const unsigned char *)(src + i + 1))
 		    && isxdigit(*(const unsigned char *)(src + i + 2))) {
@@ -4561,7 +4570,7 @@ mg_url_decode(const char *src,
 
 	dst[j] = '\0'; /* Null-terminate the destination */
 
-	return i >= src_len ? j : -1;
+	return (i >= src_len) ? j : -1;
 }
 
 
@@ -4689,8 +4698,8 @@ base64_encode(const unsigned char *src, int src_len, char *dst)
 
 	for (i = j = 0; i < src_len; i += 3) {
 		a = src[i];
-		b = i + 1 >= src_len ? 0 : src[i + 1];
-		c = i + 2 >= src_len ? 0 : src[i + 2];
+		b = ((i + 1) >= src_len) ? 0 : src[i + 1];
+		c = ((i + 2) >= src_len) ? 0 : src[i + 2];
 
 		dst[j++] = b64[a >> 2];
 		dst[j++] = b64[((a & 3) << 4) | (b >> 4)];
@@ -4749,17 +4758,17 @@ base64_decode(const unsigned char *src, int src_len, char *dst, size_t *dst_len)
 			return i;
 		}
 
-		b = b64reverse(i + 1 >= src_len ? 0 : src[i + 1]);
+		b = b64reverse(((i + 1) >= src_len) ? 0 : src[i + 1]);
 		if (b >= 254) {
 			return i + 1;
 		}
 
-		c = b64reverse(i + 2 >= src_len ? 0 : src[i + 2]);
+		c = b64reverse(((i + 2) >= src_len) ? 0 : src[i + 2]);
 		if (c == 254) {
 			return i + 2;
 		}
 
-		d = b64reverse(i + 3 >= src_len ? 0 : src[i + 3]);
+		d = b64reverse(((i + 3) >= src_len) ? 0 : src[i + 3]);
 		if (d == 254) {
 			return i + 3;
 		}
@@ -5558,7 +5567,8 @@ mg_fgets(char *buf, size_t size, struct file *filep, char **p)
 		} else {
 			eof = memend; /* Copy remaining data */
 		}
-		len = (size_t)(eof - *p) > size - 1 ? size - 1 : (size_t)(eof - *p);
+		len =
+		    ((size_t)(eof - *p) > (size - 1)) ? (size - 1) : (size_t)(eof - *p);
 		memcpy(buf, *p, len);
 		buf[len] = '\0';
 		*p += len;
@@ -6025,7 +6035,7 @@ connect_socket(struct mg_context *ctx /* may be NULL */,
 		/* While getaddrinfo on Windows will work with [::1],
 		 * getaddrinfo on Linux only works with ::1 (without []). */
 		size_t l = strlen(host + 1);
-		char *h = l > 1 ? mg_strdup(host + 1) : NULL;
+		char *h = (l > 1) ? mg_strdup(host + 1) : NULL;
 		if (h) {
 			h[l - 1] = 0;
 			if (mg_inet_pton(AF_INET6, h, &sa->sin6, sizeof(sa->sin6))) {
@@ -6221,9 +6231,9 @@ compare_dir_entries(const void *p1, const void *p2)
 		} else if (*query_string == 'n') {
 			cmp_result = strcmp(a->file_name, b->file_name);
 		} else if (*query_string == 's') {
-			cmp_result = a->file.size == b->file.size
+			cmp_result = (a->file.size == b->file.size)
 			                 ? 0
-			                 : a->file.size > b->file.size ? 1 : -1;
+			                 : ((a->file.size > b->file.size) ? 1 : -1);
 		} else if (*query_string == 'd') {
 			cmp_result =
 			    (a->file.last_modified == b->file.last_modified)
@@ -6232,7 +6242,7 @@ compare_dir_entries(const void *p1, const void *p2)
 			                                                           : -1);
 		}
 
-		return query_string[1] == 'd' ? -cmp_result : cmp_result;
+		return (query_string[1] == 'd') ? -cmp_result : cmp_result;
 	}
 	return 0;
 }
@@ -6446,8 +6456,8 @@ handle_directory_request(struct mg_connection *conn, const char *dir)
 		return;
 	}
 
-	sort_direction = conn->request_info.query_string != NULL
-	                         && conn->request_info.query_string[1] == 'd'
+	sort_direction = ((conn->request_info.query_string != NULL)
+	                  && (conn->request_info.query_string[1] == 'd'))
 	                     ? 'a'
 	                     : 'd';
 
@@ -6520,8 +6530,8 @@ send_file_data(struct mg_connection *conn,
 	}
 
 	/* Sanity check the offset */
-	size = filep->size > INT64_MAX ? INT64_MAX : (int64_t)(filep->size);
-	offset = offset < 0 ? 0 : offset > size ? size : offset;
+	size = (filep->size > INT64_MAX) ? INT64_MAX : (int64_t)(filep->size);
+	offset = (offset < 0) ? 0 : ((offset > size) ? size : offset);
 
 	if (len > 0 && filep->membuf != NULL && size > 0) {
 		/* file stored in memory */
@@ -6733,7 +6743,7 @@ handle_static_file_request(struct mg_connection *conn,
 			return;
 		}
 		conn->status_code = 206;
-		cl = n == 2 ? (r2 > cl ? cl : r2) - r1 + 1 : cl - r1;
+		cl = (n == 2) ? (((r2 > cl) ? cl : r2) - r1 + 1) : (cl - r1);
 		mg_snprintf(conn,
 		            NULL, /* range buffer is big enough */
 		            range,
@@ -7173,7 +7183,7 @@ read_request(FILE *fp,
 		}
 	}
 
-	return (request_len <= 0 && n <= 0) ? -1 : request_len;
+	return ((request_len <= 0) && (n <= 0)) ? -1 : request_len;
 }
 
 #if !defined(NO_FILES)
@@ -7510,7 +7520,7 @@ prepare_cgi_environment(struct mg_connection *conn,
 		       conn->path_info);
 	}
 
-	addenv(env, "HTTPS=%s", conn->ssl == NULL ? "off" : "on");
+	addenv(env, "HTTPS=%s", (conn->ssl == NULL) ? "off" : "on");
 
 	if ((s = mg_get_header(conn, "Content-Type")) != NULL) {
 		addenv(env, "CONTENT_TYPE=%s", s);
@@ -8812,7 +8822,7 @@ SHA1Final(unsigned char digest[20], SHA1_CTX *context)
 	unsigned char finalcount[8], c;
 
 	for (i = 0; i < 8; i++) {
-		finalcount[i] = (unsigned char)((context->count[(i >= 4 ? 0 : 1)]
+		finalcount[i] = (unsigned char)((context->count[(i >= 4) ? 0 : 1]
 		                                 >> ((3 - (i & 3)) * 8)) & 255);
 	}
 	c = 0200;
@@ -8941,14 +8951,14 @@ read_websocket(struct mg_connection *conn,
 		assert(conn->data_len >= conn->request_len);
 		if ((body_len = (size_t)(conn->data_len - conn->request_len)) >= 2) {
 			len = buf[1] & 127;
-			mask_len = buf[1] & 128 ? 4 : 0;
-			if (len < 126 && body_len >= mask_len) {
+			mask_len = (buf[1] & 128) ? 4 : 0;
+			if ((len < 126) && (body_len >= mask_len)) {
 				data_len = len;
 				header_len = 2 + mask_len;
-			} else if (len == 126 && body_len >= 4 + mask_len) {
+			} else if ((len == 126) && (body_len >= (4 + mask_len))) {
 				header_len = 4 + mask_len;
 				data_len = ((((size_t)buf[2]) << 8) + buf[3]);
-			} else if (body_len >= 10 + mask_len) {
+			} else if (body_len >= (10 + mask_len)) {
 				header_len = 10 + mask_len;
 				data_len = (((uint64_t)ntohl(*(uint32_t *)(void *)&buf[2]))
 				            << 32) + ntohl(*(uint32_t *)(void *)&buf[6]);
@@ -9380,7 +9390,7 @@ parse_net(const char *spec, uint32_t *net, uint32_t *mask)
 		len = n;
 		*net = ((uint32_t)a << 24) | ((uint32_t)b << 16) | ((uint32_t)c << 8)
 		       | (uint32_t)d;
-		*mask = slash ? 0xffffffffU << (32 - slash) : 0;
+		*mask = slash ? (0xffffffffU << (32 - slash)) : 0;
 	}
 
 	return len;
@@ -9403,8 +9413,9 @@ set_throttle(const char *spec, uint32_t remote_ip, const char *uri)
 		        && mult != ',')) {
 			continue;
 		}
-		v *= lowercase(&mult) == 'k' ? 1024 : lowercase(&mult) == 'm' ? 1048576
-		                                                              : 1;
+		v *= (lowercase(&mult) == 'k')
+		         ? 1024
+		         : ((lowercase(&mult) == 'm') ? 1048576 : 1);
 		if (vec.len == 1 && vec.ptr[0] == '*') {
 			throttle = (int)v;
 		} else if (parse_net(vec.ptr, &net, &mask) > 0) {
@@ -10774,7 +10785,7 @@ log_access(const struct mg_connection *conn)
 	            sizeof(buf),
 	            "%s - %s [%s] \"%s %s%s%s HTTP/%s\" %d %" INT64_FMT " %s %s",
 	            src_addr,
-	            ri->remote_user == NULL ? "-" : ri->remote_user,
+	            (ri->remote_user == NULL) ? "-" : ri->remote_user,
 	            date,
 	            ri->request_method ? ri->request_method : "-",
 	            ri->request_uri ? ri->request_uri : "-",
@@ -10814,7 +10825,7 @@ check_acl(struct mg_context *ctx, uint32_t remote_ip)
 		const char *list = ctx->config[ACCESS_CONTROL_LIST];
 
 		/* If any ACL is set, deny by default */
-		allowed = list == NULL ? '+' : '-';
+		allowed = (list == NULL) ? '+' : '-';
 
 		while ((list = next_option(list, &vec, NULL)) != NULL) {
 			flag = vec.ptr[0];
@@ -11021,7 +11032,7 @@ static pthread_mutex_t *ssl_mutexes;
 static int
 sslize(struct mg_connection *conn, SSL_CTX *s, int (*func)(SSL *))
 {
-	int ret, err;
+	int ret, err, i;
 	int short_trust;
 
 	if (!conn) {
@@ -11050,21 +11061,40 @@ sslize(struct mg_connection *conn, SSL_CTX *s, int (*func)(SSL *))
 		(void)err; /* TODO: set some error message */
 		SSL_free(conn->ssl);
 		conn->ssl = NULL;
-		/* maybe not? CRYPTO_cleanup_all_ex_data(); */
-		/* see
+		/* Avoid CRYPTO_cleanup_all_ex_data(); See discussion:
 		 * https://wiki.openssl.org/index.php/Talk:Library_Initialization */
 		ERR_remove_state(0);
 		return 0;
 	}
 
-	ret = func(conn->ssl);
+	/* SSL functions may fail and require to be called again:
+	 * see https://www.openssl.org/docs/manmaster/ssl/SSL_get_error.html
+	 * Here "func" could be SSL_connect or SSL_accept. */
+	for (i = 0; i <= 16; i *= 2) {
+		ret = func(conn->ssl);
+		if (ret != 1) {
+			err = SSL_get_error(conn->ssl, ret);
+			if ((err == SSL_ERROR_WANT_CONNECT)
+			    || (err == SSL_ERROR_WANT_ACCEPT)) {
+				/* Retry */
+				mg_sleep(i);
+
+			} else {
+				/* This is an error */
+				/* TODO: set some error message */
+				break;
+			}
+
+		} else {
+			/* success */
+			break;
+		}
+	}
+
 	if (ret != 1) {
-		err = SSL_get_error(conn->ssl, ret);
-		(void)err; /* TODO: set some error message */
 		SSL_free(conn->ssl);
 		conn->ssl = NULL;
-		/* maybe not? CRYPTO_cleanup_all_ex_data(); */
-		/* see
+		/* Avoid CRYPTO_cleanup_all_ex_data(); See discussion:
 		 * https://wiki.openssl.org/index.php/Talk:Library_Initialization */
 		ERR_remove_state(0);
 		return 0;
@@ -11080,7 +11110,7 @@ ssl_error(void)
 {
 	unsigned long err;
 	err = ERR_get_error();
-	return err == 0 ? "" : ERR_error_string(err, NULL);
+	return ((err == 0) ? "" : ERR_error_string(err, NULL));
 }
 
 
@@ -11630,8 +11660,7 @@ close_connection(struct mg_connection *conn)
 		 */
 		SSL_shutdown(conn->ssl);
 		SSL_free(conn->ssl);
-		/* maybe not? CRYPTO_cleanup_all_ex_data(); */
-		/* see
+		/* Avoid CRYPTO_cleanup_all_ex_data(); See discussion:
 		 * https://wiki.openssl.org/index.php/Talk:Library_Initialization */
 		ERR_remove_state(0);
 		conn->ssl = NULL;
@@ -12474,9 +12503,9 @@ process_new_connection(struct mg_connection *conn)
 			             && conn->content_len >= 0 && should_keep_alive(conn);
 
 			/* Discard all buffered data for this request */
-			discard_len = conn->content_len >= 0 && conn->request_len > 0
-			                      && conn->request_len + conn->content_len
-			                             < (int64_t)conn->data_len
+			discard_len = ((conn->content_len >= 0) && (conn->request_len > 0)
+			               && ((conn->request_len + conn->content_len)
+			                   < (int64_t)conn->data_len))
 			                  ? (int)(conn->request_len + conn->content_len)
 			                  : conn->data_len;
 			/*assert(discard_len >= 0);*/