From f1f276040eeecbc4e23bc09613ed8d44ed42ab4a Mon Sep 17 00:00:00 2001 From: hdavid16 Date: Tue, 2 Aug 2022 09:07:49 -0500 Subject: [PATCH] Fix #160 make self-loop edges curved behavior is regardless of the linetype --- src/GraphPlot.jl | 2 ++ src/lines.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++++ src/plot.jl | 29 ++++++++++------------------- 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/src/GraphPlot.jl b/src/GraphPlot.jl index c1738c1..13a0a0e 100644 --- a/src/GraphPlot.jl +++ b/src/GraphPlot.jl @@ -2,6 +2,8 @@ module GraphPlot using Compose # for plotting features using Graphs +using LinearAlgebra +using SparseArrays const gadflyjs = joinpath(dirname(Base.source_path()), "gadfly.js") diff --git a/src/lines.jl b/src/lines.jl index 4589641..e3d0a48 100644 --- a/src/lines.jl +++ b/src/lines.jl @@ -201,3 +201,49 @@ function curveedge(x1, y1, x2, y2, θ, outangle, d; k=0.5) return [(x1,y1) (xc1, yc1) (xc2, yc2) (x2, y2)] end + +function build_curved_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) + if arrowlengthfrac > 0.0 + curves_cord, arrows_cord = graphcurve(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) + curves = curve(curves_cord[:,1], curves_cord[:,2], curves_cord[:,3], curves_cord[:,4]) + carrows = line(arrows_cord) + else + curves_cord = graphcurve(g, locs_x, locs_y, nodesize, outangle) + curves = curve(curves_cord[:,1], curves_cord[:,2], curves_cord[:,3], curves_cord[:,4]) + carrows = nothing + end + + return curves, carrows +end + +function build_straight_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset) + if arrowlengthfrac > 0.0 + lines_cord, arrows_cord = graphline(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset) + lines = line(lines_cord) + larrows = line(arrows_cord) + else + lines_cord = graphline(g, locs_x, locs_y, nodesize) + lines = line(lines_cord) + larrows = nothing + end + + return lines, larrows +end + +function build_straight_curved_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) + A = adjacency_matrix(g) #adjacency matrix + B = spdiagm(diag(A)) #diagonal matrix (self-loops) + A[diagind(A)] .= 0 #set diagonal elements to 0 (remove self-loops) + if is_directed(g) + g1 = SimpleDiGraph(A) + g2 = SimpleDiGraph(B) + else + g1 = SimpleGraph(A) + g2 = SimpleGraph(B) + end + + lines, larrows = build_straight_edges(g1, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset) + curves, carrows = build_curved_edges(g2, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) + + return lines, larrows, curves, carrows +end \ No newline at end of file diff --git a/src/plot.jl b/src/plot.jl index af438b7..e1b91e2 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -206,33 +206,24 @@ function gplot(g::AbstractGraph{T}, end # Create lines and arrow heads - lines, arrows = nothing, nothing + lines, larrows = nothing, nothing + curves, carrows = nothing, nothing if linetype == "curve" - if arrowlengthfrac > 0.0 - curves_cord, arrows_cord = graphcurve(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) - lines = curve(curves_cord[:,1], curves_cord[:,2], curves_cord[:,3], curves_cord[:,4]) - arrows = line(arrows_cord) - else - curves_cord = graphcurve(g, locs_x, locs_y, nodesize, outangle) - lines = curve(curves_cord[:,1], curves_cord[:,2], curves_cord[:,3], curves_cord[:,4]) - end + curves, carrows = build_curved_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) + elseif has_self_loops(g) + lines, larrows, curves, carrows = build_straight_curved_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset, outangle) else - if arrowlengthfrac > 0.0 - lines_cord, arrows_cord = graphline(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset) - lines = line(lines_cord) - arrows = line(arrows_cord) - else - lines_cord = graphline(g, locs_x, locs_y, nodesize) - lines = line(lines_cord) - end + lines, larrows = build_straight_edges(g, locs_x, locs_y, nodesize, arrowlengthfrac, arrowangleoffset) end compose(context(units=UnitBox(-1.2, -1.2, +2.4, +2.4)), compose(context(), texts, fill(nodelabelc), stroke(nothing), fontsize(nodelabelsize)), compose(context(), nodes, fill(nodefillc), stroke(nodestrokec), linewidth(nodestrokelw)), compose(context(), edgetexts, fill(edgelabelc), stroke(nothing), fontsize(edgelabelsize)), - compose(context(), arrows, stroke(edgestrokec), linewidth(edgelinewidth)), - compose(context(), lines, stroke(edgestrokec), fill(nothing), linewidth(edgelinewidth))) + compose(context(), larrows, stroke(edgestrokec), linewidth(edgelinewidth)), + compose(context(), carrows, stroke(edgestrokec), linewidth(edgelinewidth)), + compose(context(), lines, stroke(edgestrokec), fill(nothing), linewidth(edgelinewidth)), + compose(context(), curves, stroke(edgestrokec), fill(nothing), linewidth(edgelinewidth))) end function gplot(g; layout::Function=spring_layout, keyargs...)