Skip to content

Commit 52e5e6b

Browse files
committed
lifting rules clean uo
1 parent a44e644 commit 52e5e6b

File tree

2 files changed

+58
-46
lines changed

2 files changed

+58
-46
lines changed

src/core/egg-herbie.rkt

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,11 @@
471471
;; Translates a Herbie rule into an egg rule
472472
(define (rule->egg-rule ru)
473473
(define ru-vars (map car (rule-itypes ru)))
474+
(define lhs (expr->egg-pattern (rule-input ru) ru-vars))
475+
(define rhs (expr->egg-pattern (rule-output ru) ru-vars))
476+
(when (or (equal? (rule-tags ru) '(lowering)) (equal? (rule-tags ru) '(lifting)))
477+
(printf "~a -> ~a\n" lhs rhs)
478+
(sleep 2))
474479
(struct-copy rule
475480
ru
476481
[input (expr->egg-pattern (rule-input ru) ru-vars)]
@@ -510,15 +515,6 @@
510515
;; Uses a cache to only expand each rule once.
511516
(define (expand-rules rules)
512517
(reap [sow]
513-
(sow (cons #f (make-ffi-rule "lift-literal" "($literal ?repr ?a)" "($hole ?repr ?a)")))
514-
(sow (cons #f
515-
(make-ffi-rule "lift-approx"
516-
"($approx ?spec ($hole ?r ?t))"
517-
"($hole ?r ($approx ?spec ?t))")))
518-
(sow (cons #f
519-
(make-ffi-rule "lift-if"
520-
"(if ($hole bool ?c) ($hole ?r ?t) ($hole ?r ?f))"
521-
"($hole ?r (if ?c ?t ?f))")))
522518
(for ([rule (in-list rules)])
523519
(define egg&ffi-rules
524520
(hash-ref! (*egg-rule-cache*)

src/syntax/platform.rkt

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@
341341
(define spec (impl-info impl 'spec))
342342
(values vars spec (cons impl vars)))
343343

344-
;; Parses holes into spec-expr instead of vars
344+
;; Parses holes into expr instead of vars
345345
;; (_ x y) -> (_ ($hole 'binary64 x) ($hole 'binary64 y))
346346
(define (replace-vars-with-holes expr vars itypes)
347347
(define replacements
@@ -356,32 +356,60 @@
356356
[_ expr]))) ; it can be a literal or a number
357357

358358
;; Synthesizes lifting rules for a given platform.
359+
;; Lifting rule applies a rewrite like:
360+
;; (+.f64 ($hole 'binary64 x) ($hole 'binary64 y)) -> ($hole binary64 (+ x y))
361+
;; (hypot.f64 ($hole 'binary64 x) ($hole 'binary64 y)) -> ($hole binary64 (sqrt (* x x) (* y y))))
359362
(define (platform-lifting-rules [pform (*active-platform*)])
360363
;; every impl maps to a spec
361364
(define impls (platform-impls pform))
362365
(define impl-rules
363366
(for/list ([impl (in-list impls)])
364-
(hash-ref!
365-
(*lifting-rules*)
366-
(cons impl pform)
367-
(lambda ()
368-
(define name (sym-append 'lift- impl))
369-
(define itypes (impl-info impl 'itype))
370-
(define otype (impl-info impl 'otype))
371-
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
372-
(define lhs (replace-vars-with-holes impl-expr vars itypes))
373-
(define rhs `($hole ,(representation-name otype) ,spec-expr))
374-
375-
; Lifting rule applies a rewrite like:
376-
; (+.f64 ($hole 'binary64 x) ($hole 'binary64 y)) -> ($hole binary64 (+ x y))
377-
; (hypot.f64 ($hole 'binary64 x) ($hole 'binary64 y)) -> ($hole binary64 (sqrt (* x x) (* y y))))
378-
(rule name lhs rhs (map cons vars itypes) otype '(lifting))))))
367+
(hash-ref! (*lifting-rules*)
368+
(cons impl pform)
369+
(lambda ()
370+
(define name (sym-append 'lift- impl))
371+
(define itypes (impl-info impl 'itype))
372+
(define otype (impl-info impl 'otype))
373+
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
374+
(define lhs (replace-vars-with-holes impl-expr vars itypes))
375+
(define rhs `($hole ,(representation-name otype) ,spec-expr))
376+
(rule name lhs rhs (map cons vars itypes) otype '(lifting))))))
377+
378+
(define lift-literal-rule
379+
(rule 'lift-literal
380+
'($literal a repr)
381+
'($hole repr a)
382+
'((a . real) (repr . real))
383+
'real
384+
'(lifting)))
385+
386+
(define lift-approx-rule
387+
(rule 'lift-approx
388+
'($approx s ($hole r t))
389+
'($hole r ($approx s t))
390+
'((s . real) (r . real) (t . real))
391+
'real
392+
'(lifting)))
393+
394+
(define lift-if-rule
395+
(rule 'lift-if
396+
'(if ($hole bool c)
397+
($hole r t)
398+
($hole r f))
399+
'($hole r (if c t f))
400+
'((c . real) (r . real) (t . real) (f . real))
401+
'real
402+
'(lifting)))
403+
379404
;; special rule for approx nodes
380405
; (define approx-rule (rule 'lift-approx (approx 'a 'b) 'a '((a . real) (b . real)) 'real))
381406
; (cons approx-rule impl-rules))
382-
impl-rules)
407+
(list* lift-if-rule lift-approx-rule lift-literal-rule impl-rules))
383408

384409
;; Synthesizes lowering rules for a given platform.
410+
;; Lowering rules apply a rewrite like
411+
;; ($hole binary64 (+ x y)) -> (+.f64 ($hole binary64 x) ($hole binary64 y))
412+
;; question? what to do when we may end up with ($hole binary64 ($hole binary32 x)) ?
385413
(define (platform-lowering-rules [pform (*active-platform*)])
386414
(define impls (platform-impls pform))
387415
(for/list ([impl (in-list impls)])
@@ -392,26 +420,14 @@
392420
(define itypes (impl-info impl 'itype))
393421
(define otype (impl-info impl 'otype))
394422
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
395-
396-
; shrinking lowering rules
397-
; ($hole binary64 (+ x y)) -> (+.f64 ($hole binary64 x) ($hole binary64 y))
398-
; question? what to do when we may end up with ($hole binary64 ($hole binary32 x)) ?
399-
(define name* (sym-append 'lower-shrink- impl))
400-
(define op (car spec-expr))
401-
(define op* (car impl-expr))
402-
(define lhs `($hole ,(representation-name otype) ,(list* op vars)))
403-
(define rhs
404-
`,(list* op* (map (λ (x y) `($hole ,(representation-name y) ,x)) vars itypes)))
405-
(define shrinking-lowering-rule
406-
(rule name*
407-
lhs
408-
rhs
409-
(map cons vars (map representation-type itypes))
410-
(representation-type otype)
411-
'(lowering)))
412-
413-
#;(rule name spec-expr impl-expr (map cons vars itypes) otype '(lowering))
414-
shrinking-lowering-rule))))
423+
(define lhs `($hole ,(representation-name otype) ,spec-expr))
424+
(define rhs (replace-vars-with-holes impl-expr vars itypes))
425+
(rule name
426+
lhs
427+
rhs
428+
(map cons vars (map representation-type itypes))
429+
(representation-type otype)
430+
'(lowering))))))
415431

416432
(define (expr-otype expr)
417433
(match expr

0 commit comments

Comments
 (0)