# Owner(s): ["oncall: pt2"] import itertools import math import sys import sympy from typing import Callable, List, Tuple, Type from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( TEST_Z3, instantiate_parametrized_tests, parametrize, run_tests, TestCase, ) from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity from sympy.core.relational import is_ge, is_le, is_gt, is_lt import functools import torch.fx as fx UNARY_OPS = [ "reciprocal", "square", "abs", "neg", "exp", "log", "sqrt", "floor", "ceil", ] BINARY_OPS = [ "truediv", "floordiv", # "truncdiv", # TODO # NB: pow is float_pow "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" ] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] # a mix of constants, powers of two, primes CONSTANTS = [ -1, 0, 1, 2, 3, 4, 5, 8, 16, 32, 64, 100, 101, 2**24, 2**32, 2**37 - 1, sys.maxsize - 1, sys.maxsize, ] # less constants for N^2 situations LESS_CONSTANTS = [-1, 0, 1, 2, 100] # SymPy relational types. RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le] def valid_unary(fn, v): if fn == "log" and v <= 0: return False elif fn == "reciprocal" and v == 0: return False elif fn == "sqrt" and v < 0: return False return True def valid_binary(fn, a, b): if fn == "pow" and ( # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 # no imaginary numbers or a <= 0 # 0**0 is undefined or (a == b == 0) ): return False elif fn == "pow_by_natural" and ( # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 or b < 0 or (a == b == 0) ): return False elif fn == "mod" and (a < 0 or b <= 0): return False elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True def generate_range(vals): for a1, a2 in itertools.product(vals, repeat=2): if a1 in [sympy.true, sympy.false]: if a1 == sympy.true and a2 == sympy.false: continue else: if a1 > a2: continue # ranges that only admit infinite values are not interesting if a1 == sympy.oo or a2 == -sympy.oo: continue yield ValueRanges(a1, a2) class TestNumbers(TestCase): def test_int_infinity(self): self.assertIsInstance(int_oo, IntInfinity) self.assertIsInstance(-int_oo, NegativeIntInfinity) self.assertTrue(int_oo.is_integer) # is tests here are for singleton-ness, don't use it for comparisons # against numbers self.assertIs(int_oo + int_oo, int_oo) self.assertIs(int_oo + 1, int_oo) self.assertIs(int_oo - 1, int_oo) self.assertIs(-int_oo - 1, -int_oo) self.assertIs(-int_oo + 1, -int_oo) self.assertIs(-int_oo + (-int_oo), -int_oo) self.assertIs(-int_oo - int_oo, -int_oo) self.assertIs(1 + int_oo, int_oo) self.assertIs(1 - int_oo, -int_oo) self.assertIs(int_oo * int_oo, int_oo) self.assertIs(2 * int_oo, int_oo) self.assertIs(int_oo * 2, int_oo) self.assertIs(-1 * int_oo, -int_oo) self.assertIs(-int_oo * int_oo, -int_oo) self.assertIs(2 * -int_oo, -int_oo) self.assertIs(-int_oo * 2, -int_oo) self.assertIs(-1 * -int_oo, int_oo) self.assertIs(int_oo / 2, sympy.oo) self.assertIs(-(-int_oo), int_oo) # noqa: B002 self.assertIs(abs(int_oo), int_oo) self.assertIs(abs(-int_oo), int_oo) self.assertIs(int_oo ** 2, int_oo) self.assertIs((-int_oo) ** 2, int_oo) self.assertIs((-int_oo) ** 3, -int_oo) self.assertEqual(int_oo ** -1, 0) self.assertEqual((-int_oo) ** -1, 0) self.assertIs(int_oo ** int_oo, int_oo) self.assertTrue(int_oo == int_oo) self.assertFalse(int_oo != int_oo) self.assertTrue(-int_oo == -int_oo) self.assertFalse(int_oo == 2) self.assertTrue(int_oo != 2) self.assertFalse(int_oo == sys.maxsize) self.assertTrue(int_oo >= sys.maxsize) self.assertTrue(int_oo >= 2) self.assertTrue(int_oo >= -int_oo) def test_relation(self): self.assertIs(sympy.Add(2, int_oo), int_oo) self.assertFalse(-int_oo > 2) def test_lt_self(self): self.assertFalse(int_oo < int_oo) self.assertIs(min(-int_oo, -4), -int_oo) self.assertIs(min(-int_oo, -int_oo), -int_oo) def test_float_cast(self): self.assertEqual(float(int_oo), math.inf) self.assertEqual(float(-int_oo), -math.inf) def test_mixed_oo_int_oo(self): # Arbitrary choice self.assertTrue(int_oo < sympy.oo) self.assertFalse(int_oo > sympy.oo) self.assertTrue(sympy.oo > int_oo) self.assertFalse(sympy.oo < int_oo) self.assertIs(max(int_oo, sympy.oo), sympy.oo) self.assertTrue(-int_oo > -sympy.oo) self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo) class TestValueRanges(TestCase): @parametrize("fn", UNARY_OPS) @parametrize("dtype", ("int", "float")) def test_unary_ref(self, fn, dtype): dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype] for v in CONSTANTS: if not valid_unary(fn, v): continue with self.subTest(v=v): v = dtype(v) ref_r = getattr(ReferenceAnalysis, fn)(v) r = getattr(ValueRangeAnalysis, fn)(v) self.assertEqual(r.lower.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(ref_r, r.lower) def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) @parametrize("dtype", ("int", "float")) def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} # Don't test float on int only methods if dtype == "float" and fn in ["pow_by_natural", "mod"]: return dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue a = dtype(a) b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) self.assertEqual(r.lower.is_integer, r.upper.is_integer) self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) def test_mul_zero_unknown(self): self.assertEqual( ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()), ValueRanges.wrap(0), ) self.assertEqual( ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()), ValueRanges.wrap(0.0), ) @parametrize("fn", UNARY_BOOL_OPS) def test_unary_bool_ref_range(self, fn): vals = [sympy.false, sympy.true] for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) unique = set() for a0 in vals: if a0 not in a: continue with self.subTest(a0=a0): r = getattr(ReferenceAnalysis, fn)(a0) self.assertIn(r, ref_r) unique.add(r) if ref_r.lower == ref_r.upper: self.assertEqual(len(unique), 1) else: self.assertEqual(len(unique), 2) @parametrize("fn", BINARY_BOOL_OPS) def test_binary_bool_ref_range(self, fn): vals = [sympy.false, sympy.true] for a, b in itertools.product(generate_range(vals), repeat=2): with self.subTest(a=a, b=b): ref_r = getattr(ValueRangeAnalysis, fn)(a, b) unique = set() for a0, b0 in itertools.product(vals, repeat=2): if a0 not in a or b0 not in b: continue with self.subTest(a0=a0, b0=b0): r = getattr(ReferenceAnalysis, fn)(a0, b0) self.assertIn(r, ref_r) unique.add(r) if ref_r.lower == ref_r.upper: self.assertEqual(len(unique), 1) else: self.assertEqual(len(unique), 2) @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): # TODO: bring back sympy.oo testing for float unary fns vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) for a0 in CONSTANTS: if a0 not in a: continue if not valid_unary(fn, a0): continue with self.subTest(a0=a0): r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0)) self.assertIn(r, ref_r) # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): # TODO: bring back sympy.oo testing for float unary fns vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) def test_interp(self, fn): # SymPy does not implement truncation for Expressions if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return is_integer = None if fn == "pow_by_natural": is_integer = True x = sympy.Dummy('x', integer=is_integer) y = sympy.Dummy('y', integer=is_integer) vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] arity = 1 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 symbols = [x] if arity == 2: symbols = [x, y] for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue elif arity == 2 and not valid_binary(fn, *args): continue with self.subTest(args=args): sargs = [sympy.sympify(a) for a in args] sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) ref_r = getattr(ReferenceAnalysis, fn)(*sargs) # Yes, I know this is a longwinded way of saying xreplace; the # point is to test sympy_interp r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) self.assertEqual(ref_r, r) @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) def test_python_interp_fx(self, fn): # These never show up from symbolic_shapes if fn in ("log", "exp"): return # Sympy does not support truncation on symbolic shapes if fn in ("truncdiv", "mod"): return vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] arity = 1 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 is_integer = None if fn == "pow_by_natural": is_integer = True x = sympy.Dummy('x', integer=is_integer) y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue elif arity == 2 and not valid_binary(fn, *args): continue if fn == "truncdiv" and args[1] == 0: continue elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): # Workaround mpf from symbol error if fn == "minimum": sympy_expr = sympy.Min(x, y) elif fn == "maximum": sympy_expr = sympy.Max(x, y) else: sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) if arity == 1: def trace_f(px): return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) else: def trace_f(px, py): return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) gm = fx.symbolic_trace(trace_f) self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) ) def type_name_fn(type: Type) -> str: return type.__name__ def parametrize_relational_types(*types): def wrapper(f: Callable): return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f) return wrapper class TestSympySolve(TestCase): def _create_integer_symbols(self) -> List[sympy.Symbol]: return sympy.symbols("a b c", integer=True) def test_give_up(self): from sympy import Eq, Ne a, b, c = self._create_integer_symbols() cases = [ # Not a relational operation. a + b, # 'a' appears on both sides. Eq(a, a + 1), # 'a' doesn't appear on neither side. Eq(b, c + 1), # Result is a 'sympy.And'. Eq(FloorDiv(a, b), c), # Result is a 'sympy.Or'. Ne(FloorDiv(a, b), c), ] for case in cases: e = try_solve(case, a) self.assertEqual(e, None) @parametrize_relational_types() def test_noop(self, op): a, b, _ = self._create_integer_symbols() lhs, rhs = a, 42 * b expr = op(lhs, rhs) r = try_solve(expr, a) self.assertNotEqual(r, None) r_expr, r_rhs = r self.assertEqual(r_expr, expr) self.assertEqual(r_rhs, rhs) @parametrize_relational_types() def test_noop_rhs(self, op): a, b, _ = self._create_integer_symbols() lhs, rhs = 42 * b, a mirror = mirror_rel_op(op) self.assertNotEqual(mirror, None) expr = op(lhs, rhs) r = try_solve(expr, a) self.assertNotEqual(r, None) r_expr, r_rhs = r self.assertEqual(r_expr, mirror(rhs, lhs)) self.assertEqual(r_rhs, lhs) def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs): for source, expected in cases: r = try_solve(source, thing, **kwargs) self.assertTrue( (r is None and expected is None) or (r is not None and expected is not None) ) if r is not None: r_expr, r_rhs = r self.assertEqual(r_rhs, expected) self.assertEqual(r_expr, op(thing, expected)) def test_addition(self): from sympy import Eq a, b, c = self._create_integer_symbols() cases = [ (Eq(a + b, 0), -b), (Eq(a + 5, b - 5), b - 10), (Eq(a + c * b, 1), 1 - c * b), ] self._test_cases(cases, a, Eq) @parametrize_relational_types(sympy.Eq, sympy.Ne) def test_multiplication_division(self, op): a, b, c = self._create_integer_symbols() cases = [ (op(a * b, 1), 1 / b), (op(a * 5, b - 5), (b - 5) / 5), (op(a * b, c), c / b), ] self._test_cases(cases, a, op) @parametrize_relational_types(*INEQUALITY_TYPES) def test_multiplication_division_inequality(self, op): a, b, _ = self._create_integer_symbols() intneg = sympy.Symbol("neg", integer=True, negative=True) intpos = sympy.Symbol("pos", integer=True, positive=True) cases = [ # Divide/multiply both sides by positive number. (op(a * intpos, 1), 1 / intpos), (op(a / (5 * intpos), 1), 5 * intpos), (op(a * 5, b - 5), (b - 5) / 5), # 'b' is not strictly positive nor negative, so we can't # divide/multiply both sides by 'b'. (op(a * b, 1), None), (op(a / b, 1), None), (op(a * b * intpos, 1), None), ] mirror_cases = [ # Divide/multiply both sides by negative number. (op(a * intneg, 1), 1 / intneg), (op(a / (5 * intneg), 1), 5 * intneg), (op(a * -5, b - 5), -(b - 5) / 5), ] mirror_op = mirror_rel_op(op) assert mirror_op is not None self._test_cases(cases, a, op) self._test_cases(mirror_cases, a, mirror_op) @parametrize_relational_types() def test_floordiv(self, op): from sympy import Eq, Ne, Gt, Ge, Lt, Le a, b, c = sympy.symbols("a b c") pos = sympy.Symbol("pos", positive=True) integer = sympy.Symbol("integer", integer=True) # (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), # (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), # (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))), special_case = { # 'FloorDiv' turns into 'And', which can't be simplified any further. Eq: (Eq(FloorDiv(a, pos), integer), None), # 'FloorDiv' turns into 'Or', which can't be simplified any further. Ne: (Ne(FloorDiv(a, pos), integer), None), Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos), Ge: (Ge(FloorDiv(a, pos), integer), integer * pos), Lt: (Lt(FloorDiv(a, pos), integer), integer * pos), Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos), }[op] cases: List[Tuple[sympy.Basic, sympy.Basic]] = [ # 'b' is not strictly positive (op(FloorDiv(a, b), integer), None), # 'c' is not strictly positive (op(FloorDiv(a, pos), c), None), ] # The result might change after 'FloorDiv' transformation. # Specifically: # - [Ge, Gt] => Ge # - [Le, Lt] => Lt if op in (sympy.Gt, sympy.Ge): r_op = sympy.Ge elif op in (sympy.Lt, sympy.Le): r_op = sympy.Lt else: r_op = op self._test_cases([special_case, *cases], a, r_op) self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False) def test_floordiv_eq_simplify(self): from sympy import Eq, Lt, Le a = sympy.Symbol("a", positive=True, integer=True) def check(expr, expected): r = try_solve(expr, a) self.assertNotEqual(r, None) r_expr, _ = r self.assertEqual(r_expr, expected) # (a + 10) // 3 == 3 # ===================================== # 3 * 3 <= a + 10 (always true) # a + 10 < 4 * 3 (not sure) check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10)) # (a + 10) // 2 == 4 # ===================================== # 4 * 2 <= 10 - a (not sure) # 10 - a < 5 * 2 (always true) check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10))) @skipIf(not TEST_Z3, "Z3 not installed") def test_z3_proof_floordiv_eq_simplify(self): import z3 from sympy import Eq, Lt a = sympy.Symbol("a", positive=True, integer=True) a_ = z3.Int("a") # (a + 10) // 3 == 3 # ===================================== # 3 * 3 <= a + 10 (always true) # a + 10 < 4 * 3 (not sure) solver = z3.SolverFor("QF_NRA") # Add assertions for 'a_'. solver.add(a_ > 0) expr = Eq(FloorDiv(a + 10, 3), 3) r_expr, _ = try_solve(expr, a) # Check 'try_solve' really returns the 'expected' below. expected = Lt(a, (3 + 1) * 3 - 10) self.assertEqual(r_expr, expected) # Check whether there is an integer 'a_' such that the # equation below is satisfied. solver.add( # expr (z3.ToInt((a_ + 10) / 3.0) == 3) != # expected (a_ < (3 + 1) * 3 - 10) ) # Assert that there's no such an integer. # i.e. the transformation is sound. r = solver.check() self.assertEqual(r, z3.unsat) def test_simple_floordiv_gcd(self): x, y, z = sympy.symbols("x y z") # positive tests self.assertEqual(simple_floordiv_gcd(x, x), x) self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128) self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128) self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128) self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x) self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x) self.assertEqual(simple_floordiv_gcd(x * y, x), x) self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y) self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x) # negative tests self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) class TestSingletonInt(TestCase): def test_basic(self): j1 = SingletonInt(1, coeff=1) j1_copy = SingletonInt(1, coeff=1) j2 = SingletonInt(2, coeff=1) j1x2 = SingletonInt(1, coeff=2) def test_eq(a, b, expected): self.assertEqual(sympy.Eq(a, b), expected) self.assertEqual(sympy.Ne(b, a), not expected) # eq, ne test_eq(j1, j1, True) test_eq(j1, j1_copy, True) test_eq(j1, j2, False) test_eq(j1, j1x2, False) test_eq(j1, sympy.Integer(1), False) test_eq(j1, sympy.Integer(3), False) def test_ineq(a, b, expected, *, strict=True): greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge) less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le) if isinstance(expected, bool): # expected is always True for fn in greater: self.assertEqual(fn(a, b), expected) self.assertEqual(fn(b, a), not expected) for fn in less: self.assertEqual(fn(b, a), expected) self.assertEqual(fn(a, b), not expected) else: for fn in greater: with self.assertRaisesRegex(ValueError, expected): fn(a, b) for fn in less: with self.assertRaisesRegex(ValueError, expected): fn(b, a) # ge, le, gt, lt for strict in (True, False): _test_ineq = functools.partial(test_ineq, strict=strict) _test_ineq(j1, sympy.Integer(0), True) _test_ineq(j1, sympy.Integer(3), "indeterminate") _test_ineq(j1, j2, "indeterminate") _test_ineq(j1x2, j1, True) # Special cases for ge, le, gt, lt: for ge in (sympy.Ge, is_ge): self.assertTrue(ge(j1, j1)) self.assertTrue(ge(j1, sympy.Integer(2))) with self.assertRaisesRegex(ValueError, "indeterminate"): ge(sympy.Integer(2), j1) for le in (sympy.Le, is_le): self.assertTrue(le(j1, j1)) self.assertTrue(le(sympy.Integer(2), j1)) with self.assertRaisesRegex(ValueError, "indeterminate"): le(j1, sympy.Integer(2)) for gt in (sympy.Gt, is_gt): self.assertFalse(gt(j1, j1)) self.assertFalse(gt(sympy.Integer(2), j1)) # it is only known to be that j1 >= 2, j1 > 2 is indeterminate with self.assertRaisesRegex(ValueError, "indeterminate"): gt(j1, sympy.Integer(2)) for lt in (sympy.Lt, is_lt): self.assertFalse(lt(j1, j1)) self.assertFalse(lt(j1, sympy.Integer(2))) with self.assertRaisesRegex(ValueError, "indeterminate"): lt(sympy.Integer(2), j1) # mul self.assertEqual(j1 * 2, j1x2) # Unfortunately, this doesn't not automatically simplify to 2*j1 # since sympy.Mul doesn't trigger __mul__ unlike the above. self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul) with self.assertRaisesRegex(ValueError, "cannot be multiplied"): j1 * j2 self.assertEqual(j1.free_symbols, set()) instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) instantiate_parametrized_tests(TestSympySolve) if __name__ == "__main__": run_tests()