Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let marginals return dict #61

Merged
merged 4 commits into from
Sep 5, 2023
Merged

Let marginals return dict #61

merged 4 commits into from
Sep 5, 2023

Conversation

GiggleLiu
Copy link
Member

The updated docstring

Query the marginals of the variables in a TensorNetworkModel.
The returned value is a dictionary of variables and their marginals, where a marginal is a joint probability distribution over the associated variables.
By default, the marginals of all individual variables are returned.
The marginal variables to query can be specified when constructing TensorNetworkModel as its field mars.
It will affect the contraction order of the tensor network.

Arguments

  • tn: the TensorNetworkModel to query.
  • usecuda: whether to use CUDA for tensor contraction.
  • rescale: whether to rescale the tensors during contraction.

Example

The following example is from examples/asia/main.jl.

julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));

julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077

julia> marginals(tn)
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
  [8] => [0.450138, 0.549863]
  [3] => [0.5, 0.5]
  [1] => [1.0]
  [5] => [0.45, 0.55]
  [4] => [0.055, 0.945]
  [6] => [0.10225, 0.89775]
  [7] => [0.145092, 0.854908]
  [2] => [0.05, 0.95]

julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443

julia> marginals(tn2)
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
  [2, 3] => [0.025 0.025; 0.475 0.475]
  [3, 4] => [0.05 0.45; 0.005 0.495]

In this example, we first set the evidence of variable 1 to 0, then we query the marginals of all individual variables.
The returned values is a dictionary, the key are query variables, and the value are the corresponding marginals.
The marginals are vectors, with its entries corresponding to the probability of the variable taking the value 0 and 1, respectively.
For evidence variable 1, the marginal is always [1.0], since it is fixed to 0.

Then we set the marginal variables to query to be variable 2 and 3, and variable 3 and 4, respectively.
The joint marginals may or may not increase the contraction time and space.
Here, the contraction space complexity is increased from 2^2.0 to 2^5.0, and the contraction time complexity is increased from 2^5.977 to 2^7.781.
The output marginals are joint probabilities of the query variables represented by tensors.

@codecov
Copy link

codecov bot commented Sep 5, 2023

Codecov Report

Merging #61 (6266d7c) into main (a066560) will increase coverage by 3.00%.
The diff coverage is 66.66%.

@@            Coverage Diff             @@
##             main      #61      +/-   ##
==========================================
+ Coverage   81.61%   84.61%   +3.00%     
==========================================
  Files          10       10              
  Lines         533      533              
==========================================
+ Hits          435      451      +16     
+ Misses         98       82      -16     
Files Changed Coverage Δ
src/mar.jl 94.54% <66.66%> (ø)

... and 1 file with indirect coverage changes

@mroavi mroavi merged commit 4d5b452 into main Sep 5, 2023
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants