From 2c3866dee4c235c37c609c72e6049d3ed75df596 Mon Sep 17 00:00:00 2001 From: Simon Fayer Date: Mon, 18 May 2026 16:36:26 +0100 Subject: [PATCH] feat: Add a new eval helper function --- src/DIRAC/Core/Utilities/SaferEval.py | 32 ++++ .../Core/Utilities/test/Test_SaferEval.py | 145 ++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 src/DIRAC/Core/Utilities/SaferEval.py create mode 100644 src/DIRAC/Core/Utilities/test/Test_SaferEval.py diff --git a/src/DIRAC/Core/Utilities/SaferEval.py b/src/DIRAC/Core/Utilities/SaferEval.py new file mode 100644 index 00000000000..7c53ae1a623 --- /dev/null +++ b/src/DIRAC/Core/Utilities/SaferEval.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +import ast + + +def saferEval(obj_str, max_len=2048): + """This function adds an extra length check around literal_eval. + On python3.11 and above (which has a recursion guard), this should + be safe enough for use on general authenticated user input. + + Note: This doesn't handle all of the cases of eval, such as + datetime as those are technically executing code to + instantiate the non-base objects. + """ + # Ensure input is a string + obj_str = str(obj_str) + if len(obj_str) > max_len: + raise ValueError(f"Object string is too long (>{max_len} bytes)") + try: + return ast.literal_eval(obj_str) + except (ValueError, TypeError, SyntaxError): + # This covers all of the cases where the string is wrong (unclosed brackets...) + # or contains disallowed items like function calls or non-expression. + raise ValueError("Syntax error processing object expression") + except (MemoryError, RecursionError): + # This is encountered if the object is nested too deeply and other structures + # that are probably malicious. + raise ValueError("Object expression too large") + except Exception: + # There are no other possible exceptions at the time of writing, + # this is to catch any added in future python versions. + raise ValueError("Unknown error processing object expression") diff --git a/src/DIRAC/Core/Utilities/test/Test_SaferEval.py b/src/DIRAC/Core/Utilities/test/Test_SaferEval.py new file mode 100644 index 00000000000..798cd741b67 --- /dev/null +++ b/src/DIRAC/Core/Utilities/test/Test_SaferEval.py @@ -0,0 +1,145 @@ +"""Tests for saferEval – uses pytest parametrize for conciseness.""" + +import time + +import pytest + +from DIRAC.Core.Utilities.SaferEval import saferEval + + +@pytest.mark.parametrize( + "value", + [ + None, + True, + False, + 0, + 42, + -17, + 0xFF, + 0o77, + 0b1010, + 3.14, + 1e10, + 1j, + [], + [1, "two", True, None], + (), + (1,), + (1, 2, 3), + {}, + {"a": 1, "b": 2}, + {1, 2, 3}, + [[1, 2], [3, 4]], + {"a": {"b": {"c": [1, 2]}}}, + ], +) +def test_literal(value): + assert saferEval(str(value)) == value + + +@pytest.mark.parametrize( + "input_str,expected", + [ + ('"hello"', "hello"), + ("'hello'", "hello"), + ("'a\\nb'", "a\nb"), + (r"r'\n'", r"\n"), + ('"hello 🌍"', "hello 🌍"), + ("b'bytes'", b"bytes"), + ("b'\\xff'", b"\xff"), + ], +) +def test_string_literal(input_str, expected): + assert saferEval(input_str) == expected + + +@pytest.mark.parametrize( + "input_str", + [ + "__import__('os').system('id')", + "open('/etc/passwd').read()", + "list()", + "foo", + "datetime.datetime.now()", + "lambda x: x", + "{k: v for k, v in []}", + "(x for x in [])", + "x == y", + "1 + 2", + "().__class__", + "x[0]", + "[1,2][1:]", + "*1", + "builtins.open", + "object()", + "MyList()", + "f'{1+2}'", + "@decorator", + "assert True", + "return 42", + "x += 1", + "with open('x') as f: pass", + "for x in []: pass", + "try: pass\nexcept: pass", + "import os", + "from os import path", + "del x", + "raise ValueError('x')", + "yield 1", + "await something", + "(x := 1)", + "(lambda x, /: x)(1)", + "10**200", + ], +) +def test_rejected_inputs(input_str): + with pytest.raises(ValueError): + saferEval(input_str) + + +def test_max_len_exceeded(): + with pytest.raises(ValueError): + saferEval("1" * 2049, 2048) + + +def test_max_len_custom_exceeded(): + with pytest.raises(ValueError): + saferEval("[1, 2, 3]", 5) + + +def test_max_len_custom_ok(): + assert saferEval("[1, 2, 3]", 10) == [1, 2, 3] + + +def test_max_len_boundary_default(): + assert saferEval("42") == 42 + + +@pytest.mark.parametrize("depth", [2000, 500]) +def test_deep_nesting(depth): + with pytest.raises((ValueError, RecursionError)): + saferEval("[" * depth + "1" + "]" * depth) + + +def test_large_string_literal(): + with pytest.raises(ValueError): + saferEval("'" + "a" * 3000 + "'", 2048) + + +def test_large_list(): + with pytest.raises(ValueError): + saferEval(str([1] * 3000), 2048) + + +@pytest.mark.parametrize( + "s", + [ + "{" + ", ".join(f'"k{i}": {i}' for i in range(50)) + "}", + "[" + ",".join(str(i) for i in range(50)) + "]", + ], +) +def test_performance(s): + start = time.time() + saferEval(s, 2048) + assert time.time() - start < 0.1