소스 검색

Allow multiple websocket clients with different callbacks and an additional user supplied argument

bel 10 년 전
부모
커밋
eac72465ad
4개의 변경된 파일153개의 추가작업 그리고 59개의 파일을 삭제
  1. 99 42
      examples/websocket_client/websocket_client.c
  2. 4 3
      include/civetweb.h
  3. 45 8
      src/civetweb.c
  4. 5 6
      test/unit_test.c

+ 99 - 42
examples/websocket_client/websocket_client.c

@@ -97,22 +97,24 @@ struct mg_context * start_websocket_server()
 /*************************************************************************************/
 /*************************************************************************************/
 /* WEBSOCKET CLIENT                                                                  */
 /* WEBSOCKET CLIENT                                                                  */
 /*************************************************************************************/
 /*************************************************************************************/
-struct {
+struct tclient_data {
     void * data;
     void * data;
     size_t len;
     size_t len;
-} client_data;
+};
 
 
 static int websocket_client_data_handler(struct mg_connection *conn, int flags, char *data, size_t data_len)
 static int websocket_client_data_handler(struct mg_connection *conn, int flags, char *data, size_t data_len)
 {
 {
+    struct mg_context *ctx = mg_get_context(conn);
+    struct tclient_data *pclient_data = (struct tclient_data *) mg_get_user_data(ctx);
+
     printf("From server: ");
     printf("From server: ");
     fwrite(data, 1, data_len, stdout);
     fwrite(data, 1, data_len, stdout);
     printf("\n");
     printf("\n");
 
 
-    /* TODO: extra arg (instead of global client_data) */
-    client_data.data = malloc(data_len);
-    assert(client_data.data != NULL);
-    memcpy(client_data.data, data, data_len);
-    client_data.len = data_len;
+    pclient_data->data = malloc(data_len);
+    assert(pclient_data->data != NULL);
+    memcpy(pclient_data->data, data, data_len);
+    pclient_data->len = data_len;
 
 
     return 1;
     return 1;
 }
 }
