2
2
# We don't use models from the `model_zoo.jl` file because they are subject to change
3
3
# These tests are meant to be stable and not change often
4
4
5
- @testitem " Simple model" begin
5
+ @testitem " Simple model #1 " begin
6
6
using Distributions
7
7
8
- import GraphPPL: create_model, getcontext, add_terminated_submodel !, factor_nodes, variable_nodes, is_constant, getproperties, as_node, as_variable
8
+ import GraphPPL: create_model, getcontext, add_toplevel_model !, factor_nodes, variable_nodes, is_constant, getproperties, as_node, as_variable
9
9
10
- @model function simple_model ()
10
+ @model function simple_model_1 ()
11
11
x ~ Normal (0 , 1 )
12
12
y ~ Gamma (1 , 1 )
13
13
z ~ Normal (x, y)
16
16
model = create_model ()
17
17
context = getcontext (model)
18
18
19
- add_terminated_submodel ! (model, context, simple_model , NamedTuple ())
19
+ add_toplevel_model ! (model, simple_model_1 , NamedTuple ())
20
20
21
21
flabels = collect (factor_nodes (model))
22
22
vlabels = collect (variable_nodes (model))
33
33
@test length (collect (filter (as_variable (:x ), model))) === 1
34
34
@test length (collect (filter (as_variable (:y ), model))) === 1
35
35
@test length (collect (filter (as_variable (:z ), model))) === 1
36
+ end
37
+
38
+ @testitem " Simple model #2" begin
39
+ using Distributions
40
+ using GraphPPL: create_model, getcontext, getorcreate!, add_toplevel_model!, as_node, NodeCreationOptions, prune!
41
+
42
+ @model function simple_model_2 (a, b, c)
43
+ x ~ Gamma (α = b, θ = sqrt (c))
44
+ a ~ Normal (μ = x, τ = 1 )
45
+ end
46
+
47
+ model = create_model ()
48
+ context = getcontext (model)
49
+
50
+ a = getorcreate! (model, context, NodeCreationOptions (datavar = true ), :a , nothing )
51
+ b = getorcreate! (model, context, NodeCreationOptions (datavar = true ), :b , nothing )
52
+ c = 1.0
53
+
54
+ add_toplevel_model! (model, simple_model_2, (a = a, b = b, c = c))
55
+
56
+ prune! (model)
57
+
58
+ @test length (collect (filter (as_node (Gamma), model))) === 1
59
+ @test length (collect (filter (as_node (Normal), model))) === 1
60
+ @test length (collect (filter (as_node (sqrt), model))) === 0 # should be compiled out, c is a constant
36
61
end
0 commit comments