瀏覽代碼

Implement -T option, throttling functionality

Sergey Lyubka 12 年之前
父節點
當前提交
dfc0f00478
共有 3 個文件被更改,包括 140 次插入29 次删除
  1. 12 0
      mongoose.1
  2. 98 29
      mongoose.c
  3. 30 0
      test/unit_test.c

+ 12 - 0
mongoose.1

@@ -91,6 +91,18 @@ Authorization realm. Default: "mydomain.com"
 All files that fully match ssi_pattern are treated as SSI.
 All files that fully match ssi_pattern are treated as SSI.
 Unknown SSI directives are silently ignored. Currently, two SSI directives
 Unknown SSI directives are silently ignored. Currently, two SSI directives
 are supported, "include" and "exec".  Default: "**.shtml$|**.shtm$"
 are supported, "include" and "exec".  Default: "**.shtml$|**.shtm$"
+.It Fl T Ar throttle
+Limit download speed for clients.
+.Ar throttle
+is a comma-separated list of key=value pairs, where
+key could be a '*' character (limit for all connections), a subnet in form
+x.x.x.x/mask (limit for a given subnet, for example 10.0.0.0/8), or an
+URI prefix pattern (limit for the set of URIs, for example /foo/**). The value
+is a floating-point number of bytes per second, optionally followed by a
+`k' or `m' character, meaning kilobytes and megabytes respectively. A limit
+of 0 means unlimited rate. The last matching rule wins. For example,
+"*=1k,10.0.0.0/8" means limit everybody to 1 kilobyte per second, but give
+people from 10/8 subnet unlimited speed. Default: ""
 .It Fl a Ar access_log_file
 .It Fl a Ar access_log_file
 Access log file. Default: "", no logging is done.
 Access log file. Default: "", no logging is done.
 .It Fl d Ar enable_directory_listing
 .It Fl d Ar enable_directory_listing

+ 98 - 29
mongoose.c

@@ -131,6 +131,7 @@ typedef long off_t;
 #define read(x, y, z) _read((x), (y), (unsigned) z)
 #define read(x, y, z) _read((x), (y), (unsigned) z)
 #define flockfile(x) EnterCriticalSection(&global_log_file_lock)
 #define flockfile(x) EnterCriticalSection(&global_log_file_lock)
 #define funlockfile(x) LeaveCriticalSection(&global_log_file_lock)
 #define funlockfile(x) LeaveCriticalSection(&global_log_file_lock)
+#define sleep(x) Sleep((x) * 1000)
 
 
 #if !defined(fileno)
 #if !defined(fileno)
 #define fileno(x) _fileno(x)
 #define fileno(x) _fileno(x)
@@ -426,7 +427,7 @@ struct socket {
 // NOTE(lsm): this enum shoulds be in sync with the config_options below.
 // NOTE(lsm): this enum shoulds be in sync with the config_options below.
 enum {
 enum {
   CGI_EXTENSIONS, CGI_ENVIRONMENT, PUT_DELETE_PASSWORDS_FILE, CGI_INTERPRETER,
   CGI_EXTENSIONS, CGI_ENVIRONMENT, PUT_DELETE_PASSWORDS_FILE, CGI_INTERPRETER,
-  PROTECT_URI, AUTHENTICATION_DOMAIN, SSI_EXTENSIONS,
+  PROTECT_URI, AUTHENTICATION_DOMAIN, SSI_EXTENSIONS, THROTTLE,
   ACCESS_LOG_FILE, ENABLE_DIRECTORY_LISTING, ERROR_LOG_FILE,
   ACCESS_LOG_FILE, ENABLE_DIRECTORY_LISTING, ERROR_LOG_FILE,
   GLOBAL_PASSWORDS_FILE, INDEX_FILES, ENABLE_KEEP_ALIVE, ACCESS_CONTROL_LIST,
   GLOBAL_PASSWORDS_FILE, INDEX_FILES, ENABLE_KEEP_ALIVE, ACCESS_CONTROL_LIST,
   EXTRA_MIME_TYPES, LISTENING_PORTS, DOCUMENT_ROOT, SSL_CERTIFICATE,
   EXTRA_MIME_TYPES, LISTENING_PORTS, DOCUMENT_ROOT, SSL_CERTIFICATE,
@@ -442,6 +443,7 @@ static const char *config_options[] = {
   "P", "protect_uri", NULL,
   "P", "protect_uri", NULL,
   "R", "authentication_domain", "mydomain.com",
   "R", "authentication_domain", "mydomain.com",
   "S", "ssi_pattern", "**.shtml$|**.shtm$",
   "S", "ssi_pattern", "**.shtml$|**.shtm$",
+  "T", "throttle", NULL,
   "a", "access_log_file", NULL,
   "a", "access_log_file", NULL,
   "d", "enable_directory_listing", "yes",
   "d", "enable_directory_listing", "yes",
   "e", "error_log_file", NULL,
   "e", "error_log_file", NULL,
@@ -453,7 +455,7 @@ static const char *config_options[] = {
   "p", "listening_ports", "8080",
   "p", "listening_ports", "8080",
   "r", "document_root",  ".",
   "r", "document_root",  ".",
   "s", "ssl_certificate", NULL,
   "s", "ssl_certificate", NULL,
-  "t", "num_threads", "10",
+  "t", "num_threads", "20",
   "u", "run_as_user", NULL,
   "u", "run_as_user", NULL,
   "w", "url_rewrite_patterns", NULL,
   "w", "url_rewrite_patterns", NULL,
   "x", "hide_files_patterns", NULL,
   "x", "hide_files_patterns", NULL,
@@ -499,6 +501,9 @@ struct mg_connection {
   int request_len;            // Size of the request + headers in a buffer
   int request_len;            // Size of the request + headers in a buffer
   int data_len;               // Total size of data in a buffer
   int data_len;               // Total size of data in a buffer
   int status_code;            // HTTP reply status code, e.g. 200
   int status_code;            // HTTP reply status code, e.g. 200
+  int throttle;               // Throttling, bytes/sec. <= 0 means no throttle
+  time_t last_throttle_time;  // Last time throttled data was sent
+  int64_t last_throttle_bytes;// Bytes sent this second
 };
 };
 
 
 const char **mg_get_valid_option_names(void) {
 const char **mg_get_valid_option_names(void) {
@@ -1507,8 +1512,41 @@ int mg_read(struct mg_connection *conn, void *buf, size_t len) {
 }
 }
 
 
 int mg_write(struct mg_connection *conn, const void *buf, size_t len) {
 int mg_write(struct mg_connection *conn, const void *buf, size_t len) {
-  return (int) push(NULL, conn->client.sock, conn->ssl, (const char *) buf,
-                    (int64_t) len);
+  time_t now;
+  int64_t n, total, allowed;
+
+  if (conn->throttle > 0) {
+    if ((now = time(NULL)) != conn->last_throttle_time) {
+      conn->last_throttle_time = now;
+      conn->last_throttle_bytes = 0;
+    }
+    allowed = conn->throttle - conn->last_throttle_bytes;
+    if (allowed > (int64_t) len) {
+      allowed = len;
+    }
+    if ((total = push(NULL, conn->client.sock, conn->ssl, (const char *) buf,
+                      (int64_t) allowed)) == allowed) {
+      buf = (char *) buf + total;
+      conn->last_throttle_bytes += total;
+      while (total < (int64_t) len && conn->ctx->stop_flag == 0) {
+        allowed = conn->throttle > (int64_t) len - total ?
+          len - total : conn->throttle;
+        if ((n = push(NULL, conn->client.sock, conn->ssl, (const char *) buf,
+                      (int64_t) allowed)) != allowed) {
+          break;
+        }
+        sleep(1);
+        conn->last_throttle_bytes = allowed;
+        conn->last_throttle_time = time(NULL);
+        buf = (char *) buf + n;
+        total += n;
+      }
+    }
+  } else {
+    total = push(NULL, conn->client.sock, conn->ssl, (const char *) buf,
+                 (int64_t) len);
+  }
+  return (int) total;
 }
 }
 
 
 int mg_printf(struct mg_connection *conn, const char *fmt, ...) {
 int mg_printf(struct mg_connection *conn, const char *fmt, ...) {
@@ -3745,6 +3783,53 @@ static int is_websocket_request(const struct mg_connection *conn) {
 }
 }
 #endif // !USE_WEBSOCKET
 #endif // !USE_WEBSOCKET
 
 
+static int isbyte(int n) {
+  return n >= 0 && n <= 255;
+}
+
+static int parse_net(const char *spec, uint32_t *net, uint32_t *mask) {
+  int n, a, b, c, d, slash = 32, len = 0;
+
+  if ((sscanf(spec, "%d.%d.%d.%d/%d%n", &a, &b, &c, &d, &slash, &n) == 5 ||
+      sscanf(spec, "%d.%d.%d.%d%n", &a, &b, &c, &d, &n) == 4) &&
+      isbyte(a) && isbyte(b) && isbyte(c) && isbyte(d) &&
+      slash >= 0 && slash < 33) {
+    len = n;
+    *net = ((uint32_t)a << 24) | ((uint32_t)b << 16) | ((uint32_t)c << 8) | d;
+    *mask = slash ? 0xffffffffU << (32 - slash) : 0;
+  }
+
+  return len;
+}
+
+static int set_throttle(const char *spec, uint32_t remote_ip, const char *uri) {
+  int throttle = 0;
+  struct vec vec, val;
+  uint32_t net, mask;
+  char mult;
+  double v;
+
+  while ((spec = next_option(spec, &vec, &val)) != NULL) {
+    mult = ',';
+    if (sscanf(val.ptr, "%lf%c", &v, &mult) < 1 || v < 0 ||
+        (lowercase(&mult) != 'k' && lowercase(&mult) != 'm' && mult != ',')) {
+      continue;
+    }
+    v *= lowercase(&mult) == 'k' ? 1024 : lowercase(&mult) == 'm' ? 1048576 : 1;
+    if (vec.len == 1 && vec.ptr[0] == '*') {
+      throttle = (int) v;
+    } else if (parse_net(vec.ptr, &net, &mask) > 0) {
+      if ((remote_ip & mask) == net) {
+        throttle = (int) v;
+      }
+    } else if (match_prefix(vec.ptr, vec.len, uri) > 0) {
+      throttle = (int) v;
+    }
+  }
+
+  return throttle;
+}
+
 // This is the heart of the Mongoose's logic.
 // This is the heart of the Mongoose's logic.
 // This function is called when the request is read, parsed and validated,
 // This function is called when the request is read, parsed and validated,
 // and Mongoose must decide what action to take: serve a file, or
 // and Mongoose must decide what action to take: serve a file, or
@@ -3762,6 +3847,10 @@ static void handle_request(struct mg_connection *conn) {
   url_decode(ri->uri, (size_t)uri_len, ri->uri, (size_t)(uri_len + 1), 0);
   url_decode(ri->uri, (size_t)uri_len, ri->uri, (size_t)(uri_len + 1), 0);
   remove_double_dots_and_double_slashes(ri->uri);
   remove_double_dots_and_double_slashes(ri->uri);
   stat_result = convert_uri_to_file_name(conn, path, sizeof(path), &st);
   stat_result = convert_uri_to_file_name(conn, path, sizeof(path), &st);
+  conn->throttle = set_throttle(conn->ctx->config[THROTTLE],
+                                ntohl(* (uint32_t *)
+                                      &conn->client.rsa.sin.sin_addr),
+                                ri->uri);
 
 
   DEBUG_TRACE(("%s", ri->uri));
   DEBUG_TRACE(("%s", ri->uri));
   if (!check_authorization(conn, path)) {
   if (!check_authorization(conn, path)) {
@@ -3973,15 +4062,10 @@ static void log_access(const struct mg_connection *conn) {
   fclose(fp);
   fclose(fp);
 }
 }
 
 
-static int isbyte(int n) {
-  return n >= 0 && n <= 255;
-}
-
 // Verify given socket address against the ACL.
 // Verify given socket address against the ACL.
 // Return -1 if ACL is malformed, 0 if address is disallowed, 1 if allowed.
 // Return -1 if ACL is malformed, 0 if address is disallowed, 1 if allowed.
 static int check_acl(struct mg_context *ctx, const union usa *usa) {
 static int check_acl(struct mg_context *ctx, const union usa *usa) {
-  int a, b, c, d, n, mask, allowed;
-  char flag;
+  int allowed, flag;
   uint32_t acl_subnet, acl_mask, remote_ip;
   uint32_t acl_subnet, acl_mask, remote_ip;
   struct vec vec;
   struct vec vec;
   const char *list = ctx->config[ACCESS_CONTROL_LIST];
   const char *list = ctx->config[ACCESS_CONTROL_LIST];
@@ -3996,28 +4080,13 @@ static int check_acl(struct mg_context *ctx, const union usa *usa) {
   allowed = '-';
   allowed = '-';
 
 
   while ((list = next_option(list, &vec, NULL)) != NULL) {
   while ((list = next_option(list, &vec, NULL)) != NULL) {
-    mask = 32;
-
-    if (sscanf(vec.ptr, "%c%d.%d.%d.%d%n", &flag, &a, &b, &c, &d, &n) != 5) {
+    flag = vec.ptr[0];
+    if (flag != '+' && flag != '-' &&
+        parse_net(&vec.ptr[1], &acl_subnet, &acl_mask) == 0) {
       cry(fc(ctx), "%s: subnet must be [+|-]x.x.x.x[/x]", __func__);
       cry(fc(ctx), "%s: subnet must be [+|-]x.x.x.x[/x]", __func__);
       return -1;
       return -1;
-    } else if (flag != '+' && flag != '-') {
-      cry(fc(ctx), "%s: flag must be + or -: [%s]", __func__, vec.ptr);
-      return -1;
-    } else if (!isbyte(a)||!isbyte(b)||!isbyte(c)||!isbyte(d)) {
-      cry(fc(ctx), "%s: bad ip address: [%s]", __func__, vec.ptr);
-      return -1;
-    } else if (sscanf(vec.ptr + n, "/%d", &mask) == 0) {
-      // Do nothing, no mask specified
-    } else if (mask < 0 || mask > 32) {
-      cry(fc(ctx), "%s: bad subnet mask: %d [%s]", __func__, n, vec.ptr);
-      return -1;
     }
     }
 
 
-    acl_subnet = ((uint32_t) a << 24) | ((uint32_t) b << 16) |
-      ((uint32_t) c << 8) | d;
-    acl_mask = mask ? 0xffffffffU << (32 - mask) : 0;
-
     if (acl_subnet == (ntohl(remote_ip) & acl_mask)) {
     if (acl_subnet == (ntohl(remote_ip) & acl_mask)) {
       allowed = flag;
       allowed = flag;
     }
     }
@@ -4205,7 +4274,7 @@ static void reset_per_request_attributes(struct mg_connection *conn) {
   conn->path_info = conn->log_message = NULL;
   conn->path_info = conn->log_message = NULL;
   conn->num_bytes_sent = conn->consumed_content = 0;
   conn->num_bytes_sent = conn->consumed_content = 0;
   conn->status_code = -1;
   conn->status_code = -1;
-  conn->must_close = conn->request_len = 0;
+  conn->must_close = conn->request_len = conn->throttle = 0;
 }
 }
 
 
 static void close_socket_gracefully(struct mg_connection *conn) {
 static void close_socket_gracefully(struct mg_connection *conn) {

+ 30 - 0
test/unit_test.c

@@ -227,6 +227,34 @@ static void test_mg_get_var(void) {
   ASSERT(mg_get_var(post[0], strlen(post[0]), "x", buf, 0) == -2);
   ASSERT(mg_get_var(post[0], strlen(post[0]), "x", buf, 0) == -2);
 }
 }
 
 
+static void test_set_throttle(void) {
+  ASSERT(set_throttle(NULL, 0x0a000001, "/") == 0);
+  ASSERT(set_throttle("10.0.0.0/8=20", 0x0a000001, "/") == 20);
+  ASSERT(set_throttle("10.0.0.0/8=0.5k", 0x0a000001, "/") == 512);
+  ASSERT(set_throttle("10.0.0.0/8=17m", 0x0a000001, "/") == 1048576 * 17);
+  ASSERT(set_throttle("10.0.0.0/8=1x", 0x0a000001, "/") == 0);
+  ASSERT(set_throttle("10.0.0.0/8=5,0.0.0.0/0=10", 0x0a000001, "/") == 10);
+  ASSERT(set_throttle("10.0.0.0/8=5,/foo/**=7", 0x0a000001, "/index") == 5);
+  ASSERT(set_throttle("10.0.0.0/8=5,/foo/**=7", 0x0a000001, "/foo/x") == 7);
+  ASSERT(set_throttle("10.0.0.0/8=5,/foo/**=7", 0x0b000001, "/foxo/x") == 0);
+  ASSERT(set_throttle("10.0.0.0/8=5,*=1", 0x0b000001, "/foxo/x") == 1);
+}
+
+static void test_next_option(void) {
+  const char *p, *list = "x/8,/y**=1;2k,z";
+  struct vec a, b;
+  int i;
+
+  ASSERT(next_option(NULL, &a, &b) == NULL);
+  for (i = 0, p = list; (p = next_option(p, &a, &b)) != NULL; i++) {
+    ASSERT(i != 0 || (a.ptr == list && a.len == 3 && b.len == 0));
+    ASSERT(i != 1 || (a.ptr == list + 4 && a.len == 4 && b.ptr == list + 9 &&
+                      b.len == 4));
+
+    ASSERT(i != 2 || (a.ptr == list + 14 && a.len == 1 && b.len == 0));
+  }
+}
+
 int main(void) {
 int main(void) {
   test_base64_encode();
   test_base64_encode();
   test_match_prefix();
   test_match_prefix();
@@ -235,5 +263,7 @@ int main(void) {
   test_parse_http_request();
   test_parse_http_request();
   test_mg_fetch();
   test_mg_fetch();
   test_mg_get_var();
   test_mg_get_var();
+  test_set_throttle();
+  test_next_option();
   return 0;
   return 0;
 }
 }