Skip to content

Commit

Permalink
core: revamp two_columns to fix the phantom space
Browse files Browse the repository at this point in the history
  • Loading branch information
sorawee committed Feb 14, 2024
1 parent 434d56f commit 8baf30e
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 106 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 0.4 (2024-02-14)

* Fix a critical issue in `two_columns`: remove phantom spaces,
and adjust costs to ensure optimality.
* Mark `two_columns` as experimental.

## 0.3 (2024-02-11)

* Add the `two_columns` construct
Expand Down
4 changes: 2 additions & 2 deletions doc/index.mld
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ let my_cost_factory ~page_width ?computation_width () =
let string_of_cost (c, s) = Printf.sprintf "(%s %d)" (F.string_of_cost c) s

let debug_format = F.debug_format
end: Signature.CostFactory with type t = (int * int * int * int) * int)
end: Signature.CostFactory with type t = (int * int * int) * int)
]}

We now construct a function to convert an S-expression into a document,
Expand All @@ -387,7 +387,7 @@ let revised_print_sexp (s : sexp) (w : int) =
let xs_d = List.map pretty xs in
lparen <+>
(acat (x_d :: xs_d) <|> (* the horizontal style *)
(cost ((0, 0, 0, 0), 1) (vcat (x_d :: xs_d))) <|> (* the vertical style -- penalized *)
(cost ((0, 0, 0), 1) (vcat (x_d :: xs_d))) <|> (* the vertical style -- penalized *)
(x_d <+> space <+> vcat xs_d)) <+> (* the argument list style *)
rparen
in
Expand Down
2 changes: 1 addition & 1 deletion dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

(documentation https://sorawee.github.io/pretty-expressive-ocaml/)

(version 0.3)
(version 0.4)

(using mdx 0.4)

Expand Down
149 changes: 99 additions & 50 deletions lib/printer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ module Core (C : Signature.CostFactory) = struct
| Context of (int -> int -> doc)
(* invariant: the list length >= 2 *)
| TwoColumns of (doc * doc) list
| Blank of int
| Evaled of measure_set
| Fail

type cost = C.t
Expand Down Expand Up @@ -77,6 +79,12 @@ module Core (C : Signature.CostFactory) = struct

let text s = make_text (One s) (String.length s)

let blank i = { dc = Blank i;
id = next_id ();
nl_cnt = 0;
table = None;
memo_w = init_memo_w }

let rec cost c d =
match d.dc with
| Fail -> fail
Expand Down Expand Up @@ -141,13 +149,22 @@ module Core (C : Signature.CostFactory) = struct
nl_cnt = d.nl_cnt;
table = init_table memo_w }

