|
11 | 11 | from dbt.flags import get_flags
|
12 | 12 | from dbt.adapters.factory import get_adapter
|
13 | 13 | from dbt.clients import jinja
|
| 14 | +from dbt.context.providers import ( |
| 15 | + generate_runtime_model_context, |
| 16 | + generate_runtime_unit_test_context, |
| 17 | +) |
14 | 18 | from dbt_common.clients.system import make_directory
|
15 |
| -from dbt.context.providers import generate_runtime_model_context |
16 | 19 | from dbt.contracts.graph.manifest import Manifest, UniqueID
|
17 | 20 | from dbt.contracts.graph.nodes import (
|
18 | 21 | ManifestNode,
|
|
21 | 24 | GraphMemberNode,
|
22 | 25 | InjectedCTE,
|
23 | 26 | SeedNode,
|
| 27 | + UnitTestNode, |
| 28 | + UnitTestDefinition, |
24 | 29 | )
|
25 | 30 | from dbt.exceptions import (
|
26 | 31 | GraphDependencyNotFoundError,
|
|
43 | 48 | def print_compile_stats(stats):
|
44 | 49 | names = {
|
45 | 50 | NodeType.Model: "model",
|
46 |
| - NodeType.Test: "test", |
| 51 | + NodeType.Test: "data test", |
| 52 | + NodeType.Unit: "unit test", |
47 | 53 | NodeType.Snapshot: "snapshot",
|
48 | 54 | NodeType.Analysis: "analysis",
|
49 | 55 | NodeType.Macro: "macro",
|
@@ -91,6 +97,7 @@ def _generate_stats(manifest: Manifest):
|
91 | 97 | stats[NodeType.Macro] += len(manifest.macros)
|
92 | 98 | stats[NodeType.Group] += len(manifest.groups)
|
93 | 99 | stats[NodeType.SemanticModel] += len(manifest.semantic_models)
|
| 100 | + stats[NodeType.Unit] += len(manifest.unit_tests) |
94 | 101 |
|
95 | 102 | # TODO: should we be counting dimensions + entities?
|
96 | 103 |
|
@@ -128,7 +135,7 @@ class Linker:
|
128 | 135 | def __init__(self, data=None) -> None:
|
129 | 136 | if data is None:
|
130 | 137 | data = {}
|
131 |
| - self.graph = nx.DiGraph(**data) |
| 138 | + self.graph: nx.DiGraph = nx.DiGraph(**data) |
132 | 139 |
|
133 | 140 | def edges(self):
|
134 | 141 | return self.graph.edges()
|
@@ -191,6 +198,8 @@ def link_graph(self, manifest: Manifest):
|
191 | 198 | self.link_node(exposure, manifest)
|
192 | 199 | for metric in manifest.metrics.values():
|
193 | 200 | self.link_node(metric, manifest)
|
| 201 | + for unit_test in manifest.unit_tests.values(): |
| 202 | + self.link_node(unit_test, manifest) |
194 | 203 | for saved_query in manifest.saved_queries.values():
|
195 | 204 | self.link_node(saved_query, manifest)
|
196 | 205 |
|
@@ -234,6 +243,7 @@ def add_test_edges(self, manifest: Manifest) -> None:
|
234 | 243 | # Get all tests that depend on any upstream nodes.
|
235 | 244 | upstream_tests = []
|
236 | 245 | for upstream_node in upstream_nodes:
|
| 246 | + # This gets tests with unique_ids starting with "test." |
237 | 247 | upstream_tests += _get_tests_for_node(manifest, upstream_node)
|
238 | 248 |
|
239 | 249 | for upstream_test in upstream_tests:
|
@@ -291,8 +301,10 @@ def _create_node_context(
|
291 | 301 | manifest: Manifest,
|
292 | 302 | extra_context: Dict[str, Any],
|
293 | 303 | ) -> Dict[str, Any]:
|
294 |
| - |
295 |
| - context = generate_runtime_model_context(node, self.config, manifest) |
| 304 | + if isinstance(node, UnitTestNode): |
| 305 | + context = generate_runtime_unit_test_context(node, self.config, manifest) |
| 306 | + else: |
| 307 | + context = generate_runtime_model_context(node, self.config, manifest) |
296 | 308 | context.update(extra_context)
|
297 | 309 |
|
298 | 310 | if isinstance(node, GenericTestNode):
|
@@ -460,6 +472,7 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph
|
460 | 472 | summaries["_invocation_id"] = get_invocation_id()
|
461 | 473 | summaries["linked"] = linker.get_graph_summary(manifest)
|
462 | 474 |
|
| 475 | + # This is only called for the "build" command |
463 | 476 | if add_test_edges:
|
464 | 477 | manifest.build_parent_and_child_maps()
|
465 | 478 | linker.add_test_edges(manifest)
|
@@ -526,6 +539,9 @@ def compile_node(
|
526 | 539 | the node's raw_code into compiled_code, and then calls the
|
527 | 540 | recursive method to "prepend" the ctes.
|
528 | 541 | """
|
| 542 | + if isinstance(node, UnitTestDefinition): |
| 543 | + return node |
| 544 | + |
529 | 545 | # Make sure Lexer for sqlparse 0.4.4 is initialized
|
530 | 546 | from sqlparse.lexer import Lexer # type: ignore
|
531 | 547 |
|
|
0 commit comments