Browse Source

Allow to stop websocket connectionss async from the server side (#515)

bel2125 7 years ago
parent
commit
0adf463b92
1 changed files with 107 additions and 58 deletions
  1. 107 58
      src/civetweb.c

+ 107 - 58
src/civetweb.c

@@ -2270,15 +2270,21 @@ struct mg_handler_info {
 };
 
 
+enum {
+	CONTEXT_INVALID,
+	CONTEXT_SERVER,
+	CONTEXT_HTTP_CLIENT,
+	CONTEXT_WS_CLIENT
+};
+
+
 struct mg_context {
 	volatile int stop_flag;        /* Should we stop event loop */
 	SSL_CTX *ssl_ctx;              /* SSL context */
 	char *config[NUM_OPTIONS];     /* Civetweb configuration parameters */
 	struct mg_callbacks callbacks; /* User-defined callback function */
 	void *user_data;               /* User-defined data */
-	int context_type;              /* 1 = server context,
-	                                * 2 = ws/wss client context,
-	                                */
+	int context_type;              /* See CONTEXT_* above */
 
 	struct socket *listening_sockets;
 	struct pollfd *listening_socket_fds;
@@ -2356,12 +2362,14 @@ get_memory_stat(struct mg_context *ctx)
 }
 #endif
 
+enum {
+	CONNECTION_TYPE_INVALID,
+	CONNECTION_TYPE_REQUEST,
+	CONNECTION_TYPE_RESPONSE
+};
 
 struct mg_connection {
-	int connection_type; /* 0 none
-	                      * 1 request (we are server, mg_request_info valid)
-	                      * 2 response (we are client, response_info valid)
-	                      */
+	int connection_type; /* see CONNECTION_TYPE_* above */
 
 	struct mg_request_info request_info;
 	struct mg_response_info response_info;
@@ -2394,19 +2402,23 @@ struct mg_connection {
 	char *buf;                /* Buffer for received data */
 	char *path_info;          /* PATH_INFO part of the URL */
 
-	int must_close;            /* 1 if connection must be closed */
-	int accept_gzip;           /* 1 if gzip encoding is accepted */
-	int in_error_handler;      /* 1 if in handler for user defined error
-	                            * pages */
-	int handled_requests;      /* Number of requests handled by this connection
-	                              */
-	int buf_size;              /* Buffer size */
-	int request_len;           /* Size of the request + headers in a buffer */
-	int data_len;              /* Total size of data in a buffer */
-	int status_code;           /* HTTP reply status code, e.g. 200 */
-	int throttle;              /* Throttling, bytes/sec. <= 0 means no
-	                            * throttle */
-	time_t last_throttle_time; /* Last time throttled data was sent */
+	int must_close;       /* 1 if connection must be closed */
+	int accept_gzip;      /* 1 if gzip encoding is accepted */
+	int in_error_handler; /* 1 if in handler for user defined error
+	                       * pages */
+#if defined(USE_WEBSOCKET)
+	int in_websocket_handling; /* 1 if in read_websocket */
+#endif
+	int handled_requests; /* Number of requests handled by this connection
+	                         */
+	int buf_size;         /* Buffer size */
+	int request_len;      /* Size of the request + headers in a buffer */
+	int data_len;         /* Total size of data in a buffer */
+	int status_code;      /* HTTP reply status code, e.g. 200 */
+	int throttle;         /* Throttling, bytes/sec. <= 0 means no
+	                       * throttle */
+
+	time_t last_throttle_time;   /* Last time throttled data was sent */
 	int64_t last_throttle_bytes; /* Bytes sent this second */
 	pthread_mutex_t mutex;       /* Used by mg_(un)lock_connection to ensure
 	                              * atomic transmissions for websockets */
@@ -3359,7 +3371,7 @@ mg_get_request_info(const struct mg_connection *conn)
 		return NULL;
 	}
 #if 1 /* TODO: deal with legacy */
-	if (conn->connection_type == 2) {
+	if (conn->connection_type == CONNECTION_TYPE_RESPONSE) {
 		static char txt[16];
 		sprintf(txt, "%03i", conn->response_info.status_code);
 
@@ -3374,7 +3386,7 @@ mg_get_request_info(const struct mg_connection *conn)
 		       sizeof(conn->response_info.http_headers));
 	} else
 #endif
-	    if (conn->connection_type != 1) {
+	    if (conn->connection_type != CONNECTION_TYPE_REQUEST) {
 		return NULL;
 	}
 	return &conn->request_info;
@@ -3387,7 +3399,7 @@ mg_get_response_info(const struct mg_connection *conn)
 	if (!conn) {
 		return NULL;
 	}
-	if (conn->connection_type != 2) {
+	if (conn->connection_type != CONNECTION_TYPE_RESPONSE) {
 		return NULL;
 	}
 	return &conn->response_info;
@@ -3619,12 +3631,12 @@ mg_get_header(const struct mg_connection *conn, const char *name)
 		return NULL;
 	}
 
-	if (conn->connection_type == 1) {
+	if (conn->connection_type == CONNECTION_TYPE_REQUEST) {
 		return get_header(conn->request_info.http_headers,
 		                  conn->request_info.num_headers,
 		                  name);
 	}
-	if (conn->connection_type == 2) {
+	if (conn->connection_type == CONNECTION_TYPE_RESPONSE) {
 		return get_header(conn->response_info.http_headers,
 		                  conn->request_info.num_headers,
 		                  name);
@@ -3640,10 +3652,10 @@ get_http_version(const struct mg_connection *conn)
 		return NULL;
 	}
 
-	if (conn->connection_type == 1) {
+	if (conn->connection_type == CONNECTION_TYPE_REQUEST) {
 		return conn->request_info.http_version;
 	}
-	if (conn->connection_type == 2) {
+	if (conn->connection_type == CONNECTION_TYPE_RESPONSE) {
 		return conn->response_info.http_version;
 	}
 	return NULL;
@@ -11225,11 +11237,12 @@ read_websocket(struct mg_connection *conn,
 		timeout = atoi(conn->ctx->config[REQUEST_TIMEOUT]) / 1000.0;
 	}
 
+	conn->in_websocket_handling = 1;
 	mg_set_thread_name("wsock");
 
 	/* Loop continuously, reading messages from the socket, invoking the
 	 * callback, and waiting repeatedly until an error occurs. */
-	while (!conn->ctx->stop_flag) {
+	while (!conn->ctx->stop_flag && !conn->must_close) {
 		header_len = 0;
 		assert(conn->data_len >= conn->request_len);
 		if ((body_len = (size_t)(conn->data_len - conn->request_len)) >= 2) {
@@ -11390,6 +11403,7 @@ read_websocket(struct mg_connection *conn,
 	}
 
 	mg_set_thread_name("worker");
+	conn->in_websocket_handling = 0;
 }
 
 
@@ -14495,7 +14509,8 @@ reset_per_request_attributes(struct mg_connection *conn)
 	if (!conn) {
 		return;
 	}
-	conn->connection_type = 0; /* Not yet a valid request/response */
+	conn->connection_type =
+	    CONNECTION_TYPE_INVALID; /* Not yet a valid request/response */
 
 	conn->num_bytes_sent = conn->consumed_content = 0;
 
@@ -14721,12 +14736,15 @@ close_connection(struct mg_connection *conn)
 #endif
 
 	mg_lock_connection(conn);
+
+	/* Set close flag, so keep-alive loops will stop */
 	conn->must_close = 1;
 
 	/* call the connection_close callback if assigned */
-	if ((conn->ctx->callbacks.connection_close != NULL)
-	    && (conn->ctx->context_type == 1)) {
-		conn->ctx->callbacks.connection_close(conn);
+	if (conn->ctx->callbacks.connection_close != NULL) {
+		if (conn->ctx->context_type == CONTEXT_SERVER) {
+			conn->ctx->callbacks.connection_close(conn);
+		}
 	}
 
 	/* Reset user data, after close callback is called.
@@ -14778,7 +14796,15 @@ mg_close_connection(struct mg_connection *conn)
 	}
 
 #if defined(USE_WEBSOCKET)
-	if (conn->ctx->context_type == 2) {
+	if (conn->ctx->context_type == CONTEXT_SERVER) {
+		if (conn->in_websocket_handling) {
+			/* Set close flag, so the server thread can exit. */
+			conn->must_close = 1;
+			return;
+		}
+	}
+	if (conn->ctx->context_type == CONTEXT_WS_CLIENT) {
+
 		unsigned int i;
 
 		/* ws/wss client */
@@ -14786,6 +14812,7 @@ mg_close_connection(struct mg_connection *conn)
 
 		/* client context: loops must end */
 		conn->ctx->stop_flag = 1;
+		conn->must_close = 1;
 
 		/* We need to get the client thread out of the select/recv call
 		 * here. */
@@ -14816,17 +14843,18 @@ mg_close_connection(struct mg_connection *conn)
 		mg_free(client_ctx);
 		(void)pthread_mutex_destroy(&conn->mutex);
 		mg_free(conn);
-	} else if (conn->ctx->context_type == 0) { /* Client */
+	} else if (conn->ctx->context_type == CONTEXT_HTTP_CLIENT) {
 		mg_free(conn);
 	}
 #else
-	if (conn->ctx->context_type == 0) { /* Client */
+	if (conn->ctx->context_type == CONTEXT_HTTP_CLIENT) { /* Client */
 		mg_free(conn);
 	}
 #endif /* defined(USE_WEBSOCKET) */
 }
 
 
+/* Only for memory statistics */
 static struct mg_context common_client_context;
 
 
@@ -14845,17 +14873,6 @@ mg_connect_client_impl(const struct mg_client_options *client_options,
 	unsigned max_req_size =
 	    (unsigned)atoi(config_options[MAX_REQUEST_SIZE].default_value);
 
-	if (!connect_socket(&common_client_context,
-	                    client_options->host,
-	                    client_options->port,
-	                    use_ssl,
-	                    ebuf,
-	                    ebuf_len,
-	                    &sock,
-	                    &sa)) {
-		return NULL;
-	}
-
 	conn = (struct mg_connection *)mg_calloc_ctx(1,
 	                                             sizeof(*conn) + max_req_size,
 	                                             &common_client_context);
@@ -14867,7 +14884,39 @@ mg_connect_client_impl(const struct mg_client_options *client_options,
 		            ebuf_len,
 		            "calloc(): %s",
 		            strerror(ERRNO));
-		closesocket(sock);
+		return NULL;
+	}
+
+	conn->ctx =
+	    (struct mg_context *)mg_malloc_ctx(sizeof(common_client_context),
+	                                       &common_client_context);
+
+	if (conn == NULL) {
+		mg_snprintf(NULL,
+		            NULL, /* No truncation check for ebuf */
+		            ebuf,
+		            ebuf_len,
+		            "calloc(): %s",
+		            strerror(ERRNO));
+		mg_free(conn);
+		return NULL;
+	}
+
+	*(conn->ctx) = common_client_context;
+	conn->ctx->context_type = CONTEXT_HTTP_CLIENT;
+
+	if (!connect_socket(&common_client_context,
+	                    client_options->host,
+	                    client_options->port,
+	                    use_ssl,
+	                    ebuf,
+	                    ebuf_len,
+	                    &sock,
+	                    &sa)) {
+		/* ebuf is set by connect_socket,
+		 * free all memory and return NULL; */
+		mg_free(conn->ctx);
+		mg_free(conn);
 		return NULL;
 	}
 
@@ -14914,7 +14963,6 @@ mg_connect_client_impl(const struct mg_client_options *client_options,
 
 	conn->buf_size = (int)max_req_size;
 	conn->buf = (char *)(conn + 1);
-	conn->ctx = &common_client_context;
 	conn->client.sock = sock;
 	conn->client.lsa = sa;
 
@@ -15375,7 +15423,7 @@ get_request(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 		conn->content_len = 0; /* No content */
 	}
 
-	conn->connection_type = 1; /* Valid request */
+	conn->connection_type = CONNECTION_TYPE_REQUEST; /* Valid request */
 	return 1;
 }
 
@@ -15434,7 +15482,7 @@ get_response(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 		conn->content_len = -1; /* unknown content length */
 	}
 
-	conn->connection_type = 2; /* Valid response */
+	conn->connection_type = CONNECTION_TYPE_RESPONSE; /* Valid response */
 	return 1;
 }
 
@@ -15690,7 +15738,7 @@ mg_connect_websocket_client(const char *host,
 	newctx = (struct mg_context *)mg_malloc(sizeof(struct mg_context));
 	memcpy(newctx, conn->ctx, sizeof(struct mg_context));
 	newctx->user_data = user_data;
-	newctx->context_type = 2;       /* ws/wss client context type */
+	newctx->context_type = CONTEXT_WS_CLIENT; /* ws/wss client context */
 	newctx->cfg_worker_threads = 1; /* one worker thread will be created */
 	newctx->worker_threadids =
 	    (pthread_t *)mg_calloc_ctx(newctx->cfg_worker_threads,
@@ -15759,12 +15807,13 @@ init_connection(struct mg_connection *conn)
 	conn->conn_state = 2; /* init */
 #endif
 
-	/* call the connection_close callback if assigned */
-	if ((conn->ctx->callbacks.init_connection != NULL)
-	    && (conn->ctx->context_type == 1)) {
-		void *conn_data = NULL;
-		conn->ctx->callbacks.init_connection(conn, &conn_data);
-		mg_set_user_connection_data(conn, conn_data);
+	/* call the init_connection callback if assigned */
+	if (conn->ctx->callbacks.init_connection != NULL) {
+		if (conn->ctx->context_type == CONTEXT_SERVER) {
+			void *conn_data = NULL;
+			conn->ctx->callbacks.init_connection(conn, &conn_data);
+			mg_set_user_connection_data(conn, conn_data);
+		}
 	}
 }
 
@@ -16912,7 +16961,7 @@ mg_start(const struct mg_callbacks *callbacks,
 		ctx->callbacks.init_context(ctx);
 	}
 	ctx->callbacks.exit_context = exit_callback;
-	ctx->context_type = 1; /* server context */
+	ctx->context_type = CONTEXT_SERVER; /* server context */
 
 	/* Start master (listening) thread */
 	mg_start_thread_with_id(master_thread, ctx, &ctx->masterthreadid);