Explorar el Código

Lua websockets should use the cooperative Lua threading model with timeouts

bel hace 11 años
padre
commit
d4ee6a1dc6
Se han modificado 3 ficheros con 152 adiciones y 100 borrados
  1. 3 2
      src/civetweb.c
  2. 136 97
      src/mod_lua.inl
  3. 13 1
      test/websocket.lua

+ 3 - 2
src/civetweb.c

@@ -4785,8 +4785,9 @@ static void handle_websocket_request(struct mg_connection *conn, const char *pat
             conn->lua_websocket_state = new_lua_websocket(path, conn);
             if (conn->lua_websocket_state) {
                 send_websocket_handshake(conn);
-                lua_websocket_ready(conn);
-                read_websocket(conn);
+                if (lua_websocket_ready(conn)) {
+                    read_websocket(conn);
+                }
             }
         } else
 #endif

+ 136 - 97
src/mod_lua.inl

@@ -1,9 +1,10 @@
 #include <lua.h>
 #include <lauxlib.h>
+#include <setjmp.h>
 
 #ifdef _WIN32
 static void *mmap(void *addr, int64_t len, int prot, int flags, int fd,
-                  int offset)
+    int offset)
 {
     HANDLE fh = (HANDLE) _get_osfhandle(fd);
     HANDLE mh = CreateFileMapping(fh, 0, PAGE_READONLY, 0, 0, 0);
@@ -29,7 +30,7 @@ static const char *LUASOCKET = "luasocket";
 /* Forward declarations */
 static void handle_request(struct mg_connection *);
 static int handle_lsp_request(struct mg_connection *, const char *,
-                              struct file *, struct lua_State *);
+struct file *, struct lua_State *);
 
 static void reg_string(struct lua_State *L, const char *name, const char *val)
 {
@@ -59,7 +60,7 @@ static void reg_boolean(struct lua_State *L, const char *name, int val)
 }
 
 static void reg_function(struct lua_State *L, const char *name,
-                         lua_CFunction func, struct mg_connection *conn)
+    lua_CFunction func, struct mg_connection *conn)
 {
     if (name!=NULL && func!=NULL) {
         lua_pushstring(L, name);
@@ -136,7 +137,7 @@ static int lsp_connect(lua_State *L)
 
     if (lua_isstring(L, -3) && lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
         sock = conn2(NULL, lua_tostring(L, -3), (int) lua_tonumber(L, -2),
-                     (int) lua_tonumber(L, -1), ebuf, sizeof(ebuf));
+            (int) lua_tonumber(L, -1), ebuf, sizeof(ebuf));
         if (sock == INVALID_SOCKET) {
             return luaL_error(L, ebuf);
         } else {
@@ -174,7 +175,7 @@ static void lsp_abort(lua_State *L)
 }
 
 static int lsp(struct mg_connection *conn, const char *path,
-               const char *p, int64_t len, lua_State *L)
+    const char *p, int64_t len, lua_State *L)
 {
     int i, j, pos = 0, lines = 1, lualines = 0;
     char chunkname[MG_BUF_LEN];
@@ -254,7 +255,7 @@ static int lsp_include(lua_State *L)
     struct file file = STRUCT_FILE_INITIALIZER;
     if (handle_lsp_request(conn, lua_tostring(L, -1), &file, L)) {
         /* handle_lsp_request returned an error code, meaning an error occured in
-           the included page and mg.onerror returned non-zero. Stop processing. */
+        the included page and mg.onerror returned non-zero. Stop processing. */
         lsp_abort(L);
     }
     return 0;
@@ -309,7 +310,7 @@ static int lwebsock_write(lua_State *L)
             mg_websocket_write(conn, WEBSOCKET_OPCODE_TEXT, str, size);
         }
     }
-#endif    
+#endif
     return 0;
 }
 
@@ -409,7 +410,7 @@ static void prepare_lua_environment(struct mg_connection *conn, lua_State *L, co
 
     /* Register default mg.onerror function */
     IGNORE_UNUSED_RESULT(luaL_dostring(L, "mg.onerror = function(e) mg.write('\\nLua error:\\n', "
-                                       "debug.traceback(e, 1)) end"));
+        "debug.traceback(e, 1)) end"));
 }
 
 static int lua_error_handler(lua_State *L)
