diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index 6ae4f40ae3..ca5a35ca86 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -1415,6 +1415,51 @@ def test_augassign(self) -> None: self.assertIsInstance(inferred[0], nodes.Const) self.assertEqual(inferred[0].value, 3) + def test_augassign_multi(self) -> None: + code = """ + a = 1 + a += 1 + a += 1 + print (a) + """ + ast = parse(code, __name__) + inferred = list(test_utils.get_name_node(ast, "a").infer()) + + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], nodes.Const) + self.assertEqual(inferred[0].value, 3) + + def test_augassign_multi_expr(self) -> None: + code = """ + a = 1 + a += 1 + a += 1 + a + """ + ast = parse(code, __name__) + # No inference function for Expr + inferred = list(ast.body[-1].value.infer()) + + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], nodes.Const) + self.assertEqual(inferred[0].value, 3) + + def test_augassign_multi_list(self) -> None: + code = """ + a = [] + a += [1] + a += [1] + print (a) + """ + ast = parse(code, __name__) + inferred = list(test_utils.get_name_node(ast, "a").infer()) + + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], nodes.List) + self.assertEqual(len(inferred[0].elts), 2) + self.assertEqual(inferred[0].elts[1].value, 1) + self.assertEqual(inferred[0].elts[0].value, 1) + def test_nonregr_func_arg(self) -> None: code = """ def foo(self, bar): diff --git a/tests/unittest_lookup.py b/tests/unittest_lookup.py index 1cb2526188..f98ec10364 100644 --- a/tests/unittest_lookup.py +++ b/tests/unittest_lookup.py @@ -1009,6 +1009,63 @@ def test_except_assign_after_block_overwritten(self) -> None: self.assertEqual(len(stmts), 1) self.assertEqual(stmts[0].lineno, 8) + def test_except_assign_exclusive_branches_getattr(self) -> None: + """When a variable is assigned in exlcusive branches, both are returned""" + code = """ + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + print(x) + """ + astroid = builder.parse(code) + stmts = astroid.getattr("x") + self.assertEqual(len(stmts), 2) + + self.assertEqual(stmts[0].lineno, 5) + self.assertEqual(stmts[1].lineno, 7) + + def test_except_assign_after_block_overwritten_getattr(self) -> None: + """When a variable is assigned in an except clause, it is not returned + when it is reassigned and used after the except block. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + x = 1000 + print(x) + """ + astroid = builder.parse(code) + stmts = astroid.getattr("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 8) + + def test_except_assign_after_block_overwritten_getattr_class(self) -> None: + """When a variable is assigned in an except clause, it is not returned + when it is reassigned and used after the except block. + """ + code = """ + class C: + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + x = 1000 + print(x) + C.x + """ + astroid = builder.parse(code) + stmts = list(astroid.body[-1].value.infer()) + self.assertEqual(len(stmts), 1) + if __name__ == "__main__": unittest.main()