Browse Source

Rewrite request parsing (Step 5/?)

bel2125 8 years ago
parent
commit
aa37d8e10d
3 changed files with 132 additions and 37 deletions
  1. 8 1
      include/civetweb.h
  2. 121 35
      src/civetweb.c
  3. 3 1
      src/handle_form.inl

+ 8 - 1
include/civetweb.h

@@ -602,11 +602,18 @@ CIVETWEB_API int mg_modify_passwords_file(const char *passwords_file_name,
                                           const char *password);
 
 
-/* Return information associated with the request. */
+/* Return information associated with the request.
+ * Use this function to implement a server and get data about a request
+ * from a HTTP/HTTPS client. */
 CIVETWEB_API const struct mg_request_info *
 mg_get_request_info(const struct mg_connection *);
 
 
+/* Return information associated with the request. */
+CIVETWEB_API const struct mg_response_info *
+mg_get_response_info(const struct mg_connection *);
+
+
 /* Send data to the client.
    Return:
     0   when the connection has been closed

+ 121 - 35
src/civetweb.c

@@ -2138,10 +2138,16 @@ get_memory_stat(struct mg_context *ctx)
 
 
 struct mg_connection {
+	int connection_type; /* 0 none
+	                      * 1 request (we are server, mg_request_info valid)
+	                      * 2 response (we are client, response_info valid)
+	                      */
+
 	struct mg_request_info request_info;
 	struct mg_response_info response_info;
 
 	struct mg_context *ctx;
+
 	SSL *ssl;                 /* SSL descriptor */
 	SSL_CTX *client_ssl_ctx;  /* SSL context for client connections */
 	struct socket client;     /* Connected client */
@@ -2994,7 +3000,9 @@ mg_cry(const struct mg_connection *conn, const char *fmt, ...)
 				fprintf(fi.access.fp,
 				        "%s %s: ",
 				        conn->request_info.request_method,
-				        conn->request_info.request_uri);
+				        conn->request_info.request_uri
+				            ? conn->request_info.request_uri
+				            : "");
 			}
 
 			fprintf(fi.access.fp, "%s", buf);
@@ -3032,10 +3040,39 @@ mg_get_request_info(const struct mg_connection *conn)
 	if (!conn) {
 		return NULL;
 	}
+#if 1 /* TODO: deal with legacy */
+	if (conn->connection_type == 2) {
+		static char txt[16];
+		sprintf(txt, "%03i", conn->response_info.status_code);
+		((struct mg_connection *)conn)->request_info.request_uri =
+		    txt; /* TODO: not thread safe */
+		((struct mg_connection *)conn)->request_info.num_headers =
+		    conn->response_info.num_headers;
+		memcpy(((struct mg_connection *)conn)->request_info.http_headers,
+		       conn->response_info.http_headers,
+		       sizeof(conn->response_info.http_headers));
+	} else
+#endif
+	    if (conn->connection_type != 1) {
+		return NULL;
+	}
 	return &conn->request_info;
 }
 
 
+const struct mg_response_info *
+mg_get_response_info(const struct mg_connection *conn)
+{
+	if (!conn) {
+		return NULL;
+	}
+	if (conn->connection_type != 2) {
+		return NULL;
+	}
+	return &conn->response_info;
+}
+
+
 int
 mg_get_request_link(const struct mg_connection *conn, char *buf, size_t buflen)
 {
@@ -3205,14 +3242,12 @@ skip_quoted(char **buf,
 
 /* Return HTTP header value, or NULL if not found. */
 static const char *
-get_header(const struct mg_request_info *ri, const char *name)
+get_header(const struct mg_header *hdr, int num_hdr, const char *name)
 {
 	int i;
-	if (ri) {
-		for (i = 0; i < ri->num_headers; i++) {
-			if (!mg_strcasecmp(name, ri->http_headers[i].name)) {
-				return ri->http_headers[i].value;
-			}
+	for (i = 0; i < num_hdr; i++) {
+		if (!mg_strcasecmp(name, hdr[i].name)) {
+			return hdr[i].value;
 		}
 	}
 
@@ -3224,10 +3259,10 @@ get_header(const struct mg_request_info *ri, const char *name)
 /* Retrieve requested HTTP header multiple values, and return the number of
  * found occurences */
 static int
-get_headers(const struct mg_request_info *ri,
-            const char *name,
-            const char **output,
-            int output_max_size)
+get_req_headers(const struct mg_request_info *ri,
+                const char *name,
+                const char **output,
+                int output_max_size)
 {
 	int i;
 	int cnt = 0;
@@ -3250,10 +3285,36 @@ mg_get_header(const struct mg_connection *conn, const char *name)
 		return NULL;
 	}
 
-	return get_header(&conn->request_info, name);
+	if (conn->connection_type == 1) {
+		return get_header(conn->request_info.http_headers,
+		                  conn->request_info.num_headers,
+		                  name);
+	}
+	if (conn->connection_type == 2) {
+		return get_header(conn->response_info.http_headers,
+		                  conn->request_info.num_headers,
+		                  name);
+	}
+	return NULL;
 }
 
 
+static const char *
+get_http_version(const struct mg_connection *conn)
+{
+	if (!conn) {
+		return NULL;
+	}
+
+	if (conn->connection_type == 1) {
+		return conn->request_info.http_version;
+	}
+	if (conn->connection_type == 2) {
+		return conn->response_info.http_version;
+	}
+	return NULL;
+}
+
 /* A helper function for traversing a comma separated list of values.
  * It returns a list pointer shifted to the next value, or NULL if the end
  * of the list found.
@@ -3383,7 +3444,7 @@ static int
 should_keep_alive(const struct mg_connection *conn)
 {
 	if (conn != NULL) {
-		const char *http_version = conn->request_info.http_version;
+		const char *http_version = get_http_version(conn);
 		const char *header = mg_get_header(conn, "Connection");
 		if (conn->must_close || (conn->status_code == 401)
 		    || mg_strcasecmp(conn->ctx->config[ENABLE_KEEP_ALIVE], "yes") != 0
@@ -9545,19 +9606,22 @@ handle_cgi_request(struct mg_connection *conn, const char *prog)
 
 	/* Make up and send the status line */
 	status_text = "OK";
-	if ((status = get_header(&ri, "Status")) != NULL) {
+	if ((status = get_header(ri.http_headers, ri.num_headers, "Status"))
+	    != NULL) {
 		conn->status_code = atoi(status);
 		status_text = status;
 		while (isdigit(*(const unsigned char *)status_text)
 		       || *status_text == ' ') {
 			status_text++;
 		}
-	} else if (get_header(&ri, "Location") != NULL) {
+	} else if (get_header(ri.http_headers, ri.num_headers, "Location")
+	           != NULL) {
 		conn->status_code = 302;
 	} else {
 		conn->status_code = 200;
 	}
-	connection_state = get_header(&ri, "Connection");
+	connection_state =
+	    get_header(ri.http_headers, ri.num_headers, "Connection");
 	if (!header_has_option(connection_state, "keep-alive")) {
 		conn->must_close = 1;
 	}
@@ -10824,10 +10888,10 @@ handle_websocket_request(struct mg_connection *conn,
 	if (is_callback_resource) {
 		/* Step 2.1 check and select subprotocol */
 		const char *protocols[64]; // max 64 headers
-		int nbSubprotocolHeader = get_headers(&conn->request_info,
-		                                      "Sec-WebSocket-Protocol",
-		                                      protocols,
-		                                      64);
+		int nbSubprotocolHeader = get_req_headers(&conn->request_info,
+		                                          "Sec-WebSocket-Protocol",
+		                                          protocols,
+		                                          64);
 		if ((nbSubprotocolHeader > 0) && subprotocols) {
 			int cnt = 0;
 			int idx;
@@ -11751,8 +11815,11 @@ handle_request(struct mg_connection *conn)
 		    conn->ctx->config[ACCESS_CONTROL_ALLOW_METHODS];
 		const char *cors_orig_cfg =
 		    conn->ctx->config[ACCESS_CONTROL_ALLOW_ORIGIN];
-		const char *cors_origin = get_header(ri, "Origin");
-		const char *cors_acrm = get_header(ri, "Access-Control-Request-Method");
+		const char *cors_origin =
+		    get_header(ri->http_headers, ri->num_headers, "Origin");
+		const char *cors_acrm = get_header(ri->http_headers,
+		                                   ri->num_headers,
+		                                   "Access-Control-Request-Method");
 
 		/* Todo: check if cors_origin is in cors_orig_cfg.
 		 * Or, let the client check this. */
@@ -11763,7 +11830,9 @@ handle_request(struct mg_connection *conn)
 			/* This is a valid CORS preflight, and the server is configured to
 			 * handle it automatically. */
 			const char *cors_acrh =
-			    get_header(ri, "Access-Control-Request-Headers");
+			    get_header(ri->http_headers,
+			               ri->num_headers,
+			               "Access-Control-Request-Headers");
 
 			gmt_time_string(date, sizeof(date), &curtime);
 			mg_printf(conn,
@@ -13659,24 +13728,33 @@ reset_per_request_attributes(struct mg_connection *conn)
 	if (!conn) {
 		return;
 	}
+	conn->connection_type = 0; /* Not yet a valid request/response */
+
 	conn->path_info = NULL;
 	conn->num_bytes_sent = conn->consumed_content = 0;
 	conn->status_code = -1;
 	conn->is_chunked = 0;
-	conn->must_close = conn->request_len = conn->throttle = 0;
-	conn->request_info.content_length = -1;
+	conn->must_close = 0;
+	conn->request_len = 0;
+	conn->throttle = 0;
+	conn->data_len = 0;
+	conn->chunk_remainder = 0;
+
+	conn->response_info.content_length = conn->request_info.content_length = -1;
+	conn->response_info.http_version = conn->request_info.http_version = NULL;
+	conn->response_info.num_headers = conn->request_info.num_headers = 0;
+	conn->response_info.status_text = NULL;
+	conn->response_info.status_code = 0;
+
 	conn->request_info.remote_user = NULL;
 	conn->request_info.request_method = NULL;
 	conn->request_info.request_uri = NULL;
 	conn->request_info.local_uri = NULL;
+
 #if defined(MG_LEGACY_INTERFACE)
 	/* Legacy before split into local_uri and request_uri */
 	conn->request_info.uri = NULL;
 #endif
-	conn->request_info.http_version = NULL;
-	conn->request_info.num_headers = 0;
-	conn->data_len = 0;
-	conn->chunk_remainder = 0;
 }
 
 
@@ -14455,7 +14533,9 @@ get_request(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 	}
 
 	/* Message is a valid request */
-	if ((cl = get_header(&conn->request_info, "Content-Length")) != NULL) {
+	if ((cl = get_header(conn->request_info.http_headers,
+	                     conn->request_info.num_headers,
+	                     "Content-Length")) != NULL) {
 		/* Request/response has content length set */
 		char *endptr = NULL;
 		conn->content_len = strtoll(cl, &endptr, 10);
@@ -14471,8 +14551,9 @@ get_request(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 		}
 		/* Publish the content length back to the request info. */
 		conn->request_info.content_length = conn->content_len;
-	} else if ((cl = get_header(&conn->request_info, "Transfer-Encoding"))
-	               != NULL
+	} else if ((cl = get_header(conn->request_info.http_headers,
+	                            conn->request_info.num_headers,
+	                            "Transfer-Encoding")) != NULL
 	           && !mg_strcasecmp(cl, "chunked")) {
 		conn->is_chunked = 1;
 	} else if (!mg_strcasecmp(conn->request_info.request_method, "POST")
@@ -14487,6 +14568,7 @@ get_request(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 		conn->content_len = 0;
 	}
 
+	conn->connection_type = 1; /* Valid request */
 	return 1;
 }
 
@@ -14513,7 +14595,9 @@ get_response(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 	}
 
 	/* Message is a valid response */
-	if ((cl = get_header(&conn->request_info, "Content-Length")) != NULL) {
+	if ((cl = get_header(conn->response_info.http_headers,
+	                     conn->response_info.num_headers,
+	                     "Content-Length")) != NULL) {
 		/* Request/response has content length set */
 		char *endptr = NULL;
 		conn->content_len = strtoll(cl, &endptr, 10);
@@ -14529,12 +14613,14 @@ get_response(struct mg_connection *conn, char *ebuf, size_t ebuf_len, int *err)
 		}
 		/* Publish the content length back to the request info. */
 		conn->request_info.content_length = conn->content_len;
-	} else if ((cl = get_header(&conn->request_info, "Transfer-Encoding"))
-	               != NULL
+	} else if ((cl = get_header(conn->response_info.http_headers,
+	                            conn->response_info.num_headers,
+	                            "Transfer-Encoding")) != NULL
 	           && !mg_strcasecmp(cl, "chunked")) {
 		conn->is_chunked = 1;
 	}
 
+	conn->connection_type = 2; /* Valid response */
 	return 1;
 }
 

+ 3 - 1
src/handle_form.inl

@@ -634,7 +634,9 @@ mg_handle_form_request(struct mg_connection *conn,
 
 			/* According to the RFC, every part has to have a header field like:
 			 * Content-Disposition: form-data; name="..." */
-			content_disp = get_header(&part_header, "Content-Disposition");
+			content_disp = get_header(part_header.http_headers,
+			                          part_header.num_headers,
+			                          "Content-Disposition");
 			if (!content_disp) {
 				/* Malformed request */
 				return -1;