diff --git a/.gitignore b/.gitignore index 634cdc0..92f5024 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ worker-* tmp* dask-worker-space __pycache__ +env.d/jenv/ \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml index 22b471e..00fd886 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,13 +2,13 @@ julia_version = "1.9.3" manifest_format = "2.0" -project_hash = "8ead7856fbb2572c3c0a941b94dca0886303c698" +project_hash = "87c0603f65be5921938c725fd12195c48bc0f93f" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "954634616d5846d8e216df1298be2298d55280b2" +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Test"] +git-tree-sha1 = "a7055b939deae2455aa8a67491e034f735dd08d3" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.32" +version = "0.1.33" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" @@ -19,6 +19,7 @@ version = "0.1.32" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" @@ -101,6 +102,11 @@ git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" version = "0.12.10" +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -356,7 +362,7 @@ version = "0.72.9+1" [[deps.Gen]] deps = ["Compat", "DataStructures", "Distributions", "ForwardDiff", "FunctionalCollections", "JSON", "LinearAlgebra", "MacroTools", "Parameters", "Random", "ReverseDiff", "SpecialFunctions"] -git-tree-sha1 = "9878ff4ab1990f5647e89b4228a3c9da5f0e69c7" +path = "/home/dg963/GalileoEvents/env.d/jenv/dev/Gen" uuid = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" version = "0.4.6" @@ -732,8 +738,8 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.7.2" [[deps.PhyBullet]] -deps = ["Accessors", "Conda", "Distributions", "DocStringExtensions", "Gen", "Parameters", "PhySMC", "Plots", "PyCall", "Revise", "StaticArrays", "UnicodePlots"] -git-tree-sha1 = "f553bbf8cfdc3a291380ee17f600f818f6cad054" +deps = ["Accessors", "Conda", "Distributions", "DocStringExtensions", "Gen", "Parameters", "PhySMC", "PyCall", "Revise", "StaticArrays", "UnicodePlots"] +git-tree-sha1 = "9fddd996bd0e0fded73a3a15d7f579e1f76cc46f" repo-rev = "master" repo-url = "https://github.com/CNCLgithub/PhyBullet" uuid = "63daae69-5b14-439d-ac6f-096429ca839b" diff --git a/Project.toml b/Project.toml index 79b2d73..95a084f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,13 +4,17 @@ authors = ["belledon"] version = "0.1.0" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Gen = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" Gen_Compose = "c1ef4dca-b0a6-4a35-b24b-46cbf3979a16" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PhyBullet = "63daae69-5b14-439d-ac6f-096429ca839b" PhySMC = "79c1e2f5-7911-41a0-b248-4858717ddd79" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/src/GalileoEvents.jl b/src/GalileoEvents.jl index 77b4588..8edea2b 100644 --- a/src/GalileoEvents.jl +++ b/src/GalileoEvents.jl @@ -1,5 +1,6 @@ module GalileoEvents +using Accessors using Gen using Gen_Compose using PhySMC diff --git a/src/gms/cp_gm_pb.jl b/src/gms/cp_gm_pb.jl new file mode 100644 index 0000000..38b03a3 --- /dev/null +++ b/src/gms/cp_gm_pb.jl @@ -0,0 +1,289 @@ +using Revise + +export CPParams, + CPState, + cp_model, + EventRelation, + Collision + +using LinearAlgebra:norm +using Combinatorics + +## Changepoint Model + components + +""" +Event types +""" +abstract type EventRelation end + +struct Collision <: EventRelation + a::RigidBody + b::RigidBody +end + +struct NoEvent <: EventRelation end + + +""" +holds parameters for change point model +""" +struct CPParams <: GMParams + # prior + material_prior::MaterialPrior + physics_prior::PhysPrior + # event relations + event_concepts::Vector{Type{<:EventRelation}} + # simulation + sim::BulletSim + template::BulletState + n_objects::Int64 + obs_noise::Float64 + death_factor::Float64 +end + +""" +constructs parameter struct for change point model +""" +function CPParams(client::Int64, objs::Vector{Int64}, + mprior::MaterialPrior, pprior::PhysPrior, + event_concepts::Vector{Type{<:EventRelation}}, + obs_noise::Float64=0., + death_factor=10.) + # configure simulator with the provided + # client id + sim = BulletSim(;client=client) + # These are the objects of interest in the scene + rigid_bodies = RigidBody.(objs) + # Retrieve the default latents for the objects + # as well as their initial positions + # Note: alternative latents will be suggested by the `prior` + template = BulletState(sim, rigid_bodies) + + CPParams(mprior, pprior, event_concepts, sim, template, length(objs), obs_noise, death_factor) +end + +""" +Current state of the change point model, simulation state and event state +""" +struct CPState <: GMState + bullet_state::BulletState + active_events::Set{Int64} +end + +## PRIOR + +""" +initalizes prior beliefs about mass, friction and resitution of the given objects +""" +@gen function cp_object_prior(ls::RigidBodyLatents, gm::CPParams) + # sample material + mi = @trace(categorical(gm.material_prior.material_weights), :material) + + # sample physical properties + phys_params = gm.physics_prior + mass_mu, mass_sd = phys_params.mass + mass = @trace(trunc_norm(mass_mu, mass_sd, 0., Inf), :mass) + fric_mu, fric_sd = phys_params.friction + friction = @trace(trunc_norm(fric_mu,fric_sd, 0., 1.), :friction) + res_low, res_high = phys_params.restitution + restitution = @trace(uniform(res_low, res_high), :restitution) + + # package + new_ls = setproperties(ls.data; + mass = mass, + lateralFriction = friction, + restitution = restitution) + new_latents::RigidBodyLatents = RigidBodyLatents(new_ls) + return new_latents +end + +""" +initializes belief about all objects and events +""" +@gen function cp_prior(params::CPParams) + # initialize the kinematic state + latents = params.template.latents + params_filled = Fill(params, length(latents)) + new_latents = @trace(Gen.Map(cp_object_prior)(latents, params_filled), :objects) + bullet_state = setproperties(params.template; latents = new_latents) + + # initialize the event state + active_events = Set{Int64}() + + init_state = CPState(bullet_state, active_events) + return init_state +end + +""" +Bernoulli weight that event relation holds +""" +function predicate(t::Type{Collision}, a::RigidBodyState, b::RigidBodyState) + if norm(Vector(a.linear_vel)-Vector(b.linear_vel)) < 0.01 + return 0 + end + + a_dim = a.aabb[2] - a.aabb[1] + b_dim = b.aabb[2] - b.aabb[1] + d = norm(Vector(a.position)-Vector(b.position))-norm((a_dim+b_dim)/2) # l2 distance + clamp(exp(-15d), 1e-3, 1 - 1e-3) +end + + +""" +update latents of a single element +""" +@gen function update_latents(latents::BulletElemLatents) + new_mass = @trace(trunc_norm(latents.data.mass, .1, 0., Inf), :mass) + new_restitution = @trace(trunc_norm(latents.data.restitution, .1, 0., 1.), :restitution) + + new_latents = setproperties(latents.data; + mass = new_mass, + restitution = new_restitution) + new_latents = RigidBodyLatents(new_latents) + return new_latents +end + +""" +in case of collision: Gaussian drift update of mass and restitution +""" +@gen function _collision_clause(pair_idx::Vector{Int64}, latents::Vector{BulletElemLatents}) + latents[pair_idx[1]] = @trace(update_latents(latents[pair_idx[1]]), :new_latents_a) + latents[pair_idx[2]] = @trace(update_latents(latents[pair_idx[2]]), :new_latents_b) + return latents +end + +function clause(::Type{Collision}) + _collision_clause +end + +""" +in case of no event: no change +""" +@gen function _no_event_clause(pair_idx, latents::Vector{BulletElemLatents}) + return latents +end + +function clause(::Type{NoEvent}) + _no_event_clause +end + +event_concepts = Type{<:EventRelation}[NoEvent, Collision] +switch = Gen.Switch(map(clause, event_concepts)...) + +""" +TODO: this function was intended to check if some event relations are impossible to be created a certain time step +""" +function valid_relations(state::CPState, event_concepts::Vector{Type{EventRelation}}) + return event_concepts + # TODO: replace by map in the end + for EventRelation in event_concepts + # TODO: decide if valid + end +end + +""" +map possible events to weight vector for birth decision using the predicates +""" +function calculate_predicates(obj_states) + object_pairs = collect(combinations(obj_states, 2)) + pair_idx = repeat(collect(combinations(1:length(obj_states), 2)), length(event_concepts)) + pair_idx = [[0,0], pair_idx...] # [0,0] for no event + + # break up to two lines + predicates = [predicate(event_type, a, b) for event_type in event_concepts for (a, b) in object_pairs if event_type != NoEvent] # NoEvent excluded and added in weights + event_ids = vcat(1, repeat(2:length(event_concepts), inner=length(object_pairs))) # 1 for NoEvent + return predicates, event_ids, pair_idx +end + +""" +transform predicates for pairs of objects into a probability vector that adds to 1, including one weight for NoEvent at the first position +""" +function normalize_weights(weights, active_events) + for idx in active_events # active events should not be born again + weights[idx-1] = 0 + end + weights = [max(0, 1 - sum(weights)), weights...] # first element for NoEvent + # TODO: objects that are already involved in some events should not be involved in other event types as well + return weights ./ sum(weights) +end + +""" +similar to normalize_weights but for death of events +""" +function calculate_death_weights(predicates, active_events, start_event_idx, death_factor) + can_die(idx) = idx+1 in active_events && idx+1 != start_event_idx + # dying has a much lower chance of being born + get_weight(idx) = can_die(idx) ? max(1. - predicates[idx] * death_factor, 0.) : 0.0 + weights = [get_weight(idx) for idx in 1:length(predicates)] + weights = [max(0, 1-sum(weights)), weights...] # no event at index 1 + return weights ./ sum(weights) +end + +""" +updates active events in a functional form +add=True -> add event to set of active events +add=False -> remove event from set of active events +""" +function update_active_events(active_events::Set{Int64}, event_idx::Int64, add::Bool) + if event_idx == 1 + return active_events + end + if add + return union(active_events, Set([event_idx])) + else + return setdiff(active_events, Set([event_idx])) + end +end + + +""" +iterate over event concepts and evaluate predicates for newly activated events +""" +@gen function event_kernel(active_events, bullet_state, death_factor) + predicates, event_ids, pair_idx = calculate_predicates(bullet_state.kinematics) + weights = normalize_weights(copy(predicates), active_events) + start_event_idx = @trace(categorical(weights), :start_event_idx) # up to one event is born + + updated_latents = @trace(switch(event_ids[start_event_idx], pair_idx[start_event_idx], bullet_state.latents), :event) + bullet_state = setproperties(bullet_state; latents = updated_latents) + active_events = update_active_events(active_events, start_event_idx, true) + + weights = calculate_death_weights(predicates, active_events, start_event_idx, death_factor) + end_event_idx = @trace(categorical(weights), :end_event_idx) # up to one active event dies + active_events = update_active_events(active_events, end_event_idx, false) + + return active_events, bullet_state +end + +""" +for one object, observe the noisy position in every dimension +""" +@gen function observe_position(k::RigidBodyState, noise::Float64) + @trace(broadcasted_normal(k.position, noise), :positions) +end + +""" +run event and physics kernel for one time step and observe noisy positions +""" +@gen function kernel(t::Int, prev_state::CPState, params::CPParams) + active_events, bullet_state = @trace(event_kernel(prev_state.active_events, + prev_state.bullet_state, + params.death_factor), :events) + + bullet_state::BulletState = PhySMC.step(params.sim, bullet_state) + @trace(Gen.Map(observe_position)(bullet_state.kinematics, Fill(params.obs_noise, params.n_objects)), :observe) + + return CPState(bullet_state, active_events) +end + +""" +generate physical scene with changepoints in the belief state +""" +@gen function cp_model(t::Int, params::CPParams) + # initalize the kinematic and event state + init_state = @trace(cp_prior(params), :prior) + + # unfold the event and kinematic state over time + states = @trace(Gen.Unfold(kernel)(t, init_state, params), :kernel) + return states +end diff --git a/src/gms/gms.jl b/src/gms/gms.jl index 34e24fe..164d556 100644 --- a/src/gms/gms.jl +++ b/src/gms/gms.jl @@ -1,3 +1,5 @@ +using FillArrays + export GMParams, GMState, Material, @@ -56,7 +58,7 @@ $(TYPEDSIGNATURES) A uniform prior over given materials """ -function MaterialPrior(ms::Vector{Material}) +function MaterialPrior(ms::Vector{<: Material}) n = length(ms) ws = Fill(1.0 / n, n) MaterialPrior(ms, ws) @@ -78,4 +80,4 @@ struct PhysPrior end include("mc_gm.jl") -# include("cp_gm.jl") +include("cp_gm_pb.jl") diff --git a/src/gms/mc_gm.jl b/src/gms/mc_gm.jl index d139915..ab38f66 100644 --- a/src/gms/mc_gm.jl +++ b/src/gms/mc_gm.jl @@ -17,11 +17,12 @@ $(TYPEDFIELDS) struct MCParams <: GMParams # prior material_prior::MaterialPrior - physics_prior::Vector{PhysPrior} + physics_prior::PhysPrior # simulation sim::BulletSim template::BulletState n_objects::Int64 + obs_noise::Float64 end """ @@ -29,8 +30,9 @@ $(TYPEDSIGNATURES) Initializes `MCParams` from a constructed scene in pybullet. """ -function MCParams(client::PyObject, objs::Vector{PyObject}, - mprior::MaterialPrior, pprior::Vector{PhysPrior}) +function MCParams(client::Int64, objs::Vector{Int64}, + mprior::MaterialPrior, pprior::PhysPrior, + obs_noise::Float64=0.) # configure simulator with the provided # client id sim = BulletSim(;client=client) @@ -41,7 +43,7 @@ function MCParams(client::PyObject, objs::Vector{PyObject}, # Note: alternative latents will be suggested by the `prior` template = BulletState(sim, rigid_bodies) - MCParams(mprior, pprior, sim, template, length(objs)) + MCParams(mprior, pprior, sim, template, length(objs), obs_noise) end struct MCState <: GMState @@ -56,9 +58,9 @@ end @gen function mc_object_prior(ls::RigidBodyLatents, gm::MCParams) # sample material - mi = @trace(categorical(gm.material_weights), :material) + mi = @trace(categorical(gm.material_prior.material_weights), :material) # sample physical properties - phys_params = gm.phys_params[mi] + phys_params = gm.physics_prior mass_mu, mass_sd = phys_params.mass mass = @trace(trunc_norm(mass_mu, mass_sd, 0., Inf), :mass) @@ -94,7 +96,8 @@ end @gen function kernel(t::Int, prev_state::MCState, gm::MCParams) sim_step::BulletState = PhySMC.step(gm.sim, prev_state.bullet_state) - obs = @trace(Gen.Map(observe_position)(sim_step.kinematics), :observe) + noises = Fill(gm.obs_noise, length(sim_step.kinematics)) + obs = @trace(Gen.Map(observe_position)(sim_step.kinematics, noises), :observe) next_state = MCState(sim_step) return next_state end diff --git a/src/utils/distributions.jl b/src/utils/distributions.jl index c8dc7b2..c597f82 100644 --- a/src/utils/distributions.jl +++ b/src/utils/distributions.jl @@ -8,14 +8,14 @@ struct NoisyMatrix <: Gen.Distribution{Array{Float64}} end const mat_noise = NoisyMatrix() -function Gen.logpdf(::NoisyMatrix, x::Array{Float64}, mu::Array{U}, noise::T) where {U<:Real,T<:Real} +function Gen.logpdf(::NoisyMatrix, x::Array{Float64}, mu::Array{<:Real}, noise::T) where {T<:Real} var = noise * noise diff = x - mu vec = diff[:] return -(vec' * vec)/ (2.0 * var) - 0.5 * log(2.0 * pi * var) end; -function Gen.random(::NoisyMatrix, mu::Array{U}, noise::T) where {U<:Real,T<:Real} +function Gen.random(::NoisyMatrix, mu::Array{<:Real}, noise::T) where {T<:Real} mat = copy(mu) for i in CartesianIndices(mu) mat[i] = mu[i] + randn() * noise @@ -30,7 +30,7 @@ struct LogUniform <: Gen.Distribution{Float64} end const log_uniform = LogUniform() -function Gen.logpdf(::LogUniform, x::Float64, low::T, high::T) where {U<:Real,T<:Real} +function Gen.logpdf(::LogUniform, x::Float64, low::T, high::T) where {T<:Real} l = log(low) h = log(high) v = log(x) @@ -38,7 +38,7 @@ function Gen.logpdf(::LogUniform, x::Float64, low::T, high::T) where {U<:Real,T< return (v >= l && v <= h) ? -log(h-l) : -Inf end -function Gen.random(::LogUniform, low::T, high::T) where {U<:Real,T<:Real} +function Gen.random(::LogUniform, low::T, high::T) where {T<:Real} d = uniform(log(low), log(high)) exp(d) end diff --git a/src/utils/ramp.obj b/src/utils/ramp.obj new file mode 100644 index 0000000..59fd2ad --- /dev/null +++ b/src/utils/ramp.obj @@ -0,0 +1,15 @@ +v 0.000000 0.000000 0.000000 + +v 1.000000 0.000000 0.000000 + +v 1.000000 1.000000 0.000000 + +v 0.000000 1.000000 0.000000 + +v 0.000000 0.000000 1.000000 + +v 0.000000 1.000000 1.000000 + +f 1 2 3 4 + +f 1 2 5 6 \ No newline at end of file diff --git a/src/utils/scenes.jl b/src/utils/scenes.jl index 954d1c1..fc1f3dc 100644 --- a/src/utils/scenes.jl +++ b/src/utils/scenes.jl @@ -6,8 +6,11 @@ export ramp obj_positions::NTuple{2}; slope, ramp_intersection) """ function ramp( + mass_ratio::Float64, + obj_frictions::NTuple{2, Float64} = (.5, .5), + obj_positions::NTuple{2, Float64} = (0.5, 1.5), slope::Float64=2/3, - tableRampIntersection::Float64=0., + tableRampIntersection::Float64=0. ) # for debugging #client = @pycall pb.connect(pb.GUI)::Int64 @@ -56,7 +59,8 @@ function ramp( end # add a ramp - ramp_col_id = pb.createCollisionShape(pb.GEOM_MESH, fileName="examples/ramp/ramp.obj", physicsClientId=client, meshScale=[2, base_dims[2], slope*2]) + pb.setAdditionalSearchPath("/project") + ramp_col_id = pb.createCollisionShape(pb.GEOM_MESH, fileName="src/utils/ramp.obj", physicsClientId=client, meshScale=[2, base_dims[2], slope*2]) ramp_position = [-2+tableRampIntersection, -base_dims[2]/2, 0] ramp_obj_id = pb.createMultiBody(baseCollisionShapeIndex=ramp_col_id, basePosition=ramp_position, physicsClientId=client) pb.changeDynamics(ramp_obj_id, -1; mass=0.0, restitution=0.9, physicsClientId=client) @@ -90,18 +94,18 @@ function ramp( obj_on_ramp_col_id = pb.createCollisionShape(pb.GEOM_BOX, halfExtents=obj_ramp_dims/2, physicsClientId=client) lift = obj_ramp_dims[3]/2 position = [ - -1+tableRampIntersection+lift*cos(theta_radians), + -2+2*obj_positions[1]+tableRampIntersection+lift*cos(theta_radians), 0, - 1*slope-lift*sin(theta_radians) + (2-2*obj_positions[1])*slope-lift*sin(theta_radians) ] obj_on_ramp_obj_id = pb.createMultiBody(baseCollisionShapeIndex=obj_on_ramp_col_id, basePosition=position, baseOrientation=orientation, physicsClientId=client) - pb.changeDynamics(obj_on_ramp_obj_id, -1; mass=1.0, restitution=0.9, physicsClientId=client) - + pb.changeDynamics(obj_on_ramp_obj_id, -1; mass=mass_ratio, restitution=0.9, lateralFriction=obj_frictions[1], physicsClientId=client) + # add an object on the table that will collide with the object on the ramp as that one slides down obj_on_table_dims = [0.2, 0.2, 0.1] obj_on_table_col_id = pb.createCollisionShape(pb.GEOM_BOX, halfExtents=obj_on_table_dims/2, physicsClientId=client) - obj_on_table_obj_id = pb.createMultiBody(baseCollisionShapeIndex=obj_on_table_col_id, basePosition=[1, 0, obj_on_table_dims[3]/2], physicsClientId=client) - pb.changeDynamics(obj_on_table_obj_id, -1; mass=1.0, restitution=0.9, physicsClientId=client) + obj_on_table_obj_id = pb.createMultiBody(baseCollisionShapeIndex=obj_on_table_col_id, basePosition=[2.5*(obj_positions[2]-1), 0, obj_on_table_dims[3]/2], physicsClientId=client) + pb.changeDynamics(obj_on_table_obj_id, -1; mass=1.0, restitution=0.9, lateralFriction=obj_frictions[2], physicsClientId=client) (client, obj_on_ramp_obj_id, obj_on_table_obj_id) end diff --git a/test/gms/cp_gm.jl b/test/gms/cp_gm.jl new file mode 100644 index 0000000..930ecaf --- /dev/null +++ b/test/gms/cp_gm.jl @@ -0,0 +1,236 @@ +using Revise +using Gen +using GalileoEvents +using Plots +ENV["GKSwstype"]="160" # fixes some plotting warnings + +mass_ratio = 2.0 +obj_frictions = (0.3, 0.3) +obj_positions = (0.5, 1.2) + +mprior = MaterialPrior([unknown_material]) +pprior = PhysPrior((3.0, 10.0), # mass + (0.5, 10.0), # friction + (0.2, 1.0)) # restitution + +obs_noise = 0.05 +t = 80 + +fixed_prior_cm = Gen.choicemap() +fixed_prior_cm[:prior => :objects => 1 => :mass] = 2. +fixed_prior_cm[:prior => :objects => 2 => :mass] = 1. +fixed_prior_cm[:prior => :objects => 1 => :friction] = 0.5 +fixed_prior_cm[:prior => :objects => 2 => :friction] = 0.5 +fixed_prior_cm[:prior => :objects => 1 => :restitution] = 0.2 +fixed_prior_cm[:prior => :objects => 2 => :restitution] = 0.2 + +function forward_test() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + trace, weight = Gen.generate(cp_model, (t, cp_params)); + @show weight + #display(get_choices(trace)) +end + +function add_rectangle!(plt, xstart, xend, y; height=0.8, color=:blue) + xvals = [xstart, xend, xend, xstart, xstart] + yvals = [y, y, y+height, y+height, y] + plot!(plt, xvals, yvals, fill=true, seriestype=:shape, fillcolor=color, linecolor=color) +end + +get_x2(trace, t) = get_retval(trace)[t].bullet_state.kinematics[2].position[1] + +function visualize_active_events() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + num_traces = 50 + plt = plot(legend=false, xlim=(0, t), ylim=(1, num_traces+1), yrotation=90, ylabel="Trace", yticks=false, xlabel="Time step") + collision_t = nothing + for i in 1:num_traces + if i % 10 == 0 + @show i + end + trace, _ = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm); + + start = nothing + first_x = i==1 ? get_x2(trace, 1) : nothing # only look for collision in first trace + for j in 1:t + if trace[:kernel=>j=>:events=>:start_event_idx]==2 + start = j + end + if trace[:kernel=>j=>:events=>:end_event_idx]==2 + finish = j + add_rectangle!(plt, start, finish, i) + end + if first_x !== nothing && abs(first_x - get_x2(trace, j)) > 0.001 + collision_t = j + first_x = nothing + end + end + + end + vline!(plt, [collision_t], linecolor=:red, linewidth=2, label="Vertical Line") + savefig(plt, "test/gms/plots/events.png") +end + +# constrained generation, event 2 must start at timestep 10 +function constrained_test() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + #addr = 10 => :events => :start_event_idx + #cm = Gen.choicemap(addr => 2) + trace, weight = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm) + @show weight + #display(get_choices(trace)) +end + +# update priors +function update_test() + t = 120 + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + trace, _ = Gen.generate(cp_model, (t, cp_params)) + + addr = :prior => :objects => 1 => :mass + cm = Gen.choicemap(addr => trace[addr] + 3) + trace2, _ = Gen.update(trace, cm) + + # compare final positions + t=120 + pos1 = Vector(get_retval(trace)[t].bullet_state.kinematics[1].position) + pos2 = Vector(get_retval(trace2)[t].bullet_state.kinematics[1].position) + @assert pos1 != pos2 + + return trace, trace2 +end + +# change event start +function update_test_2() + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + # generate initial trace + trace, ls = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm) + + # find first collision in the trace + start_event_indices = [trace[:kernel=>i=>:events=>:start_event_idx] for i in 1:t] + t1 = findfirst(x -> x == 2, start_event_indices) + @show ls + choices = get_choices(trace) + display(get_submap(choices, :kernel => t1 => :events)) + + # TODO: validate existence of event + # move first collision five steps earlier + cm = choicemap() + cm[:kernel => t1 => :events => :start_event_idx] = 1 + cm[:kernel => t1 - 5 => :events => :start_event_idx] = 2 + trace2, ls2, _... = Gen.update(trace, cm) + #@show ls2 + choices = get_choices(trace2) + #display(get_submap(choices, :kernel => t1 => :events)) + #display(get_submap(choices, :kernel => t1 -5 => :events)) + + # the keys have to be enumerated, subsets do not work + trace3, delta_s, _... = Gen.regenerate(trace, select( + :kernel => t1 => :events => :event => :new_latents_a => :mass,)) + #:kernel => t1 => :events => :event => :new_latents_a => :restitution)) + + @show delta_s + choices2 = get_choices(trace3) + display(get_submap(choices2, :kernel => t1 => :events)) + + + for i in 1:t + if project(trace3, select(:kernel => i)) == -Inf + @show i + display(get_submap(choices2, :kernel => i => :events)) + end + end + @show t1 + @show project(trace3, select(:kernel)) + @assert delta_s != -Inf + @assert !isnan(delta_s) + + #return trace, trace2 +end + +# redraw latents at same event start +function update_test_3() + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + # generate initial trace + trace, _ = Gen.generate(cp_model, (t, cp_params)) + + # find first collision in the trace + start_event_indices = [trace[:kernel=>i=>:events=>:start_event_idx] for i in 1:t] + t1 = findfirst(x -> x == 2, start_event_indices) + + # in future maybe gaussian rw + trace2, delta_s, _... = Gen.regenerate(trace, select(:kernel => t1 => :events => :event)) + + @assert delta_s != -Inf + @assert delta_s != NaN + + return trace, trace2 +end + + +# test switch combinator in terms of gen's reaction to proposed changes + +# toy model for dealing with complexing +# random walk with 2 delta functions (gaussian vs uniform) chosen by switch +# initial trace is changed by a mh proposal for switch index +# static first, unfold complexity second step + +@gen function function1() + v ~ normal(0., 1.) +end + +@gen function function2() + v ~ uniform(-1., 1.) +end + +switch = Gen.Switch(function1, function2) + +@gen function switch_model_static() + function_idx = @trace(categorical([0.5, 0.5]), :function) + x = @trace(switch(function_idx), :x) + y = @trace(normal(x, 1.), :y) +end + +function switch_test_static() + # unconstrained generation + trace, _ = Gen.generate(switch_model_static, ()) + display(get_choices(trace)) + + # constrained generation + cm = Gen.choicemap(:function => 1) + trace2, _ = Gen.generate(switch_model_static, (), cm) + display(get_choices(trace2)) + + # update and regenerate trace + trace3, _ = Gen.update(trace, cm) + trace4, _ = Gen.regenerate(trace3, select(:x)) + display(get_choices(trace4)) +end + +#forward_test() +#visualize_active_events() +#constrained_test() +#update_test() +update_test_2() +#update_test_3() +#switch_test_static() \ No newline at end of file diff --git a/test/gms/mc_gm.jl b/test/gms/mc_gm.jl index d00e39a..5257958 100644 --- a/test/gms/mc_gm.jl +++ b/test/gms/mc_gm.jl @@ -3,35 +3,46 @@ using GalileoEvents mass_ratio = 2.0 obj_frictions = (0.3, 0.3) -obj_positions = () # TODO fill in... +obj_positions = (0.5, 1.5) mprior = MaterialPrior([unknown_material]) pprior = PhysPrior((3.0, 10.0), # mass (0.5, 10.0), # friction (0.2, 1.0)) # restitution +obs_noise = 0.05 +t = 120 -t = 60 - -# TODO: this should evaluate without errors function forward_test() client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) - mc_params = MCParams(client, [a,b], mprior, pprior) - trace, _ = Gen.generate(t, mc_gm) - display(get_choices(trace)) + mc_params = MCParams(client, [a,b], mprior, pprior, obs_noise) + trace, _ = Gen.generate(mc_gm, (t, mc_params)) + #display(get_choices(trace)) end -# TODO: use `Gen.update` to change a traces physical latents and -# compare the final positions (they should be different) function update_test() client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) - mc_params = MCParams(client, [a,b], mprior, pprior) - trace, _ = Gen.generate(t, mc_gm) + mc_params = MCParams(client, [a,b], mprior, pprior, obs_noise) + trace, _ = Gen.generate(mc_gm, (t, mc_params)) addr = :prior => :objects => 1 => :mass - cm = Gen.choicemap(addr => trace[addr] + 3.0) - trace2 = Gen.update(trace, cm) - + cm = Gen.choicemap(addr => trace[addr] + 3) + trace2, _ = Gen.update(trace, cm) # compare final positions + t=120 + pos1 = Vector(get_retval(trace)[t].bullet_state.kinematics[1].position) + pos2 = Vector(get_retval(trace2)[t].bullet_state.kinematics[1].position) + @assert pos1 != pos2 + + return trace, trace2 end + +function main() + forward_test() + update_test() +end + +if abspath(PROGRAM_FILE) == @__FILE__ + main() +end \ No newline at end of file diff --git a/test/gms/plots/events.png b/test/gms/plots/events.png new file mode 100644 index 0000000..1924a8b Binary files /dev/null and b/test/gms/plots/events.png differ diff --git a/test/gms/plots/x_positions.png b/test/gms/plots/x_positions.png new file mode 100644 index 0000000..74b83e3 Binary files /dev/null and b/test/gms/plots/x_positions.png differ