Browse Source

Symmetric implementation for push/pull

bel 10 years ago
parent
commit
0589118f27
1 changed files with 116 additions and 45 deletions
  1. 116 45
      src/civetweb.c

+ 116 - 45
src/civetweb.c

@@ -2918,45 +2918,100 @@ static int set_non_blocking_mode(SOCKET sock)
 
 /* Write data to the IO channel - opened file descriptor, socket or SSL
  * descriptor. Return number of bytes written. */
-static int64_t
-push(FILE *fp, SOCKET sock, SSL *ssl, const char *buf, int64_t len)
+static int push(struct mg_context *ctx,
+                FILE *fp,
+                SOCKET sock,
+                SSL *ssl,
+                const char *buf,
+                int len,
+                double timeout)
 {
-	int64_t sent;
-	int n, k;
+	struct timespec start, now;
+	int n;
+
+#ifdef _WIN32
+	typedef int len_t;
+#else
+	typedef size_t len_t;
+#endif
 
-	(void)ssl; /* Get rid of warning */
-	sent = 0;
-	while (sent < len) {
-		/* How many bytes we send in this iteration */
-		k = len - sent > INT_MAX ? INT_MAX : (int)(len - sent);
+	memset(&start, 0, sizeof(start));
+	memset(&now, 0, sizeof(now));
+
+	if (timeout > 0) {
+		clock_gettime(CLOCK_MONOTONIC, &start);
+	}
+
+	do {
 
 #ifndef NO_SSL
 		if (ssl != NULL) {
-			n = SSL_write(ssl, buf + sent, k);
+			n = SSL_write(ssl, buf, len);
 		} else
 #endif
 		    if (fp != NULL) {
-			n = (int)fwrite(buf + sent, 1, (size_t)k, fp);
+			n = (int)fwrite(buf, 1, (size_t)len, fp);
 			if (ferror(fp))
 				n = -1;
 		} else {
-#ifdef _WIN32
-			typedef int len_t;
-#else
-			typedef size_t len_t;
-#endif
-			n = (int)send(sock, buf + sent, (len_t)k, MSG_NOSIGNAL);
+			n = (int)send(sock, buf, (len_t)len, MSG_NOSIGNAL);
 		}
 
-		if (n <= 0)
-			break;
+		if (ctx->stop_flag) {
+			return -1;
+		}
+		if ((n > 0) || (n == 0 && len == 0)) {
+			/* some data has been read, or no data was requested */
+			return n;
+		}
+		if (n == 0) {
+			/* shutdown of the socket at client side */
+			return -1;
+		}
+		if (n < 0) {
+			/* socket error - check errno */
+			DEBUG_TRACE("send() failed, error %d", ERRNO);
+			return -1;
+		}
+		if (timeout > 0) {
+			clock_gettime(CLOCK_MONOTONIC, &now);
+		}
+	} while ((timeout <= 0) || (mg_difftimespec(&now, &start) <= timeout));
+
+	return -1;
+}
+
+static int64_t push_all(struct mg_context *ctx,
+                        FILE *fp,
+                        SOCKET sock,
+                        SSL *ssl,
+                        const char *buf,
+                        int64_t len)
+{
+	double timeout = -1.0;
+	int64_t n, nwritten = 0;
 
-		sent += n;
+	if (ctx->config[REQUEST_TIMEOUT]) {
+		timeout = atoi(ctx->config[REQUEST_TIMEOUT]) / 1000.0;
+	}
+
+	while (len > 0 && ctx->stop_flag == 0) {
+		n = push(ctx, fp, sock, ssl, buf + nwritten, (int)len, timeout);
+		if (n < 0) {
+			nwritten = n; /* Propagate the error */
+			break;
+		} else if (n == 0) {
+			break; /* No more data to write */
+		} else {
+			nwritten += n;
+			len -= n;
+		}
 	}
 
-	return sent;
+	return nwritten;
 }
 
+
 /* Read from IO channel - opened file descriptor, socket, or SSL descriptor.
  * Return negative value on error, or number of bytes read on success. */
 static int
@@ -2965,6 +3020,12 @@ pull(FILE *fp, struct mg_connection *conn, char *buf, int len, double timeout)
 	int nread;
 	struct timespec start, now;
 
+#ifdef _WIN32
+	typedef int len_t;
+#else
+	typedef size_t len_t;
+#endif
+
 	memset(&start, 0, sizeof(start));
 	memset(&now, 0, sizeof(now));
 
@@ -2984,19 +3045,24 @@ pull(FILE *fp, struct mg_connection *conn, char *buf, int len, double timeout)
 			nread = SSL_read(conn->ssl, buf, len);
 #endif
 		} else {
-#ifdef _WIN32
-			typedef int len_t;
-#else
-			typedef size_t len_t;
-#endif
 			nread = (int)recv(conn->client.sock, buf, (len_t)len, 0);
 		}
 		if (conn->ctx->stop_flag) {
 			return -1;
 		}
-		if (nread >= 0) {
+		if ((nread > 0) || (nread == 0 && len == 0)) {
+			/* some data has been read, or no data was requested */
 			return nread;
 		}
+		if (nread == 0) {
+			/* shutdown of the socket at client side */
+			return -1;
+		}
+		if (nread < 0) {
+			/* socket error - check errno */
+			DEBUG_TRACE("recv() failed, error %d", ERRNO);
+			return -1;
+		}
 		if (timeout > 0) {
 			clock_gettime(CLOCK_MONOTONIC, &now);
 		}
@@ -3006,6 +3072,7 @@ pull(FILE *fp, struct mg_connection *conn, char *buf, int len, double timeout)
 	return -1;
 }
 
+
 static int pull_all(FILE *fp, struct mg_connection *conn, char *buf, int len)
 {
 	int n, nread = 0;
@@ -3232,22 +3299,24 @@ int mg_write(struct mg_connection *conn, const void *buf, size_t len)
 		if (allowed > (int64_t)len) {
 			allowed = (int64_t)len;
 		}
-		if ((total = push(NULL,
-		                  conn->client.sock,
-		                  conn->ssl,
-		                  (const char *)buf,
-		                  (int64_t)allowed)) == allowed) {
+		if ((total = push_all(conn->ctx,
+		                      NULL,
+		                      conn->client.sock,
+		                      conn->ssl,
+		                      (const char *)buf,
+		                      (int64_t)allowed)) == allowed) {
 			buf = (char *)buf + total;
 			conn->last_throttle_bytes += total;
 			while (total < (int64_t)len && conn->ctx->stop_flag == 0) {
 				allowed = conn->throttle > (int64_t)len - total
 				              ? (int64_t)len - total
 				              : conn->throttle;
-				if ((n = push(NULL,
-				              conn->client.sock,
-				              conn->ssl,
-				              (const char *)buf,
-				              (int64_t)allowed)) != allowed) {
+				if ((n = push_all(conn->ctx,
+				                  NULL,
+				                  conn->client.sock,
+				                  conn->ssl,
+				                  (const char *)buf,
+				                  (int64_t)allowed)) != allowed) {
 					break;
 				}
 				sleep(1);
@@ -3258,11 +3327,12 @@ int mg_write(struct mg_connection *conn, const void *buf, size_t len)
 			}
 		}
 	} else {
-		total = push(NULL,
-		             conn->client.sock,
-		             conn->ssl,
-		             (const char *)buf,
-		             (int64_t)len);
+		total = push_all(conn->ctx,
+		                 NULL,
+		                 conn->client.sock,
+		                 conn->ssl,
+		                 (const char *)buf,
+		                 (int64_t)len);
 	}
 	return (int)total;
 }
@@ -5493,7 +5563,7 @@ forward_body_data(struct mg_connection *conn, FILE *fp, SOCKET sock, SSL *ssl)
 				buffered_len = (int)conn->content_len;
 			}
 			body = conn->buf + conn->request_len + conn->consumed_content;
-			push(fp, sock, ssl, body, (int64_t)buffered_len);
+			push_all(conn->ctx, fp, sock, ssl, body, (int64_t)buffered_len);
 			conn->consumed_content += buffered_len;
 		}
 
@@ -5504,7 +5574,8 @@ forward_body_data(struct mg_connection *conn, FILE *fp, SOCKET sock, SSL *ssl)
 				to_read = (int)(conn->content_len - conn->consumed_content);
 			}
 			nread = pull(NULL, conn, buf, to_read, timeout);
-			if (nread <= 0 || push(fp, sock, ssl, buf, nread) != nread) {
+			if (nread <= 0 ||
+			    push_all(conn->ctx, fp, sock, ssl, buf, nread) != nread) {
 				break;
 			}
 			conn->consumed_content += nread;
@@ -8189,7 +8260,7 @@ static int set_ports_option(struct mg_context *ctx)
 			            * if someone already has the socket -- DTL */
 			           setsockopt(so.sock,
 			                      SOL_SOCKET,
-			                      SO_EXCLUSIVEADDRUSE,
+			                      SO_REUSEADDR /* TODO(high): check with unit test -> SO_EXCLUSIVEADDRUSE */,
 			                      (SOCK_OPT_TYPE)&on,
 			                      sizeof(on)) != 0 ||
 #else