Skip to content

Commit fc86aab

Browse files
committed
Move factorization constraints to its own plugin
1 parent 576ba3a commit fc86aab

14 files changed

+465
-258
lines changed

.JuliaFormatter.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
style = "blue"
22
indent = 4
3-
margin = 180
3+
margin = 140
44
always_for_in = true
55
whitespace_typedefs = true
66
whitespace_ops_in_indices = true

benchmark/model_zoo.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ function create_hgf(n::Int)
2222
ω = GraphPPL.getorcreate!(model, ctx, , nothing)
2323
θ = GraphPPL.getorcreate!(model, ctx, , nothing)
2424
x_begin = GraphPPL.getorcreate!(model, ctx, :x_begin, nothing)
25-
GraphPPL.add_terminated_submodel!(
25+
GraphPPL.add_toplevel_model!(
2626
model,
2727
ctx,
2828
hgf,
29-
= κ, ω = ω, θ = θ, x_begin = x_begin, depth = n);
30-
__debug__ = false,
29+
= κ, ω = ω, θ = θ, x_begin = x_begin, depth = n)
3130
)
3231
return model
3332
end

docs/src/developers_guide.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,11 @@ x = GraphPPL.getorcreate!(model, context, :x, nothing; options = GraphPPL.Variab
5757
y = GraphPPL.getorcreate!(model, context, :y, nothing; options = GraphPPL.VariableNodeOptions(datavar=true))
5858
5959
# Add the gcv model
60-
GraphPPL.add_terminated_submodel!(
60+
GraphPPL.add_toplevel_model!(
6161
model,
6262
context,
6363
gcv,
6464
(κ = κ, ω = ω, z = z, x = x, y = y);
65-
__debug__ = false,
6665
)
6766
6867
# Apply the constraints

src/graph_engine.jl

+191-77
Large diffs are not rendered by default.

src/plugins/node_created_by.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct NodeCreatedByPlugin end
88

99
const EmptyCreatedBy = Expr(:line, 0)
1010

11-
GraphPPL.plugin_type(::NodeCreatedByPlugin) = FactorNodePlugin()
11+
plugin_type(::NodeCreatedByPlugin) = FactorNodePlugin()
1212

1313
# The `created_by` field is used to track the expression that created the node.
1414
# The field can be a lambda function in which case it must be evaluated to get the expression.
@@ -21,7 +21,7 @@ Base.show(io::IO, createdby::CreatedBy) = show_createdby(io, createdby.created_b
2121
show_createdby(io::IO, created_by::Expr) = print(io, created_by)
2222
show_createdby(io::IO, created_by::Function) = show_createdby(io, created_by())
2323

24-
function process_plugin(::NodeCreatedByPlugin, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options::NodeCreationOptions)
24+
function preprocess_plugin(::NodeCreatedByPlugin, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options::NodeCreationOptions)
2525
created_by = get(options, :created_by, EmptyCreatedBy)
2626
setextra!(nodedata, :created_by, CreatedBy(created_by))
2727
return label, nodedata

src/plugins/variational_constraints/variational_constraints.jl

+40
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,43 @@ struct FullFactorization end
1515
include("variational_constraints_macro.jl")
1616
include("variational_constraints_engine.jl")
1717

18+
"""
19+
VariationalConstraintsPlugin(constraints)
20+
21+
A plugin that adds a VI related properties to the factor node for the variational inference procedure.
22+
"""
23+
struct VariationalConstraintsPlugin{C}
24+
constraints::C
25+
end
26+
27+
GraphPPL.plugin_type(::VariationalConstraintsPlugin) = FactorAndVariableNodesPlugin()
28+
29+
function preprocess_plugin(plugin::VariationalConstraintsPlugin, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options::NodeCreationOptions)
30+
preprocess_vi_plugin!(plugin, nodedata, getproperties(nodedata))
31+
return label, nodedata
32+
end
33+
34+
function preprocess_vi_plugin!(::VariationalConstraintsPlugin, nodedata::NodeData, nodeproperties::FactorNodeProperties)
35+
if hasextra(nodedata, :factorization_constraints)
36+
error("Factorizatiom constraints has been already defined for the node ", nodedata, ".")
37+
end
38+
return nothing
39+
end
40+
41+
function preprocess_vi_plugin!(::VariationalConstraintsPlugin, nodedata::NodeData, nodeproperties::VariableNodeProperties)
42+
# TODO bvdmitri: todo, add functional form constraints and messages constraints here
43+
return nothing
44+
end
45+
46+
## Applies the constraints in `constraints` to `model`. This function materializes the constraints in `constraints` and applies them to `model`.
47+
function postprocess_plugin(plugin::VariationalConstraintsPlugin, model::Model)
48+
# Attach `BitSetTuples` according to the number of neighbours of the factor node
49+
foreach(factor_nodes(model)) do flabel
50+
nodedata = model[flabel]
51+
nodeproperties = getproperties(nodedata)
52+
number_of_neighbours = length(neighbors(nodeproperties))
53+
setextra!(nodedata, :factorization_constraint_bitset, BitSetTuple(number_of_neighbours))
54+
end
55+
apply_constraints!(model, GraphPPL.get_principal_submodel(model), plugin.constraints, ConstraintStack())
56+
materialize_constraints!(model)
57+
end

src/plugins/variational_constraints/variational_constraints_engine.jl

+154-109
Large diffs are not rendered by default.

src/plugins_collection.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,17 @@ function Base.filter(::UnknownPluginType, collection::PluginsCollection)
5151
end
5252

5353
function Base.filter(trait::AbstractPluginTraitType, collection::PluginsCollection)
54-
return PluginsCollection(filter(plugin -> plugin_type(plugin) === trait, collection.collection))
54+
return PluginsCollection(filter(plugin -> isequal(plugin_type(plugin), trait), collection.collection))
5555
end
5656

57+
struct UnionPluginType{T, U} <: AbstractPluginTraitType
58+
trait1::T
59+
trait2::U
60+
end
61+
62+
function Base.isequal(type::AbstractPluginTraitType, union::UnionPluginType)
63+
return isequal(type, union.trait1) || isequal(type, union.trait2)
64+
end
5765

5866

5967
# # OLD STUFF BELOW

test/graph_construction_tests.jl

+29-4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# We don't use models from the `model_zoo.jl` file because they are subject to change
33
# These tests are meant to be stable and not change often
44

5-
@testitem "Simple model" begin
5+
@testitem "Simple model #1" begin
66
using Distributions
77

8-
import GraphPPL: create_model, getcontext, add_terminated_submodel!, factor_nodes, variable_nodes, is_constant, getproperties, as_node, as_variable
8+
import GraphPPL: create_model, getcontext, add_toplevel_model!, factor_nodes, variable_nodes, is_constant, getproperties, as_node, as_variable
99

10-
@model function simple_model()
10+
@model function simple_model_1()
1111
x ~ Normal(0, 1)
1212
y ~ Gamma(1, 1)
1313
z ~ Normal(x, y)
@@ -16,7 +16,7 @@
1616
model = create_model()
1717
context = getcontext(model)
1818

19-
add_terminated_submodel!(model, context, simple_model, NamedTuple())
19+
add_toplevel_model!(model, simple_model_1, NamedTuple())
2020

2121
flabels = collect(factor_nodes(model))
2222
vlabels = collect(variable_nodes(model))
@@ -33,4 +33,29 @@
3333
@test length(collect(filter(as_variable(:x), model))) === 1
3434
@test length(collect(filter(as_variable(:y), model))) === 1
3535
@test length(collect(filter(as_variable(:z), model))) === 1
36+
end
37+
38+
@testitem "Simple model #2" begin
39+
using Distributions
40+
using GraphPPL: create_model, getcontext, getorcreate!, add_toplevel_model!, as_node, NodeCreationOptions, prune!
41+
42+
@model function simple_model_2(a, b, c)
43+
x ~ Gamma= b, θ = sqrt(c))
44+
a ~ Normal= x, τ = 1)
45+
end
46+
47+
model = create_model()
48+
context = getcontext(model)
49+
50+
a = getorcreate!(model, context, NodeCreationOptions(datavar = true), :a, nothing)
51+
b = getorcreate!(model, context, NodeCreationOptions(datavar = true), :b, nothing)
52+
c = 1.0
53+
54+
add_toplevel_model!(model, simple_model_2, (a = a, b = b, c = c))
55+
56+
prune!(model)
57+
58+
@test length(collect(filter(as_node(Gamma), model))) === 1
59+
@test length(collect(filter(as_node(Normal), model))) === 1
60+
@test length(collect(filter(as_node(sqrt), model))) === 0 # should be compiled out, c is a constant
3661
end

test/graph_engine_tests.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ end
271271

272272
count = Ref(0)
273273

274-
function GraphPPL.process_plugin(::AnArbitraryPluginForTestUniqeness, model, context, label, nodedata, options)
274+
function GraphPPL.preprocess_plugin(::AnArbitraryPluginForTestUniqeness, model, context, label, nodedata, options)
275275
setextra!(nodedata, :count, count[])
276276
count[] = count[] + 1
277277
return label, nodedata
@@ -298,7 +298,7 @@ end
298298

299299
GraphPPL.plugin_type(::AnArbitraryPluginForChangingOptions) = GraphPPL.VariableNodePlugin()
300300

301-
function GraphPPL.process_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options)
301+
function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options)
302302
# Here we replace the original options entirely
303303
return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(constant = true, value = 1.0)))
304304
end
@@ -479,15 +479,18 @@ end
479479

