Skip to content

Commit 6735f41

Browse files
mbyrnepr2Pierre-SassoulasDanielNoord
authored
Fix order of overwritten attributes in inherited dataclasses (#1970) (#1972)
Co-authored-by: Pierre Sassoulas <[email protected]> Co-authored-by: Daniël van Noord <[email protected]>
1 parent 1c70358 commit 6735f41

File tree

3 files changed

+118
-28
lines changed

3 files changed

+118
-28
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ Release date: TBA
1616

1717
Closes #1958
1818

19+
* Fix overwritten attributes in inherited dataclasses not being ordered correctly.
20+
21+
Closes PyCQA/pylint#7881
22+
1923
* Fix a false positive when an attribute named ``Enum`` was confused with ``enum.Enum``.
2024
Calls to ``Enum`` are now inferred & the qualified name is checked.
2125

astroid/brain/brain_dataclasses.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
167167

168168

169169
def _find_arguments_from_base_classes(
170-
node: nodes.ClassDef, skippable_names: set[str]
171-
) -> tuple[str, str]:
172-
"""Iterate through all bases and add them to the list of arguments to add to the
173-
init.
174-
"""
170+
node: nodes.ClassDef,
171+
) -> tuple[
172+
dict[str, tuple[str | None, str | None]], dict[str, tuple[str | None, str | None]]
173+
]:
174+
"""Iterate through all bases and get their typing and defaults."""
175175
pos_only_store: dict[str, tuple[str | None, str | None]] = {}
176176
kw_only_store: dict[str, tuple[str | None, str | None]] = {}
177177
# See TODO down below
@@ -187,8 +187,6 @@ def _find_arguments_from_base_classes(
187187

188188
pos_only, kw_only = base_init.args._get_arguments_data()
189189
for posarg, data in pos_only.items():
190-
if posarg in skippable_names:
191-
continue
192190
# if data[1] is None:
193191
# if all_have_defaults and pos_only_store:
194192
# # TODO: This should return an Uninferable as this would raise
@@ -199,10 +197,15 @@ def _find_arguments_from_base_classes(
199197
pos_only_store[posarg] = data
200198

201199
for kwarg, data in kw_only.items():
202-
if kwarg in skippable_names:
203-
continue
204200
kw_only_store[kwarg] = data
201+
return pos_only_store, kw_only_store
202+
205203

204+
def _parse_arguments_into_strings(
205+
pos_only_store: dict[str, tuple[str | None, str | None]],
206+
kw_only_store: dict[str, tuple[str | None, str | None]],
207+
) -> tuple[str, str]:
208+
"""Parse positional and keyword arguments into strings for an __init__ method."""
206209
pos_only, kw_only = "", ""
207210
for pos_arg, data in pos_only_store.items():
208211
pos_only += pos_arg
@@ -248,11 +251,11 @@ def _generate_dataclass_init( # pylint: disable=too-many-locals
248251
params: list[str] = []
249252
kw_only_params: list[str] = []
250253
assignments: list[str] = []
251-
assign_names: list[str] = []
254+
255+
prev_pos_only_store, prev_kw_only_store = _find_arguments_from_base_classes(node)
252256

253257
for assign in assigns:
254258
name, annotation, value = assign.target.name, assign.annotation, assign.value
255-
assign_names.append(name)
256259

257260
# Check whether this assign is overriden by a property assignment
258261
property_node: nodes.FunctionDef | None = None
@@ -275,6 +278,9 @@ def _generate_dataclass_init( # pylint: disable=too-many-locals
275278
keyword.arg == "init" and not keyword.value.bool_value()
276279
for keyword in value.keywords # type: ignore[union-attr] # value is never None
277280
):
281+
# Also remove the name from the previous arguments to be inserted later
282+
prev_pos_only_store.pop(name, None)
283+
prev_kw_only_store.pop(name, None)
278284
continue
279285

280286
if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
@@ -289,32 +295,32 @@ def _generate_dataclass_init( # pylint: disable=too-many-locals
289295
init_var = False
290296
assignment_str = f"self.{name} = {name}"
291297

298+
ann_str, default_str = None, None
292299
if annotation is not None:
293-
param_str = f"{name}: {annotation.as_string()}"
294-
else:
295-
param_str = name
300+
ann_str = annotation.as_string()
296301

297302
if value:
298303
if is_field:
299304
result = _get_field_default(value) # type: ignore[arg-type]
300305
if result:
301306
default_type, default_node = result
302307
if default_type == "default":
303-
param_str += f" = {default_node.as_string()}"
308+
default_str = default_node.as_string()
304309
elif default_type == "default_factory":
305-
param_str += f" = {DEFAULT_FACTORY}"
310+
default_str = DEFAULT_FACTORY
306311
assignment_str = (
307312
f"self.{name} = {default_node.as_string()} "
308313
f"if {name} is {DEFAULT_FACTORY} else {name}"
309314
)
310315
else:
311-
param_str += f" = {value.as_string()}"
316+
default_str = value.as_string()
312317
elif property_node:
313318
# We set the result of the property call as default
314319
# This hides the fact that this would normally be a 'property object'
315320
# But we can't represent those as string
316321
try:
317-
param_str += f" = {next(property_node.infer_call_result()).as_string()}"
322+
# Call str to make sure also Uninferable gets stringified
323+
default_str = str(next(property_node.infer_call_result()).as_string())
318324
except (InferenceError, StopIteration):
319325
pass
320326
else:
@@ -323,7 +329,14 @@ def _generate_dataclass_init( # pylint: disable=too-many-locals
323329
# (self, a: str = 1) -> None
324330
previous_default = _get_previous_field_default(node, name)
325331
if previous_default:
326-
param_str += f" = {previous_default.as_string()}"
332+
default_str = previous_default.as_string()
333+
334+
# Construct the param string to add to the init if necessary
335+
param_str = name
336+
if ann_str is not None:
337+
param_str += f": {ann_str}"
338+
if default_str is not None:
339+
param_str += f" = {default_str}"
327340

328341
# If the field is a kw_only field, we need to add it to the kw_only_params
329342
# This overwrites whether or not the class is kw_only decorated
@@ -337,21 +350,33 @@ def _generate_dataclass_init( # pylint: disable=too-many-locals
337350
continue
338351
# If kw_only decorated, we need to add all parameters to the kw_only_params
339352
if kw_only_decorated:
340-
kw_only_params.append(param_str)
353+
if name in prev_kw_only_store:
354+
prev_kw_only_store[name] = (ann_str, default_str)
355+
else:
356+
kw_only_params.append(param_str)
341357
else:
342-
params.append(param_str)
358+
# If the name was previously seen, overwrite that data
359+
# pylint: disable-next=else-if-used
360+
if name in prev_pos_only_store:
361+
prev_pos_only_store[name] = (ann_str, default_str)
362+
elif name in prev_kw_only_store:
363+
params = [name] + params
364+
prev_kw_only_store.pop(name)
365+
else:
366+
params.append(param_str)
343367

344368
if not init_var:
345369
assignments.append(assignment_str)
346370

347-
prev_pos_only, prev_kw_only = _find_arguments_from_base_classes(
348-
node, set(assign_names + ["self"])
371+
prev_pos_only, prev_kw_only = _parse_arguments_into_strings(
372+
prev_pos_only_store, prev_kw_only_store
349373
)
350374

351375
# Construct the new init method paramter string
352376
# First we do the positional only parameters, making sure to add the
353377
# the self parameter and the comma to allow adding keyword only parameters
354-
params_string = f"self, {prev_pos_only}{', '.join(params)}"
378+
params_string = "" if "self" in prev_pos_only else "self, "
379+
params_string += prev_pos_only + ", ".join(params)
355380
if not params_string.endswith(", "):
356381
params_string += ", "
357382

tests/unittest_brain_dataclasses.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,67 @@ class B(A):
376376
assert inferred[1].name == "str"
377377

378378

379+
def test_dataclass_order_of_inherited_attributes():
380+
"""Test that an attribute in a child does not get put at the end of the init."""
381+
child, normal, keyword_only = astroid.extract_node(
382+
"""
383+
from dataclass import dataclass
384+
385+
386+
@dataclass
387+
class Parent:
388+
a: str
389+
b: str
390+
391+
392+
@dataclass
393+
class Child(Parent):
394+
c: str
395+
a: str
396+
397+
398+
@dataclass(kw_only=True)
399+
class KeywordOnlyParent:
400+
a: int
401+
b: str
402+
403+
404+
@dataclass
405+
class NormalChild(KeywordOnlyParent):
406+
c: str
407+
a: str
408+
409+
410+
@dataclass(kw_only=True)
411+
class KeywordOnlyChild(KeywordOnlyParent):
412+
c: str
413+
a: str
414+
415+
416+
Child.__init__ #@
417+
NormalChild.__init__ #@
418+
KeywordOnlyChild.__init__ #@
419+
"""
420+
)
421+
child_init: bases.UnboundMethod = next(child.infer())
422+
assert [a.name for a in child_init.args.args] == ["self", "a", "b", "c"]
423+
424+
normal_init: bases.UnboundMethod = next(normal.infer())
425+
if PY310_PLUS:
426+
assert [a.name for a in normal_init.args.args] == ["self", "a", "c"]
427+
assert [a.name for a in normal_init.args.kwonlyargs] == ["b"]
428+
else:
429+
assert [a.name for a in normal_init.args.args] == ["self", "a", "b", "c"]
430+
assert [a.name for a in normal_init.args.kwonlyargs] == []
431+
432+
keyword_only_init: bases.UnboundMethod = next(keyword_only.infer())
433+
if PY310_PLUS:
434+
assert [a.name for a in keyword_only_init.args.args] == ["self"]
435+
assert [a.name for a in keyword_only_init.args.kwonlyargs] == ["a", "b", "c"]
436+
else:
437+
assert [a.name for a in keyword_only_init.args.args] == ["self", "a", "b", "c"]
438+
439+
379440
def test_pydantic_field() -> None:
380441
"""Test that pydantic.Field attributes are currently Uninferable.
381442
@@ -628,12 +689,12 @@ class B(A):
628689
"""
629690
)
630691
init = next(node.infer())
631-
assert [a.name for a in init.args.args] == ["self", "arg0", "arg1", "arg2"]
692+
assert [a.name for a in init.args.args] == ["self", "arg0", "arg2", "arg1"]
632693
assert [a.as_string() if a else None for a in init.args.annotations] == [
633694
None,
634695
"float",
635-
"int",
636696
"list", # not str
697+
"int",
637698
]
638699

639700

@@ -1035,8 +1096,8 @@ class ChildWithMixedParents(BaseParent, NotADataclassParent):
10351096
assert [a.value for a in overwritten_init.args.defaults] == ["2"]
10361097

10371098
overwriting_init: bases.UnboundMethod = next(overwriting.infer())
1038-
assert [a.name for a in overwriting_init.args.args] == ["self", "_abc", "ef"]
1039-
assert [a.value for a in overwriting_init.args.defaults] == [1.0, 2.0]
1099+
assert [a.name for a in overwriting_init.args.args] == ["self", "ef", "_abc"]
1100+
assert [a.value for a in overwriting_init.args.defaults] == [2.0, 1.0]
10401101

10411102
mixed_init: bases.UnboundMethod = next(mixed.infer())
10421103
assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"]

0 commit comments

Comments
 (0)