ソースを参照

added thread safety for websockets

This was submitted by Morgan McGuire via proxy.

Hereis his comment...
The actual diff against civetweb is pretty small--I added thread safety
for websockets and rewrote and commented a very confusing function that
 was previously buggy.  The thread safety is critical since there is no
other way for the application to synchronize civet's receipt of a websocket
packet with its own desire to send one. All of this was tested pretty
thoroughly in my own apps and I had a pull request to Sergey that he never
got to before the fork.
Thomas Davis 12 年 前
コミット
2a10518fb7
2 ファイル変更106 行追加70 行削除
  1. 90 68
      civetweb.c
  2. 16 2
      civetweb.h

+ 90 - 68
civetweb.c

@@ -528,6 +528,7 @@ struct mg_connection {
   int throttle;               // Throttling, bytes/sec. <= 0 means no throttle
   time_t last_throttle_time;  // Last time throttled data was sent
   int64_t last_throttle_bytes;// Bytes sent this second
+  pthread_mutex_t mutex;      // Used by mg_lock/mg_unlock to ensure atomic transmissions for websockets
 };
 
 // Directory entry
@@ -3850,28 +3851,40 @@ static void send_websocket_handshake(struct mg_connection *conn) {
             "Sec-WebSocket-Accept: ", b64_sha, "\r\n\r\n");
 }
 