(* Only for internal use. Don't provide it. *)
let context f nl_cnt =
{ dc = Context f;
id = next_id ();
memo_w = 0;
nl_cnt;
table = init_table 0 }

(* Only for internal use. Don't provide it. *)
let evaled ms nl_cnt =
{ dc = Evaled ms;
id = next_id ();
memo_w = 0;
nl_cnt;
table = init_table 0 }

let (<|>) d1 d2 =
if d1 == fail then d2
else if d2 == fail then d1
Expand Down Expand Up @@ -268,6 +285,68 @@ module Core (C : Signature.CostFactory) = struct
| MeasureSet (m :: _) -> m
| _ -> failwith "unreachable"

let do_two_columns self ds c =
let left_ms = List.map (fun (d1, _) -> self d1 c c) ds in
let left_any_tainted = List.exists
(fun ms ->
match ms with
| Tainted _ -> true
| _ -> false)
left_ms in
let rec loop_limit
(before : (doc * doc) list)
(after_ms : measure_set list)
(after : (doc * doc) list) =
match (after_ms, after) with
| ([], []) -> fail
| (ms :: after_ms, (left, right) :: after) ->
let build c_sep ds =
List.map (fun (d1, d2) ->
d1 ^^
context (fun c_in _ ->
if c_sep >= c_in then
blank (c_sep - c_in)
else
cost (C.two_columns_overflow (c_in - c_sep)) empty) 0 ^^
d2) ds |> vcat |> fun d -> cost (C.two_columns_bias (c_sep - c)) d
in
let build_choice c_sep ms =
build
c_sep
(List.rev_append before ((evaled ms left.nl_cnt, right) :: after))
in
(match ms with
| Tainted mt ->
let m = mt () in
build_choice m.last ms
| MeasureSet ms ->
let rec loop_inner ms =
match ms with
| [] -> fail
| m :: ms -> build_choice m.last (MeasureSet [m]) <|> loop_inner ms
in
loop_inner ms <|> loop_limit ((left, right) :: before) after_ms after)
| _ -> failwith "unreachable"
in
(* NOTE: we can get the nl_cnt here to be precise with some tracking.
Do we want to do that? *)
let make_doc ms (d1, d2) =
let ms = match ms with
(* force evaluation, so that we can share the outer shell freely *)
| Tainted mt -> let m = mt () in Tainted (fun () -> m)
| MeasureSet _ -> ms
in (ms, (evaled ms d1.nl_cnt, d2))
in
let get_measure_set () =
let (after_ms, after) = List.split (List.map2 make_doc left_ms ds) in
let d = loop_limit [] after_ms after in
self d c c
in
if left_any_tainted then
Tainted (fun () -> get_measure_set () |> choose_one)
else
get_measure_set ()

let pretty_print_info
?(init_c = 0)
(renderer : Signature.renderer)
Expand Down Expand Up @@ -299,45 +378,12 @@ module Core (C : Signature.CostFactory) = struct
| MeasureSet ms -> MeasureSet (List.map add_cost ms)
| Tainted mt -> Tainted (fun () -> add_cost (mt ())))
| Context f -> self (f c i) c i
| TwoColumns ds ->
let left_ms = List.map (fun (d1, _) -> self d1 c c) ds in
let left_any_tainted = List.exists
(fun ms ->
match ms with
| Tainted _ -> true
| _ -> false)
left_ms in
let get_positions () =
let left_lasts = List.map (fun ms ->
match ms with
| MeasureSet ms -> List.map (fun m -> m.last) ms
| Tainted mt -> let m = mt () in [m.last]) left_ms in
List.sort_uniq compare (List.flatten left_lasts)
in
let rec loop_limit (rank : int) (posns : int list) =
match posns with
| [] -> fail
| current_limit :: rest ->
let trans_ds = List.map (fun (d1, d2) ->
d1 ^^
context (fun c_in _ ->
if current_limit >= c_in then
text (String.make (current_limit - c_in) ' ')
else
cost (C.two_columns_overflow (c_in - current_limit)) empty) 0 ^^
d2) ds
in
cost (C.two_columns_bias rank) (vcat trans_ds) <|>
loop_limit (rank + 1) rest
in
let get_measure_set () =
let d = get_positions () |> loop_limit 0 |> align in
self d c i
in
if left_any_tainted then
Tainted (fun () -> get_measure_set () |> choose_one)
else
get_measure_set ()
| TwoColumns ds -> do_two_columns self ds c
| Blank i ->
MeasureSet [{ last = c + i;
cost = C.text 0 0;
layout = fun () -> renderer (String.make i ' ') }]
| Evaled ms -> ms
| Fail -> failwith "fails to render"
in
let exceeds = match dc with
Expand Down Expand Up @@ -402,7 +448,8 @@ module Make (C : Signature.CostFactory): (Signature.PrinterT with type cost = C.
if idp = id then d else cost c dp
(* There are at least two lines, so it can't be flattened *)
| TwoColumns _ -> fail
| Context _ -> failwith "unreachable"
| Blank _ -> d
| Context _ | Evaled _ -> failwith "unreachable"
in
Hashtbl.add cache id out;
out
Expand Down Expand Up @@ -465,7 +512,7 @@ let make_debug_format page_width content is_tainted cost =
(* $MDX part-begin=default_cost_factory *)
let default_cost_factory ~page_width ?computation_width () =
(module struct
type t = int * int * int * int
type t = int * int * int

let limit = match computation_width with
| None -> (float_of_int page_width) *. 1.2 |> int_of_float
Expand All @@ -477,23 +524,25 @@ let default_cost_factory ~page_width ?computation_width () =
let maxwc = max page_width pos in
let a = maxwc - page_width in
let b = stop - maxwc in
(b * (2*a + b), 0, 0, 0)
(b * (2*a + b), 0, 0)
else
(0, 0, 0, 0)
(0, 0, 0)

let newline _ = (0, 0, 1, 0)
let newline _ = (0, 0, 1)

let combine (o1, ot1, h1, bt1) (o2, ot2, h2, bt2) =
(o1 + o2, ot1 + ot2, h1 + h2, bt1 + bt2)
let combine (o1, ot1, h1) (o2, ot2, h2) =
(o1 + o2, ot1 + ot2, h1 + h2)

let le c1 c2 = c1 <= c2

let two_columns_overflow w = (0, w, 0, 0)
let two_columns_overflow w = (0, w, 0)

let two_columns_bias w = (0, 0, 0, w)
let two_columns_bias _ = (0, 0, 0)

let string_of_cost (o, ot, h, bt) = Printf.sprintf "(%d %d %d %d)" o ot h bt
let string_of_cost (o, ot, h) = Printf.sprintf "(%d %d %d)" o ot h

let debug_format = make_debug_format page_width
end: Signature.CostFactory with type t = int * int * int * int)
end: Signature.CostFactory with type t = int * int * int)
(* $MDX part-end *)

let version = "0.4"
29 changes: 15 additions & 14 deletions lib/printer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ val make_debug_format : int -> string -> bool -> string -> string
containing these parameters. *)

val default_cost_factory : page_width:int -> ?computation_width:int -> unit ->
(module Signature.CostFactory with type t = int * int * int * int)
(module Signature.CostFactory with type t = int * int * int)
(** The default cost factory, parameterized by the page width limit [page_width],
and optionally {{!page-index.complimit}the computation width limit}
[computation_width].
Expand All @@ -25,16 +25,14 @@ val default_cost_factory : page_width:int -> ?computation_width:int -> unit ->
{ul {- The first component is {i badness}, which is roughly speaking
the sum of squared overflows over the page width limit}
{- The second component is sum of overflows over a column separator.}
{- The third component is the height (number of newlines).}
{- The fourth component is bias penalty to encourage toward choosing
a leftmost column separator.} }
{- The third component is the height (number of newlines).}}
Internally, [default_cost_factory] is defined as:
{@ocaml file=printer.ml,part=default_cost_factory[
let default_cost_factory ~page_width ?computation_width () =
(module struct
type t = int * int * int * int
type t = int * int * int
let limit = match computation_width with
| None -> (float_of_int page_width) *. 1.2 |> int_of_float
Expand All @@ -46,23 +44,26 @@ let default_cost_factory ~page_width ?computation_width () =
let maxwc = max page_width pos in
let a = maxwc - page_width in
let b = stop - maxwc in
(b * (2*a + b), 0, 0, 0)
(b * (2*a + b), 0, 0)
else
(0, 0, 0, 0)
(0, 0, 0)
let newline _ = (0, 0, 1, 0)
let newline _ = (0, 0, 1)
let combine (o1, ot1, h1, bt1) (o2, ot2, h2, bt2) =
(o1 + o2, ot1 + ot2, h1 + h2, bt1 + bt2)
let combine (o1, ot1, h1) (o2, ot2, h2) =
(o1 + o2, ot1 + ot2, h1 + h2)
let le c1 c2 = c1 <= c2
let two_columns_overflow w = (0, w, 0, 0)
let two_columns_overflow w = (0, w, 0)
let two_columns_bias w = (0, 0, 0, w)
let two_columns_bias _ = (0, 0, 0)
let string_of_cost (o, ot, h, bt) = Printf.sprintf "(%d %d %d %d)" o ot h bt
let string_of_cost (o, ot, h) = Printf.sprintf "(%d %d %d)" o ot h
let debug_format = make_debug_format page_width
end: Signature.CostFactory with type t = int * int * int * int)
end: Signature.CostFactory with type t = int * int * int)
]} *)

val version : string
(* a version string *)
Loading

0 comments on commit 8baf30e

Please sign in to comment.