Quellcode durchsuchen

Possible fix for #336 (more detailed tests required)

bel vor 8 Jahren
Ursprung
Commit
993c5ae783
2 geänderte Dateien mit 75 neuen und 38 gelöschten Zeilen
  1. 63 27
      src/civetweb.c
  2. 12 11
      test/public_server.c

+ 63 - 27
src/civetweb.c

@@ -4328,29 +4328,60 @@ pull(FILE *fp, struct mg_connection *conn, char *buf, int len, double timeout)
 
 #ifndef NO_SSL
 		} else if (conn->ssl != NULL) {
-			nread = SSL_read(conn->ssl, buf, len);
-			if (nread <= 0) {
-				err = SSL_get_error(conn->ssl, nread);
-				if ((err == SSL_ERROR_SYSCALL) && (nread == -1)) {
-					err = ERRNO;
-				} else if ((err == SSL_ERROR_WANT_READ)
-				           || (err == SSL_ERROR_WANT_WRITE)) {
-					nread = 0;
+
+			struct pollfd pfd[1];
+			int pollres;
+
+			pfd[0].fd = conn->client.sock;
+			pfd[0].events = POLLIN;
+			pollres = poll(pfd, 1, (int)(timeout * 1000.0));
+			if (pollres > 0) {
+				nread = SSL_read(conn->ssl, buf, len);
+				if (nread <= 0) {
+					err = SSL_get_error(conn->ssl, nread);
+					if ((err == SSL_ERROR_SYSCALL) && (nread == -1)) {
+						err = ERRNO;
+					} else if ((err == SSL_ERROR_WANT_READ)
+					           || (err == SSL_ERROR_WANT_WRITE)) {
+						nread = 0;
+					} else {
+						DEBUG_TRACE("SSL_read() failed, error %d", err);
+						return -1;
+					}
 				} else {
-					DEBUG_TRACE("SSL_read() failed, error %d", err);
-					return -1;
+					err = 0;
 				}
+
+			} else if (pollres < 0) {
+				/* Error */
+				return -1;
 			} else {
-				err = 0;
+				/* pollres = 0 means timeout */
+				nread = 0;
 			}
+
 #endif
 
 		} else {
-			nread = (int)recv(conn->client.sock, buf, (len_t)len, 0);
-			err = (nread < 0) ? ERRNO : 0;
-			if (nread == 0) {
-				/* shutdown of the socket at client side */
+			struct pollfd pfd[1];
+			int pollres;
+
+			pfd[0].fd = conn->client.sock;
+			pfd[0].events = POLLIN;
+			pollres = poll(pfd, 1, (int)(timeout * 1000.0));
+			if (pollres > 0) {
+				nread = (int)recv(conn->client.sock, buf, (len_t)len, 0);
+				err = (nread < 0) ? ERRNO : 0;
+				if (nread == 0) {
+					/* shutdown of the socket at client side */
+					return -1;
+				}
+			} else if (pollres < 0) {
+				/* error callint poll */
 				return -1;
+			} else {
+				/* pollres = 0 means timeout */
+				nread = 0;
 			}
 		}
 
@@ -12177,23 +12208,36 @@ void
 mg_close_connection(struct mg_connection *conn)
 {
 	struct mg_context *client_ctx = NULL;
-	unsigned int i;
 
 	if (conn == NULL) {
 		return;
 	}
 
+#if defined(USE_WEBSOCKET)
 	if (conn->ctx->context_type == 2) {
+		unsigned int i;
+
 		/* ws/wss client */
 		client_ctx = conn->ctx;
 
 		/* client context: loops must end */
 		conn->ctx->stop_flag = 1;
 
-		/* get client out of recv function */
-		// shutdown(conn->client.sock, SHUTDOWN_RD);
+		/* we need to get the client thread out of the select/recv call here */
+		mg_websocket_write(conn, WEBSOCKET_OPCODE_CONNECTION_CLOSE, "", 0);
+
+		/* join worker thread */
+		for (i = 0; i < client_ctx->cfg_worker_threads; i++) {
+			if (client_ctx->workerthreadids[i] != 0) {
+				mg_join_thread(client_ctx->workerthreadids[i]);
+			}
+		}
 	}
+#else
+	(void)client_ctx;
+#endif
 
+	close_connection(conn);
 
 #ifndef NO_SSL
 	if (conn->client_ssl_ctx != NULL) {
@@ -12201,15 +12245,7 @@ mg_close_connection(struct mg_connection *conn)
 	}
 #endif
 
-	close_connection(conn);
-
 	if (client_ctx != NULL) {
-		/* join worker thread */
-		for (i = 0; i < client_ctx->cfg_worker_threads; i++) {
-			if (client_ctx->workerthreadids[i] != 0) {
-				mg_join_thread(client_ctx->workerthreadids[i]);
-			}
-		}
 		/* free context */
 		mg_free(client_ctx->workerthreadids);
 		mg_free(client_ctx);
@@ -12816,7 +12852,7 @@ websocket_client_thread(void *data)
 	struct websocket_client_thread_data *cdata =
 	    (struct websocket_client_thread_data *)data;
 
-	mg_set_thread_name("ws-client");
+	mg_set_thread_name("ws-clnt");
 
 	if (cdata->conn->ctx) {
 		if (cdata->conn->ctx->callbacks.init_thread) {

+ 12 - 11
test/public_server.c

@@ -861,6 +861,7 @@ struct tclient_data {
 	void *data;
 	size_t len;
 	int closed;
+	int clientId;
 };
 
 
@@ -880,7 +881,7 @@ websocket_client_data_handler(struct mg_connection *conn,
 	ck_assert(pclient_data != NULL);
 	ck_assert_int_eq(flags, (int)(128 | 1));
 
-	printf("Client received data from server: ");
+	printf("Client %i received data from server: ", pclient_data->clientId);
 	fwrite(data, 1, data_len, stdout);
 	printf("\n");
 
@@ -905,7 +906,7 @@ websocket_client_close_handler(const struct mg_connection *conn,
 
 	ck_assert(pclient_data != NULL);
 
-	printf("Client: Close handler\n");
+	printf("Client %i: Close handler\n", pclient_data->clientId);
 	pclient_data->closed++;
 }
 #endif
@@ -959,10 +960,10 @@ START_TEST(test_request_handlers)
 #endif
 
 #if defined(USE_WEBSOCKET)
-	struct tclient_data ws_client1_data = {NULL, 0, 0};
-	struct tclient_data ws_client2_data = {NULL, 0, 0};
-	struct tclient_data ws_client3_data = {NULL, 0, 0};
-	struct tclient_data ws_client4_data = {NULL, 0, 0};
+	struct tclient_data ws_client1_data = {NULL, 0, 0, 1};
+	struct tclient_data ws_client2_data = {NULL, 0, 0, 2};
+	struct tclient_data ws_client3_data = {NULL, 0, 0, 3};
+	struct tclient_data ws_client4_data = {NULL, 0, 0, 4};
 	struct mg_connection *ws_client1_conn = NULL;
 	struct mg_connection *ws_client2_conn = NULL;
 	struct mg_connection *ws_client3_conn = NULL;
@@ -1727,7 +1728,7 @@ START_TEST(test_request_handlers)
 	ck_assert(ws_client3_conn != NULL);
 
 	wait_not_null(
-	    &(ws_client3_data.data)); /* Wait for the websocket welcome message */
+	    &(ws_client4_data.data)); /* Wait for the websocket welcome message */
 	ck_assert(ws_client1_data.closed == 1);
 	ck_assert(ws_client2_data.closed == 1);
 	ck_assert(ws_client3_data.closed == 1);
@@ -3016,11 +3017,11 @@ MAIN_PUBLIC_SERVER(void)
 	test_the_test_environment(0);
 	test_threading(0);
 	test_mg_start_stop_http_server(0);
-	// test_mg_start_stop_https_server(0);
+	test_mg_start_stop_https_server(0);
 	test_request_handlers(0);
-	// test_mg_server_and_client_tls(0);
-	// test_handle_form(0);
-	// test_http_auth(0);
+	test_mg_server_and_client_tls(0);
+	test_handle_form(0);
+	test_http_auth(0);
 	test_keep_alive(0);
 
 	printf("\nok: %i\nfailed: %i\n\n", chk_ok, chk_failed);