@@ -433,7 +434,7 @@ static int lua_error_handler(lua_State *L)
 }
 
 void mg_exec_lua_script(struct mg_connection *conn, const char *path,
-                        const void **exports)
+    const void **exports)
 {
     int i;
     lua_State *L;
@@ -461,7 +462,7 @@ void mg_exec_lua_script(struct mg_connection *conn, const char *path,
 }
 
 static void lsp_send_err(struct mg_connection *conn, struct lua_State *L,
-                         const char *fmt, ...)
+    const char *fmt, ...)
 {
     char buf[MG_BUF_LEN];
     va_list ap;
@@ -480,7 +481,7 @@ static void lsp_send_err(struct mg_connection *conn, struct lua_State *L,
 }
 
 static int handle_lsp_request(struct mg_connection *conn, const char *path,
-                              struct file *filep, struct lua_State *ls)
+struct file *filep, struct lua_State *ls)
 {
     void *p = NULL;
     lua_State *L = NULL;
@@ -490,10 +491,10 @@ static int handle_lsp_request(struct mg_connection *conn, const char *path,
     if (!mg_stat(conn, path, filep) || !mg_fopen(conn, path, "r", filep)) {
         lsp_send_err(conn, ls, "File [%s] not found", path);
     } else if (filep->membuf == NULL &&
-               (p = mmap(NULL, (size_t) filep->size, PROT_READ, MAP_PRIVATE,
-                         fileno(filep->fp), 0)) == MAP_FAILED) {
-        lsp_send_err(conn, ls, "mmap(%s, %zu, %d): %s", path, (size_t) filep->size,
-                     fileno(filep->fp), strerror(errno));
+        (p = mmap(NULL, (size_t) filep->size, PROT_READ, MAP_PRIVATE,
+        fileno(filep->fp), 0)) == MAP_FAILED) {
+            lsp_send_err(conn, ls, "mmap(%s, %zu, %d): %s", path, (size_t) filep->size,
+                fileno(filep->fp), strerror(errno));
     } else if ((L = ls != NULL ? ls : luaL_newstate()) == NULL) {
         send_http_error(conn, 500, http_500_error, "%s", "luaL_newstate failed");
     } else {
@@ -505,7 +506,7 @@ static int handle_lsp_request(struct mg_connection *conn, const char *path,
             }
         }
         error = lsp(conn, path, filep->membuf == NULL ? p : filep->membuf,
-                    filep->size, L);
+            filep->size, L);
     }
 
     if (L != NULL && ls == NULL) lua_close(L);
@@ -515,113 +516,151 @@ static int handle_lsp_request(struct mg_connection *conn, const char *path,
     return error;
 }
 
