fakeredis-py icon indicating copy to clipboard operation
fakeredis-py copied to clipboard

Memory leak: eval() creates new LuaRuntime on every call

Open chrisguidry opened this issue 4 months ago • 1 comments

Describe the bug

Each call to EVAL or EVALSHA creates a new lupa.LuaRuntime() instance (line 237 of scripting_mixin.py), and these runtimes don't get garbage collected properly, causing unbounded memory growth in long-running processes.

I discovered this while investigating memory growth in my task queue library docket, which uses fakeredis[lua] for its in-memory backend. See chrisguidry/docket#258 for the investigation details.

To Reproduce

Run a loop that executes Lua scripts repeatedly:

import asyncio
import tracemalloc
from fakeredis import FakeAsyncRedis

tracemalloc.start()

async def main():
    redis = FakeAsyncRedis()

    for i in range(1000):
        # Each eval creates a new LuaRuntime that doesn't get GC'd
        await redis.eval(b"return 1", 0)

        if i % 100 == 0:
            current, peak = tracemalloc.get_traced_memory()
            print(f"Iteration {i}: {current / 1024:.1f} KB")

asyncio.run(main())

Memory grows continuously with each iteration.

Expected behavior

Memory should remain stable after initial warmup. The LuaRuntime should be reused across eval() calls.

Desktop

  • OS: Linux (Ubuntu 24.04)
  • Python version: 3.12
  • fakeredis version: 2.32.1
  • lupa version: 2.5

Proposed Fix

Cache the LuaRuntime, set_globals function, and expected_globals set on the FakeServer instance so they're reused across calls. Here's a minimal diff:

diff --git a/fakeredis/commands_mixins/scripting_mixin.py b/fakeredis/commands_mixins/scripting_mixin.py
--- a/fakeredis/commands_mixins/scripting_mixin.py
+++ b/fakeredis/commands_mixins/scripting_mixin.py
@@ -234,42 +234,60 @@ class ScriptingCommandsMixin:
             raise SimpleError(msgs.NEGATIVE_KEYS_MSG)
         sha1 = hashlib.sha1(script).hexdigest().encode()
         self._server.script_cache[sha1] = script
-        lua_runtime: LUA_MODULE.LuaRuntime = LUA_MODULE.LuaRuntime(encoding=None, unpack_returned_tuples=True)
-        modules_import_str = "\n".join([f"{module} = require('{module}')" for module in self.load_lua_modules])
-        set_globals = lua_runtime.eval(
-            f"""
-            function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels, cjson_encode, cjson_decode, cjson_null)
-                redis = {{}}
-                redis.call = redis_call
-                redis.pcall = redis_pcall
-                redis.log = redis_log
-                for level, pylevel in python.iterex(redis_log_levels.items()) do
-                    redis[level] = pylevel
-                end
-                redis.error_reply = function(msg) return {{err=msg}} end
-                redis.status_reply = function(msg) return {{ok=msg}} end
-
-                cjson = {{}}
-                cjson.encode = cjson_encode
-                cjson.decode = cjson_decode
-                cjson.null = cjson_null
-
-                KEYS = keys
-                ARGV = argv
-                {modules_import_str}
-            end
-            """
-        )
-        expected_globals: Set[Any] = set()
-        set_globals(
-            lua_runtime.table_from(keys_and_args[:numkeys]),
-            lua_runtime.table_from(keys_and_args[numkeys:]),
-            functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
-            functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
-            functools.partial(_lua_redis_log, lua_runtime, expected_globals),
-            LUA_MODULE.as_attrgetter(REDIS_LOG_LEVELS),
-            functools.partial(_lua_cjson_encode, lua_runtime, expected_globals),
-            functools.partial(_lua_cjson_decode, lua_runtime, expected_globals),
-            _lua_cjson_null,
-        )
-        expected_globals.update(lua_runtime.globals().keys())
+
+        # Cache LuaRuntime on the server to avoid memory leak from creating
+        # new runtimes on every eval call
+        if not hasattr(self._server, "_lua_runtime"):
+            self._server._lua_runtime = LUA_MODULE.LuaRuntime(
+                encoding=None, unpack_returned_tuples=True
+            )
+            modules_import_str = "\n".join(
+                [f"{module} = require('{module}')" for module in self.load_lua_modules]
+            )
+            self._server._lua_set_globals = self._server._lua_runtime.eval(
+                f"""
+                function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels, cjson_encode, cjson_decode, cjson_null)
+                    redis = {{}}
+                    redis.call = redis_call
+                    redis.pcall = redis_pcall
+                    redis.log = redis_log
+                    for level, pylevel in python.iterex(redis_log_levels.items()) do
+                        redis[level] = pylevel
+                    end
+                    redis.error_reply = function(msg) return {{err=msg}} end
+                    redis.status_reply = function(msg) return {{ok=msg}} end
+
+                    cjson = {{}}
+                    cjson.encode = cjson_encode
+                    cjson.decode = cjson_decode
+                    cjson.null = cjson_null
+
+                    KEYS = keys
+                    ARGV = argv
+                    {modules_import_str}
+                end
+                """
+            )
+            # Capture expected globals once after first setup
+            self._server._lua_expected_globals: Set[Any] = set()
+            self._server._lua_set_globals(
+                self._server._lua_runtime.table_from([]),
+                self._server._lua_runtime.table_from([]),
+                lambda *args: None, lambda *args: None, lambda *args: None,
+                {}, lambda *args: None, lambda *args: None, None,
+            )
+            self._server._lua_expected_globals.update(
+                self._server._lua_runtime.globals().keys()
+            )
+
+        lua_runtime = self._server._lua_runtime
+        set_globals = self._server._lua_set_globals
+        expected_globals = self._server._lua_expected_globals
+
+        set_globals(
+            lua_runtime.table_from(keys_and_args[:numkeys]),
+            lua_runtime.table_from(keys_and_args[numkeys:]),
+            functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
+            functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
+            functools.partial(_lua_redis_log, lua_runtime, expected_globals),
+            LUA_MODULE.as_attrgetter(REDIS_LOG_LEVELS),
+            functools.partial(_lua_cjson_encode, lua_runtime, expected_globals),
+            functools.partial(_lua_cjson_decode, lua_runtime, expected_globals),
+            _lua_cjson_null,
+        )

         try:
             result = lua_runtime.execute(script)

Would you be interested in a PR for this? I've already implemented a workaround in docket via monkeypatching, but it would be great to have this fixed upstream.

chrisguidry avatar Dec 19 '25 19:12 chrisguidry

This is great, please create a PR, yes. I am wondering whether it is possible to write a test for it as well? Also, there is no need to create the runtime in the __init__

cunla avatar Dec 23 '25 23:12 cunla