+void mg_lock(struct mg_connection* conn) {
+  (void) pthread_mutex_lock(&conn->mutex); 
+}
+
+void mg_unlock(struct mg_connection* conn) {
+  (void) pthread_mutex_unlock(&conn->mutex);
+}
+
 static void read_websocket(struct mg_connection *conn) {
-  // Pointer to the beginning of the portion of the incoming websocket message
-  // queue. The original websocket upgrade request is never removed,
-  // so the queue begins after it.
+  // Pointer to the beginning of the portion of the incoming websocket message queue.
+  // The original websocket upgrade request is never removed, so the queue begins after it.
   unsigned char *buf = (unsigned char *) conn->buf + conn->request_len;
-  int bits, n, stop = 0;
+  int n;
+
+  // body_len is the length of the entire queue in bytes
+  // len is the length of the current message
+  // data_len is the length of the current message's data payload
+  // header_len is the length of the current message's header
   size_t i, len, mask_len, data_len, header_len, body_len;
-  // data points to the place where the message is stored when passed to the
-  // websocket_data callback. This is either mem on the stack,
-  // or a dynamically allocated buffer if it is too large.
-  char mem[4 * 1024], mask[4], *data;
 
-  assert(conn->content_len == 0);
+  // "The masking key is a 32-bit value chosen at random by the client."
+  // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5
+  unsigned char mask[4];
+
+  // data points to the place where the message is stored when passed to the websocket_data
+  // callback.  This is either mem on the stack, or a dynamically allocated buffer if it is
+  // too large.
+  char mem[4 * 1024], *data;
 
   // Loop continuously, reading messages from the socket, invoking the callback,
   // and waiting repeatedly until an error occurs.
-  while (!stop) {
+  assert(conn->content_len == 0);
+  for (;;) {
     header_len = 0;
-    // body_len is the length of the entire queue in bytes
-    // len is the length of the current message
-    // data_len is the length of the current message's data payload
-    // header_len is the length of the current message's header
     if ((body_len = conn->data_len - conn->request_len) >= 2) {
       len = buf[1] & 127;
       mask_len = buf[1] & 128 ? 4 : 0;
@@ -3883,19 +3896,11 @@ static void read_websocket(struct mg_connection *conn) {
         data_len = ((((int) buf[2]) << 8) + buf[3]);
       } else if (body_len >= 10 + mask_len) {
         header_len = 10 + mask_len;
-        data_len = (((uint64_t) htonl(* (uint32_t *) &buf[2])) << 32) +
-          htonl(* (uint32_t *) &buf[6]);
+        data_len = (((uint64_t) ntohl(* (uint32_t *) &buf[2])) << 32) +
+          ntohl(* (uint32_t *) &buf[6]);
       }
     }
 
-    // Data layout is as follows:
-    //  conn->buf               buf
-    //     v                     v              frame1           | frame2
-    //     |---------------------|----------------|--------------|-------
-    //     |                     |<--header_len-->|<--data_len-->|
-    //     |<-conn->request_len->|<-----body_len----------->|
-    //     |<-------------------conn->data_len------------->|
-
     if (header_len > 0) {
       // Allocate space to hold websocket payload
       data = mem;
@@ -3905,38 +3910,50 @@ static void read_websocket(struct mg_connection *conn) {
         break;
       }
 
-      // Save mask and bits, otherwise it may be clobbered by memmove below
-      bits = buf[0];
-      memcpy(mask, buf + header_len - mask_len, mask_len);
+      // Copy the mask before we shift the queue and destroy it
+      if (mask_len > 0) {
+        *(uint32_t*)mask = *(uint32_t*)(buf + header_len - mask_len);
+      } else {
+        *(uint32_t*)mask = 0;
+      }
 
-      // Read frame payload into the allocated buffer.
+      // Read frame payload from the first message in the queue into data and
+      // advance the queue by moving the memory in place.
       assert(body_len >= header_len);
       if (data_len + header_len > body_len) {
+        // Overflow case
         len = body_len - header_len;
         memcpy(data, buf + header_len, len);
         // TODO: handle pull error
-        pull_all(NULL, conn, data + len, data_len - len);
-        conn->data_len = conn->request_len;
+        pull(NULL, conn, data + len, data_len - len);
+        conn->data_len = 0;
       } else {
+        // Length of the message being read at the front of the queue
         len = data_len + header_len;
+
+        // Copy the data payload into the data pointer for the callback
         memcpy(data, buf + header_len, data_len);
+
+        // Move the queue forward len bytes
         memmove(buf, buf + len, body_len - len);
+
+        // Mark the queue as advanced
         conn->data_len -= len;
       }
 
       // Apply mask if necessary
       if (mask_len > 0) {
-        for (i = 0; i < data_len; i++) {
-          data[i] ^= mask[i % 4];
+        for (i = 0; i < data_len; ++i) {
+          data[i] ^= mask[i & 3];
         }
       }
 
       // Exit the loop if callback signalled to exit,
       // or "connection close" opcode received.
-      if ((bits & WEBSOCKET_OPCODE_CONNECTION_CLOSE) ||
-          (conn->ctx->callbacks.websocket_data != NULL &&
-           !conn->ctx->callbacks.websocket_data(conn, bits, data, data_len))) {
-        stop = 1;
+      if ((conn->ctx->callbacks.websocket_data != NULL &&
+          !conn->ctx->callbacks.websocket_data(conn, buf[0], data, data_len)) ||
+          (buf[0] & 0xf) == 8) {  // Opcode == 8, connection close
+        break;
       }
 
       if (data != mem) {
@@ -3944,9 +3961,10 @@ static void read_websocket(struct mg_connection *conn) {
       }
       // Not breaking the loop, process next websocket frame.
     } else {
-      // Buffering websocket request
+      // Read from the socket into the next available location in the message queue.
       if ((n = pull(NULL, conn, conn->buf + conn->data_len,
                     conn->buf_size - conn->data_len)) <= 0) {
+        // Error, no bytes read
         break;
       }
       conn->data_len += n;
@@ -3954,44 +3972,40 @@ static void read_websocket(struct mg_connection *conn) {
   }
 }
 
-int mg_websocket_write(struct mg_connection* conn, int opcode,
-                       const char *data, size_t data_len) {
-    unsigned char *copy;
-    size_t copy_len = 0;
-    int retval = -1;
+int mg_websocket_write(struct mg_connection* conn, int opcode, const char* data, size_t dataLen) {
+    unsigned char header[10];
+    size_t headerLen = 1;
 
-    if ((copy = (unsigned char *) malloc(data_len + 10)) == NULL) {
-      return -1;
-    }
+    int retval = -1;
 
-    copy[0] = 0x80 + (opcode & 0x0f);
+    header[0] = 0x80 + (opcode & 0xF);
 
     // Frame format: http://tools.ietf.org/html/rfc6455#section-5.2
-    if (data_len < 126) {
-      // Inline 7-bit length field
-      copy[1] = data_len;
-      memcpy(copy + 2, data, data_len);
-      copy_len = 2 + data_len;
-    } else if (data_len <= 0xFFFF) {
-      // 16-bit length field
-      copy[1] = 126;
-      * (uint16_t *) (copy + 2) = htons(data_len);
-      memcpy(copy + 4, data, data_len);
-      copy_len = 4 + data_len;
+    if (dataLen < 126) {
+        // inline 7-bit length field
+        header[1] = dataLen;
+        headerLen = 2;
+    } else if (dataLen <= 0xFFFF) {
+        // 16-bit length field
+        header[1] = 126;
+        *(uint16_t*)(header + 2) = htons(dataLen);
+        headerLen = 4;
     } else {
-      // 64-bit length field
-      copy[1] = 127;
-      * (uint32_t *) (copy + 2) = htonl((uint64_t) data_len >> 32);
-      * (uint32_t *) (copy + 6) = htonl(data_len & 0xffffffff);
-      memcpy(copy + 10, data, data_len);
-      copy_len = 10 + data_len;
+        // 64-bit length field
+        header[1] = 127;
+        *(uint32_t*)(header + 2) = htonl((uint64_t)dataLen >> 32);
+        *(uint32_t*)(header + 6) = htonl(dataLen & 0xFFFFFFFF);
+        headerLen = 10;
     }
 
-    // Not thread safe
-    if (copy_len > 0) {
-      retval = mg_write(conn, copy, copy_len);
-    }
-    free(copy);
+    // Note that POSIX/Winsock's send() is threadsafe
+    // http://stackoverflow.com/questions/1981372/are-parallel-calls-to-send-recv-on-the-same-socket-valid
+    // but mongoose's mg_printf/mg_write is not (because of the loop in push(), although that is only
+    // a problem if the packet is large or outgoing buffer is full).
+    (void) mg_lock(conn);
+    retval = mg_write(conn, header, headerLen);
+    retval = mg_write(conn, data, dataLen);
+    mg_unlock(conn);
 
     return retval;
 }
@@ -4738,6 +4752,7 @@ static void close_socket_gracefully(struct mg_connection *conn) {
 }
 
 static void close_connection(struct mg_connection *conn) {
+  mg_lock(conn);
   conn->must_close = 1;
 
 #ifndef NO_SSL
@@ -4752,6 +4767,8 @@ static void close_connection(struct mg_connection *conn) {
     close_socket_gracefully(conn);
     conn->client.sock = INVALID_SOCKET;
   }
+
+  mg_unlock(conn);
 }
 
 void mg_close_connection(struct mg_connection *conn) {
@@ -4761,6 +4778,7 @@ void mg_close_connection(struct mg_connection *conn) {
   }
 #endif
   close_connection(conn);
+  (void) pthread_mutex_destroy(&conn->mutex);
   free(conn);
 }
 
@@ -4791,6 +4809,7 @@ struct mg_connection *mg_connect(const char *host, int port, int use_ssl,
     conn->client.sock = sock;
     getsockname(sock, &conn->client.rsa.sa, &len);
     conn->client.is_ssl = use_ssl;
+    (void) pthread_mutex_init(&conn->mutex, NULL);
 #ifndef NO_SSL
     if (use_ssl) {
       // SSL_CTX_set_verify call is needed to switch off server certificate
@@ -4962,6 +4981,9 @@ static void *worker_thread(void *thread_func_param) {
     conn->buf = (char *) (conn + 1);
     conn->ctx = ctx;
     conn->request_info.user_data = ctx->user_data;
+    // Allocate a mutex for this connection to allow communication both
+    // within the request handler and from elsewhere in the application
+    (void) pthread_mutex_init(&conn->mutex, NULL);
 
     // Call consume_socket() even when ctx->stop_flag > 0, to let it signal
     // sq_empty condvar to wake up the master waiting in produce_socket()

+ 16 - 2
civetweb.h

@@ -19,7 +19,7 @@
 // THE SOFTWARE.
 
 #ifndef CIVETWEB_HEADER_INCLUDED
-#define  CIVETWEB_HEADER_INCLUDED
+#define CIVETWEB_HEADER_INCLUDED
 
 #include <stdio.h>
 #include <stddef.h>
@@ -95,6 +95,11 @@ struct mg_callbacks {
   int  (*websocket_data)(struct mg_connection *, int bits,
                          char *data, size_t data_len);
 
+  // Called when civetweb is closing a connection.  The per-context mutex is locked when this
+  // is invoked.  This is primarily useful for noting when a websocket is closing and removing it
+  // from any application-maintained list of clients.
+  void (*connection_close)(struct mg_connection *);
+
   // Called when civetweb tries to open a file. Used to intercept file open
   // calls, and serve file data from memory instead.
   // Parameters:
@@ -210,8 +215,11 @@ struct mg_request_info *mg_get_request_info(struct mg_connection *);
 int mg_write(struct mg_connection *, const void *buf, size_t len);
 
 
+// Send data to a websocket client wrapped in a websocket frame.  Uses mg_lock to ensure
+// that the transmission is not interrupted, i.e., when the application is proactively
+// communicating and responding to a request simultaneously.
+//
 // Send data to a websocket client wrapped in a websocket frame.
-// It is unsafe to read/write to this connection from another thread.
 // This function is available when civetweb is compiled with -DUSE_WEBSOCKET
 //
 // Return:
@@ -221,6 +229,12 @@ int mg_write(struct mg_connection *, const void *buf, size_t len);
 int mg_websocket_write(struct mg_connection* conn, int opcode,
                        const char *data, size_t data_len);
 
+// Blocks until unique access is obtained to this connection. Intended for use with websockets only.
+// Invoke this before mg_write or mg_printf when communicating with a websocket if your code has
+// server-initiated communication as well as communication in direct response to a message.
+void mg_lock(struct mg_connection* conn);
+void mg_unlock(struct mg_connection* conn);
+
 // Opcodes, from http://tools.ietf.org/html/rfc6455
 enum {
   WEBSOCKET_OPCODE_CONTINUATION = 0x0,