+#ifdef USE_WEBSOCKET
+struct lua_websock_data {
+    lua_State *main;
+    lua_State *thread;
+};
+
+static void websock_cry(struct mg_connection *conn, int err, lua_State * L, const char * ws_operation, const char * lua_operation)
+{
+    switch (err) {
+    case LUA_OK:
+    case LUA_YIELD:
+        break;
+    case LUA_ERRRUN:
+        mg_cry(conn, "%s: %s failed: runtime error: %s", ws_operation, lua_operation, lua_tostring(L, -1));
+        break;
+    case LUA_ERRSYNTAX:
+        mg_cry(conn, "%s: %s failed: syntax error: %s", ws_operation, lua_operation, lua_tostring(L, -1));
+        break;
+    case LUA_ERRMEM:
+        mg_cry(conn, "%s: %s failed: out of memory", ws_operation, lua_operation);
+        break;
+    case LUA_ERRGCMM:
+        mg_cry(conn, "%s: %s failed: error during garbage collection", ws_operation, lua_operation);
+        break;
+    case LUA_ERRERR:
+        mg_cry(conn, "%s: %s failed: error in error handling: %s", ws_operation, lua_operation, lua_tostring(L, -1));
+        break;
+    default:
+        mg_cry(conn, "%s: %s failed: error %i", ws_operation, lua_operation, err);
+        break;
+    }
+}
+
 static void * new_lua_websocket(const char * script, struct mg_connection *conn)
 {
-    lua_State *L = NULL;
+    struct lua_websock_data *lws_data;
     int ok = 0;
-    int err;
+    int err, nargs;
 
     assert(conn->lua_websocket_state == NULL);
-    L = luaL_newstate();
-    if (L) {
-        prepare_lua_environment(conn, L, script, LUA_ENV_TYPE_LUA_WEBSOCKET);
-        if (conn->ctx->callbacks.init_lua != NULL) {
-            conn->ctx->callbacks.init_lua(conn, L);
-        }
-        err = luaL_loadfile(L, script);
-        switch (err) {
-            case 0:
-                {
-                    err = lua_pcall(L, 0, LUA_MULTRET, 0);
-                    switch (err) {
-                        case 0:
-                            /* return nothing or true to continue, false to stop */
-                            ok = !lua_isboolean(L, -1) || lua_toboolean(L, -1);
-                            break;
-                        case LUA_ERRMEM:
-                            mg_cry(conn, "%s: lua_pcall failed: out of memory", __func__);
-                            break;
-                        case LUA_ERRRUN:
-                            mg_cry(conn, "%s: lua_pcall failed: runtime error: %s", __func__, lua_tostring(L, -1));
-                            break;
-                        case LUA_ERRERR:
-                            mg_cry(conn, "%s: lua_pcall failed: double fault: %s", __func__, lua_tostring(L, -1));
-                            break;
-                        default:
-                            mg_cry(conn, "%s: lua_pcall failed: error %i", __func__, err);
-                            break;
-                    }
+    lws_data = (struct lua_websock_data *) malloc(sizeof(*lws_data));
+
+    if (lws_data) {
+        lws_data->main = luaL_newstate();
+        if (lws_data->main) {
+            prepare_lua_environment(conn, lws_data->main, script, LUA_ENV_TYPE_LUA_WEBSOCKET);
+            if (conn->ctx->callbacks.init_lua != NULL) {
+                conn->ctx->callbacks.init_lua(conn, lws_data->main);
+            }
+            lws_data->thread = lua_newthread(lws_data->main);
+            err = luaL_loadfile(lws_data->thread, script);
+            if (err==LUA_OK) {
+                /* Activate the Lua script. */
+                err = lua_resume(lws_data->thread, NULL, 0);
+                if (err!=LUA_YIELD) {
+                    websock_cry(conn, err, lws_data->thread, __func__, "lua_resume");
+                } else {
+                    nargs = lua_gettop(lws_data->thread);
+                    ok = (nargs==1) && lua_isboolean(lws_data->thread, 1) && lua_toboolean(lws_data->thread, 1);
                 }
-                break;
-            case LUA_ERRMEM:
-                mg_cry(conn, "%s: luaL_loadfile failed: out of memory", __func__);
-                break;
-            case LUA_ERRFILE:
-                mg_cry(conn, "%s: luaL_loadfile failed: file %s not found", __func__, script);
-                break;
-            case LUA_ERRSYNTAX:
-                mg_cry(conn, "%s: luaL_loadfile failed: syntax error: %s", __func__, lua_tostring(L, -1));
-                break;
-            default:
-                mg_cry(conn, "%s: luaL_loadfile failed: error %i", __func__, err);
-                break;
+            } else {
+                websock_cry(conn, err, lws_data->thread, __func__, "lua_loadfile");
+            }
+
+        } else {
+            mg_cry(conn, "%s: luaL_newstate failed", __func__);
         }
+
         if (!ok) {
-            lua_close(L);
-            L = NULL;
+            if (lws_data->main) lua_close(lws_data->main);
+            free(lws_data);
+            lws_data=0;
         }
-    } else {
-        mg_cry(conn, "%s: luaL_newstate failed", __func__);
     }
 
-    return L;
-}
-
-#ifdef USE_WEBSOCKET
-static void lua_websocket_ready(struct mg_connection *conn)
-{
-    lua_State *L = (lua_State*)(conn->lua_websocket_state);
-
-    assert(L != NULL);
-
-    lua_getglobal(L, "ready");
-    if (lua_pcall(L, 0, 0, 0) != 0) {
-        mg_cry(conn, "%s: error running function `ready': %s", lua_tostring(L, -1));
-    }
+    return lws_data;
 }
 
 static int lua_websocket_data(struct mg_connection *conn, int bits, char *data, size_t data_len)
 {
-    lua_State *L = (lua_State*)(conn->lua_websocket_state);
-    int ok = 0;
-
-    assert(L != NULL);
+    struct lua_websock_data *lws_data = (struct lua_websock_data *)(conn->lua_websocket_state);
+    int err, nargs, ok=0, retry;
+    lua_Number delay;
+
+    assert(lws_data != NULL);
+    assert(lws_data->main != NULL);
+    assert(lws_data->thread != NULL);
+
+    do {
+        retry=0;
+        lua_pushboolean(lws_data->thread, 1);
+        if (bits > 0) {
+            lua_pushinteger(lws_data->thread, bits);
+            if (data) {
+                lua_pushlstring(lws_data->thread, data, data_len);
+                err = lua_resume(lws_data->thread, NULL, 3);
+            } else {
+                err = lua_resume(lws_data->thread, NULL, 2);
+            }
+        } else {
+            err = lua_resume(lws_data->thread, NULL, 1);
+        }
 
-    lua_getglobal(L, "data");
-    lua_pushinteger(L, bits);
-    lua_pushlstring(L, data, data_len);
-    if (lua_pcall(L, 2, 1, 0) != 0) {
-        mg_cry(conn, "%s: error running function `data': %s", lua_tostring(L, -1));
-    } else {
-        ok = lua_isboolean(L, -1) && lua_toboolean(L, -1);
-    }
+        if (err!=LUA_YIELD) {
+            websock_cry(conn, err, lws_data->thread, __func__, "lua_resume");
+        } else {
+            nargs = lua_gettop(lws_data->thread);
+            ok = (nargs>=1) && lua_isboolean(lws_data->thread, 1) && lua_toboolean(lws_data->thread, 1);
+            delay = (nargs>=2) && lua_isnumber(lws_data->thread, 2) ? lua_tonumber(lws_data->thread, 2) : -1.0;
+            if (ok && delay>0) {
+                fd_set rfds;
+                struct timeval tv;
+
+                FD_ZERO(&rfds);
+                FD_SET(conn->client.sock, &rfds);
+
+                tv.tv_sec = (unsigned long)delay;
+                tv.tv_usec = (unsigned long)(((double)delay - (double)((unsigned long)delay))*1000000.0);
+                retry = (0==select(1, &rfds, NULL, NULL, &tv));
+            }
+        }
+    } while (retry);
 
     return ok;
 }
 
+static int lua_websocket_ready(struct mg_connection *conn)
+{
+    return lua_websocket_data(conn, -1, NULL, 0);
+}
 
 static void lua_websocket_close(struct mg_connection *conn)
 {
-    lua_State *L = (lua_State*)(conn->lua_websocket_state);
+    struct lua_websock_data *lws_data = (struct lua_websock_data *)(conn->lua_websocket_state);
+    int err;
 
-    assert(L != NULL);
+    assert(lws_data != NULL);
+    assert(lws_data->main != NULL);
+    assert(lws_data->thread != NULL);
 
-    lua_getglobal(L, "close");
-    if (lua_pcall(L, 0, 0, 0) != 0) {
-        mg_cry(conn, "%s: error running function `close': %s", lua_tostring(L, -1));
-    }
+    lua_pushboolean(lws_data->thread, 0);
+    err = lua_resume(lws_data->thread, NULL, 1);
 
-    lua_close(L);
+    lua_close(lws_data->main);
+    free(lws_data);
     conn->lua_websocket_state = NULL;
 }
 #endif

+ 13 - 1
test/websocket.lua

@@ -46,5 +46,17 @@ function close()
 end
 
 
+-- Websocket with coroutines
 logDB("WEBSOCKET PREPARE")
-return true; -- could return false to reject the connection before the websocket handshake
+
+coroutine.yield(true); -- first yield returns (true) or (false) to accept or reject the connection
+ready()
+repeat
+    local cont, bits, content = coroutine.yield(true, 1.0)
+    if bits and content then
+        data(bits, content)
+    end
+until not cont;
+
+mg.write("text", "end")
+close()