Skip to content

Commit

Permalink
Merge branch 'main' into dump-rival
Browse files Browse the repository at this point in the history
  • Loading branch information
pavpanchekha authored Jan 26, 2025
2 parents 94c8839 + c73a45c commit 5c378e9
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 88 deletions.
4 changes: 2 additions & 2 deletions src/api/demo.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@
`(form
([action ,(url improve)] [method "post"] [id "formula"] [data-progress ,(url improve-start)])
(textarea ([name "formula"] [autofocus "true"]
[placeholder "(FPCore (x) (- (sqrt (+ x 1)) (sqrt x)))"]))
(input ([name "formula-math"] [placeholder "sqrt(x + 1) - sqrt(x)"]))
[placeholder "e.g. (FPCore (x) (- (sqrt (+ x 1)) (sqrt x)))"]))
(input ([name "formula-math"] [placeholder "e.g. sqrt(x + 1) - sqrt(x)"]))
(table ([id "input-ranges"]))
(ul ([id "errors"]))
(ul ([id "warnings"]))
Expand Down
4 changes: 3 additions & 1 deletion src/core/bsearch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@
(define pts
(for/list ([(pt ex) (in-pcontext pcontext)])
pt))
; new-sampler returns: (cons (cons val pts) hint)
; Since the sampler does not call rival-analyze, the hint is set to #f
(define (new-sampler)
(cons val (random-ref pts)))
(cons (cons val (random-ref pts)) #f))
(apply mk-pcontext (cdr (batch-prepare-points evaluator new-sampler))))

(define/reset *prepend-arguement-cache* (make-hash))
Expand Down
45 changes: 22 additions & 23 deletions src/core/localize.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,20 @@
(- a b))) ; This `if` statement handles `inf - inf`

; rank subexpressions by cost opportunity
(define localize-costss
(for/list ([subexprs (in-list subexprss)])
(sort (reap [sow]
(for ([subexpr (in-list subexprs)])
(match subexpr
[(? literal?) (void)]
[(? symbol?) (void)]
[(approx _ impl)
(define cost-opp (cost-opportunity subexpr (list impl)))
(sow (cons cost-opp subexpr))]
[(list _ args ...)
(define cost-opp (cost-opportunity subexpr args))
(sow (cons cost-opp subexpr))])))
>
#:key car)))

localize-costss)
(for/list ([subexprs (in-list subexprss)])
(sort (reap [sow]
(for ([subexpr (in-list subexprs)])
(match subexpr
[(? literal?) (void)]
[(? symbol?) (void)]
[(approx _ impl)
(define cost-opp (cost-opportunity subexpr (list impl)))
(sow (cons cost-opp subexpr))]
[(list _ args ...)
(define cost-opp (cost-opportunity subexpr args))
(sow (cons cost-opp subexpr))])))
>
#:key car)))

(define (batch-localize-errors exprs ctx)
(define subexprss (map all-subexpressions exprs))
Expand Down Expand Up @@ -364,12 +361,14 @@
(define exact-error (~s (translate-booleans (first (hash-ref data 'exact-values)))))
(define actual-error (~s (translate-booleans (first (hash-ref data 'approx-values)))))
(define percent-accurate
(if (nan? (first (hash-ref data 'absolute-error)))
'invalid ; HACK: should specify if invalid or unsamplable
(let* ([repr (repr-of expr ctx)]
[total-bits (representation-total-bits repr)]
[bits-error (ulps->bits (first (hash-ref data 'ulp-errs)))])
(* 100 (- 1 (/ bits-error total-bits))))))
(cond
[(nan? (first (hash-ref data 'absolute-error)))
'invalid] ; HACK: should specify if invalid or unsamplable
[else
(define repr (repr-of expr ctx))
(define total-bits (representation-total-bits repr))
(define bits-error (ulps->bits (first (hash-ref data 'ulp-errs))))
(* 100 (- 1 (/ bits-error total-bits)))]))
(hasheq 'ulps-error
ulp-error
'avg-error
Expand Down
15 changes: 8 additions & 7 deletions src/core/programs.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@
[else
(let loop ([a a]
[b b])
(if (null? a)
0
(let ([cmp (expr-cmp (car a) (car b))])
(if (zero? cmp)
(loop (cdr a) (cdr b))
cmp))))])]
(cond
[(null? a) 0]
[else
(define cmp (expr-cmp (car a) (car b)))
(if (zero? cmp)
(loop (cdr a) (cdr b))
cmp)]))])]
[((? list?) _) 1]
[(_ (? list?)) -1]
[((? approx?) (? approx?))
Expand All @@ -131,7 +132,7 @@
[else 1])]))

(define (expr<? a b)
(< (expr-cmp a b) 0))
(negative? (expr-cmp a b)))

;; Converting constants

Expand Down
20 changes: 11 additions & 9 deletions src/core/rival.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
[ctxs (es) (and/c unified-contexts? (lambda (ctxs) (= (length es) (length ctxs))))])
(#:pre [pre any/c])
[c real-compiler?])]
[real-apply (-> real-compiler? list? (values symbol? any/c))]
[real-apply
(->* (real-compiler? list?) ((or/c (vectorof any/c) boolean?)) (values symbol? any/c))]
[real-compiler-clear! (-> real-compiler-clear! void?)]
[real-compiler-analyze (-> real-compiler? (vectorof ival?) ival?)]))
[real-compiler-analyze
(->* (real-compiler? (vectorof ival?))
((or/c (vectorof any/c) boolean?))
(listof any/c))]))

(define (unified-contexts? ctxs)
(and ((non-empty-listof context?) ctxs)
Expand Down Expand Up @@ -86,7 +90,7 @@
(real-compiler pre vars var-reprs specs reprs machine dump-file))

;; Runs a Rival machine on an input point.
(define (real-apply compiler pt)
(define (real-apply compiler pt [hint #f])
(match-define (real-compiler _ vars var-reprs _ _ machine dump-file) compiler)
(define start (current-inexact-milliseconds))
(define pt*
Expand All @@ -103,7 +107,8 @@
[exn:rival:unsamplable? (lambda (e) (values 'exit #f))])
(parameterize ([*rival-max-precision* (*max-mpfr-prec*)]
[*rival-max-iterations* 5])
(values 'valid (rest (vector->list (rival-apply machine pt*))))))) ; rest = drop precondition
(define value (rest (vector->list (rival-apply machine pt* hint)))) ; rest = drop precondition
(values 'valid value))))
(when (> (rival-profile machine 'bumps) 0)
(warn 'ground-truth
"Could not converge on a ground truth"
Expand Down Expand Up @@ -134,8 +139,5 @@
;; Returns whether the machine is guaranteed to raise an exception
;; for the given inputs range. The result is an interval representing
;; how certain the result is: no, maybe, yes.
(define (real-compiler-analyze compiler input-ranges)
(define res (rival-analyze (real-compiler-machine compiler) input-ranges))
(if (list? res)
(car res)
res))
(define (real-compiler-analyze compiler input-ranges [hint #f])
(rival-analyze (real-compiler-machine compiler) input-ranges hint))
38 changes: 23 additions & 15 deletions src/core/sampling.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,24 @@
(module+ test
(define rand-list
(let loop ([current 0])
(if (> current 200)
empty
(let ([r (+ current (random-integer 1 10))]) (cons r (loop r))))))
(cond
[(> current 200) empty]
[else
(define r (+ current (random-integer 1 10)))
(cons r (loop r))])))
(define arr (list->vector rand-list))
(for ([i (range 0 20)])
(define max-num (vector-ref arr (- (vector-length arr) 1)))
(define search-for (random-integer 0 max-num))
(define search-result (binary-search arr search-for))
(check-true (> (vector-ref arr search-result) search-for))
(when (> search-result 0)
(when (positive? search-result)
(check-true (<= (vector-ref arr (- search-result 1)) search-for)))))

(define (make-hyperrect-sampler hyperrects* reprs)
(define (make-hyperrect-sampler hyperrects* hints* reprs)
(when (null? hyperrects*)
(raise-herbie-sampling-error "No valid values." #:url "faq.html#no-valid-values"))
(define hints (list->vector hints*))
(define hyperrects (list->vector hyperrects*))
(define lo-ends
(for/vector #:length (vector-length hyperrects)
Expand All @@ -98,15 +101,19 @@
((representation-bf->repr repr) (ival-hi interval)))))))
(define weights (partial-sums (vector-map (curryr hyperrect-weight reprs) hyperrects)))
(define weight-max (vector-ref weights (- (vector-length weights) 1)))

;; returns (cons (listof pts) hint)
(λ ()
(define rand-ordinal (random-integer 0 weight-max))
(define idx (binary-search weights rand-ordinal))
(define los (vector-ref lo-ends idx))
(define his (vector-ref hi-ends idx))
(for/list ([lo (in-list los)]
[hi (in-list his)]
[repr (in-list reprs)])
((representation-ordinal->repr repr) (random-integer lo hi)))))
(define hint (vector-ref hints idx))
(cons (for/list ([lo (in-list los)]
[hi (in-list his)]
[repr (in-list reprs)])
((representation-ordinal->repr repr) (random-integer lo hi)))
hint)))

#;(module+ test
(define two-point-hyperrects (list (list (ival (bf 0) (bf 0)) (ival (bf 1) (bf 1)))))
Expand All @@ -123,12 +130,14 @@
(equal? (representation-type repr) 'real)))
(timeline-push! 'method "search")
(define hyperrects-analysis (precondition->hyperrects pre vars var-reprs))
(match-define (cons hyperrects sampling-table)
; hints-hyperrects is a (listof '(hint hyperrect))
(match-define (list hyperrects hints sampling-table)
(find-intervals compiler hyperrects-analysis #:fuel (*max-find-range-depth*)))
(cons (make-hyperrect-sampler hyperrects var-reprs) sampling-table)]
(cons (make-hyperrect-sampler hyperrects hints var-reprs) sampling-table)]
[else
(timeline-push! 'method "random")
(cons (λ () (map random-generate var-reprs)) (hash 'unknown 1.0))]))
; sampler return false hint since rival-analyze has not been called in random method
(cons (λ () (cons (map random-generate var-reprs) #f)) (hash 'unknown 1.0))]))

;; Returns an evaluator for a list of expressions.
(define (eval-progs-real specs ctxs)
Expand Down Expand Up @@ -156,9 +165,8 @@
[skipped 0]
[points '()]
[exactss '()])
(define pt (sampler))

(define-values (status exs) (real-apply compiler pt))
(match-define (cons pt hint) (sampler))
(define-values (status exs) (real-apply compiler pt hint))
(case status
[(exit)
(warn 'ground-truth
Expand Down
56 changes: 25 additions & 31 deletions src/core/searchreals.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@
(provide find-intervals
hyperrect-weight)

(struct search-space (true false other))
(struct search-space (true false other true-hints other-hints))

(define (make-search-space . ranges)
(search-space '() '() ranges))

(define (repr-round repr dir point)
((representation-repr->bf repr) (parameterize ([bf-rounding-mode dir])
((representation-bf->repr repr) point))))
(search-space '() '() ranges '() (make-list (length ranges) #f)))

(define (total-weight reprs)
(expt 2 (apply + (map representation-total-bits reprs))))
Expand All @@ -33,28 +29,20 @@
(compose (representation-repr->ordinal repr) (representation-bf->repr repr)))
(+ 1 (- (->ordinal (ival-hi interval)) (->ordinal (ival-lo interval)))))))

(define (midpoint repr lo hi)
; Midpoint is taken in repr-space, but values are stored in bf
(define <-ordinal (compose (representation-repr->bf repr) (representation-ordinal->repr repr)))
(define ->ordinal (compose (representation-repr->ordinal repr) (representation-bf->repr repr)))

(define lower (<-ordinal (floor (/ (+ (->ordinal hi) (->ordinal lo)) 2))))
(define higher (repr-round repr 'up (bfnext lower))) ; repr-next

(and (bf>= lower lo)
(bf<= higher hi) ; False if lo and hi were already close together
(cons lower higher)))

(define (search-step compiler space split-var)
(define vars (real-compiler-vars compiler))
(define reprs (real-compiler-var-reprs compiler))
(match-define (search-space true false other) space)
(define-values (true* false* other*)
(match-define (search-space true false other true-hints other-hints) space)
(define-values (true* false* other* true-hints* other-hints*)
(for/fold ([true* true]
[false* false]
[other* '()])
([rect (in-list other)])
(match-define (ival err err?) (real-compiler-analyze compiler (list->vector rect)))
[other* '()]
[true-hints* true-hints]
[other-hints* '()])
([rect (in-list other)]
[hint (in-list other-hints)])
(match-define (list (ival err err?) hint* converged?)
(real-compiler-analyze compiler (list->vector rect) hint))
(when (eq? err 'unsamplable)
(warn 'ground-truth
#:url "faq.html#ground-truth"
Expand All @@ -68,18 +56,23 @@
repr))
(format "~a = ~a" var val))))
(cond
[err (values true* (cons rect false*) other*)]
[(not err?) (values (cons rect true*) false* other*)]
[err (values true* (cons rect false*) other* true-hints* other-hints*)]
[(and (not err?) converged?)
(values (cons rect true*) false* other* (cons hint* true-hints*) other-hints*)]
[else
(define range (list-ref rect split-var))
(define repr (list-ref reprs split-var))
(match (midpoint repr (ival-lo range) (ival-hi range))
(match (two-midpoints repr (ival-lo range) (ival-hi range))
[(cons midleft midright)
(define rect-lo (list-set rect split-var (ival (ival-lo range) midleft)))
(define rect-hi (list-set rect split-var (ival midright (ival-hi range))))
(values true* false* (list* rect-lo rect-hi other*))]
[#f (values true* false* (cons rect other*))])])))
(search-space true* false* other*))
(values true*
false*
(list* rect-lo rect-hi other*)
true-hints*
(list* hint* hint* other-hints*))]
[#f (values true* false* (cons rect other*) true-hints* (cons hint* other-hints*))])])))
(search-space true* false* other* true-hints* other-hints*))

(define (make-sampling-table reprs true false other)
(define denom (total-weight reprs))
Expand All @@ -100,11 +93,12 @@
(map (curryr cons 'other) rects)
(let loop ([space (apply make-search-space rects)]
[n 0])
(match-define (search-space true false other) space)
(match-define (search-space true false other true-hints other-hints) space)
(timeline-push! 'sampling n (make-sampling-table var-reprs true false other))

(define n* (remainder n (length (first rects))))
(if (or (>= n depth) (empty? (search-space-other space)) (>= (length other) (expt 2 depth)))
(cons (append (search-space-true space) (search-space-other space))
(list (append (search-space-true space) (search-space-other space))
(append (search-space-true-hints space) (search-space-other-hints space))
(make-sampling-table var-reprs true false other))
(loop (search-step compiler space n*) (+ n 1))))))
17 changes: 17 additions & 0 deletions src/utils/float.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
(provide ulp-difference
ulps->bits
midpoint
two-midpoints
random-generate
</total
<=/total
Expand All @@ -29,6 +30,22 @@
((representation-repr->ordinal repr) p2))
2))))

(define (repr-round repr dir point)
((representation-repr->bf repr) (parameterize ([bf-rounding-mode dir])
((representation-bf->repr repr) point))))

(define (two-midpoints repr lo hi)
; Midpoint is taken in repr-space, but values are stored in bf
(define <-ordinal (compose (representation-repr->bf repr) (representation-ordinal->repr repr)))
(define ->ordinal (compose (representation-repr->ordinal repr) (representation-bf->repr repr)))

(define lower (<-ordinal (floor (/ (+ (->ordinal hi) (->ordinal lo)) 2))))
(define higher (repr-round repr 'up (bfnext lower))) ; repr-next

(and (bf>= lower lo)
(bf<= higher hi) ; False if lo and hi were already close together
(cons lower higher)))

(define (ulps->bits x)
(real->double-flonum (log x 2)))

Expand Down

0 comments on commit 5c378e9

Please sign in to comment.