diff --git a/tests/stages/generation/test.py b/tests/stages/generation/test.py index 94118b1..5158d78 100644 --- a/tests/stages/generation/test.py +++ b/tests/stages/generation/test.py @@ -37,8 +37,7 @@ def generate(N): def test(N): k = generate(N) k_correct = np.load(f"out_{N}.npy") - comparison = k == k_correct - return comparison.all() + return np.allclose(k, k_correct, rtol=1e-10, atol=1e-10) class TestGeneration(unittest.TestCase): def test_8(self):