@@ -120,68 +122,123 @@ static int websocket_client_data_handler(struct mg_connection *conn, int flags,
 
 
 int main(int argc, char *argv[])
 int main(int argc, char *argv[])
 {
 {
-    struct mg_context *ctx;
-    struct mg_connection* newconn;
-    char ebuf[100];
+    struct mg_context *ctx = NULL;
+    struct tclient_data client1_data = {NULL, 0};
+    struct tclient_data client2_data = {NULL, 0};
+    struct mg_connection* newconn1 = NULL;
+    struct mg_connection* newconn2 = NULL;
+    char ebuf[100] = {0};
 
 
     assert(websocket_welcome_msg_len == strlen(websocket_welcome_msg));
     assert(websocket_welcome_msg_len == strlen(websocket_welcome_msg));
 
 
     /* First set up a websocket server */
     /* First set up a websocket server */
     ctx = start_websocket_server();
     ctx = start_websocket_server();
     assert(ctx != NULL);
     assert(ctx != NULL);
+    printf("Server init\n\n");
 
 
     /* Then connect a client */
     /* Then connect a client */
-    newconn = mg_websocket_client_connect("localhost", atoi(PORT), 0,
-        ebuf, sizeof(ebuf),
-        "/websocket", NULL, websocket_client_data_handler /* TODO: extra arg (instead of global client_data) */);
+    newconn1 = mg_websocket_client_connect("localhost", atoi(PORT), 0, ebuf, sizeof(ebuf),
+        "/websocket", NULL, websocket_client_data_handler, &client1_data);
 
 
-    if (newconn == NULL)
+    if (newconn1 == NULL)
     {
     {
         printf("Error: %s", ebuf);
         printf("Error: %s", ebuf);
         return 1;
         return 1;
     }
     }
 
 
     sleep(1); /* Should get the websocket welcome message */
     sleep(1); /* Should get the websocket welcome message */
-    assert(client_data.data != NULL);
-    assert(client_data.len == websocket_welcome_msg_len);
-    assert(!memcmp(client_data.data, websocket_welcome_msg, websocket_welcome_msg_len));
-    free(client_data.data);
-    client_data.data = NULL;
-    client_data.len = 0;
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
+    assert(client1_data.data != NULL);
+    assert(client1_data.len == websocket_welcome_msg_len);
+    assert(!memcmp(client1_data.data, websocket_welcome_msg, websocket_welcome_msg_len));
+    free(client1_data.data);
+    client1_data.data = NULL;
+    client1_data.len = 0;
 
 
-    mg_websocket_write(newconn, WEBSOCKET_OPCODE_TEXT, "data1", 5);
+    mg_websocket_write(newconn1, WEBSOCKET_OPCODE_TEXT, "data1", 5);
 
 
     sleep(1); /* Should get the acknowledge message */
     sleep(1); /* Should get the acknowledge message */
-    assert(client_data.data != NULL);
-    assert(client_data.len == websocket_acknowledge_msg_len);
-    assert(!memcmp(client_data.data, websocket_acknowledge_msg, websocket_acknowledge_msg_len));
-    free(client_data.data);
-    client_data.data = NULL;
-    client_data.len = 0;
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
+    assert(client1_data.data != NULL);
+    assert(client1_data.len == websocket_acknowledge_msg_len);
+    assert(!memcmp(client1_data.data, websocket_acknowledge_msg, websocket_acknowledge_msg_len));
+    free(client1_data.data);
+    client1_data.data = NULL;
+    client1_data.len = 0;
+
+    /* Then connect a client */
+    newconn2 = mg_websocket_client_connect("localhost", atoi(PORT), 0, ebuf, sizeof(ebuf),
+        "/websocket", NULL, websocket_client_data_handler, &client2_data);
+
+    if (newconn2 == NULL)
+    {
+        printf("Error: %s", ebuf);
+        return 1;
+    }
+
+    sleep(1); /* Client 2 should get the websocket welcome message */
+    assert(client1_data.data == NULL);
+    assert(client1_data.len == 0);
+    assert(client2_data.data != NULL);
+    assert(client2_data.len == websocket_welcome_msg_len);
+    assert(!memcmp(client2_data.data, websocket_welcome_msg, websocket_welcome_msg_len));
+    free(client2_data.data);
+    client2_data.data = NULL;
+    client2_data.len = 0;
 
 
-    mg_websocket_write(newconn, WEBSOCKET_OPCODE_TEXT, "data2", 5);
+    mg_websocket_write(newconn1, WEBSOCKET_OPCODE_TEXT, "data2", 5);
 
 
     sleep(1); /* Should get the acknowledge message */
     sleep(1); /* Should get the acknowledge message */
-    assert(client_data.data != NULL);
-    assert(client_data.len == websocket_acknowledge_msg_len);
-    assert(!memcmp(client_data.data, websocket_acknowledge_msg, websocket_acknowledge_msg_len));
-    free(client_data.data);
-    client_data.data = NULL;
-    client_data.len = 0;
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
+    assert(client1_data.data != NULL);
+    assert(client1_data.len == websocket_acknowledge_msg_len);
+    assert(!memcmp(client1_data.data, websocket_acknowledge_msg, websocket_acknowledge_msg_len));
+    free(client1_data.data);
+    client1_data.data = NULL;
+    client1_data.len = 0;
+
+    mg_websocket_write(newconn1, WEBSOCKET_OPCODE_TEXT, "bye", 3);
+
+    sleep(1); /* Should get the goodbye message */
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
+    assert(client1_data.data != NULL);
+    assert(client1_data.len == websocket_goodbye_msg_len);
+    assert(!memcmp(client1_data.data, websocket_goodbye_msg, websocket_goodbye_msg_len));
+    free(client1_data.data);
+    client1_data.data = NULL;
+    client1_data.len = 0;
+
+    mg_close_connection(newconn1);
+
+    sleep(1); /* Won't get any message */
+    assert(client1_data.data == NULL);
+    assert(client1_data.len == 0);
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
 
 
-    mg_websocket_write(newconn, WEBSOCKET_OPCODE_TEXT, "bye", 3);
+    mg_websocket_write(newconn2, WEBSOCKET_OPCODE_TEXT, "bye", 3);
 
 
     sleep(1); /* Should get the goodbye message */
     sleep(1); /* Should get the goodbye message */
-    assert(client_data.data != NULL);
-    assert(client_data.len == websocket_goodbye_msg_len);
-    assert(!memcmp(client_data.data, websocket_goodbye_msg, websocket_goodbye_msg_len));
-    free(client_data.data);
-    client_data.data = NULL;
-    client_data.len = 0;
+    assert(client1_data.data == NULL);
+    assert(client1_data.len == 0);
+    assert(client2_data.data != NULL);
+    assert(client2_data.len == websocket_goodbye_msg_len);
+    assert(!memcmp(client2_data.data, websocket_goodbye_msg, websocket_goodbye_msg_len));
+    free(client2_data.data);
+    client2_data.data = NULL;
+    client2_data.len = 0;
 
 
-    mg_close_connection(newconn);
+    mg_close_connection(newconn2);
 
 
     sleep(1); /* Won't get any message */
     sleep(1); /* Won't get any message */
+    assert(client1_data.data == NULL);
+    assert(client1_data.len == 0);
+    assert(client2_data.data == NULL);
+    assert(client2_data.len == 0);
 
 
     mg_stop(ctx);
     mg_stop(ctx);
     printf("Server shutdown\n");
     printf("Server shutdown\n");

+ 4 - 3
include/civetweb.h

@@ -573,7 +573,7 @@ CIVETWEB_API void mg_cry(struct mg_connection *conn,
 /* utility method to compare two buffers, case incensitive. */
 /* utility method to compare two buffers, case incensitive. */
 CIVETWEB_API int mg_strncasecmp(const char *s1, const char *s2, size_t len);
 CIVETWEB_API int mg_strncasecmp(const char *s1, const char *s2, size_t len);
 
 
-/* Connect to a websocket as a client 
+/* Connect to a websocket as a client
    Parameters:
    Parameters:
      host: host to connect to, i.e. "echo.websocket.org" or "192.168.1.1" or "localhost"
      host: host to connect to, i.e. "echo.websocket.org" or "192.168.1.1" or "localhost"
      port: server port
      port: server port
@@ -586,12 +586,13 @@ CIVETWEB_API int mg_strncasecmp(const char *s1, const char *s2, size_t len);
    Return:
    Return:
      On success, valid mg_connection object.
      On success, valid mg_connection object.
      On error, NULL. */
      On error, NULL. */
-     
+
 typedef int  (*websocket_data_func)(struct mg_connection *, int bits,
 typedef int  (*websocket_data_func)(struct mg_connection *, int bits,
                            char *data, size_t data_len);
                            char *data, size_t data_len);
 CIVETWEB_API struct mg_connection *mg_websocket_client_connect(const char *host, int port, int use_ssl,
 CIVETWEB_API struct mg_connection *mg_websocket_client_connect(const char *host, int port, int use_ssl,
                                                char *error_buffer, size_t error_buffer_size,
                                                char *error_buffer, size_t error_buffer_size,
-                                               const char *path, const char *origin, websocket_data_func data_func);
+                                               const char *path, const char *origin,
+                                               websocket_data_func data_func, void * user_data);
 
 
 #ifdef __cplusplus
 #ifdef __cplusplus
 }
 }

+ 45 - 8
src/civetweb.c

@@ -779,6 +779,7 @@ struct mg_context {
     char *config[NUM_OPTIONS];      /* Civetweb configuration parameters */
     char *config[NUM_OPTIONS];      /* Civetweb configuration parameters */
     struct mg_callbacks callbacks;  /* User-defined callback function */
     struct mg_callbacks callbacks;  /* User-defined callback function */
     void *user_data;                /* User-defined data */
     void *user_data;                /* User-defined data */
+    int context_type;               /* 1 = server context, 2 = client context */
 
 
     struct socket *listening_sockets;
     struct socket *listening_sockets;
     in_port_t *listening_ports;
     in_port_t *listening_ports;
@@ -5146,7 +5147,7 @@ static void read_websocket(struct mg_connection *conn)
        callback, and waiting repeatedly until an error occurs. */
        callback, and waiting repeatedly until an error occurs. */
     /* TODO: Investigate if this next line is needed
     /* TODO: Investigate if this next line is needed
     assert(conn->content_len == 0); */
     assert(conn->content_len == 0); */
-    for (;;) {
+    while (!conn->ctx->stop_flag) {
         header_len = 0;
         header_len = 0;
         assert(conn->data_len >= conn->request_len);
         assert(conn->data_len >= conn->request_len);
         if ((body_len = conn->data_len - conn->request_len) >= 2) {
         if ((body_len = conn->data_len - conn->request_len) >= 2) {
@@ -6316,12 +6317,29 @@ static void close_connection(struct mg_connection *conn)
 
 
 void mg_close_connection(struct mg_connection *conn)
 void mg_close_connection(struct mg_connection *conn)
 {
 {
+    struct mg_context * client_ctx = NULL;
+    int i;
+
+    if (conn->ctx->context_type == 2) {
+        client_ctx = conn->ctx;
+        /* client context: loops must end */
+        conn->ctx->stop_flag = 1;
+    }
+
 #ifndef NO_SSL
 #ifndef NO_SSL
     if (conn->client_ssl_ctx != NULL) {
     if (conn->client_ssl_ctx != NULL) {
         SSL_CTX_free((SSL_CTX *) conn->client_ssl_ctx);
         SSL_CTX_free((SSL_CTX *) conn->client_ssl_ctx);
     }
     }
 #endif
 #endif
     close_connection(conn);
     close_connection(conn);
+    if (client_ctx != NULL) {
+        /* join worker thread and free context */
+        for (i = 0; i < client_ctx->workerthreadcount; i++) {
+            mg_join_thread(client_ctx->workerthreadids[i]);
+        }
+        mg_free(client_ctx->workerthreadids);
+        mg_free(client_ctx);
+    }
     (void) pthread_mutex_destroy(&conn->mutex);
     (void) pthread_mutex_destroy(&conn->mutex);
     mg_free(conn);
     mg_free(conn);
 }
 }
@@ -6446,26 +6464,35 @@ struct mg_connection *mg_download(const char *host, int port, int use_ssl,
 }
 }
 
 
 #if defined(USE_WEBSOCKET)
 #if defined(USE_WEBSOCKET)
+#ifdef _WIN32
+static unsigned __stdcall websocket_client_thread(void *data)
+#else
 static void* websocket_client_thread(void *data)
 static void* websocket_client_thread(void *data)
+#endif
 {
 {
     struct mg_connection* conn = (struct mg_connection*)data;
     struct mg_connection* conn = (struct mg_connection*)data;
     read_websocket(conn);
     read_websocket(conn);
 
 
     DEBUG_TRACE("Websocket client thread exited\n");
     DEBUG_TRACE("Websocket client thread exited\n");
 
 
+#ifdef _WIN32
+    return 0;
+#else
     return NULL;
     return NULL;
+#endif
 }
 }
 #endif
 #endif
 
 
 struct mg_connection *mg_websocket_client_connect(const char *host, int port, int use_ssl,
 struct mg_connection *mg_websocket_client_connect(const char *host, int port, int use_ssl,
                                                char *error_buffer, size_t error_buffer_size,
                                                char *error_buffer, size_t error_buffer_size,
-                                               const char *path, const char *origin, websocket_data_func data_func)
+                                               const char *path, const char *origin, websocket_data_func data_func, void * user_data)
 {
 {
     struct mg_connection* conn = NULL;
     struct mg_connection* conn = NULL;
+    struct mg_context * newctx = NULL;
 
 
 #if defined(USE_WEBSOCKET)
 #if defined(USE_WEBSOCKET)
     static const char *magic = "x3JJHMbDL1EzLkh9GBhXDw==";
     static const char *magic = "x3JJHMbDL1EzLkh9GBhXDw==";
-    static const char *handshake_req;    
+    static const char *handshake_req;
 
 
     if(origin != NULL)
     if(origin != NULL)
     {
     {
@@ -6502,15 +6529,24 @@ struct mg_connection *mg_websocket_client_connect(const char *host, int port, in
         return conn;
         return conn;
     }
     }
 
 
-    /* For client connections, mg_context is fake. Set the callback for websocket 
-     data manually here so that read_websocket will automatically call it */
-    conn->ctx->callbacks.websocket_data = data_func;
+    /* For client connections, mg_context is fake. Since we need to set a callback
+       function, we need to create a copy and modify it. */
+    newctx = (struct mg_context *) mg_malloc(sizeof(struct mg_context));
+    memcpy(newctx, conn->ctx, sizeof(struct mg_context));
+    newctx->callbacks.websocket_data = data_func; /* read_websocket will automatically call it */
+    newctx->user_data = user_data;
+    newctx->context_type = 2; /* client context type */
+    newctx->workerthreadcount = 1; /* one worker thread will be created */
+    newctx->workerthreadids = (pthread_t*) mg_calloc(newctx->workerthreadcount, sizeof(pthread_t));
+    conn->ctx = newctx;
 
 
     /* Start a thread to read the websocket client connection
     /* Start a thread to read the websocket client connection
-    This thread will automatically stop when mg_disconnect is 
+    This thread will automatically stop when mg_disconnect is
     called on the client connection */
     called on the client connection */
-    if(mg_start_thread(websocket_client_thread, (void*)conn) != 0)
+    if (mg_start_thread_with_id(websocket_client_thread, (void*)conn, newctx->workerthreadids) != 0)
     {
     {
+        mg_free((void*)newctx->workerthreadids);
+        mg_free((void*)newctx);
         mg_free((void*)conn);
         mg_free((void*)conn);
         conn = NULL;
         conn = NULL;
         DEBUG_TRACE("Websocket client connect thread could not be started\r\n");
         DEBUG_TRACE("Websocket client connect thread could not be started\r\n");
@@ -7134,6 +7170,7 @@ struct mg_context *mg_start(const struct mg_callbacks *callbacks,
         ctx->callbacks.init_context(ctx);
         ctx->callbacks.init_context(ctx);
     }
     }
     ctx->callbacks.exit_context = exit_callback;
     ctx->callbacks.exit_context = exit_callback;
+    ctx->context_type = 1; /* server context */
 
 
     /* Start master (listening) thread */
     /* Start master (listening) thread */
     mg_start_thread_with_id(master_thread, ctx, &ctx->masterthreadid);
     mg_start_thread_with_id(master_thread, ctx, &ctx->masterthreadid);

+ 5 - 6
test/unit_test.c

@@ -543,16 +543,15 @@ static void test_mg_websocket_client_connect(int use_ssl) {
     /* Invalid port test */
     /* Invalid port test */
     conn = mg_websocket_client_connect("localhost", 0, use_ssl,
     conn = mg_websocket_client_connect("localhost", 0, use_ssl,
                              ebuf, sizeof(ebuf),
                              ebuf, sizeof(ebuf),
-                             "/", "http://localhost",websocket_data_handler);
+                             "/", "http://localhost", websocket_data_handler, NULL);
     ASSERT(conn == NULL);
     ASSERT(conn == NULL);
 
 
     /* Should succeed, the default civetweb sever should complete the handshake */
     /* Should succeed, the default civetweb sever should complete the handshake */
     conn = mg_websocket_client_connect("localhost", port, use_ssl,
     conn = mg_websocket_client_connect("localhost", port, use_ssl,
                              ebuf, sizeof(ebuf),
                              ebuf, sizeof(ebuf),
-                             "/", "http://localhost",websocket_data_handler);
+                             "/", "http://localhost", websocket_data_handler, NULL);
     ASSERT(conn != NULL);
     ASSERT(conn != NULL);
 
 
-
     /* Try an external server test */
     /* Try an external server test */
     port = 80;
     port = 80;
     if (use_ssl) { port = 443; }
     if (use_ssl) { port = 443; }
@@ -560,19 +559,19 @@ static void test_mg_websocket_client_connect(int use_ssl) {
     /* Not a websocket server path */
     /* Not a websocket server path */
     conn = mg_websocket_client_connect("websocket.org", port, use_ssl,
     conn = mg_websocket_client_connect("websocket.org", port, use_ssl,
                              ebuf, sizeof(ebuf),
                              ebuf, sizeof(ebuf),
-                             "/", "http://websocket.org",websocket_data_handler);
+                             "/", "http://websocket.org", websocket_data_handler, NULL);
     ASSERT(conn == NULL);
     ASSERT(conn == NULL);
 
 
     /* Invalid port test */
     /* Invalid port test */
     conn = mg_websocket_client_connect("echo.websocket.org", 0, use_ssl,
     conn = mg_websocket_client_connect("echo.websocket.org", 0, use_ssl,
                              ebuf, sizeof(ebuf),
                              ebuf, sizeof(ebuf),
-                             "/", "http://websocket.org",websocket_data_handler);
+                             "/", "http://websocket.org", websocket_data_handler, NULL);
     ASSERT(conn == NULL);
     ASSERT(conn == NULL);
 
 
     /* Should succeed, echo.websocket.org echos the data back */
     /* Should succeed, echo.websocket.org echos the data back */
     conn = mg_websocket_client_connect("echo.websocket.org", port, use_ssl,
     conn = mg_websocket_client_connect("echo.websocket.org", port, use_ssl,
                              ebuf, sizeof(ebuf),
                              ebuf, sizeof(ebuf),
-                             "/", "http://websocket.org",websocket_data_handler);
+                             "/", "http://websocket.org", websocket_data_handler, NULL);
     ASSERT(conn != NULL);
     ASSERT(conn != NULL);
 
 
     mg_stop(ctx);
     mg_stop(ctx);