480480
model = create_model()
481481
ctx = getcontext(model)
482+
@test isempty(model)
482483
@test nv(model) == 0
483484
@test ne(model) == 0
484485

485486
model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
486487
model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing))
488+
@test !isempty(model)
487489
@test nv(model) == 2
488490
@test ne(model) == 0
489491

490492
model[NodeLabel(:a, 1), NodeLabel(:b, 2)] = EdgeLabel(:edge, 1)
493+
@test !isempty(model)
491494
@test nv(model) == 2
492495
@test ne(model) == 1
493496
end

test/model_zoo.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ GraphPPL.factor_alias(::Type{Gamma}, ::Val{(:α, :θ)}) = GammaShapeScale
4646
function create_terminated_model(fform; plugins = GraphPPL.PluginsCollection())
4747
__model__ = GraphPPL.create_model(; fform = fform, plugins = plugins)
4848
__context__ = GraphPPL.getcontext(__model__)
49-
GraphPPL.add_terminated_submodel!(__model__, __context__, fform, NamedTuple())
49+
GraphPPL.add_toplevel_model!(__model__, __context__, fform, NamedTuple())
5050
return __model__
5151
end
5252

test/plugins/node_created_by_tests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
@testitem "Usage with the actual model" begin
6666
using Distributions
6767

68-
import GraphPPL: create_model, getcontext, add_terminated_submodel!, factor_nodes, as_node, hasextra, PluginsCollection, NodeCreatedByPlugin, getextra
68+
import GraphPPL: create_model, getcontext, add_toplevel_model!, factor_nodes, as_node, hasextra, PluginsCollection, NodeCreatedByPlugin, getextra
6969

7070
@model function simple_model()
7171
x ~ Normal(0, 1)
@@ -76,7 +76,7 @@ end
7676
model = create_model(plugins = PluginsCollection(NodeCreatedByPlugin()))
7777
context = getcontext(model)
7878

79-
add_terminated_submodel!(model, context, simple_model, NamedTuple())
79+
add_toplevel_model!(model, simple_model, NamedTuple())
8080

8181
fnormal = map(label -> model[label], filter(as_node(Normal), model))
8282
fgamma = map(label -> model[label], filter(as_node(Gamma), model))

0 commit comments

Comments
 (0)